Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions sentry_sdk/ai/span_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import sentry_sdk
from sentry_sdk.consts import SPANDATA
from sentry_sdk.ai.monitoring import record_token_usage
from sentry_sdk.ai.utils import (
get_first_from_sources,
set_data_normalized,
set_span_data_from_sources,
normalize_message_roles,
truncate_and_annotate_messages,
)
from sentry_sdk.scope import should_send_default_pii

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from typing import Any, Dict, List, Optional

from sentry_sdk.tracing import Span


def set_request_span_data(span, kwargs, integration, config, span_data=None):
# type: (Span, Dict[str, Any], Any, Dict[str, Any], Dict[str, Any] | None) -> None
"""Set request/static span data from a declarative config."""
for key, value in config.get("static", {}).items():
set_data_normalized(span, key, value)
if span_data:
for key, value in span_data.items():
set_data_normalized(span, key, value)

for kwarg_key, span_attr in config.get("params", {}).items():
if kwarg_key in kwargs:
value = kwargs[kwarg_key]
set_data_normalized(span, span_attr, value)

if should_send_default_pii() and integration.include_prompts:
for kwarg_key, span_attr in config.get("pii_params", {}).items():
if kwarg_key in kwargs:
value = kwargs[kwarg_key]
set_data_normalized(span, span_attr, value)


def set_request_messages(span, messages, target=None):
# type: (Span, Any, Optional[str]) -> None
"""Normalize, truncate, and set request messages on the span.

Caller is responsible for PII gating.
"""
if not messages:
return
messages = normalize_message_roles(messages)
scope = sentry_sdk.get_current_scope()
messages = truncate_and_annotate_messages(messages, span, scope)
if messages is not None:
set_data_normalized(
span, target or SPANDATA.GEN_AI_REQUEST_MESSAGES, messages, unpack=False
)


def set_response_span_data(
span, response, include_pii, response_config, response_text=None
):
# type: (Span, Any, bool, Dict[str, Any], Optional[List[str]]) -> None
"""Set response span data from a declarative config."""
set_span_data_from_sources(
span, response, response_config.get("sources", {}), require_truthy=False
)

if include_pii:
pii_sources = response_config.get("pii_sources")
if pii_sources:
set_span_data_from_sources(span, response, pii_sources, require_truthy=True)
if response_text:
set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_TEXT, response_text)

usage_config = response_config.get("usage")
if usage_config:
record_token_usage(
span,
input_tokens=get_first_from_sources(
response, usage_config.get("input_tokens", [])
),
output_tokens=get_first_from_sources(
response, usage_config.get("output_tokens", [])
),
total_tokens=get_first_from_sources(
response, usage_config.get("total_tokens", [])
),
)
33 changes: 31 additions & 2 deletions sentry_sdk/ai/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from sentry_sdk._types import BLOB_DATA_SUBSTITUTE

if TYPE_CHECKING:
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple

from sentry_sdk.tracing import Span

Expand All @@ -30,7 +30,7 @@ class GEN_AI_ALLOWED_MESSAGE_ROLES:
GEN_AI_MESSAGE_ROLE_REVERSE_MAPPING = {
GEN_AI_ALLOWED_MESSAGE_ROLES.SYSTEM: ["system"],
GEN_AI_ALLOWED_MESSAGE_ROLES.USER: ["user", "human"],
GEN_AI_ALLOWED_MESSAGE_ROLES.ASSISTANT: ["assistant", "ai"],
GEN_AI_ALLOWED_MESSAGE_ROLES.ASSISTANT: ["assistant", "ai", "chatbot"],
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is needed because cohere sometimes uses the role chatbot in their message structure

GEN_AI_ALLOWED_MESSAGE_ROLES.TOOL: ["tool", "tool_call"],
}

Expand Down Expand Up @@ -725,3 +725,32 @@ def set_conversation_id(conversation_id: str) -> None:
"""
scope = sentry_sdk.get_current_scope()
scope.set_conversation_id(conversation_id)


def transitive_getattr(obj, *attrs):
# type: (Any, str) -> Any
current = obj
for attr in attrs:
current = getattr(current, attr, None)
if current is None:
return None
return current


def get_first_from_sources(obj, source_paths, require_truthy=False):
# type: (Any, Sequence[tuple[str, ...]], bool) -> Any
for source_path in source_paths:
value = transitive_getattr(obj, *source_path)
if not value:
continue
if not require_truthy or value:
return value
return None


def set_span_data_from_sources(span, obj, target_sources, require_truthy):
# type: (Any, Any, Mapping[str, Sequence[tuple[str, ...]]], bool) -> None
for spandata_key, source_paths in target_sources.items():
value = get_first_from_sources(obj, source_paths, require_truthy=require_truthy)
if value is not None:
set_data_normalized(span, spandata_key, value)
Loading
Loading