Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
294 changes: 258 additions & 36 deletions kernels/quantized/cpu/op_quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
#include <cinttypes>
#include <cmath>

#if defined(__aarch64__) || defined(__ARM_NEON__)
#include <arm_neon.h>
#endif

/**
* For an input tensor, use the scale and zero_point arguments to quantize it.
*/
Expand Down Expand Up @@ -105,6 +109,143 @@ T quantize_val(
return static_cast<T>(qvalue);
}

#if defined(__aarch64__) || defined(__ARM_NEON__)

// Traits for type-specific NEON operations
template <typename T>
struct NeonQuantizeTraits;

template <>
struct NeonQuantizeTraits<uint8_t> {
// Narrow int16x8 to uint8x8 with saturation (unsigned)
static inline uint8x8_t narrow_and_saturate(int16x8_t v) {
return vqmovun_s16(v);
}

// Store uint8x8 to memory
static inline void store(uint8_t* ptr, uint8x8_t v) {
vst1_u8(ptr, v);
}

// Scalar clamping for uint8
static inline uint8_t clamp_scalar(int32_t val) {
return static_cast<uint8_t>(std::min(255, std::max(0, val)));
}
};

template <>
struct NeonQuantizeTraits<int8_t> {
// Narrow int16x8 to int8x8 with saturation (signed)
static inline int8x8_t narrow_and_saturate(int16x8_t v) {
return vqmovn_s16(v);
}

// Store int8x8 to memory
static inline void store(int8_t* ptr, int8x8_t v) {
vst1_s8(ptr, v);
}

// Scalar clamping for int8
static inline int8_t clamp_scalar(int32_t val) {
return static_cast<int8_t>(std::min(127, std::max(-128, val)));
}
};

// Unified ARM NEON optimized quantization for contiguous blocks
// Processes N elements with a single scale/zero_point pair
// Used for both per-tensor (entire tensor) and per-channel (one block per
// channel)
template <typename T>
void quantize_arm(
const float* __restrict__ in,
T* __restrict__ out,
const int64_t N,
const float inv_scale,
const int32_t zero_point,
const int32_t quant_min,
const int32_t quant_max) {
using Traits = NeonQuantizeTraits<T>;
const float32x4_t vinv_scale = vdupq_n_f32(inv_scale);

#if defined(__aarch64__)
// ARMv8: Use vcvtnq_s32_f32 for rounding
const int16x8_t vzero_point = vdupq_n_s16(static_cast<int16_t>(zero_point));
const int16x8_t vquant_min = vdupq_n_s16(static_cast<int16_t>(quant_min));
const int16x8_t vquant_max = vdupq_n_s16(static_cast<int16_t>(quant_max));

int64_t i = 0;
// Process 8 elements at a time
for (; i + 8 <= N; i += 8) {
const float32x4_t vin0123 = vld1q_f32(in + i);
const float32x4_t vin4567 = vld1q_f32(in + i + 4);

// Multiply by inv_scale and round
const int32x4_t v0123_rounded =
vcvtnq_s32_f32(vmulq_f32(vin0123, vinv_scale));
const int32x4_t v4567_rounded =
vcvtnq_s32_f32(vmulq_f32(vin4567, vinv_scale));

// Combine to int16 and add zero_point
int16x8_t v01234567_packed = vqaddq_s16(
vqmovn_high_s32(vqmovn_s32(v0123_rounded), v4567_rounded), vzero_point);

// Clamp to quant_min/quant_max
v01234567_packed = vmaxq_s16(v01234567_packed, vquant_min);
v01234567_packed = vminq_s16(v01234567_packed, vquant_max);

// Convert to T (int8/uint8) with saturation using type-specific operation
const auto vout01234567 = Traits::narrow_and_saturate(v01234567_packed);
Traits::store(out + i, vout01234567);
}

// Handle remaining elements with proper quant_min/quant_max clamping
for (; i < N; ++i) {
float val = in[i] * inv_scale;
int32_t qval = static_cast<int32_t>(std::nearbyint(val)) + zero_point;
qval = std::max(quant_min, std::min(quant_max, qval));
out[i] = static_cast<T>(qval);
}

#else
// ARMv7: Use magic float rounding
const int32x4_t voffset = vdupq_n_s32(zero_point - 0x4B400000);
const float32x4_t vmagic_float = vdupq_n_f32(12582912.0f);

int64_t i = 0;
// Process 8 elements at a time
for (; i + 8 <= N; i += 8) {
const float32x4_t vin0123 = vld1q_f32(in + i);
const float32x4_t vin4567 = vld1q_f32(in + i + 4);

const int32x4_t vraw0123 = vaddq_s32(
voffset,
vreinterpretq_s32_f32(
vaddq_f32(vmagic_float, vmulq_f32(vin0123, vinv_scale))));
const int32x4_t vraw4567 = vaddq_s32(
voffset,
vreinterpretq_s32_f32(
vaddq_f32(vmagic_float, vmulq_f32(vin4567, vinv_scale))));

const int16x8_t vraw01234567 =
vcombine_s16(vqmovn_s32(vraw0123), vqmovn_s32(vraw4567));

// Convert to T (int8/uint8) with saturation using type-specific operation
const auto vout01234567 = Traits::narrow_and_saturate(vraw01234567);
Traits::store(out + i, vout01234567);
}

// Handle remaining elements with proper quant_min/quant_max clamping
for (; i < N; ++i) {
float val = in[i] * inv_scale;
int32_t qval = static_cast<int32_t>(std::nearbyint(val)) + zero_point;
qval = std::max(quant_min, std::min(quant_max, qval));
out[i] = static_cast<T>(qval);
}
#endif
}

