@@ -2310,7 +2310,7 @@ def sub(x, y):
23102310 return x + (- y )
23112311
23122312
2313- class MulSS (Op ):
2313+ class SparseSparseMultiply (Op ):
23142314 # mul(sparse, sparse)
23152315 # See the doc of mul() for more detail
23162316 __props__ = ()
@@ -2343,12 +2343,15 @@ def infer_shape(self, fgraph, node, shapes):
23432343 return [shapes [0 ]]
23442344
23452345
2346- mul_s_s = MulSS ()
2346+ mul_s_s = SparseSparseMultiply ()
23472347
23482348
2349- class MulSD (Op ):
2349+ class SparseDenseMultiply (Op ):
23502350 # mul(sparse, dense)
23512351 # See the doc of mul() for more detail
2352+
2353+ # We're doing useless copy of indices and indptr, those should be reused
2354+ # However, PyTensor doesn't support one output -> multiple views...
23522355 __props__ = ()
23532356
23542357 def make_node (self , x , y ):
@@ -2364,64 +2367,42 @@ def make_node(self, x, y):
23642367 # Broadcasting of the sparse matrix is not supported.
23652368 # We support nd == 0 used by grad of SpSum()
23662369 assert y .type .ndim in (0 , 2 )
2367- out = SparseTensorType (dtype = dtype , format = x .type .format )()
2370+ out = SparseTensorType (dtype = dtype , format = x .type .format , shape = x . type . shape )()
23682371 return Apply (self , [x , y ], [out ])
23692372
23702373 def perform (self , node , inputs , outputs ):
23712374 (x , y ) = inputs
23722375 (out ,) = outputs
2376+ out_dtype = node .outputs [0 ].dtype
23732377 assert _is_sparse (x ) and _is_dense (y )
2374- if len ( y . shape ) == 0 :
2375- out_dtype = node . outputs [ 0 ]. dtype
2376- if x . dtype == out_dtype :
2377- z = x . copy ()
2378- else :
2379- z = x . astype ( out_dtype )
2380- out [ 0 ] = z
2381- out [ 0 ]. data *= y
2382- elif len ( y . shape ) == 1 :
2383- raise NotImplementedError () # RowScale / ColScale
2384- elif len ( y . shape ) == 2 :
2378+
2379+ if x . dtype == out_dtype :
2380+ z = x . copy ()
2381+ else :
2382+ z = x . astype ( out_dtype )
2383+ out [ 0 ] = z
2384+ z_data = z . data
2385+
2386+ if y . ndim == 0 :
2387+ z_data *= y
2388+ else : # y_ndim == 2
23852389 # if we have enough memory to fit y, maybe we can fit x.asarray()
23862390 # too?
23872391 # TODO: change runtime from O(M*N) to O(nonzeros)
23882392 M , N = x .shape
23892393 assert x .shape == y .shape
2390- out_dtype = node . outputs [ 0 ]. dtype
2391-
2394+ indices = x . indices
2395+ indptr = x . indptr
23922396 if x .format == "csc" :
2393- indices = x .indices
2394- indptr = x .indptr
2395- if x .dtype == out_dtype :
2396- z = x .copy ()
2397- else :
2398- z = x .astype (out_dtype )
2399- z_data = z .data
2400-
24012397 for j in range (0 , N ):
24022398 for i_idx in range (indptr [j ], indptr [j + 1 ]):
24032399 i = indices [i_idx ]
24042400 z_data [i_idx ] *= y [i , j ]
2405- out [0 ] = z
24062401 elif x .format == "csr" :
2407- indices = x .indices
2408- indptr = x .indptr
2409- if x .dtype == out_dtype :
2410- z = x .copy ()
2411- else :
2412- z = x .astype (out_dtype )
2413- z_data = z .data
2414-
24152402 for i in range (0 , M ):
24162403 for j_idx in range (indptr [i ], indptr [i + 1 ]):
24172404 j = indices [j_idx ]
24182405 z_data [j_idx ] *= y [i , j ]
2419- out [0 ] = z
2420- else :
2421- warn (
2422- "This implementation of MulSD is deficient: {x.format}" ,
2423- )
2424- out [0 ] = type (x )(x .toarray () * y )
24252406
24262407 def grad (self , inputs , gout ):
24272408 (x , y ) = inputs
@@ -2434,12 +2415,14 @@ def infer_shape(self, fgraph, node, shapes):
24342415 return [shapes [0 ]]
24352416
24362417
2437- mul_s_d = MulSD ()
2418+ mul_s_d = SparseDenseMultiply ()
24382419
24392420
2440- class MulSV (Op ):
2421+ class SparseDenseVectorMultiply (Op ):
24412422 """Element-wise multiplication of sparse matrix by a broadcasted dense vector element wise.
24422423
2424+ TODO: Merge with the SparseDenseMultiply Op
2425+
24432426 Notes
24442427 -----
24452428 The grad implemented is regular, i.e. not structured.
@@ -2500,7 +2483,7 @@ def infer_shape(self, fgraph, node, ins_shapes):
25002483 return [ins_shapes [0 ]]
25012484
25022485
2503- mul_s_v = MulSV ()
2486+ mul_s_v = SparseDenseVectorMultiply ()
25042487
25052488
25062489def mul (x , y ):
@@ -2539,16 +2522,17 @@ def mul(x, y):
25392522 # mul_s_s is not implemented if the types differ
25402523 if y .dtype == "float64" and x .dtype == "float32" :
25412524 x = x .astype ("float64" )
2542-
25432525 return mul_s_s (x , y )
2544- elif x_is_sparse_variable and not y_is_sparse_variable :
2526+ elif x_is_sparse_variable or y_is_sparse_variable :
2527+ if y_is_sparse_variable :
2528+ x , y = y , x
25452529 # mul is unimplemented if the dtypes differ
25462530 if y .dtype == "float64" and x .dtype == "float32" :
25472531 x = x .astype ("float64" )
2548-
2549- return mul_s_d (x , y )
2550- elif y_is_sparse_variable and not x_is_sparse_variable :
2551- return mul_s_d (y , x )
2532+ if y . ndim == 1 :
2533+ return mul_s_v (x , y )
2534+ else :
2535+ return mul_s_d (x , y )
25522536 else :
25532537 raise NotImplementedError ()
25542538
0 commit comments