1111#include < cinttypes>
1212#include < cmath>
1313
14+ #if defined(__aarch64__) || defined(__ARM_NEON__)
15+ #include < arm_neon.h>
16+ #endif
17+
1418/* *
1519 * For an input tensor, use the scale and zero_point arguments to quantize it.
1620 */
@@ -105,6 +109,143 @@ T quantize_val(
105109 return static_cast <T>(qvalue);
106110}
107111
112+ #if defined(__aarch64__) || defined(__ARM_NEON__)
113+
114+ // Traits for type-specific NEON operations
115+ template <typename T>
116+ struct NeonQuantizeTraits ;
117+
118+ template <>
119+ struct NeonQuantizeTraits <uint8_t > {
120+ // Narrow int16x8 to uint8x8 with saturation (unsigned)
121+ static inline uint8x8_t narrow_and_saturate (int16x8_t v) {
122+ return vqmovun_s16 (v);
123+ }
124+
125+ // Store uint8x8 to memory
126+ static inline void store (uint8_t * ptr, uint8x8_t v) {
127+ vst1_u8 (ptr, v);
128+ }
129+
130+ // Scalar clamping for uint8
131+ static inline uint8_t clamp_scalar (int32_t val) {
132+ return static_cast <uint8_t >(std::min (255 , std::max (0 , val)));
133+ }
134+ };
135+
136+ template <>
137+ struct NeonQuantizeTraits <int8_t > {
138+ // Narrow int16x8 to int8x8 with saturation (signed)
139+ static inline int8x8_t narrow_and_saturate (int16x8_t v) {
140+ return vqmovn_s16 (v);
141+ }
142+
143+ // Store int8x8 to memory
144+ static inline void store (int8_t * ptr, int8x8_t v) {
145+ vst1_s8 (ptr, v);
146+ }
147+
148+ // Scalar clamping for int8
149+ static inline int8_t clamp_scalar (int32_t val) {
150+ return static_cast <int8_t >(std::min (127 , std::max (-128 , val)));
151+ }
152+ };
153+
154+ // Unified ARM NEON optimized quantization for contiguous blocks
155+ // Processes N elements with a single scale/zero_point pair
156+ // Used for both per-tensor (entire tensor) and per-channel (one block per
157+ // channel)
158+ template <typename T>
159+ void quantize_arm (
160+ const float * __restrict__ in,
161+ T* __restrict__ out,
162+ const int64_t N,
163+ const float inv_scale,
164+ const int32_t zero_point,
165+ const int32_t quant_min,
166+ const int32_t quant_max) {
167+ using Traits = NeonQuantizeTraits<T>;
168+ const float32x4_t vinv_scale = vdupq_n_f32 (inv_scale);
169+
170+ #if defined(__aarch64__)
171+ // ARMv8: Use vcvtnq_s32_f32 for rounding
172+ const int16x8_t vzero_point = vdupq_n_s16 (static_cast <int16_t >(zero_point));
173+ const int16x8_t vquant_min = vdupq_n_s16 (static_cast <int16_t >(quant_min));
174+ const int16x8_t vquant_max = vdupq_n_s16 (static_cast <int16_t >(quant_max));
175+
176+ int64_t i = 0 ;
177+ // Process 8 elements at a time
178+ for (; i + 8 <= N; i += 8 ) {
179+ const float32x4_t vin0123 = vld1q_f32 (in + i);
180+ const float32x4_t vin4567 = vld1q_f32 (in + i + 4 );
181+
182+ // Multiply by inv_scale and round
183+ const int32x4_t v0123_rounded =
184+ vcvtnq_s32_f32 (vmulq_f32 (vin0123, vinv_scale));
185+ const int32x4_t v4567_rounded =
186+ vcvtnq_s32_f32 (vmulq_f32 (vin4567, vinv_scale));
187+
188+ // Combine to int16 and add zero_point
189+ int16x8_t v01234567_packed = vqaddq_s16 (
190+ vqmovn_high_s32 (vqmovn_s32 (v0123_rounded), v4567_rounded), vzero_point);
191+
192+ // Clamp to quant_min/quant_max
193+ v01234567_packed = vmaxq_s16 (v01234567_packed, vquant_min);
194+ v01234567_packed = vminq_s16 (v01234567_packed, vquant_max);
195+
196+ // Convert to T (int8/uint8) with saturation using type-specific operation
197+ const auto vout01234567 = Traits::narrow_and_saturate (v01234567_packed);
198+ Traits::store (out + i, vout01234567);
199+ }
200+
201+ // Handle remaining elements with proper quant_min/quant_max clamping
202+ for (; i < N; ++i) {
203+ float val = in[i] * inv_scale;
204+ int32_t qval = static_cast <int32_t >(std::nearbyint (val)) + zero_point;
205+ qval = std::max (quant_min, std::min (quant_max, qval));
206+ out[i] = static_cast <T>(qval);
207+ }
208+
209+ #else
210+ // ARMv7: Use magic float rounding
211+ const int32x4_t voffset = vdupq_n_s32 (zero_point - 0x4B400000 );
212+ const float32x4_t vmagic_float = vdupq_n_f32 (12582912 .0f );
213+
214+ int64_t i = 0 ;
215+ // Process 8 elements at a time
216+ for (; i + 8 <= N; i += 8 ) {
217+ const float32x4_t vin0123 = vld1q_f32 (in + i);
218+ const float32x4_t vin4567 = vld1q_f32 (in + i + 4 );
219+
220+ const int32x4_t vraw0123 = vaddq_s32 (
221+ voffset,
222+ vreinterpretq_s32_f32 (
223+ vaddq_f32 (vmagic_float, vmulq_f32 (vin0123, vinv_scale))));
224+ const int32x4_t vraw4567 = vaddq_s32 (
225+ voffset,
226+ vreinterpretq_s32_f32 (
227+ vaddq_f32 (vmagic_float, vmulq_f32 (vin4567, vinv_scale))));
228+
229+ const int16x8_t vraw01234567 =
230+ vcombine_s16 (vqmovn_s32 (vraw0123), vqmovn_s32 (vraw4567));
231+
232+ // Convert to T (int8/uint8) with saturation using type-specific operation
233+ const auto vout01234567 = Traits::narrow_and_saturate (vraw01234567);
234+ Traits::store (out + i, vout01234567);
235+ }
236+
237+ // Handle remaining elements with proper quant_min/quant_max clamping
238+ for (; i < N; ++i) {
239+ float val = in[i] * inv_scale;
240+ int32_t qval = static_cast <int32_t >(std::nearbyint (val)) + zero_point;
241+ qval = std::max (quant_min, std::min (quant_max, qval));
242+ out[i] = static_cast <T>(qval);
243+ }
244+ #endif
245+ }
246+
247+ #endif // defined(__aarch64__) || defined(__ARM_NEON__)
248+
108249Tensor& quantize_per_tensor_out (
109250 const Tensor& input,
110251 double scale,
@@ -120,19 +261,44 @@ Tensor& quantize_per_tensor_out(
120261
121262 check_quantize_per_tensor_args (input, quant_min, quant_max, dtype, out);
122263
123- // calculate the quantized input
124- #define QUANTIZE_IMPL (IN_CTYPE, OUT_CTYPE, out_dtype ) \
125- case ScalarType::out_dtype: { \
126- /* Hoist these function calls out of our inner loop because they might not \
127- * get inlined without LTO, particularly in ATen mode. */ \
128- auto * out_data_ptr = out.mutable_data_ptr <OUT_CTYPE>(); \
129- const auto * input_data_ptr = input.const_data_ptr <IN_CTYPE>(); \
130- const auto input_numel = input.numel (); \
131- for (size_t i = 0 ; i < input_numel; i++) { \
132- IN_CTYPE value = input_data_ptr[i]; \
133- out_data_ptr[i] = quantize_val<OUT_CTYPE, IN_CTYPE>( \
134- scale, zero_point, value, quant_min, quant_max); \
135- } \
264+ // Try ARM NEON optimized path for float->int8/uint8 quantization
265+ #if defined(__aarch64__) || defined(__ARM_NEON__)
266+ if (input.scalar_type () == ScalarType::Float) {
267+ if (dtype == ScalarType::Byte) {
268+ quantize_arm<uint8_t >(
269+ input.const_data_ptr <float >(),
270+ out.mutable_data_ptr <uint8_t >(),
271+ input.numel (),
272+ 1 .0f / static_cast <float >(scale),
273+ static_cast <int32_t >(zero_point),
274+ static_cast <int32_t >(quant_min),
275+ static_cast <int32_t >(quant_max));
276+ return out;
277+ } else if (dtype == ScalarType::Char) {
278+ quantize_arm<int8_t >(
279+ input.const_data_ptr <float >(),
280+ out.mutable_data_ptr <int8_t >(),
281+ input.numel (),
282+ 1 .0f / static_cast <float >(scale),
283+ static_cast <int32_t >(zero_point),
284+ static_cast <int32_t >(quant_min),
285+ static_cast <int32_t >(quant_max));
286+ return out;
287+ }
288+ }
289+ #endif
290+
291+ // Fallback scalar implementation for all other cases
292+ #define QUANTIZE_IMPL (IN_CTYPE, OUT_CTYPE, out_dtype ) \
293+ case ScalarType::out_dtype: { \
294+ auto * out_data_ptr = out.mutable_data_ptr <OUT_CTYPE>(); \
295+ const auto * input_data_ptr = input.const_data_ptr <IN_CTYPE>(); \
296+ const auto input_numel = input.numel (); \
297+ for (size_t i = 0 ; i < input_numel; i++) { \
298+ IN_CTYPE value = input_data_ptr[i]; \
299+ out_data_ptr[i] = quantize_val<OUT_CTYPE, IN_CTYPE>( \
300+ scale, zero_point, value, quant_min, quant_max); \
301+ } \
136302 } break ;
137303#define CALCULATE_FLOAT_TYPE (IN_CTYPE, in_dtype ) \
138304 case ScalarType::in_dtype: \
@@ -284,29 +450,85 @@ Tensor& quantize_per_channel_out(
284450 const double * scale_data = scale.const_data_ptr <double >();
285451 const int64_t * zero_point_data = zero_point.const_data_ptr <int64_t >();
286452
287- // High-performance single loop with direct channel calculation
288- #define QUANTIZE_IMPL (CTYPE_IN, CTYPE_OUT, out_dtype ) \
289- case ScalarType::out_dtype: { \
290- auto * out_data_ptr = out.mutable_data_ptr <CTYPE_OUT>(); \
291- const auto * input_data_ptr = input.const_data_ptr <CTYPE_IN>(); \
292- const int64_t input_numel = input.numel (); \
293- const int64_t axis_size = input.size (axis); \
294- /* Calculate the stride pattern for efficient channel index calculation */ \
295- int64_t axis_block_size = 1 ; \
296- for (int64_t i = axis + 1 ; i < input.dim (); i++) { \
297- axis_block_size *= input.size (i); \
298- } \
299- /* Single loop over all elements */ \
300- for (int64_t i = 0 ; i < input_numel; i++) { \
301- /* Calculate which channel this element belongs to */ \
302- int64_t channel_idx = (i / axis_block_size) % axis_size; \
303- /* Get quantization parameters for this channel */ \
304- double _scale = scale_data[channel_idx]; \
305- int64_t _zero_point = zero_point_data[channel_idx]; \
306- /* Apply quantization */ \
307- out_data_ptr[i] = quantize_val<CTYPE_OUT, CTYPE_IN>( \
308- _scale, _zero_point, input_data_ptr[i], quant_min, quant_max); \
309- } \
453+ // Calculate the block size for each channel
454+ int64_t axis_block_size = 1 ;
455+ for (int64_t i = axis + 1 ; i < input.dim (); i++) {
456+ axis_block_size *= input.size (i);
457+ }
458+ const int64_t axis_size = input.size (axis);
459+
460+ // Try ARM NEON optimized path for float->int8/uint8 quantization
461+ #if defined(__aarch64__) || defined(__ARM_NEON__)
462+ if (input.scalar_type () == ScalarType::Float) {
463+ const int64_t num_blocks = input.numel () / axis_block_size;
464+
465+ if (dtype == ScalarType::Byte) {
466+ auto * out_data_ptr = out.mutable_data_ptr <uint8_t >();
467+ const auto * input_data_ptr = input.const_data_ptr <float >();
468+
469+ // Process each contiguous block (which shares the same scale/zero_point)
470+ for (int64_t block = 0 ; block < num_blocks; ++block) {
471+ int64_t channel_idx = block % axis_size;
472+ float inv_scale = 1 .0f / static_cast <float >(scale_data[channel_idx]);
473+ int32_t zp = static_cast <int32_t >(zero_point_data[channel_idx]);
474+
475+ const float * in_ptr = input_data_ptr + block * axis_block_size;
476+ uint8_t * out_ptr = out_data_ptr + block * axis_block_size;
477+
478+ quantize_arm<uint8_t >(
479+ in_ptr,
480+ out_ptr,
481+ axis_block_size,
482+ inv_scale,
483+ zp,
484+ static_cast <int32_t >(quant_min),
485+ static_cast <int32_t >(quant_max));
486+ }
487+ return out;
488+ } else if (dtype == ScalarType::Char) {
489+ auto * out_data_ptr = out.mutable_data_ptr <int8_t >();
490+ const auto * input_data_ptr = input.const_data_ptr <float >();
491+
492+ // Process each contiguous block (which shares the same scale/zero_point)
493+ for (int64_t block = 0 ; block < num_blocks; ++block) {
494+ int64_t channel_idx = block % axis_size;
495+ float inv_scale = 1 .0f / static_cast <float >(scale_data[channel_idx]);
496+ int32_t zp = static_cast <int32_t >(zero_point_data[channel_idx]);
497+
498+ const float * in_ptr = input_data_ptr + block * axis_block_size;
499+ int8_t * out_ptr = out_data_ptr + block * axis_block_size;
500+
501+ quantize_arm<int8_t >(
502+ in_ptr,
503+ out_ptr,
504+ axis_block_size,
505+ inv_scale,
506+ zp,
507+ static_cast <int32_t >(quant_min),
508+ static_cast <int32_t >(quant_max));
509+ }
510+ return out;
511+ }
512+ }
513+ #endif
514+
515+ // Fallback scalar implementation
516+ #define QUANTIZE_IMPL (CTYPE_IN, CTYPE_OUT, out_dtype ) \
517+ case ScalarType::out_dtype: { \
518+ auto * out_data_ptr = out.mutable_data_ptr <CTYPE_OUT>(); \
519+ const auto * input_data_ptr = input.const_data_ptr <CTYPE_IN>(); \
520+ const int64_t input_numel = input.numel (); \
521+ /* Single loop over all elements */ \
522+ for (int64_t i = 0 ; i < input_numel; i++) { \
523+ /* Calculate which channel this element belongs to */ \
524+ int64_t channel_idx = (i / axis_block_size) % axis_size; \
525+ /* Get quantization parameters for this channel */ \
526+ double _scale = scale_data[channel_idx]; \
527+ int64_t _zero_point = zero_point_data[channel_idx]; \
528+ /* Apply quantization */ \
529+ out_data_ptr[i] = quantize_val<CTYPE_OUT, CTYPE_IN>( \
530+ _scale, _zero_point, input_data_ptr[i], quant_min, quant_max); \
531+ } \
310532 } break ;
311533
312534#define CALCULATE_FLOAT_TYPE (CTYPE_IN, in_dtype ) \
0 commit comments