diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 17a3ce9165..3960a396cf 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -65,6 +65,32 @@ MATRIX_INVERSE_OPS = (MatrixInverse, MatrixPinv) +@register_canonicalize +@node_rewriter([BlockDiagonal]) +def fuse_blockdiagonal(fgraph, node): + """Fuse nested BlockDiagonal ops into a single BlockDiagonal.""" + + if not isinstance(node.op, BlockDiagonal): + return None + + new_inputs = [] + changed = False + + for inp in node.inputs: + if inp.owner and isinstance(inp.owner.op, BlockDiagonal): + new_inputs.extend(inp.owner.inputs) + changed = True + else: + new_inputs.append(inp) + + if changed: + fused_op = BlockDiagonal(len(new_inputs)) + new_output = fused_op(*new_inputs) + return [new_output] + + return None + + def is_matrix_transpose(x: TensorVariable) -> bool: """Check if a variable corresponds to a transpose of the last two axes""" node = x.owner diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index 515120e446..cd098bed25 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -44,6 +44,71 @@ from tests.test_rop import break_op +def test_nested_blockdiag_fusion(): + x = pt.tensor("x", shape=(3, 3)) + y = pt.tensor("y", shape=(3, 3)) + z = pt.tensor("z", shape=(3, 3)) + + inner = BlockDiagonal(2)(x, y) + outer = BlockDiagonal(2)(inner, z) + + nodes_before = ancestors([outer]) + initial_count = sum( + 1 + for node in nodes_before + if getattr(node, "owner", None) and isinstance(node.owner.op, BlockDiagonal) + ) + assert initial_count == 2, "Setup failed: expected 2 nested BlockDiagonal ops" + + f = pytensor.function([x, y, z], outer) + fgraph = f.maker.fgraph + + nodes_after = fgraph.apply_nodes + fused_nodes = [node for node in nodes_after if isinstance(node.op, BlockDiagonal)] + assert len(fused_nodes) == 1, "Nested BlockDiagonal ops were not fused" + + fused_op = fused_nodes[0].op + + assert fused_op.n_inputs == 3, f"Expected n_inputs=3, got {fused_op.n_inputs}" + + out_shape = fgraph.outputs[0].type.shape + assert out_shape == (9, 9), f"Unexpected fused output shape: {out_shape}" + + +def test_deeply_nested_blockdiag_fusion(): + x = pt.tensor("x", shape=(3, 3)) + y = pt.tensor("y", shape=(3, 3)) + z = pt.tensor("z", shape=(3, 3)) + w = pt.tensor("w", shape=(3, 3)) + + inner1 = BlockDiagonal(2)(x, y) + inner2 = BlockDiagonal(2)(inner1, z) + outer = BlockDiagonal(2)(inner2, w) + + f = pytensor.function([x, y, z, w], outer) + fgraph = f.maker.fgraph + + fused_nodes = [ + node for node in fgraph.apply_nodes if isinstance(node.op, BlockDiagonal) + ] + + assert len(fused_nodes) == 1, ( + f"Expected 1 fused BlockDiagonal, got {len(fused_nodes)}" + ) + + fused_op = fused_nodes[0].op + + assert fused_op.n_inputs == 4, ( + f"Expected n_inputs=4 after fusion, got {fused_op.n_inputs}" + ) + + out_shape = fgraph.outputs[0].type.shape + expected_shape = (12, 12) # 4 blocks of (3x3) + assert out_shape == expected_shape, ( + f"Unexpected fused output shape: expected {expected_shape}, got {out_shape}" + ) + + def test_matrix_inverse_rop_lop(): rtol = 1e-7 if config.floatX == "float64" else 1e-5 mx = matrix("mx")