Skip to content

Commit a0bfe9f

Browse files
committed
fix mypy error, fix tests tol, move tests
1 parent 7d6d8f8 commit a0bfe9f

File tree

3 files changed

+46
-47
lines changed

3 files changed

+46
-47
lines changed

pytensor/tensor/nlinalg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ class MatrixInverse(Op):
9999
100100
"""
101101

102-
__props__ = ()
102+
__props__: tuple[str, ...] = ()
103103
gufunc_signature = "(m,m)->(m,m)"
104104
gufunc_spec = ("numpy.linalg.inv", 1, 1)
105105

tests/tensor/rewriting/test_linalg.py

Lines changed: 2 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1068,49 +1068,6 @@ def solve_op_in_graph(graph):
10681068
)
10691069

10701070

1071-
@pytest.mark.parametrize("lower", [True, False])
1072-
def test_triangular_inv_op(lower):
1073-
"""Tests the TriangularInv Op directly."""
1074-
x = matrix("x", dtype=config.floatX)
1075-
f = function([x], TriangularInv(lower=lower)(x))
1076-
1077-
if lower:
1078-
a = np.tril(np.random.rand(5, 5) + 0.1).astype(config.floatX)
1079-
else:
1080-
a = np.triu(np.random.rand(5, 5) + 0.1).astype(config.floatX)
1081-
1082-
a_inv = f(a)
1083-
expected_inv = np.linalg.inv(a)
1084-
1085-
# Clean the NumPy result before comparing.
1086-
if lower:
1087-
expected_inv = np.tril(expected_inv)
1088-
else:
1089-
expected_inv = np.triu(expected_inv)
1090-
1091-
# The inverse of a triangular matrix is also triangular.
1092-
# We should check the full matrix, not just a part of it.
1093-
assert_allclose(
1094-
a_inv, expected_inv, rtol=1e-7 if config.floatX == "float64" else 1e-5
1095-
)
1096-
1097-
1098-
def test_triangular_inv_op_nan_on_error():
1099-
"""
1100-
Tests the `on_error='nan'` functionality of the TriangularInv Op.
1101-
"""
1102-
x = matrix("x", dtype=config.floatX)
1103-
f_nan = function([x], TriangularInv(on_error="nan")(x))
1104-
1105-
# Create a singular triangular matrix (zero on the diagonal)
1106-
a_singular = np.tril(np.random.rand(5, 5))
1107-
a_singular[2, 2] = 0
1108-
a_singular = a_singular.astype(config.floatX)
1109-
1110-
res = f_nan(a_singular)
1111-
assert np.all(np.isnan(res))
1112-
1113-
11141071
def _check_op_in_graph(fgraph, op_type, present=True):
11151072
"""Helper to check if an Op is in a graph."""
11161073

@@ -1186,6 +1143,5 @@ def test_inv_to_triangular_inv_rewrite(case):
11861143
pytensor_result = f(a)
11871144
_, numpy_tri_func = rewrite_cases[case]
11881145
numpy_result = np.linalg.inv(numpy_tri_func(a))
1189-
assert_allclose(
1190-
pytensor_result, numpy_result, rtol=1e-7 if config.floatX == "float64" else 1e-5
1191-
)
1146+
atol = rtol = 1e-8 if config.floatX.endswith("64") else 1e-4
1147+
np.testing.assert_allclose(pytensor_result, numpy_result, rtol=rtol, atol=atol)

tests/tensor/test_slinalg.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
Solve,
2020
SolveBase,
2121
SolveTriangular,
22+
TriangularInv,
2223
block_diag,
2324
cho_solve,
2425
cholesky,
@@ -1192,3 +1193,45 @@ def _test_fn(x, case=2, mode="reduced"):
11921193
utt.verify_grad(
11931194
partial(_test_fn, case=gradient_test_case, mode=mode), [a], rng=np.random
11941195
)
1196+
1197+
1198+
@pytest.mark.parametrize("lower", [True, False])
1199+
def test_triangular_inv_op(lower):
1200+
"""Tests the TriangularInv Op directly."""
1201+
x = matrix("x", dtype=config.floatX)
1202+
f = function([x], TriangularInv(lower=lower)(x))
1203+
1204+
if lower:
1205+
a = np.tril(np.random.rand(5, 5) + 0.1).astype(config.floatX)
1206+
else:
1207+
a = np.triu(np.random.rand(5, 5) + 0.1).astype(config.floatX)
1208+
1209+
a_inv = f(a)
1210+
expected_inv = np.linalg.inv(a)
1211+
1212+
# Clean the NumPy result before comparing.
1213+
if lower:
1214+
expected_inv = np.tril(expected_inv)
1215+
else:
1216+
expected_inv = np.triu(expected_inv)
1217+
1218+
# The inverse of a triangular matrix is also triangular.
1219+
# We should check the full matrix, not just a part of it.
1220+
atol = rtol = 1e-8 if config.floatX.endswith("64") else 1e-4
1221+
np.testing.assert_allclose(a_inv, expected_inv, rtol=rtol, atol=atol)
1222+
1223+
1224+
def test_triangular_inv_op_nan_on_error():
1225+
"""
1226+
Tests the `on_error='nan'` functionality of the TriangularInv Op.
1227+
"""
1228+
x = matrix("x", dtype=config.floatX)
1229+
f_nan = function([x], TriangularInv(on_error="nan")(x))
1230+
1231+
# Create a singular triangular matrix (zero on the diagonal)
1232+
a_singular = np.tril(np.random.rand(5, 5))
1233+
a_singular[2, 2] = 0
1234+
a_singular = a_singular.astype(config.floatX)
1235+
1236+
res = f_nan(a_singular)
1237+
assert np.all(np.isnan(res))

0 commit comments

Comments
 (0)