diff --git a/backends/test/suite/flows/arm.py b/backends/test/suite/flows/arm.py index a690e4681f8..db3bb2cfd9e 100644 --- a/backends/test/suite/flows/arm.py +++ b/backends/test/suite/flows/arm.py @@ -5,19 +5,22 @@ # Create flows for Arm Backends used to test operator and model suits +from collections.abc import Callable + from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec from executorch.backends.arm.quantizer import get_symmetric_quantization_config from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.arm_tester import ArmTester -from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec from executorch.backends.arm.util._factory import create_quantizer from executorch.backends.test.suite.flow import TestFlow from executorch.backends.xnnpack.test.tester.tester import Quantize def _create_arm_flow( - name, - compile_spec: ArmCompileSpec, + name: str, + compile_spec_factory: Callable[[], ArmCompileSpec], + support_serialize: bool = True, + quantize: bool = True, symmetric_io_quantization: bool = False, per_channel_quantization: bool = True, use_portable_ops: bool = True, @@ -25,24 +28,23 @@ def _create_arm_flow( ) -> TestFlow: def _create_arm_tester(*args, **kwargs) -> ArmTester: - kwargs["compile_spec"] = compile_spec + spec = compile_spec_factory() + kwargs["compile_spec"] = spec return ArmTester( *args, **kwargs, use_portable_ops=use_portable_ops, timeout=timeout ) - support_serialize = not isinstance(compile_spec, TosaCompileSpec) - quantize = compile_spec.tosa_spec.support_integer() - - if quantize is True: + if quantize: def create_quantize_stage() -> Quantize: - quantizer = create_quantizer(compile_spec) + spec = compile_spec_factory() + quantizer = create_quantizer(spec) quantization_config = get_symmetric_quantization_config( is_per_channel=per_channel_quantization ) if symmetric_io_quantization: quantizer.set_io(quantization_config) - return Quantize(quantizer, quantization_config) + return Quantize(quantizer, quantization_config) # type: ignore return TestFlow( name, @@ -50,23 +52,29 @@ def create_quantize_stage() -> Quantize: tester_factory=_create_arm_tester, supports_serialize=support_serialize, quantize=quantize, - quantize_stage_factory=(create_quantize_stage if quantize is True else False), + quantize_stage_factory=(create_quantize_stage if quantize else False), # type: ignore ) ARM_TOSA_FP_FLOW = _create_arm_flow( "arm_tosa_fp", - common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+FP"), + lambda: common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+FP"), + support_serialize=False, + quantize=False, ) ARM_TOSA_INT_FLOW = _create_arm_flow( "arm_tosa_int", - common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+INT"), + lambda: common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+INT"), + support_serialize=False, + quantize=True, ) ARM_ETHOS_U55_FLOW = _create_arm_flow( "arm_ethos_u55", - common.get_u55_compile_spec(), + lambda: common.get_u55_compile_spec(), + quantize=True, ) ARM_ETHOS_U85_FLOW = _create_arm_flow( "arm_ethos_u85", - common.get_u85_compile_spec(), + lambda: common.get_u85_compile_spec(), + quantize=True, )