Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytensor/tensor/nlinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class MatrixInverse(Op):

"""

__props__ = ()
__props__: tuple[str, ...] = ()
gufunc_signature = "(m,m)->(m,m)"
gufunc_spec = ("numpy.linalg.inv", 1, 1)

Expand Down
99 changes: 99 additions & 0 deletions pytensor/tensor/rewriting/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,22 @@
from pytensor import tensor as pt
from pytensor.compile import optdb
from pytensor.graph import Apply, FunctionGraph
from pytensor.graph.basic import Constant
from pytensor.graph.rewriting.basic import (
copy_stack_trace,
dfs_rewriter,
node_rewriter,
)
from pytensor.graph.rewriting.unify import OpPattern
from pytensor.scalar.basic import Abs, Log, Mul, Sign
from pytensor.scalar.basic import Mul as ScalarMul
from pytensor.scalar.basic import Sub as ScalarSub
Comment on lines +19 to +20
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also have Mul and Sub imported above. Just use those?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, will fix. Thank you!

from pytensor.tensor.basic import (
AllocDiag,
ExtractDiag,
Eye,
TensorVariable,
Tri,
concatenate,
diag,
diagonal,
Expand All @@ -46,12 +50,16 @@
)
from pytensor.tensor.rewriting.blockwise import blockwise_of
from pytensor.tensor.slinalg import (
LU,
QR,
BlockDiagonal,
Cholesky,
CholeskySolve,
LUFactor,
Solve,
SolveBase,
SolveTriangular,
TriangularInv,
_bilinear_solve_discrete_lyapunov,
block_diag,
cholesky,
Expand Down Expand Up @@ -1017,3 +1025,94 @@ def scalar_solve_to_division(fgraph, node):
copy_stack_trace(old_out, new_out)

return [new_out]


def _find_triangular_op(var):
"""
Inspects a variable to see if it's triangular.

Returns `True` if lower-triangular, `False` if upper-triangular, otherwise `None`.
"""
# Case 1: Check for an explicit tag
is_lower = getattr(var.tag, "lower_triangular", False)
is_upper = getattr(var.tag, "upper_triangular", False)
if is_lower or is_upper:
return is_lower

if not var.owner:
return None

op = var.owner.op
core_op = op.core_op if isinstance(op, Blockwise) else op

# Case 2: Check for direct creator Ops
if isinstance(core_op, Cholesky):
return core_op.lower

if isinstance(core_op, LU | LUFactor):
if var.owner.outputs[1] == var:
return True
if var.owner.outputs[2] == var:
return False

if isinstance(core_op, QR):
if var.owner.outputs[1] == var:
return False

if isinstance(core_op, Tri):
k_node = var.owner.inputs[2]
if isinstance(k_node, Constant) and k_node.data == 0:
return True

# Case 3: tril/triu patterns which are implemented as Mul
if isinstance(core_op, Elemwise) and isinstance(core_op.scalar_op, ScalarMul):
other_inp = next(
(i for i in var.owner.inputs if i != var.owner.inputs[0]), None
)

if other_inp is not None and other_inp.owner:
# Check for tril pattern: Mul(x, Tri(...))
if isinstance(other_inp.owner.op, Tri):
k_node = other_inp.owner.inputs[2]
if isinstance(k_node, Constant) and k_node.data == 0:
return True # It's tril

# Check for triu pattern: Mul(x, Sub(1, Tri(k=-1)))
sub_op = other_inp.owner.op
if isinstance(sub_op, Elemwise) and isinstance(sub_op.scalar_op, ScalarSub):
sub_inputs = other_inp.owner.inputs
const_one = next(
(i for i in sub_inputs if isinstance(i, Constant) and i.data == 1),
None,
)
tri_inp = next(
(i for i in sub_inputs if i.owner and isinstance(i.owner.op, Tri)),
None,
)

if const_one is not None and tri_inp is not None:
k_node = tri_inp.owner.inputs[2]
if isinstance(k_node, Constant) and k_node.data == -1:
return False # It's triu

return None


@register_stabilize
@node_rewriter([blockwise_of(MATRIX_INVERSE_OPS)])
def rewrite_inv_to_triangular_solve(fgraph, node):
"""
This rewrite takes advantage of the fact that the inverse of a triangular
Copy link
Member

