Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion pyrit/backend/mappers/attack_mappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,11 @@ def attack_result_to_summary(
"""
message_count = stats.message_count
last_preview = stats.last_message_preview
labels = dict(stats.labels) if stats.labels else {}

# Merge attack-result labels with conversation-level labels.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Confused about this comment. Aren't attack-result labels conversation-level?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

today, the lables on "attack summary" (a GUI-related concept) come from "conversation stats" (another GUI-related concept) ... conversation stats are an aggregation of message pieces, along with their labels, which in turn end up on the attack summary.

This change here is saying "attack summary's labels must come not only from conversation stats (i.e. message pieces) but also from attack resutls" ...

Once we remove labels from message_piece , this will also be removed and the only kind of label that ends up on attack summary will be the attack result ones.

this is making me think, also related to the other comment, maybe it's just a good idea to do both at the same time, instead of having an overlap period where both objects ,attack results and message pieces, have labels) - it will avoid confusions like this - I think I'll do this, after introducing a DB migration solution in #1631

# Conversation labels take precedence on key collision.
labels = dict(ar.labels) if ar.labels else {}
labels.update(stats.labels or {})
created_str = ar.metadata.get("created_at")
updated_str = ar.metadata.get("updated_at")
created_at = datetime.fromisoformat(created_str) if created_str else datetime.now(timezone.utc)
Expand Down
1 change: 1 addition & 0 deletions pyrit/backend/services/attack_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ async def create_attack_async(self, *, request: CreateAttackRequest) -> CreateAt
"created_at": now.isoformat(),
"updated_at": now.isoformat(),
},
labels=labels,
)

# Store in memory
Expand Down
1 change: 1 addition & 0 deletions pyrit/executor/attack/multi_turn/chunked_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ async def _perform_async(self, *, context: ChunkedRequestAttackContext) -> Attac
outcome_reason=outcome_reason,
executed_turns=context.executed_turns,
metadata={"combined_chunks": combined_value, "chunk_count": len(context.chunk_responses)},
labels=context.memory_labels,
)

def _determine_attack_outcome(
Expand Down
1 change: 1 addition & 0 deletions pyrit/executor/attack/multi_turn/crescendo.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,7 @@ async def _perform_async(self, *, context: CrescendoAttackContext) -> CrescendoA
last_response=context.last_response.get_piece() if context.last_response else None,
last_score=context.last_score,
related_conversations=context.related_conversations,
labels=context.memory_labels,
)
# setting metadata for backtrack count
result.backtrack_count = context.backtrack_count
Expand Down
1 change: 1 addition & 0 deletions pyrit/executor/attack/multi_turn/multi_prompt_sending.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ async def _perform_async(self, *, context: MultiTurnAttackContext[Any]) -> Attac
outcome=outcome,
outcome_reason=outcome_reason,
executed_turns=context.executed_turns,
labels=context.memory_labels,
)

def _determine_attack_outcome(
Expand Down
1 change: 1 addition & 0 deletions pyrit/executor/attack/multi_turn/red_teaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ async def _perform_async(self, *, context: MultiTurnAttackContext[Any]) -> Attac
last_response=context.last_response.get_piece() if context.last_response else None,
last_score=context.last_score,
related_conversations=context.related_conversations,
labels=context.memory_labels,
)

async def _teardown_async(self, *, context: MultiTurnAttackContext[Any]) -> None:
Expand Down
1 change: 1 addition & 0 deletions pyrit/executor/attack/multi_turn/tree_of_attacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2084,6 +2084,7 @@ def _create_attack_result(
last_response=last_response,
last_score=context.best_objective_score,
related_conversations=context.related_conversations,
labels=context.memory_labels,
)

# Set attack-specific metadata using properties
Expand Down
1 change: 1 addition & 0 deletions pyrit/executor/attack/single_turn/prompt_sending.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ async def _perform_async(self, *, context: SingleTurnAttackContext[Any]) -> Atta
outcome=outcome,
outcome_reason=outcome_reason,
executed_turns=1,
labels=context.memory_labels,
)

def _determine_attack_outcome(
Expand Down
1 change: 1 addition & 0 deletions pyrit/executor/attack/single_turn/skeleton_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,4 +181,5 @@ def _create_skeleton_key_failure_result(self, *, context: SingleTurnAttackContex
outcome=AttackOutcome.FAILURE,
outcome_reason="Skeleton key prompt was filtered or failed",
executed_turns=1,
labels=context.memory_labels,
)
1 change: 1 addition & 0 deletions pyrit/executor/benchmark/fairness_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ async def _perform_async(self, *, context: FairnessBiasBenchmarkContext) -> Atta
atomic_attack_identifier=build_atomic_attack_identifier(
attack_identifier=ComponentIdentifier.of(self),
),
labels=context.memory_labels,
)

return last_attack_result
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

"""
add labels to Attack Results.

Revision ID: 108a72344872
Revises: e373726d391b
Create Date: 2026-04-27 13:47:20.711347
"""

from collections.abc import Sequence

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision: str = "108a72344872"
down_revision: str | None = "e373726d391b"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None


def upgrade() -> None:
"""Apply this schema upgrade."""
# ### commands auto generated by Alembic and reviewed by author ###
op.add_column("AttackResultEntries", sa.Column("labels", sa.JSON(), nullable=True))
# ### end Alembic commands ###


def downgrade() -> None:
"""Revert this schema upgrade."""
# ### commands auto generated by Alembic and reviewed by author ###
op.drop_column("AttackResultEntries", "labels")
# ### end Alembic commands ###
118 changes: 91 additions & 27 deletions pyrit/memory/azure_sql_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any, Literal, Optional, TypeVar, Union, cast

from sqlalchemy import and_, create_engine, event, exists, text
from sqlalchemy import and_, create_engine, event, exists, or_, text
from sqlalchemy.engine.base import Engine
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import InstrumentedAttribute, joinedload, sessionmaker
Expand Down Expand Up @@ -224,27 +224,61 @@ def _add_embeddings_to_memory(self, *, embedding_data: Sequence[EmbeddingDataEnt
"""
self._insert_entries(entries=embedding_data)

def _get_message_pieces_memory_label_conditions(self, *, memory_labels: dict[str, str]) -> list[TextClause]:
def _get_message_pieces_memory_label_conditions(self, *, memory_labels: dict[str, str]) -> list[Any]:
"""
Generate SQL conditions for filtering message pieces by memory labels.

Uses JSON_VALUE() function specific to SQL Azure to query label fields in JSON format.

Matches if labels are on the PromptMemoryEntry itself OR on any
AttackResultEntry that shares the same conversation_id.

Args:
memory_labels (dict[str, str]): Dictionary of label key-value pairs to filter by.

Returns:
list: List containing a single SQLAlchemy text condition with bound parameters.
"""
json_validation = "ISJSON(labels) = 1"
json_conditions = " AND ".join([f"JSON_VALUE(labels, '$.{key}') = :{key}" for key in memory_labels])
# Combine both conditions
conditions = f"{json_validation} AND {json_conditions}"
list: List containing a single SQLAlchemy OR condition with bound parameters.
"""
# Build conditions for direct PME label match
pme_label_parts: list[str] = []
pme_bindparams: dict[str, str] = {}
# Build conditions for AR label match (via exists subquery)
are_label_parts: list[str] = []
are_bindparams: dict[str, str] = {}

for key, value in memory_labels.items():
pme_param = f"pme_ml_{key}"
pme_label_parts.append(f"JSON_VALUE(labels, '$.{key}') = :{pme_param}")
pme_bindparams[pme_param] = str(value)

are_param = f"are_ml_{key}"
are_label_parts.append(f"JSON_VALUE(\"AttackResultEntries\".labels, '$.{key}') = :{are_param}")
are_bindparams[are_param] = str(value)

# Direct PME label match
combined_pme = " AND ".join(pme_label_parts)
pme_match = and_(
PromptMemoryEntry.labels.isnot(None),
cast(
"ColumnElement[bool]",
text(f"ISJSON(labels) = 1 AND {combined_pme}").bindparams(**pme_bindparams),
),
)

# Create SQL condition using SQLAlchemy's text() with bindparams
# for safe parameter passing, preventing SQL injection
condition = text(conditions).bindparams(**{key: str(value) for key, value in memory_labels.items()})
return [condition]
# AR label match via exists subquery
combined_are = " AND ".join(are_label_parts)
are_match = exists().where(
and_(
AttackResultEntry.conversation_id == PromptMemoryEntry.conversation_id,
AttackResultEntry.labels.isnot(None),
cast(
"ColumnElement[bool]",
text(f'ISJSON("AttackResultEntries".labels) = 1 AND {combined_are}').bindparams(**are_bindparams),
),
)
)

return [or_(pme_match, are_match)]

def _get_metadata_conditions(self, *, prompt_metadata: dict[str, Union[str, int]]) -> list[TextClause]:
"""
Expand Down Expand Up @@ -450,36 +484,66 @@ def _get_attack_result_label_condition(self, *, labels: dict[str, str | Sequence
"""
Azure SQL implementation for filtering AttackResults by labels.

Matches if labels are on any associated PromptMemoryEntry OR directly
on the AttackResultEntry itself.

Uses JSON_VALUE() with parameterized IN clauses. See
``MemoryInterface._get_attack_result_label_condition`` for semantics.

Returns:
Any: SQLAlchemy exists subquery condition with bound parameters.
Any: SQLAlchemy condition with bound parameters.
"""
label_conditions: list[str] = []
bindparams_dict: dict[str, str] = {}
# Build conditions for PromptMemoryEntry labels (via exists subquery)
pme_label_conditions: list[str] = []
pme_bindparams: dict[str, str] = {}
# Build conditions for AttackResultEntry labels (direct match)
are_label_conditions: list[str] = []
are_bindparams: dict[str, str] = {}

for key, raw_value in labels.items():
values = [raw_value] if isinstance(raw_value, str) else list(raw_value)
if not values:
continue
placeholders = []
pme_placeholders = []
are_placeholders = []
for idx, v in enumerate(values):
param_name = f"label_{key}_{idx}"
placeholders.append(f":{param_name}")
bindparams_dict[param_name] = str(v)
label_conditions.append(f"JSON_VALUE(labels, '$.{key}') IN ({', '.join(placeholders)})")

base = [
pme_param = f"pme_label_{key}_{idx}"
pme_placeholders.append(f":{pme_param}")
pme_bindparams[pme_param] = str(v)
are_param = f"are_label_{key}_{idx}"
are_placeholders.append(f":{are_param}")
are_bindparams[are_param] = str(v)
pme_label_conditions.append(f"JSON_VALUE(labels, '$.{key}') IN ({', '.join(pme_placeholders)})")
are_label_conditions.append(f"JSON_VALUE(labels, '$.{key}') IN ({', '.join(are_placeholders)})")

# PromptMemoryEntry subquery
pme_base: list[Any] = [
PromptMemoryEntry.conversation_id == AttackResultEntry.conversation_id,
PromptMemoryEntry.labels.isnot(None),
]
if label_conditions:
combined = " AND ".join(label_conditions)
base.append(
cast("ColumnElement[bool]", text(f"ISJSON(labels) = 1 AND {combined}").bindparams(**bindparams_dict))
if pme_label_conditions:
combined_pme = " AND ".join(pme_label_conditions)
pme_base.append(
cast(
"ColumnElement[bool]",
text(f"ISJSON(labels) = 1 AND {combined_pme}").bindparams(**pme_bindparams),
)
)
pme_match = exists().where(and_(*pme_base))

# Direct AttackResultEntry label match
are_parts: list[Any] = [AttackResultEntry.labels.isnot(None)]
if are_label_conditions:
combined_are = " AND ".join(are_label_conditions)
are_parts.append(
cast(
"ColumnElement[bool]",
text(f"ISJSON(labels) = 1 AND {combined_are}").bindparams(**are_bindparams),
)
)
are_match = and_(*are_parts)

return exists().where(and_(*base))
return or_(pme_match, are_match)

def get_unique_attack_class_names(self) -> list[str]:
"""
Expand Down
7 changes: 5 additions & 2 deletions pyrit/memory/memory_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,8 +434,11 @@ def _get_attack_result_harm_category_condition(self, *, targeted_harm_categories
@abc.abstractmethod
def _get_attack_result_label_condition(self, *, labels: dict[str, str | Sequence[str]]) -> Any:
"""
Return a database-specific condition for filtering AttackResults by labels
in the associated PromptMemoryEntry records.
Return a database-specific condition for filtering AttackResults by labels.
Matches if the labels are present on **either** an associated
PromptMemoryEntry (via conversation_id) **or** directly on the
AttackResultEntry itself.
Semantics: entries are AND-combined across label names; within a single
entry, a string value is an equality match and a sequence value is an
Expand Down
4 changes: 4 additions & 0 deletions pyrit/memory/memory_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,7 @@ class AttackResultEntry(Base):
outcome (AttackOutcome): The outcome of the attack, indicating success, failure, or undetermined.
outcome_reason (str): Optional reason for the outcome, providing additional context.
attack_metadata (dict[str, Any]): Metadata can be included as key-value pairs to provide extra context.
labels (dict[str, str]): Optional labels associated with the attack result entry.
pruned_conversation_ids (List[str]): List of conversation IDs that were pruned from the attack.
adversarial_chat_conversation_ids (List[str]): List of conversation IDs used for adversarial chat.
timestamp (DateTime): The timestamp of the attack result entry.
Expand Down Expand Up @@ -728,6 +729,7 @@ class AttackResultEntry(Base):
)
outcome_reason = mapped_column(String, nullable=True)
attack_metadata: Mapped[dict[str, Union[str, int, float, bool]]] = mapped_column(JSON, nullable=True)
labels: Mapped[dict[str, str]] = mapped_column(JSON, nullable=True)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason for adding on AR and NOT removing them from MP?

pruned_conversation_ids: Mapped[Optional[list[str]]] = mapped_column(JSON, nullable=True)
adversarial_chat_conversation_ids: Mapped[Optional[list[str]]] = mapped_column(JSON, nullable=True)
timestamp = mapped_column(DateTime, nullable=False)
Expand Down Expand Up @@ -783,6 +785,7 @@ def __init__(self, *, entry: AttackResult):
self.outcome = entry.outcome.value
self.outcome_reason = entry.outcome_reason
self.attack_metadata = self.filter_json_serializable_metadata(entry.metadata)
self.labels = entry.labels or {}

# Persist conversation references by type
self.pruned_conversation_ids = [
Expand Down Expand Up @@ -894,6 +897,7 @@ def get_attack_result(self) -> AttackResult:
outcome_reason=self.outcome_reason,
related_conversations=related_conversations,
metadata=self.attack_metadata or {},
labels=self.labels or {},
)


Expand Down
Loading
Loading