Skip to content

Commit 84191e4

Browse files
committed
Map k spanning tree
1 parent 9274e63 commit 84191e4

File tree

7 files changed

+357
-9
lines changed

7 files changed

+357
-9
lines changed
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
from __future__ import annotations
2+
3+
from abc import ABC, abstractmethod
4+
from typing import Any
5+
6+
from graphdatascience.procedure_surface.api.base_result import BaseResult
7+
from graphdatascience.procedure_surface.api.catalog.graph_api import GraphV2
8+
9+
10+
class KSpanningTreeWriteResult(BaseResult):
11+
effective_node_count: int
12+
write_millis: int
13+
post_processing_millis: int
14+
pre_processing_millis: int
15+
compute_millis: int
16+
configuration: dict[str, Any]
17+
18+
19+
class KSpanningTreeEndpoints(ABC):
20+
@abstractmethod
21+
def write(
22+
self,
23+
G: GraphV2,
24+
k: int,
25+
write_property: str,
26+
source_node: int,
27+
relationship_weight_property: str | None = None,
28+
objective: str = "minimum",
29+
relationship_types: list[str] | None = None,
30+
node_labels: list[str] | None = None,
31+
sudo: bool = False,
32+
log_progress: bool = True,
33+
username: str | None = None,
34+
concurrency: int | None = None,
35+
job_id: str | None = None,
36+
write_concurrency: int | None = None,
37+
) -> KSpanningTreeWriteResult:
38+
"""
39+
Runs the k-Spanning tree algorithm and writes the result back to the Neo4j database.
40+
41+
Parameters
42+
----------
43+
G : GraphV2
44+
The graph to run the algorithm on.
45+
k : int
46+
The number of spanning trees to compute.
47+
write_property : str
48+
The property name to store the edge weight.
49+
source_node : int
50+
The source node (root) for the k-Spanning trees.
51+
relationship_weight_property : str, optional
52+
The name of the relationship property to use as weights.
53+
objective : str, default="minimum"
54+
The objective function to optimize. Either "minimum" or "maximum".
55+
relationship_types : list[str], optional
56+
Filter to only use relationships of specific types.
57+
node_labels : list[str], optional
58+
Filter to only use nodes with specific labels.
59+
sudo : bool, default=False
60+
Whether to run with elevated privileges.
61+
log_progress : bool, default=True
62+
Whether to log progress during execution.
63+
username : str, optional
64+
The username to use for logging.
65+
concurrency : int, optional
66+
The number of threads to use for parallel computation.
67+
job_id : str, optional
68+
An optional job ID for tracking the operation.
69+
write_concurrency : int, optional
70+
The number of threads to use for writing results.
71+
72+
Returns
73+
-------
74+
KSpanningTreeWriteResult
75+
Result containing statistics and timing information.
76+
"""
77+
...
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
from __future__ import annotations
2+
3+
from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient
4+
from graphdatascience.arrow_client.v2.remote_write_back_client import RemoteWriteBackClient
5+
from graphdatascience.procedure_surface.api.catalog.graph_api import GraphV2
6+
from graphdatascience.procedure_surface.api.pathfinding.k_spanning_tree_endpoints import (
7+
KSpanningTreeEndpoints,
8+
KSpanningTreeWriteResult,
9+
)
10+
from graphdatascience.procedure_surface.arrow.node_property_endpoints import NodePropertyEndpointsHelper
11+
12+
13+
class KSpanningTreeArrowEndpoints(KSpanningTreeEndpoints):
14+
def __init__(
15+
self,
16+
arrow_client: AuthenticatedArrowClient,
17+
write_back_client: RemoteWriteBackClient | None = None,
18+
show_progress: bool = False,
19+
):
20+
self._endpoints_helper = NodePropertyEndpointsHelper(
21+
arrow_client, write_back_client=write_back_client, show_progress=show_progress
22+
)
23+
24+
def write(
25+
self,
26+
G: GraphV2,
27+
k: int,
28+
write_property: str,
29+
source_node: int,
30+
relationship_weight_property: str | None = None,
31+
objective: str = "minimum",
32+
relationship_types: list[str] | None = None,
33+
node_labels: list[str] | None = None,
34+
sudo: bool = False,
35+
log_progress: bool = True,
36+
username: str | None = None,
37+
concurrency: int | None = None,
38+
job_id: str | None = None,
39+
write_concurrency: int | None = None,
40+
) -> KSpanningTreeWriteResult:
41+
config = self._endpoints_helper.create_base_config(
42+
G,
43+
k=k,
44+
sourceNode=source_node,
45+
relationshipWeightProperty=relationship_weight_property,
46+
objective=objective,
47+
relationshipTypes=relationship_types,
48+
nodeLabels=node_labels,
49+
sudo=sudo,
50+
logProgress=log_progress,
51+
username=username,
52+
concurrency=concurrency,
53+
jobId=job_id,
54+
writeConcurrency=write_concurrency,
55+
)
56+
57+
result = self._endpoints_helper.run_job_and_write(
58+
"v2/pathfinding.kSpanningTree",
59+
G,
60+
config,
61+
property_overwrites={write_property: write_property},
62+
write_concurrency=write_concurrency,
63+
concurrency=None,
64+
)
65+
66+
return KSpanningTreeWriteResult(**result)
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
from __future__ import annotations
2+
3+
from graphdatascience.call_parameters import CallParameters
4+
from graphdatascience.procedure_surface.api.catalog.graph_api import GraphV2
5+
from graphdatascience.procedure_surface.api.pathfinding.k_spanning_tree_endpoints import (
6+
KSpanningTreeEndpoints,
7+
KSpanningTreeWriteResult,
8+
)
9+
from graphdatascience.procedure_surface.utils.config_converter import ConfigConverter
10+
from graphdatascience.query_runner.query_runner import QueryRunner
11+
12+
13+
class KSpanningTreeCypherEndpoints(KSpanningTreeEndpoints):
14+
def __init__(self, query_runner: QueryRunner):
15+
self._query_runner = query_runner
16+
17+
def write(
18+
self,
19+
G: GraphV2,
20+
k: int,
21+
write_property: str,
22+
source_node: int,
23+
relationship_weight_property: str | None = None,
24+
objective: str = "minimum",
25+
relationship_types: list[str] | None = None,
26+
node_labels: list[str] | None = None,
27+
sudo: bool = False,
28+
log_progress: bool = True,
29+
username: str | None = None,
30+
concurrency: int | None = None,
31+
job_id: str | None = None,
32+
write_concurrency: int | None = None,
33+
) -> KSpanningTreeWriteResult:
34+
config = ConfigConverter.convert_to_gds_config(
35+
k=k,
36+
writeProperty=write_property,
37+
sourceNode=source_node,
38+
relationshipWeightProperty=relationship_weight_property,
39+
objective=objective,
40+
relationshipTypes=relationship_types,
41+
nodeLabels=node_labels,
42+
sudo=sudo,
43+
logProgress=log_progress,
44+
username=username,
45+
concurrency=concurrency,
46+
jobId=job_id,
47+
writeConcurrency=write_concurrency,
48+
)
49+
params = CallParameters(graph_name=G.name(), config=config)
50+
params.ensure_job_id_in_config()
51+
52+
result = self._query_runner.call_procedure(
53+
"gds.kSpanningTree.write", params=params, logging=log_progress
54+
).squeeze()
55+
56+
return KSpanningTreeWriteResult(**result)

