@@ -2794,7 +2794,7 @@ def comparison(self, x, y):
27942794ge = __ComparisonSwitch (greater_equal_s_s , greater_equal_s_d , less_equal_s_d )
27952795
27962796
2797- class HStack (Op ):
2797+ class Stack (Op ):
27982798 __props__ = ("format" , "dtype" )
27992799
28002800 def __init__ (self , format = None , dtype = None ):
@@ -2819,6 +2819,11 @@ def make_node(self, *mat):
28192819 self , var , [SparseTensorType (dtype = self .dtype , format = self .format )()]
28202820 )
28212821
2822+ def __str__ (self ):
2823+ return f"{ self .__class__ .__name__ } ({ self .format } ,{ self .dtype } )"
2824+
2825+
2826+ class HStack (Stack ):
28222827 def perform (self , node , block , outputs ):
28232828 (out ,) = outputs
28242829 for b in block :
@@ -2853,15 +2858,9 @@ def choose(continuous, derivative):
28532858 return [choose (c , d ) for c , d in zip (is_continuous , derivative , strict = True )]
28542859
28552860 def infer_shape (self , fgraph , node , ins_shapes ):
2856- def _get (l ):
2857- return l [1 ]
2858-
2859- d = sum (map (_get , ins_shapes ))
2861+ d = sum (shape [1 ] for shape in ins_shapes )
28602862 return [(ins_shapes [0 ][0 ], d )]
28612863
2862- def __str__ (self ):
2863- return f"{ self .__class__ .__name__ } ({ self .format } ,{ self .dtype } )"
2864-
28652864
28662865def hstack (blocks , format = None , dtype = None ):
28672866 """
@@ -2897,7 +2896,7 @@ def hstack(blocks, format=None, dtype=None):
28972896 return HStack (format = format , dtype = dtype )(* blocks )
28982897
28992898
2900- class VStack (HStack ):
2899+ class VStack (Stack ):
29012900 def perform (self , node , block , outputs ):
29022901 (out ,) = outputs
29032902 for b in block :
@@ -2932,10 +2931,7 @@ def choose(continuous, derivative):
29322931 return [choose (c , d ) for c , d in zip (is_continuous , derivative , strict = True )]
29332932
29342933 def infer_shape (self , fgraph , node , ins_shapes ):
2935- def _get (l ):
2936- return l [0 ]
2937-
2938- d = sum (map (_get , ins_shapes ))
2934+ d = sum (shape [0 ] for shape in ins_shapes )
29392935 return [(d , ins_shapes [0 ][1 ])]
29402936
29412937
0 commit comments