#endif // defined(__aarch64__) || defined(__ARM_NEON__)

Tensor& quantize_per_tensor_out(
const Tensor& input,
double scale,
Expand All @@ -120,19 +261,44 @@ Tensor& quantize_per_tensor_out(

check_quantize_per_tensor_args(input, quant_min, quant_max, dtype, out);

// calculate the quantized input
#define QUANTIZE_IMPL(IN_CTYPE, OUT_CTYPE, out_dtype) \
case ScalarType::out_dtype: { \
/* Hoist these function calls out of our inner loop because they might not \
* get inlined without LTO, particularly in ATen mode. */ \
auto* out_data_ptr = out.mutable_data_ptr<OUT_CTYPE>(); \
const auto* input_data_ptr = input.const_data_ptr<IN_CTYPE>(); \
const auto input_numel = input.numel(); \
for (size_t i = 0; i < input_numel; i++) { \
IN_CTYPE value = input_data_ptr[i]; \
out_data_ptr[i] = quantize_val<OUT_CTYPE, IN_CTYPE>( \
scale, zero_point, value, quant_min, quant_max); \
} \
// Try ARM NEON optimized path for float->int8/uint8 quantization
#if defined(__aarch64__) || defined(__ARM_NEON__)
if (input.scalar_type() == ScalarType::Float) {
if (dtype == ScalarType::Byte) {
quantize_arm<uint8_t>(
input.const_data_ptr<float>(),
out.mutable_data_ptr<uint8_t>(),
input.numel(),
1.0f / static_cast<float>(scale),
static_cast<int32_t>(zero_point),
static_cast<int32_t>(quant_min),
static_cast<int32_t>(quant_max));
return out;
} else if (dtype == ScalarType::Char) {
quantize_arm<int8_t>(
input.const_data_ptr<float>(),
out.mutable_data_ptr<int8_t>(),
input.numel(),
1.0f / static_cast<float>(scale),
static_cast<int32_t>(zero_point),
static_cast<int32_t>(quant_min),
static_cast<int32_t>(quant_max));
return out;
}
}
#endif

// Fallback scalar implementation for all other cases
#define QUANTIZE_IMPL(IN_CTYPE, OUT_CTYPE, out_dtype) \
case ScalarType::out_dtype: { \
auto* out_data_ptr = out.mutable_data_ptr<OUT_CTYPE>(); \
const auto* input_data_ptr = input.const_data_ptr<IN_CTYPE>(); \
const auto input_numel = input.numel(); \
for (size_t i = 0; i < input_numel; i++) { \
IN_CTYPE value = input_data_ptr[i]; \
out_data_ptr[i] = quantize_val<OUT_CTYPE, IN_CTYPE>( \
scale, zero_point, value, quant_min, quant_max); \
} \
} break;
#define CALCULATE_FLOAT_TYPE(IN_CTYPE, in_dtype) \
case ScalarType::in_dtype: \
Expand Down Expand Up @@ -284,29 +450,85 @@ Tensor& quantize_per_channel_out(
const double* scale_data = scale.const_data_ptr<double>();
const int64_t* zero_point_data = zero_point.const_data_ptr<int64_t>();

// High-performance single loop with direct channel calculation
#define QUANTIZE_IMPL(CTYPE_IN, CTYPE_OUT, out_dtype) \
case ScalarType::out_dtype: { \
auto* out_data_ptr = out.mutable_data_ptr<CTYPE_OUT>(); \
const auto* input_data_ptr = input.const_data_ptr<CTYPE_IN>(); \
const int64_t input_numel = input.numel(); \
const int64_t axis_size = input.size(axis); \
/* Calculate the stride pattern for efficient channel index calculation */ \
int64_t axis_block_size = 1; \
for (int64_t i = axis + 1; i < input.dim(); i++) { \
axis_block_size *= input.size(i); \
} \
/* Single loop over all elements */ \
for (int64_t i = 0; i < input_numel; i++) { \
/* Calculate which channel this element belongs to */ \
int64_t channel_idx = (i / axis_block_size) % axis_size; \
/* Get quantization parameters for this channel */ \
double _scale = scale_data[channel_idx]; \
int64_t _zero_point = zero_point_data[channel_idx]; \
/* Apply quantization */ \
out_data_ptr[i] = quantize_val<CTYPE_OUT, CTYPE_IN>( \
_scale, _zero_point, input_data_ptr[i], quant_min, quant_max); \
} \
// Calculate the block size for each channel
int64_t axis_block_size = 1;
for (int64_t i = axis + 1; i < input.dim(); i++) {
axis_block_size *= input.size(i);
}
const int64_t axis_size = input.size(axis);

// Try ARM NEON optimized path for float->int8/uint8 quantization
#if defined(__aarch64__) || defined(__ARM_NEON__)
if (input.scalar_type() == ScalarType::Float) {
const int64_t num_blocks = input.numel() / axis_block_size;

if (dtype == ScalarType::Byte) {
auto* out_data_ptr = out.mutable_data_ptr<uint8_t>();
const auto* input_data_ptr = input.const_data_ptr<float>();

// Process each contiguous block (which shares the same scale/zero_point)
for (int64_t block = 0; block < num_blocks; ++block) {
int64_t channel_idx = block % axis_size;
float inv_scale = 1.0f / static_cast<float>(scale_data[channel_idx]);
int32_t zp = static_cast<int32_t>(zero_point_data[channel_idx]);

const float* in_ptr = input_data_ptr + block * axis_block_size;
uint8_t* out_ptr = out_data_ptr + block * axis_block_size;

quantize_arm<uint8_t>(
in_ptr,
out_ptr,
axis_block_size,
inv_scale,
zp,
static_cast<int32_t>(quant_min),
static_cast<int32_t>(quant_max));
}
return out;
} else if (dtype == ScalarType::Char) {
auto* out_data_ptr = out.mutable_data_ptr<int8_t>();
const auto* input_data_ptr = input.const_data_ptr<float>();

// Process each contiguous block (which shares the same scale/zero_point)
for (int64_t block = 0; block < num_blocks; ++block) {
int64_t channel_idx = block % axis_size;
float inv_scale = 1.0f / static_cast<float>(scale_data[channel_idx]);
int32_t zp = static_cast<int32_t>(zero_point_data[channel_idx]);

const float* in_ptr = input_data_ptr + block * axis_block_size;
int8_t* out_ptr = out_data_ptr + block * axis_block_size;

quantize_arm<int8_t>(
in_ptr,
out_ptr,
axis_block_size,
inv_scale,
zp,
static_cast<int32_t>(quant_min),
static_cast<int32_t>(quant_max));
}
return out;
}
}
#endif

// Fallback scalar implementation
#define QUANTIZE_IMPL(CTYPE_IN, CTYPE_OUT, out_dtype) \
case ScalarType::out_dtype: { \
auto* out_data_ptr = out.mutable_data_ptr<CTYPE_OUT>(); \
const auto* input_data_ptr = input.const_data_ptr<CTYPE_IN>(); \
const int64_t input_numel = input.numel(); \
/* Single loop over all elements */ \
for (int64_t i = 0; i < input_numel; i++) { \
/* Calculate which channel this element belongs to */ \
int64_t channel_idx = (i / axis_block_size) % axis_size; \
/* Get quantization parameters for this channel */ \
double _scale = scale_data[channel_idx]; \
int64_t _zero_point = zero_point_data[channel_idx]; \
/* Apply quantization */ \
out_data_ptr[i] = quantize_val<CTYPE_OUT, CTYPE_IN>( \
_scale, _zero_point, input_data_ptr[i], quant_min, quant_max); \
} \
} break;

#define CALCULATE_FLOAT_TYPE(CTYPE_IN, in_dtype) \
Expand Down
Loading
Loading