Skip to content

Commit 792bd04

Browse files
jessegrabowskiJesse Grabowski
andauthored
Refactor Expm and Eig, add jax dispatch for expm (#1668)
* `linalg.eig` always returns complex dtype * Update Eig dispatch for Numba, Jax, and Pytorch backends * Clean up `pytensor.linalg.expm` and related tests * Add JAX dispatch for expm * Implement L_op instead of grad in `Eigh` --------- Co-authored-by: Jesse Grabowski <jesse.grabowski@readyx.com>
1 parent 27c21cd commit 792bd04

File tree

9 files changed

+209
-150
lines changed

9 files changed

+209
-150
lines changed

pytensor/link/jax/dispatch/slinalg.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
Cholesky,
1111
CholeskySolve,
1212
Eigvalsh,
13+
Expm,
1314
LUFactor,
1415
PivotToPermutations,
1516
Solve,
@@ -179,3 +180,11 @@ def qr(x, mode=mode):
179180
return jax.scipy.linalg.qr(x, mode=mode)
180181

181182
return qr
183+
184+
185+
@jax_funcify.register(Expm)
186+
def jax_funcify_Expm(op, **kwargs):
187+
def expm(x):
188+
return jax.scipy.linalg.expm(x)
189+
190+
return expm

pytensor/link/numba/dispatch/nlinalg.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,15 +76,12 @@ def slogdet(x):
7676

7777
@numba_funcify.register(Eig)
7878
def numba_funcify_Eig(op, node, **kwargs):
79-
out_dtype_1 = node.outputs[0].type.numpy_dtype
80-
out_dtype_2 = node.outputs[1].type.numpy_dtype
81-
82-
inputs_cast = int_to_float_fn(node.inputs, out_dtype_1)
79+
w_dtype = node.outputs[0].type.numpy_dtype
80+
inputs_cast = int_to_float_fn(node.inputs, w_dtype)
8381

8482
@numba_basic.numba_njit
8583
def eig(x):
86-
out = np.linalg.eig(inputs_cast(x))
87-
return (out[0].astype(out_dtype_1), out[1].astype(out_dtype_2))
84+
return np.linalg.eig(inputs_cast(x))
8885

8986
return eig
9087

pytensor/tensor/nlinalg.py

Lines changed: 63 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,15 @@
1616
from pytensor.tensor import math as ptm
1717
from pytensor.tensor.basic import as_tensor_variable, diagonal
1818
from pytensor.tensor.blockwise import Blockwise
19-
from pytensor.tensor.type import Variable, dvector, lscalar, matrix, scalar, vector
19+
from pytensor.tensor.type import (
20+
Variable,
21+
dvector,
22+
lscalar,
23+
matrix,
24+
scalar,
25+
tensor,
26+
vector,
27+
)
2028

2129

2230
class MatrixPinv(Op):
@@ -297,37 +305,78 @@ def slogdet(x: TensorLike) -> tuple[ptb.TensorVariable, ptb.TensorVariable]:
297305
class Eig(Op):
298306
"""
299307
Compute the eigenvalues and right eigenvectors of a square array.
300-
301308
"""
302309

303310
__props__: tuple[str, ...] = ()
304-
gufunc_signature = "(m,m)->(m),(m,m)"
305311
gufunc_spec = ("numpy.linalg.eig", 1, 2)
312+
gufunc_signature = "(m,m)->(m),(m,m)"
306313

307314
def make_node(self, x):
308315
x = as_tensor_variable(x)
309316
assert x.ndim == 2
310-
w = vector(dtype=x.dtype)
311-
v = matrix(dtype=x.dtype)
317+
318+
M, N = x.type.shape
319+
320+
if M is not None and N is not None and M != N:
321+
raise ValueError(
322+
f"Input to Eig must be a square matrix, got static shape: ({M}, {N})"
323+
)
324+
325+
dtype = np.promote_types(x.dtype, np.complex64)
326+
327+
w = tensor(dtype=dtype, shape=(M,))
328+
v = tensor(dtype=dtype, shape=(M, N))
329+
312330
return Apply(self, [x], [w, v])
313331

314332
def perform(self, node, inputs, outputs):
315333
(x,) = inputs
316-
(w, v) = outputs
317-
w[0], v[0] = (z.astype(x.dtype) for z in np.linalg.eig(x))
334+
dtype = np.promote_types(x.dtype, np.complex64)
335+
336+
w, v = np.linalg.eig(x)
337+
338+
# If the imaginary part of the eigenvalues is zero, numpy automatically casts them to real. We require
339+
# a statically known return dtype, so we have to cast back to complex to avoid dtype mismatch.
340+
outputs[0][0] = w.astype(dtype, copy=False)
341+
outputs[1][0] = v.astype(dtype, copy=False)
318342

319343
def infer_shape(self, fgraph, node, shapes):
320-
n = shapes[0][0]
344+
(x_shapes,) = shapes
345+
n, _ = x_shapes
346+
321347
return [(n,), (n, n)]
322348

349+
def L_op(self, inputs, outputs, output_grads):
350+
raise NotImplementedError(
351+
"Gradients for Eig is not implemented because it always returns complex values, "
352+
"for which autodiff is not yet supported in PyTensor (PRs welcome :) ).\n"
353+
"If you know that your input has strictly real-valued eigenvalues (e.g. it is a "
354+
"symmetric matrix), use pt.linalg.eigh instead."
355+
)
356+
323357

324-
eig = Blockwise(Eig())
358+
def eig(x: TensorLike):
359+
"""
360+
Return the eigenvalues and right eigenvectors of a square array.
361+
362+
Note that regardless of the input dtype, the eigenvalues and eigenvectors are returned as complex numbers. As a
363+
result, the gradient of this operation is not implemented (because PyTensor does not support autodiff for complex
364+
values yet).
365+
366+
If you know that your input has strictly real-valued eigenvalues (e.g. it is a symmetric matrix), use
367+
`pytensor.tensor.linalg.eigh` instead.
368+
369+
Parameters
370+
----------
371+
x: TensorLike
372+
Square matrix, or array of such matrices
373+
"""
374+
return Blockwise(Eig())(x)
325375

326376

327377
class Eigh(Eig):
328378
"""
329379
Return the eigenvalues and eigenvectors of a Hermitian or symmetric matrix.
330-
331380
"""
332381

333382
__props__ = ("UPLO",)
@@ -354,7 +403,7 @@ def perform(self, node, inputs, outputs):
354403
(w, v) = outputs
355404
w[0], v[0] = np.linalg.eigh(x, self.UPLO)
356405

357-
def grad(self, inputs, g_outputs):
406+
def L_op(self, inputs, outputs, output_grads):
358407
r"""The gradient function should return
359408
360409
.. math:: \sum_n\left(W_n\frac{\partial\,w_n}
@@ -378,10 +427,9 @@ def grad(self, inputs, g_outputs):
378427
379428
"""
380429
(x,) = inputs
381-
w, v = self(x)
382-
# Replace gradients wrt disconnected variables with
383-
# zeros. This is a work-around for issue #1063.
384-
gw, gv = _zero_disconnected([w, v], g_outputs)
430+
w, v = outputs
431+
gw, gv = _zero_disconnected([w, v], output_grads)
432+
385433
return [EighGrad(self.UPLO)(x, w, v, gw, gv)]
386434

387435

pytensor/tensor/slinalg.py

Lines changed: 29 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
import numpy as np
88
import scipy.linalg as scipy_linalg
9-
from numpy.exceptions import ComplexWarning
109
from scipy.linalg import LinAlgError, LinAlgWarning, get_lapack_funcs
1110

1211
import pytensor
@@ -1304,82 +1303,60 @@ def eigvalsh(a, b, lower=True):
13041303
class Expm(Op):
13051304
"""
13061305
Compute the matrix exponential of a square array.
1307-
13081306
"""
13091307

13101308
__props__ = ()
1309+
gufunc_signature = "(m,m)->(m,m)"
13111310

13121311
def make_node(self, A):
13131312
A = as_tensor_variable(A)
13141313
assert A.ndim == 2
1315-
expm = matrix(dtype=A.dtype)
1316-
return Apply(
1317-
self,
1318-
[
1319-
A,
1320-
],
1321-
[
1322-
expm,
1323-
],
1324-
)
1314+
1315+
expm = matrix(dtype=A.dtype, shape=A.type.shape)
1316+
1317+
return Apply(self, [A], [expm])
13251318

13261319
def perform(self, node, inputs, outputs):
13271320
(A,) = inputs
13281321
(expm,) = outputs
13291322
expm[0] = scipy_linalg.expm(A)
13301323

1331-
def grad(self, inputs, outputs):
1324+
def L_op(self, inputs, outputs, output_grads):
1325+
# Kalbfleisch and Lawless, J. Am. Stat. Assoc. 80 (1985) Equation 3.4
1326+
# Kind of... You need to do some algebra from there to arrive at
1327+
# this expression.
13321328
(A,) = inputs
1333-
(g_out,) = outputs
1334-
return [ExpmGrad()(A, g_out)]
1329+
(_,) = outputs # Outputs not used; included for signature consistency only
1330+
(A_bar,) = output_grads
13351331

1336-
def infer_shape(self, fgraph, node, shapes):
1337-
return [shapes[0]]
1332+
w, V = pt.linalg.eig(A)
13381333

1334+
exp_w = pt.exp(w)
1335+
numer = pt.sub.outer(exp_w, exp_w)
1336+
denom = pt.sub.outer(w, w)
13391337

1340-
class ExpmGrad(Op):
1341-
"""
1342-
Gradient of the matrix exponential of a square array.
1338+
# When w_i ≈ w_j, we have a removable singularity in the expression for X, because
1339+
# lim b->a (e^a - e^b) / (a - b) = e^a (derivation left for the motivated reader)
1340+
X = pt.where(pt.abs(denom) < 1e-8, exp_w, numer / denom)
13431341

1344-
"""
1342+
diag_idx = pt.arange(w.shape[0])
1343+
X = X[..., diag_idx, diag_idx].set(exp_w)
13451344

1346-
__props__ = ()
1345+
inner = solve(V, A_bar.T @ V).T
1346+
result = solve(V.T, inner * X) @ V.T
13471347

1348-
def make_node(self, A, gw):
1349-
A = as_tensor_variable(A)
1350-
assert A.ndim == 2
1351-
out = matrix(dtype=A.dtype)
1352-
return Apply(
1353-
self,
1354-
[A, gw],
1355-
[
1356-
out,
1357-
],
1358-
)
1348+
# At this point, result is always a complex dtype. If the input was real, the output should be
1349+
# real as well (and all the imaginary parts are numerical noise)
1350+
if A.dtype not in ("complex64", "complex128"):
1351+
return [result.real]
1352+
1353+
return [result]
13591354

13601355
def infer_shape(self, fgraph, node, shapes):
13611356
return [shapes[0]]
13621357

1363-
def perform(self, node, inputs, outputs):
1364-
# Kalbfleisch and Lawless, J. Am. Stat. Assoc. 80 (1985) Equation 3.4
1365-
# Kind of... You need to do some algebra from there to arrive at
1366-
# this expression.
1367-
(A, gA) = inputs
1368-
(out,) = outputs
1369-
w, V = scipy_linalg.eig(A, right=True)
1370-
U = scipy_linalg.inv(V).T
1371-
1372-
exp_w = np.exp(w)
1373-
X = np.subtract.outer(exp_w, exp_w) / np.subtract.outer(w, w)
1374-
np.fill_diagonal(X, exp_w)
1375-
Y = U.dot(V.T.dot(gA).dot(U) * X).dot(V.T)
1376-
1377-
with warnings.catch_warnings():
1378-
warnings.simplefilter("ignore", ComplexWarning)
1379-
out[0] = Y.astype(A.dtype)
1380-
13811358

1382-
expm = Expm()
1359+
expm = Blockwise(Expm())
13831360

13841361

13851362
class SolveContinuousLyapunov(Op):

tests/link/jax/test_slinalg.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,3 +361,12 @@ def test_jax_cho_solve(b_shape, lower):
361361
out = pt_slinalg.cho_solve((c, lower), b, b_ndim=len(b_shape))
362362

363363
compare_jax_and_py([A, b], [out], [A_val, b_val])
364+
365+
366+
def test_jax_expm():
367+
rng = np.random.default_rng(utt.fetch_seed())
368+
A = pt.tensor(name="A", shape=(5, 5))
369+
A_val = rng.normal(size=(5, 5)).astype(config.floatX)
370+
out = pt_slinalg.expm(A)
371+
372+
compare_jax_and_py([A], [out], [A_val])

tests/link/numba/test_nlinalg.py

Lines changed: 19 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pytest
55

66
import pytensor.tensor as pt
7+
from pytensor import config
78
from pytensor.tensor import nlinalg
89
from tests.link.numba.test_basic import compare_numba_and_py
910

@@ -51,45 +52,24 @@ def test_Det_SLogDet(op, dtype):
5152
)
5253

