Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions litellm/integrations/SlackAlerting/budget_alert_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,14 @@ def get_id(self, user_info: CallInfo) -> str:
return user_info.token or "default_id"


class EndUserBudgetAlert(BaseBudgetAlertType):
def get_event_message(self) -> str:
return "Customer Budget: "

def get_id(self, user_info: CallInfo) -> str:
return user_info.customer_id or "default_id"


def get_budget_alert_type(
type: Literal[
"token_budget",
Expand All @@ -93,6 +101,7 @@ def get_budget_alert_type(
"proxy_budget",
"projected_limit_exceeded",
"project_budget",
"end_user_budget",
],
) -> BaseBudgetAlertType:
"""Factory function to get the appropriate budget alert type class"""
Expand All @@ -107,6 +116,7 @@ def get_budget_alert_type(
"token_budget": TokenBudgetAlert(),
"projected_limit_exceeded": ProjectedLimitExceededAlert(),
"project_budget": ProjectBudgetAlert(),
"end_user_budget": EndUserBudgetAlert(),
}

if type in alert_types:
Expand Down
1 change: 1 addition & 0 deletions litellm/integrations/SlackAlerting/slack_alerting.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,7 @@ async def budget_alerts(
"proxy_budget",
"projected_limit_exceeded",
"project_budget",
"end_user_budget",
],
user_info: CallInfo,
):
Expand Down
79 changes: 59 additions & 20 deletions litellm/proxy/auth/auth_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,17 +358,13 @@ async def common_checks(
)

# 5. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
if (
end_user_object is not None
and end_user_object.litellm_budget_table is not None
):
end_user_budget = end_user_object.litellm_budget_table.max_budget
if end_user_budget is not None and end_user_object.spend > end_user_budget:
raise litellm.BudgetExceededError(
current_cost=end_user_object.spend,
max_budget=end_user_budget,
message=f"ExceededBudget: End User={end_user_object.user_id} over budget. Spend={end_user_object.spend}, Budget={end_user_budget}",
)
if end_user_object is not None:
await _check_end_user_budget(
end_user_obj=end_user_object,
route=route,
valid_token=valid_token,
proxy_logging_obj=proxy_logging_obj,
)

# 6. [OPTIONAL] If 'enforce_user_param' enabled - did developer pass in 'user' param for openai endpoints
if (
Expand Down Expand Up @@ -772,34 +768,83 @@ async def _apply_default_budget_to_end_user(
return end_user_obj


def _check_end_user_budget(
async def _check_end_user_budget(
end_user_obj: LiteLLM_EndUserTable,
route: str,
valid_token: Optional[UserAPIKeyAuth] = None,
proxy_logging_obj: Optional[ProxyLogging] = None,
) -> None:
"""
Check if end user is within their budget limit.

Fires a budget alert (via Slack/webhook) when the max or soft budget is
crossed, following the same pattern used by ``_team_max_budget_check`` and
``_team_soft_budget_check``.

Args:
end_user_obj: The end user object to check
route: The request route
valid_token: Optional token for the current request
proxy_logging_obj: Optional proxy logging object for budget alerts

Raises:
litellm.BudgetExceededError: If end user has exceeded their budget
litellm.BudgetExceededError: If end user has exceeded their max budget
"""
if route in LiteLLMRoutes.info_routes.value:
return

if end_user_obj.litellm_budget_table is None:
return

end_user_budget = end_user_obj.litellm_budget_table.max_budget
budget_table = end_user_obj.litellm_budget_table
end_user_budget = budget_table.max_budget
end_user_soft_budget = budget_table.soft_budget

# Max budget check — alert + block
if end_user_budget is not None and end_user_obj.spend > end_user_budget:
if proxy_logging_obj is not None and valid_token is not None:
call_info = CallInfo(
token=valid_token.token,
spend=end_user_obj.spend,
max_budget=end_user_budget,
soft_budget=end_user_soft_budget,
customer_id=end_user_obj.user_id,
event_group=Litellm_EntityType.END_USER,
)
asyncio.create_task(
proxy_logging_obj.budget_alerts(
type="end_user_budget",
user_info=call_info,
)
)
raise litellm.BudgetExceededError(
current_cost=end_user_obj.spend,
max_budget=end_user_budget,
message=f"ExceededBudget: End User={end_user_obj.user_id} over budget. Spend={end_user_obj.spend}, Budget={end_user_budget}",
)

# Soft budget check — alert only, does not block
if (
end_user_soft_budget is not None
and end_user_obj.spend >= end_user_soft_budget
and proxy_logging_obj is not None
and valid_token is not None
):
call_info = CallInfo(
token=valid_token.token,
spend=end_user_obj.spend,
max_budget=end_user_budget,
soft_budget=end_user_soft_budget,
customer_id=end_user_obj.user_id,
event_group=Litellm_EntityType.END_USER,
)
asyncio.create_task(
proxy_logging_obj.budget_alerts(
type="end_user_budget",
user_info=call_info,
)
)


@log_db_metrics
async def get_end_user_object(
Expand Down Expand Up @@ -848,9 +893,6 @@ async def get_end_user_object(
parent_otel_span=parent_otel_span,
)

# Check budget limits
_check_end_user_budget(end_user_obj=return_obj, route=route)

return return_obj

# Fetch from database
Expand Down Expand Up @@ -879,9 +921,6 @@ async def get_end_user_object(
key="end_user_id:{}".format(end_user_id), value=_response.dict()
)

# Check budget limits
_check_end_user_budget(end_user_obj=_response, route=route)

return _response

except Exception as e:
Expand Down
1 change: 1 addition & 0 deletions litellm/proxy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1476,6 +1476,7 @@ async def budget_alerts(
"proxy_budget",
"projected_limit_exceeded",
"project_budget",
"end_user_budget",
],
user_info: CallInfo,
):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from litellm.integrations.SlackAlerting.budget_alert_types import SoftBudgetAlert
from litellm.integrations.SlackAlerting.budget_alert_types import (
EndUserBudgetAlert,
SoftBudgetAlert,
)
from litellm.proxy._types import CallInfo, Litellm_EntityType


Expand Down Expand Up @@ -36,6 +39,46 @@ def test_get_id_with_empty_token(self):
token="",
event_group=Litellm_EntityType.KEY,
)


result = alert.get_id(user_info)
assert result == "default_id"


class TestEndUserBudgetAlert:
def test_get_event_message(self):
"""Test that get_event_message returns the correct customer budget message"""
alert = EndUserBudgetAlert()
assert alert.get_event_message() == "Customer Budget: "

def test_get_id_with_customer_id(self):
"""Test that get_id returns user_info.customer_id when customer_id is provided"""
alert = EndUserBudgetAlert()
user_info = CallInfo(
spend=50.0,
customer_id="customer_123",
event_group=Litellm_EntityType.END_USER,
)
result = alert.get_id(user_info)
assert result == "customer_123"

def test_get_id_without_customer_id(self):
"""Test that get_id returns 'default_id' when customer_id is None"""
alert = EndUserBudgetAlert()
user_info = CallInfo(
spend=50.0,
customer_id=None,
event_group=Litellm_EntityType.END_USER,
)
result = alert.get_id(user_info)
assert result == "default_id"

def test_get_id_with_empty_customer_id(self):
"""Test that get_id returns 'default_id' when customer_id is empty string"""
alert = EndUserBudgetAlert()
user_info = CallInfo(
spend=50.0,
customer_id="",
event_group=Litellm_EntityType.END_USER,
)
result = alert.get_id(user_info)
assert result == "default_id"
Loading
Loading