Skip to content

Commit c9519b5

Browse files
committed
add triangular rewrite
1 parent 1697264 commit c9519b5

File tree

1 file changed

+54
-0
lines changed

1 file changed

+54
-0
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
ExtractDiag,
2121
Eye,
2222
TensorVariable,
23+
Tri,
2324
concatenate,
2425
diag,
2526
diagonal,
@@ -1017,3 +1018,56 @@ def scalar_solve_to_division(fgraph, node):
10171018
copy_stack_trace(old_out, new_out)
10181019

10191020
return [new_out]
1021+
1022+
1023+
def _find_triangular_op(var):
1024+
"""
1025+
Inspects a variable to see if it's triangular.
1026+
1027+
Returns a tuple (is_lower, is_upper) if triangular, otherwise None.
1028+
"""
1029+
1030+
is_lower = getattr(var.tag, "lower_triangular", False)
1031+
is_upper = getattr(var.tag, "upper_triangular", False)
1032+
1033+
if is_lower or is_upper:
1034+
return (is_lower, is_upper)
1035+
1036+
if var.owner and isinstance(var.owner.op, Tri):
1037+
# The 'k' parameter of Tri determines the diagonal.
1038+
# k=0 is the main diagonal.
1039+
k = var.owner.op.k
1040+
if k == 0:
1041+
is_lower = var.owner.op.lower
1042+
return (is_lower, not is_lower)
1043+
1044+
if var.owner and isinstance(var.owner.op, Blockwise):
1045+
core_op = var.owner.op.core_op
1046+
if isinstance(core_op, Cholesky):
1047+
return (core_op.lower, not core_op.lower)
1048+
1049+
return None
1050+
1051+
1052+
@register_canonicalize
1053+
@register_stabilize
1054+
@node_rewriter([Blockwise])
1055+
def rewrite_inv_to_triangular_solve(fgraph, node):
1056+
"""
1057+
This rewrite takes advantage of the fact that the inverse of a triangular
1058+
matrix can be computed more efficiently than the inverse of a general
1059+
matrix by using a triangular solve instead of a general matrix inverse.
1060+
"""
1061+
core_op = node.op.core_op
1062+
if not isinstance(core_op, ALL_INVERSE_OPS):
1063+
return None
1064+
1065+
A = node.inputs[0]
1066+
triangular_info = _find_triangular_op(A)
1067+
if triangular_info is None:
1068+
return None
1069+
1070+
is_lower, is_upper = triangular_info
1071+
if is_lower or is_upper:
1072+
I = pt.eye(A.shape[0], dtype=A.dtype)
1073+
return [solve_triangular(A, I, lower=is_lower)]

0 commit comments

Comments
 (0)