diff --git a/pytensor/tensor/extra_ops.py b/pytensor/tensor/extra_ops.py index 33a5a6b8dc..07868d4fcd 100644 --- a/pytensor/tensor/extra_ops.py +++ b/pytensor/tensor/extra_ops.py @@ -1,5 +1,6 @@ import warnings -from collections.abc import Collection, Iterable +from collections.abc import Collection, Iterable, Sequence +from itertools import pairwise from textwrap import dedent import numpy as np @@ -7,6 +8,7 @@ import pytensor import pytensor.scalar.basic as ps +from pytensor.compile.builders import OpFromGraph from pytensor.gradient import ( DisconnectedType, _float_zeros_like, @@ -25,7 +27,7 @@ from pytensor.scalar import upcast from pytensor.tensor import TensorLike, as_tensor_variable from pytensor.tensor import basic as ptb -from pytensor.tensor.basic import alloc, join, second +from pytensor.tensor.basic import alloc, join, second, split from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.math import abs as pt_abs from pytensor.tensor.math import all as pt_all @@ -43,7 +45,7 @@ ) from pytensor.tensor.math import max as pt_max from pytensor.tensor.math import sum as pt_sum -from pytensor.tensor.shape import Shape_i +from pytensor.tensor.shape import Shape_i, specify_shape from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes from pytensor.tensor.utils import normalize_reduce_axis @@ -2011,6 +2013,442 @@ def concat_with_broadcast(tensor_list, axis=0): return join(axis, *bcast_tensor_inputs) +class PackHelper: + def __init__(self, axes: int | Sequence[int] | None): + self.axes = tuple(axes) if isinstance(axes, list) else axes + self.op_name = "Pack{axes=" + str(self.axes) + "}" + + def _analyze_axes_list(self) -> tuple[int, int, int, int | None]: + """ + Analyze the provided axes list to determine how many axes are before and after the interval to be raveled, as + well as the minimum and maximum number of axes that the inputs can have. + + The rules are: + - Axes must be strictly increasing in both the positive and negative parts of the list. + - Negative axes must come after positive axes. + - There can be at most one "hole" in the axes list, which can be either an implicit hole on an endpoint + (e.g. [0, 1]) or an explicit hole in the middle (e.g. [0, 2] or [1, -1]). + + Returns + ------- + n_axes_before: int + The number of axes before the interval to be raveled. + n_axes_after: int + The number of axes after the interval to be raveled. + min_axes: int + The minimum number of axes that the inputs must have. + max_axes: int or None + The maximum number of axes that the inputs can have, or None if there is no strict maximum. A maximum is + only introduced when it would resolve ambiguities in the interpretation of the axes list. For example, + [2, 3] can be either interpreted as having two ravel intervals [:2] and [4:], which is illegal, + unless 3 is interpreted as -1, which is only possible if all inputs have exactly 4 axes. Likewise, + [-3, -1] can be interpreted as having two ravel intervals [:-3], [-3:], unless -3 is interpreted as 0, + which is only possible if all inputs have exactly 3 axes. + """ + axes = self.axes + if axes is None: + return 0, 0, 0, None + + if isinstance(axes, int): + axes = [axes] + + if len(set(axes)) != len(axes): + raise ValueError("axes must have no duplicates") + if axes is not None and len(axes) == 0: + raise ValueError("axes=[] is ambiguous; use None to ravel all") + + first_negative_idx = next((i for i, a in enumerate(axes) if a < 0), len(axes)) + positive_axes = list(axes[:first_negative_idx]) + negative_axes = list(axes[first_negative_idx:]) + + if not all(a < 0 for a in negative_axes): + raise ValueError("Negative axes must come after positive") + + def strictly_increasing(s): + return all(b > a for a, b in pairwise(s)) + + if (positive_axes and not strictly_increasing(positive_axes)) or ( + negative_axes and not strictly_increasing(negative_axes) + ): + raise ValueError("Axes must be strictly increasing") + + def find_gaps(s): + return [i for i, (a, b) in enumerate(pairwise(s)) if b - a > 1] + + pos_gaps = find_gaps(positive_axes) + neg_gaps = find_gaps(negative_axes) + positive_only = positive_axes and not negative_axes + negative_only = negative_axes and not positive_axes + mixed_case = positive_axes and negative_axes + + max_axes: int | None = None + + n_explicit_holes = len(pos_gaps) + len(neg_gaps) + if n_explicit_holes > 1: + raise ValueError( + "Too many holes in axes list. There can be at most one hole in the axes list, " + "including implict holes resulting from omitting the 0 or -1 axis." + ) + + if mixed_case: + if pos_gaps or neg_gaps: + raise ValueError( + "Too many holes in axes list. There can be at most one hole in the axes list, " + "including implict holes resulting from omitting the 0 or -1 axis. Because both " + "positive and negative axes are present, there is always assume to be an explit hole " + "between them." + ) + n_before = len(positive_axes) + n_after = len(negative_axes) + min_axes = n_before + n_after + + if positive_only: + # There are four cases to consider when all axes are positive: + # 0. There are two implicit gaps (0 is not present) and an explicit gap (e.g. [2, 4]) + # This case is always illegal, as there is no interpretation that would result in having + # 1. There is only an implicit right hole (e.g. [0, 1]) + # This case is legal, and requires no special interpretation. It corresponds to 'i j *' in einops + # 2. There is an explicit internal hole (e.g. [0, 2]) + # This case is legal, but requires interpreting the last axis as -1, which introduces a maximum number + # of axes. It corresponds to 'i * j' in einops, and requires at least one input to have 3 dimensions, and + # no input to have more than 3 dimensions. + # 2. The axes start at an index greater than 0, but have no internal holes (e.g. [2, 3]) + # This case is legal, but requires flipping the axes to negative indexing, so that the largest axis is + # -1, followed by -2, etc. This introduces a maximum number of axes. + if pos_gaps and positive_axes[0] != 0: + raise ValueError( + "Too many holes in axes list. There can be at most one hole in the axes list, " + "including implict holes resulting from omitting the 0 or -1 axis. In this case, " + "there is an explicit internal hole as well as an implicit left hole." + ) + + elif positive_axes[0] == 0 and not pos_gaps: + # Case 1: Only right implicit hole. No ambiguities. + n_before = positive_axes[-1] + 1 + n_after = 0 + min_axes = n_before + n_after + max_axes = None + + elif pos_gaps: + # Case 2: Explicit hole in the positives, plus right implicit hole. + split = pos_gaps[0] + 1 + n_before = split + n_after = len(positive_axes) - split + min_axes = n_before + n_after + + # Close the right implicit hole + max_axes = positive_axes[-1] + 1 + + else: + # Case 3: Left and right implicit holes, but the right can be closed by flipping to negative axes and + # adding a maximum number of axes. + # Compute min_axes and max_axes under Case 1 of the negative_only scenario, with a max_axes constraint. + max_axes = positive_axes[-1] + 1 + n_before = 0 + n_after = len(positive_axes) + min_axes = n_before + n_after + + if negative_only: + # The same four cases are considered when all axes are negative, but ordering is reversed. + # 0. There are two implicit holes (e.g. [-4, -2]) + # This case is always illegal, as there is no interpretation that would result in having only one hole + # in the axis list. + # 1. There is only an implicit left hole (e.g. [-2, -1]) + # This case is legal, and requires no special interpretation. It corresponds to '* i j' in einops + # 2. There is an explicit internal hole (e.g. [-3, -1]) + # This case is legal, but requires interpreting the smallest axis as 0, which introduces a maximum number + # of axes. It corresponds to '* i j' in einops, and requires at least one input to have 3 dimensions, and + # no input to have more than 3 dimensions. + # 3. The axes end at an index less than -1, but have no internal holes (e.g. [-4, -3]). Flip to positive + # axes, adding a maximum number of axes. Interpret the smallest axis as 0 to resolve ambiguity. + if neg_gaps and negative_axes[-1] != -1: + raise ValueError( + "Too many holes in axes list. There can be at most one hole in the axes list, " + "including implict holes resulting from omitting the 0 or -1 axis. In this case, " + "there is an explicit internal hole as well as an implicit right hole." + ) + elif negative_axes[-1] == -1 and not neg_gaps: + # Case 1: No ambiguities, only left implicit hole. + n_before = 0 + n_after = len(negative_axes) + min_axes = n_before + n_after + max_axes = None + elif neg_gaps: + # Case 2: Explicit hole in the negatives, plus left implicit hole. + split = neg_gaps[0] + 1 + n_before = split + n_after = len(negative_axes) - split + min_axes = n_before + n_after + + # Close the left implicit hole + max_axes = abs(min(negative_axes)) + else: + # Case 3: Left and right implicit holes, but the left can be closed by flipping to positive axes and + # adding a maximum number of axes. + max_axes = abs(negative_axes[0]) + n_before = negative_axes[-1] + max_axes + 1 + n_after = 0 + min_axes = n_before + n_after + + return n_before, n_after, min_axes, max_axes + + def validate_inputs(self, tensors: list[TensorLike]): + tensors = [ptb.as_tensor_variable(t) for t in tensors] + _, _, min_axes, max_axes = self._analyze_axes_list() + + if min([t.ndim for t in tensors]) < min_axes: + raise ValueError( + f"All input tensors to {self.op_name} must have at least {min_axes} dimensions, but the minimum " + f"number of dimensions found was {min([t.ndim for t in tensors])}." + ) + + max_ndim = max([t.ndim for t in tensors]) + if ( + max_axes is not None + and max_ndim > max_axes + and not any(t.ndim == max_axes for t in tensors) + ): + raise ValueError( + f"All input tensors to {self.op_name} must have at most {max_axes} dimensions, but the maximum " + f"number of dimensions found was {max_ndim}." + ) + + def infer_shape(self, tensors: list[TensorLike]) -> tuple[int | None, ...]: + tensors = [ptb.as_tensor_variable(t) for t in tensors] + n_axes_before, n_axes_after, _, _ = self._analyze_axes_list() + + def _coalesce_dim(shapes: list[int | None], axis: int) -> int | None: + unique_shapes = {s for s in shapes if s is not None} + if not unique_shapes: + return None + if len(unique_shapes) > 1: + raise ValueError( + f"Input tensors to Pack op have incompatible sizes on dimension {axis} : {shapes}" + ) + return unique_shapes.pop() + + shapes_to_pack = [ + t.type.shape[n_axes_before : t.ndim - n_axes_after] for t in tensors + ] + packed_shape = ( + None + if any( + shape is None + for packed_shape in shapes_to_pack + for shape in packed_shape + ) + else int(sum(np.prod(shapes) for shapes in shapes_to_pack)) + ) + prefix_shapes = [ + _coalesce_dim([t.type.shape[i] for t in tensors], i) + for i in range(n_axes_before) + ] + suffix_shapes = [ + _coalesce_dim( + [t.type.shape[t.ndim - n_axes_after + i] for t in tensors], + n_axes_before + i, + ) + for i in range(n_axes_after) + ] + + return (*prefix_shapes, packed_shape, *suffix_shapes) + + +class Pack(OpFromGraph): + "Wrapper for the Pack Op" + + +def pack( + *tensors: TensorVariable, axes: int | Sequence[int] | None = None +) -> tuple[TensorVariable, list[tuple[TensorVariable]]]: + """ + Given a list of tensors of varying shapes and dimensions, ravels and concatenates them into a single 1d vector. + + Parameters + ---------- + tensors: TensorVariable + Tensors to be packed. Tensors can have varying shapes and dimensions, but must have the same size along each + of the dimensions specified in the `axes` parameter. + axes: int or sequence of int, optional + Axes to be preserved. All other axes will be raveled (packed), and the output is the result of concatenating + on the new raveled dimension. If None, all axes will be raveled and joined. Axes can be either positive or + negative, but must be striclty increasing in both the positive and negative parts of the list. Negative axes + must come after positive axes. + + Returns + ------- + flat_tensor: TensorVariable + A new symbolic variable representing the concatenated 1d vector of all tensor inputs + packed_shapes: list of tuples of TensorVariable + A list of tuples, where each tuple contains the symbolic shape of the original tensors. + + Notes + ----- + This function is a helper for joining tensors of varying shapes into a single tenor. This is done by choosing a + list of axes to concatenate, and raveling all other axes. The resulting tensor are then concatenated along the + raveled axis. The original shapes of the tensors are also returned, so that they can be unpacked later. + + The `axes` parameter determines which dimensions are *not* raveled. The requested axes must exist in all input + tensors, but there are otherwwise no restrictions on the shapes or dimensions of the input tensors. For example, if + `axes=[0]`, then the first dimension of each tensor is preserved, and all other dimensions are raveled: + + .. code-block:: python + + import pytensor.tensor as pt + + x = pt.tensor("x", shape=(2, 3, 4)) + y = pt.tensor("y", shape=(2, 5)) + packed_output, shapes = pack(x, y, axes=0) + # packed_output will have shape (2, 3 * 4 + 5) = (2, 17) + + Since axes = 0, the first dimension of both `x` and `y` is preserved. This first example is equivalent to a simple + reshape and concat operation: + + .. code-block:: python + + x_reshaped = x.reshape(2, -1) # shape (2, 12) + y_reshaped = y.reshape(2, -1) # shape (2, 5) + packed_output = pt.concatenate( + [x_reshaped, y_reshaped], axis=1 + ) # shape (2, 17) + + `axes` can also be negative, in which case the axes are counted from the end of the tensor shape. For example, + if `axes=[-1]`, then the last dimension of each tensor is preserved, and all other dimensions are raveled: + + .. code-block:: python + + import pytensor.tensor as pt + + x = pt.tensor("x", shape=(3, 4, 7)) + y = pt.tensor("y", shape=(6, 2, 1, 7)) + packed_output, shapes = pack(x, y, axes=-1) + # packed_output will have shape (3 * 4 + 6 * 2 * 1, 7) = (24, 7) + + The most important restriction of `axes` is that there can be at most one "hole" in the axes list. A hole is + defined as a missing axis in the sequence of axes. The easiest way to define a hole is by using both positive + and negative axes together. For example, `axes=[0, -1]` has a hole between the first and last axes. In this case, + the first and last dimensions of each tensor are preserved, and all other dimensions are raveled: + + .. code-block:: python + + import pytensor.tensor as pt + + x = pt.tensor("x", shape=(2, 3, 2, 3, 7)) + y = pt.tensor("y", shape=(2, 6, 7)) + packed_output, shapes = pack(x, y, axes=[0, -1]) + # packed_output will have shape (2, 3 * 2 * 3 + 6, 7) = (2, 24, 7) + + Multiple explicit holes are not allowed. For example, `axes = [0, 2, -1]` is illegal because there are two holes, + one between axes 0 and 2, and another between axes 2 and -1. + + Implicit holes are also possible when using only positive or only negative axes. `axes = [0]` already has an + implicit hole to the right of axis 0. `axes = [2, 3]` has two implicit holes, one to the left of axis 2, and another + to the right. This is illegal, since there are two holes. However, `axes = [2, 3]` can be made legal if we interpret + axis 3 as the last axis (-1), which closes the right implicit hole. The interpretation requires that at least one + input tensor has exactly 4 dimensions: + + .. code-block:: python + + import pytensor.tensor as pt + + x = pt.tensor("x", shape=(5, 2, 3, 4)) + y = pt.tensor("y", shape=(2, 3, 4)) + packed_output, shapes = pack(x, y, axes=[2, 3]) + # packed_output will have shape (5 * 2 + 2, 3, 4) = (12, 3, 4) + + Note here that `y` has only 3 dimensions, so axis 3 is interpreted as -1, the last axis. If no input has 4 + dimensions, or if any input has more than 4 dimensions, an error is raised in this case. + + Negative axes have similar rules regarding implicit holes. `axes = [-1]` has an implicit hole to the left of + axis -1. `axes = [-3, -2]` has two implicit holes. To arrive at a valid interpretation, we take -3 to be axis 0, + which closes the left implicit hole. This requires that at least one input tensor has exactly 3 dimensions: + + .. code-block:: python + + import pytensor.tensor as pt + + x = pt.tensor("x", shape=(2, 3, 4)) + y = pt.tensor("y", shape=(6, 4)) + packed_output, shapes = pack(x, y, axes=[-3, -2]) + # packed_output will have shape (2 + 6, 3, 4) = (8, 3, 4) + + Similarly to the previous example, if no input has 3 dimensions, or if any input has more than 3 dimensions, an + error would be raised in this example. + """ + if not tensors: + raise ValueError("Cannot pack an empty list of tensors.") + + tensors = [ptb.as_tensor(tensor) for tensor in tensors] + + pack_helper = PackHelper(axes=axes) + + reshaped_tensors = [] + tmp_shapes = [] + + n_axes_before, n_axes_after, _, _ = pack_helper._analyze_axes_list() + pack_helper.validate_inputs(tensors) + output_shape = pack_helper.infer_shape(tensors) + + for i, tensor in enumerate(tensors): + shape = tensor.shape + ndim = tensor.ndim + axis_after_packed_axes = ndim - n_axes_after + tmp_shapes.append(shape[n_axes_before:axis_after_packed_axes]) + reshaped_tensors.append( + tensor.reshape( + (*shape[:n_axes_before], -1, *shape[axis_after_packed_axes:]) + ) + ) + + packed_output_tensor = specify_shape( + ptb.join(n_axes_before, *reshaped_tensors), output_shape + ) + packed_output_shapes = [ + ptb.as_tensor_variable(packed_shape).astype("int64") + for i, packed_shape in enumerate(tmp_shapes) + ] + + pack_op = Pack( + inputs=tensors, + outputs=[packed_output_tensor, *packed_output_shapes], + name="Pack{axes=" + str(axes) + "}", + inline=True, + ) + + outputs = pack_op(*tensors) + return outputs[0], outputs[1:] + + +def unpack( + flat_tensor: TensorVariable, packed_shapes: list[tuple[TensorVariable | int]] +) -> tuple[TensorVariable, ...]: + """ + Unpack a flat tensor into its original shapes based on the provided packed shapes. + + Parameters + ---------- + flat_tensor: TensorVariable + A 1D tensor that contains the concatenated values of the original tensors. + packed_shapes: list of tuples of TensorVariable + A list of tuples, where each tuple contains the symbolic shape of the original tensors. + + Returns + ------- + unpacked_tensors: tuple of TensorVariable + A tuple containing the unpacked tensors with their original shapes. + """ + if not packed_shapes: + raise ValueError("Cannot unpack an empty list of shapes.") + + n_splits = len(packed_shapes) + split_size = [prod(shape).astype(int) for shape in packed_shapes] + unpacked_tensors = split(flat_tensor, splits_size=split_size, n_splits=n_splits) + + return tuple( + [x.reshape(shape) for x, shape in zip(unpacked_tensors, packed_shapes)] + ) + + __all__ = [ "bartlett", "bincount", @@ -2027,10 +2465,12 @@ def concat_with_broadcast(tensor_list, axis=0): "geomspace", "linspace", "logspace", + "pack", "ravel_multi_index", "repeat", "searchsorted", "squeeze", "unique", + "unpack", "unravel_index", ] diff --git a/tests/tensor/test_extra_ops.py b/tests/tensor/test_extra_ops.py index 01de6cb517..250142bf7f 100644 --- a/tests/tensor/test_extra_ops.py +++ b/tests/tensor/test_extra_ops.py @@ -21,6 +21,7 @@ CumOp, FillDiagonal, FillDiagonalOffset, + PackHelper, RavelMultiIndex, Repeat, SearchsortedOp, @@ -38,11 +39,13 @@ diff, fill_diagonal, fill_diagonal_offset, + pack, ravel_multi_index, repeat, searchsorted, squeeze, to_one_hot, + unpack, unravel_index, ) from pytensor.tensor.type import ( @@ -1387,3 +1390,200 @@ def test_concat_with_broadcast(): a = pt.tensor("a", shape=(1, 3, 5)) b = pt.tensor("b", shape=(3, 5)) pt.concat_with_broadcast([a, b], axis=1) + + +class TestPack: + @pytest.mark.parametrize( + "axes, expected", + [ + ([0, 1], [2, 0, 2, None]), # 'i j *' + ([-1], [0, 1, 1, None]), # '* k' + ([0, 1, 3], [2, 1, 3, 4]), # 'i j * k' + ([-3, -1], [1, 1, 2, 3]), # '* i j' + ([2, 3], [0, 2, 2, 4]), # '* i j' + ([-3, -2], [2, 0, 2, 3]), # 'i j *' + ([0, -1], [1, 1, 2, None]), # 'i * k' + ([0, 1, 2, -1], [3, 1, 4, None]), # 'i j k * l' + ([0, 1, 4], [2, 1, 3, 5]), + ([-4, -1], [1, 1, 2, 4]), + ], + ids=[ + "basic", + "keep_last", + "ravel_middle_implicit_end", + "implicit_start", + "ravel_start", + "implicit_end", + "mix_pos_neg", + "ravel_middle_explicit_end", + "pos_internal_bigger_gap", + "neg_internal_bigger_gap", + ], + ) + def test_analyze_axes_list_valid(self, axes, expected): + helper = PackHelper(axes) + outputs = helper._analyze_axes_list() + names = ["n_before", "n_after", "min_axes", "max_axes"] + for out, exp, name in zip(outputs, expected, names, strict=True): + assert out == exp, f"Expected {exp}, got {out} for {name}" + + def test_analyze_axes_list_invalid(self): + # Two explicit holes + helper = PackHelper([0, 2, -1]) + with pytest.raises(ValueError, match="Too many holes"): + helper._analyze_axes_list() + + # Explict hole + two implicit holes + helper = PackHelper([1, 3]) + with pytest.raises(ValueError, match="Too many holes"): + helper._analyze_axes_list() + + # Two explicit holes, all positive + helper = PackHelper([0, 2, 4]) + with pytest.raises(ValueError, match="Too many holes"): + helper._analyze_axes_list() + + # Explicit hole + two implicit hole, all negative + helper = PackHelper([-4, -2]) + with pytest.raises(ValueError, match="Too many holes"): + helper._analyze_axes_list() + + # Two explicit holes + implicit hole, all negative + helper = PackHelper([-5, -3, -1]) + with pytest.raises(ValueError, match="Too many holes"): + helper._analyze_axes_list() + + # Duplicate axes + helper = PackHelper([0, 0]) + with pytest.raises(ValueError, match="axes must have no duplicates"): + helper._analyze_axes_list() + + # Not monotonic + helper = PackHelper([0, 2, 1]) + with pytest.raises(ValueError, match="Axes must be strictly increasing"): + helper._analyze_axes_list() + + # Negative before positive + helper = PackHelper([-1, 0]) + with pytest.raises(ValueError, match="Negative axes must come after positive"): + helper._analyze_axes_list() + + def test_pack_basic(self): + # rng = np.random.default_rng() + x = pt.tensor("x", shape=()) + y = pt.tensor("y", shape=(5,)) + z = pt.tensor("z", shape=(3, 3)) + + input_dict = { + variable: np.zeros(variable.type.shape, dtype=config.floatX) + for variable in [x, y, z] + } + + # Simple case, reduce all axes, equivalent to einops '*' + packed_tensor, packed_shapes = pack(x, y, z, axes=None) + assert packed_tensor.type.shape == (15,) + for tensor, packed_shape in zip([x, y, z], packed_shapes): + assert packed_shape.type.shape == (tensor.ndim,) + np.testing.assert_allclose(packed_shape.eval(input_dict), tensor.type.shape) + + # To preserve an axis, all inputs need at least one dimension, and the preserved axis has to agree. + # x is scalar, so pack will raise: + with pytest.raises( + ValueError, + match=r"All input tensors to Pack{axes=0} must have at least 1 dimensions", + ): + pack(x, y, z, axes=0) + + # With valid x, pack should still raise, because the axis of concatenation doesn't agree across all inputs + x = pt.tensor("x", shape=(3,)) + with pytest.raises( + ValueError, + match=r"Input tensors to Pack op have incompatible sizes on dimension 0 : " + r"\[3, 5, 3\]", + ): + pack(x, y, z, axes=0) + + # Valid case, preserve first axis, equivalent to einops 'i *' + y = pt.tensor("y", shape=(3, 5)) + z = pt.tensor("z", shape=(3, 3, 3)) + packed_tensor, packed_shapes = pack(x, y, z, axes=0) + input_dict = { + variable: np.zeros(variable.type.shape, dtype=config.floatX) + for variable in [x, y, z] + } + assert packed_tensor.type.shape == (3, 15) + for tensor, packed_shape in zip([x, y, z], packed_shapes): + assert packed_shape.type.shape == (tensor.ndim - 1,) + np.testing.assert_allclose( + packed_shape.eval(input_dict), tensor.type.shape[1:] + ) + # More complex case, preserve last axis implicitly, equivalent to einops 'i * k'. This introduces a max + # dimension condition on the input shapes + x = pt.tensor("x", shape=(3, 2)) + y = pt.tensor("y", shape=(3, 5, 2)) + z = pt.tensor("z", shape=(3, 1, 7, 5, 2)) + + with pytest.raises( + ValueError, + match=r"All input tensors to Pack{axes=\(0, 3\)} must have at most 4 " + r"dimensions, but the maximum number of dimensions found was 5", + ): + pack(x, y, z, axes=[0, 3]) + + z = pt.tensor("z", shape=(3, 1, 7, 2)) + packed_tensor, packed_shapes = pack(x, y, z, axes=[0, 3]) + input_dict = { + variable: np.zeros(variable.type.shape, dtype=config.floatX) + for variable in [x, y, z] + } + assert packed_tensor.type.shape == (3, 13, 2) + for tensor, packed_shape in zip([x, y, z], packed_shapes): + assert packed_shape.type.shape == (tensor.ndim - 2,) + np.testing.assert_allclose( + packed_shape.eval(input_dict), tensor.type.shape[1:-1] + ) + + def test_pack_unpack_round_trip(self): + rng = np.random.default_rng() + + x = pt.tensor("x", shape=(5,)) + y = pt.tensor("y", shape=(3, 3)) + z = pt.tensor("z", shape=()) + + flat_packed, packed_shapes = pack(x, y, z, axes=None) + new_outputs = unpack(flat_packed, packed_shapes) + + fn = pytensor.function([x, y, z], new_outputs, mode="FAST_COMPILE") + + input_vals = [ + rng.normal(size=var.type.shape).astype(config.floatX) for var in [x, y, z] + ] + output_vals = fn(*input_vals) + + for input_val, output_val in zip(input_vals, output_vals, strict=True): + np.testing.assert_allclose(input_val, output_val) + + +def test_make_replacements_with_pack_unpack(): + rng = np.random.default_rng() + + x = pt.tensor("x", shape=()) + y = pt.tensor("y", shape=(5,)) + z = pt.tensor("z", shape=(3, 3)) + + loss = (x + y.sum() + z.sum()) ** 2 + + flat_packed, packed_shapes = pack(x, y, z, axes=None) + new_input = flat_packed.type() + new_outputs = unpack(new_input, packed_shapes) + + loss = pytensor.graph.graph_replace(loss, dict(zip([x, y, z], new_outputs))) + fn = pytensor.function([new_input, x, y, z], loss, mode="FAST_COMPILE") + + input_vals = [ + rng.normal(size=(var.type.shape)).astype(config.floatX) for var in [x, y, z] + ] + flat_inputs = np.concatenate([input.ravel() for input in input_vals], axis=0) + output_val = fn(flat_inputs, *input_vals) + + assert np.allclose(output_val, sum([input.sum() for input in input_vals]) ** 2)