1111Array class and helper functions.
1212"""
1313
14- from .algorithm import sum , count
15- from .arith import cast
1614import inspect
1715import os
1816from .library import *
2725
2826_display_dims_limit = None
2927
30-
3128def set_display_dims_limit (* dims ):
3229 """
3330 Sets the dimension limit after which array's data won't get
@@ -47,7 +44,6 @@ def set_display_dims_limit(*dims):
4744 global _display_dims_limit
4845 _display_dims_limit = dims
4946
50-
5147def get_display_dims_limit ():
5248 """
5349 Gets the dimension limit after which array's data won't get
@@ -71,7 +67,6 @@ def get_display_dims_limit():
7167 """
7268 return _display_dims_limit
7369
74-
7570def _in_display_dims_limit (dims ):
7671 if _is_running_in_py_charm :
7772 return False
@@ -85,7 +80,6 @@ def _in_display_dims_limit(dims):
8580 return False
8681 return True
8782
88-
8983def _create_array (buf , numdims , idims , dtype , is_device ):
9084 out_arr = c_void_ptr_t (0 )
9185 c_dims = dim4 (idims [0 ], idims [1 ], idims [2 ], idims [3 ])
@@ -97,7 +91,6 @@ def _create_array(buf, numdims, idims, dtype, is_device):
9791 numdims , c_pointer (c_dims ), dtype .value ))
9892 return out_arr
9993
100-
10194def _create_strided_array (buf , numdims , idims , dtype , is_device , offset , strides ):
10295 out_arr = c_void_ptr_t (0 )
10396 c_dims = dim4 (idims [0 ], idims [1 ], idims [2 ], idims [3 ])
@@ -119,15 +112,16 @@ def _create_strided_array(buf, numdims, idims, dtype, is_device, offset, strides
119112 location .value ))
120113 return out_arr
121114
122-
123115def _create_empty_array (numdims , idims , dtype ):
124116 out_arr = c_void_ptr_t (0 )
117+
118+ if numdims == 0 : return out_arr
119+
125120 c_dims = dim4 (idims [0 ], idims [1 ], idims [2 ], idims [3 ])
126121 safe_call (backend .get ().af_create_handle (c_pointer (out_arr ),
127122 numdims , c_pointer (c_dims ), dtype .value ))
128123 return out_arr
129124
130-
131125def constant_array (val , d0 , d1 = None , d2 = None , d3 = None , dtype = Dtype .f32 ):
132126 """
133127 Internal function to create a C array. Should not be used externall.
@@ -182,7 +176,6 @@ def _binary_func(lhs, rhs, c_func):
182176
183177 return out
184178
185-
186179def _binary_funcr (lhs , rhs , c_func ):
187180 out = Array ()
188181 other = lhs
@@ -199,10 +192,9 @@ def _binary_funcr(lhs, rhs, c_func):
199192
200193 return out
201194
202-
203195def _ctype_to_lists (ctype_arr , dim , shape , offset = 0 ):
204196 if (dim == 0 ):
205- return list (ctype_arr [offset : offset + shape [0 ]])
197+ return list (ctype_arr [offset : offset + shape [0 ]])
206198 else :
207199 dim_len = shape [dim ]
208200 res = [[]] * dim_len
@@ -211,7 +203,6 @@ def _ctype_to_lists(ctype_arr, dim, shape, offset=0):
211203 offset += shape [0 ]
212204 return res
213205
214-
215206def _slice_to_length (key , dim ):
216207 tkey = [key .start , key .stop , key .step ]
217208
@@ -230,7 +221,6 @@ def _slice_to_length(key, dim):
230221
231222 return int (((tkey [1 ] - tkey [0 ] - 1 ) / tkey [2 ]) + 1 )
232223
233-
234224def _get_info (dims , buf_len ):
235225 elements = 1
236226 numdims = 0
@@ -260,7 +250,6 @@ def _get_indices(key):
260250
261251 return inds
262252
263-
264253def _get_assign_dims (key , idims ):
265254
266255 dims = [1 ]* 4
@@ -307,7 +296,6 @@ def _get_assign_dims(key, idims):
307296 else :
308297 raise IndexError ("Invalid type while assigning to arrayfire.array" )
309298
310-
311299def transpose (a , conj = False ):
312300 """
313301 Perform the transpose on an input.
@@ -330,7 +318,6 @@ def transpose(a, conj=False):
330318 safe_call (backend .get ().af_transpose (c_pointer (out .arr ), a .arr , conj ))
331319 return out
332320
333-
334321def transpose_inplace (a , conj = False ):
335322 """
336323 Perform inplace transpose on an input.
@@ -351,7 +338,6 @@ def transpose_inplace(a, conj=False):
351338 """
352339 safe_call (backend .get ().af_transpose_inplace (a .arr , conj ))
353340
354-
355341class Array (BaseArray ):
356342
357343 """
@@ -461,8 +447,8 @@ def __init__(self, src=None, dims=None, dtype=None, is_device=False, offset=None
461447
462448 super (Array , self ).__init__ ()
463449
464- buf = None
465- buf_len = 0
450+ buf = None
451+ buf_len = 0
466452
467453 if dtype is not None :
468454 if isinstance (dtype , str ):
@@ -472,7 +458,7 @@ def __init__(self, src=None, dims=None, dtype=None, is_device=False, offset=None
472458 else :
473459 type_char = None
474460
475- _type_char = 'f'
461+ _type_char = 'f'
476462
477463 if src is not None :
478464
@@ -483,12 +469,12 @@ def __init__(self, src=None, dims=None, dtype=None, is_device=False, offset=None
483469 host = __import__ ("array" )
484470
485471 if isinstance (src , host .array ):
486- buf , buf_len = src .buffer_info ()
472+ buf ,buf_len = src .buffer_info ()
487473 _type_char = src .typecode
488474 numdims , idims = _get_info (dims , buf_len )
489475 elif isinstance (src , list ):
490476 tmp = host .array ('f' , src )
491- buf , buf_len = tmp .buffer_info ()
477+ buf ,buf_len = tmp .buffer_info ()
492478 _type_char = tmp .typecode
493479 numdims , idims = _get_info (dims , buf_len )
494480 elif isinstance (src , int ) or isinstance (src , c_void_ptr_t ):
@@ -512,7 +498,7 @@ def __init__(self, src=None, dims=None, dtype=None, is_device=False, offset=None
512498 raise TypeError ("src is an object of unsupported class" )
513499
514500 if (type_char is not None and
515- type_char != _type_char ):
501+ type_char != _type_char ):
516502 raise TypeError ("Can not create array of requested type from input data type" )
517503 if (offset is None and strides is None ):
518504 self .arr = _create_array (buf , numdims , idims , to_dtype [_type_char ], is_device )
@@ -634,8 +620,8 @@ def strides(self):
634620 s2 = c_dim_t (0 )
635621 s3 = c_dim_t (0 )
636622 safe_call (backend .get ().af_get_strides (c_pointer (s0 ), c_pointer (s1 ),
637- c_pointer (s2 ), c_pointer (s3 ), self .arr ))
638- strides = (s0 .value , s1 .value , s2 .value , s3 .value )
623+ c_pointer (s2 ), c_pointer (s3 ), self .arr ))
624+ strides = (s0 .value ,s1 .value ,s2 .value ,s3 .value )
639625 return strides [:self .numdims ()]
640626
641627 def elements (self ):
@@ -694,8 +680,8 @@ def dims(self):
694680 d2 = c_dim_t (0 )
695681 d3 = c_dim_t (0 )
696682 safe_call (backend .get ().af_get_dims (c_pointer (d0 ), c_pointer (d1 ),
697- c_pointer (d2 ), c_pointer (d3 ), self .arr ))
698- dims = (d0 .value , d1 .value , d2 .value , d3 .value )
683+ c_pointer (d2 ), c_pointer (d3 ), self .arr ))
684+ dims = (d0 .value ,d1 .value ,d2 .value ,d3 .value )
699685 return dims [:self .numdims ()]
700686
701687 @property
@@ -920,7 +906,7 @@ def __itruediv__(self, other):
920906 """
921907 Perform self /= other.
922908 """
923- self = _binary_func (self , other , backend .get ().af_div )
909+ self = _binary_func (self , other , backend .get ().af_div )
924910 return self
925911
926912 def __rtruediv__ (self , other ):
@@ -939,7 +925,7 @@ def __idiv__(self, other):
939925 """
940926 Perform other / self.
941927 """
942- self = _binary_func (self , other , backend .get ().af_div )
928+ self = _binary_func (self , other , backend .get ().af_div )
943929 return self
944930
945931 def __rdiv__ (self , other ):
@@ -958,7 +944,7 @@ def __imod__(self, other):
958944 """
959945 Perform self %= other.
960946 """
961- self = _binary_func (self , other , backend .get ().af_mod )
947+ self = _binary_func (self , other , backend .get ().af_mod )
962948 return self
963949
964950 def __rmod__ (self , other ):
@@ -977,7 +963,7 @@ def __ipow__(self, other):
977963 """
978964 Perform self **= other.
979965 """
980- self = _binary_func (self , other , backend .get ().af_pow )
966+ self = _binary_func (self , other , backend .get ().af_pow )
981967 return self
982968
983969 def __rpow__ (self , other ):
@@ -1120,15 +1106,15 @@ def logical_and(self, other):
11201106 Return self && other.
11211107 """
11221108 out = Array ()
1123- safe_call (backend .get ().af_and (c_pointer (out .arr ), self .arr , other .arr )) # TODO: bcast var?
1109+ safe_call (backend .get ().af_and (c_pointer (out .arr ), self .arr , other .arr )) # TODO: bcast var?
11241110 return out
11251111
11261112 def logical_or (self , other ):
11271113 """
11281114 Return self || other.
11291115 """
11301116 out = Array ()
1131- safe_call (backend .get ().af_or (c_pointer (out .arr ), self .arr , other .arr )) # TODO: bcast var?
1117+ safe_call (backend .get ().af_or (c_pointer (out .arr ), self .arr , other .arr )) # TODO: bcast var?
11321118 return out
11331119
11341120 def __nonzero__ (self ):
@@ -1158,11 +1144,12 @@ def __getitem__(self, key):
11581144 inds = _get_indices (key )
11591145
11601146 safe_call (backend .get ().af_index_gen (c_pointer (out .arr ),
1161- self .arr , c_dim_t (n_dims ), inds .pointer ))
1147+ self .arr , c_dim_t (n_dims ), inds .pointer ))
11621148 return out
11631149 except RuntimeError as e :
11641150 raise IndexError (str (e ))
11651151
1152+
11661153 def __setitem__ (self , key , val ):
11671154 """
11681155 Perform self[key] = val
@@ -1188,14 +1175,14 @@ def __setitem__(self, key, val):
11881175 n_dims = 1
11891176 other_arr = constant_array (val , int (num ), dtype = self .type ())
11901177 else :
1191- other_arr = constant_array (val , tdims [0 ], tdims [1 ], tdims [2 ], tdims [3 ], self .type ())
1178+ other_arr = constant_array (val , tdims [0 ] , tdims [1 ], tdims [2 ], tdims [3 ], self .type ())
11921179 del_other = True
11931180 else :
11941181 other_arr = val .arr
11951182 del_other = False
11961183
11971184 out_arr = c_void_ptr_t (0 )
1198- inds = _get_indices (key )
1185+ inds = _get_indices (key )
11991186
12001187 safe_call (backend .get ().af_assign_gen (c_pointer (out_arr ),
12011188 self .arr , c_dim_t (n_dims ), inds .pointer ,
@@ -1414,7 +1401,6 @@ def to_ndarray(self, output=None):
14141401 safe_call (backend .get ().af_get_data_ptr (c_void_ptr_t (output .ctypes .data ), tmp .arr ))
14151402 return output
14161403
1417-
14181404def display (a , precision = 4 ):
14191405 """
14201406 Displays the contents of an array.
@@ -1440,7 +1426,6 @@ def display(a, precision=4):
14401426 safe_call (backend .get ().af_print_array_gen (name .encode ('utf-8' ),
14411427 a .arr , c_int_t (precision )))
14421428
1443-
14441429def save_array (key , a , filename , append = False ):
14451430 """
14461431 Save an array to disk.
@@ -1472,7 +1457,6 @@ def save_array(key, a, filename, append=False):
14721457 append ))
14731458 return index .value
14741459
1475-
14761460def read_array (filename , index = None , key = None ):
14771461 """
14781462 Read an array from disk.
@@ -1506,3 +1490,6 @@ def read_array(filename, index=None, key=None):
15061490 key .encode ('utf-8' )))
15071491
15081492 return out
1493+
1494+ from .algorithm import (sum , count )
1495+ from .arith import cast
0 commit comments