Skip to content

Commit 860a7ab

Browse files
Remove unnecessary comments
1 parent d5e161b commit 860a7ab

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

pytensor/tensor/extra_ops.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import warnings
2-
from collections.abc import Collection, Iterable
2+
from collections.abc import Collection, Iterable, Sequence
33

44
import numpy as np
55

@@ -2075,7 +2075,7 @@ def concat_with_broadcast(tensor_list, axis=0):
20752075

20762076

20772077
def pack(
2078-
*tensors: TensorVariable,
2078+
*tensors: TensorVariable, axes: int | Sequence[int] | None = None
20792079
) -> tuple[TensorVariable, list[tuple[TensorVariable]]]:
20802080
"""
20812081
Given a list of tensors of varying shapes and dimensions, ravels and concatenates them into a single 1d vector.
@@ -2084,6 +2084,9 @@ def pack(
20842084
----------
20852085
tensors: TensorVariable
20862086
Tensors to be packed into a single vector.
2087+
axes: int or sequence of int, optional
2088+
Axes to be concatenated. All other axes will be raveled (packed) and joined. If None, all axes will be raveled
2089+
and joined.
20872090
20882091
Returns
20892092
-------
@@ -2095,13 +2098,11 @@ def pack(
20952098
if not tensors:
20962099
raise ValueError("Cannot pack an empty list of tensors.")
20972100

2098-
# Get the shapes of the input tensors
20992101
packed_shapes = [
21002102
t.type.shape if not any(s is None for s in t.type.shape) else t.shape
21012103
for t in tensors
21022104
]
21032105

2104-
# Flatten each tensor and concatenate them into a single 1D vector
21052106
flat_tensor = join(0, *[t.ravel() for t in tensors])
21062107

21072108
return flat_tensor, packed_shapes

0 commit comments

Comments
 (0)