5354

54-
@pytest.mark.parametrize(
55-
"x, exc",
56-
[
57-
(
58-
(
59-
pt.dmatrix(),
60-
(lambda x: x.T.dot(x))(x),
61-
),
62-
None,
63-
),
64-
(
65-
(
66-
pt.dmatrix(),
67-
(lambda x: x.T.dot(x))(y),
68-
),
69-
None,
70-
),
71-
(
72-
(
73-
pt.lmatrix(),
74-
(lambda x: x.T.dot(x))(
75-
rng.integers(1, 10, size=(3, 3)).astype("int64")
76-
),
77-
),
78-
None,
79-
),
80-
],
81-
)
82-
def test_Eig(x, exc):
83-
x, test_x = x
84-
g = nlinalg.Eig()(x)
85-
86-
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
87-
with cm:
88-
compare_numba_and_py(
89-
[x],
90-
g,
91-
[test_x],
92-
)
55+
@pytest.mark.parametrize("input_dtype", ["float", "int"])
56+
@pytest.mark.parametrize("symmetric", [True, False], ids=["symmetric", "general"])
57+
def test_Eig(input_dtype, symmetric):
58+
x = pt.dmatrix("x")
59+
if input_dtype == "float":
60+
x_val = rng.normal(size=(3, 3)).astype(config.floatX)
61+
else:
62+
x_val = rng.integers(1, 10, size=(3, 3)).astype("int64")
63+
64+
if symmetric:
65+
x_val = x_val + x_val.T
66+
67+
g = nlinalg.eig(x)
68+
compare_numba_and_py(
69+
graph_inputs=[x],
70+
graph_outputs=g,
71+
test_inputs=[x_val],
72+
)
9373

9474

9575
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)