Skip to content

Commit 8884c67

Browse files
committed
Rename SparseMultiply Ops
1 parent aa68c80 commit 8884c67

File tree

3 files changed

+41
-55
lines changed

3 files changed

+41
-55
lines changed

pytensor/sparse/basic.py

Lines changed: 33 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -2310,7 +2310,7 @@ def sub(x, y):
23102310
return x + (-y)
23112311

23122312

2313-
class MulSS(Op):
2313+
class SparseSparseMultiply(Op):
23142314
# mul(sparse, sparse)
23152315
# See the doc of mul() for more detail
23162316
__props__ = ()
@@ -2343,12 +2343,15 @@ def infer_shape(self, fgraph, node, shapes):
23432343
return [shapes[0]]
23442344

23452345

2346-
mul_s_s = MulSS()
2346+
mul_s_s = SparseSparseMultiply()
23472347

23482348

2349-
class MulSD(Op):
2349+
class SparseDenseMultiply(Op):
23502350
# mul(sparse, dense)
23512351
# See the doc of mul() for more detail
2352+
2353+
# We're doing useless copy of indices and indptr, those should be reused
2354+
# However, PyTensor doesn't support one output -> multiple views...
23522355
__props__ = ()
23532356

23542357
def make_node(self, x, y):
@@ -2364,64 +2367,42 @@ def make_node(self, x, y):
23642367
# Broadcasting of the sparse matrix is not supported.
23652368
# We support nd == 0 used by grad of SpSum()
23662369
assert y.type.ndim in (0, 2)
2367-
out = SparseTensorType(dtype=dtype, format=x.type.format)()
2370+
out = SparseTensorType(dtype=dtype, format=x.type.format, shape=x.type.shape)()
23682371
return Apply(self, [x, y], [out])
23692372

23702373
def perform(self, node, inputs, outputs):
23712374
(x, y) = inputs
23722375
(out,) = outputs
2376+
out_dtype = node.outputs[0].dtype
23732377
assert _is_sparse(x) and _is_dense(y)
2374-
if len(y.shape) == 0:
2375-
out_dtype = node.outputs[0].dtype
2376-
if x.dtype == out_dtype:
2377-
z = x.copy()
2378-
else:
2379-
z = x.astype(out_dtype)
2380-
out[0] = z
2381-
out[0].data *= y
2382-
elif len(y.shape) == 1:
2383-
raise NotImplementedError() # RowScale / ColScale
2384-
elif len(y.shape) == 2:
2378+
2379+
if x.dtype == out_dtype:
2380+
z = x.copy()
2381+
else:
2382+
z = x.astype(out_dtype)
2383+
out[0] = z
2384+
z_data = z.data
2385+
2386+
if y.ndim == 0:
2387+
z_data *= y
2388+
else: # y_ndim == 2
23852389
# if we have enough memory to fit y, maybe we can fit x.asarray()
23862390
# too?
23872391
# TODO: change runtime from O(M*N) to O(nonzeros)
23882392
M, N = x.shape
23892393
assert x.shape == y.shape
2390-
out_dtype = node.outputs[0].dtype
2391-
2394+
indices = x.indices
2395+
indptr = x.indptr
23922396
if x.format == "csc":
2393-
indices = x.indices
2394-
indptr = x.indptr
2395-
if x.dtype == out_dtype:
2396-
z = x.copy()
2397-
else:
2398-
z = x.astype(out_dtype)
2399-
z_data = z.data
2400-
24012397
for j in range(0, N):
24022398
for i_idx in range(indptr[j], indptr[j + 1]):
24032399
i = indices[i_idx]
24042400
z_data[i_idx] *= y[i, j]
2405-
out[0] = z
24062401
elif x.format == "csr":
2407-
indices = x.indices
2408-
indptr = x.indptr
2409-
if x.dtype == out_dtype:
2410-
z = x.copy()
2411-
else:
2412-
z = x.astype(out_dtype)
2413-
z_data = z.data
2414-
24152402
for i in range(0, M):
24162403
for j_idx in range(indptr[i], indptr[i + 1]):
24172404
j = indices[j_idx]
24182405
z_data[j_idx] *= y[i, j]
2419-
out[0] = z
2420-
else:
2421-
warn(
2422-
"This implementation of MulSD is deficient: {x.format}",
2423-
)
2424-
out[0] = type(x)(x.toarray() * y)
24252406

