Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions constants/agent.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
# Stop agent loop when LLM cost reaches this fraction of revenue (credit_cost_usd)
COST_CAP_RATIO = 0.8
MAX_ITERATIONS = 30
MAX_PLANNING_ITERATIONS = 20
9 changes: 2 additions & 7 deletions scripts/git/pre_commit_hook.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/bin/bash
# Git pre-commit hook. No stashing - runs on the working directory as-is.
# Git pre-commit hook. Does NOT auto-stage — user handles staging.
# Install: ln -sf ../../scripts/git/pre_commit_hook.sh .git/hooks/pre-commit
set -uo pipefail

Expand Down Expand Up @@ -35,17 +35,12 @@ python3 schemas/supabase/generate_types.py && git add schemas/supabase/
# Get staged Python files (excluding deleted, .venv, schemas)
STAGED_PY_FILES=$(git diff --cached --name-only --diff-filter=d -- '*.py' | grep -v '^\.\?venv/' | grep -v '^schemas/')

# Format and auto-fix staged Python files
# Format staged Python files (user re-stages if needed)
if [ -n "$STAGED_PY_FILES" ]; then
# shellcheck disable=SC2086
black $STAGED_PY_FILES
# shellcheck disable=SC2086
git add $STAGED_PY_FILES

# shellcheck disable=SC2086
ruff check --fix $STAGED_PY_FILES
# shellcheck disable=SC2086
git add $STAGED_PY_FILES
fi

# Markdownlint for staged .md files
Expand Down
38 changes: 18 additions & 20 deletions services/chat_with_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,11 @@


