diff --git a/mypyc/irbuild/for_helpers.py b/mypyc/irbuild/for_helpers.py index 715f5432cd13..4b3645d909cc 100644 --- a/mypyc/irbuild/for_helpers.py +++ b/mypyc/irbuild/for_helpers.py @@ -7,7 +7,7 @@ from __future__ import annotations -from typing import Callable, ClassVar +from typing import Any, Callable, ClassVar, cast from mypy.nodes import ( ARG_POS, @@ -1184,8 +1184,10 @@ def init(self, indexes: list[Lvalue], exprs: list[Expression]) -> None: ) self.gens.append(gen) + self.conditions, self.cond_blocks = self.__sort_conditions() + def gen_condition(self) -> None: - for i, gen in enumerate(self.gens): + for i, gen in enumerate(self.conditions): gen.gen_condition() if i < len(self.gens) - 1: self.builder.activate_block(self.cond_blocks[i]) @@ -1202,6 +1204,105 @@ def gen_cleanup(self) -> None: for gen in self.gens: gen.gen_cleanup() + def __sort_conditions(self) -> tuple[list[ForGenerator], list[BasicBlock]]: + # We don't necessarily need to check the gens in order, + # we just need to know which gen ends first. Some gens + # are quicker to check than others, so we will check the + # specialized ForHelpers before we check any generic + # ForIterable + + gens = self.gens + cond_blocks = self.cond_blocks[:] + cond_blocks.remove(self.body_block) + + def check_type(obj: Any, typ: type[ForGenerator]) -> bool: + # ForEnumerate gen_condition is as fast as it's underlying generator's + return ( + isinstance(obj, typ) + or isinstance(obj, ForEnumerate) + and isinstance(obj.main_gen, typ) + ) + + # these are slowest, they invoke Python's iteration protocol + for_iterable = [ + (g, block) for g, block in zip(gens, cond_blocks) if check_type(g, ForSequence) + ] + + # These aren't the slowest but they're slow, we need to pack an RTuple and then get and item and do a comparison + for_dict = [ + (g, block) for g, block in zip(gens, cond_blocks) if check_type(g, ForDictionaryCommon) + ] + + # These are faster than ForIterable but not as fast as others (faster than ForDict?) + for_native = [ + (g, block) for g, block in zip(gens, cond_blocks) if check_type(g, ForNativeGenerator) + ] + + # forward involves in the best case one pyssize_t comparison, else one length check + the comparison + # reverse is slightly slower than forward, with one extra check + for_sequence_reverse_with_len_check = [ + (g, block) + for g, block in zip(gens, cond_blocks) + if check_type(g, ForSequence) + and ( + for_seq := cast(ForSequence, g.main_gen if isinstance(g, ForEnumerate) else g) + ).reverse + and for_seq.length_reg is not None + ] + for_sequence_reverse_no_len_check = [ + (g, block) + for g, block in zip(gens, cond_blocks) + if check_type(g, ForSequence) + and ( + for_seq := cast(ForSequence, g.main_gen if isinstance(g, ForEnumerate) else g) + ).reverse + and for_seq.length_reg is None + ] + for_sequence_forward_with_len_check = [ + (g, block) + for g, block in zip(gens, cond_blocks) + if check_type(g, ForSequence) + and not ( + for_seq := cast(ForSequence, g.main_gen if isinstance(g, ForEnumerate) else g) + ).reverse + and for_seq.length_reg is not None + ] + for_sequence_forward_no_len_check = [ + (g, block) + for g, block in zip(gens, cond_blocks) + if check_type(g, ForSequence) + and not ( + for_seq := cast(ForSequence, g.main_gen if isinstance(g, ForEnumerate) else g) + ).reverse + and for_seq.length_reg is None + ] + + # these are really fast, just a C int equality check + for_range = [(g, block) for g, block in zip(gens, cond_blocks) if isinstance(g, ForRange)] + + ordered = ( + for_range + + for_sequence_forward_no_len_check + + for_sequence_reverse_no_len_check + + for_sequence_forward_with_len_check + + for_sequence_reverse_with_len_check + + for_native + + for_dict + ) + + # this is a failsafe for ForHelper classes which might have been added after this commit but not added to this function's code + leftovers = [ + g_and_block + for g_and_block in zip(gens, cond_blocks) + if g_and_block not in ordered + for_iterable + ] + + gens_and_blocks = ordered + leftovers + for_iterable + conditions = [g for (g, block) in gens_and_blocks] + cond_blocks = [block for (g, block) in gens_and_blocks] + [self.body_block] + + return conditions, cond_blocks + def get_expr_length(builder: IRBuilder, expr: Expression) -> int | None: if isinstance(expr, (StrExpr, BytesExpr)):