|
25 | 25 | from pytensor.scalar import upcast |
26 | 26 | from pytensor.tensor import TensorLike, as_tensor_variable |
27 | 27 | from pytensor.tensor import basic as ptb |
28 | | -from pytensor.tensor.basic import alloc, arange, join, second |
| 28 | +from pytensor.tensor.basic import alloc, join, second, split |
29 | 29 | from pytensor.tensor.exceptions import NotScalarConstantError |
30 | 30 | from pytensor.tensor.math import abs as pt_abs |
31 | 31 | from pytensor.tensor.math import all as pt_all |
@@ -2065,17 +2065,15 @@ def unpack( |
2065 | 2065 | if not packed_shapes: |
2066 | 2066 | raise ValueError("Cannot unpack an empty list of shapes.") |
2067 | 2067 |
|
2068 | | - start = 0 |
2069 | | - unpacked_tensors = [] |
2070 | | - for shape in packed_shapes: |
2071 | | - size = prod(shape, no_zeros_in_input=True) |
2072 | | - end = start + size |
2073 | | - unpacked_tensors.append( |
2074 | | - take(flat_tensor, arange(start, end, dtype="int"), axis=0).reshape(shape) |
2075 | | - ) |
2076 | | - start = end |
| 2068 | + n_splits = len(packed_shapes) |
| 2069 | + split_size = [ |
| 2070 | + prod(shape, no_zeros_in_input=True).astype(int) for shape in packed_shapes |
| 2071 | + ] |
| 2072 | + unpacked_tensors = split(flat_tensor, splits_size=split_size, n_splits=n_splits) |
2077 | 2073 |
|
2078 | | - return tuple(unpacked_tensors) |
| 2074 | + return tuple( |
| 2075 | + [x.reshape(shape) for x, shape in zip(unpacked_tensors, packed_shapes)] |
| 2076 | + ) |
2079 | 2077 |
|
2080 | 2078 |
|
2081 | 2079 | __all__ = [ |
|
0 commit comments