Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions pyrit/backend/mappers/attack_mappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
Score,
TargetInfo,
)
from pyrit.common.deprecation import print_deprecation_message
from pyrit.models import AttackResult, ChatMessageRole, PromptDataType
from pyrit.models import Message as PyritMessage
from pyrit.models import MessagePiece as PyritMessagePiece
Expand Down Expand Up @@ -409,7 +410,7 @@ def request_piece_to_pyrit_message_piece(
role: ChatMessageRole,
conversation_id: str,
sequence: int,
labels: Optional[dict[str, str]] = None,
labels: Optional[dict[str, str]] = None, # deprecated
) -> PyritMessagePiece:
"""
Convert a single request piece DTO to a PyRIT MessagePiece domain object.
Expand All @@ -420,10 +421,17 @@ def request_piece_to_pyrit_message_piece(
conversation_id: The conversation/attack ID.
sequence: The message sequence number.
labels: Optional labels to attach to the piece.
Deprecated: This parameter will be removed in a release 0.16.0.

Returns:
PyritMessagePiece domain object.
"""
if labels is not None:
print_deprecation_message(
old_item="request_piece_to_pyrit_message_piece(..., labels=...)",
new_item="request_piece_to_pyrit_message_piece(...)",
removed_in="0.16.0",
)
metadata: Optional[dict[str, str | int]] = None
if piece.prompt_metadata:
metadata = dict(piece.prompt_metadata)
Expand All @@ -439,7 +447,7 @@ def request_piece_to_pyrit_message_piece(
conversation_id=conversation_id,
sequence=sequence,
prompt_metadata=metadata,
labels=labels or {},
labels=labels or {}, # deprecated
original_prompt_id=original_prompt_id,
)

Expand All @@ -449,7 +457,7 @@ def request_to_pyrit_message(
request: AddMessageRequest,
conversation_id: str,
sequence: int,
labels: Optional[dict[str, str]] = None,
labels: Optional[dict[str, str]] = None, # deprecated
) -> PyritMessage:
"""
Build a PyRIT Message from an AddMessageRequest DTO.
Expand All @@ -459,17 +467,24 @@ def request_to_pyrit_message(
conversation_id: The conversation/attack ID.
sequence: The message sequence number.
labels: Optional labels to attach to each piece.
Deprecated: This parameter will be removed in a release 0.16.0.

Returns:
PyritMessage ready to send to the target.
"""
if labels is not None:
print_deprecation_message(
old_item="request_to_pyrit_message(..., labels=...)",
new_item="request_to_pyrit_message(...)",
removed_in="0.16.0",
)
pieces = [
request_piece_to_pyrit_message_piece(
piece=p,
role=request.role,
conversation_id=conversation_id,
sequence=sequence,
labels=labels,
labels=labels, # deprecated
)
for p in request.pieces
]
Expand Down
20 changes: 10 additions & 10 deletions pyrit/backend/services/attack_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ async def create_attack_async(self, *, request: CreateAttackRequest) -> CreateAt
await self._store_prepended_messages(
conversation_id=conversation_id,
prepended=request.prepended_conversation,
labels=labels,
labels=labels, # deprecated
)

