Skip to content

Conversation

@jessegrabowski
Copy link
Member

@jessegrabowski jessegrabowski commented Aug 10, 2025

Description

Adds pt.pack and pt.unpack helpers, roughly conforming to the einops functions of the same name.

These helps are for situations where we have a ragged list of inputs that need to be raveled into a single flat list for some intermediate step. This occurs in places like optimization.

Example usage:

x = pt.tensor("x", shape=shapes[0])
y = pt.tensor("y", shape=shapes[1])
z = pt.tensor("z", shape=shapes[2])

flat_params, packed_shapes = pt.pack(x, y, z)

Unpack simply undoes the computation, although there's norewrite to ensure pt.unpack(*pt.pack(*inputs)) is the identity function:

x, y, z = pt.unpack(flat_params, packed_shapes)

The use-case I forsee is creating replacement for a function of the inputs we're packing, for example:

loss = (x + y.sum() + z.sum()) ** 2

flat_packed, packed_shapes = pack(x, y, z)
new_input = flat_packed.type()
new_outputs = unpack(new_input, packed_shapes)

loss = pytensor.graph.graph_replace(loss, dict(zip([x, y, z], new_outputs)))
fn = pytensor.function([new_input], loss)

Note that the final compiled function depends only on new_input, only because the shapes of the 3 packed variables were statically known. This leads to my design choices section:

  1. I decided to work with the static shapes directly if they are available. This means that pack will eagerly return a list of integer shapes as packed_shapes if possible. If not possible, they will be symbolic shapes. This is maybe an anti-pattern -- we might prefer a rewrite to handle this later, but it seemed easy enough to do eagerly.
  2. I didn't add support for batch dims. This is left to the user to do himself using pt.vectorize.
  3. The einops API has arguments to support packing/unpacking on arbitrary subsets of dimensions. I didn't do this, because I couldn't think of a use-case that a user couldn't get himself using DimShuffle and vectorize.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pytensor--1578.org.readthedocs.build/en/1578/

@jessegrabowski jessegrabowski added the enhancement New feature or request label Aug 10, 2025
@jessegrabowski
Copy link
Member Author

The pack -> type -> unpack -> replace pattern might be common enough to merit it's own helper. PyMC has tools for doing this, for example, in RaveledArray and DictToArrayBijector, that could be replaced with appropriate symbolic operations.

One other thing I forgot to mention is that this will all fail on inputs with shape 0, since that will ruin the prod(shape) used to get the shape of the flat output. Not sure what to do in that case.

@jessegrabowski jessegrabowski force-pushed the pack-unpack branch 2 times, most recently from da89b9d to 9ead211 Compare August 10, 2025 08:31
@codecov
Copy link

codecov bot commented Aug 10, 2025

Codecov Report

❌ Patch coverage is 94.02985% with 8 lines in your changes missing coverage. Please review.
✅ Project coverage is 81.74%. Comparing base (1f9a67b) to head (0b86851).

Files with missing lines Patch % Lines
pytensor/tensor/extra_ops.py 94.02% 4 Missing and 4 partials ⚠️

