diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 0d4217a786..4a8404550e 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -15,13 +15,12 @@ from pytensor.graph.fg import FunctionGraph from pytensor.graph.type import Type from pytensor.ifelse import IfElse -from pytensor.link.numba.dispatch.sparse import CSCMatrixType, CSRMatrixType from pytensor.link.utils import ( fgraph_to_python, ) from pytensor.scalar.basic import ScalarType from pytensor.sparse import SparseTensorType -from pytensor.tensor.type import TensorType +from pytensor.tensor.type import DenseTensorType, TensorType def numba_njit(*args, fastmath=None, **kwargs): @@ -81,7 +80,7 @@ def get_numba_type( Return Numba scalars for zero dimensional :class:`TensorType`\s. """ - if isinstance(pytensor_type, TensorType): + if isinstance(pytensor_type, DenseTensorType): dtype = pytensor_type.numpy_dtype numba_dtype = numba.from_dtype(dtype) if force_scalar or ( @@ -94,12 +93,14 @@ def get_numba_type( numba_dtype = numba.from_dtype(dtype) return numba_dtype elif isinstance(pytensor_type, SparseTensorType): + from pytensor.link.numba.dispatch.sparse import CSCMatrixType, CSRMatrixType + dtype = pytensor_type.numpy_dtype - numba_dtype = numba.from_dtype(dtype) + # numba_dtype = numba.from_dtype(dtype) if pytensor_type.format == "csr": - return CSRMatrixType(numba_dtype) + return CSRMatrixType() if pytensor_type.format == "csc": - return CSCMatrixType(numba_dtype) + return CSCMatrixType() raise NotImplementedError() else: @@ -339,6 +340,7 @@ def identity(x): @numba_funcify.register(DeepCopyOp) def numba_funcify_DeepCopyOp(op, node, **kwargs): + # FIXME: SparseTensorType will match on this condition, but `np.copy` doesn't work with them if isinstance(node.inputs[0].type, TensorType): @numba_njit diff --git a/pytensor/link/numba/dispatch/sparse.py b/pytensor/link/numba/dispatch/sparse.py index e25083e92d..992f3f2084 100644 --- a/pytensor/link/numba/dispatch/sparse.py +++ b/pytensor/link/numba/dispatch/sparse.py @@ -17,6 +17,14 @@ unbox, ) +from pytensor.link.numba.dispatch.basic import numba_funcify, numba_njit +from pytensor.sparse import ( + CSM, + CSMProperties, + SparseDenseMultiply, + SparseDenseVectorMultiply, +) + class CSMatrixType(types.Type): """A Numba `Type` modeled after the base class `scipy.sparse.compressed._cs_matrix`.""" @@ -27,9 +35,12 @@ class CSMatrixType(types.Type): def instance_class(data, indices, indptr, shape): raise NotImplementedError() - def __init__(self, dtype): - self.dtype = dtype - self.data = types.Array(dtype, 1, "A") + def __init__(self): + # TODO: Accept dtype again + # Actually accept data type, so that in can have a layout other than "A" + self.dtype = types.float64 + # TODO: Most times data/indices/indptr are C-contiguous, allow setting those + self.data = types.Array(self.dtype, 1, "A") self.indices = types.Array(types.int32, 1, "A") self.indptr = types.Array(types.int32, 1, "A") self.shape = types.UniTuple(types.int64, 2) @@ -64,14 +75,14 @@ def instance_class(data, indices, indptr, shape): @typeof_impl.register(sp.sparse.csc_matrix) def typeof_csc_matrix(val, c): - data = typeof_impl(val.data, c) - return CSCMatrixType(data.dtype) + # data = typeof_impl(val.data, c) + return CSCMatrixType() @typeof_impl.register(sp.sparse.csr_matrix) def typeof_csr_matrix(val, c): - data = typeof_impl(val.data, c) - return CSRMatrixType(data.dtype) + # data = typeof_impl(val.data, c) + return CSRMatrixType() @register_model(CSRMatrixType) @@ -136,6 +147,7 @@ def box_matrix(typ, val, c): indptr_obj = c.box(typ.indptr, struct_ptr.indptr) shape_obj = c.box(typ.shape, struct_ptr.shape) + # Why incref here, just to decref later? c.pyapi.incref(data_obj) c.pyapi.incref(indices_obj) c.pyapi.incref(indptr_obj) @@ -154,6 +166,65 @@ def box_matrix(typ, val, c): return obj +def _intrinsic_cs_codegen(context, builder, sig, args): + matrix_type = sig.return_type + struct = cgutils.create_struct_proxy(matrix_type)(context, builder) + data, indices, indptr, shape = args + struct.data = data + struct.indices = indices + struct.indptr = indptr + struct.shape = shape + # TODO: Check why do we use use impl_ret_borrowed, whereas numba numpy array uses impl_ret_new_ref + # Is it because we create a struct_proxy. What is that even? + return impl_ret_borrowed( + context, + builder, + matrix_type, + struct._getvalue(), + ) + + +@intrinsic +def csr_matrix_from_components(typingctx, data, indices, indptr, shape): + # TODO: put dtype back in + sig = CSRMatrixType()(data, indices, indptr, shape) + return sig, _intrinsic_cs_codegen + + +@intrinsic +def csc_matrix_from_components(typingctx, data, indices, indptr, shape): + sig = CSCMatrixType()(data, indices, indptr, shape) + return sig, _intrinsic_cs_codegen + + +@overload(sp.sparse.csr_matrix) +def overload_csr_matrix(arg1, shape, dtype=None): + if not isinstance(arg1, types.Tuple) or len(arg1) != 3: + return None + if isinstance(shape, types.NoneType): + return None + + def impl(arg1, shape, dtype=None): + data, indices, indptr = arg1 + return csr_matrix_from_components(data, indices, indptr, shape) + + return impl + + +@overload(sp.sparse.csc_matrix) +def overload_csc_matrix(arg1, shape, dtype=None): + if not isinstance(arg1, types.Tuple) or len(arg1) != 3: + return None + if isinstance(shape, types.NoneType): + return None + + def impl(arg1, shape, dtype=None): + data, indices, indptr = arg1 + return csc_matrix_from_components(data, indices, indptr, shape) + + return impl + + @overload(np.shape) def overload_sparse_shape(x): if isinstance(x, CSMatrixType): @@ -161,46 +232,167 @@ def overload_sparse_shape(x): @overload_attribute(CSMatrixType, "ndim") -def overload_sparse_ndim(inst): - if not isinstance(inst, CSMatrixType): +def overload_sparse_ndim(matrix): + if not isinstance(matrix, CSMatrixType): return - def ndim(inst): + def ndim(matrix): return 2 return ndim -@intrinsic -def _sparse_copy(typingctx, inst, data, indices, indptr, shape): - def _construct(context, builder, sig, args): - typ = sig.return_type - struct = cgutils.create_struct_proxy(typ)(context, builder) - _, data, indices, indptr, shape = args - struct.data = data - struct.indices = indices - struct.indptr = indptr - struct.shape = shape - return impl_ret_borrowed( - context, - builder, - sig.return_type, - struct._getvalue(), +@overload_method(CSMatrixType, "copy") +def overload_sparse_copy(matrix): + match matrix: + case CSRMatrixType(): + builder = csr_matrix_from_components + case CSCMatrixType(): + builder = csc_matrix_from_components + case _: + return + + def copy(matrix): + return builder( + matrix.data.copy(), + matrix.indices.copy(), + matrix.indptr.copy(), + matrix.shape, ) - sig = inst(inst, inst.data, inst.indices, inst.indptr, inst.shape) - - return sig, _construct - + return copy -@overload_method(CSMatrixType, "copy") -def overload_sparse_copy(inst): - if not isinstance(inst, CSMatrixType): - return - def copy(inst): - return _sparse_copy( - inst, inst.data.copy(), inst.indices.copy(), inst.indptr.copy(), inst.shape +@overload_method(CSMatrixType, "astype") +def overload_sparse_astype(matrix, dtype): + match matrix: + case CSRMatrixType(): + builder = csr_matrix_from_components + case CSCMatrixType(): + builder = csc_matrix_from_components + case _: + return + + def astype(matrix, dtype): + return builder( + matrix.data.astype(dtype), + matrix.indices.copy(), + matrix.indptr.copy(), + matrix.shape, ) - return copy + return astype + + +@numba_funcify.register(CSMProperties) +def numba_funcify_CSMProperties(op, **kwargs): + @numba_njit + def csm_properties(x): + # Reconsider this int32/int64. Scipy/base PyTensor use int32 for indices/indptr. + # But this seems to be legacy mistake and devs would choose int64 nowadays, and may move there. + return x.data, x.indices, x.indptr, np.asarray(x.shape, dtype="int64") + + return csm_properties + + +@numba_funcify.register(CSM) +def numba_funcify_CSM(op, **kwargs): + format = op.format + + @numba_njit + def csm_constructor(data, indices, indptr, shape): + constructor_arg = (data, indices, indptr) + shape_arg = (shape[0], shape[1]) + if format == "csr": + return sp.sparse.csr_matrix(constructor_arg, shape=shape_arg) + else: + return sp.sparse.csc_matrix(constructor_arg, shape=shape_arg) + + return csm_constructor + + +@numba_funcify.register(SparseDenseMultiply) +@numba_funcify.register(SparseDenseVectorMultiply) +def numba_funcify_SparseDenseMultiply(op, node, **kwargs): + x, y = node.inputs + [z] = node.outputs + out_dtype = z.type.dtype + format = z.type.format + same_dtype = x.type.dtype == out_dtype + + if y.ndim == 0: + + @numba_njit + def sparse_multiply_scalar(x, y): + if same_dtype: + z = x.copy() + else: + z = x.astype(out_dtype) + # Numba doesn't know how to handle in-place mutation / assignment of fields + # z.data *= y + z_data = z.data + z_data *= y + return z + + return sparse_multiply_scalar + + elif y.ndim == 1: + + @numba_njit + def sparse_dense_multiply(x, y): + assert x.shape[1] == y.shape[0] + if same_dtype: + z = x.copy() + else: + z = x.astype(out_dtype) + + M, N = x.shape + indices = x.indices + indptr = x.indptr + z_data = z.data + if format == "csc": + for j in range(0, N): + for i_idx in range(indptr[j], indptr[j + 1]): + z_data[i_idx] *= y[j] + return z + + else: + for i in range(0, M): + for j_idx in range(indptr[i], indptr[i + 1]): + j = indices[j_idx] + z_data[j_idx] *= y[j] + + return z + + return sparse_dense_multiply + + else: # y.ndim == 2 + + @numba_njit + def sparse_dense_multiply(x, y): + assert x.shape == y.shape + if same_dtype: + z = x.copy() + else: + z = x.astype(out_dtype) + + M, N = x.shape + indices = x.indices + indptr = x.indptr + z_data = z.data + if format == "csc": + for j in range(0, N): + for i_idx in range(indptr[j], indptr[j + 1]): + i = indices[i_idx] + z_data[i_idx] *= y[i, j] + return z + + else: + for i in range(0, M): + for j_idx in range(indptr[i], indptr[i + 1]): + j = indices[j_idx] + z_data[j_idx] *= y[i, j] + + return z + + return sparse_dense_multiply diff --git a/pytensor/sparse/basic.py b/pytensor/sparse/basic.py index 23d3a4e2a0..f1aa4be226 100644 --- a/pytensor/sparse/basic.py +++ b/pytensor/sparse/basic.py @@ -500,19 +500,19 @@ def unique_value(self): # for more dtypes, call SparseTensorType(format, dtype) -def matrix(format, name=None, dtype=None): +def matrix(format, name=None, dtype=None, shape=None): if dtype is None: dtype = config.floatX - type = SparseTensorType(format=format, dtype=dtype) + type = SparseTensorType(format=format, dtype=dtype, shape=shape) return type(name) -def csc_matrix(name=None, dtype=None): - return matrix("csc", name, dtype) +def csc_matrix(name=None, dtype=None, shape=None): + return matrix("csc", name=name, dtype=dtype, shape=shape) -def csr_matrix(name=None, dtype=None): - return matrix("csr", name, dtype) +def csr_matrix(name=None, dtype=None, shape=None): + return matrix("csr", name=name, dtype=dtype, shape=shape) def bsr_matrix(name=None, dtype=None): @@ -727,10 +727,22 @@ def make_node(self, data, indices, indptr, shape): if shape.type.ndim != 1 or shape.type.dtype not in discrete_dtypes: raise TypeError("n_rows must be integer type", shape, shape.type) + static_shape = (None, None) + if ( + shape.owner is not None + and isinstance(shape.owner.op, CSMProperties) + and shape.owner.outputs[3] is shape + ): + static_shape = shape.owner.inputs[0].type.shape + return Apply( self, [data, indices, indptr, shape], - [SparseTensorType(dtype=data.type.dtype, format=self.format)()], + [ + SparseTensorType( + dtype=data.type.dtype, format=self.format, shape=static_shape + )() + ], ) def perform(self, node, inputs, outputs): @@ -2298,7 +2310,7 @@ def sub(x, y): return x + (-y) -class MulSS(Op): +class SparseSparseMultiply(Op): # mul(sparse, sparse) # See the doc of mul() for more detail __props__ = () @@ -2331,12 +2343,15 @@ def infer_shape(self, fgraph, node, shapes): return [shapes[0]] -mul_s_s = MulSS() +mul_s_s = SparseSparseMultiply() -class MulSD(Op): +class SparseDenseMultiply(Op): # mul(sparse, dense) # See the doc of mul() for more detail + + # We're doing useless copy of indices and indptr, those should be reused + # However, PyTensor doesn't support one output -> multiple views... __props__ = () def make_node(self, x, y): @@ -2352,64 +2367,42 @@ def make_node(self, x, y): # Broadcasting of the sparse matrix is not supported. # We support nd == 0 used by grad of SpSum() assert y.type.ndim in (0, 2) - out = SparseTensorType(dtype=dtype, format=x.type.format)() + out = SparseTensorType(dtype=dtype, format=x.type.format, shape=x.type.shape)() return Apply(self, [x, y], [out]) def perform(self, node, inputs, outputs): (x, y) = inputs (out,) = outputs + out_dtype = node.outputs[0].dtype assert _is_sparse(x) and _is_dense(y) - if len(y.shape) == 0: - out_dtype = node.outputs[0].dtype - if x.dtype == out_dtype: - z = x.copy() - else: - z = x.astype(out_dtype) - out[0] = z - out[0].data *= y - elif len(y.shape) == 1: - raise NotImplementedError() # RowScale / ColScale - elif len(y.shape) == 2: + + if x.dtype == out_dtype: + z = x.copy() + else: + z = x.astype(out_dtype) + out[0] = z + z_data = z.data + + if y.ndim == 0: + z_data *= y + else: # y_ndim == 2 # if we have enough memory to fit y, maybe we can fit x.asarray() # too? # TODO: change runtime from O(M*N) to O(nonzeros) M, N = x.shape assert x.shape == y.shape - out_dtype = node.outputs[0].dtype - + indices = x.indices + indptr = x.indptr if x.format == "csc": - indices = x.indices - indptr = x.indptr - if x.dtype == out_dtype: - z = x.copy() - else: - z = x.astype(out_dtype) - z_data = z.data - for j in range(0, N): for i_idx in range(indptr[j], indptr[j + 1]): i = indices[i_idx] z_data[i_idx] *= y[i, j] - out[0] = z elif x.format == "csr": - indices = x.indices - indptr = x.indptr - if x.dtype == out_dtype: - z = x.copy() - else: - z = x.astype(out_dtype) - z_data = z.data - for i in range(0, M): for j_idx in range(indptr[i], indptr[i + 1]): j = indices[j_idx] z_data[j_idx] *= y[i, j] - out[0] = z - else: - warn( - "This implementation of MulSD is deficient: {x.format}", - ) - out[0] = type(x)(x.toarray() * y) def grad(self, inputs, gout): (x, y) = inputs @@ -2422,12 +2415,14 @@ def infer_shape(self, fgraph, node, shapes): return [shapes[0]] -mul_s_d = MulSD() +mul_s_d = SparseDenseMultiply() -class MulSV(Op): +class SparseDenseVectorMultiply(Op): """Element-wise multiplication of sparse matrix by a broadcasted dense vector element wise. + TODO: Merge with the SparseDenseMultiply Op + Notes ----- The grad implemented is regular, i.e. not structured. @@ -2488,7 +2483,7 @@ def infer_shape(self, fgraph, node, ins_shapes): return [ins_shapes[0]] -mul_s_v = MulSV() +mul_s_v = SparseDenseVectorMultiply() def mul(x, y): @@ -2527,16 +2522,17 @@ def mul(x, y): # mul_s_s is not implemented if the types differ if y.dtype == "float64" and x.dtype == "float32": x = x.astype("float64") - return mul_s_s(x, y) - elif x_is_sparse_variable and not y_is_sparse_variable: + elif x_is_sparse_variable or y_is_sparse_variable: + if y_is_sparse_variable: + x, y = y, x # mul is unimplemented if the dtypes differ if y.dtype == "float64" and x.dtype == "float32": x = x.astype("float64") - - return mul_s_d(x, y) - elif y_is_sparse_variable and not x_is_sparse_variable: - return mul_s_d(y, x) + if y.ndim == 1: + return mul_s_v(x, y) + else: + return mul_s_d(x, y) else: raise NotImplementedError() diff --git a/pytensor/tensor/type.py b/pytensor/tensor/type.py index 5ae92006e2..3dbf3d2fc3 100644 --- a/pytensor/tensor/type.py +++ b/pytensor/tensor/type.py @@ -71,7 +71,7 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape): def __init__( self, dtype: str | npt.DTypeLike, - shape: Iterable[bool | int | None] | None = None, + shape: Iterable[bool | int | None] | int | None = None, name: str | None = None, broadcastable: Iterable[bool] | None = None, ): @@ -99,7 +99,7 @@ def __init__( ) shape = broadcastable - if str(dtype) == "floatX": + if dtype == "floatX": self.dtype = config.floatX else: try: @@ -118,6 +118,8 @@ def parse_bcast_and_shape(s): f"TensorType broadcastable/shape must be a boolean, integer or None, got {type(s)} {s}" ) + if isinstance(shape, int): + shape = (shape,) self.shape = tuple(parse_bcast_and_shape(s) for s in shape) self.dtype_specs() # error checking is done there self.name = name diff --git a/tests/link/numba/test_sparse.py b/tests/link/numba/test_sparse.py index 3d91ca13a8..f49e72044d 100644 --- a/tests/link/numba/test_sparse.py +++ b/tests/link/numba/test_sparse.py @@ -1,7 +1,13 @@ +from functools import partial + import numpy as np import pytest +import scipy import scipy as sp +import pytensor.sparse as ps +import pytensor.tensor as pt + numba = pytest.importorskip("numba") @@ -13,6 +19,23 @@ from tests.link.numba.test_basic import compare_numba_and_py +def sparse_assert_fn(a, b): + a_is_sparse = sp.sparse.issparse(a) + assert a_is_sparse == sp.sparse.issparse(b) + if a_is_sparse: + assert a.format == b.format + assert a.dtype == b.dtype + assert a.shape == b.shape + np.testing.assert_allclose(a.data, b.data, strict=True) + np.testing.assert_allclose(a.indices, b.indices, strict=True) + np.testing.assert_allclose(a.indptr, b.indptr, strict=True) + else: + np.testing.assert_allclose(a, b, strict=True) + + +compare_numba_and_py_sparse = partial(compare_numba_and_py, assert_fn=sparse_assert_fn) + + pytestmark = pytest.mark.filterwarnings("error") @@ -77,14 +100,12 @@ def test_fn(x): def test_sparse_copy(): @numba.njit def test_fn(x): - y = x.copy() - return ( - y is not x and np.all(x.data == y.data) and np.all(x.indices == y.indices) - ) + return x.copy() - x_val = sp.sparse.csr_matrix(np.eye(100)) + x = sp.sparse.csr_matrix(np.eye(100)) - assert test_fn(x_val) + y = test_fn(x) + assert y is not x and np.all(x.data == y.data) and np.all(x.indices == y.indices) def test_sparse_objmode(): @@ -101,3 +122,59 @@ def test_sparse_objmode(): match="Numba will use object mode to run SparseDot's perform method", ): compare_numba_and_py([x, y], out, [x_val, y_val]) + + +def test_overload_csr_matrix_constructor(): + @numba.njit + def csr_matrix_constructor(data, indices, indptr): + return sp.sparse.csr_matrix((data, indices, indptr), shape=(3, 3)) + + inp = sp.sparse.random(3, 3, density=0.5, format="csr") + + # Test with pure scipy csr_matrix constructor + out = sp.sparse.csr_matrix((inp.data, inp.indices, inp.indptr), copy=False) + # CSR_matrix does a useless slice on data and indices to trim away useless zeros + # which means these attributes are views of the original arrays. + assert out.data is not inp.data + assert not out.data.flags.owndata + + assert out.indices is not inp.indices + assert not out.indices.flags.owndata + + assert out.indptr is inp.indptr + assert out.indptr.flags.owndata + + # Test ours + out_pt = csr_matrix_constructor(inp.data, inp.indices, inp.indptr) + # Should work the same as Scipy's constructor, because it's ultimately used + assert isinstance(out_pt, scipy.sparse.csr_matrix) + assert out_pt.data is not inp.data + assert not out_pt.data.flags.owndata + assert (out_pt.data == inp.data).all() + + assert out_pt.indices is not inp.indices + assert not out_pt.indices.flags.owndata + assert (out_pt.indices == inp.indices).all() + + assert out_pt.indptr is inp.indptr + assert out_pt.indptr.flags.owndata + assert (out_pt.indptr == inp.indptr).all() + + +@pytest.mark.parametrize("format", ["csr", "csc"]) +@pytest.mark.parametrize("y_ndim", [0, 1, 2]) +def test_simple_graph(y_ndim, format): + ps_matrix = ps.csr_matrix if format == "csr" else ps.csc_matrix + x = ps_matrix("x", shape=(3, 3)) + y = pt.tensor("y", shape=(3,) * y_ndim) + z = ps.sin(x * y) + + rng = np.random.default_rng((155, y_ndim, format == "csr")) + x_test = sp.sparse.random(3, 3, density=0.5, format=format, random_state=rng) + y_test = rng.normal(size=(3,) * y_ndim) + + compare_numba_and_py_sparse( + [x, y], + z, + [x_test, y_test], + ) diff --git a/tests/sparse/test_basic.py b/tests/sparse/test_basic.py index 27376fa770..abfd2f76af 100644 --- a/tests/sparse/test_basic.py +++ b/tests/sparse/test_basic.py @@ -32,12 +32,12 @@ EnsureSortedIndices, GetItemScalar, HStack, - MulSD, - MulSS, Neg, Remove0, SamplingDot, + SparseDenseMultiply, SparseFromDense, + SparseSparseMultiply, SparseTensorType, SquareDiagonal, StructuredDot, @@ -514,7 +514,7 @@ def test_mul_ss(self): sp.sparse.csr_matrix(random_lil((10, 40), config.floatX, 3)), ] * 2, - MulSS, + SparseSparseMultiply, ) def test_mul_sd(self): @@ -527,7 +527,7 @@ def test_mul_sd(self): sp.sparse.csr_matrix(random_lil((10, 40), config.floatX, 3)), np.random.standard_normal((10, 40)).astype(config.floatX), ], - MulSD, + SparseDenseMultiply, excluding=["local_mul_s_d"], ) diff --git a/tests/sparse/test_rewriting.py b/tests/sparse/test_rewriting.py index 280d9dbf70..4634fa744b 100644 --- a/tests/sparse/test_rewriting.py +++ b/tests/sparse/test_rewriting.py @@ -77,7 +77,8 @@ def test_local_mul_s_d(): f = pytensor.function(inputs, sparse.mul_s_d(*inputs), mode=mode) assert not any( - isinstance(node.op, sparse.MulSD) for node in f.maker.fgraph.toposort() + isinstance(node.op, sparse.SparseDenseMultiply) + for node in f.maker.fgraph.toposort() ) @@ -94,7 +95,8 @@ def test_local_mul_s_v(): f = pytensor.function(inputs, sparse.mul_s_v(*inputs), mode=mode) assert not any( - isinstance(node.op, sparse.MulSV) for node in f.maker.fgraph.toposort() + isinstance(node.op, sparse.SparseDenseVectorMultiply) + for node in f.maker.fgraph.toposort() )