Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions pytensor/tensor/rewriting/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,32 @@
MATRIX_INVERSE_OPS = (MatrixInverse, MatrixPinv)


@register_canonicalize
@node_rewriter([BlockDiagonal])
def fuse_blockdiagonal(fgraph, node):
"""Fuse nested BlockDiagonal ops into a single BlockDiagonal."""

if not isinstance(node.op, BlockDiagonal):
return None

new_inputs = []
changed = False

for inp in node.inputs:
if inp.owner and isinstance(inp.owner.op, BlockDiagonal):
new_inputs.extend(inp.owner.inputs)
changed = True
else:
new_inputs.append(inp)

if changed:
fused_op = BlockDiagonal(len(new_inputs))
new_output = fused_op(*new_inputs)
return [new_output]

return None


def is_matrix_transpose(x: TensorVariable) -> bool:
"""Check if a variable corresponds to a transpose of the last two axes"""
node = x.owner
Expand Down
65 changes: 65 additions & 0 deletions tests/tensor/rewriting/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,71 @@
from tests.test_rop import break_op


def test_nested_blockdiag_fusion():
x = pt.tensor("x", shape=(3, 3))
y = pt.tensor("y", shape=(3, 3))
z = pt.tensor("z", shape=(3, 3))

inner = BlockDiagonal(2)(x, y)
outer = BlockDiagonal(2)(inner, z)

nodes_before = ancestors([outer])
initial_count = sum(
1
for node in nodes_before
if getattr(node, "owner", None) and isinstance(node.owner.op, BlockDiagonal)
)
assert initial_count == 2, "Setup failed: expected 2 nested BlockDiagonal ops"

f = pytensor.function([x, y, z], outer)
fgraph = f.maker.fgraph

nodes_after = fgraph.apply_nodes
fused_nodes = [node for node in nodes_after if isinstance(node.op, BlockDiagonal)]
assert len(fused_nodes) == 1, "Nested BlockDiagonal ops were not fused"

fused_op = fused_nodes[0].op

assert fused_op.n_inputs == 3, f"Expected n_inputs=3, got {fused_op.n_inputs}"

out_shape = fgraph.outputs[0].type.shape
assert out_shape == (9, 9), f"Unexpected fused output shape: {out_shape}"


def test_deeply_nested_blockdiag_fusion():
x = pt.tensor("x", shape=(3, 3))
y = pt.tensor("y", shape=(3, 3))
z = pt.tensor("z", shape=(3, 3))
w = pt.tensor("w", shape=(3, 3))

inner1 = BlockDiagonal(2)(x, y)
inner2 = BlockDiagonal(2)(inner1, z)
outer = BlockDiagonal(2)(inner2, w)

f = pytensor.function([x, y, z, w], outer)
fgraph = f.maker.fgraph

fused_nodes = [
node for node in fgraph.apply_nodes if isinstance(node.op, BlockDiagonal)
]

assert len(fused_nodes) == 1, (
f"Expected 1 fused BlockDiagonal, got {len(fused_nodes)}"
)

fused_op = fused_nodes[0].op

assert fused_op.n_inputs == 4, (
f"Expected n_inputs=4 after fusion, got {fused_op.n_inputs}"
)

out_shape = fgraph.outputs[0].type.shape
expected_shape = (12, 12) # 4 blocks of (3x3)
assert out_shape == expected_shape, (
f"Unexpected fused output shape: expected {expected_shape}, got {out_shape}"
)


def test_matrix_inverse_rop_lop():
rtol = 1e-7 if config.floatX == "float64" else 1e-5
mx = matrix("mx")
Expand Down