Skip to content

Conversation

@tchan102
Copy link

@tchan102 tchan102 commented Nov 4, 2025

Add optimization for Join → Repeat when concatenating identical tensors

Description

This PR introduces a graph rewrite optimization in pytensor/tensor/rewriting/basic.py that replaces redundant Join operations with an equivalent and more efficient Repeat operation when all concatenated tensors are identical.

Example:
join(0, x, x, x) → repeat(x, 3, axis=0)

Key additions:

  • Implemented new rewrite function local_join_to_repeat registered under both @register_canonicalize and @register_specialize.
  • Added corresponding test test_local_join_to_repeat to verify correctness, performance, and behavior for vectors and matrices.

Related Issue

Checklist

Type of change

  • [ x] New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

@ricardoV94 ricardoV94 added graph rewriting enhancement New feature or request labels Nov 4, 2025
@ricardoV94
Copy link
Member

Let's try with @register_canonicalize only

@ricardoV94
Copy link
Member

Btw would be nice to get rid of this join (and split) symbolic axis if you would like to work on that after this PR. relevant issue: #1528

# (e.g., x[None] has a guaranteed 1 at that axis)
shp = first.type.shape # tuple of ints/None
if shp is None or axis_val >= len(shp) or shp[axis_val] != 1:
return None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why? I don't think this is needed condition

Copy link
Author

@tchan102 tchan102 Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feedback! I wanted to clarify about removing the shape[axis_val] != 1 condition.

When I removed this check, the tests fail because join and repeat behave differently when the size along the join axis is not 1:

Example:
x = vector("x") # e.g., [1.0, 2.0]

Case 1: Without ExpandDims (size is NOT 1 along axis 0)
join(0, x, x) # → [1.0, 2.0, 1.0, 2.0] (concatenates)
repeat(x, 2, 0) # → [1.0, 1.0, 2.0, 2.0] (repeats each element)
❌ These are NOT equivalent!

Case 2: With ExpandDims (size IS 1 along axis 0)
join(0, x[None], x[None]) # → [[1.0, 2.0], [1.0, 2.0]]
repeat(x[None], 2, 0) # → [[1.0, 2.0], [1.0, 2.0]]
✅ These ARE equivalent!

The optimization seems to only be mathematically correct when the size along the join axis is 1 (e.g., after ExpandDims). Without this check, test_local_join_to_repeat fails.

Could you clarify what you meant by "not needed"? Should we:

  1. Keep the check as-is (only optimize when shape[axis_val] == 1)
  2. Remove only the defensive parts (shp is None or axis_val >= len(shp))
  3. Apply the optimization more broadly (and update the tests accordingly)

Thanks for your guidance!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request graph rewriting

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Rewrite concatenate([x, x]) as repeat(x, 2)

2 participants