Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -649,8 +649,13 @@ async def _get_allowed_mcp_servers_for_key(
)
)

# Combine both lists
all_servers = direct_mcp_servers + access_group_servers
# servers referenced in tool permissions should also be accessible
tool_perm_servers = list(
(key_object_permission.mcp_tool_permissions or {}).keys()
)

# Combine all lists
all_servers = direct_mcp_servers + access_group_servers + tool_perm_servers
return list(set(all_servers))
except Exception as e:
verbose_logger.warning(
Expand Down Expand Up @@ -686,8 +691,13 @@ async def _get_allowed_mcp_servers_for_team(
)
)

# Combine both lists
all_servers = direct_mcp_servers + access_group_servers
# servers referenced in tool permissions should also be accessible
tool_perm_servers = list(
(object_permissions.mcp_tool_permissions or {}).keys()
)

# Combine all lists
all_servers = direct_mcp_servers + access_group_servers + tool_perm_servers
return list(set(all_servers))
except Exception as e:
verbose_logger.warning(
Expand Down Expand Up @@ -737,17 +747,20 @@ async def _get_allowed_mcp_servers_for_end_user(
# Get direct MCP servers
direct_mcp_servers = end_user_obj.object_permission.mcp_servers or []



# Get MCP servers from access groups
access_group_servers = (
await MCPRequestHandler._get_mcp_servers_from_access_groups(
end_user_obj.object_permission.mcp_access_groups or []
)
)

# Combine both lists
all_servers = direct_mcp_servers + access_group_servers
# servers referenced in tool permissions should also be accessible
tool_perm_servers = list(
(end_user_obj.object_permission.mcp_tool_permissions or {}).keys()
)

# Combine all lists
all_servers = direct_mcp_servers + access_group_servers + tool_perm_servers
return list(set(all_servers))
except Exception as e:
verbose_logger.warning(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1738,3 +1738,34 @@ async def test_get_allowed_tools_for_server_agent_no_restriction(self):
user_api_key_auth=user_api_key_auth,
)
assert sorted(result) == ["tool_a", "tool_b"]


@pytest.mark.asyncio
async def test_tool_permission_servers_included_in_allowed_servers():
"""
Servers listed only in mcp_tool_permissions (not in mcp_servers)
should still be accessible.

Regression test for https://github.com/BerriAI/litellm/issues/21954
"""
perm = MagicMock()
perm.mcp_servers = []
perm.mcp_access_groups = []
perm.mcp_tool_permissions = {"server_id_123": ["tool_a", "tool_b"]}

user_api_key_auth = UserAPIKeyAuth(
api_key="test-key",
user_id="test-user",
)

with patch.object(
MCPRequestHandler, "_get_key_object_permission", return_value=perm
), patch.object(
MCPRequestHandler, "_get_mcp_servers_from_access_groups",
new_callable=AsyncMock,
return_value=[],
):
result = await MCPRequestHandler._get_allowed_mcp_servers_for_key(
user_api_key_auth=user_api_key_auth,
)
assert "server_id_123" in result
Loading