Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions litellm/responses/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,19 @@ async def aresponses_api_with_mcp(
"litellm_metadata", {}
).get("user_api_key_auth")

# Extract MCP auth headers from request (for dynamic auth when fetching tools)
mcp_auth_header: Optional[str] = None
mcp_server_auth_headers: Optional[Dict[str, Dict[str, str]]] = None
secret_fields = kwargs.get("secret_fields")
if secret_fields and isinstance(secret_fields, dict):
from litellm.responses.utils import ResponsesAPIRequestUtils

mcp_auth_header, mcp_server_auth_headers, _, _ = (
ResponsesAPIRequestUtils.extract_mcp_headers_from_request(
secret_fields=secret_fields, tools=tools
)
)

# Get original MCP tools (for events) and OpenAI tools (for LLM) by reusing existing methods
(
original_mcp_tools,
Expand All @@ -185,6 +198,8 @@ async def aresponses_api_with_mcp(
user_api_key_auth=user_api_key_auth,
mcp_tools_with_litellm_proxy=mcp_tools_with_litellm_proxy,
litellm_trace_id=kwargs.get("litellm_trace_id"),
mcp_auth_header=mcp_auth_header,
mcp_server_auth_headers=mcp_server_auth_headers,
)
openai_tools = LiteLLM_Proxy_MCP_Handler._transform_mcp_tools_to_openai(
original_mcp_tools
Expand Down Expand Up @@ -370,6 +385,8 @@ async def aresponses_api_with_mcp(
) = await LiteLLM_Proxy_MCP_Handler._process_mcp_tools_without_openai_transform(
user_api_key_auth=user_api_key_auth,
mcp_tools_with_litellm_proxy=mcp_tools_with_litellm_proxy,
mcp_auth_header=mcp_auth_header,
mcp_server_auth_headers=mcp_server_auth_headers,
)
final_response = (
LiteLLM_Proxy_MCP_Handler._add_mcp_output_elements_to_response(
Expand Down
26 changes: 14 additions & 12 deletions litellm/responses/mcp/chat_completions_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,14 +120,27 @@ async def acompletion_with_mcp( # noqa: PLR0915
(kwargs.get("metadata", {}) or {}).get("user_api_key_auth")
)

# Process MCP tools
# Extract MCP auth headers before fetching tools (needed for dynamic auth)
(
mcp_auth_header,
mcp_server_auth_headers,
oauth2_headers,
raw_headers,
) = ResponsesAPIRequestUtils.extract_mcp_headers_from_request(
secret_fields=kwargs.get("secret_fields"),
tools=tools,
)

# Process MCP tools (pass auth headers for dynamic auth)
(
deduplicated_mcp_tools,
tool_server_map,
) = await LiteLLM_Proxy_MCP_Handler._process_mcp_tools_without_openai_transform(
user_api_key_auth=user_api_key_auth,
mcp_tools_with_litellm_proxy=mcp_tools_with_litellm_proxy,
litellm_trace_id=kwargs.get("litellm_trace_id"),
mcp_auth_header=mcp_auth_header,
mcp_server_auth_headers=mcp_server_auth_headers,
)

openai_tools = LiteLLM_Proxy_MCP_Handler._transform_mcp_tools_to_openai(
Expand All @@ -143,17 +156,6 @@ async def acompletion_with_mcp( # noqa: PLR0915
mcp_tools_with_litellm_proxy=mcp_tools_with_litellm_proxy
)

# Extract MCP auth headers
(
mcp_auth_header,
mcp_server_auth_headers,
oauth2_headers,
raw_headers,
) = ResponsesAPIRequestUtils.extract_mcp_headers_from_request(
secret_fields=kwargs.get("secret_fields"),
tools=tools,
)

# Prepare call parameters
# Remove keys that shouldn't be passed to acompletion
clean_kwargs = {k: v for k, v in kwargs.items() if k not in ["acompletion"]}
Expand Down
23 changes: 17 additions & 6 deletions litellm/responses/mcp/litellm_proxy_mcp_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,17 @@ async def _get_mcp_tools_from_manager(
user_api_key_auth: Any,
mcp_tools_with_litellm_proxy: Optional[Iterable[ToolParam]],
litellm_trace_id: Optional[str] = None,
mcp_auth_header: Optional[str] = None,
mcp_server_auth_headers: Optional[Dict[str, Dict[str, str]]] = None,
) -> tuple[List[MCPTool], List[str]]:
"""
Get available tools from the MCP server manager.

Args:
user_api_key_auth: User authentication info for access control
mcp_tools_with_litellm_proxy: ToolParam objects with server_url starting with "litellm_proxy"
mcp_auth_header: Optional deprecated auth header for MCP servers
mcp_server_auth_headers: Optional server-specific auth headers (e.g. from x-mcp-{alias}-*)

Returns:
List of MCP tools
Expand All @@ -126,20 +130,21 @@ async def _get_mcp_tools_from_manager(
server_url = (
_tool.get("server_url", "") if isinstance(_tool, dict) else ""
)
if isinstance(server_url, str) and server_url.startswith(
LITELLM_PROXY_MCP_SERVER_URL_PREFIX
):
mcp_servers.append(server_url.split("/")[-1])
if isinstance(server_url, str) and server_url.startswith(
LITELLM_PROXY_MCP_SERVER_URL_PREFIX
):
mcp_servers.append(server_url.split("/")[-1])

tools = await _get_tools_from_mcp_servers(
user_api_key_auth=user_api_key_auth,
mcp_auth_header=None,
mcp_auth_header=mcp_auth_header,
mcp_servers=mcp_servers,
mcp_server_auth_headers=None,
mcp_server_auth_headers=mcp_server_auth_headers,
log_list_tools_to_spendlogs=True,
list_tools_log_source="responses",
litellm_trace_id=litellm_trace_id,
)

allowed_mcp_server_ids = (
await global_mcp_server_manager.get_allowed_mcp_servers(user_api_key_auth)
)
Expand Down Expand Up @@ -278,6 +283,8 @@ async def _process_mcp_tools_without_openai_transform(
user_api_key_auth: Any,
mcp_tools_with_litellm_proxy: List[ToolParam],
litellm_trace_id: Optional[str] = None,
mcp_auth_header: Optional[str] = None,
mcp_server_auth_headers: Optional[Dict[str, Dict[str, str]]] = None,
) -> tuple[List[Any], dict[str, str]]:
"""
Process MCP tools through filtering and deduplication pipeline without OpenAI transformation.
Expand All @@ -286,6 +293,8 @@ async def _process_mcp_tools_without_openai_transform(
Args:
user_api_key_auth: User authentication info for access control
mcp_tools_with_litellm_proxy: ToolParam objects with server_url starting with "litellm_proxy"
mcp_auth_header: Optional deprecated auth header for MCP servers
mcp_server_auth_headers: Optional server-specific auth headers (e.g. from x-mcp-{alias}-*)

Returns:
List of filtered and deduplicated MCP tools in their original format
Expand All @@ -301,6 +310,8 @@ async def _process_mcp_tools_without_openai_transform(
user_api_key_auth=user_api_key_auth,
mcp_tools_with_litellm_proxy=mcp_tools_with_litellm_proxy,
litellm_trace_id=litellm_trace_id,
mcp_auth_header=mcp_auth_header,
mcp_server_auth_headers=mcp_server_auth_headers,
)

# Step 2: Filter tools based on allowed_tools parameter
Expand Down
71 changes: 71 additions & 0 deletions tests/mcp_tests/test_aresponses_api_with_mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import sys
import pytest
from typing import List, Any, cast
from unittest.mock import AsyncMock, patch

sys.path.insert(0, os.path.abspath("../../.."))

Expand Down Expand Up @@ -254,6 +255,76 @@ async def test_aresponses_api_with_mcp_mock_integration():
print(f"Other tools parsed: {len(other_parsed)}")


@pytest.mark.asyncio
async def test_aresponses_api_with_mcp_passes_mcp_server_auth_headers_to_process_tools():
"""
Test that MCP auth headers from secret_fields (e.g. x-mcp-linear_config-authorization)
are passed to _process_mcp_tools_without_openai_transform when using the responses API.
"""
from litellm.responses.main import aresponses_api_with_mcp

captured_process_kwargs = {}

async def mock_process(**kwargs):
captured_process_kwargs.update(kwargs)
return ([], {})

mock_response = ResponsesAPIResponse(
**{
"id": "resp_test",
"object": "response",
"created_at": 1234567890,
"status": "completed",
"error": None,
"incomplete_details": None,
"instructions": None,
"max_output_tokens": None,
"model": "gpt-4o",
"output": [{"type": "message", "id": "msg_1", "status": "completed", "role": "assistant", "content": []}],
"parallel_tool_calls": True,
"previous_response_id": None,
"reasoning": {"effort": None, "summary": None},
"store": True,
"temperature": 1.0,
"text": {"format": {"type": "text"}},
"tool_choice": "auto",
"tools": [],
"top_p": 1.0,
"truncation": "disabled",
"usage": {"input_tokens": 1, "output_tokens": 1, "total_tokens": 2},
"user": None,
"metadata": {},
}
)

mcp_tools = [{"type": "mcp", "server_url": "litellm_proxy"}]
secret_fields = {
"raw_headers": {"x-mcp-linear_config-authorization": "Bearer linear-token"},
}

with patch.object(
LiteLLM_Proxy_MCP_Handler,
"_process_mcp_tools_without_openai_transform",
mock_process,
), patch(
"litellm.responses.main.aresponses",
new_callable=AsyncMock,
return_value=mock_response,
):
await aresponses_api_with_mcp(
input=[{"role": "user", "type": "message", "content": "hi"}],
model="gpt-4o",
tools=mcp_tools,
secret_fields=secret_fields,
)

assert "mcp_server_auth_headers" in captured_process_kwargs
mcp_server_auth_headers = captured_process_kwargs["mcp_server_auth_headers"]
assert mcp_server_auth_headers is not None
assert "linear_config" in mcp_server_auth_headers
assert mcp_server_auth_headers["linear_config"]["Authorization"] == "Bearer linear-token"


@pytest.mark.asyncio
async def test_mcp_allowed_tools_filtering():
"""
Expand Down
66 changes: 66 additions & 0 deletions tests/test_litellm/responses/mcp/test_chat_completions_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,72 @@ def mock_extract(**kwargs):
assert captured_secret_fields["value"] == {"api_key": "value"}


@pytest.mark.asyncio
async def test_acompletion_with_mcp_passes_mcp_server_auth_headers_to_process_tools(
monkeypatch,
):
"""
Test that MCP auth headers extracted from secret_fields (e.g. x-mcp-linear_config-authorization)
are passed to _process_mcp_tools_without_openai_transform for dynamic auth when fetching tools.
"""
tools = [{"type": "mcp", "server_url": "litellm_proxy"}]
mock_acompletion = AsyncMock(return_value="ok")

captured_process_kwargs = {}

async def mock_process(**kwargs):
captured_process_kwargs.update(kwargs)
return ([], {})

monkeypatch.setattr(
LiteLLM_Proxy_MCP_Handler,
"_should_use_litellm_mcp_gateway",
staticmethod(lambda t: True),
)
monkeypatch.setattr(
LiteLLM_Proxy_MCP_Handler,
"_parse_mcp_tools",
staticmethod(lambda t: (t, [])),
)
monkeypatch.setattr(
LiteLLM_Proxy_MCP_Handler,
"_process_mcp_tools_without_openai_transform",
mock_process,
)
monkeypatch.setattr(
LiteLLM_Proxy_MCP_Handler,
"_transform_mcp_tools_to_openai",
staticmethod(lambda *_, **__: ["openai-tool"]),
)
monkeypatch.setattr(
LiteLLM_Proxy_MCP_Handler,
"_should_auto_execute_tools",
staticmethod(lambda **_: False),
)

# secret_fields with raw_headers containing MCP auth - extract_mcp_headers_from_request
# will parse these and pass to _process_mcp_tools_without_openai_transform
secret_fields = {
"raw_headers": {
"x-mcp-linear_config-authorization": "Bearer linear-token",
},
}

with patch("litellm.acompletion", mock_acompletion):
await acompletion_with_mcp(
model="test-model",
messages=[],
tools=tools,
secret_fields=secret_fields,
)

assert "mcp_server_auth_headers" in captured_process_kwargs
mcp_server_auth_headers = captured_process_kwargs["mcp_server_auth_headers"]
assert mcp_server_auth_headers is not None
assert "linear_config" in mcp_server_auth_headers
assert mcp_server_auth_headers["linear_config"]["Authorization"] == "Bearer linear-token"


@pytest.mark.asyncio
async def test_acompletion_with_mcp_auto_exec_performs_follow_up(monkeypatch):
from litellm.utils import CustomStreamWrapper
Expand Down
Loading