66import scipy .linalg
77
88import pytensor
9- from pytensor import In , config , function
9+ from pytensor import In , config , function , scan
1010from pytensor .compile import get_default_mode , get_mode
1111from pytensor .gradient import grad
1212from pytensor .graph import Apply , Op
13- from pytensor .graph .replace import vectorize_node
13+ from pytensor .graph .replace import vectorize_graph , vectorize_node
1414from pytensor .raise_op import assert_op
1515from pytensor .tensor import diagonal , dmatrix , log , ones_like , scalar , tensor , vector
1616from pytensor .tensor .blockwise import Blockwise , vectorize_node_fallback
@@ -162,13 +162,13 @@ def perform(self, *args, **kwargs):
162162 raise NotImplementedError ("Test Op should not be present in final graph" )
163163
164164
165- test_op = MyTestOp ()
165+ my_test_op = MyTestOp ()
166166
167167
168168def test_vectorize_node_default_signature ():
169169 vec = tensor (shape = (None ,))
170170 mat = tensor (shape = (5 , None ))
171- node = test_op .make_node (vec , mat )
171+ node = my_test_op .make_node (vec , mat )
172172
173173 vect_node = vectorize_node (node , mat , mat )
174174 assert isinstance (vect_node .op , Blockwise ) and isinstance (
@@ -179,9 +179,9 @@ def test_vectorize_node_default_signature():
179179 with pytest .raises (
180180 ValueError , match = "Signature not provided nor found in core_op MyTestOp"
181181 ):
182- Blockwise (test_op )
182+ Blockwise (my_test_op )
183183
184- vect_node = Blockwise (test_op , signature = "(m),(n)->(m),(n)" ).make_node (vec , mat )
184+ vect_node = Blockwise (my_test_op , signature = "(m),(n)->(m),(n)" ).make_node (vec , mat )
185185 assert vect_node .outputs [0 ].type .shape == (
186186 5 ,
187187 None ,
@@ -198,7 +198,7 @@ def test_blockwise_shape():
198198 inp_test = np .zeros ((5 , 4 , 3 ), dtype = config .floatX )
199199
200200 # Shape can be inferred from inputs
201- op = Blockwise (test_op , signature = "(m, n) -> (n, m)" )
201+ op = Blockwise (my_test_op , signature = "(m, n) -> (n, m)" )
202202 out = op (inp )
203203 assert out .type .shape == (5 , None , None )
204204
@@ -210,7 +210,7 @@ def test_blockwise_shape():
210210 assert tuple (shape_fn (inp_test )) == (5 , 3 , 4 )
211211
212212 # Shape can only be partially inferred from inputs
213- op = Blockwise (test_op , signature = "(m, n) -> (m, k)" )
213+ op = Blockwise (my_test_op , signature = "(m, n) -> (m, k)" )
214214 out = op (inp )
215215 assert out .type .shape == (5 , None , None )
216216
@@ -233,7 +233,7 @@ def test_blockwise_shape():
233233 inp1_test = np .zeros ((7 , 1 , 4 , 3 ), dtype = config .floatX )
234234 inp2_test = np .zeros ((1 , 5 , 4 , 3 ), dtype = config .floatX )
235235
236- op = Blockwise (test_op , signature = "(m, n), (m, n) -> (n, m), (m, k)" )
236+ op = Blockwise (my_test_op , signature = "(m, n), (m, n) -> (n, m), (m, k)" )
237237 outs = op (inp1 , inp2 )
238238 assert outs [0 ].type .shape == (7 , 5 , None , None )
239239 assert outs [1 ].type .shape == (7 , 5 , None , None )
@@ -650,3 +650,51 @@ def L_op(self, inputs, outputs, output_gradients):
650650 np .ones (12 , dtype = config .floatX ),
651651 strict = True ,
652652 )
653+
654+
655+ def test_blockwise_grad_core_type ():
656+ class StrictCoreTypeOp (Op ):
657+ def make_node (self , x ):
658+ assert x .type .shape [- 1 ] == 2
659+ return Apply (self , [x ], [x .type ()])
660+
661+ def perform (self , node , inputs , output_storage ):
662+ output_storage [0 ][0 ] = inputs [0 ] + 1
663+
664+ def L_op (self , inputs , outputs , output_grads ):
665+ [x ] = inputs
666+ assert x .type .shape == (2 ,)
667+ return [x .zeros_like ()]
668+
669+ strict_core_type_op = StrictCoreTypeOp ()
670+ block_strict_core_type_op = Blockwise (strict_core_type_op , signature = "(a)->(a)" )
671+
672+ x = tensor ("x" , shape = (5 , 2 ), dtype = "float64" )
673+ y = block_strict_core_type_op (x )
674+ assert y .type .shape == (5 , 2 )
675+
676+ grad_y = grad (y .sum (), x )
677+ assert grad_y .type .shape == (5 , 2 )
678+ np .testing .assert_allclose (
679+ grad_y .eval ({x : np .ones ((5 , 2 ))}),
680+ np .zeros ((5 , 2 )),
681+ )
682+
683+
684+ def test_scan_gradient_core_type ():
685+ n_steps = 3
686+ seq = tensor ("seq" , shape = (n_steps , 1 ), dtype = "float64" )
687+ out , _ = scan (
688+ lambda s : s ,
689+ sequences = [seq ],
690+ n_steps = n_steps ,
691+ )
692+
693+ vec_seq = tensor ("vec_seq" , shape = (None , n_steps , 1 ), dtype = "float64" )
694+ vec_out = vectorize_graph (out , replace = {seq : vec_seq })
695+ grad_sit_sot0 = grad (vec_out .sum (), vec_seq )
696+
697+ np .testing .assert_allclose (
698+ grad_sit_sot0 .eval ({vec_seq : np .ones ((4 , n_steps , 1 ))}),
699+ np .ones ((4 , n_steps , 1 )),
700+ )
0 commit comments