Skip to content
121 changes: 65 additions & 56 deletions pyrit/common/display_response.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,65 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import io
import logging

from PIL import Image

from pyrit.common.notebook_utils import is_in_ipython_session
from pyrit.memory import CentralMemory
from pyrit.models import AzureBlobStorageIO, DiskStorageIO, MessagePiece

logger = logging.getLogger(__name__)


async def display_image_response(response_piece: MessagePiece) -> None:
"""
Display response images if running in notebook environment.

Args:
response_piece (MessagePiece): The response piece to display.

Raises:
RuntimeError: If storage IO is not initialized.
"""
memory = CentralMemory.get_memory_instance()
if (
response_piece.response_error == "none"
and response_piece.converted_value_data_type == "image_path"
and is_in_ipython_session()
):
image_location = response_piece.converted_value

try:
if memory.results_storage_io is None:
raise RuntimeError("Storage IO not initialized")
image_bytes = await memory.results_storage_io.read_file(image_location)
except Exception as e:
if isinstance(memory.results_storage_io, AzureBlobStorageIO):
try:
# Fallback to reading from disk if the storage IO fails
image_bytes = await DiskStorageIO().read_file(image_location)
except Exception as exc:
logger.error(f"Failed to read image from {image_location}. Full exception: {str(exc)}")
return
else:
logger.error(f"Failed to read image from {image_location}. Full exception: {str(e)}")
return

image_stream = io.BytesIO(image_bytes)
image = Image.open(image_stream)

# Jupyter built-in display function only works in notebooks.
display(image) # type: ignore[name-defined] # noqa: F821
if response_piece.response_error == "blocked":
logger.info("---\nContent blocked, cannot show a response.\n---")
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import io
import logging

from PIL import Image, ImageEnhance

from pyrit.common.notebook_utils import is_in_ipython_session
from pyrit.memory import CentralMemory
from pyrit.models import AzureBlobStorageIO, DiskStorageIO, MessagePiece

logger = logging.getLogger(__name__)


async def display_image_response(response_piece: MessagePiece, safe_outputs: bool = False) -> None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"safe" implies that we hide only unsafe ones. Maybe hide_outputs or blur_outputs or similar would be preferable? Also, this only applies to images, not all outputs.

I also don't want to turn this on in our example notebooks because the entire point is to illustrate how the package works. We can show it in one of them, of course.

"""
Display response images if running in notebook environment.

Args:
response_piece (MessagePiece): The response piece to display.
safe_outputs (bool): Whether to sanitize image outputs before displaying them.

Raises:
RuntimeError: If storage IO is not initialized.
"""
memory = CentralMemory.get_memory_instance()
if (
response_piece.response_error == "none"
and response_piece.converted_value_data_type == "image_path"
and is_in_ipython_session()
):
image_location = response_piece.converted_value

try:
if memory.results_storage_io is None:
raise RuntimeError("Storage IO not initialized")
image_bytes = await memory.results_storage_io.read_file(image_location)
except Exception as e:
if isinstance(memory.results_storage_io, AzureBlobStorageIO):
try:
# Fallback to reading from disk if the storage IO fails
image_bytes = await DiskStorageIO().read_file(image_location)
except Exception as exc:
logger.error(f"Failed to read image from {image_location}. Full exception: {str(exc)}")
return
else:
logger.error(f"Failed to read image from {image_location}. Full exception: {str(e)}")
return

image_stream = io.BytesIO(image_bytes)
image: Image.Image = Image.open(image_stream)

if safe_outputs:
new_width = int(image.width * 0.5)
new_height = int(image.height * 0.5)
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)

image = ImageEnhance.Color(image).enhance(0.0)
image = image.rotate(90.0, expand=True, fillcolor=(255, 255, 255))
Comment on lines +54 to +60
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the rotation and the other changes? Blurring seems more appropriate?


# Jupyter built-in display function only works in notebooks.
display(image) # type: ignore[name-defined] # noqa: F821
if response_piece.response_error == "blocked":
logger.info("---\nContent blocked, cannot show a response.\n---")
9 changes: 7 additions & 2 deletions pyrit/executor/attack/printer/console_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ class ConsoleAttackResultPrinter(AttackResultPrinter):
for consoles that don't support ANSI characters.
"""

def __init__(self, *, width: int = 100, indent_size: int = 2, enable_colors: bool = True):
def __init__(
self, *, width: int = 100, indent_size: int = 2, enable_colors: bool = True, safe_outputs: bool = False
):
"""
Initialize the console printer.

