|
| 1 | +import json |
1 | 2 | from typing import Any, List |
2 | 3 |
|
3 | 4 | from ..error.illegal_attr_checker import IllegalAttrChecker |
4 | 5 | from ..error.uncallable_namespace import UncallableNamespace |
5 | | -import json |
6 | 6 |
|
7 | 7 |
|
8 | 8 | class GNNNodeClassificationRunner(UncallableNamespace, IllegalAttrChecker): |
9 | | - def train(self, graph_name: str, model_name: str, feature_properties: List[str], target_property: str, |
10 | | - target_node_label: str = None, node_labels: List[str] = None) -> "Series[Any]": |
11 | | - configMap = { |
| 9 | + def train( |
| 10 | + self, |
| 11 | + graph_name: str, |
| 12 | + model_name: str, |
| 13 | + feature_properties: List[str], |
| 14 | + target_property: str, |
| 15 | + target_node_label: str = None, |
| 16 | + node_labels: List[str] = None, |
| 17 | + ) -> "Series[Any]": # noqa: F821 |
| 18 | + mlConfigMap = { |
12 | 19 | "featureProperties": feature_properties, |
13 | 20 | "targetProperty": target_property, |
14 | 21 | "job_type": "train", |
15 | | - "nodeProperties": feature_properties + [target_property] |
| 22 | + "nodeProperties": feature_properties + [target_property], |
16 | 23 | } |
17 | 24 |
|
18 | 25 | if target_node_label: |
19 | | - configMap["targetNodeLabel"] = target_node_label |
| 26 | + mlConfigMap["targetNodeLabel"] = target_node_label |
20 | 27 | if node_labels: |
21 | | - configMap["nodeLabels"] = node_labels |
| 28 | + mlConfigMap["nodeLabels"] = node_labels |
22 | 29 |
|
23 | | - mlTrainingConfig = json.dumps(configMap) |
| 30 | + mlTrainingConfig = json.dumps(mlConfigMap) |
24 | 31 |
|
25 | 32 | # token and uri will be injected by arrow_query_runner |
26 | 33 | self._query_runner.run_query( |
27 | | - f"CALL gds.upload.graph($graph_name, $config)", |
28 | | - params={"graph_name": graph_name, "config": { |
29 | | - "mlTrainingConfig": mlTrainingConfig, |
30 | | - "modelName": model_name |
31 | | - }} |
32 | | - ) |
33 | | - |
| 34 | + "CALL gds.upload.graph($graph_name, $config)", |
| 35 | + params={ |
| 36 | + "graph_name": graph_name, |
| 37 | + "config": {"mlTrainingConfig": mlTrainingConfig, "modelName": model_name}, |
| 38 | + }, |
| 39 | + ) |
34 | 40 |
|
35 | | - def predict(self, graph_name: str, model_name: str, feature_properties: List[str], target_node_label: str = None, node_labels: List[str] = None) -> "Series[Any]": |
36 | | - configMap = { |
| 41 | + def predict( |
| 42 | + self, |
| 43 | + graph_name: str, |
| 44 | + model_name: str, |
| 45 | + feature_properties: List[str], |
| 46 | + target_node_label: str = None, |
| 47 | + node_labels: List[str] = None, |
| 48 | + ) -> "Series[Any]": # noqa: F821 |
| 49 | + mlConfigMap = { |
37 | 50 | "featureProperties": feature_properties, |
38 | 51 | "job_type": "predict", |
| 52 | + "nodeProperties": feature_properties, |
39 | 53 | } |
40 | 54 | if target_node_label: |
41 | | - configMap["targetNodeLabel"] = target_node_label |
42 | | - mlTrainingConfig = json.dumps(configMap) |
43 | | - # TODO query available node labels |
44 | | - node_labels = ["Paper"] if not node_labels else node_labels |
| 55 | + mlConfigMap["targetNodeLabel"] = target_node_label |
| 56 | + if node_labels: |
| 57 | + mlConfigMap["nodeLabels"] = node_labels |
| 58 | + |
| 59 | + mlTrainingConfig = json.dumps(mlConfigMap) |
45 | 60 | self._query_runner.run_query( |
46 | | - f"CALL gds.upload.graph('{graph_name}', {{mlTrainingConfig: '{mlTrainingConfig}', modelName: '{model_name}', nodeLabels: {node_labels}, nodeProperties: {feature_properties}}})") |
| 61 | + "CALL gds.upload.graph($graph_name, $config)", |
| 62 | + params={ |
| 63 | + "graph_name": graph_name, |
| 64 | + "config": {"mlTrainingConfig": mlTrainingConfig, "modelName": model_name}, |
| 65 | + }, |
| 66 | + ) # type: ignore |
0 commit comments