Skip to content

Commit 98bc247

Browse files
authored
Merge pull request #22143 from shivaaang/fix/llm-client-cache-unawaited-coroutine
fix(caching): store task references in LLMClientCache._remove_key
2 parents 3e60ca3 + fb72979 commit 98bc247

File tree

2 files changed

+184
-0
lines changed

2 files changed

+184
-0
lines changed

litellm/caching/llm_caching_handler.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,37 @@
33
"""
44

55
import asyncio
6+
from typing import Set
67

78
from .in_memory_cache import InMemoryCache
89

910

1011
class LLMClientCache(InMemoryCache):
12+
# Background tasks must be stored to prevent garbage collection, which would
13+
# trigger "coroutine was never awaited" warnings. See:
14+
# https://docs.python.org/3/library/asyncio-task.html#creating-tasks
15+
# Intentionally shared across all instances as a global task registry.
16+
_background_tasks: Set[asyncio.Task] = set()
17+
18+
def _remove_key(self, key: str) -> None:
19+
"""Close async clients before evicting them to prevent connection pool leaks."""
20+
value = self.cache_dict.get(key)
21+
super()._remove_key(key)
22+
if value is not None:
23+
close_fn = getattr(value, "aclose", None) or getattr(value, "close", None)
24+
if close_fn and asyncio.iscoroutinefunction(close_fn):
25+
try:
26+
task = asyncio.get_running_loop().create_task(close_fn())
27+
self._background_tasks.add(task)
28+
task.add_done_callback(self._background_tasks.discard)
29+
except RuntimeError:
30+
pass
31+
elif close_fn and callable(close_fn):
32+
try:
33+
close_fn()
34+
except Exception:
35+
pass
36+
1137
def update_cache_key_with_event_loop(self, key):
1238
"""
1339
Add the event loop to the cache key, to prevent event loop closed errors.
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
import asyncio
2+
import os
3+
import sys
4+
import warnings
5+
6+
import pytest
7+
8+
sys.path.insert(
9+
0, os.path.abspath("../../..")
10+
) # Adds the parent directory to the system path
11+
12+
from litellm.caching.llm_caching_handler import LLMClientCache
13+
14+
15+
class MockAsyncClient:
16+
"""Mock async HTTP client with an async close method."""
17+
18+
def __init__(self):
19+
self.closed = False
20+
21+
async def close(self):
22+
self.closed = True
23+
24+
25+
class MockSyncClient:
26+
"""Mock sync HTTP client with a sync close method."""
27+
28+
def __init__(self):
29+
self.closed = False
30+
31+
def close(self):
32+
self.closed = True
33+
34+
35+
@pytest.mark.asyncio
36+
async def test_remove_key_no_unawaited_coroutine_warning():
37+
"""
38+
Test that evicting an async client from LLMClientCache does not produce
39+
'coroutine was never awaited' warnings.
40+
41+
Regression test for https://github.com/BerriAI/litellm/issues/22128
42+
"""
43+
cache = LLMClientCache(max_size_in_memory=2)
44+
45+
mock_client = MockAsyncClient()
46+
cache.cache_dict["test-key"] = mock_client
47+
cache.ttl_dict["test-key"] = 0 # expired
48+
49+
with warnings.catch_warnings(record=True) as caught_warnings:
50+
warnings.simplefilter("always")
51+
cache._remove_key("test-key")
52+
# Let the event loop process the close task
53+
await asyncio.sleep(0.1)
54+
55+
coroutine_warnings = [
56+
w for w in caught_warnings if "coroutine" in str(w.message).lower()
57+
]
58+
assert (
59+
len(coroutine_warnings) == 0
60+
), f"Got unawaited coroutine warnings: {coroutine_warnings}"
61+
62+
63+
@pytest.mark.asyncio
64+
async def test_remove_key_closes_async_client():
65+
"""
66+
Test that evicting an async client from the cache properly closes it.
67+
"""
68+
cache = LLMClientCache(max_size_in_memory=2)
69+
70+
mock_client = MockAsyncClient()
71+
cache.cache_dict["test-key"] = mock_client
72+
cache.ttl_dict["test-key"] = 0
73+
74+
cache._remove_key("test-key")
75+
# Let the event loop process the close task
76+
await asyncio.sleep(0.1)
77+
78+
assert mock_client.closed is True
79+
assert "test-key" not in cache.cache_dict
80+
assert "test-key" not in cache.ttl_dict
81+
82+
83+
def test_remove_key_closes_sync_client():
84+
"""
85+
Test that evicting a sync client from the cache properly closes it.
86+
"""
87+
cache = LLMClientCache(max_size_in_memory=2)
88+
89+
mock_client = MockSyncClient()
90+
cache.cache_dict["test-key"] = mock_client
91+
cache.ttl_dict["test-key"] = 0
92+
93+
cache._remove_key("test-key")
94+
95+
assert mock_client.closed is True
96+
assert "test-key" not in cache.cache_dict
97+
98+
99+
@pytest.mark.asyncio
100+
async def test_eviction_closes_async_clients():
101+
"""
102+
Test that cache eviction (when cache is full) properly closes async clients
103+
without producing warnings.
104+
"""
105+
cache = LLMClientCache(max_size_in_memory=2, default_ttl=1)
106+
107+
clients = []
108+
for i in range(2):
109+
client = MockAsyncClient()
110+
clients.append(client)
111+
cache.set_cache(f"key-{i}", client)
112+
113+
with warnings.catch_warnings(record=True) as caught_warnings:
114+
warnings.simplefilter("always")
115+
# This should trigger eviction of one of the existing entries
116+
cache.set_cache("key-new", "new-value")
117+
await asyncio.sleep(0.1)
118+
119+
coroutine_warnings = [
120+
w for w in caught_warnings if "coroutine" in str(w.message).lower()
121+
]
122+
assert (
123+
len(coroutine_warnings) == 0
124+
), f"Got unawaited coroutine warnings: {coroutine_warnings}"
125+
126+
127+
def test_remove_key_no_event_loop():
128+
"""
129+
Test that _remove_key doesn't raise when there's no running event loop
130+
(falls through to the RuntimeError except branch).
131+
"""
132+
cache = LLMClientCache(max_size_in_memory=2)
133+
134+
mock_client = MockAsyncClient()
135+
cache.cache_dict["test-key"] = mock_client
136+
cache.ttl_dict["test-key"] = 0
137+
138+
# Should not raise even though there's no running event loop
139+
cache._remove_key("test-key")
140+
assert "test-key" not in cache.cache_dict
141+
142+
143+
@pytest.mark.asyncio
144+
async def test_background_tasks_cleaned_up_after_completion():
145+
"""
146+
Test that completed close tasks are removed from the _background_tasks set.
147+
"""
148+
cache = LLMClientCache(max_size_in_memory=2)
149+
150+
mock_client = MockAsyncClient()
151+
cache.cache_dict["test-key"] = mock_client
152+
cache.ttl_dict["test-key"] = 0
153+
154+
cache._remove_key("test-key")
155+
# Let the task complete
156+
await asyncio.sleep(0.1)
157+
158+
assert len(cache._background_tasks) == 0

0 commit comments

Comments
 (0)