2323@pytest .mark .parametrize ("view" , [None , (- 1 ,), slice (- 2 , None , None )])
2424def test_scan_sit_sot (view ):
2525 x0 = pt .scalar ("x0" , dtype = "float64" )
26- xs , _ = scan (
26+ xs = scan (
2727 lambda xtm1 : xtm1 + 1 ,
2828 outputs_info = [x0 ],
2929 n_steps = 10 ,
30+ return_updates = False ,
3031 )
3132 if view :
3233 xs = xs [view ]
@@ -37,10 +38,11 @@ def test_scan_sit_sot(view):
3738@pytest .mark .parametrize ("view" , [None , (- 1 ,), slice (- 4 , - 1 , None )])
3839def test_scan_mit_sot (view ):
3940 x0 = pt .vector ("x0" , dtype = "float64" , shape = (3 ,))
40- xs , _ = scan (
41+ xs = scan (
4142 lambda xtm3 , xtm1 : xtm3 + xtm1 + 1 ,
4243 outputs_info = [{"initial" : x0 , "taps" : [- 3 , - 1 ]}],
4344 n_steps = 10 ,
45+ return_updates = False ,
4446 )
4547 if view :
4648 xs = xs [view ]
@@ -57,13 +59,14 @@ def test_scan_multiple_mit_sot(view_x, view_y):
5759 def step (xtm3 , xtm1 , ytm4 , ytm2 ):
5860 return xtm3 + ytm4 + 1 , xtm1 + ytm2 + 2
5961
60- [xs , ys ], _ = scan (
62+ [xs , ys ] = scan (
6163 fn = step ,
6264 outputs_info = [
6365 {"initial" : x0 , "taps" : [- 3 , - 1 ]},
6466 {"initial" : y0 , "taps" : [- 4 , - 2 ]},
6567 ],
6668 n_steps = 10 ,
69+ return_updates = False ,
6770 )
6871 if view_x :
6972 xs = xs [view_x ]
@@ -80,10 +83,8 @@ def test_scan_nit_sot(view):
8083
8184 xs = pt .vector ("x0" , dtype = "float64" , shape = (10 ,))
8285
83- ys , _ = scan (
84- lambda x : pt .exp (x ),
85- outputs_info = [None ],
86- sequences = [xs ],
86+ ys = scan (
87+ lambda x : pt .exp (x ), outputs_info = [None ], sequences = [xs ], return_updates = False
8788 )
8889 if view :
8990 ys = ys [view ]
@@ -106,11 +107,12 @@ def step(xtm1, ytm3, ytm1, rho):
106107 rho = pt .scalar ("rho" , dtype = "float64" )
107108 x0 = pt .vector ("xs" , shape = (2 ,))
108109 y0 = pt .vector ("ys" , shape = (3 ,))
109- [outs , _ ], _ = scan (
110+ [outs , _ ] = scan (
110111 step ,
111112 outputs_info = [x0 , {"initial" : y0 , "taps" : [- 3 , - 1 ]}],
112113 non_sequences = [rho ],
113114 n_steps = 10 ,
115+ return_updates = False ,
114116 )
115117 grads = pt .grad (outs .sum (), wrt = [x0 , y0 , rho ])
116118 compare_jax_and_py (
@@ -191,10 +193,11 @@ def update_fn(rng):
191193
192194@pytest .mark .xfail (raises = NotImplementedError )
193195def test_scan_while ():
194- xs , _ = scan (
196+ xs = scan (
195197 lambda x : (x + 1 , until (x < 10 )),
196198 outputs_info = [pt .zeros (())],
197199 n_steps = 100 ,
200+ return_updates = False ,
198201 )
199202
200203 compare_jax_and_py ([], [xs ], [])
@@ -210,7 +213,7 @@ def input_step_fn(y_tm1, y_tm3, a):
210213 res .name = "y_t"
211214 return res
212215
213- y_scan_pt , _ = scan (
216+ y_scan_pt = scan (
214217 fn = input_step_fn ,
215218 outputs_info = [
216219 {
@@ -223,6 +226,7 @@ def input_step_fn(y_tm1, y_tm3, a):
223226 non_sequences = [a_pt ],
224227 n_steps = 10 ,
225228 name = "y_scan" ,
229+ return_updates = False ,
226230 )
227231 y_scan_pt .name = "y"
228232 y_scan_pt .owner .inputs [0 ].name = "y_all"
@@ -241,11 +245,12 @@ def test_nd_scan_sit_sot(x0_func, A_func):
241245 k = 3
242246
243247 # Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
244- xs , _ = scan (
248+ xs = scan (
245249 lambda X , A : A @ X ,
246250 non_sequences = [A ],
247251 outputs_info = [x0 ],
248252 n_steps = n_steps ,
253+ return_updates = False ,
249254 )
250255
251256 x0_val = (
@@ -267,11 +272,12 @@ def test_nd_scan_sit_sot_with_seq():
267272 A = pt .matrix ("A" , shape = (k , k ))
268273
269274 # Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
270- xs , _ = scan (
275+ xs = scan (
271276 lambda X , A : A @ X ,
272277 non_sequences = [A ],
273278 sequences = [x ],
274279 n_steps = n_steps ,
280+ return_updates = False ,
275281 )
276282
277283 x_val = np .arange (n_steps * k , dtype = config .floatX ).reshape (n_steps , k )
@@ -287,11 +293,12 @@ def test_nd_scan_mit_sot():
287293 B = pt .matrix ("B" , shape = (3 , 3 ))
288294
289295 # Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
290- xs , _ = scan (
296+ xs = scan (
291297 lambda xtm3 , xtm1 , A , B : A @ xtm3 + B @ xtm1 ,
292298 outputs_info = [{"initial" : x0 , "taps" : [- 3 , - 1 ]}],
293299 non_sequences = [A , B ],
294300 n_steps = 10 ,
301+ return_updates = False ,
295302 )
296303
297304 x0_val = np .arange (9 , dtype = config .floatX ).reshape (3 , 3 )
@@ -310,12 +317,13 @@ def step(x, A):
310317 return A @ x , x .sum ()
311318
312319 # Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
313- xs , _ = scan (
320+ xs = scan (
314321 step ,
315322 outputs_info = [x0 , None ],
316323 non_sequences = [A ],
317324 n_steps = 10 ,
318325 mode = get_mode ("JAX" ),
326+ return_updates = False ,
319327 )
320328
321329 x0_val = np .arange (3 , dtype = config .floatX )
@@ -329,7 +337,13 @@ def test_default_mode_excludes_incompatible_rewrites():
329337 # See issue #426
330338 A = matrix ("A" )
331339 B = matrix ("B" )
332- out , _ = scan (lambda a , b : a @ b , outputs_info = [A ], non_sequences = [B ], n_steps = 2 )
340+ out = scan (
341+ lambda a , b : a @ b ,
342+ outputs_info = [A ],
343+ non_sequences = [B ],
344+ n_steps = 2 ,
345+ return_updates = False ,
346+ )
333347 compare_jax_and_py ([A , B ], [out ], [np .eye (3 ), np .eye (3 )], jax_mode = "JAX" )
334348
335349
@@ -353,8 +367,11 @@ def _(op, **kwargs):
353367
354368 x = pt .tensor ("x" , shape = (None , 3 ))
355369
356- out , _ = scan (
357- lambda x : inc_without_static_shape (x ), outputs_info = [None ], sequences = [x ]
370+ out = scan (
371+ lambda x : inc_without_static_shape (x ),
372+ outputs_info = [None ],
373+ sequences = [x ],
374+ return_updates = False ,
358375 )
359376 f = function ([x ], out , mode = get_mode ("JAX" ).excluding ("scan" ))
360377 assert sum (isinstance (node .op , Scan ) for node in f .maker .fgraph .apply_nodes ) == 1
@@ -364,10 +381,11 @@ def _(op, **kwargs):
364381 np .testing .assert_allclose (f (np .zeros ((0 , 3 ))), np .empty ((0 , 3 )))
365382
366383 # With known static shape we should always manage, regardless of the internal implementation
367- out2 , _ = scan (
384+ out2 = scan (
368385 lambda x : pt .specify_shape (inc_without_static_shape (x ), x .shape ),
369386 outputs_info = [None ],
370387 sequences = [x ],
388+ return_updates = False ,
371389 )
372390 f2 = function ([x ], out2 , mode = get_mode ("JAX" ).excluding ("scan" ))
373391 np .testing .assert_allclose (f2 ([[1 , 2 , 3 ]]), np .array ([[2 , 3 , 4 ]]))
@@ -418,11 +436,12 @@ def seir_one_step(ct0, dt0, st0, et0, it0, beta, gamma, delta):
418436 it1 = it0 + ct0 - dt0
419437 return st1 , et1 , it1 , logp_c1 , logp_d1
420438
421- (st , et , it , logp_c_all , logp_d_all ), _ = scan (
439+ (st , et , it , logp_c_all , logp_d_all ) = scan (
422440 fn = seir_one_step ,
423441 sequences = [C_t , D_t ],
424442 outputs_info = [st0 , et0 , it0 , None , None ],
425443 non_sequences = [beta , gamma , delta ],
444+ return_updates = False ,
426445 )
427446 st .name = "S_t"
428447 et .name = "E_t"
@@ -511,11 +530,12 @@ def cycle_step(A0, A1, A2, A1_hat, _norm, step_num):
511530 max_iter = 100
512531 tol = 1e-7
513532
514- (* _ , A1_hat , norm , _n_steps ), _ = scan (
533+ (* _ , A1_hat , norm , _n_steps ) = scan (
515534 step ,
516535 outputs_info = [A , B , C , B , norm , step_num ],
517536 non_sequences = [tol ],
518537 n_steps = max_iter ,
538+ return_updates = False ,
519539 )
520540 A1_hat = A1_hat [- 1 ]
521541
0 commit comments