Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytensor/tensor/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -1375,7 +1375,7 @@ def clone(
axis=axis,
dtype=dtype,
acc_dtype=acc_dtype,
upcast_discrete_output=None,
upcast_discrete_output=upcast_discrete_output,
**kwargs,
)

Expand Down
103 changes: 44 additions & 59 deletions pytensor/tensor/rewriting/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
get_underlying_scalar_constant_value,
moveaxis,
ones_like,
register_infer_shape,
split,
switch,
zeros,
Expand Down Expand Up @@ -1848,13 +1847,16 @@ def local_reduce_chain(fgraph, node) -> list[TensorVariable] | None:


@register_canonicalize
@register_uncanonicalize # Needed for MaxAndArgmax -> CAReduce
@register_specialize
@node_rewriter([CAReduce])
def local_reduce_join(fgraph, node):
"""
CAReduce{scalar.op}(Join(axis=x, a, b), axis=x) -> Elemwise{scalar.op}(a, b)
reduce = CAReduce({scalar.op}, axis=None | (x, ...))
reduce(Join(axis=x, a, b)) -> Elemwise{scalar.op}(reduce(a), reduce(b))

When a, b have a dim length of 1 along the join axis
For now limited where:
- Join has as many inputs as the Elemwise can handle
- The reduction covers the join axis

"""
if not (node.inputs[0].owner and isinstance(node.inputs[0].owner.op, Join)):
Expand All @@ -1870,42 +1872,30 @@ def local_reduce_join(fgraph, node):
return None
if n_joined_inputs > 2 and not isinstance(node.op.scalar_op, ps.Add | ps.Mul):
# We don't rewrite if a single Elemwise cannot take all inputs at once
# TODO: Make all CAReduce scalar ops support variadic inputs
return None

if not isinstance(join_axis_tensor, Constant):
return None
join_axis = join_axis_tensor.data

# Check whether reduction happens on joined axis
# Check whether reduction covers join axis
# TODO: Even if it doesn't we can apply the rewrite, it just require another join after
reduce_op = node.op
reduce_axis = reduce_op.axis
if reduce_axis is None:
if joined_out.type.ndim > 1:
return None
elif reduce_axis != (join_axis,):
if reduce_axis is not None and not (
isinstance(join_axis_tensor, Constant) and join_axis_tensor.data in reduce_axis
):
return None

# Check all inputs are broadcastable along the join axis and squeeze those dims away
new_inputs = []
for inp in joined_inputs:
if not inp.type.broadcastable[join_axis]:
return None
# Most times inputs to join have an expand_dims, we eagerly clean up those here
new_input = apply_local_dimshuffle_lift(fgraph, inp.squeeze(join_axis))
new_inputs.append(new_input)

ret = Elemwise(node.op.scalar_op)(*new_inputs)
reduced_inputs = [reduce_op(inp) for inp in joined_inputs]
ret = Elemwise(node.op.scalar_op)(*reduced_inputs)

if ret.dtype != node.outputs[0].dtype:
# The reduction do something about the dtype.
# The reduction does something about the dtype.
return None

return [ret]


@register_infer_shape
@register_canonicalize("fast_compile", "local_cut_useless_reduce")
@register_useless("local_cut_useless_reduce")
@register_canonicalize("fast_compile")
@register_useless
@node_rewriter([CAReduce])
def local_useless_reduce(fgraph, node):
"""Sum(a, axis=[]) -> a"""
Expand All @@ -1922,40 +1912,35 @@ def local_useless_reduce(fgraph, node):
def local_reduce_broadcastable(fgraph, node):
"""Remove reduction over broadcastable dimensions."""
(reduced,) = node.inputs
reduced_broadcastable = reduced.type.broadcastable
odtype = node.outputs[0].dtype
if node.op.axis is None:
if all(reduced.broadcastable):
return [reduced.dimshuffle().astype(odtype)]
reduce_axis = node.op.axis
if reduce_axis is None:
reduce_axis = tuple(range(reduced.type.ndim))

