|
20 | 20 | ExtractDiag, |
21 | 21 | Eye, |
22 | 22 | TensorVariable, |
| 23 | + Tri, |
23 | 24 | concatenate, |
24 | 25 | diag, |
25 | 26 | diagonal, |
@@ -1015,3 +1016,56 @@ def scalar_solve_to_division(fgraph, node): |
1015 | 1016 | copy_stack_trace(old_out, new_out) |
1016 | 1017 |
|
1017 | 1018 | return [new_out] |
| 1019 | + |
| 1020 | + |
| 1021 | +def _find_triangular_op(var): |
| 1022 | + """ |
| 1023 | + Inspects a variable to see if it's triangular. |
| 1024 | +
|
| 1025 | + Returns a tuple (is_lower, is_upper) if triangular, otherwise None. |
| 1026 | + """ |
| 1027 | + |
| 1028 | + is_lower = getattr(var.tag, "lower_triangular", False) |
| 1029 | + is_upper = getattr(var.tag, "upper_triangular", False) |
| 1030 | + |
| 1031 | + if is_lower or is_upper: |
| 1032 | + return (is_lower, is_upper) |
| 1033 | + |
| 1034 | + if var.owner and isinstance(var.owner.op, Tri): |
| 1035 | + # The 'k' parameter of Tri determines the diagonal. |
| 1036 | + # k=0 is the main diagonal. |
| 1037 | + k = var.owner.op.k |
| 1038 | + if k == 0: |
| 1039 | + is_lower = var.owner.op.lower |
| 1040 | + return (is_lower, not is_lower) |
| 1041 | + |
| 1042 | + if var.owner and isinstance(var.owner.op, Blockwise): |
| 1043 | + core_op = var.owner.op.core_op |
| 1044 | + if isinstance(core_op, Cholesky): |
| 1045 | + return (core_op.lower, not core_op.lower) |
| 1046 | + |
| 1047 | + return None |
| 1048 | + |
| 1049 | + |
| 1050 | +@register_canonicalize |
| 1051 | +@register_stabilize |
| 1052 | +@node_rewriter([Blockwise]) |
| 1053 | +def rewrite_inv_to_triangular_solve(fgraph, node): |
| 1054 | + """ |
| 1055 | + This rewrite takes advantage of the fact that the inverse of a triangular |
| 1056 | + matrix can be computed more efficiently than the inverse of a general |
| 1057 | + matrix by using a triangular solve instead of a general matrix inverse. |
| 1058 | + """ |
| 1059 | + core_op = node.op.core_op |
| 1060 | + if not isinstance(core_op, ALL_INVERSE_OPS): |
| 1061 | + return None |
| 1062 | + |
| 1063 | + A = node.inputs[0] |
| 1064 | + triangular_info = _find_triangular_op(A) |
| 1065 | + if triangular_info is None: |
| 1066 | + return None |
| 1067 | + |
| 1068 | + is_lower, is_upper = triangular_info |
| 1069 | + if is_lower or is_upper: |
| 1070 | + I = pt.eye(A.shape[0], dtype=A.dtype) |
| 1071 | + return [solve_triangular(A, I, lower=is_lower)] |
0 commit comments