|
19 | 19 | ExtractDiag, |
20 | 20 | Eye, |
21 | 21 | TensorVariable, |
| 22 | + Tri, |
22 | 23 | concatenate, |
23 | 24 | diag, |
24 | 25 | diagonal, |
@@ -1062,3 +1063,56 @@ def scalar_solve_to_division(fgraph, node): |
1062 | 1063 | copy_stack_trace(old_out, new_out) |
1063 | 1064 |
|
1064 | 1065 | return [new_out] |
| 1066 | + |
| 1067 | + |
| 1068 | +def _find_triangular_op(var): |
| 1069 | + """ |
| 1070 | + Inspects a variable to see if it's triangular. |
| 1071 | +
|
| 1072 | + Returns a tuple (is_lower, is_upper) if triangular, otherwise None. |
| 1073 | + """ |
| 1074 | + |
| 1075 | + is_lower = getattr(var.tag, "lower_triangular", False) |
| 1076 | + is_upper = getattr(var.tag, "upper_triangular", False) |
| 1077 | + |
| 1078 | + if is_lower or is_upper: |
| 1079 | + return (is_lower, is_upper) |
| 1080 | + |
| 1081 | + if var.owner and isinstance(var.owner.op, Tri): |
| 1082 | + # The 'k' parameter of Tri determines the diagonal. |
| 1083 | + # k=0 is the main diagonal. |
| 1084 | + k = var.owner.op.k |
| 1085 | + if k == 0: |
| 1086 | + is_lower = var.owner.op.lower |
| 1087 | + return (is_lower, not is_lower) |
| 1088 | + |
| 1089 | + if var.owner and isinstance(var.owner.op, Blockwise): |
| 1090 | + core_op = var.owner.op.core_op |
| 1091 | + if isinstance(core_op, Cholesky): |
| 1092 | + return (core_op.lower, not core_op.lower) |
| 1093 | + |
| 1094 | + return None |
| 1095 | + |
| 1096 | + |
| 1097 | +@register_canonicalize |
| 1098 | +@register_stabilize |
| 1099 | +@node_rewriter([Blockwise]) |
| 1100 | +def rewrite_inv_to_triangular_solve(fgraph, node): |
| 1101 | + """ |
| 1102 | + This rewrite takes advantage of the fact that the inverse of a triangular |
| 1103 | + matrix can be computed more efficiently than the inverse of a general |
| 1104 | + matrix by using a triangular solve instead of a general matrix inverse. |
| 1105 | + """ |
| 1106 | + core_op = node.op.core_op |
| 1107 | + if not isinstance(core_op, ALL_INVERSE_OPS): |
| 1108 | + return None |
| 1109 | + |
| 1110 | + A = node.inputs[0] |
| 1111 | + triangular_info = _find_triangular_op(A) |
| 1112 | + if triangular_info is None: |
| 1113 | + return None |
| 1114 | + |
| 1115 | + is_lower, is_upper = triangular_info |
| 1116 | + if is_lower or is_upper: |
| 1117 | + I = pt.eye(A.shape[0], dtype=A.dtype) |
| 1118 | + return [solve_triangular(A, I, lower=is_lower)] |
0 commit comments