Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 103 additions & 2 deletions mypyc/irbuild/for_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from __future__ import annotations

from typing import Callable, ClassVar
from typing import Any, Callable, ClassVar, cast

from mypy.nodes import (
ARG_POS,
Expand Down Expand Up @@ -1184,8 +1184,10 @@ def init(self, indexes: list[Lvalue], exprs: list[Expression]) -> None:
)
self.gens.append(gen)

self.conditions, self.cond_blocks = self.__sort_conditions()

def gen_condition(self) -> None:
for i, gen in enumerate(self.gens):
for i, gen in enumerate(self.conditions):
gen.gen_condition()
if i < len(self.gens) - 1:
self.builder.activate_block(self.cond_blocks[i])
Expand All @@ -1202,6 +1204,105 @@ def gen_cleanup(self) -> None:
for gen in self.gens:
gen.gen_cleanup()

def __sort_conditions(self) -> tuple[list[ForGenerator], list[BasicBlock]]:
# We don't necessarily need to check the gens in order,
# we just need to know which gen ends first. Some gens
# are quicker to check than others, so we will check the
# specialized ForHelpers before we check any generic
# ForIterable

gens = self.gens
cond_blocks = self.cond_blocks[:]
cond_blocks.remove(self.body_block)

def check_type(obj: Any, typ: type[ForGenerator]) -> bool:
# ForEnumerate gen_condition is as fast as it's underlying generator's
return (
isinstance(obj, typ)
or isinstance(obj, ForEnumerate)
and isinstance(obj.main_gen, typ)
)

# these are slowest, they invoke Python's iteration protocol
for_iterable = [
(g, block) for g, block in zip(gens, cond_blocks) if check_type(g, ForSequence)
]

# These aren't the slowest but they're slow, we need to pack an RTuple and then get and item and do a comparison
for_dict = [
(g, block) for g, block in zip(gens, cond_blocks) if check_type(g, ForDictionaryCommon)
]

# These are faster than ForIterable but not as fast as others (faster than ForDict?)
for_native = [
(g, block) for g, block in zip(gens, cond_blocks) if check_type(g, ForNativeGenerator)
]

# forward involves in the best case one pyssize_t comparison, else one length check + the comparison
# reverse is slightly slower than forward, with one extra check
for_sequence_reverse_with_len_check = [
(g, block)
for g, block in zip(gens, cond_blocks)
if check_type(g, ForSequence)
and (
for_seq := cast(ForSequence, g.main_gen if isinstance(g, ForEnumerate) else g)
).reverse
and for_seq.length_reg is not None
]
for_sequence_reverse_no_len_check = [
(g, block)
for g, block in zip(gens, cond_blocks)
if check_type(g, ForSequence)
and (
for_seq := cast(ForSequence, g.main_gen if isinstance(g, ForEnumerate) else g)
).reverse
and for_seq.length_reg is None
]
for_sequence_forward_with_len_check = [
(g, block)
for g, block in zip(gens, cond_blocks)
if check_type(g, ForSequence)
and not (
for_seq := cast(ForSequence, g.main_gen if isinstance(g, ForEnumerate) else g)
).reverse
and for_seq.length_reg is not None
]
for_sequence_forward_no_len_check = [
(g, block)
for g, block in zip(gens, cond_blocks)
if check_type(g, ForSequence)
and not (
for_seq := cast(ForSequence, g.main_gen if isinstance(g, ForEnumerate) else g)
).reverse
and for_seq.length_reg is None
]

# these are really fast, just a C int equality check
for_range = [(g, block) for g, block in zip(gens, cond_blocks) if isinstance(g, ForRange)]

ordered = (
for_range
+ for_sequence_forward_no_len_check
+ for_sequence_reverse_no_len_check
+ for_sequence_forward_with_len_check
+ for_sequence_reverse_with_len_check
+ for_native
+ for_dict
)

# this is a failsafe for ForHelper classes which might have been added after this commit but not added to this function's code
leftovers = [
g_and_block
for g_and_block in zip(gens, cond_blocks)
if g_and_block not in ordered + for_iterable
]

gens_and_blocks = ordered + leftovers + for_iterable
conditions = [g for (g, block) in gens_and_blocks]
cond_blocks = [block for (g, block) in gens_and_blocks] + [self.body_block]

return conditions, cond_blocks


def get_expr_length(builder: IRBuilder, expr: Expression) -> int | None:
if isinstance(expr, (StrExpr, BytesExpr)):
Expand Down
Loading