Skip to content

Rewrite mit-mot Scans as sit-sot #1687

@ricardoV94

Description

@ricardoV94

Description

Scan always uses mit-mot for reverse-mode autodiff. This is the most general approach as it allows arbitrary connection pattern between intermediate states and the function cost. However, many times users select only the last state, and the mit-mot is doing a useless reading / adding of zeros in each step (as all but the last step are disconnected).

Here is an example:

import pytensor
import pytensor.tensor as pt

x0 = pt.scalar("x0")
xs, _ = pytensor.scan(lambda x: x ** 2, outputs_info=[x0], n_steps=4)
g = pt.grad(xs[-1], x0)
pytensor.function([x0], g).dprint(print_shape=True)
Print results

Sum{axes=None} [id A] shape=() 12
 └─ Subtensor{start:stop:step} [id B] shape=(?,) 11
    ├─ Scan{grad_of_scan_fn, while_loop=False, inplace=all} [id C] shape=(?,) 10
    │  ├─ 4 [id D] shape=()
    │  ├─ Subtensor{start:stop:step} [id E] shape=(?,) 9
    │  │  ├─ Scan{scan_fn, while_loop=False, inplace=all} [id F] shape=(?,) 7
    │  │  │  ├─ 3 [id G] shape=()
    │  │  │  └─ SetSubtensor{:stop} [id H] shape=(4,) 5
    │  │  │     ├─ AllocEmpty{dtype='float64'} [id I] shape=(4,) 0
    │  │  │     │  └─ 4 [id D] shape=()
    │  │  │     ├─ ExpandDims{axis=0} [id J] shape=(1,) 3
    │  │  │     │  └─ x0 [id K] shape=()
    │  │  │     └─ 1 [id L] shape=()
    │  │  ├─ 3 [id M] shape=()
    │  │  ├─ -5 [id N] shape=()
    │  │  └─ -1 [id O] shape=()
    │  └─ Subtensor{::step} [id P] shape=(?,) 8
    │     ├─ IncSubtensor{start:} [id Q] shape=(5,) 6
    │     │  ├─ Alloc [id R] shape=(5,) 2
    │     │  │  ├─ [0.] [id S] shape=(1,)
    │     │  │  └─ 5 [id T] shape=()
    │     │  ├─ IncSubtensor{i} [id U] shape=(4,) 4
    │     │  │  ├─ Alloc [id V] shape=(4,) 1
    │     │  │  │  ├─ [0.] [id S] shape=(1,)
    │     │  │  │  └─ 4 [id D] shape=()
    │     │  │  ├─ 1.0 [id W] shape=()
    │     │  │  └─ -1 [id O] shape=()
    │     │  └─ 1 [id L] shape=()
    │     └─ -1 [id O] shape=()
    ├─ 4 [id X] shape=()
    ├─ 3 [id M] shape=()
    └─ -1 [id O] shape=()
Inner graphs:
Scan{grad_of_scan_fn, while_loop=False, inplace=all} [id C]
 ← Composite{((2.0 * i1 * i2) + i0)} [id Y] shape=()
    ├─ *2-<Scalar(float64, shape=())> [id Z] shape=() -> [id P]
    ├─ *1-<Scalar(float64, shape=())> [id BA] shape=() -> [id P]
    └─ *0-<Scalar(float64, shape=())> [id BB] shape=() -> [id E]
Scan{scan_fn, while_loop=False, inplace=all} [id F]
 ← Sqr [id BC] shape=()
    └─ *0-<Scalar(float64, shape=())> [id BB] shape=() -> [id H]

The MIT-MOT looks like

    │  └─ Subtensor{::step} [id P] shape=(?,) 8
    │     ├─ IncSubtensor{start:} [id Q] shape=(5,) 6
    │     │  ├─ Alloc [id R] shape=(5,) 2
    │     │  │  ├─ [0.] [id S] shape=(1,)
    │     │  │  └─ 5 [id T] shape=()
    │     │  ├─ IncSubtensor{i} [id U] shape=(4,) 4
    │     │  │  ├─ Alloc [id V] shape=(4,) 1
    │     │  │  │  ├─ [0.] [id S] shape=(1,)
    │     │  │  │  └─ 4 [id D] shape=()
    │     │  │  ├─ 1.0 [id W] shape=()
    │     │  │  └─ -1 [id O] shape=()
    │     │  └─ 1 [id L] shape=()
    │     └─ -1 [id O] shape=()

This is cleaned up a bit by #1666 , but if we read carefully (or evaluate it), we see it's just [1, 0, 0, 0, 0]

from pytensor.scan.op import Scan
from pytensor.graph.traversal import apply_ancestors

grad_scan = next(n for n in apply_ancestors([g]) if isinstance(n.op, Scan))
n_steps, forward_seq, _, mit_mot = grad_scan.inputs
mit_mot.eval({x0: 0.95})  # array([1., 0., 0., 0., 0.])

SIT-SOT should be more performant as it doesn't require materializing/reading the whole tape but only the last updated state (after the scan memsave rewrite that is)

equiv_scan_with_x0_masked, _ = pytensor.scan(
    lambda s, g_out: 2 * s * g_out, 
    sequences=[forward_seq],
    # Here we would put whatever the gradient at the last step is
    # It's one in our case
    outputs_info=[x0.ones_like()],
    n_steps=n_steps,
)
equiv_scan = equiv_scan_with_x0_masked[0].owner.inputs[0].owner.inputs[0]
equiv_scan.eval({x0: 0.95}), first_scan.out.eval({x0: 0.95})
# (array([1.        , 1.32684086, 2.16144035, 3.90139983, 7.41265968]),
#  array([1.        , 1.32684086, 2.16144035, 3.90139983, 7.41265968]))

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions