diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index 80f9a5f7da..90897559ed 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -1375,7 +1375,7 @@ def clone( axis=axis, dtype=dtype, acc_dtype=acc_dtype, - upcast_discrete_output=None, + upcast_discrete_output=upcast_discrete_output, **kwargs, ) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 307700ac70..ab134cc380 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -31,7 +31,6 @@ get_underlying_scalar_constant_value, moveaxis, ones_like, - register_infer_shape, split, switch, zeros, @@ -1848,13 +1847,16 @@ def local_reduce_chain(fgraph, node) -> list[TensorVariable] | None: @register_canonicalize -@register_uncanonicalize # Needed for MaxAndArgmax -> CAReduce +@register_specialize @node_rewriter([CAReduce]) def local_reduce_join(fgraph, node): """ - CAReduce{scalar.op}(Join(axis=x, a, b), axis=x) -> Elemwise{scalar.op}(a, b) + reduce = CAReduce({scalar.op}, axis=None | (x, ...)) + reduce(Join(axis=x, a, b)) -> Elemwise{scalar.op}(reduce(a), reduce(b)) - When a, b have a dim length of 1 along the join axis + For now limited where: + - Join has as many inputs as the Elemwise can handle + - The reduction covers the join axis """ if not (node.inputs[0].owner and isinstance(node.inputs[0].owner.op, Join)): @@ -1870,42 +1872,30 @@ def local_reduce_join(fgraph, node): return None if n_joined_inputs > 2 and not isinstance(node.op.scalar_op, ps.Add | ps.Mul): # We don't rewrite if a single Elemwise cannot take all inputs at once + # TODO: Make all CAReduce scalar ops support variadic inputs return None - if not isinstance(join_axis_tensor, Constant): - return None - join_axis = join_axis_tensor.data - - # Check whether reduction happens on joined axis + # Check whether reduction covers join axis + # TODO: Even if it doesn't we can apply the rewrite, it just require another join after reduce_op = node.op reduce_axis = reduce_op.axis - if reduce_axis is None: - if joined_out.type.ndim > 1: - return None - elif reduce_axis != (join_axis,): + if reduce_axis is not None and not ( + isinstance(join_axis_tensor, Constant) and join_axis_tensor.data in reduce_axis + ): return None - # Check all inputs are broadcastable along the join axis and squeeze those dims away - new_inputs = [] - for inp in joined_inputs: - if not inp.type.broadcastable[join_axis]: - return None - # Most times inputs to join have an expand_dims, we eagerly clean up those here - new_input = apply_local_dimshuffle_lift(fgraph, inp.squeeze(join_axis)) - new_inputs.append(new_input) - - ret = Elemwise(node.op.scalar_op)(*new_inputs) + reduced_inputs = [reduce_op(inp) for inp in joined_inputs] + ret = Elemwise(node.op.scalar_op)(*reduced_inputs) if ret.dtype != node.outputs[0].dtype: - # The reduction do something about the dtype. + # The reduction does something about the dtype. return None return [ret] -@register_infer_shape -@register_canonicalize("fast_compile", "local_cut_useless_reduce") -@register_useless("local_cut_useless_reduce") +@register_canonicalize("fast_compile") +@register_useless @node_rewriter([CAReduce]) def local_useless_reduce(fgraph, node): """Sum(a, axis=[]) -> a""" @@ -1922,40 +1912,35 @@ def local_useless_reduce(fgraph, node): def local_reduce_broadcastable(fgraph, node): """Remove reduction over broadcastable dimensions.""" (reduced,) = node.inputs + reduced_broadcastable = reduced.type.broadcastable odtype = node.outputs[0].dtype - if node.op.axis is None: - if all(reduced.broadcastable): - return [reduced.dimshuffle().astype(odtype)] + reduce_axis = node.op.axis + if reduce_axis is None: + reduce_axis = tuple(range(reduced.type.ndim)) + + cuttable = [a for a in reduce_axis if reduced_broadcastable[a]] + + if not cuttable: + return None + + new_reduce_axis = [] + counter = 0 + for i in range(reduced.type.ndim): + if i not in cuttable: + if i in reduce_axis: + new_reduce_axis.append(counter) + counter += 1 + + new_reduced = reduced.squeeze(cuttable) + # Not rare that we can get rid of useless squeeze(expand_dims). Call eagerly here + new_reduced = apply_local_dimshuffle_lift(fgraph, new_reduced) + + if new_reduce_axis: + new_op = node.op.clone(axis=new_reduce_axis) + return [new_op(new_reduced)] else: - axis = list(node.op.axis) - cuttable = [a for a in axis if reduced.broadcastable[a]] - if cuttable: - # -- we can remove some axes of summation. - new_axis = [] - pattern = [] - ii = 0 - for p in range(reduced.ndim): - if p not in cuttable: - if p in axis: - new_axis.append(ii) - pattern.append(p) - ii += 1 - new_reduced = reduced.dimshuffle(*pattern) - if new_axis: - if type(node.op) is CAReduce: - # This case handles `CAReduce` instances - # (e.g. generated by `scalar_elemwise`), and not the - # scalar `Op`-specific subclasses - # TODO FIXME: This highlights a major design flaw in - # `CAReduce` (or at least our use of it), and it needs - # to be fixed - new_op = node.op.__class__(node.op.scalar_op, axis=new_axis) - else: - new_op = node.op.__class__(axis=new_axis) - return [new_op(new_reduced)] - else: - # -- in this case we can remove the reduction completely - return [new_reduced.astype(odtype)] + # -- in this case we can remove the reduction completely + return [new_reduced.astype(odtype)] @register_specialize diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 2d6e36c3ad..9d443909ff 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -28,11 +28,11 @@ ) from pytensor.graph.rewriting.db import RewriteDatabaseQuery from pytensor.graph.rewriting.utils import is_same_graph, rewrite_graph -from pytensor.graph.traversal import ancestors +from pytensor.graph.traversal import ancestors, apply_ancestors from pytensor.printing import debugprint from pytensor.scalar import PolyGamma, Psi, TriGamma from pytensor.tensor import inplace -from pytensor.tensor.basic import Alloc, constant, join, second, switch +from pytensor.tensor.basic import Alloc, Join, constant, join, second, switch from pytensor.tensor.blas import Dot22, Gemv from pytensor.tensor.blas_c import CGemv from pytensor.tensor.blockwise import Blockwise @@ -3413,14 +3413,20 @@ def test_local_reduce_broadcast_some_1(self): class TestReduceJoin: def setup_method(self): - self.mode = get_default_mode().including( - "canonicalize", "specialize", "uncanonicalize" - ) + self.mode = get_default_mode().including("canonicalize", "specialize") @pytest.mark.parametrize( - "op, nin", [(pt_sum, 3), (pt_max, 2), (pt_min, 2), (prod, 3)] + "reduce_op, nin", [(pt_sum, 3), (pt_max, 2), (pt_min, 2), (prod, 3)] ) - def test_local_reduce_join(self, op, nin): + @pytest.mark.parametrize( + "reduce_axis", + ( + 0, + pytest.param(1, marks=pytest.mark.xfail(reason="Not implemented yet")), + None, + ), + ) + def test_local_reduce_join(self, reduce_op, nin, reduce_axis): vx = matrix() vy = matrix() vz = matrix() @@ -3431,18 +3437,25 @@ def test_local_reduce_join(self, op, nin): inputs = (vx, vy, vz)[:nin] test_values = (x, y, z)[:nin] - out = op(inputs, axis=0) + out = reduce_op(inputs, axis=reduce_axis) + assert sum(isinstance(node.op, Join) for node in apply_ancestors([out])) == 1 f = function(inputs, out, mode=self.mode) + np.testing.assert_allclose( - f(*test_values), getattr(np, op.__name__)(test_values, axis=0) + f(*test_values), + getattr(np, reduce_op.__name__)(test_values, axis=reduce_axis), + ) + assert ( + sum( + isinstance(node.op, Join) + for node in apply_ancestors(f.maker.fgraph.outputs) + ) + == 0 ) - topo = f.maker.fgraph.toposort() - assert len(topo) <= 2 - assert isinstance(topo[-1].op, Elemwise) def test_type(self): # Test different axis for the join and the reduction - # We must force the dtype, of otherwise, this tests will fail + # We must force the dtype, otherwise, this tests will fail # on 32 bit systems A = shared(np.array([1, 2, 3, 4, 5], dtype="int64")) @@ -3451,29 +3464,28 @@ def test_type(self): topo = f.maker.fgraph.toposort() assert isinstance(topo[-1].op, Elemwise) - # Test a case that was bugged in a old PyTensor bug + # Test a case that was bugged in an old PyTensor version f = function([], pt_sum(pt.stack([A, A]), axis=1), mode=self.mode) np.testing.assert_allclose(f(), [15, 15]) topo = f.maker.fgraph.toposort() - assert not isinstance(topo[-1].op, Elemwise) + assert sum(isinstance(node.op, Join) for node in topo) == 1 - # This case could be rewritten + # This case can be rewritten A = shared(np.array([1, 2, 3, 4, 5]).reshape(5, 1)) f = function([], pt_sum(pt.concatenate((A, A), axis=1), axis=1), mode=self.mode) np.testing.assert_allclose(f(), [2, 4, 6, 8, 10]) topo = f.maker.fgraph.toposort() - assert not isinstance(topo[-1].op, Elemwise) + assert sum(isinstance(node.op, Join) for node in topo) == 0 A = shared(np.array([1, 2, 3, 4, 5]).reshape(5, 1)) f = function([], pt_sum(pt.concatenate((A, A), axis=1), axis=0), mode=self.mode) np.testing.assert_allclose(f(), [15, 15]) topo = f.maker.fgraph.toposort() - assert not isinstance(topo[-1].op, Elemwise) + assert sum(isinstance(node.op, Join) for node in topo) == 1 - def test_not_supported_axis_none(self): - # Test that the rewrite does not crash in one case where it - # is not applied. Reported at + def test_bug_regression(self): + # Test that the rewrite does not crash in one case where it used to # https://groups.google.com/d/topic/theano-users/EDgyCU00fFA/discussion vx = matrix() vy = matrix() @@ -3484,9 +3496,10 @@ def test_not_supported_axis_none(self): out = pt_sum([vx, vy, vz], axis=None) f = function([vx, vy, vz], out, mode=self.mode) + assert sum(isinstance(node.op, Join) for node in f.maker.fgraph.toposort()) == 0 np.testing.assert_allclose(f(x, y, z), np.sum([x, y, z])) - def test_not_supported_unequal_shapes(self): + def test_unequal_shapes(self): # Not the same shape along the join axis vx = matrix(shape=(1, 3)) vy = matrix(shape=(2, 3)) @@ -3495,6 +3508,7 @@ def test_not_supported_unequal_shapes(self): out = pt_sum(join(0, vx, vy), axis=0) f = function([vx, vy], out, mode=self.mode) + assert sum(isinstance(node.op, Join) for node in f.maker.fgraph.toposort()) == 0 np.testing.assert_allclose( f(x, y), np.sum(np.concatenate([x, y], axis=0), axis=0) ) @@ -3514,7 +3528,7 @@ def test_non_ds_inputs(self): fg = FunctionGraph([x], [out], clone=False) [rewritten_out] = local_reduce_join.transform(fg, out.owner) - expected_out = add(exp(x), log(x)) + expected_out = add(exp(x[None]).sum(axis=0), log(x[None]).sum(axis=0)) assert equal_computations([rewritten_out], [expected_out])