-
Notifications
You must be signed in to change notification settings - Fork 146
Implement pack/unpack helpers #1578
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
36422cd to
cf633e7
Compare
|
The pack -> type -> unpack -> replace pattern might be common enough to merit it's own helper. PyMC has tools for doing this, for example, in One other thing I forgot to mention is that this will all fail on inputs with shape 0, since that will ruin the |
da89b9d to
9ead211
Compare
Codecov Report❌ Patch coverage is
❌ Your patch check has failed because the patch coverage (94.02%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #1578 +/- ##
==========================================
+ Coverage 81.70% 81.74% +0.03%
==========================================
Files 246 246
Lines 53632 53763 +131
Branches 9438 9462 +24
==========================================
+ Hits 43820 43946 +126
- Misses 7330 7333 +3
- Partials 2482 2484 +2
🚀 New features to boost your workflow:
|
2 and 3. I would really like to have these, it's what I needed for the batched_dot_to_core rewrites.This isn't a simple case of vectorize because the dims I want to pack are both on the left and right of other dims |
|
I am inclined to making this a core op and not just a helper. It obliviates most uses of reshape and it's much easier to reason about, not having to worry about pesky -1 or whether the reshape shape comes from the original input shapes or not. That would pretty much address #883 We could use OFG and/or specialize to reshape/split later. It need also not be done in this PR. It's an implementation detail as far as the user is concerned. |
860a7ab to
fcbd0af
Compare
fcbd0af to
5788333
Compare
I pushed a commit that adds a "feature complete" Basically, you pack by selecting which axes you don't want to ravel. Axes should be None, int, or tuple[int]. If None, it's the same as x = pt.tensor("x", shape=())
y = pt.tensor("y", shape=(5,))
z = pt.tensor("z", shape=(3, 3))
packed_tensor, packed_shapes = pt.pack(x, y, z, axes=None)
packed_tensor.type.shape # (15,)Once you pass in an integers, your inputs need to have the same shape on the dimensions of concatenation. All dimensions that are in a "hole" of the provided axes is raveled and joined. What's a hole? You can have explicit or implicit holes. An "explicit" hole is a gap in the integers you provide. For example, x = pt.tensor("x", shape=(5, 3))
y = pt.tensor("y", shape=(5, 2, 4, 3))
z = pt.tensor("z", shape=(5, 6, 3))
packed_tensor, packed_shapes = pt.pack(x, y, z, axes=[0,-1])
packed_tensor.type.shape # (5, 13, 3)
x = pt.tensor('x', shape=(2, 6))
y = pt.tensor('y', shape=(2, 3, 7, 6))
z = pt.tensor('z', shape=(2, 10, 6))
packed_tensor, packed_shapes = pt.pack(x, y, z, axes=[0,3])
packed_tensor.type.shape # (2, 32, 6)Minimum size is also enforced -- we couldn't have passed in I could imagine this being more strict --
|
| pack_op = Pack( | ||
| inputs=tensors, | ||
| outputs=[packed_output_tensor, *packed_output_shapes], | ||
| name="Pack{axes=" + str(axes) + "}", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why give name instead of just defining the __str__ of the Pack?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because I looked at the docs for OpFromGraph and saw there was a name field I could pass
|
|
||
| if len(set(axes)) != len(axes): | ||
| raise ValueError("axes must have no duplicates") | ||
| if axes is not None and len(axes) == 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No Ops should be supported in general. Makes writing code easier because you don't have to think about edge case. Empty axes are supported in most Ops that allow variable number of axes (like sum)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What should axes = [] do?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tell me what axes does, and I might know
| self.axes = tuple(axes) if isinstance(axes, list) else axes | ||
| self.op_name = "Pack{axes=" + str(self.axes) + "}" | ||
|
|
||
| def _analyze_axes_list(self) -> tuple[int, int, int, int | None]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is pretty gnarly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
technically you asked for it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How so?
| "Wrapper for the Pack Op" | ||
|
|
||
|
|
||
| def pack( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can't tell how this function works for the docstrings. What happens when I pass inputs with different dimensions, and single/list of axes?
| axes: int or sequence of int, optional | ||
| Axes to be concatenated. All other axes will be raveled (packed) and joined. If None, all axes will be raveled | ||
| and joined. | ||
| Axes to be preserved. All other axes will be raveled (packed), and the output is the result of concatenating |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC this is changing a bit the usual meaning of axes. Axes=None is usually the same as an exhaustive list of all axes. a.sum(None) == a.sum(tuple(range(a.ndim)), but here it is flipped?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
axis=[] would be the same as None in this case, nothing is preserved.
| ) # shape (2, 17) | ||
| `axes` can also be negative, in which case the axes are counted from the end of the tensor shape. For example, | ||
| if `axes=[-1]`, then the last dimension of each tensor is preserved, and all other dimensions are raveled: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The axes specification seems troublesome for vectorization. If axis were those to ravel, then vectorization of this input is very much like axis for other Ops with axes, just shift them to the right by the number of new batch dims.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nvm it's not too bad, we need to add the new batch axis to whatever was defined originally. I'm a bit worried about negative axis at the Op level still
|
What if we make 1) axis be the ones to ravel and 2) you have to specify axis per input? You can have a single number / list of ints but then only valid if all inputs have the same ndim? |
|
Ok 1) is not intuitive, but 2 may still have some merit? It's strictly more powerful and maybe less magical to specify the axes for each input. It also means we can normalize them to be positive and I think simplify the code analysis? In your use cases this would be terrible UX? |
|
BTW I'm not bashing on the idea. On the contrary I quite like it. I just lost some of the context on the PR and I'm being lazy about getting it back. One thing that would be nice to prove the API is to refactor the batch dot rewrite (dot to batched matmul or whatever is called) to use pack. This is my motivating case where I wanted this functionality. |
|
The main use-case I have in mind for this is in optimize/pymc where we get parameters in arbitrary shapes, but we want to pack them into a single vector and do a replacement of the original variables with elements of that vector (see the Having to specify the number of dims ahead of time in that case doesn't work, because we don't know what the user will give us. |
|
Why do you need to specify dims before you get to see the inputs? |
| new_outputs = unpack(new_input, packed_shapes) | ||
|
|
||
| loss = pytensor.graph.graph_replace(loss, dict(zip([x, y, z], new_outputs))) | ||
| fn = pytensor.function([new_input, x, y, z], loss, mode="FAST_COMPILE") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess that for this example to make sense you want to rewrite the graph after the replace to get rid of the dependency on x, y, z since shapes are static
| *tensors: TensorVariable, axes: int | Sequence[int] | None = None | ||
| ) -> tuple[TensorVariable, list[tuple[TensorVariable]]]: | ||
| """ | ||
| Given a list of tensors of varying shapes and dimensions, ravels and concatenates them into a single 1d vector. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| Given a list of tensors of varying shapes and dimensions, ravels and concatenates them into a single 1d vector. | |
| Concatenate a list of tensors along a subset of consecutive raveled dimensions. |
Description
Adds pt.pack and pt.unpack helpers, roughly conforming to the
einopsfunctions of the same name.These helps are for situations where we have a ragged list of inputs that need to be raveled into a single flat list for some intermediate step. This occurs in places like optimization.
Example usage:
Unpack simply undoes the computation, although there's norewrite to ensure
pt.unpack(*pt.pack(*inputs))is the identity function:The use-case I forsee is creating replacement for a function of the inputs we're packing, for example:
Note that the final compiled function depends only on
new_input, only because the shapes of the 3 packed variables were statically known. This leads to my design choices section:packwill eagerly return a list of integer shapes aspacked_shapesif possible. If not possible, they will be symbolic shapes. This is maybe an anti-pattern -- we might prefer a rewrite to handle this later, but it seemed easy enough to do eagerly.pt.vectorize.einopsAPI has arguments to support packing/unpacking on arbitrary subsets of dimensions. I didn't do this, because I couldn't think of a use-case that a user couldn't get himself usingDimShuffleandvectorize.Related Issue
Checklist
Type of change
📚 Documentation preview 📚: https://pytensor--1578.org.readthedocs.build/en/1578/