Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion environment-osx-arm64.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ dependencies:
- diff-cover
- mypy
- types-setuptools
- scipy-stubs
- pytest
- pytest-cov
- pytest-xdist
Expand Down
1 change: 0 additions & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ dependencies:
- diff-cover
- mypy
- types-setuptools
- scipy-stubs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

broke run_mypy when there are errors

- pytest
- pytest-cov
- pytest-xdist
Expand Down
41 changes: 21 additions & 20 deletions pytensor/compile/function/pfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,26 +546,27 @@ def construct_pfunc_ins_and_outs(
"variables in the inputs list."
)

# Check that we are not using `givens` to replace input variables, because
# this typically does nothing, contrary to what one may expect.
in_var_set = set(in_variables)
try:
givens_pairs = list(givens.items())
except AttributeError:
givens_pairs = givens
for x, y in givens_pairs:
if x in in_var_set:
raise RuntimeError(
f"You are trying to replace variable '{x}' through the "
"`givens` parameter, but this variable is an input to your "
"function. Replacing inputs is currently forbidden because it "
"has no effect. One way to modify an input `x` to a function "
"evaluating f(x) is to define a new input `y` and use "
"`pytensor.function([y], f(x), givens={x: g(y)})`. Another "
"solution consists in using `pytensor.clone_replace`, e.g. like this: "
"`pytensor.function([x], "
"pytensor.clone_replace(f(x), replace={x: g(x)}))`."
)
if givens:
# Check that we are not using `givens` to replace input variables, because
# this typically does nothing, contrary to what one may expect.
in_var_set = set(in_variables)
try:
givens_pairs = list(givens.items())
except AttributeError:
givens_pairs = givens
for x, y in givens_pairs:
if x in in_var_set:
raise RuntimeError(
f"You are trying to replace variable '{x}' through the "
"`givens` parameter, but this variable is an input to your "
"function. Replacing inputs is currently forbidden because it "
"has no effect. One way to modify an input `x` to a function "
"evaluating f(x) is to define a new input `y` and use "
"`pytensor.function([y], f(x), givens={x: g(y)})`. Another "
"solution consists in using `pytensor.clone_replace`, e.g. like this: "
"`pytensor.function([x], "
"pytensor.clone_replace(f(x), replace={x: g(x)}))`."
)

if not fgraph:
# Extend the outputs with the updates on input variables so they are
Expand Down
6 changes: 2 additions & 4 deletions pytensor/gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -2188,7 +2188,7 @@ def hessian(cost, wrt, consider_constant=None, disconnected_inputs="raise"):
# It is possible that the inputs are disconnected from expr,
# even if they are connected to cost.
# This should not be an error.
hess, updates = pytensor.scan(
hess = pytensor.scan(
lambda i, y, x: grad(
y[i],
x,
Expand All @@ -2197,9 +2197,7 @@ def hessian(cost, wrt, consider_constant=None, disconnected_inputs="raise"):
),
sequences=pytensor.tensor.arange(expr.shape[0]),
non_sequences=[expr, input],
)
assert not updates, (
"Scan has returned a list of updates; this should not happen."
return_updates=False,
)
hessians.append(hess)
return as_list_or_tuple(using_list, using_tuple, hessians)
Expand Down
24 changes: 12 additions & 12 deletions pytensor/link/jax/dispatch/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,23 +60,23 @@ def scan(*outer_inputs):
mit_mot_init,
mit_sot_init,
sit_sot_init,
op.outer_shared(outer_inputs),
op.outer_untraced_sit_sot(outer_inputs),
op.outer_non_seqs(outer_inputs),
) # JAX `init`

def jax_args_to_inner_func_args(carry, x):
"""Convert JAX scan arguments into format expected by scan_inner_func.

scan(carry, x) -> scan_inner_func(seqs, MIT-SOT, SIT-SOT, shared, non_seqs)
scan(carry, x) -> scan_inner_func(seqs, MIT-SOT, SIT-SOT, untraced SIT-SOT, non_seqs)
"""

# `carry` contains all inner taps, shared terms, and non_seqs
# `carry` contains all inner taps and non_seqs
(
i,
inner_mit_mot,
inner_mit_sot,
inner_sit_sot,
inner_shared,
inner_untraced_sit_sot,
inner_non_seqs,
) = carry

Expand Down Expand Up @@ -108,7 +108,7 @@ def jax_args_to_inner_func_args(carry, x):
*mit_mot_flatten,
*mit_sot_flatten,
*inner_sit_sot,
*inner_shared,
*inner_untraced_sit_sot,
*inner_non_seqs,
)

Expand All @@ -118,22 +118,22 @@ def inner_func_outs_to_jax_outs(
):
"""Convert inner_scan_func outputs into format expected by JAX scan.

old_carry + (MIT-SOT_outs, SIT-SOT_outs, NIT-SOT_outs, shared_outs) -> (new_carry, ys)
old_carry + (MIT-SOT_outs, SIT-SOT_outs, NIT-SOT_outs, untraced_SIT-SOT_outs) -> (new_carry, ys)
"""
(
i,
old_mit_mot,
old_mit_sot,
_old_sit_sot,
_old_shared,
_old_untraced_sit_sot,
inner_non_seqs,
) = old_carry

new_mit_mot_vals = op.inner_mitmot_outs_grouped(inner_scan_outs)
new_mit_sot_vals = op.inner_mitsot_outs(inner_scan_outs)
new_sit_sot = op.inner_sitsot_outs(inner_scan_outs)
new_nit_sot = op.inner_nitsot_outs(inner_scan_outs)
new_shared = op.inner_shared_outs(inner_scan_outs)
new_untraced_sit_sot = op.inner_untraced_sit_sot_outs(inner_scan_outs)

