@@ -209,18 +209,28 @@ def get_partial_traces(traces):
209209 ):
210210 if init_state is not None :
211211 # MIT-SOT and SIT-SOT: The final output should be as long as the input buffer
212- trace = jnp .atleast_1d (trace )
213- init_state = jnp .expand_dims (
214- init_state , range (trace .ndim - init_state .ndim )
215- )
216- full_trace = jnp .concatenate ([init_state , trace ], axis = 0 )
217212 buffer_size = buffer .shape [0 ]
213+ if trace .shape [0 ] > buffer_size :
214+ # Trace is longer than buffer, keep just the last `buffer.shape[0]` entries
215+ partial_trace = trace [- buffer_size :]
216+ else :
217+ # Trace is shorter than buffer, this happens when we keep the initial_state
218+ if init_state .ndim < buffer .ndim :
219+ init_state = init_state [None ]
220+ if (
221+ n_init_needed := buffer_size - trace .shape [0 ]
222+ ) < init_state .shape [0 ]:
223+ # We may not need to keep all the initial states
224+ init_state = init_state [- n_init_needed :]
225+ partial_trace = jnp .concatenate ([init_state , trace ], axis = 0 )
218226 else :
219227 # NIT-SOT: Buffer is just the number of entries that should be returned
220- full_trace = jnp .atleast_1d (trace )
221228 buffer_size = buffer
229+ partial_trace = (
230+ trace [- buffer_size :] if trace .shape [0 ] > buffer else trace
231+ )
222232
223- partial_trace = full_trace [ - buffer_size :]
233+ assert partial_trace . shape [ 0 ] == buffer_size
224234 partial_traces .append (partial_trace )
225235
226236 return partial_traces
0 commit comments