@dataclass
class AgentResult:
class AgentResult: # pylint: disable=too-many-instance-attributes
messages: list[MessageParam]
token_input: int
token_output: int
cost_usd: float
is_completed: bool
completion_reason: str
p: int
Expand Down Expand Up @@ -67,12 +68,7 @@ async def chat_with_agent(
logger.info("Using model: %s", current_model)

try:
(
response_message,
tool_calls,
token_input,
token_output,
) = chat_with_model(
llm_result = chat_with_model(
messages=messages,
system_content=system_message,
tools=tools,
Expand Down Expand Up @@ -125,24 +121,25 @@ async def chat_with_agent(
)

# Return if no tool calls (agent returned text without calling a tool)
if not tool_calls:
logger.info("No tools were called. Response: %s", response_message)
messages.append(response_message)
if not llm_result.tool_calls:
logger.info("No tools were called. Response: %s", llm_result.assistant_message)
messages.append(llm_result.assistant_message)
return AgentResult(
messages=messages,
token_input=token_input,
token_output=token_output,
token_input=llm_result.token_input,
token_output=llm_result.token_output,
cost_usd=llm_result.cost_usd,
is_completed=False,
completion_reason="",
p=p,
is_planned=False,
)

# Append assistant message before processing tool calls
messages.append(response_message)
messages.append(llm_result.assistant_message)

# Extract text from the assistant message for completion context
content = response_message["content"]
content = llm_result.assistant_message["content"]
if isinstance(content, str):
assistant_text = content
else:
Expand All @@ -156,11 +153,11 @@ async def chat_with_agent(
tool_result_blocks: list[ToolResultBlockParam] = []
log_msgs: list[str] = []
is_completed = False
num_tool_calls = len(tool_calls)
num_tool_calls = len(llm_result.tool_calls)
logger.info("Processing %d tool call(s)", num_tool_calls)

# pylint: disable-next=too-many-nested-blocks
for i, tc in enumerate(tool_calls, start=1):
for i, tc in enumerate(llm_result.tool_calls, start=1):
tool_use_id = tc.id
tool_name = tc.name
tool_args = tc.args
Expand Down Expand Up @@ -505,7 +502,7 @@ async def chat_with_agent(
if log_msgs:
update_comment(
body=create_progress_bar(
p=p + 5 * len(tool_calls), msg="\n".join(log_messages)
p=p + 5 * len(llm_result.tool_calls), msg="\n".join(log_messages)
),
base_args=base_args,
)
Expand All @@ -520,10 +517,11 @@ async def chat_with_agent(

return AgentResult(
messages=messages,
token_input=token_input,
token_output=token_output,
token_input=llm_result.token_input,
token_output=llm_result.token_output,
cost_usd=llm_result.cost_usd,
is_completed=is_completed,
completion_reason=assistant_text,
p=p + 5 * len(tool_calls),
p=p + 5 * len(llm_result.tool_calls),
is_planned=False,
)
30 changes: 11 additions & 19 deletions services/claude/chat_with_claude.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,26 @@
# Standard imports
from dataclasses import dataclass
import time

# Third party imports
from anthropic import AuthenticationError
from anthropic._exceptions import OverloadedError
from anthropic.types import MessageParam, ToolUnionParam, ToolUseBlock

# Local imports
from constants.claude import CONTEXT_WINDOW, MAX_OUTPUT_TOKENS
from constants.models import ClaudeModelId
from services.claude.client import claude
from services.claude.strip_strict_from_tools import strip_strict_from_tools
from services.claude.exceptions import (
ClaudeAuthenticationError,
ClaudeOverloadedError,
)
from services.claude.remove_outdated_file_edit_attempts import (
remove_outdated_file_edit_attempts,
)
from services.claude.strip_strict_from_tools import strip_strict_from_tools
from services.claude.trim_messages import trim_messages_to_token_limit
from services.llm_result import LlmResult, ToolCall
from services.supabase.llm_requests.insert_llm_request import insert_llm_request
from utils.error.handle_exceptions import handle_exceptions


@dataclass
class ToolCall:
id: str
name: str
args: dict | None


@handle_exceptions(raise_on_error=True)
def chat_with_claude(
messages: list[MessageParam],
Expand Down Expand Up @@ -126,7 +116,7 @@ def chat_with_claude(
# Combine system message with user messages for logging
system_msg: MessageParam = {"role": "user", "content": system_content}
full_messages = [system_msg, *messages]
insert_llm_request(
llm_record = insert_llm_request(
usage_id=usage_id,
provider="claude",
model_id=model_id,
Expand All @@ -137,10 +127,12 @@ def chat_with_claude(
response_time_ms=response_time_ms,
created_by=created_by,
)

return (
assistant_message,
tool_calls,
token_input,
token_output,
cost_usd = llm_record["total_cost_usd"] if llm_record else 0.0

return LlmResult(
assistant_message=assistant_message,
tool_calls=tool_calls,
token_input=token_input,
token_output=token_output,
cost_usd=cost_usd,
)
25 changes: 12 additions & 13 deletions services/claude/test_chat_with_claude.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def test_chat_with_claude_success(mock_claude, mock_insert_llm_request):
mock_response = Mock()
mock_response.content = [Mock(type="text", text="Hello! How can I help you?")]
mock_response.usage = Mock(output_tokens=15)
mock_insert_llm_request.return_value = {"total_cost_usd": 0.05}

mock_claude.messages.create.return_value = mock_response
mock_claude.messages.count_tokens.return_value = Mock(input_tokens=20)
Expand All @@ -31,13 +32,13 @@ def test_chat_with_claude_success(mock_claude, mock_insert_llm_request):
created_by="4:test-user",
)

assistant_message, tool_calls, token_input, token_output = result
assert assistant_message["role"] == "assistant"
content = cast(list, assistant_message["content"])
assert result.assistant_message["role"] == "assistant"
content = cast(list, result.assistant_message["content"])
assert content[0]["text"] == "Hello! How can I help you?"
assert not tool_calls
assert token_input == 20
assert token_output == 15
assert not result.tool_calls
assert result.token_input == 20
assert result.token_output == 15
assert result.cost_usd == 0.05

mock_insert_llm_request.assert_called_once()
call_args = mock_insert_llm_request.call_args[1]
Expand Down Expand Up @@ -88,11 +89,10 @@ def test_chat_with_claude_with_tool_use(mock_claude, mock_insert_llm_request):
created_by="4:test-user",
)

_, tool_calls, _, _ = result
assert len(tool_calls) == 1
assert tool_calls[0].id == "tool_123"
assert tool_calls[0].name == "test_function"
assert tool_calls[0].args == {"param": "value"}
assert len(result.tool_calls) == 1
assert result.tool_calls[0].id == "tool_123"
assert result.tool_calls[0].name == "test_function"
assert result.tool_calls[0].args == {"param": "value"}

mock_insert_llm_request.assert_called_once()

Expand All @@ -116,8 +116,7 @@ def test_chat_with_claude_no_usage_response(mock_claude, mock_insert_llm_request
created_by="4:test-user",
)

_, _, _, token_output = result
assert token_output == 0 # output tokens should be 0 when no usage info
assert result.token_output == 0 # output tokens should be 0 when no usage info
mock_insert_llm_request.assert_called_once()


Expand Down
23 changes: 9 additions & 14 deletions services/google_ai/chat_with_google.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Standard imports
import time
import uuid
from dataclasses import dataclass

# Third-party imports
from anthropic.types import MessageParam, ToolUnionParam
Expand All @@ -12,18 +11,12 @@
from services.google_ai.client import get_google_ai_client
from services.google_ai.convert_messages import convert_messages_to_google
from services.google_ai.convert_tools import convert_tools_to_google
from services.llm_result import LlmResult, ToolCall
from services.supabase.llm_requests.insert_llm_request import insert_llm_request
from utils.error.handle_exceptions import handle_exceptions
from utils.logging.logging_config import logger


@dataclass
class ToolCall:
id: str
name: str
args: dict | None


@handle_exceptions(raise_on_error=True)
def chat_with_google(
messages: list[MessageParam],
Expand Down Expand Up @@ -105,7 +98,7 @@ def chat_with_google(
# Log to Supabase
system_msg: MessageParam = {"role": "user", "content": system_content}
full_messages = [system_msg, *messages]
insert_llm_request(
llm_record = insert_llm_request(
usage_id=usage_id,
provider="google",
model_id=model_id,
Expand All @@ -116,6 +109,7 @@ def chat_with_google(
response_time_ms=response_time_ms,
created_by=created_by,
)
cost_usd = llm_record["total_cost_usd"] if llm_record else 0.0

logger.info(
"Google AI response: model=%s, input_tokens=%d, output_tokens=%d, tool_calls=%d",
Expand All @@ -125,9 +119,10 @@ def chat_with_google(
len(tool_calls),
)

return (
assistant_message,
tool_calls,
token_input,
token_output,
return LlmResult(
assistant_message=assistant_message,
tool_calls=tool_calls,
token_input=token_input,
token_output=token_output,
cost_usd=cost_usd,
)
Loading