diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py index fbe97b9a68..fc7c908755 100644 --- a/pytensor/tensor/rewriting/subtensor.py +++ b/pytensor/tensor/rewriting/subtensor.py @@ -15,8 +15,9 @@ node_rewriter, ) from pytensor.raise_op import Assert -from pytensor.scalar import Add, ScalarConstant, ScalarType +from pytensor.scalar import Add, ScalarConstant from pytensor.scalar import constant as scalar_constant +from pytensor.scalar import switch as scalar_switch from pytensor.tensor.basic import ( Alloc, ExtractDiag, @@ -28,6 +29,7 @@ cast, concatenate, expand_dims, + fill, full, get_scalar_constant_value, get_underlying_scalar_constant_value, @@ -162,15 +164,18 @@ def transform_take(a, indices, axis): return transform_take(a, indices.flatten(), axis).reshape(shape, ndim=ndim) +none_slice = slice(None) + + def is_full_slice(x): """Determine if `x` is a ``slice(None)`` or a symbolic equivalent.""" if isinstance(x, slice): - return x == slice(None) + return x == none_slice if isinstance(x, Variable) and isinstance(x.type, SliceType): if x.owner is None: if isinstance(x, Constant): - return x.data == slice(None) + return x.data == none_slice else: # Root slice variable return False @@ -281,7 +286,7 @@ def local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1(fgraph, node): @register_canonicalize @register_specialize @register_stabilize -@node_rewriter([Subtensor]) +@node_rewriter([Subtensor, IncSubtensor]) def local_useless_slice(fgraph, node): """ Remove useless slice(None) of the form: @@ -297,69 +302,113 @@ def local_useless_slice(fgraph, node): where x is a vector of length 7 """ - idxs = get_idx_list(node.inputs, node.op.idx_list) - x = node.inputs[0] + is_subtensor = isinstance(node.op, Subtensor) + if is_subtensor: + x, *indices = node.inputs + else: # IncSubtensor + x, y, *indices = node.inputs - if not idxs: - return [node.inputs[0]] + idx_list = node.op.idx_list - new_idxs = list(idxs) - change_flag = False - last_useful_idx = -1 - for dim, s in enumerate(new_idxs): - if not isinstance(s, slice): - last_useful_idx = dim - continue + if idx_list: + new_idxs = list(indices_from_subtensor(indices, idx_list)) + change_flag = False + last_useful_idx = -1 + for dim, s in enumerate(new_idxs): + if not isinstance(s, slice): + last_useful_idx = dim + continue - if s == slice(None): - continue + if s == none_slice: + continue + + step = s.step + + if step is None: + positive_step = True + elif isinstance(step, Constant): + step_value = step.data + positive_step = step.data > 0 + if step_value == 1: + change_flag = True + step = None + else: + # We can only canonicalize start and stop if we know the sign of step + last_useful_idx = dim + continue - step = s.step + start = s.start + stop = s.stop - if step is None: - positive_step = True - elif isinstance(step, Constant): - step_value = step.data - positive_step = step.data > 0 - if step_value == 1: + # TODO: start and stop can also be useless (e.g., arange(5)[-7:] or arange(5)[:7]) + if ( + start is not None + and isinstance(start, Constant) + and start.data == (0 if positive_step else -1) + ): change_flag = True - step = None - else: - # We can only canonicalize start and stop if we know the sign of step - last_useful_idx = dim - continue + start = None + + if ( + stop is not None + and x.type.shape[dim] is not None + and isinstance(stop, Constant) + and stop.data + == (x.type.shape[dim] if positive_step else -x.type.shape[dim] - 1) + ): + change_flag = True + stop = None + + if start is not None or stop is not None or step is not None: + last_useful_idx = dim - start = s.start - stop = s.stop + new_idxs[dim] = slice(start, stop, step) - if start is not None and get_scalar_constant_value( - start, only_process_constants=True, raise_not_constant=False - ) == (0 if positive_step else -1): + if (last_useful_idx + 1) < len(new_idxs): + new_idxs = new_idxs[: last_useful_idx + 1] change_flag = True - start = None - if ( - stop is not None - and x.type.shape[dim] is not None - and get_scalar_constant_value( - stop, only_process_constants=True, raise_not_constant=False - ) - == (x.type.shape[dim] if positive_step else -x.type.shape[dim] - 1) + else: # No indices + new_idxs = () + change_flag = True + + if is_subtensor: + if not change_flag: + return None + + new_out = x[tuple(new_idxs)] if new_idxs else x + + else: + # Check if we're left only with reversed slices + # TODO: Or {0, -1} indices on length 1 dimensions + # Otherwise keep it, as we wouldn't be reduce the number of ops + if all( + isinstance(idx, slice) + and idx.step is not None + and idx.start is None + and idx.stop is None + and (isinstance(idx.step, Constant) and idx.step.data == -1) + for idx in new_idxs ): + # In that case reverse y instead + # arange(5)[::-1] = y == y[::-1] + # arange(5)[::-1] += y == arange(5) += y[::-1] + # TODO: or expand_dims + # zeros((5, 1, 5))[:, 0, :] += y[:, :] == zeros((5, 1, 5)) += y[:, None, :] + y = y[tuple(new_idxs)] + new_idxs = [] change_flag = True - stop = None - if start is not None or stop is not None or step is not None: - last_useful_idx = dim + if not change_flag: + return None - new_idxs[dim] = slice(start, stop, step) + if node.op.set_instead_of_inc: + new_out = x[new_idxs].set(y) if new_idxs else fill(x, y) + else: + new_out = x[new_idxs].inc(y) if new_idxs else x + y - if change_flag or ((last_useful_idx + 1) < len(idxs)): - new_idxs = tuple(new_idxs[: last_useful_idx + 1]) - out = x[new_idxs] if new_idxs else x - # Copy over previous output stacktrace - copy_stack_trace(node.outputs, out) - return [out] + copy_stack_trace(node.outputs, new_out) + return [new_out] @register_canonicalize @@ -443,8 +492,8 @@ def local_subtensor_merge(fgraph, node): return [out] -@register_specialize @register_canonicalize +@register_specialize @node_rewriter([Subtensor]) def local_subtensor_remove_broadcastable_index(fgraph, node): """ @@ -454,41 +503,34 @@ def local_subtensor_remove_broadcastable_index(fgraph, node): a[0,:,-1,:] -> a.dimshuffle(1,3), when a.broadcastable = (True, False, True, False) + # TODO: Merge this with local_useless_slice """ - if isinstance(node.op, Subtensor): - idx = node.op.idx_list - else: - return + x, *index_vars = node.inputs - remove_dim = [] - node_inputs_idx = 1 - for dim, elem in enumerate(idx): - if isinstance(elem, ScalarType): - # The idx is a ScalarType, ie a Type. This means the actual index - # is contained in node.inputs[1] - dim_index = node.inputs[node_inputs_idx] - if isinstance(dim_index, ScalarConstant): - dim_index = dim_index.value - if dim_index in (0, -1) and node.inputs[0].broadcastable[dim]: - remove_dim.append(dim) - node_inputs_idx += 1 - else: - return - elif isinstance(elem, slice): - if elem != slice(None): - return - elif isinstance(elem, int | np.integer): - if elem in (0, -1) and node.inputs[0].broadcastable[dim]: - remove_dim.append(dim) - else: - raise TypeError("case not expected") + # Check that we only have none slices or {0, -1} indices in dimesions of length 1 + index_vars_iter = iter(index_vars) + dims_to_remove = [] + for dim, (idx_type, dim_is_len1) in enumerate( + zip(node.op.idx_list, x.type.broadcastable) + ): + if isinstance(idx_type, slice): + if idx_type == none_slice: + continue - if len(remove_dim) == 0: - return - else: - all_dim = range(node.inputs[0].ndim) - remain_dim = [x for x in all_dim if x not in remove_dim] - return [node.inputs[0].dimshuffle(tuple(remain_dim))] + elif dim_is_len1: + dim_idx = next(index_vars_iter) + + # TODO: Add a shape_unsafe variant where we don't care about the + # value of the index if the dimension is broadcastable. + # This is only for symbolic indices. We would still not mask constant invalid indices + if isinstance(dim_idx, ScalarConstant) and idx_type.data in (0, -1): + dims_to_remove.append(dim) + continue + + # Constraints not met, we have either a meaningful slice or index, or an invalid index + return None + + return [x.squeeze(dims_to_remove)] @register_specialize @@ -499,102 +541,34 @@ def local_subtensor_inc_subtensor(fgraph, node): Subtensor(SetSubtensor(x, y, idx), idx) -> y """ - if isinstance(node.op, Subtensor): - x = node.inputs[0] - if not (x.owner and isinstance(x.owner.op, IncSubtensor)): - return - if not x.owner.op.set_instead_of_inc: - return - - if x.owner.inputs[2:] == node.inputs[1:] and tuple( - x.owner.op.idx_list - ) == tuple(node.op.idx_list): - out = node.outputs[0] - y = x.owner.inputs[1] - # If the dtypes differ, cast y into x.dtype - if x.dtype != y.dtype: - y = y.astype(x.dtype) - if ( - out.type.dtype == y.type.dtype - and out.type.broadcastable == y.type.broadcastable - ): - # if x[idx] and y have the same type, directly return y - return [y] - else: - # The difference is related to broadcasting pattern - assert out.broadcastable != y.broadcastable - # We have to alloc y to the shape of x[idx] - x_subtensor = node.op(x.owner.inputs[0], *x.owner.inputs[2:]) - return [alloc(y, *x_subtensor.shape)] - else: - return - - -@register_useless -@register_canonicalize -@register_specialize -@node_rewriter([IncSubtensor]) -def local_useless_inc_subtensor(fgraph, node): - r"""Remove redundant `IncSubtensor`\s. - - More specifically, ``set_subtensor(x[indices], y)`` is replaced by - ``y[indices]`` when ``indices`` are full `slice`\s and ``y``'s shape is - equal to ``x[indices]``, and ``inc_subtensor(x[indices], y)`` is replaced - by ``y[indices]`` when ``x[indices]`` is some array of ``0``\s, ``indices`` - are full slices, and the shapes are equal. - """ - if not isinstance(node.op, IncSubtensor): + x = node.inputs[0] + if not (x.owner and isinstance(x.owner.op, IncSubtensor)): return - - if not hasattr(fgraph, "shape_feature"): + if not x.owner.op.set_instead_of_inc: return - x, y, *index_inputs = node.inputs - - if node.op.set_instead_of_inc is False: - # This is an increment operation, so the array being incremented must - # consist of all zeros in order for the entire operation to be useless - try: - c = get_underlying_scalar_constant_value(x) - if c != 0: - return - except NotScalarConstantError: - return - - idx_cst = indices_from_subtensor(list(index_inputs), node.op.idx_list) - - # Check that all indices are full slices with only reversals and no step - # sizes - # TODO: It seems like there should be a basic `IncSubtensor` - # canonicalization that removes these redundant slices. - if all( - isinstance(e, slice) - and e.start is None - and e.stop is None - and ( - e.step is None - or get_scalar_constant_value( - e.step, only_process_constants=True, raise_not_constant=False - ) - == -1 - ) - for e in idx_cst + if x.owner.inputs[2:] == node.inputs[1:] and tuple(x.owner.op.idx_list) == tuple( + node.op.idx_list ): - # `IncSubtensor` broadcasts `x` on `y` based on run-time shapes, so we - # must check that they are the same - if not fgraph.shape_feature.same_shape(x, y): - return - - # There are no reversals, so we don't need a replacement. - if all(e.step is None for e in node.op.idx_list): - # They are exactly the same shapes, so we can remove this `IncSubtensor` + out = node.outputs[0] + y = x.owner.inputs[1] + # If the dtypes differ, cast y into x.dtype + if x.dtype != y.dtype: + y = y.astype(x.dtype) + if ( + out.type.dtype == y.type.dtype + and out.type.broadcastable == y.type.broadcastable + ): + # if x[idx] and y have the same type, directly return y return [y] - - new_node = Subtensor(node.op.idx_list).make_node(y, *index_inputs) - new_out = new_node.outputs[0] - copy_stack_trace(node.outputs, new_out) - - return [new_out] + else: + # The difference is related to broadcasting pattern + assert out.broadcastable != y.broadcastable + # We have to alloc y to the shape of x[idx] + x_subtensor = node.op(x.owner.inputs[0], *x.owner.inputs[2:]) + return [alloc(y, *x_subtensor.shape)] + else: + return @register_canonicalize @@ -610,8 +584,7 @@ def local_set_to_inc_subtensor(fgraph, node): """ if ( - isinstance(node.op, AdvancedIncSubtensor1) - and node.op.set_instead_of_inc + node.op.set_instead_of_inc and node.inputs[1].owner and isinstance(node.inputs[1].owner.op, Elemwise) and isinstance(node.inputs[1].owner.op.scalar_op, Add) @@ -646,6 +619,8 @@ def local_set_to_inc_subtensor(fgraph, node): @node_rewriter([Subtensor]) def local_useless_subtensor(fgraph, node): """Remove `Subtensor` if it takes the full input.""" + # TODO: This is largely a duplicate of local_useless_slice except it also checks if symbolic stop matches shape_of + # from the ShapeFeature. This is unlikely to happen, but anway can be added to `local_useless_slice`. if not node.op.idx_list: return [node.inputs[0]] @@ -1193,6 +1168,83 @@ def local_setsubtensor_of_constants(fgraph, node): return False +def local_useless_nested_set_subtensor(fgraph, node): + """Rewrite nested set_subtensor on the same base buffer + + alloc(x, outer_shape)[outer_idx].set( + alloc(x, inner_shape)[inner_idx].set(y) + ) -> alloc(x, outher_shape)[merged_idx].set(y) + + """ + outer_alloc, inner_inc, *outer_index_vars = node.inputs + + if not ( + outer_alloc.owner is not None + and isinstance(outer_alloc.owner.op, Alloc) + and ( + node.op.set_instead_of_inc + # inc on zeros is the same as set + or outer_alloc.owner.op.value_is_scalar_zero(outer_alloc.owner.inputs[0]) + ) + and inner_inc.owner is not None + and isinstance(inner_inc.owner.op, IncSubtensor) + ): + return None + + bufer_val = outer_alloc.owner.inputs[0] + + inner_alloc, y, *inner_index_vars = inner_inc.owner.inputs + + if not ( + inner_alloc.owner is not None + and isinstance(inner_alloc.owner.op, Alloc) + and inner_alloc.owner.inputs[0] is bufer_val + and ( + inner_inc.owner.op.set_instead_of_inc + or inner_alloc.owner.op.value_is_scalar_zero(inner_alloc.owner.inputs[0]) + ) + ): + return None + + # Compute indices so writing is equivalent + # This hardcodes some common cases that show up, e.g., in gradient of scan + outer_indices = indices_from_subtensor(outer_index_vars, node.op.idx_list) + inner_indices = indices_from_subtensor( + inner_index_vars, inner_inc.owner.op.idx_list + ) + + # TODO: Support more cases + if ( + len(inner_indices) == 1 + and isinstance((outer_slice := outer_indices[0]), slice) + and outer_slice.stop is None + and outer_slice.step is None + and len(inner_indices) == 1 + and isinstance((inner_index := inner_indices[0]), Variable) + ): + if inner_alloc.type.broadcastable[0]: + # This is a useless inner set_subtensor. We let `local_useless_slices` handle it instead + return None + + # If inner_index is positive we add it to the start of the outer slice + # Otherwise use it as is. + # Alloc(...)[5:].set(Alloc(...)[2].set(y)) -> Alloc(...)[5 + 2].inc(y) + # Alloc(...)[5:].set(Alloc(...)[-2].set(y)) -> Alloc(...)[-2].inc(y) + # Alloc(...)[-5:].set(Alloc(...)[2].set(y)) -> Alloc(...)[-5 + 2].inc(y) + # Alloc(...)[-5:].set(Alloc(...)[-2].set(y)) -> Alloc(...)[-2].inc(y) + if outer_slice.start is None: + new_index = inner_index + else: + new_index = scalar_switch( + inner_index < 0, + inner_index, + outer_slice.start + inner_index, + ) + new_out = outer_alloc[new_index].inc(y) + copy_stack_trace(node.outputs, new_out) + return [new_out] + + @register_canonicalize @register_specialize @node_rewriter([AdvancedSubtensor1]) @@ -1275,100 +1327,99 @@ def local_useless_inc_subtensor_alloc(fgraph, node): intermediate `alloc` where possible. """ - if isinstance(node.op, IncSubtensor | AdvancedIncSubtensor | AdvancedIncSubtensor1): - x = node.inputs[0] - y = node.inputs[1] - i = node.inputs[2:] + x = node.inputs[0] + y = node.inputs[1] + i = node.inputs[2:] - if y.owner is not None and isinstance(y.owner.op, Alloc): - # `z` is the input of the Alloc op, i.e. at.alloc(z, ) - z = y.owner.inputs[0] + if y.owner is not None and isinstance(y.owner.op, Alloc): + # `z` is the input of the Alloc op, i.e. at.alloc(z, ) + z = y.owner.inputs[0] - try: - shape_feature = fgraph.shape_feature - except AttributeError: - # The shape feature may not be available in some mode, but we - # need it for this optimization, so don't continue. - return False + try: + shape_feature = fgraph.shape_feature + except AttributeError: + # The shape feature may not be available in some mode, but we + # need it for this optimization, so don't continue. + return False - shape_of = shape_feature.shape_of - same_shape = shape_feature.same_shape - - # Get the subtensor of `x` indexed by `i` in order to compare - # shapes later. - if isinstance(node.op, IncSubtensor): - xi = Subtensor(node.op.idx_list)(x, *i) - elif isinstance(node.op, AdvancedIncSubtensor): - xi = advanced_subtensor(x, *i) - elif isinstance(node.op, AdvancedIncSubtensor1): - xi = advanced_subtensor1(x, *i) - else: - raise Exception("Should never happen!") - - reason = "local_useless_incsubtensor_alloc" - - # Add `xi` to the shape feature `fgraph`. This is important for - # shape inference later because the variable must be part of the - # function graph in order to call `same_shape` on it. - if xi not in shape_of: - shape_feature.on_import(fgraph, xi.owner, f"{reason}: add `xi`") - - # `xi` may have more dimensions than `y` since the subtensor ops - # do automatic broadcasting of the increment internally. Thus, we - # need to make the leading implicitly broadcasted dimensions - # explicit for shape comparison later. - if xi.ndim > y.ndim: - y = shape_padleft(y, xi.ndim - y.ndim) - if y not in shape_of: - shape_feature.on_import(fgraph, y.owner, f"{reason}: add `y`") - - # Build `z_broad` explicitly to include extra implicit dimensions. - z_broad = (True,) * (xi.ndim - z.ndim) + z.broadcastable - - cond = [ - # The shapes of `y` and `xi` must either agree or `y` may - # also have shape equal to 1 which may be treated as a - # broadcastable dimension by the subtensor op. - or_(eq(y.shape[k], 1), eq(y.shape[k], xi.shape[k])) - # Loop over all dimensions. - for k in range(xi.ndim) - # We need to check the above shapes, if - # * the pre-alloc increment `z` is broadcastable in - # dimension `k` (if it isn't, then the shapes of `z` and - # `y` are the same by the definition of the `Alloc` op in - # this dimension and replacing `y` by `z` will not hide a - # shape error), and - # * `xi` and `y` do not have the same shape in dimension - # `k` or we cannot infer the shape statically (if the - # shapes of `xi` and `y` are not the same, then replacing - # `y` by `z` will hide the shape error of `y`), and - # * the shape of `y` is not equal to 1 or we cannot infer - # the shape statically (if the shape of `y` is equal to - # 1, then `y` is broadcasted by the inc_subtensor op - # internally, so the shapes of `xi` and `y` do not need - # to match in dimension `k`; else we need to check at - # runtime that the shape of `y` is either 1 or the same - # as `xi` or otherwise replacing `y` by `z` will hide a - # shape error). - if ( - z_broad[k] - and not same_shape(xi, y, dim_x=k, dim_y=k) - and shape_of[y][k] != 1 - ) - ] + shape_of = shape_feature.shape_of + same_shape = shape_feature.same_shape + + # Get the subtensor of `x` indexed by `i` in order to compare + # shapes later. + if isinstance(node.op, IncSubtensor): + xi = Subtensor(node.op.idx_list)(x, *i) + elif isinstance(node.op, AdvancedIncSubtensor): + xi = advanced_subtensor(x, *i) + elif isinstance(node.op, AdvancedIncSubtensor1): + xi = advanced_subtensor1(x, *i) + else: + raise Exception("Should never happen!") + + reason = "local_useless_incsubtensor_alloc" + + # Add `xi` to the shape feature `fgraph`. This is important for + # shape inference later because the variable must be part of the + # function graph in order to call `same_shape` on it. + if xi not in shape_of: + shape_feature.on_import(fgraph, xi.owner, f"{reason}: add `xi`") + + # `xi` may have more dimensions than `y` since the subtensor ops + # do automatic broadcasting of the increment internally. Thus, we + # need to make the leading implicitly broadcasted dimensions + # explicit for shape comparison later. + if xi.ndim > y.ndim: + y = shape_padleft(y, xi.ndim - y.ndim) + if y not in shape_of: + shape_feature.on_import(fgraph, y.owner, f"{reason}: add `y`") + + # Build `z_broad` explicitly to include extra implicit dimensions. + z_broad = (True,) * (xi.ndim - z.ndim) + z.broadcastable + + cond = [ + # The shapes of `y` and `xi` must either agree or `y` may + # also have shape equal to 1 which may be treated as a + # broadcastable dimension by the subtensor op. + or_(eq(y.shape[k], 1), eq(y.shape[k], xi.shape[k])) + # Loop over all dimensions. + for k in range(xi.ndim) + # We need to check the above shapes, if + # * the pre-alloc increment `z` is broadcastable in + # dimension `k` (if it isn't, then the shapes of `z` and + # `y` are the same by the definition of the `Alloc` op in + # this dimension and replacing `y` by `z` will not hide a + # shape error), and + # * `xi` and `y` do not have the same shape in dimension + # `k` or we cannot infer the shape statically (if the + # shapes of `xi` and `y` are not the same, then replacing + # `y` by `z` will hide the shape error of `y`), and + # * the shape of `y` is not equal to 1 or we cannot infer + # the shape statically (if the shape of `y` is equal to + # 1, then `y` is broadcasted by the inc_subtensor op + # internally, so the shapes of `xi` and `y` do not need + # to match in dimension `k`; else we need to check at + # runtime that the shape of `y` is either 1 or the same + # as `xi` or otherwise replacing `y` by `z` will hide a + # shape error). + if ( + z_broad[k] + and not same_shape(xi, y, dim_x=k, dim_y=k) + and shape_of[y][k] != 1 + ) + ] - if len(cond) > 0: - msg = "`x[i]` and `y` do not have the same shape." - z = Assert(msg)(z, *cond) + if len(cond) > 0: + msg = "`x[i]` and `y` do not have the same shape." + z = Assert(msg)(z, *cond) - r = node.op(x, z, *i) - # Copy over stacktrace from previous output, since - # we don't expect problems when removing the intermediate - # alloc operation and so we still want to point at the line - # of the inc_subtensor operation. - copy_stack_trace(node.outputs, r) + r = node.op(x, z, *i) + # Copy over stacktrace from previous output, since + # we don't expect problems when removing the intermediate + # alloc operation and so we still want to point at the line + # of the inc_subtensor operation. + copy_stack_trace(node.outputs, r) - return [r] + return [r] @register_specialize @@ -1591,7 +1642,7 @@ def local_blockwise_of_subtensor(fgraph, node): [idx.squeeze() for idx in idxs], node.op.core_op.idx_list ) # Add empty slices for the batch dims - none_slices = (slice(None),) * node.op.batch_ndim(node) + none_slices = (none_slice,) * node.op.batch_ndim(node) return [x[(*none_slices, *core_idxs)]] @@ -1710,7 +1761,7 @@ def local_blockwise_inc_subtensor(fgraph, node): else: # In the case we don't have batch indices, # we can use slice(None) to broadcast the core indices to each new batch dimension of x / y - batch_slices = [slice(None)] * batch_ndim + batch_slices = [none_slice] * batch_ndim new_idxs = (*batch_slices, *core_idxs) x_view = x[new_idxs] diff --git a/tests/tensor/rewriting/test_subtensor.py b/tests/tensor/rewriting/test_subtensor.py index 91a1f96e81..04de62980c 100644 --- a/tests/tensor/rewriting/test_subtensor.py +++ b/tests/tensor/rewriting/test_subtensor.py @@ -2,6 +2,7 @@ import numpy as np import pytest +from unittest_tools import assert_equal_computations import pytensor import pytensor.scalar as ps @@ -11,17 +12,19 @@ from pytensor.compile.mode import Mode, get_default_mode, get_mode from pytensor.compile.ops import DeepCopyOp from pytensor.configdefaults import config -from pytensor.graph import rewrite_graph, vectorize_graph +from pytensor.graph import FunctionGraph, rewrite_graph, vectorize_graph from pytensor.graph.basic import Constant, Variable, equal_computations -from pytensor.graph.rewriting.basic import check_stack_trace +from pytensor.graph.rewriting.basic import check_stack_trace, dfs_rewriter from pytensor.graph.traversal import ancestors from pytensor.raise_op import Assert from pytensor.tensor.basic import Alloc, _convert_to_int8 from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.math import Dot, dot, exp, sqr +from pytensor.tensor.rewriting.basic import constant_folding from pytensor.tensor.rewriting.subtensor import ( local_replace_AdvancedSubtensor, + local_useless_nested_set_subtensor, ) from pytensor.tensor.shape import ( SpecifyShape, @@ -2113,3 +2116,27 @@ def test_local_convert_negative_indices(): # TODO: If Subtensor decides to raise on make_node, this test can be removed rewritten_out = rewrite_graph(x[:, :, -2]) assert equal_computations([rewritten_out], [x[:, :, -2]]) + + +@pytest.mark.parametrize("inner_broadcasts", [False, True]) +def test_local_useless_nested_set_subtensor(inner_broadcasts): + y = scalar("y") + zero = pt.constant(0.0) + if inner_broadcasts: + # The inner alloc broadcasts along the sliced dimension + out = pt.alloc(zero, 5, 3)[1:].inc(pt.alloc(zero, 1, 3)[-1].inc(y)) + else: + out = pt.alloc(zero, 5, 3)[1:].inc(pt.alloc(zero, 4, 3)[-1].inc(y)) + + fgraph = FunctionGraph(outputs=[out], copy_inputs=False) + rewrite = dfs_rewriter(local_useless_nested_set_subtensor, constant_folding) + rewrite.rewrite(fgraph) + res = fgraph.outputs[0] + + expected = pt.alloc(zero, 5, 3)[-1].inc(y) + assert_equal_computations([res], [expected], original=[out]) + no_opt_mode = Mode("py", optimizer=None) + np.testing.assert_allclose( + out.eval({y: 3.0}, mode=no_opt_mode), + expected.eval({y: 3.0}, mode=no_opt_mode), + )