From 97b86a0570a0d47be00d7aa3afcf5c0599413de8 Mon Sep 17 00:00:00 2001 From: ssjia Date: Thu, 6 Nov 2025 14:27:20 -0500 Subject: [PATCH] [ET-VK][ez] Fuse update_cache + custom_sdpa into sdpa_with_kv_cache Pull Request resolved: https://github.com/pytorch/executorch/pull/15618 SDPA used to be handled by a custom op `sdpa_with_kv_cache`, but it was eventually split (D62301837) into update_cache and custom_sdpa ops. However, having a single fused op is useful for Vulkan since it allows more control over how the cache tensors are stored and represented. Essentially, it makes it easier to manage the cache tensors and opens up opportunities for future optimizations. This diff introduces a fusion pass that does 2 things: 1. Combine update_cache and custom_sdpa back into sdpa_with_kv_cache 2. Ensure all references to the cache_pos symint use the same node - this prevents the select_at_dim_as_symint op from being called every time it is used. ghstack-source-id: 321258710 @exported-using-ghexport Differential Revision: [D86340339](https://our.internmc.facebook.com/intern/diff/D86340339/) --- backends/vulkan/patterns/TARGETS | 1 + backends/vulkan/patterns/__init__.py | 2 + backends/vulkan/patterns/sdpa.py | 155 +++++++++++++++++++++++++++ backends/vulkan/utils.py | 12 +++ 4 files changed, 170 insertions(+) create mode 100644 backends/vulkan/patterns/sdpa.py diff --git a/backends/vulkan/patterns/TARGETS b/backends/vulkan/patterns/TARGETS index ddc9cd77c04..3baf7c9e251 100644 --- a/backends/vulkan/patterns/TARGETS +++ b/backends/vulkan/patterns/TARGETS @@ -12,6 +12,7 @@ runtime.python_library( "quantized_linear.py", "quantized_convolution.py", "quantized_binary.py", + "sdpa.py", "select_as_symint.py", ], visibility = [ diff --git a/backends/vulkan/patterns/__init__.py b/backends/vulkan/patterns/__init__.py index 9239416dc2d..9b875def944 100644 --- a/backends/vulkan/patterns/__init__.py +++ b/backends/vulkan/patterns/__init__.py @@ -14,6 +14,8 @@ import executorch.backends.vulkan.patterns.rope # noqa +import executorch.backends.vulkan.patterns.sdpa # noqa + import executorch.backends.vulkan.patterns.select_as_symint # noqa import torch diff --git a/backends/vulkan/patterns/sdpa.py b/backends/vulkan/patterns/sdpa.py new file mode 100644 index 00000000000..f67799f9b76 --- /dev/null +++ b/backends/vulkan/patterns/sdpa.py @@ -0,0 +1,155 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Optional + +import executorch.backends.vulkan.utils as utils + +import torch + +from executorch.backends.vulkan.patterns.pattern_registry import ( + PatternMatch, + register_pattern_detector, + register_pattern_replacement, +) + +from executorch.exir import ExportedProgram + + +def is_update_cache_node(node: Any) -> bool: + return utils.node_has_target(node, "llama::update_cache") + + +def is_custom_sdpa_node(node: Any) -> bool: + return utils.node_has_target(node, "llama::custom_sdpa") + + +def is_sdpa_with_kv_cache_node(node: Any) -> bool: + return utils.node_has_target(node, "llama::sdpa_with_kv_cache") + + +class CausalSDPAMatch(PatternMatch): + def __init__(self, custom_sdpa_node: torch.fx.Node) -> None: + self.anchor_node = custom_sdpa_node + self.match_found = False + self.all_nodes = [self.anchor_node] + + # llama.custom_sdpa has signature: + # custom_sdpa(query, key_cache, value_cache, start_pos, attn_mask, dropout_p, is_causal, scale) -> output + if len(custom_sdpa_node.args) < 4: + return + + self.query_node = custom_sdpa_node.args[0] + self.key_cache_node = custom_sdpa_node.args[1] + self.value_cache_node = custom_sdpa_node.args[2] + self.start_pos_node = custom_sdpa_node.args[3] + self.attn_mask_node = custom_sdpa_node.args[4] + self.dropout_p_node = custom_sdpa_node.args[5] + self.is_causal_node = custom_sdpa_node.args[6] + if len(custom_sdpa_node.args) > 7: + self.scale_node = custom_sdpa_node.args[7] + else: + self.scale_node = None + + # try to find update key cache node + self.update_key_cache_node = None + for user in self.key_cache_node.users: + if is_update_cache_node(user): + self.update_key_cache_node = user + break + + self.key_projection_node = None + if self.update_key_cache_node is not None: + self.key_projection_node = self.update_key_cache_node.args[0] + + # find update value cache node + self.update_value_cache_node = None + for user in self.value_cache_node.users: + if is_update_cache_node(user): + self.update_value_cache_node = user + break + + self.value_projection_node = None + if self.update_value_cache_node is not None: + self.value_projection_node = self.update_value_cache_node.args[0] + + # We have additional optional arguments but we don't need to capture them + # since the new op doesn't use them + + self.match_found = True + + +@register_pattern_detector("causal_sdpa") +def find_causal_sdpa_patterns( + node: torch.fx.Node, +) -> Optional[CausalSDPAMatch]: + if not is_custom_sdpa_node(node): + return None + + matched_pattern = CausalSDPAMatch(node) + if matched_pattern.match_found: + return matched_pattern + + return None + + +## +## Pattern Replacement +## + + +def find_singleton_start_pos_node(graph_module: torch.fx.GraphModule): + for node in graph_module.graph.nodes: + if is_update_cache_node(node): + return node.args[2] + + if is_sdpa_with_kv_cache_node(node): + return node.args[5] + + raise Exception( + "Could not find an instance of llama::update_cache or sdpa_with_kv_cache" + ) + + +@register_pattern_replacement("causal_sdpa") +def replace_custom_sdpa_with_causal_sdpa( + ep: ExportedProgram, + graph_module: torch.fx.GraphModule, + match: CausalSDPAMatch, +): + assert match.update_key_cache_node is not None + assert match.key_projection_node is not None + assert match.update_value_cache_node is not None + assert match.value_projection_node is not None + + singleton_start_pos_node = find_singleton_start_pos_node(graph_module) + + with graph_module.graph.inserting_before(match.anchor_node): + new_node = graph_module.graph.create_node( + "call_function", + torch.ops.llama.sdpa_with_kv_cache.default, + args=( + match.query_node, + match.key_projection_node, + match.value_projection_node, + match.key_cache_node, + match.value_cache_node, + singleton_start_pos_node, + 1, + match.attn_mask_node, + match.dropout_p_node, + match.is_causal_node, + match.scale_node, + ), + ) + + new_node.meta["val"] = match.anchor_node.meta["val"] + match.anchor_node.replace_all_uses_with(new_node) + + # Manually erase update_cache nodes since DCE will not remove them since they + # modify inputs (specifically, the cache args are modified) + graph_module.graph.erase_node(match.update_key_cache_node) + graph_module.graph.erase_node(match.update_value_cache_node) diff --git a/backends/vulkan/utils.py b/backends/vulkan/utils.py index 00147dab2c3..9c527cbc36a 100644 --- a/backends/vulkan/utils.py +++ b/backends/vulkan/utils.py @@ -373,6 +373,18 @@ def find_quant_user(node: torch.fx.Node) -> Optional[torch.fx.Node]: return None +def node_has_target(node: Any, target: str): + if not hasattr(node, "target"): + return False + + if isinstance(node.target, str): + return node.target == target + elif hasattr(node.target, "name"): + return node.target.name() == target + + return False + + ## ## Memory Layout, Storage Type Determination ##