From 5a5c0bc043e15725d116a7f66a7f325b6beeb8af Mon Sep 17 00:00:00 2001 From: Tat Chan Date: Tue, 4 Nov 2025 00:45:08 -0800 Subject: [PATCH 1/5] Rewrite concatenate([x, x]) as repeat(x, 2) --- pytensor/tensor/rewriting/basic.py | 36 ++++++++++++- tests/tensor/rewriting/test_basic.py | 79 ++++++++++++++++++++++++++++ 2 files changed, 114 insertions(+), 1 deletion(-) diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index c9ade02a00..5970f12da4 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -82,7 +82,7 @@ ) from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.exceptions import NotScalarConstantError -from pytensor.tensor.extra_ops import broadcast_arrays +from pytensor.tensor.extra_ops import broadcast_arrays, repeat from pytensor.tensor.math import Sum, add, eq, variadic_add from pytensor.tensor.shape import Shape_i, shape_padleft from pytensor.tensor.type import DenseTensorType, TensorType @@ -909,6 +909,40 @@ def local_join_make_vector(fgraph, node): copy_stack_trace(node.outputs, ret) return [ret] +@register_specialize +@register_canonicalize +@node_rewriter([Join]) +def local_join_to_repeat(fgraph, node): + """Join(axis, x, x, x, ...) -> repeat(x, n, axis) + + When the same tensor is concatenated multiple times, + replace with a single repeat operation which is more efficient. + + Examples + -------- + concatenate([x, x, x], axis=0) -> repeat(x, 3, axis=0) + """ + if not isinstance(node.op, Join): + return + + # Extract axis and the tensors being joined + axis, *tensors = node.inputs + + # Need at least 2 tensors to consider optimization + if len(tensors) <= 1: + return + + # Check if all tensors are identical + if not all(t == tensors[0] for t in tensors[1:]): + return + + # Replace with repeat operation + result = repeat(tensors[0], len(tensors), axis) + + # Preserve debugging information + copy_stack_trace(node.outputs[0], result) + + return [result] @register_specialize @register_canonicalize diff --git a/tests/tensor/rewriting/test_basic.py b/tests/tensor/rewriting/test_basic.py index d9eb2ad7ad..40ab200a2c 100644 --- a/tests/tensor/rewriting/test_basic.py +++ b/tests/tensor/rewriting/test_basic.py @@ -35,6 +35,7 @@ tile, ) from pytensor.tensor.elemwise import DimShuffle, Elemwise +from pytensor.tensor.extra_ops import Repeat from pytensor.tensor.math import ( add, bitwise_and, @@ -1247,6 +1248,84 @@ def test_local_join_1(): assert f.maker.fgraph.outputs[0].dtype == config.floatX +def test_local_join_to_repeat(): + """Test that Join(axis, x, x, ...) gets rewritten to repeat(x, n, axis)""" + + # Test with vector - concatenate same vector 3 times along axis 0 + x = vector("x") + s = join(0, x, x, x) + f = function([x], s, mode=rewrite_mode) + + # Check numerical correctness + test_val = np.array([1.0, 2.0, 3.0], dtype=config.floatX) + result = f(test_val) + expected = np.array([1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0], dtype=config.floatX) + assert np.allclose(result, expected) + + # Check that Join was replaced with Repeat + ops = f.maker.fgraph.toposort() + assert len([n for n in ops if isinstance(n.op, Join)]) == 0 + assert len([n for n in ops if isinstance(n.op, Repeat)]) == 1 + + # Test with matrix - concatenate same matrix along axis 0 + a = matrix("a") + s = join(0, a, a, a, a) + f = function([a], s, mode=rewrite_mode) + + test_mat = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX) + result = f(test_mat) + expected = np.vstack([test_mat, test_mat, test_mat, test_mat]) + assert np.allclose(result, expected) + + # Check optimization applied + ops = f.maker.fgraph.toposort() + assert len([n for n in ops if isinstance(n.op, Join)]) == 0 + assert len([n for n in ops if isinstance(n.op, Repeat)]) == 1 + + # Test with matrix - concatenate along axis 1 + s = join(1, a, a) + f = function([a], s, mode=rewrite_mode) + + result = f(test_mat) + expected = np.hstack([test_mat, test_mat]) + assert np.allclose(result, expected) + + # Check optimization applied + ops = f.maker.fgraph.toposort() + assert len([n for n in ops if isinstance(n.op, Join)]) == 0 + assert len([n for n in ops if isinstance(n.op, Repeat)]) == 1 + + # Test that it does NOT apply when tensors are different + b = matrix("b") + s = join(0, a, b) + f = function([a, b], s, mode=rewrite_mode) + + test_mat1 = np.array([[1.0, 2.0]], dtype=config.floatX) + test_mat2 = np.array([[3.0, 4.0]], dtype=config.floatX) + result = f(test_mat1, test_mat2) + expected = np.vstack([test_mat1, test_mat2]) + assert np.allclose(result, expected) + + # Join should still be present (not optimized to Repeat) + ops = f.maker.fgraph.toposort() + assert len([n for n in ops if isinstance(n.op, Join)]) == 1 + assert len([n for n in ops if isinstance(n.op, Repeat)]) == 0 + + # Test with 5 repetitions to ensure it works with larger counts + s = join(0, x, x, x, x, x) + f = function([x], s, mode=rewrite_mode) + + test_val = np.array([1.0, 2.0], dtype=config.floatX) + result = f(test_val) + expected = np.tile(test_val, 5) + assert np.allclose(result, expected) + + # Check optimization applied + ops = f.maker.fgraph.toposort() + assert len([n for n in ops if isinstance(n.op, Join)]) == 0 + assert len([n for n in ops if isinstance(n.op, Repeat)]) == 1 + + def test_local_join_empty(): # Vector case empty_vec = np.asarray([], dtype=config.floatX) From 44aa137d40c3cd6fde8c61b25818b8b3d00a0562 Mon Sep 17 00:00:00 2001 From: Tat Chan Date: Tue, 4 Nov 2025 00:55:59 -0800 Subject: [PATCH 2/5] fixed format --- pytensor/tensor/rewriting/basic.py | 2 ++ tests/tensor/rewriting/test_basic.py | 4 +++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index 5970f12da4..555ef72464 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -909,6 +909,7 @@ def local_join_make_vector(fgraph, node): copy_stack_trace(node.outputs, ret) return [ret] + @register_specialize @register_canonicalize @node_rewriter([Join]) @@ -944,6 +945,7 @@ def local_join_to_repeat(fgraph, node): return [result] + @register_specialize @register_canonicalize @register_useless diff --git a/tests/tensor/rewriting/test_basic.py b/tests/tensor/rewriting/test_basic.py index 40ab200a2c..4879b816cf 100644 --- a/tests/tensor/rewriting/test_basic.py +++ b/tests/tensor/rewriting/test_basic.py @@ -1259,7 +1259,9 @@ def test_local_join_to_repeat(): # Check numerical correctness test_val = np.array([1.0, 2.0, 3.0], dtype=config.floatX) result = f(test_val) - expected = np.array([1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0], dtype=config.floatX) + expected = np.array( + [1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0], dtype=config.floatX + ) assert np.allclose(result, expected) # Check that Join was replaced with Repeat From 58dc0276a23aceae5d4edf4b45445abae3e66d58 Mon Sep 17 00:00:00 2001 From: Tat Chan Date: Tue, 4 Nov 2025 01:11:02 -0800 Subject: [PATCH 3/5] remove register_specialize and not instance check --- pytensor/tensor/rewriting/basic.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index 555ef72464..4a90af02eb 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -910,7 +910,6 @@ def local_join_make_vector(fgraph, node): return [ret] -@register_specialize @register_canonicalize @node_rewriter([Join]) def local_join_to_repeat(fgraph, node): @@ -923,9 +922,6 @@ def local_join_to_repeat(fgraph, node): -------- concatenate([x, x, x], axis=0) -> repeat(x, 3, axis=0) """ - if not isinstance(node.op, Join): - return - # Extract axis and the tensors being joined axis, *tensors = node.inputs From 9348020f065815c257242d878dabe8649583e83c Mon Sep 17 00:00:00 2001 From: Tat Chan Date: Wed, 5 Nov 2025 01:28:27 -0800 Subject: [PATCH 4/5] Handle symbolic axis and fix test assertions for Alloc --- pytensor/tensor/rewriting/basic.py | 48 +++++++++++++---- tests/tensor/rewriting/test_basic.py | 77 +++++++++++++++++----------- 2 files changed, 85 insertions(+), 40 deletions(-) diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index 4a90af02eb..0297c04b07 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -30,7 +30,7 @@ from pytensor import compile, config from pytensor.compile.ops import ViewOp from pytensor.graph import FunctionGraph, Op -from pytensor.graph.basic import Constant +from pytensor.graph.basic import Constant, equal_computations from pytensor.graph.rewriting.basic import ( NodeProcessingGraphRewriter, NodeRewriter, @@ -82,7 +82,7 @@ ) from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.exceptions import NotScalarConstantError -from pytensor.tensor.extra_ops import broadcast_arrays, repeat +from pytensor.tensor.extra_ops import broadcast_arrays from pytensor.tensor.math import Sum, add, eq, variadic_add from pytensor.tensor.shape import Shape_i, shape_padleft from pytensor.tensor.type import DenseTensorType, TensorType @@ -915,26 +915,52 @@ def local_join_make_vector(fgraph, node): def local_join_to_repeat(fgraph, node): """Join(axis, x, x, x, ...) -> repeat(x, n, axis) - When the same tensor is concatenated multiple times, - replace with a single repeat operation which is more efficient. + When the same tensor is concatenated multiple times along an axis + where it has size 1, replace with a repeat operation which is more efficient. Examples -------- - concatenate([x, x, x], axis=0) -> repeat(x, 3, axis=0) + concatenate([x[None], x[None], x[None]], axis=0) -> repeat(x[None], 3, axis=0) """ # Extract axis and the tensors being joined - axis, *tensors = node.inputs + axis_sym, *tensors = node.inputs # Need at least 2 tensors to consider optimization if len(tensors) <= 1: - return + return None - # Check if all tensors are identical - if not all(t == tensors[0] for t in tensors[1:]): - return + # Extract (and normalize) axis as Python int + try: + axis_val = int(get_scalar_constant_value(axis_sym, only_process_constants=True)) + except NotScalarConstantError: + return None + + # Get first tensor and check if ndim is known + first = tensors[0] + ndim = first.ndim + if ndim is None: + return None + + # Normalize negative axes (e.g., -1 -> ndim-1) + axis_val = axis_val % ndim + + # All inputs must be structurally the same tensor + # Use equal_computations to check structural equality, not symbolic == + for t in tensors[1:]: + if not equal_computations([t], [first]): + return None + + # Only apply when size along join axis is statically 1 + # (e.g., x[None] has a guaranteed 1 at that axis) + shp = first.type.shape # tuple of ints/None + if shp is None or axis_val >= len(shp) or shp[axis_val] != 1: + return None # Replace with repeat operation - result = repeat(tensors[0], len(tensors), axis) + from pytensor.tensor.extra_ops import repeat + + n = len(tensors) + result = repeat(first, n, axis=axis_val) # Preserve debugging information copy_stack_trace(node.outputs[0], result) diff --git a/tests/tensor/rewriting/test_basic.py b/tests/tensor/rewriting/test_basic.py index 4879b816cf..110656b9c6 100644 --- a/tests/tensor/rewriting/test_basic.py +++ b/tests/tensor/rewriting/test_basic.py @@ -35,7 +35,6 @@ tile, ) from pytensor.tensor.elemwise import DimShuffle, Elemwise -from pytensor.tensor.extra_ops import Repeat from pytensor.tensor.math import ( add, bitwise_and, @@ -1249,83 +1248,103 @@ def test_local_join_1(): def test_local_join_to_repeat(): - """Test that Join(axis, x, x, ...) gets rewritten to repeat(x, n, axis)""" + """Test that Join(axis, x, x, ...) gets rewritten to repeat(x, n, axis) - # Test with vector - concatenate same vector 3 times along axis 0 + This optimization applies when joining the same tensor multiple times + along an axis where it has size 1 (e.g., after ExpandDims). + """ + + # Test with vector expanded to (1, n) - concatenate along axis 0 x = vector("x") - s = join(0, x, x, x) + x_expanded = x[None] # Shape: (1, n) + s = join(0, x_expanded, x_expanded, x_expanded) # Shape: (3, n) f = function([x], s, mode=rewrite_mode) # Check numerical correctness test_val = np.array([1.0, 2.0, 3.0], dtype=config.floatX) result = f(test_val) expected = np.array( - [1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0], dtype=config.floatX + [[1.0, 2.0, 3.0], [1.0, 2.0, 3.0], [1.0, 2.0, 3.0]], dtype=config.floatX ) assert np.allclose(result, expected) - # Check that Join was replaced with Repeat + # Check that Join was replaced with Alloc (repeat with scalar repeats becomes Alloc) ops = f.maker.fgraph.toposort() assert len([n for n in ops if isinstance(n.op, Join)]) == 0 - assert len([n for n in ops if isinstance(n.op, Repeat)]) == 1 + assert len([n for n in ops if isinstance(n.op, Alloc)]) >= 1 - # Test with matrix - concatenate same matrix along axis 0 - a = matrix("a") - s = join(0, a, a, a, a) + # Test with matrix - add dimension and concatenate along new axis + a = matrix("a") # Shape: (m, n) + a_expanded = a[None, :, :] # Shape: (1, m, n) + s = join(0, a_expanded, a_expanded, a_expanded, a_expanded) # Shape: (4, m, n) f = function([a], s, mode=rewrite_mode) test_mat = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX) result = f(test_mat) - expected = np.vstack([test_mat, test_mat, test_mat, test_mat]) + expected = np.array([test_mat, test_mat, test_mat, test_mat]) assert np.allclose(result, expected) # Check optimization applied ops = f.maker.fgraph.toposort() assert len([n for n in ops if isinstance(n.op, Join)]) == 0 - assert len([n for n in ops if isinstance(n.op, Repeat)]) == 1 + assert len([n for n in ops if isinstance(n.op, Alloc)]) >= 1 - # Test with matrix - concatenate along axis 1 - s = join(1, a, a) + # Test with matrix - expand along axis 1 and concatenate + a_expanded_ax1 = a[:, None, :] # Shape: (m, 1, n) + s = join(1, a_expanded_ax1, a_expanded_ax1) # Shape: (m, 2, n) f = function([a], s, mode=rewrite_mode) result = f(test_mat) - expected = np.hstack([test_mat, test_mat]) + expected = np.array([[[1.0, 2.0], [1.0, 2.0]], [[3.0, 4.0], [3.0, 4.0]]]) assert np.allclose(result, expected) # Check optimization applied ops = f.maker.fgraph.toposort() assert len([n for n in ops if isinstance(n.op, Join)]) == 0 - assert len([n for n in ops if isinstance(n.op, Repeat)]) == 1 + assert len([n for n in ops if isinstance(n.op, Alloc)]) >= 1 # Test that it does NOT apply when tensors are different - b = matrix("b") - s = join(0, a, b) - f = function([a, b], s, mode=rewrite_mode) - - test_mat1 = np.array([[1.0, 2.0]], dtype=config.floatX) - test_mat2 = np.array([[3.0, 4.0]], dtype=config.floatX) - result = f(test_mat1, test_mat2) - expected = np.vstack([test_mat1, test_mat2]) + y = vector("y") + s = join(0, x[None], y[None]) + f = function([x, y], s, mode=rewrite_mode) + + test_vec1 = np.array([1.0, 2.0], dtype=config.floatX) + test_vec2 = np.array([3.0, 4.0], dtype=config.floatX) + result = f(test_vec1, test_vec2) + expected = np.array([[1.0, 2.0], [3.0, 4.0]]) + assert np.allclose(result, expected) + + # Join should still be present (not optimized) + ops = f.maker.fgraph.toposort() + assert len([n for n in ops if isinstance(n.op, Join)]) == 1 + + # Test that it does NOT apply when tensor doesn't have size 1 along join axis + # (regular concatenation without ExpandDims) + s = join(0, x, x, x) # Shape: (3n,) not using ExpandDims + f = function([x], s, mode=rewrite_mode) + + test_val = np.array([1.0, 2.0], dtype=config.floatX) + result = f(test_val) + expected = np.array([1.0, 2.0, 1.0, 2.0, 1.0, 2.0], dtype=config.floatX) assert np.allclose(result, expected) - # Join should still be present (not optimized to Repeat) + # Join should still be present (optimization doesn't apply) ops = f.maker.fgraph.toposort() assert len([n for n in ops if isinstance(n.op, Join)]) == 1 - assert len([n for n in ops if isinstance(n.op, Repeat)]) == 0 # Test with 5 repetitions to ensure it works with larger counts - s = join(0, x, x, x, x, x) + s = join(0, x[None], x[None], x[None], x[None], x[None]) f = function([x], s, mode=rewrite_mode) test_val = np.array([1.0, 2.0], dtype=config.floatX) result = f(test_val) - expected = np.tile(test_val, 5) + expected = np.array([[1.0, 2.0]] * 5, dtype=config.floatX) assert np.allclose(result, expected) # Check optimization applied ops = f.maker.fgraph.toposort() assert len([n for n in ops if isinstance(n.op, Join)]) == 0 - assert len([n for n in ops if isinstance(n.op, Repeat)]) == 1 + assert len([n for n in ops if isinstance(n.op, Alloc)]) >= 1 def test_local_join_empty(): From 63dbcf8a80a355c89628f0564aed9f9445ed6be1 Mon Sep 17 00:00:00 2001 From: Tat Chan Date: Wed, 5 Nov 2025 12:41:44 -0800 Subject: [PATCH 5/5] removed equal computation --- pytensor/tensor/rewriting/basic.py | 46 +++++++++++------------------- 1 file changed, 17 insertions(+), 29 deletions(-) diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index 0297c04b07..76d1ac8aba 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -30,7 +30,7 @@ from pytensor import compile, config from pytensor.compile.ops import ViewOp from pytensor.graph import FunctionGraph, Op -from pytensor.graph.basic import Constant, equal_computations +from pytensor.graph.basic import Constant from pytensor.graph.rewriting.basic import ( NodeProcessingGraphRewriter, NodeRewriter, @@ -923,44 +923,32 @@ def local_join_to_repeat(fgraph, node): concatenate([x[None], x[None], x[None]], axis=0) -> repeat(x[None], 3, axis=0) """ # Extract axis and the tensors being joined - axis_sym, *tensors = node.inputs + axis, *tensors = node.inputs - # Need at least 2 tensors to consider optimization - if len(tensors) <= 1: + # Optimization only applies when axis is constant + if not isinstance(axis, Constant): return None - # Extract (and normalize) axis as Python int - try: - axis_val = int(get_scalar_constant_value(axis_sym, only_process_constants=True)) - except NotScalarConstantError: - return None + # Extract the Python integer from the constant + axis_val = axis.data - # Get first tensor and check if ndim is known - first = tensors[0] - ndim = first.ndim - if ndim is None: - return None - - # Normalize negative axes (e.g., -1 -> ndim-1) - axis_val = axis_val % ndim + # Need at least 2 tensors to consider optimization + if len(tensors) <= 1: + return - # All inputs must be structurally the same tensor - # Use equal_computations to check structural equality, not symbolic == - for t in tensors[1:]: - if not equal_computations([t], [first]): - return None + # Check if all tensors are identical + if not all(t == tensors[0] for t in tensors[1:]): + return - # Only apply when size along join axis is statically 1 - # (e.g., x[None] has a guaranteed 1 at that axis) - shp = first.type.shape # tuple of ints/None - if shp is None or axis_val >= len(shp) or shp[axis_val] != 1: - return None + # Only optimize if the tensor has size 1 along the join axis + first_tensor = tensors[0] + if first_tensor.type.shape[axis_val] != 1: + return # Replace with repeat operation from pytensor.tensor.extra_ops import repeat - n = len(tensors) - result = repeat(first, n, axis=axis_val) + result = repeat(first_tensor, len(tensors), axis_val) # Preserve debugging information copy_stack_trace(node.outputs[0], result)