cuttable = [a for a in reduce_axis if reduced_broadcastable[a]]

if not cuttable:
return None

new_reduce_axis = []
counter = 0
for i in range(reduced.type.ndim):
if i not in cuttable:
if i in reduce_axis:
new_reduce_axis.append(counter)
counter += 1

new_reduced = reduced.squeeze(cuttable)
# Not rare that we can get rid of useless squeeze(expand_dims). Call eagerly here
new_reduced = apply_local_dimshuffle_lift(fgraph, new_reduced)

if new_reduce_axis:
new_op = node.op.clone(axis=new_reduce_axis)
return [new_op(new_reduced)]
else:
axis = list(node.op.axis)
cuttable = [a for a in axis if reduced.broadcastable[a]]
if cuttable:
# -- we can remove some axes of summation.
new_axis = []
pattern = []
ii = 0
for p in range(reduced.ndim):
if p not in cuttable:
if p in axis:
new_axis.append(ii)
pattern.append(p)
ii += 1
new_reduced = reduced.dimshuffle(*pattern)
if new_axis:
if type(node.op) is CAReduce:
# This case handles `CAReduce` instances
# (e.g. generated by `scalar_elemwise`), and not the
# scalar `Op`-specific subclasses
# TODO FIXME: This highlights a major design flaw in
# `CAReduce` (or at least our use of it), and it needs
# to be fixed
new_op = node.op.__class__(node.op.scalar_op, axis=new_axis)
else:
new_op = node.op.__class__(axis=new_axis)
return [new_op(new_reduced)]
else:
# -- in this case we can remove the reduction completely
return [new_reduced.astype(odtype)]
# -- in this case we can remove the reduction completely
return [new_reduced.astype(odtype)]


@register_specialize
Expand Down
60 changes: 37 additions & 23 deletions tests/tensor/rewriting/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@
)
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.graph.rewriting.utils import is_same_graph, rewrite_graph
from pytensor.graph.traversal import ancestors
from pytensor.graph.traversal import ancestors, apply_ancestors
from pytensor.printing import debugprint
from pytensor.scalar import PolyGamma, Psi, TriGamma
from pytensor.tensor import inplace
from pytensor.tensor.basic import Alloc, constant, join, second, switch
from pytensor.tensor.basic import Alloc, Join, constant, join, second, switch
from pytensor.tensor.blas import Dot22, Gemv
from pytensor.tensor.blas_c import CGemv
from pytensor.tensor.blockwise import Blockwise
Expand Down Expand Up @@ -3413,14 +3413,20 @@ def test_local_reduce_broadcast_some_1(self):

class TestReduceJoin:
def setup_method(self):
self.mode = get_default_mode().including(
"canonicalize", "specialize", "uncanonicalize"
)
self.mode = get_default_mode().including("canonicalize", "specialize")

@pytest.mark.parametrize(
"op, nin", [(pt_sum, 3), (pt_max, 2), (pt_min, 2), (prod, 3)]
"reduce_op, nin", [(pt_sum, 3), (pt_max, 2), (pt_min, 2), (prod, 3)]
)
def test_local_reduce_join(self, op, nin):
@pytest.mark.parametrize(
"reduce_axis",
(
0,
pytest.param(1, marks=pytest.mark.xfail(reason="Not implemented yet")),
None,
),
)
def test_local_reduce_join(self, reduce_op, nin, reduce_axis):
vx = matrix()
vy = matrix()
vz = matrix()
Expand All @@ -3431,18 +3437,25 @@ def test_local_reduce_join(self, op, nin):
inputs = (vx, vy, vz)[:nin]
test_values = (x, y, z)[:nin]

out = op(inputs, axis=0)
out = reduce_op(inputs, axis=reduce_axis)
assert sum(isinstance(node.op, Join) for node in apply_ancestors([out])) == 1
f = function(inputs, out, mode=self.mode)

