Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions pytensor/link/numba/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,12 @@
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.type import Type
from pytensor.ifelse import IfElse
from pytensor.link.numba.dispatch.sparse import CSCMatrixType, CSRMatrixType
from pytensor.link.utils import (
fgraph_to_python,
)
from pytensor.scalar.basic import ScalarType
from pytensor.sparse import SparseTensorType
from pytensor.tensor.type import TensorType
from pytensor.tensor.type import DenseTensorType, TensorType


def numba_njit(*args, fastmath=None, **kwargs):
Expand Down Expand Up @@ -81,7 +80,7 @@ def get_numba_type(
Return Numba scalars for zero dimensional :class:`TensorType`\s.
"""

if isinstance(pytensor_type, TensorType):
if isinstance(pytensor_type, DenseTensorType):
dtype = pytensor_type.numpy_dtype
numba_dtype = numba.from_dtype(dtype)
if force_scalar or (
Expand All @@ -94,12 +93,14 @@ def get_numba_type(
numba_dtype = numba.from_dtype(dtype)
return numba_dtype
elif isinstance(pytensor_type, SparseTensorType):
from pytensor.link.numba.dispatch.sparse import CSCMatrixType, CSRMatrixType

dtype = pytensor_type.numpy_dtype
numba_dtype = numba.from_dtype(dtype)
# numba_dtype = numba.from_dtype(dtype)
if pytensor_type.format == "csr":
return CSRMatrixType(numba_dtype)
return CSRMatrixType()
if pytensor_type.format == "csc":
return CSCMatrixType(numba_dtype)
return CSCMatrixType()

raise NotImplementedError()
else:
Expand Down Expand Up @@ -339,6 +340,7 @@ def identity(x):

@numba_funcify.register(DeepCopyOp)
def numba_funcify_DeepCopyOp(op, node, **kwargs):
# FIXME: SparseTensorType will match on this condition, but `np.copy` doesn't work with them
if isinstance(node.inputs[0].type, TensorType):

@numba_njit
Expand Down
266 changes: 229 additions & 37 deletions pytensor/link/numba/dispatch/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@
unbox,
)

from pytensor.link.numba.dispatch.basic import numba_funcify, numba_njit
from pytensor.sparse import (
CSM,
CSMProperties,
SparseDenseMultiply,
SparseDenseVectorMultiply,
)


class CSMatrixType(types.Type):
"""A Numba `Type` modeled after the base class `scipy.sparse.compressed._cs_matrix`."""
Expand All @@ -27,9 +35,12 @@ class CSMatrixType(types.Type):
def instance_class(data, indices, indptr, shape):
raise NotImplementedError()

def __init__(self, dtype):
self.dtype = dtype
self.data = types.Array(dtype, 1, "A")
def __init__(self):
# TODO: Accept dtype again
# Actually accept data type, so that in can have a layout other than "A"
self.dtype = types.float64
# TODO: Most times data/indices/indptr are C-contiguous, allow setting those
self.data = types.Array(self.dtype, 1, "A")
self.indices = types.Array(types.int32, 1, "A")
self.indptr = types.Array(types.int32, 1, "A")
self.shape = types.UniTuple(types.int64, 2)
Expand Down Expand Up @@ -64,14 +75,14 @@ def instance_class(data, indices, indptr, shape):

@typeof_impl.register(sp.sparse.csc_matrix)
def typeof_csc_matrix(val, c):
data = typeof_impl(val.data, c)
return CSCMatrixType(data.dtype)
# data = typeof_impl(val.data, c)
return CSCMatrixType()


@typeof_impl.register(sp.sparse.csr_matrix)
def typeof_csr_matrix(val, c):
data = typeof_impl(val.data, c)
return CSRMatrixType(data.dtype)
# data = typeof_impl(val.data, c)
return CSRMatrixType()


@register_model(CSRMatrixType)
Expand Down Expand Up @@ -136,6 +147,7 @@ def box_matrix(typ, val, c):
indptr_obj = c.box(typ.indptr, struct_ptr.indptr)
shape_obj = c.box(typ.shape, struct_ptr.shape)

# Why incref here, just to decref later?
c.pyapi.incref(data_obj)
c.pyapi.incref(indices_obj)
c.pyapi.incref(indptr_obj)
Expand All @@ -154,53 +166,233 @@ def box_matrix(typ, val, c):
return obj


def _intrinsic_cs_codegen(context, builder, sig, args):
matrix_type = sig.return_type
struct = cgutils.create_struct_proxy(matrix_type)(context, builder)
data, indices, indptr, shape = args
struct.data = data
struct.indices = indices
struct.indptr = indptr
struct.shape = shape
# TODO: Check why do we use use impl_ret_borrowed, whereas numba numpy array uses impl_ret_new_ref
# Is it because we create a struct_proxy. What is that even?
return impl_ret_borrowed(
context,
builder,
matrix_type,
struct._getvalue(),
)


@intrinsic
def csr_matrix_from_components(typingctx, data, indices, indptr, shape):
# TODO: put dtype back in
sig = CSRMatrixType()(data, indices, indptr, shape)
return sig, _intrinsic_cs_codegen


@intrinsic
def csc_matrix_from_components(typingctx, data, indices, indptr, shape):
sig = CSCMatrixType()(data, indices, indptr, shape)
return sig, _intrinsic_cs_codegen


@overload(sp.sparse.csr_matrix)
def overload_csr_matrix(arg1, shape, dtype=None):
if not isinstance(arg1, types.Tuple) or len(arg1) != 3:
return None
if isinstance(shape, types.NoneType):
return None

def impl(arg1, shape, dtype=None):
data, indices, indptr = arg1
return csr_matrix_from_components(data, indices, indptr, shape)

return impl


@overload(sp.sparse.csc_matrix)
def overload_csc_matrix(arg1, shape, dtype=None):
if not isinstance(arg1, types.Tuple) or len(arg1) != 3:
return None
if isinstance(shape, types.NoneType):
return None

def impl(arg1, shape, dtype=None):
data, indices, indptr = arg1
return csc_matrix_from_components(data, indices, indptr, shape)

return impl


@overload(np.shape)
def overload_sparse_shape(x):
if isinstance(x, CSMatrixType):
return lambda x: x.shape


@overload_attribute(CSMatrixType, "ndim")
def overload_sparse_ndim(inst):
if not isinstance(inst, CSMatrixType):
def overload_sparse_ndim(matrix):
if not isinstance(matrix, CSMatrixType):
return

def ndim(inst):
def ndim(matrix):
return 2

return ndim


@intrinsic
def _sparse_copy(typingctx, inst, data, indices, indptr, shape):
def _construct(context, builder, sig, args):
typ = sig.return_type
struct = cgutils.create_struct_proxy(typ)(context, builder)
_, data, indices, indptr, shape = args
struct.data = data
struct.indices = indices
struct.indptr = indptr
struct.shape = shape
return impl_ret_borrowed(
context,
builder,
sig.return_type,
struct._getvalue(),
@overload_method(CSMatrixType, "copy")
def overload_sparse_copy(matrix):
match matrix:
case CSRMatrixType():
builder = csr_matrix_from_components
case CSCMatrixType():
builder = csc_matrix_from_components
case _:
return

def copy(matrix):
return builder(
matrix.data.copy(),
matrix.indices.copy(),
matrix.indptr.copy(),
matrix.shape,
)

sig = inst(inst, inst.data, inst.indices, inst.indptr, inst.shape)

return sig, _construct

return copy

@overload_method(CSMatrixType, "copy")
def overload_sparse_copy(inst):
if not isinstance(inst, CSMatrixType):
return

def copy(inst):
return _sparse_copy(
inst, inst.data.copy(), inst.indices.copy(), inst.indptr.copy(), inst.shape
@overload_method(CSMatrixType, "astype")
def overload_sparse_astype(matrix, dtype):
match matrix:
case CSRMatrixType():
builder = csr_matrix_from_components
case CSCMatrixType():
builder = csc_matrix_from_components
case _:
return

def astype(matrix, dtype):
return builder(
matrix.data.astype(dtype),
matrix.indices.copy(),
matrix.indptr.copy(),
matrix.shape,
)

return copy
return astype


@numba_funcify.register(CSMProperties)
def numba_funcify_CSMProperties(op, **kwargs):
@numba_njit
def csm_properties(x):
# Reconsider this int32/int64. Scipy/base PyTensor use int32 for indices/indptr.
# But this seems to be legacy mistake and devs would choose int64 nowadays, and may move there.
return x.data, x.indices, x.indptr, np.asarray(x.shape, dtype="int64")

return csm_properties


@numba_funcify.register(CSM)
def numba_funcify_CSM(op, **kwargs):
format = op.format

@numba_njit
def csm_constructor(data, indices, indptr, shape):
constructor_arg = (data, indices, indptr)
shape_arg = (shape[0], shape[1])
if format == "csr":
return sp.sparse.csr_matrix(constructor_arg, shape=shape_arg)
else:
return sp.sparse.csc_matrix(constructor_arg, shape=shape_arg)

return csm_constructor


@numba_funcify.register(SparseDenseMultiply)
@numba_funcify.register(SparseDenseVectorMultiply)
def numba_funcify_SparseDenseMultiply(op, node, **kwargs):
x, y = node.inputs
[z] = node.outputs
out_dtype = z.type.dtype
format = z.type.format
same_dtype = x.type.dtype == out_dtype

if y.ndim == 0:

@numba_njit
def sparse_multiply_scalar(x, y):
if same_dtype:
z = x.copy()
else:
z = x.astype(out_dtype)
# Numba doesn't know how to handle in-place mutation / assignment of fields
# z.data *= y
z_data = z.data
z_data *= y
return z

return sparse_multiply_scalar

elif y.ndim == 1:

@numba_njit
def sparse_dense_multiply(x, y):
assert x.shape[1] == y.shape[0]
if same_dtype:
z = x.copy()
else:
z = x.astype(out_dtype)

M, N = x.shape
indices = x.indices
indptr = x.indptr
z_data = z.data
if format == "csc":
for j in range(0, N):
for i_idx in range(indptr[j], indptr[j + 1]):
z_data[i_idx] *= y[j]
return z

else:
for i in range(0, M):
for j_idx in range(indptr[i], indptr[i + 1]):
j = indices[j_idx]
z_data[j_idx] *= y[j]

return z

return sparse_dense_multiply

else: # y.ndim == 2

@numba_njit
def sparse_dense_multiply(x, y):
assert x.shape == y.shape
if same_dtype:
z = x.copy()
else:
z = x.astype(out_dtype)

M, N = x.shape
indices = x.indices
indptr = x.indptr
z_data = z.data
if format == "csc":
for j in range(0, N):
for i_idx in range(indptr[j], indptr[j + 1]):
i = indices[i_idx]
z_data[i_idx] *= y[i, j]
return z

else:
for i in range(0, M):
for j_idx in range(indptr[i], indptr[i + 1]):
j = indices[j_idx]
z_data[j_idx] *= y[i, j]

return z

return sparse_dense_multiply
Loading
Loading