@ricardoV94 ricardoV94 Nov 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should make sure any rewrites targeting MatrixInverse, such as MatrixInv(A) @ x -> Solve(A, x) will also work with TriangularInv (when it makes sense). Otherwise this may actually be worse in those cases

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this was mentioned in the PR but suggested that it be handled separately . Would you prefer to handle it in the same PR and also investigate other cases ?

Otherwise this may actually be worse in those cases

Sorry, didn't quite understand how this would be worse ? Because instead of the Solve rewrite we will be doing TriangularInv(A) @ x rewrite ?

matrix can be computed more efficiently than the inverse of a general
matrix by using a triangular inv instead of a general matrix inverse.
"""

A = node.inputs[0]
is_lower = _find_triangular_op(A)
if is_lower is None:
return None

new_op = TriangularInv(lower=is_lower)
new_inv = new_op(A)
copy_stack_trace(node.outputs[0], new_inv)
return [new_inv]
67 changes: 66 additions & 1 deletion pytensor/tensor/slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from pytensor.tensor import math as ptm
from pytensor.tensor.basic import as_tensor_variable, diagonal
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.nlinalg import kron, matrix_dot
from pytensor.tensor.nlinalg import MatrixInverse, kron, matrix_dot
from pytensor.tensor.shape import reshape
from pytensor.tensor.type import matrix, tensor, vector
from pytensor.tensor.variable import TensorVariable
Expand Down Expand Up @@ -1016,6 +1016,71 @@ def solve_triangular(
return cast(TensorVariable, ret)


class TriangularInv(MatrixInverse):
Copy link
Member

@ricardoV94 ricardoV94 Nov 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't subclass from Ops we also instantiate. That may cause confusion between TriangularInv and MatrixInverse. Sometimes that's fine, others it isn't. Better to have a BaseMatrixInverse that both inherit from. Then code can look for BaseMatrixInverse if it doesn't matter which subclass it is, or just the specific one it cares about.

For instance, I'm surprised your current rewrite is not applying recursively since the returned graph should fit the bill for the pattern you're matching (an inverse of an A that is found to be triangular)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, @ricardoV94 ! This was very insightful. I'm not sure if this is true, but perhaps in the re-write we do a type check and not an isinstance check ? I used type check in my test helper for this reason.

If I make the BaseMatrixInverse class, I suppose I should change inheritance for, say, MatrixPinv as well ?

"""
Computes the inverse of a triangular matrix.
"""

__props__ = ("lower", "on_error", "overwrite_a")

def __init__(self, lower=True, on_error="raise", overwrite_a=False):
self.lower = lower
if on_error not in ("raise", "nan"):
raise ValueError('on_error must be one of "raise" or "nan"')
self.on_error = on_error
self.overwrite_a = overwrite_a

if self.overwrite_a:
self.destroy_map = {0: [0]}

def perform(self, node, inputs, outputs):
(x,) = inputs
(z,) = outputs
(trtri,) = get_lapack_funcs(("trtri",), (x,))

# Check if we want to overwrite and if the input is C-contiguous
c_contiguous_input = self.overwrite_a and x.flags["C_CONTIGUOUS"]
if c_contiguous_input:
# Transpose C-contiguous to F-contiguous
x_in = x.T
lower_flag = not self.lower
overwrite_flag = True
else:
# Use original matrix and flags
x_in = x
lower_flag = self.lower
overwrite_flag = self.overwrite_a

# Call trtri with the potentially transposed input and correct flags
# Use overwrite_c (LAPACK flag for trtri) based on our logic
inv_maybe_transposed, info = trtri(
x_in, lower=lower_flag, overwrite_c=overwrite_flag
)

if info != 0:
if self.on_error == "nan":
z[0] = np.full_like(x, np.nan)
return
elif info > 0:
raise np.linalg.LinAlgError("Singular matrix")
elif info < 0:
raise ValueError(
f"illegal value in {-info}-th argument of internal trtri"
)
z[0] = inv_maybe_transposed.T if c_contiguous_input else inv_maybe_transposed

def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
"""
Allows this Op to overwrite its input buffer with its output.
"""
if not allowed_inplace_inputs:
return self

new_props = self._props_dict() # type: ignore
new_props["overwrite_a"] = True
return type(self)(**new_props)


class Solve(SolveBase):
"""
Solve a system of linear equations.
Expand Down
100 changes: 100 additions & 0 deletions tests/tensor/rewriting/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from pytensor.compile import get_default_mode
from pytensor.configdefaults import config
from pytensor.graph import ancestors
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.utils import rewrite_graph
from pytensor.tensor import swapaxes
from pytensor.tensor.blockwise import Blockwise
Expand All @@ -23,6 +24,7 @@
MatrixInverse,
MatrixPinv,
SLogDet,
inv,
matrix_inverse,
svd,
)
Expand All @@ -34,8 +36,11 @@
Solve,
SolveBase,
SolveTriangular,
TriangularInv,
cho_solve,
cholesky,
lu,
qr,
solve,
solve_triangular,
)
Expand Down Expand Up @@ -1060,3 +1065,98 @@ def solve_op_in_graph(graph):
np.testing.assert_allclose(
f(a_val, b_val), c_val, rtol=1e-7 if config.floatX == "float64" else 1e-5
)


