Skip to content

Commit 7d6d8f8

Browse files
committed
fix tests & mypy issues
1 parent 1740676 commit 7d6d8f8

File tree

2 files changed

+10
-7
lines changed

2 files changed

+10
-7
lines changed

pytensor/tensor/slinalg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -991,7 +991,7 @@ def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
991991
if not allowed_inplace_inputs:
992992
return self
993993

994-
new_props = self._props_dict()
994+
new_props = self._props_dict() # type: ignore
995995
new_props["overwrite_a"] = True
996996
return type(self)(**new_props)
997997

tests/tensor/rewriting/test_linalg.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1071,13 +1071,13 @@ def solve_op_in_graph(graph):
10711071
@pytest.mark.parametrize("lower", [True, False])
10721072
def test_triangular_inv_op(lower):
10731073
"""Tests the TriangularInv Op directly."""
1074-
x = matrix("x")
1074+
x = matrix("x", dtype=config.floatX)
10751075
f = function([x], TriangularInv(lower=lower)(x))
10761076

10771077
if lower:
1078-
a = np.tril(np.random.rand(5, 5) + 0.1)
1078+
a = np.tril(np.random.rand(5, 5) + 0.1).astype(config.floatX)
10791079
else:
1080-
a = np.triu(np.random.rand(5, 5) + 0.1)
1080+
a = np.triu(np.random.rand(5, 5) + 0.1).astype(config.floatX)
10811081

10821082
a_inv = f(a)
10831083
expected_inv = np.linalg.inv(a)
@@ -1099,12 +1099,13 @@ def test_triangular_inv_op_nan_on_error():
10991099
"""
11001100
Tests the `on_error='nan'` functionality of the TriangularInv Op.
11011101
"""
1102-
x = matrix("x")
1102+
x = matrix("x", dtype=config.floatX)
11031103
f_nan = function([x], TriangularInv(on_error="nan")(x))
11041104

11051105
# Create a singular triangular matrix (zero on the diagonal)
11061106
a_singular = np.tril(np.random.rand(5, 5))
11071107
a_singular[2, 2] = 0
1108+
a_singular = a_singular.astype(config.floatX)
11081109

11091110
res = f_nan(a_singular)
11101111
assert np.all(np.isnan(res))
@@ -1159,7 +1160,7 @@ def test_inv_to_triangular_inv_rewrite(case):
11591160
"""
11601161
Tests the rewrite of inv(triangular) -> TriangularInv.
11611162
"""
1162-
x = matrix("x")
1163+
x = matrix("x", dtype=config.floatX)
11631164
build_tri, _ = rewrite_cases[case]
11641165
x_tri = build_tri(x)
11651166
y_inv = inv(x_tri)
@@ -1179,7 +1180,9 @@ def test_inv_to_triangular_inv_rewrite(case):
11791180

11801181
# Check numerical correctness
11811182
a = np.random.rand(5, 5)
1182-
a = np.dot(a, a.T) + np.eye(5) # Make positive definite for Cholesky
1183+
a = (np.dot(a, a.T) + np.eye(5)).astype(
1184+
config.floatX
1185+
) # Make positive definite for Cholesky
11831186
pytensor_result = f(a)
11841187
_, numpy_tri_func = rewrite_cases[case]
11851188
numpy_result = np.linalg.inv(numpy_tri_func(a))

0 commit comments

Comments
 (0)