Skip to content

Conversation

@ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Oct 29, 2025

This PR allows for explicit untraceable entries in outputs_info in Scan, and deprecates (for developers for now) the whole shared updates shenanigans

This way one can pass an RNG as an explicit outputs_info, and get the final value explicitly. The reason this is untraceable and has special logic is that we can't concatenate each intermediate state in a numpy array.

Well we could probably use a numpy object array or a TypedList, and may want to in the future, but for now I just want to start deprecating the whole updates complexity in Scan.

Note that internally the functionality was already there. This PR is using exactly the same old shared_outs machinery (now renamed everywhere to untraced_sit_sot), which allows an output to be carried without trying to place or read it from an array with more dimensions.

We should actually always use this machinery when only the last state is needed for tensor variables as well. But that's for another day.

The following code was impossible to write before, any rng that we wished to update in a scan had to be a shared variable.

    rng_init = random_generator_type("rng")
    rng_x0, x0 = pt.random.normal(0, rng=rng_init).owner.outputs

    def step(prev_x, prev_rng):
        next_rng, next_x = pt.random.normal(prev_x, rng=prev_rng).owner.outputs
        return next_x, next_rng

    [xs, rng_final], updates = scan(
        fn=step,
        outputs_info=[x0, rng_x0],
        n_steps=10,
    )
    assert isinstance(xs.type, TensorType)
    assert isinstance(rng_final.type, RandomGeneratorType)
    assert not updates

As we did for OpFromGraph, there should be no concept of SharedVariables in regular Ops, just outside of a PyTensor function. This PR moves in that direction.

@ricardoV94
Copy link
Member Author

ricardoV94 commented Oct 29, 2025

As discussed in #1706 I think the API for the example I posted in this issue should be:

    rng_init = random_generator_type("rng")
    rng_x0, x0 = rng_init.normal()

    def step(prev_x, prev_rng):
        next_rng, next_x = prev_rng.normal(prev_x)
        return next_x, next_rng

    [xs, rng_final], updates = scan(
        fn=step,
        outputs_info=[x0, rng_x0],
        n_steps=10,
    )
    assert isinstance(xs.type, TensorType)
    assert isinstance(rng_final.type, RandomGeneratorType)
    assert not updates

And eventually we'll remove the whole updates thing from pytensor.scan.
This is outside of the scope of this PR, but this functionality is kind of a pre-requisite for the new API to play nicely

@ricardoV94 ricardoV94 force-pushed the scan_rng_outputs_info branch from 484a9c8 to 64a08e7 Compare October 30, 2025 15:32
@ricardoV94 ricardoV94 force-pushed the scan_rng_outputs_info branch from 64a08e7 to 0a23cac Compare October 30, 2025 15:33
@ricardoV94 ricardoV94 marked this pull request as ready for review October 30, 2025 15:34
@ricardoV94
Copy link
Member Author

