Skip to content

Commit 14fec15

Browse files
committed
review comments: overwrite_a test + tri rewrite test
1 parent 208316d commit 14fec15

File tree

4 files changed

+92
-61
lines changed

4 files changed

+92
-61
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1031,13 +1031,13 @@ def _find_triangular_op(var):
10311031
"""
10321032
Inspects a variable to see if it's triangular.
10331033
1034-
Returns a tuple (is_lower, is_upper) if triangular, otherwise None.
1034+
Returns `True` if lower-triangular, `False` if upper-triangular, otherwise `None`.
10351035
"""
10361036
# Case 1: Check for an explicit tag
10371037
is_lower = getattr(var.tag, "lower_triangular", False)
10381038
is_upper = getattr(var.tag, "upper_triangular", False)
10391039
if is_lower or is_upper:
1040-
return (is_lower, is_upper)
1040+
return is_lower
10411041

10421042
if not var.owner:
10431043
return None
@@ -1047,7 +1047,7 @@ def _find_triangular_op(var):
10471047

10481048
# Case 2: Check for direct creator Ops
10491049
if isinstance(core_op, Cholesky):
1050-
return (core_op.lower, not core_op.lower)
1050+
return core_op.lower
10511051

10521052
if isinstance(core_op, LU | LUFactor):
10531053
if var.owner.outputs[1] == var:
@@ -1060,11 +1060,10 @@ def _find_triangular_op(var):
10601060
return (False, True)
10611061

10621062
# pt.tri will get constant folded so no point re-writing ?
1063-
# if isinstance(core_op, Tri):
1064-
# k_node = var.owner.inputs[2]
1065-
# if isinstance(k_node, Constant) and k_node.data == 0:
1066-
# print('re-writing ... ')
1067-
# return (True, False)
1063+
if isinstance(core_op, Tri):
1064+
k_node = var.owner.inputs[2]
1065+
if isinstance(k_node, Constant) and k_node.data == 0:
1066+
return True
10681067

10691068
# Case 3: tril/triu patterns which are implemented as Mul
10701069
if isinstance(core_op, Elemwise) and isinstance(core_op.scalar_op, ScalarMul):
@@ -1077,7 +1076,7 @@ def _find_triangular_op(var):
10771076
if isinstance(other_inp.owner.op, Tri):
10781077
k_node = other_inp.owner.inputs[2]
10791078
if isinstance(k_node, Constant) and k_node.data == 0:
1080-
return (True, False) # It's tril
1079+
return True # It's tril
10811080

10821081
# Check for triu pattern: Mul(x, Sub(1, Tri(k=-1)))
10831082
sub_op = other_inp.owner.op
@@ -1095,7 +1094,7 @@ def _find_triangular_op(var):
10951094
if const_one is not None and tri_inp is not None:
10961095
k_node = tri_inp.owner.inputs[2]
10971096
if isinstance(k_node, Constant) and k_node.data == -1:
1098-
return (False, True) # It's triu
1097+
return False # It's triu
10991098

11001099
return None
11011100

