Skip to content

Commit d06d47b

Browse files
committed
[Executorch] Add simd path for op quantize
Pull Request resolved: #15608 Reason this doesnt directly use Vectorize class is because the equivalent APIs dont exist in Vectorize class Differential Revision: [D84962236](https://our.internmc.facebook.com/intern/diff/D84962236/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D84962236/)! ghstack-source-id: 321455124
1 parent 94ff0fe commit d06d47b

File tree

2 files changed

+794
-37
lines changed

2 files changed

+794
-37
lines changed

kernels/quantized/cpu/op_quantize.cpp

Lines changed: 258 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@
1111
#include <cinttypes>
1212
#include <cmath>
1313

14+
#if defined(__aarch64__) || defined(__ARM_NEON__)
15+
#include <arm_neon.h>
16+
#endif
17+
1418
/**
1519
* For an input tensor, use the scale and zero_point arguments to quantize it.
1620
*/
@@ -105,6 +109,143 @@ T quantize_val(
105109
return static_cast<T>(qvalue);
106110
}
107111

112+
#if defined(__aarch64__) || defined(__ARM_NEON__)
113+
114+
// Traits for type-specific NEON operations
115+
template <typename T>
116+
struct NeonQuantizeTraits;
117+
118+
template <>
119+
struct NeonQuantizeTraits<uint8_t> {
120+
// Narrow int16x8 to uint8x8 with saturation (unsigned)
121+
static inline uint8x8_t narrow_and_saturate(int16x8_t v) {
122+
return vqmovun_s16(v);
123+
}
124+
125+
// Store uint8x8 to memory
126+
static inline void store(uint8_t* ptr, uint8x8_t v) {
127+
vst1_u8(ptr, v);
128+
}
129+
130+
// Scalar clamping for uint8
131+
static inline uint8_t clamp_scalar(int32_t val) {
132+
return static_cast<uint8_t>(std::min(255, std::max(0, val)));
133+
}
134+
};
135+
136+
template <>
137+
struct NeonQuantizeTraits<int8_t> {
138+
// Narrow int16x8 to int8x8 with saturation (signed)
139+
static inline int8x8_t narrow_and_saturate(int16x8_t v) {
140+
return vqmovn_s16(v);
141+
}
142+
143+
// Store int8x8 to memory
144+
static inline void store(int8_t* ptr, int8x8_t v) {
145+
vst1_s8(ptr, v);
146+
}
147+
148+
// Scalar clamping for int8
149+
static inline int8_t clamp_scalar(int32_t val) {
150+
return static_cast<int8_t>(std::min(127, std::max(-128, val)));
151+
}
152+
};
153+
154+
// Unified ARM NEON optimized quantization for contiguous blocks
155+
// Processes N elements with a single scale/zero_point pair
156+
// Used for both per-tensor (entire tensor) and per-channel (one block per
157+
// channel)
158+
template <typename T>
159+
void quantize_arm(
160+
const float* __restrict__ in,
161+
T* __restrict__ out,
162+
const int64_t N,
163+
const float inv_scale,
164+
const int32_t zero_point,
165+
const int32_t quant_min,
166+
const int32_t quant_max) {
167+
using Traits = NeonQuantizeTraits<T>;
168+
const float32x4_t vinv_scale = vdupq_n_f32(inv_scale);
169+
170+
#if defined(__aarch64__)
171+
// ARMv8: Use vcvtnq_s32_f32 for rounding
172+
const int16x8_t vzero_point = vdupq_n_s16(static_cast<int16_t>(zero_point));
173+
const int16x8_t vquant_min = vdupq_n_s16(static_cast<int16_t>(quant_min));
174+
const int16x8_t vquant_max = vdupq_n_s16(static_cast<int16_t>(quant_max));
175+
176+
int64_t i = 0;
177+
// Process 8 elements at a time
178+
for (; i + 8 <= N; i += 8) {
179+
const float32x4_t vin0123 = vld1q_f32(in + i);
180+
const float32x4_t vin4567 = vld1q_f32(in + i + 4);
181+
182+
// Multiply by inv_scale and round
183+
const int32x4_t v0123_rounded =
184+
vcvtnq_s32_f32(vmulq_f32(vin0123, vinv_scale));
185+
const int32x4_t v4567_rounded =
186+
vcvtnq_s32_f32(vmulq_f32(vin4567, vinv_scale));
187+
188+
// Combine to int16 and add zero_point
189+
int16x8_t v01234567_packed = vqaddq_s16(
190+
vqmovn_high_s32(vqmovn_s32(v0123_rounded), v4567_rounded), vzero_point);
191+
192+
// Clamp to quant_min/quant_max
193+
v01234567_packed = vmaxq_s16(v01234567_packed, vquant_min);
194+
v01234567_packed = vminq_s16(v01234567_packed, vquant_max);
195+
196+
// Convert to T (int8/uint8) with saturation using type-specific operation
197+
const auto vout01234567 = Traits::narrow_and_saturate(v01234567_packed);
198+
Traits::store(out + i, vout01234567);
199+
}
200+
201+
// Handle remaining elements with proper quant_min/quant_max clamping
202+
for (; i < N; ++i) {
203+
float val = in[i] * inv_scale;
204+
int32_t qval = static_cast<int32_t>(std::nearbyint(val)) + zero_point;
205+
qval = std::max(quant_min, std::min(quant_max, qval));
206+
out[i] = static_cast<T>(qval);
207+
}
208+
209+
#else
210+
// ARMv7: Use magic float rounding
211+
const int32x4_t voffset = vdupq_n_s32(zero_point - 0x4B400000);
212+
const float32x4_t vmagic_float = vdupq_n_f32(12582912.0f);
213+
214+
int64_t i = 0;
215+
// Process 8 elements at a time
216+
for (; i + 8 <= N; i += 8) {
217+
const float32x4_t vin0123 = vld1q_f32(in + i);
218+
const float32x4_t vin4567 = vld1q_f32(in + i + 4);
219+
220+
const int32x4_t vraw0123 = vaddq_s32(
221+
voffset,
222+
vreinterpretq_s32_f32(
223+
vaddq_f32(vmagic_float, vmulq_f32(vin0123, vinv_scale))));
224+
const int32x4_t vraw4567 = vaddq_s32(
225+
voffset,
226+
vreinterpretq_s32_f32(
227+
vaddq_f32(vmagic_float, vmulq_f32(vin4567, vinv_scale))));
228+
229+
const int16x8_t vraw01234567 =
230+
vcombine_s16(vqmovn_s32(vraw0123), vqmovn_s32(vraw4567));
231+
232+
// Convert to T (int8/uint8) with saturation using type-specific operation
233+
const auto vout01234567 = Traits::narrow_and_saturate(vraw01234567);
234+
Traits::store(out + i, vout01234567);
235+
}
236+
237+
// Handle remaining elements with proper quant_min/quant_max clamping
238+
for (; i < N; ++i) {
239+
float val = in[i] * inv_scale;
240+
int32_t qval = static_cast<int32_t>(std::nearbyint(val)) + zero_point;
241+
qval = std::max(quant_min, std::min(quant_max, qval));
242+
out[i] = static_cast<T>(qval);
243+
}
244+
#endif
245+
}
246+
247+
#endif // defined(__aarch64__) || defined(__ARM_NEON__)
248+
108249
Tensor& quantize_per_tensor_out(
109250
const Tensor& input,
110251
double scale,
@@ -120,19 +261,44 @@ Tensor& quantize_per_tensor_out(
120261

121262
check_quantize_per_tensor_args(input, quant_min, quant_max, dtype, out);
122263

123-
// calculate the quantized input
124-
#define QUANTIZE_IMPL(IN_CTYPE, OUT_CTYPE, out_dtype) \
125-
case ScalarType::out_dtype: { \
126-
/* Hoist these function calls out of our inner loop because they might not \
127-
* get inlined without LTO, particularly in ATen mode. */ \
128-
auto* out_data_ptr = out.mutable_data_ptr<OUT_CTYPE>(); \
129-
const auto* input_data_ptr = input.const_data_ptr<IN_CTYPE>(); \
130-
const auto input_numel = input.numel(); \
131-
for (size_t i = 0; i < input_numel; i++) { \
132-
IN_CTYPE value = input_data_ptr[i]; \
133-
out_data_ptr[i] = quantize_val<OUT_CTYPE, IN_CTYPE>( \
134-
scale, zero_point, value, quant_min, quant_max); \
135-
} \
264+
// Try ARM NEON optimized path for float->int8/uint8 quantization
265+
#if defined(__aarch64__) || defined(__ARM_NEON__)
266+
if (input.scalar_type() == ScalarType::Float) {
267+
if (dtype == ScalarType::Byte) {
268+
quantize_arm<uint8_t>(
269+
input.const_data_ptr<float>(),
270+
out.mutable_data_ptr<uint8_t>(),
271+
input.numel(),
272+
1.0f / static_cast<float>(scale),
273+
static_cast<int32_t>(zero_point),
274+
static_cast<int32_t>(quant_min),
275+
static_cast<int32_t>(quant_max));
276+
return out;
277+
} else if (dtype == ScalarType::Char) {
278+
quantize_arm<int8_t>(
279+
input.const_data_ptr<float>(),
280+
out.mutable_data_ptr<int8_t>(),
281+
input.numel(),
282+
1.0f / static_cast<float>(scale),
283+
static_cast<int32_t>(zero_point),
284+
static_cast<int32_t>(quant_min),
285+
static_cast<int32_t>(quant_max));
286+
return out;
287+
}
288+
}
289+
#endif
290+
291+
// Fallback scalar implementation for all other cases
292+
#define QUANTIZE_IMPL(IN_CTYPE, OUT_CTYPE, out_dtype) \
293+
case ScalarType::out_dtype: { \
294+
auto* out_data_ptr = out.mutable_data_ptr<OUT_CTYPE>(); \
295+
const auto* input_data_ptr = input.const_data_ptr<IN_CTYPE>(); \
296+
const auto input_numel = input.numel(); \
297+
for (size_t i = 0; i < input_numel; i++) { \
298+
IN_CTYPE value = input_data_ptr[i]; \
299+
out_data_ptr[i] = quantize_val<OUT_CTYPE, IN_CTYPE>( \
300+
scale, zero_point, value, quant_min, quant_max); \
301+
} \
136302
} break;
137303
#define CALCULATE_FLOAT_TYPE(IN_CTYPE, in_dtype) \
138304
case ScalarType::in_dtype: \
@@ -284,29 +450,85 @@ Tensor& quantize_per_channel_out(
284450
const double* scale_data = scale.const_data_ptr<double>();
285451
const int64_t* zero_point_data = zero_point.const_data_ptr<int64_t>();
286452

287-
// High-performance single loop with direct channel calculation
288-
#define QUANTIZE_IMPL(CTYPE_IN, CTYPE_OUT, out_dtype) \
289-
case ScalarType::out_dtype: { \
290-
auto* out_data_ptr = out.mutable_data_ptr<CTYPE_OUT>(); \
291-
const auto* input_data_ptr = input.const_data_ptr<CTYPE_IN>(); \
292-
const int64_t input_numel = input.numel(); \
293-
const int64_t axis_size = input.size(axis); \
294-
/* Calculate the stride pattern for efficient channel index calculation */ \
295-
int64_t axis_block_size = 1; \
296-
for (int64_t i = axis + 1; i < input.dim(); i++) { \
297-
axis_block_size *= input.size(i); \
298-
} \
299-
/* Single loop over all elements */ \
300-
for (int64_t i = 0; i < input_numel; i++) { \
301-
/* Calculate which channel this element belongs to */ \
302-
int64_t channel_idx = (i / axis_block_size) % axis_size; \
303-
/* Get quantization parameters for this channel */ \
304-
double _scale = scale_data[channel_idx]; \
305-
int64_t _zero_point = zero_point_data[channel_idx]; \
306-
/* Apply quantization */ \
307-
out_data_ptr[i] = quantize_val<CTYPE_OUT, CTYPE_IN>( \
308-
_scale, _zero_point, input_data_ptr[i], quant_min, quant_max); \
309-
} \
453+
// Calculate the block size for each channel
454+
int64_t axis_block_size = 1;
455+
for (int64_t i = axis + 1; i < input.dim(); i++) {
456+
axis_block_size *= input.size(i);
457+
}
458+
const int64_t axis_size = input.size(axis);
459+
460+
// Try ARM NEON optimized path for float->int8/uint8 quantization
461+
#if defined(__aarch64__) || defined(__ARM_NEON__)
462+
if (input.scalar_type() == ScalarType::Float) {
463+
const int64_t num_blocks = input.numel() / axis_block_size;
464+
465+
if (dtype == ScalarType::Byte) {
466+
auto* out_data_ptr = out.mutable_data_ptr<uint8_t>();
467+
const auto* input_data_ptr = input.const_data_ptr<float>();
468+
469+
// Process each contiguous block (which shares the same scale/zero_point)
470+
for (int64_t block = 0; block < num_blocks; ++block) {
471+
int64_t channel_idx = block % axis_size;
472+
float inv_scale = 1.0f / static_cast<float>(scale_data[channel_idx]);
473+
int32_t zp = static_cast<int32_t>(zero_point_data[channel_idx]);
474+
475+
const float* in_ptr = input_data_ptr + block * axis_block_size;
476+
uint8_t* out_ptr = out_data_ptr + block * axis_block_size;
477+
478+
quantize_arm<uint8_t>(
479+
in_ptr,
480+
out_ptr,
481+
axis_block_size,
482+
inv_scale,
483+
zp,
484+
static_cast<int32_t>(quant_min),
485+
static_cast<int32_t>(quant_max));
486+
}
487+
return out;
488+
} else if (dtype == ScalarType::Char) {
489+
auto* out_data_ptr = out.mutable_data_ptr<int8_t>();
490+
const auto* input_data_ptr = input.const_data_ptr<float>();
491+
492+
// Process each contiguous block (which shares the same scale/zero_point)
493+
for (int64_t block = 0; block < num_blocks; ++block) {
494+
int64_t channel_idx = block % axis_size;
495+
float inv_scale = 1.0f / static_cast<float>(scale_data[channel_idx]);
496+
int32_t zp = static_cast<int32_t>(zero_point_data[channel_idx]);
497+
498+
const float* in_ptr = input_data_ptr + block * axis_block_size;
499+
int8_t* out_ptr = out_data_ptr + block * axis_block_size;
500+
501+
quantize_arm<int8_t>(
502+
in_ptr,
503+
out_ptr,
504+
axis_block_size,
505+
inv_scale,
506+
zp,
507+
static_cast<int32_t>(quant_min),
508+
static_cast<int32_t>(quant_max));
509+
}
510+
return out;
511+
}
512+
}
513+
#endif
514+
515+
// Fallback scalar implementation
516+
#define QUANTIZE_IMPL(CTYPE_IN, CTYPE_OUT, out_dtype) \
517+
case ScalarType::out_dtype: { \
518+
auto* out_data_ptr = out.mutable_data_ptr<CTYPE_OUT>(); \
519+
const auto* input_data_ptr = input.const_data_ptr<CTYPE_IN>(); \
520+
const int64_t input_numel = input.numel(); \
521+
/* Single loop over all elements */ \
522+
for (int64_t i = 0; i < input_numel; i++) { \
523+
/* Calculate which channel this element belongs to */ \
524+
int64_t channel_idx = (i / axis_block_size) % axis_size; \
525+
/* Get quantization parameters for this channel */ \
526+
double _scale = scale_data[channel_idx]; \
527+
int64_t _zero_point = zero_point_data[channel_idx]; \
528+
/* Apply quantization */ \
529+
out_data_ptr[i] = quantize_val<CTYPE_OUT, CTYPE_IN>( \
530+
_scale, _zero_point, input_data_ptr[i], quant_min, quant_max); \
531+
} \
310532
} break;
311533

312534
#define CALCULATE_FLOAT_TYPE(CTYPE_IN, in_dtype) \

0 commit comments

Comments
 (0)