Skip to content

Commit 87c4442

Browse files
committed
add delegate func create_diagonal
1 parent ebe9a5b commit 87c4442

File tree

3 files changed

+118
-97
lines changed

3 files changed

+118
-97
lines changed

src/array_api_extra/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
argpartition,
55
atleast_nd,
66
cov,
7+
create_diagonal,
78
expand_dims,
89
isclose,
910
isin,
@@ -17,7 +18,6 @@
1718
from ._lib._funcs import (
1819
apply_where,
1920
broadcast_shapes,
20-
create_diagonal,
2121
default_dtype,
2222
kron,
2323
nunique,

src/array_api_extra/_delegation.py

Lines changed: 114 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919
from ._lib._utils._typing import Array, DType
2020

2121
__all__ = [
22+
"atleast_nd",
2223
"cov",
24+
"create_diagonal",
2325
"expand_dims",
2426
"isclose",
2527
"nan_to_num",
@@ -29,6 +31,55 @@
2931
]
3032

3133

34+
def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array:
35+
"""
36+
Recursively expand the dimension of an array to at least `ndim`.
37+
38+
Parameters
39+
----------
40+
x : array
41+
Input array.
42+
ndim : int
43+
The minimum number of dimensions for the result.
44+
xp : array_namespace, optional
45+
The standard-compatible namespace for `x`. Default: infer.
46+
47+
Returns
48+
-------
49+
array
50+
An array with ``res.ndim`` >= `ndim`.
51+
If ``x.ndim`` >= `ndim`, `x` is returned.
52+
If ``x.ndim`` < `ndim`, `x` is expanded by prepending new axes
53+
until ``res.ndim`` equals `ndim`.
54+
55+
Examples
56+
--------
57+
>>> import array_api_strict as xp
58+
>>> import array_api_extra as xpx
59+
>>> x = xp.asarray([1])
60+
>>> xpx.atleast_nd(x, ndim=3, xp=xp)
61+
Array([[[1]]], dtype=array_api_strict.int64)
62+
63+
>>> x = xp.asarray([[[1, 2],
64+
... [3, 4]]])
65+
>>> xpx.atleast_nd(x, ndim=1, xp=xp) is x
66+
True
67+
"""
68+
if xp is None:
69+
xp = array_namespace(x)
70+
71+
if 1 <= ndim <= 3 and (
72+
is_numpy_namespace(xp)
73+
or is_jax_namespace(xp)
74+
or is_dask_namespace(xp)
75+
or is_cupy_namespace(xp)
76+
or is_torch_namespace(xp)
77+
):
78+
return getattr(xp, f"atleast_{ndim}d")(x)
79+
80+
return _funcs.atleast_nd(x, ndim=ndim, xp=xp)
81+
82+
3283
def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
3384
"""
3485
Estimate a covariance matrix.
@@ -109,6 +160,69 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
109160
return _funcs.cov(m, xp=xp)
110161

111162

163+
def create_diagonal(
164+
x: Array, /, *, offset: int = 0, xp: ModuleType | None = None
165+
) -> Array:
166+
"""
167+
Construct a diagonal array.
168+
169+
Parameters
170+
----------
171+
x : array
172+
An array having shape ``(*batch_dims, k)``.
173+
offset : int, optional
174+
Offset from the leading diagonal (default is ``0``).
175+
Use positive ints for diagonals above the leading diagonal,
176+
and negative ints for diagonals below the leading diagonal.
177+
xp : array_namespace, optional
178+
The standard-compatible namespace for `x`. Default: infer.
179+
180+
Returns
181+
-------
182+
array
183+
An array having shape ``(*batch_dims, k+abs(offset), k+abs(offset))`` with `x`
184+
on the diagonal (offset by `offset`).
185+
186+
Examples
187+
--------
188+
>>> import array_api_strict as xp
189+
>>> import array_api_extra as xpx
190+
>>> x = xp.asarray([2, 4, 8])
191+
192+
>>> xpx.create_diagonal(x, xp=xp)
193+
Array([[2, 0, 0],
194+
[0, 4, 0],
195+
[0, 0, 8]], dtype=array_api_strict.int64)
196+
197+
>>> xpx.create_diagonal(x, offset=-2, xp=xp)
198+
Array([[0, 0, 0, 0, 0],
199+
[0, 0, 0, 0, 0],
200+
[2, 0, 0, 0, 0],
201+
[0, 4, 0, 0, 0],
202+
[0, 0, 8, 0, 0]], dtype=array_api_strict.int64)
203+
"""
204+
if xp is None:
205+
xp = array_namespace(x)
206+
207+
if x.ndim == 0:
208+
err_msg = "`x` must be at least 1-dimensional."
209+
raise ValueError(err_msg)
210+
211+
if is_torch_namespace(xp):
212+
return xp.diag_embed(
213+
atleast_nd(x, ndim=1, xp=xp), offset=offset, dim1=-2, dim2=-1
214+
)
215+
216+
if (is_dask_namespace(xp) or is_cupy_namespace(xp)) and x.ndim < 2:
217+
return xp.diag(x, k=offset)
218+
219+
if (is_jax_namespace(xp) or is_numpy_namespace(xp)) and x.ndim < 3:
220+
batch_dim, n = eager_shape(x)[:-1], eager_shape(x, -1)[0] + abs(offset)
221+
return xp.reshape(xp.diag(x, k=offset), (*batch_dim, n, n))
222+
223+
return _funcs.create_diagonal(x, offset=offset, xp=xp)
224+
225+
112226
def expand_dims(
113227
a: Array, /, *, axis: int | tuple[int, ...] = (0,), xp: ModuleType | None = None
114228
) -> Array:
@@ -197,55 +311,6 @@ def expand_dims(
197311
return _funcs.expand_dims(a, axis=axis, xp=xp)
198312

199313

200-
def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array:
201-
"""
202-
Recursively expand the dimension of an array to at least `ndim`.
203-
204-
Parameters
205-
----------
206-
x : array
207-
Input array.
208-
ndim : int
209-
The minimum number of dimensions for the result.
210-
xp : array_namespace, optional
211-
The standard-compatible namespace for `x`. Default: infer.
212-
213-
Returns
214-
-------
215-
array
216-
An array with ``res.ndim`` >= `ndim`.
217-
If ``x.ndim`` >= `ndim`, `x` is returned.
218-
If ``x.ndim`` < `ndim`, `x` is expanded by prepending new axes
219-
until ``res.ndim`` equals `ndim`.
220-
221-
Examples
222-
--------
223-
>>> import array_api_strict as xp
224-
>>> import array_api_extra as xpx
225-
>>> x = xp.asarray([1])
226-
>>> xpx.atleast_nd(x, ndim=3, xp=xp)
227-
Array([[[1]]], dtype=array_api_strict.int64)
228-
229-
>>> x = xp.asarray([[[1, 2],
230-
... [3, 4]]])
231-
>>> xpx.atleast_nd(x, ndim=1, xp=xp) is x
232-
True
233-
"""
234-
if xp is None:
235-
xp = array_namespace(x)
236-
237-
if 1 <= ndim <= 3 and (
238-
is_numpy_namespace(xp)
239-
or is_jax_namespace(xp)
240-
or is_dask_namespace(xp)
241-
or is_cupy_namespace(xp)
242-
or is_torch_namespace(xp)
243-
):
244-
return getattr(xp, f"atleast_{ndim}d")(x)
245-
246-
return _funcs.atleast_nd(x, ndim=ndim, xp=xp)
247-
248-
249314
def isclose(
250315
a: Array | complex,
251316
b: Array | complex,

src/array_api_extra/_lib/_funcs.py

Lines changed: 3 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -295,53 +295,9 @@ def one_hot(
295295

296296

297297
def create_diagonal(
298-
x: Array, /, *, offset: int = 0, xp: ModuleType | None = None
299-
) -> Array:
300-
"""
301-
Construct a diagonal array.
302-
303-
Parameters
304-
----------
305-
x : array
306-
An array having shape ``(*batch_dims, k)``.
307-
offset : int, optional
308-
Offset from the leading diagonal (default is ``0``).
309-
Use positive ints for diagonals above the leading diagonal,
310-
and negative ints for diagonals below the leading diagonal.
311-
xp : array_namespace, optional
312-
The standard-compatible namespace for `x`. Default: infer.
313-
314-
Returns
315-
-------
316-
array
317-
An array having shape ``(*batch_dims, k+abs(offset), k+abs(offset))`` with `x`
318-
on the diagonal (offset by `offset`).
319-
320-
Examples
321-
--------
322-
>>> import array_api_strict as xp
323-
>>> import array_api_extra as xpx
324-
>>> x = xp.asarray([2, 4, 8])
325-
326-
>>> xpx.create_diagonal(x, xp=xp)
327-
Array([[2, 0, 0],
328-
[0, 4, 0],
329-
[0, 0, 8]], dtype=array_api_strict.int64)
330-
331-
>>> xpx.create_diagonal(x, offset=-2, xp=xp)
332-
Array([[0, 0, 0, 0, 0],
333-
[0, 0, 0, 0, 0],
334-
[2, 0, 0, 0, 0],
335-
[0, 4, 0, 0, 0],
336-
[0, 0, 8, 0, 0]], dtype=array_api_strict.int64)
337-
"""
338-
if xp is None:
339-
xp = array_namespace(x)
340-
341-
if x.ndim == 0:
342-
err_msg = "`x` must be at least 1-dimensional."
343-
raise ValueError(err_msg)
344-
298+
x: Array, /, *, offset: int = 0, xp: ModuleType
299+
) -> Array: # numpydoc ignore=PR01,RT01
300+
"""See docstring in array_api_extra._delegation."""
345301
x_shape = eager_shape(x)
346302
batch_dims = x_shape[:-1]
347303
n = x_shape[-1] + abs(offset)

0 commit comments

Comments
 (0)