diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index c9ade02a00..76d1ac8aba 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -910,6 +910,52 @@ def local_join_make_vector(fgraph, node): return [ret] +@register_canonicalize +@node_rewriter([Join]) +def local_join_to_repeat(fgraph, node): + """Join(axis, x, x, x, ...) -> repeat(x, n, axis) + + When the same tensor is concatenated multiple times along an axis + where it has size 1, replace with a repeat operation which is more efficient. + + Examples + -------- + concatenate([x[None], x[None], x[None]], axis=0) -> repeat(x[None], 3, axis=0) + """ + # Extract axis and the tensors being joined + axis, *tensors = node.inputs + + # Optimization only applies when axis is constant + if not isinstance(axis, Constant): + return None + + # Extract the Python integer from the constant + axis_val = axis.data + + # Need at least 2 tensors to consider optimization + if len(tensors) <= 1: + return + + # Check if all tensors are identical + if not all(t == tensors[0] for t in tensors[1:]): + return + + # Only optimize if the tensor has size 1 along the join axis + first_tensor = tensors[0] + if first_tensor.type.shape[axis_val] != 1: + return + + # Replace with repeat operation + from pytensor.tensor.extra_ops import repeat + + result = repeat(first_tensor, len(tensors), axis_val) + + # Preserve debugging information + copy_stack_trace(node.outputs[0], result) + + return [result] + + @register_specialize @register_canonicalize @register_useless diff --git a/tests/tensor/rewriting/test_basic.py b/tests/tensor/rewriting/test_basic.py index d9eb2ad7ad..110656b9c6 100644 --- a/tests/tensor/rewriting/test_basic.py +++ b/tests/tensor/rewriting/test_basic.py @@ -1247,6 +1247,106 @@ def test_local_join_1(): assert f.maker.fgraph.outputs[0].dtype == config.floatX +def test_local_join_to_repeat(): + """Test that Join(axis, x, x, ...) gets rewritten to repeat(x, n, axis) + + This optimization applies when joining the same tensor multiple times + along an axis where it has size 1 (e.g., after ExpandDims). + """ + + # Test with vector expanded to (1, n) - concatenate along axis 0 + x = vector("x") + x_expanded = x[None] # Shape: (1, n) + s = join(0, x_expanded, x_expanded, x_expanded) # Shape: (3, n) + f = function([x], s, mode=rewrite_mode) + + # Check numerical correctness + test_val = np.array([1.0, 2.0, 3.0], dtype=config.floatX) + result = f(test_val) + expected = np.array( + [[1.0, 2.0, 3.0], [1.0, 2.0, 3.0], [1.0, 2.0, 3.0]], dtype=config.floatX + ) + assert np.allclose(result, expected) + + # Check that Join was replaced with Alloc (repeat with scalar repeats becomes Alloc) + ops = f.maker.fgraph.toposort() + assert len([n for n in ops if isinstance(n.op, Join)]) == 0 + assert len([n for n in ops if isinstance(n.op, Alloc)]) >= 1 + + # Test with matrix - add dimension and concatenate along new axis + a = matrix("a") # Shape: (m, n) + a_expanded = a[None, :, :] # Shape: (1, m, n) + s = join(0, a_expanded, a_expanded, a_expanded, a_expanded) # Shape: (4, m, n) + f = function([a], s, mode=rewrite_mode) + + test_mat = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX) + result = f(test_mat) + expected = np.array([test_mat, test_mat, test_mat, test_mat]) + assert np.allclose(result, expected) + + # Check optimization applied + ops = f.maker.fgraph.toposort() + assert len([n for n in ops if isinstance(n.op, Join)]) == 0 + assert len([n for n in ops if isinstance(n.op, Alloc)]) >= 1 + + # Test with matrix - expand along axis 1 and concatenate + a_expanded_ax1 = a[:, None, :] # Shape: (m, 1, n) + s = join(1, a_expanded_ax1, a_expanded_ax1) # Shape: (m, 2, n) + f = function([a], s, mode=rewrite_mode) + + result = f(test_mat) + expected = np.array([[[1.0, 2.0], [1.0, 2.0]], [[3.0, 4.0], [3.0, 4.0]]]) + assert np.allclose(result, expected) + + # Check optimization applied + ops = f.maker.fgraph.toposort() + assert len([n for n in ops if isinstance(n.op, Join)]) == 0 + assert len([n for n in ops if isinstance(n.op, Alloc)]) >= 1 + + # Test that it does NOT apply when tensors are different + y = vector("y") + s = join(0, x[None], y[None]) + f = function([x, y], s, mode=rewrite_mode) + + test_vec1 = np.array([1.0, 2.0], dtype=config.floatX) + test_vec2 = np.array([3.0, 4.0], dtype=config.floatX) + result = f(test_vec1, test_vec2) + expected = np.array([[1.0, 2.0], [3.0, 4.0]]) + assert np.allclose(result, expected) + + # Join should still be present (not optimized) + ops = f.maker.fgraph.toposort() + assert len([n for n in ops if isinstance(n.op, Join)]) == 1 + + # Test that it does NOT apply when tensor doesn't have size 1 along join axis + # (regular concatenation without ExpandDims) + s = join(0, x, x, x) # Shape: (3n,) not using ExpandDims + f = function([x], s, mode=rewrite_mode) + + test_val = np.array([1.0, 2.0], dtype=config.floatX) + result = f(test_val) + expected = np.array([1.0, 2.0, 1.0, 2.0, 1.0, 2.0], dtype=config.floatX) + assert np.allclose(result, expected) + + # Join should still be present (optimization doesn't apply) + ops = f.maker.fgraph.toposort() + assert len([n for n in ops if isinstance(n.op, Join)]) == 1 + + # Test with 5 repetitions to ensure it works with larger counts + s = join(0, x[None], x[None], x[None], x[None], x[None]) + f = function([x], s, mode=rewrite_mode) + + test_val = np.array([1.0, 2.0], dtype=config.floatX) + result = f(test_val) + expected = np.array([[1.0, 2.0]] * 5, dtype=config.floatX) + assert np.allclose(result, expected) + + # Check optimization applied + ops = f.maker.fgraph.toposort() + assert len([n for n in ops if isinstance(n.op, Join)]) == 0 + assert len([n for n in ops if isinstance(n.op, Alloc)]) >= 1 + + def test_local_join_empty(): # Vector case empty_vec = np.asarray([], dtype=config.floatX)