Skip to content

Commit bcb1a99

Browse files
authored
Merge pull request #171 from dispatchrun/flask-tests
Flask integration tests
2 parents fc85617 + e0a1929 commit bcb1a99

File tree

3 files changed

+206
-0
lines changed

3 files changed

+206
-0
lines changed

README.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ Python package to develop applications with the Dispatch platform.
2828
- [Running Dispatch Applications](#running-dispatch-applications)
2929
- [Writing Transactional Applications with Dispatch](#writing-transactional-applications-with-dispatch)
3030
- [Integration with FastAPI](#integration-with-fastapi)
31+
- [Integration with Flask](#integration-with-flask)
3132
- [Configuration](#configuration)
3233
- [Serialization](#serialization)
3334
- [Examples](#examples)
@@ -198,6 +199,22 @@ In this example, GET requests on the HTTP server dispatch calls to the
198199
`publish` function. The function runs concurrently to the rest of the
199200
program, driven by the Dispatch SDK.
200201

202+
### Integration with Flask
203+
204+
Dispatch can also be integrated with web applications built on [Flask][flask].
205+
206+
The API is nearly identical to FastAPI above, instead use:
207+
208+
```python
209+
from flask import Flask
210+
from dispatch.flask import Dispatch
211+
212+
app = Flask(__name__)
213+
dispatch = Dispatch(app)
214+
```
215+
216+
[flask]: https://flask.palletsprojects.com/en/3.0.x/
217+
201218
### Configuration
202219

203220
The Dispatch CLI automatically configures the SDK, so manual configuration is

src/dispatch/test/flask.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from typing import Mapping
2+
3+
import werkzeug.test
4+
from flask import Flask
5+
6+
from dispatch.test.http import HttpClient, HttpResponse
7+
8+
9+
def http_client(app: Flask) -> HttpClient:
10+
"""Build a client for a Flask app."""
11+
return Client(app.test_client())
12+
13+
14+
class Client(HttpClient):
15+
def __init__(self, client: werkzeug.test.Client):
16+
self.client = client
17+
18+
def get(self, url: str, headers: Mapping[str, str] = {}) -> HttpResponse:
19+
response = self.client.get(url, headers=headers.items())
20+
return Response(response)
21+
22+
def post(
23+
self, url: str, body: bytes, headers: Mapping[str, str] = {}
24+
) -> HttpResponse:
25+
response = self.client.post(url, data=body, headers=headers.items())
26+
return Response(response)
27+
28+
def url_for(self, path: str) -> str:
29+
return "http://localhost" + path
30+
31+
32+
class Response(HttpResponse):
33+
def __init__(self, response):
34+
self.response = response
35+
36+
@property
37+
def status_code(self):
38+
return self.response.status_code
39+
40+
@property
41+
def body(self):
42+
return self.response.data
43+
44+
def raise_for_status(self):
45+
if self.response.status_code // 100 != 2:
46+
raise RuntimeError(f"HTTP status code {self.response.status_code}")

tests/test_flask.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
import base64
2+
import os
3+
import pickle
4+
import struct
5+
import unittest
6+
from typing import Any, Optional
7+
from unittest import mock
8+
9+
import google.protobuf.any_pb2
10+
import google.protobuf.wrappers_pb2
11+
from cryptography.hazmat.primitives.asymmetric.ed25519 import (
12+
Ed25519PrivateKey,
13+
Ed25519PublicKey,
14+
)
15+
from flask import Flask
16+
17+
import dispatch
18+
from dispatch.experimental.durable.registry import clear_functions
19+
from dispatch.flask import Dispatch
20+
from dispatch.function import Arguments, Error, Function, Input, Output
21+
from dispatch.proto import _any_unpickle as any_unpickle
22+
from dispatch.sdk.v1 import call_pb2 as call_pb
23+
from dispatch.sdk.v1 import function_pb2 as function_pb
24+
from dispatch.signature import (
25+
parse_verification_key,
26+
private_key_from_pem,
27+
public_key_from_pem,
28+
)
29+
from dispatch.status import Status
30+
from dispatch.test import DispatchServer, DispatchService, EndpointClient
31+
from dispatch.test.flask import http_client
32+
33+
34+
def create_dispatch_instance(app: Flask, endpoint: str):
35+
return Dispatch(
36+
app,
37+
endpoint=endpoint,
38+
api_key="0000000000000000",
39+
api_url="http://127.0.0.1:10000",
40+
)
41+
42+
43+
def create_endpoint_client(app: Flask, signing_key: Optional[Ed25519PrivateKey] = None):
44+
return EndpointClient(http_client(app), signing_key)
45+
46+
47+
class TestFlask(unittest.TestCase):
48+
def test_flask(self):
49+
app = Flask(__name__)
50+
dispatch = create_dispatch_instance(app, endpoint="http://127.0.0.1:9999/")
51+
52+
@dispatch.primitive_function
53+
def my_function(input: Input) -> Output:
54+
return Output.value(
55+
f"You told me: '{input.input}' ({len(input.input)} characters)"
56+
)
57+
58+
client = create_endpoint_client(app)
59+
pickled = pickle.dumps("Hello World!")
60+
input_any = google.protobuf.any_pb2.Any()
61+
input_any.Pack(google.protobuf.wrappers_pb2.BytesValue(value=pickled))
62+
63+
req = function_pb.RunRequest(
64+
function=my_function.name,
65+
input=input_any,
66+
)
67+
68+
resp = client.run(req)
69+
70+
self.assertIsInstance(resp, function_pb.RunResponse)
71+
72+
resp.exit.result.output.Unpack(
73+
output_bytes := google.protobuf.wrappers_pb2.BytesValue()
74+
)
75+
output = pickle.loads(output_bytes.value)
76+
77+
self.assertEqual(output, "You told me: 'Hello World!' (12 characters)")
78+
79+
80+
signing_key = private_key_from_pem(
81+
"""
82+
-----BEGIN PRIVATE KEY-----
83+
MC4CAQAwBQYDK2VwBCIEIJ+DYvh6SEqVTm50DFtMDoQikTmiCqirVv9mWG9qfSnF
84+
-----END PRIVATE KEY-----
85+
"""
86+
)
87+
88+
verification_key = public_key_from_pem(
89+
"""
90+
-----BEGIN PUBLIC KEY-----
91+
MCowBQYDK2VwAyEAJrQLj5P/89iXES9+vFgrIy29clF9CC/oPPsw3c5D0bs=
92+
-----END PUBLIC KEY-----
93+
"""
94+
)
95+
96+
97+
class TestFlaskE2E(unittest.TestCase):
98+
def setUp(self):
99+
self.endpoint_app = Flask(__name__)
100+
endpoint_client = create_endpoint_client(self.endpoint_app, signing_key)
101+
102+
api_key = "0000000000000000"
103+
self.dispatch_service = DispatchService(
104+
endpoint_client, api_key, collect_roundtrips=True
105+
)
106+
self.dispatch_server = DispatchServer(self.dispatch_service)
107+
self.dispatch_client = dispatch.Client(
108+
api_key, api_url=self.dispatch_server.url
109+
)
110+
111+
self.dispatch = Dispatch(
112+
self.endpoint_app,
113+
endpoint="http://function-service", # unused
114+
verification_key=verification_key,
115+
api_key=api_key,
116+
api_url=self.dispatch_server.url,
117+
)
118+
119+
self.dispatch_server.start()
120+
121+
def tearDown(self):
122+
self.dispatch_server.stop()
123+
124+
def test_simple_end_to_end(self):
125+
# The Flask server.
126+
@self.dispatch.function
127+
def my_function(name: str) -> str:
128+
return f"Hello world: {name}"
129+
130+
call = my_function.build_call(52)
131+
self.assertEqual(call.function.split(".")[-1], "my_function")
132+
133+
# The client.
134+
[dispatch_id] = self.dispatch_client.dispatch([my_function.build_call(52)])
135+
136+
# Simulate execution for testing purposes.
137+
self.dispatch_service.dispatch_calls()
138+
139+
# Validate results.
140+
roundtrips = self.dispatch_service.roundtrips[dispatch_id]
141+
self.assertEqual(len(roundtrips), 1)
142+
_, response = roundtrips[0]
143+
self.assertEqual(any_unpickle(response.exit.result.output), "Hello world: 52")

0 commit comments

Comments
 (0)