22from itertools import product
33from typing import Iterator , List , Optional , Tuple , Union
44
5- from . typing import Scalar , Shape
5+ from ndindex import iter_indices as _iter_indices
66
7- __all__ = ["normalise_axis" , "ndindex" , "axis_ndindex" , "axes_ndindex" , "reshape" ]
7+ from .typing import AtomicIndex , Index , Scalar , Shape
8+
9+ __all__ = [
10+ "broadcast_shapes" ,
11+ "normalise_axis" ,
12+ "ndindex" ,
13+ "axis_ndindex" ,
14+ "axes_ndindex" ,
15+ "reshape" ,
16+ "fmt_idx" ,
17+ ]
18+
19+
20+ class BroadcastError (ValueError ):
21+ """Shapes do not broadcast with eachother"""
22+
23+
24+ def _broadcast_shapes (shape1 : Shape , shape2 : Shape ) -> Shape :
25+ """Broadcasts `shape1` and `shape2`"""
26+ N1 = len (shape1 )
27+ N2 = len (shape2 )
28+ N = max (N1 , N2 )
29+ shape = [None for _ in range (N )]
30+ i = N - 1
31+ while i >= 0 :
32+ n1 = N1 - N + i
33+ if N1 - N + i >= 0 :
34+ d1 = shape1 [n1 ]
35+ else :
36+ d1 = 1
37+ n2 = N2 - N + i
38+ if N2 - N + i >= 0 :
39+ d2 = shape2 [n2 ]
40+ else :
41+ d2 = 1
42+
43+ if d1 == 1 :
44+ shape [i ] = d2
45+ elif d2 == 1 :
46+ shape [i ] = d1
47+ elif d1 == d2 :
48+ shape [i ] = d1
49+ else :
50+ raise BroadcastError ()
51+
52+ i = i - 1
53+
54+ return tuple (shape )
55+
56+
57+ def broadcast_shapes (* shapes : Shape ):
58+ if len (shapes ) == 0 :
59+ raise ValueError ("shapes=[] must be non-empty" )
60+ elif len (shapes ) == 1 :
61+ return shapes [0 ]
62+ result = _broadcast_shapes (shapes [0 ], shapes [1 ])
63+ for i in range (2 , len (shapes )):
64+ result = _broadcast_shapes (result , shapes [i ])
65+ return result
866
967
1068def normalise_axis (
@@ -17,13 +75,21 @@ def normalise_axis(
1775 return axes
1876
1977
20- def ndindex (shape ):
21- """Iterator of n-D indices to an array
78+ def ndindex (shape : Shape ) -> Iterator [Index ]:
79+ """Yield every index of a shape"""
80+ return (indices [0 ] for indices in iter_indices (shape ))
81+
2282
23- Yields tuples of integers to index every element of an array of shape
24- `shape`. Same as np.ndindex().
25- """
26- return product (* [range (i ) for i in shape ])
83+ def iter_indices (
84+ * shapes : Shape , skip_axes : Tuple [int , ...] = ()
85+ ) -> Iterator [Tuple [Index , ...]]:
86+ """Wrapper for ndindex.iter_indices()"""
87+ # Prevent iterations if any shape has 0-sides
88+ for shape in shapes :
89+ if 0 in shape :
90+ return
91+ for indices in _iter_indices (* shapes , skip_axes = skip_axes ):
92+ yield tuple (i .raw for i in indices ) # type: ignore
2793
2894
2995def axis_ndindex (
@@ -60,7 +126,7 @@ def axes_ndindex(shape: Shape, axes: Tuple[int, ...]) -> Iterator[List[Shape]]:
60126 yield list (indices )
61127
62128
63- def reshape (flat_seq : List [Scalar ], shape : Shape ) -> Union [Scalar , List [ Scalar ] ]:
129+ def reshape (flat_seq : List [Scalar ], shape : Shape ) -> Union [Scalar , List ]:
64130 """Reshape a flat sequence"""
65131 if any (s == 0 for s in shape ):
66132 raise ValueError (
@@ -75,3 +141,33 @@ def reshape(flat_seq: List[Scalar], shape: Shape) -> Union[Scalar, List[Scalar]]
75141 size = len (flat_seq )
76142 n = math .prod (shape [1 :])
77143 return [reshape (flat_seq [i * n : (i + 1 ) * n ], shape [1 :]) for i in range (size // n )]
144+
145+
146+ def fmt_i (i : AtomicIndex ) -> str :
147+ if isinstance (i , int ):
148+ return str (i )
149+ elif isinstance (i , slice ):
150+ res = ""
151+ if i .start is not None :
152+ res += str (i .start )
153+ res += ":"
154+ if i .stop is not None :
155+ res += str (i .stop )
156+ if i .step is not None :
157+ res += f":{ i .step } "
158+ return res
159+ else :
160+ return "..."
161+
162+
163+ def fmt_idx (sym : str , idx : Index ) -> str :
164+ if idx == ():
165+ return sym
166+ res = f"{ sym } ["
167+ _idx = idx if isinstance (idx , tuple ) else (idx ,)
168+ if len (_idx ) == 1 :
169+ res += fmt_i (_idx [0 ])
170+ else :
171+ res += ", " .join (fmt_i (i ) for i in _idx )
172+ res += "]"
173+ return res
0 commit comments