Skip to content

Commit c7af980

Browse files
committed
fix mypy error, fix tests tol, move tests
1 parent 4c5e21d commit c7af980

File tree

3 files changed

+46
-47
lines changed

3 files changed

+46
-47
lines changed

pytensor/tensor/nlinalg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ class MatrixInverse(Op):
107107
108108
"""
109109

110-
__props__ = ()
110+
__props__: tuple[str, ...] = ()
111111
gufunc_signature = "(m,m)->(m,m)"
112112
gufunc_spec = ("numpy.linalg.inv", 1, 1)
113113

tests/tensor/rewriting/test_linalg.py

Lines changed: 2 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1068,49 +1068,6 @@ def solve_op_in_graph(graph):
10681068
)
10691069

10701070

1071-
@pytest.mark.parametrize("lower", [True, False])
1072-
def test_triangular_inv_op(lower):
1073-
"""Tests the TriangularInv Op directly."""
1074-
x = matrix("x", dtype=config.floatX)
1075-
f = function([x], TriangularInv(lower=lower)(x))
1076-
1077-
if lower:
1078-
a = np.tril(np.random.rand(5, 5) + 0.1).astype(config.floatX)
1079-
else:
1080-
a = np.triu(np.random.rand(5, 5) + 0.1).astype(config.floatX)
1081-
1082-
a_inv = f(a)
1083-
expected_inv = np.linalg.inv(a)
1084-
1085-
# Clean the NumPy result before comparing.
1086-
if lower:
1087-
expected_inv = np.tril(expected_inv)
1088-
else:
1089-
expected_inv = np.triu(expected_inv)
1090-
1091-
# The inverse of a triangular matrix is also triangular.
1092-
# We should check the full matrix, not just a part of it.
1093-
assert_allclose(
1094-
a_inv, expected_inv, rtol=1e-7 if config.floatX == "float64" else 1e-5
1095-
)
1096-
1097-
1098-
def test_triangular_inv_op_nan_on_error():
1099-
"""
1100-
Tests the `on_error='nan'` functionality of the TriangularInv Op.
1101-
"""
1102-
x = matrix("x", dtype=config.floatX)
1103-
f_nan = function([x], TriangularInv(on_error="nan")(x))
1104-
1105-
# Create a singular triangular matrix (zero on the diagonal)
1106-
a_singular = np.tril(np.random.rand(5, 5))
1107-
a_singular[2, 2] = 0
1108-
a_singular = a_singular.astype(config.floatX)
1109-
1110-
res = f_nan(a_singular)
1111-
assert np.all(np.isnan(res))
1112-
1113-
11141071
def _check_op_in_graph(fgraph, op_type, present=True):
11151072
"""Helper to check if an Op is in a graph."""
11161073

@@ -1186,6 +1143,5 @@ def test_inv_to_triangular_inv_rewrite(case):
11861143
pytensor_result = f(a)
11871144
_, numpy_tri_func = rewrite_cases[case]
11881145
numpy_result = np.linalg.inv(numpy_tri_func(a))
1189-
assert_allclose(
1190-
pytensor_result, numpy_result, rtol=1e-7 if config.floatX == "float64" else 1e-5
1191-
)
1146+
atol = rtol = 1e-8 if config.floatX.endswith("64") else 1e-4
1147+
np.testing.assert_allclose(pytensor_result, numpy_result, rtol=rtol, atol=atol)

tests/tensor/test_slinalg.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
Solve,
2020
SolveBase,
2121
SolveTriangular,
22+
TriangularInv,
2223
block_diag,
2324
cho_solve,
2425
cholesky,
@@ -1235,3 +1236,45 @@ def _test_fn(x, case=2, mode="reduced"):
12351236
utt.verify_grad(
12361237
partial(_test_fn, case=gradient_test_case, mode=mode), [a], rng=np.random
12371238
)
1239+
1240+
1241+
@pytest.mark.parametrize("lower", [True, False])
1242+
def test_triangular_inv_op(lower):
1243+
"""Tests the TriangularInv Op directly."""
1244+
x = matrix("x", dtype=config.floatX)
1245+
f = function([x], TriangularInv(lower=lower)(x))
1246+
1247+
if lower:
1248+
a = np.tril(np.random.rand(5, 5) + 0.1).astype(config.floatX)
1249+
else:
1250+
a = np.triu(np.random.rand(5, 5) + 0.1).astype(config.floatX)
1251+
1252+
a_inv = f(a)
1253+
expected_inv = np.linalg.inv(a)
1254+
1255+
# Clean the NumPy result before comparing.
1256+
if lower:
1257+
expected_inv = np.tril(expected_inv)
1258+
else:
1259+
expected_inv = np.triu(expected_inv)
1260+
1261+
# The inverse of a triangular matrix is also triangular.
1262+
# We should check the full matrix, not just a part of it.
1263+
atol = rtol = 1e-8 if config.floatX.endswith("64") else 1e-4
1264+
np.testing.assert_allclose(a_inv, expected_inv, rtol=rtol, atol=atol)
1265+
1266+
1267+
def test_triangular_inv_op_nan_on_error():
1268+
"""
1269+
Tests the `on_error='nan'` functionality of the TriangularInv Op.
1270+
"""
1271+
x = matrix("x", dtype=config.floatX)
1272+
f_nan = function([x], TriangularInv(on_error="nan")(x))
1273+
1274+
# Create a singular triangular matrix (zero on the diagonal)
1275+
a_singular = np.tril(np.random.rand(5, 5))
1276+
a_singular[2, 2] = 0
1277+
a_singular = a_singular.astype(config.floatX)
1278+
1279+
res = f_nan(a_singular)
1280+
assert np.all(np.isnan(res))

0 commit comments

Comments
 (0)