Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions litellm/llms/openai/chat/gpt_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@
Support for gpt model family
"""

import hashlib
import re
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Coroutine,
Dict,
Iterator,
List,
Literal,
Expand All @@ -20,6 +23,10 @@
import httpx

import litellm

# Pre-compiled pattern for validating OpenAI-compatible tool call IDs.
# Strict providers (e.g. Mistral) reject IDs that don't match this.
_VALID_TOOL_CALL_ID_RE = re.compile(r"^[a-zA-Z0-9]{1,64}$")
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.litellm_core_utils.llm_response_utils.convert_dict_to_response import (
_extract_reasoning_content,
Expand Down Expand Up @@ -431,6 +438,7 @@ def transform_request(
dict: The transformed request. Sent as the body of the API call.
"""
messages = self._transform_messages(messages=messages, model=model)
messages = self._normalize_tool_call_ids(messages)
messages, tools = self.remove_cache_control_flag_from_messages_and_tools(
model=model, messages=messages, tools=optional_params.get("tools", [])
)
Expand All @@ -456,6 +464,7 @@ async def async_transform_request(
transformed_messages = await self._transform_messages(
messages=messages, model=model, is_async=True
)
transformed_messages = self._normalize_tool_call_ids(transformed_messages)
transformed_messages, tools = (
self.remove_cache_control_flag_from_messages_and_tools(
model=model,
Expand All @@ -477,6 +486,74 @@ async def async_transform_request(
model, messages, optional_params, litellm_params, headers
)

def _normalize_tool_call_ids(
self, messages: List[AllMessageValues]
) -> List[AllMessageValues]:
"""
Normalize non-compliant tool call IDs in message history to satisfy strict
OpenAI-compatible providers (e.g. Mistral) that require IDs matching
^[a-zA-Z0-9]{1,64}$.

Some providers (e.g. MiniMax via OpenRouter) return IDs like
``call_function_jlv0n7uyomle_1`` which contain underscores and are longer
than 9 characters. When those IDs are forwarded in a subsequent request to
a strict provider the request is rejected with a 400 BadRequestError.

This method scans the message list once, collects every non-compliant ID,
builds a stable deterministic remapping using an MD5 hash (truncated to 9
characters), and then rewrites:

* ``tool_calls[].id`` on assistant messages
* ``tool_call_id`` on tool/function result messages

IDs that already satisfy ``^[a-zA-Z0-9]{1,64}$`` are left unchanged so
there is zero overhead for well-behaved providers.
"""
def _remap_id(original: str, mapping: Dict[str, str]) -> str:
if original not in mapping:
digest = hashlib.md5(original.encode()).hexdigest()[:9]
mapping[original] = digest
return mapping[original]

# First pass: collect all non-compliant IDs that appear in assistant messages.
id_mapping: Dict[str, str] = {}
for message in messages:
role = message.get("role")
if role == "assistant":
tool_calls = message.get("tool_calls") # type: ignore[union-attr]
if tool_calls:
for tc in tool_calls:
tc_id = (
tc.get("id") if isinstance(tc, dict) else getattr(tc, "id", None)
)
if tc_id and not _VALID_TOOL_CALL_ID_RE.match(tc_id):
_remap_id(tc_id, id_mapping)

if not id_mapping:
return messages

# Second pass: rewrite IDs in-place.
for message in messages:
role = message.get("role")
if role == "assistant":
tool_calls = message.get("tool_calls") # type: ignore[union-attr]
if tool_calls:
for tc in tool_calls:
if isinstance(tc, dict):
tc_id = tc.get("id")
if tc_id and tc_id in id_mapping:
tc["id"] = id_mapping[tc_id]
else:
tc_id = getattr(tc, "id", None)
if tc_id and tc_id in id_mapping:
tc.id = id_mapping[tc_id] # type: ignore[union-attr]
elif role in ("tool", "function"):
tc_id = message.get("tool_call_id") # type: ignore[union-attr]
if tc_id and tc_id in id_mapping:
message["tool_call_id"] = id_mapping[tc_id] # type: ignore[index]

return messages

def _passed_in_tools(self, optional_params: dict) -> bool:
return optional_params.get("tools", None) is not None

Expand Down
131 changes: 131 additions & 0 deletions tests/test_litellm/llms/openai/chat/test_openai_gpt_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,3 +324,134 @@ def test_prompt_cache_params_passed_through(self):
)
assert optional_params.get("prompt_cache_key") == "my-cache-key"
assert optional_params.get("prompt_cache_retention") == "24h"


class TestNormalizeToolCallIds:
"""
Tests for OpenAIGPTConfig._normalize_tool_call_ids().

Regression tests for: https://github.com/BerriAI/litellm/issues/22317
Some OpenAI-compatible providers (e.g. MiniMax via OpenRouter) return tool
call IDs like ``call_function_jlv0n7uyomle_1`` which contain underscores and
exceed OpenAI's 9-character limit. When those IDs are forwarded in a
subsequent request to a strict provider (e.g. Mistral) the request is
rejected with a 400 BadRequestError.
"""

def setup_method(self):
self.config = OpenAIGPTConfig()

def test_compliant_ids_unchanged(self):
"""IDs that already satisfy ^[a-zA-Z0-9]{1,64}$ must not be modified."""
messages = [
{
"role": "assistant",
"content": None,
"tool_calls": [
{"id": "abc123xyz", "type": "function", "function": {"name": "f", "arguments": "{}"}}
],
},
{"role": "tool", "tool_call_id": "abc123xyz", "content": "result"},
]
result = self.config._normalize_tool_call_ids(messages)
assert result[0]["tool_calls"][0]["id"] == "abc123xyz"
assert result[1]["tool_call_id"] == "abc123xyz"

def test_non_compliant_id_is_remapped(self):
"""IDs with underscores/hyphens or wrong length must be remapped."""
bad_id = "call_function_jlv0n7uyomle_1"
messages = [
{
"role": "assistant",
"content": None,
"tool_calls": [
{"id": bad_id, "type": "function", "function": {"name": "f", "arguments": "{}"}}
],
},
{"role": "tool", "tool_call_id": bad_id, "content": "result"},
]
result = self.config._normalize_tool_call_ids(messages)
new_id = result[0]["tool_calls"][0]["id"]
assert new_id != bad_id
# Must satisfy the strict format
import re
assert re.match(r"^[a-zA-Z0-9]{1,64}$", new_id)
# Assistant and tool messages must use the same new ID
assert result[1]["tool_call_id"] == new_id

def test_remapping_is_deterministic(self):
"""The same original ID must always produce the same normalised ID."""
bad_id = "call_function_jlv0n7uyomle_1"
messages = lambda: [ # noqa: E731
{
"role": "assistant",
"content": None,
"tool_calls": [
{"id": bad_id, "type": "function", "function": {"name": "f", "arguments": "{}"}}
],
},
{"role": "tool", "tool_call_id": bad_id, "content": "result"},
]
id_first = self.config._normalize_tool_call_ids(messages())[0]["tool_calls"][0]["id"]
id_second = self.config._normalize_tool_call_ids(messages())[0]["tool_calls"][0]["id"]
assert id_first == id_second

def test_multiple_tool_calls_each_remapped_consistently(self):
"""Each distinct non-compliant ID gets its own deterministic mapping."""
bad_id_1 = "call_function_aaaa_1"
bad_id_2 = "call_function_bbbb_2"
messages = [
{
"role": "assistant",
"content": None,
"tool_calls": [
{"id": bad_id_1, "type": "function", "function": {"name": "f1", "arguments": "{}"}},
{"id": bad_id_2, "type": "function", "function": {"name": "f2", "arguments": "{}"}},
],
},
{"role": "tool", "tool_call_id": bad_id_1, "content": "result1"},
{"role": "tool", "tool_call_id": bad_id_2, "content": "result2"},
]
result = self.config._normalize_tool_call_ids(messages)
new_id_1 = result[0]["tool_calls"][0]["id"]
new_id_2 = result[0]["tool_calls"][1]["id"]
assert new_id_1 != bad_id_1
assert new_id_2 != bad_id_2
assert new_id_1 != new_id_2
assert result[1]["tool_call_id"] == new_id_1
assert result[2]["tool_call_id"] == new_id_2

def test_messages_without_tool_calls_unaffected(self):
"""Plain user/assistant messages must pass through unchanged."""
messages = [
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "world"},
]
result = self.config._normalize_tool_call_ids(messages)
assert result == messages

def test_transform_request_normalizes_ids(self):
"""transform_request must normalise non-compliant IDs end-to-end."""
bad_id = "call_function_jlv0n7uyomle_1"
messages = [
{
"role": "assistant",
"content": None,
"tool_calls": [
{"id": bad_id, "type": "function", "function": {"name": "get_weather", "arguments": "{}"}}
],
},
{"role": "tool", "tool_call_id": bad_id, "content": "sunny"},
]
result = self.config.transform_request(
model="gpt-4o",
messages=messages,
optional_params={},
litellm_params={},
headers={},
)
out_messages = result["messages"]
new_id = out_messages[0]["tool_calls"][0]["id"]
import re
assert re.match(r"^[a-zA-Z0-9]{1,64}$", new_id)
assert out_messages[1]["tool_call_id"] == new_id
Loading