I went ahead and started deprecating the updates API. For now you have to pass return_updates=False in which case scan will not return the updates dict (and raises if it wouldn't be empty). This is a DeprecationWarning so users won't see it by default. It's meant to start changing things under the hood in the ecosystem.

I envision we move to a FutureWarning and finally remove. Before we do that I think we need to implement #1707 to offer a viable alternative API to RandomStreams. It doesn't make sense to ask users to retrieve hidden updates with owner.outputs.

@ricardoV94 ricardoV94 force-pushed the scan_rng_outputs_info branch 4 times, most recently from f042578 to 7d3c8aa Compare October 31, 2025 11:46
@ricardoV94 ricardoV94 changed the title Allow non-shared untraced SIT-SOT in Scan Start deprecating shared updates functionality in Scan Oct 31, 2025
@ricardoV94 ricardoV94 force-pushed the scan_rng_outputs_info branch from 7d3c8aa to 58bd365 Compare October 31, 2025 12:11
@codecov
Copy link

codecov bot commented Oct 31, 2025

Codecov Report

❌ Patch coverage is 81.85841% with 41 lines in your changes missing coverage. Please review.
✅ Project coverage is 81.71%. Comparing base (1f9a67b) to head (f8aa58b).
⚠️ Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/scan/op.py 67.50% 19 Missing and 7 partials ⚠️
pytensor/scan/basic.py 94.25% 2 Missing and 3 partials ⚠️
pytensor/scan/views.py 55.55% 2 Missing and 2 partials ⚠️
pytensor/scan/checkpoints.py 57.14% 2 Missing and 1 partial ⚠️
pytensor/scan/rewriting.py 72.72% 3 Missing ⚠️

❌ Your patch check has failed because the patch coverage (81.85%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1704      +/-   ##
==========================================
+ Coverage   81.70%   81.71%   +0.01%     
==========================================
  Files         246      246              
  Lines       53632    53668      +36     
  Branches     9438     9442       +4     
==========================================
+ Hits        43820    43855      +35     
- Misses       7330     7335       +5     
+ Partials     2482     2478       -4     
Files with missing lines Coverage Δ
pytensor/compile/function/pfunc.py 83.41% <100.00%> (+0.08%) ⬆️
pytensor/gradient.py 77.99% <100.00%> (-0.03%) ⬇️
pytensor/link/jax/dispatch/scan.py 97.18% <100.00%> (ø)
...sor/link/numba/dispatch/linalg/decomposition/lu.py 66.66% <100.00%> (ø)
pytensor/link/numba/dispatch/scan.py 96.02% <100.00%> (ø)
pytensor/scan/utils.py 87.58% <100.00%> (ø)
pytensor/tensor/pad.py 97.14% <100.00%> (ø)
pytensor/scan/checkpoints.py 74.00% <57.14%> (-1.52%) ⬇️
pytensor/scan/rewriting.py 82.97% <72.72%> (ø)
pytensor/scan/views.py 77.77% <55.55%> (-7.94%) ⬇️
... and 2 more

... and 2 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

array of row swaps, such that L[perm] @ U = A.
"""
return linalg.lu(
return linalg.lu( # type: ignore[no-any-return]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was this failing before?

Copy link
Member Author

@ricardoV94 ricardoV94 Nov 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's because I removed the scipy stubs (first commit), because they broke the run_mypy output. I'll open an issue to track.

- diff-cover
- mypy
- types-setuptools
- scipy-stubs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

broke run_mypy when there are errors

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR introduces a significant API change to PyTensor's scan function to deprecate the two-return-value pattern (outputs, updates) in favor of returning only outputs when updates are empty. The changes also rename internal "shared" variables to "untraced_sit_sot" for better clarity. Key changes include:

  • Adding return_updates parameter to scan and related functions with default True for backward compatibility
  • Renaming internal n_shared_outs to n_untraced_sit_sot_outs throughout the codebase
  • Supporting non-stacking output types (like RNG variables) via the new untraced mechanism
  • Updating all tests to use return_updates=False where appropriate
  • Adding deprecation warnings for the old API and transitional logic

Reviewed Changes

Copilot reviewed 22 out of 22 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
pytensor/scan/basic.py Core API changes: adds return_updates parameter, manages deprecation warnings, implements untraced sit_sot logic
pytensor/scan/op.py Renames n_shared_outs → n_untraced_sit_sot_outs, adds deprecated property aliases
pytensor/scan/utils.py Updates utility functions to use new naming convention
pytensor/scan/rewriting.py Updates optimization passes to use new naming
pytensor/scan/views.py Adds return_updates parameter to map/reduce/foldl/foldr
pytensor/scan/checkpoints.py Adds return_updates parameter to scan_checkpoints
tests/scan/test_basic.py Extensive test updates using return_updates=False
tests/scan/test_views.py Parametrizes tests for both API modes
tests/scan/test_rewriting.py Updates all scan calls to use return_updates=False
tests/scan/test_checkpoints.py Parametrizes TestScanCheckpoint class
tests/tensor/test_blockwise.py Minor test updates
tests/tensor/linalg/test_rewriting.py Minor test updates
tests/link/numba/test_scan.py Updates for numba backend compatibility
tests/link/jax/test_scan.py Updates for JAX backend compatibility
pytensor/tensor/pad.py Updates scan call
pytensor/gradient.py Updates hessian function scan call
pytensor/compile/function/pfunc.py Guards givens check with if statement
environment*.yml Removes scipy-stubs dependency

@ricardoV94 ricardoV94 force-pushed the scan_rng_outputs_info branch from e7886cc to dd82ce7 Compare November 1, 2025 14:28
@ricardoV94 ricardoV94 force-pushed the scan_rng_outputs_info branch from dd82ce7 to f8aa58b Compare November 4, 2025 08:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants