Skip to content

Commit 5a660c6

Browse files
committed
improve test coverage, fix return types
1 parent c40159e commit 5a660c6

File tree

3 files changed

+27
-9
lines changed

3 files changed

+27
-9
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1051,15 +1051,14 @@ def _find_triangular_op(var):
10511051

10521052
if isinstance(core_op, LU | LUFactor):
10531053
if var.owner.outputs[1] == var:
1054-
return (True, False)
1054+
return True
10551055
if var.owner.outputs[2] == var:
1056-
return (False, True)
1056+
return False
10571057

10581058
if isinstance(core_op, QR):
10591059
if var.owner.outputs[1] == var:
1060-
return (False, True)
1060+
return False
10611061

1062-
# pt.tri will get constant folded so no point re-writing ?
10631062
if isinstance(core_op, Tri):
10641063
k_node = var.owner.inputs[2]
10651064
if isinstance(k_node, Constant) and k_node.data == 0:

tests/tensor/rewriting/test_linalg.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from pytensor.graph.fg import FunctionGraph
1515
from pytensor.graph.rewriting.utils import rewrite_graph
1616
from pytensor.tensor import swapaxes
17-
from pytensor.tensor.basic import triu
1817
from pytensor.tensor.blockwise import Blockwise
1918
from pytensor.tensor.elemwise import DimShuffle
2019
from pytensor.tensor.math import dot, matmul
@@ -1089,8 +1088,12 @@ def _check_op_in_graph(fgraph, op_type, present=True):
10891088
lambda x: pt.tri(x.shape[0], x.shape[1], k=0, dtype=x.dtype),
10901089
lambda a: np.tri(N=a.shape[0], M=a.shape[1], k=0, dtype=a.dtype),
10911090
),
1091+
"tril": (
1092+
lambda x: pt.tril(x),
1093+
lambda a: np.tril(a),
1094+
),
10921095
"triu": (
1093-
lambda x: triu(x),
1096+
lambda x: pt.triu(x),
10941097
lambda a: np.triu(a),
10951098
),
10961099
"cholesky": (

tests/tensor/test_slinalg.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1263,10 +1263,26 @@ def test_triangular_inv_op(self, lower):
12631263
atol = rtol = 1e-8 if config.floatX.endswith("64") else 1e-4
12641264
np.testing.assert_allclose(a_inv, expected_inv, rtol=rtol, atol=atol)
12651265

1266+
def test_triangular_inv_op_bad_on_error(self):
1267+
"""Tests that a bad `on_error` value raises a ValueError."""
1268+
with pytest.raises(ValueError, match="on_error must be one of"):
1269+
TriangularInv(on_error="foo")
1270+
1271+
def test_triangular_inv_op_raise_on_error(self):
1272+
"""Tests the default `on_error='raise'` functionality."""
1273+
x = matrix("x", dtype=config.floatX)
1274+
f_raise = function([x], TriangularInv()(x))
1275+
1276+
# Create a singular triangular matrix (zero on the diagonal)
1277+
a_singular = np.tril(np.random.rand(5, 5))
1278+
a_singular[2, 2] = 0
1279+
a_singular = a_singular.astype(config.floatX)
1280+
1281+
with pytest.raises(np.linalg.LinAlgError, match="Singular matrix"):
1282+
f_raise(a_singular)
1283+
12661284
def test_triangular_inv_op_nan_on_error(self):
1267-
"""
1268-
Tests the `on_error='nan'` functionality of the TriangularInv Op.
1269-
"""
1285+
"""Tests the `on_error='nan'` functionality of the TriangularInv Op."""
12701286
x = matrix("x", dtype=config.floatX)
12711287
f_nan = function([x], TriangularInv(on_error="nan")(x))
12721288

0 commit comments

Comments
 (0)