Skip to content

Commit 900c72f

Browse files
Jerzy Góragoraje
authored andcommitted
feature: add-websocket-support
1 parent f7675e9 commit 900c72f

File tree

11 files changed

+180
-82
lines changed

11 files changed

+180
-82
lines changed

README.md

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ pip install fastapi-controllers
3434
```python
3535
import uvicorn
3636
from fastapi import FastAPI, Response, status
37+
from fastapi.websockets import WebSocket
3738

3839
from fastapi_controllers import Controller, get
3940

@@ -43,6 +44,13 @@ class ExampleController(Controller):
4344
async def get_example(self) -> Response:
4445
return Response(status_code=status.HTTP_200_OK)
4546

47+
@websocket("/ws")
48+
async def ws_example(websocket: WebSocket) -> None:
49+
await websocket.accept()
50+
while True:
51+
data = await websocket.receive_text()
52+
await websocket.send_text(f"Received: {data}")
53+
4654

4755
if __name__ == "__main__":
4856
app = FastAPI()
@@ -54,12 +62,12 @@ FastAPI's `APIRouter` is created and populated with API routes by the `Controlle
5462

5563
## Seamless integration
5664

57-
The router-related parameters as well as those of HTTP request-specific decorators are expected to be the same as those used by `fastapi.APIRouter` and `fastapi.APIRouter.<request_method>`. Validation of the provided parameters is performed during initialization via the `inspect` module. This ensures compatibility with the FastAPI framework and prevents the introduction of a new, unnecessary naming convention.
65+
The router-related parameters as well as those of HTTP request-specific and websocket decorators are expected to be the same as those used by `fastapi.APIRouter`, `fastapi.APIRouter.<request_method>` and `fastapi.APIRouter.websocket`. Validation of the provided parameters is performed during initialization via the `inspect` module. This ensures compatibility with the FastAPI framework and prevents the introduction of a new, unnecessary naming convention.
5866

59-
### Supported HTTP request methods
67+
### Available decorators
6068

6169
```python
62-
from fastapi_controllers import delete, get, head, options, patch, post, put, trace
70+
from fastapi_controllers import delete, get, head, options, patch, post, put, trace, websocket
6371
```
6472

6573
## Use class variables to customize your APIRouter

fastapi_controllers/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from fastapi_controllers.controllers import Controller
2-
from fastapi_controllers.routing import delete, get, head, options, patch, post, put, trace
2+
from fastapi_controllers.routing import delete, get, head, options, patch, post, put, trace, websocket
33

44
__all__ = [
55
"Controller",
@@ -11,4 +11,5 @@
1111
"post",
1212
"put",
1313
"trace",
14+
"websocket",
1415
]

fastapi_controllers/controllers.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import inspect
22
from enum import Enum
3-
from types import SimpleNamespace
43
from typing import Any, Dict, List, Optional, Sequence, Union
54

65
from fastapi import APIRouter, params
76

7+
from fastapi_controllers.definitions import HTTPRouteDefinition, RouteData, WebsocketRouteDefinition
88
from fastapi_controllers.helpers import _replace_signature, _validate_against_apirouter_signature
99

1010

@@ -33,12 +33,21 @@ def create_router(cls) -> APIRouter:
3333
router = APIRouter(**cls.__router_params__) # type: ignore
3434
for _, func in inspect.getmembers(cls, predicate=inspect.isfunction):
3535
_replace_signature(cls, func)
36-
api_route_data: Optional[SimpleNamespace] = getattr(func, "__api_route_data__", None)
37-
if api_route_data:
38-
router.add_api_route(
39-
api_route_data.args[0],
40-
func,
41-
*api_route_data.args[1:],
42-
**api_route_data.kwargs,
43-
)
36+
route_data: Optional[RouteData] = getattr(func, "__route_data__", None)
37+
if route_data:
38+
if isinstance(route_data.route_definition, HTTPRouteDefinition):
39+
router.add_api_route(
40+
route_data.route_args[0],
41+
func,
42+
*route_data.route_args[1:],
43+
methods=[route_data.route_definition.request_method],
44+
**route_data.route_kwargs,
45+
)
46+
if isinstance(route_data.route_definition, WebsocketRouteDefinition):
47+
router.add_api_websocket_route(
48+
route_data.route_args[0],
49+
func,
50+
*route_data.route_args[1:],
51+
**route_data.route_kwargs,
52+
)
4453
return router

