Skip to content

Commit c539eed

Browse files
committed
Update caching logic
1 parent cd0660f commit c539eed

File tree

1 file changed

+7
-8
lines changed

1 file changed

+7
-8
lines changed

src/openai/_models.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
override,
2020
runtime_checkable,
2121
)
22+
from weakref import WeakKeyDictionary
2223

2324
import pydantic
2425
from pydantic.fields import FieldInfo
@@ -77,6 +78,8 @@
7778

7879
ReprArgs = Sequence[Tuple[Optional[str], Any]]
7980

81+
_DISCRIMINATOR_CACHE: "WeakKeyDictionary[type, DiscriminatorDetails]" = WeakKeyDictionary()
82+
8083

8184
@runtime_checkable
8285
class _ConfigProtocol(Protocol):
@@ -593,11 +596,6 @@ def construct_type(*, value: object, type_: object, metadata: Optional[List[Any]
593596
return value
594597

595598

596-
@runtime_checkable
597-
class CachedDiscriminatorType(Protocol):
598-
__discriminator__: DiscriminatorDetails
599-
600-
601599
class DiscriminatorDetails:
602600
field_name: str
603601
"""The name of the discriminator field in the variant class, e.g.
@@ -640,8 +638,9 @@ def __init__(
640638

641639

642640
def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, ...]) -> DiscriminatorDetails | None:
643-
if isinstance(union, CachedDiscriminatorType):
644-
return union.__discriminator__
641+
cached_discriminator = _DISCRIMINATOR_CACHE.get(union)
642+
if cached_discriminator is not None:
643+
return cached_discriminator
645644

646645
discriminator_field_name: str | None = None
647646

@@ -694,7 +693,7 @@ def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any,
694693
discriminator_field=discriminator_field_name,
695694
discriminator_alias=discriminator_alias,
696695
)
697-
cast(CachedDiscriminatorType, Annotated[union, details])
696+
_DISCRIMINATOR_CACHE[union] = details
698697
return details
699698

700699

0 commit comments

Comments
 (0)