# New carry for next step
# Update MIT-MOT buffer at positions indicated by output taps
Expand All @@ -150,14 +150,14 @@ def inner_func_outs_to_jax_outs(
old_mit_sot, new_mit_sot_vals, strict=True
)
]
# For SIT-SOT, and shared just pass along the new value
# For SIT-SOT just pass along the new value
# Non-sequences remain unchanged
new_carry = (
i + 1,
new_mit_mot,
new_mit_sot,
new_sit_sot,
new_shared,
new_untraced_sit_sot,
inner_non_seqs,
)

Expand All @@ -183,7 +183,7 @@ def jax_inner_func(carry, x):
final_mit_mot,
_final_mit_sot,
_final_sit_sot,
final_shared,
final_untraced_sit_sot,
_final_non_seqs,
),
traces,
Expand Down Expand Up @@ -238,7 +238,7 @@ def get_partial_traces(traces):
scan_outs_final = [
*final_mit_mot,
*get_partial_traces(traces),
*final_shared,
*final_untraced_sit_sot,
]

if len(scan_outs_final) == 1:
Expand Down
6 changes: 3 additions & 3 deletions pytensor/link/numba/dispatch/linalg/decomposition/lu.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def _lu_1(
Called when permute_l is True and p_indices is False, and returns a tuple of (perm, L, U), where perm an integer
array of row swaps, such that L[perm] @ U = A.
"""
return linalg.lu(
return linalg.lu( # type: ignore[no-any-return]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was this failing before?

Copy link
Member Author

@ricardoV94 ricardoV94 Nov 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's because I removed the scipy stubs (first commit), because they broke the run_mypy output. I'll open an issue to track.

a,
permute_l=permute_l,
check_finite=check_finite,
Expand All @@ -70,7 +70,7 @@ def _lu_2(
Called when permute_l is False and p_indices is True, and returns a tuple of (PL, U), where PL is the
permuted L matrix, PL = P @ L.
"""
return linalg.lu(
return linalg.lu( # type: ignore[no-any-return]
a,
permute_l=permute_l,
check_finite=check_finite,
Expand All @@ -92,7 +92,7 @@ def _lu_3(
Called when permute_l is False and p_indices is False, and returns a tuple of (P, L, U), where P is the permutation
matrix, P @ L @ U = A.
"""
return linalg.lu(
return linalg.lu( # type: ignore[no-any-return]
a,
permute_l=permute_l,
check_finite=check_finite,
Expand Down
18 changes: 9 additions & 9 deletions pytensor/link/numba/dispatch/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,19 +108,19 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
outer_in_mit_sot_names = op.outer_mitsot(outer_in_names)
outer_in_sit_sot_names = op.outer_sitsot(outer_in_names)
outer_in_nit_sot_names = op.outer_nitsot(outer_in_names)
outer_in_shared_names = op.outer_shared(outer_in_names)
outer_in_untraced_sit_sot_names = op.outer_untraced_sit_sot(outer_in_names)
outer_in_non_seqs_names = op.outer_non_seqs(outer_in_names)

# These are all the outer-input names that have produce outputs/have output
# taps (i.e. they have inner-outputs and corresponding outer-outputs).
# Outer-outputs are ordered as follows:
# mit-mot-outputs + mit-sot-outputs + sit-sot-outputs + nit-sots + shared-outputs
# mit-mot-outputs + mit-sot-outputs + sit-sot-outputs + nit-sots + untraced-sit-sot-outputs
outer_in_outtap_names = (
outer_in_mit_mot_names
+ outer_in_mit_sot_names
+ outer_in_sit_sot_names
+ outer_in_nit_sot_names
+ outer_in_shared_names
+ outer_in_untraced_sit_sot_names
)

# We create distinct variables for/references to the storage arrays for
Expand All @@ -138,16 +138,18 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
for outer_in_name in outer_in_nit_sot_names:
outer_in_to_storage_name[outer_in_name] = f"{outer_in_name}_nitsot_storage"

for outer_in_name in outer_in_shared_names:
outer_in_to_storage_name[outer_in_name] = f"{outer_in_name}_shared_storage"
for outer_in_name in outer_in_untraced_sit_sot_names:
outer_in_to_storage_name[outer_in_name] = (
f"{outer_in_name}_untraced_sit_sot_storage"
)

outer_output_names = list(outer_in_to_storage_name.values())
assert len(outer_output_names) == len(node.outputs)

# Construct the inner-input expressions (e.g. indexed storage expressions)
# Inner-inputs are ordered as follows:
# sequences + mit-mot-inputs + mit-sot-inputs + sit-sot-inputs +
# shared-inputs + non-sequences.
# untraced-sit-sot-inputs + non-sequences.
temp_scalar_storage_alloc_stmts: list[str] = []
inner_in_exprs_scalar: list[str] = []
inner_in_exprs: list[str] = []
Expand Down Expand Up @@ -204,11 +206,9 @@ def add_inner_in_expr(

# Inner-outputs consist of:
# mit-mot-outputs + mit-sot-outputs + sit-sot-outputs + nit-sots +
# shared-outputs [+ while-condition]
# untraced-sit-sot-outputs [+ while-condition]
inner_output_names = [f"inner_out_{i}" for i in range(len(op.inner_outputs))]

# inner_out_shared_names = op.inner_shared_outs(inner_output_names)

# The assignment statements that copy inner-outputs into the outer-outputs
# storage
inner_out_to_outer_in_stmts: list[str] = []
Expand Down
Loading