1+ import enum
12import warnings
23from functools import partial
34
4- import six
55from promise import Promise , is_thenable
66from sqlalchemy .orm .query import Query
77
88from graphene import NonNull
99from graphene .relay import Connection , ConnectionField
10- from graphene .relay .connection import PageInfo
11- from graphql_relay .connection .arrayconnection import connection_from_list_slice
10+ from graphene .relay .connection import connection_adapter , page_info_adapter
11+ from graphql_relay .connection .arrayconnection import \
12+ connection_from_array_slice
1213
1314from .batching import get_batch_resolver
14- from .utils import get_query
15+ from .utils import EnumValue , get_query
1516
1617
1718class UnsortedSQLAlchemyConnectionField (ConnectionField ):
1819 @property
1920 def type (self ):
2021 from .types import SQLAlchemyObjectType
2122
22- _type = super (ConnectionField , self ).type
23- nullable_type = get_nullable_type (_type )
23+ type_ = super (ConnectionField , self ).type
24+ nullable_type = get_nullable_type (type_ )
2425 if issubclass (nullable_type , Connection ):
25- return _type
26+ return type_
2627 assert issubclass (nullable_type , SQLAlchemyObjectType ), (
2728 "SQLALchemyConnectionField only accepts SQLAlchemyObjectType types, not {}"
2829 ).format (nullable_type .__name__ )
@@ -31,7 +32,7 @@ def type(self):
3132 ), "The type {} doesn't have a connection" .format (
3233 nullable_type .__name__
3334 )
34- assert _type == nullable_type , (
35+ assert type_ == nullable_type , (
3536 "Passing a SQLAlchemyObjectType instance is deprecated. "
3637 "Pass the connection type instead accessible via SQLAlchemyObjectType.connection"
3738 )
@@ -53,15 +54,19 @@ def resolve_connection(cls, connection_type, model, info, args, resolved):
5354 _len = resolved .count ()
5455 else :
5556 _len = len (resolved )
56- connection = connection_from_list_slice (
57- resolved ,
58- args ,
57+
58+ def adjusted_connection_adapter (edges , pageInfo ):
59+ return connection_adapter (connection_type , edges , pageInfo )
60+
61+ connection = connection_from_array_slice (
62+ array_slice = resolved ,
63+ args = args ,
5964 slice_start = 0 ,
60- list_length = _len ,
61- list_slice_length = _len ,
62- connection_type = connection_type ,
63- pageinfo_type = PageInfo ,
65+ array_length = _len ,
66+ array_slice_length = _len ,
67+ connection_type = adjusted_connection_adapter ,
6468 edge_type = connection_type .Edge ,
69+ page_info_type = page_info_adapter ,
6570 )
6671 connection .iterable = resolved
6772 connection .length = _len
@@ -77,7 +82,7 @@ def connection_resolver(cls, resolver, connection_type, model, root, info, **arg
7782
7883 return on_resolve (resolved )
7984
80- def get_resolver (self , parent_resolver ):
85+ def wrap_resolve (self , parent_resolver ):
8186 return partial (
8287 self .connection_resolver ,
8388 parent_resolver ,
@@ -88,8 +93,8 @@ def get_resolver(self, parent_resolver):
8893
8994# TODO Rename this to SortableSQLAlchemyConnectionField
9095class SQLAlchemyConnectionField (UnsortedSQLAlchemyConnectionField ):
91- def __init__ (self , type , * args , ** kwargs ):
92- nullable_type = get_nullable_type (type )
96+ def __init__ (self , type_ , * args , ** kwargs ):
97+ nullable_type = get_nullable_type (type_ )
9398 if "sort" not in kwargs and issubclass (nullable_type , Connection ):
9499 # Let super class raise if type is not a Connection
95100 try :
@@ -103,16 +108,25 @@ def __init__(self, type, *args, **kwargs):
103108 )
104109 elif "sort" in kwargs and kwargs ["sort" ] is None :
105110 del kwargs ["sort" ]
106- super (SQLAlchemyConnectionField , self ).__init__ (type , * args , ** kwargs )
111+ super (SQLAlchemyConnectionField , self ).__init__ (type_ , * args , ** kwargs )
107112
108113 @classmethod
109114 def get_query (cls , model , info , sort = None , ** args ):
110115 query = get_query (model , info .context )
111116 if sort is not None :
112- if isinstance (sort , six .string_types ):
113- query = query .order_by (sort .value )
114- else :
115- query = query .order_by (* (col .value for col in sort ))
117+ if not isinstance (sort , list ):
118+ sort = [sort ]
119+ sort_args = []
120+ # ensure consistent handling of graphene Enums, enum values and
121+ # plain strings
122+ for item in sort :
123+ if isinstance (item , enum .Enum ):
124+ sort_args .append (item .value .value )
125+ elif isinstance (item , EnumValue ):
126+ sort_args .append (item .value )
127+ else :
128+ sort_args .append (item )
129+ query = query .order_by (* sort_args )
116130 return query
117131
118132
@@ -123,7 +137,7 @@ class BatchSQLAlchemyConnectionField(UnsortedSQLAlchemyConnectionField):
123137 Use at your own risk.
124138 """
125139
126- def get_resolver (self , parent_resolver ):
140+ def wrap_resolve (self , parent_resolver ):
127141 return partial (
128142 self .connection_resolver ,
129143 self .resolver ,
@@ -148,13 +162,13 @@ def default_connection_field_factory(relationship, registry, **field_kwargs):
148162__connectionFactory = UnsortedSQLAlchemyConnectionField
149163
150164
151- def createConnectionField (_type , ** field_kwargs ):
165+ def createConnectionField (type_ , ** field_kwargs ):
152166 warnings .warn (
153167 'createConnectionField is deprecated and will be removed in the next '
154168 'major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.' ,
155169 DeprecationWarning ,
156170 )
157- return __connectionFactory (_type , ** field_kwargs )
171+ return __connectionFactory (type_ , ** field_kwargs )
158172
159173
160174def registerConnectionFieldFactory (factoryMethod ):
0 commit comments