|
7 | 7 | */ |
8 | 8 |
|
9 | 9 | #include <executorch/runtime/kernel/kernel_includes.h> |
| 10 | +#include <executorch/runtime/kernel/thread_parallel_interface.h> |
10 | 11 | #include <algorithm> |
11 | 12 | #include <cinttypes> |
12 | 13 | #include <cmath> |
@@ -461,51 +462,104 @@ Tensor& quantize_per_channel_out( |
461 | 462 | #if defined(__aarch64__) || defined(__ARM_NEON__) |
462 | 463 | if (input.scalar_type() == ScalarType::Float) { |
463 | 464 | const int64_t num_blocks = input.numel() / axis_block_size; |
| 465 | + const int64_t total_elements = input.numel(); |
| 466 | + constexpr int64_t MIN_ELEMENTS_FOR_PARALLEL = 512; |
| 467 | + const bool use_parallel = (total_elements >= MIN_ELEMENTS_FOR_PARALLEL); |
464 | 468 |
|
465 | 469 | if (dtype == ScalarType::Byte) { |
466 | 470 | auto* out_data_ptr = out.mutable_data_ptr<uint8_t>(); |
467 | 471 | const auto* input_data_ptr = input.const_data_ptr<float>(); |
468 | 472 |
|
469 | | - // Process each contiguous block (which shares the same scale/zero_point) |
470 | | - for (int64_t block = 0; block < num_blocks; ++block) { |
471 | | - int64_t channel_idx = block % axis_size; |
472 | | - float inv_scale = 1.0f / static_cast<float>(scale_data[channel_idx]); |
473 | | - int32_t zp = static_cast<int32_t>(zero_point_data[channel_idx]); |
474 | | - |
475 | | - const float* in_ptr = input_data_ptr + block * axis_block_size; |
476 | | - uint8_t* out_ptr = out_data_ptr + block * axis_block_size; |
477 | | - |
478 | | - quantize_arm<uint8_t>( |
479 | | - in_ptr, |
480 | | - out_ptr, |
481 | | - axis_block_size, |
482 | | - inv_scale, |
483 | | - zp, |
484 | | - static_cast<int32_t>(quant_min), |
485 | | - static_cast<int32_t>(quant_max)); |
| 473 | + if (use_parallel) { |
| 474 | + ::executorch::extension::parallel_for( |
| 475 | + 0, num_blocks, 1, [&](const int64_t begin, const int64_t end) { |
| 476 | + for (int64_t block = begin; block < end; ++block) { |
| 477 | + int64_t channel_idx = block % axis_size; |
| 478 | + float inv_scale = |
| 479 | + 1.0f / static_cast<float>(scale_data[channel_idx]); |
| 480 | + int32_t zp = static_cast<int32_t>(zero_point_data[channel_idx]); |
| 481 | + |
| 482 | + const float* in_ptr = input_data_ptr + block * axis_block_size; |
| 483 | + uint8_t* out_ptr = out_data_ptr + block * axis_block_size; |
| 484 | + |
| 485 | + quantize_arm<uint8_t>( |
| 486 | + in_ptr, |
| 487 | + out_ptr, |
| 488 | + axis_block_size, |
| 489 | + inv_scale, |
| 490 | + zp, |
| 491 | + static_cast<int32_t>(quant_min), |
| 492 | + static_cast<int32_t>(quant_max)); |
| 493 | + } |
| 494 | + }); |
| 495 | + } else { |
| 496 | + // Process each contiguous block (which shares the same |
| 497 | + // scale/zero_point) |
| 498 | + for (int64_t block = 0; block < num_blocks; ++block) { |
| 499 | + int64_t channel_idx = block % axis_size; |
| 500 | + float inv_scale = 1.0f / static_cast<float>(scale_data[channel_idx]); |
| 501 | + int32_t zp = static_cast<int32_t>(zero_point_data[channel_idx]); |
| 502 | + |
| 503 | + const float* in_ptr = input_data_ptr + block * axis_block_size; |
| 504 | + uint8_t* out_ptr = out_data_ptr + block * axis_block_size; |
| 505 | + |
| 506 | + quantize_arm<uint8_t>( |
| 507 | + in_ptr, |
| 508 | + out_ptr, |
| 509 | + axis_block_size, |
| 510 | + inv_scale, |
| 511 | + zp, |
| 512 | + static_cast<int32_t>(quant_min), |
| 513 | + static_cast<int32_t>(quant_max)); |
| 514 | + } |
486 | 515 | } |
487 | 516 | return out; |
488 | 517 | } else if (dtype == ScalarType::Char) { |
489 | 518 | auto* out_data_ptr = out.mutable_data_ptr<int8_t>(); |
490 | 519 | const auto* input_data_ptr = input.const_data_ptr<float>(); |
491 | 520 |
|
492 | | - // Process each contiguous block (which shares the same scale/zero_point) |
493 | | - for (int64_t block = 0; block < num_blocks; ++block) { |
494 | | - int64_t channel_idx = block % axis_size; |
495 | | - float inv_scale = 1.0f / static_cast<float>(scale_data[channel_idx]); |
496 | | - int32_t zp = static_cast<int32_t>(zero_point_data[channel_idx]); |
497 | | - |
498 | | - const float* in_ptr = input_data_ptr + block * axis_block_size; |
499 | | - int8_t* out_ptr = out_data_ptr + block * axis_block_size; |
500 | | - |
501 | | - quantize_arm<int8_t>( |
502 | | - in_ptr, |
503 | | - out_ptr, |
504 | | - axis_block_size, |
505 | | - inv_scale, |
506 | | - zp, |
507 | | - static_cast<int32_t>(quant_min), |
508 | | - static_cast<int32_t>(quant_max)); |
| 521 | + if (use_parallel) { |
| 522 | + ::executorch::extension::parallel_for( |
| 523 | + 0, num_blocks, 1, [&](const int64_t begin, const int64_t end) { |
| 524 | + for (int64_t block = begin; block < end; ++block) { |
| 525 | + int64_t channel_idx = block % axis_size; |
| 526 | + float inv_scale = |
| 527 | + 1.0f / static_cast<float>(scale_data[channel_idx]); |
| 528 | + int32_t zp = static_cast<int32_t>(zero_point_data[channel_idx]); |
| 529 | + |
| 530 | + const float* in_ptr = input_data_ptr + block * axis_block_size; |
| 531 | + int8_t* out_ptr = out_data_ptr + block * axis_block_size; |
| 532 | + |
| 533 | + quantize_arm<int8_t>( |
| 534 | + in_ptr, |
| 535 | + out_ptr, |
| 536 | + axis_block_size, |
| 537 | + inv_scale, |
| 538 | + zp, |
| 539 | + static_cast<int32_t>(quant_min), |
| 540 | + static_cast<int32_t>(quant_max)); |
| 541 | + } |
| 542 | + }); |
| 543 | + } else { |
| 544 | + // Process each contiguous block (which shares the same |
| 545 | + // scale/zero_point) |
| 546 | + for (int64_t block = 0; block < num_blocks; ++block) { |
| 547 | + int64_t channel_idx = block % axis_size; |
| 548 | + float inv_scale = 1.0f / static_cast<float>(scale_data[channel_idx]); |
| 549 | + int32_t zp = static_cast<int32_t>(zero_point_data[channel_idx]); |
| 550 | + |
| 551 | + const float* in_ptr = input_data_ptr + block * axis_block_size; |
| 552 | + int8_t* out_ptr = out_data_ptr + block * axis_block_size; |
| 553 | + |
| 554 | + quantize_arm<int8_t>( |
| 555 | + in_ptr, |
| 556 | + out_ptr, |
| 557 | + axis_block_size, |
| 558 | + inv_scale, |
| 559 | + zp, |
| 560 | + static_cast<int32_t>(quant_min), |
| 561 | + static_cast<int32_t>(quant_max)); |
| 562 | + } |
509 | 563 | } |
510 | 564 | return out; |
511 | 565 | } |
|
0 commit comments