Skip to content

Commit 3082ed5

Browse files
Rename sparse functions to match numpy array API (#1663)
* Rename `mul` -> `multiply` * Rename `sub` -> `subtract` * Space... the final frontier
1 parent 5547eb0 commit 3082ed5

File tree

2 files changed

+52
-23
lines changed

2 files changed

+52
-23
lines changed

pytensor/sparse/basic.py

Lines changed: 45 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2268,36 +2268,49 @@ def add(x, y):
22682268
raise NotImplementedError()
22692269

22702270

2271-
def sub(x, y):
2271+
def subtract(
2272+
x: SparseVariable | TensorVariable, y: SparseVariable | TensorVariable
2273+
) -> SparseVariable:
22722274
"""
22732275
Subtract two matrices, at least one of which is sparse.
22742276
2275-
This method will provide the right op according
2276-
to the inputs.
2277+
This method will provide the right op according to the inputs.
22772278
22782279
Parameters
22792280
----------
2280-
x
2281+
x : SparseVariable or TensorVariable
22812282
A matrix variable.
2282-
y
2283+
y : SparseVariable or TensorVariable
22832284
A matrix variable.
22842285
22852286
Returns
22862287
-------
2287-
A sparse matrix
2288-
`x` - `y`
2288+
result: SparseVariable
2289+
Result of `x - y`, as a sparse matrix.
22892290
22902291
Notes
22912292
-----
22922293
At least one of `x` and `y` must be a sparse matrix.
22932294
2294-
The grad will be structured only when one of the variable will be a dense
2295-
matrix.
2296-
2295+
The grad will be structured only when one of the variable will be a dense matrix.
22972296
"""
22982297
return x + (-y)
22992298

23002299

2300+
def sub(x, y):
2301+
warn(
2302+
"pytensor.sparse.sub is deprecated and will be removed in a future version. Use "
2303+
"pytensor.sparse.subtract instead.",
2304+
category=DeprecationWarning,
2305+
stacklevel=2,
2306+
)
2307+
2308+
return subtract(x, y)
2309+
2310+
2311+
sub.__doc__ = subtract.__doc__
2312+
2313+
23012314
class MulSS(Op):
23022315
# mul(sparse, sparse)
23032316
# See the doc of mul() for more detail
@@ -2491,29 +2504,31 @@ def infer_shape(self, fgraph, node, ins_shapes):
24912504
mul_s_v = MulSV()
24922505

24932506

2494-
def mul(x, y):
2507+
def multiply(
2508+
x: SparseTensorType | TensorType, y: SparseTensorType | TensorType
2509+
) -> SparseVariable:
24952510
"""
24962511
Multiply elementwise two matrices, at least one of which is sparse.
24972512
24982513
This method will provide the right op according to the inputs.
24992514
25002515
Parameters
25012516
----------
2502-
x
2517+
x : SparseVariable
25032518
A matrix variable.
2504-
y
2519+
y : SparseVariable
25052520
A matrix variable.
25062521
25072522
Returns
25082523
-------
2509-
A sparse matrix
2510-
`x` * `y`
2524+
result: SparseVariable
2525+
The elementwise multiplication of `x` and `y`.
25112526
25122527
Notes
25132528
-----
25142529
At least one of `x` and `y` must be a sparse matrix.
2515-
The grad is regular, i.e. not structured.
25162530
2531+
The gradient is regular, i.e. not structured.
25172532
"""
25182533

25192534
x = as_sparse_or_tensor_variable(x)
@@ -2541,6 +2556,20 @@ def mul(x, y):
25412556
raise NotImplementedError()
25422557

25432558

2559+
def mul(x, y):
2560+
warn(
2561+
"pytensor.sparse.mul is deprecated and will be removed in a future version. Use "
2562+
"pytensor.sparse.multiply instead.",
2563+
category=DeprecationWarning,
2564+
stacklevel=2,
2565+
)
2566+
2567+
return multiply(x, y)
2568+
2569+
2570+
mul.__doc__ = multiply.__doc__
2571+
2572+
25442573
class __ComparisonOpSS(Op):
25452574
"""
25462575
Used as a superclass for all comparisons between two sparses matrices.

tests/sparse/test_basic.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@
6565
gt,
6666
le,
6767
lt,
68-
mul,
6968
mul_s_v,
69+
multiply,
7070
sampling_dot,
7171
sp_ones_like,
7272
square_diagonal,
@@ -724,21 +724,21 @@ def test_AddDS(self):
724724

725725
def test_MulSS(self):
726726
self._testSS(
727-
mul,
727+
multiply,
728728
np.array([[1.0, 0], [3, 0], [0, 6]]),
729729
np.array([[1.0, 2], [3, 0], [0, 6]]),
730730
)
731731

732732
def test_MulSD(self):
733733
self._testSD(
734-
mul,
734+
multiply,
735735
np.array([[1.0, 0], [3, 0], [0, 6]]),
736736
np.array([[1.0, 2], [3, 0], [0, 6]]),
737737
)
738738

739739
def test_MulDS(self):
740740
self._testDS(
741-
mul,
741+
multiply,
742742
np.array([[1.0, 0], [3, 0], [0, 6]]),
743743
np.array([[1.0, 2], [3, 0], [0, 6]]),
744744
)
@@ -783,7 +783,7 @@ def _testSS(
783783
assert np.all(val.todense() == array1 + array2)
784784
if dtype1.startswith("float") and dtype2.startswith("float"):
785785
verify_grad_sparse(op, [a, b], structured=False)
786-
elif op is mul:
786+
elif op is multiply:
787787
assert np.all(val.todense() == array1 * array2)
788788
if dtype1.startswith("float") and dtype2.startswith("float"):
789789
verify_grad_sparse(op, [a, b], structured=False)
@@ -833,7 +833,7 @@ def _testSD(
833833
continue
834834
if dtype1.startswith("float") and dtype2.startswith("float"):
835835
verify_grad_sparse(op, [a, b], structured=True)
836-
elif op is mul:
836+
elif op is multiply:
837837
assert _is_sparse_variable(apb)
838838
assert np.all(val.todense() == b.multiply(array1))
839839
assert np.all(
@@ -887,7 +887,7 @@ def _testDS(
887887
b = b.data
888888
if dtype1.startswith("float") and dtype2.startswith("float"):
889889
verify_grad_sparse(op, [a, b], structured=True)
890-
elif op is mul:
890+
elif op is multiply:
891891
assert _is_sparse_variable(apb)
892892
ans = np.array([[1, 0], [9, 0], [0, 36]])
893893
assert np.all(val.todense() == (a.multiply(array2)))

0 commit comments

Comments
 (0)