Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 19 additions & 17 deletions backends/vulkan/_passes/remove_redundant_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,35 +31,37 @@ class RemoveRedundantOpsTransform(ExportPass):
exir_ops.edge.aten.lift_fresh_copy.default,
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
exir_ops.edge.dim_order_ops._clone_dim_order.default,
exir_ops.edge.aten.expand_copy.default,
}

def __init__(self) -> None:
super(RemoveRedundantOpsTransform, self).__init__()

def _should_remove(self, node: torch.fx.Node) -> bool:
if node.target in self.redundant_ops:
return True

# Only remove to_copy if dtype does not change. Otherwise, memory format changes
# will be handled internally by the backend.
if (
node.target == exir_ops.edge.aten._to_copy.default
or node.target == torch.ops.aten._to_copy.default
):
src_dtype = node.meta["val"].dtype
# pyre-ignore
dst_dtype = node.args[0].meta["val"].dtype
return src_dtype == dst_dtype

return False
if node.target not in self.redundant_ops:
return False

orig_node = node.args[0]
assert isinstance(orig_node, torch.fx.Node)

src_dtype = orig_node.meta["val"].dtype
dst_dtype = node.meta["val"].dtype

# Do not remove if the op is converting the dtype.
if src_dtype != dst_dtype:
return False

src_shape = orig_node.meta["val"].shape
dst_shape = node.meta["val"].shape

return src_shape == dst_shape

def _remove(self, graph_module: torch.fx.GraphModule) -> None:
for node in graph_module.graph.nodes:
if not self._should_remove(node):
continue

with graph_module.graph.inserting_after(node):
node.replace_all_uses_with(node.args[0])
node.replace_all_uses_with(node.args[0])

graph_module.graph.eliminate_dead_code()

Expand Down
26 changes: 2 additions & 24 deletions backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,12 @@
# pyre-unsafe

import operator

from typing import Any, Callable, Dict, List, Optional, Union

import executorch.backends.vulkan.custom_ops_lib # noqa

import executorch.backends.vulkan.utils as utils

import torch

from executorch.exir.dialects._ops import ops as exir_ops

from executorch.exir.dialects.edge._ops import EdgeOpOverload
from torch._subclasses.fake_tensor import FakeTensor

Expand Down Expand Up @@ -129,6 +124,7 @@ def update_features_impl(op: OpKey):
# Symbolic integer ops
torch.ops.aten.sym_size.int,
operator.add,
operator.sub,
operator.lt,
operator.gt,
operator.ge,
Expand Down Expand Up @@ -297,27 +293,9 @@ def check_to_copy_node(node: torch.fx.Node) -> bool:

@update_features(exir_ops.edge.dim_order_ops._to_dim_order_copy.default)
def register_to_copy_dim_order_op():
# Currently there is no "real" implementation for to_dim_order_copy, but it can be
# removed as long as the operator is not changing the dtype, i.e. the operator call
# is modifying the dim order only. Therefore, check that the input and output dtypes
# are the same, if so the operator is safe to remove.
def check_dim_order_copy_node(node: torch.fx.Node) -> bool:
in_arg = node.args[0]
if not isinstance(in_arg, torch.fx.Node):
return False

in_tensor = in_arg.meta.get("val", None)
out_tensor = node.meta.get("val", None)

if in_tensor.dtype != out_tensor.dtype:
return False

return True

return OpFeatures(
inputs_storage=utils.ANY_STORAGE,
inputs_storage=utils.ANY_BUFFER,
supports_resize=True,
are_node_inputs_supported_fn=check_dim_order_copy_node,
)


Expand Down
21 changes: 13 additions & 8 deletions backends/vulkan/runtime/graph/ops/glsl/view_buffer.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ ${layout_declare_ubo(B, "BufferMetadata", "inp")}

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

${layout_declare_spec_const(C, "int", "all_contiguous", "0")}

