Skip to content

Commit 07c48f3

Browse files
committed
test tag cases
1 parent d0dbf0e commit 07c48f3

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

tests/tensor/rewriting/test_linalg.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1084,6 +1084,14 @@ def _check_op_in_graph(fgraph, op_type, present=True):
10841084

10851085

10861086
rewrite_cases = {
1087+
"lower_tag": (
1088+
lambda x: (setattr(x.tag, "lower_triangular", True), x)[-1],
1089+
lambda a: np.tril(a),
1090+
),
1091+
"upper_tag": (
1092+
lambda x: (setattr(x.tag, "upper_triangular", True), x)[-1],
1093+
lambda a: np.triu(a),
1094+
),
10871095
"tri": (
10881096
lambda x: pt.tri(x.shape[0], x.shape[1], k=0, dtype=x.dtype),
10891097
lambda a: np.tri(N=a.shape[0], M=a.shape[1], k=0, dtype=a.dtype),
@@ -1143,6 +1151,10 @@ def test_inv_to_triangular_inv_rewrite(case):
11431151
a = (np.dot(a, a.T) + np.eye(5)).astype(
11441152
config.floatX
11451153
) # Make positive definite for Cholesky
1154+
if case == "lower_tag":
1155+
a = np.tril(a)
1156+
elif case == "upper_tag":
1157+
a = np.triu(a)
11461158
pytensor_result = f(a)
11471159
_, numpy_tri_func = rewrite_cases[case]
11481160
numpy_result = np.linalg.inv(numpy_tri_func(a))

0 commit comments

Comments
 (0)