return CreateAttackResponse(
Expand Down Expand Up @@ -614,14 +614,14 @@ async def add_message_async(self, *, attack_result_id: str, request: AddMessageR
target_registry_name=target_registry_name,
request=request,
sequence=sequence,
labels=attack_labels,
labels=attack_labels, # deprecated
)
else:
await self._store_message_only_async(
conversation_id=msg_conversation_id,
request=request,
sequence=sequence,
labels=attack_labels,
labels=attack_labels, # deprecated
)

await self._update_attack_after_message_async(attack_result_id=attack_result_id, ar=ar, request=request)
Expand Down Expand Up @@ -852,7 +852,7 @@ def _duplicate_conversation_up_to(
# Apply optional overrides to the fresh pieces before persisting
for piece in all_pieces:
if labels_override is not None:
piece.labels = dict(labels_override)
piece.labels = dict(labels_override) # deprecated
if remap_assistant_to_simulated and piece.api_role == "assistant":
piece._role = "simulated_assistant"

Expand Down Expand Up @@ -943,7 +943,7 @@ async def _store_prepended_messages(
self,
conversation_id: str,
prepended: list[Any],
labels: Optional[dict[str, str]] = None,
labels: Optional[dict[str, str]] = None, # deprecated
) -> None:
"""Store prepended conversation messages in memory."""
for seq, msg in enumerate(prepended):
Expand All @@ -953,7 +953,7 @@ async def _store_prepended_messages(
role=msg.role,
conversation_id=conversation_id,
sequence=seq,
labels=labels,
labels=labels, # deprecated
)
self._memory.add_message_pieces_to_memory(message_pieces=[piece])

Expand All @@ -964,7 +964,7 @@ async def _send_and_store_message_async(
target_registry_name: str,
request: AddMessageRequest,
sequence: int,
labels: Optional[dict[str, str]] = None,
labels: Optional[dict[str, str]] = None, # deprecated
) -> None:
"""Send message to target via normalizer and store response."""
target_obj = get_target_service().get_target_object(target_registry_name=target_registry_name)
Expand All @@ -979,7 +979,7 @@ async def _send_and_store_message_async(
request=request,
conversation_id=conversation_id,
sequence=sequence,
labels=labels,
labels=labels, # deprecated
)

converter_configs = self._get_converter_configs(request)
Expand All @@ -1000,7 +1000,7 @@ async def _store_message_only_async(
conversation_id: str,
request: AddMessageRequest,
sequence: int,
labels: Optional[dict[str, str]] = None,
labels: Optional[dict[str, str]] = None, # deprecated
) -> None:
"""Store message without sending (send=False)."""
await self._persist_base64_pieces_async(request)
Expand All @@ -1010,7 +1010,7 @@ async def _store_message_only_async(
role=request.role,
conversation_id=conversation_id,
sequence=sequence,
labels=labels,
labels=labels, # deprecated
)
self._memory.add_message_pieces_to_memory(message_pieces=[piece])

Expand Down
23 changes: 19 additions & 4 deletions pyrit/executor/attack/component/conversation_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Optional

from pyrit.common.deprecation import print_deprecation_message
from pyrit.common.utils import combine_dict
from pyrit.executor.attack.component.prepended_conversation_config import (
PrependedConversationConfig,
Expand Down Expand Up @@ -57,7 +58,7 @@ def get_adversarial_chat_messages(
adversarial_chat_conversation_id: str,
attack_identifier: ComponentIdentifier,
adversarial_chat_target_identifier: ComponentIdentifier,
labels: Optional[dict[str, str]] = None,
labels: Optional[dict[str, str]] = None, # deprecated
) -> list[Message]:
"""
Transform prepended conversation messages for adversarial chat with swapped roles.
Expand All @@ -76,10 +77,17 @@ def get_adversarial_chat_messages(
attack_identifier (ComponentIdentifier): Attack identifier to associate with messages.
adversarial_chat_target_identifier (ComponentIdentifier): Target identifier for the adversarial chat.
labels: Optional labels to associate with the messages.
Deprecated: This parameter will be removed in a release 0.16.0.

Returns:
List of transformed messages with swapped roles and new IDs.
"""
if labels is not None:
print_deprecation_message(
Comment thread
behnam-o marked this conversation as resolved.
old_item="get_adversarial_chat_messages(..., labels=...)",
new_item="get_adversarial_chat_messages(...)",
removed_in="0.16.0",
)
if not prepended_conversation:
return []

Expand Down Expand Up @@ -110,7 +118,7 @@ def get_adversarial_chat_messages(
conversation_id=adversarial_chat_conversation_id,
attack_identifier=attack_identifier,
prompt_target_identifier=adversarial_chat_target_identifier,
labels=labels,
labels=labels, # deprecated
)

result.append(adversarial_piece.to_message())
Expand Down Expand Up @@ -245,7 +253,7 @@ def set_system_prompt(
target: PromptChatTarget,
conversation_id: str,
system_prompt: str,
labels: Optional[dict[str, str]] = None,
labels: Optional[dict[str, str]] = None, # deprecated
) -> None:
"""
Set or update the system prompt for a conversation.
Expand All @@ -255,12 +263,19 @@ def set_system_prompt(
conversation_id: Unique identifier for the conversation.
system_prompt: The system prompt text.
labels: Optional labels to associate with the system prompt.
Deprecated: This parameter will be removed in a release 0.16.0.
"""
if labels is not None:
print_deprecation_message(
old_item="set_system_prompt(..., labels=...)",
new_item="set_system_prompt(...)",
removed_in="0.16.0",
)
target.set_system_prompt(
system_prompt=system_prompt,
conversation_id=conversation_id,
attack_identifier=self._attack_identifier,
labels=labels,
labels=labels, # deprecated
)

async def initialize_context_async(
Expand Down
2 changes: 1 addition & 1 deletion pyrit/executor/attack/multi_turn/crescendo.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ async def _setup_async(self, *, context: CrescendoAttackContext) -> None:
system_prompt=system_prompt,
conversation_id=context.session.adversarial_chat_conversation_id,
attack_identifier=self.get_identifier(),
labels=context.memory_labels,
labels=context.memory_labels, # deprecated
)

# Initialize backtrack count in context
Expand Down
2 changes: 1 addition & 1 deletion pyrit/executor/attack/multi_turn/red_teaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ async def _setup_async(self, *, context: MultiTurnAttackContext[Any]) -> None:
system_prompt=adversarial_system_prompt,
conversation_id=context.session.adversarial_chat_conversation_id,
attack_identifier=self.get_identifier(),
labels=context.memory_labels,
labels=context.memory_labels, # deprecated
)

async def _perform_async(self, *, context: MultiTurnAttackContext[Any]) -> AttackResult:
Expand Down
2 changes: 1 addition & 1 deletion pyrit/executor/attack/multi_turn/tree_of_attacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,7 +982,7 @@ async def _generate_first_turn_prompt_async(self, objective: str) -> str:
system_prompt=system_prompt,
conversation_id=self.adversarial_chat_conversation_id,
attack_identifier=self._attack_id,
labels=self._memory_labels,
labels=self._memory_labels, # deprecated
)

logger.debug(f"Node {self.node_id}: Using initial seed prompt for first turn")
Expand Down
4 changes: 2 additions & 2 deletions pyrit/executor/promptgen/anecdoctor.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ async def _setup_async(self, *, context: AnecdoctorContext) -> None:
system_prompt=system_prompt,
conversation_id=context.conversation_id,
attack_identifier=self.get_identifier(),
labels=context.memory_labels,
labels=context.memory_labels, # deprecated
)