@@ -1111,13 +1110,11 @@ def rewrite_inv_to_triangular_solve(fgraph, node):
11111110
"""
11121111

11131112
A = node.inputs[0]
1114-
triangular_info = _find_triangular_op(A)
1115-
if triangular_info is None:
1113+
is_lower = _find_triangular_op(A)
1114+
if is_lower is None:
11161115
return None
11171116

1118-
is_lower, is_upper = triangular_info
1119-
if is_lower or is_upper:
1120-
new_op = TriangularInv(lower=is_lower)
1121-
new_inv = new_op(A)
1122-
copy_stack_trace(node.outputs[0], new_inv)
1123-
return [new_inv]
1117+
new_op = TriangularInv(lower=is_lower)
1118+
new_inv = new_op(A)
1119+
copy_stack_trace(node.outputs[0], new_inv)
1120+
return [new_inv]

pytensor/tensor/slinalg.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,8 +1036,27 @@ def __init__(self, lower=True, on_error="raise", overwrite_a=False):
10361036
def perform(self, node, inputs, outputs):
10371037
(x,) = inputs
10381038
(z,) = outputs
1039-
(dtrtri,) = get_lapack_funcs(("trtri",), (x,))
1040-
inv, info = dtrtri(x, lower=self.lower, overwrite_c=True)
1039+
(trtri,) = get_lapack_funcs(("trtri",), (x,))
1040+
1041+
# Check if we want to overwrite and if the input is C-contiguous
1042+
c_contiguous_input = self.overwrite_a and x.flags["C_CONTIGUOUS"]
1043+
if c_contiguous_input:
1044+
# Transpose C-contiguous to F-contiguous
1045+
x_in = x.T
1046+
lower_flag = not self.lower
1047+
overwrite_flag = True
1048+
else:
1049+
# Use original matrix and flags
1050+
x_in = x
1051+
lower_flag = self.lower
1052+
overwrite_flag = self.overwrite_a
1053+
1054+
# Call trtri with the potentially transposed input and correct flags
1055+
# Use overwrite_c (LAPACK flag for trtri) based on our logic
1056+
inv_maybe_transposed, info = trtri(
1057+
x_in, lower=lower_flag, overwrite_c=overwrite_flag
1058+
)
1059+
10411060
if info != 0:
10421061
if self.on_error == "nan":
10431062
z[0] = np.full_like(x, np.nan)
@@ -1048,7 +1067,7 @@ def perform(self, node, inputs, outputs):
10481067
raise ValueError(
10491068
f"illegal value in {-info}-th argument of internal trtri"
10501069
)
1051-
z[0] = inv
1070+
z[0] = inv_maybe_transposed.T if c_contiguous_input else inv_maybe_transposed
10521071

10531072
def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
10541073
"""

tests/tensor/rewriting/test_linalg.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from pytensor.graph.fg import FunctionGraph
1515
from pytensor.graph.rewriting.utils import rewrite_graph
1616
from pytensor.tensor import swapaxes
17-
from pytensor.tensor.basic import tril, triu
17+
from pytensor.tensor.basic import triu
1818
from pytensor.tensor.blockwise import Blockwise
1919
from pytensor.tensor.elemwise import DimShuffle
2020
from pytensor.tensor.math import dot, matmul
@@ -1085,9 +1085,9 @@ def _check_op_in_graph(fgraph, op_type, present=True):
10851085

10861086

