Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 0 additions & 14 deletions backends/vulkan/_passes/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -63,19 +63,6 @@ runtime.python_library(
],
)

runtime.python_library(
name = "remove_local_scalar_dense",
srcs = ["remove_local_scalar_dense_ops.py"],
visibility = [
"//executorch/backends/...",
],
deps = [
"//caffe2:torch",
"//executorch/exir:pass_base",
"//executorch/exir/dialects:lib",
],
)

runtime.python_library(
name = "remove_redundant_ops",
srcs = ["remove_redundant_ops.py"],
Expand Down Expand Up @@ -161,7 +148,6 @@ runtime.python_library(
":fuse_quantized_ops",
":insert_prepack_nodes",
":remove_asserts",
":remove_local_scalar_dense",
":remove_redundant_ops",
":replace_qdq",
":squeeze_unsqueeze_inputs",
Expand Down
4 changes: 0 additions & 4 deletions backends/vulkan/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@
remove_asserts,
RemoveAssertsTransform,
)
from executorch.backends.vulkan._passes.remove_local_scalar_dense_ops import (
RemoveLocalScalarDenseOpsTransform,
)
from executorch.backends.vulkan._passes.remove_redundant_ops import (
RemoveRedundantOpsTransform,
)
Expand All @@ -35,7 +32,6 @@
"insert_prepack_nodes",
"remove_asserts",
"RemoveAssertsTransform",
"RemoveLocalScalarDenseOpsTransform",
"RemoveRedundantOpsTransform",
"ReplaceQDQPass",
"SqueezeUnsqueezeInputs",
Expand Down
110 changes: 0 additions & 110 deletions backends/vulkan/_passes/remove_local_scalar_dense_ops.py

This file was deleted.

17 changes: 17 additions & 0 deletions backends/vulkan/custom_ops_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import executorch.backends.vulkan.patterns as vk_patterns
import torch.library

from torch._subclasses.fake_tensor import FakeTensor

namespace = "et_vk"
lib = torch.library.Library(namespace, "DEF")

Expand Down Expand Up @@ -614,3 +616,18 @@ def add_q8ta_q8ta_q8to_impl(
)
lib.impl(name, add_q8ta_q8ta_q8to_impl, "CompositeExplicitAutograd")
add_q8ta_q8ta_q8to_op = getattr(getattr(torch.ops, namespace), name)

#############################
## select_as_symint ##
#############################


def select_as_symint_impl(x: torch.Tensor, dim: int, index: int):
assert isinstance(x, FakeTensor)
return x.fake_mode.shape_env.create_unbacked_symint()


name = "select_as_symint"
lib.define(f"{name}(Tensor x, int dim, int index) -> SymInt")
lib.impl(name, select_as_symint_impl, "Meta")
select_as_symint_op = getattr(getattr(torch.ops, namespace), name)
40 changes: 0 additions & 40 deletions backends/vulkan/partitioner/vulkan_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,36 +184,6 @@ def is_linear_permute(self, node: torch.fx.Node) -> Tuple[bool, bool]:

return False, False

def is_in_local_scalar_dense_chain(self, node: torch.fx.Node) -> Tuple[bool, bool]:
"""
Scalar tensors are usually converted to scalar values in the graph via`
scalar_tensor[0].item()` in Python, which translates to a chain of
`local_scalar_dense(torch.select.int(scalar_tensor, 0, 0))` in the graph.
This function marks the entire chain as supported by the Vulkan delegate.

Later, within vulkan_preprocess there will be a graph transform which replaces
the chain with passing in the scalar tensor directly.

Similar to the `is_linear_permute` function, this function has 2 return values.
"""
if node.target == exir_ops.edge.aten.select_copy.int:
if len(node.users) != 1:
return False, False
# pyre-ignore
if node.args[0].meta["val"].numel() != 1:
return False, False

local_scalar_dense = list(node.users.keys())[0]
if local_scalar_dense.target != torch.ops.aten._local_scalar_dense.default:
return False, False

return self.is_in_local_scalar_dense_chain(local_scalar_dense)

if node.target == torch.ops.aten._local_scalar_dense.default:
return True, all(self.node_is_compatible(user)[0] for user in node.users)

return False, False

def log_skip(self, node: torch.fx.Node, reason: str) -> None:
if node.op == "call_function":
logger.info(
Expand Down Expand Up @@ -261,16 +231,6 @@ def _is_node_supported(self, node: torch.fx.Node) -> bool: # noqa: C901
self.log_skip(node, "permute node of non compatible linear node")
return False

(
is_in_local_scalar_dense_chain,
dst_node_is_compatible,
) = self.is_in_local_scalar_dense_chain(node)
if is_in_local_scalar_dense_chain and dst_node_is_compatible:
return True
elif is_in_local_scalar_dense_chain:
self.log_skip(node, "local scalar dense of incompatible op node")
return False

features = None
if target not in vulkan_supported_ops:
# For some ops, i.e. custom ops the name is registered instead of the
Expand Down
1 change: 1 addition & 0 deletions backends/vulkan/patterns/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ runtime.python_library(
"quantized_linear.py",
"quantized_convolution.py",
"quantized_binary.py",
"select_as_symint.py",
],
visibility = [
"//executorch/backends/...",
Expand Down
2 changes: 2 additions & 0 deletions backends/vulkan/patterns/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

import executorch.backends.vulkan.patterns.rope # noqa

import executorch.backends.vulkan.patterns.select_as_symint # noqa

import torch

from executorch.backends.vulkan.patterns.pattern_registry import (
Expand Down
104 changes: 104 additions & 0 deletions backends/vulkan/patterns/select_as_symint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional

import torch

from executorch.backends.vulkan.patterns.pattern_registry import (
PatternMatch,
register_pattern_detector,
register_pattern_replacement,
)

from executorch.exir import ExportedProgram
from executorch.exir.dialects._ops import ops as exir_ops


class SelectAsSymIntMatch(PatternMatch):
def __init__(self, local_scalar_dense_node: torch.fx.Node) -> None:
self.anchor_node = local_scalar_dense_node
self.match_found = False

# Check if the input to local_scalar_dense is a select_copy node
if len(local_scalar_dense_node.args) < 1:
return

select_node = local_scalar_dense_node.args[0]
if not isinstance(select_node, torch.fx.Node):
return

if (
select_node.op != "call_function"
or select_node.target != exir_ops.edge.aten.select_copy.int
):
return

# select_copy.int has signature: select_copy(Tensor self, int dim, int index)
if len(select_node.args) < 3:
return

self.select_node = select_node

self.tensor_node = select_node.args[0]
self.dim_node = select_node.args[1]
self.index_node = select_node.args[2]

self.all_nodes = [
self.anchor_node,
self.select_node,
self.tensor_node,
self.dim_node,
self.index_node,
]

self.match_found = True


@register_pattern_detector("select_as_symint")
def find_select_as_symint_patterns(
node: torch.fx.Node,
) -> Optional[SelectAsSymIntMatch]:
if node.target != torch.ops.aten._local_scalar_dense.default:
return None

matched_pattern = SelectAsSymIntMatch(node)
if matched_pattern.match_found:
return matched_pattern

return None


##
## Pattern Replacement
##


@register_pattern_replacement("select_as_symint")
def replace_select_local_scalar_dense_with_select_as_symint(
ep: ExportedProgram,
graph_module: torch.fx.GraphModule,
match: SelectAsSymIntMatch,
):
with graph_module.graph.inserting_before(match.anchor_node):
new_node = graph_module.graph.create_node(
"call_function",
exir_ops.edge.et_vk.select_as_symint.default,
args=(
match.tensor_node,
match.dim_node,
match.index_node,
),
)

new_node.meta["val"] = match.anchor_node.meta["val"]
match.anchor_node.replace_all_uses_with(new_node)

# # Remove both the local_scalar_dense and select_copy nodes
# graph_module.graph.erase_node(match.anchor_node)
# # Only erase select_node if it has no other users
# if len(match.select_node.users) == 0:
# graph_module.graph.erase_node(match.select_node)
2 changes: 1 addition & 1 deletion backends/vulkan/runtime/VulkanBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,7 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
}
}

if (should_propagate_resize) {
if (should_propagate_resize || compute_graph->has_data_dependent_shapes()) {
compute_graph->propagate_resize();
}

Expand Down
14 changes: 14 additions & 0 deletions backends/vulkan/runtime/api/containers/StagingBuffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,20 @@ class StagingBuffer final {
inline void set_staging_zeros() {
memset(data(), 0, nbytes());
}

template <typename T>
T select_element_at_dim(
const std::vector<int64_t>& sizes,
const int64_t dim,
const int64_t index) {
int64_t stride = 1;
for (size_t i = dim + 1; i < sizes.size(); ++i) {
stride *= sizes[i];
}
const int64_t offset = index * stride;
const T* typed_data = reinterpret_cast<const T*>(data());
return typed_data[offset];
}
};

} // namespace api
Expand Down
Loading
Loading