Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions pytensor/tensor/rewriting/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -910,6 +910,52 @@ def local_join_make_vector(fgraph, node):
return [ret]


@register_canonicalize
@node_rewriter([Join])
def local_join_to_repeat(fgraph, node):
"""Join(axis, x, x, x, ...) -> repeat(x, n, axis)

When the same tensor is concatenated multiple times along an axis
where it has size 1, replace with a repeat operation which is more efficient.

Examples
--------
concatenate([x[None], x[None], x[None]], axis=0) -> repeat(x[None], 3, axis=0)
"""
# Extract axis and the tensors being joined
axis, *tensors = node.inputs

# Optimization only applies when axis is constant
if not isinstance(axis, Constant):
return None

# Extract the Python integer from the constant
axis_val = axis.data

# Need at least 2 tensors to consider optimization
if len(tensors) <= 1:
return

# Check if all tensors are identical
if not all(t == tensors[0] for t in tensors[1:]):
return

# Only optimize if the tensor has size 1 along the join axis
first_tensor = tensors[0]
if first_tensor.type.shape[axis_val] != 1:
return

# Replace with repeat operation
from pytensor.tensor.extra_ops import repeat

result = repeat(first_tensor, len(tensors), axis_val)

# Preserve debugging information
copy_stack_trace(node.outputs[0], result)

return [result]


@register_specialize
@register_canonicalize
@register_useless
Expand Down
100 changes: 100 additions & 0 deletions tests/tensor/rewriting/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1247,6 +1247,106 @@ def test_local_join_1():
assert f.maker.fgraph.outputs[0].dtype == config.floatX


def test_local_join_to_repeat():
"""Test that Join(axis, x, x, ...) gets rewritten to repeat(x, n, axis)

This optimization applies when joining the same tensor multiple times
along an axis where it has size 1 (e.g., after ExpandDims).
"""

# Test with vector expanded to (1, n) - concatenate along axis 0
x = vector("x")
x_expanded = x[None] # Shape: (1, n)
s = join(0, x_expanded, x_expanded, x_expanded) # Shape: (3, n)
f = function([x], s, mode=rewrite_mode)

# Check numerical correctness
test_val = np.array([1.0, 2.0, 3.0], dtype=config.floatX)
result = f(test_val)
expected = np.array(
[[1.0, 2.0, 3.0], [1.0, 2.0, 3.0], [1.0, 2.0, 3.0]], dtype=config.floatX
)
assert np.allclose(result, expected)

# Check that Join was replaced with Alloc (repeat with scalar repeats becomes Alloc)
ops = f.maker.fgraph.toposort()
assert len([n for n in ops if isinstance(n.op, Join)]) == 0
assert len([n for n in ops if isinstance(n.op, Alloc)]) >= 1

# Test with matrix - add dimension and concatenate along new axis
a = matrix("a") # Shape: (m, n)
a_expanded = a[None, :, :] # Shape: (1, m, n)
s = join(0, a_expanded, a_expanded, a_expanded, a_expanded) # Shape: (4, m, n)
f = function([a], s, mode=rewrite_mode)

test_mat = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX)
result = f(test_mat)
expected = np.array([test_mat, test_mat, test_mat, test_mat])
assert np.allclose(result, expected)

# Check optimization applied
ops = f.maker.fgraph.toposort()
assert len([n for n in ops if isinstance(n.op, Join)]) == 0
assert len([n for n in ops if isinstance(n.op, Alloc)]) >= 1

# Test with matrix - expand along axis 1 and concatenate
a_expanded_ax1 = a[:, None, :] # Shape: (m, 1, n)
s = join(1, a_expanded_ax1, a_expanded_ax1) # Shape: (m, 2, n)
f = function([a], s, mode=rewrite_mode)

result = f(test_mat)
expected = np.array([[[1.0, 2.0], [1.0, 2.0]], [[3.0, 4.0], [3.0, 4.0]]])
assert np.allclose(result, expected)

# Check optimization applied
ops = f.maker.fgraph.toposort()
assert len([n for n in ops if isinstance(n.op, Join)]) == 0
assert len([n for n in ops if isinstance(n.op, Alloc)]) >= 1

# Test that it does NOT apply when tensors are different
y = vector("y")
s = join(0, x[None], y[None])
f = function([x, y], s, mode=rewrite_mode)

test_vec1 = np.array([1.0, 2.0], dtype=config.floatX)
test_vec2 = np.array([3.0, 4.0], dtype=config.floatX)
result = f(test_vec1, test_vec2)
expected = np.array([[1.0, 2.0], [3.0, 4.0]])
assert np.allclose(result, expected)

# Join should still be present (not optimized)
ops = f.maker.fgraph.toposort()
assert len([n for n in ops if isinstance(n.op, Join)]) == 1

# Test that it does NOT apply when tensor doesn't have size 1 along join axis
# (regular concatenation without ExpandDims)
s = join(0, x, x, x) # Shape: (3n,) not using ExpandDims
f = function([x], s, mode=rewrite_mode)

test_val = np.array([1.0, 2.0], dtype=config.floatX)
result = f(test_val)
expected = np.array([1.0, 2.0, 1.0, 2.0, 1.0, 2.0], dtype=config.floatX)
assert np.allclose(result, expected)

# Join should still be present (optimization doesn't apply)
ops = f.maker.fgraph.toposort()
assert len([n for n in ops if isinstance(n.op, Join)]) == 1

# Test with 5 repetitions to ensure it works with larger counts
s = join(0, x[None], x[None], x[None], x[None], x[None])
f = function([x], s, mode=rewrite_mode)

test_val = np.array([1.0, 2.0], dtype=config.floatX)
result = f(test_val)
expected = np.array([[1.0, 2.0]] * 5, dtype=config.floatX)
assert np.allclose(result, expected)

# Check optimization applied
ops = f.maker.fgraph.toposort()
assert len([n for n in ops if isinstance(n.op, Join)]) == 0
assert len([n for n in ops if isinstance(n.op, Alloc)]) >= 1


def test_local_join_empty():
# Vector case
empty_vec = np.asarray([], dtype=config.floatX)
Expand Down