diff --git a/pyrit/backend/mappers/attack_mappers.py b/pyrit/backend/mappers/attack_mappers.py index 0245e2af12..c37dd77fd9 100644 --- a/pyrit/backend/mappers/attack_mappers.py +++ b/pyrit/backend/mappers/attack_mappers.py @@ -197,8 +197,11 @@ def attack_result_to_summary( """ message_count = stats.message_count last_preview = stats.last_message_preview - labels = dict(stats.labels) if stats.labels else {} + # Merge attack-result labels with conversation-level labels. + # Conversation labels take precedence on key collision. + labels = dict(ar.labels) if ar.labels else {} + labels.update(stats.labels or {}) created_str = ar.metadata.get("created_at") updated_str = ar.metadata.get("updated_at") created_at = datetime.fromisoformat(created_str) if created_str else datetime.now(timezone.utc) diff --git a/pyrit/backend/services/attack_service.py b/pyrit/backend/services/attack_service.py index 41b98739aa..63b8123d28 100644 --- a/pyrit/backend/services/attack_service.py +++ b/pyrit/backend/services/attack_service.py @@ -330,6 +330,7 @@ async def create_attack_async(self, *, request: CreateAttackRequest) -> CreateAt "created_at": now.isoformat(), "updated_at": now.isoformat(), }, + labels=labels, ) # Store in memory diff --git a/pyrit/executor/attack/multi_turn/chunked_request.py b/pyrit/executor/attack/multi_turn/chunked_request.py index 1a70c89195..ed95c5d226 100644 --- a/pyrit/executor/attack/multi_turn/chunked_request.py +++ b/pyrit/executor/attack/multi_turn/chunked_request.py @@ -325,6 +325,7 @@ async def _perform_async(self, *, context: ChunkedRequestAttackContext) -> Attac outcome_reason=outcome_reason, executed_turns=context.executed_turns, metadata={"combined_chunks": combined_value, "chunk_count": len(context.chunk_responses)}, + labels=context.memory_labels, ) def _determine_attack_outcome( diff --git a/pyrit/executor/attack/multi_turn/crescendo.py b/pyrit/executor/attack/multi_turn/crescendo.py index 4a180d5df3..f137b322f3 100644 --- a/pyrit/executor/attack/multi_turn/crescendo.py +++ b/pyrit/executor/attack/multi_turn/crescendo.py @@ -402,6 +402,7 @@ async def _perform_async(self, *, context: CrescendoAttackContext) -> CrescendoA last_response=context.last_response.get_piece() if context.last_response else None, last_score=context.last_score, related_conversations=context.related_conversations, + labels=context.memory_labels, ) # setting metadata for backtrack count result.backtrack_count = context.backtrack_count diff --git a/pyrit/executor/attack/multi_turn/multi_prompt_sending.py b/pyrit/executor/attack/multi_turn/multi_prompt_sending.py index a9d4b75adc..8447737578 100644 --- a/pyrit/executor/attack/multi_turn/multi_prompt_sending.py +++ b/pyrit/executor/attack/multi_turn/multi_prompt_sending.py @@ -295,6 +295,7 @@ async def _perform_async(self, *, context: MultiTurnAttackContext[Any]) -> Attac outcome=outcome, outcome_reason=outcome_reason, executed_turns=context.executed_turns, + labels=context.memory_labels, ) def _determine_attack_outcome( diff --git a/pyrit/executor/attack/multi_turn/red_teaming.py b/pyrit/executor/attack/multi_turn/red_teaming.py index a8778f664a..1feec20586 100644 --- a/pyrit/executor/attack/multi_turn/red_teaming.py +++ b/pyrit/executor/attack/multi_turn/red_teaming.py @@ -322,6 +322,7 @@ async def _perform_async(self, *, context: MultiTurnAttackContext[Any]) -> Attac last_response=context.last_response.get_piece() if context.last_response else None, last_score=context.last_score, related_conversations=context.related_conversations, + labels=context.memory_labels, ) async def _teardown_async(self, *, context: MultiTurnAttackContext[Any]) -> None: diff --git a/pyrit/executor/attack/multi_turn/tree_of_attacks.py b/pyrit/executor/attack/multi_turn/tree_of_attacks.py index 7ea7f927b7..aa2293576e 100644 --- a/pyrit/executor/attack/multi_turn/tree_of_attacks.py +++ b/pyrit/executor/attack/multi_turn/tree_of_attacks.py @@ -2084,6 +2084,7 @@ def _create_attack_result( last_response=last_response, last_score=context.best_objective_score, related_conversations=context.related_conversations, + labels=context.memory_labels, ) # Set attack-specific metadata using properties diff --git a/pyrit/executor/attack/single_turn/prompt_sending.py b/pyrit/executor/attack/single_turn/prompt_sending.py index 07f1d670fa..cdb2d4b619 100644 --- a/pyrit/executor/attack/single_turn/prompt_sending.py +++ b/pyrit/executor/attack/single_turn/prompt_sending.py @@ -238,6 +238,7 @@ async def _perform_async(self, *, context: SingleTurnAttackContext[Any]) -> Atta outcome=outcome, outcome_reason=outcome_reason, executed_turns=1, + labels=context.memory_labels, ) def _determine_attack_outcome( diff --git a/pyrit/executor/attack/single_turn/skeleton_key.py b/pyrit/executor/attack/single_turn/skeleton_key.py index 683614dce5..40cc5cc302 100644 --- a/pyrit/executor/attack/single_turn/skeleton_key.py +++ b/pyrit/executor/attack/single_turn/skeleton_key.py @@ -181,4 +181,5 @@ def _create_skeleton_key_failure_result(self, *, context: SingleTurnAttackContex outcome=AttackOutcome.FAILURE, outcome_reason="Skeleton key prompt was filtered or failed", executed_turns=1, + labels=context.memory_labels, ) diff --git a/pyrit/executor/benchmark/fairness_bias.py b/pyrit/executor/benchmark/fairness_bias.py index 05bb424c17..63d33f4639 100644 --- a/pyrit/executor/benchmark/fairness_bias.py +++ b/pyrit/executor/benchmark/fairness_bias.py @@ -200,6 +200,7 @@ async def _perform_async(self, *, context: FairnessBiasBenchmarkContext) -> Atta atomic_attack_identifier=build_atomic_attack_identifier( attack_identifier=ComponentIdentifier.of(self), ), + labels=context.memory_labels, ) return last_attack_result diff --git a/pyrit/memory/alembic/versions/108a72344872_add_labels_to_attack_results.py b/pyrit/memory/alembic/versions/108a72344872_add_labels_to_attack_results.py new file mode 100644 index 0000000000..66b1574b93 --- /dev/null +++ b/pyrit/memory/alembic/versions/108a72344872_add_labels_to_attack_results.py @@ -0,0 +1,35 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +add labels to Attack Results. + +Revision ID: 108a72344872 +Revises: e373726d391b +Create Date: 2026-04-27 13:47:20.711347 +""" + +from collections.abc import Sequence + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "108a72344872" +down_revision: str | None = "e373726d391b" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Apply this schema upgrade.""" + # ### commands auto generated by Alembic and reviewed by author ### + op.add_column("AttackResultEntries", sa.Column("labels", sa.JSON(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Revert this schema upgrade.""" + # ### commands auto generated by Alembic and reviewed by author ### + op.drop_column("AttackResultEntries", "labels") + # ### end Alembic commands ### diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index a0d4f69686..6a33f63098 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -9,7 +9,7 @@ from datetime import datetime, timedelta, timezone from typing import TYPE_CHECKING, Any, Literal, Optional, TypeVar, Union, cast -from sqlalchemy import and_, create_engine, event, exists, text +from sqlalchemy import and_, create_engine, event, exists, or_, text from sqlalchemy.engine.base import Engine from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import InstrumentedAttribute, joinedload, sessionmaker @@ -224,27 +224,61 @@ def _add_embeddings_to_memory(self, *, embedding_data: Sequence[EmbeddingDataEnt """ self._insert_entries(entries=embedding_data) - def _get_message_pieces_memory_label_conditions(self, *, memory_labels: dict[str, str]) -> list[TextClause]: + def _get_message_pieces_memory_label_conditions(self, *, memory_labels: dict[str, str]) -> list[Any]: """ Generate SQL conditions for filtering message pieces by memory labels. Uses JSON_VALUE() function specific to SQL Azure to query label fields in JSON format. + Matches if labels are on the PromptMemoryEntry itself OR on any + AttackResultEntry that shares the same conversation_id. + Args: memory_labels (dict[str, str]): Dictionary of label key-value pairs to filter by. Returns: - list: List containing a single SQLAlchemy text condition with bound parameters. - """ - json_validation = "ISJSON(labels) = 1" - json_conditions = " AND ".join([f"JSON_VALUE(labels, '$.{key}') = :{key}" for key in memory_labels]) - # Combine both conditions - conditions = f"{json_validation} AND {json_conditions}" + list: List containing a single SQLAlchemy OR condition with bound parameters. + """ + # Build conditions for direct PME label match + pme_label_parts: list[str] = [] + pme_bindparams: dict[str, str] = {} + # Build conditions for AR label match (via exists subquery) + are_label_parts: list[str] = [] + are_bindparams: dict[str, str] = {} + + for key, value in memory_labels.items(): + pme_param = f"pme_ml_{key}" + pme_label_parts.append(f"JSON_VALUE(labels, '$.{key}') = :{pme_param}") + pme_bindparams[pme_param] = str(value) + + are_param = f"are_ml_{key}" + are_label_parts.append(f"JSON_VALUE(\"AttackResultEntries\".labels, '$.{key}') = :{are_param}") + are_bindparams[are_param] = str(value) + + # Direct PME label match + combined_pme = " AND ".join(pme_label_parts) + pme_match = and_( + PromptMemoryEntry.labels.isnot(None), + cast( + "ColumnElement[bool]", + text(f"ISJSON(labels) = 1 AND {combined_pme}").bindparams(**pme_bindparams), + ), + ) - # Create SQL condition using SQLAlchemy's text() with bindparams - # for safe parameter passing, preventing SQL injection - condition = text(conditions).bindparams(**{key: str(value) for key, value in memory_labels.items()}) - return [condition] + # AR label match via exists subquery + combined_are = " AND ".join(are_label_parts) + are_match = exists().where( + and_( + AttackResultEntry.conversation_id == PromptMemoryEntry.conversation_id, + AttackResultEntry.labels.isnot(None), + cast( + "ColumnElement[bool]", + text(f'ISJSON("AttackResultEntries".labels) = 1 AND {combined_are}').bindparams(**are_bindparams), + ), + ) + ) + + return [or_(pme_match, are_match)] def _get_metadata_conditions(self, *, prompt_metadata: dict[str, Union[str, int]]) -> list[TextClause]: """ @@ -450,36 +484,66 @@ def _get_attack_result_label_condition(self, *, labels: dict[str, str | Sequence """ Azure SQL implementation for filtering AttackResults by labels. + Matches if labels are on any associated PromptMemoryEntry OR directly + on the AttackResultEntry itself. + Uses JSON_VALUE() with parameterized IN clauses. See ``MemoryInterface._get_attack_result_label_condition`` for semantics. Returns: - Any: SQLAlchemy exists subquery condition with bound parameters. + Any: SQLAlchemy condition with bound parameters. """ - label_conditions: list[str] = [] - bindparams_dict: dict[str, str] = {} + # Build conditions for PromptMemoryEntry labels (via exists subquery) + pme_label_conditions: list[str] = [] + pme_bindparams: dict[str, str] = {} + # Build conditions for AttackResultEntry labels (direct match) + are_label_conditions: list[str] = [] + are_bindparams: dict[str, str] = {} + for key, raw_value in labels.items(): values = [raw_value] if isinstance(raw_value, str) else list(raw_value) if not values: continue - placeholders = [] + pme_placeholders = [] + are_placeholders = [] for idx, v in enumerate(values): - param_name = f"label_{key}_{idx}" - placeholders.append(f":{param_name}") - bindparams_dict[param_name] = str(v) - label_conditions.append(f"JSON_VALUE(labels, '$.{key}') IN ({', '.join(placeholders)})") - - base = [ + pme_param = f"pme_label_{key}_{idx}" + pme_placeholders.append(f":{pme_param}") + pme_bindparams[pme_param] = str(v) + are_param = f"are_label_{key}_{idx}" + are_placeholders.append(f":{are_param}") + are_bindparams[are_param] = str(v) + pme_label_conditions.append(f"JSON_VALUE(labels, '$.{key}') IN ({', '.join(pme_placeholders)})") + are_label_conditions.append(f"JSON_VALUE(labels, '$.{key}') IN ({', '.join(are_placeholders)})") + + # PromptMemoryEntry subquery + pme_base: list[Any] = [ PromptMemoryEntry.conversation_id == AttackResultEntry.conversation_id, PromptMemoryEntry.labels.isnot(None), ] - if label_conditions: - combined = " AND ".join(label_conditions) - base.append( - cast("ColumnElement[bool]", text(f"ISJSON(labels) = 1 AND {combined}").bindparams(**bindparams_dict)) + if pme_label_conditions: + combined_pme = " AND ".join(pme_label_conditions) + pme_base.append( + cast( + "ColumnElement[bool]", + text(f"ISJSON(labels) = 1 AND {combined_pme}").bindparams(**pme_bindparams), + ) + ) + pme_match = exists().where(and_(*pme_base)) + + # Direct AttackResultEntry label match + are_parts: list[Any] = [AttackResultEntry.labels.isnot(None)] + if are_label_conditions: + combined_are = " AND ".join(are_label_conditions) + are_parts.append( + cast( + "ColumnElement[bool]", + text(f"ISJSON(labels) = 1 AND {combined_are}").bindparams(**are_bindparams), + ) ) + are_match = and_(*are_parts) - return exists().where(and_(*base)) + return or_(pme_match, are_match) def get_unique_attack_class_names(self) -> list[str]: """ diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 79017f89fe..557631ed52 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -434,8 +434,11 @@ def _get_attack_result_harm_category_condition(self, *, targeted_harm_categories @abc.abstractmethod def _get_attack_result_label_condition(self, *, labels: dict[str, str | Sequence[str]]) -> Any: """ - Return a database-specific condition for filtering AttackResults by labels - in the associated PromptMemoryEntry records. + Return a database-specific condition for filtering AttackResults by labels. + + Matches if the labels are present on **either** an associated + PromptMemoryEntry (via conversation_id) **or** directly on the + AttackResultEntry itself. Semantics: entries are AND-combined across label names; within a single entry, a string value is an equality match and a sequence value is an diff --git a/pyrit/memory/memory_models.py b/pyrit/memory/memory_models.py index 5a11fa78c7..42a9431108 100644 --- a/pyrit/memory/memory_models.py +++ b/pyrit/memory/memory_models.py @@ -697,6 +697,7 @@ class AttackResultEntry(Base): outcome (AttackOutcome): The outcome of the attack, indicating success, failure, or undetermined. outcome_reason (str): Optional reason for the outcome, providing additional context. attack_metadata (dict[str, Any]): Metadata can be included as key-value pairs to provide extra context. + labels (dict[str, str]): Optional labels associated with the attack result entry. pruned_conversation_ids (List[str]): List of conversation IDs that were pruned from the attack. adversarial_chat_conversation_ids (List[str]): List of conversation IDs used for adversarial chat. timestamp (DateTime): The timestamp of the attack result entry. @@ -728,6 +729,7 @@ class AttackResultEntry(Base): ) outcome_reason = mapped_column(String, nullable=True) attack_metadata: Mapped[dict[str, Union[str, int, float, bool]]] = mapped_column(JSON, nullable=True) + labels: Mapped[dict[str, str]] = mapped_column(JSON, nullable=True) pruned_conversation_ids: Mapped[Optional[list[str]]] = mapped_column(JSON, nullable=True) adversarial_chat_conversation_ids: Mapped[Optional[list[str]]] = mapped_column(JSON, nullable=True) timestamp = mapped_column(DateTime, nullable=False) @@ -783,6 +785,7 @@ def __init__(self, *, entry: AttackResult): self.outcome = entry.outcome.value self.outcome_reason = entry.outcome_reason self.attack_metadata = self.filter_json_serializable_metadata(entry.metadata) + self.labels = entry.labels or {} # Persist conversation references by type self.pruned_conversation_ids = [ @@ -894,6 +897,7 @@ def get_attack_result(self) -> AttackResult: outcome_reason=self.outcome_reason, related_conversations=related_conversations, metadata=self.attack_metadata or {}, + labels=self.labels or {}, ) diff --git a/pyrit/memory/sqlite_memory.py b/pyrit/memory/sqlite_memory.py index b910f68953..a0abc6b82b 100644 --- a/pyrit/memory/sqlite_memory.py +++ b/pyrit/memory/sqlite_memory.py @@ -10,7 +10,7 @@ from pathlib import Path from typing import Any, Literal, Optional, TypeVar, Union, cast -from sqlalchemy import and_, create_engine, func, or_, text +from sqlalchemy import and_, create_engine, exists, func, or_, text from sqlalchemy.engine.base import Engine from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import InstrumentedAttribute, joinedload, sessionmaker @@ -142,21 +142,37 @@ def get_all_embeddings(self) -> Sequence[EmbeddingDataEntry]: result: Sequence[EmbeddingDataEntry] = self._query_entries(EmbeddingDataEntry) return result - def _get_message_pieces_memory_label_conditions(self, *, memory_labels: dict[str, str]) -> list[TextClause]: + def _get_message_pieces_memory_label_conditions(self, *, memory_labels: dict[str, str]) -> list[Any]: """ Generate SQLAlchemy filter conditions for filtering conversation pieces by memory labels. For SQLite, we use JSON_EXTRACT function to handle JSON fields. + Matches if labels are on the PromptMemoryEntry itself OR on any + AttackResultEntry that shares the same conversation_id. + Returns: list: A list of SQLAlchemy conditions. """ - # For SQLite, we use JSON_EXTRACT with text() and bindparams similar to Azure SQL approach - json_conditions = " AND ".join([f"JSON_EXTRACT(labels, '$.{key}') = :{key}" for key in memory_labels]) - - # Create SQL condition using SQLAlchemy's text() with bindparams - # for safe parameter passing, preventing SQL injection - condition = text(json_conditions).bindparams(**{key: str(value) for key, value in memory_labels.items()}) - return [condition] + per_key_pme_conditions = [] + per_key_are_conditions = [] + for key, value in memory_labels.items(): + pme_col = func.json_extract(PromptMemoryEntry.labels, f"$.{key}") + per_key_pme_conditions.append(pme_col == str(value)) + are_col = func.json_extract(AttackResultEntry.labels, f"$.{key}") + per_key_are_conditions.append(are_col == str(value)) + + pme_match = and_( + PromptMemoryEntry.labels.isnot(None), + *per_key_pme_conditions, + ) + are_match = exists().where( + and_( + AttackResultEntry.conversation_id == PromptMemoryEntry.conversation_id, + AttackResultEntry.labels.isnot(None), + *per_key_are_conditions, + ) + ) + return [or_(pme_match, are_match)] def _get_message_pieces_prompt_metadata_conditions( self, *, prompt_metadata: dict[str, Union[str, int]] @@ -605,32 +621,43 @@ def _get_attack_result_label_condition(self, *, labels: dict[str, str | Sequence SQLite implementation for filtering AttackResults by labels. Uses json_extract() function specific to SQLite. + Matches if labels are on any associated PromptMemoryEntry OR directly + on the AttackResultEntry itself. + Keys are AND-combined. For each key, a string value is an equality match; a sequence value is an OR-within-key match (any listed value matches). Empty sequences are no-ops (no constraint on that key). Returns: - Any: A SQLAlchemy subquery for filtering by labels. + Any: A SQLAlchemy condition for filtering by labels. """ from sqlalchemy import and_, exists, func from pyrit.memory.memory_models import AttackResultEntry, PromptMemoryEntry - per_key_conditions = [] + per_key_pme_conditions = [] + per_key_are_conditions = [] for key, raw_value in labels.items(): values = [raw_value] if isinstance(raw_value, str) else list(raw_value) if not values: continue - col = func.json_extract(PromptMemoryEntry.labels, f"$.{key}") - per_key_conditions.append(col.in_(values)) + pme_col = func.json_extract(PromptMemoryEntry.labels, f"$.{key}") + per_key_pme_conditions.append(pme_col.in_(values)) + are_col = func.json_extract(AttackResultEntry.labels, f"$.{key}") + per_key_are_conditions.append(are_col.in_(values)) - return exists().where( + pme_match = exists().where( and_( PromptMemoryEntry.conversation_id == AttackResultEntry.conversation_id, PromptMemoryEntry.labels.isnot(None), - and_(*per_key_conditions), + and_(*per_key_pme_conditions), ) ) + are_match = and_( + AttackResultEntry.labels.isnot(None), + *per_key_are_conditions, + ) + return or_(pme_match, are_match) def get_unique_attack_class_names(self) -> list[str]: """ diff --git a/pyrit/models/attack_result.py b/pyrit/models/attack_result.py index a385ac36e7..5cbdf3c93e 100644 --- a/pyrit/models/attack_result.py +++ b/pyrit/models/attack_result.py @@ -87,6 +87,9 @@ class AttackResult(StrategyResult): # Arbitrary metadata metadata: dict[str, Any] = field(default_factory=dict) + # labels associated with this attack result + labels: dict[str, str] = field(default_factory=dict) + @property def attack_identifier(self) -> Optional[ComponentIdentifier]: """ diff --git a/tests/unit/backend/test_attack_service.py b/tests/unit/backend/test_attack_service.py index 690b8a2e3f..c4183eb58f 100644 --- a/tests/unit/backend/test_attack_service.py +++ b/tests/unit/backend/test_attack_service.py @@ -97,6 +97,7 @@ def make_attack_result( "created_at": created.isoformat(), "updated_at": updated.isoformat(), }, + labels={"test_ar_label": "test_ar_value"}, ) @@ -431,7 +432,7 @@ async def test_list_attacks_includes_labels_in_summary(self, attack_service, moc result = await attack_service.list_attacks_async() assert len(result.items) == 1 - assert result.items[0].labels == {"env": "prod", "team": "red"} + assert result.items[0].labels == {"env": "prod", "team": "red", "test_ar_label": "test_ar_value"} @pytest.mark.asyncio async def test_list_attacks_filters_by_labels_directly(self, attack_service, mock_memory) -> None: diff --git a/tests/unit/backend/test_mappers.py b/tests/unit/backend/test_mappers.py index 0f483b3f10..ad6ef1a380 100644 --- a/tests/unit/backend/test_mappers.py +++ b/tests/unit/backend/test_mappers.py @@ -81,6 +81,7 @@ def _make_attack_result( "created_at": now.isoformat(), "updated_at": now.isoformat(), }, + labels={"test_ar_label": "test_ar_value"}, ) @@ -175,7 +176,7 @@ def test_labels_are_mapped(self) -> None: summary = attack_result_to_summary(ar, stats=stats) - assert summary.labels == {"env": "prod", "team": "red"} + assert summary.labels == {"env": "prod", "team": "red", "test_ar_label": "test_ar_value"} def test_labels_passed_through_without_normalization(self) -> None: """Test that labels are passed through as-is (DB stores canonical keys after migration).""" @@ -187,7 +188,24 @@ def test_labels_passed_through_without_normalization(self) -> None: summary = attack_result_to_summary(ar, stats=stats) - assert summary.labels == {"operator": "alice", "operation": "op_red", "env": "prod"} + assert summary.labels == { + "operator": "alice", + "operation": "op_red", + "env": "prod", + "test_ar_label": "test_ar_value", + } + + def test_conversation_labels_take_precedence_on_collision(self) -> None: + """Test that conversation-level labels override attack-result labels on key collision.""" + ar = _make_attack_result() + stats = ConversationStats( + message_count=1, + labels={"test_ar_label": "conversation_wins"}, + ) + + summary = attack_result_to_summary(ar, stats=stats) + + assert summary.labels["test_ar_label"] == "conversation_wins" def test_outcome_success(self) -> None: """Test that success outcome is mapped.""" @@ -249,6 +267,7 @@ def test_converters_extracted_from_identifier(self) -> None: ), outcome=AttackOutcome.UNDETERMINED, metadata={"created_at": now.isoformat(), "updated_at": now.isoformat()}, + labels={"test_label": "test_value"}, ) summary = attack_result_to_summary(ar, stats=ConversationStats(message_count=0)) diff --git a/tests/unit/memory/memory_interface/test_interface_attack_results.py b/tests/unit/memory/memory_interface/test_interface_attack_results.py index 45fee0476f..0d71392e3e 100644 --- a/tests/unit/memory/memory_interface/test_interface_attack_results.py +++ b/tests/unit/memory/memory_interface/test_interface_attack_results.py @@ -38,12 +38,18 @@ def create_message_piece(conversation_id: str, prompt_num: int, targeted_harm_ca ) -def create_attack_result(conversation_id: str, objective_num: int, outcome: AttackOutcome = AttackOutcome.SUCCESS): +def create_attack_result( + conversation_id: str, + objective_num: int, + outcome: AttackOutcome = AttackOutcome.SUCCESS, + labels: dict[str, str] | None = None, +): """Helper function to create AttackResult.""" return AttackResult( conversation_id=conversation_id, objective=f"Objective {objective_num}", outcome=outcome, + labels=labels or {}, ) @@ -782,17 +788,14 @@ def test_get_attack_results_by_harm_category_multiple(sqlite_instance: MemoryInt def test_get_attack_results_by_labels_single(sqlite_instance: MemoryInterface): """Test filtering attack results by single label.""" - # Create message pieces with labels - message_piece1 = create_message_piece("conv_1", 1, labels={"operation": "test_op", "operator": "roakey"}) - message_piece2 = create_message_piece("conv_2", 2, labels={"operation": "test_op"}) - message_piece3 = create_message_piece("conv_3", 3, labels={"operation": "other_op", "operator": "roakey"}) - - sqlite_instance.add_message_pieces_to_memory(message_pieces=[message_piece1, message_piece2, message_piece3]) - - # Create attack results - attack_result1 = create_attack_result("conv_1", 1, AttackOutcome.SUCCESS) - attack_result2 = create_attack_result("conv_2", 2, AttackOutcome.FAILURE) - attack_result3 = create_attack_result("conv_3", 3, AttackOutcome.SUCCESS) + # Create attack results with labels + attack_result1 = create_attack_result( + "conv_1", 1, AttackOutcome.SUCCESS, labels={"operation": "test_op", "operator": "roakey"} + ) + attack_result2 = create_attack_result("conv_2", 2, AttackOutcome.FAILURE, labels={"operation": "test_op"}) + attack_result3 = create_attack_result( + "conv_3", 3, AttackOutcome.SUCCESS, labels={"operation": "other_op", "operator": "roakey"} + ) sqlite_instance.add_attack_results_to_memory(attack_results=[attack_result1, attack_result2, attack_result3]) @@ -859,22 +862,26 @@ def test_get_attack_results_rejects_invalid_label_keys(sqlite_instance: MemoryIn def test_get_attack_results_by_labels_multiple(sqlite_instance: MemoryInterface): """Test filtering attack results by multiple labels (AND logic).""" - # Create message pieces with multiple labels using helper function - message_piece1 = create_message_piece( - "conv_1", 1, labels={"operation": "test_op", "operator": "roakey", "phase": "initial"} - ) - message_piece2 = create_message_piece( - "conv_2", 2, labels={"operation": "test_op", "operator": "roakey", "phase": "final"} - ) - message_piece3 = create_message_piece("conv_3", 3, labels={"operation": "test_op", "phase": "initial"}) - - sqlite_instance.add_message_pieces_to_memory(message_pieces=[message_piece1, message_piece2, message_piece3]) - - # Create attack results + # Create attack results with multiple labels attack_results = [ - create_attack_result("conv_1", 1, AttackOutcome.SUCCESS), - create_attack_result("conv_2", 2, AttackOutcome.SUCCESS), - create_attack_result("conv_3", 3, AttackOutcome.FAILURE), + create_attack_result( + "conv_1", + 1, + AttackOutcome.SUCCESS, + labels={"operation": "test_op", "operator": "roakey", "phase": "initial"}, + ), + create_attack_result( + "conv_2", + 2, + AttackOutcome.SUCCESS, + labels={"operation": "test_op", "operator": "roakey", "phase": "final"}, + ), + create_attack_result( + "conv_3", + 3, + AttackOutcome.FAILURE, + labels={"operation": "test_op", "phase": "initial"}, + ), ] sqlite_instance.add_attack_results_to_memory(attack_results=attack_results) @@ -937,30 +944,18 @@ def test_get_attack_results_by_labels_or_within_key_and_across_keys(sqlite_insta def test_get_attack_results_by_harm_category_and_labels(sqlite_instance: MemoryInterface): """Test filtering attack results by both harm categories and labels.""" - # Create message pieces with both harm categories and labels using helper function - message_piece1 = create_message_piece( - "conv_1", - 1, - targeted_harm_categories=["violence", "illegal"], - labels={"operation": "test_op", "operator": "roakey"}, - ) - message_piece2 = create_message_piece( - "conv_2", 2, targeted_harm_categories=["violence"], labels={"operation": "test_op", "operator": "roakey"} - ) - message_piece3 = create_message_piece( - "conv_3", - 3, - targeted_harm_categories=["violence", "illegal"], - labels={"operation": "other_op", "operator": "bob"}, - ) + # Create message pieces with harm categories (harm categories still live on PromptMemoryEntry) + message_piece1 = create_message_piece("conv_1", 1, targeted_harm_categories=["violence", "illegal"]) + message_piece2 = create_message_piece("conv_2", 2, targeted_harm_categories=["violence"]) + message_piece3 = create_message_piece("conv_3", 3, targeted_harm_categories=["violence", "illegal"]) sqlite_instance.add_message_pieces_to_memory(message_pieces=[message_piece1, message_piece2, message_piece3]) - # Create attack results + # Create attack results with labels attack_results = [ - create_attack_result("conv_1", 1, AttackOutcome.SUCCESS), - create_attack_result("conv_2", 2, AttackOutcome.SUCCESS), - create_attack_result("conv_3", 3, AttackOutcome.FAILURE), + create_attack_result("conv_1", 1, AttackOutcome.SUCCESS, labels={"operation": "test_op", "operator": "roakey"}), + create_attack_result("conv_2", 2, AttackOutcome.SUCCESS, labels={"operation": "test_op", "operator": "roakey"}), + create_attack_result("conv_3", 3, AttackOutcome.FAILURE, labels={"operation": "other_op", "operator": "bob"}), ] sqlite_instance.add_attack_results_to_memory(attack_results=attack_results) @@ -999,11 +994,8 @@ def test_get_attack_results_harm_category_no_matches(sqlite_instance: MemoryInte def test_get_attack_results_labels_no_matches(sqlite_instance: MemoryInterface): """Test filtering by labels that don't exist.""" - # Create attack result without the labels we'll search for - message_piece = create_message_piece("conv_1", 1, labels={"operation": "test_op"}) - sqlite_instance.add_message_pieces_to_memory(message_pieces=[message_piece]) - - attack_result = create_attack_result("conv_1", 1, AttackOutcome.SUCCESS) + # Create attack result with labels that don't match the search + attack_result = create_attack_result("conv_1", 1, AttackOutcome.SUCCESS, labels={"operation": "test_op"}) sqlite_instance.add_attack_results_to_memory(attack_results=[attack_result]) # Search for non-existent labels @@ -1015,11 +1007,6 @@ def test_get_attack_results_labels_query_on_empty_labels(sqlite_instance: Memory """Test querying for labels when records have no labels at all""" # Create attack results with NO labels - message_piece1 = create_message_piece("conv_1", 1) - message_piece2 = create_message_piece("conv_2", 1) - - sqlite_instance.add_message_pieces_to_memory(message_pieces=[message_piece1, message_piece2]) - attack_result1 = create_attack_result("conv_1", 1, AttackOutcome.SUCCESS) attack_result2 = create_attack_result("conv_2", 2, AttackOutcome.FAILURE) @@ -1039,16 +1026,14 @@ def test_get_attack_results_labels_key_exists_value_mismatch(sqlite_instance: Me """Test querying for labels where the key exists but the value doesn't match.""" # Create attack results with specific label values - message_piece1 = create_message_piece("conv_1", 1, labels={"operation": "op_exists", "researcher": "roakey"}) - message_piece2 = create_message_piece("conv_2", 1, labels={"operation": "another_op", "researcher": "roakey"}) - message_piece3 = create_message_piece("conv_3", 1, labels={"operation": "test_op"}) - - sqlite_instance.add_message_pieces_to_memory(message_pieces=[message_piece1, message_piece2, message_piece3]) - attack_results = [ - create_attack_result("conv_1", 1, AttackOutcome.SUCCESS), - create_attack_result("conv_2", 2, AttackOutcome.SUCCESS), - create_attack_result("conv_3", 3, AttackOutcome.FAILURE), + create_attack_result( + "conv_1", 1, AttackOutcome.SUCCESS, labels={"operation": "op_exists", "researcher": "roakey"} + ), + create_attack_result( + "conv_2", 2, AttackOutcome.SUCCESS, labels={"operation": "another_op", "researcher": "roakey"} + ), + create_attack_result("conv_3", 3, AttackOutcome.FAILURE, labels={"operation": "test_op"}), ] sqlite_instance.add_attack_results_to_memory(attack_results=attack_results) @@ -1091,6 +1076,27 @@ def test_get_attack_results_labels_key_exists_value_mismatch(sqlite_instance: Me assert results[0].conversation_id == "conv_1" +def test_get_attack_results_by_labels_falls_back_to_conversation_labels(sqlite_instance: MemoryInterface): + """Test that label filtering matches via PromptMemoryEntry when AttackResult has no labels.""" + + # Attack result with NO labels + attack_result = create_attack_result("conv_1", 1, AttackOutcome.SUCCESS, labels={}) + sqlite_instance.add_attack_results_to_memory(attack_results=[attack_result]) + + # Conversation message carries the labels instead + message_piece = create_message_piece("conv_1", 1, labels={"operation": "legacy_op"}) + sqlite_instance.add_message_pieces_to_memory(message_pieces=[message_piece]) + + # Should still find the attack result via the PME fallback path + results = sqlite_instance.get_attack_results(labels={"operation": "legacy_op"}) + assert len(results) == 1 + assert results[0].conversation_id == "conv_1" + + # Non-matching label should return nothing + results = sqlite_instance.get_attack_results(labels={"operation": "missing"}) + assert len(results) == 0 + + # --------------------------------------------------------------------------- # get_unique_attack_labels tests # --------------------------------------------------------------------------- diff --git a/tests/unit/memory/memory_interface/test_interface_prompts.py b/tests/unit/memory/memory_interface/test_interface_prompts.py index d5af1c41de..3bdeccb2d9 100644 --- a/tests/unit/memory/memory_interface/test_interface_prompts.py +++ b/tests/unit/memory/memory_interface/test_interface_prompts.py @@ -738,6 +738,113 @@ def test_get_message_pieces_labels(sqlite_instance: MemoryInterface): assert "harm_category" in retrieved_entry.labels +def test_get_message_pieces_labels_falls_back_to_attack_result_labels(sqlite_instance: MemoryInterface): + """PMEs without labels are returned when a matching AttackResultEntry shares the conversation_id.""" + from pyrit.memory.memory_models import AttackResultEntry + from pyrit.models import AttackOutcome, AttackResult + + conv_id = str(uuid.uuid4()) + labels = {"operation": "op1", "operator": "name1"} + + # PME with NO labels + pme = PromptMemoryEntry( + entry=MessagePiece( + role="user", + original_value="Hello from AR", + conversation_id=conv_id, + ) + ) + # AttackResultEntry with labels sharing the same conversation_id + ar = AttackResult( + conversation_id=conv_id, + objective="test", + outcome=AttackOutcome.SUCCESS, + labels=labels, + ) + are = AttackResultEntry(entry=ar) + + sqlite_instance._insert_entries(entries=[pme, are]) + + retrieved = sqlite_instance.get_message_pieces(labels=labels) + assert len(retrieved) == 1 + assert retrieved[0].original_value == "Hello from AR" + + +def test_get_message_pieces_labels_returns_pme_and_ar_label_matches(sqlite_instance: MemoryInterface): + """Both PMEs with direct labels and PMEs matched via AR labels are returned.""" + from pyrit.memory.memory_models import AttackResultEntry + from pyrit.models import AttackOutcome, AttackResult + + labels = {"operation": "op1"} + + # PME with direct labels + pme_direct = PromptMemoryEntry( + entry=MessagePiece( + role="user", + original_value="Direct label", + labels=labels, + ) + ) + # PME without labels, but associated AR has labels + conv_id = str(uuid.uuid4()) + pme_via_ar = PromptMemoryEntry( + entry=MessagePiece( + role="user", + original_value="Via AR label", + conversation_id=conv_id, + ) + ) + ar = AttackResult( + conversation_id=conv_id, + objective="test", + outcome=AttackOutcome.SUCCESS, + labels=labels, + ) + are = AttackResultEntry(entry=ar) + + # PME with no labels and no matching AR + pme_no_match = PromptMemoryEntry( + entry=MessagePiece( + role="user", + original_value="No match", + ) + ) + + sqlite_instance._insert_entries(entries=[pme_direct, pme_via_ar, are, pme_no_match]) + + retrieved = sqlite_instance.get_message_pieces(labels=labels) + assert len(retrieved) == 2 + original_values = {r.original_value for r in retrieved} + assert original_values == {"Direct label", "Via AR label"} + + +def test_get_message_pieces_labels_no_match_when_ar_labels_differ(sqlite_instance: MemoryInterface): + """PMEs are NOT returned when the AR labels don't match the query.""" + from pyrit.memory.memory_models import AttackResultEntry + from pyrit.models import AttackOutcome, AttackResult + + conv_id = str(uuid.uuid4()) + pme = PromptMemoryEntry( + entry=MessagePiece( + role="user", + original_value="Unmatched", + conversation_id=conv_id, + ) + ) + ar = AttackResult( + conversation_id=conv_id, + objective="test", + outcome=AttackOutcome.SUCCESS, + labels={"operation": "other_op"}, + ) + are = AttackResultEntry(entry=ar) + + sqlite_instance._insert_entries(entries=[pme, are]) + + retrieved = sqlite_instance.get_message_pieces(labels={"operation": "op1"}) + assert len(retrieved) == 0 + + def test_get_message_pieces_metadata(sqlite_instance: MemoryInterface): metadata: dict[str, str | int] = {"key1": "value1", "key2": "value2"} entries = [ diff --git a/tests/unit/memory/test_azure_sql_memory.py b/tests/unit/memory/test_azure_sql_memory.py index 7d1b85341e..3950d96454 100644 --- a/tests/unit/memory/test_azure_sql_memory.py +++ b/tests/unit/memory/test_azure_sql_memory.py @@ -250,6 +250,54 @@ def test_get_memories_with_attack_id(memory_interface: AzureSQLMemory): pytest.skip("Test requires Azure SQL-specific JSON functions; covered by integration tests") +def test_get_attack_result_label_condition_single_label(memory_interface: AzureSQLMemory): + """Test that _get_attack_result_label_condition builds a valid condition for a single label.""" + condition = memory_interface._get_attack_result_label_condition(labels={"operation": "test_op"}) + compiled = str(condition.compile(compile_kwargs={"literal_binds": False})) + assert "JSON_VALUE" in compiled + assert "ISJSON" in compiled + + +def test_get_attack_result_label_condition_multiple_labels(memory_interface: AzureSQLMemory): + """Test that _get_attack_result_label_condition builds a valid condition for multiple labels.""" + condition = memory_interface._get_attack_result_label_condition( + labels={"operation": "test_op", "operator": "roakey"} + ) + compiled = str(condition.compile(compile_kwargs={"literal_binds": False})) + # Both AR-direct and PME-conversation branches should appear + assert "AttackResultEntries" in compiled + assert "PromptMemoryEntries" in compiled + + +def test_get_message_pieces_memory_label_conditions_single_label(memory_interface: AzureSQLMemory): + """Test that _get_message_pieces_memory_label_conditions builds a valid OR condition.""" + conditions = memory_interface._get_message_pieces_memory_label_conditions(memory_labels={"operation": "test_op"}) + assert len(conditions) == 1 + compiled = str(conditions[0].compile(compile_kwargs={"literal_binds": False})) + assert "ISJSON" in compiled + assert "JSON_VALUE" in compiled + + +def test_get_message_pieces_memory_label_conditions_includes_ar_fallback(memory_interface: AzureSQLMemory): + """Test that the condition references both PME and AR tables for the OR fallback.""" + conditions = memory_interface._get_message_pieces_memory_label_conditions( + memory_labels={"operation": "test_op", "operator": "roakey"} + ) + compiled = str(conditions[0].compile(compile_kwargs={"literal_binds": False})) + assert "AttackResultEntries" in compiled + assert "PromptMemoryEntries" in compiled + + +def test_get_message_pieces_memory_label_conditions_bind_params(memory_interface: AzureSQLMemory): + """Test that bind parameters are created for both PME and AR branches.""" + conditions = memory_interface._get_message_pieces_memory_label_conditions(memory_labels={"operation": "test_op"}) + params = conditions[0].compile().params + # PME branch param + assert params.get("pme_ml_operation") == "test_op" + # AR branch param + assert params.get("are_ml_operation") == "test_op" + + def test_update_entries(memory_interface: AzureSQLMemory): # Insert a test entry entry = PromptMemoryEntry( @@ -381,32 +429,37 @@ def test_get_attack_result_label_condition_with_string_value(memory_interface: A """String values produce a single-placeholder IN clause with the stringified value.""" condition = memory_interface._get_attack_result_label_condition(labels={"operator": "roakey"}) params = condition.compile().params - assert params.get("label_operator_0") == "roakey" + assert params.get("pme_label_operator_0") == "roakey" + assert params.get("are_label_operator_0") == "roakey" def test_get_attack_result_label_condition_with_sequence_value(memory_interface: AzureSQLMemory): """Sequence values produce one placeholder per element.""" condition = memory_interface._get_attack_result_label_condition(labels={"operation": ["op_a", "op_b", "op_c"]}) params = condition.compile().params - assert params.get("label_operation_0") == "op_a" - assert params.get("label_operation_1") == "op_b" - assert params.get("label_operation_2") == "op_c" + assert params.get("pme_label_operation_0") == "op_a" + assert params.get("pme_label_operation_1") == "op_b" + assert params.get("pme_label_operation_2") == "op_c" + assert params.get("are_label_operation_0") == "op_a" + assert params.get("are_label_operation_1") == "op_b" + assert params.get("are_label_operation_2") == "op_c" def test_get_attack_result_label_condition_skips_empty_sequence(memory_interface: AzureSQLMemory): """Empty sequence values are skipped (no filter applied for that key).""" condition = memory_interface._get_attack_result_label_condition(labels={"operator": "roakey", "operation": []}) params = condition.compile().params - # operator gets a bind param; operation (empty) does not. - assert params.get("label_operator_0") == "roakey" - assert not any(k.startswith("label_operation_") for k in params) + # operator gets bind params; operation (empty) does not. + assert params.get("pme_label_operator_0") == "roakey" + assert params.get("are_label_operator_0") == "roakey" + assert not any("label_operation_" in k for k in params) def test_get_attack_result_label_condition_empty_labels_dict(memory_interface: AzureSQLMemory): """An empty labels dict produces a condition with no label filters bound.""" condition = memory_interface._get_attack_result_label_condition(labels={}) params = condition.compile().params - assert not any(k.startswith("label_") for k in params) + assert not any("label_" in k for k in params) @pytest.mark.parametrize( diff --git a/tests/unit/scenario/test_scenario.py b/tests/unit/scenario/test_scenario.py index 947cf6f645..95b04451e2 100644 --- a/tests/unit/scenario/test_scenario.py +++ b/tests/unit/scenario/test_scenario.py @@ -88,6 +88,7 @@ def sample_attack_results(): objective=f"objective{i}", outcome=AttackOutcome.SUCCESS, executed_turns=1, + labels={"test_label": f"value{i}"}, ) for i in range(5) ]