diff --git a/backends/vulkan/_passes/remove_redundant_ops.py b/backends/vulkan/_passes/remove_redundant_ops.py index 8e602dd17b4..25bdd34de70 100644 --- a/backends/vulkan/_passes/remove_redundant_ops.py +++ b/backends/vulkan/_passes/remove_redundant_ops.py @@ -31,35 +31,37 @@ class RemoveRedundantOpsTransform(ExportPass): exir_ops.edge.aten.lift_fresh_copy.default, exir_ops.edge.dim_order_ops._to_dim_order_copy.default, exir_ops.edge.dim_order_ops._clone_dim_order.default, + exir_ops.edge.aten.expand_copy.default, } def __init__(self) -> None: super(RemoveRedundantOpsTransform, self).__init__() def _should_remove(self, node: torch.fx.Node) -> bool: - if node.target in self.redundant_ops: - return True - - # Only remove to_copy if dtype does not change. Otherwise, memory format changes - # will be handled internally by the backend. - if ( - node.target == exir_ops.edge.aten._to_copy.default - or node.target == torch.ops.aten._to_copy.default - ): - src_dtype = node.meta["val"].dtype - # pyre-ignore - dst_dtype = node.args[0].meta["val"].dtype - return src_dtype == dst_dtype - - return False + if node.target not in self.redundant_ops: + return False + + orig_node = node.args[0] + assert isinstance(orig_node, torch.fx.Node) + + src_dtype = orig_node.meta["val"].dtype + dst_dtype = node.meta["val"].dtype + + # Do not remove if the op is converting the dtype. + if src_dtype != dst_dtype: + return False + + src_shape = orig_node.meta["val"].shape + dst_shape = node.meta["val"].shape + + return src_shape == dst_shape def _remove(self, graph_module: torch.fx.GraphModule) -> None: for node in graph_module.graph.nodes: if not self._should_remove(node): continue - with graph_module.graph.inserting_after(node): - node.replace_all_uses_with(node.args[0]) + node.replace_all_uses_with(node.args[0]) graph_module.graph.eliminate_dead_code() diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index da127f72528..e487491dfbb 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -7,17 +7,12 @@ # pyre-unsafe import operator - from typing import Any, Callable, Dict, List, Optional, Union import executorch.backends.vulkan.custom_ops_lib # noqa - import executorch.backends.vulkan.utils as utils - import torch - from executorch.exir.dialects._ops import ops as exir_ops - from executorch.exir.dialects.edge._ops import EdgeOpOverload from torch._subclasses.fake_tensor import FakeTensor @@ -129,6 +124,7 @@ def update_features_impl(op: OpKey): # Symbolic integer ops torch.ops.aten.sym_size.int, operator.add, + operator.sub, operator.lt, operator.gt, operator.ge, @@ -297,27 +293,9 @@ def check_to_copy_node(node: torch.fx.Node) -> bool: @update_features(exir_ops.edge.dim_order_ops._to_dim_order_copy.default) def register_to_copy_dim_order_op(): - # Currently there is no "real" implementation for to_dim_order_copy, but it can be - # removed as long as the operator is not changing the dtype, i.e. the operator call - # is modifying the dim order only. Therefore, check that the input and output dtypes - # are the same, if so the operator is safe to remove. - def check_dim_order_copy_node(node: torch.fx.Node) -> bool: - in_arg = node.args[0] - if not isinstance(in_arg, torch.fx.Node): - return False - - in_tensor = in_arg.meta.get("val", None) - out_tensor = node.meta.get("val", None) - - if in_tensor.dtype != out_tensor.dtype: - return False - - return True - return OpFeatures( - inputs_storage=utils.ANY_STORAGE, + inputs_storage=utils.ANY_BUFFER, supports_resize=True, - are_node_inputs_supported_fn=check_dim_order_copy_node, ) diff --git a/backends/vulkan/runtime/graph/ops/glsl/view_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/view_buffer.glsl index 2c02803a9b1..96b9aa85a1f 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/view_buffer.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/view_buffer.glsl @@ -18,6 +18,8 @@ ${layout_declare_ubo(B, "BufferMetadata", "inp")} layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +${layout_declare_spec_const(C, "int", "all_contiguous", "0")} + /* * The insight behind the view operation is that the contiguous index of each * tensor element in the input and output tensors are the same. @@ -28,17 +30,20 @@ void main() { return; } - TensorIndex outp_tidx; - linear_idx_to_tensor_idx(outp, outp_bufi, outp_tidx); + uint inp_bufi = outp_bufi; + if (all_contiguous == 0) { + TensorIndex outp_tidx; + linear_idx_to_tensor_idx(outp, outp_bufi, outp_tidx); - // To map the output to the input, find the input element that has the same - // contiguous index as the output element. - const uint contig_idx = tensor_idx_to_contiguous_idx(outp, outp_tidx); + // To map the output to the input, find the input element that has the same + // contiguous index as the output element. + const uint contig_idx = tensor_idx_to_contiguous_idx(outp, outp_tidx); - TensorIndex inp_tidx; - contiguous_idx_to_tensor_idx(inp, contig_idx, inp_tidx); + TensorIndex inp_tidx; + contiguous_idx_to_tensor_idx(inp, contig_idx, inp_tidx); - const uint inp_bufi = tensor_idx_to_linear_idx(inp, inp_tidx); + inp_bufi = tensor_idx_to_linear_idx(inp, inp_tidx); + } t_outp[outp_bufi] = t_inp[inp_bufi]; } diff --git a/backends/vulkan/runtime/graph/ops/glsl/view_convert_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/view_convert_buffer.glsl new file mode 100644 index 00000000000..a926c9fea11 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/view_convert_buffer.glsl @@ -0,0 +1,54 @@ +#version 450 core + +#define PRECISION ${PRECISION} + +#define IN_T ${buffer_scalar_type(IN_DTYPE)} +#define OUT_T ${buffer_scalar_type(OUT_DTYPE)} + +${define_required_extensions(IN_DTYPE)} +${define_required_extensions(OUT_DTYPE)} + +layout(std430) buffer; + +#include "indexing.glslh" + +${layout_declare_buffer(B, "w", "t_outp", OUT_DTYPE)} +${layout_declare_buffer(B, "r", "t_inp", IN_DTYPE)} + +${layout_declare_ubo(B, "BufferMetadata", "outp")} +${layout_declare_ubo(B, "BufferMetadata", "inp")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +${layout_declare_spec_const(C, "int", "all_contiguous", "0")} + +/* + * The insight behind the view_convert operation is that the contiguous index of each + * tensor element in the input and output tensors are the same, but the data types + * may be different and need conversion. + */ +void main() { + const uint outp_bufi = gl_GlobalInvocationID.x; + if (outp_bufi >= numel(outp)) { + return; + } + + uint inp_bufi = outp_bufi; + + if (all_contiguous == 0) { + TensorIndex outp_tidx; + linear_idx_to_tensor_idx(outp, outp_bufi, outp_tidx); + + // To map the output to the input, find the input element that has the same + // contiguous index as the output element. + const uint contig_idx = tensor_idx_to_contiguous_idx(outp, outp_tidx); + + TensorIndex inp_tidx; + contiguous_idx_to_tensor_idx(inp, contig_idx, inp_tidx); + + inp_bufi = tensor_idx_to_linear_idx(inp, inp_tidx); + } + + // Convert data type from input to output + t_outp[outp_bufi] = OUT_T(t_inp[inp_bufi]); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/view_convert_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/view_convert_buffer.yaml new file mode 100644 index 00000000000..11d56cad4a9 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/view_convert_buffer.yaml @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +view_convert_buffer: + parameter_names_with_default_values: + IN_DTYPE: float + OUT_DTYPE: float + STORAGE: buffer + generate_variant_forall: + combination: + parameter_names: [IN_DTYPE, OUT_DTYPE] + combos: + - parameter_values: [int32, float] + - parameter_values: [int32, half] + - parameter_values: [uint8, float] + - parameter_values: [uint8, half] + - parameter_values: [uint8, int32] + shader_variants: + - NAME: view_convert_buffer diff --git a/backends/vulkan/runtime/graph/ops/impl/Unsqueeze.cpp b/backends/vulkan/runtime/graph/ops/impl/Unsqueeze.cpp index 36a8ee4c3b1..602fe1ef129 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Unsqueeze.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Unsqueeze.cpp @@ -67,6 +67,18 @@ void resize_unsqueeze_node( std::vector out_sizes = graph->sizes_of(in); + std::vector unsqueezed_dims; + + if (graph->val_is_int_list(dims_ref)) { + const IntListPtr dims = graph->get_int_list(dims_ref); + for (int64_t d : *dims) { + unsqueezed_dims.push_back(d); + } + } else { + const int64_t dim = graph->extract_scalar(dims_ref); + unsqueezed_dims.push_back(dim); + } + // Insert singleton dimensions at the specified positions for (auto dim : dims_vec) { int64_t d = dim; diff --git a/backends/vulkan/runtime/graph/ops/impl/View.cpp b/backends/vulkan/runtime/graph/ops/impl/View.cpp index 8701a6246b0..5e2c898573a 100644 --- a/backends/vulkan/runtime/graph/ops/impl/View.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/View.cpp @@ -60,6 +60,16 @@ void resize_view_node( } } +void resize_to_dim_order_copy_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args) { + const ValueRef out = args.at(0).refs.at(0); + const ValueRef in = args.at(1).refs.at(0); + const std::vector in_sizes = graph->sizes_of(in); + graph->virtual_resize(out, in_sizes); +} + void add_view_node( ComputeGraph& graph, ValueRef in, @@ -98,6 +108,11 @@ void add_view_copy_buffer_node( std::string kernel_name = "view_buffer"; add_dtype_suffix(kernel_name, graph.dtype_of(out)); + bool all_contiguous = graph.is_contiguous_buffer_tensor(in) && + graph.is_contiguous_buffer_tensor(out); + + int32_t all_contiguous_int = all_contiguous ? 1 : 0; + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), @@ -110,7 +125,41 @@ void add_view_copy_buffer_node( // Push Constants {}, // Specialization Constants + {all_contiguous_int}, + // Resize Args + resize_args, + // Resizing Logic + resize_fn)); +} + +void add_view_copy_convert_buffer_node( + ComputeGraph& graph, + ValueRef in, + ValueRef out, + const std::vector& resize_args, + const ExecuteNode::ResizeFunction& resize_fn) { + std::string kernel_name = "view_convert_buffer"; + add_dtype_suffix(kernel_name, graph.dtype_of(in)); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); + + bool all_contiguous = graph.is_contiguous_buffer_tensor(in) && + graph.is_contiguous_buffer_tensor(out); + + int32_t all_contiguous_int = all_contiguous ? 1 : 0; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + default_pick_global_wg_size, + default_pick_local_wg_size, + // Inputs and Outputs + {{out, vkapi::kWrite}, {in, vkapi::kRead}}, + // Parameter Buffers + {graph.buffer_meta_ubo(out), graph.buffer_meta_ubo(in)}, + // Push Constants {}, + // Specialization Constants + {all_contiguous_int}, // Resize Args resize_args, // Resizing Logic @@ -132,8 +181,38 @@ void view(ComputeGraph& graph, const std::vector& args) { return add_view_node(graph, in, sizes, out); } +void to_dim_order_copy(ComputeGraph& graph, const std::vector& args) { + int args_idx = 0; + const ValueRef in = args.at(args_idx++); + const ValueRef dtype = args.at(args_idx++); + (void)dtype; + const ValueRef layout = args.at(args_idx++); + (void)layout; + const ValueRef device = args.at(args_idx++); + (void)device; + const ValueRef pin_memory = args.at(args_idx++); + (void)pin_memory; + const ValueRef non_blocking = args.at(args_idx++); + (void)non_blocking; + const ValueRef dim_order = args.at(args_idx++); + (void)dim_order; + + const ValueRef out = args.at(args_idx++); + + VK_CHECK_COND(graph.is_buffer_storage(in) && graph.is_buffer_storage(out)); + + if (graph.dtype_of(in) == graph.dtype_of(out)) { + return add_view_copy_buffer_node( + graph, in, out, {}, resize_to_dim_order_copy_node); + } + + return add_view_copy_convert_buffer_node( + graph, in, out, {}, resize_to_dim_order_copy_node); +} + REGISTER_OPERATORS { VK_REGISTER_OP(aten.view_copy.default, view); + VK_REGISTER_OP(dim_order_ops._to_dim_order_copy.default, to_dim_order_copy); } } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/View.h b/backends/vulkan/runtime/graph/ops/impl/View.h index 7a7a8d57742..c8e52492417 100644 --- a/backends/vulkan/runtime/graph/ops/impl/View.h +++ b/backends/vulkan/runtime/graph/ops/impl/View.h @@ -24,6 +24,19 @@ void add_view_copy_buffer_node( const std::vector& resize_args, const ExecuteNode::ResizeFunction& resize_fn); +/* + * Dispatches the view_convert_buffer compute shader. This can be used to + * implement ops that preserve the "contiguous" indexes of elements between the + * input and output while converting between different data types such as + * view_copy with dtype conversion. + */ +void add_view_copy_convert_buffer_node( + ComputeGraph& graph, + ValueRef in, + ValueRef out, + const std::vector& resize_args, + const ExecuteNode::ResizeFunction& resize_fn); + void add_view_node( ComputeGraph& graph, ValueRef in, diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py index f92cea64767..f38c510a8b1 100644 --- a/backends/vulkan/test/test_vulkan_delegate.py +++ b/backends/vulkan/test/test_vulkan_delegate.py @@ -11,20 +11,14 @@ from typing import Tuple import executorch.backends.vulkan.test.utils as test_utils - import torch - from executorch.backends.transforms.convert_dtype_pass import I64toI32 - from executorch.backends.vulkan.partitioner.vulkan_partitioner import VulkanPartitioner - from executorch.backends.vulkan.vulkan_preprocess import VulkanBackend - from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( get_symmetric_quantization_config, XNNPACKQuantizer, ) - from executorch.exir import ( EdgeCompileConfig, EdgeProgramManager, @@ -36,11 +30,8 @@ ) from executorch.extension.pytree import tree_flatten from torch.export import Dim, export, ExportedProgram - from torchao.quantization.granularity import PerGroup - from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e - from torchao.quantization.pt2e.quantizer import Quantizer from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_ from torchao.utils import unwrap_tensor_subclass @@ -69,9 +60,6 @@ def lower_module( edge_program = to_edge_transform_and_lower( program, compile_config=edge_compile_config, - transform_passes=[ - I64toI32(edge_compile_config._skip_dim_order), - ], partitioner=[VulkanPartitioner(compile_options)], )