|
2 | 2 |
|
3 | 3 | import os |
4 | 4 | import inspect |
5 | | -from typing import TYPE_CHECKING, Any, Annotated, Type, Tuple, Union, Generic, TypeVar, Callable, Optional, cast |
| 5 | +from typing import TYPE_CHECKING, Any, Type, Tuple, Union, Generic, TypeVar, Callable, Optional, cast |
| 6 | +from weakref import WeakKeyDictionary |
6 | 7 | from datetime import date, datetime |
7 | 8 | from typing_extensions import ( |
8 | 9 | List, |
|
77 | 78 |
|
78 | 79 | ReprArgs = Sequence[Tuple[Optional[str], Any]] |
79 | 80 |
|
| 81 | +_DISCRIMINATOR_CACHE: "WeakKeyDictionary[type, DiscriminatorDetails]" = WeakKeyDictionary() |
| 82 | + |
80 | 83 |
|
81 | 84 | @runtime_checkable |
82 | 85 | class _ConfigProtocol(Protocol): |
@@ -593,11 +596,6 @@ def construct_type(*, value: object, type_: object, metadata: Optional[List[Any] |
593 | 596 | return value |
594 | 597 |
|
595 | 598 |
|
596 | | -@runtime_checkable |
597 | | -class CachedDiscriminatorType(Protocol): |
598 | | - __discriminator__: DiscriminatorDetails |
599 | | - |
600 | | - |
601 | 599 | class DiscriminatorDetails: |
602 | 600 | field_name: str |
603 | 601 | """The name of the discriminator field in the variant class, e.g. |
@@ -640,8 +638,9 @@ def __init__( |
640 | 638 |
|
641 | 639 |
|
642 | 640 | def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, ...]) -> DiscriminatorDetails | None: |
643 | | - if isinstance(union, CachedDiscriminatorType): |
644 | | - return union.__discriminator__ |
| 641 | + cached_discriminator = _DISCRIMINATOR_CACHE.get(union) |
| 642 | + if cached_discriminator is not None: |
| 643 | + return cached_discriminator |
645 | 644 |
|
646 | 645 | discriminator_field_name: str | None = None |
647 | 646 |
|
@@ -694,7 +693,7 @@ def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, |
694 | 693 | discriminator_field=discriminator_field_name, |
695 | 694 | discriminator_alias=discriminator_alias, |
696 | 695 | ) |
697 | | - cast(CachedDiscriminatorType, Annotated[union, details]) |
| 696 | + _DISCRIMINATOR_CACHE[union] = details |
698 | 697 | return details |
699 | 698 |
|
700 | 699 |
|
|
0 commit comments