22
33from typing import TYPE_CHECKING
44if TYPE_CHECKING :
5- import paddle
6- array = paddle .Tensor
7- from paddle import dtype as Dtype
5+ import torch
6+ array = torch .Tensor
7+ from torch import dtype as Dtype
88 from typing import Optional , Union , Tuple , Literal
99 inf = float ('inf' )
1010
1111from ._aliases import _fix_promotion , sum
1212
13- from paddle .linalg import * # noqa: F403
13+ from torch .linalg import * # noqa: F403
1414
15- # paddle .linalg doesn't define __all__
16- # from paddle .linalg import __all__ as linalg_all
17- from paddle import linalg as paddle_linalg
18- linalg_all = [i for i in dir (paddle_linalg ) if not i .startswith ('_' )]
15+ # torch .linalg doesn't define __all__
16+ # from torch .linalg import __all__ as linalg_all
17+ from torch import linalg as torch_linalg
18+ linalg_all = [i for i in dir (torch_linalg ) if not i .startswith ('_' )]
1919
20- # outer is implemented in paddle but aren't in the linalg namespace
21- from paddle import outer
20+ # outer is implemented in torch but aren't in the linalg namespace
21+ from torch import outer
2222# These functions are in both the main and linalg namespaces
2323from ._aliases import matmul , matrix_transpose , tensordot
2424
25- # Note: paddle.linalg.cross does not default to axis=-1 (it defaults to the
25+ # Note: torch.linalg.cross does not default to axis=-1 (it defaults to the
26+ # first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743
2627
27- # paddle.cross also does not support broadcasting when it would add new
28+ # torch.cross also does not support broadcasting when it would add new
29+ # dimensions https://github.com/pytorch/pytorch/issues/39656
2830def cross (x1 : array , x2 : array , / , * , axis : int = - 1 ) -> array :
2931 x1 , x2 = _fix_promotion (x1 , x2 , only_scalar = False )
3032 if not (- min (x1 .ndim , x2 .ndim ) <= axis < max (x1 .ndim , x2 .ndim )):
3133 raise ValueError (f"axis { axis } out of bounds for cross product of arrays with shapes { x1 .shape } and { x2 .shape } " )
3234 if not (x1 .shape [axis ] == x2 .shape [axis ] == 3 ):
3335 raise ValueError (f"cross product axis must have size 3, got { x1 .shape [axis ]} and { x2 .shape [axis ]} " )
34- x1 , x2 = paddle .broadcast_tensors (x1 , x2 )
35- return paddle_linalg .cross (x1 , x2 , axis = axis )
36+ x1 , x2 = torch .broadcast_tensors (x1 , x2 )
37+ return torch_linalg .cross (x1 , x2 , dim = axis )
3638
3739def vecdot (x1 : array , x2 : array , / , * , axis : int = - 1 , ** kwargs ) -> array :
3840 from ._aliases import isdtype
3941
4042 x1 , x2 = _fix_promotion (x1 , x2 , only_scalar = False )
4143
42- # paddle .linalg.vecdot incorrectly allows broadcasting along the contracted dimension
44+ # torch .linalg.vecdot incorrectly allows broadcasting along the contracted dimension
4345 if x1 .shape [axis ] != x2 .shape [axis ]:
4446 raise ValueError ("x1 and x2 must have the same size along the given axis" )
4547
46- # paddle .linalg.vecdot doesn't support integer dtypes
48+ # torch .linalg.vecdot doesn't support integer dtypes
4749 if isdtype (x1 .dtype , 'integral' ) or isdtype (x2 .dtype , 'integral' ):
4850 if kwargs :
4951 raise RuntimeError ("vecdot kwargs not supported for integral dtypes" )
5052
51- x1_ = paddle .moveaxis (x1 , axis , - 1 )
52- x2_ = paddle .moveaxis (x2 , axis , - 1 )
53- x1_ , x2_ = paddle .broadcast_tensors (x1_ , x2_ )
53+ x1_ = torch .moveaxis (x1 , axis , - 1 )
54+ x2_ = torch .moveaxis (x2 , axis , - 1 )
55+ x1_ , x2_ = torch .broadcast_tensors (x1_ , x2_ )
5456
5557 res = x1_ [..., None , :] @ x2_ [..., None ]
5658 return res [..., 0 , 0 ]
57- return paddle .linalg .vecdot (x1 , x2 , axis = axis , ** kwargs )
59+ return torch .linalg .vecdot (x1 , x2 , dim = axis , ** kwargs )
5860
5961def solve (x1 : array , x2 : array , / , ** kwargs ) -> array :
6062 x1 , x2 = _fix_promotion (x1 , x2 , only_scalar = False )
61- # paddle tries to emulate NumPy 1 solve behavior by using batched 1-D solve
63+ # Torch tries to emulate NumPy 1 solve behavior by using batched 1-D solve
6264 # whenever
6365 # 1. x1.ndim - 1 == x2.ndim
6466 # 2. x1.shape[:-1] == x2.shape
6567 #
6668 # See linalg_solve_is_vector_rhs in
6769 # aten/src/ATen/native/LinearAlgebraUtils.h and
68- # paddle_META_FUNC (_linalg_solve_ex) in
69- # aten/src/ATen/native/BatchLinearAlgebra.cpp in the Pypaddle source code.
70+ # TORCH_META_FUNC (_linalg_solve_ex) in
71+ # aten/src/ATen/native/BatchLinearAlgebra.cpp in the PyTorch source code.
7072 #
7173 # The easiest way to work around this is to prepend a size 1 dimension to
7274 # x2, since x2 is already one dimension less than x1.
7375 #
74- # See https://github.com/pypaddle/pypaddle /issues/52915
76+ # See https://github.com/pytorch/pytorch /issues/52915
7577 if x2 .ndim != 1 and x1 .ndim - 1 == x2 .ndim and x1 .shape [:- 1 ] == x2 .shape :
7678 x2 = x2 [None ]
77- return paddle .linalg .solve (x1 , x2 , ** kwargs )
79+ return torch .linalg .solve (x1 , x2 , ** kwargs )
7880
79- # paddle .trace doesn't support the offset argument and doesn't support stacking
81+ # torch .trace doesn't support the offset argument and doesn't support stacking
8082def trace (x : array , / , * , offset : int = 0 , dtype : Optional [Dtype ] = None ) -> array :
8183 # Use our wrapped sum to make sure it does upcasting correctly
82- return sum (paddle .diagonal (x , offset = offset , dim1 = - 2 , dim2 = - 1 ), axis = - 1 , dtype = dtype )
84+ return sum (torch .diagonal (x , offset = offset , dim1 = - 2 , dim2 = - 1 ), axis = - 1 , dtype = dtype )
8385
8486def vector_norm (
8587 x : array ,
@@ -90,30 +92,30 @@ def vector_norm(
9092 ord : Union [int , float , Literal [inf , - inf ]] = 2 ,
9193 ** kwargs ,
9294) -> array :
93- # paddle .vector_norm incorrectly treats axis=() the same as axis=None
95+ # torch .vector_norm incorrectly treats axis=() the same as axis=None
9496 if axis == ():
9597 out = kwargs .get ('out' )
9698 if out is None :
9799 dtype = None
98- if x .dtype == paddle .complex64 :
99- dtype = paddle .float32
100- elif x .dtype == paddle .complex128 :
101- dtype = paddle .float64
100+ if x .dtype == torch .complex64 :
101+ dtype = torch .float32
102+ elif x .dtype == torch .complex128 :
103+ dtype = torch .float64
102104
103- out = paddle .zeros_like (x , dtype = dtype )
105+ out = torch .zeros_like (x , dtype = dtype )
104106
105107 # The norm of a single scalar works out to abs(x) in every case except
106- # for p =0, which is x != 0.
108+ # for ord =0, which is x != 0.
107109 if ord == 0 :
108110 out [:] = (x != 0 )
109111 else :
110- out [:] = paddle .abs (x )
112+ out [:] = torch .abs (x )
111113 return out
112- return paddle .linalg .vector_norm (x , p = ord , axis = axis , keepdim = keepdims , ** kwargs )
114+ return torch .linalg .vector_norm (x , ord = ord , axis = axis , keepdim = keepdims , ** kwargs )
113115
114116__all__ = linalg_all + ['outer' , 'matmul' , 'matrix_transpose' , 'tensordot' ,
115117 'cross' , 'vecdot' , 'solve' , 'trace' , 'vector_norm' ]
116118
117- _all_ignore = ['paddle_linalg ' , 'sum' ]
119+ _all_ignore = ['torch_linalg ' , 'sum' ]
118120
119121del linalg_all
0 commit comments