Skip to content

Commit 20b2009

Browse files
committed
Reduce allocation overhead in quantized sdpa
Pull Request resolved: #15610 For small models dequantizing portions of v cache causes extra alloc overhead. Probably a better way to handle this is to dequantize entire v cache outside the model There isnt significant perf advantage from this yet but subsequent diffs will use caching allocator where this refactor help. ghstack-source-id: 321455128 @exported-using-ghexport Differential Revision: [D85532077](https://our.internmc.facebook.com/intern/diff/D85532077/)
1 parent e3b4dba commit 20b2009

File tree

1 file changed

+20
-7
lines changed

1 file changed

+20
-7
lines changed

extension/llm/custom_ops/op_sdpa_impl.h

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -213,13 +213,13 @@ void dequant_and_gemm(
213213
const int64_t v_stride_n,
214214
float* o_data,
215215
const int64_t o_stride_m,
216-
const float beta) {
217-
std::vector<float> dequantized_v_data(v_data.m * v_data.n);
216+
const float beta,
217+
float* buf_qdq_ptr) {
218218
dequantize_per_channel_optimized(
219219
static_cast<const int8_t*>(v_data.data),
220220
static_cast<const float*>(v_data.scales),
221221
static_cast<const int8_t*>(v_data.zero_points),
222-
dequantized_v_data.data(),
222+
buf_qdq_ptr,
223223
-128,
224224
127,
225225
1,
@@ -237,7 +237,7 @@ void dequant_and_gemm(
237237
m,
238238
k,
239239
static_cast<float>(1),
240-
dequantized_v_data.data(),
240+
buf_qdq_ptr,
241241
v_data.n,
242242
qk_data,
243243
qk_stride_m,
@@ -257,7 +257,8 @@ void _qk_at_v_gemm(
257257
const int64_t v_stride_n,
258258
accum_t* o_data,
259259
const int64_t o_stride_m,
260-
const accum_t beta) {
260+
const accum_t beta,
261+
accum_t* buf_qdq_ptr) {
261262
if (v_data.dtype == ScalarType::Char) {
262263
if constexpr (std::is_same<accum_t, float>::value) {
263264
if (m > 4) {
@@ -273,7 +274,8 @@ void _qk_at_v_gemm(
273274
v_stride_n,
274275
o_data,
275276
o_stride_m,
276-
beta);
277+
beta,
278+
buf_qdq_ptr);
277279
} else {
278280
// For smaller batch sizes, use quantized gemm
279281
int a_stride_m_tmp, b_stride_n_tmp;
@@ -773,6 +775,15 @@ void cpu_flash_attention(
773775
// at::Tensor buf_reduced = at::empty(
774776
// {num_thread, qSplitSize, is_reduced_type ? kvSplitSize : 0},
775777
// query.options());
778+
int64_t size_per_thread_qdq_vec = qSplitSize * kvSplitSize * headSize;
779+
// Lets align size_per_thread_qdq_vec to 64 bytes, for coalesced cache reads,
780+
// by padding with right number of per thread elements
781+
constexpr int64_t kAlignment = 32;
782+
size_per_thread_qdq_vec = (size_per_thread_qdq_vec + kAlignment - 1) & (-(kAlignment - 1));
783+
int64_t size_per_thread_qdq_bytes = size_per_thread_qdq_vec * query.element_size();
784+
int64_t size_qdq_bytes = size_per_thread_qdq_bytes * num_thread;
785+
std::vector<char> scratch_for_quant_dequant_vec(size_qdq_bytes);
786+
accum_t* scratch_for_quant_dequant = reinterpret_cast<accum_t*>(scratch_for_quant_dequant_vec.data());
776787

777788
// Data ptrs
778789
const scalar_t* q_data = query.const_data_ptr<scalar_t>();
@@ -797,6 +808,7 @@ void cpu_flash_attention(
797808
scalar_t* qk_reduced_data = is_reduced_type
798809
? buf_reduced_data + ompIdx * qSplitSize * kvSplitSize
799810
: nullptr;
811+
accum_t* buf_qdq_ptr = scratch_for_quant_dequant + ompIdx * size_per_thread_qdq_vec;
800812

801813
for (int64_t z = begin; z < end; z++) {
802814
int64_t m = k * qSplitSize;
@@ -1053,7 +1065,8 @@ void cpu_flash_attention(
10531065
vStrideN,
10541066
dst_data,
10551067
headSize,
1056-
n == 0 ? static_cast<accum_t>(0) : static_cast<accum_t>(1));
1068+
n == 0 ? static_cast<accum_t>(0) : static_cast<accum_t>(1),
1069+
buf_qdq_ptr);
10571070
}
10581071
// dst <- dst / sum[row]
10591072
// reorder MHA output with strides

0 commit comments

Comments
 (0)