Skip to content

Commit 913e012

Browse files
committed
Implement basic sparse Ops in Numba
1 parent 37a240e commit 913e012

File tree

3 files changed

+316
-43
lines changed

3 files changed

+316
-43
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,12 @@
1515
from pytensor.graph.fg import FunctionGraph
1616
from pytensor.graph.type import Type
1717
from pytensor.ifelse import IfElse
18-
from pytensor.link.numba.dispatch.sparse import CSCMatrixType, CSRMatrixType
1918
from pytensor.link.utils import (
2019
fgraph_to_python,
2120
)
2221
from pytensor.scalar.basic import ScalarType
2322
from pytensor.sparse import SparseTensorType
24-
from pytensor.tensor.type import TensorType
23+
from pytensor.tensor.type import DenseTensorType, TensorType
2524

2625

2726
def numba_njit(*args, fastmath=None, **kwargs):
@@ -81,7 +80,7 @@ def get_numba_type(
8180
Return Numba scalars for zero dimensional :class:`TensorType`\s.
8281
"""
8382

84-
if isinstance(pytensor_type, TensorType):
83+
if isinstance(pytensor_type, DenseTensorType):
8584
dtype = pytensor_type.numpy_dtype
8685
numba_dtype = numba.from_dtype(dtype)
8786
if force_scalar or (
@@ -94,12 +93,14 @@ def get_numba_type(
9493
numba_dtype = numba.from_dtype(dtype)
9594
return numba_dtype
9695
elif isinstance(pytensor_type, SparseTensorType):
96+
from pytensor.link.numba.dispatch.sparse import CSCMatrixType, CSRMatrixType
97+
9798
dtype = pytensor_type.numpy_dtype
98-
numba_dtype = numba.from_dtype(dtype)
99+
# numba_dtype = numba.from_dtype(dtype)
99100
if pytensor_type.format == "csr":
100-
return CSRMatrixType(numba_dtype)
101+
return CSRMatrixType()
101102
if pytensor_type.format == "csc":
102-
return CSCMatrixType(numba_dtype)
103+
return CSCMatrixType()
103104

104105
raise NotImplementedError()
105106
else:
@@ -339,6 +340,7 @@ def identity(x):
339340

340341
@numba_funcify.register(DeepCopyOp)
341342
def numba_funcify_DeepCopyOp(op, node, **kwargs):
343+
# FIXME: SparseTensorType will match on this condition, but `np.copy` doesn't work with them
342344
if isinstance(node.inputs[0].type, TensorType):
343345

344346
@numba_njit

pytensor/link/numba/dispatch/sparse.py

Lines changed: 229 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,14 @@
1717
unbox,
1818
)
1919

20+
from pytensor.link.numba.dispatch.basic import numba_funcify, numba_njit
21+
from pytensor.sparse import (
22+
CSM,
23+
CSMProperties,
24+
SparseDenseMultiply,
25+
SparseDenseVectorMultiply,
26+
)
27+
2028

2129
class CSMatrixType(types.Type):
2230
"""A Numba `Type` modeled after the base class `scipy.sparse.compressed._cs_matrix`."""
@@ -27,9 +35,12 @@ class CSMatrixType(types.Type):
2735
def instance_class(data, indices, indptr, shape):
2836
raise NotImplementedError()
2937

30-
def __init__(self, dtype):
31-
self.dtype = dtype
32-
self.data = types.Array(dtype, 1, "A")
38+
def __init__(self):
39+
# TODO: Accept dtype again
40+
# Actually accept data type, so that in can have a layout other than "A"
41+
self.dtype = types.float64
42+
# TODO: Most times data/indices/indptr are C-contiguous, allow setting those
43+
self.data = types.Array(self.dtype, 1, "A")
3344
self.indices = types.Array(types.int32, 1, "A")
3445
self.indptr = types.Array(types.int32, 1, "A")
3546
self.shape = types.UniTuple(types.int64, 2)
@@ -64,14 +75,14 @@ def instance_class(data, indices, indptr, shape):
6475

6576
@typeof_impl.register(sp.sparse.csc_matrix)
6677
def typeof_csc_matrix(val, c):
67-
data = typeof_impl(val.data, c)
68-
return CSCMatrixType(data.dtype)
78+
# data = typeof_impl(val.data, c)
79+
return CSCMatrixType()
6980

7081

7182
@typeof_impl.register(sp.sparse.csr_matrix)
7283
def typeof_csr_matrix(val, c):
73-
data = typeof_impl(val.data, c)
74-
return CSRMatrixType(data.dtype)
84+
# data = typeof_impl(val.data, c)
85+
return CSRMatrixType()
7586

7687

7788
@register_model(CSRMatrixType)
@@ -136,6 +147,7 @@ def box_matrix(typ, val, c):
136147
indptr_obj = c.box(typ.indptr, struct_ptr.indptr)
137148
shape_obj = c.box(typ.shape, struct_ptr.shape)
138149

150+
# Why incref here, just to decref later?
139151
c.pyapi.incref(data_obj)
140152
c.pyapi.incref(indices_obj)
141153
c.pyapi.incref(indptr_obj)
@@ -154,53 +166,233 @@ def box_matrix(typ, val, c):
154166
return obj
155167

156168

169+
def _intrinsic_cs_codegen(context, builder, sig, args):
170+
matrix_type = sig.return_type
171+
struct = cgutils.create_struct_proxy(matrix_type)(context, builder)
172+
data, indices, indptr, shape = args
173+
struct.data = data
174+
struct.indices = indices
175+
struct.indptr = indptr
176+
struct.shape = shape
177+
# TODO: Check why do we use use impl_ret_borrowed, whereas numba numpy array uses impl_ret_new_ref
178+
# Is it because we create a struct_proxy. What is that even?
179+
return impl_ret_borrowed(
180+
context,
181+
builder,
182+
matrix_type,
183+
struct._getvalue(),
184+
)
185+
186+
187+
@intrinsic
188+
def csr_matrix_from_components(typingctx, data, indices, indptr, shape):
189+
# TODO: put dtype back in
190+
sig = CSRMatrixType()(data, indices, indptr, shape)
191+
return sig, _intrinsic_cs_codegen
192+
193+
194+
@intrinsic
195+
def csc_matrix_from_components(typingctx, data, indices, indptr, shape):
196+
sig = CSCMatrixType()(data, indices, indptr, shape)
197+
return sig, _intrinsic_cs_codegen
198+
199+
200+
@overload(sp.sparse.csr_matrix)
201+
def overload_csr_matrix(arg1, shape, dtype=None):
202+
if not isinstance(arg1, types.Tuple) or len(arg1) != 3:
203+
return None
204+
if isinstance(shape, types.NoneType):
205+
return None
206+
207+
def impl(arg1, shape, dtype=None):
208+
data, indices, indptr = arg1
209+
return csr_matrix_from_components(data, indices, indptr, shape)
210+
211+
return impl
212+
213+
214+
@overload(sp.sparse.csc_matrix)
215+
def overload_csc_matrix(arg1, shape, dtype=None):
216+
if not isinstance(arg1, types.Tuple) or len(arg1) != 3:
217+
return None
218+
if isinstance(shape, types.NoneType):
219+
return None
220+
221+
def impl(arg1, shape, dtype=None):
222+
data, indices, indptr = arg1
223+
return csc_matrix_from_components(data, indices, indptr, shape)
224+
225+
return impl
226+
227+
157228
@overload(np.shape)
158229
def overload_sparse_shape(x):
159230
if isinstance(x, CSMatrixType):
160231
return lambda x: x.shape
161232

162233

163234
@overload_attribute(CSMatrixType, "ndim")
164-
def overload_sparse_ndim(inst):
165-
if not isinstance(inst, CSMatrixType):
235+
def overload_sparse_ndim(matrix):
236+
if not isinstance(matrix, CSMatrixType):
166237
return
167238

168-
def ndim(inst):
239+
def ndim(matrix):
169240
return 2
170241

171242
return ndim
172243

173244

174-
@intrinsic
175-
def _sparse_copy(typingctx, inst, data, indices, indptr, shape):
176-
def _construct(context, builder, sig, args):
177-
typ = sig.return_type
178-
struct = cgutils.create_struct_proxy(typ)(context, builder)
179-
_, data, indices, indptr, shape = args
180-
struct.data = data
181-
struct.indices = indices
182-
struct.indptr = indptr
183-
struct.shape = shape
184-
return impl_ret_borrowed(
185-
context,
186-
builder,
187-
sig.return_type,
188-
struct._getvalue(),
245+
@overload_method(CSMatrixType, "copy")
246+
def overload_sparse_copy(matrix):
247+
match matrix:
248+
case CSRMatrixType():
249+
builder = csr_matrix_from_components
250+
case CSCMatrixType():
251+
builder = csc_matrix_from_components
252+
case _:
253+
return
254+
255+
def copy(matrix):
256+
return builder(
257+
matrix.data.copy(),
258+
matrix.indices.copy(),
259+
matrix.indptr.copy(),
260+
matrix.shape,
189261
)
190262

191-
sig = inst(inst, inst.data, inst.indices, inst.indptr, inst.shape)
192-
193-
return sig, _construct
194-
263+
return copy
195264

196-
@overload_method(CSMatrixType, "copy")
197-
def overload_sparse_copy(inst):
198-
if not isinstance(inst, CSMatrixType):
199-
return
200265

201-
def copy(inst):
202-
return _sparse_copy(
203-
inst, inst.data.copy(), inst.indices.copy(), inst.indptr.copy(), inst.shape
266+
@overload_method(CSMatrixType, "astype")
267+
def overload_sparse_astype(matrix, dtype):
268+
match matrix:
269+
case CSRMatrixType():
270+
builder = csr_matrix_from_components
271+
case CSCMatrixType():
272+
builder = csc_matrix_from_components
273+
case _:
274+
return
275+
276+
def astype(matrix, dtype):
277+
return builder(
278+
matrix.data.astype(dtype),
279+
matrix.indices.copy(),
280+
matrix.indptr.copy(),
281+
matrix.shape,
204282
)
205283

206-
return copy
284+
return astype
285+
286+
287+
@numba_funcify.register(CSMProperties)
288+
def numba_funcify_CSMProperties(op, **kwargs):
289+
@numba_njit
290+
def csm_properties(x):
291+
# Reconsider this int32/int64. Scipy/base PyTensor use int32 for indices/indptr.
292+
# But this seems to be legacy mistake and devs would choose int64 nowadays, and may move there.
293+
return x.data, x.indices, x.indptr, np.asarray(x.shape, dtype="int64")
294+
295+
return csm_properties
296+
297+
298+
@numba_funcify.register(CSM)
299+
def numba_funcify_CSM(op, **kwargs):
300+
format = op.format
301+
302+
@numba_njit
303+
def csm_constructor(data, indices, indptr, shape):
304+
constructor_arg = (data, indices, indptr)
305+
shape_arg = (shape[0], shape[1])
306+
if format == "csr":
307+
return sp.sparse.csr_matrix(constructor_arg, shape=shape_arg)
308+
else:
309+
return sp.sparse.csc_matrix(constructor_arg, shape=shape_arg)
310+
311+
return csm_constructor
312+
313+
314+
@numba_funcify.register(SparseDenseMultiply)
315+
@numba_funcify.register(SparseDenseVectorMultiply)
316+
def numba_funcify_SparseDenseMultiply(op, node, **kwargs):
317+
x, y = node.inputs
318+
[z] = node.outputs
319+
out_dtype = z.type.dtype
320+
format = z.type.format
321+
same_dtype = x.type.dtype == out_dtype
322+
323+
if y.ndim == 0:
324+
325+
@numba_njit
326+
def sparse_multiply_scalar(x, y):
327+
if same_dtype:
328+
z = x.copy()
329+
else:
330+
z = x.astype(out_dtype)
331+
# Numba doesn't know how to handle in-place mutation / assignment of fields
332+
# z.data *= y
333+
z_data = z.data
334+
z_data *= y
335+
return z
336+
337+
return sparse_multiply_scalar
338+
339+
elif y.ndim == 1:
340+
341+
@numba_njit
342+
def sparse_dense_multiply(x, y):
343+
assert x.shape[1] == y.shape[0]
344+
if same_dtype:
345+
z = x.copy()
346+
else:
347+
z = x.astype(out_dtype)
348+
349+
M, N = x.shape
350+
indices = x.indices
351+
indptr = x.indptr
352+
z_data = z.data
353+
if format == "csc":
354+
for j in range(0, N):
355+
for i_idx in range(indptr[j], indptr[j + 1]):
356+
z_data[i_idx] *= y[j]
357+
return z
358+
359+
else:
360+
for i in range(0, M):
361+
for j_idx in range(indptr[i], indptr[i + 1]):
362+
j = indices[j_idx]
363+
z_data[j_idx] *= y[j]
364+
365+
return z
366+
367+
return sparse_dense_multiply
368+
369+
else: # y.ndim == 2
370+
371+
@numba_njit
372+
def sparse_dense_multiply(x, y):
373+
assert x.shape == y.shape
374+
if same_dtype:
375+
z = x.copy()
376+
else:
377+
z = x.astype(out_dtype)
378+
379+
M, N = x.shape
380+
indices = x.indices
381+
indptr = x.indptr
382+
z_data = z.data
383+
if format == "csc":
384+
for j in range(0, N):
385+
for i_idx in range(indptr[j], indptr[j + 1]):
386+
i = indices[i_idx]
387+
z_data[i_idx] *= y[i, j]
388+
return z
389+
390+
else:
391+
for i in range(0, M):
392+
for j_idx in range(indptr[i], indptr[i + 1]):
393+
j = indices[j_idx]
394+
z_data[j_idx] *= y[i, j]
395+
396+
return z
397+
398+
return sparse_dense_multiply

0 commit comments

Comments
 (0)