Skip to content

Commit 80d5caf

Browse files
committed
Add option for for streaming with Arrow with gds.<algo>.stream
1 parent b3e0e9e commit 80d5caf

File tree

2 files changed

+92
-3
lines changed

2 files changed

+92
-3
lines changed

graphdatascience/algo/algo_proc_runner.py

Lines changed: 74 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,21 @@
11
from abc import ABC
22
from typing import Any, Dict, Tuple
3+
from uuid import uuid4
4+
from warnings import warn
35

46
from pandas import DataFrame, Series
57

8+
from ..call_parameters import CallParameters
69
from ..error.illegal_attr_checker import IllegalAttrChecker
10+
from ..graph.graph_entity_ops_runner import (
11+
GraphElementPropertyRunner,
12+
GraphNodePropertiesRunner,
13+
GraphRelationshipsRunner,
14+
)
715
from ..graph.graph_object import Graph
816
from ..graph.graph_type_check import graph_type_check
917
from ..model.graphsage_model import GraphSageModel
10-
from graphdatascience.call_parameters import CallParameters
18+
from ..query_runner.arrow_query_runner import ArrowQueryRunner
1119

1220

1321
class AlgoProcRunner(IllegalAttrChecker, ABC):
@@ -24,8 +32,71 @@ def estimate(self, G: Graph, **config: Any) -> "Series[Any]":
2432

2533

2634
class StreamModeRunner(AlgoProcRunner):
27-
def __call__(self, G: Graph, **config: Any) -> DataFrame:
28-
return self._run_procedure(G, config)
35+
def __call__(self, G: Graph, stream_with_arrow: bool = False, **config: Any) -> DataFrame:
36+
if stream_with_arrow:
37+
if not isinstance(self._query_runner, ArrowQueryRunner):
38+
raise ValueError("The `stream_with_arrow` option requires GDS EE with the Arrow server enabled")
39+
40+
try:
41+
return self._stream_with_arrow(G, config)
42+
except Exception as e:
43+
warn(
44+
"Falling back to streaming with Neo4j driver, "
45+
f"since failed to stream with Arrow with reason: {str(e)}"
46+
)
47+
48+
del config["mutateProperty"]
49+
del config["mutateRelationshipType"]
50+
51+
self._namespace = self._namespace.replace("mutate", "stream")
52+
return self._run_procedure(G, config)
53+
else:
54+
return self._run_procedure(G, config)
55+
56+
def _stream_with_arrow(self, G: Graph, config: Dict[str, Any]) -> DataFrame:
57+
self._namespace = self._namespace.replace("stream", "mutate")
58+
59+
mutate_property = str(uuid4())
60+
config["mutateProperty"] = mutate_property
61+
try:
62+
self._run_procedure(G, config)
63+
64+
elem_prop_runner = GraphElementPropertyRunner(
65+
self._query_runner, "gds.graph.nodeProperty", self._server_version
66+
)
67+
if "concurrency" in config:
68+
result = elem_prop_runner.stream(G, mutate_property, concurrency=config["concurrency"])
69+
else:
70+
result = elem_prop_runner.stream(G, mutate_property)
71+
72+
node_prop_runner = GraphNodePropertiesRunner(
73+
self._query_runner, "gds.graph.nodeProperties", self._server_version
74+
)
75+
node_prop_runner.drop(G, [mutate_property])
76+
77+
return result
78+
except Exception as e:
79+
if "No value specified for the mandatory configuration parameter `mutateRelationshipType`" not in str(e):
80+
raise e
81+
82+
mutate_relationship_type = str(uuid4())
83+
config["mutateRelationshipType"] = mutate_relationship_type
84+
self._run_procedure(G, config)
85+
86+
elem_prop_runner = GraphElementPropertyRunner(
87+
self._query_runner, "gds.graph.relationshipProperty", self._server_version
88+
)
89+
if "concurrency" in config:
90+
result = elem_prop_runner.stream(
91+
G, mutate_property, [mutate_relationship_type], concurrency=config["concurrency"]
92+
)
93+
else:
94+
result = elem_prop_runner.stream(G, mutate_property, [mutate_relationship_type])
95+
96+
rel_runner = GraphRelationshipsRunner(self._query_runner, "gds.graph.relationships", self._server_version)
97+
rel_runner.drop(G, mutate_relationship_type)
98+
99+
return result.drop("relationshipType", axis=1)
29100

30101

31102
class StandardModeRunner(AlgoProcRunner):

graphdatascience/tests/integration/test_simple_algo.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,3 +113,21 @@ def test_fastRP_write_estimate(gds: GraphDataScience) -> None:
113113
result = gds.fastRP.write.estimate(G, writeProperty="embedding", embeddingDimension=4, randomSeed=42)
114114

115115
assert result["requiredMemory"]
116+
117+
118+
def test_nodeSimilarity_stream_with_arrow(gds: GraphDataScience) -> None:
119+
G, _ = gds.graph.project(GRAPH_NAME, "*", "*")
120+
121+
result = gds.nodeSimilarity.stream(G, similarityCutoff=0, stream_with_arrow=True)
122+
123+
assert len(result) == 2
124+
assert result["propertyValue"][0] == 0.5
125+
126+
127+
def test_fastRP_stream_with_arrow(gds: GraphDataScience) -> None:
128+
G, _ = gds.graph.project(GRAPH_NAME, "*", "*")
129+
130+
result = gds.fastRP.stream(G, stream_with_arrow=True, embeddingDimension=32)
131+
132+
assert len(result) == 3
133+
assert len(result["propertyValue"][0]) == 32

0 commit comments

Comments
 (0)