async def _perform_async(self, *, context: AnecdoctorContext) -> AnecdoctorResult:
Expand Down Expand Up @@ -376,7 +376,7 @@ async def _extract_knowledge_graph_async(self, *, context: AnecdoctorContext) ->
system_prompt=kg_system_prompt,
conversation_id=kg_conversation_id,
attack_identifier=self.get_identifier(),
labels=self._memory_labels,
labels=self._memory_labels, # deprecated
)

# Format examples for knowledge graph extraction using few-shot format
Expand Down
2 changes: 1 addition & 1 deletion pyrit/models/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@ def construct_response_from_request(
role="assistant",
original_value=resp_text,
conversation_id=request.conversation_id,
labels=request.labels,
labels=request.labels, # deprecated
prompt_target_identifier=request.prompt_target_identifier,
attack_identifier=request.attack_identifier,
original_value_data_type=response_type,
Expand Down
12 changes: 10 additions & 2 deletions pyrit/models/message_piece.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, get_args
from uuid import uuid4

from pyrit.common.deprecation import print_deprecation_message
from pyrit.identifiers.component_identifier import ComponentIdentifier
from pyrit.models.literals import ChatMessageRole, PromptDataType, PromptResponseError

Expand Down Expand Up @@ -69,6 +70,7 @@ def __init__(
Defaults to None.
sequence: The order of the conversation within a conversation_id. Defaults to -1.
labels: The labels associated with the memory entry. Several can be standardized. Defaults to None.
Deprecated: This parameter will be removed in a release 0.16.0.
prompt_metadata: The metadata associated with the prompt. This can be specific to any scenarios.
Because memory is how components talk with each other, this can be component specific.
e.g. the URI from a file uploaded to a blob store, or a document type you want to upload.
Expand Down Expand Up @@ -117,6 +119,12 @@ def __init__(
self.timestamp = timestamp.replace(tzinfo=timezone.utc)
else:
self.timestamp = timestamp
if labels is not None:
print_deprecation_message(
old_item="MessagePiece(..., labels=...)",
new_item="MessagePiece(...)",
removed_in="0.16.0",
)
self.labels = labels or {}
self.prompt_metadata = prompt_metadata or {}

Expand Down Expand Up @@ -212,7 +220,7 @@ def copy_lineage_from(self, source: MessagePiece) -> None:
source: The piece whose lineage metadata is authoritative.
"""
self.conversation_id = source.conversation_id
self.labels = dict(source.labels)
self.labels = dict(source.labels) # deprecated
self.attack_identifier = source.attack_identifier
self.prompt_target_identifier = source.prompt_target_identifier
self.prompt_metadata = dict(source.prompt_metadata)
Expand Down Expand Up @@ -327,7 +335,7 @@ def to_dict(self) -> dict[str, object]:
"conversation_id": self.conversation_id,
"sequence": self.sequence,
"timestamp": self.timestamp.isoformat() if self.timestamp else None,
"labels": self.labels,
"labels": self.labels, # deprecated
"targeted_harm_categories": self.targeted_harm_categories if self.targeted_harm_categories else None,
"prompt_metadata": self.prompt_metadata,
"converter_identifiers": [conv.to_dict() for conv in self.converter_identifiers],
Expand Down
10 changes: 9 additions & 1 deletion pyrit/prompt_normalizer/prompt_normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Any, Optional
from uuid import uuid4

from pyrit.common.deprecation import print_deprecation_message
from pyrit.exceptions import (
ComponentRole,
EmptyResponseException,
Expand Down Expand Up @@ -80,6 +81,7 @@ async def send_prompt_async(
response_converter_configurations (list[PromptConverterConfiguration], optional): Configurations for
converting the response. Defaults to an empty list.
labels (Optional[dict[str, str]], optional): Labels associated with the request. Defaults to None.
Deprecated: This parameter will be removed in a release 0.16.0.
attack_identifier (Optional[ComponentIdentifier], optional): Identifier for the attack. Defaults to
None.

Expand All @@ -90,6 +92,12 @@ async def send_prompt_async(
Exception: If an error occurs during the request processing.
ValueError: If the message pieces are not part of the same sequence.
"""
if labels is not None:
print_deprecation_message(
old_item="send_prompt_async(..., labels=...)",
new_item="send_prompt_async(...)",
removed_in="0.16.0",
)
# Validates that the MessagePieces in the Message are part of the same sequence
request_converter_configurations = request_converter_configurations or []
response_converter_configurations = response_converter_configurations or []
Expand All @@ -103,7 +111,7 @@ async def send_prompt_async(
for piece in request.message_pieces:
piece.conversation_id = conversation_id
if labels:
piece.labels = labels
piece.labels = labels # deprecated
piece.prompt_target_identifier = target.get_identifier()
if attack_identifier:
piece.attack_identifier = attack_identifier
Expand Down
Loading
Loading