Skip to content

Commit 47c9332

Browse files
committed
Add prize_steiner_tree to v2 endpoints
1 parent 79c64ab commit 47c9332

File tree

9 files changed

+1036
-3
lines changed

9 files changed

+1036
-3
lines changed
Lines changed: 287 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,287 @@
1+
from __future__ import annotations
2+
3+
from abc import ABC, abstractmethod
4+
from typing import Any
5+
6+
from pandas import DataFrame
7+
8+
from graphdatascience.procedure_surface.api.base_result import BaseResult
9+
from graphdatascience.procedure_surface.api.catalog.graph_api import GraphV2
10+
from graphdatascience.procedure_surface.api.estimation_result import EstimationResult
11+
12+
13+
class PrizeSteinerTreeMutateResult(BaseResult):
14+
relationships_written: int
15+
mutate_millis: int
16+
effective_node_count: int
17+
sum_of_prizes: float
18+
total_weight: float
19+
pre_processing_millis: int
20+
compute_millis: int
21+
configuration: dict[str, Any]
22+
23+
24+
class PrizeSteinerTreeWriteResult(BaseResult):
25+
relationships_written: int
26+
write_millis: int
27+
effective_node_count: int
28+
sum_of_prizes: float
29+
total_weight: float
30+
pre_processing_millis: int
31+
compute_millis: int
32+
configuration: dict[str, Any]
33+
34+
35+
class PrizeSteinerTreeStatsResult(BaseResult):
36+
effective_node_count: int
37+
sum_of_prizes: float
38+
total_weight: float
39+
pre_processing_millis: int
40+
compute_millis: int
41+
configuration: dict[str, Any]
42+
43+
44+
class PrizeSteinerTreeEndpoints(ABC):
45+
@abstractmethod
46+
def stream(
47+
self,
48+
G: GraphV2,
49+
prize_property: str,
50+
relationship_weight_property: str | None = None,
51+
relationship_types: list[str] | None = None,
52+
node_labels: list[str] | None = None,
53+
sudo: bool = False,
54+
log_progress: bool = True,
55+
username: str | None = None,
56+
concurrency: int | None = None,
57+
job_id: str | None = None,
58+
) -> DataFrame:
59+
"""
60+
Runs the Prize Steiner tree algorithm and returns the result as a DataFrame.
61+
62+
Parameters
63+
----------
64+
G : GraphV2
65+
The graph to run the algorithm on.
66+
prize_property : str
67+
The name of the node property containing prize values.
68+
relationship_weight_property : str, optional
69+
The name of the relationship property to use as weights.
70+
relationship_types : list[str], optional
71+
Filter to only use relationships of specific types.
72+
node_labels : list[str], optional
73+
Filter to only use nodes with specific labels.
74+
sudo : bool, default=False
75+
Whether to run with elevated privileges.
76+
log_progress : bool, default=True
77+
Whether to log progress during execution.
78+
username : str, optional
79+
The username to use for logging.
80+
concurrency : int, optional
81+
The number of threads to use for parallel computation.
82+
job_id : str, optional
83+
An optional job ID for tracking the operation.
84+
85+
Returns
86+
-------
87+
DataFrame
88+
A DataFrame containing the tree edges with columns: nodeId, parentId, weight.
89+
"""
90+
...
91+
92+
@abstractmethod
93+
def stats(
94+
self,
95+
G: GraphV2,
96+
prize_property: str,
97+
relationship_weight_property: str | None = None,
98+
relationship_types: list[str] | None = None,
99+
node_labels: list[str] | None = None,
100+
sudo: bool = False,
101+
log_progress: bool = True,
102+
username: str | None = None,
103+
concurrency: int | None = None,
104+
job_id: str | None = None,
105+
) -> PrizeSteinerTreeStatsResult:
106+
"""
107+
Runs the Prize Steiner tree algorithm in stats mode, returning statistics without modifying the graph.
108+
109+
Parameters
110+
----------
111+
G : GraphV2
112+
The graph to run the algorithm on.
113+
prize_property : str
114+
The name of the node property containing prize values.
115+
relationship_weight_property : str, optional
116+
The name of the relationship property to use as weights.
117+
relationship_types : list[str], optional
118+
Filter to only use relationships of specific types.
119+
node_labels : list[str], optional
120+
Filter to only use nodes with specific labels.
121+
sudo : bool, default=False
122+
Whether to run with elevated privileges.
123+
log_progress : bool, default=True
124+
Whether to log progress during execution.
125+
username : str, optional
126+
The username to use for logging.
127+
concurrency : int, optional
128+
The number of threads to use for parallel computation.
129+
job_id : str, optional
130+
An optional job ID for tracking the operation.
131+
132+
Returns
133+
-------
134+
PrizeSteinerTreeStatsResult
135+
Statistics about the computed Prize Steiner tree.
136+
"""
137+
...
138+
139+
@abstractmethod
140+
def mutate(
141+
self,
142+
G: GraphV2,
143+
mutate_relationship_type: str,
144+
mutate_property: str,
145+
prize_property: str,
146+
relationship_weight_property: str | None = None,
147+
relationship_types: list[str] | None = None,
148+
node_labels: list[str] | None = None,
149+
sudo: bool = False,
150+
log_progress: bool = True,
151+
username: str | None = None,
152+
concurrency: int | None = None,
153+
job_id: str | None = None,
154+
) -> PrizeSteinerTreeMutateResult:
155+
"""
156+
Runs the Prize Steiner tree algorithm and adds the result as new relationships to the in-memory graph.
157+
158+
Parameters
159+
----------
160+
G : GraphV2
161+
The graph to run the algorithm on.
162+
mutate_relationship_type : str
163+
The relationship type to use for the new relationships.
164+
mutate_property : str
165+
The property name to store the edge weight.
166+
prize_property : str
167+
The name of the node property containing prize values.
168+
relationship_weight_property : str, optional
169+
The name of the relationship property to use as weights.
170+
relationship_types : list[str], optional
171+
Filter to only use relationships of specific types.
172+
node_labels : list[str], optional
173+
Filter to only use nodes with specific labels.
174+
sudo : bool, default=False
175+
Whether to run with elevated privileges.
176+
log_progress : bool, default=True
177+
Whether to log progress during execution.
178+
username : str, optional
179+
The username to use for logging.
180+
concurrency : int, optional
181+
The number of threads to use for parallel computation.
182+
job_id : str, optional
183+
An optional job ID for tracking the operation.
184+
185+
Returns
186+
-------
187+
PrizeSteinerTreeMutateResult
188+
Result containing statistics and timing information.
189+
"""
190+
...
191+
192+
@abstractmethod
193+
def write(
194+
self,
195+
G: GraphV2,
196+
write_relationship_type: str,
197+
write_property: str,
198+
prize_property: str,
199+
relationship_weight_property: str | None = None,
200+
relationship_types: list[str] | None = None,
201+
node_labels: list[str] | None = None,
202+
sudo: bool = False,
203+
log_progress: bool = True,
204+
username: str | None = None,
205+
concurrency: int | None = None,
206+
job_id: str | None = None,
207+
write_concurrency: int | None = None,
208+
) -> PrizeSteinerTreeWriteResult:
209+
"""
210+
Runs the Prize Steiner tree algorithm and writes the result back to the Neo4j database.
211+
212+
Parameters
213+
----------
214+
G : GraphV2
215+
The graph to run the algorithm on.
216+
write_relationship_type : str
217+
The relationship type to use for the new relationships.
218+
write_property : str
219+
The property name to store the edge weight.
220+
prize_property : str
221+
The name of the node property containing prize values.
222+
relationship_weight_property : str, optional
223+
The name of the relationship property to use as weights.
224+
relationship_types : list[str], optional
225+
Filter to only use relationships of specific types.
226+
node_labels : list[str], optional
227+
Filter to only use nodes with specific labels.
228+
sudo : bool, default=False
229+
Whether to run with elevated privileges.
230+
log_progress : bool, default=True
231+
Whether to log progress during execution.
232+
username : str, optional
233+
The username to use for logging.
234+
concurrency : int, optional
235+
The number of threads to use for parallel computation.
236+
job_id : str, optional
237+
An optional job ID for tracking the operation.
238+
write_concurrency : int, optional
239+
The number of threads to use for writing results.
240+
241+
Returns
242+
-------
243+
PrizeSteinerTreeWriteResult
244+
Result containing statistics and timing information.
245+
"""
246+
...
247+
248+
@abstractmethod
249+
def estimate(
250+
self,
251+
G: GraphV2 | dict[str, Any],
252+
prize_property: str,
253+
relationship_weight_property: str | None = None,
254+
relationship_types: list[str] | None = None,
255+
node_labels: list[str] | None = None,
256+
sudo: bool = False,
257+
username: str | None = None,
258+
concurrency: int | None = None,
259+
) -> EstimationResult:
260+
"""
261+
Estimates the memory requirements for running the Prize Steiner tree algorithm.
262+
263+
Parameters
264+
----------
265+
G : GraphV2 | dict[str, Any]
266+
The graph to estimate for, or a dictionary with nodeCount and relationshipCount.
267+
prize_property : str
268+
The name of the node property containing prize values.
269+
relationship_weight_property : str, optional
270+
The name of the relationship property to use as weights.
271+
relationship_types : list[str], optional
272+
Filter to only use relationships of specific types.
273+
node_labels : list[str], optional
274+
Filter to only use nodes with specific labels.
275+
sudo : bool, default=False
276+
Whether to run with elevated privileges.
277+
username : str, optional
278+
The username to use for logging.
279+
concurrency : int, optional
280+
The number of threads to use for parallel computation.
281+
282+
Returns
283+
-------
284+
EstimationResult
285+
Memory estimation results including required bytes and percentages.
286+
"""
287+
...

