@@ -1485,4 +1485,273 @@ async def test_get_objects_resolves_org_by_name():
14851485 )
14861486
14871487
1488+ # ---------------------------------------------------------------------------
1489+ # Fix 1: OIDC discovery URL resolution
1490+ # ---------------------------------------------------------------------------
1491+
1492+
1493+ @pytest .mark .asyncio
1494+ async def test_resolve_jwks_url_passthrough_for_direct_jwks_url ():
1495+ """Non-discovery URLs are returned unchanged."""
1496+ from unittest .mock import AsyncMock , MagicMock
1497+
1498+ from litellm .caching .dual_cache import DualCache
1499+
1500+ handler = JWTHandler ()
1501+ handler .update_environment (
1502+ prisma_client = None ,
1503+ user_api_key_cache = DualCache (),
1504+ litellm_jwtauth = LiteLLM_JWTAuth (),
1505+ )
1506+ url = "https://login.microsoftonline.com/common/discovery/keys"
1507+ result = await handler ._resolve_jwks_url (url )
1508+ assert result == url
1509+
1510+
1511+ @pytest .mark .asyncio
1512+ async def test_resolve_jwks_url_resolves_oidc_discovery_document ():
1513+ """
1514+ A .well-known/openid-configuration URL should be fetched and its
1515+ jwks_uri returned.
1516+ """
1517+ from unittest .mock import AsyncMock , MagicMock , patch
1518+
1519+ from litellm .caching .dual_cache import DualCache
1520+
1521+ handler = JWTHandler ()
1522+ cache = DualCache ()
1523+ handler .update_environment (
1524+ prisma_client = None ,
1525+ user_api_key_cache = cache ,
1526+ litellm_jwtauth = LiteLLM_JWTAuth (),
1527+ )
1528+
1529+ discovery_url = "https://login.microsoftonline.com/tenant/.well-known/openid-configuration"
1530+ jwks_url = "https://login.microsoftonline.com/tenant/discovery/keys"
1531+
1532+ mock_response = MagicMock ()
1533+ mock_response .status_code = 200
1534+ mock_response .json .return_value = {"jwks_uri" : jwks_url , "issuer" : "https://..." }
1535+
1536+ with patch .object (handler .http_handler , "get" , new_callable = AsyncMock , return_value = mock_response ) as mock_get :
1537+ result = await handler ._resolve_jwks_url (discovery_url )
1538+
1539+ assert result == jwks_url
1540+ mock_get .assert_called_once_with (discovery_url )
1541+
1542+
1543+ @pytest .mark .asyncio
1544+ async def test_resolve_jwks_url_caches_resolved_jwks_uri ():
1545+ """Resolved jwks_uri is cached — second call does not hit the network."""
1546+ from unittest .mock import AsyncMock , MagicMock , patch
1547+
1548+ from litellm .caching .dual_cache import DualCache
1549+
1550+ handler = JWTHandler ()
1551+ cache = DualCache ()
1552+ handler .update_environment (
1553+ prisma_client = None ,
1554+ user_api_key_cache = cache ,
1555+ litellm_jwtauth = LiteLLM_JWTAuth (),
1556+ )
1557+
1558+ discovery_url = "https://login.microsoftonline.com/tenant/.well-known/openid-configuration"
1559+ jwks_url = "https://login.microsoftonline.com/tenant/discovery/keys"
1560+
1561+ mock_response = MagicMock ()
1562+ mock_response .json .return_value = {"jwks_uri" : jwks_url }
1563+
1564+ with patch .object (handler .http_handler , "get" , new_callable = AsyncMock , return_value = mock_response ) as mock_get :
1565+ first = await handler ._resolve_jwks_url (discovery_url )
1566+ second = await handler ._resolve_jwks_url (discovery_url )
1567+
1568+ assert first == jwks_url
1569+ assert second == jwks_url
1570+ # Network should only be hit once
1571+ assert mock_get .call_count == 1
1572+
1573+
1574+ @pytest .mark .asyncio
1575+ async def test_resolve_jwks_url_raises_if_no_jwks_uri_in_discovery_doc ():
1576+ """Raise a helpful error if the discovery document has no jwks_uri."""
1577+ from unittest .mock import AsyncMock , MagicMock , patch
1578+
1579+ from litellm .caching .dual_cache import DualCache
1580+
1581+ handler = JWTHandler ()
1582+ handler .update_environment (
1583+ prisma_client = None ,
1584+ user_api_key_cache = DualCache (),
1585+ litellm_jwtauth = LiteLLM_JWTAuth (),
1586+ )
1587+
1588+ discovery_url = "https://example.com/.well-known/openid-configuration"
1589+ mock_response = MagicMock ()
1590+ mock_response .json .return_value = {"issuer" : "https://example.com" } # no jwks_uri
1591+
1592+ with patch .object (handler .http_handler , "get" , new_callable = AsyncMock , return_value = mock_response ):
1593+ with pytest .raises (Exception , match = "jwks_uri" ):
1594+ await handler ._resolve_jwks_url (discovery_url )
1595+
1596+
1597+ # ---------------------------------------------------------------------------
1598+ # Fix 2: handle array values in team_id_jwt_field (e.g. AAD "roles" claim)
1599+ # ---------------------------------------------------------------------------
1600+
1601+
1602+ def _make_jwt_handler (team_id_jwt_field : str ) -> JWTHandler :
1603+ from litellm .caching .dual_cache import DualCache
1604+
1605+ handler = JWTHandler ()
1606+ handler .update_environment (
1607+ prisma_client = None ,
1608+ user_api_key_cache = DualCache (),
1609+ litellm_jwtauth = LiteLLM_JWTAuth (team_id_jwt_field = team_id_jwt_field ),
1610+ )
1611+ return handler
1612+
1613+
1614+ def test_get_team_id_returns_first_element_when_roles_is_list ():
1615+ """
1616+ AAD sends roles as a list. get_team_id() must return the first string
1617+ element rather than the raw list (which would later crash with
1618+ 'unhashable type: list').
1619+ """
1620+ handler = _make_jwt_handler ("roles" )
1621+ token = {"oid" : "user-oid" , "roles" : ["team1" ]}
1622+ result = handler .get_team_id (token = token , default_value = None )
1623+ assert result == "team1"
1624+
1625+
1626+ def test_get_team_id_returns_first_element_from_multi_value_roles_list ():
1627+ """When roles has multiple entries, the first one is used."""
1628+ handler = _make_jwt_handler ("roles" )
1629+ token = {"roles" : ["team2" , "team1" ]}
1630+ result = handler .get_team_id (token = token , default_value = None )
1631+ assert result == "team2"
1632+
1633+
1634+ def test_get_team_id_returns_default_when_roles_list_is_empty ():
1635+ """Empty list should fall back to default_value."""
1636+ handler = _make_jwt_handler ("roles" )
1637+ token = {"roles" : []}
1638+ result = handler .get_team_id (token = token , default_value = "fallback" )
1639+ assert result == "fallback"
1640+
1641+
1642+ def test_get_team_id_still_works_with_string_value ():
1643+ """String values (non-array) continue to work as before."""
1644+ handler = _make_jwt_handler ("appid" )
1645+ token = {"appid" : "my-team-id" }
1646+ result = handler .get_team_id (token = token , default_value = None )
1647+ assert result == "my-team-id"
1648+
1649+
1650+ def test_get_team_id_list_result_is_hashable ():
1651+ """
1652+ The value returned by get_team_id() must be hashable so it can be
1653+ added to a set (the operation that previously crashed).
1654+ """
1655+ handler = _make_jwt_handler ("roles" )
1656+ token = {"roles" : ["team1" ]}
1657+ result = handler .get_team_id (token = token , default_value = None )
1658+ # This must not raise TypeError
1659+ s : set = set ()
1660+ s .add (result )
1661+ assert "team1" in s
1662+
1663+
1664+ # ---------------------------------------------------------------------------
1665+ # Fix 3: helpful error message for dot-notation array indexing (roles.0)
1666+ # ---------------------------------------------------------------------------
1667+
1668+
1669+ @pytest .mark .asyncio
1670+ async def test_find_and_validate_specific_team_id_hints_bracket_notation ():
1671+ """
1672+ When team_id_jwt_field is set to 'roles.0' (unsupported dot-notation for
1673+ array indexing) and no team is found, the exception message should suggest
1674+ using 'roles' instead (and explain LiteLLM auto-unwraps list values).
1675+ """
1676+ from unittest .mock import MagicMock
1677+
1678+ from litellm .caching .dual_cache import DualCache
1679+
1680+ handler = _make_jwt_handler ("roles.0" )
1681+ # token has roles as a list — dot-notation won't find anything
1682+ token = {"roles" : ["team1" ]}
1683+
1684+ with pytest .raises (Exception ) as exc_info :
1685+ await JWTAuthManager .find_and_validate_specific_team_id (
1686+ jwt_handler = handler ,
1687+ jwt_valid_token = token ,
1688+ prisma_client = None ,
1689+ user_api_key_cache = DualCache (),
1690+ parent_otel_span = None ,
1691+ proxy_logging_obj = MagicMock (),
1692+ )
1693+
1694+ error_msg = str (exc_info .value )
1695+ # Should mention the bad field name and suggest the fix
1696+ assert "roles.0" in error_msg , f"Expected field name in: { error_msg } "
1697+ assert "roles" in error_msg and "list" in error_msg , (
1698+ f"Expected hint about using 'roles' instead: { error_msg } "
1699+ )
1700+
1701+
1702+ @pytest .mark .asyncio
1703+ async def test_find_and_validate_specific_team_id_hints_bracket_index_notation ():
1704+ """
1705+ When team_id_jwt_field is set to 'roles[0]' (bracket indexing, also unsupported
1706+ in get_nested_value) the error message should suggest using 'roles' instead.
1707+ """
1708+ from unittest .mock import MagicMock
1709+
1710+ from litellm .caching .dual_cache import DualCache
1711+
1712+ handler = _make_jwt_handler ("roles[0]" )
1713+ token = {"roles" : ["team1" ]}
1714+
1715+ with pytest .raises (Exception ) as exc_info :
1716+ await JWTAuthManager .find_and_validate_specific_team_id (
1717+ jwt_handler = handler ,
1718+ jwt_valid_token = token ,
1719+ prisma_client = None ,
1720+ user_api_key_cache = DualCache (),
1721+ parent_otel_span = None ,
1722+ proxy_logging_obj = MagicMock (),
1723+ )
1724+
1725+ error_msg = str (exc_info .value )
1726+ assert "roles[0]" in error_msg , f"Expected field name in: { error_msg } "
1727+ assert "roles" in error_msg and "list" in error_msg , (
1728+ f"Expected hint about using 'roles' instead: { error_msg } "
1729+ )
1730+
1731+
1732+ @pytest .mark .asyncio
1733+ async def test_find_and_validate_specific_team_id_no_hint_for_valid_field ():
1734+ """
1735+ When team_id_jwt_field is a normal field name (no dot-notation) the
1736+ error message should not contain a spurious bracket-notation hint.
1737+ """
1738+ from unittest .mock import AsyncMock , MagicMock
1739+
1740+ from litellm .caching .dual_cache import DualCache
1741+
1742+ handler = _make_jwt_handler ("appid" )
1743+ token = {} # no appid — triggers the "no team found" path
1744+
1745+ with pytest .raises (Exception ) as exc_info :
1746+ await JWTAuthManager .find_and_validate_specific_team_id (
1747+ jwt_handler = handler ,
1748+ jwt_valid_token = token ,
1749+ prisma_client = None ,
1750+ user_api_key_cache = DualCache (),
1751+ parent_otel_span = None ,
1752+ proxy_logging_obj = MagicMock (),
1753+ )
1754+
1755+ error_msg = str (exc_info .value )
1756+ assert "Hint" not in error_msg
14881757
0 commit comments