np.testing.assert_allclose(
f(*test_values), getattr(np, op.__name__)(test_values, axis=0)
f(*test_values),
getattr(np, reduce_op.__name__)(test_values, axis=reduce_axis),
)
assert (
sum(
isinstance(node.op, Join)
for node in apply_ancestors(f.maker.fgraph.outputs)
)
== 0
)
topo = f.maker.fgraph.toposort()
assert len(topo) <= 2
assert isinstance(topo[-1].op, Elemwise)

def test_type(self):
# Test different axis for the join and the reduction
# We must force the dtype, of otherwise, this tests will fail
# We must force the dtype, otherwise, this tests will fail
# on 32 bit systems
A = shared(np.array([1, 2, 3, 4, 5], dtype="int64"))

Expand All @@ -3451,29 +3464,28 @@ def test_type(self):
topo = f.maker.fgraph.toposort()
assert isinstance(topo[-1].op, Elemwise)

# Test a case that was bugged in a old PyTensor bug
# Test a case that was bugged in an old PyTensor version
f = function([], pt_sum(pt.stack([A, A]), axis=1), mode=self.mode)

np.testing.assert_allclose(f(), [15, 15])
topo = f.maker.fgraph.toposort()
assert not isinstance(topo[-1].op, Elemwise)
assert sum(isinstance(node.op, Join) for node in topo) == 1

# This case could be rewritten
# This case can be rewritten
A = shared(np.array([1, 2, 3, 4, 5]).reshape(5, 1))
f = function([], pt_sum(pt.concatenate((A, A), axis=1), axis=1), mode=self.mode)
np.testing.assert_allclose(f(), [2, 4, 6, 8, 10])
topo = f.maker.fgraph.toposort()
assert not isinstance(topo[-1].op, Elemwise)
assert sum(isinstance(node.op, Join) for node in topo) == 0

A = shared(np.array([1, 2, 3, 4, 5]).reshape(5, 1))
f = function([], pt_sum(pt.concatenate((A, A), axis=1), axis=0), mode=self.mode)
np.testing.assert_allclose(f(), [15, 15])
topo = f.maker.fgraph.toposort()
assert not isinstance(topo[-1].op, Elemwise)
assert sum(isinstance(node.op, Join) for node in topo) == 1

def test_not_supported_axis_none(self):
# Test that the rewrite does not crash in one case where it
# is not applied. Reported at
def test_bug_regression(self):
# Test that the rewrite does not crash in one case where it used to
# https://groups.google.com/d/topic/theano-users/EDgyCU00fFA/discussion
vx = matrix()
vy = matrix()
Expand All @@ -3484,9 +3496,10 @@ def test_not_supported_axis_none(self):

out = pt_sum([vx, vy, vz], axis=None)
f = function([vx, vy, vz], out, mode=self.mode)
assert sum(isinstance(node.op, Join) for node in f.maker.fgraph.toposort()) == 0
np.testing.assert_allclose(f(x, y, z), np.sum([x, y, z]))

def test_not_supported_unequal_shapes(self):
def test_unequal_shapes(self):
# Not the same shape along the join axis
vx = matrix(shape=(1, 3))
vy = matrix(shape=(2, 3))
Expand All @@ -3495,6 +3508,7 @@ def test_not_supported_unequal_shapes(self):
out = pt_sum(join(0, vx, vy), axis=0)

f = function([vx, vy], out, mode=self.mode)
assert sum(isinstance(node.op, Join) for node in f.maker.fgraph.toposort()) == 0
np.testing.assert_allclose(
f(x, y), np.sum(np.concatenate([x, y], axis=0), axis=0)
)
Expand All @@ -3514,7 +3528,7 @@ def test_non_ds_inputs(self):

fg = FunctionGraph([x], [out], clone=False)
[rewritten_out] = local_reduce_join.transform(fg, out.owner)
expected_out = add(exp(x), log(x))
expected_out = add(exp(x[None]).sum(axis=0), log(x[None]).sum(axis=0))
assert equal_computations([rewritten_out], [expected_out])


Expand Down
Loading