Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1124,7 +1124,7 @@ async def _ws_connect(
headers=resp.headers,
)

if resp.headers.get(hdrs.CONNECTION, "").lower() != "upgrade":
if not resp._upgraded:
raise WSServerHandshakeError(
resp.request_info,
resp.history,
Expand Down
2 changes: 2 additions & 0 deletions aiohttp/client_reqrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ class ClientResponse(HeadersMixin):
_headers: HeadersDictProxy = None # type: ignore[assignment]
_history: tuple["ClientResponse", ...] = ()
_raw_headers: RawHeaders = None # type: ignore[assignment]
_upgraded: bool = False # parser saw a Connection: upgrade token

_connection: "Connection | None" = None # current connection
_cookies: SimpleCookie | None = None
Expand Down Expand Up @@ -490,6 +491,7 @@ async def start(self, connection: "Connection") -> "ClientResponse":
# headers
self._headers = message.headers
self._raw_headers = message.raw_headers
self._upgraded = message.upgrade

# payload
self.content = payload
Expand Down
2 changes: 1 addition & 1 deletion requirements/constraints.txt
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ cryptography==48.0.0
# via trustme
cython==3.2.5
# via -r requirements/cython.in
distlib==0.4.1
distlib==0.4.2
# via virtualenv
docutils==0.21.2
# via
Expand Down
2 changes: 1 addition & 1 deletion requirements/dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ coverage==7.14.1
# pytest-cov
cryptography==48.0.0
# via trustme
distlib==0.4.1
distlib==0.4.2
# via virtualenv
docutils==0.21.2
# via
Expand Down
2 changes: 1 addition & 1 deletion requirements/lint.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ click==8.4.1
# via slotscheck
cryptography==48.0.0
# via trustme
distlib==0.4.1
distlib==0.4.2
# via virtualenv
exceptiongroup==1.3.1
# via pytest
Expand Down
40 changes: 40 additions & 0 deletions tests/test_client_proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,3 +387,43 @@ async def test_abort_without_transport() -> None:
# Should not raise and should still clean up
assert proto._exception is None
mock_drop_timeout.assert_not_called()


@pytest.mark.parametrize(
("connection", "expected"),
[(b"upgrade, keep-alive", True), (b"keep-alive", False)],
)
async def test_response_start_records_upgrade(
connection: bytes, expected: bool
) -> None:
"""ClientResponse.start() preserves the parser's Connection upgrade flag."""
loop = asyncio.get_running_loop()
proto = ResponseHandler(loop=loop)
proto.connection_made(mock.Mock())
conn = mock.Mock(protocol=proto)
proto.set_response_params(read_until_eof=True)
proto.data_received(
b"HTTP/1.1 101 Switching Protocols\r\n"
b"Upgrade: websocket\r\n"
b"Connection: " + connection + b"\r\n\r\n"
)

url = URL("http://ws-upgrade.org")
response = ClientResponse(
"get",
url,
writer=mock.Mock(),
continue100=None,
timer=TimerNoop(),
traces=[],
loop=loop,
session=mock.Mock(),
request_headers=CIMultiDict[str](),
original_url=url,
stream_writer=mock.create_autospec(
AbstractStreamWriter, spec_set=True, instance=True
),
)
await response.start(conn)
assert response._upgraded is expected
response.close()
4 changes: 4 additions & 0 deletions tests/test_client_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ async def test_ws_connect(ws_key: str, key_data: bytes) -> None:
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
hdrs.SEC_WEBSOCKET_PROTOCOL: "chat",
}
resp._upgraded = True
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
Expand Down Expand Up @@ -255,13 +256,16 @@ async def test_ws_connect_err_upgrade(ws_key: str, key_data: bytes) -> None:


async def test_ws_connect_err_conn(ws_key: str, key_data: bytes) -> None:
# The parser did not see a Connection: upgrade token (resp._upgraded is
# False), so the handshake must be rejected.
resp = mock.Mock()
resp.status = 101
resp.headers = {
hdrs.UPGRADE: "websocket",
hdrs.CONNECTION: "close",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
}
resp._upgraded = False
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
m_os.urandom.return_value = key_data
Expand Down
28 changes: 28 additions & 0 deletions tests/test_http_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,6 +847,34 @@ def test_upgrade_header_non_ascii(parser: HttpRequestParser) -> None:
assert not upgrade


@pytest.mark.parametrize(
("connection", "expected"),
[
("upgrade", True),
("upgrade, keep-alive", True), # other tokens alongside upgrade
("keep-alive, upgrade", True), # upgrade not first
("Upgrade, Keep-Alive", True), # case-insensitive
("keep-alive", False), # no upgrade token
("keep-alive, notupgrade", False), # substring is not a token
],
)
def test_response_upgrade_token_in_connection_list(
response: HttpResponseParser, connection: str, expected: bool
) -> None:
# RFC 9110 §7.6.1: Connection is a comma-separated token list, so the parser
# must set msg.upgrade for a 101 response whenever "upgrade" appears as a
# token, regardless of position, case, or neighbouring tokens.
text = (
b"HTTP/1.1 101 Switching Protocols\r\n"
b"Upgrade: websocket\r\n"
b"Connection: " + connection.encode() + b"\r\n\r\n"
)
messages, upgrade, tail = response.feed_data(text)
msg = messages[0][0]
assert msg.upgrade == expected
assert upgrade == expected


def test_request_te_chunked_with_content_length(parser: HttpRequestParser) -> None:
text = (
b"GET /test HTTP/1.1\r\n"
Expand Down
Loading