diff --git a/aiohttp/client.py b/aiohttp/client.py index 4eb4e9454e2..59f9673af82 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -1124,7 +1124,7 @@ async def _ws_connect( headers=resp.headers, ) - if resp.headers.get(hdrs.CONNECTION, "").lower() != "upgrade": + if not resp._upgraded: raise WSServerHandshakeError( resp.request_info, resp.history, diff --git a/aiohttp/client_reqrep.py b/aiohttp/client_reqrep.py index 8a060105146..6c5501e61b1 100644 --- a/aiohttp/client_reqrep.py +++ b/aiohttp/client_reqrep.py @@ -197,6 +197,7 @@ class ClientResponse(HeadersMixin): _headers: HeadersDictProxy = None # type: ignore[assignment] _history: tuple["ClientResponse", ...] = () _raw_headers: RawHeaders = None # type: ignore[assignment] + _upgraded: bool = False # parser saw a Connection: upgrade token _connection: "Connection | None" = None # current connection _cookies: SimpleCookie | None = None @@ -490,6 +491,7 @@ async def start(self, connection: "Connection") -> "ClientResponse": # headers self._headers = message.headers self._raw_headers = message.raw_headers + self._upgraded = message.upgrade # payload self.content = payload diff --git a/requirements/constraints.txt b/requirements/constraints.txt index bf6706c7e15..0006a5c94fc 100644 --- a/requirements/constraints.txt +++ b/requirements/constraints.txt @@ -71,7 +71,7 @@ cryptography==48.0.0 # via trustme cython==3.2.5 # via -r requirements/cython.in -distlib==0.4.1 +distlib==0.4.2 # via virtualenv docutils==0.21.2 # via diff --git a/requirements/dev.txt b/requirements/dev.txt index 9cb104b1ac0..a54cfb9ad9f 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -69,7 +69,7 @@ coverage==7.14.1 # pytest-cov cryptography==48.0.0 # via trustme -distlib==0.4.1 +distlib==0.4.2 # via virtualenv docutils==0.21.2 # via diff --git a/requirements/lint.txt b/requirements/lint.txt index 2d0369bcfc0..d4d482d03b2 100644 --- a/requirements/lint.txt +++ b/requirements/lint.txt @@ -36,7 +36,7 @@ click==8.4.1 # via slotscheck cryptography==48.0.0 # via trustme -distlib==0.4.1 +distlib==0.4.2 # via virtualenv exceptiongroup==1.3.1 # via pytest diff --git a/tests/test_client_proto.py b/tests/test_client_proto.py index 4bcf860ac43..42e79978bf8 100644 --- a/tests/test_client_proto.py +++ b/tests/test_client_proto.py @@ -387,3 +387,43 @@ async def test_abort_without_transport() -> None: # Should not raise and should still clean up assert proto._exception is None mock_drop_timeout.assert_not_called() + + +@pytest.mark.parametrize( + ("connection", "expected"), + [(b"upgrade, keep-alive", True), (b"keep-alive", False)], +) +async def test_response_start_records_upgrade( + connection: bytes, expected: bool +) -> None: + """ClientResponse.start() preserves the parser's Connection upgrade flag.""" + loop = asyncio.get_running_loop() + proto = ResponseHandler(loop=loop) + proto.connection_made(mock.Mock()) + conn = mock.Mock(protocol=proto) + proto.set_response_params(read_until_eof=True) + proto.data_received( + b"HTTP/1.1 101 Switching Protocols\r\n" + b"Upgrade: websocket\r\n" + b"Connection: " + connection + b"\r\n\r\n" + ) + + url = URL("http://ws-upgrade.org") + response = ClientResponse( + "get", + url, + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=loop, + session=mock.Mock(), + request_headers=CIMultiDict[str](), + original_url=url, + stream_writer=mock.create_autospec( + AbstractStreamWriter, spec_set=True, instance=True + ), + ) + await response.start(conn) + assert response._upgraded is expected + response.close() diff --git a/tests/test_client_ws.py b/tests/test_client_ws.py index 2c85734ca3f..44c71a5ec37 100644 --- a/tests/test_client_ws.py +++ b/tests/test_client_ws.py @@ -30,6 +30,7 @@ async def test_ws_connect(ws_key: str, key_data: bytes) -> None: hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, hdrs.SEC_WEBSOCKET_PROTOCOL: "chat", } + resp._upgraded = True resp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: @@ -255,6 +256,8 @@ async def test_ws_connect_err_upgrade(ws_key: str, key_data: bytes) -> None: async def test_ws_connect_err_conn(ws_key: str, key_data: bytes) -> None: + # The parser did not see a Connection: upgrade token (resp._upgraded is + # False), so the handshake must be rejected. resp = mock.Mock() resp.status = 101 resp.headers = { @@ -262,6 +265,7 @@ async def test_ws_connect_err_conn(ws_key: str, key_data: bytes) -> None: hdrs.CONNECTION: "close", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, } + resp._upgraded = False with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data diff --git a/tests/test_http_parser.py b/tests/test_http_parser.py index 8b6d5f52094..dd28b194650 100644 --- a/tests/test_http_parser.py +++ b/tests/test_http_parser.py @@ -847,6 +847,34 @@ def test_upgrade_header_non_ascii(parser: HttpRequestParser) -> None: assert not upgrade +@pytest.mark.parametrize( + ("connection", "expected"), + [ + ("upgrade", True), + ("upgrade, keep-alive", True), # other tokens alongside upgrade + ("keep-alive, upgrade", True), # upgrade not first + ("Upgrade, Keep-Alive", True), # case-insensitive + ("keep-alive", False), # no upgrade token + ("keep-alive, notupgrade", False), # substring is not a token + ], +) +def test_response_upgrade_token_in_connection_list( + response: HttpResponseParser, connection: str, expected: bool +) -> None: + # RFC 9110 ยง7.6.1: Connection is a comma-separated token list, so the parser + # must set msg.upgrade for a 101 response whenever "upgrade" appears as a + # token, regardless of position, case, or neighbouring tokens. + text = ( + b"HTTP/1.1 101 Switching Protocols\r\n" + b"Upgrade: websocket\r\n" + b"Connection: " + connection.encode() + b"\r\n\r\n" + ) + messages, upgrade, tail = response.feed_data(text) + msg = messages[0][0] + assert msg.upgrade == expected + assert upgrade == expected + + def test_request_te_chunked_with_content_length(parser: HttpRequestParser) -> None: text = ( b"GET /test HTTP/1.1\r\n"