Skip to content

Commit fffeb4c

Browse files
author
ssjia
committed
[ET-VK][ez] Ensure that attn_weight buffers do not exceed GPU buffer numel limit
Pull Request resolved: #15651 Title says it all! To give a concrete example, Llama3.2-1B-Instruct will have attn weights with size `{1, 32, max_seq_len, max_context_len}`. Usually `max_seq_len == max_context_len`, and if `max_context_len = 2048` Then the attention weight tensors will have sizes `{1, 32, 2048, 2048}` which will contain 134217728 elements. The `maxStorageBufferRange` for Adreno 750 is also 134217728 (2^27), so using context length of 2048 will produce incorrect results on Adreno 750. In practice, it is unlikely that the prompt sequence length will be equal to the context length, so the solution is to adjust down the `max_seq_len` dim of the attention weight tensors to ensure that the GPU buffer numel limit is not hit. ghstack-source-id: 321555042 @exported-using-ghexport Differential Revision: [D86443407](https://our.internmc.facebook.com/intern/diff/D86443407/)
1 parent ef3e85a commit fffeb4c

File tree

2 files changed

+32
-7
lines changed

2 files changed

+32
-7
lines changed

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -639,6 +639,10 @@ class ComputeGraph final {
639639

640640
bool device_name_contains(const char* substr);
641641

642+
int64_t max_buffer_numel() {
643+
return static_cast<int64_t>(context_->adapter_ptr()->max_buffer_numel());
644+
}
645+
642646
//
643647
// Graph Building
644648
//

backends/vulkan/runtime/graph/ops/impl/SDPA.cpp

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -471,10 +471,31 @@ void sdpa_impl(ComputeGraph& graph, const std::vector<ValueRef>& args) {
471471
VK_CHECK_COND(graph.val_is_none(attn_mask));
472472

473473
const int64_t num_q_heads = graph.size_at<int64_t>(-2, q_projected);
474-
const int64_t max_seq_len = graph.size_at<int64_t>(-3, q_projected);
475-
474+
int64_t max_seq_len = graph.size_at<int64_t>(-3, q_projected);
476475
const int64_t max_context_len = graph.size_at<int32_t>(-3, k_cache);
477476

477+
const utils::StorageType attn_weights_storage =
478+
graph.storage_type_of(q_projected);
479+
480+
// If using buffer storage for attn weights, we need to ensure that the buffer
481+
// numel limit is not exceeded. If needed, manually adjust max_seq_len based
482+
// on the buffer numel limit.
483+
if (attn_weights_storage == utils::kBuffer) {
484+
const int64_t max_buffer_numel = graph.max_buffer_numel();
485+
if (num_q_heads * max_seq_len * max_context_len >= max_buffer_numel) {
486+
// Compute the maximum possible value for max_seq_len that will hit
487+
// the buffer numel limit.
488+
max_seq_len = max_buffer_numel / (num_q_heads * max_context_len);
489+
// Adjust down to the nearest multiple of 4 to make sure the limit is
490+
// not hit.
491+
if (max_seq_len % 4 != 0) {
492+
max_seq_len = (max_seq_len / 4) * 4;
493+
} else {
494+
max_seq_len -= 4;
495+
}
496+
}
497+
}
498+
478499
std::vector<int64_t> attn_weight_full_sizes = {
479500
1, // batch
480501
num_q_heads,
@@ -485,14 +506,14 @@ void sdpa_impl(ComputeGraph& graph, const std::vector<ValueRef>& args) {
485506
&graph,
486507
attn_weight_full_sizes,
487508
graph.dtype_of(q_projected),
488-
graph.storage_type_of(q_projected),
509+
attn_weights_storage,
489510
utils::kWidthPacked);
490511

491512
TmpTensor attn_weights_softmax(
492513
&graph,
493514
attn_weight_full_sizes,
494515
graph.dtype_of(q_projected),
495-
graph.storage_type_of(q_projected),
516+
attn_weights_storage,
496517
utils::kWidthPacked);
497518

498519
add_sdpa_compute_attn_weights_node(
@@ -528,9 +549,9 @@ void sdpa_with_kv_cache_impl(
528549

529550
utils::StorageType cache_storage = graph.storage_type_of(q_projected);
530551
const ValueRef k_cache =
531-
prepack_standard(graph, k_cache_data, cache_storage, utils::kWidthPacked);
552+
graph.add_tensor_like(k_cache_data, cache_storage, utils::kWidthPacked);
532553
const ValueRef v_cache =
533-
prepack_standard(graph, v_cache_data, cache_storage, utils::kWidthPacked);
554+
graph.add_tensor_like(v_cache_data, cache_storage, utils::kWidthPacked);
534555

535556
update_cache_impl(graph, {k_projected, k_cache, input_pos_symint, -1});
536557
update_cache_impl(graph, {v_projected, v_cache, input_pos_symint, -1});
@@ -573,7 +594,7 @@ void compute_attn_weight_with_kv_cache_impl(
573594

574595
(void)sequence_len;
575596

576-
utils::StorageType cache_storage = graph.storage_type_of(q_projected);
597+
const utils::StorageType cache_storage = graph.storage_type_of(q_projected);
577598
const ValueRef k_cache =
578599
graph.add_tensor_like(k_cache_data, cache_storage, utils::kWidthPacked);
579600
const ValueRef v_cache =

0 commit comments

Comments
 (0)