graphdatascience/procedure_surface/api/pathfinding/steiner_tree_endpoints.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,35 @@ def stats(
118118
"""
119119
Runs the Steiner tree algorithm in stats mode, returning statistics without modifying the graph.
120120
121+
Parameters
122+
----------
123+
G : GraphV2
124+
The graph to run the algorithm on.
125+
source_node : int
126+
The source node (root) for the Steiner tree.
127+
target_nodes : list[int]
128+
The list of target nodes (terminals) that must be connected.
129+
relationship_weight_property : str, optional
130+
The name of the relationship property to use as weights.
131+
delta : float, default=2.0
132+
The delta parameter for the shortest path computation used internally.
133+
apply_rerouting : bool, default=False
134+
Whether to apply rerouting optimization to improve the tree.
135+
relationship_types : list[str], optional
136+
Filter to only use relationships of specific types.
137+
node_labels : list[str], optional
138+
Filter to only use nodes with specific labels.
139+
sudo : bool, default=False
140+
Whether to run with elevated privileges.
141+
log_progress : bool, default=True
142+
Whether to log progress during execution.
143+
username : str, optional
144+
The username to use for logging.
145+
concurrency : int, optional
146+
The number of threads to use for parallel computation.
147+
job_id : str, optional
148+
An optional job ID for tracking the operation.
149+
121150
Returns
122151
-------
123152
SteinerTreeStatsResult
@@ -149,10 +178,36 @@ def mutate(
149178
150179
Parameters
151180
----------
181+
G : GraphV2
182+
The graph to run the algorithm on.
152183
mutate_relationship_type : str
153184
The relationship type to use for the new relationships.
154185
mutate_property : str
155186
The property name to store the edge weight.
187+
source_node : int
188+
The source node (root) for the Steiner tree.
189+
target_nodes : list[int]
190+
The list of target nodes (terminals) that must be connected.
191+
relationship_weight_property : str, optional
192+
The name of the relationship property to use as weights.
193+
delta : float, default=2.0
194+
The delta parameter for the shortest path computation used internally.
195+
apply_rerouting : bool, default=False
196+
Whether to apply rerouting optimization to improve the tree.
197+
relationship_types : list[str], optional
198+
Filter to only use relationships of specific types.
199+
node_labels : list[str], optional
200+
Filter to only use nodes with specific labels.
201+
sudo : bool, default=False
202+
Whether to run with elevated privileges.
203+
log_progress : bool, default=True
204+
Whether to log progress during execution.
205+
username : str, optional
206+
The username to use for logging.
207+
concurrency : int, optional
208+
The number of threads to use for parallel computation.
209+
job_id : str, optional
210+
An optional job ID for tracking the operation.
156211
157212
Returns
158213
-------
@@ -186,10 +241,36 @@ def write(
186241
187242
Parameters
188243
----------
244+
G : GraphV2
245+
The graph to run the algorithm on.
189246
write_relationship_type : str
190247
The relationship type to use for the new relationships.
191248
write_property : str
192249
The property name to store the edge weight.
250+
source_node : int
251+
The source node (root) for the Steiner tree.
252+
target_nodes : list[int]
253+
The list of target nodes (terminals) that must be connected.
254+
relationship_weight_property : str, optional
255+
The name of the relationship property to use as weights.
256+
delta : float, default=2.0
257+
The delta parameter for the shortest path computation used internally.
258+
apply_rerouting : bool, default=False
259+
Whether to apply rerouting optimization to improve the tree.
260+
relationship_types : list[str], optional
261+
Filter to only use relationships of specific types.
262+
node_labels : list[str], optional
263+
Filter to only use nodes with specific labels.
264+
sudo : bool, default=False
265+
Whether to run with elevated privileges.
266+
log_progress : bool, default=True
267+
Whether to log progress during execution.
268+
username : str, optional
269+
The username to use for logging.
270+
concurrency : int, optional
271+
The number of threads to use for parallel computation.
272+
job_id : str, optional
273+
An optional job ID for tracking the operation.
193274
write_concurrency : int, optional
194275
The number of threads to use for writing results.
195276

0 commit comments

Comments
 (0)