Skip to content

Commit 685768a

Browse files
authored
[PR #9698/422ba0c backport][3.11] Add flow control tests for WebSocket (#9701)
Co-authored-by: pre-commit-ci[bot]
1 parent edd44a2 commit 685768a

File tree

1 file changed

+67
-3
lines changed

1 file changed

+67
-3
lines changed

tests/test_websocket_parser.py

Lines changed: 67 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
websocket_mask,
1919
)
2020
from aiohttp._websocket.models import WS_DEFLATE_TRAILING
21+
from aiohttp.base_protocol import BaseProtocol
2122
from aiohttp.http import WebSocketError, WSCloseCode, WSMessage, WSMsgType
2223
from aiohttp.http_websocket import WebSocketReader
2324

@@ -83,12 +84,36 @@ def build_close_frame(code=1000, message=b"", noheader=False):
8384

8485

8586
@pytest.fixture()
86-
def out(loop):
87-
return aiohttp.DataQueue(loop)
87+
def protocol(loop: asyncio.AbstractEventLoop) -> BaseProtocol:
88+
transport = mock.Mock(spec_set=asyncio.Transport)
89+
protocol = BaseProtocol(loop)
90+
protocol.connection_made(transport)
91+
return protocol
8892

8993

9094
@pytest.fixture()
91-
def parser(out: aiohttp.DataQueue[WSMessage]) -> PatchableWebSocketReader:
95+
def out(
96+
loop: asyncio.AbstractEventLoop, protocol: BaseProtocol
97+
) -> aiohttp.FlowControlDataQueue[WSMessage]:
98+
return aiohttp.FlowControlDataQueue(protocol, 2**16, loop=loop)
99+
100+
101+
@pytest.fixture()
102+
def out_low_limit(
103+
loop: asyncio.AbstractEventLoop, protocol: BaseProtocol
104+
) -> aiohttp.FlowControlDataQueue[WSMessage]:
105+
return aiohttp.FlowControlDataQueue(protocol, 16, loop=loop)
106+
107+
108+
@pytest.fixture()
109+
def parser_low_limit(
110+
out_low_limit: aiohttp.FlowControlDataQueue[WSMessage],
111+
) -> PatchableWebSocketReader:
112+
return PatchableWebSocketReader(out_low_limit, 4 * 1024 * 1024)
113+
114+
115+
@pytest.fixture()
116+
def parser(out: aiohttp.FlowControlDataQueue[WSMessage]) -> PatchableWebSocketReader:
92117
return PatchableWebSocketReader(out, 4 * 1024 * 1024)
93118

94119

@@ -569,3 +594,42 @@ def test_pickle(self) -> None:
569594
assert err2.code == WSCloseCode.PROTOCOL_ERROR
570595
assert str(err2) == "Something invalid"
571596
assert err2.foo == "bar"
597+
598+
599+
def test_flow_control_binary(
600+
protocol: BaseProtocol,
601+
out_low_limit: aiohttp.FlowControlDataQueue[WSMessage],
602+
parser_low_limit: WebSocketReader,
603+
) -> None:
604+
large_payload = b"b" * (1 + 16 * 2)
605+
large_payload_len = len(large_payload)
606+
with mock.patch.object(parser_low_limit, "parse_frame", autospec=True) as m:
607+
m.return_value = [(1, WSMsgType.BINARY, large_payload, False)]
608+
609+
parser_low_limit.feed_data(b"")
610+
611+
res = out_low_limit._buffer[0]
612+
assert res == (WSMessage(WSMsgType.BINARY, large_payload, ""), large_payload_len)
613+
assert protocol._reading_paused is True
614+
615+
616+
@pytest.mark.xfail(
617+
reason="Flow control is currently broken on master branch; see #9686"
618+
)
619+
def test_flow_control_multi_byte_text(
620+
protocol: BaseProtocol,
621+
out_low_limit: aiohttp.FlowControlDataQueue[WSMessage],
622+
parser_low_limit: WebSocketReader,
623+
) -> None:
624+
large_payload_text = "𒀁" * (1 + 16 * 2)
625+
large_payload = large_payload_text.encode("utf-8")
626+
large_payload_len = len(large_payload)
627+
628+
with mock.patch.object(parser_low_limit, "parse_frame", autospec=True) as m:
629+
m.return_value = [(1, WSMsgType.TEXT, large_payload, False)]
630+
631+
parser_low_limit.feed_data(b"")
632+
633+
res = out_low_limit._buffer[0]
634+
assert res == (WSMessage(WSMsgType.TEXT, large_payload_text, ""), large_payload_len)
635+
assert protocol._reading_paused is True

0 commit comments

Comments
 (0)