❌ Your patch check has failed because the patch coverage (94.02%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1578      +/-   ##
==========================================
+ Coverage   81.70%   81.74%   +0.03%     
==========================================
  Files         246      246              
  Lines       53632    53763     +131     
  Branches     9438     9462      +24     
==========================================
+ Hits        43820    43946     +126     
- Misses       7330     7333       +3     
- Partials     2482     2484       +2     
Files with missing lines Coverage Δ
pytensor/tensor/extra_ops.py 88.88% <94.02%> (+0.95%) ⬆️

... and 2 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@ricardoV94
Copy link
Member

ricardoV94 commented Aug 10, 2025

  1. Better to have the same types as return, static shape to constant is introduced during rewrites already

2 and 3. I would really like to have these, it's what I needed for the batched_dot_to_core rewrites.This isn't a simple case of vectorize because the dims I want to pack are both on the left and right of other dims

@ricardoV94
Copy link
Member

I am inclined to making this a core op and not just a helper. It obliviates most uses of reshape and it's much easier to reason about, not having to worry about pesky -1 or whether the reshape shape comes from the original input shapes or not.

That would pretty much address #883

We could use OFG and/or specialize to reshape/split later. It need also not be done in this PR. It's an implementation detail as far as the user is concerned.

@jessegrabowski jessegrabowski force-pushed the pack-unpack branch 2 times, most recently from 860a7ab to fcbd0af Compare September 20, 2025 16:12
@jessegrabowski
Copy link
Member Author

jessegrabowski commented Nov 2, 2025

I am inclined to making this a core op and not just a helper. It obliviates most uses of reshape and it's much easier to reason about, not having to worry about pesky -1 or whether the reshape shape comes from the original input shapes or not.

I pushed a commit that adds a "feature complete" Pack Op. It can do everything that einops.pack can do, and more. I'd say it's a bit on the overly complex side, but that's on brand for me. The API I cooked up could maybe be simplified.

Basically, you pack by selecting which axes you don't want to ravel. Axes should be None, int, or tuple[int].

If None, it's the same as packed_tensor = pt.join(0, *[var.ravel() for var in input_vars]):

x = pt.tensor("x", shape=())
y = pt.tensor("y", shape=(5,))
z = pt.tensor("z", shape=(3, 3))
packed_tensor, packed_shapes = pt.pack(x, y, z, axes=None)
packed_tensor.type.shape # (15,)

Once you pass in an integers, your inputs need to have the same shape on the dimensions of concatenation. All dimensions that are in a "hole" of the provided axes is raveled and joined. What's a hole? You can have explicit or implicit holes. An "explicit" hole is a gap in the integers you provide. For example, [0, -1] has an explicit hole: all dimensions except the first and last will raveled and joined.

x = pt.tensor("x", shape=(5, 3))
y = pt.tensor("y", shape=(5, 2, 4, 3))
z = pt.tensor("z", shape=(5, 6, 3))
packed_tensor, packed_shapes = pt.pack(x, y, z, axes=[0,-1])
packed_tensor.type.shape # (5, 13, 3)

axes=[0,3] also looks like an explicit hole -- dimensions 1 and 2 will be raveled and joined. But there's also an implicit hole, because there could be dimensions beyond 3. That makes 2 holes, which is an invalid pack. To resolve this, we assume that 3 is the maximum dimension size -- you need to pass at least one tensor with ndims==4 (since we want to concate on axis=3), and no input can have more. Under those conditions, axes=[0, 3] is treated the same as axes=[0, -1]:

x = pt.tensor('x', shape=(2, 6))
y = pt.tensor('y', shape=(2, 3, 7, 6))
z = pt.tensor('z', shape=(2, 10, 6))
packed_tensor, packed_shapes = pt.pack(x, y, z, axes=[0,3])
packed_tensor.type.shape # (2, 32, 6)

Minimum size is also enforced -- we couldn't have passed in x = pt.tensor('x', shape=(2,)) for example, because we need at least 2 dimensions to concatenate, because we asked for 2 pack axes.

I could imagine this being more strict -- axes=[0, 3] could also imply that all inputs are ndim==4. The reason I didn't do it this way is because I wanted to match the feel of einops. axis=[0, 3] feels like it should correspond to something like i * j, with at most 2 dimensions inside the ellipsis.

Pack could probably also be an OpFromGraph, there's nothing special going on in the perform method. That would be better because then we get the gradients for free.

pack_op = Pack(
inputs=tensors,
outputs=[packed_output_tensor, *packed_output_shapes],
name="Pack{axes=" + str(axes) + "}",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why give name instead of just defining the __str__ of the Pack?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because I looked at the docs for OpFromGraph and saw there was a name field I could pass


if len(set(axes)) != len(axes):
raise ValueError("axes must have no duplicates")
if axes is not None and len(axes) == 0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No Ops should be supported in general. Makes writing code easier because you don't have to think about edge case. Empty axes are supported in most Ops that allow variable number of axes (like sum)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What should axes = [] do?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tell me what axes does, and I might know

self.axes = tuple(axes) if isinstance(axes, list) else axes
self.op_name = "Pack{axes=" + str(self.axes) + "}"

def _analyze_axes_list(self) -> tuple[int, int, int, int | None]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is pretty gnarly

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

technically you asked for it

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How so?

"Wrapper for the Pack Op"


def pack(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't tell how this function works for the docstrings. What happens when I pass inputs with different dimensions, and single/list of axes?

axes: int or sequence of int, optional
Axes to be concatenated. All other axes will be raveled (packed) and joined. If None, all axes will be raveled
and joined.
Axes to be preserved. All other axes will be raveled (packed), and the output is the result of concatenating
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC this is changing a bit the usual meaning of axes. Axes=None is usually the same as an exhaustive list of all axes. a.sum(None) == a.sum(tuple(range(a.ndim)), but here it is flipped?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

axis=[] would be the same as None in this case, nothing is preserved.

) # shape (2, 17)
`axes` can also be negative, in which case the axes are counted from the end of the tensor shape. For example,
if `axes=[-1]`, then the last dimension of each tensor is preserved, and all other dimensions are raveled:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The axes specification seems troublesome for vectorization. If axis were those to ravel, then vectorization of this input is very much like axis for other Ops with axes, just shift them to the right by the number of new batch dims.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nvm it's not too bad, we need to add the new batch axis to whatever was defined originally. I'm a bit worried about negative axis at the Op level still

@ricardoV94
Copy link
Member

What if we make 1) axis be the ones to ravel and 2) you have to specify axis per input? You can have a single number / list of ints but then only valid if all inputs have the same ndim?

@ricardoV94
Copy link
Member

ricardoV94 commented Nov 2, 2025

Ok 1) is not intuitive, but 2 may still have some merit? It's strictly more powerful and maybe less magical to specify the axes for each input. It also means we can normalize them to be positive and I think simplify the code analysis?