graphdatascience/session/session_v2_endpoints.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from graphdatascience.procedure_surface.api.node_embedding.hashgnn_endpoints import HashGNNEndpoints
3434
from graphdatascience.procedure_surface.api.node_embedding.node2vec_endpoints import Node2VecEndpoints
3535
from graphdatascience.procedure_surface.api.pathfinding.all_shortest_path_endpoints import AllShortestPathEndpoints
36+
from graphdatascience.procedure_surface.api.pathfinding.k_spanning_tree_endpoints import KSpanningTreeEndpoints
3637
from graphdatascience.procedure_surface.api.pathfinding.prize_steiner_tree_endpoints import PrizeSteinerTreeEndpoints
3738
from graphdatascience.procedure_surface.api.pathfinding.shortest_path_endpoints import ShortestPathEndpoints
3839
from graphdatascience.procedure_surface.api.pathfinding.spanning_tree_endpoints import SpanningTreeEndpoints
@@ -90,6 +91,9 @@
9091
from graphdatascience.procedure_surface.arrow.pathfinding.all_shortest_path_arrow_endpoints import (
9192
AllShortestPathArrowEndpoints,
9293
)
94+
from graphdatascience.procedure_surface.arrow.pathfinding.k_spanning_tree_arrow_endpoints import (
95+
KSpanningTreeArrowEndpoints,
96+
)
9397
from graphdatascience.procedure_surface.arrow.pathfinding.prize_steiner_tree_arrow_endpoints import (
9498
PrizeSteinerTreeArrowEndpoints,
9599
)
@@ -218,6 +222,12 @@ def kmeans(self) -> KMeansEndpoints:
218222
def knn(self) -> KnnEndpoints:
219223
return KnnArrowEndpoints(self._arrow_client, self._write_back_client, show_progress=self._show_progress)
220224

