From 7876b5e821b7c4fc931d0aa9274b87feff16ca68 Mon Sep 17 00:00:00 2001 From: Hiroshi Nishio Date: Thu, 16 Apr 2026 20:17:24 -0700 Subject: [PATCH] Add cost cap to stop agent loop at 80% of revenue; refactor LLM returns to LlmResult dataclass MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add COST_CAP_RATIO constant and get_total_cost_for_pr to query cumulative LLM cost across all PR invocations - Integrate cost cap check into should_bail (silent bail — no customer-facing comment or Slack) - Replace 5-element tuples from chat_with_claude/chat_with_google with shared LlmResult dataclass - Rename get_credit_cost → get_credit_price to distinguish revenue (customer price) from LLM cost - Fix pre-commit hook: remove git add after formatting to preserve partial staging --- constants/agent.py | 2 + scripts/git/pre_commit_hook.sh | 9 +- services/chat_with_agent.py | 38 ++- services/claude/chat_with_claude.py | 30 +- services/claude/test_chat_with_claude.py | 25 +- services/google_ai/chat_with_google.py | 23 +- services/google_ai/test_chat_with_google.py | 74 +++-- services/llm_result.py | 19 ++ ...get_credit_cost.py => get_credit_price.py} | 2 +- services/supabase/credits/insert_credit.py | 4 +- .../supabase/credits/test_get_credit_cost.py | 24 -- .../supabase/credits/test_get_credit_price.py | 24 ++ .../supabase/credits/test_insert_credit.py | 8 +- .../llm_requests/get_total_cost_for_pr.py | 44 +++ .../llm_requests/insert_llm_request.py | 3 +- .../test_get_total_cost_for_pr.py | 74 +++++ .../llm_requests/test_insert_llm_request.py | 29 +- services/test_chat_with_agent.py | 301 ++++++++++++------ services/test_chat_with_model.py | 21 +- services/test_llm_result.py | 41 +++ services/webhook/check_suite_handler.py | 6 +- services/webhook/new_pr_handler.py | 7 +- services/webhook/review_run_handler.py | 6 +- services/webhook/test_check_suite_handler.py | 10 + services/webhook/test_new_pr_handler.py | 34 ++ services/webhook/test_review_run_handler.py | 11 +- services/webhook/test_setup_handler.py | 1 + services/webhook/test_webhook_handler.py | 30 +- services/webhook/utils/should_bail.py | 14 + services/webhook/utils/test_should_bail.py | 112 +++++++ 30 files changed, 760 insertions(+), 266 deletions(-) create mode 100644 services/llm_result.py rename services/supabase/credits/{get_credit_cost.py => get_credit_price.py} (89%) delete mode 100644 services/supabase/credits/test_get_credit_cost.py create mode 100644 services/supabase/credits/test_get_credit_price.py create mode 100644 services/supabase/llm_requests/get_total_cost_for_pr.py create mode 100644 services/supabase/llm_requests/test_get_total_cost_for_pr.py create mode 100644 services/test_llm_result.py diff --git a/constants/agent.py b/constants/agent.py index 01b6c5ee6..5fbf375ef 100644 --- a/constants/agent.py +++ b/constants/agent.py @@ -1,2 +1,4 @@ +# Stop agent loop when LLM cost reaches this fraction of revenue (credit_cost_usd) +COST_CAP_RATIO = 0.8 MAX_ITERATIONS = 30 MAX_PLANNING_ITERATIONS = 20 diff --git a/scripts/git/pre_commit_hook.sh b/scripts/git/pre_commit_hook.sh index 824db2481..b96652e4d 100755 --- a/scripts/git/pre_commit_hook.sh +++ b/scripts/git/pre_commit_hook.sh @@ -1,5 +1,5 @@ #!/bin/bash -# Git pre-commit hook. No stashing - runs on the working directory as-is. +# Git pre-commit hook. Does NOT auto-stage — user handles staging. # Install: ln -sf ../../scripts/git/pre_commit_hook.sh .git/hooks/pre-commit set -uo pipefail @@ -35,17 +35,12 @@ python3 schemas/supabase/generate_types.py && git add schemas/supabase/ # Get staged Python files (excluding deleted, .venv, schemas) STAGED_PY_FILES=$(git diff --cached --name-only --diff-filter=d -- '*.py' | grep -v '^\.\?venv/' | grep -v '^schemas/') -# Format and auto-fix staged Python files +# Format staged Python files (user re-stages if needed) if [ -n "$STAGED_PY_FILES" ]; then # shellcheck disable=SC2086 black $STAGED_PY_FILES - # shellcheck disable=SC2086 - git add $STAGED_PY_FILES - # shellcheck disable=SC2086 ruff check --fix $STAGED_PY_FILES - # shellcheck disable=SC2086 - git add $STAGED_PY_FILES fi # Markdownlint for staged .md files diff --git a/services/chat_with_agent.py b/services/chat_with_agent.py index 4934f3cf3..56a74218a 100644 --- a/services/chat_with_agent.py +++ b/services/chat_with_agent.py @@ -32,10 +32,11 @@ @dataclass -class AgentResult: +class AgentResult: # pylint: disable=too-many-instance-attributes messages: list[MessageParam] token_input: int token_output: int + cost_usd: float is_completed: bool completion_reason: str p: int @@ -67,12 +68,7 @@ async def chat_with_agent( logger.info("Using model: %s", current_model) try: - ( - response_message, - tool_calls, - token_input, - token_output, - ) = chat_with_model( + llm_result = chat_with_model( messages=messages, system_content=system_message, tools=tools, @@ -125,13 +121,14 @@ async def chat_with_agent( ) # Return if no tool calls (agent returned text without calling a tool) - if not tool_calls: - logger.info("No tools were called. Response: %s", response_message) - messages.append(response_message) + if not llm_result.tool_calls: + logger.info("No tools were called. Response: %s", llm_result.assistant_message) + messages.append(llm_result.assistant_message) return AgentResult( messages=messages, - token_input=token_input, - token_output=token_output, + token_input=llm_result.token_input, + token_output=llm_result.token_output, + cost_usd=llm_result.cost_usd, is_completed=False, completion_reason="", p=p, @@ -139,10 +136,10 @@ async def chat_with_agent( ) # Append assistant message before processing tool calls - messages.append(response_message) + messages.append(llm_result.assistant_message) # Extract text from the assistant message for completion context - content = response_message["content"] + content = llm_result.assistant_message["content"] if isinstance(content, str): assistant_text = content else: @@ -156,11 +153,11 @@ async def chat_with_agent( tool_result_blocks: list[ToolResultBlockParam] = [] log_msgs: list[str] = [] is_completed = False - num_tool_calls = len(tool_calls) + num_tool_calls = len(llm_result.tool_calls) logger.info("Processing %d tool call(s)", num_tool_calls) # pylint: disable-next=too-many-nested-blocks - for i, tc in enumerate(tool_calls, start=1): + for i, tc in enumerate(llm_result.tool_calls, start=1): tool_use_id = tc.id tool_name = tc.name tool_args = tc.args @@ -505,7 +502,7 @@ async def chat_with_agent( if log_msgs: update_comment( body=create_progress_bar( - p=p + 5 * len(tool_calls), msg="\n".join(log_messages) + p=p + 5 * len(llm_result.tool_calls), msg="\n".join(log_messages) ), base_args=base_args, ) @@ -520,10 +517,11 @@ async def chat_with_agent( return AgentResult( messages=messages, - token_input=token_input, - token_output=token_output, + token_input=llm_result.token_input, + token_output=llm_result.token_output, + cost_usd=llm_result.cost_usd, is_completed=is_completed, completion_reason=assistant_text, - p=p + 5 * len(tool_calls), + p=p + 5 * len(llm_result.tool_calls), is_planned=False, ) diff --git a/services/claude/chat_with_claude.py b/services/claude/chat_with_claude.py index 343ba276c..d204b6b67 100644 --- a/services/claude/chat_with_claude.py +++ b/services/claude/chat_with_claude.py @@ -1,17 +1,12 @@ -# Standard imports -from dataclasses import dataclass import time -# Third party imports from anthropic import AuthenticationError from anthropic._exceptions import OverloadedError from anthropic.types import MessageParam, ToolUnionParam, ToolUseBlock -# Local imports from constants.claude import CONTEXT_WINDOW, MAX_OUTPUT_TOKENS from constants.models import ClaudeModelId from services.claude.client import claude -from services.claude.strip_strict_from_tools import strip_strict_from_tools from services.claude.exceptions import ( ClaudeAuthenticationError, ClaudeOverloadedError, @@ -19,18 +14,13 @@ from services.claude.remove_outdated_file_edit_attempts import ( remove_outdated_file_edit_attempts, ) +from services.claude.strip_strict_from_tools import strip_strict_from_tools from services.claude.trim_messages import trim_messages_to_token_limit +from services.llm_result import LlmResult, ToolCall from services.supabase.llm_requests.insert_llm_request import insert_llm_request from utils.error.handle_exceptions import handle_exceptions -@dataclass -class ToolCall: - id: str - name: str - args: dict | None - - @handle_exceptions(raise_on_error=True) def chat_with_claude( messages: list[MessageParam], @@ -126,7 +116,7 @@ def chat_with_claude( # Combine system message with user messages for logging system_msg: MessageParam = {"role": "user", "content": system_content} full_messages = [system_msg, *messages] - insert_llm_request( + llm_record = insert_llm_request( usage_id=usage_id, provider="claude", model_id=model_id, @@ -137,10 +127,12 @@ def chat_with_claude( response_time_ms=response_time_ms, created_by=created_by, ) - - return ( - assistant_message, - tool_calls, - token_input, - token_output, + cost_usd = llm_record["total_cost_usd"] if llm_record else 0.0 + + return LlmResult( + assistant_message=assistant_message, + tool_calls=tool_calls, + token_input=token_input, + token_output=token_output, + cost_usd=cost_usd, ) diff --git a/services/claude/test_chat_with_claude.py b/services/claude/test_chat_with_claude.py index 3425838fc..8caf65b8f 100644 --- a/services/claude/test_chat_with_claude.py +++ b/services/claude/test_chat_with_claude.py @@ -13,6 +13,7 @@ def test_chat_with_claude_success(mock_claude, mock_insert_llm_request): mock_response = Mock() mock_response.content = [Mock(type="text", text="Hello! How can I help you?")] mock_response.usage = Mock(output_tokens=15) + mock_insert_llm_request.return_value = {"total_cost_usd": 0.05} mock_claude.messages.create.return_value = mock_response mock_claude.messages.count_tokens.return_value = Mock(input_tokens=20) @@ -31,13 +32,13 @@ def test_chat_with_claude_success(mock_claude, mock_insert_llm_request): created_by="4:test-user", ) - assistant_message, tool_calls, token_input, token_output = result - assert assistant_message["role"] == "assistant" - content = cast(list, assistant_message["content"]) + assert result.assistant_message["role"] == "assistant" + content = cast(list, result.assistant_message["content"]) assert content[0]["text"] == "Hello! How can I help you?" - assert not tool_calls - assert token_input == 20 - assert token_output == 15 + assert not result.tool_calls + assert result.token_input == 20 + assert result.token_output == 15 + assert result.cost_usd == 0.05 mock_insert_llm_request.assert_called_once() call_args = mock_insert_llm_request.call_args[1] @@ -88,11 +89,10 @@ def test_chat_with_claude_with_tool_use(mock_claude, mock_insert_llm_request): created_by="4:test-user", ) - _, tool_calls, _, _ = result - assert len(tool_calls) == 1 - assert tool_calls[0].id == "tool_123" - assert tool_calls[0].name == "test_function" - assert tool_calls[0].args == {"param": "value"} + assert len(result.tool_calls) == 1 + assert result.tool_calls[0].id == "tool_123" + assert result.tool_calls[0].name == "test_function" + assert result.tool_calls[0].args == {"param": "value"} mock_insert_llm_request.assert_called_once() @@ -116,8 +116,7 @@ def test_chat_with_claude_no_usage_response(mock_claude, mock_insert_llm_request created_by="4:test-user", ) - _, _, _, token_output = result - assert token_output == 0 # output tokens should be 0 when no usage info + assert result.token_output == 0 # output tokens should be 0 when no usage info mock_insert_llm_request.assert_called_once() diff --git a/services/google_ai/chat_with_google.py b/services/google_ai/chat_with_google.py index 51d40628e..2f355cead 100644 --- a/services/google_ai/chat_with_google.py +++ b/services/google_ai/chat_with_google.py @@ -1,7 +1,6 @@ # Standard imports import time import uuid -from dataclasses import dataclass # Third-party imports from anthropic.types import MessageParam, ToolUnionParam @@ -12,18 +11,12 @@ from services.google_ai.client import get_google_ai_client from services.google_ai.convert_messages import convert_messages_to_google from services.google_ai.convert_tools import convert_tools_to_google +from services.llm_result import LlmResult, ToolCall from services.supabase.llm_requests.insert_llm_request import insert_llm_request from utils.error.handle_exceptions import handle_exceptions from utils.logging.logging_config import logger -@dataclass -class ToolCall: - id: str - name: str - args: dict | None - - @handle_exceptions(raise_on_error=True) def chat_with_google( messages: list[MessageParam], @@ -105,7 +98,7 @@ def chat_with_google( # Log to Supabase system_msg: MessageParam = {"role": "user", "content": system_content} full_messages = [system_msg, *messages] - insert_llm_request( + llm_record = insert_llm_request( usage_id=usage_id, provider="google", model_id=model_id, @@ -116,6 +109,7 @@ def chat_with_google( response_time_ms=response_time_ms, created_by=created_by, ) + cost_usd = llm_record["total_cost_usd"] if llm_record else 0.0 logger.info( "Google AI response: model=%s, input_tokens=%d, output_tokens=%d, tool_calls=%d", @@ -125,9 +119,10 @@ def chat_with_google( len(tool_calls), ) - return ( - assistant_message, - tool_calls, - token_input, - token_output, + return LlmResult( + assistant_message=assistant_message, + tool_calls=tool_calls, + token_input=token_input, + token_output=token_output, + cost_usd=cost_usd, ) diff --git a/services/google_ai/test_chat_with_google.py b/services/google_ai/test_chat_with_google.py index 9ca2d7097..5dcb81d5b 100644 --- a/services/google_ai/test_chat_with_google.py +++ b/services/google_ai/test_chat_with_google.py @@ -63,6 +63,7 @@ def _mock_tool_call_response( @patch("services.google_ai.chat_with_google.get_google_ai_client") def test_text_response(mock_get_client, mock_insert): """Text-only response returns correct assistant message and token counts.""" + mock_insert.return_value = {"total_cost_usd": 0.0} mock_client = Mock() mock_client.models.generate_content.return_value = _mock_text_response( "Hello! How can I help?", prompt_tokens=20, candidates_tokens=15 @@ -81,14 +82,14 @@ def test_text_response(mock_get_client, mock_insert): created_by="4:test-user", ) - assistant_message, tool_calls, token_input, token_output = result - assert assistant_message == { + assert result.assistant_message == { "role": "assistant", "content": [{"type": "text", "text": "Hello! How can I help?"}], } - assert not tool_calls - assert token_input == 20 - assert token_output == 15 + assert not result.tool_calls + assert result.token_input == 20 + assert result.token_output == 15 + assert result.cost_usd == 0.0 mock_insert.assert_called_once() call_kwargs = mock_insert.call_args[1] @@ -103,6 +104,7 @@ def test_text_response(mock_get_client, mock_insert): @patch("services.google_ai.chat_with_google.get_google_ai_client") def test_tool_call_response(mock_get_client, mock_insert): """Response with text + function_call returns tool_calls and correct message.""" + mock_insert.return_value = {"total_cost_usd": 0.05} mock_client = Mock() mock_client.models.generate_content.return_value = _mock_tool_call_response( text="I'll read that file.", @@ -135,8 +137,7 @@ def test_tool_call_response(mock_get_client, mock_insert): created_by="4:test-user", ) - assistant_message, tool_calls, token_input, token_output = result - assert assistant_message == { + assert result.assistant_message == { "role": "assistant", "content": [ {"type": "text", "text": "I'll read that file."}, @@ -148,18 +149,19 @@ def test_tool_call_response(mock_get_client, mock_insert): }, ], } - assert len(tool_calls) == 1 - assert tool_calls[0].id == "toolu_abc123" - assert tool_calls[0].name == "get_remote_file_content" - assert tool_calls[0].args == {"file_path": "README.md"} - assert token_input == 30 - assert token_output == 25 + assert len(result.tool_calls) == 1 + assert result.tool_calls[0].id == "toolu_abc123" + assert result.tool_calls[0].name == "get_remote_file_content" + assert result.tool_calls[0].args == {"file_path": "README.md"} + assert result.token_input == 30 + assert result.token_output == 25 @patch("services.google_ai.chat_with_google.insert_llm_request") @patch("services.google_ai.chat_with_google.get_google_ai_client") def test_no_usage_metadata(mock_get_client, mock_insert): """When usage_metadata is None, token counts default to 0.""" + mock_insert.return_value = {"total_cost_usd": 0.0} response = _mock_text_response("Response") response.usage_metadata = None mock_client = Mock() @@ -175,15 +177,15 @@ def test_no_usage_metadata(mock_get_client, mock_insert): created_by="4:test-user", ) - _, _, token_input, token_output = result - assert token_input == 0 - assert token_output == 0 + assert result.token_input == 0 + assert result.token_output == 0 @patch("services.google_ai.chat_with_google.insert_llm_request") @patch("services.google_ai.chat_with_google.get_google_ai_client") def test_empty_candidates(mock_get_client, mock_insert): """When candidates list is empty, returns empty content.""" + mock_insert.return_value = {"total_cost_usd": 0.0} response = Mock() response.candidates = [] response.usage_metadata = Mock(prompt_token_count=10, candidates_token_count=0) @@ -200,10 +202,9 @@ def test_empty_candidates(mock_get_client, mock_insert): created_by="4:test-user", ) - assistant_message, tool_calls, _, _ = result # No parts → content_list is empty → falls back to empty content_text - assert assistant_message == {"role": "assistant", "content": ""} - assert not tool_calls + assert result.assistant_message == {"role": "assistant", "content": ""} + assert not result.tool_calls @patch("services.google_ai.chat_with_google.insert_llm_request") @@ -247,12 +248,11 @@ def test_function_call_without_id_generates_one(mock_get_client, mock_insert): created_by="4:test-user", ) - _, tool_calls, _, _ = result - assert len(tool_calls) == 1 - assert tool_calls[0].id.startswith("toolu_") - assert len(tool_calls[0].id) == 30 # "toolu_" + 24 hex chars - assert tool_calls[0].name == "run_command" - assert tool_calls[0].args == {"command": "ls"} + assert len(result.tool_calls) == 1 + assert result.tool_calls[0].id.startswith("toolu_") + assert len(result.tool_calls[0].id) == 30 # "toolu_" + 24 hex chars + assert result.tool_calls[0].name == "run_command" + assert result.tool_calls[0].args == {"command": "ls"} # --- Sociable integration tests: real Google AI API calls --- @@ -286,12 +286,11 @@ def test_integration_text_response(mock_insert): created_by="4:integration-test", ) - assistant_message, tool_calls, token_input, token_output = result - assert assistant_message["role"] == "assistant" - assert isinstance(assistant_message["content"], (str, list)) - assert not tool_calls - assert token_input > 0 - assert token_output > 0 + assert result.assistant_message["role"] == "assistant" + assert isinstance(result.assistant_message["content"], (str, list)) + assert not result.tool_calls + assert result.token_input > 0 + assert result.token_output > 0 mock_insert.assert_called_once() call_kwargs = mock_insert.call_args[1] @@ -321,17 +320,16 @@ def test_integration_tool_call_with_real_tools(mock_insert): created_by="4:integration-test", ) - assistant_message, tool_calls, token_input, token_output = result - assert assistant_message["role"] == "assistant" - assert token_input > 0 - assert token_output > 0 + assert result.assistant_message["role"] == "assistant" + assert result.token_input > 0 + assert result.token_output > 0 # Model should call a file-reading tool - assert len(tool_calls) >= 1 - tool_names = [tc.name for tc in tool_calls] + assert len(result.tool_calls) >= 1 + tool_names = [tc.name for tc in result.tool_calls] assert any( name in tool_names for name in ("get_local_file_content", "query_file") ), f"Expected a file-reading tool call, got: {tool_names}" # Each tool call has a valid id - for tc in tool_calls: + for tc in result.tool_calls: assert tc.id assert tc.name diff --git a/services/llm_result.py b/services/llm_result.py new file mode 100644 index 000000000..641229eef --- /dev/null +++ b/services/llm_result.py @@ -0,0 +1,19 @@ +from dataclasses import dataclass + +from anthropic.types import MessageParam + + +@dataclass +class ToolCall: + id: str + name: str + args: dict | None + + +@dataclass +class LlmResult: + assistant_message: MessageParam + tool_calls: list[ToolCall] + token_input: int + token_output: int + cost_usd: float diff --git a/services/supabase/credits/get_credit_cost.py b/services/supabase/credits/get_credit_price.py similarity index 89% rename from services/supabase/credits/get_credit_cost.py rename to services/supabase/credits/get_credit_price.py index 682d70756..012c8eeea 100644 --- a/services/supabase/credits/get_credit_cost.py +++ b/services/supabase/credits/get_credit_price.py @@ -3,7 +3,7 @@ @handle_exceptions(default_return_value=MAX_CREDIT_COST_USD, raise_on_error=False) -def get_credit_cost(model_id: ModelId | None): +def get_credit_price(model_id: ModelId | None): if not model_id: return MAX_CREDIT_COST_USD entry = MODEL_REGISTRY.get(model_id) diff --git a/services/supabase/credits/insert_credit.py b/services/supabase/credits/insert_credit.py index 6ab7136af..355243188 100644 --- a/services/supabase/credits/insert_credit.py +++ b/services/supabase/credits/insert_credit.py @@ -1,7 +1,7 @@ # Local imports from constants.models import CREDIT_GRANT_AMOUNT_USD, ModelId from schemas.supabase.types import CreditTransactionType -from services.supabase.credits.get_credit_cost import get_credit_cost +from services.supabase.credits.get_credit_price import get_credit_price from services.supabase.client import supabase from utils.error.handle_exceptions import handle_exceptions @@ -14,7 +14,7 @@ def insert_credit( model_id: ModelId | None = None, ): if transaction_type == "usage": - amount_usd = -get_credit_cost(model_id) + amount_usd = -get_credit_price(model_id) elif transaction_type == "grant": amount_usd = CREDIT_GRANT_AMOUNT_USD else: diff --git a/services/supabase/credits/test_get_credit_cost.py b/services/supabase/credits/test_get_credit_cost.py deleted file mode 100644 index 886caa5b0..000000000 --- a/services/supabase/credits/test_get_credit_cost.py +++ /dev/null @@ -1,24 +0,0 @@ -from constants.models import ( - MAX_CREDIT_COST_USD, - MODEL_REGISTRY, - ClaudeModelId, - GoogleModelId, -) -from services.supabase.credits.get_credit_cost import get_credit_cost - - -def test_returns_cost_for_known_model(): - for model_id, info in MODEL_REGISTRY.items(): - assert get_credit_cost(model_id) == info["credit_cost_usd"] - - -def test_returns_max_cost_for_none(): - assert get_credit_cost(None) == MAX_CREDIT_COST_USD - - -def test_opus_costs_8(): - assert get_credit_cost(ClaudeModelId.OPUS_4_6) == 8 - - -def test_gemma_costs_2(): - assert get_credit_cost(GoogleModelId.GEMMA_4_31B) == 2 diff --git a/services/supabase/credits/test_get_credit_price.py b/services/supabase/credits/test_get_credit_price.py new file mode 100644 index 000000000..391f16787 --- /dev/null +++ b/services/supabase/credits/test_get_credit_price.py @@ -0,0 +1,24 @@ +from constants.models import ( + MAX_CREDIT_COST_USD, + MODEL_REGISTRY, + ClaudeModelId, + GoogleModelId, +) +from services.supabase.credits.get_credit_price import get_credit_price + + +def test_returns_cost_for_known_model(): + for model_id, info in MODEL_REGISTRY.items(): + assert get_credit_price(model_id) == info["credit_cost_usd"] + + +def test_returns_max_cost_for_none(): + assert get_credit_price(None) == MAX_CREDIT_COST_USD + + +def test_opus_costs_8(): + assert get_credit_price(ClaudeModelId.OPUS_4_6) == 8 + + +def test_gemma_costs_2(): + assert get_credit_price(GoogleModelId.GEMMA_4_31B) == 2 diff --git a/services/supabase/credits/test_insert_credit.py b/services/supabase/credits/test_insert_credit.py index 442344f4e..3e7bc2d4f 100644 --- a/services/supabase/credits/test_insert_credit.py +++ b/services/supabase/credits/test_insert_credit.py @@ -53,7 +53,7 @@ def test_usage_inserts_correct_data_with_usage_id(mock_supabase, mock_query_chai usage_id = 456 with patch( - "services.supabase.credits.insert_credit.get_credit_cost", + "services.supabase.credits.insert_credit.get_credit_price", return_value=MAX_CREDIT_COST_USD, ): result = insert_credit( @@ -78,7 +78,7 @@ def test_usage_with_model_id_uses_model_specific_cost(mock_supabase, mock_query_ model_id = GoogleModelId.GEMMA_4_31B with patch( - "services.supabase.credits.insert_credit.get_credit_cost", + "services.supabase.credits.insert_credit.get_credit_price", return_value=2, ): result = insert_credit( @@ -124,7 +124,7 @@ def test_insert_exception_returns_none(mock_supabase, mock_query_chain): mock_query_chain["table"].insert.side_effect = Exception("Insert error") with patch( - "services.supabase.credits.insert_credit.get_credit_cost", + "services.supabase.credits.insert_credit.get_credit_price", return_value=MAX_CREDIT_COST_USD, ): result = insert_credit(owner_id=555555, transaction_type="usage") @@ -168,7 +168,7 @@ def test_various_usage_ids(mock_supabase, mock_query_chain, usage_id): owner_id = 123456 with patch( - "services.supabase.credits.insert_credit.get_credit_cost", + "services.supabase.credits.insert_credit.get_credit_price", return_value=MAX_CREDIT_COST_USD, ): result = insert_credit( diff --git a/services/supabase/llm_requests/get_total_cost_for_pr.py b/services/supabase/llm_requests/get_total_cost_for_pr.py new file mode 100644 index 000000000..264875889 --- /dev/null +++ b/services/supabase/llm_requests/get_total_cost_for_pr.py @@ -0,0 +1,44 @@ +from services.supabase.client import supabase +from utils.error.handle_exceptions import handle_exceptions +from utils.logging.logging_config import logger + + +@handle_exceptions(default_return_value=0.0, raise_on_error=False) +def get_total_cost_for_pr(owner_name: str, repo_name: str, pr_number: int): + # Get all usage IDs for this PR + usage_result = ( + supabase.table("usage") + .select("id") + .eq("owner_name", owner_name) + .eq("repo_name", repo_name) + .eq("pr_number", pr_number) + .execute() + ) + if not usage_result.data: + logger.info("No usage records for %s/%s#%d", owner_name, repo_name, pr_number) + return 0.0 + + usage_ids = [row["id"] for row in usage_result.data] + + # Sum LLM costs across all invocations for this PR + cost_result = ( + supabase.table("llm_requests") + .select("total_cost_usd") + .in_("usage_id", usage_ids) + .execute() + ) + total = ( + sum(row["total_cost_usd"] for row in cost_result.data) + if cost_result.data + else 0.0 + ) + logger.info( + "Total LLM cost for %s/%s#%d: $%.4f (%d usage records, %d llm_requests)", + owner_name, + repo_name, + pr_number, + total, + len(usage_ids), + len(cost_result.data) if cost_result.data else 0, + ) + return total diff --git a/services/supabase/llm_requests/insert_llm_request.py b/services/supabase/llm_requests/insert_llm_request.py index 1f60bf63d..ab150ff0f 100644 --- a/services/supabase/llm_requests/insert_llm_request.py +++ b/services/supabase/llm_requests/insert_llm_request.py @@ -2,6 +2,7 @@ from anthropic.types import MessageParam +from schemas.supabase.types import LlmRequests from services.supabase.client import supabase from services.supabase.llm_requests.calculate_costs import calculate_costs from utils.error.handle_exceptions import handle_exceptions @@ -55,4 +56,4 @@ def insert_llm_request( } result = supabase.table("llm_requests").insert(data).execute() - return result.data[0] if result.data else None + return LlmRequests(**result.data[0]) if result.data else None diff --git a/services/supabase/llm_requests/test_get_total_cost_for_pr.py b/services/supabase/llm_requests/test_get_total_cost_for_pr.py new file mode 100644 index 000000000..e1f395778 --- /dev/null +++ b/services/supabase/llm_requests/test_get_total_cost_for_pr.py @@ -0,0 +1,74 @@ +# pyright: reportUnusedVariable=false +from unittest.mock import MagicMock, patch + +from services.supabase.llm_requests.get_total_cost_for_pr import get_total_cost_for_pr + +MOCK_SUPABASE = "services.supabase.llm_requests.get_total_cost_for_pr.supabase" + + +@patch(MOCK_SUPABASE) +def test_returns_zero_when_no_usage_records(mock_supabase): + usage_query = MagicMock() + usage_query.execute.return_value = MagicMock(data=[]) + mock_supabase.table.return_value.select.return_value.eq.return_value.eq.return_value.eq.return_value = ( + usage_query + ) + + result = get_total_cost_for_pr("owner", "repo", 1) + assert result == 0.0 + + +@patch(MOCK_SUPABASE) +def test_sums_costs_across_multiple_usage_ids(mock_supabase): + # First call: usage table returns 2 usage IDs + usage_query = MagicMock() + usage_query.execute.return_value = MagicMock(data=[{"id": 100}, {"id": 200}]) + + # Second call: llm_requests table returns costs + cost_query = MagicMock() + cost_query.execute.return_value = MagicMock( + data=[ + {"total_cost_usd": 1.50}, + {"total_cost_usd": 2.30}, + {"total_cost_usd": 0.80}, + ] + ) + + def table_side_effect(name): + mock_table = MagicMock() + if name == "usage": + mock_table.select.return_value.eq.return_value.eq.return_value.eq.return_value = ( + usage_query + ) + elif name == "llm_requests": + mock_table.select.return_value.in_.return_value = cost_query + return mock_table + + mock_supabase.table.side_effect = table_side_effect + + result = get_total_cost_for_pr("owner", "repo", 42) + assert result == 4.60 + + +@patch(MOCK_SUPABASE) +def test_returns_zero_when_no_llm_requests(mock_supabase): + usage_query = MagicMock() + usage_query.execute.return_value = MagicMock(data=[{"id": 100}]) + + cost_query = MagicMock() + cost_query.execute.return_value = MagicMock(data=[]) + + def table_side_effect(name): + mock_table = MagicMock() + if name == "usage": + mock_table.select.return_value.eq.return_value.eq.return_value.eq.return_value = ( + usage_query + ) + elif name == "llm_requests": + mock_table.select.return_value.in_.return_value = cost_query + return mock_table + + mock_supabase.table.side_effect = table_side_effect + + result = get_total_cost_for_pr("owner", "repo", 42) + assert result == 0.0 diff --git a/services/supabase/llm_requests/test_insert_llm_request.py b/services/supabase/llm_requests/test_insert_llm_request.py index c91c8569b..f66ce86c5 100644 --- a/services/supabase/llm_requests/test_insert_llm_request.py +++ b/services/supabase/llm_requests/test_insert_llm_request.py @@ -1,3 +1,4 @@ +import datetime import json from unittest.mock import Mock, patch @@ -6,13 +7,35 @@ from constants.models import ClaudeModelId from services.supabase.llm_requests.insert_llm_request import insert_llm_request +MOCK_DB_ROW = { + "id": 1, + "usage_id": 123, + "provider": "claude", + "model_id": ClaudeModelId.SONNET_4_6, + "input_content": json.dumps([{"role": "user", "content": "test"}]), + "input_length": 35, + "input_tokens": 10, + "input_cost_usd": 0.001, + "output_content": json.dumps({"role": "assistant", "content": "response"}), + "output_length": 42, + "output_tokens": 5, + "output_cost_usd": 0.005, + "total_cost_usd": 0.006, + "response_time_ms": 1000, + "error_message": None, + "created_at": datetime.datetime(2026, 4, 16), + "created_by": "test", + "updated_at": datetime.datetime(2026, 4, 16), + "updated_by": "test", +} + @patch("services.supabase.llm_requests.insert_llm_request.supabase") @patch("services.supabase.llm_requests.insert_llm_request.calculate_costs") def test_insert_llm_request_success(mock_calculate_costs, mock_supabase): mock_calculate_costs.return_value = (0.001, 0.005) mock_result = Mock() - mock_result.data = [{"id": 1}] + mock_result.data = [MOCK_DB_ROW] mock_supabase.table.return_value.insert.return_value.execute.return_value = ( mock_result ) @@ -32,7 +55,9 @@ def test_insert_llm_request_success(mock_calculate_costs, mock_supabase): created_by="test", ) - assert result == {"id": 1} + assert result is not None + assert result["id"] == 1 + assert result["total_cost_usd"] == 0.006 mock_calculate_costs.assert_called_once_with( "claude", ClaudeModelId.SONNET_4_6, 10, 5 ) diff --git a/services/test_chat_with_agent.py b/services/test_chat_with_agent.py index 8b2cf528c..230fdc37e 100644 --- a/services/test_chat_with_agent.py +++ b/services/test_chat_with_agent.py @@ -5,7 +5,7 @@ import pytest from constants.models import ClaudeModelId, GoogleModelId, ModelId from services.chat_with_agent import chat_with_agent -from services.claude.chat_with_claude import ToolCall +from services.llm_result import LlmResult, ToolCall from services.claude.exceptions import ClaudeOverloadedError from services.claude.tools.file_modify_result import FileMoveResult, FileWriteResult @@ -15,11 +15,12 @@ async def test_chat_with_agent_passes_usage_id_to_claude( mock_chat_with_model, create_test_base_args ): - mock_chat_with_model.return_value = ( - {"role": "assistant", "content": "response"}, - [], - 15, - 10, + mock_chat_with_model.return_value = LlmResult( + assistant_message={"role": "assistant", "content": "response"}, + tool_calls=[], + token_input=15, + token_output=10, + cost_usd=0.05, ) base_args = create_test_base_args(model_id=GoogleModelId.GEMMA_4_31B) @@ -43,11 +44,12 @@ async def test_chat_with_agent_passes_usage_id_to_claude( async def test_chat_with_agent_returns_token_counts( mock_chat_with_model, create_test_base_args ): - mock_chat_with_model.return_value = ( - {"role": "assistant", "content": "response"}, - [], - 25, - 15, + mock_chat_with_model.return_value = LlmResult( + assistant_message={"role": "assistant", "content": "response"}, + tool_calls=[], + token_input=25, + token_output=15, + cost_usd=0.0, ) base_args = create_test_base_args(model_id=GoogleModelId.GEMINI_2_5_FLASH) @@ -63,6 +65,61 @@ async def test_chat_with_agent_returns_token_counts( assert result.token_input == 25 assert result.token_output == 15 + assert result.cost_usd == 0.0 + + +@pytest.mark.asyncio +@patch("services.chat_with_agent.chat_with_model") +async def test_cost_usd_computed_for_claude_model( + mock_chat_with_model, create_test_base_args +): + mock_chat_with_model.return_value = LlmResult( + assistant_message={"role": "assistant", "content": "response"}, + tool_calls=[], + token_input=30_000, + token_output=500, + cost_usd=0.1625, + ) + + base_args = create_test_base_args(model_id=ClaudeModelId.OPUS_4_6) + + result = await chat_with_agent( + messages=[{"role": "user", "content": "test"}], + system_message="test system message", + base_args=base_args, + tools=[], + usage_id=789, + model_id=ClaudeModelId.OPUS_4_6, + ) + + assert result.cost_usd == 0.1625 + + +@pytest.mark.asyncio +@patch("services.chat_with_agent.chat_with_model") +async def test_cost_usd_computed_for_google_model( + mock_chat_with_model, create_test_base_args +): + mock_chat_with_model.return_value = LlmResult( + assistant_message={"role": "assistant", "content": "response"}, + tool_calls=[], + token_input=100_000, + token_output=2_000, + cost_usd=0.0162, + ) + + base_args = create_test_base_args(model_id=GoogleModelId.GEMINI_2_5_FLASH) + + result = await chat_with_agent( + messages=[{"role": "user", "content": "test"}], + system_message="test system message", + base_args=base_args, + tools=[], + usage_id=789, + model_id=GoogleModelId.GEMINI_2_5_FLASH, + ) + + assert result.cost_usd == 0.0162 @pytest.mark.asyncio @@ -72,8 +129,8 @@ async def test_get_local_file_content_start_line_end_line_logging( mock_update_comment, mock_chat_with_model, create_test_base_args ): """Test that start_line and end_line parameters are properly logged in chat_with_agent.""" - mock_chat_with_model.return_value = ( - { + mock_chat_with_model.return_value = LlmResult( + assistant_message={ "role": "assistant", "content": [ { @@ -84,15 +141,16 @@ async def test_get_local_file_content_start_line_end_line_logging( } ], }, - [ + tool_calls=[ ToolCall( id="test_id", name="get_local_file_content", args={"file_path": "test.py", "start_line": 10, "end_line": 20}, ) ], - 15, - 10, + token_input=15, + token_output=10, + cost_usd=0.05, ) base_args = create_test_base_args(model_id=ClaudeModelId.SONNET_4_6) @@ -131,8 +189,8 @@ async def test_delete_file_logging( mock_update_comment, mock_chat_with_model, create_test_base_args ): """Test that delete_file function calls are properly logged in chat_with_agent.""" - mock_chat_with_model.return_value = ( - { + mock_chat_with_model.return_value = LlmResult( + assistant_message={ "role": "assistant", "content": [ { @@ -143,13 +201,14 @@ async def test_delete_file_logging( } ], }, - [ + tool_calls=[ ToolCall( id="test_id", name="delete_file", args={"file_path": "test_file.py"} ) ], - 15, - 10, + token_input=15, + token_output=10, + cost_usd=0.05, ) base_args = create_test_base_args(model_id=ClaudeModelId.OPUS_4_6) @@ -190,8 +249,8 @@ async def test_move_file_logging( mock_update_comment, mock_chat_with_model, create_test_base_args ): """Test that move_file function calls are properly logged in chat_with_agent.""" - mock_chat_with_model.return_value = ( - { + mock_chat_with_model.return_value = LlmResult( + assistant_message={ "role": "assistant", "content": [ { @@ -205,15 +264,16 @@ async def test_move_file_logging( } ], }, - [ + tool_calls=[ ToolCall( id="test_id", name="move_file", args={"old_file_path": "old_file.py", "new_file_path": "new_file.py"}, ) ], - 15, - 10, + token_input=15, + token_output=10, + cost_usd=0.05, ) base_args = create_test_base_args(model_id=GoogleModelId.GEMMA_4_31B) @@ -252,8 +312,8 @@ async def test_move_file_logging( async def test_write_and_commit_file_handles_new_content_arg_name( mock_chat_with_model, create_test_base_args ): - mock_chat_with_model.return_value = ( - { + mock_chat_with_model.return_value = LlmResult( + assistant_message={ "role": "assistant", "content": [ { @@ -267,15 +327,16 @@ async def test_write_and_commit_file_handles_new_content_arg_name( } ], }, - [ + tool_calls=[ ToolCall( id="test_id", name="write_and_commit_file", args={"file_path": "test.py", "new_content": "updated content"}, ) ], - 15, - 10, + token_input=15, + token_output=10, + cost_usd=0.05, ) base_args = create_test_base_args(model_id=GoogleModelId.GEMINI_2_5_FLASH) @@ -307,8 +368,8 @@ async def test_write_and_commit_file_handles_new_content_arg_name( async def test_unavailable_tool_sends_slack_notification( mock_slack_notify, mock_chat_with_model, create_test_base_args ): - mock_chat_with_model.return_value = ( - { + mock_chat_with_model.return_value = LlmResult( + assistant_message={ "role": "assistant", "content": [ { @@ -319,9 +380,10 @@ async def test_unavailable_tool_sends_slack_notification( } ], }, - [ToolCall(id="test_id", name="bash", args={"command": "ls -la"})], - 15, - 10, + tool_calls=[ToolCall(id="test_id", name="bash", args={"command": "ls -la"})], + token_input=15, + token_output=10, + cost_usd=0.05, ) base_args = create_test_base_args(model_id=ClaudeModelId.SONNET_4_6) @@ -354,8 +416,8 @@ async def test_verify_task_is_complete_with_pr_changes_returns_is_completed_true _mock_update_comment, mock_get_pr_files, mock_chat_with_model, create_test_base_args ): mock_get_pr_files.return_value = [{"filename": "test.py", "status": "modified"}] - mock_chat_with_model.return_value = ( - { + mock_chat_with_model.return_value = LlmResult( + assistant_message={ "role": "assistant", "content": [ { @@ -366,9 +428,10 @@ async def test_verify_task_is_complete_with_pr_changes_returns_is_completed_true } ], }, - [ToolCall(id="test_id", name="verify_task_is_complete", args={})], - 15, - 10, + tool_calls=[ToolCall(id="test_id", name="verify_task_is_complete", args={})], + token_input=15, + token_output=10, + cost_usd=0.05, ) base_args = create_test_base_args( @@ -403,8 +466,8 @@ async def test_verify_task_is_complete_without_pr_changes_returns_is_completed_f mock_get_pr_files, _mock_update_comment, mock_chat_with_model, create_test_base_args ): mock_get_pr_files.return_value = [] - mock_chat_with_model.return_value = ( - { + mock_chat_with_model.return_value = LlmResult( + assistant_message={ "role": "assistant", "content": [ { @@ -415,9 +478,10 @@ async def test_verify_task_is_complete_without_pr_changes_returns_is_completed_f } ], }, - [ToolCall(id="test_id", name="verify_task_is_complete", args={})], - 15, - 10, + tool_calls=[ToolCall(id="test_id", name="verify_task_is_complete", args={})], + token_input=15, + token_output=10, + cost_usd=0.05, ) base_args = create_test_base_args( @@ -455,8 +519,8 @@ async def test_verify_task_is_complete_with_none_args_still_executes( isinstance(None, dict) is False, so the tool was silently skipped and returned None. Gemma then entered a dead loop returning empty responses for 20 iterations.""" mock_get_pr_files.return_value = [{"filename": "test.py", "status": "modified"}] - mock_chat_with_model.return_value = ( - { + mock_chat_with_model.return_value = LlmResult( + assistant_message={ "role": "assistant", "content": [ { @@ -467,9 +531,10 @@ async def test_verify_task_is_complete_with_none_args_still_executes( } ], }, - [ToolCall(id="test_id", name="verify_task_is_complete", args=None)], - 15, - 10, + tool_calls=[ToolCall(id="test_id", name="verify_task_is_complete", args=None)], + token_input=15, + token_output=10, + cost_usd=0.05, ) base_args = create_test_base_args( @@ -499,8 +564,8 @@ async def test_verify_task_is_complete_with_none_args_still_executes( async def test_regular_tool_returns_is_completed_false( mock_chat_with_model, create_test_base_args ): - mock_chat_with_model.return_value = ( - { + mock_chat_with_model.return_value = LlmResult( + assistant_message={ "role": "assistant", "content": [ { @@ -511,15 +576,16 @@ async def test_regular_tool_returns_is_completed_false( } ], }, - [ + tool_calls=[ ToolCall( id="test_id", name="get_local_file_content", args={"file_path": "test.py"}, ) ], - 15, - 10, + token_input=15, + token_output=10, + cost_usd=0.05, ) base_args = create_test_base_args(model_id=GoogleModelId.GEMINI_2_5_FLASH) @@ -547,11 +613,12 @@ async def test_regular_tool_returns_is_completed_false( async def test_no_tool_call_returns_is_completed_false( mock_chat_with_model, create_test_base_args ): - mock_chat_with_model.return_value = ( - {"role": "assistant", "content": "I'm thinking about it..."}, - [], - 15, - 10, + mock_chat_with_model.return_value = LlmResult( + assistant_message={"role": "assistant", "content": "I'm thinking about it..."}, + tool_calls=[], + token_input=15, + token_output=10, + cost_usd=0.05, ) base_args = create_test_base_args(model_id=ClaudeModelId.SONNET_4_6) @@ -576,8 +643,8 @@ async def test_file_write_result_success_includes_formatted_content( _mock_update_comment, mock_chat_with_model, create_test_base_args ): """Test that FileWriteResult with success=True includes formatted content with line numbers.""" - mock_chat_with_model.return_value = ( - { + mock_chat_with_model.return_value = LlmResult( + assistant_message={ "role": "assistant", "content": [ { @@ -588,15 +655,16 @@ async def test_file_write_result_success_includes_formatted_content( } ], }, - [ + tool_calls=[ ToolCall( id="test_id", name="apply_diff_to_file", args={"file_path": "test.py", "diff": "some diff"}, ) ], - 15, - 10, + token_input=15, + token_output=10, + cost_usd=0.05, ) base_args = create_test_base_args(model_id=ClaudeModelId.OPUS_4_6) @@ -638,8 +706,8 @@ async def test_apply_diff_no_changes_logs_tool_result_message( mock_update_comment, mock_chat_with_model, create_test_base_args ): """Test that apply_diff_to_file with no changes uses tool_result.message instead of hardcoded 'Committed changes'.""" - mock_chat_with_model.return_value = ( - { + mock_chat_with_model.return_value = LlmResult( + assistant_message={ "role": "assistant", "content": [ { @@ -650,15 +718,16 @@ async def test_apply_diff_no_changes_logs_tool_result_message( } ], }, - [ + tool_calls=[ ToolCall( id="test_id", name="apply_diff_to_file", args={"file_path": "test.py", "diff": "some diff"}, ) ], - 15, - 10, + token_input=15, + token_output=10, + cost_usd=0.05, ) base_args = create_test_base_args(model_id=GoogleModelId.GEMMA_4_31B) @@ -710,8 +779,8 @@ async def test_file_write_result_failure_returns_message_only( _mock_update_comment, mock_chat_with_model, create_test_base_args ): """Test that FileWriteResult with success=False returns only the message.""" - mock_chat_with_model.return_value = ( - { + mock_chat_with_model.return_value = LlmResult( + assistant_message={ "role": "assistant", "content": [ { @@ -722,15 +791,16 @@ async def test_file_write_result_failure_returns_message_only( } ], }, - [ + tool_calls=[ ToolCall( id="test_id", name="apply_diff_to_file", args={"file_path": "test.py", "diff": "bad diff"}, ) ], - 15, - 10, + token_input=15, + token_output=10, + cost_usd=0.05, ) base_args = create_test_base_args(model_id=GoogleModelId.GEMINI_2_5_FLASH) @@ -769,8 +839,8 @@ async def test_file_move_result_returns_message( _mock_update_comment, mock_chat_with_model, create_test_base_args ): """Test that FileMoveResult returns the message.""" - mock_chat_with_model.return_value = ( - { + mock_chat_with_model.return_value = LlmResult( + assistant_message={ "role": "assistant", "content": [ { @@ -784,15 +854,16 @@ async def test_file_move_result_returns_message( } ], }, - [ + tool_calls=[ ToolCall( id="test_id", name="move_file", args={"old_file_path": "old.py", "new_file_path": "new.py"}, ) ], - 15, - 10, + token_input=15, + token_output=10, + cost_usd=0.05, ) base_args = create_test_base_args(model_id=ClaudeModelId.SONNET_4_6) @@ -832,8 +903,8 @@ async def test_full_file_read_calls_replace_with_is_full_file_read_true( mock_replace, _mock_update_comment, mock_chat_with_model, create_test_base_args ): """Test that reading a full file calls replace_old_file_content with is_full_file_read=True.""" - mock_chat_with_model.return_value = ( - { + mock_chat_with_model.return_value = LlmResult( + assistant_message={ "role": "assistant", "content": [ { @@ -844,15 +915,16 @@ async def test_full_file_read_calls_replace_with_is_full_file_read_true( } ], }, - [ + tool_calls=[ ToolCall( id="test_id", name="get_local_file_content", args={"file_path": "src/main.py"}, ) ], - 15, - 10, + token_input=15, + token_output=10, + cost_usd=0.05, ) base_args = create_test_base_args(model_id=ClaudeModelId.OPUS_4_6) @@ -886,8 +958,8 @@ async def test_partial_file_read_calls_replace_with_is_full_file_read_false( mock_replace, _mock_update_comment, mock_chat_with_model, create_test_base_args ): """Test that reading a partial file calls replace_old_file_content with is_full_file_read=False.""" - mock_chat_with_model.return_value = ( - { + mock_chat_with_model.return_value = LlmResult( + assistant_message={ "role": "assistant", "content": [ { @@ -902,15 +974,16 @@ async def test_partial_file_read_calls_replace_with_is_full_file_read_false( } ], }, - [ + tool_calls=[ ToolCall( id="test_id", name="get_local_file_content", args={"file_path": "src/main.py", "start_line": 10, "end_line": 20}, ) ], - 15, - 10, + token_input=15, + token_output=10, + cost_usd=0.05, ) base_args = create_test_base_args(model_id=GoogleModelId.GEMMA_4_31B) @@ -943,8 +1016,8 @@ async def test_multiple_parallel_tool_calls( _mock_update_comment, mock_chat_with_model, create_test_base_args ): """Test that multiple tool_use blocks are all executed and results returned in one message.""" - mock_chat_with_model.return_value = ( - { + mock_chat_with_model.return_value = LlmResult( + assistant_message={ "role": "assistant", "content": [ { @@ -967,7 +1040,7 @@ async def test_multiple_parallel_tool_calls( }, ], }, - [ + tool_calls=[ ToolCall( id="tool_1", name="get_local_file_content", args={"file_path": "a.py"} ), @@ -978,8 +1051,9 @@ async def test_multiple_parallel_tool_calls( id="tool_3", name="get_local_file_content", args={"file_path": "c.py"} ), ], - 15, - 10, + token_input=15, + token_output=10, + cost_usd=0.05, ) base_args = create_test_base_args(model_id=GoogleModelId.GEMINI_2_5_FLASH) @@ -1027,8 +1101,8 @@ async def test_gitauto_md_edit_always_allowed( mock_chat_with_model, create_test_base_args, ): - mock_chat_with_model.return_value = ( - { + mock_chat_with_model.return_value = LlmResult( + assistant_message={ "role": "assistant", "content": [ { @@ -1042,7 +1116,7 @@ async def test_gitauto_md_edit_always_allowed( } ], }, - [ + tool_calls=[ ToolCall( id="test_id", name="write_and_commit_file", @@ -1052,8 +1126,9 @@ async def test_gitauto_md_edit_always_allowed( }, ) ], - 15, - 10, + token_input=15, + token_output=10, + cost_usd=0.05, ) base_args = create_test_base_args(model_id=ClaudeModelId.SONNET_4_6) @@ -1102,7 +1177,13 @@ def side_effect(**kwargs): assert kwargs["model_id"] == ClaudeModelId.OPUS_4_6 raise RuntimeError("Opus 4.6 down") assert kwargs["model_id"] == ClaudeModelId.OPUS_4_5 - return ({"role": "assistant", "content": "ok"}, [], 10, 5) + return LlmResult( + assistant_message={"role": "assistant", "content": "ok"}, + tool_calls=[], + token_input=10, + token_output=5, + cost_usd=0.05, + ) mock_chat_with_model.side_effect = side_effect base_args = create_test_base_args(model_id=ClaudeModelId.OPUS_4_6) @@ -1131,7 +1212,13 @@ def side_effect(**kwargs): models_tried.append(kwargs["model_id"]) if len(models_tried) < 3: raise RuntimeError("model down") - return ({"role": "assistant", "content": "ok"}, [], 10, 5) + return LlmResult( + assistant_message={"role": "assistant", "content": "ok"}, + tool_calls=[], + token_input=10, + token_output=5, + cost_usd=0.05, + ) mock_chat_with_model.side_effect = side_effect base_args = create_test_base_args(model_id=ClaudeModelId.SONNET_4_6) @@ -1201,7 +1288,13 @@ def side_effect(**kwargs): raise ClaudeOverloadedError("529") # 4th call: Opus 4.5 succeeds assert kwargs["model_id"] == ClaudeModelId.OPUS_4_5 - return ({"role": "assistant", "content": "ok"}, [], 10, 5) + return LlmResult( + assistant_message={"role": "assistant", "content": "ok"}, + tool_calls=[], + token_input=10, + token_output=5, + cost_usd=0.05, + ) mock_chat_with_model.side_effect = side_effect base_args = create_test_base_args(model_id=GoogleModelId.GEMINI_2_5_FLASH) diff --git a/services/test_chat_with_model.py b/services/test_chat_with_model.py index 423e02a33..42445d602 100644 --- a/services/test_chat_with_model.py +++ b/services/test_chat_with_model.py @@ -2,12 +2,23 @@ from constants.models import ClaudeModelId, GoogleModelId from services.chat_with_model import chat_with_model +from services.llm_result import LlmResult + + +def _make_llm_result(): + return LlmResult( + assistant_message={"role": "assistant", "content": "hi"}, + tool_calls=[], + token_input=100, + token_output=50, + cost_usd=0.01, + ) @patch("services.chat_with_model.chat_with_claude") def test_routes_to_anthropic_for_opus(mock_claude: MagicMock): """Opus model should route to chat_with_claude.""" - mock_claude.return_value = ({"role": "assistant", "content": "hi"}, [], 100, 50) + mock_claude.return_value = _make_llm_result() result = chat_with_model( messages=[{"role": "user", "content": "test"}], @@ -19,13 +30,13 @@ def test_routes_to_anthropic_for_opus(mock_claude: MagicMock): ) mock_claude.assert_called_once() - assert result[0]["role"] == "assistant" + assert result.assistant_message["role"] == "assistant" @patch("services.chat_with_model.chat_with_claude") def test_routes_to_anthropic_for_sonnet(mock_claude: MagicMock): """Sonnet model should route to chat_with_claude.""" - mock_claude.return_value = ({"role": "assistant", "content": "hi"}, [], 100, 50) + mock_claude.return_value = _make_llm_result() chat_with_model( messages=[{"role": "user", "content": "test"}], @@ -42,7 +53,7 @@ def test_routes_to_anthropic_for_sonnet(mock_claude: MagicMock): @patch("services.chat_with_model.chat_with_google") def test_routes_to_google_for_gemma(mock_google: MagicMock): """Gemma model should route to chat_with_google.""" - mock_google.return_value = ({"role": "assistant", "content": "hi"}, [], 100, 50) + mock_google.return_value = _make_llm_result() chat_with_model( messages=[{"role": "user", "content": "test"}], @@ -59,7 +70,7 @@ def test_routes_to_google_for_gemma(mock_google: MagicMock): @patch("services.chat_with_model.chat_with_google") def test_routes_to_google_for_gemini_flash(mock_google: MagicMock): """Gemini Flash model should route to chat_with_google.""" - mock_google.return_value = ({"role": "assistant", "content": "hi"}, [], 100, 50) + mock_google.return_value = _make_llm_result() chat_with_model( messages=[{"role": "user", "content": "test"}], diff --git a/services/test_llm_result.py b/services/test_llm_result.py new file mode 100644 index 000000000..5418c7566 --- /dev/null +++ b/services/test_llm_result.py @@ -0,0 +1,41 @@ +from services.llm_result import LlmResult, ToolCall + + +def test_tool_call_fields(): + tc = ToolCall(id="call_1", name="my_tool", args={"key": "value"}) + assert tc.id == "call_1" + assert tc.name == "my_tool" + assert tc.args == {"key": "value"} + + +def test_tool_call_none_args(): + tc = ToolCall(id="call_2", name="my_tool", args=None) + assert tc.args is None + + +def test_llm_result_fields(): + result = LlmResult( + assistant_message={"role": "assistant", "content": "hello"}, + tool_calls=[], + token_input=100, + token_output=50, + cost_usd=0.025, + ) + assert result.assistant_message == {"role": "assistant", "content": "hello"} + assert not result.tool_calls + assert result.token_input == 100 + assert result.token_output == 50 + assert result.cost_usd == 0.025 + + +def test_llm_result_with_tool_calls(): + tc = ToolCall(id="call_1", name="read_file", args={"path": "/tmp/foo"}) + result = LlmResult( + assistant_message={"role": "assistant", "content": ""}, + tool_calls=[tc], + token_input=200, + token_output=100, + cost_usd=0.05, + ) + assert len(result.tool_calls) == 1 + assert result.tool_calls[0].name == "read_file" diff --git a/services/webhook/check_suite_handler.py b/services/webhook/check_suite_handler.py index ab3c9d293..6200341ea 100644 --- a/services/webhook/check_suite_handler.py +++ b/services/webhook/check_suite_handler.py @@ -10,7 +10,7 @@ # Local imports from config import EMAIL_LINK, GITHUB_APP_USER_NAME, PRODUCT_ID, UTF8 -from constants.agent import MAX_ITERATIONS +from constants.agent import COST_CAP_RATIO, MAX_ITERATIONS from constants.general import MAX_GITAUTO_COMMITS_PER_PR, MAX_INFRA_RETRIES from constants.messages import PERMISSION_DENIED_MESSAGE, CHECK_RUN_FAILED_MESSAGE from services.agents.verify_task_is_complete import verify_task_is_complete @@ -55,6 +55,7 @@ from services.slack.slack_notify import slack_notify from services.supabase.check_suites.insert_check_suite import insert_check_suite from services.supabase.credits.check_purchase_exists import check_purchase_exists +from services.supabase.credits.get_credit_price import get_credit_price from services.supabase.circleci_tokens.get_circleci_token import get_circleci_token from services.supabase.codecov_tokens.get_codecov_token import get_codecov_token from services.supabase.create_user_request import create_user_request @@ -711,12 +712,15 @@ async def handle_check_suite( trigger=trigger, repo_settings=repo_settings, clone_dir=clone_dir ) + cost_cap_usd = get_credit_price(model_id) * COST_CAP_RATIO + for _iteration in range(MAX_ITERATIONS): if should_bail( current_time=current_time, phase="execution", base_args=base_args, slack_thread_ts=thread_ts, + cost_cap_usd=cost_cap_usd, ): break diff --git a/services/webhook/new_pr_handler.py b/services/webhook/new_pr_handler.py index 07ee2018c..6ec6d78ec 100644 --- a/services/webhook/new_pr_handler.py +++ b/services/webhook/new_pr_handler.py @@ -8,7 +8,7 @@ from anthropic.types import MessageParam # Local imports -from constants.agent import MAX_ITERATIONS +from constants.agent import COST_CAP_RATIO, MAX_ITERATIONS from constants.messages import SETTINGS_LINKS from constants.triggers import NewPrTrigger from services.agents.verify_task_is_complete import verify_task_is_complete @@ -49,6 +49,7 @@ from services.supabase.coverages.get_coverages import get_coverages from services.supabase.create_user_request import create_user_request from services.supabase.credits.check_purchase_exists import check_purchase_exists +from services.supabase.credits.get_credit_price import get_credit_price from services.supabase.credits.insert_credit import insert_credit from services.supabase.email_sends.insert_email_send import insert_email_send from services.supabase.email_sends.update_email_send import update_email_send @@ -569,6 +570,8 @@ async def handle_new_pr( total_token_output = 0 is_completed = False completion_reason = "" + revenue_usd = get_credit_price(model_id) + cost_cap_usd = revenue_usd * COST_CAP_RATIO system_message = create_system_message( trigger=trigger, repo_settings=repo_settings, clone_dir=clone_dir @@ -580,6 +583,7 @@ async def handle_new_pr( phase="pr processing", base_args=base_args, slack_thread_ts=None, + cost_cap_usd=cost_cap_usd, ): break @@ -600,6 +604,7 @@ async def handle_new_pr( p = result.p total_token_input += result.token_input total_token_output += result.token_output + if is_completed: logger.info( "Agent signaled completion via verify_task_is_complete, breaking loop" diff --git a/services/webhook/review_run_handler.py b/services/webhook/review_run_handler.py index ee0d69ea4..dd828421b 100644 --- a/services/webhook/review_run_handler.py +++ b/services/webhook/review_run_handler.py @@ -9,7 +9,7 @@ # Local imports from config import GITHUB_APP_USER_NAME, PRODUCT_ID -from constants.agent import MAX_ITERATIONS +from constants.agent import COST_CAP_RATIO, MAX_ITERATIONS from constants.triggers import ReviewTrigger from services.github.types.webhook.review_run_payload import ReviewRunPayload from services.agents.verify_task_is_complete import verify_task_is_complete @@ -32,6 +32,7 @@ from services.github.comments.update_comment import update_comment from services.slack.slack_notify import slack_notify from services.supabase.credits.check_purchase_exists import check_purchase_exists +from services.supabase.credits.get_credit_price import get_credit_price from services.git.create_empty_commit import create_empty_commit from services.git.get_reference import get_reference from services.github.pulls.get_pull_request import get_pull_request @@ -426,12 +427,15 @@ async def handle_review_run( trigger=trigger, repo_settings=repo_settings, clone_dir=clone_dir ) + cost_cap_usd = get_credit_price(model_id) * COST_CAP_RATIO + for _iteration in range(MAX_ITERATIONS): if should_bail( current_time=current_time, phase="execution", base_args=base_args, slack_thread_ts=thread_ts, + cost_cap_usd=cost_cap_usd, ): break diff --git a/services/webhook/test_check_suite_handler.py b/services/webhook/test_check_suite_handler.py index 3127f1ef8..c26b16be0 100644 --- a/services/webhook/test_check_suite_handler.py +++ b/services/webhook/test_check_suite_handler.py @@ -446,6 +446,7 @@ async def test_handle_check_suite_full_workflow( completion_reason="", p=50, is_planned=False, + cost_usd=0.0, ), AgentResult( messages=[], @@ -455,6 +456,7 @@ async def test_handle_check_suite_full_workflow( completion_reason="", p=75, is_planned=False, + cost_usd=0.0, ), ] @@ -1086,6 +1088,7 @@ async def test_check_run_handler_token_accumulation( completion_reason="", p=90, is_planned=False, + cost_usd=0.0, ), AgentResult( messages=[{"role": "user", "content": "test"}], @@ -1095,6 +1098,7 @@ async def test_check_run_handler_token_accumulation( completion_reason="", p=95, is_planned=False, + cost_usd=0.0, ), ] @@ -1347,6 +1351,7 @@ async def test_handle_check_suite_codecov_failure( completion_reason="", p=50, is_planned=False, + cost_usd=0.0, ), AgentResult( messages=[], @@ -1356,6 +1361,7 @@ async def test_handle_check_suite_codecov_failure( completion_reason="", p=75, is_planned=False, + cost_usd=0.0, ), ] @@ -1472,6 +1478,7 @@ async def test_handle_check_suite_codecov_no_token( completion_reason="", p=50, is_planned=False, + cost_usd=0.0, ), AgentResult( messages=[], @@ -1481,6 +1488,7 @@ async def test_handle_check_suite_codecov_no_token( completion_reason="", p=75, is_planned=False, + cost_usd=0.0, ), ] @@ -1593,6 +1601,7 @@ async def test_handle_check_suite_max_iterations_forces_verification( completion_reason="", p=50, is_planned=False, + cost_usd=0.0, ), AgentResult( messages=[], @@ -1602,6 +1611,7 @@ async def test_handle_check_suite_max_iterations_forces_verification( completion_reason="", p=75, is_planned=False, + cost_usd=0.0, ), ] diff --git a/services/webhook/test_new_pr_handler.py b/services/webhook/test_new_pr_handler.py index 9c4ca12be..7c0c32511 100644 --- a/services/webhook/test_new_pr_handler.py +++ b/services/webhook/test_new_pr_handler.py @@ -239,6 +239,7 @@ async def test_image_urls_processing( completion_reason="", p=0, is_planned=False, + cost_usd=0.0, ) mock_get_pr_files.return_value = [] @@ -322,6 +323,7 @@ async def test_image_unsupported_format_skipped( completion_reason="", p=0, is_planned=False, + cost_usd=0.0, ) mock_get_pr_files.return_value = [] @@ -403,6 +405,7 @@ async def test_image_base64_fetch_failed( completion_reason="", p=0, is_planned=False, + cost_usd=0.0, ) mock_get_pr_files.return_value = [] @@ -642,6 +645,7 @@ async def test_retry_loop_exhausted_not_explored_but_committed( completion_reason="", p=10, is_planned=False, + cost_usd=0.0, ), AgentResult( messages=[], @@ -651,6 +655,7 @@ async def test_retry_loop_exhausted_not_explored_but_committed( completion_reason="", p=20, is_planned=False, + cost_usd=0.0, ), AgentResult( messages=[], @@ -660,6 +665,7 @@ async def test_retry_loop_exhausted_not_explored_but_committed( completion_reason="", p=30, is_planned=False, + cost_usd=0.0, ), AgentResult( messages=[], @@ -669,6 +675,7 @@ async def test_retry_loop_exhausted_not_explored_but_committed( completion_reason="", p=40, is_planned=False, + cost_usd=0.0, ), AgentResult( messages=[], @@ -678,6 +685,7 @@ async def test_retry_loop_exhausted_not_explored_but_committed( completion_reason="", p=50, is_planned=False, + cost_usd=0.0, ), AgentResult( messages=[], @@ -687,6 +695,7 @@ async def test_retry_loop_exhausted_not_explored_but_committed( completion_reason="", p=60, is_planned=False, + cost_usd=0.0, ), AgentResult( messages=[], @@ -696,6 +705,7 @@ async def test_retry_loop_exhausted_not_explored_but_committed( completion_reason="", p=70, is_planned=False, + cost_usd=0.0, ), AgentResult( messages=[], @@ -705,6 +715,7 @@ async def test_retry_loop_exhausted_not_explored_but_committed( completion_reason="", p=80, is_planned=False, + cost_usd=0.0, ), AgentResult( messages=[], @@ -714,6 +725,7 @@ async def test_retry_loop_exhausted_not_explored_but_committed( completion_reason="", p=90, is_planned=False, + cost_usd=0.0, ), AgentResult( messages=[], @@ -723,6 +735,7 @@ async def test_retry_loop_exhausted_not_explored_but_committed( completion_reason="", p=95, is_planned=False, + cost_usd=0.0, ), ] mock_verify_task_is_complete.return_value = { @@ -816,6 +829,7 @@ async def test_retry_loop_exhausted_explored_but_not_committed( completion_reason="", p=10, is_planned=False, + cost_usd=0.0, ), AgentResult( messages=[], @@ -825,6 +839,7 @@ async def test_retry_loop_exhausted_explored_but_not_committed( completion_reason="", p=20, is_planned=False, + cost_usd=0.0, ), AgentResult( messages=[], @@ -834,6 +849,7 @@ async def test_retry_loop_exhausted_explored_but_not_committed( completion_reason="", p=30, is_planned=False, + cost_usd=0.0, ), AgentResult( messages=[], @@ -843,6 +859,7 @@ async def test_retry_loop_exhausted_explored_but_not_committed( completion_reason="", p=40, is_planned=False, + cost_usd=0.0, ), AgentResult( messages=[], @@ -852,6 +869,7 @@ async def test_retry_loop_exhausted_explored_but_not_committed( completion_reason="", p=50, is_planned=False, + cost_usd=0.0, ), AgentResult( messages=[], @@ -861,6 +879,7 @@ async def test_retry_loop_exhausted_explored_but_not_committed( completion_reason="", p=60, is_planned=False, + cost_usd=0.0, ), AgentResult( messages=[], @@ -870,6 +889,7 @@ async def test_retry_loop_exhausted_explored_but_not_committed( completion_reason="", p=70, is_planned=False, + cost_usd=0.0, ), AgentResult( messages=[], @@ -879,6 +899,7 @@ async def test_retry_loop_exhausted_explored_but_not_committed( completion_reason="", p=80, is_planned=False, + cost_usd=0.0, ), AgentResult( messages=[], @@ -888,6 +909,7 @@ async def test_retry_loop_exhausted_explored_but_not_committed( completion_reason="", p=90, is_planned=False, + cost_usd=0.0, ), ] mock_verify_task_is_complete.return_value = { @@ -978,6 +1000,7 @@ async def test_retry_counter_reset_on_successful_loop( completion_reason="", p=10, is_planned=False, + cost_usd=0.0, ), AgentResult( messages=[], @@ -987,6 +1010,7 @@ async def test_retry_counter_reset_on_successful_loop( completion_reason="", p=20, is_planned=False, + cost_usd=0.0, ), AgentResult( messages=[], @@ -996,6 +1020,7 @@ async def test_retry_counter_reset_on_successful_loop( completion_reason="", p=100, is_planned=False, + cost_usd=0.0, ), ] @@ -1083,6 +1108,7 @@ async def test_non_test_file_skipped_in_header_merge( completion_reason="", p=50, is_planned=False, + cost_usd=0.0, ) mock_get_pr_files.return_value = [{"filename": "src/main.py"}] mock_is_test_file.return_value = False @@ -1175,6 +1201,7 @@ async def test_test_file_header_merge( completion_reason="", p=50, is_planned=False, + cost_usd=0.0, ) mock_get_pr_files.return_value = [{"filename": "tests/test_example.py"}] mock_is_test_file.return_value = True @@ -1270,6 +1297,7 @@ async def test_test_file_header_merge_no_content( completion_reason="", p=50, is_planned=False, + cost_usd=0.0, ) mock_get_pr_files.return_value = [{"filename": "tests/test_example.py"}] mock_is_test_file.return_value = True @@ -1372,6 +1400,7 @@ async def test_test_file_header_merge_no_change( completion_reason="", p=50, is_planned=False, + cost_usd=0.0, ) mock_get_pr_files.return_value = [{"filename": "tests/test_example.py"}] mock_is_test_file.return_value = True @@ -1474,6 +1503,7 @@ async def test_credits_depleted_email_sent( completion_reason="", p=50, is_planned=False, + cost_usd=0.0, ) mock_get_pr_files.return_value = [{"filename": "test.py", "status": "modified"}] mock_get_owner.return_value = {"id": 456, "credit_balance_usd": 0} @@ -1605,6 +1635,7 @@ async def test_new_pr_handler_token_accumulation( completion_reason="", p=90, is_planned=False, + cost_usd=0.0, ), AgentResult( messages=[ @@ -1617,6 +1648,7 @@ async def test_new_pr_handler_token_accumulation( completion_reason="", p=95, is_planned=False, + cost_usd=0.0, ), ] mock_get_pull_request_files.return_value = [ @@ -1770,6 +1802,7 @@ async def test_few_test_files_include_contents_in_prompt( completion_reason="", p=0, is_planned=False, + cost_usd=0.0, ) mock_update_comment.return_value = None mock_update_usage.return_value = None @@ -1916,6 +1949,7 @@ async def test_many_test_files_include_paths_only_in_prompt( completion_reason="", p=0, is_planned=False, + cost_usd=0.0, ) mock_update_comment.return_value = None mock_update_usage.return_value = None diff --git a/services/webhook/test_review_run_handler.py b/services/webhook/test_review_run_handler.py index c163c8776..4a6edcdf0 100644 --- a/services/webhook/test_review_run_handler.py +++ b/services/webhook/test_review_run_handler.py @@ -6,14 +6,14 @@ import pytest from config import PRODUCT_ID - -FIXTURES_DIR = Path(__file__).parent / "fixtures" from services.agents.verify_task_is_complete import VerifyTaskIsCompleteResult from services.agents.verify_task_is_ready import VerifyTaskIsReadyResult from services.chat_with_agent import AgentResult from services.github.pulls.get_review_thread_comments import ReviewThreadResult from services.webhook.review_run_handler import handle_review_run +FIXTURES_DIR = Path(__file__).parent / "fixtures" + @pytest.fixture def mock_review_comment_payload(): @@ -163,6 +163,7 @@ async def test_review_run_handler_accumulates_tokens_correctly( completion_reason="", p=40, is_planned=False, + cost_usd=0.0, ), ] @@ -297,6 +298,7 @@ async def test_review_run_handler_max_iterations_forces_verification( completion_reason="", p=40, is_planned=False, + cost_usd=0.0, ), AgentResult( messages=[{"role": "user", "content": "review"}], @@ -306,6 +308,7 @@ async def test_review_run_handler_max_iterations_forces_verification( completion_reason="", p=60, is_planned=False, + cost_usd=0.0, ), ] @@ -542,6 +545,7 @@ async def test_bot_first_review_comment_is_processed( completion_reason="Removed the unused variable as suggested.", p=40, is_planned=False, + cost_usd=0.0, ) await handle_review_run(mock_bot_review_comment_payload, trigger="pr_file_review") @@ -736,6 +740,7 @@ async def test_human_review_comment_always_processed( completion_reason="Fixed the logic as requested.", p=40, is_planned=False, + cost_usd=0.0, ) await handle_review_run(mock_review_comment_payload, trigger="pr_file_review") @@ -874,6 +879,7 @@ async def test_pr_comment_uses_create_comment_not_reply( completion_reason="Completed the task as requested.", p=40, is_planned=False, + cost_usd=0.0, ) await handle_review_run(mock_pr_comment_payload, trigger="pr_comment") @@ -1164,6 +1170,7 @@ async def test_bot_pr_comment_mentioning_pr_file_is_processed( completion_reason="Addressed the lint error.", p=40, is_planned=False, + cost_usd=0.0, ) await handle_review_run(mock_bot_pr_comment_payload, trigger="pr_comment") diff --git a/services/webhook/test_setup_handler.py b/services/webhook/test_setup_handler.py index aeb9ded24..ee00d56c5 100644 --- a/services/webhook/test_setup_handler.py +++ b/services/webhook/test_setup_handler.py @@ -24,6 +24,7 @@ def _make_agent_result(is_completed=False): completion_reason="", p=0, is_planned=False, + cost_usd=0.0, ) diff --git a/services/webhook/test_webhook_handler.py b/services/webhook/test_webhook_handler.py index e12c62616..f3aa249a8 100644 --- a/services/webhook/test_webhook_handler.py +++ b/services/webhook/test_webhook_handler.py @@ -231,7 +231,10 @@ async def test_handle_webhook_event_pull_request_labeled_dashboard( payload = { "action": "labeled", "label": {"name": "gitauto"}, - "pull_request": {"number": 42, "head": {"ref": "gitauto/dashboard-20250101-120000-Ab12"}}, + "pull_request": { + "number": 42, + "head": {"ref": "gitauto/dashboard-20250101-120000-Ab12"}, + }, "sender": {"login": "test-user", "id": 12345}, } @@ -252,7 +255,10 @@ async def test_handle_webhook_event_pull_request_labeled_schedule( payload = { "action": "labeled", "label": {"name": "gitauto"}, - "pull_request": {"number": 42, "head": {"ref": "gitauto/schedule-20250101-120000-Ab12"}}, + "pull_request": { + "number": 42, + "head": {"ref": "gitauto/schedule-20250101-120000-Ab12"}, + }, "sender": {"login": "test-user", "id": 12345}, } @@ -273,7 +279,10 @@ async def test_handle_webhook_event_pull_request_labeled_non_gitauto_label_ignor payload = { "action": "labeled", "label": {"name": "dependencies"}, - "pull_request": {"number": 99, "head": {"ref": "dependabot/npm_and_yarn/ajv-6.14.0"}}, + "pull_request": { + "number": 99, + "head": {"ref": "dependabot/npm_and_yarn/ajv-6.14.0"}, + }, "sender": {"login": "dependabot[bot]", "id": 49699333}, } @@ -290,7 +299,10 @@ async def test_handle_webhook_event_pull_request_labeled_bot_sender_ignored( payload = { "action": "labeled", "label": {"name": "gitauto"}, - "pull_request": {"number": 99, "head": {"ref": "dependabot/npm_and_yarn/ajv-6.14.0"}}, + "pull_request": { + "number": 99, + "head": {"ref": "dependabot/npm_and_yarn/ajv-6.14.0"}, + }, "sender": {"login": "dependabot[bot]", "id": 49699333}, } @@ -307,7 +319,10 @@ async def test_handle_webhook_event_pull_request_labeled_gitauto_bot_allowed( payload = { "action": "labeled", "label": {"name": "gitauto"}, - "pull_request": {"number": 42, "head": {"ref": "gitauto/schedule-20250101-120000-Ab12"}}, + "pull_request": { + "number": 42, + "head": {"ref": "gitauto/schedule-20250101-120000-Ab12"}, + }, "sender": {"login": "gitauto[bot]", "id": 160085510}, } @@ -395,6 +410,7 @@ async def test_handle_webhook_event_pull_request_closed_non_gitauto_branch(self) payload = { "action": "closed", "pull_request": { + "number": 456, "merged_at": "2023-01-01T00:00:00Z", "head": {"ref": "feature/some-branch"}, }, @@ -529,7 +545,7 @@ async def test_handle_webhook_event_pull_request_review_comment_created( self, mock_handle_review_run ): """Test handling of pull request review comment created event.""" - payload = {"action": "created"} + payload = {"action": "created", "pull_request": {"number": 456}} await handle_webhook_event( event_name="pull_request_review_comment", payload=payload @@ -544,7 +560,7 @@ async def test_handle_webhook_event_pull_request_review_comment_edited( self, mock_handle_review_run ): """Test handling of pull request review comment edited event.""" - payload = {"action": "edited"} + payload = {"action": "edited", "pull_request": {"number": 456}} await handle_webhook_event( event_name="pull_request_review_comment", payload=payload diff --git a/services/webhook/utils/should_bail.py b/services/webhook/utils/should_bail.py index 2a689acb4..85a6a2e18 100644 --- a/services/webhook/utils/should_bail.py +++ b/services/webhook/utils/should_bail.py @@ -3,6 +3,7 @@ from services.github.comments.update_comment import update_comment from services.github.pulls.is_pull_request_open import is_pull_request_open from services.slack.slack_notify import slack_notify +from services.supabase.llm_requests.get_total_cost_for_pr import get_total_cost_for_pr from services.types.base_args import BaseArgs from utils.error.handle_exceptions import handle_exceptions from utils.logging.logging_config import logger @@ -19,6 +20,7 @@ def should_bail( phase: str, base_args: BaseArgs, slack_thread_ts: str | None, + cost_cap_usd: float, ): """Check if the loop should stop. Handles logging, comment updates, and slack.""" owner = base_args["owner"] @@ -40,6 +42,18 @@ def should_bail( msg = get_oom_message(used_mb, phase) logger.error(msg) + # Cost cap: bail silently (log only, no comment to user) + if not msg and pr_number: + total_cost = get_total_cost_for_pr(owner, repo, pr_number) + if total_cost >= cost_cap_usd: + logger.warning( + "Cost cap reached: $%.2f >= $%.2f cap. Stopping %s silently.", + total_cost, + cost_cap_usd, + phase, + ) + return True + if not msg: if pr_number and not is_pull_request_open( owner=owner, repo=repo, pr_number=pr_number, token=token diff --git a/services/webhook/utils/test_should_bail.py b/services/webhook/utils/test_should_bail.py index 0f931a85b..7b2a82964 100644 --- a/services/webhook/utils/test_should_bail.py +++ b/services/webhook/utils/test_should_bail.py @@ -3,6 +3,7 @@ from services.webhook.utils.should_bail import should_bail +MOCK_COST = "services.webhook.utils.should_bail.get_total_cost_for_pr" MOCK_OOM_OK = "services.webhook.utils.should_bail.is_lambda_oom_approaching" MOCK_TIMEOUT = "services.webhook.utils.should_bail.is_lambda_timeout_approaching" @@ -10,11 +11,13 @@ @patch("services.webhook.utils.should_bail.update_comment") @patch("services.webhook.utils.should_bail.check_branch_exists", return_value=True) @patch("services.webhook.utils.should_bail.is_pull_request_open", return_value=True) +@patch(MOCK_COST, return_value=0.0) @patch(MOCK_OOM_OK, return_value=(False, 500.0)) @patch(MOCK_TIMEOUT, return_value=(False, 60.0)) def test_returns_false_when_all_checks_pass( _mock_timeout, _mock_oom, + _mock_cost, _mock_pr_open, _mock_branch, _mock_update, @@ -26,6 +29,7 @@ def test_returns_false_when_all_checks_pass( "phase": "execution", "base_args": base_args, "slack_thread_ts": None, + "cost_cap_usd": 6.40, } assert should_bail(**bail_kwargs) is False @@ -42,6 +46,7 @@ def test_returns_true_on_timeout( "phase": "execution", "base_args": base_args, "slack_thread_ts": None, + "cost_cap_usd": 6.40, } assert should_bail(**bail_kwargs) is True @@ -49,11 +54,13 @@ def test_returns_true_on_timeout( @patch("services.webhook.utils.should_bail.update_comment") @patch("services.webhook.utils.should_bail.check_branch_exists", return_value=True) @patch("services.webhook.utils.should_bail.is_pull_request_open", return_value=False) +@patch(MOCK_COST, return_value=0.0) @patch(MOCK_OOM_OK, return_value=(False, 500.0)) @patch(MOCK_TIMEOUT, return_value=(False, 60.0)) def test_returns_true_when_pr_closed( _mock_timeout, _mock_oom, + _mock_cost, _mock_pr_open, _mock_branch, _mock_update, @@ -65,6 +72,7 @@ def test_returns_true_when_pr_closed( "phase": "execution", "base_args": base_args, "slack_thread_ts": None, + "cost_cap_usd": 6.40, } assert should_bail(**bail_kwargs) is True @@ -72,11 +80,13 @@ def test_returns_true_when_pr_closed( @patch("services.webhook.utils.should_bail.update_comment") @patch("services.webhook.utils.should_bail.check_branch_exists", return_value=False) @patch("services.webhook.utils.should_bail.is_pull_request_open", return_value=True) +@patch(MOCK_COST, return_value=0.0) @patch(MOCK_OOM_OK, return_value=(False, 500.0)) @patch(MOCK_TIMEOUT, return_value=(False, 60.0)) def test_returns_true_when_branch_deleted( _mock_timeout, _mock_oom, + _mock_cost, _mock_pr_open, _mock_branch, _mock_update, @@ -88,6 +98,7 @@ def test_returns_true_when_branch_deleted( "phase": "execution", "base_args": base_args, "slack_thread_ts": None, + "cost_cap_usd": 6.40, } assert should_bail(**bail_kwargs) is True @@ -111,6 +122,7 @@ def test_timeout_checked_first( "phase": "execution", "base_args": base_args, "slack_thread_ts": None, + "cost_cap_usd": 6.40, } assert should_bail(**bail_kwargs) is True mock_pr_open.assert_not_called() @@ -130,6 +142,7 @@ def test_calls_update_comment_on_bail( "phase": "execution", "base_args": base_args, "slack_thread_ts": None, + "cost_cap_usd": 6.40, } should_bail(**bail_kwargs) mock_update.assert_called_once() @@ -148,6 +161,7 @@ def test_calls_slack_when_thread_ts_provided( phase="execution", base_args=base_args, slack_thread_ts="ts123", + cost_cap_usd=6.40, ) mock_slack.assert_called_once() @@ -165,6 +179,7 @@ def test_skips_slack_when_thread_ts_is_none( "phase": "execution", "base_args": base_args, "slack_thread_ts": None, + "cost_cap_usd": 6.40, } should_bail(**bail_kwargs) mock_slack.assert_not_called() @@ -185,6 +200,7 @@ def test_skips_pr_check_when_pr_number_not_set( phase="execution", base_args=args_without_pr_number, slack_thread_ts=None, + cost_cap_usd=6.40, ) assert result is False mock_pr_open.assert_not_called() @@ -205,6 +221,7 @@ def test_returns_true_on_oom( "phase": "execution", "base_args": base_args, "slack_thread_ts": None, + "cost_cap_usd": 6.40, } assert should_bail(**bail_kwargs) is True @@ -228,6 +245,7 @@ def test_oom_skips_pr_and_branch_checks( "phase": "execution", "base_args": base_args, "slack_thread_ts": None, + "cost_cap_usd": 6.40, } assert should_bail(**bail_kwargs) is True mock_pr_open.assert_not_called() @@ -246,8 +264,102 @@ def test_timeout_takes_priority_over_oom( "phase": "execution", "base_args": base_args, "slack_thread_ts": None, + "cost_cap_usd": 6.40, } result = should_bail(**bail_kwargs) assert result is True # OOM check should not be called when timeout is already approaching mock_oom.assert_not_called() + + +# --- Cost cap tests --- + + +@patch("services.webhook.utils.should_bail.update_comment") +@patch(MOCK_COST, return_value=7.00) +@patch(MOCK_OOM_OK, return_value=(False, 500.0)) +@patch(MOCK_TIMEOUT, return_value=(False, 60.0)) +def test_returns_true_on_cost_cap( + _mock_timeout, _mock_oom, _mock_cost, mock_update, create_test_base_args +): + base_args = create_test_base_args(pr_number=42, new_branch="feature-branch") + bail_kwargs = { + "current_time": 1000.0, + "phase": "execution", + "base_args": base_args, + "slack_thread_ts": None, + "cost_cap_usd": 6.40, + } + assert should_bail(**bail_kwargs) is True + # Cost cap bails silently — no update_comment to customer + mock_update.assert_not_called() + + +@patch("services.webhook.utils.should_bail.slack_notify") +@patch("services.webhook.utils.should_bail.update_comment") +@patch(MOCK_COST, return_value=7.00) +@patch(MOCK_OOM_OK, return_value=(False, 500.0)) +@patch(MOCK_TIMEOUT, return_value=(False, 60.0)) +def test_cost_cap_skips_slack( + _mock_timeout, + _mock_oom, + _mock_cost, + _mock_update, + mock_slack, + create_test_base_args, +): + base_args = create_test_base_args(pr_number=42, new_branch="feature-branch") + should_bail( + current_time=1000.0, + phase="execution", + base_args=base_args, + slack_thread_ts="ts123", + cost_cap_usd=6.40, + ) + # Cost cap bails silently — no slack notification + mock_slack.assert_not_called() + + +@patch("services.webhook.utils.should_bail.check_branch_exists") +@patch("services.webhook.utils.should_bail.is_pull_request_open") +@patch(MOCK_COST, return_value=7.00) +@patch(MOCK_OOM_OK, return_value=(False, 500.0)) +@patch(MOCK_TIMEOUT, return_value=(False, 60.0)) +def test_cost_cap_skips_pr_and_branch_checks( + _mock_timeout, + _mock_oom, + _mock_cost, + mock_pr_open, + mock_branch, + create_test_base_args, +): + base_args = create_test_base_args(pr_number=42, new_branch="feature-branch") + bail_kwargs = { + "current_time": 1000.0, + "phase": "execution", + "base_args": base_args, + "slack_thread_ts": None, + "cost_cap_usd": 6.40, + } + assert should_bail(**bail_kwargs) is True + mock_pr_open.assert_not_called() + mock_branch.assert_not_called() + + +@patch(MOCK_COST) +@patch(MOCK_OOM_OK, return_value=(False, 500.0)) +@patch(MOCK_TIMEOUT, return_value=(False, 60.0)) +def test_cost_cap_skipped_when_no_pr_number( + _mock_timeout, _mock_oom, mock_cost, create_test_base_args +): + args_without_pr = create_test_base_args(new_branch="feature-branch") + del args_without_pr["pr_number"] + should_bail( + current_time=1000.0, + phase="execution", + base_args=args_without_pr, + slack_thread_ts=None, + cost_cap_usd=6.40, + ) + # No pr_number → cost check not reached + mock_cost.assert_not_called()