|
18 | 18 | websocket_mask,
|
19 | 19 | )
|
20 | 20 | from aiohttp._websocket.models import WS_DEFLATE_TRAILING
|
| 21 | +from aiohttp.base_protocol import BaseProtocol |
21 | 22 | from aiohttp.http import WebSocketError, WSCloseCode, WSMessage, WSMsgType
|
22 | 23 | from aiohttp.http_websocket import WebSocketReader
|
23 | 24 |
|
@@ -83,12 +84,36 @@ def build_close_frame(code=1000, message=b"", noheader=False):
|
83 | 84 |
|
84 | 85 |
|
85 | 86 | @pytest.fixture()
|
86 |
| -def out(loop): |
87 |
| - return aiohttp.DataQueue(loop) |
| 87 | +def protocol(loop: asyncio.AbstractEventLoop) -> BaseProtocol: |
| 88 | + transport = mock.Mock(spec_set=asyncio.Transport) |
| 89 | + protocol = BaseProtocol(loop) |
| 90 | + protocol.connection_made(transport) |
| 91 | + return protocol |
88 | 92 |
|
89 | 93 |
|
90 | 94 | @pytest.fixture()
|
91 |
| -def parser(out: aiohttp.DataQueue[WSMessage]) -> PatchableWebSocketReader: |
| 95 | +def out( |
| 96 | + loop: asyncio.AbstractEventLoop, protocol: BaseProtocol |
| 97 | +) -> aiohttp.FlowControlDataQueue[WSMessage]: |
| 98 | + return aiohttp.FlowControlDataQueue(protocol, 2**16, loop=loop) |
| 99 | + |
| 100 | + |
| 101 | +@pytest.fixture() |
| 102 | +def out_low_limit( |
| 103 | + loop: asyncio.AbstractEventLoop, protocol: BaseProtocol |
| 104 | +) -> aiohttp.FlowControlDataQueue[WSMessage]: |
| 105 | + return aiohttp.FlowControlDataQueue(protocol, 16, loop=loop) |
| 106 | + |
| 107 | + |
| 108 | +@pytest.fixture() |
| 109 | +def parser_low_limit( |
| 110 | + out_low_limit: aiohttp.FlowControlDataQueue[WSMessage], |
| 111 | +) -> PatchableWebSocketReader: |
| 112 | + return PatchableWebSocketReader(out_low_limit, 4 * 1024 * 1024) |
| 113 | + |
| 114 | + |
| 115 | +@pytest.fixture() |
| 116 | +def parser(out: aiohttp.FlowControlDataQueue[WSMessage]) -> PatchableWebSocketReader: |
92 | 117 | return PatchableWebSocketReader(out, 4 * 1024 * 1024)
|
93 | 118 |
|
94 | 119 |
|
@@ -569,3 +594,42 @@ def test_pickle(self) -> None:
|
569 | 594 | assert err2.code == WSCloseCode.PROTOCOL_ERROR
|
570 | 595 | assert str(err2) == "Something invalid"
|
571 | 596 | assert err2.foo == "bar"
|
| 597 | + |
| 598 | + |
| 599 | +def test_flow_control_binary( |
| 600 | + protocol: BaseProtocol, |
| 601 | + out_low_limit: aiohttp.FlowControlDataQueue[WSMessage], |
| 602 | + parser_low_limit: WebSocketReader, |
| 603 | +) -> None: |
| 604 | + large_payload = b"b" * (1 + 16 * 2) |
| 605 | + large_payload_len = len(large_payload) |
| 606 | + with mock.patch.object(parser_low_limit, "parse_frame", autospec=True) as m: |
| 607 | + m.return_value = [(1, WSMsgType.BINARY, large_payload, False)] |
| 608 | + |
| 609 | + parser_low_limit.feed_data(b"") |
| 610 | + |
| 611 | + res = out_low_limit._buffer[0] |
| 612 | + assert res == (WSMessage(WSMsgType.BINARY, large_payload, ""), large_payload_len) |
| 613 | + assert protocol._reading_paused is True |
| 614 | + |
| 615 | + |
| 616 | +@pytest.mark.xfail( |
| 617 | + reason="Flow control is currently broken on master branch; see #9686" |
| 618 | +) |
| 619 | +def test_flow_control_multi_byte_text( |
| 620 | + protocol: BaseProtocol, |
| 621 | + out_low_limit: aiohttp.FlowControlDataQueue[WSMessage], |
| 622 | + parser_low_limit: WebSocketReader, |
| 623 | +) -> None: |
| 624 | + large_payload_text = "𒀁" * (1 + 16 * 2) |
| 625 | + large_payload = large_payload_text.encode("utf-8") |
| 626 | + large_payload_len = len(large_payload) |
| 627 | + |
| 628 | + with mock.patch.object(parser_low_limit, "parse_frame", autospec=True) as m: |
| 629 | + m.return_value = [(1, WSMsgType.TEXT, large_payload, False)] |
| 630 | + |
| 631 | + parser_low_limit.feed_data(b"") |
| 632 | + |
| 633 | + res = out_low_limit._buffer[0] |
| 634 | + assert res == (WSMessage(WSMsgType.TEXT, large_payload_text, ""), large_payload_len) |
| 635 | + assert protocol._reading_paused is True |
0 commit comments