Skip to content

Commit d0e8131

Browse files
authored
[ET-VK] Implementation of to_dim_order_copy
Differential Revision: D86340341 Pull Request resolved: #15619
1 parent 0d33060 commit d0e8131

File tree

9 files changed

+214
-61
lines changed

9 files changed

+214
-61
lines changed

backends/vulkan/_passes/remove_redundant_ops.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -31,35 +31,37 @@ class RemoveRedundantOpsTransform(ExportPass):
3131
exir_ops.edge.aten.lift_fresh_copy.default,
3232
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
3333
exir_ops.edge.dim_order_ops._clone_dim_order.default,
34+
exir_ops.edge.aten.expand_copy.default,
3435
}
3536

3637
def __init__(self) -> None:
3738
super(RemoveRedundantOpsTransform, self).__init__()
3839

3940
def _should_remove(self, node: torch.fx.Node) -> bool:
40-
if node.target in self.redundant_ops:
41-
return True
42-
43-
# Only remove to_copy if dtype does not change. Otherwise, memory format changes
44-
# will be handled internally by the backend.
45-
if (
46-
node.target == exir_ops.edge.aten._to_copy.default
47-
or node.target == torch.ops.aten._to_copy.default
48-
):
49-
src_dtype = node.meta["val"].dtype
50-
# pyre-ignore
51-
dst_dtype = node.args[0].meta["val"].dtype
52-
return src_dtype == dst_dtype
53-
54-
return False
41+
if node.target not in self.redundant_ops:
42+
return False
43+
44+
orig_node = node.args[0]
45+
assert isinstance(orig_node, torch.fx.Node)
46+
47+
src_dtype = orig_node.meta["val"].dtype
48+
dst_dtype = node.meta["val"].dtype
49+
50+
# Do not remove if the op is converting the dtype.
51+
if src_dtype != dst_dtype:
52+
return False
53+
54+
src_shape = orig_node.meta["val"].shape
55+
dst_shape = node.meta["val"].shape
56+
57+
return src_shape == dst_shape
5558

5659
def _remove(self, graph_module: torch.fx.GraphModule) -> None:
5760
for node in graph_module.graph.nodes:
5861
if not self._should_remove(node):
5962
continue
6063

61-
with graph_module.graph.inserting_after(node):
62-
node.replace_all_uses_with(node.args[0])
64+
node.replace_all_uses_with(node.args[0])
6365

6466
graph_module.graph.eliminate_dead_code()
6567

backends/vulkan/op_registry.py

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,12 @@
77
# pyre-unsafe
88

99
import operator
10-
1110
from typing import Any, Callable, Dict, List, Optional, Union
1211

1312
import executorch.backends.vulkan.custom_ops_lib # noqa
14-
1513
import executorch.backends.vulkan.utils as utils
16-
1714
import torch
18-
1915
from executorch.exir.dialects._ops import ops as exir_ops
20-
2116
from executorch.exir.dialects.edge._ops import EdgeOpOverload
2217
from torch._subclasses.fake_tensor import FakeTensor
2318

@@ -129,6 +124,7 @@ def update_features_impl(op: OpKey):
129124
# Symbolic integer ops
130125
torch.ops.aten.sym_size.int,
131126
operator.add,
127+
operator.sub,
132128
operator.lt,
133129
operator.gt,
134130
operator.ge,
@@ -297,27 +293,9 @@ def check_to_copy_node(node: torch.fx.Node) -> bool:
297293