In your use cases this would be terrible UX?

@ricardoV94
Copy link
Member

BTW I'm not bashing on the idea. On the contrary I quite like it. I just lost some of the context on the PR and I'm being lazy about getting it back.

One thing that would be nice to prove the API is to refactor the batch dot rewrite (dot to batched matmul or whatever is called) to use pack. This is my motivating case where I wanted this functionality.

@jessegrabowski
Copy link
Member Author

The main use-case I have in mind for this is in optimize/pymc where we get parameters in arbitrary shapes, but we want to pack them into a single vector and do a replacement of the original variables with elements of that vector (see the test_make_replacements_with_pack_unpack test for what I have in mind).

Having to specify the number of dims ahead of time in that case doesn't work, because we don't know what the user will give us.

@ricardoV94
Copy link
Member

Why do you need to specify dims before you get to see the inputs?

new_outputs = unpack(new_input, packed_shapes)

loss = pytensor.graph.graph_replace(loss, dict(zip([x, y, z], new_outputs)))
fn = pytensor.function([new_input, x, y, z], loss, mode="FAST_COMPILE")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess that for this example to make sense you want to rewrite the graph after the replace to get rid of the dependency on x, y, z since shapes are static

*tensors: TensorVariable, axes: int | Sequence[int] | None = None
) -> tuple[TensorVariable, list[tuple[TensorVariable]]]:
"""
Given a list of tensors of varying shapes and dimensions, ravels and concatenates them into a single 1d vector.
Copy link
Member

@ricardoV94 ricardoV94 Nov 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Given a list of tensors of varying shapes and dimensions, ravels and concatenates them into a single 1d vector.
Concatenate a list of tensors along a subset of consecutive raveled dimensions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Implement pack/unpack Ops

2 participants