-
Notifications
You must be signed in to change notification settings - Fork 146
Start deprecating shared updates functionality in Scan #1704
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
f1d6f44 to
484a9c8
Compare
|
As discussed in #1706 I think the API for the example I posted in this issue should be: rng_init = random_generator_type("rng")
rng_x0, x0 = rng_init.normal()
def step(prev_x, prev_rng):
next_rng, next_x = prev_rng.normal(prev_x)
return next_x, next_rng
[xs, rng_final], updates = scan(
fn=step,
outputs_info=[x0, rng_x0],
n_steps=10,
)
assert isinstance(xs.type, TensorType)
assert isinstance(rng_final.type, RandomGeneratorType)
assert not updatesAnd eventually we'll remove the whole |
484a9c8 to
64a08e7
Compare
Partially reverts d894350
return_steps has not been a thing for 14 years
64a08e7 to
0a23cac
Compare
|
I went ahead and started deprecating the updates API. For now you have to pass I envision we move to a FutureWarning and finally remove. Before we do that I think we need to implement #1707 to offer a viable alternative API to RandomStreams. It doesn't make sense to ask users to retrieve hidden updates with |
f042578 to
7d3c8aa
Compare
7d3c8aa to
58bd365
Compare
Codecov Report❌ Patch coverage is ❌ Your patch check has failed because the patch coverage (81.85%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #1704 +/- ##
==========================================
+ Coverage 81.70% 81.71% +0.01%
==========================================
Files 246 246
Lines 53632 53668 +36
Branches 9438 9442 +4
==========================================
+ Hits 43820 43855 +35
- Misses 7330 7335 +5
+ Partials 2482 2478 -4
🚀 New features to boost your workflow:
|
| array of row swaps, such that L[perm] @ U = A. | ||
| """ | ||
| return linalg.lu( | ||
| return linalg.lu( # type: ignore[no-any-return] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
was this failing before?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's because I removed the scipy stubs (first commit), because they broke the run_mypy output. I'll open an issue to track.
| - diff-cover | ||
| - mypy | ||
| - types-setuptools | ||
| - scipy-stubs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
broke run_mypy when there are errors
58bd365 to
e7886cc
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR introduces a significant API change to PyTensor's scan function to deprecate the two-return-value pattern (outputs, updates) in favor of returning only outputs when updates are empty. The changes also rename internal "shared" variables to "untraced_sit_sot" for better clarity. Key changes include:
- Adding
return_updatesparameter toscanand related functions with defaultTruefor backward compatibility - Renaming internal
n_shared_outston_untraced_sit_sot_outsthroughout the codebase - Supporting non-stacking output types (like RNG variables) via the new untraced mechanism
- Updating all tests to use
return_updates=Falsewhere appropriate - Adding deprecation warnings for the old API and transitional logic
Reviewed Changes
Copilot reviewed 22 out of 22 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| pytensor/scan/basic.py | Core API changes: adds return_updates parameter, manages deprecation warnings, implements untraced sit_sot logic |
| pytensor/scan/op.py | Renames n_shared_outs → n_untraced_sit_sot_outs, adds deprecated property aliases |
| pytensor/scan/utils.py | Updates utility functions to use new naming convention |
| pytensor/scan/rewriting.py | Updates optimization passes to use new naming |
| pytensor/scan/views.py | Adds return_updates parameter to map/reduce/foldl/foldr |
| pytensor/scan/checkpoints.py | Adds return_updates parameter to scan_checkpoints |
| tests/scan/test_basic.py | Extensive test updates using return_updates=False |
| tests/scan/test_views.py | Parametrizes tests for both API modes |
| tests/scan/test_rewriting.py | Updates all scan calls to use return_updates=False |
| tests/scan/test_checkpoints.py | Parametrizes TestScanCheckpoint class |
| tests/tensor/test_blockwise.py | Minor test updates |
| tests/tensor/linalg/test_rewriting.py | Minor test updates |
| tests/link/numba/test_scan.py | Updates for numba backend compatibility |
| tests/link/jax/test_scan.py | Updates for JAX backend compatibility |
| pytensor/tensor/pad.py | Updates scan call |
| pytensor/gradient.py | Updates hessian function scan call |
| pytensor/compile/function/pfunc.py | Guards givens check with if statement |
| environment*.yml | Removes scipy-stubs dependency |
e7886cc to
dd82ce7
Compare
Using DeprecationWarning to keep it visible only for devs for now
dd82ce7 to
f8aa58b
Compare
This PR allows for explicit untraceable entries in outputs_info in Scan, and deprecates (for developers for now) the whole shared updates shenanigans
This way one can pass an RNG as an explicit outputs_info, and get the final value explicitly. The reason this is untraceable and has special logic is that we can't concatenate each intermediate state in a numpy array.
Well we could probably use a numpy object array or a TypedList, and may want to in the future, but for now I just want to start deprecating the whole updates complexity in Scan.
Note that internally the functionality was already there. This PR is using exactly the same old
shared_outsmachinery (now renamed everywhere to untraced_sit_sot), which allows an output to be carried without trying to place or read it from an array with more dimensions.We should actually always use this machinery when only the last state is needed for tensor variables as well. But that's for another day.
The following code was impossible to write before, any rng that we wished to update in a scan had to be a shared variable.
As we did for OpFromGraph, there should be no concept of SharedVariables in regular Ops, just outside of a PyTensor function. This PR moves in that direction.