Skip to content

Commit 33c9ad8

Browse files
committed
ENH: new function fill_diagonal
1 parent 39f5889 commit 33c9ad8

File tree

4 files changed

+149
-1
lines changed

4 files changed

+149
-1
lines changed

src/array_api_extra/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
atleast_nd,
66
cov,
77
expand_dims,
8+
fill_diagonal,
89
isclose,
910
isin,
1011
nan_to_num,
@@ -39,6 +40,7 @@
3940
"create_diagonal",
4041
"default_dtype",
4142
"expand_dims",
43+
"fill_diagonal",
4244
"isclose",
4345
"isin",
4446
"kron",

src/array_api_extra/_delegation.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -964,3 +964,56 @@ def isin(
964964
return xp.isin(a, b, assume_unique=assume_unique, invert=invert)
965965

966966
return _funcs.isin(a, b, assume_unique=assume_unique, invert=invert, xp=xp)
967+
968+
969+
def fill_diagonal(
970+
a: Array,
971+
val: Array | int | float,
972+
*,
973+
wrap: bool = False,
974+
xp: ModuleType | None = None,
975+
) -> None | Array:
976+
"""
977+
Fill the main diagonal of the given array `a` of any dimensionality >= 2.
978+
979+
For an array `a` with ``a.ndim >= 2``, the diagonal is the list of
980+
values ``a[i, ..., i]`` with indices ``i`` all identical. This function
981+
modifies the input array in-place without returning a value. However
982+
specifically for JAX, a copy of the array `a` with the diagonal elements
983+
overwritten is returned. This is because it is not possible to modify JAX's
984+
immutable arrays in-place.
985+
986+
Parameters
987+
----------
988+
a : Array
989+
Input array whose diagonal is to be filled. It should be at least 2-D.
990+
991+
val : Array | int | float
992+
Value(s) to write on the diagonal. If `val` is a scalar, the value is
993+
written along the diagonal. If `val` is an Array, the flattened `val`
994+
is written along the diagonal.
995+
996+
wrap : bool, optional
997+
Only applicable for NumPy and Cupy. For tall matrices the
998+
diagonal is "wrapped" after N columns. Default: False.
999+
1000+
xp : array_namespace, optional
1001+
The standard-compatible namespace for `a` and `val`. Default: infer.
1002+
1003+
Returns
1004+
-------
1005+
Array | None
1006+
For JAX a copy of the original array `a` is returned. For all other
1007+
cases the array `a` is modified in-place so None is returned.
1008+
"""
1009+
if xp is None:
1010+
xp = array_namespace(a, val)
1011+
1012+
if is_jax_namespace(xp):
1013+
return xp.fill_diagonal(a, val, inplace=False)
1014+
1015+
if is_numpy_namespace(xp) or is_cupy_namespace(xp):
1016+
xp.fill_diagonal(a, val, wrap=wrap)
1017+
1018+
_funcs.fill_diagonal(a, val, xp=xp)
1019+
return None

src/array_api_extra/_lib/_funcs.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,12 @@
88

99
from ._at import at
1010
from ._utils import _compat, _helpers
11-
from ._utils._compat import array_namespace, is_dask_namespace, is_jax_array
11+
from ._utils._compat import (
12+
array_namespace,
13+
is_dask_namespace,
14+
is_jax_array,
15+
is_torch_namespace,
16+
)
1217
from ._utils._helpers import (
1318
asarrays,
1419
capabilities,
@@ -786,3 +791,36 @@ def isin( # numpydoc ignore=PR01,RT01
786791
_helpers.in1d(a, b, assume_unique=assume_unique, invert=invert, xp=xp),
787792
original_a_shape,
788793
)
794+
795+
796+
def fill_diagonal( # numpydoc ignore=PR01,RT01
797+
a: Array,
798+
val: Array | int | float,
799+
*,
800+
xp: ModuleType,
801+
) -> None:
802+
"""See docstring in `array_api_extra._delegation.py`."""
803+
if a.ndim < 2:
804+
msg = f"array `a` must be at least 2-d. Got array with shape {tuple(a.shape)}"
805+
raise ValueError(msg)
806+
807+
a, val = asarrays(a, val, xp=xp)
808+
min_rows_columns = min(x or 0 for x in a.shape)
809+
if is_torch_namespace(xp):
810+
val_size = math.prod(x or 0 for x in val.shape)
811+
else:
812+
val_size = val.size or 0
813+
if val.ndim > 0 and val_size != min_rows_columns:
814+
msg = (
815+
"`val` needs to be a scalar or an array of the same size as the "
816+
f"diagonal of `a` ({min_rows_columns}). Got {val.shape[0]}"
817+
)
818+
raise ValueError(msg)
819+
820+
if val.ndim == 0:
821+
for i in range(min_rows_columns):
822+
a[i, i] = val
823+
else:
824+
val = cast(Array, xp.reshape(val, (-1,)))
825+
for i in range(min_rows_columns):
826+
a[i, i] = val[i]

tests/test_funcs.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
create_diagonal,
2222
default_dtype,
2323
expand_dims,
24+
fill_diagonal,
2425
isclose,
2526
isin,
2627
kron,
@@ -1548,3 +1549,57 @@ def test_kind(self, xp: ModuleType, library: Backend):
15481549
expected = xp.asarray([False, True, False, True])
15491550
res = isin(a, b, kind="sort")
15501551
xp_assert_equal(res, expected)
1552+
1553+
1554+
@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="item assignment not supported")
1555+
@pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY, reason="numpy read only arrays")
1556+
class TestFillDiagonal:
1557+
def test_simple(self, xp: ModuleType):
1558+
a = xp.zeros((3, 3), dtype=xp.int64)
1559+
val = 5
1560+
expected = xp.asarray([[5, 0, 0], [0, 5, 0], [0, 0, 5]], dtype=xp.int64)
1561+
if is_jax_namespace(xp):
1562+
a = cast(Array, fill_diagonal(a, val))
1563+
else:
1564+
_ = fill_diagonal(a, val)
1565+
xp_assert_equal(a, expected)
1566+
1567+
def test_val_1d(self, xp: ModuleType):
1568+
a = xp.zeros((3, 3), dtype=xp.int64)
1569+
val = xp.asarray([1, 2, 3], dtype=xp.int64)
1570+
expected = xp.asarray([[1, 0, 0], [0, 2, 0], [0, 0, 3]], dtype=xp.int64)
1571+
if is_jax_namespace(xp):
1572+
a = cast(Array, fill_diagonal(a, val))
1573+
else:
1574+
_ = fill_diagonal(a, val)
1575+
xp_assert_equal(a, expected)
1576+
1577+
@pytest.mark.parametrize(
1578+
("a_shape", "expected"),
1579+
[
1580+
((4, 4), [[1, 0, 0, 0], [0, 2, 0, 0], [0, 0, 3, 0], [0, 0, 0, 4]]),
1581+
(
1582+
(5, 4),
1583+
[[1, 0, 0, 0], [0, 2, 0, 0], [0, 0, 3, 0], [0, 0, 0, 4], [0, 0, 0, 0]],
1584+
),
1585+
],
1586+
)
1587+
def test_val_2d(self, xp: ModuleType, a_shape: tuple[int, int], expected: Array):
1588+
a = xp.zeros(a_shape, dtype=xp.int64)
1589+
val = xp.asarray([[1, 2], [3, 4]], dtype=xp.int64)
1590+
expected = xp.asarray(expected, dtype=xp.int64)
1591+
if is_jax_namespace(xp):
1592+
a = cast(Array, fill_diagonal(a, val))
1593+
else:
1594+
_ = fill_diagonal(a, val)
1595+
xp_assert_equal(a, expected)
1596+
1597+
@pytest.mark.parametrize("val_scalar", [True, False])
1598+
def test_device(self, xp: ModuleType, device: Device, val_scalar: bool):
1599+
a = xp.zeros((3, 3), dtype=xp.int64, device=device)
1600+
val = 5 if val_scalar else xp.asarray([1, 2, 3], dtype=xp.int64, device=device)
1601+
if is_jax_namespace(xp):
1602+
a = cast(Array, fill_diagonal(a, val))
1603+
else:
1604+
_ = fill_diagonal(a, val)
1605+
assert get_device(a) == device

0 commit comments

Comments
 (0)