Expand All @@ -34,6 +36,8 @@ def __init__(self, *, width: int = 100, indent_size: int = 2, enable_colors: boo
Defaults to 2.
enable_colors (bool): Whether to enable ANSI color output. When False,
all output will be plain text without colors. Defaults to True.
safe_outputs (bool): Whether to sanitize image outputs before displaying them.
Defaults to False.

Raises:
ValueError: If width <= 0 or indent_size < 0.
Expand All @@ -42,6 +46,7 @@ def __init__(self, *, width: int = 100, indent_size: int = 2, enable_colors: boo
self._width = width
self._indent = " " * indent_size
self._enable_colors = enable_colors
self._safe_outputs = safe_outputs

def _print_colored(self, text: str, *colors: str) -> None:
"""
Expand Down Expand Up @@ -227,7 +232,7 @@ async def print_messages_async(
self._print_wrapped_text(piece.converted_value, Fore.YELLOW)

# Display images if present
await display_image_response(piece)
await display_image_response(response_piece=piece, safe_outputs=self._safe_outputs)

# Print scores with better formatting (only if scores are requested)
if include_scores:
Expand Down
20 changes: 18 additions & 2 deletions pyrit/executor/attack/printer/markdown_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import os
from datetime import datetime, timezone
from pathlib import Path

from pyrit.executor.attack.printer.attack_result_printer import AttackResultPrinter
from pyrit.memory import CentralMemory
Expand All @@ -18,17 +19,20 @@ class MarkdownAttackResultPrinter(AttackResultPrinter):
markdown formatting that should be properly rendered.
"""

def __init__(self, *, display_inline: bool = True):
def __init__(self, *, display_inline: bool = True, output_file_path: Path | None = None):
"""
Initialize the markdown printer.

Args:
display_inline (bool): If True, uses IPython.display to render markdown
inline in Jupyter notebooks. If False, prints markdown strings.
Defaults to True.
output_file_path (Path | None): If set, markdown output is appended to this
file instead of being displayed or printed. Defaults to None.
"""
self._memory = CentralMemory.get_memory_instance()
self._display_inline = display_inline
self._output_file_path = output_file_path

def _render_markdown(self, markdown_lines: list[str]) -> None:
"""
Expand All @@ -42,6 +46,12 @@ def _render_markdown(self, markdown_lines: list[str]) -> None:
"""
full_markdown = "\n".join(markdown_lines)

if self._output_file_path:
os.makedirs(os.path.dirname(self._output_file_path), exist_ok=True)
with open(self._output_file_path, "a", encoding="utf-8") as f:
f.write(full_markdown + "\n")
return

if self._display_inline:
try:
from IPython.display import Markdown, display
Expand Down Expand Up @@ -351,7 +361,13 @@ def _format_image_content(self, *, image_path: str) -> list[str]:
Returns:
List[str]: List of markdown lines for the image.
"""
relative_path = os.path.relpath(image_path)
# If output to file, set image path relative to output path
start_path = os.path.dirname(self._output_file_path) if self._output_file_path else "."
try:
relative_path = os.path.relpath(path=image_path, start=start_path)
except ValueError:
# os.path.relpath raises ValueError on Windows when paths are on different drives
relative_path = image_path
posix_path = relative_path.replace("\\", "/")
return [f"![Image]({posix_path})\n"]

Expand Down
38 changes: 38 additions & 0 deletions tests/unit/common/test_display_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,28 @@
# Licensed under the MIT license.

import logging
from io import BytesIO
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
from PIL import Image

from pyrit.common.display_response import display_image_response


@pytest.fixture
def sample_image_bytes():
"""Sample RGB image for testing with configurable format and size."""

def _create_image(format="PNG", size=(200, 200)): # noqa: A002
img = Image.new("RGB", size, color=(125, 125, 125))
img_bytes = BytesIO()
img.save(img_bytes, format=format)
return img_bytes.getvalue()

return _create_image


@pytest.fixture()
def _mock_central_memory():
mock_memory = MagicMock()
Expand Down Expand Up @@ -66,6 +81,29 @@ async def test_display_image_reads_and_displays(mock_display, mock_image, mock_i
mock_display.assert_called_once_with(mock_img_obj)


@pytest.mark.asyncio
@patch("pyrit.common.display_response.is_in_ipython_session", return_value=True)
@patch("pyrit.common.display_response.display", create=True)
async def test_display_image_applies_safe_outputs(mock_display, mock_ipython, _mock_central_memory, sample_image_bytes):
original_size = (200, 100)
image_bytes = sample_image_bytes(format="PNG", size=original_size)
_mock_central_memory.results_storage_io.read_file = AsyncMock(return_value=image_bytes)

piece = MagicMock()
piece.response_error = "none"
piece.converted_value_data_type = "image_path"
piece.converted_value = "path/to/img.png"

await display_image_response(piece, safe_outputs=True)

displayed_image = mock_display.call_args[0][0]
expected_size = (50, 100)
assert displayed_image.size == expected_size

pixels = list(displayed_image.get_flattened_data())
assert all(r == g == b for r, g, b in pixels)


@pytest.mark.asyncio
@patch("pyrit.common.display_response.is_in_ipython_session", return_value=True)
async def test_display_image_logs_error_on_read_failure(mock_ipython, _mock_central_memory, caplog):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import os
import uuid
from pathlib import Path
from unittest.mock import MagicMock, patch

import pytest
Expand Down Expand Up @@ -36,6 +37,11 @@ def markdown_printer(patch_central_database):
return MarkdownAttackResultPrinter(display_inline=False)


@pytest.fixture
def markdown_printer_to_file(patch_central_database, tmp_path):
return MarkdownAttackResultPrinter(output_file_path=Path(tmp_path) / "output" / "output.md")


@pytest.fixture
def sample_boolean_score():
return Score(
Expand Down Expand Up @@ -150,6 +156,36 @@ def test_format_image_content(markdown_printer):
assert "image.png" in formatted[0]


def test_format_image_content_relative_to_output_file(markdown_printer_to_file, tmp_path):
"""Test that image path is relative to output file dir when outputting to file."""
image_path = os.path.join(str(tmp_path), "images", "screenshot.png")

# When outputting to file, path should be relative to the output file's directory
formatted_file = markdown_printer_to_file._format_image_content(image_path=image_path)
expected_file_rel = "../images/screenshot.png"
assert f"![Image]({expected_file_rel})" in formatted_file[0]


def test_format_image_content_relative_to_cwd(markdown_printer):
"""Test that image path is relative to cwd when outputting to console."""
# Use a path under the current working directory to avoid cross-drive issues
image_path = os.path.join(os.getcwd(), "images", "screenshot.png")

formatted_console = markdown_printer._format_image_content(image_path=image_path)
expected_console_rel = os.path.relpath(image_path).replace("\\", "/")
assert f"![Image]({expected_console_rel})" in formatted_console[0]


def test_format_image_content_relpath_error_fallback(markdown_printer_to_file):
"""Test that image path falls back to absolute path when a relative path cannot be computed."""
image_path = "C:\\other_drive\\images\\screenshot.png"

with patch("pyrit.executor.attack.printer.markdown_printer.os.path.relpath", side_effect=ValueError):
formatted = markdown_printer_to_file._format_image_content(image_path=image_path)

assert "![Image](C:/other_drive/images/screenshot.png)" in formatted[0]


def test_format_audio_content(markdown_printer):
"""Test audio content formatting."""
audio_path = "test.wav"
Expand Down Expand Up @@ -247,3 +283,40 @@ async def test_print_summary_async(markdown_printer, sample_attack_result, capsy
assert "Test objective" in captured.out
assert "TestAttack" in captured.out
assert "test-conv-123" in captured.out


@pytest.mark.asyncio
async def test_output_file_path_appends_to_file(markdown_printer_to_file, sample_attack_result, capsys):
"""Test that output_file_path writes markdown to file and produces no stdout."""
await markdown_printer_to_file.print_result_async(sample_attack_result)

content = markdown_printer_to_file._output_file_path.read_text(encoding="utf-8")
assert "Attack Result: SUCCESS" in content
assert "## Attack Summary" in content

captured = capsys.readouterr()
assert captured.out == ""


@pytest.mark.asyncio
async def test_output_file_path_appends_multiple_calls(markdown_printer_to_file, sample_attack_result):
"""Test that calling print twice appends both reports to the same file."""
await markdown_printer_to_file.print_result_async(sample_attack_result)
await markdown_printer_to_file.print_result_async(sample_attack_result)

content = markdown_printer_to_file._output_file_path.read_text(encoding="utf-8")
assert content.count("Attack Result: SUCCESS") == 2


@pytest.mark.asyncio
async def test_output_file_path_none_does_not_write(
markdown_printer, sample_attack_result, mock_memory, tmp_path, capsys
):
"""Test that default output_file_path=None prints to stdout and writes no file."""
await markdown_printer.print_result_async(sample_attack_result)

captured = capsys.readouterr()
assert "Attack Result: SUCCESS" in captured.out

# No file should have been created in tmp_path
assert list(tmp_path.iterdir()) == []
Loading