Skip to content

Commit bdf5369

Browse files
committed
initial simplest version of kg train API
1 parent b071b61 commit bdf5369

File tree

2 files changed

+154
-0
lines changed

2 files changed

+154
-0
lines changed

examples/kge-transe.ipynb

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"id": "b68543d5e71ceeb2",
7+
"metadata": {},
8+
"outputs": [],
9+
"source": [
10+
"import os\n",
11+
"from graphdatascience import GraphDataScience"
12+
]
13+
},
14+
{
15+
"cell_type": "code",
16+
"execution_count": null,
17+
"id": "e685a47b61f968ef",
18+
"metadata": {},
19+
"outputs": [],
20+
"source": [
21+
"NEO4J_URI = \"bolt://localhost:7687\"\n",
22+
"NEO4J_DB = \"neo4j\""
23+
]
24+
},
25+
{
26+
"cell_type": "code",
27+
"execution_count": null,
28+
"id": "initial_id",
29+
"metadata": {},
30+
"outputs": [],
31+
"source": [
32+
"if os.environ.get(\"NEO4J_USER\") and os.environ.get(\"NEO4J_PASSWORD\"):\n",
33+
" NEO4J_AUTH = (\n",
34+
" os.environ.get(\"NEO4J_USER\"),\n",
35+
" os.environ.get(\"NEO4J_PASSWORD\"),\n",
36+
" )\n",
37+
"gds = GraphDataScience(NEO4J_URI, auth=NEO4J_AUTH, database=NEO4J_DB)"
38+
]
39+
},
40+
{
41+
"cell_type": "code",
42+
"execution_count": null,
43+
"id": "a14f06aebe1ed34c",
44+
"metadata": {},
45+
"outputs": [],
46+
"source": [
47+
"_ = gds.run_cypher(\n",
48+
" \"\"\"\n",
49+
" CREATE\n",
50+
" (dan:Person {name: 'Dan'}),\n",
51+
" (annie:Person {name: 'Annie'}),\n",
52+
" (matt:Person {name: 'Matt'}),\n",
53+
" (jeff:Person {name: 'Jeff'}),\n",
54+
" (brie:Person {name: 'Brie'}),\n",
55+
" (elsa:Person {name: 'Elsa'}),\n",
56+
"\n",
57+
" (cookies:Product {name: 'Cookies'}),\n",
58+
" (tomatoes:Product {name: 'Tomatoes'}),\n",
59+
" (cucumber:Product {name: 'Cucumber'}),\n",
60+
" (celery:Product {name: 'Celery'}),\n",
61+
" (kale:Product {name: 'Kale'}),\n",
62+
" (milk:Product {name: 'Milk'}),\n",
63+
" (chocolate:Product {name: 'Chocolate'}),\n",
64+
"\n",
65+
" (dan)-[:BUYS {amount: 1.2}]->(cookies),\n",
66+
" (dan)-[:BUYS {amount: 3.2}]->(milk),\n",
67+
" (dan)-[:BUYS {amount: 2.2}]->(chocolate),\n",
68+
"\n",
69+
" (annie)-[:BUYS {amount: 1.2}]->(cucumber),\n",
70+
" (annie)-[:BUYS {amount: 3.2}]->(milk),\n",
71+
" (annie)-[:BUYS {amount: 3.2}]->(tomatoes),\n",
72+
"\n",
73+
" (matt)-[:BUYS {amount: 3}]->(tomatoes),\n",
74+
" (matt)-[:BUYS {amount: 2}]->(kale),\n",
75+
" (matt)-[:BUYS {amount: 1}]->(cucumber),\n",
76+
"\n",
77+
" (jeff)-[:BUYS {amount: 3}]->(cookies),\n",
78+
" (jeff)-[:BUYS {amount: 2}]->(milk),\n",
79+
"\n",
80+
" (brie)-[:BUYS {amount: 1}]->(tomatoes),\n",
81+
" (brie)-[:BUYS {amount: 2}]->(milk),\n",
82+
" (brie)-[:BUYS {amount: 2}]->(kale),\n",
83+
" (brie)-[:BUYS {amount: 3}]->(cucumber),\n",
84+
" (brie)-[:BUYS {amount: 0.3}]->(celery),\n",
85+
"\n",
86+
" (elsa)-[:BUYS {amount: 3}]->(chocolate),\n",
87+
" (elsa)-[:BUYS {amount: 3}]->(milk)\n",
88+
" \"\"\"\n",
89+
")\n",
90+
"node_projection = [\"Person\", \"Product\"]\n",
91+
"relationship_projection = {\"BUYS\": {\"orientation\": \"UNDIRECTED\", \"properties\": \"amount\"}}\n",
92+
"G, result = gds.graph.project(\"purchases222\", node_projection, relationship_projection)\n",
93+
"print(f\"The projection took {result['projectMillis']} ms\")\n",
94+
"print(f\"Graph '{G.name()}' node count: {G.node_count()}\")\n",
95+
"print(f\"Graph '{G.name()}' node labels: {G.node_labels()}\")"
96+
]
97+
},
98+
{
99+
"cell_type": "code",
100+
"execution_count": null,
101+
"id": "e049480efa34e8ca",
102+
"metadata": {},
103+
"outputs": [],
104+
"source": [
105+
"gds.model.transe.train(\n",
106+
" G,\n",
107+
" proportions=[0.8, 0.1, 0.1],\n",
108+
" embedding_dimension=50,\n",
109+
" batch_size=512,\n",
110+
" epochs=100,\n",
111+
" optimizer=\"Adam\",\n",
112+
" optimizer_kwargs={\"lr\": 0.01, \"weight_decay\": 5e-4},\n",
113+
" # loss\n",
114+
")"
115+
]
116+
}
117+
],
118+
"metadata": {
119+
"language_info": {
120+
"name": "python"
121+
}
122+
},
123+
"nbformat": 4,
124+
"nbformat_minor": 5
125+
}

graphdatascience/model/model_proc_runner.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
from typing import Any, Dict, List, Optional, Tuple
23

34
from pandas import DataFrame, Series
@@ -45,6 +46,34 @@ def create(
4546
relationship_type_embeddings,
4647
)
4748

49+
@compatible_with("train", min_inclusive=ServerVersion(2, 5, 0))
50+
@client_only_endpoint("gds.model.transe")
51+
def train(self,
52+
G: Graph,
53+
proportions: list,
54+
embedding_dimension: int,
55+
batch_size: int,
56+
epochs: int,
57+
optimizer: str,
58+
optimizer_kwargs: dict,
59+
# loss: str
60+
) -> int:
61+
config = {'scoring_function': 'TransE',
62+
'proportions': proportions,
63+
'embedding_dimension': embedding_dimension,
64+
'num_epochs': epochs,
65+
'graph_name': G.name(),
66+
'batch_size': batch_size,
67+
'optimizer': optimizer,
68+
'optimizer_kwargs': optimizer_kwargs,
69+
# 'loss': loss,
70+
}
71+
config_path = "/tmp/kge-train-config-dump.json"
72+
print('Dumped to ' + config_path)
73+
config_file = open(config_path, "w")
74+
75+
json.dump(config, config_file)
76+
return 0
4877

4978
class ModelProcRunner(ModelResolver):
5079
@client_only_endpoint("gds.model")

0 commit comments

Comments
 (0)