Skip to content

Commit cfc30ed

Browse files
committed
WIP test against gds-api-spec
ref GDSA-144
1 parent 84191e4 commit cfc30ed

File tree

2 files changed

+221
-0
lines changed

2 files changed

+221
-0
lines changed
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import json
2+
from pathlib import Path
3+
from typing import Any, List, Optional
4+
5+
from pydantic import BaseModel, Field
6+
7+
8+
class TypeInfo(BaseModel):
9+
"""Represents type information for a parameter or return field."""
10+
11+
typeName: str = Field(alias="typeName")
12+
optional: bool = Field(alias="optional")
13+
14+
class Config:
15+
populate_by_name = True
16+
17+
18+
class Parameter(BaseModel):
19+
"""Represents a procedure parameter."""
20+
21+
name: str
22+
type: TypeInfo
23+
defaultValue: Optional[Any] = None
24+
25+
class Config:
26+
populate_by_name = True
27+
28+
29+
class ReturnField(BaseModel):
30+
"""Represents a return field from a procedure mode."""
31+
32+
name: str
33+
type: TypeInfo
34+
35+
class Config:
36+
populate_by_name = True
37+
38+
39+
class Mode(BaseModel, extra="forbid"):
40+
"""Represents an execution mode (stream, stats, mutate, write) for a procedure."""
41+
42+
mode: str
43+
parameters: List[Parameter]
44+
returnFields: List[ReturnField]
45+
46+
class Config:
47+
populate_by_name = True
48+
49+
50+
class Procedure(BaseModel, extra="forbid"):
51+
"""Represents a GDS procedure with its parameters and modes."""
52+
53+
name: str
54+
parameters: List[Parameter]
55+
modes: List[Mode]
56+
57+
class Config:
58+
populate_by_name = True
59+
60+
def parameters_for_mode(self, mode_name: str) -> List[Parameter]:
61+
"""Get the parameters for a specific mode."""
62+
result = self.parameters.copy()
63+
for mode in self.modes:
64+
if mode.mode == mode_name:
65+
result.extend(mode.parameters)
66+
return result
67+
raise ValueError(
68+
f"Mode '{mode_name}' not found in procedure '{self.name}'. Available modes: {[m.mode for m in self.modes]}."
69+
)
70+
71+
72+
def resolve_spec_from_file(file_path: Path) -> list[Procedure]:
73+
"""
74+
Load and parse the gds-api-spec.json file.
75+
76+
Args:
77+
file_path: Path to the gds-api-spec.json file.
78+
If None, uses the default location in the repository root.
79+
80+
Returns:
81+
GdsApiSpec: Parsed API specification containing all procedures.
82+
"""
83+
with open(file_path, "r") as f:
84+
data = json.load(f)
85+
86+
# The JSON file is a list of procedures at the root level
87+
procedures = [Procedure(**proc) for proc in data]
88+
89+
return procedures
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
from pathlib import Path
2+
import re
3+
from collections import defaultdict
4+
from typing import Any
5+
6+
from pydantic import BaseModel
7+
import pytest
8+
9+
from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient
10+
from graphdatascience.session.session_v2_endpoints import SessionV2Endpoints
11+
from graphdatascience.tests.integrationV2.procedure_surface.session.gds_api_spec import (
12+
GdsApiSpec,
13+
resolve_spec_from_file,
14+
)
15+
16+
MISSING_ENDPOINTS = set()
17+
18+
# mapping of the snake-cased version of endpoint parts to the actual attribute names in SessionV2Endpoints
19+
ENDPOINT_MAPPINGS = {
20+
# centrality algos
21+
"betweenness": "betweenness_centrality",
22+
"celf": "influence_maximization_celf",
23+
"closeness": "closeness_centrality",
24+
"degree": "degree_centrality",
25+
"eigenvector": "eigenvector_centrality",
26+
"harmonic": "harmonic_centrality",
27+
# community algos
28+
"cliquecounting": "clique_counting",
29+
"k1coloring": "k1_coloring",
30+
"kcore": "k_core_decomposition",
31+
"maxkcut": "max_k_cut",
32+
# embedding algos
33+
"fastrp": "fast_rp",
34+
"graphSage": "graphsage",
35+
"hashgnn": "hash_gnn",
36+
# pathfinding algos
37+
"source_target": "shortest_path",
38+
"single_source": "all_shortest_path",
39+
"delta_stepping": "delta",
40+
"kspanning_tree": "k_spanning_tree",
41+
"prizesteiner_tree": "prize_steiner_tree",
42+
"spanning_tree": "spanning_tree",
43+
"steiner_tree": "steiner_tree",
44+
}
45+
46+
47+
@pytest.fixture
48+
def endpoints(arrow_client: AuthenticatedArrowClient) -> SessionV2Endpoints:
49+
return SessionV2Endpoints(arrow_client, db_client=None, show_progress=False)
50+
51+
52+
def to_snake(camel: str) -> str:
53+
# adjusted version of pydantic.alias_generators.to_snake (without digit handling)
54+
55+
# Handle the sequence of uppercase letters followed by a lowercase letter
56+
snake = re.sub(r"([A-Z]+)([A-Z][a-z])", lambda m: f"{m.group(1)}_{m.group(2)}", camel)
57+
# Insert an underscore between a lowercase letter and an uppercase letter
58+
snake = re.sub(r"([a-z])([A-Z])", lambda m: f"{m.group(1)}_{m.group(2)}", snake)
59+
# Replace hyphens with underscores to handle kebab-case
60+
snake = snake.replace("-", "_")
61+
return snake.lower()
62+
63+
64+
def resolve_callable_object(endpoints: SessionV2Endpoints, endpoint: str) -> Any | None:
65+
"""Check if an algorithm is available through gds.v2 interface"""
66+
67+
endpoint_parts = endpoint.split(".")
68+
endpoint_parts = [to_snake(part) for part in endpoint_parts]
69+
# algo_parts = [ENDPOINT_MAPPINGS.get(part, part) for part in algo_parts]
70+
71+
callable_object = endpoints
72+
for endpoint_part in endpoint_parts:
73+
# Get the algorithm endpoint
74+
if not hasattr(callable_object, endpoint_part):
75+
return None
76+
77+
callable_object = getattr(callable_object, endpoint_part)
78+
79+
if not callable(callable_object):
80+
raise ValueError(f"Resolved object {callable_object} for endpoint {endpoint} is not callable")
81+
82+
return callable_object
83+
84+
85+
# TODO how to fetch to json? it is not published anywhere yet? (could be published as part of the release?)
86+
def test_api_spec_coverage(endpoints: SessionV2Endpoints) -> None:
87+
# Get all available Arrow actions
88+
api_spec = resolve_spec_from_file(Path("/Users/florentin/repos/graph-data-science-client/gds-api-spec.json"))
89+
90+
algo_prefixes = ["pathfinding", "centrality", "community", "similarity", "embedding"]
91+
# Filter to only v2 algorithm actions (exclude graph, model, catalog operations)
92+
algorithm_actions: set[str] = {
93+
action for action in api_spec.procedures if any(action.startswith(prefix) for prefix in algo_prefixes)
94+
}
95+
96+
missing_endpoints: set[str] = set()
97+
available_endpoints: set[str] = set()
98+
99+
algos_per_category = defaultdict(list)
100+
for action in algorithm_actions:
101+
category, algo_parts = action.split(".", maxsplit=1)
102+
algos_per_category[category].append(algo_parts)
103+
104+
for category, algos in algos_per_category.items():
105+
for algo in algos:
106+
callable_object = resolve_callable_object(
107+
endpoints,
108+
algo,
109+
)
110+
if not callable_object:
111+
missing_endpoints.add(f"{category}.{algo}")
112+
else:
113+
# TODO verify against gds-api spec
114+
returnFields =
115+
continue
116+
117+
# Print summary
118+
print("\nGDS API Spec Coverage Summary:")
119+
print(f"Total algorithm actions found: {len(algorithm_actions)}")
120+
print(f"Available through gds.v2: {len(available_endpoints)}")
121+
122+
# check if any previously missing algos are now available
123+
newly_available_endpoints = available_endpoints.intersection(MISSING_ENDPOINTS)
124+
assert not newly_available_endpoints, "Endpoints now available, please remove from MISSING_ENDPOINTS"
125+
126+
# check missing endpoints against known missing algos
127+
missing_endpoints = missing_endpoints.difference(MISSING_ENDPOINTS)
128+
assert not missing_endpoints, f"Unexpectedly missing endpoints {len(missing_endpoints)}"
129+
130+
131+
def get_api_spec() -> GdsApiSpec:
132+
return GdsApiSpec()

0 commit comments

Comments
 (0)