11import warnings
2- from collections .abc import Collection , Iterable
2+ from collections .abc import Collection , Iterable , Sequence
33from textwrap import dedent
44
55import numpy as np
@@ -2012,7 +2012,7 @@ def concat_with_broadcast(tensor_list, axis=0):
20122012
20132013
20142014def pack (
2015- * tensors : TensorVariable ,
2015+ * tensors : TensorVariable , axes : int | Sequence [ int ] | None = None
20162016) -> tuple [TensorVariable , list [tuple [TensorVariable ]]]:
20172017 """
20182018 Given a list of tensors of varying shapes and dimensions, ravels and concatenates them into a single 1d vector.
@@ -2021,6 +2021,9 @@ def pack(
20212021 ----------
20222022 tensors: TensorVariable
20232023 Tensors to be packed into a single vector.
2024+ axes: int or sequence of int, optional
2025+ Axes to be concatenated. All other axes will be raveled (packed) and joined. If None, all axes will be raveled
2026+ and joined.
20242027
20252028 Returns
20262029 -------
@@ -2032,13 +2035,11 @@ def pack(
20322035 if not tensors :
20332036 raise ValueError ("Cannot pack an empty list of tensors." )
20342037
2035- # Get the shapes of the input tensors
20362038 packed_shapes = [
20372039 t .type .shape if not any (s is None for s in t .type .shape ) else t .shape
20382040 for t in tensors
20392041 ]
20402042
2041- # Flatten each tensor and concatenate them into a single 1D vector
20422043 flat_tensor = join (0 , * [t .ravel () for t in tensors ])
20432044
20442045 return flat_tensor , packed_shapes
0 commit comments