Skip to content

Commit ee703ce

Browse files
ishaan-jaffcursoragentgreptile-apps[bot]
authored
fix(jwt): OIDC discovery URLs, roles array handling, dot-notation error hints (#22336)
* fix(jwt): support OIDC discovery URLs, handle roles array, improve error hints Three fixes for Azure AD JWT auth: 1. OIDC discovery URL support - JWT_PUBLIC_KEY_URL can now be set to .well-known/openid-configuration endpoints. The proxy fetches the discovery doc, extracts jwks_uri, and caches it. 2. Handle roles claim as array - when team_id_jwt_field points to a list (e.g. AAD's "roles": ["team1"]), auto-unwrap the first element instead of crashing with 'unhashable type: list'. 3. Better error hint for dot-notation indexing - when team_id_jwt_field is set to "roles.0" or "roles[0]", the 401 error now explains to use "roles" instead and that LiteLLM auto-unwraps lists. * Add integration demo script for JWT auth fixes (OIDC discovery, array roles, dot-notation hints) Co-authored-by: Ishaan Jaff <ishaan-jaff@users.noreply.github.com> * Add demo_servers.py for manual JWT auth testing with mock JWKS/OIDC endpoints Co-authored-by: Ishaan Jaff <ishaan-jaff@users.noreply.github.com> * Add demo screenshots for PR comment Co-authored-by: Ishaan Jaff <ishaan-jaff@users.noreply.github.com> * Add integration test results with screenshots for PR review Co-authored-by: Ishaan Jaff <ishaan-jaff@users.noreply.github.com> * address greptile review feedback (greploop iteration 1) - fix: add HTTP status code check in _resolve_jwks_url before parsing JSON - fix: remove misleading bracket-notation hint from debug log (get_nested_value does not support it) * Update tests/test_litellm/proxy/auth/test_handle_jwt.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> * remove demo scripts and assets --------- Co-authored-by: Cursor Agent <cursoragent@cursor.com> Co-authored-by: Ishaan Jaff <ishaan-jaff@users.noreply.github.com> Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
1 parent 8ab5428 commit ee703ce

File tree

2 files changed

+351
-2
lines changed

2 files changed

+351
-2
lines changed

litellm/proxy/auth/handle_jwt.py

Lines changed: 82 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import fnmatch
1010
import os
11+
import re
1112
from typing import Any, List, Literal, Optional, Set, Tuple, cast
1213

1314
from cryptography import x509
@@ -235,7 +236,17 @@ def get_team_id(self, token: dict, default_value: Optional[str]) -> Optional[str
235236
return self.litellm_jwtauth.team_id_default
236237
else:
237238
return default_value
238-
# At this point, team_id is not the sentinel, so it should be a string
239+
# AAD and other IdPs often send roles/groups as a list of strings.
240+
# team_id_jwt_field is singular, so take the first element when a list
241+
# is returned. This avoids "unhashable type: 'list'" errors downstream.
242+
if isinstance(team_id, list):
243+
if not team_id:
244+
return default_value
245+
verbose_proxy_logger.debug(
246+
f"JWT Auth: team_id_jwt_field '{self.litellm_jwtauth.team_id_jwt_field}' "
247+
f"returned a list {team_id}; using first element '{team_id[0]}' automatically."
248+
)
249+
team_id = team_id[0]
239250
return team_id # type: ignore[return-value]
240251
elif self.litellm_jwtauth.team_id_default is not None:
241252
team_id = self.litellm_jwtauth.team_id_default
@@ -453,6 +464,52 @@ def get_scopes(self, token: dict) -> List[str]:
453464
scopes = []
454465
return scopes
455466

467+
async def _resolve_jwks_url(self, url: str) -> str:
468+
"""
469+
If url points to an OIDC discovery document (*.well-known/openid-configuration),
470+
fetch it and return the jwks_uri contained within. Otherwise return url unchanged.
471+
This lets JWT_PUBLIC_KEY_URL be set to a well-known discovery endpoint instead of
472+
requiring operators to manually find the JWKS URL.
473+
"""
474+
if ".well-known/openid-configuration" not in url:
475+
return url
476+
477+
cache_key = f"litellm_oidc_discovery_{url}"
478+
cached_jwks_uri = await self.user_api_key_cache.async_get_cache(cache_key)
479+
if cached_jwks_uri is not None:
480+
return cached_jwks_uri
481+
482+
verbose_proxy_logger.debug(
483+
f"JWT Auth: Fetching OIDC discovery document from {url}"
484+
)
485+
response = await self.http_handler.get(url)
486+
if response.status_code != 200:
487+
raise Exception(
488+
f"JWT Auth: OIDC discovery endpoint {url} returned status {response.status_code}: {response.text}"
489+
)
490+
try:
491+
discovery = response.json()
492+
except Exception as e:
493+
raise Exception(
494+
f"JWT Auth: Failed to parse OIDC discovery document at {url}: {e}"
495+
)
496+
497+
jwks_uri = discovery.get("jwks_uri")
498+
if not jwks_uri:
499+
raise Exception(
500+
f"JWT Auth: OIDC discovery document at {url} does not contain a 'jwks_uri' field."
501+
)
502+
503+
verbose_proxy_logger.debug(
504+
f"JWT Auth: Resolved OIDC discovery {url} -> jwks_uri={jwks_uri}"
505+
)
506+
await self.user_api_key_cache.async_set_cache(
507+
key=cache_key,
508+
value=jwks_uri,
509+
ttl=self.litellm_jwtauth.public_key_ttl,
510+
)
511+
return jwks_uri
512+
456513
async def get_public_key(self, kid: Optional[str]) -> dict:
457514
keys_url = os.getenv("JWT_PUBLIC_KEY_URL")
458515

@@ -462,6 +519,7 @@ async def get_public_key(self, kid: Optional[str]) -> dict:
462519
keys_url_list = [url.strip() for url in keys_url.split(",")]
463520

464521
for key_url in keys_url_list:
522+
key_url = await self._resolve_jwks_url(key_url)
465523
cache_key = f"litellm_jwt_auth_keys_{key_url}"
466524

467525
cached_keys = await self.user_api_key_cache.async_get_cache(cache_key)
@@ -913,8 +971,30 @@ async def find_and_validate_specific_team_id(
913971
if jwt_handler.is_required_team_id() is True:
914972
team_id_field = jwt_handler.litellm_jwtauth.team_id_jwt_field
915973
team_alias_field = jwt_handler.litellm_jwtauth.team_alias_jwt_field
974+
hint = ""
975+
if team_id_field:
976+
# "roles.0" — dot-notation numeric indexing is not supported
977+
if "." in team_id_field:
978+
parts = team_id_field.rsplit(".", 1)
979+
if parts[-1].isdigit():
980+
base_field = parts[0]
981+
hint = (
982+
f" Hint: dot-notation array indexing (e.g. '{team_id_field}') is not "
983+
f"supported. Use '{base_field}' instead — LiteLLM automatically "
984+
f"uses the first element when the field value is a list."
985+
)
986+
# "roles[0]" — bracket-notation indexing is also not supported in get_nested_value
987+
elif "[" in team_id_field and team_id_field.endswith("]"):
988+
m = re.match(r"^(\w+)\[(\d+)\]$", team_id_field)
989+
if m:
990+
base_field = m.group(1)
991+
hint = (
992+
f" Hint: array indexing (e.g. '{team_id_field}') is not supported "
993+
f"in team_id_jwt_field. Use '{base_field}' instead — LiteLLM "
994+
f"automatically uses the first element when the field value is a list."
995+
)
916996
raise Exception(
917-
f"No team found in token. Checked team_id field '{team_id_field}' and team_alias field '{team_alias_field}'"
997+
f"No team found in token. Checked team_id field '{team_id_field}' and team_alias field '{team_alias_field}'.{hint}"
918998
)
919999

9201000
return individual_team_id, team_object

tests/test_litellm/proxy/auth/test_handle_jwt.py

Lines changed: 269 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1485,4 +1485,273 @@ async def test_get_objects_resolves_org_by_name():
14851485
)
14861486

14871487

1488+
# ---------------------------------------------------------------------------
1489+
# Fix 1: OIDC discovery URL resolution
1490+
# ---------------------------------------------------------------------------
1491+
1492+
1493+
@pytest.mark.asyncio
1494+
async def test_resolve_jwks_url_passthrough_for_direct_jwks_url():
1495+
"""Non-discovery URLs are returned unchanged."""
1496+
from unittest.mock import AsyncMock, MagicMock
1497+
1498+
from litellm.caching.dual_cache import DualCache
1499+
1500+
handler = JWTHandler()
1501+
handler.update_environment(
1502+
prisma_client=None,
1503+
user_api_key_cache=DualCache(),
1504+
litellm_jwtauth=LiteLLM_JWTAuth(),
1505+
)
1506+
url = "https://login.microsoftonline.com/common/discovery/keys"
1507+
result = await handler._resolve_jwks_url(url)
1508+
assert result == url
1509+
1510+
1511+
@pytest.mark.asyncio
1512+
async def test_resolve_jwks_url_resolves_oidc_discovery_document():
1513+
"""
1514+
A .well-known/openid-configuration URL should be fetched and its
1515+
jwks_uri returned.
1516+
"""
1517+
from unittest.mock import AsyncMock, MagicMock, patch
1518+
1519+
from litellm.caching.dual_cache import DualCache
1520+
1521+
handler = JWTHandler()
1522+
cache = DualCache()
1523+
handler.update_environment(
1524+
prisma_client=None,
1525+
user_api_key_cache=cache,
1526+
litellm_jwtauth=LiteLLM_JWTAuth(),
1527+
)
1528+
1529+
discovery_url = "https://login.microsoftonline.com/tenant/.well-known/openid-configuration"
1530+
jwks_url = "https://login.microsoftonline.com/tenant/discovery/keys"
1531+
1532+
mock_response = MagicMock()
1533+
mock_response.status_code = 200
1534+
mock_response.json.return_value = {"jwks_uri": jwks_url, "issuer": "https://..."}
1535+
1536+
with patch.object(handler.http_handler, "get", new_callable=AsyncMock, return_value=mock_response) as mock_get:
1537+
result = await handler._resolve_jwks_url(discovery_url)
1538+
1539+
assert result == jwks_url
1540+
mock_get.assert_called_once_with(discovery_url)
1541+
1542+
1543+
@pytest.mark.asyncio
1544+
async def test_resolve_jwks_url_caches_resolved_jwks_uri():
1545+
"""Resolved jwks_uri is cached — second call does not hit the network."""
1546+
from unittest.mock import AsyncMock, MagicMock, patch
1547+
1548+
from litellm.caching.dual_cache import DualCache
1549+
1550+
handler = JWTHandler()
1551+
cache = DualCache()
1552+
handler.update_environment(
1553+
prisma_client=None,
1554+
user_api_key_cache=cache,
1555+
litellm_jwtauth=LiteLLM_JWTAuth(),
1556+
)
1557+
1558+
discovery_url = "https://login.microsoftonline.com/tenant/.well-known/openid-configuration"
1559+
jwks_url = "https://login.microsoftonline.com/tenant/discovery/keys"
1560+
1561+
mock_response = MagicMock()
1562+
mock_response.json.return_value = {"jwks_uri": jwks_url}
1563+
1564+
with patch.object(handler.http_handler, "get", new_callable=AsyncMock, return_value=mock_response) as mock_get:
1565+
first = await handler._resolve_jwks_url(discovery_url)
1566+
second = await handler._resolve_jwks_url(discovery_url)
1567+
1568+
assert first == jwks_url
1569+
assert second == jwks_url
1570+
# Network should only be hit once
1571+
assert mock_get.call_count == 1
1572+
1573+
1574+
@pytest.mark.asyncio
1575+
async def test_resolve_jwks_url_raises_if_no_jwks_uri_in_discovery_doc():
1576+
"""Raise a helpful error if the discovery document has no jwks_uri."""
1577+
from unittest.mock import AsyncMock, MagicMock, patch
1578+
1579+
from litellm.caching.dual_cache import DualCache
1580+
1581+
handler = JWTHandler()
1582+
handler.update_environment(
1583+
prisma_client=None,
1584+
user_api_key_cache=DualCache(),
1585+
litellm_jwtauth=LiteLLM_JWTAuth(),
1586+
)
1587+
1588+
discovery_url = "https://example.com/.well-known/openid-configuration"
1589+
mock_response = MagicMock()
1590+
mock_response.json.return_value = {"issuer": "https://example.com"} # no jwks_uri
1591+
1592+
with patch.object(handler.http_handler, "get", new_callable=AsyncMock, return_value=mock_response):
1593+
with pytest.raises(Exception, match="jwks_uri"):
1594+
await handler._resolve_jwks_url(discovery_url)
1595+
1596+
1597+
# ---------------------------------------------------------------------------
1598+
# Fix 2: handle array values in team_id_jwt_field (e.g. AAD "roles" claim)
1599+
# ---------------------------------------------------------------------------
1600+
1601+
1602+
def _make_jwt_handler(team_id_jwt_field: str) -> JWTHandler:
1603+
from litellm.caching.dual_cache import DualCache
1604+
1605+
handler = JWTHandler()
1606+
handler.update_environment(
1607+
prisma_client=None,
1608+
user_api_key_cache=DualCache(),
1609+
litellm_jwtauth=LiteLLM_JWTAuth(team_id_jwt_field=team_id_jwt_field),
1610+
)
1611+
return handler
1612+
1613+
1614+
def test_get_team_id_returns_first_element_when_roles_is_list():
1615+
"""
1616+
AAD sends roles as a list. get_team_id() must return the first string
1617+
element rather than the raw list (which would later crash with
1618+
'unhashable type: list').
1619+
"""
1620+
handler = _make_jwt_handler("roles")
1621+
token = {"oid": "user-oid", "roles": ["team1"]}
1622+
result = handler.get_team_id(token=token, default_value=None)
1623+
assert result == "team1"
1624+
1625+
1626+
def test_get_team_id_returns_first_element_from_multi_value_roles_list():
1627+
"""When roles has multiple entries, the first one is used."""
1628+
handler = _make_jwt_handler("roles")
1629+
token = {"roles": ["team2", "team1"]}
1630+
result = handler.get_team_id(token=token, default_value=None)
1631+
assert result == "team2"
1632+
1633+
1634+
def test_get_team_id_returns_default_when_roles_list_is_empty():
1635+
"""Empty list should fall back to default_value."""
1636+
handler = _make_jwt_handler("roles")
1637+
token = {"roles": []}
1638+
result = handler.get_team_id(token=token, default_value="fallback")
1639+
assert result == "fallback"
1640+
1641+
1642+
def test_get_team_id_still_works_with_string_value():
1643+
"""String values (non-array) continue to work as before."""
1644+
handler = _make_jwt_handler("appid")
1645+
token = {"appid": "my-team-id"}
1646+
result = handler.get_team_id(token=token, default_value=None)
1647+
assert result == "my-team-id"
1648+
1649+
1650+
def test_get_team_id_list_result_is_hashable():
1651+
"""
1652+
The value returned by get_team_id() must be hashable so it can be
1653+
added to a set (the operation that previously crashed).
1654+
"""
1655+
handler = _make_jwt_handler("roles")
1656+
token = {"roles": ["team1"]}
1657+
result = handler.get_team_id(token=token, default_value=None)
1658+
# This must not raise TypeError
1659+
s: set = set()
1660+
s.add(result)
1661+
assert "team1" in s
1662+
1663+
1664+
# ---------------------------------------------------------------------------
1665+
# Fix 3: helpful error message for dot-notation array indexing (roles.0)
1666+
# ---------------------------------------------------------------------------
1667+
1668+
1669+
@pytest.mark.asyncio
1670+
async def test_find_and_validate_specific_team_id_hints_bracket_notation():
1671+
"""
1672+
When team_id_jwt_field is set to 'roles.0' (unsupported dot-notation for
1673+
array indexing) and no team is found, the exception message should suggest
1674+
using 'roles' instead (and explain LiteLLM auto-unwraps list values).
1675+
"""
1676+
from unittest.mock import MagicMock
1677+
1678+
from litellm.caching.dual_cache import DualCache
1679+
1680+
handler = _make_jwt_handler("roles.0")
1681+
# token has roles as a list — dot-notation won't find anything
1682+
token = {"roles": ["team1"]}
1683+
1684+
with pytest.raises(Exception) as exc_info:
1685+
await JWTAuthManager.find_and_validate_specific_team_id(
1686+
jwt_handler=handler,
1687+
jwt_valid_token=token,
1688+
prisma_client=None,
1689+
user_api_key_cache=DualCache(),
1690+
parent_otel_span=None,
1691+
proxy_logging_obj=MagicMock(),
1692+
)
1693+
1694+
error_msg = str(exc_info.value)
1695+
# Should mention the bad field name and suggest the fix
1696+
assert "roles.0" in error_msg, f"Expected field name in: {error_msg}"
1697+
assert "roles" in error_msg and "list" in error_msg, (
1698+
f"Expected hint about using 'roles' instead: {error_msg}"
1699+
)
1700+
1701+
1702+
@pytest.mark.asyncio
1703+
async def test_find_and_validate_specific_team_id_hints_bracket_index_notation():
1704+
"""
1705+
When team_id_jwt_field is set to 'roles[0]' (bracket indexing, also unsupported
1706+
in get_nested_value) the error message should suggest using 'roles' instead.
1707+
"""
1708+
from unittest.mock import MagicMock
1709+
1710+
from litellm.caching.dual_cache import DualCache
1711+
1712+
handler = _make_jwt_handler("roles[0]")
1713+
token = {"roles": ["team1"]}
1714+
1715+
with pytest.raises(Exception) as exc_info:
1716+
await JWTAuthManager.find_and_validate_specific_team_id(
1717+
jwt_handler=handler,
1718+
jwt_valid_token=token,
1719+
prisma_client=None,
1720+
user_api_key_cache=DualCache(),
1721+
parent_otel_span=None,
1722+
proxy_logging_obj=MagicMock(),
1723+
)
1724+
1725+
error_msg = str(exc_info.value)
1726+
assert "roles[0]" in error_msg, f"Expected field name in: {error_msg}"
1727+
assert "roles" in error_msg and "list" in error_msg, (
1728+
f"Expected hint about using 'roles' instead: {error_msg}"
1729+
)
1730+
1731+
1732+
@pytest.mark.asyncio
1733+
async def test_find_and_validate_specific_team_id_no_hint_for_valid_field():
1734+
"""
1735+
When team_id_jwt_field is a normal field name (no dot-notation) the
1736+
error message should not contain a spurious bracket-notation hint.
1737+
"""
1738+
from unittest.mock import AsyncMock, MagicMock
1739+
1740+
from litellm.caching.dual_cache import DualCache
1741+
1742+
handler = _make_jwt_handler("appid")
1743+
token = {} # no appid — triggers the "no team found" path
1744+
1745+
with pytest.raises(Exception) as exc_info:
1746+
await JWTAuthManager.find_and_validate_specific_team_id(
1747+
jwt_handler=handler,
1748+
jwt_valid_token=token,
1749+
prisma_client=None,
1750+
user_api_key_cache=DualCache(),
1751+
parent_otel_span=None,
1752+
proxy_logging_obj=MagicMock(),
1753+
)
1754+
1755+
error_msg = str(exc_info.value)
1756+
assert "Hint" not in error_msg
14881757

0 commit comments

Comments
 (0)