Skip to content

Commit b9af952

Browse files
committed
Optimize partial trace definition
1 parent b3b6861 commit b9af952

File tree

1 file changed

+17
-7
lines changed

1 file changed

+17
-7
lines changed

pytensor/link/jax/dispatch/scan.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -209,18 +209,28 @@ def get_partial_traces(traces):
209209
):
210210
if init_state is not None:
211211
# MIT-SOT and SIT-SOT: The final output should be as long as the input buffer
212-
trace = jnp.atleast_1d(trace)
213-
init_state = jnp.expand_dims(
214-
init_state, range(trace.ndim - init_state.ndim)
215-
)
216-
full_trace = jnp.concatenate([init_state, trace], axis=0)
217212
buffer_size = buffer.shape[0]
213+
if trace.shape[0] > buffer_size:
214+
# Trace is longer than buffer, keep just the last `buffer.shape[0]` entries
215+
partial_trace = trace[-buffer_size:]
216+
else:
217+
# Trace is shorter than buffer, this happens when we keep the initial_state
218+
if init_state.ndim < buffer.ndim:
219+
init_state = init_state[None]
220+
if (
221+
n_init_needed := buffer_size - trace.shape[0]
222+
) < init_state.shape[0]:
223+
# We may not need to keep all the initial states
224+
init_state = init_state[-n_init_needed:]
225+
partial_trace = jnp.concatenate([init_state, trace], axis=0)
218226
else:
219227
# NIT-SOT: Buffer is just the number of entries that should be returned
220-
full_trace = jnp.atleast_1d(trace)
221228
buffer_size = buffer
229+
partial_trace = (
230+
trace[-buffer_size:] if trace.shape[0] > buffer else trace
231+
)
222232

223-
partial_trace = full_trace[-buffer_size:]
233+
assert partial_trace.shape[0] == buffer_size
224234
partial_traces.append(partial_trace)
225235

226236
return partial_traces

0 commit comments

Comments
 (0)