|
11 | 11 | from pytensor.compile import get_default_mode |
12 | 12 | from pytensor.configdefaults import config |
13 | 13 | from pytensor.graph import ancestors |
| 14 | +from pytensor.graph.fg import FunctionGraph |
14 | 15 | from pytensor.graph.rewriting.utils import rewrite_graph |
15 | 16 | from pytensor.tensor import swapaxes |
| 17 | +from pytensor.tensor.basic import tril, triu |
16 | 18 | from pytensor.tensor.blockwise import Blockwise |
17 | 19 | from pytensor.tensor.elemwise import DimShuffle |
18 | 20 | from pytensor.tensor.math import dot, matmul |
|
38 | 40 | TriangularInv, |
39 | 41 | cho_solve, |
40 | 42 | cholesky, |
| 43 | + lu, |
| 44 | + qr, |
41 | 45 | solve, |
42 | 46 | solve_triangular, |
43 | 47 | ) |
@@ -1064,42 +1068,121 @@ def solve_op_in_graph(graph): |
1064 | 1068 | ) |
1065 | 1069 |
|
1066 | 1070 |
|
1067 | | -def test_triangular_inv_op(): |
| 1071 | +@pytest.mark.parametrize("lower", [True, False]) |
| 1072 | +def test_triangular_inv_op(lower): |
| 1073 | + """Tests the TriangularInv Op directly.""" |
1068 | 1074 | x = matrix("x") |
1069 | | - f_lower = function([x], Blockwise(TriangularInv(lower=True))(x)) |
1070 | | - f_upper = function([x], Blockwise(TriangularInv(lower=False))(x)) |
| 1075 | + f = function([x], TriangularInv(lower=lower)(x)) |
1071 | 1076 |
|
1072 | | - # Test lower |
1073 | | - a = np.tril(np.random.rand(5, 5) + 0.1) |
1074 | | - a_inv = f_lower(a) |
1075 | | - expected_inv = np.linalg.inv(a) |
1076 | | - np.testing.assert_allclose( |
1077 | | - np.tril(a_inv), np.tril(expected_inv), rtol=1e-5, atol=1e-7 |
1078 | | - ) |
| 1077 | + if lower: |
| 1078 | + a = np.tril(np.random.rand(5, 5) + 0.1) |
| 1079 | + else: |
| 1080 | + a = np.triu(np.random.rand(5, 5) + 0.1) |
1079 | 1081 |
|
1080 | | - # Test upper |
1081 | | - a = np.triu(np.random.rand(5, 5) + 0.1) |
1082 | | - a_inv = f_upper(a) |
| 1082 | + a_inv = f(a) |
1083 | 1083 | expected_inv = np.linalg.inv(a) |
1084 | | - np.testing.assert_allclose( |
1085 | | - np.triu(a_inv), np.triu(expected_inv), rtol=1e-5, atol=1e-7 |
| 1084 | + |
| 1085 | + # Clean the NumPy result before comparing. |
| 1086 | + if lower: |
| 1087 | + expected_inv = np.tril(expected_inv) |
| 1088 | + else: |
| 1089 | + expected_inv = np.triu(expected_inv) |
| 1090 | + |
| 1091 | + # The inverse of a triangular matrix is also triangular. |
| 1092 | + # We should check the full matrix, not just a part of it. |
| 1093 | + assert_allclose( |
| 1094 | + a_inv, expected_inv, rtol=1e-7 if config.floatX == "float64" else 1e-5 |
1086 | 1095 | ) |
1087 | 1096 |
|
1088 | 1097 |
|
1089 | | -def test_inv_to_triangular_inv_rewrite(): |
| 1098 | +def test_triangular_inv_op_nan_on_error(): |
| 1099 | + """ |
| 1100 | + Tests the `on_error='nan'` functionality of the TriangularInv Op. |
| 1101 | + """ |
1090 | 1102 | x = matrix("x") |
| 1103 | + f_nan = function([x], TriangularInv(on_error="nan")(x)) |
| 1104 | + |
| 1105 | + # Create a singular triangular matrix (zero on the diagonal) |
| 1106 | + a_singular = np.tril(np.random.rand(5, 5)) |
| 1107 | + a_singular[2, 2] = 0 |
1091 | 1108 |
|
1092 | | - x_chol = cholesky(x) |
1093 | | - y_chol = inv(x_chol) |
1094 | | - f_chol = function([x], y_chol) |
1095 | | - assert any( |
1096 | | - isinstance(node.op, TriangularInv) |
1097 | | - or (hasattr(node.op, "core_op") and isinstance(node.op.core_op, TriangularInv)) |
1098 | | - for node in f_chol.maker.fgraph.apply_nodes |
| 1109 | + res = f_nan(a_singular) |
| 1110 | + assert np.all(np.isnan(res)) |
| 1111 | + |
| 1112 | + |
| 1113 | +def _check_op_in_graph(fgraph, op_type, present=True): |
| 1114 | + """Helper to check if an Op is in a graph.""" |
| 1115 | + |
| 1116 | + # We use type() instead of isinstance() to avoid matching subclasses |
| 1117 | + # (e.g., finding TriangularInv when we're looking for MatrixInverse). |
| 1118 | + found = any( |
| 1119 | + type(node.op) is op_type |
| 1120 | + or (hasattr(node.op, "core_op") and type(node.op.core_op) is op_type) |
| 1121 | + for node in fgraph.apply_nodes |
1099 | 1122 | ) |
| 1123 | + if present: |
| 1124 | + assert found, f"{op_type.__name__} not found in graph" |
| 1125 | + else: |
| 1126 | + assert not found, f"{op_type.__name__} unexpectedly found in graph" |
| 1127 | + |
| 1128 | + |
| 1129 | +rewrite_cases = { |
| 1130 | + "tril": ( |
| 1131 | + lambda x: tril(x), |
| 1132 | + lambda a: np.tril(a), |
| 1133 | + ), |
| 1134 | + "triu": ( |
| 1135 | + lambda x: triu(x), |
| 1136 | + lambda a: np.triu(a), |
| 1137 | + ), |
| 1138 | + "cholesky": ( |
| 1139 | + lambda x: cholesky(x), |
| 1140 | + lambda a: np.linalg.cholesky(a), |
| 1141 | + ), |
| 1142 | + "lu_L": ( |
| 1143 | + lambda x: lu(x)[1], |
| 1144 | + lambda a: scipy.linalg.lu(a)[1], |
| 1145 | + ), |
| 1146 | + "lu_U": ( |
| 1147 | + lambda x: lu(x)[2], |
| 1148 | + lambda a: scipy.linalg.lu(a)[2], |
| 1149 | + ), |
| 1150 | + "qr_R": ( |
| 1151 | + lambda x: qr(x)[1], |
| 1152 | + lambda a: np.linalg.qr(a)[1], |
| 1153 | + ), |
| 1154 | +} |
| 1155 | + |
| 1156 | + |
| 1157 | +@pytest.mark.parametrize("case", rewrite_cases.keys()) |
| 1158 | +def test_inv_to_triangular_inv_rewrite(case): |
| 1159 | + """ |
| 1160 | + Tests the rewrite of inv(triangular) -> TriangularInv. |
| 1161 | + """ |
| 1162 | + x = matrix("x") |
| 1163 | + build_tri, _ = rewrite_cases[case] |
| 1164 | + x_tri = build_tri(x) |
| 1165 | + y_inv = inv(x_tri) |
1100 | 1166 |
|
| 1167 | + # Check graph BEFORE compilation |
| 1168 | + pre_compile_fgraph = FunctionGraph([x], [y_inv], clone=False) |
| 1169 | + _check_op_in_graph(pre_compile_fgraph, MatrixInverse, present=True) |
| 1170 | + _check_op_in_graph(pre_compile_fgraph, TriangularInv, present=False) |
| 1171 | + |
| 1172 | + # Trigger the rewrite |
| 1173 | + f = function([x], y_inv) |
| 1174 | + |
| 1175 | + # Check graph AFTER compilation |
| 1176 | + post_compile_fgraph = f.maker.fgraph |
| 1177 | + _check_op_in_graph(post_compile_fgraph, TriangularInv, present=True) |
| 1178 | + _check_op_in_graph(post_compile_fgraph, MatrixInverse, present=False) |
| 1179 | + |
| 1180 | + # Check numerical correctness |
1101 | 1181 | a = np.random.rand(5, 5) |
1102 | | - a = np.dot(a, a.T) + np.eye(5) * 0.1 # ensure positive definite |
1103 | | - np.testing.assert_allclose( |
1104 | | - f_chol(a), np.linalg.inv(np.linalg.cholesky(a)), rtol=1e-5, atol=1e-7 |
| 1182 | + a = np.dot(a, a.T) + np.eye(5) # Make positive definite for Cholesky |
| 1183 | + pytensor_result = f(a) |
| 1184 | + _, numpy_tri_func = rewrite_cases[case] |
| 1185 | + numpy_result = np.linalg.inv(numpy_tri_func(a)) |
| 1186 | + assert_allclose( |
| 1187 | + pytensor_result, numpy_result, rtol=1e-7 if config.floatX == "float64" else 1e-5 |
1105 | 1188 | ) |
0 commit comments