1616from pytensor .tensor import math as ptm
1717from pytensor .tensor .basic import as_tensor_variable , diagonal
1818from pytensor .tensor .blockwise import Blockwise
19- from pytensor .tensor .type import Variable , dvector , lscalar , matrix , scalar , vector
19+ from pytensor .tensor .type import (
20+ Variable ,
21+ dvector ,
22+ lscalar ,
23+ matrix ,
24+ scalar ,
25+ tensor ,
26+ vector ,
27+ )
2028
2129
2230class MatrixPinv (Op ):
@@ -297,37 +305,78 @@ def slogdet(x: TensorLike) -> tuple[ptb.TensorVariable, ptb.TensorVariable]:
297305class Eig (Op ):
298306 """
299307 Compute the eigenvalues and right eigenvectors of a square array.
300-
301308 """
302309
303310 __props__ : tuple [str , ...] = ()
304- gufunc_signature = "(m,m)->(m),(m,m)"
305311 gufunc_spec = ("numpy.linalg.eig" , 1 , 2 )
312+ gufunc_signature = "(m,m)->(m),(m,m)"
306313
307314 def make_node (self , x ):
308315 x = as_tensor_variable (x )
309316 assert x .ndim == 2
310- w = vector (dtype = x .dtype )
311- v = matrix (dtype = x .dtype )
317+
318+ M , N = x .type .shape
319+
320+ if M is not None and N is not None and M != N :
321+ raise ValueError (
322+ f"Input to Eig must be a square matrix, got static shape: ({ M } , { N } )"
323+ )
324+
325+ dtype = np .promote_types (x .dtype , np .complex64 )
326+
327+ w = tensor (dtype = dtype , shape = (M ,))
328+ v = tensor (dtype = dtype , shape = (M , N ))
329+
312330 return Apply (self , [x ], [w , v ])
313331
314332 def perform (self , node , inputs , outputs ):
315333 (x ,) = inputs
316- (w , v ) = outputs
317- w [0 ], v [0 ] = (z .astype (x .dtype ) for z in np .linalg .eig (x ))
334+ dtype = np .promote_types (x .dtype , np .complex64 )
335+
336+ w , v = np .linalg .eig (x )
337+
338+ # If the imaginary part of the eigenvalues is zero, numpy automatically casts them to real. We require
339+ # a statically known return dtype, so we have to cast back to complex to avoid dtype mismatch.
340+ outputs [0 ][0 ] = w .astype (dtype , copy = False )
341+ outputs [1 ][0 ] = v .astype (dtype , copy = False )
318342
319343 def infer_shape (self , fgraph , node , shapes ):
320- n = shapes [0 ][0 ]
344+ (x_shapes ,) = shapes
345+ n , _ = x_shapes
346+
321347 return [(n ,), (n , n )]
322348
349+ def L_op (self , inputs , outputs , output_grads ):
350+ raise NotImplementedError (
351+ "Gradients for Eig is not implemented because it always returns complex values, "
352+ "for which autodiff is not yet supported in PyTensor (PRs welcome :) ).\n "
353+ "If you know that your input has strictly real-valued eigenvalues (e.g. it is a "
354+ "symmetric matrix), use pt.linalg.eigh instead."
355+ )
356+
323357
324- eig = Blockwise (Eig ())
358+ def eig (x : TensorLike ):
359+ """
360+ Return the eigenvalues and right eigenvectors of a square array.
361+
362+ Note that regardless of the input dtype, the eigenvalues and eigenvectors are returned as complex numbers. As a
363+ result, the gradient of this operation is not implemented (because PyTensor does not support autodiff for complex
364+ values yet).
365+
366+ If you know that your input has strictly real-valued eigenvalues (e.g. it is a symmetric matrix), use
367+ `pytensor.tensor.linalg.eigh` instead.
368+
369+ Parameters
370+ ----------
371+ x: TensorLike
372+ Square matrix, or array of such matrices
373+ """
374+ return Blockwise (Eig ())(x )
325375
326376
327377class Eigh (Eig ):
328378 """
329379 Return the eigenvalues and eigenvectors of a Hermitian or symmetric matrix.
330-
331380 """
332381
333382 __props__ = ("UPLO" ,)
@@ -354,7 +403,7 @@ def perform(self, node, inputs, outputs):
354403 (w , v ) = outputs
355404 w [0 ], v [0 ] = np .linalg .eigh (x , self .UPLO )
356405
357- def grad (self , inputs , g_outputs ):
406+ def L_op (self , inputs , outputs , output_grads ):
358407 r"""The gradient function should return
359408
360409 .. math:: \sum_n\left(W_n\frac{\partial\,w_n}
@@ -378,10 +427,9 @@ def grad(self, inputs, g_outputs):
378427
379428 """
380429 (x ,) = inputs
381- w , v = self (x )
382- # Replace gradients wrt disconnected variables with
383- # zeros. This is a work-around for issue #1063.
384- gw , gv = _zero_disconnected ([w , v ], g_outputs )
430+ w , v = outputs
431+ gw , gv = _zero_disconnected ([w , v ], output_grads )
432+
385433 return [EighGrad (self .UPLO )(x , w , v , gw , gv )]
386434
387435
0 commit comments