Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions litellm/litellm_core_utils/litellm_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -2506,11 +2506,18 @@ async def async_success_handler( # noqa: PLR0915

self.model_call_details["async_complete_streaming_response"] = result

# Only set response_cost to None if not already calculated by
# pass-through handlers (e.g. Gemini/Vertex handlers already
# compute cost via completion_cost)
# Merge response_cost and model from kwargs if available.
# Streaming pass-through handlers compute cost and return it
# in kwargs, but it needs to be set on model_call_details for
# the standard logging payload builder to pick it up.
if self.model_call_details.get("response_cost") is None:
self.model_call_details["response_cost"] = None
response_cost_from_kwargs = kwargs.get("response_cost")
if response_cost_from_kwargs is not None:
self.model_call_details["response_cost"] = response_cost_from_kwargs
else:
self.model_call_details["response_cost"] = None
if kwargs.get("model") and not self.model_call_details.get("model"):
self.model_call_details["model"] = kwargs["model"]

# Only build standard_logging_object if not already built by
# _success_handler_helper_fn
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ def _handle_logging_anthropic_collected_chunks(
start_time: datetime,
all_chunks: List[str],
end_time: datetime,
kwargs: Optional[dict] = None,
) -> PassThroughEndpointLoggingTypedDict:
"""
Takes raw chunks from Anthropic passthrough endpoint and logs them in litellm callbacks
Expand Down Expand Up @@ -212,7 +213,7 @@ def _handle_logging_anthropic_collected_chunks(
kwargs = AnthropicPassthroughLoggingHandler._create_anthropic_response_logging_payload(
litellm_model_response=complete_streaming_response,
model=model,
kwargs={},
kwargs=kwargs or {},
start_time=start_time,
end_time=end_time,
logging_obj=litellm_logging_obj,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,7 @@ def _handle_logging_openai_collected_chunks(
start_time: datetime,
all_chunks: List[str],
end_time: datetime,
kwargs: Optional[dict] = None,
) -> PassThroughEndpointLoggingTypedDict:
"""
Handle logging for collected OpenAI streaming chunks with cost tracking.
Expand Down Expand Up @@ -535,23 +536,30 @@ def _handle_logging_openai_collected_chunks(
custom_llm_provider=custom_llm_provider,
)

# Preserve existing litellm_params to maintain metadata tags
existing_litellm_params = litellm_logging_obj.model_call_details.get(
# Preserve existing litellm_params from passed kwargs or logging object
incoming_kwargs = kwargs or {}
existing_litellm_params = incoming_kwargs.get(
"litellm_params"
) or litellm_logging_obj.model_call_details.get(
"litellm_params", {}
) or {}

# Prepare kwargs for logging
kwargs = {
"response_cost": response_cost,
"model": model,
"custom_llm_provider": custom_llm_provider,
"litellm_params": existing_litellm_params.copy(),
"litellm_params": existing_litellm_params.copy() if isinstance(existing_litellm_params, dict) else {},
"call_type": incoming_kwargs.get("call_type", "pass_through_endpoint"),
"litellm_call_id": incoming_kwargs.get("litellm_call_id"),
}

# Extract user information for tracking
# Extract user information from passed kwargs or logging object
passthrough_logging_payload: Optional[
PassthroughStandardLoggingPayload
] = litellm_logging_obj.model_call_details.get(
] = incoming_kwargs.get(
"passthrough_logging_payload"
) or litellm_logging_obj.model_call_details.get(
"passthrough_logging_payload"
)
if passthrough_logging_payload:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,7 @@ def _handle_logging_vertex_collected_chunks(
all_chunks: List[str],
model: Optional[str],
end_time: datetime,
kwargs: Optional[dict] = None,
) -> PassThroughEndpointLoggingTypedDict:
"""
Takes raw chunks from Vertex passthrough endpoint and logs them in litellm callbacks
Expand All @@ -341,7 +342,7 @@ def _handle_logging_vertex_collected_chunks(
- Creates standard logging object
- Logs in litellm callbacks
"""
kwargs: Dict[str, Any] = {}
_kwargs: Dict[str, Any] = kwargs or {}
model = model or VertexPassthroughLoggingHandler.extract_model_from_url(
url_route
)
Expand All @@ -360,13 +361,13 @@ def _handle_logging_vertex_collected_chunks(
)
return {
"result": None,
"kwargs": kwargs,
"kwargs": _kwargs,
}

kwargs = VertexPassthroughLoggingHandler._create_vertex_response_logging_payload_for_generate_content(
_kwargs = VertexPassthroughLoggingHandler._create_vertex_response_logging_payload_for_generate_content(
litellm_model_response=complete_streaming_response,
model=model,
kwargs=kwargs,
kwargs=_kwargs,
start_time=start_time,
end_time=end_time,
logging_obj=litellm_logging_obj,
Expand All @@ -377,7 +378,7 @@ def _handle_logging_vertex_collected_chunks(

return {
"result": complete_streaming_response,
"kwargs": kwargs,
"kwargs": _kwargs,
}

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -830,6 +830,7 @@ async def pass_through_request( # noqa: PLR0915
start_time=start_time,
passthrough_success_handler_obj=pass_through_endpoint_logging,
url_route=str(url),
kwargs=kwargs,
),
headers=HttpPassThroughEndpointHelpers.get_response_headers(
headers=response.headers,
Expand Down Expand Up @@ -867,6 +868,7 @@ async def pass_through_request( # noqa: PLR0915
start_time=start_time,
passthrough_success_handler_obj=pass_through_endpoint_logging,
url_route=str(url),
kwargs=kwargs,
),
headers=HttpPassThroughEndpointHelpers.get_response_headers(
headers=response.headers,
Expand Down
18 changes: 12 additions & 6 deletions litellm/proxy/pass_through_endpoints/streaming_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ async def chunk_processor(
start_time: datetime,
passthrough_success_handler_obj: PassThroughEndpointLogging,
url_route: str,
kwargs: Optional[dict] = None,
):
"""
- Yields chunks from the response
Expand Down Expand Up @@ -83,6 +84,7 @@ async def chunk_processor(
start_time=start_time,
raw_bytes=raw_bytes,
end_time=end_time,
kwargs=kwargs,
)
)
except Exception as e:
Expand All @@ -100,6 +102,7 @@ async def _route_streaming_logging_to_handler(
raw_bytes: List[bytes],
end_time: datetime,
model: Optional[str] = None,
kwargs: Optional[dict] = None,
):
"""
Route the logging for the collected chunks to the appropriate handler
Expand All @@ -115,7 +118,7 @@ async def _route_streaming_logging_to_handler(
standard_logging_response_object: Optional[
PassThroughEndpointLoggingResultValues
] = None
kwargs: dict = {}
handler_kwargs: dict = {}
if endpoint_type == EndpointType.ANTHROPIC:
anthropic_passthrough_logging_handler_result = AnthropicPassthroughLoggingHandler._handle_logging_anthropic_collected_chunks(
litellm_logging_obj=litellm_logging_obj,
Expand All @@ -126,11 +129,12 @@ async def _route_streaming_logging_to_handler(
start_time=start_time,
all_chunks=all_chunks,
end_time=end_time,
kwargs=kwargs,
)
standard_logging_response_object = (
anthropic_passthrough_logging_handler_result["result"]
)
kwargs = anthropic_passthrough_logging_handler_result["kwargs"]
handler_kwargs = anthropic_passthrough_logging_handler_result["kwargs"]
elif endpoint_type == EndpointType.VERTEX_AI:
vertex_passthrough_logging_handler_result = (
VertexPassthroughLoggingHandler._handle_logging_vertex_collected_chunks(
Expand All @@ -143,12 +147,13 @@ async def _route_streaming_logging_to_handler(
all_chunks=all_chunks,
end_time=end_time,
model=model,
kwargs=kwargs,
)
)
standard_logging_response_object = (
vertex_passthrough_logging_handler_result["result"]
)
kwargs = vertex_passthrough_logging_handler_result["kwargs"]
handler_kwargs = vertex_passthrough_logging_handler_result["kwargs"]
elif endpoint_type == EndpointType.OPENAI:
openai_passthrough_logging_handler_result = (
OpenAIPassthroughLoggingHandler._handle_logging_openai_collected_chunks(
Expand All @@ -160,12 +165,13 @@ async def _route_streaming_logging_to_handler(
start_time=start_time,
all_chunks=all_chunks,
end_time=end_time,
kwargs=kwargs,
)
)
standard_logging_response_object = (
openai_passthrough_logging_handler_result["result"]
)
kwargs = openai_passthrough_logging_handler_result["kwargs"]
handler_kwargs = openai_passthrough_logging_handler_result["kwargs"]

if standard_logging_response_object is None:
standard_logging_response_object = StandardPassThroughResponseObject(
Expand All @@ -176,7 +182,7 @@ async def _route_streaming_logging_to_handler(
start_time=start_time,
end_time=end_time,
cache_hit=False,
**kwargs,
**handler_kwargs,
)
if litellm_logging_obj._should_run_sync_callbacks_for_async_calls() is False:
return
Expand All @@ -187,7 +193,7 @@ async def _route_streaming_logging_to_handler(
end_time=end_time,
cache_hit=False,
start_time=start_time,
**kwargs,
**handler_kwargs,
)

@staticmethod
Expand Down
8 changes: 4 additions & 4 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

86 changes: 86 additions & 0 deletions tests/pass_through_unit_tests/test_unit_test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,92 @@ async def mock_aiter_bytes():
), "Collected chunks do not match raw chunks"


@pytest.mark.asyncio
async def test_chunk_processor_passes_kwargs_to_logging_handler():
"""
Test that kwargs (containing litellm_params with API key metadata) are
propagated from chunk_processor through to _route_streaming_logging_to_handler.

This ensures API key attribution reaches Langfuse traces for streaming
pass-through requests (e.g., Claude Code hitting /anthropic/v1/messages).
"""
response = AsyncMock(spec=httpx.Response)

# Minimal streaming response with message_start and message_stop events
raw_chunks = [
b'event: message_start\ndata: {"type":"message_start","message":{"id":"msg_123","type":"message","role":"assistant","content":[],"model":"claude-3-haiku-20240307","stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":10,"output_tokens":1}}}\n\n',
b'event: content_block_start\ndata: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}\n\n',
b'event: content_block_delta\ndata: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}}\n\n',
b'event: content_block_stop\ndata: {"type":"content_block_stop","index":0}\n\n',
b'event: message_delta\ndata: {"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"output_tokens":5}}\n\n',
b'event: message_stop\ndata: {"type":"message_stop"}\n\n',
]

async def mock_aiter_bytes():
for chunk in raw_chunks:
yield chunk

response.aiter_bytes = mock_aiter_bytes

request_body = {"model": "claude-3-haiku-20240307", "messages": [{"role": "user", "content": "Hi"}]}
litellm_logging_obj = MagicMock()
litellm_logging_obj.async_success_handler = AsyncMock()
litellm_logging_obj._should_run_sync_callbacks_for_async_calls = MagicMock(return_value=False)
litellm_logging_obj.model_call_details = {}
start_time = datetime.now()
passthrough_success_handler_obj = MagicMock()

# The kwargs that should be threaded through — simulating what
# _init_kwargs_for_pass_through_endpoint() creates
input_kwargs = {
"litellm_params": {
"metadata": {
"user_api_key_hash": "sk-hashed-abc123",
"user_api_key_alias": "test-key-alias",
"user_api_key_team_id": "team-456",
"user_api_key_user_id": "user-789",
"user_api_key_org_id": "org-012",
},
"proxy_server_request": {
"url": "https://proxy/anthropic/v1/messages",
"method": "POST",
"body": request_body,
},
},
"passthrough_logging_payload": PassthroughStandardLoggingPayload(
url="https://api.anthropic.com/v1/messages",
request_body=request_body,
),
"call_type": "pass_through_endpoint",
"litellm_call_id": "call-test-123",
}

# Consume the async generator
async for _ in PassThroughStreamingHandler.chunk_processor(
response=response,
request_body=request_body,
litellm_logging_obj=litellm_logging_obj,
endpoint_type=EndpointType.ANTHROPIC,
start_time=start_time,
passthrough_success_handler_obj=passthrough_success_handler_obj,
url_route="/v1/messages",
kwargs=input_kwargs,
):
pass

# Allow the asyncio.create_task to run
await asyncio.sleep(0.5)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing asyncio import causes NameError

asyncio is used here (await asyncio.sleep(0.5)) but is never imported in this file. This will cause a NameError at runtime, meaning this test will always fail. Add import asyncio at the top of the file with the other module-level imports.

Suggested change
await asyncio.sleep(0.5)
await asyncio.sleep(0.5)

The import needs to be added at the top of the file:

import asyncio
import json
import os
...

Context Used: Context from dashboard - CLAUDE.md (source)


# Verify async_success_handler was called with kwargs containing
# the API key metadata from input_kwargs
assert litellm_logging_obj.async_success_handler.called, \
"async_success_handler should have been called after streaming completed"
call_kwargs = litellm_logging_obj.async_success_handler.call_args
# The handler_kwargs are spread as **kwargs, check they include response_cost
# (set by the Anthropic handler) and that litellm_params metadata was preserved
assert call_kwargs is not None, "async_success_handler was called but with no args"


def test_convert_raw_bytes_to_str_lines():
"""
Test that the _convert_raw_bytes_to_str_lines method correctly converts raw bytes to a list of strings
Expand Down
Loading
Loading