/*
* The insight behind the view operation is that the contiguous index of each
* tensor element in the input and output tensors are the same.
Expand All @@ -28,17 +30,20 @@ void main() {
return;
}

TensorIndex outp_tidx;
linear_idx_to_tensor_idx(outp, outp_bufi, outp_tidx);
uint inp_bufi = outp_bufi;
if (all_contiguous == 0) {
TensorIndex outp_tidx;
linear_idx_to_tensor_idx(outp, outp_bufi, outp_tidx);

// To map the output to the input, find the input element that has the same
// contiguous index as the output element.
const uint contig_idx = tensor_idx_to_contiguous_idx(outp, outp_tidx);
// To map the output to the input, find the input element that has the same
// contiguous index as the output element.
const uint contig_idx = tensor_idx_to_contiguous_idx(outp, outp_tidx);

TensorIndex inp_tidx;
contiguous_idx_to_tensor_idx(inp, contig_idx, inp_tidx);
TensorIndex inp_tidx;
contiguous_idx_to_tensor_idx(inp, contig_idx, inp_tidx);

const uint inp_bufi = tensor_idx_to_linear_idx(inp, inp_tidx);
inp_bufi = tensor_idx_to_linear_idx(inp, inp_tidx);
}

t_outp[outp_bufi] = t_inp[inp_bufi];
}
54 changes: 54 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/view_convert_buffer.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#version 450 core

#define PRECISION ${PRECISION}

#define IN_T ${buffer_scalar_type(IN_DTYPE)}
#define OUT_T ${buffer_scalar_type(OUT_DTYPE)}

${define_required_extensions(IN_DTYPE)}
${define_required_extensions(OUT_DTYPE)}

layout(std430) buffer;

#include "indexing.glslh"

${layout_declare_buffer(B, "w", "t_outp", OUT_DTYPE)}
${layout_declare_buffer(B, "r", "t_inp", IN_DTYPE)}

${layout_declare_ubo(B, "BufferMetadata", "outp")}
${layout_declare_ubo(B, "BufferMetadata", "inp")}

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

${layout_declare_spec_const(C, "int", "all_contiguous", "0")}

/*
* The insight behind the view_convert operation is that the contiguous index of each
* tensor element in the input and output tensors are the same, but the data types
* may be different and need conversion.
*/
void main() {
const uint outp_bufi = gl_GlobalInvocationID.x;
if (outp_bufi >= numel(outp)) {
return;
}

uint inp_bufi = outp_bufi;

if (all_contiguous == 0) {
TensorIndex outp_tidx;
linear_idx_to_tensor_idx(outp, outp_bufi, outp_tidx);

// To map the output to the input, find the input element that has the same
// contiguous index as the output element.
const uint contig_idx = tensor_idx_to_contiguous_idx(outp, outp_tidx);

TensorIndex inp_tidx;
contiguous_idx_to_tensor_idx(inp, contig_idx, inp_tidx);

inp_bufi = tensor_idx_to_linear_idx(inp, inp_tidx);
}

// Convert data type from input to output
t_outp[outp_bufi] = OUT_T(t_inp[inp_bufi]);
}
22 changes: 22 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/view_convert_buffer.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