298294
@update_features(exir_ops.edge.dim_order_ops._to_dim_order_copy.default)
299295
def register_to_copy_dim_order_op():
300-
# Currently there is no "real" implementation for to_dim_order_copy, but it can be
301-
# removed as long as the operator is not changing the dtype, i.e. the operator call
302-
# is modifying the dim order only. Therefore, check that the input and output dtypes
303-
# are the same, if so the operator is safe to remove.
304-
def check_dim_order_copy_node(node: torch.fx.Node) -> bool:
305-
in_arg = node.args[0]
306-
if not isinstance(in_arg, torch.fx.Node):
307-
return False
308-
309-
in_tensor = in_arg.meta.get("val", None)
310-
out_tensor = node.meta.get("val", None)
311-
312-
if in_tensor.dtype != out_tensor.dtype:
313-
return False
314-
315-
return True
316-
317296
return OpFeatures(
318-
inputs_storage=utils.ANY_STORAGE,
297+
inputs_storage=utils.ANY_BUFFER,
319298
supports_resize=True,
320-
are_node_inputs_supported_fn=check_dim_order_copy_node,
321299
)
322300

323301

backends/vulkan/runtime/graph/ops/glsl/view_buffer.glsl

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ ${layout_declare_ubo(B, "BufferMetadata", "inp")}
1818

1919
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
2020