225+
@property
226+
def k_spanning_tree(self) -> KSpanningTreeEndpoints:
227+
return KSpanningTreeArrowEndpoints(
228+
self._arrow_client, self._write_back_client, show_progress=self._show_progress
229+
)
230+
221231
@property
222232
def label_propagation(self) -> LabelPropagationEndpoints:
223233
return LabelPropagationArrowEndpoints(
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
from typing import Generator
2+
3+
import pytest
4+
5+
from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient
6+
from graphdatascience.arrow_client.v2.remote_write_back_client import RemoteWriteBackClient
7+
from graphdatascience.procedure_surface.api.catalog.graph_api import GraphV2
8+
from graphdatascience.procedure_surface.arrow.pathfinding.k_spanning_tree_arrow_endpoints import (
9+
KSpanningTreeArrowEndpoints,
10+
)
11+
from graphdatascience.query_runner.query_runner import QueryRunner
12+
from graphdatascience.tests.integrationV2.procedure_surface.arrow.graph_creation_helper import (
13+
create_graph_from_db,
14+
)
15+
16+
graph = """
17+
CREATE
18+
(a: Node {id: 0}),
19+
(b: Node {id: 1}),
20+
(c: Node {id: 2}),
21+
(d: Node {id: 3}),
22+
(e: Node {id: 4}),
23+
(f: Node {id: 5}),
24+
(a)-[:LINK {cost: 1.0}]->(b),
25+
(a)-[:LINK {cost: 1.0}]->(c),
26+
(b)-[:LINK {cost: 1.0}]->(d),
27+
(c)-[:LINK {cost: 1.0}]->(e),
28+
(d)-[:LINK {cost: 1.0}]->(f),
29+
(e)-[:LINK {cost: 1.0}]->(f)
30+
"""
31+
32+
33+
@pytest.fixture
34+
def db_graph(arrow_client: AuthenticatedArrowClient, query_runner: QueryRunner) -> Generator[GraphV2, None, None]:
35+
with create_graph_from_db(
36+
arrow_client,
37+
query_runner,
38+
"g",
39+
graph,
40+
"""
41+
MATCH (source)-[r]->(target)
42+
WITH gds.graph.project.remote(source, target, {
43+
sourceNodeProperties: properties(source),
44+
targetNodeProperties: properties(target),
45+
relationshipProperties: properties(r)
46+
}) as g
47+
RETURN g
48+
""",
49+
undirected_relationship_types=["*"],
50+
) as g:
51+
yield g
52+
53+
54+
@pytest.mark.db_integration
55+
def test_k_spanning_tree_write(
56+
arrow_client: AuthenticatedArrowClient, query_runner: QueryRunner, db_graph: GraphV2
57+
) -> None:
58+
k_spanning_tree_endpoints = KSpanningTreeArrowEndpoints(
59+
arrow_client, write_back_client=RemoteWriteBackClient(arrow_client, query_runner)
60+
)
61+
result = k_spanning_tree_endpoints.write(
62+
G=db_graph,
63+
k=3,
64+
write_property="weight",
65+
source_node=0,
66+
relationship_weight_property="cost",
67+
)
68+
69+
assert result.effective_node_count == 3
70+
assert result.write_millis >= 0
71+
assert result.compute_millis >= 0
72+
assert result.pre_processing_millis >= 0
73+
assert result.post_processing_millis >= 0
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
from typing import Generator
2+
3+
import pytest
4+
5+
from graphdatascience import QueryRunner
6+
from graphdatascience.procedure_surface.api.catalog.graph_api import GraphV2
7+
from graphdatascience.procedure_surface.cypher.pathfinding.k_spanning_tree_cypher_endpoints import (
8+
KSpanningTreeCypherEndpoints,
9+
)
10+
from graphdatascience.tests.integrationV2.procedure_surface.cypher.cypher_graph_helper import create_graph
11+
from graphdatascience.tests.integrationV2.procedure_surface.node_lookup_helper import find_node_by_name
12+
13+
14+
@pytest.fixture
15+
def sample_graph(query_runner: QueryRunner) -> Generator[GraphV2, None, None]:
16+
create_statement = """
17+
CREATE
18+
(a: Node {name: 'A'}),
19+
(b: Node {name: 'B'}),
20+
(c: Node {name: 'C'}),
21+
(d: Node {name: 'D'}),
22+
(e: Node {name: 'E'}),
23+
(f: Node {name: 'F'}),
24+
(a)-[:LINK {cost: 1.0}]->(b),
25+
(a)-[:LINK {cost: 1.0}]->(c),
26+
(b)-[:LINK {cost: 1.0}]->(d),
27+
(c)-[:LINK {cost: 1.0}]->(e),
28+
(d)-[:LINK {cost: 1.0}]->(f),
29+
(e)-[:LINK {cost: 1.0}]->(f)
30+
"""
31+
32+
projection_query = """
33+
MATCH (source)-[r]->(target)
34+
WITH gds.graph.project('g', source, target, {
35+
relationshipProperties: properties(r)
36+
}, {undirectedRelationshipTypes: ['*']}) AS G
37+
RETURN G
38+
"""
39+
40+
with create_graph(
41+
query_runner,
42+
"g",
43+
create_statement,
44+
projection_query,
45+
) as g:
46+
yield g
47+
48+
49+
@pytest.fixture
50+
def k_spanning_tree_endpoints(query_runner: QueryRunner) -> Generator[KSpanningTreeCypherEndpoints, None, None]:
51+
yield KSpanningTreeCypherEndpoints(query_runner)
52+
53+
54+
def test_k_spanning_tree_write(
55+
k_spanning_tree_endpoints: KSpanningTreeCypherEndpoints, sample_graph: GraphV2, query_runner: QueryRunner
56+
) -> None:
57+
source = find_node_by_name(query_runner, "A")
58+
59+
result = k_spanning_tree_endpoints.write(
60+
G=sample_graph,
61+
k=3,
62+
write_property="weight",
63+
source_node=source,
64+
relationship_weight_property="cost",
65+
)
66+
67+
assert result.effective_node_count == 3
68+
assert result.write_millis >= 0
69+
assert result.compute_millis >= 0
70+
assert result.pre_processing_millis >= 0
71+
assert result.post_processing_millis >= 0

0 commit comments

Comments
 (0)