Skip to content

Commit 7d7fcef

Browse files
committed
add triangular rewrite
1 parent 1dc982c commit 7d7fcef

File tree

1 file changed

+54
-0
lines changed

1 file changed

+54
-0
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
ExtractDiag,
2020
Eye,
2121
TensorVariable,
22+
Tri,
2223
concatenate,
2324
diag,
2425
diagonal,
@@ -1062,3 +1063,56 @@ def scalar_solve_to_division(fgraph, node):
10621063
copy_stack_trace(old_out, new_out)
10631064

10641065
return [new_out]
1066+
1067+
1068+
def _find_triangular_op(var):
1069+
"""
1070+
Inspects a variable to see if it's triangular.
1071+
1072+
Returns a tuple (is_lower, is_upper) if triangular, otherwise None.
1073+
"""
1074+
1075+
is_lower = getattr(var.tag, "lower_triangular", False)
1076+
is_upper = getattr(var.tag, "upper_triangular", False)
1077+
1078+
if is_lower or is_upper:
1079+
return (is_lower, is_upper)
1080+
1081+
if var.owner and isinstance(var.owner.op, Tri):
1082+
# The 'k' parameter of Tri determines the diagonal.
1083+
# k=0 is the main diagonal.
1084+
k = var.owner.op.k
1085+
if k == 0:
1086+
is_lower = var.owner.op.lower
1087+
return (is_lower, not is_lower)
1088+
1089+
if var.owner and isinstance(var.owner.op, Blockwise):
1090+
core_op = var.owner.op.core_op
1091+
if isinstance(core_op, Cholesky):
1092+
return (core_op.lower, not core_op.lower)
1093+
1094+
return None
1095+
1096+
1097+
@register_canonicalize
1098+
@register_stabilize
1099+
@node_rewriter([Blockwise])
1100+
def rewrite_inv_to_triangular_solve(fgraph, node):
1101+
"""
1102+
This rewrite takes advantage of the fact that the inverse of a triangular
1103+
matrix can be computed more efficiently than the inverse of a general
1104+
matrix by using a triangular solve instead of a general matrix inverse.
1105+
"""
1106+
core_op = node.op.core_op
1107+
if not isinstance(core_op, ALL_INVERSE_OPS):
1108+
return None
1109+
1110+
A = node.inputs[0]
1111+
triangular_info = _find_triangular_op(A)
1112+
if triangular_info is None:
1113+
return None
1114+
1115+
is_lower, is_upper = triangular_info
1116+
if is_lower or is_upper:
1117+
I = pt.eye(A.shape[0], dtype=A.dtype)
1118+
return [solve_triangular(A, I, lower=is_lower)]

0 commit comments

Comments
 (0)