def _check_op_in_graph(fgraph, op_type, present=True):
"""Helper to check if an Op is in a graph."""

# We use type() instead of isinstance() to avoid matching subclasses
# (e.g., finding TriangularInv when we're looking for MatrixInverse).
found = any(
type(node.op) is op_type
or (hasattr(node.op, "core_op") and type(node.op.core_op) is op_type)
for node in fgraph.apply_nodes
)
if present:
assert found, f"{op_type.__name__} not found in graph"
else:
assert not found, f"{op_type.__name__} unexpectedly found in graph"


rewrite_cases = {
"lower_tag": (
lambda x: (setattr(x.tag, "lower_triangular", True), x)[-1],
lambda a: np.tril(a),
),
"upper_tag": (
lambda x: (setattr(x.tag, "upper_triangular", True), x)[-1],
lambda a: np.triu(a),
),
"tri": (
lambda x: pt.tri(x.shape[0], x.shape[1], k=0, dtype=x.dtype),
lambda a: np.tri(N=a.shape[0], M=a.shape[1], k=0, dtype=a.dtype),
),
"tril": (
lambda x: pt.tril(x),
lambda a: np.tril(a),
),
"triu": (
lambda x: pt.triu(x),
lambda a: np.triu(a),
),
"cholesky": (
lambda x: cholesky(x),
lambda a: np.linalg.cholesky(a),
),
"lu_L": (
lambda x: lu(x)[1],
lambda a: scipy.linalg.lu(a)[1],
),
"lu_U": (
lambda x: lu(x)[2],
lambda a: scipy.linalg.lu(a)[2],
),
"qr_R": (
lambda x: qr(x)[1],
lambda a: np.linalg.qr(a)[1],
),
}


@pytest.mark.parametrize("case", rewrite_cases.keys())
def test_inv_to_triangular_inv_rewrite(case):
"""
Tests the rewrite of inv(triangular) -> TriangularInv.
"""
x = matrix("x", dtype=config.floatX)
build_tri, _ = rewrite_cases[case]
x_tri = build_tri(x)
y_inv = inv(x_tri)

# Check graph BEFORE compilation
pre_compile_fgraph = FunctionGraph([x], [y_inv], clone=False)
_check_op_in_graph(pre_compile_fgraph, MatrixInverse, present=True)
_check_op_in_graph(pre_compile_fgraph, TriangularInv, present=False)

# Trigger the rewrite
f = function([x], y_inv)

# Check graph AFTER compilation
post_compile_fgraph = f.maker.fgraph
_check_op_in_graph(post_compile_fgraph, TriangularInv, present=True)
_check_op_in_graph(post_compile_fgraph, MatrixInverse, present=False)

# Check numerical correctness
a = np.random.rand(5, 5)
a = (np.dot(a, a.T) + np.eye(5)).astype(
config.floatX
) # Make positive definite for Cholesky
if case == "lower_tag":
a = np.tril(a)
elif case == "upper_tag":
a = np.triu(a)
pytensor_result = f(a)
_, numpy_tri_func = rewrite_cases[case]
numpy_result = np.linalg.inv(numpy_tri_func(a))
atol = rtol = 1e-8 if config.floatX.endswith("64") else 1e-4
np.testing.assert_allclose(pytensor_result, numpy_result, rtol=rtol, atol=atol)
Loading