diff --git a/.github/workflows/README.md b/.github/workflows/README.md new file mode 100644 index 00000000..c53d04c9 --- /dev/null +++ b/.github/workflows/README.md @@ -0,0 +1,101 @@ +# Running GitHub Actions Locally + +This project uses GitHub Actions for CI/CD. You can run these workflows locally using [act](https://github.com/nektos/act), which simulates GitHub Actions in a Docker environment. + +## Workflow Overview + +### Main Workflow: `tests.yml` ⭐ **RECOMMENDED** +**Primary workflow for all tests and linting** +- Runs on: `push` and `pull_request` to `main` and `develop` +- Jobs: + - `backend-lint` - Lints Python backend code + - `simulator-lint` - Lints Python simulator code + - `frontend-lint` - Lints TypeScript/JavaScript frontend code + - `backend-tests` - Runs backend unit and integration tests (Python 3.11, 3.12, 3.13) + - `frontend-tests` - Runs frontend type checking + - `e2e-tests` - Runs end-to-end tests with Playwright (includes proper dependency installation) + - `build` - Builds the frontend application + +**Use this workflow for:** +- ✅ Full CI/CD pipeline +- ✅ Comprehensive testing +- ✅ Pull request validation +- ✅ **E2E tests (properly configured with dependencies)** + +### Secondary Workflow: `e2e-tests.yml` +**Standalone E2E test workflow (legacy/alternative)** +- Runs on: `push`, `pull_request` to `main` and `develop`, and `workflow_dispatch` +- Single job: `e2e` - Runs only E2E tests + +**Use this workflow for:** +- Quick E2E test runs +- Manual E2E testing via workflow_dispatch +- Focused E2E test debugging + +**Note:** The main `tests.yml` workflow is **recommended** for most use cases as it includes all tests, linting, and properly configured E2E tests with dependency installation. + +## Prerequisites + +1. **Docker**: Ensure Docker is installed and running + ```bash + docker --version + ``` + +2. **Install act**: + - **macOS**: `brew install act` + - **Linux**: Download from [act releases](https://github.com/nektos/act/releases) + - **Windows**: Use WSL or download from releases + +## Usage + +### List all workflows +```bash +act -l +``` + +### Run all workflows +```bash +act +``` + +### Run a specific workflow +```bash +act -W .github/workflows/tests.yml +act -W .github/workflows/e2e-tests.yml +``` + +### Run a specific job +```bash +act -j backend-tests +act -j frontend-tests +act -j e2e-tests +``` + +### Run on specific event +```bash +act push +act pull_request +``` + +### Use secrets (if needed) +Create a `.secrets` file in the repository root: +``` +SECRET_NAME=secret_value +``` + +Then run: +```bash +act --secret-file .secrets +``` + +## Limitations + +- Services (PostgreSQL, Redis) are automatically set up by act +- Some actions may behave differently locally vs. on GitHub +- Large workflows may take longer locally + +## Troubleshooting + +- If Docker images fail to pull, use `act --pull=false` +- For verbose output: `act -v` +- To use a specific platform: `act --container-architecture linux/amd64` diff --git a/.github/workflows/backend-tests.yml b/.github/workflows/backend-tests.yml new file mode 100644 index 00000000..786d2e2c --- /dev/null +++ b/.github/workflows/backend-tests.yml @@ -0,0 +1,67 @@ +name: Backend Tests + +on: + push: + branches: [main, develop] + paths: + - 'backend/**' + - '.github/workflows/backend-tests.yml' + pull_request: + branches: [main, develop] + paths: + - 'backend/**' + - '.github/workflows/backend-tests.yml' + +jobs: + test: + runs-on: ubuntu-latest + services: + postgres: + image: postgres:15 + env: + POSTGRES_PASSWORD: postgres + POSTGRES_DB: postgres + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + ports: + - 5432:5432 + redis: + image: redis:8 + options: >- + --health-cmd "redis-cli ping" + --health-interval 10s + --health-timeout 5s + --health-retries 5 + ports: + - 6379:6379 + + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.13' + - name: Install dependencies + working-directory: backend + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + - name: Run unit tests + working-directory: backend + run: | + pytest tests/unit -v --cov=api --cov-report=xml + env: + DATABASE_URL: postgresql+asyncpg://postgres:postgres@localhost:5432/postgres + REDIS_URL: redis://localhost:6379 + - name: Run integration tests + working-directory: backend + run: | + pytest tests/integration -v + env: + DATABASE_URL: postgresql+asyncpg://postgres:postgres@localhost:5432/postgres + REDIS_URL: redis://localhost:6379 + + diff --git a/.github/workflows/e2e-tests.yml b/.github/workflows/e2e-tests.yml new file mode 100644 index 00000000..ce763aa9 --- /dev/null +++ b/.github/workflows/e2e-tests.yml @@ -0,0 +1,114 @@ +name: E2E Tests + +on: + push: + branches: [main, develop] + pull_request: + branches: [main, develop] + workflow_dispatch: + +jobs: + e2e: + runs-on: ubuntu-latest + services: + postgres: + image: postgres:15 + env: + POSTGRES_PASSWORD: postgres + POSTGRES_DB: postgres + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + ports: + - 5432:5432 + redis: + image: redis:8 + options: >- + --health-cmd "redis-cli ping" + --health-interval 10s + --health-timeout 5s + --health-retries 5 + ports: + - 6379:6379 + + steps: + - uses: actions/checkout@v4 + - name: Set up Node.js + uses: actions/setup-node@v4 + with: + node-version: '20' + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.13' + - name: Install backend dependencies + working-directory: backend + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + - name: Install frontend dependencies + working-directory: web + run: npm ci + - name: Install E2E test dependencies + working-directory: tests + run: npm ci + - name: Install Playwright browsers + working-directory: tests + run: npx playwright install --with-deps chromium + continue-on-error: true + - name: Run database migrations + working-directory: backend + run: | + alembic upgrade head + env: + DATABASE_URL: postgresql+asyncpg://postgres:postgres@localhost:5432/postgres + - name: Start backend + working-directory: backend + run: | + uvicorn main:app --host 0.0.0.0 --port 8000 & + env: + DATABASE_URL: postgresql+asyncpg://postgres:postgres@localhost:5432/postgres + REDIS_URL: redis://localhost:6379 + ISSUER_URI: http://localhost:8080/realms/tasks + PUBLIC_ISSUER_URI: http://localhost:8080/realms/tasks + CLIENT_ID: tasks-backend + CLIENT_SECRET: tasks-secret + ENV: test + - name: Build frontend + working-directory: web + run: npm run build + - name: Start frontend + working-directory: web + run: | + npm start & + env: + NEXT_PUBLIC_API_URL: http://localhost:8000/graphql + - name: Wait for backend + run: | + echo "Waiting for backend to start..." + timeout 120 bash -c 'until curl -f -s http://localhost:8000/health > /dev/null 2>&1; do sleep 2; done' + echo "Backend is ready!" + + - name: Wait for frontend + run: | + echo "Waiting for frontend to start..." + timeout 120 bash -c 'until curl -f -s http://localhost:3000 > /dev/null 2>&1; do sleep 2; done' + echo "Frontend is ready!" + - name: Run E2E tests + working-directory: tests + run: | + npx playwright test + env: + E2E_BASE_URL: http://localhost:3000 + CI: true + - name: Upload test results + if: always() + uses: actions/upload-artifact@v4 + with: + name: playwright-report + path: playwright-report/ + retention-days: 30 + + diff --git a/.github/workflows/frontend-tests.yml b/.github/workflows/frontend-tests.yml new file mode 100644 index 00000000..8788a47f --- /dev/null +++ b/.github/workflows/frontend-tests.yml @@ -0,0 +1,31 @@ +name: Frontend Tests + +on: + push: + branches: [main, develop] + paths: + - 'web/**' + - '.github/workflows/frontend-tests.yml' + pull_request: + branches: [main, develop] + paths: + - 'web/**' + - '.github/workflows/frontend-tests.yml' + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Set up Node.js + uses: actions/setup-node@v4 + with: + node-version: '20' + - name: Install dependencies + working-directory: web + run: npm ci + - name: Run linter + working-directory: web + run: npm run lint + + diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 00000000..68c8266a --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,402 @@ +name: Tests + +on: + push: + branches: [main, develop] + pull_request: + branches: [main, develop] + +jobs: + backend-lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.13" + + - name: Install dependencies + working-directory: backend + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + + - name: Run linting + working-directory: backend + run: | + pip install ruff + ruff check . --output-format=concise + + simulator-lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.13" + + - name: Install dependencies + working-directory: simulator + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + + - name: Run linting + working-directory: simulator + run: | + pip install ruff + ruff check . --output-format=concise + + frontend-lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Node.js + uses: actions/setup-node@v4 + with: + node-version: "20" + cache: "npm" + cache-dependency-path: web/package-lock.json + + - name: Install dependencies + working-directory: web + run: npm ci + + - name: Run linter + working-directory: web + run: npm run lint + + backend-tests: + needs: [backend-lint] + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.11", "3.12", "3.13"] + + services: + postgres: + image: postgres:15 + env: + POSTGRES_USER: test + POSTGRES_PASSWORD: test + POSTGRES_DB: test + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + ports: + - 5432:5432 + + redis: + image: redis:7-alpine + options: >- + --health-cmd "redis-cli ping" + --health-interval 10s + --health-timeout 5s + --health-retries 5 + ports: + - 6379:6379 + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Cache pip packages + uses: actions/cache@v4 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip-${{ hashFiles('backend/requirements.txt') }} + restore-keys: | + ${{ runner.os }}-pip- + + - name: Install dependencies + working-directory: backend + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + + - name: Run unit tests + working-directory: backend + run: | + pytest tests/unit -v --cov=api --cov=database --cov-report=xml --cov-report=term + + - name: Run integration tests + working-directory: backend + run: | + pytest tests/integration -v + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v4 + with: + file: ./backend/coverage.xml + flags: backend + name: backend-coverage + + frontend-tests: + needs: [frontend-lint] + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - name: Set up Node.js + uses: actions/setup-node@v4 + with: + node-version: "20" + cache: "npm" + cache-dependency-path: web/package-lock.json + + - name: Install dependencies + working-directory: web + run: npm ci + + - name: Type check + working-directory: web + run: npx tsc --noEmit + + e2e-tests: + needs: [backend-tests, frontend-tests] + runs-on: ubuntu-latest + + services: + postgres: + image: postgres:15 + env: + POSTGRES_USER: test + POSTGRES_PASSWORD: test + POSTGRES_DB: test + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + ports: + - 5432:5432 + + redis: + image: redis:7-alpine + options: >- + --health-cmd "redis-cli ping" + --health-interval 10s + --health-timeout 5s + --health-retries 5 + ports: + - 6379:6379 + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.13" + + - name: Set up Node.js + uses: actions/setup-node@v4 + with: + node-version: "20" + cache: "npm" + cache-dependency-path: web/package-lock.json + + - name: Install backend dependencies + working-directory: backend + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + + - name: Install frontend dependencies + working-directory: web + run: npm ci + + - name: Install E2E test dependencies + working-directory: tests + run: npm ci + + - name: Install Playwright browsers + working-directory: tests + run: npx playwright install --with-deps chromium + continue-on-error: true + + - name: Run database migrations + working-directory: backend + run: | + alembic upgrade head + env: + DATABASE_URL: postgresql+asyncpg://test:test@localhost:5432/test + + - name: Start backend server + working-directory: backend + run: | + uvicorn main:app --host 0.0.0.0 --port 8000 > /tmp/backend.log 2>&1 & + BACKEND_PID=$! + echo $BACKEND_PID > /tmp/backend.pid + echo "Backend started with PID: $BACKEND_PID" + sleep 3 + env: + DATABASE_URL: postgresql+asyncpg://test:test@localhost:5432/test + REDIS_URL: redis://localhost:6379 + ISSUER_URI: http://localhost:8080/realms/tasks + PUBLIC_ISSUER_URI: http://localhost:8080/realms/tasks + CLIENT_ID: tasks-backend + CLIENT_SECRET: tasks-secret + ENV: test + INFLUXDB_URL: http://localhost:8086 + INFLUXDB_TOKEN: test-token + INFLUXDB_ORG: test + INFLUXDB_BUCKET: test + + - name: Build frontend + working-directory: web + run: npm run build + env: + NEXT_PUBLIC_API_URL: http://localhost:8000/graphql + + - name: Start frontend server + working-directory: web + run: | + npm start > /tmp/frontend.log 2>&1 & + FRONTEND_PID=$! + echo $FRONTEND_PID > /tmp/frontend.pid + echo "Frontend started with PID: $FRONTEND_PID" + sleep 3 + env: + NEXT_PUBLIC_API_URL: http://localhost:8000/graphql + + - name: Wait for backend + run: | + echo "Waiting for backend to start..." + sleep 10 + for i in {1..60}; do + if curl -f -s http://localhost:8000/health > /dev/null 2>&1; then + echo "Backend is ready!" + exit 0 + fi + if [ $i -le 10 ]; then + echo "Attempt $i/60: Backend not ready yet..." + fi + sleep 2 + done + echo "Backend failed to start after 120 seconds" + echo "=== Backend Log ===" + cat /tmp/backend.log || echo "No backend log found" + echo "=== Checking if process is running ===" + ps aux | grep uvicorn || echo "No uvicorn process found" + exit 1 + + - name: Wait for frontend + run: | + echo "Waiting for frontend to start..." + sleep 5 + for i in {1..60}; do + if curl -f -s http://localhost:3000 > /dev/null 2>&1; then + echo "Frontend is ready!" + exit 0 + fi + echo "Attempt $i/60: Frontend not ready yet..." + sleep 2 + done + echo "Frontend failed to start after 120 seconds" + echo "=== Frontend Log ===" + cat /tmp/frontend.log || echo "No frontend log found" + exit 1 + + - name: Verify servers are running + run: | + echo "=== Verifying servers ===" + if curl -f -s http://localhost:8000/health > /dev/null 2>&1; then + echo "✓ Backend is running" + else + echo "✗ Backend is not running" + echo "Backend log:" + tail -20 /tmp/backend.log || echo "No backend log" + exit 1 + fi + if curl -f -s http://localhost:3000 > /dev/null 2>&1; then + echo "✓ Frontend is running" + else + echo "✗ Frontend is not running" + echo "Frontend log:" + tail -20 /tmp/frontend.log || echo "No frontend log" + exit 1 + fi + echo "=== Server processes ===" + ps aux | grep -E "(uvicorn|node)" | grep -v grep || echo "No server processes found" + + - name: Run E2E tests + working-directory: tests + env: + E2E_BASE_URL: http://localhost:3000 + CI: true + run: | + echo "E2E_BASE_URL is set to: $E2E_BASE_URL" + echo "Testing connection to frontend..." + curl -f -s http://localhost:3000 > /dev/null && echo "Frontend is accessible" || echo "Frontend is not accessible" + npx playwright test + + - name: Upload Playwright report + if: always() + uses: actions/upload-artifact@v4 + with: + name: playwright-report + path: tests/playwright-report/ + retention-days: 30 + + - name: Upload server logs + if: failure() + uses: actions/upload-artifact@v4 + with: + name: server-logs + path: | + /tmp/backend.log + /tmp/frontend.log + retention-days: 7 + + build: + needs: [backend-tests, frontend-tests, e2e-tests, simulator-lint] + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.13" + + - name: Set up Node.js + uses: actions/setup-node@v4 + with: + node-version: "20" + cache: "npm" + cache-dependency-path: web/package-lock.json + + - name: Install backend dependencies + working-directory: backend + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + + - name: Install frontend dependencies + working-directory: web + run: npm ci + + - name: Build frontend + working-directory: web + run: npm run build + + - name: Upload build artifacts + if: success() + uses: actions/upload-artifact@v4 + with: + name: frontend-build + path: web/build + retention-days: 7 diff --git a/.gitignore b/.gitignore index 87dedb74..fd70cdb7 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,6 @@ # misc .DS_Store *.pem +coverage.xml +.coverage +htmlcov/ diff --git a/E2E_TESTING.md b/E2E_TESTING.md new file mode 100644 index 00000000..6938a89c --- /dev/null +++ b/E2E_TESTING.md @@ -0,0 +1,61 @@ +# E2E Testing Guide + +## Running E2E Tests Locally + +### Prerequisites + +1. **Start Docker services** (PostgreSQL and Redis): + ```bash + docker-compose -f docker-compose.dev.yml up -d postgres redis + ``` + +2. **Start the backend server**: + ```bash + cd backend + source test_env/bin/activate # or your virtual environment + DATABASE_URL="postgresql+asyncpg://postgres:postgres@localhost:5432/postgres" \ + REDIS_URL="redis://localhost:6379" \ + ENV="test" \ + uvicorn main:app --host 0.0.0.0 --port 8000 + ``` + +3. **Start the frontend server**: + ```bash + cd web + npm run build # if not already built + NEXT_PUBLIC_API_URL="http://localhost:8000/graphql" npm start + ``` + +4. **Wait for servers to be ready**: + - Backend: `http://localhost:8000/health` should return `{"status": "ok"}` + - Frontend: `http://localhost:3000` should return HTTP 200 + +### Running Tests + +```bash +cd tests +E2E_BASE_URL="http://localhost:3000" CI=true npx playwright test +``` + +## NixOS Limitation + +**Note for NixOS users**: Playwright's Chromium browser cannot run directly on NixOS due to dynamic linking limitations. The error message will indicate: +``` +NixOS cannot run dynamically linked executables intended for generic +linux environments out of the box. +``` + +**Solutions for NixOS**: +1. Use GitHub Actions to run E2E tests (recommended) +2. Use a Docker container to run the tests +3. Configure NixOS with proper FHS (Filesystem Hierarchy Standard) support + +The tests are designed to work correctly on GitHub Actions (Ubuntu runners). + +## GitHub Actions + +E2E tests run automatically on: +- Push to `main` or `develop` branches +- Pull requests to `main` or `develop` branches + +The main workflow (`tests.yml`) includes proper E2E test setup with dependency installation. diff --git a/README.md b/README.md index c37331f4..7f1adaa3 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,8 @@ # helpwave tasks +[![Tests](https://github.com/helpwave/tasks/actions/workflows/tests.yml/badge.svg)](https://github.com/helpwave/tasks/actions/workflows/tests.yml) +[![License](https://img.shields.io/badge/license-MIT-blue.svg)](LICENSE) + **helpwave tasks** is a modern, open-source task and ward-management platform tailored for healthcare - designed to bring clarity, efficiency and structure to hospitals, wards and clinical workflows. ## Quick Start @@ -117,6 +120,47 @@ Once the development environment is running: - **keycloak/** - Keycloak realm configuration - **scaffold/** - Initial data for hospital structure +## Testing + +### Running Tests Locally + +**Backend Tests:** +```bash +cd backend +python -m pytest tests/unit -v +python -m pytest tests/integration -v +``` + +**Frontend Linting:** +```bash +cd web +npm run lint +``` + +**E2E Tests:** +```bash +cd tests +npm install +npx playwright test +``` + +### Running GitHub Actions Locally + +You can run GitHub Actions workflows locally using [act](https://github.com/nektos/act). See [.github/workflows/README.md](.github/workflows/README.md) for detailed instructions. + +Quick start: +```bash +# Install act (requires Docker) +brew install act # macOS +# or download from https://github.com/nektos/act/releases + +# Run all workflows +act + +# Run specific job +act -j backend-tests +``` + ## Docker Images All components are containerized and available on GitHub Container Registry: diff --git a/TESTING.md b/TESTING.md new file mode 100644 index 00000000..254a1feb --- /dev/null +++ b/TESTING.md @@ -0,0 +1,153 @@ +# Testing Guide + +This document describes how to run tests for all components of the tasks project. + +## Quick Test Commands + +### Backend Tests +```bash +cd backend +python -m venv venv +source venv/bin/activate # On Windows: venv\Scripts\activate +pip install -r requirements.txt +pytest tests/unit -v +pytest tests/integration -v +pytest tests/ -v # Run all tests +``` + +### Frontend Linting +```bash +cd web +npm install +npm run lint +``` + +### E2E Tests +```bash +cd tests +npm install +npx playwright install chromium +npx playwright test +``` + +### All Linting +```bash +# Backend +cd backend +ruff check . + +# Simulator +cd simulator +ruff check . + +# Frontend +cd web +npm run lint +``` + +## Running GitHub Actions Locally + +Use [act](https://github.com/nektos/act) to run GitHub Actions workflows locally. + +### Installation + +**macOS:** +```bash +brew install act +``` + +**Linux:** +```bash +# Download from https://github.com/nektos/act/releases +# Or use package manager if available +``` + +**Prerequisites:** +- Docker must be installed and running + +### Usage + +```bash +# List all workflows +act -l + +# Run all workflows +act + +# Run specific workflow file +act -W .github/workflows/tests.yml + +# Run specific job +act -j backend-tests +act -j frontend-tests +act -j e2e-tests + +# Run on specific event +act push +act pull_request +``` + +### Troubleshooting + +- If Docker images fail: `act --pull=false` +- For verbose output: `act -v` +- To use secrets: Create `.secrets` file and use `act --secret-file .secrets` + +See [.github/workflows/README.md](.github/workflows/README.md) for more details. + +## CI/CD Pipeline + +The GitHub Actions workflow (`.github/workflows/tests.yml`) runs: + +1. **Backend Tests** - Unit and integration tests across Python 3.11, 3.12, 3.13 +2. **Frontend Tests** - TypeScript type checking and ESLint +3. **E2E Tests** - Playwright end-to-end tests + +All tests run automatically on: +- Push to `main` or `develop` branches +- Pull requests to `main` or `develop` branches + +## Test Structure + +``` +backend/ + tests/ + unit/ # Unit tests for services and utilities + integration/ # Integration tests for resolvers + conftest.py # Shared test fixtures + +tests/ + e2e/ # End-to-end Playwright tests + *.spec.ts # Test specifications + playwright.config.ts + package.json # E2E test dependencies +``` + +## Fixing Common Issues + +### npm ci Error (EUSAGE) + +If you see `npm ci` errors about lock file sync: +```bash +cd web # or tests +rm -rf node_modules package-lock.json +npm install +``` + +### Playwright Browser Not Found + +```bash +cd tests +npx playwright install chromium +``` + +### Python Module Not Found + +Ensure you're in a virtual environment with dependencies installed: +```bash +cd backend +python -m venv venv +source venv/bin/activate +pip install -r requirements.txt +``` + diff --git a/backend/api/audit.py b/backend/api/audit.py index 15bc1d08..aa0c37a6 100644 --- a/backend/api/audit.py +++ b/backend/api/audit.py @@ -7,7 +7,13 @@ from typing import Any, Callable import strawberry -from config import INFLUXDB_BUCKET, INFLUXDB_ORG, INFLUXDB_TOKEN, INFLUXDB_URL, LOGGER +from config import ( + INFLUXDB_BUCKET, + INFLUXDB_ORG, + INFLUXDB_TOKEN, + INFLUXDB_URL, + LOGGER, +) from influxdb_client import InfluxDBClient, Point from influxdb_client.client.write_api import SYNCHRONOUS @@ -21,7 +27,9 @@ class AuditLogger: @classmethod def _get_client(cls) -> InfluxDBClient | None: if not INFLUXDB_TOKEN: - logger.warning("InfluxDB token not configured, skipping audit logging") + logger.warning( + "InfluxDB token not configured, skipping audit logging" + ) return None if cls._client is None: try: @@ -31,8 +39,12 @@ def _get_client(cls) -> InfluxDBClient | None: token=INFLUXDB_TOKEN, org=INFLUXDB_ORG, ) - cls._write_api = cls._client.write_api(write_options=SYNCHRONOUS) - logger.info(f"Successfully connected to InfluxDB (org: {INFLUXDB_ORG}, bucket: {INFLUXDB_BUCKET})") + cls._write_api = cls._client.write_api( + write_options=SYNCHRONOUS + ) + logger.info( + f"Successfully connected to InfluxDB (org: {INFLUXDB_ORG}, bucket: {INFLUXDB_BUCKET})" + ) except Exception as e: logger.error(f"Failed to connect to InfluxDB: {e}") return None @@ -48,7 +60,9 @@ def log_activity( ) -> None: client = cls._get_client() if not client or not cls._write_api: - logger.debug(f"Skipping InfluxDB log for activity {activity_name} (case_id: {case_id}) - client not available") + logger.debug( + f"Skipping InfluxDB log for activity {activity_name} (case_id: {case_id}) - client not available" + ) return try: @@ -66,11 +80,17 @@ def log_activity( if context: point = point.field("context", json.dumps(context)) - logger.info(f"Writing to InfluxDB: activity={activity_name}, case_id={case_id}, user_id={user_id}, bucket={INFLUXDB_BUCKET}") + logger.info( + f"Writing to InfluxDB: activity={activity_name}, case_id={case_id}, user_id={user_id}, bucket={INFLUXDB_BUCKET}" + ) cls._write_api.write(bucket=INFLUXDB_BUCKET, record=point) - logger.debug(f"Successfully wrote to InfluxDB: activity={activity_name}, case_id={case_id}") + logger.debug( + f"Successfully wrote to InfluxDB: activity={activity_name}, case_id={case_id}" + ) except Exception as e: - logger.error(f"Failed to write to InfluxDB: activity={activity_name}, case_id={case_id}, error={e}") + logger.error( + f"Failed to write to InfluxDB: activity={activity_name}, case_id={case_id}, error={e}" + ) @classmethod def calculate_checksum(cls, data: dict[str, Any] | Any) -> str: @@ -91,16 +111,22 @@ def decorator(func: Callable) -> Callable: @wraps(func) async def wrapper(*args: Any, **kwargs: Any) -> Any: - logger.debug(f"Audit decorator called for function: {func.__name__}, args count: {len(args)}, kwargs: {list(kwargs.keys())}, param_names: {param_names}") + logger.debug( + f"Audit decorator called for function: {func.__name__}, args count: {len(args)}, kwargs: {list(kwargs.keys())}, param_names: {param_names}" + ) info = None if "info" in kwargs: info = kwargs["info"] - logger.debug(f"Found Info object in kwargs for {func.__name__}") + logger.debug( + f"Found Info object in kwargs for {func.__name__}" + ) elif info_param_index is not None and len(args) > info_param_index: info = args[info_param_index] - logger.debug(f"Found Info object in args[{info_param_index}] for {func.__name__}") + logger.debug( + f"Found Info object in args[{info_param_index}] for {func.__name__}" + ) else: for i, arg in enumerate(args): if arg is None: @@ -108,24 +134,32 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any: try: if hasattr(arg, "context"): info = arg - logger.debug(f"Found Info object in args[{i}] for {func.__name__} by context attribute") + logger.debug( + f"Found Info object in args[{i}] for {func.__name__} by context attribute" + ) break elif hasattr(arg, "__class__"): type_name = type(arg).__name__ if "Info" in type_name: info = arg - logger.debug(f"Found Info object in args[{i}] for {func.__name__} by type name: {type_name}") + logger.debug( + f"Found Info object in args[{i}] for {func.__name__} by type name: {type_name}" + ) break except Exception as e: logger.debug(f"Error checking arg[{i}]: {e}") continue if not info: - logger.warning(f"Audit decorator: No Info object found for {func.__name__}, skipping audit log. Args: {[type(a).__name__ if a is not None else 'None' for a in args]}, Kwargs keys: {list(kwargs.keys())}, info_param_index: {info_param_index}") + logger.warning( + f"Audit decorator: No Info object found for {func.__name__}, skipping audit log. Args: {[type(a).__name__ if a is not None else 'None' for a in args]}, Kwargs keys: {list(kwargs.keys())}, info_param_index: {info_param_index}" + ) return await func(*args, **kwargs) if not hasattr(info, "context"): - logger.warning(f"Audit decorator: Info object found but no context attribute for {func.__name__}, skipping audit log") + logger.warning( + f"Audit decorator: Info object found but no context attribute for {func.__name__}, skipping audit log" + ) return await func(*args, **kwargs) context = info.context @@ -137,9 +171,13 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any: if "id" in kwargs: entity_id = str(kwargs["id"]) - logger.debug(f"Found id in kwargs: {entity_id} for {func.__name__}") + logger.debug( + f"Found id in kwargs: {entity_id} for {func.__name__}" + ) elif len(args) > 1: - logger.debug(f"Checking args[1] for {func.__name__}: {type(args[1])}, value: {args[1]}") + logger.debug( + f"Checking args[1] for {func.__name__}: {type(args[1])}, value: {args[1]}" + ) if isinstance(args[1], str): entity_id = args[1] elif hasattr(args[1], "__str__"): @@ -147,7 +185,9 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any: if entity_id: case_id = entity_id - logger.debug(f"Set case_id from entity_id: {case_id} for {func.__name__}") + logger.debug( + f"Set case_id from entity_id: {case_id} for {func.__name__}" + ) activity = activity_name or func.__name__ @@ -173,7 +213,9 @@ def serialize_payload(obj: Any) -> Any: if serialized is not None: result[k] = serialized except Exception as e: - logger.debug(f"Error serializing field {k}: {e}") + logger.debug( + f"Error serializing field {k}: {e}" + ) result[k] = str(v) return result elif hasattr(obj, "__annotations__"): @@ -183,15 +225,23 @@ def serialize_payload(obj: Any) -> Any: if not attr_name.startswith("_"): try: attr_value = getattr(obj, attr_name, None) - if attr_value is not strawberry.UNSET and attr_value is not None: + if ( + attr_value is not strawberry.UNSET + and attr_value is not None + ): serialized = serialize_payload(attr_value) if serialized is not None: result[attr_name] = serialized except Exception as e: - logger.debug(f"Error serializing annotation {attr_name}: {e}") + logger.debug( + f"Error serializing annotation {attr_name}: {e}" + ) try: attr_value = getattr(obj, attr_name, None) - if attr_value is not strawberry.UNSET and attr_value is not None: + if ( + attr_value is not strawberry.UNSET + and attr_value is not None + ): result[attr_name] = str(attr_value) except Exception: pass @@ -216,18 +266,26 @@ def serialize_payload(obj: Any) -> Any: audit_context["payload"] = payload payload_json = json.dumps(payload, default=str) - logger.debug(f"Calling function {func.__name__} with activity={activity}, case_id={case_id}, payload keys: {list(payload.keys())}, payload size: {len(payload_json)} bytes") + logger.debug( + f"Calling function {func.__name__} with activity={activity}, case_id={case_id}, payload keys: {list(payload.keys())}, payload size: {len(payload_json)} bytes" + ) result = await func(*args, **kwargs) if hasattr(result, "id"): case_id = str(result.id) - logger.debug(f"Set case_id from result.id: {case_id} for {func.__name__}") + logger.debug( + f"Set case_id from result.id: {case_id} for {func.__name__}" + ) elif isinstance(result, dict) and "id" in result: case_id = str(result["id"]) - logger.debug(f"Set case_id from result dict: {case_id} for {func.__name__}") + logger.debug( + f"Set case_id from result dict: {case_id} for {func.__name__}" + ) if case_id: - logger.info(f"Audit decorator: Logging activity {activity} for case_id {case_id} (user: {user_id}), payload included: {len(payload) > 0}") + logger.info( + f"Audit decorator: Logging activity {activity} for case_id {case_id} (user: {user_id}), payload included: {len(payload) > 0}" + ) AuditLogger.log_activity( case_id=case_id, activity_name=activity, @@ -235,7 +293,9 @@ def serialize_payload(obj: Any) -> Any: context=audit_context, ) else: - logger.warning(f"Audit decorator: No case_id found for {func.__name__}, skipping audit log. Result type: {type(result)}, has id attr: {hasattr(result, 'id') if result else 'N/A'}") + logger.warning( + f"Audit decorator: No case_id found for {func.__name__}, skipping audit log. Result type: {type(result)}, has id attr: {hasattr(result, 'id') if result else 'N/A'}" + ) return result diff --git a/backend/api/context.py b/backend/api/context.py index 7a6206a4..b736cc4a 100644 --- a/backend/api/context.py +++ b/backend/api/context.py @@ -39,7 +39,6 @@ async def get_context( email = user_payload.get("email") picture = user_payload.get("picture") - # Extract organizations from OIDC token (can be array or single value) organizations_raw = user_payload.get("organization") organizations = None if organizations_raw: diff --git a/backend/api/inputs.py b/backend/api/inputs.py index 550ec81a..03c7827b 100644 --- a/backend/api/inputs.py +++ b/backend/api/inputs.py @@ -70,8 +70,12 @@ class CreatePatientInput: assigned_location_id: strawberry.ID | None = None assigned_location_ids: list[strawberry.ID] | None = None clinic_id: strawberry.ID # Required: location node from kind CLINIC - position_id: strawberry.ID | None = None # Optional: location node from type hospital, practice, clinic, ward, bed or room - team_ids: list[strawberry.ID] | None = None # Array: location nodes from type clinic, team, practice, hospital + position_id: strawberry.ID | None = ( + None # Optional: location node from type hospital, practice, clinic, ward, bed or room + ) + team_ids: list[strawberry.ID] | None = ( + None # Array: location nodes from type clinic, team, practice, hospital + ) properties: list[PropertyValueInput] | None = None state: PatientState | None = None diff --git a/backend/api/resolvers/base.py b/backend/api/resolvers/base.py new file mode 100644 index 00000000..bcb98fe8 --- /dev/null +++ b/backend/api/resolvers/base.py @@ -0,0 +1,128 @@ +from collections.abc import AsyncGenerator +from typing import Generic, TypeVar + +import strawberry +from api.context import Info +from api.services.base import BaseRepository +from api.services.notifications import ( + notify_entity_created, + notify_entity_deleted, + notify_entity_update, +) +from api.services.subscription import create_redis_subscription +from sqlalchemy.ext.asyncio import AsyncSession + +ModelType = TypeVar("ModelType") + + +class BaseQueryResolver(Generic[ModelType]): + def __init__(self, model: type[ModelType]): + self.model = model + + def get_repository(self, db: AsyncSession) -> BaseRepository[ModelType]: + return BaseRepository(db, self.model) + + async def _get_by_id_impl( + self, info: Info, id: strawberry.ID + ) -> ModelType | None: + repo = self.get_repository(info.context.db) + return await repo.get_by_id(id) + + async def _get_all_impl(self, info: Info) -> list[ModelType]: + repo = self.get_repository(info.context.db) + return await repo.get_all() + + @strawberry.field + async def get_by_id( + self, info: Info, id: strawberry.ID + ) -> ModelType | None: + return await self._get_by_id_impl(info, id) + + @strawberry.field + async def get_all(self, info: Info) -> list[ModelType]: + return await self._get_all_impl(info) + + +class BaseMutationResolver(Generic[ModelType]): + @staticmethod + def get_repository(db: AsyncSession, model: type[ModelType]) -> BaseRepository[ModelType]: + return BaseRepository(db, model) + + @staticmethod + async def delete_entity( + info: Info, + entity: ModelType, + model: type[ModelType], + entity_name: str, + related_entity_type: str | None = None, + related_entity_id: str | None = None, + ) -> None: + repo = BaseRepository(info.context.db, model) + entity_id = entity.id + await repo.delete(entity) + await notify_entity_deleted( + entity_name, entity_id, related_entity_type, related_entity_id + ) + + @staticmethod + async def create_and_notify( + info: Info, + entity: ModelType, + model: type[ModelType], + entity_name: str, + related_entity_type: str | None = None, + related_entity_id: str | None = None, + ) -> ModelType: + repo = BaseRepository(info.context.db, model) + await repo.create(entity) + await notify_entity_created(entity_name, entity.id) + if related_entity_type and related_entity_id: + await notify_entity_update(related_entity_type, related_entity_id) + return entity + + @staticmethod + async def update_and_notify( + info: Info, + entity: ModelType, + model: type[ModelType], + entity_name: str, + related_entity_type: str | None = None, + related_entity_id: str | None = None, + ) -> ModelType: + repo = BaseRepository(info.context.db, model) + await repo.update(entity) + await notify_entity_update( + entity_name, entity.id, related_entity_type, related_entity_id + ) + return entity + + +class BaseSubscriptionResolver: + @staticmethod + async def entity_created( + info: Info, entity_name: str + ) -> AsyncGenerator[strawberry.ID, None]: + async for entity_id in create_redis_subscription( + f"{entity_name}_created" + ): + yield entity_id + + @staticmethod + async def entity_updated( + info: Info, + entity_name: str, + entity_id: strawberry.ID | None = None, + ) -> AsyncGenerator[strawberry.ID, None]: + async for updated_id in create_redis_subscription( + f"{entity_name}_updated", entity_id + ): + yield updated_id + + @staticmethod + async def entity_deleted( + info: Info, entity_name: str + ) -> AsyncGenerator[strawberry.ID, None]: + async for entity_id in create_redis_subscription( + f"{entity_name}_deleted" + ): + yield entity_id diff --git a/backend/api/resolvers/patient.py b/backend/api/resolvers/patient.py index dc9985cc..8d3118da 100644 --- a/backend/api/resolvers/patient.py +++ b/backend/api/resolvers/patient.py @@ -3,44 +3,16 @@ import strawberry from api.audit import audit_log from api.context import Info -from api.inputs import CreatePatientInput, PatientState, Sex, UpdatePatientInput +from api.inputs import CreatePatientInput, PatientState, UpdatePatientInput +from api.resolvers.base import BaseMutationResolver, BaseSubscriptionResolver +from api.services.checksum import validate_checksum +from api.services.location import LocationService +from api.services.property import PropertyService from api.types.patient import PatientType from database import models -from database.session import publish_to_redis, redis_client from sqlalchemy import select from sqlalchemy.orm import aliased, selectinload -from .utils import process_properties - - -def validate_location_kind(location: models.LocationNode, expected_kind: str, field_name: str) -> None: - """Validate that a location has the expected kind.""" - if location.kind.upper() != expected_kind.upper(): - raise Exception( - f"{field_name} must be a location of kind {expected_kind}, " - f"but got {location.kind}" - ) - - -def validate_position_kind(location: models.LocationNode, field_name: str) -> None: - """Validate that a location is a valid position type.""" - allowed_kinds = {"HOSPITAL", "PRACTICE", "CLINIC", "WARD", "BED", "ROOM"} - if location.kind.upper() not in allowed_kinds: - raise Exception( - f"{field_name} must be a location of kind HOSPITAL, PRACTICE, CLINIC, " - f"WARD, BED, or ROOM, but got {location.kind}" - ) - - -def validate_team_kind(location: models.LocationNode, field_name: str) -> None: - """Validate that a location is a valid team type.""" - allowed_kinds = {"CLINIC", "TEAM", "PRACTICE", "HOSPITAL"} - if location.kind.upper() not in allowed_kinds: - raise Exception( - f"{field_name} must be a location of kind CLINIC, TEAM, PRACTICE, " - f"or HOSPITAL, but got {location.kind}" - ) - @strawberry.type class PatientQuery: @@ -78,7 +50,9 @@ async def patients( state_values = [s.value for s in states] query = query.where(models.Patient.state.in_(state_values)) else: - query = query.where(models.Patient.state == PatientState.ADMITTED.value) + query = query.where( + models.Patient.state == PatientState.ADMITTED.value + ) if location_node_id: cte = ( select(models.LocationNode.id) @@ -137,7 +111,15 @@ async def recent_patients( @strawberry.type -class PatientMutation: +class PatientMutation(BaseMutationResolver[models.Patient]): + @staticmethod + def _get_property_service(db) -> PropertyService: + return PropertyService(db) + + @staticmethod + def _get_location_service(db) -> LocationService: + return LocationService(db) + @strawberry.mutation @audit_log("create_patient") async def create_patient( @@ -146,44 +128,21 @@ async def create_patient( data: CreatePatientInput, ) -> PatientType: db = info.context.db - initial_state = data.state.value if data.state else PatientState.WAIT.value - - clinic_result = await db.execute( - select(models.LocationNode).where( - models.LocationNode.id == data.clinic_id, - ), + location_service = PatientMutation._get_location_service(db) + initial_state = ( + data.state.value if data.state else PatientState.WAIT.value ) - clinic = clinic_result.scalars().first() - if not clinic: - raise Exception(f"Clinic location with id {data.clinic_id} not found") - validate_location_kind(clinic, "CLINIC", "clinic_id") - position = None + await location_service.validate_and_get_clinic(data.clinic_id) + if data.position_id: - position_result = await db.execute( - select(models.LocationNode).where( - models.LocationNode.id == data.position_id, - ), - ) - position = position_result.scalars().first() - if not position: - raise Exception(f"Position location with id {data.position_id} not found") - validate_position_kind(position, "position_id") + await location_service.validate_and_get_position(data.position_id) teams = [] if data.team_ids: - teams_result = await db.execute( - select(models.LocationNode).where( - models.LocationNode.id.in_(data.team_ids), - ), + teams = await location_service.validate_and_get_teams( + data.team_ids ) - teams = list(teams_result.scalars().all()) - if len(teams) != len(data.team_ids): - found_ids = {t.id for t in teams} - missing_ids = set(data.team_ids) - found_ids - raise Exception(f"Team locations with ids {missing_ids} not found") - for team in teams: - validate_team_kind(team, "team_ids") new_patient = models.Patient( firstname=data.firstname, @@ -200,36 +159,28 @@ async def create_patient( new_patient.teams = teams if data.assigned_location_ids: - result = await db.execute( - select(models.LocationNode).where( - models.LocationNode.id.in_(data.assigned_location_ids), - ), + locations = await location_service.get_locations_by_ids( + data.assigned_location_ids ) - locations = result.scalars().all() - new_patient.assigned_locations = list(locations) + new_patient.assigned_locations = locations elif data.assigned_location_id: - result = await db.execute( - select(models.LocationNode).where( - models.LocationNode.id == data.assigned_location_id, - ), + location = await location_service.get_location_by_id( + data.assigned_location_id ) - location = result.scalars().first() - if location: - new_patient.assigned_locations = [location] + new_patient.assigned_locations = [location] if location else [] if data.properties: - await process_properties( - db, - new_patient, - data.properties, - "patient", + property_service = PatientMutation._get_property_service(db) + await property_service.process_properties( + new_patient, data.properties, "patient" ) - db.add(new_patient) - await db.commit() - + repo = BaseMutationResolver.get_repository(db, models.Patient) + await repo.create(new_patient) await db.refresh(new_patient, ["assigned_locations", "teams"]) - await publish_to_redis("patient_created", new_patient.id) + await BaseMutationResolver.create_and_notify( + info, new_patient, models.Patient, "patient" + ) return new_patient @strawberry.mutation @@ -254,22 +205,7 @@ async def update_patient( raise Exception("Patient not found") if data.checksum: - patient_type = PatientType( - id=patient.id, - firstname=patient.firstname, - lastname=patient.lastname, - birthdate=patient.birthdate, - sex=Sex(patient.sex), - state=PatientState(patient.state), - assigned_location_id=patient.assigned_location_id, - clinic_id=patient.clinic_id, - position_id=patient.position_id, - ) - current_checksum = patient_type.checksum - if data.checksum != current_checksum: - raise Exception( - f"CONFLICT: Patient data has been modified. Expected checksum: {current_checksum}, Got: {data.checksum}" - ) + validate_checksum(patient, data.checksum, "Patient") if data.firstname is not None: patient.firstname = data.firstname @@ -280,96 +216,71 @@ async def update_patient( if data.sex is not None: patient.sex = data.sex.value + location_service = PatientMutation._get_location_service(db) + if data.clinic_id is not None: - clinic_result = await db.execute( - select(models.LocationNode).where( - models.LocationNode.id == data.clinic_id, - ), - ) - clinic = clinic_result.scalars().first() - if not clinic: - raise Exception(f"Clinic location with id {data.clinic_id} not found") - validate_location_kind(clinic, "CLINIC", "clinic_id") + await location_service.validate_and_get_clinic(data.clinic_id) patient.clinic_id = data.clinic_id if data.position_id is not strawberry.UNSET: if data.position_id is None: patient.position_id = None else: - position_result = await db.execute( - select(models.LocationNode).where( - models.LocationNode.id == data.position_id, - ), + await location_service.validate_and_get_position( + data.position_id ) - position = position_result.scalars().first() - if not position: - raise Exception(f"Position location with id {data.position_id} not found") - validate_position_kind(position, "position_id") patient.position_id = data.position_id if data.team_ids is not strawberry.UNSET: if data.team_ids is None or len(data.team_ids) == 0: patient.teams = [] else: - teams_result = await db.execute( - select(models.LocationNode).where( - models.LocationNode.id.in_(data.team_ids), - ), + patient.teams = await location_service.validate_and_get_teams( + data.team_ids ) - teams = list(teams_result.scalars().all()) - if len(teams) != len(data.team_ids): - found_ids = {t.id for t in teams} - missing_ids = set(data.team_ids) - found_ids - raise Exception(f"Team locations with ids {missing_ids} not found") - for team in teams: - validate_team_kind(team, "team_ids") - patient.teams = teams if data.assigned_location_ids is not None: - result = await db.execute( - select(models.LocationNode).where( - models.LocationNode.id.in_(data.assigned_location_ids), - ), + locations = await location_service.get_locations_by_ids( + data.assigned_location_ids ) - locations = result.scalars().all() - patient.assigned_locations = list(locations) + patient.assigned_locations = locations elif data.assigned_location_id is not None: - result = await db.execute( - select(models.LocationNode).where( - models.LocationNode.id == data.assigned_location_id, - ), + location = await location_service.get_location_by_id( + data.assigned_location_id ) - location = result.scalars().first() - if location: - patient.assigned_locations = [location] - else: - patient.assigned_locations = [] + patient.assigned_locations = [location] if location else [] if data.properties: - await process_properties(db, patient, data.properties, "patient") + property_service = PatientMutation._get_property_service(db) + await property_service.process_properties( + patient, data.properties, "patient" + ) - await db.commit() + await BaseMutationResolver.update_and_notify( + info, patient, models.Patient, "patient" + ) await db.refresh(patient, ["assigned_locations", "teams"]) - await publish_to_redis("patient_updated", patient.id) return patient @strawberry.mutation @audit_log("delete_patient") async def delete_patient(self, info: Info, id: strawberry.ID) -> bool: - db = info.context.db - result = await db.execute( - select(models.Patient).where(models.Patient.id == id), - ) - patient = result.scalars().first() + repo = BaseMutationResolver.get_repository(info.context.db, models.Patient) + patient = await repo.get_by_id(id) if not patient: return False - await db.delete(patient) - await db.commit() + await BaseMutationResolver.delete_entity( + info, patient, models.Patient, "patient" + ) return True - @strawberry.mutation - @audit_log("admit_patient") - async def admit_patient(self, info: Info, id: strawberry.ID) -> PatientType: + @staticmethod + async def _update_patient_state( + info: Info, + id: strawberry.ID, + state: PatientState, + ) -> PatientType: + from api.services.notifications import notify_entity_update db = info.context.db result = await db.execute( select(models.Patient) @@ -382,95 +293,53 @@ async def admit_patient(self, info: Info, id: strawberry.ID) -> PatientType: patient = result.scalars().first() if not patient: raise Exception("Patient not found") - patient.state = PatientState.ADMITTED.value - await db.commit() + patient.state = state.value + await BaseMutationResolver.update_and_notify( + info, patient, models.Patient, "patient" + ) await db.refresh(patient, ["assigned_locations"]) - await publish_to_redis("patient_updated", patient.id) - await publish_to_redis("patient_state_changed", patient.id) + await notify_entity_update("patient_state_changed", patient.id) return patient + @strawberry.mutation + @audit_log("admit_patient") + async def admit_patient( + self, info: Info, id: strawberry.ID + ) -> PatientType: + return await PatientMutation._update_patient_state( + info, id, PatientState.ADMITTED + ) + @strawberry.mutation @audit_log("discharge_patient") - async def discharge_patient(self, info: Info, id: strawberry.ID) -> PatientType: - db = info.context.db - result = await db.execute( - select(models.Patient) - .where(models.Patient.id == id) - .options( - selectinload(models.Patient.assigned_locations), - selectinload(models.Patient.teams), - ), + async def discharge_patient( + self, info: Info, id: strawberry.ID + ) -> PatientType: + return await PatientMutation._update_patient_state( + info, id, PatientState.DISCHARGED ) - patient = result.scalars().first() - if not patient: - raise Exception("Patient not found") - patient.state = PatientState.DISCHARGED.value - await db.commit() - await db.refresh(patient, ["assigned_locations"]) - await publish_to_redis("patient_updated", patient.id) - await publish_to_redis("patient_state_changed", patient.id) - return patient @strawberry.mutation @audit_log("mark_patient_dead") - async def mark_patient_dead(self, info: Info, id: strawberry.ID) -> PatientType: - db = info.context.db - result = await db.execute( - select(models.Patient) - .where(models.Patient.id == id) - .options( - selectinload(models.Patient.assigned_locations), - selectinload(models.Patient.teams), - ), - ) - patient = result.scalars().first() - if not patient: - raise Exception("Patient not found") - patient.state = PatientState.DEAD.value - await db.commit() - await db.refresh(patient, ["assigned_locations"]) - await publish_to_redis("patient_updated", patient.id) - await publish_to_redis("patient_state_changed", patient.id) - return patient + async def mark_patient_dead( + self, info: Info, id: strawberry.ID + ) -> PatientType: + return await PatientMutation._update_patient_state(info, id, PatientState.DEAD) @strawberry.mutation @audit_log("wait_patient") async def wait_patient(self, info: Info, id: strawberry.ID) -> PatientType: - db = info.context.db - result = await db.execute( - select(models.Patient) - .where(models.Patient.id == id) - .options( - selectinload(models.Patient.assigned_locations), - selectinload(models.Patient.teams), - ), - ) - patient = result.scalars().first() - if not patient: - raise Exception("Patient not found") - patient.state = PatientState.WAIT.value - await db.commit() - await db.refresh(patient, ["assigned_locations"]) - await publish_to_redis("patient_updated", patient.id) - await publish_to_redis("patient_state_changed", patient.id) - return patient + return await PatientMutation._update_patient_state(info, id, PatientState.WAIT) @strawberry.type -class PatientSubscription: +class PatientSubscription(BaseSubscriptionResolver): @strawberry.subscription async def patient_created( - self, - info: Info, + self, info: Info ) -> AsyncGenerator[strawberry.ID, None]: - pubsub = redis_client.pubsub() - await pubsub.subscribe("patient_created") - try: - async for message in pubsub.listen(): - if message["type"] == "message": - yield message["data"] - finally: - await pubsub.close() + async for patient_id in BaseSubscriptionResolver.entity_created(info, "patient"): + yield patient_id @strawberry.subscription async def patient_updated( @@ -478,16 +347,8 @@ async def patient_updated( info: Info, patient_id: strawberry.ID | None = None, ) -> AsyncGenerator[strawberry.ID, None]: - pubsub = redis_client.pubsub() - await pubsub.subscribe("patient_updated") - try: - async for message in pubsub.listen(): - if message["type"] == "message": - patient_id_str = message["data"] - if patient_id is None or patient_id_str == patient_id: - yield patient_id_str - finally: - await pubsub.close() + async for updated_id in BaseSubscriptionResolver.entity_updated(info, "patient", patient_id): + yield updated_id @strawberry.subscription async def patient_state_changed( @@ -495,13 +356,9 @@ async def patient_state_changed( info: Info, patient_id: strawberry.ID | None = None, ) -> AsyncGenerator[strawberry.ID, None]: - pubsub = redis_client.pubsub() - await pubsub.subscribe("patient_state_changed") - try: - async for message in pubsub.listen(): - if message["type"] == "message": - patient_id_str = message["data"] - if patient_id is None or patient_id_str == patient_id: - yield patient_id_str - finally: - await pubsub.close() + from api.services.subscription import create_redis_subscription + + async for updated_id in create_redis_subscription( + "patient_state_changed", patient_id + ): + yield updated_id diff --git a/backend/api/resolvers/property.py b/backend/api/resolvers/property.py index 12f90b31..8f63af1c 100644 --- a/backend/api/resolvers/property.py +++ b/backend/api/resolvers/property.py @@ -4,6 +4,7 @@ CreatePropertyDefinitionInput, UpdatePropertyDefinitionInput, ) +from api.resolvers.base import BaseMutationResolver from api.types.property import PropertyDefinitionType from database import models from sqlalchemy import select @@ -23,7 +24,11 @@ async def property_definitions( @strawberry.type -class PropertyDefinitionMutation: +class PropertyDefinitionMutation( + BaseMutationResolver[models.PropertyDefinition] +): + pass + @strawberry.mutation async def create_property_definition( self, @@ -41,10 +46,9 @@ async def create_property_definition( is_active=data.is_active, allowed_entities=entities_str, ) - info.context.db.add(defn) - await info.context.db.commit() - await info.context.db.refresh(defn) - return defn + return await BaseMutationResolver.create_and_notify( + info, defn, models.PropertyDefinition, "property_definition" + ) @strawberry.mutation async def update_property_definition( @@ -54,14 +58,10 @@ async def update_property_definition( data: UpdatePropertyDefinitionInput, ) -> PropertyDefinitionType: db = info.context.db - result = await db.execute( - select(models.PropertyDefinition).where( - models.PropertyDefinition.id == id, - ), + repo = BaseMutationResolver.get_repository(db, models.PropertyDefinition) + defn = await repo.get_by_id_or_raise( + id, "Property Definition not found" ) - defn = result.scalars().first() - if not defn: - raise Exception("Property Definition not found") if data.name is not None: defn.name = data.name @@ -76,9 +76,9 @@ async def update_property_definition( [e.value for e in data.allowed_entities], ) - await db.commit() - await db.refresh(defn) - return defn + return await BaseMutationResolver.update_and_notify( + info, defn, models.PropertyDefinition, "property_definition" + ) @strawberry.mutation async def delete_property_definition( @@ -87,15 +87,12 @@ async def delete_property_definition( id: strawberry.ID, ) -> bool: db = info.context.db - result = await db.execute( - select(models.PropertyDefinition).where( - models.PropertyDefinition.id == id, - ), - ) - defn = result.scalars().first() + repo = BaseMutationResolver.get_repository(db, models.PropertyDefinition) + defn = await repo.get_by_id(id) if not defn: return False - await db.delete(defn) - await db.commit() + await BaseMutationResolver.delete_entity( + info, defn, models.PropertyDefinition, "property_definition" + ) return True diff --git a/backend/api/resolvers/task.py b/backend/api/resolvers/task.py index fe638e17..435f328b 100644 --- a/backend/api/resolvers/task.py +++ b/backend/api/resolvers/task.py @@ -1,27 +1,25 @@ from collections.abc import AsyncGenerator -from datetime import timezone - import strawberry from api.audit import audit_log from api.context import Info from api.inputs import CreateTaskInput, UpdateTaskInput +from api.resolvers.base import BaseMutationResolver, BaseSubscriptionResolver +from api.services.base import BaseRepository +from api.services.checksum import validate_checksum +from api.services.datetime import normalize_datetime_to_utc +from api.services.property import PropertyService from api.types.task import TaskType from database import models -from database.session import publish_to_redis, redis_client from sqlalchemy import desc, select -from .utils import process_properties - @strawberry.type class TaskQuery: @strawberry.field async def task(self, info: Info, id: strawberry.ID) -> TaskType | None: - result = await info.context.db.execute( - select(models.Task).where(models.Task.id == id), - ) - return result.scalars().first() + repo = BaseRepository(info.context.db, models.Task) + return await repo.get_by_id(id) @strawberry.field async def tasks( @@ -54,36 +52,36 @@ async def recent_tasks( @strawberry.type -class TaskMutation: +class TaskMutation(BaseMutationResolver[models.Task]): + @staticmethod + def _get_property_service(db) -> PropertyService: + return PropertyService(db) + @strawberry.mutation @audit_log("create_task") async def create_task(self, info: Info, data: CreateTaskInput) -> TaskType: - due_date = data.due_date - if due_date and due_date.tzinfo is not None: - due_date = due_date.astimezone(timezone.utc).replace(tzinfo=None) - new_task = models.Task( title=data.title, description=data.description, patient_id=data.patient_id, assignee_id=data.assignee_id, - due_date=due_date, + due_date=normalize_datetime_to_utc(data.due_date), ) - info.context.db.add(new_task) + if data.properties: - await process_properties( - info.context.db, - new_task, - data.properties, - "task", + property_service = TaskMutation._get_property_service(info.context.db) + await property_service.process_properties( + new_task, data.properties, "task" ) - await info.context.db.commit() - await info.context.db.refresh(new_task) - await publish_to_redis("task_created", new_task.id) - if new_task.patient_id: - await publish_to_redis("patient_updated", new_task.patient_id) - return new_task + return await BaseMutationResolver.create_and_notify( + info, + new_task, + models.Task, + "task", + "patient" if new_task.patient_id else None, + new_task.patient_id if new_task.patient_id else None, + ) @strawberry.mutation @audit_log("update_task") @@ -94,30 +92,11 @@ async def update_task( data: UpdateTaskInput, ) -> TaskType: db = info.context.db - result = await db.execute( - select(models.Task).where(models.Task.id == id), - ) - task = result.scalars().first() - if not task: - raise Exception("Task not found") + repo = BaseMutationResolver.get_repository(db, models.Task) + task = await repo.get_by_id_or_raise(id, "Task not found") if data.checksum: - task_type = TaskType( - id=task.id, - title=task.title, - description=task.description, - done=task.done, - due_date=task.due_date, - creation_date=task.creation_date, - update_date=task.update_date, - assignee_id=task.assignee_id, - patient_id=task.patient_id, - ) - current_checksum = task_type.checksum - if data.checksum != current_checksum: - raise Exception( - f"CONFLICT: Task data has been modified. Expected checksum: {current_checksum}, Got: {data.checksum}" - ) + validate_checksum(task, data.checksum, "Task") if data.title is not None: task.title = data.title @@ -127,24 +106,40 @@ async def update_task( task.done = data.done if data.due_date is not strawberry.UNSET: - if data.due_date is not None: - if data.due_date.tzinfo is not None: - task.due_date = data.due_date.astimezone(timezone.utc).replace( - tzinfo=None, - ) - else: - task.due_date = data.due_date - else: - task.due_date = None + task.due_date = ( + normalize_datetime_to_utc(data.due_date) + if data.due_date + else None + ) if data.properties: - await process_properties(db, task, data.properties, "task") + property_service = TaskMutation._get_property_service(db) + await property_service.process_properties( + task, data.properties, "task" + ) + + return await BaseMutationResolver.update_and_notify( + info, + task, + models.Task, + "task", + "patient", + task.patient_id, + ) - await db.commit() - await db.refresh(task) - await publish_to_redis("task_updated", task.id) - if task.patient_id: - await publish_to_redis("patient_updated", task.patient_id) + @staticmethod + async def _update_task_field( + info: Info, + id: strawberry.ID, + field_updater, + ) -> TaskType: + db = info.context.db + repo = BaseMutationResolver.get_repository(db, models.Task) + task = await repo.get_by_id_or_raise(id, "Task not found") + field_updater(task) + await BaseMutationResolver.update_and_notify( + info, task, models.Task, "task", "patient", task.patient_id + ) return task @strawberry.mutation @@ -155,115 +150,63 @@ async def assign_task( id: strawberry.ID, user_id: strawberry.ID, ) -> TaskType: - db = info.context.db - result = await db.execute( - select(models.Task).where(models.Task.id == id), + return await TaskMutation._update_task_field( + info, + id, + lambda task: setattr(task, "assignee_id", user_id), ) - task = result.scalars().first() - if not task: - raise Exception("Task not found") - - task.assignee_id = user_id - await db.commit() - await db.refresh(task) - await publish_to_redis("task_updated", task.id) - if task.patient_id: - await publish_to_redis("patient_updated", task.patient_id) - return task @strawberry.mutation @audit_log("unassign_task") async def unassign_task(self, info: Info, id: strawberry.ID) -> TaskType: - db = info.context.db - result = await db.execute( - select(models.Task).where(models.Task.id == id), + return await TaskMutation._update_task_field( + info, + id, + lambda task: setattr(task, "assignee_id", None), ) - task = result.scalars().first() - if not task: - raise Exception("Task not found") - - task.assignee_id = None - await db.commit() - await db.refresh(task) - await publish_to_redis("task_updated", task.id) - if task.patient_id: - await publish_to_redis("patient_updated", task.patient_id) - return task @strawberry.mutation @audit_log("complete_task") async def complete_task(self, info: Info, id: strawberry.ID) -> TaskType: - db = info.context.db - result = await db.execute( - select(models.Task).where(models.Task.id == id), + return await TaskMutation._update_task_field( + info, + id, + lambda task: setattr(task, "done", True), ) - task = result.scalars().first() - if not task: - raise Exception("Task not found") - - task.done = True - await db.commit() - await db.refresh(task) - await publish_to_redis("task_updated", task.id) - if task.patient_id: - await publish_to_redis("patient_updated", task.patient_id) - return task @strawberry.mutation @audit_log("reopen_task") async def reopen_task(self, info: Info, id: strawberry.ID) -> TaskType: - db = info.context.db - result = await db.execute( - select(models.Task).where(models.Task.id == id), + return await TaskMutation._update_task_field( + info, + id, + lambda task: setattr(task, "done", False), ) - task = result.scalars().first() - if not task: - raise Exception("Task not found") - - task.done = False - await db.commit() - await db.refresh(task) - await publish_to_redis("task_updated", task.id) - if task.patient_id: - await publish_to_redis("patient_updated", task.patient_id) - return task @strawberry.mutation @audit_log("delete_task") async def delete_task(self, info: Info, id: strawberry.ID) -> bool: db = info.context.db - result = await db.execute( - select(models.Task).where(models.Task.id == id), - ) - task = result.scalars().first() + repo = BaseMutationResolver.get_repository(db, models.Task) + task = await repo.get_by_id(id) if not task: return False - task_id = task.id patient_id = task.patient_id - await db.delete(task) - await db.commit() - await publish_to_redis("task_deleted", task_id) - if patient_id: - await publish_to_redis("patient_updated", patient_id) + await BaseMutationResolver.delete_entity( + info, task, models.Task, "task", "patient", patient_id + ) return True @strawberry.type -class TaskSubscription: +class TaskSubscription(BaseSubscriptionResolver): @strawberry.subscription async def task_created( - self, - info: Info, + self, info: Info ) -> AsyncGenerator[strawberry.ID, None]: - pubsub = redis_client.pubsub() - await pubsub.subscribe("task_created") - try: - async for message in pubsub.listen(): - if message["type"] == "message": - yield message["data"] - finally: - await pubsub.close() + async for task_id in BaseSubscriptionResolver.entity_created(info, "task"): + yield task_id @strawberry.subscription async def task_updated( @@ -271,27 +214,12 @@ async def task_updated( info: Info, task_id: strawberry.ID | None = None, ) -> AsyncGenerator[strawberry.ID, None]: - pubsub = redis_client.pubsub() - await pubsub.subscribe("task_updated") - try: - async for message in pubsub.listen(): - if message["type"] == "message": - task_id_str = message["data"] - if task_id is None or task_id_str == task_id: - yield task_id_str - finally: - await pubsub.close() + async for updated_id in BaseSubscriptionResolver.entity_updated(info, "task", task_id): + yield updated_id @strawberry.subscription async def task_deleted( - self, - info: Info, + self, info: Info ) -> AsyncGenerator[strawberry.ID, None]: - pubsub = redis_client.pubsub() - await pubsub.subscribe("task_deleted") - try: - async for message in pubsub.listen(): - if message["type"] == "message": - yield message["data"] - finally: - await pubsub.close() + async for task_id in BaseSubscriptionResolver.entity_deleted(info, "task"): + yield task_id diff --git a/backend/api/resolvers/utils.py b/backend/api/resolvers/utils.py deleted file mode 100644 index af1314ae..00000000 --- a/backend/api/resolvers/utils.py +++ /dev/null @@ -1,36 +0,0 @@ -from api.inputs import PropertyValueInput -from database import models -from sqlalchemy.ext.asyncio import AsyncSession - - -async def process_properties( - db: AsyncSession, - entity, - props_data: list[PropertyValueInput], - entity_kind: str, -): - if not props_data: - return - for p_in in props_data: - ms_val = ( - ",".join(p_in.multi_select_values) - if p_in.multi_select_values - else None - ) - prop_val = models.PropertyValue( - definition_id=p_in.definition_id, - text_value=p_in.text_value, - number_value=p_in.number_value, - boolean_value=p_in.boolean_value, - date_value=p_in.date_value, - date_time_value=p_in.date_time_value, - select_value=p_in.select_value, - multi_select_values=ms_val, - ) - - if entity_kind == "patient": - prop_val.patient_id = entity.id - elif entity_kind == "task": - prop_val.task_id = entity.id - - db.add(prop_val) diff --git a/backend/api/services/__init__.py b/backend/api/services/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/backend/api/services/base.py b/backend/api/services/base.py new file mode 100644 index 00000000..7838c8de --- /dev/null +++ b/backend/api/services/base.py @@ -0,0 +1,54 @@ +from typing import Generic, TypeVar + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +ModelType = TypeVar("ModelType") + + +class BaseRepository(Generic[ModelType]): + def __init__(self, db: AsyncSession, model: type[ModelType]): + self.db = db + self.model = model + + async def get_by_id(self, id: str) -> ModelType | None: + result = await self.db.execute( + select(self.model).where(self.model.id == id), + ) + return result.scalars().first() + + async def get_by_id_or_raise( + self, id: str, error_message: str = "Entity not found" + ) -> ModelType: + entity = await self.get_by_id(id) + if not entity: + raise Exception(error_message) + return entity + + async def get_all(self) -> list[ModelType]: + result = await self.db.execute(select(self.model)) + return list(result.scalars().all()) + + async def create(self, entity: ModelType) -> ModelType: + self.db.add(entity) + await self.db.commit() + await self.db.refresh(entity) + return entity + + async def update(self, entity: ModelType) -> ModelType: + await self.db.commit() + await self.db.refresh(entity) + return entity + + async def delete(self, entity: ModelType) -> None: + await self.db.delete(entity) + await self.db.commit() + + +class BaseService: + def __init__(self, db: AsyncSession): + self.db = db + + async def commit_and_refresh(self, entity) -> None: + await self.db.commit() + await self.db.refresh(entity) diff --git a/backend/api/services/checksum.py b/backend/api/services/checksum.py new file mode 100644 index 00000000..732e9ca4 --- /dev/null +++ b/backend/api/services/checksum.py @@ -0,0 +1,20 @@ +from typing import Any + +from api.types.base import calculate_checksum_for_instance + + +def validate_checksum( + entity: Any, + provided_checksum: str, + entity_name: str = "Entity", +) -> None: + if not provided_checksum: + return + + current_checksum = calculate_checksum_for_instance(entity) + + if provided_checksum != current_checksum: + raise Exception( + f"CONFLICT: {entity_name} data has been modified. " + f"Expected checksum: {current_checksum}, Got: {provided_checksum}" + ) diff --git a/backend/api/services/datetime.py b/backend/api/services/datetime.py new file mode 100644 index 00000000..fff6b288 --- /dev/null +++ b/backend/api/services/datetime.py @@ -0,0 +1,9 @@ +from datetime import timezone + + +def normalize_datetime_to_utc(dt) -> None | object: + if dt is None: + return None + if dt.tzinfo is not None: + return dt.astimezone(timezone.utc).replace(tzinfo=None) + return dt diff --git a/backend/api/services/location.py b/backend/api/services/location.py new file mode 100644 index 00000000..4918dc60 --- /dev/null +++ b/backend/api/services/location.py @@ -0,0 +1,85 @@ +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from database import models +from .validation import LocationValidator + + +class LocationService: + def __init__(self, db: AsyncSession): + self.db = db + self.validator = LocationValidator() + + async def get_location_by_id( + self, location_id: str + ) -> models.LocationNode | None: + result = await self.db.execute( + select(models.LocationNode).where( + models.LocationNode.id == location_id + ), + ) + return result.scalars().first() + + async def get_location_by_id_or_raise( + self, + location_id: str, + error_message: str | None = None, + ) -> models.LocationNode: + location = await self.get_location_by_id(location_id) + if not location: + raise Exception( + error_message or f"Location with id {location_id} not found" + ) + return location + + async def get_locations_by_ids( + self, + location_ids: list[str], + ) -> list[models.LocationNode]: + if not location_ids: + return [] + result = await self.db.execute( + select(models.LocationNode).where( + models.LocationNode.id.in_(location_ids) + ), + ) + return list(result.scalars().all()) + + async def validate_and_get_clinic( + self, + clinic_id: str, + ) -> models.LocationNode: + clinic = await self.get_location_by_id_or_raise( + clinic_id, + f"Clinic location with id {clinic_id} not found", + ) + self.validator.validate_kind(clinic, "CLINIC", "clinic_id") + return clinic + + async def validate_and_get_position( + self, + position_id: str | None, + ) -> models.LocationNode | None: + if not position_id: + return None + position = await self.get_location_by_id_or_raise( + position_id, + f"Position location with id {position_id} not found", + ) + self.validator.validate_position_kind(position, "position_id") + return position + + async def validate_and_get_teams( + self, + team_ids: list[str], + ) -> list[models.LocationNode]: + if not team_ids: + return [] + teams = await self.get_locations_by_ids(team_ids) + if len(teams) != len(team_ids): + found_ids = {t.id for t in teams} + missing_ids = set(team_ids) - found_ids + raise Exception(f"Team locations with ids {missing_ids} not found") + for team in teams: + self.validator.validate_team_kind(team, "team_ids") + return teams diff --git a/backend/api/services/notifications.py b/backend/api/services/notifications.py new file mode 100644 index 00000000..77bbb92c --- /dev/null +++ b/backend/api/services/notifications.py @@ -0,0 +1,31 @@ +from database.session import publish_to_redis + + +async def notify_entity_update( + entity_type: str, + entity_id: str, + related_entity_type: str | None = None, + related_entity_id: str | None = None, +) -> None: + await publish_to_redis(f"{entity_type}_updated", str(entity_id)) + if related_entity_type and related_entity_id: + await publish_to_redis( + f"{related_entity_type}_updated", str(related_entity_id) + ) + + +async def notify_entity_created(entity_type: str, entity_id: str) -> None: + await publish_to_redis(f"{entity_type}_created", str(entity_id)) + + +async def notify_entity_deleted( + entity_type: str, + entity_id: str, + related_entity_type: str | None = None, + related_entity_id: str | None = None, +) -> None: + await publish_to_redis(f"{entity_type}_deleted", str(entity_id)) + if related_entity_type and related_entity_id: + await publish_to_redis( + f"{related_entity_type}_updated", str(related_entity_id) + ) diff --git a/backend/api/services/property.py b/backend/api/services/property.py new file mode 100644 index 00000000..c4fb58ae --- /dev/null +++ b/backend/api/services/property.py @@ -0,0 +1,42 @@ +from api.inputs import PropertyValueInput +from database import models +from sqlalchemy.ext.asyncio import AsyncSession + + +class PropertyService: + def __init__(self, db: AsyncSession): + self.db = db + + async def process_properties( + self, + entity: models.Patient | models.Task, + props_data: list[PropertyValueInput], + entity_kind: str, + ) -> None: + if not props_data: + return + + for prop_input in props_data: + multi_select_value = ( + ",".join(prop_input.multi_select_values) + if prop_input.multi_select_values + else None + ) + + prop_value = models.PropertyValue( + definition_id=prop_input.definition_id, + text_value=prop_input.text_value, + number_value=prop_input.number_value, + boolean_value=prop_input.boolean_value, + date_value=prop_input.date_value, + date_time_value=prop_input.date_time_value, + select_value=prop_input.select_value, + multi_select_values=multi_select_value, + ) + + if entity_kind == "patient": + prop_value.patient_id = entity.id + elif entity_kind == "task": + prop_value.task_id = entity.id + + self.db.add(prop_value) diff --git a/backend/api/services/subscription.py b/backend/api/services/subscription.py new file mode 100644 index 00000000..28382e22 --- /dev/null +++ b/backend/api/services/subscription.py @@ -0,0 +1,19 @@ +from collections.abc import AsyncGenerator + +from database.session import redis_client + + +async def create_redis_subscription( + channel: str, + filter_id: str | None = None, +) -> AsyncGenerator[str, None]: + pubsub = redis_client.pubsub() + await pubsub.subscribe(channel) + try: + async for message in pubsub.listen(): + if message["type"] == "message": + message_id = message["data"] + if filter_id is None or message_id == filter_id: + yield message_id + finally: + await pubsub.close() diff --git a/backend/api/services/validation.py b/backend/api/services/validation.py new file mode 100644 index 00000000..5e9d1c01 --- /dev/null +++ b/backend/api/services/validation.py @@ -0,0 +1,42 @@ +from database import models + + +class LocationValidator: + @staticmethod + def validate_kind( + location: models.LocationNode, expected_kind: str, field_name: str + ) -> None: + if location.kind.upper() != expected_kind.upper(): + raise Exception( + f"{field_name} must be a location of kind {expected_kind}, " + f"but got {location.kind}" + ) + + @staticmethod + def validate_position_kind( + location: models.LocationNode, field_name: str + ) -> None: + allowed_kinds = { + "HOSPITAL", + "PRACTICE", + "CLINIC", + "WARD", + "BED", + "ROOM", + } + if location.kind.upper() not in allowed_kinds: + raise Exception( + f"{field_name} must be a location of kind HOSPITAL, PRACTICE, CLINIC, " + f"WARD, BED, or ROOM, but got {location.kind}" + ) + + @staticmethod + def validate_team_kind( + location: models.LocationNode, field_name: str + ) -> None: + allowed_kinds = {"CLINIC", "TEAM", "PRACTICE", "HOSPITAL"} + if location.kind.upper() not in allowed_kinds: + raise Exception( + f"{field_name} must be a location of kind CLINIC, TEAM, PRACTICE, " + f"or HOSPITAL, but got {location.kind}" + ) diff --git a/backend/api/types/base.py b/backend/api/types/base.py index 8dc852d6..c2693d41 100644 --- a/backend/api/types/base.py +++ b/backend/api/types/base.py @@ -41,7 +41,10 @@ def _is_safe_to_serialize(value: Any) -> bool: return True if hasattr(value, "__class__"): class_name = value.__class__.__name__ - if any(x in class_name for x in ["InstrumentedList", "AppenderQuery", "Query", "Session"]): + if any( + x in class_name + for x in ["InstrumentedList", "AppenderQuery", "Query", "Session"] + ): return False if hasattr(value, "__iter__") and not isinstance(value, (str, bytes)): return False diff --git a/backend/database/migrations/versions/add_patient_location_mapping.py b/backend/database/migrations/versions/add_patient_location_mapping.py index 67a07504..1199072c 100644 --- a/backend/database/migrations/versions/add_patient_location_mapping.py +++ b/backend/database/migrations/versions/add_patient_location_mapping.py @@ -3,8 +3,8 @@ Revision ID: add_patient_location_mapping Revises: add_patient_locations Create Date: 2025-01-15 00:00:00.000000 - """ + from typing import Sequence, Union from alembic import op @@ -12,43 +12,50 @@ # revision identifiers, used by Alembic. -revision: str = 'add_patient_location_mapping' -down_revision: Union[str, Sequence[str], None] = 'add_user_email_and_organizations' +revision: str = "add_patient_location_mapping" +down_revision: Union[str, Sequence[str], None] = ( + "add_user_email_and_organizations" +) branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: """Upgrade schema.""" - op.add_column('patients', sa.Column('clinic_id', sa.String(), nullable=True)) - - op.add_column('patients', sa.Column('position_id', sa.String(), nullable=True)) - + op.add_column( + "patients", sa.Column("clinic_id", sa.String(), nullable=True) + ) + + op.add_column( + "patients", sa.Column("position_id", sa.String(), nullable=True) + ) + op.create_foreign_key( - 'fk_patients_clinic_id', - 'patients', - 'location_nodes', - ['clinic_id'], - ['id'] + "fk_patients_clinic_id", + "patients", + "location_nodes", + ["clinic_id"], + ["id"], ) op.create_foreign_key( - 'fk_patients_position_id', - 'patients', - 'location_nodes', - ['position_id'], - ['id'] + "fk_patients_position_id", + "patients", + "location_nodes", + ["position_id"], + ["id"], ) - + op.create_table( - 'patient_teams', - sa.Column('patient_id', sa.String(), nullable=False), - sa.Column('location_id', sa.String(), nullable=False), - sa.ForeignKeyConstraint(['patient_id'], ['patients.id']), - sa.ForeignKeyConstraint(['location_id'], ['location_nodes.id']), - sa.PrimaryKeyConstraint('patient_id', 'location_id') + "patient_teams", + sa.Column("patient_id", sa.String(), nullable=False), + sa.Column("location_id", sa.String(), nullable=False), + sa.ForeignKeyConstraint(["patient_id"], ["patients.id"]), + sa.ForeignKeyConstraint(["location_id"], ["location_nodes.id"]), + sa.PrimaryKeyConstraint("patient_id", "location_id"), ) - - op.execute(""" + + op.execute( + """ UPDATE patients p SET clinic_id = ( SELECT location_id @@ -61,9 +68,11 @@ def upgrade() -> None: WHERE EXISTS ( SELECT 1 FROM patient_locations pl2 WHERE pl2.patient_id = p.id ) - """) - - op.execute(""" + """ + ) + + op.execute( + """ UPDATE patients p SET clinic_id = ( SELECT location_id @@ -75,31 +84,37 @@ def upgrade() -> None: AND EXISTS ( SELECT 1 FROM patient_locations pl2 WHERE pl2.patient_id = p.id ) - """) - - op.execute(""" + """ + ) + + op.execute( + """ UPDATE patients SET clinic_id = assigned_location_id WHERE clinic_id IS NULL AND assigned_location_id IS NOT NULL - """) - - op.execute(""" + """ + ) + + op.execute( + """ UPDATE patients p SET clinic_id = ( SELECT id FROM location_nodes WHERE kind = 'CLINIC' LIMIT 1 ) WHERE clinic_id IS NULL - """) - - op.alter_column('patients', 'clinic_id', nullable=False) + """ + ) + + op.alter_column("patients", "clinic_id", nullable=False) def downgrade() -> None: """Downgrade schema.""" - op.drop_table('patient_teams') - op.drop_constraint('fk_patients_position_id', 'patients', type_='foreignkey') - op.drop_constraint('fk_patients_clinic_id', 'patients', type_='foreignkey') - op.drop_column('patients', 'position_id') - op.drop_column('patients', 'clinic_id') - + op.drop_table("patient_teams") + op.drop_constraint( + "fk_patients_position_id", "patients", type_="foreignkey" + ) + op.drop_constraint("fk_patients_clinic_id", "patients", type_="foreignkey") + op.drop_column("patients", "position_id") + op.drop_column("patients", "clinic_id") diff --git a/backend/database/migrations/versions/add_patient_locations_table.py b/backend/database/migrations/versions/add_patient_locations_table.py index 4972a911..b14c8a96 100644 --- a/backend/database/migrations/versions/add_patient_locations_table.py +++ b/backend/database/migrations/versions/add_patient_locations_table.py @@ -1,10 +1,10 @@ -"""Add patient_locations many-to-many table +"""Add patient_locations many-to-many table. Revision ID: add_patient_locations Revises: baace9e34585 Create Date: 2025-12-13 00:00:00.000000 - """ + from typing import Sequence, Union from alembic import op @@ -12,8 +12,8 @@ # revision identifiers, used by Alembic. -revision: str = 'add_patient_locations' -down_revision: Union[str, Sequence[str], None] = 'baace9e34585' +revision: str = "add_patient_locations" +down_revision: Union[str, Sequence[str], None] = "baace9e34585" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -22,23 +22,31 @@ def upgrade() -> None: """Upgrade schema.""" # Create the association table op.create_table( - 'patient_locations', - sa.Column('patient_id', sa.String(), nullable=False), - sa.Column('location_id', sa.String(), nullable=False), - sa.ForeignKeyConstraint(['patient_id'], ['patients.id'], ), - sa.ForeignKeyConstraint(['location_id'], ['location_nodes.id'], ), - sa.PrimaryKeyConstraint('patient_id', 'location_id') + "patient_locations", + sa.Column("patient_id", sa.String(), nullable=False), + sa.Column("location_id", sa.String(), nullable=False), + sa.ForeignKeyConstraint( + ["patient_id"], + ["patients.id"], + ), + sa.ForeignKeyConstraint( + ["location_id"], + ["location_nodes.id"], + ), + sa.PrimaryKeyConstraint("patient_id", "location_id"), ) - + # Migrate existing single location assignments to the new table - op.execute(""" + op.execute( + """ INSERT INTO patient_locations (patient_id, location_id) SELECT id, assigned_location_id FROM patients WHERE assigned_location_id IS NOT NULL - """) + """ + ) def downgrade() -> None: """Downgrade schema.""" - op.drop_table('patient_locations') + op.drop_table("patient_locations") diff --git a/backend/database/migrations/versions/add_patient_state.py b/backend/database/migrations/versions/add_patient_state.py index 1062b0a0..f7160114 100644 --- a/backend/database/migrations/versions/add_patient_state.py +++ b/backend/database/migrations/versions/add_patient_state.py @@ -3,14 +3,18 @@ from alembic import op import sqlalchemy as sa -revision: str = 'add_patient_state' -down_revision: Union[str, Sequence[str], None] = 'add_patient_locations' +revision: str = "add_patient_state" +down_revision: Union[str, Sequence[str], None] = "add_patient_locations" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None + def upgrade() -> None: - op.add_column('patients', sa.Column('state', sa.String(), nullable=False, server_default='WAIT')) + op.add_column( + "patients", + sa.Column("state", sa.String(), nullable=False, server_default="WAIT"), + ) -def downgrade() -> None: - op.drop_column('patients', 'state') +def downgrade() -> None: + op.drop_column("patients", "state") diff --git a/backend/database/migrations/versions/add_user_email_and_organizations.py b/backend/database/migrations/versions/add_user_email_and_organizations.py index 21a460d7..f8ed02f5 100644 --- a/backend/database/migrations/versions/add_user_email_and_organizations.py +++ b/backend/database/migrations/versions/add_user_email_and_organizations.py @@ -1,10 +1,10 @@ -"""Add email and organizations fields to users table +"""Add email and organizations fields to users table. Revision ID: add_user_email_and_organizations Revises: add_patient_state Create Date: 2025-12-13 12:00:00.000000 - """ + from typing import Sequence, Union from alembic import op @@ -12,20 +12,21 @@ # revision identifiers, used by Alembic. -revision: str = 'add_user_email_and_organizations' -down_revision: Union[str, Sequence[str], None] = 'add_patient_state' +revision: str = "add_user_email_and_organizations" +down_revision: Union[str, Sequence[str], None] = "add_patient_state" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: """Upgrade schema.""" - op.add_column('users', sa.Column('email', sa.String(), nullable=True)) - op.add_column('users', sa.Column('organizations', sa.String(), nullable=True)) + op.add_column("users", sa.Column("email", sa.String(), nullable=True)) + op.add_column( + "users", sa.Column("organizations", sa.String(), nullable=True) + ) def downgrade() -> None: """Downgrade schema.""" - op.drop_column('users', 'organizations') - op.drop_column('users', 'email') - + op.drop_column("users", "organizations") + op.drop_column("users", "email") diff --git a/backend/database/migrations/versions/baace9e34585_.py b/backend/database/migrations/versions/baace9e34585_.py index 98d79391..66732193 100644 --- a/backend/database/migrations/versions/baace9e34585_.py +++ b/backend/database/migrations/versions/baace9e34585_.py @@ -1,10 +1,10 @@ -"""empty message +"""Empty message. Revision ID: baace9e34585 -Revises: +Revises: Create Date: 2025-12-12 21:19:42.928171 - """ + from typing import Sequence, Union from alembic import op @@ -12,7 +12,7 @@ # revision identifiers, used by Alembic. -revision: str = 'baace9e34585' +revision: str = "baace9e34585" down_revision: Union[str, Sequence[str], None] = None branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -21,80 +21,114 @@ def upgrade() -> None: """Upgrade schema.""" # ### commands auto generated by Alembic - please adjust! ### - op.create_table('location_nodes', - sa.Column('id', sa.String(), nullable=False), - sa.Column('title', sa.String(), nullable=False), - sa.Column('kind', sa.String(), nullable=False), - sa.Column('parent_id', sa.String(), nullable=True), - sa.ForeignKeyConstraint(['parent_id'], ['location_nodes.id'], ), - sa.PrimaryKeyConstraint('id') + op.create_table( + "location_nodes", + sa.Column("id", sa.String(), nullable=False), + sa.Column("title", sa.String(), nullable=False), + sa.Column("kind", sa.String(), nullable=False), + sa.Column("parent_id", sa.String(), nullable=True), + sa.ForeignKeyConstraint( + ["parent_id"], + ["location_nodes.id"], + ), + sa.PrimaryKeyConstraint("id"), ) - op.create_table('property_definitions', - sa.Column('id', sa.String(), nullable=False), - sa.Column('name', sa.String(), nullable=False), - sa.Column('description', sa.String(), nullable=True), - sa.Column('field_type', sa.String(), nullable=False), - sa.Column('options', sa.String(), nullable=True), - sa.Column('is_active', sa.Boolean(), nullable=False), - sa.Column('allowed_entities', sa.String(), nullable=False), - sa.PrimaryKeyConstraint('id') + op.create_table( + "property_definitions", + sa.Column("id", sa.String(), nullable=False), + sa.Column("name", sa.String(), nullable=False), + sa.Column("description", sa.String(), nullable=True), + sa.Column("field_type", sa.String(), nullable=False), + sa.Column("options", sa.String(), nullable=True), + sa.Column("is_active", sa.Boolean(), nullable=False), + sa.Column("allowed_entities", sa.String(), nullable=False), + sa.PrimaryKeyConstraint("id"), ) - op.create_table('users', - sa.Column('id', sa.String(), nullable=False), - sa.Column('username', sa.String(), nullable=False), - sa.Column('firstname', sa.String(), nullable=True), - sa.Column('lastname', sa.String(), nullable=True), - sa.Column('title', sa.String(), nullable=True), - sa.Column('avatar_url', sa.String(), nullable=True), - sa.PrimaryKeyConstraint('id') + op.create_table( + "users", + sa.Column("id", sa.String(), nullable=False), + sa.Column("username", sa.String(), nullable=False), + sa.Column("firstname", sa.String(), nullable=True), + sa.Column("lastname", sa.String(), nullable=True), + sa.Column("title", sa.String(), nullable=True), + sa.Column("avatar_url", sa.String(), nullable=True), + sa.PrimaryKeyConstraint("id"), ) - op.create_table('patients', - sa.Column('id', sa.String(), nullable=False), - sa.Column('firstname', sa.String(), nullable=False), - sa.Column('lastname', sa.String(), nullable=False), - sa.Column('birthdate', sa.Date(), nullable=False), - sa.Column('sex', sa.String(), nullable=False), - sa.Column('assigned_location_id', sa.String(), nullable=True), - sa.ForeignKeyConstraint(['assigned_location_id'], ['location_nodes.id'], ), - sa.PrimaryKeyConstraint('id') + op.create_table( + "patients", + sa.Column("id", sa.String(), nullable=False), + sa.Column("firstname", sa.String(), nullable=False), + sa.Column("lastname", sa.String(), nullable=False), + sa.Column("birthdate", sa.Date(), nullable=False), + sa.Column("sex", sa.String(), nullable=False), + sa.Column("assigned_location_id", sa.String(), nullable=True), + sa.ForeignKeyConstraint( + ["assigned_location_id"], + ["location_nodes.id"], + ), + sa.PrimaryKeyConstraint("id"), ) - op.create_table('tasks', - sa.Column('id', sa.String(), nullable=False), - sa.Column('title', sa.String(), nullable=False), - sa.Column('description', sa.String(), nullable=True), - sa.Column('done', sa.Boolean(), nullable=False), - sa.Column('due_date', sa.DateTime(), nullable=True), - sa.Column('creation_date', sa.DateTime(), nullable=False), - sa.Column('update_date', sa.DateTime(), nullable=True), - sa.Column('assignee_id', sa.String(), nullable=True), - sa.Column('patient_id', sa.String(), nullable=False), - sa.ForeignKeyConstraint(['assignee_id'], ['users.id'], ), - sa.ForeignKeyConstraint(['patient_id'], ['patients.id'], ), - sa.PrimaryKeyConstraint('id') + op.create_table( + "tasks", + sa.Column("id", sa.String(), nullable=False), + sa.Column("title", sa.String(), nullable=False), + sa.Column("description", sa.String(), nullable=True), + sa.Column("done", sa.Boolean(), nullable=False), + sa.Column("due_date", sa.DateTime(), nullable=True), + sa.Column("creation_date", sa.DateTime(), nullable=False), + sa.Column("update_date", sa.DateTime(), nullable=True), + sa.Column("assignee_id", sa.String(), nullable=True), + sa.Column("patient_id", sa.String(), nullable=False), + sa.ForeignKeyConstraint( + ["assignee_id"], + ["users.id"], + ), + sa.ForeignKeyConstraint( + ["patient_id"], + ["patients.id"], + ), + sa.PrimaryKeyConstraint("id"), ) - op.create_table('property_values', - sa.Column('id', sa.String(), nullable=False), - sa.Column('definition_id', sa.String(), nullable=False), - sa.Column('patient_id', sa.String(), nullable=True), - sa.Column('task_id', sa.String(), nullable=True), - sa.Column('text_value', sa.String(), nullable=True), - sa.Column('number_value', sa.Float(), nullable=True), - sa.Column('boolean_value', sa.Boolean(), nullable=True), - sa.Column('date_value', sa.Date(), nullable=True), - sa.Column('date_time_value', sa.DateTime(), nullable=True), - sa.Column('select_value', sa.String(), nullable=True), - sa.Column('multi_select_values', sa.String(), nullable=True), - sa.ForeignKeyConstraint(['definition_id'], ['property_definitions.id'], ), - sa.ForeignKeyConstraint(['patient_id'], ['patients.id'], ), - sa.ForeignKeyConstraint(['task_id'], ['tasks.id'], ), - sa.PrimaryKeyConstraint('id') + op.create_table( + "property_values", + sa.Column("id", sa.String(), nullable=False), + sa.Column("definition_id", sa.String(), nullable=False), + sa.Column("patient_id", sa.String(), nullable=True), + sa.Column("task_id", sa.String(), nullable=True), + sa.Column("text_value", sa.String(), nullable=True), + sa.Column("number_value", sa.Float(), nullable=True), + sa.Column("boolean_value", sa.Boolean(), nullable=True), + sa.Column("date_value", sa.Date(), nullable=True), + sa.Column("date_time_value", sa.DateTime(), nullable=True), + sa.Column("select_value", sa.String(), nullable=True), + sa.Column("multi_select_values", sa.String(), nullable=True), + sa.ForeignKeyConstraint( + ["definition_id"], + ["property_definitions.id"], + ), + sa.ForeignKeyConstraint( + ["patient_id"], + ["patients.id"], + ), + sa.ForeignKeyConstraint( + ["task_id"], + ["tasks.id"], + ), + sa.PrimaryKeyConstraint("id"), ) - op.create_table('task_dependencies', - sa.Column('previous_task_id', sa.String(), nullable=False), - sa.Column('next_task_id', sa.String(), nullable=False), - sa.ForeignKeyConstraint(['next_task_id'], ['tasks.id'], ), - sa.ForeignKeyConstraint(['previous_task_id'], ['tasks.id'], ), - sa.PrimaryKeyConstraint('previous_task_id', 'next_task_id') + op.create_table( + "task_dependencies", + sa.Column("previous_task_id", sa.String(), nullable=False), + sa.Column("next_task_id", sa.String(), nullable=False), + sa.ForeignKeyConstraint( + ["next_task_id"], + ["tasks.id"], + ), + sa.ForeignKeyConstraint( + ["previous_task_id"], + ["tasks.id"], + ), + sa.PrimaryKeyConstraint("previous_task_id", "next_task_id"), ) # ### end Alembic commands ### @@ -102,11 +136,11 @@ def upgrade() -> None: def downgrade() -> None: """Downgrade schema.""" # ### commands auto generated by Alembic - please adjust! ### - op.drop_table('task_dependencies') - op.drop_table('property_values') - op.drop_table('tasks') - op.drop_table('patients') - op.drop_table('users') - op.drop_table('property_definitions') - op.drop_table('location_nodes') + op.drop_table("task_dependencies") + op.drop_table("property_values") + op.drop_table("tasks") + op.drop_table("patients") + op.drop_table("users") + op.drop_table("property_definitions") + op.drop_table("location_nodes") # ### end Alembic commands ### diff --git a/backend/database/session.py b/backend/database/session.py index 1b339f27..09252e1c 100644 --- a/backend/database/session.py +++ b/backend/database/session.py @@ -20,11 +20,17 @@ async def publish_to_redis(channel: str, message: str) -> None: try: - logger.info(f"Publishing to Redis: channel={channel}, message={message}") + logger.info( + f"Publishing to Redis: channel={channel}, message={message}" + ) await redis_client.publish(channel, message) - logger.debug(f"Successfully published to Redis: channel={channel}, message={message}") + logger.debug( + f"Successfully published to Redis: channel={channel}, message={message}" + ) except Exception as e: - logger.error(f"Failed to publish to Redis: channel={channel}, message={message}, error={e}") + logger.error( + f"Failed to publish to Redis: channel={channel}, message={message}, error={e}" + ) async def get_db_session() -> AsyncGenerator[AsyncSession, None]: diff --git a/backend/main.py b/backend/main.py index 1533f287..d8890d92 100644 --- a/backend/main.py +++ b/backend/main.py @@ -23,6 +23,7 @@ async def lifespan(app: FastAPI): yield logger.info("Shutting down application...") + schema = Schema( query=Query, mutation=Mutation, @@ -52,6 +53,12 @@ async def lifespan(app: FastAPI): app.include_router(auth.router) app.include_router(graphql_app, prefix="/graphql") + +@app.get("/health") +async def health_check(): + return {"status": "ok"} + + app.add_middleware( CORSMiddleware, allow_origins=ALLOWED_ORIGINS, diff --git a/backend/pytest.ini b/backend/pytest.ini new file mode 100644 index 00000000..7fd55620 --- /dev/null +++ b/backend/pytest.ini @@ -0,0 +1,9 @@ +[pytest] +asyncio_mode = auto +asyncio_default_fixture_loop_scope = function +testpaths = tests +python_files = test_*.py +python_classes = Test* +python_functions = test_* + + diff --git a/backend/requirements.txt b/backend/requirements.txt index 2ae83829..45bbcfd0 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -9,3 +9,8 @@ sqlalchemy==2.0.45 strawberry-graphql[fastapi]==0.287.3 uvicorn[standard]==0.38.0 influxdb_client==1.49.0 +pytest==8.3.4 +pytest-asyncio==0.24.0 +pytest-cov==5.0.0 +aiosqlite==0.20.0 +httpx==0.27.2 diff --git a/backend/scaffold.py b/backend/scaffold.py index 3eb06a69..2855a9c4 100644 --- a/backend/scaffold.py +++ b/backend/scaffold.py @@ -19,11 +19,15 @@ async def load_scaffold_data() -> None: scaffold_path = Path(SCAFFOLD_DIRECTORY) if not scaffold_path.exists(): - logger.warning(f"Scaffold directory {SCAFFOLD_DIRECTORY} does not exist, skipping") + logger.warning( + f"Scaffold directory {SCAFFOLD_DIRECTORY} does not exist, skipping" + ) return if not scaffold_path.is_dir(): - logger.warning(f"Scaffold path {SCAFFOLD_DIRECTORY} is not a directory, skipping") + logger.warning( + f"Scaffold path {SCAFFOLD_DIRECTORY} is not a directory, skipping" + ) return async with async_session() as session: @@ -31,16 +35,22 @@ async def load_scaffold_data() -> None: existing_location = result.scalar_one_or_none() if existing_location: - logger.info("Location nodes already exist in database, skipping scaffold loading") + logger.info( + "Location nodes already exist in database, skipping scaffold loading" + ) return json_files = list(scaffold_path.glob("*.json")) if not json_files: - logger.info(f"No JSON files found in {SCAFFOLD_DIRECTORY}, skipping scaffold loading") + logger.info( + f"No JSON files found in {SCAFFOLD_DIRECTORY}, skipping scaffold loading" + ) return - logger.info(f"Loading scaffold data from {len(json_files)} JSON file(s) in {SCAFFOLD_DIRECTORY}") + logger.info( + f"Loading scaffold data from {len(json_files)} JSON file(s) in {SCAFFOLD_DIRECTORY}" + ) for json_file in json_files: try: @@ -53,15 +63,21 @@ async def load_scaffold_data() -> None: elif isinstance(data, dict): await _create_location_tree(session, data, None) else: - logger.warning(f"Invalid JSON structure in {json_file}, expected list or object") + logger.warning( + f"Invalid JSON structure in {json_file}, expected list or object" + ) await session.commit() - logger.info(f"Successfully loaded scaffold data from {json_file}") + logger.info( + f"Successfully loaded scaffold data from {json_file}" + ) except json.JSONDecodeError as e: logger.error(f"Failed to parse JSON file {json_file}: {e}") await session.rollback() except Exception as e: - logger.error(f"Error loading scaffold data from {json_file}: {e}") + logger.error( + f"Error loading scaffold data from {json_file}: {e}" + ) await session.rollback() diff --git a/backend/tests/__init__.py b/backend/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py new file mode 100644 index 00000000..48776790 --- /dev/null +++ b/backend/tests/conftest.py @@ -0,0 +1,101 @@ +import pytest +from sqlalchemy.ext.asyncio import ( + AsyncSession, + create_async_engine, + async_sessionmaker, +) +from sqlalchemy.pool import StaticPool + +from database.models.base import Base +from database.models.location import LocationNode +from database.models.patient import Patient +from database.models.task import Task +from database.models.user import User +from api.inputs import Sex, PatientState + + +@pytest.fixture +async def db_session(): + engine = create_async_engine( + "sqlite+aiosqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + async_session = async_sessionmaker(engine, expire_on_commit=False) + + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + async with async_session() as session: + yield session + + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.drop_all) + + await engine.dispose() + + +@pytest.fixture +async def sample_user(db_session: AsyncSession) -> User: + user = User( + id="user-1", + username="testuser", + firstname="Test", + lastname="User", + title="Dr.", + ) + db_session.add(user) + await db_session.commit() + await db_session.refresh(user) + return user + + +@pytest.fixture +async def sample_location(db_session: AsyncSession) -> LocationNode: + location = LocationNode( + id="location-1", + title="Test Clinic", + kind="CLINIC", + ) + db_session.add(location) + await db_session.commit() + await db_session.refresh(location) + return location + + +@pytest.fixture +async def sample_patient( + db_session: AsyncSession, sample_location: LocationNode +) -> Patient: + from datetime import date + + patient = Patient( + id="patient-1", + firstname="John", + lastname="Doe", + birthdate=date(1990, 1, 1), + sex=Sex.MALE.value, + state=PatientState.ADMITTED.value, + clinic_id=sample_location.id, + ) + db_session.add(patient) + await db_session.commit() + await db_session.refresh(patient) + return patient + + +@pytest.fixture +async def sample_task( + db_session: AsyncSession, sample_patient: Patient +) -> Task: + task = Task( + id="task-1", + title="Test Task", + description="Test Description", + patient_id=sample_patient.id, + done=False, + ) + db_session.add(task) + await db_session.commit() + await db_session.refresh(task) + return task diff --git a/backend/tests/integration/__init__.py b/backend/tests/integration/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/backend/tests/integration/test_patient_resolver.py b/backend/tests/integration/test_patient_resolver.py new file mode 100644 index 00000000..be752666 --- /dev/null +++ b/backend/tests/integration/test_patient_resolver.py @@ -0,0 +1,77 @@ +import pytest +from api.context import Context +from api.resolvers.patient import PatientQuery, PatientMutation +from api.inputs import Sex, PatientState + + +class MockInfo: + def __init__(self, db): + self.context = Context(db=db) + + +@pytest.mark.asyncio +async def test_patient_query_get_patient(db_session, sample_patient): + info = MockInfo(db_session) + query = PatientQuery() + result = await query.patient(info, sample_patient.id) + assert result is not None + assert result.id == sample_patient.id + assert result.firstname == sample_patient.firstname + + +@pytest.mark.asyncio +async def test_patient_query_patients(db_session, sample_patient): + info = MockInfo(db_session) + query = PatientQuery() + results = await query.patients(info) + assert len(results) >= 1 + assert any(p.id == sample_patient.id for p in results) + + +@pytest.mark.asyncio +async def test_patient_mutation_create_patient(db_session, sample_location): + from api.inputs import CreatePatientInput + from datetime import date + + info = MockInfo(db_session) + mutation = PatientMutation() + input_data = CreatePatientInput( + firstname="Jane", + lastname="Doe", + birthdate=date(1990, 1, 1), + sex=Sex.FEMALE, + clinic_id=sample_location.id, + ) + result = await mutation.create_patient(info, input_data) + assert result.id is not None + assert result.firstname == "Jane" + assert result.lastname == "Doe" + assert result.clinic_id == sample_location.id + + +@pytest.mark.asyncio +async def test_patient_mutation_update_patient(db_session, sample_patient): + from api.inputs import UpdatePatientInput + + info = MockInfo(db_session) + mutation = PatientMutation() + input_data = UpdatePatientInput(firstname="Updated Name") + result = await mutation.update_patient(info, sample_patient.id, input_data) + assert result.firstname == "Updated Name" + assert result.id == sample_patient.id + + +@pytest.mark.asyncio +async def test_patient_mutation_admit_patient(db_session, sample_patient): + info = MockInfo(db_session) + mutation = PatientMutation() + result = await mutation.admit_patient(info, sample_patient.id) + assert result.state == PatientState.ADMITTED.value + + +@pytest.mark.asyncio +async def test_patient_mutation_discharge_patient(db_session, sample_patient): + info = MockInfo(db_session) + mutation = PatientMutation() + result = await mutation.discharge_patient(info, sample_patient.id) + assert result.state == PatientState.DISCHARGED.value diff --git a/backend/tests/integration/test_task_resolver.py b/backend/tests/integration/test_task_resolver.py new file mode 100644 index 00000000..3cf68c75 --- /dev/null +++ b/backend/tests/integration/test_task_resolver.py @@ -0,0 +1,87 @@ +import pytest +from api.context import Context +from api.resolvers.task import TaskQuery, TaskMutation +from database.models.task import Task + + +class MockInfo: + def __init__(self, db): + self.context = Context(db=db) + + +@pytest.mark.asyncio +async def test_task_query_get_task(db_session, sample_task): + info = MockInfo(db_session) + query = TaskQuery() + result = await query.task(info, sample_task.id) + assert result is not None + assert result.id == sample_task.id + assert result.title == sample_task.title + + +@pytest.mark.asyncio +async def test_task_query_tasks_by_patient(db_session, sample_patient): + info = MockInfo(db_session) + task1 = Task(title="Task 1", patient_id=sample_patient.id) + task2 = Task(title="Task 2", patient_id=sample_patient.id) + db_session.add(task1) + db_session.add(task2) + await db_session.commit() + + query = TaskQuery() + results = await query.tasks(info, patient_id=sample_patient.id) + assert len(results) >= 2 + task_titles = {t.title for t in results} + assert "Task 1" in task_titles + assert "Task 2" in task_titles + + +@pytest.mark.asyncio +async def test_task_mutation_create_task(db_session, sample_patient): + from api.inputs import CreateTaskInput + + info = MockInfo(db_session) + mutation = TaskMutation() + input_data = CreateTaskInput( + title="New Task", + description="Description", + patient_id=sample_patient.id, + ) + result = await mutation.create_task(info, input_data) + assert result.id is not None + assert result.title == "New Task" + assert result.patient_id == sample_patient.id + + +@pytest.mark.asyncio +async def test_task_mutation_update_task(db_session, sample_task): + from api.inputs import UpdateTaskInput + + info = MockInfo(db_session) + mutation = TaskMutation() + input_data = UpdateTaskInput(title="Updated Title") + result = await mutation.update_task(info, sample_task.id, input_data) + assert result.title == "Updated Title" + assert result.id == sample_task.id + + +@pytest.mark.asyncio +async def test_task_mutation_complete_task(db_session, sample_task): + info = MockInfo(db_session) + mutation = TaskMutation() + result = await mutation.complete_task(info, sample_task.id) + assert result.done is True + assert result.id == sample_task.id + + +@pytest.mark.asyncio +async def test_task_mutation_delete_task(db_session, sample_task): + info = MockInfo(db_session) + mutation = TaskMutation() + task_id = sample_task.id + result = await mutation.delete_task(info, task_id) + assert result is True + + query = TaskQuery() + task = await query.task(info, task_id) + assert task is None diff --git a/backend/tests/unit/__init__.py b/backend/tests/unit/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/backend/tests/unit/test_base_repository.py b/backend/tests/unit/test_base_repository.py new file mode 100644 index 00000000..eb1f24a1 --- /dev/null +++ b/backend/tests/unit/test_base_repository.py @@ -0,0 +1,71 @@ +import pytest +from api.services.base import BaseRepository +from database.models.task import Task + + +@pytest.mark.asyncio +async def test_get_by_id(db_session, sample_task): + repo = BaseRepository(db_session, Task) + result = await repo.get_by_id(sample_task.id) + assert result is not None + assert result.id == sample_task.id + assert result.title == "Test Task" + + +@pytest.mark.asyncio +async def test_get_by_id_not_found(db_session): + repo = BaseRepository(db_session, Task) + result = await repo.get_by_id("non-existent") + assert result is None + + +@pytest.mark.asyncio +async def test_get_by_id_or_raise(db_session, sample_task): + repo = BaseRepository(db_session, Task) + result = await repo.get_by_id_or_raise(sample_task.id) + assert result.id == sample_task.id + + +@pytest.mark.asyncio +async def test_get_by_id_or_raise_not_found(db_session): + repo = BaseRepository(db_session, Task) + with pytest.raises(Exception, match="Entity not found"): + await repo.get_by_id_or_raise("non-existent") + + +@pytest.mark.asyncio +async def test_get_all(db_session, sample_task): + repo = BaseRepository(db_session, Task) + results = await repo.get_all() + assert len(results) >= 1 + assert any(t.id == sample_task.id for t in results) + + +@pytest.mark.asyncio +async def test_create(db_session): + repo = BaseRepository(db_session, Task) + new_task = Task( + title="New Task", + description="New Description", + patient_id="patient-1", + ) + result = await repo.create(new_task) + assert result.id is not None + assert result.title == "New Task" + + +@pytest.mark.asyncio +async def test_update(db_session, sample_task): + repo = BaseRepository(db_session, Task) + sample_task.title = "Updated Title" + result = await repo.update(sample_task) + assert result.title == "Updated Title" + + +@pytest.mark.asyncio +async def test_delete(db_session, sample_task): + repo = BaseRepository(db_session, Task) + task_id = sample_task.id + await repo.delete(sample_task) + result = await repo.get_by_id(task_id) + assert result is None diff --git a/backend/tests/unit/test_base_resolvers.py b/backend/tests/unit/test_base_resolvers.py new file mode 100644 index 00000000..f90616d2 --- /dev/null +++ b/backend/tests/unit/test_base_resolvers.py @@ -0,0 +1,69 @@ +import pytest +from api.context import Context +from api.resolvers.base import BaseMutationResolver, BaseQueryResolver +from database.models.task import Task + + +class MockInfo: + def __init__(self, db): + self.context = Context(db=db) + + +@pytest.mark.asyncio +async def test_base_query_resolver_get_by_id(db_session, sample_task): + resolver = BaseQueryResolver(Task) + info = MockInfo(db_session) + result = await resolver._get_by_id_impl(info, sample_task.id) + assert result is not None + assert result.id == sample_task.id + + +@pytest.mark.asyncio +async def test_base_query_resolver_get_all(db_session, sample_task): + resolver = BaseQueryResolver(Task) + info = MockInfo(db_session) + results = await resolver._get_all_impl(info) + assert len(results) >= 1 + assert any(t.id == sample_task.id for t in results) + + +@pytest.mark.asyncio +async def test_base_mutation_resolver_create_and_notify( + db_session, sample_patient +): + info = MockInfo(db_session) + new_task = Task( + title="New Task", + description="Description", + patient_id=sample_patient.id, + ) + result = await BaseMutationResolver.create_and_notify( + info, new_task, Task, "task" + ) + assert result.id is not None + assert result.title == "New Task" + assert result.patient_id == sample_patient.id + + +@pytest.mark.asyncio +async def test_base_mutation_resolver_update_and_notify( + db_session, sample_task +): + info = MockInfo(db_session) + sample_task.title = "Updated Title" + result = await BaseMutationResolver.update_and_notify( + info, sample_task, Task, "task" + ) + assert result.title == "Updated Title" + + +@pytest.mark.asyncio +async def test_base_mutation_resolver_delete_entity(db_session, sample_task): + info = MockInfo(db_session) + await BaseMutationResolver.delete_entity( + info, sample_task, Task, "task" + ) + + repo = BaseMutationResolver.get_repository(db_session, Task) + result = await repo.get_by_id(sample_task.id) + assert result is None diff --git a/backend/tests/unit/test_checksum.py b/backend/tests/unit/test_checksum.py new file mode 100644 index 00000000..cae3f8da --- /dev/null +++ b/backend/tests/unit/test_checksum.py @@ -0,0 +1,21 @@ +import pytest +from api.services.checksum import validate_checksum + + +@pytest.mark.asyncio +async def test_validate_checksum_valid(db_session, sample_task): + from api.types.base import calculate_checksum_for_instance + + checksum = calculate_checksum_for_instance(sample_task) + validate_checksum(sample_task, checksum, "Task") + + +@pytest.mark.asyncio +async def test_validate_checksum_invalid(db_session, sample_task): + with pytest.raises(Exception, match="CONFLICT"): + validate_checksum(sample_task, "invalid-checksum", "Task") + + +@pytest.mark.asyncio +async def test_validate_checksum_none(db_session, sample_task): + validate_checksum(sample_task, None, "Task") diff --git a/backend/tests/unit/test_location_service.py b/backend/tests/unit/test_location_service.py new file mode 100644 index 00000000..511a24f8 --- /dev/null +++ b/backend/tests/unit/test_location_service.py @@ -0,0 +1,100 @@ +import pytest +from api.services.location import LocationService +from database.models.location import LocationNode + + +@pytest.mark.asyncio +async def test_get_location_by_id(db_session, sample_location): + service = LocationService(db_session) + result = await service.get_location_by_id(sample_location.id) + assert result is not None + assert result.id == sample_location.id + + +@pytest.mark.asyncio +async def test_get_location_by_id_not_found(db_session): + service = LocationService(db_session) + result = await service.get_location_by_id("non-existent") + assert result is None + + +@pytest.mark.asyncio +async def test_get_location_by_id_or_raise(db_session, sample_location): + service = LocationService(db_session) + result = await service.get_location_by_id_or_raise(sample_location.id) + assert result.id == sample_location.id + + +@pytest.mark.asyncio +async def test_get_location_by_id_or_raise_not_found(db_session): + service = LocationService(db_session) + with pytest.raises(Exception, match="not found"): + await service.get_location_by_id_or_raise("non-existent") + + +@pytest.mark.asyncio +async def test_get_locations_by_ids(db_session, sample_location): + service = LocationService(db_session) + location2 = LocationNode(id="location-2", title="Location 2", kind="WARD") + db_session.add(location2) + await db_session.commit() + + results = await service.get_locations_by_ids( + [sample_location.id, "location-2"] + ) + assert len(results) == 2 + + +@pytest.mark.asyncio +async def test_validate_and_get_clinic(db_session, sample_location): + service = LocationService(db_session) + result = await service.validate_and_get_clinic(sample_location.id) + assert result.id == sample_location.id + assert result.kind == "CLINIC" + + +@pytest.mark.asyncio +async def test_validate_and_get_clinic_wrong_kind(db_session): + service = LocationService(db_session) + ward = LocationNode(id="ward-1", title="Ward", kind="WARD") + db_session.add(ward) + await db_session.commit() + + with pytest.raises(Exception, match="must be a location of kind CLINIC"): + await service.validate_and_get_clinic("ward-1") + + +@pytest.mark.asyncio +async def test_validate_and_get_position(db_session): + service = LocationService(db_session) + ward = LocationNode(id="ward-1", title="Ward", kind="WARD") + db_session.add(ward) + await db_session.commit() + + result = await service.validate_and_get_position("ward-1") + assert result is not None + assert result.kind == "WARD" + + +@pytest.mark.asyncio +async def test_validate_and_get_teams(db_session): + service = LocationService(db_session) + team1 = LocationNode(id="team-1", title="Team 1", kind="TEAM") + team2 = LocationNode(id="team-2", title="Team 2", kind="TEAM") + db_session.add(team1) + db_session.add(team2) + await db_session.commit() + + results = await service.validate_and_get_teams(["team-1", "team-2"]) + assert len(results) == 2 + + +@pytest.mark.asyncio +async def test_validate_and_get_teams_missing(db_session): + service = LocationService(db_session) + team1 = LocationNode(id="team-1", title="Team 1", kind="TEAM") + db_session.add(team1) + await db_session.commit() + + with pytest.raises(Exception, match="not found"): + await service.validate_and_get_teams(["team-1", "non-existent"]) diff --git a/backend/tests/unit/test_property_service.py b/backend/tests/unit/test_property_service.py new file mode 100644 index 00000000..23a33f97 --- /dev/null +++ b/backend/tests/unit/test_property_service.py @@ -0,0 +1,94 @@ +import pytest +from api.inputs import PropertyValueInput +from api.services.property import PropertyService + + +@pytest.mark.asyncio +async def test_process_properties_for_patient(db_session, sample_patient): + service = PropertyService(db_session) + props = [ + PropertyValueInput( + definition_id="def-1", + text_value="Test Value", + ), + ] + await service.process_properties(sample_patient, props, "patient") + await db_session.commit() + + from database.models.property import PropertyValue + from sqlalchemy import select + + result = await db_session.execute( + select(PropertyValue).where( + PropertyValue.patient_id == sample_patient.id + ) + ) + prop_values = result.scalars().all() + assert len(prop_values) == 1 + assert prop_values[0].text_value == "Test Value" + + +@pytest.mark.asyncio +async def test_process_properties_for_task(db_session, sample_task): + service = PropertyService(db_session) + props = [ + PropertyValueInput( + definition_id="def-1", + number_value=42, + ), + ] + await service.process_properties(sample_task, props, "task") + await db_session.commit() + + from database.models.property import PropertyValue + from sqlalchemy import select + + result = await db_session.execute( + select(PropertyValue).where(PropertyValue.task_id == sample_task.id) + ) + prop_values = result.scalars().all() + assert len(prop_values) == 1 + assert prop_values[0].number_value == 42 + + +@pytest.mark.asyncio +async def test_process_properties_empty_list(db_session, sample_patient): + service = PropertyService(db_session) + await service.process_properties(sample_patient, [], "patient") + await db_session.commit() + + from database.models.property import PropertyValue + from sqlalchemy import select + + result = await db_session.execute( + select(PropertyValue).where( + PropertyValue.patient_id == sample_patient.id + ) + ) + prop_values = result.scalars().all() + assert len(prop_values) == 0 + + +@pytest.mark.asyncio +async def test_process_properties_multi_select(db_session, sample_patient): + service = PropertyService(db_session) + props = [ + PropertyValueInput( + definition_id="def-1", + multi_select_values=["option1", "option2", "option3"], + ), + ] + await service.process_properties(sample_patient, props, "patient") + await db_session.commit() + + from database.models.property import PropertyValue + from sqlalchemy import select + + result = await db_session.execute( + select(PropertyValue).where( + PropertyValue.patient_id == sample_patient.id + ) + ) + prop_values = result.scalars().all() + assert len(prop_values) == 1 + assert prop_values[0].multi_select_values == "option1,option2,option3" diff --git a/shell.nix b/shell.nix index 744a8ec8..688ffd66 100644 --- a/shell.nix +++ b/shell.nix @@ -42,6 +42,7 @@ pkgs.mkShell { netcat pkgs.gcc hadolint + pkgs.act ]; venvDir = "./backend/venv"; @@ -197,7 +198,12 @@ pkgs.mkShell { fi } + run-act() { + echo ">>> Running GitHub Actions locally with act..." + ${pkgs.act}/bin/act "$@" + } + echo ">>> Environment ready." - echo "Commands: run-dev-backend, run-dev-web, run-dev-all, run-alembic, psql-dev, redis-cli-dev, clean-dev, start-docker, stop-docker, run-simulator, lint-dockerfiles" + echo "Commands: run-dev-backend, run-dev-web, run-dev-all, run-alembic, psql-dev, redis-cli-dev, clean-dev, start-docker, stop-docker, run-simulator, lint-dockerfiles, run-act" ''; } diff --git a/simulator/graphql_client.py b/simulator/graphql_client.py index a7cd6d09..3825a29f 100644 --- a/simulator/graphql_client.py +++ b/simulator/graphql_client.py @@ -22,7 +22,10 @@ def _is_authentication_error(self, response: requests.Response) -> bool: if "errors" in error_data: for error in error_data["errors"]: message = error.get("message", "").lower() - if "unauthenticated" in message or "not authenticated" in message: + if ( + "unauthenticated" in message + or "not authenticated" in message + ): return True except Exception: pass diff --git a/simulator/patient_manager.py b/simulator/patient_manager.py index 143dfbd0..fd8d3a2e 100644 --- a/simulator/patient_manager.py +++ b/simulator/patient_manager.py @@ -8,7 +8,9 @@ class PatientManager: - def __init__(self, client: GraphQLClient, location_manager: LocationManager): + def __init__( + self, client: GraphQLClient, location_manager: LocationManager + ): self.client = client self.location_manager = location_manager self.patient_ids: List[str] = [] @@ -86,7 +88,9 @@ def ensure_diagnosis_property(self) -> None: else: self._log_errors("ensure_diagnosis_property", response) - def create_patient(self, admit_directly: bool = False) -> Tuple[Optional[str], Optional[str]]: + def create_patient( + self, admit_directly: bool = False + ) -> Tuple[Optional[str], Optional[str]]: if not self.location_manager.clinics: self.location_manager.load_locations() if not self.location_manager.clinics: @@ -147,7 +151,9 @@ def create_patient(self, admit_directly: bool = False) -> Tuple[Optional[str], O self.patient_ids.append(pid) state_msg = "admitted" if admit_directly else "in waiting room" - logger.info(f"Created patient {first} {last} ({state_msg}) - Diagnosis: {diagnosis}") + logger.info( + f"Created patient {first} {last} ({state_msg}) - Diagnosis: {diagnosis}" + ) self.ensure_diagnosis_property() if self.diagnosis_property_id: @@ -175,7 +181,7 @@ def _add_diagnosis_property(self, patient_id: str, diagnosis: str) -> None: "textValue": diagnosis, } ] - } + }, } response = self.client.query(mutation, variables) @@ -201,7 +207,9 @@ def admit_patient(self, patient_id: Optional[str] = None) -> bool: if "errors" in response: for error in response["errors"]: if "Patient not found" in error.get("message", ""): - logger.warning(f"Patient {patient_id} not found. Removing from list.") + logger.warning( + f"Patient {patient_id} not found. Removing from list." + ) if patient_id in self.patient_ids: self.patient_ids.remove(patient_id) return False @@ -233,7 +241,9 @@ def discharge_patient(self, patient_id: Optional[str] = None) -> bool: if "errors" in response: for error in response["errors"]: if "Patient not found" in error.get("message", ""): - logger.warning(f"Patient {patient_id} not found. Removing from list.") + logger.warning( + f"Patient {patient_id} not found. Removing from list." + ) if patient_id in self.patient_ids: self.patient_ids.remove(patient_id) return False @@ -277,7 +287,9 @@ def move_patient(self, patient_id: Optional[str] = None) -> bool: if "errors" in response: for error in response["errors"]: if "Patient not found" in error.get("message", ""): - logger.warning(f"Patient {patient_id} not found. Removing from list.") + logger.warning( + f"Patient {patient_id} not found. Removing from list." + ) if patient_id in self.patient_ids: self.patient_ids.remove(patient_id) return False @@ -292,7 +304,9 @@ def move_patient(self, patient_id: Optional[str] = None) -> bool: return True return False - def update_patient_position(self, patient_id: Optional[str] = None) -> bool: + def update_patient_position( + self, patient_id: Optional[str] = None + ) -> bool: if not patient_id: if not self.patient_ids: return False @@ -321,7 +335,9 @@ def update_patient_position(self, patient_id: Optional[str] = None) -> bool: if "errors" in response: for error in response["errors"]: if "Patient not found" in error.get("message", ""): - logger.warning(f"Patient {patient_id} not found. Removing from list.") + logger.warning( + f"Patient {patient_id} not found. Removing from list." + ) if patient_id in self.patient_ids: self.patient_ids.remove(patient_id) return False @@ -332,6 +348,8 @@ def update_patient_position(self, patient_id: Optional[str] = None) -> bool: if data and data.get("updatePatient"): position = data["updatePatient"].get("position") position_name = position["title"] if position else "none" - logger.info(f"Updated patient {patient_id} position to {position_name}") + logger.info( + f"Updated patient {patient_id} position to {position_name}" + ) return True return False diff --git a/simulator/simulator.py b/simulator/simulator.py index 5a6d9668..0fc80699 100644 --- a/simulator/simulator.py +++ b/simulator/simulator.py @@ -12,7 +12,9 @@ class ClinicSimulator: def __init__(self): self.client = GraphQLClient() self.location_manager = LocationManager(self.client) - self.patient_manager = PatientManager(self.client, self.location_manager) + self.patient_manager = PatientManager( + self.client, self.location_manager + ) self.task_manager = TaskManager(self.client) self.user_id: Optional[str] = None @@ -56,7 +58,9 @@ def run(self) -> None: logger.info("Creating initial patients...") while len(self.patient_manager.patient_ids) < 5: admit_directly = random.random() < 0.4 - patient_id, diagnosis = self.patient_manager.create_patient(admit_directly=admit_directly) + patient_id, diagnosis = self.patient_manager.create_patient( + admit_directly=admit_directly + ) if patient_id and diagnosis: self.task_manager.create_treatment_tasks(patient_id, diagnosis) @@ -88,8 +92,8 @@ def run(self) -> None: logger.info("Simulation stopped by user.") break except Exception as e: - logger.error(f"Error in simulation loop: {e}") - time.sleep(5) + logger.error(f"Error in simulation loop: {e}", exc_info=True) + raise def _action_create_task(self) -> None: if not self.patient_manager.patient_ids: @@ -102,7 +106,9 @@ def _action_update_task(self) -> None: def _action_create_patient(self) -> None: admit_directly = random.random() < 0.3 - patient_id, diagnosis = self.patient_manager.create_patient(admit_directly=admit_directly) + patient_id, diagnosis = self.patient_manager.create_patient( + admit_directly=admit_directly + ) if patient_id and diagnosis: self.task_manager.create_treatment_tasks(patient_id, diagnosis) diff --git a/simulator/task_manager.py b/simulator/task_manager.py index c72e4bf9..723f9a5e 100644 --- a/simulator/task_manager.py +++ b/simulator/task_manager.py @@ -78,7 +78,9 @@ def create_task( elif self.current_user_id and random.random() > 0.3: assignee_id = self.current_user_id - due_date = (datetime.now() + timedelta(hours=random.randint(1, 48))).isoformat() + due_date = ( + datetime.now() + timedelta(hours=random.randint(1, 48)) + ).isoformat() mutation = """ mutation CreateTask($title: String!, $patientId: ID!, $assigneeId: ID, $dueDate: DateTime) { @@ -116,18 +118,24 @@ def create_task( assignee_info = task.get("assignee") assignee_msg = "" if assignee_info: - assignee_msg = f" assigned to {assignee_info.get('username', 'user')}" + assignee_msg = ( + f" assigned to {assignee_info.get('username', 'user')}" + ) else: assignee_msg = " (unassigned)" due_date_str = task.get("dueDate", due_date) - logger.info(f"Created task '{title}'{assignee_msg} due {due_date_str}") + logger.info( + f"Created task '{title}'{assignee_msg} due {due_date_str}" + ) return tid else: self._log_errors("create_task", response) return None - def create_treatment_tasks(self, patient_id: str, diagnosis: str) -> List[str]: + def create_treatment_tasks( + self, patient_id: str, diagnosis: str + ) -> List[str]: treatments = TreatmentPlanner.get_treatments_for_diagnosis(diagnosis) task_ids = [] @@ -157,7 +165,9 @@ def update_task(self) -> bool: done } } - """ % ("completeTask" if complete else "reopenTask") + """ % ( + "completeTask" if complete else "reopenTask" + ) response = self.client.query(mutation, {"id": tid}) diff --git a/tests/.gitignore b/tests/.gitignore new file mode 100644 index 00000000..6199b727 --- /dev/null +++ b/tests/.gitignore @@ -0,0 +1,26 @@ +# dependencies +node_modules +.pnp +.pnp.js + +# debug +npm-debug.log* +yarn-debug.log* +yarn-error.log* +.pnpm-debug.log* + +# build +build +tsconfig.tsbuildinfo + +# jetbrains +.idea +.fleet + +# misc +.DS_Store +*.pem + + +# local env files +.env*.local diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/e2e/__init__.py b/tests/e2e/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/e2e/auth.spec.ts b/tests/e2e/auth.spec.ts new file mode 100644 index 00000000..7779f7cf --- /dev/null +++ b/tests/e2e/auth.spec.ts @@ -0,0 +1,68 @@ +import { test, expect } from '@playwright/test'; + +const baseURL = process.env.E2E_BASE_URL || 'http://localhost:3000'; + +test.describe('Authentication', () => { + test('should display login page', async ({ page, baseURL: configBaseURL }) => { + const url = configBaseURL || baseURL; + await page.goto(url); + await expect(page).toHaveTitle(/tasks/i); + }); + + test('should handle authentication flow', async ({ page, baseURL: configBaseURL }) => { + const url = configBaseURL || baseURL; + await page.goto(url); + await page.waitForLoadState('networkidle'); + + const body = page.locator('body'); + await expect(body).toBeVisible(); + }); + + test('should handle unauthenticated access gracefully', async ({ page, baseURL: configBaseURL }) => { + const url = configBaseURL || baseURL; + await page.goto(url); + await page.waitForLoadState('networkidle'); + + const errors: string[] = []; + page.on('pageerror', (error) => { + errors.push(error.message); + }); + + await page.waitForTimeout(1000); + expect(errors.length).toBeLessThanOrEqual(0); + }); + + test('should have proper page metadata', async ({ page, baseURL: configBaseURL }) => { + const url = configBaseURL || baseURL; + await page.goto(url); + await page.waitForLoadState('networkidle'); + + const title = await page.title(); + expect(title).toBeTruthy(); + }); + + test('should load without console errors', async ({ page, baseURL: configBaseURL }) => { + const consoleErrors: string[] = []; + page.on('console', (msg) => { + if (msg.type() === 'error') { + consoleErrors.push(msg.text()); + } + }); + + const url = configBaseURL || baseURL; + await page.goto(url); + await page.waitForLoadState('networkidle'); + + const criticalErrors = consoleErrors.filter( + (error) => + !error.includes('favicon') && + !error.includes('404') && + !error.includes('Failed to load resource') && + !error.includes('net::ERR_') && + !error.toLowerCase().includes('chunk') + ); + expect(criticalErrors.length).toBeLessThanOrEqual(0); + }); +}); + + diff --git a/tests/e2e/navigation.spec.ts b/tests/e2e/navigation.spec.ts new file mode 100644 index 00000000..faf210bc --- /dev/null +++ b/tests/e2e/navigation.spec.ts @@ -0,0 +1,47 @@ +import { test, expect } from '@playwright/test'; + +const baseURL = process.env.E2E_BASE_URL || 'http://localhost:3000'; + +test.describe('Navigation', () => { + test.beforeEach(async ({ page, baseURL: configBaseURL }) => { + const url = configBaseURL || baseURL; + await page.goto(url); + await page.waitForLoadState('networkidle'); + }); + + test('should navigate between pages', async ({ page, baseURL: configBaseURL }) => { + const baseUrl = configBaseURL || baseURL; + const pages = ['/tasks', '/patients', '/properties']; + + for (const path of pages) { + const fullUrl = `${baseUrl}${path}`; + await page.goto(fullUrl); + await page.waitForLoadState('networkidle'); + await expect(page.locator('body')).toBeVisible(); + } + }); + + test('should handle 404 page', async ({ page, baseURL: configBaseURL }) => { + const baseUrl = configBaseURL || baseURL; + await page.goto(`${baseUrl}/non-existent-page`); + await page.waitForLoadState('networkidle'); + + const body = page.locator('body'); + await expect(body).toBeVisible(); + }); + + test('should maintain state during navigation', async ({ page, baseURL: configBaseURL }) => { + const baseUrl = configBaseURL || baseURL; + await page.goto(baseUrl); + await page.waitForLoadState('networkidle'); + + await page.goto(`${baseUrl}/tasks`); + await page.waitForLoadState('networkidle'); + + await page.goBack(); + await page.waitForLoadState('networkidle'); + + await expect(page).toHaveURL(/.*\/$/); + }); +}); + diff --git a/tests/e2e/playwright.config.ts b/tests/e2e/playwright.config.ts new file mode 100644 index 00000000..7c23228e --- /dev/null +++ b/tests/e2e/playwright.config.ts @@ -0,0 +1,37 @@ +import { defineConfig, devices } from '@playwright/test'; + +const baseURL = process.env.E2E_BASE_URL || 'http://localhost:3000'; + +if (!baseURL || baseURL.trim() === '') { + throw new Error('E2E_BASE_URL must be set to a valid URL'); +} + +export default defineConfig({ + testDir: './tests/e2e', + fullyParallel: true, + forbidOnly: !!process.env.CI, + retries: process.env.CI ? 2 : 0, + workers: process.env.CI ? 1 : undefined, + reporter: 'html', + use: { + baseURL: baseURL.trim(), + trace: 'on-first-retry', + }, + timeout: 30000, + projects: [ + { + name: 'chromium', + use: { + ...devices['Desktop Chrome'], + baseURL: baseURL.trim(), + }, + }, + ], + webServer: process.env.CI ? undefined : { + command: 'cd web && npm run dev', + url: baseURL.trim(), + reuseExistingServer: true, + }, +}); + + diff --git a/tests/e2e/tasks.spec.ts b/tests/e2e/tasks.spec.ts new file mode 100644 index 00000000..c5440e29 --- /dev/null +++ b/tests/e2e/tasks.spec.ts @@ -0,0 +1,55 @@ +import { test, expect } from '@playwright/test'; + +const baseURL = process.env.E2E_BASE_URL || 'http://localhost:3000'; + +test.describe('Tasks', () => { + test.beforeEach(async ({ page, baseURL: configBaseURL }) => { + const url = configBaseURL || baseURL; + await page.goto(url); + await page.waitForLoadState('networkidle'); + }); + + test('should display tasks page', async ({ page, baseURL: configBaseURL }) => { + const baseUrl = configBaseURL || baseURL; + await page.goto(`${baseUrl}/tasks`); + await page.waitForLoadState('networkidle'); + await expect(page.locator('body')).toBeVisible(); + }); + + test('should navigate to tasks page from home', async ({ page, baseURL: configBaseURL }) => { + const url = configBaseURL || baseURL; + await page.goto(url); + await page.waitForLoadState('networkidle'); + + const tasksLink = page.locator('a[href*="/tasks"]').first(); + if (await tasksLink.isVisible()) { + await tasksLink.click(); + await page.waitForLoadState('networkidle'); + await expect(page).toHaveURL(/.*tasks/); + } + }); + + test('should handle page load without errors', async ({ page, baseURL: configBaseURL }) => { + const errors: string[] = []; + page.on('pageerror', (error) => { + errors.push(error.message); + }); + + const baseUrl = configBaseURL || baseURL; + await page.goto(`${baseUrl}/tasks`); + await page.waitForLoadState('networkidle'); + + expect(errors).toHaveLength(0); + }); + + test('should have accessible page structure', async ({ page, baseURL: configBaseURL }) => { + const baseUrl = configBaseURL || baseURL; + await page.goto(`${baseUrl}/tasks`); + await page.waitForLoadState('networkidle'); + + const mainContent = page.locator('main, [role="main"], body'); + await expect(mainContent.first()).toBeVisible(); + }); +}); + + diff --git a/tests/package-lock.json b/tests/package-lock.json new file mode 100644 index 00000000..a8403b95 --- /dev/null +++ b/tests/package-lock.json @@ -0,0 +1,78 @@ +{ + "name": "tasks-e2e-tests", + "version": "1.0.0", + "lockfileVersion": 3, + "requires": true, + "packages": { + "": { + "name": "tasks-e2e-tests", + "version": "1.0.0", + "devDependencies": { + "@playwright/test": "^1.48.0" + } + }, + "node_modules/@playwright/test": { + "version": "1.57.0", + "resolved": "https://registry.npmjs.org/@playwright/test/-/test-1.57.0.tgz", + "integrity": "sha512-6TyEnHgd6SArQO8UO2OMTxshln3QMWBtPGrOCgs3wVEmQmwyuNtB10IZMfmYDE0riwNR1cu4q+pPcxMVtaG3TA==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "playwright": "1.57.0" + }, + "bin": { + "playwright": "cli.js" + }, + "engines": { + "node": ">=18" + } + }, + "node_modules/fsevents": { + "version": "2.3.2", + "resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.2.tgz", + "integrity": "sha512-xiqMQR4xAeHTuB9uWm+fFRcIOgKBMiOBP+eXiyT7jsgVCq1bkVygt00oASowB7EdtpOHaaPgKt812P9ab+DDKA==", + "dev": true, + "hasInstallScript": true, + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": "^8.16.0 || ^10.6.0 || >=11.0.0" + } + }, + "node_modules/playwright": { + "version": "1.57.0", + "resolved": "https://registry.npmjs.org/playwright/-/playwright-1.57.0.tgz", + "integrity": "sha512-ilYQj1s8sr2ppEJ2YVadYBN0Mb3mdo9J0wQ+UuDhzYqURwSoW4n1Xs5vs7ORwgDGmyEh33tRMeS8KhdkMoLXQw==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "playwright-core": "1.57.0" + }, + "bin": { + "playwright": "cli.js" + }, + "engines": { + "node": ">=18" + }, + "optionalDependencies": { + "fsevents": "2.3.2" + } + }, + "node_modules/playwright-core": { + "version": "1.57.0", + "resolved": "https://registry.npmjs.org/playwright-core/-/playwright-core-1.57.0.tgz", + "integrity": "sha512-agTcKlMw/mjBWOnD6kFZttAAGHgi/Nw0CZ2o6JqWSbMlI219lAFLZZCyqByTsvVAJq5XA5H8cA6PrvBRpBWEuQ==", + "dev": true, + "license": "Apache-2.0", + "bin": { + "playwright-core": "cli.js" + }, + "engines": { + "node": ">=18" + } + } + } +} diff --git a/tests/package.json b/tests/package.json new file mode 100644 index 00000000..c6503314 --- /dev/null +++ b/tests/package.json @@ -0,0 +1,13 @@ +{ + "name": "tasks-e2e-tests", + "version": "1.0.0", + "private": true, + "scripts": { + "test": "playwright test", + "test:ui": "playwright test --ui" + }, + "devDependencies": { + "@playwright/test": "^1.48.0" + } +} + diff --git a/tests/test-results/.last-run.json b/tests/test-results/.last-run.json new file mode 100644 index 00000000..4674647c --- /dev/null +++ b/tests/test-results/.last-run.json @@ -0,0 +1,17 @@ +{ + "status": "failed", + "failedTests": [ + "9563bb343517399bdb94-b09cc788647708a872fb", + "9563bb343517399bdb94-5a2dacc3bcad4643592d", + "9563bb343517399bdb94-c594e0a8656a86e7b661", + "9563bb343517399bdb94-4e29b303eacdb49eb504", + "9563bb343517399bdb94-c9e47df030b0a0590df3", + "3e99b8f6c3cad4665a83-aed8a12dd1c4db3d6226", + "3e99b8f6c3cad4665a83-3cad26c0f189ba76b775", + "3e99b8f6c3cad4665a83-548e59e1498f72b6c889", + "a5bfa21ef28a555ef7bf-c261be3b40dfaa1657b5", + "a5bfa21ef28a555ef7bf-4146f522add738f3954c", + "a5bfa21ef28a555ef7bf-8960c93e209089280ebd", + "a5bfa21ef28a555ef7bf-8d0cdcf618720c3d054f" + ] +} \ No newline at end of file diff --git a/web/package-lock.json b/web/package-lock.json index ed73c1de..c01a7a2c 100644 --- a/web/package-lock.json +++ b/web/package-lock.json @@ -38,6 +38,7 @@ "@graphql-codegen/typescript-operations": "5.0.6", "@graphql-codegen/typescript-react-query": "6.1.1", "@helpwave/eslint-config": "0.0.11", + "@playwright/test": "^1.57.0", "@types/node": "20.17.10", "@types/react": "18.3.17", "@types/react-dom": "18.3.5", @@ -3908,6 +3909,22 @@ "url": "https://opencollective.com/parcel" } }, + "node_modules/@playwright/test": { + "version": "1.57.0", + "resolved": "https://registry.npmjs.org/@playwright/test/-/test-1.57.0.tgz", + "integrity": "sha512-6TyEnHgd6SArQO8UO2OMTxshln3QMWBtPGrOCgs3wVEmQmwyuNtB10IZMfmYDE0riwNR1cu4q+pPcxMVtaG3TA==", + "devOptional": true, + "license": "Apache-2.0", + "dependencies": { + "playwright": "1.57.0" + }, + "bin": { + "playwright": "cli.js" + }, + "engines": { + "node": ">=18" + } + }, "node_modules/@repeaterjs/repeater": { "version": "3.0.6", "resolved": "https://registry.npmjs.org/@repeaterjs/repeater/-/repeater-3.0.6.tgz", @@ -7128,6 +7145,21 @@ "dev": true, "license": "ISC" }, + "node_modules/fsevents": { + "version": "2.3.2", + "resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.2.tgz", + "integrity": "sha512-xiqMQR4xAeHTuB9uWm+fFRcIOgKBMiOBP+eXiyT7jsgVCq1bkVygt00oASowB7EdtpOHaaPgKt812P9ab+DDKA==", + "dev": true, + "hasInstallScript": true, + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": "^8.16.0 || ^10.6.0 || >=11.0.0" + } + }, "node_modules/function-bind": { "version": "1.1.2", "resolved": "https://registry.npmjs.org/function-bind/-/function-bind-1.1.2.tgz", @@ -10177,6 +10209,38 @@ "url": "https://github.com/sponsors/jonschlinkert" } }, + "node_modules/playwright": { + "version": "1.57.0", + "resolved": "https://registry.npmjs.org/playwright/-/playwright-1.57.0.tgz", + "integrity": "sha512-ilYQj1s8sr2ppEJ2YVadYBN0Mb3mdo9J0wQ+UuDhzYqURwSoW4n1Xs5vs7ORwgDGmyEh33tRMeS8KhdkMoLXQw==", + "devOptional": true, + "license": "Apache-2.0", + "dependencies": { + "playwright-core": "1.57.0" + }, + "bin": { + "playwright": "cli.js" + }, + "engines": { + "node": ">=18" + }, + "optionalDependencies": { + "fsevents": "2.3.2" + } + }, + "node_modules/playwright-core": { + "version": "1.57.0", + "resolved": "https://registry.npmjs.org/playwright-core/-/playwright-core-1.57.0.tgz", + "integrity": "sha512-agTcKlMw/mjBWOnD6kFZttAAGHgi/Nw0CZ2o6JqWSbMlI219lAFLZZCyqByTsvVAJq5XA5H8cA6PrvBRpBWEuQ==", + "devOptional": true, + "license": "Apache-2.0", + "bin": { + "playwright-core": "cli.js" + }, + "engines": { + "node": ">=18" + } + }, "node_modules/possible-typed-array-names": { "version": "1.1.0", "resolved": "https://registry.npmjs.org/possible-typed-array-names/-/possible-typed-array-names-1.1.0.tgz", diff --git a/web/package.json b/web/package.json index b16c0cd4..b2b1eee4 100644 --- a/web/package.json +++ b/web/package.json @@ -42,6 +42,7 @@ "@graphql-codegen/typescript-operations": "5.0.6", "@graphql-codegen/typescript-react-query": "6.1.1", "@helpwave/eslint-config": "0.0.11", + "@playwright/test": "^1.57.0", "@types/node": "20.17.10", "@types/react": "18.3.17", "@types/react-dom": "18.3.5", diff --git a/web/public/env-config.js b/web/public/env-config.js new file mode 100644 index 00000000..edcf23e1 --- /dev/null +++ b/web/public/env-config.js @@ -0,0 +1,2 @@ +window.__ENV = {} +