21+
${layout_declare_spec_const(C, "int", "all_contiguous", "0")}
22+
2123
/*
2224
* The insight behind the view operation is that the contiguous index of each
2325
* tensor element in the input and output tensors are the same.
@@ -28,17 +30,20 @@ void main() {
2830
return;
2931
}
3032

31-
TensorIndex outp_tidx;
32-
linear_idx_to_tensor_idx(outp, outp_bufi, outp_tidx);
33+
uint inp_bufi = outp_bufi;
34+
if (all_contiguous == 0) {
35+
TensorIndex outp_tidx;
36+
linear_idx_to_tensor_idx(outp, outp_bufi, outp_tidx);
3337

34-
// To map the output to the input, find the input element that has the same
35-
// contiguous index as the output element.
36-
const uint contig_idx = tensor_idx_to_contiguous_idx(outp, outp_tidx);
38+
// To map the output to the input, find the input element that has the same
39+
// contiguous index as the output element.
40+
const uint contig_idx = tensor_idx_to_contiguous_idx(outp, outp_tidx);
3741

38-
TensorIndex inp_tidx;
39-
contiguous_idx_to_tensor_idx(inp, contig_idx, inp_tidx);
42+
TensorIndex inp_tidx;
43+
contiguous_idx_to_tensor_idx(inp, contig_idx, inp_tidx);
4044

41-
const uint inp_bufi = tensor_idx_to_linear_idx(inp, inp_tidx);
45+
inp_bufi = tensor_idx_to_linear_idx(inp, inp_tidx);
46+
}
4247

4348
t_outp[outp_bufi] = t_inp[inp_bufi];
4449
}
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
#version 450 core
2+
3+
#define PRECISION ${PRECISION}
4+
5+
#define IN_T ${buffer_scalar_type(IN_DTYPE)}
6+
#define OUT_T ${buffer_scalar_type(OUT_DTYPE)}
7+
8+
${define_required_extensions(IN_DTYPE)}
9+
${define_required_extensions(OUT_DTYPE)}
10+
11+
layout(std430) buffer;
12+
13+
#include "indexing.glslh"
14+
15+
${layout_declare_buffer(B, "w", "t_outp", OUT_DTYPE)}
16+
${layout_declare_buffer(B, "r", "t_inp", IN_DTYPE)}
17+
18+
${layout_declare_ubo(B, "BufferMetadata", "outp")}
19+
${layout_declare_ubo(B, "BufferMetadata", "inp")}
20+
21+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
22+
23+
${layout_declare_spec_const(C, "int", "all_contiguous", "0")}
24+
25+
/*
26+
* The insight behind the view_convert operation is that the contiguous index of each
27+
* tensor element in the input and output tensors are the same, but the data types
28+
* may be different and need conversion.
29+
*/
30+
void main() {
31+
const uint outp_bufi = gl_GlobalInvocationID.x;
32+
if (outp_bufi >= numel(outp)) {
33+
return;
34+
}
35+
36+
uint inp_bufi = outp_bufi;
37+
38+
if (all_contiguous == 0) {
39+
TensorIndex outp_tidx;
40+
linear_idx_to_tensor_idx(outp, outp_bufi, outp_tidx);
41+
42+
// To map the output to the input, find the input element that has the same
43+
// contiguous index as the output element.
44+
const uint contig_idx = tensor_idx_to_contiguous_idx(outp, outp_tidx);
45+
46+
TensorIndex inp_tidx;
47+
contiguous_idx_to_tensor_idx(inp, contig_idx, inp_tidx);
48+
49+
inp_bufi = tensor_idx_to_linear_idx(inp, inp_tidx);
50+
}
51+
52+
// Convert data type from input to output
53+
t_outp[outp_bufi] = OUT_T(t_inp[inp_bufi]);
54+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
view_convert_buffer:
8+
parameter_names_with_default_values:
9+
IN_DTYPE: float
10+
OUT_DTYPE: float
11+
STORAGE: buffer
12+
generate_variant_forall:
13+
combination:
14+
parameter_names: [IN_DTYPE, OUT_DTYPE]
15+
combos:
16+
- parameter_values: [int32, float]
17+
- parameter_values: [int32, half]
18+
- parameter_values: [uint8, float]
19+
- parameter_values: [uint8, half]
20+
- parameter_values: [uint8, int32]
21+
shader_variants:
22+
- NAME: view_convert_buffer

backends/vulkan/runtime/graph/ops/impl/Unsqueeze.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,18 @@ void resize_unsqueeze_node(
6767

6868
std::vector<int64_t> out_sizes = graph->sizes_of(in);
6969

70+
std::vector<int64_t> unsqueezed_dims;
71+
72+
if (graph->val_is_int_list(dims_ref)) {
73+
const IntListPtr dims = graph->get_int_list(dims_ref);
74+
for (int64_t d : *dims) {
75+
unsqueezed_dims.push_back(d);
76+
}
77+
} else {
78+
const int64_t dim = graph->extract_scalar<int64_t>(dims_ref);
79+
unsqueezed_dims.push_back(dim);
80+
}
81+
7082
// Insert singleton dimensions at the specified positions
7183
for (auto dim : dims_vec) {
7284
int64_t d = dim;

backends/vulkan/runtime/graph/ops/impl/View.cpp

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,16 @@ void resize_view_node(
6060
}
6161
}
6262

63+
void resize_to_dim_order_copy_node(
64+
ComputeGraph* graph,
65+
const std::vector<ArgGroup>& args,
66+
const std::vector<ValueRef>& extra_args) {
67+
const ValueRef out = args.at(0).refs.at(0);
68+
const ValueRef in = args.at(1).refs.at(0);
69+
const std::vector<int64_t> in_sizes = graph->sizes_of(in);
70+
graph->virtual_resize(out, in_sizes);
71+
}
72+
6373
void add_view_node(
6474
ComputeGraph& graph,
6575
ValueRef in,
@@ -98,6 +108,11 @@ void add_view_copy_buffer_node(
98108
std::string kernel_name = "view_buffer";
99109
add_dtype_suffix(kernel_name, graph.dtype_of(out));
100110

111+
bool all_contiguous = graph.is_contiguous_buffer_tensor(in) &&
112+
graph.is_contiguous_buffer_tensor(out);
113+
114+
int32_t all_contiguous_int = all_contiguous ? 1 : 0;
115+
101116
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
102117
graph,
103118
VK_KERNEL_FROM_STR(kernel_name),
@@ -110,7 +125,41 @@ void add_view_copy_buffer_node(
110125
// Push Constants
111126
{},
112127
// Specialization Constants
128+
{all_contiguous_int},
129+
// Resize Args
130+
resize_args,
131+
// Resizing Logic
132+
resize_fn));
133+
}
134+
135+
void add_view_copy_convert_buffer_node(
136+
ComputeGraph& graph,
137+
ValueRef in,
138+
ValueRef out,
139+
const std::vector<ValueRef>& resize_args,
140+
const ExecuteNode::ResizeFunction& resize_fn) {
141+
std::string kernel_name = "view_convert_buffer";
142+
add_dtype_suffix(kernel_name, graph.dtype_of(in));
143+
add_dtype_suffix(kernel_name, graph.dtype_of(out));
144+
145+
bool all_contiguous = graph.is_contiguous_buffer_tensor(in) &&
146+
graph.is_contiguous_buffer_tensor(out);
147+
148+
int32_t all_contiguous_int = all_contiguous ? 1 : 0;
149+
150+
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
151+
graph,
152+
VK_KERNEL_FROM_STR(kernel_name),
153+
default_pick_global_wg_size,
154+
default_pick_local_wg_size,
155+
// Inputs and Outputs
156+
{{out, vkapi::kWrite}, {in, vkapi::kRead}},
157+
// Parameter Buffers
158+
{graph.buffer_meta_ubo(out), graph.buffer_meta_ubo(in)},
159+
// Push Constants
113160
{},
161+
// Specialization Constants
162+
{all_contiguous_int},
114163
// Resize Args
115164
resize_args,
116165
// Resizing Logic
@@ -132,8 +181,38 @@ void view(ComputeGraph& graph, const std::vector<ValueRef>& args) {
132181
return add_view_node(graph, in, sizes, out);
133182
}
134183

184+
void to_dim_order_copy(ComputeGraph& graph, const std::vector<ValueRef>& args) {
185+
int args_idx = 0;
186+
const ValueRef in = args.at(args_idx++);
187+
const ValueRef dtype = args.at(args_idx++);
188+
(void)dtype;
189+
const ValueRef layout = args.at(args_idx++);
190+
(void)layout;
191+
const ValueRef device = args.at(args_idx++);
192+
(void)device;
193+
const ValueRef pin_memory = args.at(args_idx++);
194+
(void)pin_memory;
195+
const ValueRef non_blocking = args.at(args_idx++);
196+
(void)non_blocking;
197+
const ValueRef dim_order = args.at(args_idx++);
198+
(void)dim_order;
199+
200+
const ValueRef out = args.at(args_idx++);
201+
202+
VK_CHECK_COND(graph.is_buffer_storage(in) && graph.is_buffer_storage(out));
203+
204+
if (graph.dtype_of(in) == graph.dtype_of(out)) {
205+
return add_view_copy_buffer_node(
206+
graph, in, out, {}, resize_to_dim_order_copy_node);
207+
}
208+
209+
return add_view_copy_convert_buffer_node(
210+
graph, in, out, {}, resize_to_dim_order_copy_node);
211+
}
212+
135213
REGISTER_OPERATORS {
136214
VK_REGISTER_OP(aten.view_copy.default, view);
215+
VK_REGISTER_OP(dim_order_ops._to_dim_order_copy.default, to_dim_order_copy);
137216
}
138217

139218
} // namespace vkcompute

backends/vulkan/runtime/graph/ops/impl/View.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,19 @@ void add_view_copy_buffer_node(
2424
const std::vector<ValueRef>& resize_args,
2525
const ExecuteNode::ResizeFunction& resize_fn);
2626

27+
/*
28+
* Dispatches the view_convert_buffer compute shader. This can be used to
29+
* implement ops that preserve the "contiguous" indexes of elements between the
30+
* input and output while converting between different data types such as
31+
* view_copy with dtype conversion.
32+
*/
33+
void add_view_copy_convert_buffer_node(
34+
ComputeGraph& graph,
35+
ValueRef in,
36+
ValueRef out,
37+
const std::vector<ValueRef>& resize_args,
38+
const ExecuteNode::ResizeFunction& resize_fn);
39+
2740
void add_view_node(
2841
ComputeGraph& graph,
2942
ValueRef in,

0 commit comments

Comments
 (0)