1+ import json
12from typing import Any , List
23
34from ..error .illegal_attr_checker import IllegalAttrChecker
45from ..error .uncallable_namespace import UncallableNamespace
5- import json
66
77
88class GNNNodeClassificationRunner (UncallableNamespace , IllegalAttrChecker ):
9- def train (self , graph_name : str , model_name : str , feature_properties : List [str ], target_property : str ,
10- target_node_label : str = None , node_labels : List [str ] = None ) -> "Series[Any]" :
9+ def train (
10+ self ,
11+ graph_name : str ,
12+ model_name : str ,
13+ feature_properties : List [str ],
14+ target_property : str ,
15+ target_node_label : str = None ,
16+ node_labels : List [str ] = None ,
17+ ) -> "Series[Any]" :
1118 configMap = {
1219 "featureProperties" : feature_properties ,
1320 "targetProperty" : target_property ,
1421 "job_type" : "train" ,
1522 }
23+
1624 node_properties = feature_properties + [target_property ]
1725 if target_node_label :
1826 configMap ["targetNodeLabel" ] = target_node_label
1927 mlTrainingConfig = json .dumps (configMap )
2028 # TODO query available node labels
2129 node_labels = ["Paper" ] if not node_labels else node_labels
22- self ._query_runner .run_query (f"CALL gds.upload.graph('{ graph_name } ', {{mlTrainingConfig: '{ mlTrainingConfig } ', modelName: '{ model_name } ', nodeLabels: { node_labels } , nodeProperties: { node_properties } }})" )
23-
30+ self ._query_runner .run_query (
31+ f"CALL gds.upload.graph('{ graph_name } ', {{mlTrainingConfig: '{ mlTrainingConfig } ', modelName: '{ model_name } ', nodeLabels: { node_labels } , nodeProperties: { node_properties } }})"
32+ )
2433
25- def predict (self , graph_name : str , model_name : str , feature_properties : List [str ], target_node_label : str = None , node_labels : List [str ] = None ) -> "Series[Any]" :
34+ def predict (
35+ self ,
36+ graph_name : str ,
37+ model_name : str ,
38+ feature_properties : List [str ],
39+ target_node_label : str = None ,
40+ node_labels : List [str ] = None ,
41+ ) -> "Series[Any]" :
2642 configMap = {
2743 "featureProperties" : feature_properties ,
2844 "job_type" : "predict" ,
@@ -33,4 +49,5 @@ def predict(self, graph_name: str, model_name: str, feature_properties: List[str
3349 # TODO query available node labels
3450 node_labels = ["Paper" ] if not node_labels else node_labels
3551 self ._query_runner .run_query (
36- f"CALL gds.upload.graph('{ graph_name } ', {{mlTrainingConfig: '{ mlTrainingConfig } ', modelName: '{ model_name } ', nodeLabels: { node_labels } , nodeProperties: { feature_properties } }})" )
52+ f"CALL gds.upload.graph('{ graph_name } ', {{mlTrainingConfig: '{ mlTrainingConfig } ', modelName: '{ model_name } ', nodeLabels: { node_labels } , nodeProperties: { feature_properties } }})"
53+ )
0 commit comments