Skip to content

Commit 18df18e

Browse files
committed
backport: Implement OneOf Input Objects via @OneOf directive
Replicates graphql/graphql-js@29144f7
1 parent a4d66f5 commit 18df18e

25 files changed

+717
-10
lines changed

.mypy.ini

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[mypy]
2-
python_version = 3.7
2+
python_version = 3.9
33
check_untyped_defs = True
44
no_implicit_optional = True
55
strict_optional = True

src/graphql/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,7 @@
256256
GraphQLSkipDirective,
257257
GraphQLDeprecatedDirective,
258258
GraphQLSpecifiedByDirective,
259+
GraphQLOneOfDirective,
259260
# "Enum" of Type Kinds
260261
TypeKind,
261262
# Constant Deprecation Reason
@@ -488,6 +489,7 @@
488489
"GraphQLSkipDirective",
489490
"GraphQLDeprecatedDirective",
490491
"GraphQLSpecifiedByDirective",
492+
"GraphQLOneOfDirective",
491493
"TypeKind",
492494
"DEFAULT_DEPRECATION_REASON",
493495
"introspection_types",

src/graphql/execution/values.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,11 @@ def coerce_variable_values(
122122
continue
123123

124124
def on_input_value_error(
125-
path: List[Union[str, int]], invalid_value: Any, error: GraphQLError
125+
path: List[Union[str, int]],
126+
invalid_value: Any,
127+
error: GraphQLError,
128+
var_name: str = var_name,
129+
var_def_node: VariableDefinitionNode = var_def_node,
126130
) -> None:
127131
invalid_str = inspect(invalid_value)
128132
prefix = f"Variable '${var_name}' got invalid value {invalid_str}"
@@ -196,7 +200,8 @@ def get_argument_values(
196200
value_node,
197201
)
198202
continue # pragma: no cover
199-
is_null = variable_values[variable_name] is None
203+
variable_value = variable_values[variable_name]
204+
is_null = variable_value is None or variable_value is Undefined
200205

201206
if is_null and is_non_null_type(arg_type):
202207
raise GraphQLError(

src/graphql/type/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@
132132
GraphQLSkipDirective,
133133
GraphQLDeprecatedDirective,
134134
GraphQLSpecifiedByDirective,
135+
GraphQLOneOfDirective,
135136
# Keyword Args
136137
GraphQLDirectiveKwargs,
137138
# Constant Deprecation Reason
@@ -276,6 +277,7 @@
276277
"GraphQLSkipDirective",
277278
"GraphQLDeprecatedDirective",
278279
"GraphQLSpecifiedByDirective",
280+
"GraphQLOneOfDirective",
279281
"GraphQLDirectiveKwargs",
280282
"DEFAULT_DEPRECATION_REASON",
281283
"is_specified_scalar_type",

src/graphql/type/definition.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1373,6 +1373,7 @@ def __copy__(self) -> "GraphQLEnumValue": # pragma: no cover
13731373
class GraphQLInputObjectTypeKwargs(GraphQLNamedTypeKwargs, total=False):
13741374
fields: GraphQLInputFieldMap
13751375
out_type: Optional[GraphQLInputFieldOutType]
1376+
is_one_of: bool
13761377

13771378

13781379
class GraphQLInputObjectType(GraphQLNamedType):
@@ -1402,6 +1403,7 @@ class GeoPoint(GraphQLInputObjectType):
14021403

14031404
ast_node: Optional[InputObjectTypeDefinitionNode]
14041405
extension_ast_nodes: Tuple[InputObjectTypeExtensionNode, ...]
1406+
is_one_of: bool
14051407

14061408
def __init__(
14071409
self,
@@ -1412,6 +1414,7 @@ def __init__(
14121414
extensions: Optional[Dict[str, Any]] = None,
14131415
ast_node: Optional[InputObjectTypeDefinitionNode] = None,
14141416
extension_ast_nodes: Optional[Collection[InputObjectTypeExtensionNode]] = None,
1417+
is_one_of: bool = False,
14151418
) -> None:
14161419
super().__init__(
14171420
name=name,
@@ -1437,6 +1440,7 @@ def __init__(
14371440
self._fields = fields
14381441
if out_type is not None:
14391442
self.out_type = out_type # type: ignore
1443+
self.is_one_of = is_one_of
14401444

14411445
@staticmethod
14421446
def out_type(value: Dict[str, Any]) -> Any:
@@ -1456,6 +1460,7 @@ def to_kwargs(self) -> GraphQLInputObjectTypeKwargs:
14561460
if self.out_type is GraphQLInputObjectType.out_type
14571461
else self.out_type
14581462
),
1463+
is_one_of=self.is_one_of,
14591464
)
14601465

14611466
def __copy__(self) -> "GraphQLInputObjectType": # pragma: no cover

src/graphql/type/directives.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
"GraphQLSkipDirective",
2323
"GraphQLDeprecatedDirective",
2424
"GraphQLSpecifiedByDirective",
25+
"GraphQLOneOfDirective",
2526
"DirectiveLocation",
2627
"DEFAULT_DEPRECATION_REASON",
2728
]
@@ -237,12 +238,21 @@ def assert_directive(directive: Any) -> GraphQLDirective:
237238
description="Exposes a URL that specifies the behavior of this scalar.",
238239
)
239240

241+
# Used to declare an Input Object as a OneOf Input Objects.
242+
GraphQLOneOfDirective = GraphQLDirective(
243+
name="oneOf",
244+
locations=[DirectiveLocation.INPUT_OBJECT],
245+
args={},
246+
description="Indicates an Input Object is a OneOf Input Object.",
247+
)
248+
240249

241250
specified_directives: Tuple[GraphQLDirective, ...] = (
242251
GraphQLIncludeDirective,
243252
GraphQLSkipDirective,
244253
GraphQLDeprecatedDirective,
245254
GraphQLSpecifiedByDirective,
255+
GraphQLOneOfDirective,
246256
)
247257
"""A tuple with all directives from the GraphQL specification"""
248258

src/graphql/type/introspection.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,7 @@ def args(directive, _info, includeDeprecated=False):
282282
resolve=TypeResolvers.input_fields,
283283
),
284284
"ofType": GraphQLField(__Type, resolve=TypeResolvers.of_type),
285+
"isOneOf": GraphQLField(GraphQLBoolean, resolve=TypeResolvers.is_one_of),
285286
},
286287
)
287288