fastapi_controllers/definitions.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from dataclasses import dataclass
2+
from enum import Enum
3+
from typing import Any, Dict, Tuple
4+
5+
6+
class HTTPRequestMethod(str, Enum):
7+
DELETE = "DELETE"
8+
GET = "GET"
9+
HEAD = "HEAD"
10+
OPTIONS = "OPTIONS"
11+
PATCH = "PATCH"
12+
POST = "POST"
13+
PUT = "PUT"
14+
TRACE = "TRACE"
15+
16+
17+
class RouteDefinition:
18+
def __init__(self, *, binds: str) -> None:
19+
self.binds = binds
20+
21+
22+
class HTTPRouteDefinition(RouteDefinition):
23+
def __init__(self, *, binds: str, request_method: HTTPRequestMethod) -> None:
24+
super().__init__(binds=binds)
25+
self.request_method = request_method
26+
27+
28+
class WebsocketRouteDefinition(RouteDefinition):
29+
...
30+
31+
32+
@dataclass
33+
class RouteData:
34+
route_definition: RouteDefinition
35+
route_args: Tuple[Any, ...]
36+
route_kwargs: Dict[str, Any]
37+
38+
39+
@dataclass
40+
class Route:
41+
delete = HTTPRouteDefinition(binds="delete", request_method=HTTPRequestMethod.DELETE)
42+
get = HTTPRouteDefinition(binds="get", request_method=HTTPRequestMethod.GET)
43+
head = HTTPRouteDefinition(binds="head", request_method=HTTPRequestMethod.HEAD)
44+
options = HTTPRouteDefinition(binds="options", request_method=HTTPRequestMethod.OPTIONS)
45+
patch = HTTPRouteDefinition(binds="patch", request_method=HTTPRequestMethod.PATCH)
46+
post = HTTPRouteDefinition(binds="post", request_method=HTTPRequestMethod.POST)
47+
put = HTTPRouteDefinition(binds="put", request_method=HTTPRequestMethod.PUT)
48+
trace = HTTPRouteDefinition(binds="trace", request_method=HTTPRequestMethod.TRACE)
49+
websocket = WebsocketRouteDefinition(binds="websocket")

fastapi_controllers/routing.py

Lines changed: 41 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,68 +1,77 @@
1-
from enum import Enum
2-
from types import SimpleNamespace
3-
from typing import Any, Callable
1+
from typing import Any, Callable, Dict, Tuple
42

3+
from fastapi_controllers.definitions import Route, RouteData, RouteDefinition
54
from fastapi_controllers.helpers import _validate_against_apirouter_signature
65

76

8-
class _HTTPRequestMethod(str, Enum):
9-
DELETE = "DELETE"
10-
GET = "GET"
11-
HEAD = "HEAD"
12-
OPTIONS = "OPTIONS"
13-
PATCH = "PATCH"
14-
POST = "POST"
15-
PUT = "PUT"
16-
TRACE = "TRACE"
17-
18-
197
class _RouteDecorator:
20-
method: str
8+
_route_definition: RouteDefinition
219

22-
def __init_subclass__(cls, method: _HTTPRequestMethod) -> None:
10+
def __init_subclass__(cls, route_definition: RouteDefinition) -> None:
2311
super().__init_subclass__()
24-
cls.method = method
12+
cls._route_definition = route_definition
2513

2614
def __init__(self, *args: Any, **kwargs: Any) -> None:
27-
self.args = args
28-
if self.args and callable(self.args[0]):
29-
raise TypeError("You must provide a path for the route.")
30-
self.kwargs = kwargs
31-
_validate_against_apirouter_signature(self.method.lower(), args=args, kwargs=kwargs)
15+
self.route_args = args
16+
self.route_kwargs = kwargs
17+
_validate_against_apirouter_signature(self._route_definition.binds, args=args, kwargs=kwargs)
3218

3319
def __call__(self, func: Callable[..., Any]) -> Callable[..., Any]:
34-
self.kwargs["methods"] = [self.method]
35-
func.__api_route_data__ = SimpleNamespace(args=self.args, kwargs=self.kwargs) # type: ignore
20+
func.__route_data__ = RouteData( # type: ignore
21+
route_definition=self._route_definition,
22+
route_args=self.route_args,
23+
route_kwargs=self.route_kwargs,
24+
)
3625
return func
3726

27+
@property
28+
def route_args(self) -> Tuple[Any, ...]:
29+
return self._route_args
30+
31+
@route_args.setter
32+
def route_args(self, value: Tuple[Any, ...]) -> None:
33+
self._route_args = value
34+
35+
@property
36+
def route_kwargs(self) -> Dict[str, Any]:
37+
return self._route_kwargs
38+
39+
@route_kwargs.setter
40+
def route_kwargs(self, value: Dict[str, Any]) -> None:
41+
self._route_kwargs = value
42+
43+
44+
class delete(_RouteDecorator, route_definition=Route.delete):
45+
...
46+
3847

39-
class delete(_RouteDecorator, method=_HTTPRequestMethod.DELETE):
48+
class get(_RouteDecorator, route_definition=Route.get):
4049
...
4150

4251

43-
class get(_RouteDecorator, method=_HTTPRequestMethod.GET):
52+
class head(_RouteDecorator, route_definition=Route.head):
4453
...
4554

4655

47-
class head(_RouteDecorator, method=_HTTPRequestMethod.HEAD):
56+
class options(_RouteDecorator, route_definition=Route.options):
4857
...
4958

5059

51-
class options(_RouteDecorator, method=_HTTPRequestMethod.OPTIONS):
60+
class patch(_RouteDecorator, route_definition=Route.patch):
5261
...
5362

5463

55-
class patch(_RouteDecorator, method=_HTTPRequestMethod.PATCH):
64+
class post(_RouteDecorator, route_definition=Route.post):
5665
...
5766

