1717 unbox ,
1818)
1919
20+ from pytensor .link .numba .dispatch .basic import numba_funcify , numba_njit
21+ from pytensor .sparse import (
22+ CSM ,
23+ CSMProperties ,
24+ SparseDenseMultiply ,
25+ SparseDenseVectorMultiply ,
26+ )
27+
2028
2129class CSMatrixType (types .Type ):
2230 """A Numba `Type` modeled after the base class `scipy.sparse.compressed._cs_matrix`."""
@@ -27,9 +35,12 @@ class CSMatrixType(types.Type):
2735 def instance_class (data , indices , indptr , shape ):
2836 raise NotImplementedError ()
2937
30- def __init__ (self , dtype ):
31- self .dtype = dtype
32- self .data = types .Array (dtype , 1 , "A" )
38+ def __init__ (self ):
39+ # TODO: Accept dtype again
40+ # Actually accept data type, so that in can have a layout other than "A"
41+ self .dtype = types .float64
42+ # TODO: Most times data/indices/indptr are C-contiguous, allow setting those
43+ self .data = types .Array (self .dtype , 1 , "A" )
3344 self .indices = types .Array (types .int32 , 1 , "A" )
3445 self .indptr = types .Array (types .int32 , 1 , "A" )
3546 self .shape = types .UniTuple (types .int64 , 2 )
@@ -64,14 +75,14 @@ def instance_class(data, indices, indptr, shape):
6475
6576@typeof_impl .register (sp .sparse .csc_matrix )
6677def typeof_csc_matrix (val , c ):
67- data = typeof_impl (val .data , c )
68- return CSCMatrixType (data . dtype )
78+ # data = typeof_impl(val.data, c)
79+ return CSCMatrixType ()
6980
7081
7182@typeof_impl .register (sp .sparse .csr_matrix )
7283def typeof_csr_matrix (val , c ):
73- data = typeof_impl (val .data , c )
74- return CSRMatrixType (data . dtype )
84+ # data = typeof_impl(val.data, c)
85+ return CSRMatrixType ()
7586
7687
7788@register_model (CSRMatrixType )
@@ -136,6 +147,7 @@ def box_matrix(typ, val, c):
136147 indptr_obj = c .box (typ .indptr , struct_ptr .indptr )
137148 shape_obj = c .box (typ .shape , struct_ptr .shape )
138149
150+ # Why incref here, just to decref later?
139151 c .pyapi .incref (data_obj )
140152 c .pyapi .incref (indices_obj )
141153 c .pyapi .incref (indptr_obj )
@@ -154,53 +166,233 @@ def box_matrix(typ, val, c):
154166 return obj
155167
156168
169+ def _intrinsic_cs_codegen (context , builder , sig , args ):
170+ matrix_type = sig .return_type
171+ struct = cgutils .create_struct_proxy (matrix_type )(context , builder )
172+ data , indices , indptr , shape = args
173+ struct .data = data
174+ struct .indices = indices
175+ struct .indptr = indptr
176+ struct .shape = shape
177+ # TODO: Check why do we use use impl_ret_borrowed, whereas numba numpy array uses impl_ret_new_ref
178+ # Is it because we create a struct_proxy. What is that even?
179+ return impl_ret_borrowed (
180+ context ,
181+ builder ,
182+ matrix_type ,
183+ struct ._getvalue (),
184+ )
185+
186+
187+ @intrinsic
188+ def csr_matrix_from_components (typingctx , data , indices , indptr , shape ):
189+ # TODO: put dtype back in
190+ sig = CSRMatrixType ()(data , indices , indptr , shape )
191+ return sig , _intrinsic_cs_codegen
192+
193+
194+ @intrinsic
195+ def csc_matrix_from_components (typingctx , data , indices , indptr , shape ):
196+ sig = CSCMatrixType ()(data , indices , indptr , shape )
197+ return sig , _intrinsic_cs_codegen
198+
199+
200+ @overload (sp .sparse .csr_matrix )
201+ def overload_csr_matrix (arg1 , shape , dtype = None ):
202+ if not isinstance (arg1 , types .Tuple ) or len (arg1 ) != 3 :
203+ return None
204+ if isinstance (shape , types .NoneType ):
205+ return None
206+
207+ def impl (arg1 , shape , dtype = None ):
208+ data , indices , indptr = arg1
209+ return csr_matrix_from_components (data , indices , indptr , shape )
210+
211+ return impl
212+
213+
214+ @overload (sp .sparse .csc_matrix )
215+ def overload_csc_matrix (arg1 , shape , dtype = None ):
216+ if not isinstance (arg1 , types .Tuple ) or len (arg1 ) != 3 :
217+ return None
218+ if isinstance (shape , types .NoneType ):
219+ return None
220+
221+ def impl (arg1 , shape , dtype = None ):
222+ data , indices , indptr = arg1
223+ return csc_matrix_from_components (data , indices , indptr , shape )
224+
225+ return impl
226+
227+
157228@overload (np .shape )
158229def overload_sparse_shape (x ):
159230 if isinstance (x , CSMatrixType ):
160231 return lambda x : x .shape
161232
162233
163234@overload_attribute (CSMatrixType , "ndim" )
164- def overload_sparse_ndim (inst ):
165- if not isinstance (inst , CSMatrixType ):
235+ def overload_sparse_ndim (matrix ):
236+ if not isinstance (matrix , CSMatrixType ):
166237 return
167238
168- def ndim (inst ):
239+ def ndim (matrix ):
169240 return 2
170241
171242 return ndim
172243
173244
174- @intrinsic
175- def _sparse_copy (typingctx , inst , data , indices , indptr , shape ):
176- def _construct (context , builder , sig , args ):
177- typ = sig .return_type
178- struct = cgutils .create_struct_proxy (typ )(context , builder )
179- _ , data , indices , indptr , shape = args
180- struct .data = data
181- struct .indices = indices
182- struct .indptr = indptr
183- struct .shape = shape
184- return impl_ret_borrowed (
185- context ,
186- builder ,
187- sig .return_type ,
188- struct ._getvalue (),
245+ @overload_method (CSMatrixType , "copy" )
246+ def overload_sparse_copy (matrix ):
247+ match matrix :
248+ case CSRMatrixType ():
249+ builder = csr_matrix_from_components
250+ case CSCMatrixType ():
251+ builder = csc_matrix_from_components
252+ case _:
253+ return
254+
255+ def copy (matrix ):
256+ return builder (
257+ matrix .data .copy (),
258+ matrix .indices .copy (),
259+ matrix .indptr .copy (),
260+ matrix .shape ,
189261 )
190262
191- sig = inst (inst , inst .data , inst .indices , inst .indptr , inst .shape )
192-
193- return sig , _construct
194-
263+ return copy
195264
196- @overload_method (CSMatrixType , "copy" )
197- def overload_sparse_copy (inst ):
198- if not isinstance (inst , CSMatrixType ):
199- return
200265
201- def copy (inst ):
202- return _sparse_copy (
203- inst , inst .data .copy (), inst .indices .copy (), inst .indptr .copy (), inst .shape
266+ @overload_method (CSMatrixType , "astype" )
267+ def overload_sparse_astype (matrix , dtype ):
268+ match matrix :
269+ case CSRMatrixType ():
270+ builder = csr_matrix_from_components
271+ case CSCMatrixType ():
272+ builder = csc_matrix_from_components
273+ case _:
274+ return
275+
276+ def astype (matrix , dtype ):
277+ return builder (
278+ matrix .data .astype (dtype ),
279+ matrix .indices .copy (),
280+ matrix .indptr .copy (),
281+ matrix .shape ,
204282 )
205283
206- return copy
284+ return astype
285+
286+
287+ @numba_funcify .register (CSMProperties )
288+ def numba_funcify_CSMProperties (op , ** kwargs ):
289+ @numba_njit
290+ def csm_properties (x ):
291+ # Reconsider this int32/int64. Scipy/base PyTensor use int32 for indices/indptr.
292+ # But this seems to be legacy mistake and devs would choose int64 nowadays, and may move there.
293+ return x .data , x .indices , x .indptr , np .asarray (x .shape , dtype = "int64" )
294+
295+ return csm_properties
296+
297+
298+ @numba_funcify .register (CSM )
299+ def numba_funcify_CSM (op , ** kwargs ):
300+ format = op .format
301+
302+ @numba_njit
303+ def csm_constructor (data , indices , indptr , shape ):
304+ constructor_arg = (data , indices , indptr )
305+ shape_arg = (shape [0 ], shape [1 ])
306+ if format == "csr" :
307+ return sp .sparse .csr_matrix (constructor_arg , shape = shape_arg )
308+ else :
309+ return sp .sparse .csc_matrix (constructor_arg , shape = shape_arg )
310+
311+ return csm_constructor
312+
313+
314+ @numba_funcify .register (SparseDenseMultiply )
315+ @numba_funcify .register (SparseDenseVectorMultiply )
316+ def numba_funcify_SparseDenseMultiply (op , node , ** kwargs ):
317+ x , y = node .inputs
318+ [z ] = node .outputs
319+ out_dtype = z .type .dtype
320+ format = z .type .format
321+ same_dtype = x .type .dtype == out_dtype
322+
323+ if y .ndim == 0 :
324+
325+ @numba_njit
326+ def sparse_multiply_scalar (x , y ):
327+ if same_dtype :
328+ z = x .copy ()
329+ else :
330+ z = x .astype (out_dtype )
331+ # Numba doesn't know how to handle in-place mutation / assignment of fields
332+ # z.data *= y
333+ z_data = z .data
334+ z_data *= y
335+ return z
336+
337+ return sparse_multiply_scalar
338+
339+ elif y .ndim == 1 :
340+
341+ @numba_njit
342+ def sparse_dense_multiply (x , y ):
343+ assert x .shape [1 ] == y .shape [0 ]
344+ if same_dtype :
345+ z = x .copy ()
346+ else :
347+ z = x .astype (out_dtype )
348+
349+ M , N = x .shape
350+ indices = x .indices
351+ indptr = x .indptr
352+ z_data = z .data
353+ if format == "csc" :
354+ for j in range (0 , N ):
355+ for i_idx in range (indptr [j ], indptr [j + 1 ]):
356+ z_data [i_idx ] *= y [j ]
357+ return z
358+
359+ else :
360+ for i in range (0 , M ):
361+ for j_idx in range (indptr [i ], indptr [i + 1 ]):
362+ j = indices [j_idx ]
363+ z_data [j_idx ] *= y [j ]
364+
365+ return z
366+
367+ return sparse_dense_multiply
368+
369+ else : # y.ndim == 2
370+
371+ @numba_njit
372+ def sparse_dense_multiply (x , y ):
373+ assert x .shape == y .shape
374+ if same_dtype :
375+ z = x .copy ()
376+ else :
377+ z = x .astype (out_dtype )
378+
379+ M , N = x .shape
380+ indices = x .indices
381+ indptr = x .indptr
382+ z_data = z .data
383+ if format == "csc" :
384+ for j in range (0 , N ):
385+ for i_idx in range (indptr [j ], indptr [j + 1 ]):
386+ i = indices [i_idx ]
387+ z_data [i_idx ] *= y [i , j ]
388+ return z
389+
390+ else :
391+ for i in range (0 , M ):
392+ for j_idx in range (indptr [i ], indptr [i + 1 ]):
393+ j = indices [j_idx ]
394+ z_data [j_idx ] *= y [i , j ]
395+
396+ return z
397+
398+ return sparse_dense_multiply
0 commit comments