From ae61ab48e405a957a6b8b164c1a629d403935271 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Wed, 5 Nov 2025 10:45:11 -0800 Subject: [PATCH] Reduce allocation overhead in quantized sdpa For small models dequantizing portions of v cache causes extra alloc overhead. Probably a better way to handle this is to dequantize entire v cache outside the model There isnt significant perf advantage from this yet but subsequent diffs will use caching allocator where this refactor help. Differential Revision: [D85532077](https://our.internmc.facebook.com/intern/diff/D85532077/) [ghstack-poisoned] --- extension/llm/custom_ops/op_sdpa_impl.h | 27 ++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/extension/llm/custom_ops/op_sdpa_impl.h b/extension/llm/custom_ops/op_sdpa_impl.h index e0a81c4650c..07ce16dd048 100644 --- a/extension/llm/custom_ops/op_sdpa_impl.h +++ b/extension/llm/custom_ops/op_sdpa_impl.h @@ -213,13 +213,13 @@ void dequant_and_gemm( const int64_t v_stride_n, float* o_data, const int64_t o_stride_m, - const float beta) { - std::vector dequantized_v_data(v_data.m * v_data.n); + const float beta, + float* buf_qdq_ptr) { dequantize_per_channel_optimized( static_cast(v_data.data), static_cast(v_data.scales), static_cast(v_data.zero_points), - dequantized_v_data.data(), + buf_qdq_ptr, -128, 127, 1, @@ -237,7 +237,7 @@ void dequant_and_gemm( m, k, static_cast(1), - dequantized_v_data.data(), + buf_qdq_ptr, v_data.n, qk_data, qk_stride_m, @@ -257,7 +257,8 @@ void _qk_at_v_gemm( const int64_t v_stride_n, accum_t* o_data, const int64_t o_stride_m, - const accum_t beta) { + const accum_t beta, + accum_t* buf_qdq_ptr) { if (v_data.dtype == ScalarType::Char) { if constexpr (std::is_same::value) { if (m > 4) { @@ -273,7 +274,8 @@ void _qk_at_v_gemm( v_stride_n, o_data, o_stride_m, - beta); + beta, + buf_qdq_ptr); } else { // For smaller batch sizes, use quantized gemm int a_stride_m_tmp, b_stride_n_tmp; @@ -773,6 +775,15 @@ void cpu_flash_attention( // at::Tensor buf_reduced = at::empty( // {num_thread, qSplitSize, is_reduced_type ? kvSplitSize : 0}, // query.options()); + int64_t size_per_thread_qdq_vec = qSplitSize * kvSplitSize * headSize; + // Lets align size_per_thread_qdq_vec to 64 bytes, for coalesced cache reads, + // by padding with right number of per thread elements + constexpr int64_t kAlignment = 32; + size_per_thread_qdq_vec = (size_per_thread_qdq_vec + kAlignment - 1) & (-(kAlignment - 1)); + int64_t size_per_thread_qdq_bytes = size_per_thread_qdq_vec * query.element_size(); + int64_t size_qdq_bytes = size_per_thread_qdq_bytes * num_thread; + std::vector scratch_for_quant_dequant_vec(size_qdq_bytes); + accum_t* scratch_for_quant_dequant = reinterpret_cast(scratch_for_quant_dequant_vec.data()); // Data ptrs const scalar_t* q_data = query.const_data_ptr(); @@ -797,6 +808,7 @@ void cpu_flash_attention( scalar_t* qk_reduced_data = is_reduced_type ? buf_reduced_data + ompIdx * qSplitSize * kvSplitSize : nullptr; + accum_t* buf_qdq_ptr = scratch_for_quant_dequant + ompIdx * size_per_thread_qdq_vec; for (int64_t z = begin; z < end; z++) { int64_t m = k * qSplitSize; @@ -1053,7 +1065,8 @@ void cpu_flash_attention( vStrideN, dst_data, headSize, - n == 0 ? static_cast(0) : static_cast(1)); + n == 0 ? static_cast(0) : static_cast(1), + buf_qdq_ptr); } // dst <- dst / sum[row] // reorder MHA output with strides