|
3 | 3 | import pytest |
4 | 4 | from pydantic import BaseModel |
5 | 5 |
|
6 | | -from agents import Agent, ModelBehaviorError, Runner, UserError |
| 6 | +from agents import ( |
| 7 | + Agent, |
| 8 | + ModelBehaviorError, |
| 9 | + RunContextWrapper, |
| 10 | + Runner, |
| 11 | + UserError, |
| 12 | + default_tool_error_function, |
| 13 | +) |
| 14 | +from agents.exceptions import AgentsException |
7 | 15 |
|
8 | 16 | from ..fake_model import FakeModel |
9 | 17 | from ..test_responses import get_function_tool_call, get_text_message |
@@ -195,3 +203,60 @@ async def test_runner_calls_mcp_tool_with_args(streaming: bool): |
195 | 203 | assert server.tool_results == [f"result_test_tool_2_{json_args}"] |
196 | 204 |
|
197 | 205 | await server.cleanup() |
| 206 | + |
| 207 | + |
| 208 | +class CrashingFakeMCPServer(FakeMCPServer): |
| 209 | + async def call_tool( |
| 210 | + self, |
| 211 | + tool_name: str, |
| 212 | + arguments: dict[str, object] | None, |
| 213 | + meta: dict[str, object] | None = None, |
| 214 | + ): |
| 215 | + raise Exception("Crash!") |
| 216 | + |
| 217 | + |
| 218 | +@pytest.mark.asyncio |
| 219 | +@pytest.mark.parametrize("streaming", [False, True]) |
| 220 | +async def test_runner_emits_mcp_error_tool_call_output_item(streaming: bool): |
| 221 | + """Runner should emit tool_call_output_item with failure output when MCP tool raises.""" |
| 222 | + server = CrashingFakeMCPServer() |
| 223 | + server.add_tool("crashing_tool", {}) |
| 224 | + |
| 225 | + model = FakeModel() |
| 226 | + agent = Agent( |
| 227 | + name="test", |
| 228 | + model=model, |
| 229 | + mcp_servers=[server], |
| 230 | + ) |
| 231 | + |
| 232 | + model.add_multiple_turn_outputs( |
| 233 | + [ |
| 234 | + [get_text_message("a_message"), get_function_tool_call("crashing_tool", "{}")], |
| 235 | + [get_text_message("done")], |
| 236 | + ] |
| 237 | + ) |
| 238 | + |
| 239 | + if streaming: |
| 240 | + streamed_result = Runner.run_streamed(agent, input="user_message") |
| 241 | + async for _ in streamed_result.stream_events(): |
| 242 | + pass |
| 243 | + tool_output_items = [ |
| 244 | + item for item in streamed_result.new_items if item.type == "tool_call_output_item" |
| 245 | + ] |
| 246 | + assert streamed_result.final_output == "done" |
| 247 | + else: |
| 248 | + non_streamed_result = await Runner.run(agent, input="user_message") |
| 249 | + tool_output_items = [ |
| 250 | + item for item in non_streamed_result.new_items if item.type == "tool_call_output_item" |
| 251 | + ] |
| 252 | + assert non_streamed_result.final_output == "done" |
| 253 | + |
| 254 | + assert tool_output_items, "Expected tool_call_output_item for MCP failure" |
| 255 | + wrapped_error = AgentsException( |
| 256 | + "Error invoking MCP tool crashing_tool on server 'fake_mcp_server': Crash!" |
| 257 | + ) |
| 258 | + expected_error_message = default_tool_error_function( |
| 259 | + RunContextWrapper(context=None), |
| 260 | + wrapped_error, |
| 261 | + ) |
| 262 | + assert tool_output_items[0].output == expected_error_message |
0 commit comments