11import ctypes
22import ctypes .util
3+ import functools
4+ import weakref
35
46import mlir .execution_engine
57import mlir .passmanager
911import numpy as np
1012import scipy .sparse as sps
1113
12- from ._core import DEBUG , MLIR_C_RUNNER_UTILS , SCRIPT_PATH , ctx
13- from ._dtypes import DType , Float64 , Index
14- from ._memref import MemrefF64_1D , MemrefIdx_1D
14+ from ._common import fn_cache
15+ from ._core import CWD , DEBUG , MLIR_C_RUNNER_UTILS , ctx
16+ from ._dtypes import DType , Index , asdtype
17+ from ._memref import make_memref_ctype , ranked_memref_from_np
18+
19+
20+ def _hold_self_ref_in_ret (fn ):
21+ @functools .wraps (fn )
22+ def wrapped (self , * a , ** kw ):
23+ ptr = ctypes .py_object (self )
24+ ctypes .pythonapi .Py_IncRef (ptr )
25+ ret = fn (self , * a , ** kw )
26+
27+ def finalizer (ptr ):
28+ ctypes .pythonapi .Py_DecRef (ptr )
29+
30+ weakref .finalize (ret , finalizer , ptr )
31+ return ret
32+
33+ return wrapped
1534
1635
1736class Tensor :
@@ -26,21 +45,21 @@ def __init__(self, obj, module, tensor_type, disassemble_fn, values_dtype, index
2645 def __del__ (self ):
2746 self .module .invoke ("free_tensor" , ctypes .pointer (self .obj ))
2847
48+ @_hold_self_ref_in_ret
2949 def to_scipy_sparse (self ):
3050 """
3151 Returns scipy.sparse or ndarray
3252 """
33- return self .disassemble_fn (self .module , self .obj )
53+ return self .disassemble_fn (self .module , self .obj , self . values_dtype )
3454
3555
3656class DenseFormat :
37- modules = {}
38-
57+ @fn_cache
3958 def get_module (shape : tuple [int ], values_dtype : DType , index_dtype : DType ):
4059 with ir .Location .unknown (ctx ):
4160 module = ir .Module .create ()
42- values_dtype = values_dtype .get ()
43- index_dtype = index_dtype .get ()
61+ values_dtype = values_dtype .get_mlir_type ()
62+ index_dtype = index_dtype .get_mlir_type ()
4463 index_width = getattr (index_dtype , "width" , 0 )
4564 levels = (sparse_tensor .LevelType .dense , sparse_tensor .LevelType .dense )
4665 ordering = ir .AffineMap .get_permutation ([0 , 1 ])
@@ -78,18 +97,19 @@ def free_tensor(tensor_shaped):
7897 disassemble .func_op .attributes ["llvm.emit_c_interface" ] = ir .UnitAttr .get ()
7998 free_tensor .func_op .attributes ["llvm.emit_c_interface" ] = ir .UnitAttr .get ()
8099 if DEBUG :
81- (SCRIPT_PATH / "dense_module.mlir" ).write_text (str (module ))
100+ (CWD / "dense_module.mlir" ).write_text (str (module ))
82101 pm = mlir .passmanager .PassManager .parse ("builtin.module(sparsifier{create-sparse-deallocs=1})" )
83102 pm .run (module .operation )
84103 if DEBUG :
85- (SCRIPT_PATH / "dense_module_opt.mlir" ).write_text (str (module ))
104+ (CWD / "dense_module_opt.mlir" ).write_text (str (module ))
86105
87106 module = mlir .execution_engine .ExecutionEngine (module , opt_level = 2 , shared_libs = [MLIR_C_RUNNER_UTILS ])
88107 return (module , dense_shaped )
89108
90109 @classmethod
91110 def assemble (cls , module , arr : np .ndarray ) -> ctypes .c_void_p :
92- data = MemrefF64_1D .from_numpy (arr .flatten ())
111+ assert arr .ndim == 2
112+ data = ranked_memref_from_np (arr .flatten ())
93113 out = ctypes .c_void_p ()
94114 module .invoke (
95115 "assemble" ,
@@ -99,18 +119,18 @@ def assemble(cls, module, arr: np.ndarray) -> ctypes.c_void_p:
99119 return out
100120
101121 @classmethod
102- def disassemble (cls , module : ir .Module , ptr : ctypes .c_void_p ) -> np .ndarray :
122+ def disassemble (cls , module : ir .Module , ptr : ctypes .c_void_p , dtype : type [ DType ] ) -> np .ndarray :
103123 class Dense (ctypes .Structure ):
104124 _fields_ = [
105- ("data" , MemrefF64_1D ),
125+ ("data" , make_memref_ctype ( dtype , 1 ) ),
106126 ("data_len" , np .ctypeslib .c_intp ),
107127 ("shape_x" , np .ctypeslib .c_intp ),
108128 ("shape_y" , np .ctypeslib .c_intp ),
109129 ]
110130
111131 def to_np (self ) -> np .ndarray :
112132 data = self .data .to_numpy ()[: self .data_len ]
113- return data .copy (). reshape ((self .shape_x , self .shape_y ))
133+ return data .reshape ((self .shape_x , self .shape_y ))
114134
115135 arr = Dense ()
116136 module .invoke (
@@ -122,18 +142,17 @@ def to_np(self) -> np.ndarray:
122142
123143
124144class COOFormat :
125- modules = {}
126145 # TODO: implement
146+ ...
127147
128148
129149class CSRFormat :
130- modules = {}
131-
132- def get_module (shape : tuple [int ], values_dtype : DType , index_dtype : DType ):
150+ @fn_cache
151+ def get_module (shape : tuple [int ], values_dtype : type [DType ], index_dtype : type [DType ]):
133152 with ir .Location .unknown (ctx ):
134153 module = ir .Module .create ()
135- values_dtype = values_dtype .get ()
136- index_dtype = index_dtype .get ()
154+ values_dtype = values_dtype .get_mlir_type ()
155+ index_dtype = index_dtype .get_mlir_type ()
137156 index_width = getattr (index_dtype , "width" , 0 )
138157 levels = (sparse_tensor .LevelType .dense , sparse_tensor .LevelType .compressed )
139158 ordering = ir .AffineMap .get_permutation ([0 , 1 ])
@@ -175,11 +194,11 @@ def free_tensor(tensor_shaped):
175194 disassemble .func_op .attributes ["llvm.emit_c_interface" ] = ir .UnitAttr .get ()
176195 free_tensor .func_op .attributes ["llvm.emit_c_interface" ] = ir .UnitAttr .get ()
177196 if DEBUG :
178- (SCRIPT_PATH / "scr_module .mlir" ).write_text (str (module ))
197+ (CWD / "csr_module .mlir" ).write_text (str (module ))
179198 pm = mlir .passmanager .PassManager .parse ("builtin.module(sparsifier{create-sparse-deallocs=1})" )
180199 pm .run (module .operation )
181200 if DEBUG :
182- (SCRIPT_PATH / "csr_module_opt.mlir" ).write_text (str (module ))
201+ (CWD / "csr_module_opt.mlir" ).write_text (str (module ))
183202
184203 module = mlir .execution_engine .ExecutionEngine (module , opt_level = 2 , shared_libs = [MLIR_C_RUNNER_UTILS ])
185204 return (module , csr_shaped )
@@ -189,20 +208,20 @@ def assemble(cls, module: ir.Module, arr: sps.csr_array) -> ctypes.c_void_p:
189208 out = ctypes .c_void_p ()
190209 module .invoke (
191210 "assemble" ,
192- ctypes .pointer (ctypes .pointer (MemrefIdx_1D . from_numpy (arr .indptr ))),
193- ctypes .pointer (ctypes .pointer (MemrefIdx_1D . from_numpy (arr .indices ))),
194- ctypes .pointer (ctypes .pointer (MemrefF64_1D . from_numpy (arr .data ))),
211+ ctypes .pointer (ctypes .pointer (ranked_memref_from_np (arr .indptr ))),
212+ ctypes .pointer (ctypes .pointer (ranked_memref_from_np (arr .indices ))),
213+ ctypes .pointer (ctypes .pointer (ranked_memref_from_np (arr .data ))),
195214 ctypes .pointer (out ),
196215 )
197216 return out
198217
199218 @classmethod
200- def disassemble (cls , module : ir .Module , ptr : ctypes .c_void_p ) -> sps .csr_array :
219+ def disassemble (cls , module : ir .Module , ptr : ctypes .c_void_p , dtype : type [ DType ] ) -> sps .csr_array :
201220 class Csr (ctypes .Structure ):
202221 _fields_ = [
203- ("data" , MemrefF64_1D ),
204- ("pos" , MemrefIdx_1D ),
205- ("crd" , MemrefIdx_1D ),
222+ ("data" , make_memref_ctype ( dtype , 1 ) ),
223+ ("pos" , make_memref_ctype ( Index , 1 ) ),
224+ ("crd" , make_memref_ctype ( Index , 1 ) ),
206225 ("data_len" , np .ctypeslib .c_intp ),
207226 ("pos_len" , np .ctypeslib .c_intp ),
208227 ("crd_len" , np .ctypeslib .c_intp ),
@@ -214,7 +233,7 @@ def to_sps(self) -> sps.csr_array:
214233 pos = self .pos .to_numpy ()[: self .pos_len ]
215234 crd = self .crd .to_numpy ()[: self .crd_len ]
216235 data = self .data .to_numpy ()[: self .data_len ]
217- return sps .csr_array ((data . copy () , crd . copy () , pos . copy () ), shape = (self .shape_x , self .shape_y ))
236+ return sps .csr_array ((data , crd , pos ), shape = (self .shape_x , self .shape_y ))
218237
219238 arr = Csr ()
220239 module .invoke (
@@ -235,23 +254,21 @@ def _is_numpy_obj(x) -> bool:
235254
236255def asarray (obj ) -> Tensor :
237256 # TODO: discover obj's dtype
238- values_dtype = Float64
239- index_dtype = Index
257+ values_dtype = asdtype (obj .dtype )
240258
241259 # TODO: support other scipy formats
242260 if _is_scipy_sparse_obj (obj ):
243261 format_class = CSRFormat
262+ # This can be int32 or int64
263+ index_dtype = asdtype (obj .indptr .dtype )
244264 elif _is_numpy_obj (obj ):
245265 format_class = DenseFormat
266+ index_dtype = Index
246267 else :
247268 raise Exception (f"{ type (obj )} not supported." )
248269
249270 # TODO: support proper caching
250- if hash (obj .shape ) in format_class .modules :
251- module , tensor_type = format_class .modules [hash (obj .shape )]
252- else :
253- module , tensor_type = format_class .get_module (obj .shape , values_dtype , index_dtype )
254- format_class .modules [hash (obj .shape )] = module , tensor_type
271+ module , tensor_type = format_class .get_module (obj .shape , values_dtype , index_dtype )
255272
256273 assembled_obj = format_class .assemble (module , obj )
257274 return Tensor (assembled_obj , module , tensor_type , format_class .disassemble , values_dtype , index_dtype )
0 commit comments