1
1
from abc import ABC
2
2
from typing import Any , Dict , Tuple
3
+ from uuid import uuid4
4
+ from warnings import warn
3
5
4
6
from pandas import DataFrame , Series
5
7
8
+ from ..call_parameters import CallParameters
6
9
from ..error .illegal_attr_checker import IllegalAttrChecker
10
+ from ..graph .graph_entity_ops_runner import (
11
+ GraphElementPropertyRunner ,
12
+ GraphNodePropertiesRunner ,
13
+ GraphRelationshipsRunner ,
14
+ )
7
15
from ..graph .graph_object import Graph
8
16
from ..graph .graph_type_check import graph_type_check
9
17
from ..model .graphsage_model import GraphSageModel
10
- from graphdatascience . call_parameters import CallParameters
18
+ from .. query_runner . arrow_query_runner import ArrowQueryRunner
11
19
12
20
13
21
class AlgoProcRunner (IllegalAttrChecker , ABC ):
@@ -24,8 +32,71 @@ def estimate(self, G: Graph, **config: Any) -> "Series[Any]":
24
32
25
33
26
34
class StreamModeRunner (AlgoProcRunner ):
27
- def __call__ (self , G : Graph , ** config : Any ) -> DataFrame :
28
- return self ._run_procedure (G , config )
35
+ def __call__ (self , G : Graph , stream_with_arrow : bool = False , ** config : Any ) -> DataFrame :
36
+ if stream_with_arrow :
37
+ if not isinstance (self ._query_runner , ArrowQueryRunner ):
38
+ raise ValueError ("The `stream_with_arrow` option requires GDS EE with the Arrow server enabled" )
39
+
40
+ try :
41
+ return self ._stream_with_arrow (G , config )
42
+ except Exception as e :
43
+ warn (
44
+ "Falling back to streaming with Neo4j driver, "
45
+ f"since failed to stream with Arrow with reason: { str (e )} "
46
+ )
47
+
48
+ del config ["mutateProperty" ]
49
+ del config ["mutateRelationshipType" ]
50
+
51
+ self ._namespace = self ._namespace .replace ("mutate" , "stream" )
52
+ return self ._run_procedure (G , config )
53
+ else :
54
+ return self ._run_procedure (G , config )
55
+
56
+ def _stream_with_arrow (self , G : Graph , config : Dict [str , Any ]) -> DataFrame :
57
+ self ._namespace = self ._namespace .replace ("stream" , "mutate" )
58
+
59
+ mutate_property = str (uuid4 ())
60
+ config ["mutateProperty" ] = mutate_property
61
+ try :
62
+ self ._run_procedure (G , config )
63
+
64
+ elem_prop_runner = GraphElementPropertyRunner (
65
+ self ._query_runner , "gds.graph.nodeProperty" , self ._server_version
66
+ )
67
+ if "concurrency" in config :
68
+ result = elem_prop_runner .stream (G , mutate_property , concurrency = config ["concurrency" ])
69
+ else :
70
+ result = elem_prop_runner .stream (G , mutate_property )
71
+
72
+ node_prop_runner = GraphNodePropertiesRunner (
73
+ self ._query_runner , "gds.graph.nodeProperties" , self ._server_version
74
+ )
75
+ node_prop_runner .drop (G , [mutate_property ])
76
+
77
+ return result
78
+ except Exception as e :
79
+ if "No value specified for the mandatory configuration parameter `mutateRelationshipType`" not in str (e ):
80
+ raise e
81
+
82
+ mutate_relationship_type = str (uuid4 ())
83
+ config ["mutateRelationshipType" ] = mutate_relationship_type
84
+ self ._run_procedure (G , config )
85
+
86
+ elem_prop_runner = GraphElementPropertyRunner (
87
+ self ._query_runner , "gds.graph.relationshipProperty" , self ._server_version
88
+ )
89
+ if "concurrency" in config :
90
+ result = elem_prop_runner .stream (
91
+ G , mutate_property , [mutate_relationship_type ], concurrency = config ["concurrency" ]
92
+ )
93
+ else :
94
+ result = elem_prop_runner .stream (G , mutate_property , [mutate_relationship_type ])
95
+
96
+ rel_runner = GraphRelationshipsRunner (self ._query_runner , "gds.graph.relationships" , self ._server_version )
97
+ rel_runner .drop (G , mutate_relationship_type )
98
+
99
+ return result .drop ("relationshipType" , axis = 1 )
29
100
30
101
31
102
class StandardModeRunner (AlgoProcRunner ):
0 commit comments