@@ -368,6 +369,10 @@ def input_fields(type_, _info, includeDeprecated=False):
368369
def of_type(type_, _info):
369370
return getattr(type_, "of_type", None)
370371

372+
@staticmethod
373+
def is_one_of(type_, _info):
374+
return type_.is_one_of if is_input_object_type(type_) else None
375+
371376

372377
__Field: GraphQLObjectType = GraphQLObjectType(
373378
name="__Field",

src/graphql/type/validate.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
)
1313

1414
from ..error import GraphQLError
15-
from ..pyutils import inspect
15+
from ..pyutils import Undefined, inspect
1616
from ..language import (
1717
DirectiveNode,
1818
InputValueDefinitionNode,
@@ -493,6 +493,28 @@ def validate_input_fields(self, input_obj: GraphQLInputObjectType) -> None:
493493
],
494494
)
495495

496+
if input_obj.is_one_of:
497+
self.validate_one_of_input_object_field(input_obj, field_name, field)
498+
499+
def validate_one_of_input_object_field(
500+
self,
501+
type_: GraphQLInputObjectType,
502+
field_name: str,
503+
field: GraphQLInputField,
504+
) -> None:
505+
if is_non_null_type(field.type):
506+
self.report_error(
507+
f"OneOf input field {type_.name}.{field_name} must be nullable.",
508+
field.ast_node and field.ast_node.type,
509+
)
510+
511+
if field.default_value is not Undefined:
512+
self.report_error(
513+
f"OneOf input field {type_.name}.{field_name}"
514+
" cannot have a default value.",
515+
field.ast_node,
516+
)
517+
496518

497519
def get_operation_type_node(
498520
schema: GraphQLSchema, operation: OperationType

src/graphql/utilities/coerce_input_value.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,30 @@ def coerce_input_value(
127127
+ did_you_mean(suggestions)
128128
),
129129
)
130+
131+
if type_.is_one_of:
132+
keys = list(coerced_dict)
133+
if len(keys) != 1:
134+
on_error(
135+
path.as_list() if path else [],
136+
input_value,
137+
GraphQLError(
138+
"Exactly one key must be specified"
139+
f" for OneOf type '{type_.name}'.",
140+
),
141+
)
142+
else:
143+
key = keys[0]
144+
value = coerced_dict[key]
145+
if value is None:
146+
on_error(
147+
(path.as_list() if path else []) + [key],
148+
value,
149+
GraphQLError(
150+
f"Field '{key}' must be non-null.",
151+
),
152+
)
153+
130154
return type_.out_type(coerced_dict)
131155

132156
if is_leaf_type(type_):

src/graphql/utilities/extend_schema.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
GraphQLSchema,
6868
GraphQLSchemaKwargs,
6969
GraphQLSpecifiedByDirective,
70+
GraphQLOneOfDirective,
7071
GraphQLType,
7172
GraphQLUnionType,
7273
assert_schema,
@@ -605,6 +606,7 @@ def build_input_object_type(
605606
fields=lambda: build_input_field_map(all_nodes),
606607
ast_node=ast_node,
607608
extension_ast_nodes=extension_nodes,
609+
is_one_of=is_one_of(ast_node),
608610
)
609611

610612
build_type_for_kind = cast(
@@ -698,3 +700,10 @@ def get_specified_by_url(
698700

699701
specified_by_url = get_directive_values(GraphQLSpecifiedByDirective, node)
700702
return specified_by_url["url"] if specified_by_url else None
703+
704+
705+
def is_one_of(node: InputObjectTypeDefinitionNode) -> bool:
706+
"""Given an input object node, returns if the node should be OneOf."""
707+
from ..execution import get_directive_values
708+
709+
return get_directive_values(GraphQLOneOfDirective, node) is not None

0 commit comments

Comments
 (0)