10871087
rewrite_cases = {
1088-
"tril": (
1089-
lambda x: tril(x),
1090-
lambda a: np.tril(a),
1088+
"tri": (
1089+
lambda x: pt.tri(x.shape[0], x.shape[1], k=0, dtype=x.dtype),
1090+
lambda a: np.tri(N=a.shape[0], M=a.shape[1], k=0, dtype=a.dtype),
10911091
),
10921092
"triu": (
10931093
lambda x: triu(x),

tests/tensor/test_slinalg.py

Lines changed: 51 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import scipy
99
from scipy import linalg as scipy_linalg
1010

11+
import pytensor
1112
from pytensor import function, grad
1213
from pytensor import tensor as pt
1314
from pytensor.configdefaults import config
@@ -1238,43 +1239,57 @@ def _test_fn(x, case=2, mode="reduced"):
12381239
)
12391240

12401241

1241-
@pytest.mark.parametrize("lower", [True, False])
1242-
def test_triangular_inv_op(lower):
1243-
"""Tests the TriangularInv Op directly."""
1244-
x = matrix("x", dtype=config.floatX)
1245-
f = function([x], TriangularInv(lower=lower)(x))
1246-
1247-
if lower:
1248-
a = np.tril(np.random.rand(5, 5) + 0.1).astype(config.floatX)
1249-
else:
1250-
a = np.triu(np.random.rand(5, 5) + 0.1).astype(config.floatX)
1251-
1252-
a_inv = f(a)
1253-
expected_inv = np.linalg.inv(a)
1254-
1255-
# Clean the NumPy result before comparing.
1256-
if lower:
1257-
expected_inv = np.tril(expected_inv)
1258-
else:
1259-
expected_inv = np.triu(expected_inv)
1260-
1261-
# The inverse of a triangular matrix is also triangular.
1262-
# We should check the full matrix, not just a part of it.
1263-
atol = rtol = 1e-8 if config.floatX.endswith("64") else 1e-4
1264-
np.testing.assert_allclose(a_inv, expected_inv, rtol=rtol, atol=atol)
1265-
1266-
1267-
def test_triangular_inv_op_nan_on_error():
1242+
class TestTriangularInv:
12681243
"""
1269-
Tests the `on_error='nan'` functionality of the TriangularInv Op.
1244+
Tests for the `TriangularInv` `Op`.
12701245
"""
1271-
x = matrix("x", dtype=config.floatX)
1272-
f_nan = function([x], TriangularInv(on_error="nan")(x))
12731246

1274-
# Create a singular triangular matrix (zero on the diagonal)
1275-
a_singular = np.tril(np.random.rand(5, 5))
1276-
a_singular[2, 2] = 0
1277-
a_singular = a_singular.astype(config.floatX)
1247+
@pytest.mark.parametrize("lower", [True, False])
1248+
def test_triangular_inv_op(self, lower):
1249+
"""Tests the TriangularInv Op directly."""
1250+
x = matrix("x", dtype=config.floatX)
1251+
f = function([x], TriangularInv(lower=lower)(x))
1252+
1253+
if lower:
1254+
a = np.tril(np.random.rand(5, 5) + 0.1).astype(config.floatX)
1255+
else:
1256+
a = np.triu(np.random.rand(5, 5) + 0.1).astype(config.floatX)
1257+
1258+
a_inv = f(a)
1259+
expected_inv = np.linalg.inv(a)
1260+
1261+
# The inverse of a triangular matrix is also triangular.
1262+
# We should check the full matrix, not just a part of it.
1263+
atol = rtol = 1e-8 if config.floatX.endswith("64") else 1e-4
1264+
np.testing.assert_allclose(a_inv, expected_inv, rtol=rtol, atol=atol)
1265+
1266+
def test_triangular_inv_op_nan_on_error(self):
1267+
"""
1268+
Tests the `on_error='nan'` functionality of the TriangularInv Op.
1269+
"""
1270+
x = matrix("x", dtype=config.floatX)
1271+
f_nan = function([x], TriangularInv(on_error="nan")(x))
1272+
1273+
# Create a singular triangular matrix (zero on the diagonal)
1274+
a_singular = np.tril(np.random.rand(5, 5))
1275+
a_singular[2, 2] = 0
1276+
a_singular = a_singular.astype(config.floatX)
1277+
1278+
res = f_nan(a_singular)
1279+
assert np.all(np.isnan(res))
1280+
1281+
@pytest.mark.parametrize("overwrite_a", [True, False])
1282+
def test_triangular_inv_op_inplace(self, overwrite_a):
1283+
"""Tests the TriangularInv Op directly."""
1284+
x = matrix("x", dtype=config.floatX)
1285+
f = function(
1286+
[pytensor.In(x, mutable=overwrite_a)],
1287+
TriangularInv(overwrite_a=overwrite_a)(x),
1288+
accept_inplace=True,
1289+
)
1290+
1291+
a = np.tril(np.random.rand(5, 5) + 0.1).astype(config.floatX)
1292+
a_copy = a.copy()
1293+
f(a)
12781294

1279-
res = f_nan(a_singular)
1280-
assert np.all(np.isnan(res))
1295+
assert overwrite_a == (not np.allclose(a, a_copy))

0 commit comments

Comments
 (0)