diff --git a/pyrit/backend/main.py b/pyrit/backend/main.py index f580a6a0f..a1a9cad0b 100644 --- a/pyrit/backend/main.py +++ b/pyrit/backend/main.py @@ -18,7 +18,7 @@ import pyrit from pyrit.backend.middleware import RequestIdMiddleware, SecurityHeadersMiddleware, register_error_handlers from pyrit.backend.middleware.auth import EntraAuthMiddleware -from pyrit.backend.routes import attacks, auth, converters, health, labels, media, targets, version +from pyrit.backend.routes import attacks, auth, converters, health, labels, media, scenarios, targets, version from pyrit.memory import CentralMemory # Check for development mode from environment variable @@ -85,6 +85,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: app.include_router(attacks.router, prefix="/api", tags=["attacks"]) app.include_router(targets.router, prefix="/api", tags=["targets"]) app.include_router(converters.router, prefix="/api", tags=["converters"]) +app.include_router(scenarios.router, prefix="/api", tags=["scenarios"]) app.include_router(labels.router, prefix="/api", tags=["labels"]) app.include_router(health.router, prefix="/api", tags=["health"]) app.include_router(auth.router, prefix="/api", tags=["auth"]) diff --git a/pyrit/backend/models/__init__.py b/pyrit/backend/models/__init__.py index e40844933..d606d89eb 100644 --- a/pyrit/backend/models/__init__.py +++ b/pyrit/backend/models/__init__.py @@ -47,6 +47,10 @@ CreateConverterResponse, PreviewStep, ) +from pyrit.backend.models.scenarios import ( + ScenarioListResponse, + ScenarioSummary, +) from pyrit.backend.models.targets import ( CreateTargetRequest, TargetInstance, @@ -91,6 +95,9 @@ "CreateConverterRequest", "CreateConverterResponse", "PreviewStep", + # Scenarios + "ScenarioListResponse", + "ScenarioSummary", # Targets "CreateTargetRequest", "TargetInstance", diff --git a/pyrit/backend/models/scenarios.py b/pyrit/backend/models/scenarios.py new file mode 100644 index 000000000..f1d134ce7 --- /dev/null +++ b/pyrit/backend/models/scenarios.py @@ -0,0 +1,37 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Scenario API response models. + +Scenarios are multi-attack security testing campaigns. These models represent +the metadata about available scenarios (listing), not scenario execution results. +""" + +from typing import Optional + +from pydantic import BaseModel, Field + +from pyrit.backend.models.common import PaginationInfo + + +class ScenarioSummary(BaseModel): + """Summary of a registered scenario.""" + + scenario_name: str = Field(..., description="Registry key (e.g., 'foundry.red_team_agent')") + class_name: str = Field(..., description="Python class name (e.g., 'RedTeamAgentScenario')") + description: str = Field(..., description="Human-readable description of the scenario") + default_strategy: str = Field(..., description="Default strategy name used when none specified") + aggregate_strategies: list[str] = Field( + ..., description="Aggregate strategies that combine multiple attack approaches" + ) + all_strategies: list[str] = Field(..., description="All available concrete strategy names") + default_datasets: list[str] = Field(..., description="Default dataset names used by the scenario") + max_dataset_size: Optional[int] = Field(None, description="Maximum items per dataset (None means unlimited)") + + +class ScenarioListResponse(BaseModel): + """Response for listing scenarios.""" + + items: list[ScenarioSummary] = Field(..., description="List of scenario summaries") + pagination: PaginationInfo = Field(..., description="Pagination metadata") diff --git a/pyrit/backend/routes/__init__.py b/pyrit/backend/routes/__init__.py index 09283645e..ca412238e 100644 --- a/pyrit/backend/routes/__init__.py +++ b/pyrit/backend/routes/__init__.py @@ -5,7 +5,7 @@ API route handlers. """ -from pyrit.backend.routes import attacks, converters, health, labels, media, targets, version +from pyrit.backend.routes import attacks, converters, health, labels, media, scenarios, targets, version __all__ = [ "attacks", @@ -13,6 +13,7 @@ "health", "labels", "media", + "scenarios", "targets", "version", ] diff --git a/pyrit/backend/routes/scenarios.py b/pyrit/backend/routes/scenarios.py new file mode 100644 index 000000000..9cd3e2ef4 --- /dev/null +++ b/pyrit/backend/routes/scenarios.py @@ -0,0 +1,68 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Scenario API routes. + +Provides endpoints for listing available scenarios and their metadata. +""" + +from typing import Optional + +from fastapi import APIRouter, HTTPException, Query, status + +from pyrit.backend.models.common import ProblemDetail +from pyrit.backend.models.scenarios import ScenarioListResponse, ScenarioSummary +from pyrit.backend.services.scenario_service import get_scenario_service + +router = APIRouter(prefix="/scenarios", tags=["scenarios"]) + + +@router.get( + "", + response_model=ScenarioListResponse, +) +async def list_scenarios( + limit: int = Query(50, ge=1, le=200, description="Maximum items per page"), + cursor: Optional[str] = Query(None, description="Pagination cursor (scenario_name to start after)"), +) -> ScenarioListResponse: + """ + List all available scenarios. + + Returns scenario metadata including strategies, datasets, and defaults. + Use GET /api/scenarios/{scenario_name} for full details on a specific scenario. + + Returns: + ScenarioListResponse: Paginated list of scenario summaries. + """ + service = get_scenario_service() + return await service.list_scenarios_async(limit=limit, cursor=cursor) + + +@router.get( + "/{scenario_name:path}", + response_model=ScenarioSummary, + responses={ + 404: {"model": ProblemDetail, "description": "Scenario not found"}, + }, +) +async def get_scenario(scenario_name: str) -> ScenarioSummary: + """ + Get details for a specific scenario. + + Args: + scenario_name: Registry name of the scenario (e.g., 'foundry.red_team_agent'). + + Returns: + ScenarioSummary: Full scenario metadata. + """ + service = get_scenario_service() + + scenario = await service.get_scenario_async(scenario_name=scenario_name) + if not scenario: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Scenario '{scenario_name}' not found", + ) + + return scenario diff --git a/pyrit/backend/services/__init__.py b/pyrit/backend/services/__init__.py index fe7ac6c90..29807150a 100644 --- a/pyrit/backend/services/__init__.py +++ b/pyrit/backend/services/__init__.py @@ -15,6 +15,10 @@ ConverterService, get_converter_service, ) +from pyrit.backend.services.scenario_service import ( + ScenarioService, + get_scenario_service, +) from pyrit.backend.services.target_service import ( TargetService, get_target_service, @@ -25,6 +29,8 @@ "get_attack_service", "ConverterService", "get_converter_service", + "ScenarioService", + "get_scenario_service", "TargetService", "get_target_service", ] diff --git a/pyrit/backend/services/scenario_service.py b/pyrit/backend/services/scenario_service.py new file mode 100644 index 000000000..ffcf52045 --- /dev/null +++ b/pyrit/backend/services/scenario_service.py @@ -0,0 +1,133 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Scenario service for listing available scenarios. + +Provides read-only access to the ScenarioRegistry, exposing scenario metadata +through the REST API. +""" + +from functools import lru_cache +from typing import Optional + +from pyrit.backend.models.common import PaginationInfo +from pyrit.backend.models.scenarios import ScenarioListResponse, ScenarioSummary +from pyrit.registry import ScenarioMetadata, ScenarioRegistry + + +def _metadata_to_summary(metadata: ScenarioMetadata) -> ScenarioSummary: + """ + Convert a ScenarioMetadata dataclass to a ScenarioSummary Pydantic model. + + Args: + metadata: The registry metadata for a scenario. + + Returns: + ScenarioSummary Pydantic model. + """ + return ScenarioSummary( + scenario_name=metadata.registry_name, + class_name=metadata.class_name, + description=metadata.class_description, + default_strategy=metadata.default_strategy, + aggregate_strategies=list(metadata.aggregate_strategies), + all_strategies=list(metadata.all_strategies), + default_datasets=list(metadata.default_datasets), + max_dataset_size=metadata.max_dataset_size, + ) + + +class ScenarioService: + """ + Service for listing available scenarios. + + Uses ScenarioRegistry as the source of truth for scenario metadata. + """ + + def __init__(self) -> None: + """Initialize the scenario service.""" + self._registry = ScenarioRegistry.get_registry_singleton() + + async def list_scenarios_async( + self, + *, + limit: int = 50, + cursor: Optional[str] = None, + ) -> ScenarioListResponse: + """ + List all available scenarios with pagination. + + Args: + limit: Maximum items to return per page. + cursor: Pagination cursor (scenario_name to start after). + + Returns: + ScenarioListResponse with paginated scenario summaries. + """ + all_metadata = self._registry.list_metadata() + all_summaries = [_metadata_to_summary(m) for m in all_metadata] + + page, has_more = self._paginate(items=all_summaries, cursor=cursor, limit=limit) + next_cursor = page[-1].scenario_name if has_more and page else None + + return ScenarioListResponse( + items=page, + pagination=PaginationInfo(limit=limit, has_more=has_more, next_cursor=next_cursor, prev_cursor=cursor), + ) + + async def get_scenario_async(self, *, scenario_name: str) -> Optional[ScenarioSummary]: + """ + Get a single scenario by registry name. + + Args: + scenario_name: The registry key of the scenario (e.g., 'foundry.red_team_agent'). + + Returns: + ScenarioSummary if found, None otherwise. + """ + all_metadata = self._registry.list_metadata() + for metadata in all_metadata: + if metadata.registry_name == scenario_name: + return _metadata_to_summary(metadata) + return None + + @staticmethod + def _paginate( + *, + items: list[ScenarioSummary], + cursor: Optional[str], + limit: int, + ) -> tuple[list[ScenarioSummary], bool]: + """ + Apply cursor-based pagination. + + Args: + items: Full list of items. + cursor: Scenario name to start after. + limit: Maximum items per page. + + Returns: + Tuple of (paginated items, has_more flag). + """ + start_idx = 0 + if cursor: + for i, item in enumerate(items): + if item.scenario_name == cursor: + start_idx = i + 1 + break + + page = items[start_idx : start_idx + limit] + has_more = len(items) > start_idx + limit + return page, has_more + + +@lru_cache(maxsize=1) +def get_scenario_service() -> ScenarioService: + """ + Get the global scenario service instance. + + Returns: + The singleton ScenarioService instance. + """ + return ScenarioService() diff --git a/tests/unit/backend/test_scenario_service.py b/tests/unit/backend/test_scenario_service.py new file mode 100644 index 000000000..eccc762e4 --- /dev/null +++ b/tests/unit/backend/test_scenario_service.py @@ -0,0 +1,341 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Tests for backend scenario service and routes. +""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi import status +from fastapi.testclient import TestClient + +from pyrit.backend.main import app +from pyrit.backend.models.common import PaginationInfo +from pyrit.backend.models.scenarios import ScenarioListResponse, ScenarioSummary +from pyrit.backend.services.scenario_service import ScenarioService, get_scenario_service +from pyrit.registry import ScenarioMetadata + + +@pytest.fixture +def client() -> TestClient: + """Create a test client for the FastAPI app.""" + return TestClient(app) + + +@pytest.fixture(autouse=True) +def clear_service_cache(): + """Clear the scenario service singleton cache between tests.""" + get_scenario_service.cache_clear() + yield + get_scenario_service.cache_clear() + + +def _make_scenario_metadata( + *, + registry_name: str = "test.scenario", + class_name: str = "TestScenario", + description: str = "A test scenario", + default_strategy: str = "default", + all_strategies: tuple[str, ...] = ("prompt_sending", "role_play"), + aggregate_strategies: tuple[str, ...] = ("all", "default"), + default_datasets: tuple[str, ...] = ("test_dataset",), + max_dataset_size: int | None = None, +) -> ScenarioMetadata: + """Create a ScenarioMetadata instance for testing.""" + return ScenarioMetadata( + registry_name=registry_name, + class_name=class_name, + class_module="pyrit.scenario.scenarios.test", + class_description=description, + default_strategy=default_strategy, + all_strategies=all_strategies, + aggregate_strategies=aggregate_strategies, + default_datasets=default_datasets, + max_dataset_size=max_dataset_size, + ) + + +# ============================================================================ +# ScenarioService Unit Tests +# ============================================================================ + + +class TestScenarioServiceListScenarios: + """Tests for ScenarioService.list_scenarios_async.""" + + @pytest.mark.asyncio + async def test_list_scenarios_returns_empty_when_no_scenarios(self) -> None: + """Test that list returns empty list when no scenarios are registered.""" + with patch.object(ScenarioService, "__init__", lambda self: None): + service = ScenarioService() + service._registry = MagicMock() + service._registry.list_metadata.return_value = [] + + result = await service.list_scenarios_async() + + assert result.items == [] + assert result.pagination.has_more is False + + @pytest.mark.asyncio + async def test_list_scenarios_returns_scenarios_from_registry(self) -> None: + """Test that list returns scenarios from registry.""" + metadata = _make_scenario_metadata() + + with patch.object(ScenarioService, "__init__", lambda self: None): + service = ScenarioService() + service._registry = MagicMock() + service._registry.list_metadata.return_value = [metadata] + + result = await service.list_scenarios_async() + + assert len(result.items) == 1 + assert result.items[0].scenario_name == "test.scenario" + assert result.items[0].class_name == "TestScenario" + assert result.items[0].description == "A test scenario" + assert result.items[0].default_strategy == "default" + assert result.items[0].aggregate_strategies == ["all", "default"] + assert result.items[0].all_strategies == ["prompt_sending", "role_play"] + assert result.items[0].default_datasets == ["test_dataset"] + assert result.items[0].max_dataset_size is None + + @pytest.mark.asyncio + async def test_list_scenarios_paginates_with_limit(self) -> None: + """Test that list respects the limit parameter.""" + metadata_list = [ + _make_scenario_metadata(registry_name=f"test.scenario_{i}", class_name=f"Scenario{i}") for i in range(5) + ] + + with patch.object(ScenarioService, "__init__", lambda self: None): + service = ScenarioService() + service._registry = MagicMock() + service._registry.list_metadata.return_value = metadata_list + + result = await service.list_scenarios_async(limit=3) + + assert len(result.items) == 3 + assert result.pagination.has_more is True + assert result.pagination.next_cursor == "test.scenario_2" + + @pytest.mark.asyncio + async def test_list_scenarios_paginates_with_cursor(self) -> None: + """Test that list uses cursor for pagination.""" + metadata_list = [ + _make_scenario_metadata(registry_name=f"test.scenario_{i}", class_name=f"Scenario{i}") for i in range(5) + ] + + with patch.object(ScenarioService, "__init__", lambda self: None): + service = ScenarioService() + service._registry = MagicMock() + service._registry.list_metadata.return_value = metadata_list + + result = await service.list_scenarios_async(limit=2, cursor="test.scenario_1") + + assert len(result.items) == 2 + assert result.items[0].scenario_name == "test.scenario_2" + assert result.items[1].scenario_name == "test.scenario_3" + assert result.pagination.has_more is True + + @pytest.mark.asyncio + async def test_list_scenarios_last_page_has_more_false(self) -> None: + """Test that last page shows has_more=False.""" + metadata_list = [ + _make_scenario_metadata(registry_name=f"test.scenario_{i}", class_name=f"Scenario{i}") for i in range(3) + ] + + with patch.object(ScenarioService, "__init__", lambda self: None): + service = ScenarioService() + service._registry = MagicMock() + service._registry.list_metadata.return_value = metadata_list + + result = await service.list_scenarios_async(limit=5) + + assert len(result.items) == 3 + assert result.pagination.has_more is False + assert result.pagination.next_cursor is None + + @pytest.mark.asyncio + async def test_list_scenarios_includes_max_dataset_size(self) -> None: + """Test that max_dataset_size is included in response.""" + metadata = _make_scenario_metadata(max_dataset_size=10) + + with patch.object(ScenarioService, "__init__", lambda self: None): + service = ScenarioService() + service._registry = MagicMock() + service._registry.list_metadata.return_value = [metadata] + + result = await service.list_scenarios_async() + + assert result.items[0].max_dataset_size == 10 + + +class TestScenarioServiceGetScenario: + """Tests for ScenarioService.get_scenario_async.""" + + @pytest.mark.asyncio + async def test_get_scenario_returns_matching_scenario(self) -> None: + """Test that get returns the matching scenario.""" + metadata = _make_scenario_metadata(registry_name="foundry.red_team_agent") + + with patch.object(ScenarioService, "__init__", lambda self: None): + service = ScenarioService() + service._registry = MagicMock() + service._registry.list_metadata.return_value = [metadata] + + result = await service.get_scenario_async(scenario_name="foundry.red_team_agent") + + assert result is not None + assert result.scenario_name == "foundry.red_team_agent" + + @pytest.mark.asyncio + async def test_get_scenario_returns_none_for_missing(self) -> None: + """Test that get returns None when scenario not found.""" + with patch.object(ScenarioService, "__init__", lambda self: None): + service = ScenarioService() + service._registry = MagicMock() + service._registry.list_metadata.return_value = [] + + result = await service.get_scenario_async(scenario_name="nonexistent") + + assert result is None + + +# ============================================================================ +# Route Tests +# ============================================================================ + + +class TestScenarioRoutes: + """Tests for scenario API routes.""" + + def test_list_scenarios_returns_200(self, client: TestClient) -> None: + """Test that GET /api/scenarios returns 200.""" + with patch("pyrit.backend.routes.scenarios.get_scenario_service") as mock_get_service: + mock_service = MagicMock() + mock_service.list_scenarios_async = AsyncMock( + return_value=ScenarioListResponse( + items=[], + pagination=PaginationInfo(limit=50, has_more=False, next_cursor=None, prev_cursor=None), + ) + ) + mock_get_service.return_value = mock_service + + response = client.get("/api/scenarios") + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["items"] == [] + assert data["pagination"]["has_more"] is False + + def test_list_scenarios_with_items(self, client: TestClient) -> None: + """Test that GET /api/scenarios returns scenario data.""" + summary = ScenarioSummary( + scenario_name="foundry.red_team_agent", + class_name="RedTeamAgentScenario", + description="Red team agent testing", + default_strategy="default", + aggregate_strategies=["all", "default"], + all_strategies=["prompt_sending", "role_play"], + default_datasets=["airt_hate"], + max_dataset_size=10, + ) + + with patch("pyrit.backend.routes.scenarios.get_scenario_service") as mock_get_service: + mock_service = MagicMock() + mock_service.list_scenarios_async = AsyncMock( + return_value=ScenarioListResponse( + items=[summary], + pagination=PaginationInfo(limit=50, has_more=False, next_cursor=None, prev_cursor=None), + ) + ) + mock_get_service.return_value = mock_service + + response = client.get("/api/scenarios") + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert len(data["items"]) == 1 + item = data["items"][0] + assert item["scenario_name"] == "foundry.red_team_agent" + assert item["class_name"] == "RedTeamAgentScenario" + assert item["default_strategy"] == "default" + assert item["aggregate_strategies"] == ["all", "default"] + assert item["all_strategies"] == ["prompt_sending", "role_play"] + assert item["default_datasets"] == ["airt_hate"] + assert item["max_dataset_size"] == 10 + + def test_list_scenarios_passes_pagination_params(self, client: TestClient) -> None: + """Test that pagination params are forwarded to service.""" + with patch("pyrit.backend.routes.scenarios.get_scenario_service") as mock_get_service: + mock_service = MagicMock() + mock_service.list_scenarios_async = AsyncMock( + return_value=ScenarioListResponse( + items=[], + pagination=PaginationInfo(limit=10, has_more=False, next_cursor=None, prev_cursor=None), + ) + ) + mock_get_service.return_value = mock_service + + response = client.get("/api/scenarios?limit=10&cursor=test.scenario_1") + + assert response.status_code == status.HTTP_200_OK + mock_service.list_scenarios_async.assert_called_once_with(limit=10, cursor="test.scenario_1") + + def test_get_scenario_returns_200(self, client: TestClient) -> None: + """Test that GET /api/scenarios/{name} returns 200 when found.""" + summary = ScenarioSummary( + scenario_name="foundry.red_team_agent", + class_name="RedTeamAgentScenario", + description="Red team agent testing", + default_strategy="default", + aggregate_strategies=["all"], + all_strategies=["prompt_sending"], + default_datasets=["airt_hate"], + max_dataset_size=None, + ) + + with patch("pyrit.backend.routes.scenarios.get_scenario_service") as mock_get_service: + mock_service = MagicMock() + mock_service.get_scenario_async = AsyncMock(return_value=summary) + mock_get_service.return_value = mock_service + + response = client.get("/api/scenarios/foundry.red_team_agent") + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["scenario_name"] == "foundry.red_team_agent" + + def test_get_scenario_returns_404_when_not_found(self, client: TestClient) -> None: + """Test that GET /api/scenarios/{name} returns 404 when not found.""" + with patch("pyrit.backend.routes.scenarios.get_scenario_service") as mock_get_service: + mock_service = MagicMock() + mock_service.get_scenario_async = AsyncMock(return_value=None) + mock_get_service.return_value = mock_service + + response = client.get("/api/scenarios/nonexistent") + + assert response.status_code == status.HTTP_404_NOT_FOUND + + def test_get_scenario_with_dotted_name(self, client: TestClient) -> None: + """Test that dotted scenario names (e.g., 'foundry.red_team_agent') work in path.""" + summary = ScenarioSummary( + scenario_name="garak.encoding", + class_name="EncodingScenario", + description="Encoding scenario", + default_strategy="all", + aggregate_strategies=["all"], + all_strategies=["base64", "rot13"], + default_datasets=[], + max_dataset_size=None, + ) + + with patch("pyrit.backend.routes.scenarios.get_scenario_service") as mock_get_service: + mock_service = MagicMock() + mock_service.get_scenario_async = AsyncMock(return_value=summary) + mock_get_service.return_value = mock_service + + response = client.get("/api/scenarios/garak.encoding") + + assert response.status_code == status.HTTP_200_OK + mock_service.get_scenario_async.assert_called_once_with(scenario_name="garak.encoding")