24262407
def grad(self, inputs, gout):
24272408
(x, y) = inputs
@@ -2434,12 +2415,14 @@ def infer_shape(self, fgraph, node, shapes):
24342415
return [shapes[0]]
24352416

24362417

2437-
mul_s_d = MulSD()
2418+
mul_s_d = SparseDenseMultiply()
24382419

24392420

2440-
class MulSV(Op):
2421+
class SparseDenseVectorMultiply(Op):
24412422
"""Element-wise multiplication of sparse matrix by a broadcasted dense vector element wise.
24422423
2424+
TODO: Merge with the SparseDenseMultiply Op
2425+
24432426
Notes
24442427
-----
24452428
The grad implemented is regular, i.e. not structured.
@@ -2500,7 +2483,7 @@ def infer_shape(self, fgraph, node, ins_shapes):
25002483
return [ins_shapes[0]]
25012484

25022485

2503-
mul_s_v = MulSV()
2486+
mul_s_v = SparseDenseVectorMultiply()
25042487

25052488

25062489
def mul(x, y):
@@ -2539,16 +2522,17 @@ def mul(x, y):
25392522
# mul_s_s is not implemented if the types differ
25402523
if y.dtype == "float64" and x.dtype == "float32":
25412524
x = x.astype("float64")
2542-
25432525
return mul_s_s(x, y)
2544-
elif x_is_sparse_variable and not y_is_sparse_variable:
2526+
elif x_is_sparse_variable or y_is_sparse_variable:
2527+
if y_is_sparse_variable:
2528+
x, y = y, x
25452529
# mul is unimplemented if the dtypes differ
25462530
if y.dtype == "float64" and x.dtype == "float32":
25472531
x = x.astype("float64")
2548-
2549-
return mul_s_d(x, y)
2550-
elif y_is_sparse_variable and not x_is_sparse_variable:
2551-
return mul_s_d(y, x)
2532+
if y.ndim == 1:
2533+
return mul_s_v(x, y)
2534+
else:
2535+
return mul_s_d(x, y)
25522536
else:
25532537
raise NotImplementedError()
25542538

tests/sparse/test_basic.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,12 @@
3232
EnsureSortedIndices,
3333
GetItemScalar,
3434
HStack,
35-
MulSD,
36-
MulSS,
3735
Neg,
3836
Remove0,
3937
SamplingDot,
38+
SparseDenseMultiply,
4039
SparseFromDense,
40+
SparseSparseMultiply,
4141
SparseTensorType,
4242
SquareDiagonal,
4343
StructuredDot,
@@ -514,7 +514,7 @@ def test_mul_ss(self):
514514
sp.sparse.csr_matrix(random_lil((10, 40), config.floatX, 3)),
515515
]
516516
* 2,
517-
MulSS,
517+
SparseSparseMultiply,
518518
)
519519

520520
def test_mul_sd(self):
@@ -527,7 +527,7 @@ def test_mul_sd(self):
527527
sp.sparse.csr_matrix(random_lil((10, 40), config.floatX, 3)),
528528
np.random.standard_normal((10, 40)).astype(config.floatX),
529529
],
530-
MulSD,
530+
SparseDenseMultiply,
531531
excluding=["local_mul_s_d"],
532532
)
533533

tests/sparse/test_rewriting.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,8 @@ def test_local_mul_s_d():
7777
f = pytensor.function(inputs, sparse.mul_s_d(*inputs), mode=mode)
7878

7979
assert not any(
80-
isinstance(node.op, sparse.MulSD) for node in f.maker.fgraph.toposort()
80+
isinstance(node.op, sparse.SparseDenseMultiply)
81+
for node in f.maker.fgraph.toposort()
8182
)
8283

8384

@@ -94,7 +95,8 @@ def test_local_mul_s_v():
9495
f = pytensor.function(inputs, sparse.mul_s_v(*inputs), mode=mode)
9596

9697
assert not any(
97-
isinstance(node.op, sparse.MulSV) for node in f.maker.fgraph.toposort()
98+
isinstance(node.op, sparse.SparseDenseVectorMultiply)
99+
for node in f.maker.fgraph.toposort()
98100
)
99101

100102

0 commit comments

Comments
 (0)