-
Notifications
You must be signed in to change notification settings - Fork 146
Rewrite concatenate([x, x]) as repeat(x, 2) #1714
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Let's try with |
|
Btw would be nice to get rid of this join (and split) symbolic axis if you would like to work on that after this PR. relevant issue: #1528 |
pytensor/tensor/rewriting/basic.py
Outdated
| # (e.g., x[None] has a guaranteed 1 at that axis) | ||
| shp = first.type.shape # tuple of ints/None | ||
| if shp is None or axis_val >= len(shp) or shp[axis_val] != 1: | ||
| return None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why? I don't think this is needed condition
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the feedback! I wanted to clarify about removing the shape[axis_val] != 1 condition.
When I removed this check, the tests fail because join and repeat behave differently when the size along the join axis is not 1:
Example:
x = vector("x") # e.g., [1.0, 2.0]
Case 1: Without ExpandDims (size is NOT 1 along axis 0)
join(0, x, x) # → [1.0, 2.0, 1.0, 2.0] (concatenates)
repeat(x, 2, 0) # → [1.0, 1.0, 2.0, 2.0] (repeats each element)
❌ These are NOT equivalent!
Case 2: With ExpandDims (size IS 1 along axis 0)
join(0, x[None], x[None]) # → [[1.0, 2.0], [1.0, 2.0]]
repeat(x[None], 2, 0) # → [[1.0, 2.0], [1.0, 2.0]]
✅ These ARE equivalent!
The optimization seems to only be mathematically correct when the size along the join axis is 1 (e.g., after ExpandDims). Without this check, test_local_join_to_repeat fails.
Could you clarify what you meant by "not needed"? Should we:
- Keep the check as-is (only optimize when shape[axis_val] == 1)
- Remove only the defensive parts (shp is None or axis_val >= len(shp))
- Apply the optimization more broadly (and update the tests accordingly)
Thanks for your guidance!
Add optimization for Join → Repeat when concatenating identical tensors
Description
This PR introduces a graph rewrite optimization in pytensor/tensor/rewriting/basic.py that replaces redundant Join operations with an equivalent and more efficient Repeat operation when all concatenated tensors are identical.
Example:
join(0, x, x, x) → repeat(x, 3, axis=0)
Key additions:
Related Issue
Checklist
Type of change