Skip to content

Commit 602eb04

Browse files
Move shared HStack and VStack methods to Stack class (#1662)
1 parent f772066 commit 602eb04

File tree

1 file changed

+9
-13
lines changed

1 file changed

+9
-13
lines changed

pytensor/sparse/basic.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2794,7 +2794,7 @@ def comparison(self, x, y):
27942794
ge = __ComparisonSwitch(greater_equal_s_s, greater_equal_s_d, less_equal_s_d)
27952795

27962796

2797-
class HStack(Op):
2797+
class Stack(Op):
27982798
__props__ = ("format", "dtype")
27992799

28002800
def __init__(self, format=None, dtype=None):
@@ -2819,6 +2819,11 @@ def make_node(self, *mat):
28192819
self, var, [SparseTensorType(dtype=self.dtype, format=self.format)()]
28202820
)
28212821

2822+
def __str__(self):
2823+
return f"{self.__class__.__name__}({self.format},{self.dtype})"
2824+
2825+
2826+
class HStack(Stack):
28222827
def perform(self, node, block, outputs):
28232828
(out,) = outputs
28242829
for b in block:
@@ -2853,15 +2858,9 @@ def choose(continuous, derivative):
28532858
return [choose(c, d) for c, d in zip(is_continuous, derivative, strict=True)]
28542859

28552860
def infer_shape(self, fgraph, node, ins_shapes):
2856-
def _get(l):
2857-
return l[1]
2858-
2859-
d = sum(map(_get, ins_shapes))
2861+
d = sum(shape[1] for shape in ins_shapes)
28602862
return [(ins_shapes[0][0], d)]
28612863

2862-
def __str__(self):
2863-
return f"{self.__class__.__name__}({self.format},{self.dtype})"
2864-
28652864

28662865
def hstack(blocks, format=None, dtype=None):
28672866
"""
@@ -2897,7 +2896,7 @@ def hstack(blocks, format=None, dtype=None):
28972896
return HStack(format=format, dtype=dtype)(*blocks)
28982897

28992898

2900-
class VStack(HStack):
2899+
class VStack(Stack):
29012900
def perform(self, node, block, outputs):
29022901
(out,) = outputs
29032902
for b in block:
@@ -2932,10 +2931,7 @@ def choose(continuous, derivative):
29322931
return [choose(c, d) for c, d in zip(is_continuous, derivative, strict=True)]
29332932

29342933
def infer_shape(self, fgraph, node, ins_shapes):
2935-
def _get(l):
2936-
return l[0]
2937-
2938-
d = sum(map(_get, ins_shapes))
2934+
d = sum(shape[0] for shape in ins_shapes)
29392935
return [(d, ins_shapes[0][1])]
29402936

29412937

0 commit comments

Comments
 (0)