Skip to content

Commit 58c0286

Browse files
Remove unnecessary comments
1 parent 79d9662 commit 58c0286

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

pytensor/tensor/extra_ops.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import warnings
2-
from collections.abc import Collection, Iterable
2+
from collections.abc import Collection, Iterable, Sequence
33
from textwrap import dedent
44

55
import numpy as np
@@ -2012,7 +2012,7 @@ def concat_with_broadcast(tensor_list, axis=0):
20122012

20132013

20142014
def pack(
2015-
*tensors: TensorVariable,
2015+
*tensors: TensorVariable, axes: int | Sequence[int] | None = None
20162016
) -> tuple[TensorVariable, list[tuple[TensorVariable]]]:
20172017
"""
20182018
Given a list of tensors of varying shapes and dimensions, ravels and concatenates them into a single 1d vector.
@@ -2021,6 +2021,9 @@ def pack(
20212021
----------
20222022
tensors: TensorVariable
20232023
Tensors to be packed into a single vector.
2024+
axes: int or sequence of int, optional
2025+
Axes to be concatenated. All other axes will be raveled (packed) and joined. If None, all axes will be raveled
2026+
and joined.
20242027
20252028
Returns
20262029
-------
@@ -2032,13 +2035,11 @@ def pack(
20322035
if not tensors:
20332036
raise ValueError("Cannot pack an empty list of tensors.")
20342037

2035-
# Get the shapes of the input tensors
20362038
packed_shapes = [
20372039
t.type.shape if not any(s is None for s in t.type.shape) else t.shape
20382040
for t in tensors
20392041
]
20402042

2041-
# Flatten each tensor and concatenate them into a single 1D vector
20422043
flat_tensor = join(0, *[t.ravel() for t in tensors])
20432044

20442045
return flat_tensor, packed_shapes

0 commit comments

Comments
 (0)