Skip to content

Commit 2e22d34

Browse files
Use split
1 parent 9568a83 commit 2e22d34

File tree

1 file changed

+9
-11
lines changed

1 file changed

+9
-11
lines changed

pytensor/tensor/extra_ops.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from pytensor.scalar import upcast
2626
from pytensor.tensor import TensorLike, as_tensor_variable
2727
from pytensor.tensor import basic as ptb
28-
from pytensor.tensor.basic import alloc, arange, join, second
28+
from pytensor.tensor.basic import alloc, join, second, split
2929
from pytensor.tensor.exceptions import NotScalarConstantError
3030
from pytensor.tensor.math import abs as pt_abs
3131
from pytensor.tensor.math import all as pt_all
@@ -2065,17 +2065,15 @@ def unpack(
20652065
if not packed_shapes:
20662066
raise ValueError("Cannot unpack an empty list of shapes.")
20672067

2068-
start = 0
2069-
unpacked_tensors = []
2070-
for shape in packed_shapes:
2071-
size = prod(shape, no_zeros_in_input=True)
2072-
end = start + size
2073-
unpacked_tensors.append(
2074-
take(flat_tensor, arange(start, end, dtype="int"), axis=0).reshape(shape)
2075-
)
2076-
start = end
2068+
n_splits = len(packed_shapes)
2069+
split_size = [
2070+
prod(shape, no_zeros_in_input=True).astype(int) for shape in packed_shapes
2071+
]
2072+
unpacked_tensors = split(flat_tensor, splits_size=split_size, n_splits=n_splits)
20772073

2078-
return tuple(unpacked_tensors)
2074+
return tuple(
2075+
[x.reshape(shape) for x, shape in zip(unpacked_tensors, packed_shapes)]
2076+
)
20792077

20802078

20812079
__all__ = [

0 commit comments

Comments
 (0)