|
19 | 19 | from ._lib._utils._typing import Array, DType |
20 | 20 |
|
21 | 21 | __all__ = [ |
| 22 | + "atleast_nd", |
22 | 23 | "cov", |
| 24 | + "create_diagonal", |
23 | 25 | "expand_dims", |
24 | 26 | "isclose", |
25 | 27 | "nan_to_num", |
|
29 | 31 | ] |
30 | 32 |
|
31 | 33 |
|
| 34 | +def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array: |
| 35 | + """ |
| 36 | + Recursively expand the dimension of an array to at least `ndim`. |
| 37 | +
|
| 38 | + Parameters |
| 39 | + ---------- |
| 40 | + x : array |
| 41 | + Input array. |
| 42 | + ndim : int |
| 43 | + The minimum number of dimensions for the result. |
| 44 | + xp : array_namespace, optional |
| 45 | + The standard-compatible namespace for `x`. Default: infer. |
| 46 | +
|
| 47 | + Returns |
| 48 | + ------- |
| 49 | + array |
| 50 | + An array with ``res.ndim`` >= `ndim`. |
| 51 | + If ``x.ndim`` >= `ndim`, `x` is returned. |
| 52 | + If ``x.ndim`` < `ndim`, `x` is expanded by prepending new axes |
| 53 | + until ``res.ndim`` equals `ndim`. |
| 54 | +
|
| 55 | + Examples |
| 56 | + -------- |
| 57 | + >>> import array_api_strict as xp |
| 58 | + >>> import array_api_extra as xpx |
| 59 | + >>> x = xp.asarray([1]) |
| 60 | + >>> xpx.atleast_nd(x, ndim=3, xp=xp) |
| 61 | + Array([[[1]]], dtype=array_api_strict.int64) |
| 62 | +
|
| 63 | + >>> x = xp.asarray([[[1, 2], |
| 64 | + ... [3, 4]]]) |
| 65 | + >>> xpx.atleast_nd(x, ndim=1, xp=xp) is x |
| 66 | + True |
| 67 | + """ |
| 68 | + if xp is None: |
| 69 | + xp = array_namespace(x) |
| 70 | + |
| 71 | + if 1 <= ndim <= 3 and ( |
| 72 | + is_numpy_namespace(xp) |
| 73 | + or is_jax_namespace(xp) |
| 74 | + or is_dask_namespace(xp) |
| 75 | + or is_cupy_namespace(xp) |
| 76 | + or is_torch_namespace(xp) |
| 77 | + ): |
| 78 | + return getattr(xp, f"atleast_{ndim}d")(x) |
| 79 | + |
| 80 | + return _funcs.atleast_nd(x, ndim=ndim, xp=xp) |
| 81 | + |
| 82 | + |
32 | 83 | def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array: |
33 | 84 | """ |
34 | 85 | Estimate a covariance matrix. |
@@ -109,6 +160,69 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array: |
109 | 160 | return _funcs.cov(m, xp=xp) |
110 | 161 |
|
111 | 162 |
|
| 163 | +def create_diagonal( |
| 164 | + x: Array, /, *, offset: int = 0, xp: ModuleType | None = None |
| 165 | +) -> Array: |
| 166 | + """ |
| 167 | + Construct a diagonal array. |
| 168 | +
|
| 169 | + Parameters |
| 170 | + ---------- |
| 171 | + x : array |
| 172 | + An array having shape ``(*batch_dims, k)``. |
| 173 | + offset : int, optional |
| 174 | + Offset from the leading diagonal (default is ``0``). |
| 175 | + Use positive ints for diagonals above the leading diagonal, |
| 176 | + and negative ints for diagonals below the leading diagonal. |
| 177 | + xp : array_namespace, optional |
| 178 | + The standard-compatible namespace for `x`. Default: infer. |
| 179 | +
|
| 180 | + Returns |
| 181 | + ------- |
| 182 | + array |
| 183 | + An array having shape ``(*batch_dims, k+abs(offset), k+abs(offset))`` with `x` |
| 184 | + on the diagonal (offset by `offset`). |
| 185 | +
|
| 186 | + Examples |
| 187 | + -------- |
| 188 | + >>> import array_api_strict as xp |
| 189 | + >>> import array_api_extra as xpx |
| 190 | + >>> x = xp.asarray([2, 4, 8]) |
| 191 | +
|
| 192 | + >>> xpx.create_diagonal(x, xp=xp) |
| 193 | + Array([[2, 0, 0], |
| 194 | + [0, 4, 0], |
| 195 | + [0, 0, 8]], dtype=array_api_strict.int64) |
| 196 | +
|
| 197 | + >>> xpx.create_diagonal(x, offset=-2, xp=xp) |
| 198 | + Array([[0, 0, 0, 0, 0], |
| 199 | + [0, 0, 0, 0, 0], |
| 200 | + [2, 0, 0, 0, 0], |
| 201 | + [0, 4, 0, 0, 0], |
| 202 | + [0, 0, 8, 0, 0]], dtype=array_api_strict.int64) |
| 203 | + """ |
| 204 | + if xp is None: |
| 205 | + xp = array_namespace(x) |
| 206 | + |
| 207 | + if x.ndim == 0: |
| 208 | + err_msg = "`x` must be at least 1-dimensional." |
| 209 | + raise ValueError(err_msg) |
| 210 | + |
| 211 | + if is_torch_namespace(xp): |
| 212 | + return xp.diag_embed( |
| 213 | + atleast_nd(x, ndim=1, xp=xp), offset=offset, dim1=-2, dim2=-1 |
| 214 | + ) |
| 215 | + |
| 216 | + if (is_dask_namespace(xp) or is_cupy_namespace(xp)) and x.ndim < 2: |
| 217 | + return xp.diag(x, k=offset) |
| 218 | + |
| 219 | + if (is_jax_namespace(xp) or is_numpy_namespace(xp)) and x.ndim < 3: |
| 220 | + batch_dim, n = eager_shape(x)[:-1], eager_shape(x, -1)[0] + abs(offset) |
| 221 | + return xp.reshape(xp.diag(x, k=offset), (*batch_dim, n, n)) |
| 222 | + |
| 223 | + return _funcs.create_diagonal(x, offset=offset, xp=xp) |
| 224 | + |
| 225 | + |
112 | 226 | def expand_dims( |
113 | 227 | a: Array, /, *, axis: int | tuple[int, ...] = (0,), xp: ModuleType | None = None |
114 | 228 | ) -> Array: |
@@ -197,55 +311,6 @@ def expand_dims( |
197 | 311 | return _funcs.expand_dims(a, axis=axis, xp=xp) |
198 | 312 |
|
199 | 313 |
|
200 | | -def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array: |
201 | | - """ |
202 | | - Recursively expand the dimension of an array to at least `ndim`. |
203 | | -
|
204 | | - Parameters |
205 | | - ---------- |
206 | | - x : array |
207 | | - Input array. |
208 | | - ndim : int |
209 | | - The minimum number of dimensions for the result. |
210 | | - xp : array_namespace, optional |
211 | | - The standard-compatible namespace for `x`. Default: infer. |
212 | | -
|
213 | | - Returns |
214 | | - ------- |
215 | | - array |
216 | | - An array with ``res.ndim`` >= `ndim`. |
217 | | - If ``x.ndim`` >= `ndim`, `x` is returned. |
218 | | - If ``x.ndim`` < `ndim`, `x` is expanded by prepending new axes |
219 | | - until ``res.ndim`` equals `ndim`. |
220 | | -
|
221 | | - Examples |
222 | | - -------- |
223 | | - >>> import array_api_strict as xp |
224 | | - >>> import array_api_extra as xpx |
225 | | - >>> x = xp.asarray([1]) |
226 | | - >>> xpx.atleast_nd(x, ndim=3, xp=xp) |
227 | | - Array([[[1]]], dtype=array_api_strict.int64) |
228 | | -
|
229 | | - >>> x = xp.asarray([[[1, 2], |
230 | | - ... [3, 4]]]) |
231 | | - >>> xpx.atleast_nd(x, ndim=1, xp=xp) is x |
232 | | - True |
233 | | - """ |
234 | | - if xp is None: |
235 | | - xp = array_namespace(x) |
236 | | - |
237 | | - if 1 <= ndim <= 3 and ( |
238 | | - is_numpy_namespace(xp) |
239 | | - or is_jax_namespace(xp) |
240 | | - or is_dask_namespace(xp) |
241 | | - or is_cupy_namespace(xp) |
242 | | - or is_torch_namespace(xp) |
243 | | - ): |
244 | | - return getattr(xp, f"atleast_{ndim}d")(x) |
245 | | - |
246 | | - return _funcs.atleast_nd(x, ndim=ndim, xp=xp) |
247 | | - |
248 | | - |
249 | 314 | def isclose( |
250 | 315 | a: Array | complex, |
251 | 316 | b: Array | complex, |
|
0 commit comments