Skip to content

Commit 34b3eb8

Browse files
committed
add triangular rewrite
1 parent 96122d1 commit 34b3eb8

File tree

1 file changed

+54
-0
lines changed

1 file changed

+54
-0
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
ExtractDiag,
2121
Eye,
2222
TensorVariable,
23+
Tri,
2324
concatenate,
2425
diag,
2526
diagonal,
@@ -1015,3 +1016,56 @@ def scalar_solve_to_division(fgraph, node):
10151016
copy_stack_trace(old_out, new_out)
10161017

10171018
return [new_out]
1019+
1020+
1021+
def _find_triangular_op(var):
1022+
"""
1023+
Inspects a variable to see if it's triangular.
1024+
1025+
Returns a tuple (is_lower, is_upper) if triangular, otherwise None.
1026+
"""
1027+
1028+
is_lower = getattr(var.tag, "lower_triangular", False)
1029+
is_upper = getattr(var.tag, "upper_triangular", False)
1030+
1031+
if is_lower or is_upper:
1032+
return (is_lower, is_upper)
1033+
1034+
if var.owner and isinstance(var.owner.op, Tri):
1035+
# The 'k' parameter of Tri determines the diagonal.
1036+
# k=0 is the main diagonal.
1037+
k = var.owner.op.k
1038+
if k == 0:
1039+
is_lower = var.owner.op.lower
1040+
return (is_lower, not is_lower)
1041+
1042+
if var.owner and isinstance(var.owner.op, Blockwise):
1043+
core_op = var.owner.op.core_op
1044+
if isinstance(core_op, Cholesky):
1045+
return (core_op.lower, not core_op.lower)
1046+
1047+
return None
1048+
1049+
1050+
@register_canonicalize
1051+
@register_stabilize
1052+
@node_rewriter([Blockwise])
1053+
def rewrite_inv_to_triangular_solve(fgraph, node):
1054+
"""
1055+
This rewrite takes advantage of the fact that the inverse of a triangular
1056+
matrix can be computed more efficiently than the inverse of a general
1057+
matrix by using a triangular solve instead of a general matrix inverse.
1058+
"""
1059+
core_op = node.op.core_op
1060+
if not isinstance(core_op, ALL_INVERSE_OPS):
1061+
return None
1062+
1063+
A = node.inputs[0]
1064+
triangular_info = _find_triangular_op(A)
1065+
if triangular_info is None:
1066+
return None
1067+
1068+
is_lower, is_upper = triangular_info
1069+
if is_lower or is_upper:
1070+
I = pt.eye(A.shape[0], dtype=A.dtype)
1071+
return [solve_triangular(A, I, lower=is_lower)]

0 commit comments

Comments
 (0)