Skip to content

Commit e3b4dba

Browse files
committed
[Executorch] Add multithreading for op_quantize
Pull Request resolved: #15609 As the title Differential Revision: [D84962233](https://our.internmc.facebook.com/intern/diff/D84962233/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D84962233/)! ghstack-source-id: 321455129
1 parent d06d47b commit e3b4dba

File tree

2 files changed

+91
-34
lines changed

2 files changed

+91
-34
lines changed

kernels/quantized/cpu/op_quantize.cpp

Lines changed: 88 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
*/
88

99
#include <executorch/runtime/kernel/kernel_includes.h>
10+
#include <executorch/runtime/kernel/thread_parallel_interface.h>
1011
#include <algorithm>
1112
#include <cinttypes>
1213
#include <cmath>
@@ -461,51 +462,104 @@ Tensor& quantize_per_channel_out(
461462
#if defined(__aarch64__) || defined(__ARM_NEON__)
462463
if (input.scalar_type() == ScalarType::Float) {
463464
const int64_t num_blocks = input.numel() / axis_block_size;
465+
const int64_t total_elements = input.numel();
466+
constexpr int64_t MIN_ELEMENTS_FOR_PARALLEL = 512;
467+
const bool use_parallel = (total_elements >= MIN_ELEMENTS_FOR_PARALLEL);
464468

465469
if (dtype == ScalarType::Byte) {
466470
auto* out_data_ptr = out.mutable_data_ptr<uint8_t>();
467471
const auto* input_data_ptr = input.const_data_ptr<float>();
468472

469-
// Process each contiguous block (which shares the same scale/zero_point)
470-
for (int64_t block = 0; block < num_blocks; ++block) {
471-
int64_t channel_idx = block % axis_size;
472-
float inv_scale = 1.0f / static_cast<float>(scale_data[channel_idx]);
473-
int32_t zp = static_cast<int32_t>(zero_point_data[channel_idx]);
474-
475-
const float* in_ptr = input_data_ptr + block * axis_block_size;
476-
uint8_t* out_ptr = out_data_ptr + block * axis_block_size;
477-
478-
quantize_arm<uint8_t>(
479-
in_ptr,
480-
out_ptr,
481-
axis_block_size,
482-
inv_scale,
483-
zp,
484-
static_cast<int32_t>(quant_min),
485-
static_cast<int32_t>(quant_max));
473+
if (use_parallel) {
474+
::executorch::extension::parallel_for(
475+
0, num_blocks, 1, [&](const int64_t begin, const int64_t end) {
476+
for (int64_t block = begin; block < end; ++block) {
477+
int64_t channel_idx = block % axis_size;
478+
float inv_scale =
479+
1.0f / static_cast<float>(scale_data[channel_idx]);
480+
int32_t zp = static_cast<int32_t>(zero_point_data[channel_idx]);
481+
482+
const float* in_ptr = input_data_ptr + block * axis_block_size;
483+
uint8_t* out_ptr = out_data_ptr + block * axis_block_size;
484+
485+
quantize_arm<uint8_t>(
486+
in_ptr,
487+
out_ptr,
488+
axis_block_size,
489+
inv_scale,
490+
zp,
491+
static_cast<int32_t>(quant_min),
492+
static_cast<int32_t>(quant_max));
493+
}
494+
});
495+
} else {
496+
// Process each contiguous block (which shares the same
497+
// scale/zero_point)
498+
for (int64_t block = 0; block < num_blocks; ++block) {
499+
int64_t channel_idx = block % axis_size;
500+
float inv_scale = 1.0f / static_cast<float>(scale_data[channel_idx]);
501+
int32_t zp = static_cast<int32_t>(zero_point_data[channel_idx]);
502+
503+
const float* in_ptr = input_data_ptr + block * axis_block_size;
504+
uint8_t* out_ptr = out_data_ptr + block * axis_block_size;
505+
506+
quantize_arm<uint8_t>(
507+
in_ptr,
508+
out_ptr,
509+
axis_block_size,
510+
inv_scale,
511+
zp,
512+
static_cast<int32_t>(quant_min),
513+
static_cast<int32_t>(quant_max));
514+
}
486515
}
487516
return out;
488517
} else if (dtype == ScalarType::Char) {
489518
auto* out_data_ptr = out.mutable_data_ptr<int8_t>();
490519
const auto* input_data_ptr = input.const_data_ptr<float>();
491520

492-
// Process each contiguous block (which shares the same scale/zero_point)
493-
for (int64_t block = 0; block < num_blocks; ++block) {
494-
int64_t channel_idx = block % axis_size;
495-
float inv_scale = 1.0f / static_cast<float>(scale_data[channel_idx]);
496-
int32_t zp = static_cast<int32_t>(zero_point_data[channel_idx]);
497-
498-
const float* in_ptr = input_data_ptr + block * axis_block_size;
499-
int8_t* out_ptr = out_data_ptr + block * axis_block_size;
500-
501-
quantize_arm<int8_t>(
502-
in_ptr,
503-
out_ptr,
504-
axis_block_size,
505-
inv_scale,
506-
zp,
507-
static_cast<int32_t>(quant_min),
508-
static_cast<int32_t>(quant_max));
521+
if (use_parallel) {
522+
::executorch::extension::parallel_for(
523+
0, num_blocks, 1, [&](const int64_t begin, const int64_t end) {
524+
for (int64_t block = begin; block < end; ++block) {
525+
int64_t channel_idx = block % axis_size;
526+
float inv_scale =
527+
1.0f / static_cast<float>(scale_data[channel_idx]);
528+
int32_t zp = static_cast<int32_t>(zero_point_data[channel_idx]);
529+
530+
const float* in_ptr = input_data_ptr + block * axis_block_size;
531+
int8_t* out_ptr = out_data_ptr + block * axis_block_size;
532+
533+
quantize_arm<int8_t>(
534+
in_ptr,
535+
out_ptr,
536+
axis_block_size,
537+
inv_scale,
538+
zp,
539+
static_cast<int32_t>(quant_min),
540+
static_cast<int32_t>(quant_max));
541+
}
542+
});
543+
} else {
544+
// Process each contiguous block (which shares the same
545+
// scale/zero_point)
546+
for (int64_t block = 0; block < num_blocks; ++block) {
547+
int64_t channel_idx = block % axis_size;
548+
float inv_scale = 1.0f / static_cast<float>(scale_data[channel_idx]);
549+
int32_t zp = static_cast<int32_t>(zero_point_data[channel_idx]);
550+
551+
const float* in_ptr = input_data_ptr + block * axis_block_size;
552+
int8_t* out_ptr = out_data_ptr + block * axis_block_size;
553+
554+
quantize_arm<int8_t>(
555+
in_ptr,
556+
out_ptr,
557+
axis_block_size,
558+
inv_scale,
559+
zp,
560+
static_cast<int32_t>(quant_min),
561+
static_cast<int32_t>(quant_max));
562+
}
509563
}
510564
return out;
511565
}

kernels/quantized/cpu/targets.bzl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ _QUANT_OPS = (
5252
),
5353
op_target(
5454
name = "op_quantize",
55+
deps = [
56+
"//executorch/extension/threadpool:threadpool",
57+
],
5558
),
5659
)
5760

0 commit comments

Comments
 (0)