diff --git a/pyrit/common/display_response.py b/pyrit/common/display_response.py index 6a97af39cc..1d1d84cc23 100644 --- a/pyrit/common/display_response.py +++ b/pyrit/common/display_response.py @@ -1,56 +1,65 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -import io -import logging - -from PIL import Image - -from pyrit.common.notebook_utils import is_in_ipython_session -from pyrit.memory import CentralMemory -from pyrit.models import AzureBlobStorageIO, DiskStorageIO, MessagePiece - -logger = logging.getLogger(__name__) - - -async def display_image_response(response_piece: MessagePiece) -> None: - """ - Display response images if running in notebook environment. - - Args: - response_piece (MessagePiece): The response piece to display. - - Raises: - RuntimeError: If storage IO is not initialized. - """ - memory = CentralMemory.get_memory_instance() - if ( - response_piece.response_error == "none" - and response_piece.converted_value_data_type == "image_path" - and is_in_ipython_session() - ): - image_location = response_piece.converted_value - - try: - if memory.results_storage_io is None: - raise RuntimeError("Storage IO not initialized") - image_bytes = await memory.results_storage_io.read_file(image_location) - except Exception as e: - if isinstance(memory.results_storage_io, AzureBlobStorageIO): - try: - # Fallback to reading from disk if the storage IO fails - image_bytes = await DiskStorageIO().read_file(image_location) - except Exception as exc: - logger.error(f"Failed to read image from {image_location}. Full exception: {str(exc)}") - return - else: - logger.error(f"Failed to read image from {image_location}. Full exception: {str(e)}") - return - - image_stream = io.BytesIO(image_bytes) - image = Image.open(image_stream) - - # Jupyter built-in display function only works in notebooks. - display(image) # type: ignore[name-defined] # noqa: F821 - if response_piece.response_error == "blocked": - logger.info("---\nContent blocked, cannot show a response.\n---") +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import io +import logging + +from PIL import Image, ImageEnhance + +from pyrit.common.notebook_utils import is_in_ipython_session +from pyrit.memory import CentralMemory +from pyrit.models import AzureBlobStorageIO, DiskStorageIO, MessagePiece + +logger = logging.getLogger(__name__) + + +async def display_image_response(response_piece: MessagePiece, safe_outputs: bool = False) -> None: + """ + Display response images if running in notebook environment. + + Args: + response_piece (MessagePiece): The response piece to display. + safe_outputs (bool): Whether to sanitize image outputs before displaying them. + + Raises: + RuntimeError: If storage IO is not initialized. + """ + memory = CentralMemory.get_memory_instance() + if ( + response_piece.response_error == "none" + and response_piece.converted_value_data_type == "image_path" + and is_in_ipython_session() + ): + image_location = response_piece.converted_value + + try: + if memory.results_storage_io is None: + raise RuntimeError("Storage IO not initialized") + image_bytes = await memory.results_storage_io.read_file(image_location) + except Exception as e: + if isinstance(memory.results_storage_io, AzureBlobStorageIO): + try: + # Fallback to reading from disk if the storage IO fails + image_bytes = await DiskStorageIO().read_file(image_location) + except Exception as exc: + logger.error(f"Failed to read image from {image_location}. Full exception: {str(exc)}") + return + else: + logger.error(f"Failed to read image from {image_location}. Full exception: {str(e)}") + return + + image_stream = io.BytesIO(image_bytes) + image: Image.Image = Image.open(image_stream) + + if safe_outputs: + new_width = int(image.width * 0.5) + new_height = int(image.height * 0.5) + image = image.resize((new_width, new_height), Image.Resampling.LANCZOS) + + image = ImageEnhance.Color(image).enhance(0.0) + image = image.rotate(90.0, expand=True, fillcolor=(255, 255, 255)) + + # Jupyter built-in display function only works in notebooks. + display(image) # type: ignore[name-defined] # noqa: F821 + if response_piece.response_error == "blocked": + logger.info("---\nContent blocked, cannot show a response.\n---") diff --git a/pyrit/executor/attack/printer/console_printer.py b/pyrit/executor/attack/printer/console_printer.py index ff1cce42f9..c1117fbf29 100644 --- a/pyrit/executor/attack/printer/console_printer.py +++ b/pyrit/executor/attack/printer/console_printer.py @@ -23,7 +23,9 @@ class ConsoleAttackResultPrinter(AttackResultPrinter): for consoles that don't support ANSI characters. """ - def __init__(self, *, width: int = 100, indent_size: int = 2, enable_colors: bool = True): + def __init__( + self, *, width: int = 100, indent_size: int = 2, enable_colors: bool = True, safe_outputs: bool = False + ): """ Initialize the console printer. @@ -34,6 +36,8 @@ def __init__(self, *, width: int = 100, indent_size: int = 2, enable_colors: boo Defaults to 2. enable_colors (bool): Whether to enable ANSI color output. When False, all output will be plain text without colors. Defaults to True. + safe_outputs (bool): Whether to sanitize image outputs before displaying them. + Defaults to False. Raises: ValueError: If width <= 0 or indent_size < 0. @@ -42,6 +46,7 @@ def __init__(self, *, width: int = 100, indent_size: int = 2, enable_colors: boo self._width = width self._indent = " " * indent_size self._enable_colors = enable_colors + self._safe_outputs = safe_outputs def _print_colored(self, text: str, *colors: str) -> None: """ @@ -227,7 +232,7 @@ async def print_messages_async( self._print_wrapped_text(piece.converted_value, Fore.YELLOW) # Display images if present - await display_image_response(piece) + await display_image_response(response_piece=piece, safe_outputs=self._safe_outputs) # Print scores with better formatting (only if scores are requested) if include_scores: diff --git a/pyrit/executor/attack/printer/markdown_printer.py b/pyrit/executor/attack/printer/markdown_printer.py index 5946ce985c..51d58fe469 100644 --- a/pyrit/executor/attack/printer/markdown_printer.py +++ b/pyrit/executor/attack/printer/markdown_printer.py @@ -3,6 +3,7 @@ import os from datetime import datetime, timezone +from pathlib import Path from pyrit.executor.attack.printer.attack_result_printer import AttackResultPrinter from pyrit.memory import CentralMemory @@ -18,7 +19,7 @@ class MarkdownAttackResultPrinter(AttackResultPrinter): markdown formatting that should be properly rendered. """ - def __init__(self, *, display_inline: bool = True): + def __init__(self, *, display_inline: bool = True, output_file_path: Path | None = None): """ Initialize the markdown printer. @@ -26,9 +27,12 @@ def __init__(self, *, display_inline: bool = True): display_inline (bool): If True, uses IPython.display to render markdown inline in Jupyter notebooks. If False, prints markdown strings. Defaults to True. + output_file_path (Path | None): If set, markdown output is appended to this + file instead of being displayed or printed. Defaults to None. """ self._memory = CentralMemory.get_memory_instance() self._display_inline = display_inline + self._output_file_path = output_file_path def _render_markdown(self, markdown_lines: list[str]) -> None: """ @@ -42,6 +46,12 @@ def _render_markdown(self, markdown_lines: list[str]) -> None: """ full_markdown = "\n".join(markdown_lines) + if self._output_file_path: + os.makedirs(os.path.dirname(self._output_file_path), exist_ok=True) + with open(self._output_file_path, "a", encoding="utf-8") as f: + f.write(full_markdown + "\n") + return + if self._display_inline: try: from IPython.display import Markdown, display @@ -351,7 +361,13 @@ def _format_image_content(self, *, image_path: str) -> list[str]: Returns: List[str]: List of markdown lines for the image. """ - relative_path = os.path.relpath(image_path) + # If output to file, set image path relative to output path + start_path = os.path.dirname(self._output_file_path) if self._output_file_path else "." + try: + relative_path = os.path.relpath(path=image_path, start=start_path) + except ValueError: + # os.path.relpath raises ValueError on Windows when paths are on different drives + relative_path = image_path posix_path = relative_path.replace("\\", "/") return [f"![Image]({posix_path})\n"] diff --git a/tests/unit/common/test_display_response.py b/tests/unit/common/test_display_response.py index 23f1fea9d2..7301492b83 100644 --- a/tests/unit/common/test_display_response.py +++ b/tests/unit/common/test_display_response.py @@ -2,13 +2,28 @@ # Licensed under the MIT license. import logging +from io import BytesIO from unittest.mock import AsyncMock, MagicMock, patch import pytest +from PIL import Image from pyrit.common.display_response import display_image_response +@pytest.fixture +def sample_image_bytes(): + """Sample RGB image for testing with configurable format and size.""" + + def _create_image(format="PNG", size=(200, 200)): # noqa: A002 + img = Image.new("RGB", size, color=(125, 125, 125)) + img_bytes = BytesIO() + img.save(img_bytes, format=format) + return img_bytes.getvalue() + + return _create_image + + @pytest.fixture() def _mock_central_memory(): mock_memory = MagicMock() @@ -66,6 +81,29 @@ async def test_display_image_reads_and_displays(mock_display, mock_image, mock_i mock_display.assert_called_once_with(mock_img_obj) +@pytest.mark.asyncio +@patch("pyrit.common.display_response.is_in_ipython_session", return_value=True) +@patch("pyrit.common.display_response.display", create=True) +async def test_display_image_applies_safe_outputs(mock_display, mock_ipython, _mock_central_memory, sample_image_bytes): + original_size = (200, 100) + image_bytes = sample_image_bytes(format="PNG", size=original_size) + _mock_central_memory.results_storage_io.read_file = AsyncMock(return_value=image_bytes) + + piece = MagicMock() + piece.response_error = "none" + piece.converted_value_data_type = "image_path" + piece.converted_value = "path/to/img.png" + + await display_image_response(piece, safe_outputs=True) + + displayed_image = mock_display.call_args[0][0] + expected_size = (50, 100) + assert displayed_image.size == expected_size + + pixels = list(displayed_image.get_flattened_data()) + assert all(r == g == b for r, g, b in pixels) + + @pytest.mark.asyncio @patch("pyrit.common.display_response.is_in_ipython_session", return_value=True) async def test_display_image_logs_error_on_read_failure(mock_ipython, _mock_central_memory, caplog): diff --git a/tests/unit/executor/attack/core/test_markdown_printer.py b/tests/unit/executor/attack/printer/test_markdown_printer.py similarity index 74% rename from tests/unit/executor/attack/core/test_markdown_printer.py rename to tests/unit/executor/attack/printer/test_markdown_printer.py index fc0ff0adbf..e9fcc816ff 100644 --- a/tests/unit/executor/attack/core/test_markdown_printer.py +++ b/tests/unit/executor/attack/printer/test_markdown_printer.py @@ -3,6 +3,7 @@ import os import uuid +from pathlib import Path from unittest.mock import MagicMock, patch import pytest @@ -36,6 +37,11 @@ def markdown_printer(patch_central_database): return MarkdownAttackResultPrinter(display_inline=False) +@pytest.fixture +def markdown_printer_to_file(patch_central_database, tmp_path): + return MarkdownAttackResultPrinter(output_file_path=Path(tmp_path) / "output" / "output.md") + + @pytest.fixture def sample_boolean_score(): return Score( @@ -150,6 +156,36 @@ def test_format_image_content(markdown_printer): assert "image.png" in formatted[0] +def test_format_image_content_relative_to_output_file(markdown_printer_to_file, tmp_path): + """Test that image path is relative to output file dir when outputting to file.""" + image_path = os.path.join(str(tmp_path), "images", "screenshot.png") + + # When outputting to file, path should be relative to the output file's directory + formatted_file = markdown_printer_to_file._format_image_content(image_path=image_path) + expected_file_rel = "../images/screenshot.png" + assert f"![Image]({expected_file_rel})" in formatted_file[0] + + +def test_format_image_content_relative_to_cwd(markdown_printer): + """Test that image path is relative to cwd when outputting to console.""" + # Use a path under the current working directory to avoid cross-drive issues + image_path = os.path.join(os.getcwd(), "images", "screenshot.png") + + formatted_console = markdown_printer._format_image_content(image_path=image_path) + expected_console_rel = os.path.relpath(image_path).replace("\\", "/") + assert f"![Image]({expected_console_rel})" in formatted_console[0] + + +def test_format_image_content_relpath_error_fallback(markdown_printer_to_file): + """Test that image path falls back to absolute path when a relative path cannot be computed.""" + image_path = "C:\\other_drive\\images\\screenshot.png" + + with patch("pyrit.executor.attack.printer.markdown_printer.os.path.relpath", side_effect=ValueError): + formatted = markdown_printer_to_file._format_image_content(image_path=image_path) + + assert "![Image](C:/other_drive/images/screenshot.png)" in formatted[0] + + def test_format_audio_content(markdown_printer): """Test audio content formatting.""" audio_path = "test.wav" @@ -247,3 +283,40 @@ async def test_print_summary_async(markdown_printer, sample_attack_result, capsy assert "Test objective" in captured.out assert "TestAttack" in captured.out assert "test-conv-123" in captured.out + + +@pytest.mark.asyncio +async def test_output_file_path_appends_to_file(markdown_printer_to_file, sample_attack_result, capsys): + """Test that output_file_path writes markdown to file and produces no stdout.""" + await markdown_printer_to_file.print_result_async(sample_attack_result) + + content = markdown_printer_to_file._output_file_path.read_text(encoding="utf-8") + assert "Attack Result: SUCCESS" in content + assert "## Attack Summary" in content + + captured = capsys.readouterr() + assert captured.out == "" + + +@pytest.mark.asyncio +async def test_output_file_path_appends_multiple_calls(markdown_printer_to_file, sample_attack_result): + """Test that calling print twice appends both reports to the same file.""" + await markdown_printer_to_file.print_result_async(sample_attack_result) + await markdown_printer_to_file.print_result_async(sample_attack_result) + + content = markdown_printer_to_file._output_file_path.read_text(encoding="utf-8") + assert content.count("Attack Result: SUCCESS") == 2 + + +@pytest.mark.asyncio +async def test_output_file_path_none_does_not_write( + markdown_printer, sample_attack_result, mock_memory, tmp_path, capsys +): + """Test that default output_file_path=None prints to stdout and writes no file.""" + await markdown_printer.print_result_async(sample_attack_result) + + captured = capsys.readouterr() + assert "Attack Result: SUCCESS" in captured.out + + # No file should have been created in tmp_path + assert list(tmp_path.iterdir()) == []