diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index 04aed345e..affa6b572 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -391,12 +391,23 @@ async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> No await self._handle_unsupported_request(request, send) def _check_accept_headers(self, request: Request) -> tuple[bool, bool]: - """Check if the request accepts the required media types.""" + """Check if the request accepts the required media types. + + Supports wildcard media types per RFC 7231, section 5.3.2: + - */* matches any media type + - application/* matches any application/ subtype + - text/* matches any text/ subtype + """ accept_header = request.headers.get("accept", "") - accept_types = [media_type.strip() for media_type in accept_header.split(",")] + accept_types = [media_type.strip().split(";")[0].strip() for media_type in accept_header.split(",")] - has_json = any(media_type.startswith(CONTENT_TYPE_JSON) for media_type in accept_types) - has_sse = any(media_type.startswith(CONTENT_TYPE_SSE) for media_type in accept_types) + has_wildcard = "*/*" in accept_types + has_json = has_wildcard or any( + media_type.startswith(CONTENT_TYPE_JSON) or media_type == "application/*" for media_type in accept_types + ) + has_sse = has_wildcard or any( + media_type.startswith(CONTENT_TYPE_SSE) or media_type == "text/*" for media_type in accept_types + ) return has_json, has_sse diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 42b1a3698..4596d6631 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -572,8 +572,10 @@ def json_server_url(json_server_port: int) -> str: # Basic request validation tests def test_accept_header_validation(basic_server: None, basic_server_url: str): """Test that Accept header is properly validated.""" - # Test without Accept header - response = requests.post( + # Test without Accept header (suppress requests library default Accept: */*) + session = requests.Session() + session.headers.update({"Accept": None}) # type: ignore[arg-type] + response = session.post( f"{basic_server_url}/mcp", headers={"Content-Type": "application/json"}, json={"jsonrpc": "2.0", "method": "initialize", "id": 1}, @@ -582,6 +584,52 @@ def test_accept_header_validation(basic_server: None, basic_server_url: str): assert "Not Acceptable" in response.text +@pytest.mark.parametrize( + "accept_header", + [ + "*/*", + "application/*, text/*", + "text/*, application/json", + "application/json, text/*", + "*/*;q=0.8", + "application/*;q=0.9, text/*;q=0.8", + ], +) +def test_accept_header_wildcard(basic_server: None, basic_server_url: str, accept_header: str): + """Test that wildcard Accept headers are accepted per RFC 7231.""" + response = requests.post( + f"{basic_server_url}/mcp", + headers={ + "Accept": accept_header, + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 200 + + +@pytest.mark.parametrize( + "accept_header", + [ + "text/html", + "application/*", + "text/*", + ], +) +def test_accept_header_incompatible(basic_server: None, basic_server_url: str, accept_header: str): + """Test that incompatible Accept headers are rejected for SSE mode.""" + response = requests.post( + f"{basic_server_url}/mcp", + headers={ + "Accept": accept_header, + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 406 + assert "Not Acceptable" in response.text + + def test_content_type_validation(basic_server: None, basic_server_url: str): """Test that Content-Type header is properly validated.""" # Test with incorrect Content-Type @@ -826,7 +874,10 @@ def test_json_response_accept_json_only(json_response_server: None, json_server_ def test_json_response_missing_accept_header(json_response_server: None, json_server_url: str): """Test that json_response servers reject requests without Accept header.""" mcp_url = f"{json_server_url}/mcp" - response = requests.post( + # Suppress requests library default Accept: */* header + session = requests.Session() + session.headers.update({"Accept": None}) # type: ignore[arg-type] + response = session.post( mcp_url, headers={ "Content-Type": "application/json", @@ -853,6 +904,29 @@ def test_json_response_incorrect_accept_header(json_response_server: None, json_ assert "Not Acceptable" in response.text +@pytest.mark.parametrize( + "accept_header", + [ + "*/*", + "application/*", + "application/*;q=0.9", + ], +) +def test_json_response_wildcard_accept_header(json_response_server: None, json_server_url: str, accept_header: str): + """Test that json_response servers accept wildcard Accept headers per RFC 7231.""" + mcp_url = f"{json_server_url}/mcp" + response = requests.post( + mcp_url, + headers={ + "Accept": accept_header, + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 200 + assert response.headers.get("Content-Type") == "application/json" + + def test_get_sse_stream(basic_server: None, basic_server_url: str): """Test establishing an SSE stream via GET request.""" # First, we need to initialize a session @@ -941,8 +1015,10 @@ def test_get_validation(basic_server: None, basic_server_url: str): assert init_data is not None negotiated_version = init_data["result"]["protocolVersion"] - # Test without Accept header - response = requests.get( + # Test without Accept header (suppress requests library default Accept: */*) + session = requests.Session() + session.headers.update({"Accept": None}) # type: ignore[arg-type] + response = session.get( mcp_url, headers={ MCP_SESSION_ID_HEADER: session_id,