view_convert_buffer:
parameter_names_with_default_values:
IN_DTYPE: float
OUT_DTYPE: float
STORAGE: buffer
generate_variant_forall:
combination:
parameter_names: [IN_DTYPE, OUT_DTYPE]
combos:
- parameter_values: [int32, float]
- parameter_values: [int32, half]
- parameter_values: [uint8, float]
- parameter_values: [uint8, half]
- parameter_values: [uint8, int32]
shader_variants:
- NAME: view_convert_buffer
12 changes: 12 additions & 0 deletions backends/vulkan/runtime/graph/ops/impl/Unsqueeze.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,18 @@ void resize_unsqueeze_node(

std::vector<int64_t> out_sizes = graph->sizes_of(in);

std::vector<int64_t> unsqueezed_dims;

if (graph->val_is_int_list(dims_ref)) {
const IntListPtr dims = graph->get_int_list(dims_ref);
for (int64_t d : *dims) {
unsqueezed_dims.push_back(d);
}
} else {
const int64_t dim = graph->extract_scalar<int64_t>(dims_ref);
unsqueezed_dims.push_back(dim);
}

// Insert singleton dimensions at the specified positions
for (auto dim : dims_vec) {
int64_t d = dim;
Expand Down
79 changes: 79 additions & 0 deletions backends/vulkan/runtime/graph/ops/impl/View.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,16 @@ void resize_view_node(
}
}

void resize_to_dim_order_copy_node(
ComputeGraph* graph,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& extra_args) {
const ValueRef out = args.at(0).refs.at(0);
const ValueRef in = args.at(1).refs.at(0);
const std::vector<int64_t> in_sizes = graph->sizes_of(in);
graph->virtual_resize(out, in_sizes);
}

void add_view_node(
ComputeGraph& graph,
ValueRef in,
Expand Down Expand Up @@ -98,6 +108,11 @@ void add_view_copy_buffer_node(
std::string kernel_name = "view_buffer";
add_dtype_suffix(kernel_name, graph.dtype_of(out));

bool all_contiguous = graph.is_contiguous_buffer_tensor(in) &&
graph.is_contiguous_buffer_tensor(out);

int32_t all_contiguous_int = all_contiguous ? 1 : 0;

graph.execute_nodes().emplace_back(new DynamicDispatchNode(
graph,
VK_KERNEL_FROM_STR(kernel_name),
Expand All @@ -110,7 +125,41 @@ void add_view_copy_buffer_node(
// Push Constants
{},
// Specialization Constants
{all_contiguous_int},
// Resize Args
resize_args,
// Resizing Logic
resize_fn));
}

void add_view_copy_convert_buffer_node(
ComputeGraph& graph,
ValueRef in,
ValueRef out,
const std::vector<ValueRef>& resize_args,
const ExecuteNode::ResizeFunction& resize_fn) {
std::string kernel_name = "view_convert_buffer";
add_dtype_suffix(kernel_name, graph.dtype_of(in));
add_dtype_suffix(kernel_name, graph.dtype_of(out));

bool all_contiguous = graph.is_contiguous_buffer_tensor(in) &&
graph.is_contiguous_buffer_tensor(out);

int32_t all_contiguous_int = all_contiguous ? 1 : 0;

graph.execute_nodes().emplace_back(new DynamicDispatchNode(
graph,
VK_KERNEL_FROM_STR(kernel_name),
default_pick_global_wg_size,
default_pick_local_wg_size,
// Inputs and Outputs
{{out, vkapi::kWrite}, {in, vkapi::kRead}},
// Parameter Buffers
{graph.buffer_meta_ubo(out), graph.buffer_meta_ubo(in)},
// Push Constants
{},
// Specialization Constants
{all_contiguous_int},
// Resize Args
resize_args,
// Resizing Logic
Expand All @@ -132,8 +181,38 @@ void view(ComputeGraph& graph, const std::vector<ValueRef>& args) {
return add_view_node(graph, in, sizes, out);
}

void to_dim_order_copy(ComputeGraph& graph, const std::vector<ValueRef>& args) {
int args_idx = 0;
const ValueRef in = args.at(args_idx++);
const ValueRef dtype = args.at(args_idx++);
(void)dtype;
const ValueRef layout = args.at(args_idx++);
(void)layout;
const ValueRef device = args.at(args_idx++);
(void)device;
const ValueRef pin_memory = args.at(args_idx++);
(void)pin_memory;
const ValueRef non_blocking = args.at(args_idx++);
(void)non_blocking;
const ValueRef dim_order = args.at(args_idx++);
(void)dim_order;

const ValueRef out = args.at(args_idx++);

VK_CHECK_COND(graph.is_buffer_storage(in) && graph.is_buffer_storage(out));

if (graph.dtype_of(in) == graph.dtype_of(out)) {
return add_view_copy_buffer_node(
graph, in, out, {}, resize_to_dim_order_copy_node);
}

return add_view_copy_convert_buffer_node(
graph, in, out, {}, resize_to_dim_order_copy_node);
}

REGISTER_OPERATORS {
VK_REGISTER_OP(aten.view_copy.default, view);
VK_REGISTER_OP(dim_order_ops._to_dim_order_copy.default, to_dim_order_copy);
}

} // namespace vkcompute
13 changes: 13 additions & 0 deletions backends/vulkan/runtime/graph/ops/impl/View.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,19 @@ void add_view_copy_buffer_node(
const std::vector<ValueRef>& resize_args,
const ExecuteNode::ResizeFunction& resize_fn);

/*
* Dispatches the view_convert_buffer compute shader. This can be used to
* implement ops that preserve the "contiguous" indexes of elements between the
* input and output while converting between different data types such as
* view_copy with dtype conversion.
*/
void add_view_copy_convert_buffer_node(
ComputeGraph& graph,
ValueRef in,
ValueRef out,
const std::vector<ValueRef>& resize_args,
const ExecuteNode::ResizeFunction& resize_fn);

void add_view_node(
ComputeGraph& graph,
ValueRef in,
Expand Down
Loading
Loading