Skip to content

Commit e7886cc

Browse files
committed
Start using new API in tests that don't involve shared updates
1 parent 207735f commit e7886cc

File tree

6 files changed

+505
-266
lines changed

6 files changed

+505
-266
lines changed

tests/link/jax/test_scan.py

Lines changed: 40 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,11 @@
2323
@pytest.mark.parametrize("view", [None, (-1,), slice(-2, None, None)])
2424
def test_scan_sit_sot(view):
2525
x0 = pt.scalar("x0", dtype="float64")
26-
xs, _ = scan(
26+
xs = scan(
2727
lambda xtm1: xtm1 + 1,
2828
outputs_info=[x0],
2929
n_steps=10,
30+
return_updates=False,
3031
)
3132
if view:
3233
xs = xs[view]
@@ -37,10 +38,11 @@ def test_scan_sit_sot(view):
3738
@pytest.mark.parametrize("view", [None, (-1,), slice(-4, -1, None)])
3839
def test_scan_mit_sot(view):
3940
x0 = pt.vector("x0", dtype="float64", shape=(3,))
40-
xs, _ = scan(
41+
xs = scan(
4142
lambda xtm3, xtm1: xtm3 + xtm1 + 1,
4243
outputs_info=[{"initial": x0, "taps": [-3, -1]}],
4344
n_steps=10,
45+
return_updates=False,
4446
)
4547
if view:
4648
xs = xs[view]
@@ -57,13 +59,14 @@ def test_scan_multiple_mit_sot(view_x, view_y):
5759
def step(xtm3, xtm1, ytm4, ytm2):
5860
return xtm3 + ytm4 + 1, xtm1 + ytm2 + 2
5961

60-
[xs, ys], _ = scan(
62+
[xs, ys] = scan(
6163
fn=step,
6264
outputs_info=[
6365
{"initial": x0, "taps": [-3, -1]},
6466
{"initial": y0, "taps": [-4, -2]},
6567
],
6668
n_steps=10,
69+
return_updates=False,
6770
)
6871
if view_x:
6972
xs = xs[view_x]
@@ -80,10 +83,8 @@ def test_scan_nit_sot(view):
8083

8184
xs = pt.vector("x0", dtype="float64", shape=(10,))
8285

83-
ys, _ = scan(
84-
lambda x: pt.exp(x),
85-
outputs_info=[None],
86-
sequences=[xs],
86+
ys = scan(
87+
lambda x: pt.exp(x), outputs_info=[None], sequences=[xs], return_updates=False
8788
)
8889
if view:
8990
ys = ys[view]
@@ -106,11 +107,12 @@ def step(xtm1, ytm3, ytm1, rho):
106107
rho = pt.scalar("rho", dtype="float64")
107108
x0 = pt.vector("xs", shape=(2,))
108109
y0 = pt.vector("ys", shape=(3,))
109-
[outs, _], _ = scan(
110+
[outs, _] = scan(
110111
step,
111112
outputs_info=[x0, {"initial": y0, "taps": [-3, -1]}],
112113
non_sequences=[rho],
113114
n_steps=10,
115+
return_updates=False,
114116
)
115117
grads = pt.grad(outs.sum(), wrt=[x0, y0, rho])
116118
compare_jax_and_py(
@@ -191,10 +193,11 @@ def update_fn(rng):
191193

192194
@pytest.mark.xfail(raises=NotImplementedError)
193195
def test_scan_while():
194-
xs, _ = scan(
196+
xs = scan(
195197
lambda x: (x + 1, until(x < 10)),
196198
outputs_info=[pt.zeros(())],
197199
n_steps=100,
200+
return_updates=False,
198201
)
199202

200203
compare_jax_and_py([], [xs], [])
@@ -210,7 +213,7 @@ def input_step_fn(y_tm1, y_tm3, a):
210213
res.name = "y_t"
211214
return res
212215

213-
y_scan_pt, _ = scan(
216+
y_scan_pt = scan(
214217
fn=input_step_fn,
215218
outputs_info=[
216219
{
@@ -223,6 +226,7 @@ def input_step_fn(y_tm1, y_tm3, a):
223226
non_sequences=[a_pt],
224227
n_steps=10,
225228
name="y_scan",
229+
return_updates=False,
226230
)
227231
y_scan_pt.name = "y"
228232
y_scan_pt.owner.inputs[0].name = "y_all"
@@ -241,11 +245,12 @@ def test_nd_scan_sit_sot(x0_func, A_func):
241245
k = 3
242246

243247
# Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
244-
xs, _ = scan(
248+
xs = scan(
245249
lambda X, A: A @ X,
246250
non_sequences=[A],
247251
outputs_info=[x0],
248252
n_steps=n_steps,
253+
return_updates=False,
249254
)
250255

251256
x0_val = (
@@ -267,11 +272,12 @@ def test_nd_scan_sit_sot_with_seq():
267272
A = pt.matrix("A", shape=(k, k))
268273

269274
# Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
270-
xs, _ = scan(
275+
xs = scan(
271276
lambda X, A: A @ X,
272277
non_sequences=[A],
273278
sequences=[x],
274279
n_steps=n_steps,
280+
return_updates=False,
275281
)
276282

277283
x_val = np.arange(n_steps * k, dtype=config.floatX).reshape(n_steps, k)
@@ -287,11 +293,12 @@ def test_nd_scan_mit_sot():
287293
B = pt.matrix("B", shape=(3, 3))
288294

289295
# Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
290-
xs, _ = scan(
296+
xs = scan(
291297
lambda xtm3, xtm1, A, B: A @ xtm3 + B @ xtm1,
292298
outputs_info=[{"initial": x0, "taps": [-3, -1]}],
293299
non_sequences=[A, B],
294300
n_steps=10,
301+
return_updates=False,
295302
)
296303

297304
x0_val = np.arange(9, dtype=config.floatX).reshape(3, 3)
@@ -310,12 +317,13 @@ def step(x, A):
310317
return A @ x, x.sum()
311318

312319
# Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
313-
xs, _ = scan(
320+
xs = scan(
314321
step,
315322
outputs_info=[x0, None],
316323
non_sequences=[A],
317324
n_steps=10,
318325
mode=get_mode("JAX"),
326+
return_updates=False,
319327
)
320328

321329
x0_val = np.arange(3, dtype=config.floatX)
@@ -329,7 +337,13 @@ def test_default_mode_excludes_incompatible_rewrites():
329337
# See issue #426
330338
A = matrix("A")
331339
B = matrix("B")
332-
out, _ = scan(lambda a, b: a @ b, outputs_info=[A], non_sequences=[B], n_steps=2)
340+
out = scan(
341+
lambda a, b: a @ b,
342+
outputs_info=[A],
343+
non_sequences=[B],
344+
n_steps=2,
345+
return_updates=False,
346+
)
333347
compare_jax_and_py([A, B], [out], [np.eye(3), np.eye(3)], jax_mode="JAX")
334348

335349

@@ -353,8 +367,11 @@ def _(op, **kwargs):
353367

354368
x = pt.tensor("x", shape=(None, 3))
355369

356-
out, _ = scan(
357-
lambda x: inc_without_static_shape(x), outputs_info=[None], sequences=[x]
370+
out = scan(
371+
lambda x: inc_without_static_shape(x),
372+
outputs_info=[None],
373+
sequences=[x],
374+
return_updates=False,
358375
)
359376
f = function([x], out, mode=get_mode("JAX").excluding("scan"))
360377
assert sum(isinstance(node.op, Scan) for node in f.maker.fgraph.apply_nodes) == 1
@@ -364,10 +381,11 @@ def _(op, **kwargs):
364381
np.testing.assert_allclose(f(np.zeros((0, 3))), np.empty((0, 3)))
365382

366383
# With known static shape we should always manage, regardless of the internal implementation
367-
out2, _ = scan(
384+
out2 = scan(
368385
lambda x: pt.specify_shape(inc_without_static_shape(x), x.shape),
369386
outputs_info=[None],
370387
sequences=[x],
388+
return_updates=False,
371389
)
372390
f2 = function([x], out2, mode=get_mode("JAX").excluding("scan"))
373391
np.testing.assert_allclose(f2([[1, 2, 3]]), np.array([[2, 3, 4]]))
@@ -418,11 +436,12 @@ def seir_one_step(ct0, dt0, st0, et0, it0, beta, gamma, delta):
418436
it1 = it0 + ct0 - dt0
419437
return st1, et1, it1, logp_c1, logp_d1
420438

421-
(st, et, it, logp_c_all, logp_d_all), _ = scan(
439+
(st, et, it, logp_c_all, logp_d_all) = scan(
422440
fn=seir_one_step,
423441
sequences=[C_t, D_t],
424442
outputs_info=[st0, et0, it0, None, None],
425443
non_sequences=[beta, gamma, delta],
444+
return_updates=False,
426445
)
427446
st.name = "S_t"
428447
et.name = "E_t"
@@ -511,11 +530,12 @@ def cycle_step(A0, A1, A2, A1_hat, _norm, step_num):
511530
max_iter = 100
512531
tol = 1e-7
513532

514-
(*_, A1_hat, norm, _n_steps), _ = scan(
533+
(*_, A1_hat, norm, _n_steps) = scan(
515534
step,
516535
outputs_info=[A, B, C, B, norm, step_num],
517536
non_sequences=[tol],
518537
n_steps=max_iter,
538+
return_updates=False,
519539
)
520540
A1_hat = A1_hat[-1]
521541

tests/link/numba/test_scan.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -206,11 +206,12 @@ def seir_one_step(ct0, dt0, st0, et0, it0, logp_c, logp_d, beta, gamma, delta):
206206
it1 = it0 + ct0 - dt0
207207
return st1, et1, it1, logp_c1, logp_d1
208208

209-
(st, et, it, logp_c_all, logp_d_all), _ = scan(
209+
(st, et, it, logp_c_all, logp_d_all) = scan(
210210
fn=seir_one_step,
211211
sequences=[pt_C, pt_D],
212212
outputs_info=[st0, et0, it0, logp_c, logp_d],
213213
non_sequences=[beta, gamma, delta],
214+
return_updates=False,
214215
)
215216
st.name = "S_t"
216217
et.name = "E_t"
@@ -268,7 +269,7 @@ def input_step_fn(b, b2, c, x_tm1, y_tm1, y_tm3, a):
268269
y_t.name = "y_t"
269270
return x_t, y_t, pt.fill((10,), z_t)
270271

271-
scan_res, _ = scan(
272+
scan_res = scan(
272273
fn=input_step_fn,
273274
sequences=[
274275
{
@@ -297,6 +298,7 @@ def input_step_fn(b, b2, c, x_tm1, y_tm1, y_tm3, a):
297298
n_steps=5,
298299
name="yz_scan",
299300
strict=True,
301+
return_updates=False,
300302
)
301303

302304
test_input_vals = [
@@ -312,11 +314,12 @@ def power_of_2(previous_power, max_value):
312314
return previous_power * 2, until(previous_power * 2 > max_value)
313315

314316
max_value = pt.scalar()
315-
values, _ = scan(
317+
values = scan(
316318
power_of_2,
317319
outputs_info=pt.constant(1.0),
318320
non_sequences=max_value,
319321
n_steps=1024,
322+
return_updates=False,
320323
)
321324

322325
test_input_vals = [
@@ -331,20 +334,25 @@ def test_scan_multiple_none_output():
331334
def power_step(prior_result, x):
332335
return prior_result * x, prior_result * x * x, prior_result * x * x * x
333336

334-
result, _ = scan(
337+
result = scan(
335338
power_step,
336339
non_sequences=[A],
337340
outputs_info=[pt.ones_like(A), None, None],
338341
n_steps=3,
342+
return_updates=False,
339343
)
340344
test_input_vals = (np.array([1.0, 2.0]),)
341345
compare_numba_and_py([A], result, test_input_vals)
342346

343347

344348
def test_grad_sitsot():
345349
def get_sum_of_grad(inp):
346-
scan_outputs, _updates = scan(
347-
fn=lambda x: x * 2, outputs_info=[inp], n_steps=5, mode="NUMBA"
350+
scan_outputs = scan(
351+
fn=lambda x: x * 2,
352+
outputs_info=[inp],
353+
n_steps=5,
354+
mode="NUMBA",
355+
return_updates=False,
348356
)
349357
return grad(scan_outputs.sum(), inp).sum()
350358

@@ -362,8 +370,11 @@ def test_mitmots_basic():
362370
def inner_fct(seq, state_old, state_current):
363371
return state_old * 2 + state_current + seq
364372

365-
out, _ = scan(
366-
inner_fct, sequences=seq, outputs_info={"initial": init_x, "taps": [-2, -1]}
373+
out = scan(
374+
inner_fct,
375+
sequences=seq,
376+
outputs_info={"initial": init_x, "taps": [-2, -1]},
377+
return_updates=False,
367378
)
368379

369380
g_outs = grad(out.sum(), [seq, init_x])
@@ -383,10 +394,11 @@ def inner_fct(seq, state_old, state_current):
383394
def test_inner_graph_optimized():
384395
"""Test that inner graph of Scan is optimized"""
385396
xs = vector("xs")
386-
seq, _ = scan(
397+
seq = scan(
387398
fn=lambda x: log(1 + x),
388399
sequences=[xs],
389400
mode=get_mode("NUMBA"),
401+
return_updates=False,
390402
)
391403

392404
# Disable scan pushout, in which case the whole scan is replaced by an Elemwise
@@ -421,13 +433,14 @@ def step(seq1, seq2, mitsot1, mitsot2, sitsot1):
421433
sitsot2 = (sitsot1 + mitsot3) / np.sqrt(2)
422434
return mitsot3, sitsot2
423435

424-
outs, _ = scan(
436+
outs = scan(
425437
fn=step,
426438
sequences=[seq1, seq2],
427439
outputs_info=[
428440
dict(initial=mitsot_init, taps=[-2, -1]),
429441
dict(initial=sitsot_init, taps=[-1]),
430442
],
443+
return_updates=False,
431444
)
432445

433446
rng = np.random.default_rng(474)
@@ -468,7 +481,7 @@ def step(ztm3, ztm1, xtm1, ytm1, ytm2, a):
468481
y = ytm1 + 1 + ytm2 + a
469482
return z, x, z + x + y, y
470483

471-
[zs, xs, ws, ys], _ = scan(
484+
[zs, xs, ws, ys] = scan(
472485
fn=step,
473486
outputs_info=[
474487
dict(initial=z0, taps=[-3, -1]),
@@ -478,6 +491,7 @@ def step(ztm3, ztm1, xtm1, ytm1, ytm2, a):
478491
],
479492
non_sequences=[a],
480493
n_steps=n_steps,
494+
return_updates=False,
481495
)
482496
numba_fn, _ = compare_numba_and_py(
483497
[n_steps] * (not n_steps_constant) + [a, x0, y0, z0],
@@ -529,10 +543,11 @@ def step(ztm3, ztm1, xtm1, ytm1, ytm2, a):
529543
class TestScanSITSOTBuffer:
530544
def buffer_tester(self, n_steps, op_size, buffer_size, benchmark=None):
531545
x0 = pt.vector(shape=(op_size,), dtype="float64")
532-
xs, _ = pytensor.scan(
546+
xs = pytensor.scan(
533547
fn=lambda xtm1: (xtm1 + 1),
534548
outputs_info=[x0],
535549
n_steps=n_steps - 1, # 1- makes it easier to align/misalign
550+
return_updates=False,
536551
)
537552
if buffer_size == "unit":
538553
xs_kept = xs[-1] # Only last state is used
@@ -588,12 +603,13 @@ def f_pow2(x_tm2, x_tm1):
588603

589604
init_x = pt.vector("init_x", shape=(2,))
590605
n_steps = pt.iscalar("n_steps")
591-
output, _ = scan(
606+
output = scan(
592607
f_pow2,
593608
sequences=[],
594609
outputs_info=[{"initial": init_x, "taps": [-2, -1]}],
595610
non_sequences=[],
596611
n_steps=n_steps_val if constant_n_steps else n_steps,
612+
return_updates=False,
597613
)
598614

599615
init_x_val = np.array([1.0, 2.0], dtype=init_x.type.dtype)

0 commit comments

Comments
 (0)