Skip to content

Commit c6404c0

Browse files
committed
address review comments;
add other conditions to trigger rewrite enhance TriInv Op add tests
1 parent 97dde20 commit c6404c0

File tree

3 files changed

+207
-41
lines changed

3 files changed

+207
-41
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 68 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,22 @@
88
from pytensor import tensor as pt
99
from pytensor.compile import optdb
1010
from pytensor.graph import Apply, FunctionGraph
11+
from pytensor.graph.basic import Constant
1112
from pytensor.graph.rewriting.basic import (
1213
copy_stack_trace,
1314
dfs_rewriter,
1415
node_rewriter,
1516
)
1617
from pytensor.graph.rewriting.unify import OpPattern
1718
from pytensor.scalar.basic import Abs, Log, Mul, Sign
19+
from pytensor.scalar.basic import Mul as ScalarMul
20+
from pytensor.scalar.basic import Sub as ScalarSub
1821
from pytensor.tensor.basic import (
1922
AllocDiag,
2023
ExtractDiag,
2124
Eye,
2225
TensorVariable,
26+
Tri,
2327
concatenate,
2428
diag,
2529
diagonal,
@@ -46,9 +50,12 @@
4650
)
4751
from pytensor.tensor.rewriting.blockwise import blockwise_of
4852
from pytensor.tensor.slinalg import (
53+
LU,
54+
QR,
4955
BlockDiagonal,
5056
Cholesky,
5157
CholeskySolve,
58+
LUFactor,
5259
Solve,
5360
SolveBase,
5461
SolveTriangular,
@@ -1026,17 +1033,69 @@ def _find_triangular_op(var):
10261033
10271034
Returns a tuple (is_lower, is_upper) if triangular, otherwise None.
10281035
"""
1029-
1036+
# Case 1: Check for an explicit tag
10301037
is_lower = getattr(var.tag, "lower_triangular", False)
10311038
is_upper = getattr(var.tag, "upper_triangular", False)
1032-
10331039
if is_lower or is_upper:
10341040
return (is_lower, is_upper)
10351041

1036-
if var.owner and isinstance(var.owner.op, Blockwise):
1037-
core_op = var.owner.op.core_op
1038-
if isinstance(core_op, Cholesky):
1039-
return (core_op.lower, not core_op.lower)
1042+
if not var.owner:
1043+
return None
1044+
1045+
op = var.owner.op
1046+
core_op = op.core_op if isinstance(op, Blockwise) else op
1047+
1048+
# Case 2: Check for direct creator Ops
1049+
if isinstance(core_op, Cholesky):
1050+
return (core_op.lower, not core_op.lower)
1051+
1052+
if isinstance(core_op, LU | LUFactor):
1053+
if var.owner.outputs[1] == var:
1054+
return (True, False)
1055+
if var.owner.outputs[2] == var:
1056+
return (False, True)
1057+
1058+
if isinstance(core_op, QR):
1059+
if var.owner.outputs[1] == var:
1060+
return (False, True)
1061+
1062+
# pt.tri will get constant folded so no point re-writing ?
1063+
# if isinstance(core_op, Tri):
1064+
# k_node = var.owner.inputs[2]
1065+
# if isinstance(k_node, Constant) and k_node.data == 0:
1066+
# print('re-writing ... ')
1067+
# return (True, False)
1068+
1069+
# Case 3: tril/triu patterns which are implemented as Mul
1070+
if isinstance(core_op, Elemwise) and isinstance(core_op.scalar_op, ScalarMul):
1071+
other_inp = next(
1072+
(i for i in var.owner.inputs if i != var.owner.inputs[0]), None
1073+
)
1074+
1075+
if other_inp is not None and other_inp.owner:
1076+
# Check for tril pattern: Mul(x, Tri(...))
1077+
if isinstance(other_inp.owner.op, Tri):
1078+
k_node = other_inp.owner.inputs[2]
1079+
if isinstance(k_node, Constant) and k_node.data == 0:
1080+
return (True, False) # It's tril
1081+
1082+
# Check for triu pattern: Mul(x, Sub(1, Tri(k=-1)))
1083+
sub_op = other_inp.owner.op
1084+
if isinstance(sub_op, Elemwise) and isinstance(sub_op.scalar_op, ScalarSub):
1085+
sub_inputs = other_inp.owner.inputs
1086+
const_one = next(
1087+
(i for i in sub_inputs if isinstance(i, Constant) and i.data == 1),
1088+
None,
1089+
)
1090+
tri_inp = next(
1091+
(i for i in sub_inputs if i.owner and isinstance(i.owner.op, Tri)),
1092+
None,
1093+
)
1094+
1095+
if const_one is not None and tri_inp is not None:
1096+
k_node = tri_inp.owner.inputs[2]
1097+
if isinstance(k_node, Constant) and k_node.data == -1:
1098+
return (False, True) # It's triu
10401099

10411100
return None
10421101

@@ -1059,4 +1118,6 @@ def rewrite_inv_to_triangular_solve(fgraph, node):
10591118
is_lower, is_upper = triangular_info
10601119
if is_lower or is_upper:
10611120
new_op = TriangularInv(lower=is_lower)
1062-
return [new_op(A)]
1121+
new_inv = new_op(A)
1122+
copy_stack_trace(node.outputs[0], new_inv)
1123+
return [new_inv]

pytensor/tensor/slinalg.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1021,24 +1021,46 @@ class TriangularInv(MatrixInverse):
10211021
Computes the inverse of a triangular matrix.
10221022
"""
10231023

1024-
__props__ = ("lower",)
1024+
__props__ = ("lower", "on_error", "overwrite_a")
10251025

1026-
def __init__(self, lower=True):
1026+
def __init__(self, lower=True, on_error="raise", overwrite_a=False):
10271027
self.lower = lower
1028+
if on_error not in ("raise", "nan"):
1029+
raise ValueError('on_error must be one of "raise" or "nan"')
1030+
self.on_error = on_error
1031+
self.overwrite_a = overwrite_a
1032+
1033+
if self.overwrite_a:
1034+
self.destroy_map = {0: [0]}
10281035

10291036
def perform(self, node, inputs, outputs):
10301037
(x,) = inputs
10311038
(z,) = outputs
10321039
(dtrtri,) = get_lapack_funcs(("trtri",), (x,))
10331040
inv, info = dtrtri(x, lower=self.lower, overwrite_c=True)
1034-
if info > 0:
1035-
raise np.linalg.LinAlgError("Singular matrix")
1036-
elif info < 0:
1037-
raise ValueError(
1038-
"illegal value in %d-th argument of internal trtri" % -info
1039-
)
1041+
if info != 0:
1042+
if self.on_error == "nan":
1043+
z[0] = np.full_like(x, np.nan)
1044+
return
1045+
elif info > 0:
1046+
raise np.linalg.LinAlgError("Singular matrix")
1047+
elif info < 0:
1048+
raise ValueError(
1049+
f"illegal value in {-info}-th argument of internal trtri"
1050+
)
10401051
z[0] = inv
10411052

1053+
def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
1054+
"""
1055+
Allows this Op to overwrite its input buffer with its output.
1056+
"""
1057+
if not allowed_inplace_inputs:
1058+
return self
1059+
1060+
new_props = self._props_dict()
1061+
new_props["overwrite_a"] = True
1062+
return type(self)(**new_props)
1063+
10421064

10431065
class Solve(SolveBase):
10441066
"""

tests/tensor/rewriting/test_linalg.py

Lines changed: 109 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111
from pytensor.compile import get_default_mode
1212
from pytensor.configdefaults import config
1313
from pytensor.graph import ancestors
14+
from pytensor.graph.fg import FunctionGraph
1415
from pytensor.graph.rewriting.utils import rewrite_graph
1516
from pytensor.tensor import swapaxes
17+
from pytensor.tensor.basic import tril, triu
1618
from pytensor.tensor.blockwise import Blockwise
1719
from pytensor.tensor.elemwise import DimShuffle
1820
from pytensor.tensor.math import dot, matmul
@@ -38,6 +40,8 @@
3840
TriangularInv,
3941
cho_solve,
4042
cholesky,
43+
lu,
44+
qr,
4145
solve,
4246
solve_triangular,
4347
)
@@ -1064,42 +1068,121 @@ def solve_op_in_graph(graph):
10641068
)
10651069

10661070

1067-
def test_triangular_inv_op():
1071+
@pytest.mark.parametrize("lower", [True, False])
1072+
def test_triangular_inv_op(lower):
1073+
"""Tests the TriangularInv Op directly."""
10681074
x = matrix("x")
1069-
f_lower = function([x], Blockwise(TriangularInv(lower=True))(x))
1070-
f_upper = function([x], Blockwise(TriangularInv(lower=False))(x))
1075+
f = function([x], TriangularInv(lower=lower)(x))
10711076

1072-
# Test lower
1073-
a = np.tril(np.random.rand(5, 5) + 0.1)
1074-
a_inv = f_lower(a)
1075-
expected_inv = np.linalg.inv(a)
1076-
np.testing.assert_allclose(
1077-
np.tril(a_inv), np.tril(expected_inv), rtol=1e-5, atol=1e-7
1078-
)
1077+
if lower:
1078+
a = np.tril(np.random.rand(5, 5) + 0.1)
1079+
else:
1080+
a = np.triu(np.random.rand(5, 5) + 0.1)
10791081

1080-
# Test upper
1081-
a = np.triu(np.random.rand(5, 5) + 0.1)
1082-
a_inv = f_upper(a)
1082+
a_inv = f(a)
10831083
expected_inv = np.linalg.inv(a)
1084-
np.testing.assert_allclose(
1085-
np.triu(a_inv), np.triu(expected_inv), rtol=1e-5, atol=1e-7
1084+
1085+
# Clean the NumPy result before comparing.
1086+
if lower:
1087+
expected_inv = np.tril(expected_inv)
1088+
else:
1089+
expected_inv = np.triu(expected_inv)
1090+
1091+
# The inverse of a triangular matrix is also triangular.
1092+
# We should check the full matrix, not just a part of it.
1093+
assert_allclose(
1094+
a_inv, expected_inv, rtol=1e-7 if config.floatX == "float64" else 1e-5
10861095
)
10871096

10881097

1089-
def test_inv_to_triangular_inv_rewrite():
1098+
def test_triangular_inv_op_nan_on_error():
1099+
"""
1100+
Tests the `on_error='nan'` functionality of the TriangularInv Op.
1101+
"""
10901102
x = matrix("x")
1103+
f_nan = function([x], TriangularInv(on_error="nan")(x))
1104+
1105+
# Create a singular triangular matrix (zero on the diagonal)
1106+
a_singular = np.tril(np.random.rand(5, 5))
1107+
a_singular[2, 2] = 0
10911108

1092-
x_chol = cholesky(x)
1093-
y_chol = inv(x_chol)
1094-
f_chol = function([x], y_chol)
1095-
assert any(
1096-
isinstance(node.op, TriangularInv)
1097-
or (hasattr(node.op, "core_op") and isinstance(node.op.core_op, TriangularInv))
1098-
for node in f_chol.maker.fgraph.apply_nodes
1109+
res = f_nan(a_singular)
1110+
assert np.all(np.isnan(res))
1111+
1112+
1113+
def _check_op_in_graph(fgraph, op_type, present=True):
1114+
"""Helper to check if an Op is in a graph."""
1115+
1116+
# We use type() instead of isinstance() to avoid matching subclasses
1117+
# (e.g., finding TriangularInv when we're looking for MatrixInverse).
1118+
found = any(
1119+
type(node.op) is op_type
1120+
or (hasattr(node.op, "core_op") and type(node.op.core_op) is op_type)
1121+
for node in fgraph.apply_nodes
10991122
)
1123+
if present:
1124+
assert found, f"{op_type.__name__} not found in graph"
1125+
else:
1126+
assert not found, f"{op_type.__name__} unexpectedly found in graph"
1127+
1128+
1129+
rewrite_cases = {
1130+
"tril": (
1131+
lambda x: tril(x),
1132+
lambda a: np.tril(a),
1133+
),
1134+
"triu": (
1135+
lambda x: triu(x),
1136+
lambda a: np.triu(a),
1137+
),
1138+
"cholesky": (
1139+
lambda x: cholesky(x),
1140+
lambda a: np.linalg.cholesky(a),
1141+
),
1142+
"lu_L": (
1143+
lambda x: lu(x)[1],
1144+
lambda a: scipy.linalg.lu(a)[1],
1145+
),
1146+
"lu_U": (
1147+
lambda x: lu(x)[2],
1148+
lambda a: scipy.linalg.lu(a)[2],
1149+
),
1150+
"qr_R": (
1151+
lambda x: qr(x)[1],
1152+
lambda a: np.linalg.qr(a)[1],
1153+
),
1154+
}
1155+
1156+
1157+
@pytest.mark.parametrize("case", rewrite_cases.keys())
1158+
def test_inv_to_triangular_inv_rewrite(case):
1159+
"""
1160+
Tests the rewrite of inv(triangular) -> TriangularInv.
1161+
"""
1162+
x = matrix("x")
1163+
build_tri, _ = rewrite_cases[case]
1164+
x_tri = build_tri(x)
1165+
y_inv = inv(x_tri)
11001166

1167+
# Check graph BEFORE compilation
1168+
pre_compile_fgraph = FunctionGraph([x], [y_inv], clone=False)
1169+
_check_op_in_graph(pre_compile_fgraph, MatrixInverse, present=True)
1170+
_check_op_in_graph(pre_compile_fgraph, TriangularInv, present=False)
1171+
1172+
# Trigger the rewrite
1173+
f = function([x], y_inv)
1174+
1175+
# Check graph AFTER compilation
1176+
post_compile_fgraph = f.maker.fgraph
1177+
_check_op_in_graph(post_compile_fgraph, TriangularInv, present=True)
1178+
_check_op_in_graph(post_compile_fgraph, MatrixInverse, present=False)
1179+
1180+
# Check numerical correctness
11011181
a = np.random.rand(5, 5)
1102-
a = np.dot(a, a.T) + np.eye(5) * 0.1 # ensure positive definite
1103-
np.testing.assert_allclose(
1104-
f_chol(a), np.linalg.inv(np.linalg.cholesky(a)), rtol=1e-5, atol=1e-7
1182+
a = np.dot(a, a.T) + np.eye(5) # Make positive definite for Cholesky
1183+
pytensor_result = f(a)
1184+
_, numpy_tri_func = rewrite_cases[case]
1185+
numpy_result = np.linalg.inv(numpy_tri_func(a))
1186+
assert_allclose(
1187+
pytensor_result, numpy_result, rtol=1e-7 if config.floatX == "float64" else 1e-5
11051188
)

0 commit comments

Comments
 (0)