22
33from functools import reduce as _reduce , wraps as _wraps
44from builtins import all as _builtin_all , any as _builtin_any
5- from typing import Any , List , Optional , Sequence , Tuple , Union
5+ from typing import Any , List , Optional , Sequence , Tuple , Union , Literal
66
77import torch
88
@@ -828,6 +828,12 @@ def sign(x: Array, /) -> Array:
828828 return out
829829
830830
831+ def meshgrid (* arrays : Array , indexing : Literal ['xy' , 'ij' ] = 'xy' ) -> List [Array ]:
832+ # enforce the default of 'xy'
833+ # TODO: is the return type a list or a tuple
834+ return list (torch .meshgrid (* arrays , indexing = 'xy' ))
835+
836+
831837__all__ = ['__array_namespace_info__' , 'asarray' , 'result_type' , 'can_cast' ,
832838 'permute_dims' , 'bitwise_invert' , 'newaxis' , 'conj' , 'add' ,
833839 'atan2' , 'bitwise_and' , 'bitwise_left_shift' , 'bitwise_or' ,
@@ -844,6 +850,6 @@ def sign(x: Array, /) -> Array:
844850 'UniqueAllResult' , 'UniqueCountsResult' , 'UniqueInverseResult' ,
845851 'unique_all' , 'unique_counts' , 'unique_inverse' , 'unique_values' ,
846852 'matmul' , 'matrix_transpose' , 'vecdot' , 'tensordot' , 'isdtype' ,
847- 'take' , 'take_along_axis' , 'sign' , 'finfo' , 'iinfo' , 'repeat' ]
853+ 'take' , 'take_along_axis' , 'sign' , 'finfo' , 'iinfo' , 'repeat' , 'meshgrid' ]
848854
849855_all_ignore = ['torch' , 'get_xp' ]
0 commit comments