@@ -213,13 +213,13 @@ void dequant_and_gemm(
213213 const int64_t v_stride_n,
214214 float * o_data,
215215 const int64_t o_stride_m,
216- const float beta) {
217- std::vector< float > dequantized_v_data (v_data. m * v_data. n );
216+ const float beta,
217+ float * buf_qdq_ptr) {
218218 dequantize_per_channel_optimized (
219219 static_cast <const int8_t *>(v_data.data ),
220220 static_cast <const float *>(v_data.scales ),
221221 static_cast <const int8_t *>(v_data.zero_points ),
222- dequantized_v_data. data () ,
222+ buf_qdq_ptr ,
223223 -128 ,
224224 127 ,
225225 1 ,
@@ -237,7 +237,7 @@ void dequant_and_gemm(
237237 m,
238238 k,
239239 static_cast <float >(1 ),
240- dequantized_v_data.data() ,
240+ buf_qdq_ptr ,
241241 v_data.n,
242242 qk_data,
243243 qk_stride_m,
@@ -257,7 +257,8 @@ void _qk_at_v_gemm(
257257 const int64_t v_stride_n,
258258 accum_t * o_data,
259259 const int64_t o_stride_m,
260- const accum_t beta) {
260+ const accum_t beta,
261+ accum_t * buf_qdq_ptr) {
261262 if (v_data.dtype == ScalarType::Char) {
262263 if constexpr (std::is_same<accum_t , float >::value) {
263264 if (m > 4 ) {
@@ -273,7 +274,8 @@ void _qk_at_v_gemm(
273274 v_stride_n,
274275 o_data,
275276 o_stride_m,
276- beta);
277+ beta,
278+ buf_qdq_ptr);
277279 } else {
278280 // For smaller batch sizes, use quantized gemm
279281 int a_stride_m_tmp, b_stride_n_tmp;
@@ -773,6 +775,15 @@ void cpu_flash_attention(
773775 // at::Tensor buf_reduced = at::empty(
774776 // {num_thread, qSplitSize, is_reduced_type ? kvSplitSize : 0},
775777 // query.options());
778+ int64_t size_per_thread_qdq_vec = qSplitSize * kvSplitSize * headSize;
779+ // Lets align size_per_thread_qdq_vec to 64 bytes, for coalesced cache reads,
780+ // by padding with right number of per thread elements
781+ constexpr int64_t kAlignment = 32 ;
782+ size_per_thread_qdq_vec = (size_per_thread_qdq_vec + kAlignment - 1 ) & (-(kAlignment - 1 ));
783+ int64_t size_per_thread_qdq_bytes = size_per_thread_qdq_vec * query.element_size ();
784+ int64_t size_qdq_bytes = size_per_thread_qdq_bytes * num_thread;
785+ std::vector<char > scratch_for_quant_dequant_vec (size_qdq_bytes);
786+ accum_t * scratch_for_quant_dequant = reinterpret_cast <accum_t *>(scratch_for_quant_dequant_vec.data ());
776787
777788 // Data ptrs
778789 const scalar_t * q_data = query.const_data_ptr <scalar_t >();
@@ -797,6 +808,7 @@ void cpu_flash_attention(
797808 scalar_t * qk_reduced_data = is_reduced_type
798809 ? buf_reduced_data + ompIdx * qSplitSize * kvSplitSize
799810 : nullptr ;
811+ accum_t * buf_qdq_ptr = scratch_for_quant_dequant + ompIdx * size_per_thread_qdq_vec;
800812
801813 for (int64_t z = begin; z < end; z++) {
802814 int64_t m = k * qSplitSize;
@@ -1053,7 +1065,8 @@ void cpu_flash_attention(
10531065 vStrideN,
10541066 dst_data,
10551067 headSize,
1056- n == 0 ? static_cast <accum_t >(0 ) : static_cast <accum_t >(1 ));
1068+ n == 0 ? static_cast <accum_t >(0 ) : static_cast <accum_t >(1 ),
1069+ buf_qdq_ptr);
10571070 }
10581071 // dst <- dst / sum[row]
10591072 // reorder MHA output with strides
0 commit comments