5867

59-
class post(_RouteDecorator, method=_HTTPRequestMethod.POST):
68+
class put(_RouteDecorator, route_definition=Route.put):
6069
...
6170

6271

63-
class put(_RouteDecorator, method=_HTTPRequestMethod.PUT):
72+
class trace(_RouteDecorator, route_definition=Route.trace):
6473
...
6574

6675

67-
class trace(_RouteDecorator, method=_HTTPRequestMethod.TRACE):
76+
class websocket(_RouteDecorator, route_definition=Route.websocket):
6877
...

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "fastapi-controllers"
3-
version = "0.1.1"
3+
version = "0.2.0"
44
description = "Simple Controller implementation for FastAPI"
55
authors = ["Jerzy Góra <[email protected]>"]
66
license = "MIT"

tests/functional/conftest.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
import pytest
55
from fastapi import Depends, FastAPI, Response, status
66
from fastapi.testclient import TestClient
7+
from fastapi.websockets import WebSocket
78

8-
from fastapi_controllers import Controller, delete, get, head, options, patch, post, put, trace
9+
from fastapi_controllers import Controller, delete, get, head, options, patch, post, put, trace, websocket
910

1011

1112
def sync_dependency() -> str:
@@ -97,6 +98,12 @@ async def test_put(self) -> Response:
9798
async def test_trace(self) -> Response:
9899
return Response(status_code=status.HTTP_200_OK)
99100

101+
@websocket("/ws")
102+
async def websocket(websocket: WebSocket) -> None: # type: ignore
103+
await websocket.accept()
104+
await websocket.send_json({"msg": "Hello WebSocket"})
105+
await websocket.close()
106+
100107

101108
@pytest.fixture
102109
def sync_test_client() -> TestClient:

tests/functional/test_async.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,28 @@
11
import pytest
22
from fastapi import status
33
from fastapi.testclient import TestClient
4+
from starlette.websockets import WebSocketDisconnect
45

5-
from fastapi_controllers.routing import _HTTPRequestMethod
6+
from fastapi_controllers.definitions import HTTPRequestMethod
67

78

89
def describe_test_controller_async() -> None:
9-
@pytest.mark.parametrize("http_request_method", _HTTPRequestMethod.__members__.values())
10+
@pytest.mark.parametrize("http_request_method", HTTPRequestMethod.__members__.values())
1011
def it_responds_to_http_methods(
1112
async_test_client: TestClient,
12-
http_request_method: _HTTPRequestMethod,
13+
http_request_method: HTTPRequestMethod,
1314
) -> None:
1415
request_func = getattr(async_test_client, http_request_method.lower(), None)
1516
if request_func:
1617
response = request_func("/test-async")
1718
assert response.status_code == status.HTTP_200_OK
1819

20+
def it_supports_websockets(async_test_client: TestClient) -> None:
21+
with pytest.raises(WebSocketDisconnect):
22+
with async_test_client.websocket_connect("/ws") as websocket:
23+
data = websocket.receive_json()
24+
assert data == {"msg": "Hello WebSocket"}
25+
1926
def it_resolves_async_dependencies(async_test_client: TestClient) -> None:
2027
response = async_test_client.get("/test-async")
2128
assert response.status_code == status.HTTP_200_OK

tests/functional/test_sync.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@
22
from fastapi import status
33
from fastapi.testclient import TestClient
44

5-
from fastapi_controllers.routing import _HTTPRequestMethod
5+
from fastapi_controllers.definitions import HTTPRequestMethod
66

77

88
def describe_test_controller_sync() -> None:
9-
@pytest.mark.parametrize("http_request_method", _HTTPRequestMethod.__members__.values())
9+
@pytest.mark.parametrize("http_request_method", HTTPRequestMethod.__members__.values())
1010
def it_responds_to_http_methods(
1111
sync_test_client: TestClient,
12-
http_request_method: _HTTPRequestMethod,
12+
http_request_method: HTTPRequestMethod,
1313
) -> None:
1414
request_func = getattr(sync_test_client, http_request_method.lower(), None)
1515
if request_func:

tests/unit/test_controllers.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from pytest_mock import MockerFixture
66

77
from fastapi_controllers.controllers import Controller
8-
from fastapi_controllers.routing import get
8+
from fastapi_controllers.routing import get, websocket
99

1010

1111
@pytest.fixture(autouse=True)
@@ -80,6 +80,10 @@ class FakeController(Controller):
8080
def fake_method(self) -> None:
8181
...
8282

83+
@websocket("/ws")
84+
def fake_ws(self) -> None:
85+
...
86+
8387
FakeController.create_router()
8488
apirouter.assert_called_once_with(prefix="/test", dependencies=None, tags=None)
8589
apirouter.return_value.add_api_route.assert_called_once_with(
@@ -88,3 +92,7 @@ def fake_method(self) -> None:
8892
deprecated=True,
8993
methods=["GET"],
9094
)
95+
apirouter.return_value.add_api_websocket_route.assert_called_once_with(
96+
"/ws",
97+
FakeController.fake_ws,
98+
)

0 commit comments

Comments
 (0)