Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions src/mcp/server/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,12 +391,23 @@ async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> No
await self._handle_unsupported_request(request, send)

def _check_accept_headers(self, request: Request) -> tuple[bool, bool]:
"""Check if the request accepts the required media types."""
"""Check if the request accepts the required media types.

Supports wildcard media types per RFC 7231, section 5.3.2:
- */* matches any media type
- application/* matches any application/ subtype
- text/* matches any text/ subtype
"""
accept_header = request.headers.get("accept", "")
accept_types = [media_type.strip() for media_type in accept_header.split(",")]
accept_types = [media_type.strip().split(";")[0].strip() for media_type in accept_header.split(",")]

has_json = any(media_type.startswith(CONTENT_TYPE_JSON) for media_type in accept_types)
has_sse = any(media_type.startswith(CONTENT_TYPE_SSE) for media_type in accept_types)
has_wildcard = "*/*" in accept_types
has_json = has_wildcard or any(
media_type.startswith(CONTENT_TYPE_JSON) or media_type == "application/*" for media_type in accept_types
)
has_sse = has_wildcard or any(
media_type.startswith(CONTENT_TYPE_SSE) or media_type == "text/*" for media_type in accept_types
)

return has_json, has_sse

Expand Down
86 changes: 81 additions & 5 deletions tests/shared/test_streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,8 +572,10 @@ def json_server_url(json_server_port: int) -> str:
# Basic request validation tests
def test_accept_header_validation(basic_server: None, basic_server_url: str):
"""Test that Accept header is properly validated."""
# Test without Accept header
response = requests.post(
# Test without Accept header (suppress requests library default Accept: */*)
session = requests.Session()
session.headers.update({"Accept": None}) # type: ignore[arg-type]
response = session.post(
f"{basic_server_url}/mcp",
headers={"Content-Type": "application/json"},
json={"jsonrpc": "2.0", "method": "initialize", "id": 1},
Expand All @@ -582,6 +584,52 @@ def test_accept_header_validation(basic_server: None, basic_server_url: str):
assert "Not Acceptable" in response.text


@pytest.mark.parametrize(
"accept_header",
[
"*/*",
"application/*, text/*",
"text/*, application/json",
"application/json, text/*",
"*/*;q=0.8",
"application/*;q=0.9, text/*;q=0.8",
],
)
def test_accept_header_wildcard(basic_server: None, basic_server_url: str, accept_header: str):
"""Test that wildcard Accept headers are accepted per RFC 7231."""
response = requests.post(
f"{basic_server_url}/mcp",
headers={
"Accept": accept_header,
"Content-Type": "application/json",
},
json=INIT_REQUEST,
)
assert response.status_code == 200


@pytest.mark.parametrize(
"accept_header",
[
"text/html",
"application/*",
"text/*",
],
)
def test_accept_header_incompatible(basic_server: None, basic_server_url: str, accept_header: str):
"""Test that incompatible Accept headers are rejected for SSE mode."""
response = requests.post(
f"{basic_server_url}/mcp",
headers={
"Accept": accept_header,
"Content-Type": "application/json",
},
json=INIT_REQUEST,
)
assert response.status_code == 406
assert "Not Acceptable" in response.text


def test_content_type_validation(basic_server: None, basic_server_url: str):
"""Test that Content-Type header is properly validated."""
# Test with incorrect Content-Type
Expand Down Expand Up @@ -826,7 +874,10 @@ def test_json_response_accept_json_only(json_response_server: None, json_server_
def test_json_response_missing_accept_header(json_response_server: None, json_server_url: str):
"""Test that json_response servers reject requests without Accept header."""
mcp_url = f"{json_server_url}/mcp"
response = requests.post(
# Suppress requests library default Accept: */* header
session = requests.Session()
session.headers.update({"Accept": None}) # type: ignore[arg-type]
response = session.post(
mcp_url,
headers={
"Content-Type": "application/json",
Expand All @@ -853,6 +904,29 @@ def test_json_response_incorrect_accept_header(json_response_server: None, json_
assert "Not Acceptable" in response.text


@pytest.mark.parametrize(
"accept_header",
[
"*/*",
"application/*",
"application/*;q=0.9",
],
)
def test_json_response_wildcard_accept_header(json_response_server: None, json_server_url: str, accept_header: str):
"""Test that json_response servers accept wildcard Accept headers per RFC 7231."""
mcp_url = f"{json_server_url}/mcp"
response = requests.post(
mcp_url,
headers={
"Accept": accept_header,
"Content-Type": "application/json",
},
json=INIT_REQUEST,
)
assert response.status_code == 200
assert response.headers.get("Content-Type") == "application/json"


def test_get_sse_stream(basic_server: None, basic_server_url: str):
"""Test establishing an SSE stream via GET request."""
# First, we need to initialize a session
Expand Down Expand Up @@ -941,8 +1015,10 @@ def test_get_validation(basic_server: None, basic_server_url: str):
assert init_data is not None
negotiated_version = init_data["result"]["protocolVersion"]

# Test without Accept header
response = requests.get(
# Test without Accept header (suppress requests library default Accept: */*)
session = requests.Session()
session.headers.update({"Accept": None}) # type: ignore[arg-type]
response = session.get(
mcp_url,
headers={
MCP_SESSION_ID_HEADER: session_id,
Expand Down