11import warnings
2- from collections .abc import Collection , Iterable
2+ from collections .abc import Collection , Iterable , Sequence
33
44import numpy as np
55
@@ -2075,7 +2075,7 @@ def concat_with_broadcast(tensor_list, axis=0):
20752075
20762076
20772077def pack (
2078- * tensors : TensorVariable ,
2078+ * tensors : TensorVariable , axes : int | Sequence [ int ] | None = None
20792079) -> tuple [TensorVariable , list [tuple [TensorVariable ]]]:
20802080 """
20812081 Given a list of tensors of varying shapes and dimensions, ravels and concatenates them into a single 1d vector.
@@ -2084,6 +2084,9 @@ def pack(
20842084 ----------
20852085 tensors: TensorVariable
20862086 Tensors to be packed into a single vector.
2087+ axes: int or sequence of int, optional
2088+ Axes to be concatenated. All other axes will be raveled (packed) and joined. If None, all axes will be raveled
2089+ and joined.
20872090
20882091 Returns
20892092 -------
@@ -2095,13 +2098,11 @@ def pack(
20952098 if not tensors :
20962099 raise ValueError ("Cannot pack an empty list of tensors." )
20972100
2098- # Get the shapes of the input tensors
20992101 packed_shapes = [
21002102 t .type .shape if not any (s is None for s in t .type .shape ) else t .shape
21012103 for t in tensors
21022104 ]
21032105
2104- # Flatten each tensor and concatenate them into a single 1D vector
21052106 flat_tensor = join (0 , * [t .ravel () for t in tensors ])
21062107
21072108 return flat_tensor , packed_shapes
0 commit comments