diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..fadd4fa --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,70 @@ +name: CI + +on: + push: + branches: [main] + pull_request: + branches: [main] + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + lint: + name: ruff + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.12" + - name: Install + run: | + python -m pip install --upgrade pip + pip install -e ".[dev]" + - name: ruff check + run: ruff check . + - name: ruff format --check + run: ruff format --check . + + typecheck: + name: mypy + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.12" + - name: Install + run: | + python -m pip install --upgrade pip + pip install -e ".[dev]" + - name: mypy + run: mypy + + test: + name: pytest (py${{ matrix.python }}, ${{ matrix.os }}) + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest] + python: ["3.11", "3.12"] + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python }} + - name: Install (ubuntu) + if: runner.os == 'Linux' + run: | + python -m pip install --upgrade pip + pip install -e ".[dev]" + - name: Install (macos) + if: runner.os == 'macOS' + run: | + python -m pip install --upgrade pip + pip install -e ".[dev,macos]" + - name: pytest + run: pytest -ra diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..f30205c --- /dev/null +++ b/.gitignore @@ -0,0 +1,77 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# PyInstaller +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Environments +.env +.venv +env/ +venv/ +ENV/ + +# Type-checking +.mypy_cache/ +.dmypy.json +dmypy.json +.pyre/ +.pytype/ +.ruff_cache/ + +# IDEs +.idea/ +.vscode/ +*.swp +*.swo + +# macOS +.DS_Store + +# Project-specific +benchmarks/results-*.json +benchmarks/.cache/ +*.mp4 +*.mov +*.mkv +*.webm +!tests/fixtures/**/*.mp4 +.claude/ diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..9f2c934 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,23 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [Unreleased] + +### Added +- Repo scaffolding: `pyproject.toml` with hatchling build, dual Apache-2.0/MIT license, ruff + mypy strict configs, pytest setup. +- `src/findit_keyframe/` module skeleton: `types`, `decoder`, `quality`, `sampler`, `saliency`, `cli`. +- `docs/algorithm.md` and `docs/rust-porting.md` skeletons for algorithm spec and Rust translation map. +- GitHub Actions CI: ruff, mypy, pytest on push/PR for Python 3.11 and 3.12. +- Quality module: `rgb_to_luma` (BT.601 fixed-point), `laplacian_variance`, `mean_luma`, `luma_variance`, `entropy`, `QualityGate`, `compute_quality`. +- Decoder module: `VideoDecoder` (PyAV backend, context-manager) with `decode_at` (keyframe seek + forward decode) and `decode_sequential` (linear pass over a shot list); `Strategy` enum and `pick_strategy` density heuristic. +- Sampler module: `compute_bins` (boundary-shrunken equal partition), `score_bin_candidates` (ordinal-rank composite), `select_from_bin`, fallback path with `Confidence.Low` / `Confidence.Degraded`, top-level `extract_for_shot` and `extract_all`. +- CLI (`findit-keyframe extract VIDEO SHOTS_JSON OUTPUT`): scenesdetect-compatible shot JSON parsing, optional `--config` for `SamplingConfig` overrides, `--saliency {none,apple}` flag (apple stub returns input-error until T6), per-keyframe baseline JPEG output via PyAV's mjpeg / image2 muxer, `manifest.json` output with quality dict and confidence string. Exit codes: 0 success, 1 input error, 2 extraction error. +- Benchmark script (`benchmarks/bench_e2e.py`): standalone CLI, optional shot JSON, configurable `--target-size`, append-only `results.md` log with date and git SHA, peak-memory normalised across Linux/macOS. +- Saliency providers (`saliency.py`): `SaliencyProvider` runtime-checkable Protocol, `NoopSaliencyProvider` (always 0.0), `AppleVisionSaliencyProvider` (macOS, wraps `VNGenerateAttentionBasedSaliencyImageRequest` via pyobjc; lazy import so non-macOS imports don't crash), `default_saliency_provider()` factory. +- Sampler now accepts an optional `saliency_provider` argument on `select_from_bin`, `extract_for_shot`, and `extract_all`; saliency mass feeds the composite score weight. +- CLI `--saliency apple` now wires `AppleVisionSaliencyProvider` end-to-end (was a stub returning input-error in P3). +- Documentation finalised: `docs/algorithm.md` reflects shipped behaviour (cell-centred sampling, ordinal-rank scoring, three-tier fallback) and adds a Saliency Provider Contract section. `docs/rust-porting.md` carries a complete type map (including `DecodedFrame`, `Strategy`, `QualityGate`, all three `SaliencyProvider` impls), the actual numpy ops used, and an Apple Vision idiom map for `objc2-vision`. diff --git a/LICENSE-APACHE b/LICENSE-APACHE new file mode 100644 index 0000000..9fa31ef --- /dev/null +++ b/LICENSE-APACHE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for describing the origin of the Work and + reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Support. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or support. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2026 Findit-AI + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/LICENSE-MIT b/LICENSE-MIT new file mode 100644 index 0000000..e017625 --- /dev/null +++ b/LICENSE-MIT @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2026 Findit-AI + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..98a9bb7 --- /dev/null +++ b/README.md @@ -0,0 +1,60 @@ +# findit-keyframe + +Per-shot keyframe extraction with stratified temporal sampling. + +> **Status**: Python reference implementation. The Rust translation is the long-term target — see [`docs/rust-porting.md`](docs/rust-porting.md). Every module is written for a 1:1 port. + +`findit-keyframe` consumes shot boundaries (e.g., from [`scenesdetect`](https://github.com/Findit-AI/scenesdetect)) and selects 1–N high-quality, **temporally distributed** frames per shot so downstream models — vision-language (Qwen3-VL-2B), embeddings (SigLIP 2), saliency (Apple Vision) — see the temporal progression of each shot, not just one representative moment. + +## Why not "one frame per shot" + +A single keyframe per shot loses temporal information. A 30-second talking-head shot and a 30-second action sequence look identical to a downstream VLM if you give it one frame. We split each shot into equal-duration **bins**, run a quality gate on candidates within each bin, and pick the best one — yielding a small, well-spaced sequence the VLM can reason about. + +## Install + +```bash +pip install -e ".[dev]" # core +pip install -e ".[dev,macos]" # + Apple Vision saliency provider +``` + +Requires Python ≥ 3.11. Core deps: `av` (PyAV) and `numpy`. No OpenCV. + +## Quickstart + +```python +from pathlib import Path +from findit_keyframe import ( + SamplingConfig, ShotRange, Timebase, Timestamp, + VideoDecoder, default_saliency_provider, extract_all, +) + +shots = [ + ShotRange( + start=Timestamp(0, Timebase(1, 1000)), + end =Timestamp(5000, Timebase(1, 1000)), + ), + # ... more shots from scenesdetect ... +] +decoder = VideoDecoder.open(Path("my_video.mp4")) +keyframes = extract_all(shots, decoder, SamplingConfig(), default_saliency_provider()) +for shot_keyframes in keyframes: + for kf in shot_keyframes: + print(kf.timestamp.seconds, kf.confidence, kf.quality.laplacian_var) +``` + +## CLI + +```bash +findit-keyframe extract video.mp4 shots.json out/ --saliency apple +``` + +The shot JSON schema and manifest output schema are documented in [`docs/algorithm.md`](docs/algorithm.md). + +## License + +Dual-licensed under either of: + +- Apache License, Version 2.0 ([`LICENSE-APACHE`](LICENSE-APACHE)) +- MIT license ([`LICENSE-MIT`](LICENSE-MIT)) + +at your option. diff --git a/TASKS.md b/TASKS.md new file mode 100644 index 0000000..3d6d6d5 --- /dev/null +++ b/TASKS.md @@ -0,0 +1,417 @@ +# findit-keyframe — Task Document + +**Target repo**: `github.com/Findit-AI/findit-keyframe` (to be created) +**Language**: Python (reference implementation), later translated to Rust by teammate +**Status**: Planning phase — this document is the source of truth for scope and acceptance + +--- + +## 0. Context Recap + +- **Upstream**: `scenesdetect` (Rust, published at `github.com/Findit-AI/scenesdetect` branch `0.1.0`) emits shot boundaries as `TimeRange` (start/end timestamps with timebase). +- **Downstream**: SigLIP 2 (vision embedding), Apple Vision (saliency/tags), Qwen3-VL-2B (VLM description). +- **Problem**: For each shot, select 1–N high-quality frames that are **temporally distributed** across the shot so VLMs can understand temporal progression, not just a single representative moment. +- **Architectural decision**: `findit-keyframe` is a **standalone repo**, Python reference first, then translated to Rust. +- **Why Python first**: Faster algorithm iteration, easier for teammate review before commitment to Rust. + +--- + +## 1. Non-Goals (scope control) + +To prevent scope creep, we explicitly exclude: + +- ❌ Scene boundary detection (that's `scenesdetect`'s job; we consume its output). +- ❌ Training any ML selector (no labeled data, no budget). +- ❌ Video-level summarization (we're per-shot, not whole-video). +- ❌ GPU / hardware-accelerated decode in Python (Rust phase will add VideoToolbox). +- ❌ Distributed processing / cluster execution. +- ❌ Direct downstream model inference (SigLIP/Qwen). We only provide clean keyframe outputs; consumers run inference. +- ❌ Cross-shot deduplication (future optimization, not P1). + +--- + +## 2. Project-Level Goals + +| # | Goal | Why it matters | +|---|------|----------------| +| G1 | Produce a clean, reviewable Python reference implementation | Teammate can review algorithm correctness before committing to Rust | +| G2 | Every Python module has a 1:1 Rust translation path | Avoid Python-idiom traps (metaclasses, duck typing, dynamic dispatch) | +| G3 | Algorithmic correctness verifiable by deterministic fixtures | Rust translation can replay the same fixtures and match bit-for-bit (or near-bit) | +| G4 | Standalone repo, zero dependency on FindIt internals | Can be open-sourced; teammate clones independently | +| G5 | Reasonable Python performance (not a toy) | Must survive real videos (Kino Demo, Kurutta Ippeiji) for benchmarking | + +--- + +## 3. Task Breakdown + +Each task has: +- **Goal**: what the deliverable is +- **Scope**: what's included / excluded +- **Verification**: how we confirm it's done correctly + +--- + +### Task 1 — Repo scaffolding + +**Goal**: Create `findit-keyframe` repo with complete project structure, toolchain config, CI, and documentation skeleton. + +**Scope (included)**: +- `pyproject.toml` with `hatchling` build backend, locked Python ≥ 3.11 +- Dependencies: `av` (PyAV), `numpy`; optional `macos` extra for `pyobjc-framework-Vision`; dev extras for `pytest`, `ruff`, `mypy` +- Directory layout: + ``` + src/findit_keyframe/{__init__.py, types.py, decoder.py, quality.py, sampler.py, saliency.py, cli.py} + tests/ + benchmarks/ + examples/ + docs/ + ``` +- `README.md`, `LICENSE` (Apache-2.0 + MIT dual to match scenesdetect), `CHANGELOG.md`, `.gitignore` +- `.github/workflows/ci.yml`: pytest, ruff, mypy on push/PR +- `ruff.toml` with strict lint rules matching the project ethos (line-length 100, import sorting, etc.) +- `mypy` in strict mode + +**Scope (excluded)**: +- No actual algorithm code yet +- No fixture videos committed +- No benchmark numbers + +**Verification**: +- [ ] `git clone` works on a fresh machine +- [ ] `pip install -e ".[dev,macos]"` succeeds on macOS +- [ ] `pytest` runs (0 tests collected is fine) +- [ ] `ruff check .` passes +- [ ] `mypy src/` passes (no code yet → trivially passes) +- [ ] CI green on first push +- [ ] `README.md` clearly states: "Python reference implementation; Rust translation target. See `docs/rust-porting.md`." + +**Estimated effort**: 0.5–1 day + +--- + +### Task 2 — Type definitions (`types.py`) + +**Goal**: Mirror `scenesdetect`'s public types in Python using `@dataclass(frozen=True)` + explicit type hints, so Rust translation is 1:1. + +**Scope (included)**: +- `Timebase(num: int, den: int)` — rational timebase, `den > 0` invariant +- `Timestamp(pts: int, timebase: Timebase)` — with `seconds` property +- `ShotRange(start: Timestamp, end: Timestamp)` — with `duration_sec` property +- `SamplingConfig` — user-tunable parameters with defaults: + - `target_interval_sec: float = 4.0` + - `candidates_per_bin: int = 6` + - `max_frames_per_shot: int = 16` + - `boundary_shrink_pct: float = 0.02` + - `fallback_expand_pct: float = 0.20` + - `target_size: int = 384` +- `Confidence` enum (`High`, `Low`, `Degraded`) +- `QualityMetrics(laplacian_var, mean_luma, luma_variance, entropy, saliency_mass)` +- `ExtractedKeyframe(shot_id, timestamp, bucket_index, rgb: bytes, width, height, quality, confidence)` + +**Scope (excluded)**: +- No `__slots__` (interferes with dataclass in some versions; Rust doesn't need it anyway) +- No custom `__hash__` / `__eq__` unless required by algorithm (let `@dataclass(frozen=True)` generate them) +- No `Protocol` / ABC for decoder (we'll use a concrete class later; teammate can refactor) + +**Verification**: +- [ ] `mypy --strict src/findit_keyframe/types.py` passes +- [ ] All dataclasses are `frozen=True` where immutable (everything except `SamplingConfig` and `ExtractedKeyframe`) +- [ ] Unit tests confirming: + - `Timebase(0, 1)` raises (zero numerator is fine — scenesdetect allows it; zero denominator must fail) + - `Timestamp.seconds` returns correct float for common cases (`1000 @ 1/1000 == 1.0s`) + - `ShotRange.duration_sec` handles start/end in the same timebase correctly +- [ ] `docs/rust-porting.md` has a **Type Map** section showing Python ↔ Rust field correspondence for every type + +**Estimated effort**: 0.5 day + +--- + +### Task 3 — Video decoder (`decoder.py`) + +**Goal**: PyAV-based frame decoder with two strategies (Sequential demux vs per-shot seek), auto-selected by shot density. + +**Scope (included)**: +- `VideoDecoder` class: + - `open(path: Path) -> VideoDecoder` — opens a video, caches metadata (fps, duration, time_base) + - `decode_at(time_sec: float) -> DecodedFrame` — seek + decode one frame at/near the given time, returns RGB frame + PTS + - `decode_sequential(shots: list[ShotRange]) -> Iterator[(shot_id, DecodedFrame)]` — single pass through the file, emits frames falling inside any shot's range +- Auto-strategy picker: `pick_strategy(shots, duration_sec) -> Strategy` based on density heuristic (density > 0.3 shots/sec OR shots > 200 → Sequential, else PerShotSeek) +- RGB conversion: decode in native format, convert to RGB24 at `target_size × target_size` via PyAV's built-in reformatter (wraps swscale) +- Frame metadata: PTS in video time_base, exposed as `Timestamp` + +**Scope (excluded)**: +- No VideoToolbox hardware decode (macOS hwaccel belongs in Rust phase) +- No variable frame rate edge cases beyond basic PTS handling (document as known limitation) +- No audio decode + +**Verification**: +- [ ] Unit test: decode a 10s synthetic video (generated via ffmpeg with known frame patterns), verify `decode_at(5.0)` returns the expected frame +- [ ] Integration test: open `Kino Demo Render.mp4`, decode first frame, verify dimensions match expected resolution (1920×1080) +- [ ] Performance sanity: 1-minute 1080p video, 100 `decode_at` calls should complete in < 5 seconds on M-series Mac (with PerShotSeek) +- [ ] Sequential strategy: decode a full 10s video (30 fps, so 300 frames), should yield 300 frames in order +- [ ] Strategy auto-selection: unit test with mock shots list, confirm density threshold picks correct strategy + +**Estimated effort**: 1.5 days + +--- + +### Task 4 — Quality metrics (`quality.py`) + +**Goal**: Pure numpy implementation of per-frame quality signals. No OpenCV. Every function translates 1:1 to Rust `ndarray` or manual loops. + +**Scope (included)**: +- `rgb_to_luma(rgb: np.ndarray) -> np.ndarray` — BT.601 integer: `(66*R + 129*G + 25*B + 128) >> 8 + 16`, matches scenesdetect's fixed-point choice +- `laplacian_variance(luma: np.ndarray) -> float` — 3×3 Laplacian kernel `[[0,1,0],[1,-4,1],[0,1,0]]`, return variance of the filtered output. Explicit loop + numpy, not `scipy.signal` or `cv2`. +- `mean_luma(luma: np.ndarray) -> float` — arithmetic mean, normalized to `[0.0, 1.0]` +- `luma_variance(luma: np.ndarray) -> float` — sample variance +- `entropy(luma: np.ndarray, bins: int = 256) -> float` — Shannon entropy of the 256-bin histogram, base-2 log. Manual numpy implementation (no scipy). +- `QualityGate` class with defaults matching our earlier decision: + - Reject if `mean_luma` outside `[15/255, 240/255]` + - Reject if `luma_variance` < 5 +- `compute_quality(rgb: np.ndarray, saliency: Optional[float]) -> QualityMetrics` — composite, returns populated `QualityMetrics` dataclass + +**Scope (excluded)**: +- No OpenCV, no scipy, no scikit-image +- No learned quality models (NIMA, MUSIQ) +- No motion blur detection beyond what Laplacian variance catches +- No Apple Vision integration in this module (that's `saliency.py`) + +**Verification**: +- [ ] Unit test: synthetic all-black frame → `mean_luma ≈ 0.0`, `laplacian_var ≈ 0.0`, gate rejects +- [ ] Unit test: synthetic random noise frame → high laplacian_var, gate accepts +- [ ] Unit test: gradient image (smooth ramp) → low laplacian_var, high luma_variance, gate accepts +- [ ] Golden fixture test: pre-computed quality metrics JSON for 10 canonical frames (stored in `tests/fixtures/quality/`); Python output must match to 6 decimal places +- [ ] Performance: `compute_quality` on 384×384 RGB should complete in < 5 ms on M-series Mac +- [ ] `docs/rust-porting.md` shows the exact numpy op → ndarray op mapping for each function + +**Estimated effort**: 1 day + +--- + +### Task 5 — Sampler (`sampler.py`) — **core algorithm** + +**Goal**: Implement stratified temporal sampling + within-bucket quality selection with graceful degradation. This is the algorithmic heart. + +**Scope (included)**: +- `compute_bins(shot: ShotRange, config: SamplingConfig) -> list[tuple[float, float]]` — partition shot into N equal-duration bins with boundary shrinkage on first/last +- `score_candidate(quality: QualityMetrics) -> float` — weighted composite: + ``` + score = 0.6 * normalized(laplacian_var) + + 0.2 * normalized(entropy) + + 0.2 * (saliency_mass or 0.0) + ``` + where `normalized` is z-score or percentile-rank within the bin's candidate pool (document which) +- `select_from_bin(candidates: list[DecodedFrame], config: SamplingConfig) -> Optional[(DecodedFrame, QualityMetrics, Confidence)]` — apply quality gate, score, pick `argmax`; return `None` if all filtered out +- `fallback_pick(shot, bins, bin_idx, decoder, config)` — expand search to ±`fallback_expand_pct` of adjacent bins; if still nothing, force-pick highest-quality candidate with `Confidence.Degraded` +- `extract_for_shot(shot: ShotRange, decoder: VideoDecoder, config: SamplingConfig) -> list[ExtractedKeyframe]` — main entry point +- `extract_all(shots: list[ShotRange], decoder: VideoDecoder, config: SamplingConfig) -> list[list[ExtractedKeyframe]]` — process all shots; uses decoder's auto-strategy + +**Scope (excluded)**: +- No MMR-based cross-bin deduplication (keep for P3) +- No CLIP/SigLIP-based relevance scoring (keep for P3) +- No cross-shot coherence +- No caching of decoded frames across shots (memory concerns; re-decode is cheap for sparse seek) + +**Verification**: +- [ ] Unit test: synthetic 20s shot with known quality gradient (frames 0–100 sharp, 100–200 blurred, 200–300 sharp) — verify selected frames come from sharp regions in each bucket +- [ ] Unit test: N = `ceil(duration / 4.0)` with duration = 5s → N = 2; duration = 60s → N = 15; duration = 120s → N = 16 (capped) +- [ ] Unit test: black-frame shot — all candidates in first bucket fail hard gate → fallback picks degraded frame, `Confidence.Degraded` +- [ ] Unit test: bin boundaries — first bucket's `t0` is shifted by `boundary_shrink_pct`, last bucket's `t1` is shifted similarly +- [ ] Integration test: run on `Kino Demo Render.mp4` with scene cuts from scenesdetect, visually inspect top 5 shots — no blurry frames, no black frames, temporally distributed +- [ ] Regression fixture: JSON snapshot of `(shot_id, bin_index, timestamp, quality_score)` tuples for Kino Demo, deterministic across runs +- [ ] `docs/rust-porting.md` documents the exact algorithm in pseudocode (same as Python, but language-agnostic) + +**Estimated effort**: 2 days + +--- + +### Task 6 — Saliency adapter (`saliency.py`) + +**Goal**: Optional Apple Vision saliency integration, with a no-op stub for non-macOS / Rust-future environments. + +**Scope (included)**: +- `SaliencyProvider` protocol/interface: `compute(rgb: np.ndarray) -> float` (returns saliency mass in `[0, 1]`) +- `NoopSaliencyProvider` — always returns 0.0; used as default and on non-macOS +- `AppleVisionSaliencyProvider` (macOS only, guarded by `platform.system() == "Darwin"`) — uses `pyobjc-framework-Vision` to call `VNGenerateAttentionBasedSaliencyImageRequest`, sums the heatmap pixels +- Factory: `default_saliency_provider() -> SaliencyProvider` — returns Apple provider on macOS, Noop elsewhere + +**Scope (excluded)**: +- No other saliency models (DeepGaze, UNISAL, etc.) +- No batch processing (one frame at a time; Rust can batch later) + +**Verification**: +- [ ] Unit test: `NoopSaliencyProvider` always returns 0.0 +- [ ] Integration test (macOS only, skipped elsewhere): `AppleVisionSaliencyProvider` on a known image (center-white, corners-black) returns a saliency value > 0.3 +- [ ] Gracefully degrades: if `pyobjc-framework-Vision` is not installed, import of `AppleVisionSaliencyProvider` must not crash `findit_keyframe` package import +- [ ] `docs/rust-porting.md` shows how `AppleVisionSaliencyProvider` maps to `objc2-vision` + `VNGenerateAttentionBasedSaliencyImageRequest` in Rust + +**Estimated effort**: 1 day (macOS testing adds overhead) + +--- + +### Task 7 — CLI tool (`cli.py`) + +**Goal**: Command-line utility for end-to-end testing and teammate demonstration. + +**Scope (included)**: +- `findit-keyframe extract VIDEO_PATH SHOTS_JSON OUTPUT_DIR` — reads shots from JSON (scenesdetect-compatible format), extracts keyframes, writes them as JPEG files + a manifest JSON +- `--config CONFIG_JSON` — override `SamplingConfig` defaults +- `--saliency {none, apple}` — pick saliency provider +- Shot JSON schema (input): + ```json + {"shots": [{"id": 0, "start_pts": 0, "end_pts": 1000, "timebase_num": 1, "timebase_den": 1000}, ...]} + ``` +- Manifest JSON schema (output): + ```json + {"video": "path", "keyframes": [{"shot_id": 0, "bucket": 0, "file": "kf_000_000.jpg", "timestamp_sec": 1.2, "quality": {...}, "confidence": "high"}, ...]} + ``` +- Uses `argparse` (stdlib, no click/typer — cleaner Rust translation) + +**Scope (excluded)**: +- No interactive TUI +- No progress bar via tqdm (keep deps minimal; log lines are fine) +- No direct scenesdetect invocation (user runs scenesdetect separately, pipes its JSON) + +**Verification**: +- [ ] CLI help works: `findit-keyframe --help` +- [ ] End-to-end test: given a fixture video + shot JSON, produces expected number of JPEG files +- [ ] Manifest JSON is valid and round-trips through `json.loads` +- [ ] Exit codes: 0 on success, 1 on input errors, 2 on extraction failures + +**Estimated effort**: 1 day + +--- + +### Task 8 — Benchmarks (`benchmarks/`) + +**Goal**: Establish a performance baseline that the future Rust translation must beat (target: 5–10× speedup). + +**Scope (included)**: +- `bench_e2e.py` — run extraction on `Kino Demo Render.mp4` and `狂った一頁 編集済み.mp4`, log: + - Total wall time + - Frames decoded + - Frames per second throughput + - Memory high-water mark (via `resource.getrusage`) +- Uses `pytest-benchmark` for statistical rigor +- Output: Markdown table written to `benchmarks/results.md`, committed + +**Scope (excluded)**: +- No per-function micro-benchmarks (let Rust phase profile separately) +- No flamegraph generation + +**Verification**: +- [ ] Benchmark runs to completion on both test videos +- [ ] Results written to `benchmarks/results.md` with timestamp + git SHA +- [ ] Performance budget: Kino Demo (1m44s) should complete extraction in < 30 seconds on M-series Mac (Python, unoptimized) + +**Estimated effort**: 0.5 day + +--- + +### Task 9 — Documentation (`docs/`) + +**Goal**: Both user-facing and Rust-translation-facing docs. + +**Scope (included)**: +- `docs/algorithm.md` — algorithmic specification with pseudocode, invariants, parameter rationale. Language-agnostic. +- `docs/rust-porting.md` — **the most important doc**: + - Type map (Python dataclass ↔ Rust struct, field by field) + - Dependency map (Python lib ↔ Rust crate) + - Idiom map (Python pattern ↔ Rust pattern) — e.g., `dataclasses.replace` ↔ `Struct { field: new, ..old }` + - Test fixture map (Python test ↔ Rust test, same JSON fixtures) + - Known Python-only shortcuts that Rust can tighten (e.g., `np.linspace` vs explicit index loop) +- `README.md` — user-facing: quickstart, example, how to consume scenesdetect output +- `CHANGELOG.md` — standard Keep-a-Changelog format + +**Scope (excluded)**: +- No API reference site (docs.rs equivalent) — that's for Rust phase +- No tutorials beyond the quickstart + +**Verification**: +- [ ] All Python types have a row in the type map +- [ ] All numpy ops used in the code have a row in the idiom map +- [ ] README quickstart is copy-pasteable and works on a fresh clone +- [ ] Teammate can produce a Rust module skeleton from `rust-porting.md` alone, without reading Python source (sanity-check via code review) + +**Estimated effort**: 1 day (continuous, updated as tasks progress) + +--- + +## 4. Implementation Phases + +| Phase | Tasks | Deliverable | ETA | +|-------|-------|-------------|-----| +| **P1 — Foundation** | T1 + T2 + T9 (started) | Scaffold + types + type-map docs | 1.5 days | +| **P2 — MVP** | T3 + T4 + T5 (basic) | Working end-to-end on synthetic video | 4.5 days | +| **P3 — Real video** | T7 + T5 (fallback polish) + T8 | CLI + benchmarks on real videos | 2 days | +| **P4 — macOS + polish** | T6 + T9 finalize | Apple Vision saliency + complete docs | 1.5 days | +| **P5 — Handoff** | Teammate review | Comments, Rust translation begins | (parallel to P4) | + +**Total Python phase**: ~9–10 working days. +**Rust translation** (by teammate, out of our scope): ~2 weeks separately. + +--- + +## 5. Cross-Cutting Verification + +Across all tasks, these must hold: + +- [ ] **No OpenCV** anywhere (`grep -r "import cv2"` → 0 hits) +- [ ] **No scipy / sklearn / skimage** in `src/` (test fixtures may use them for generating ground truth, but not production code) +- [ ] **Type hints on every public function** (`mypy --strict` clean) +- [ ] **All public functions have docstrings** with: purpose, args, returns, raises +- [ ] **No global mutable state** (Rust doesn't like it, Python shouldn't need it) +- [ ] **No `from X import *`** (explicit imports only, translation-friendly) +- [ ] **Every test has a fixture file** or is runnable without network + +--- + +## 6. Risk Register + +| Risk | Likelihood | Impact | Mitigation | +|------|------------|--------|------------| +| PyAV API quirks on macOS (FFmpeg 7.x) | Medium | Medium | Pin `av>=13.0,<14.0` — PyAV 13's exception hierarchy (`av.error.FFmpegError` as a common base inheriting from both `OSError` and `ValueError`) is what the CLI's narrowed `except` clause relies on; test on real Kino Demo early | +| Apple Vision pyobjc behavior differs from objc2-vision | Low | Low | Document saliency is "signal only"; Rust may tune independently | +| Numpy precision differs from Rust ndarray (rare but real) | Low | Medium | Golden fixtures at 6-decimal precision with tolerance | +| Shot edge cases: zero-duration, overlapping, negative | Medium | High | Validate inputs strictly; unit tests for pathological shots | +| Decoder returning frames off by ±1 frame from requested PTS | High | Medium | Document ±1 frame tolerance; test at known I-frames | + +--- + +## 7. Out-of-Scope but Noted for Later + +Recording decisions that are explicitly **not P1** but likely P2+ (Rust phase or beyond): + +- **MMR-based cross-bin deduplication** — if identical shots (news anchors) have visually near-identical keyframes across bins, collapse them. Needs VLM or CLIP embedding, defer. +- **SigLIP medoid as L1 selection** — when selecting 1 frame for indexing (not VLM description), use SigLIP embedding medoid within quality-gated candidates. Needs ONNX runtime integration. +- **Cross-shot coherence** — if consecutive shots form a "scene" in film-theory sense, deduplicate keyframes across shot boundaries. Research needed. +- **Learned quality models** (NIMA, MUSIQ) — if Laplacian variance proves insufficient on real data, revisit. +- **Hardware decode (VideoToolbox)** — Rust phase only. + +--- + +## 8. Glossary + +- **Shot**: A contiguous run of video frames from one scene cut to the next, as emitted by scenesdetect. +- **Bucket / Bin**: An equal-duration subdivision of a shot for temporal stratification. +- **Keyframe**: A representative frame extracted from a bucket; may be one of many per shot. +- **Hard gate**: Pass/fail quality threshold that rejects a candidate outright (e.g., black frame). +- **Soft score**: Continuous quality score used to rank candidates that passed the hard gate. +- **Degraded confidence**: Output tag indicating all candidates failed hard gate but one was force-selected. +- **Medoid**: The element of a set with minimum total distance to all others; a "real" (vs interpolated) centroid. + +--- + +## 9. Working Directory Handoff + +Current working directory: `/Users/cheongzhiyan/Developer/Findit_app` + +**Next steps to switch**: + +1. Confirm this task document is complete and accurate (you and I both review). +2. Create the repo: `github.com/Findit-AI/findit-keyframe` (you handle GitHub side; provide local path when ready). +3. Clone locally to e.g. `/Users/cheongzhiyan/Developer/findit-keyframe`. +4. Copy this `TASKS.md` into the new repo root (or `docs/TASKS.md`). +5. Move working directory to the new repo for all subsequent implementation work. + +From that point forward, all tasks proceed from the new working directory. diff --git a/benchmarks/.gitkeep b/benchmarks/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/benchmarks/README.md b/benchmarks/README.md new file mode 100644 index 0000000..aea76b7 --- /dev/null +++ b/benchmarks/README.md @@ -0,0 +1,35 @@ +# Benchmarks + +End-to-end performance baselines for `findit-keyframe`. The Rust port (P5+) +is expected to beat these numbers by 5-10x on identical inputs. + +## Running + +```bash +# Synthetic uniform-shot baseline on any video. +python benchmarks/bench_e2e.py --video path/to/video.mp4 + +# Replay a real shot list (e.g. produced by scenesdetect). +python benchmarks/bench_e2e.py --video path/to/video.mp4 --shots shots.json + +# Smaller output frames (faster, cheaper). +python benchmarks/bench_e2e.py --video path/to/video.mp4 --target-size 256 +``` + +Each run appends a row to [`results.md`](results.md) with the date, git +SHA, and headline numbers (wall time, throughput, peak RSS). + +## Reading the numbers + +* **Wall (s)**: total time for `extract_all`, decoder open + close included. +* **KF/s**: keyframes emitted per wall-clock second. Useful to compare + across videos of different length. +* **Mem (MB)**: peak resident-set size. Linux reports KB internally; + macOS reports bytes; the script normalises both to MB. + +## Performance budget (P3 baseline target) + +Per `TASKS.md` §3.T8: the Kino Demo render (1m44s, 1080p) should finish +extraction in under 30 seconds on an M-series Mac at the default +`target_size=384`. A regression beyond that warrants investigation +before tagging a release. diff --git a/benchmarks/bench_e2e.py b/benchmarks/bench_e2e.py new file mode 100644 index 0000000..4e58b15 --- /dev/null +++ b/benchmarks/bench_e2e.py @@ -0,0 +1,158 @@ +#!/usr/bin/env python3 +"""End-to-end keyframe extraction benchmark. + +Usage:: + + python benchmarks/bench_e2e.py --video PATH [--shots PATH] [--target-size N] + [--results-md PATH] [--quiet] + +If ``--shots`` is omitted, uniform 4-second shots covering the whole video +are generated. Each run prints a JSON summary on stdout (suppressed with +``--quiet``) and appends a Markdown row to ``benchmarks/results.md`` with +date, git SHA, and the headline numbers. + +The Rust port (P5+) must beat this baseline by 5-10x on the same input; +keep this script self-contained so identical fixtures can be replayed +across implementations. +""" + +from __future__ import annotations + +import argparse +import json +import resource +import subprocess +import sys +import time +from datetime import UTC, datetime +from pathlib import Path +from typing import Any + +from findit_keyframe.cli import _parse_shot_json +from findit_keyframe.decoder import VideoDecoder +from findit_keyframe.sampler import extract_all +from findit_keyframe.types import SamplingConfig, ShotRange, Timebase, Timestamp + + +def _git_sha() -> str: + try: + return subprocess.check_output( + ["git", "rev-parse", "--short", "HEAD"], + text=True, + stderr=subprocess.DEVNULL, + ).strip() + except (subprocess.CalledProcessError, FileNotFoundError): + return "unknown" + + +def _uniform_shots(duration_sec: float, interval_sec: float = 4.0) -> list[ShotRange]: + """Synthetic shot list covering ``[0, duration_sec)`` in equal slices.""" + tb = Timebase(1, 1000) + n = max(1, int(duration_sec // interval_sec)) + return [ + ShotRange( + start=Timestamp(round(i * interval_sec * 1000), tb), + end=Timestamp(round(min((i + 1) * interval_sec, duration_sec) * 1000), tb), + ) + for i in range(n) + ] + + +def _peak_memory_mb() -> float: + """Process peak RSS in MB. Linux reports KB; macOS reports bytes.""" + rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss + divisor = 1024 * 1024 if sys.platform == "darwin" else 1024 + return rss / divisor + + +def run_benchmark( + video: Path, + shots: list[ShotRange] | None = None, + target_size: int = 384, +) -> dict[str, Any]: + """Time ``extract_all`` on ``video`` and return a result dict.""" + config = SamplingConfig(target_size=target_size) + with VideoDecoder.open(video, target_size=config.target_size) as decoder: + all_shots = shots if shots is not None else _uniform_shots(decoder.duration_sec) + t0 = time.perf_counter() + keyframes_per_shot = extract_all(all_shots, decoder, config) + wall = time.perf_counter() - t0 + duration = decoder.duration_sec + n_keyframes = sum(len(s) for s in keyframes_per_shot) + return { + "video": str(video), + "duration_sec": round(duration, 3), + "shots": len(all_shots), + "keyframes": n_keyframes, + "wall_sec": round(wall, 3), + "kf_per_sec": round(n_keyframes / wall, 1) if wall > 0 else None, + "memory_mb": round(_peak_memory_mb(), 1), + "target_size": target_size, + } + + +_RESULTS_HEADER = ( + "# findit-keyframe benchmarks\n" + "\n" + "Append-only log of `bench_e2e.py` runs. Each row is one run.\n" + "\n" + "| Date (UTC) | Git | Video | Duration (s) | Shots | Keyframes " + "| Wall (s) | KF/s | Mem (MB) | Target |\n" + "|------------|-----|-------|--------------|-------|-----------" + "|----------|------|----------|--------|\n" +) + + +def append_result(results_md: Path, result: dict[str, Any], git_sha: str) -> None: + if not results_md.exists(): + results_md.write_text(_RESULTS_HEADER) + timestamp = datetime.now(UTC).strftime("%Y-%m-%d %H:%M") + row = ( + f"| {timestamp} | `{git_sha}` | `{Path(result['video']).name}` " + f"| {result['duration_sec']} | {result['shots']} | {result['keyframes']} " + f"| {result['wall_sec']} | {result['kf_per_sec']} | {result['memory_mb']} " + f"| {result['target_size']} |\n" + ) + with results_md.open("a") as f: + f.write(row) + + +def main() -> int: + parser = argparse.ArgumentParser( + prog="bench_e2e", + description="End-to-end findit-keyframe extraction benchmark.", + ) + parser.add_argument("--video", type=Path, required=True, help="Source video file.") + parser.add_argument( + "--shots", + type=Path, + default=None, + help="Optional shot JSON; defaults to 4-second uniform shots.", + ) + parser.add_argument( + "--target-size", + type=int, + default=384, + help="Output frame edge length (default 384, matching SamplingConfig).", + ) + parser.add_argument( + "--results-md", + type=Path, + default=Path(__file__).parent / "results.md", + help="Markdown results file to append to (default: benchmarks/results.md).", + ) + parser.add_argument("--quiet", action="store_true", help="Suppress stdout JSON summary.") + args = parser.parse_args() + + shots = _parse_shot_json(args.shots) if args.shots is not None else None + result = run_benchmark(args.video, shots, target_size=args.target_size) + sha = _git_sha() + append_result(args.results_md, result, sha) + + if not args.quiet: + print(json.dumps({**result, "git_sha": sha}, indent=2)) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/docs/algorithm.md b/docs/algorithm.md new file mode 100644 index 0000000..cba5209 --- /dev/null +++ b/docs/algorithm.md @@ -0,0 +1,132 @@ +# Algorithm Specification + +> **Audience**: Implementers (Python authors, Rust translators) and reviewers. Language-agnostic. + +## 0. Overview + +Given a video and a list of shots (half-open `[start, end)` time ranges from `scenesdetect`), produce, per shot, a small ordered list of **keyframes** that: + +1. Are **temporally distributed** across the shot (no clustering at one timestamp). +2. Pass a **hard quality gate** (not black, not blown-out, not flat). +3. Maximise a **soft quality score** within their respective sub-windows. +4. Optionally weight a **saliency signal** (Apple Vision attention) when available. + +A keyframe carries: source shot id, timestamp (the actual decoded PTS), originating bin index, decoded RGB pixels, quality metrics, and a `Confidence` tag (`High` / `Low` / `Degraded`). + +## 1. Stratified Temporal Sampling + +For a shot of duration `D` seconds, target interval `I` (default 4.0 s), and per-shot cap `M` (default 16): + +``` +N = clamp(ceil(D / I), 1, M) +``` + +The shot is partitioned into `N` equal-duration **bins** over the symmetrically shrunken effective range `[start + s, end - s]` where `s = boundary_shrink_pct * D` (default 2 %). Shrinking the outer edges keeps the algorithm away from the ±1-frame uncertainty around scenesdetect cuts. + +## 2. Within-Bin Selection + +For each bin `[t0, t1)`: + +1. **Probe** `K = candidates_per_bin` (default 6) candidate timestamps at the **cell centres** of the bin: `t0 + (i + 0.5) * (t1 - t0) / K` for `i ∈ [0, K)`. Cell centres avoid sampling exactly on bin edges, where the upstream cut detector is least confident. +2. **Decode** each candidate frame through the decoder (strategy is its concern). +3. **Hard gate** each candidate via `QualityGate`: + - `mean_luma ∈ [15/255, 240/255]` (inclusive) + - `luma_variance ≥ 5` (sample variance, ddof = 1, on the raw 0–255 scale) +4. **Score** survivors with the composite: + ``` + score = 0.6 * rank(laplacian_var) + + 0.2 * rank(entropy) + + 0.2 * saliency_mass # 0.0 when no saliency provider is wired + ``` + `rank` is **stable ordinal rank** within the bin's surviving pool, scaled to `[0, 1]`. A single-survivor bin gets `rank = 1` for both metrics, so its score collapses to `0.8 + 0.2 * saliency_mass`. +5. **Pick** `argmax` → emit `ExtractedKeyframe(confidence=High)`. + +## 3. Fallback + +If a bin has zero survivors after step 2.3: + +1. **Expand** the search window symmetrically by `fallback_expand_pct * (t1 - t0)` (default 20 % of bin width) on each side, clamped to the shot's `[start.seconds, end.seconds]`. Re-probe `K` candidates in the expanded window, decode, gate, score. A surviving pick → `Confidence.Low`. +2. **Force-pick**: if the expanded window also yields no survivors, gather the union of the native-bin and expanded-bin candidates, score them all (gate ignored), and pick `argmax`. Emit `Confidence.Degraded`. + +The three-level confidence makes downstream filtering trivial: drop `Degraded` if you only want curated frames, or accept everything if you'd rather get something for every bin. + +## 4. Saliency Provider Contract + +A `SaliencyProvider` exposes a single method `compute(rgb: ndarray) -> float` returning a saliency mass in `[0.0, 1.0]`. Two implementations ship: + +- `NoopSaliencyProvider` — always returns `0.0`. Default everywhere. +- `AppleVisionSaliencyProvider` — wraps `VNGenerateAttentionBasedSaliencyImageRequest`. The scalar is `clamp(sum(area * confidence), 0, 1)` over the request's `salientObjects` bounding boxes; the heatmap `CVPixelBuffer` is intentionally *not* read because the bounding-box scalar carries the same "is anything attention-grabbing" signal with a much cleaner pyobjc API. The Rust port (`objc2-vision`) may pick either path. + +`default_saliency_provider()` returns Apple Vision on macOS when `pyobjc-framework-Vision` is installed, else `Noop`. + +## 5. Invariants + +- Bin count `N ≥ 1` for any non-zero-duration shot. +- Bins are disjoint and union-cover the (post-shrink) shot interior. +- Output keyframe count per shot is exactly `N`, monotonically non-decreasing in `bucket_index`. +- Every `ExtractedKeyframe.bucket_index` is in `[0, N)`. +- Saliency mass is in `[0.0, 1.0]`; quality fields are real-valued and finite. + +## 6. JSON Schemas + +### Input — shots + +```json +{ + "shots": [ + { + "id": 0, + "start_pts": 0, + "end_pts": 1000, + "timebase_num": 1, + "timebase_den": 1000 + } + ] +} +``` + +### Output — manifest + +```json +{ + "video": "path/to/video.mp4", + "keyframes": [ + { + "shot_id": 0, + "bucket": 0, + "file": "kf_000_000.jpg", + "timestamp_sec": 1.234, + "quality": { + "laplacian_var": 215.4, + "mean_luma": 0.41, + "luma_variance": 1820.7, + "entropy": 7.31, + "saliency_mass": 0.62 + }, + "confidence": "high" + } + ] +} +``` + +`timestamp_sec` is the actual decoded PTS, which may differ from the probe time by ±1 frame because the seek lands on the nearest preceding keyframe and decoding stops at the first frame whose PTS is at or after the target. + +## 7. Parameter Rationale + +| Parameter | Default | Rationale | +|-----------|---------|-----------| +| `target_interval_sec` | 4.0 | One frame per ~4 s gives a VLM enough temporal cadence to detect changes without saturating context. | +| `candidates_per_bin` | 6 | Empirically enough to find a sharp frame at 24-60 fps in any 2-10 s bin without runaway decode cost. | +| `max_frames_per_shot` | 16 | Hard cap to bound per-shot cost; long static shots still compress to 16. | +| `boundary_shrink_pct` | 0.02 | Avoids the ±1-frame uncertainty around a cut. | +| `fallback_expand_pct` | 0.20 | Symmetric expand into both neighbours covers most under-exposed openings. | +| `target_size` | 384 | Sweet spot for SigLIP 2 / Qwen3-VL inputs; resize is part of decode to amortise cost. | +| Quality gate `mean_luma` | `[15/255, 240/255]` | Reject black-frame openings and blown-out flashes; range matches BT.601 limited-range Y / 255. | +| Quality gate `luma_variance` | `≥ 5` | Reject pixel-flat frames (e.g. a single solid colour). | + +## 8. Known Limitations + +- **VFR (variable frame rate)**: PTS handling is correct, but quality scores compare across temporally-uneven samples. Not a defect for our shot lengths. +- **Decoder ±1 frame jitter**: requested timestamp may resolve to nearest I/P frame. +- **Shot-spanning duplicates**: identical anchor-shot keyframes across consecutive shots are not deduplicated. `TASKS.md` §7 ("Out-of-Scope but Noted for Later") tracks the MMR cross-bin deduplication path alongside four other deferred items (SigLIP medoid selection, cross-shot coherence, learned quality models, hardware decode); all are tagged P2+ for the Rust phase. +- **Sequential decode strategy**: `pick_strategy` returns `Sequential` for dense shot lists, but the current implementation always uses `PerShotSeek`. The Sequential optimisation lands in P3+ (Rust phase). diff --git a/docs/rust-porting.md b/docs/rust-porting.md new file mode 100644 index 0000000..87095a0 --- /dev/null +++ b/docs/rust-porting.md @@ -0,0 +1,116 @@ +# Rust Porting Guide + +> **Audience**: The teammate translating `findit-keyframe` to Rust. **Read this before reading any Python source.** If anything in this doc disagrees with the Python code, the Python code is wrong — file an issue. + +This guide intentionally re-derives the public surface so a Rust skeleton can be sketched without consulting the Python implementation. Algorithm details live in [`algorithm.md`](algorithm.md). + +## 1. Type Map + +| Python (`src/findit_keyframe/`) | Rust target | Notes | +|---------------------------------|-------------|-------| +| `types.Timebase(num: int, den: int)` | `pub struct Timebase { num: i32, den: NonZeroU32 }` | Mirror upstream `scenesdetect::frame::Timebase`. Reduced-form (gcd-based) equality and hash; `den > 0` invariant enforced at construction. | +| `types.Timestamp(pts: int, timebase: Timebase)` | `pub struct Timestamp { pts: i64, timebase: Timebase }` | Match upstream; `seconds() -> f64`. Comparison and hash via integer cross-multiplication so cross-timebase ordering is exact. | +| `types.ShotRange(start: Timestamp, end: Timestamp)` | `pub type ShotRange = scenesdetect::frame::TimeRange;` | Half-open. `end > start` enforced at construction. Reuse the upstream type if the Rust port pulls `scenesdetect` as a dep. | +| `types.SamplingConfig` | `#[derive(Clone, Copy)] pub struct SamplingConfig { ... }` | Plain `Copy` struct; `Default` impl mirrors Python defaults. Mutable in Python; in Rust use struct update syntax. | +| `types.Confidence` (`StrEnum`) | `#[non_exhaustive] pub enum Confidence { High, Low, Degraded }` | `value` is the lowercase string for manifest output; in Rust derive `Display` returning the same. | +| `types.QualityMetrics` | `#[derive(Clone, Copy, PartialEq)] pub struct QualityMetrics { ... }` | All fields `f32` (or `f64` for parity). No `Eq` / `Hash` because of float fields. | +| `types.ExtractedKeyframe` | `pub struct ExtractedKeyframe { ..., rgb: Vec }` | Mutable in Python; in Rust own the buffer. | +| `decoder.DecodedFrame` | `pub struct DecodedFrame<'a> { pts: Timestamp, width: u32, height: u32, rgb: Cow<'a, [u8]> }` | Borrowed when from decoder buffer; owned after copy. Width/height stored explicitly so `&[u8]` carries no shape info. | +| `decoder.Strategy` (`StrEnum`) | `pub enum Strategy { Sequential, PerShotSeek }` | | +| `quality.QualityGate` | `pub struct QualityGate { min_mean_luma: f32, max_mean_luma: f32, min_luma_variance: f32 }` | `Default` impl matches Python's `15/255`, `240/255`, `5.0`. `passes(&QualityMetrics) -> bool`. | +| `saliency.SaliencyProvider` (`Protocol`) | `pub trait SaliencyProvider { fn compute(&self, rgb: &RgbFrame) -> f32; }` | Object-safe trait. | +| `saliency.NoopSaliencyProvider` | `pub struct NoopSaliencyProvider;` | Returns `0.0`. | +| `saliency.AppleVisionSaliencyProvider` | `pub struct AppleVisionSaliencyProvider { ... }` (cfg `target_os = "macos"`) | Wraps `objc2-vision`'s `VNGenerateAttentionBasedSaliencyImageRequest`. | + +## 2. Dependency Map + +| Python | Rust | +|--------|------| +| `numpy` | Plain `Vec` / `&[u8]`. **Do not** use `ndarray`; upstream `scenesdetect` rejected it. SIMD via `std::simd` or hand-tuned `aarch64`/`x86`/`wasm32` modules behind `cfg`. | +| `av` (PyAV) | `ffmpeg-next` (or upstream `scenesdetect`'s decoder helpers if exposed). VideoToolbox hwaccel via `ffmpeg-sys` + `AVHWDeviceType::VideoToolbox`. | +| `pyobjc-framework-Vision` + `pyobjc-framework-Quartz` | `objc2-vision`, `objc2-quartz-core`, `objc2-core-foundation`. | +| `argparse` | `clap` (derive macros). | +| `pytest` | Cargo built-in `#[test]` + `proptest` for property tests. | +| `pytest-benchmark` | `criterion` (already used by `scenesdetect`). | +| `json` (stdlib) | `serde` + `serde_json`. | +| `dataclasses.replace` | Struct update syntax: `SamplingConfig { foo: new_foo, ..old }`. | +| PyAV `av.open(..., format="image2")` + `mjpeg` codec for JPEG output | `image::codecs::jpeg::JpegEncoder` with quality 92. | + +## 3. Idiom Map + +### Operator-level mappings + +| Python pattern | Rust pattern | +|----------------|--------------| +| `@dataclass(frozen=True)` | `#[derive(Clone, Copy, PartialEq, Eq, Hash)] pub struct ...` (drop `Eq`/`Hash` for floats). | +| `Optional[T]` / `T \| None` | `Option`. | +| `list[T]` | `Vec` for owned, `&[T]` for borrowed. | +| `Iterator[(int, X)]` | `impl Iterator`. | +| `Protocol` | `trait`. | +| `enum.StrEnum` | `enum` with `#[non_exhaustive]` for public; impl `Display` for the lowercase string. | +| `dataclasses.replace(c, x=v)` | `Config { x: v, ..c }`. | + +### Numpy ops we actually use + +| Python | Rust | +|--------|------| +| `rgb[..., 0].astype(np.uint32)` etc. (per-channel split for BT.601) | Manual `for i in (0..len).step_by(3)`; promote to `u32` via `as u32`. | +| `((66*r + 129*g + 25*b + 128) >> 8) + 16` | Same expression, `u32` arithmetic, `>> 8` then `+ 16`. | +| 3×3 Laplacian via slicing: `f[:-2,1:-1] + f[2:,1:-1] + f[1:-1,:-2] + f[1:-1,2:] - 4*f[1:-1,1:-1]` | Index loop over `(y, x)` for `y in 1..h-1`, `x in 1..w-1`; reads centre + 4 neighbours, accumulates `i32`. | +| `arr.var()` (population) | Two-pass: mean = `sum / n`, var = `sum((x-mean)^2) / n` (or Welford). | +| `arr.var(ddof=1)` (sample) | Same but divide by `n - 1`. | +| `np.histogram(luma, bins=256, range=(0, 256))` | `[u32; 256]` accumulator; one pass over pixels. (Upstream `scenesdetect::histogram` uses 4-wide parallel accumulators — borrow that pattern.) | +| Shannon entropy: `-sum(p * log2(p)) for p > 0` | Manual loop; `f64::log2`. | +| `np.argsort(kind="stable")` for ordinal rank | `Vec<(idx, value)>::sort_by(...)` (Rust's `sort` is stable). | +| `np.argmax(scores)` | `scores.iter().enumerate().max_by(...).map(|(i, _)| i)`. | +| `np.frombuffer(bytes, dtype=np.uint8).reshape(...)` | Just a slice view — no copy. | + +### Apple Vision idiom + +| Python | Rust (`objc2-vision`) | +|--------|----------------------| +| RGB → padded RGBX (32-bit aligned) → `CGImageCreate` | Same: `CGImageCreate` wants 32 bpp; pad to RGBA with `kCGImageAlphaNoneSkipLast`. | +| `VNImageRequestHandler.alloc().initWithCGImage_options_(cg, {})` | `VNImageRequestHandler::initWithCGImage_options:` | +| `VNGenerateAttentionBasedSaliencyImageRequest.alloc().init()` | `VNGenerateAttentionBasedSaliencyImageRequest::new()`. | +| `handler.performRequests_error_([req], None)` | `handler.performRequests(&[req])?` (returns `Result`). | +| `obs.salientObjects()` returns `[VNRectangleObservation]`; sum `area * confidence` | Same iteration, sum `f32`. | + +## 4. Test Fixture Map + +JSON / image fixtures under `tests/fixtures/` (where present) are the **shared ground truth** between implementations. Tests that build inputs procedurally (most of T2-T5) should be replayed in Rust with the same numeric expectations. + +| Python test | Rust test (target) | Notes | +|-------------|--------------------|-------| +| `tests/test_types.py` | `tests/types_*.rs` | Hand-derived numeric expectations; replay verbatim. | +| `tests/test_quality.py` | `tests/quality_*.rs` | All numeric assertions are derived in test docstrings; replay verbatim. | +| `tests/test_quality_golden.py` + `tests/fixtures/quality/canonical.json` | `tests/quality_golden.rs` | **Bit-precision contract.** Implement the six generators (solid / channel / h_gradient / v_gradient / checker / single_pixel — exact integer-arithmetic semantics in `_build_frame`'s docstring), load the JSON, assert each metric matches expected to 6 decimals (`tolerance = 1e-6`). | +| `tests/test_decoder.py` | `tests/decoder_*.rs` | Tiny test video encoded inside `conftest.py`; reproduce with the same encoder settings. | +| `tests/test_sampler.py` | `tests/sampler_*.rs` | Pure-function tests + integration on `varied_video` (per-frame mid-tone noise; reproduce). | +| `tests/test_cli.py` | `tests/cli_*.rs` | Manifest schema and exit codes are the contract. | +| `tests/test_saliency.py` | `tests/saliency_*.rs` | Apple Vision tests are macOS-gated by `cfg(target_os = "macos")`. | + +Tolerance: 6 decimal places for float comparisons unless documented otherwise. + +## 5. Python-Only Shortcuts (Tighten in Rust) + +- `np.var(ddof=1)`: numpy uses an O(n) two-pass internally; Rust should consider Welford for numerical stability on large `n`. +- `_ordinal_rank` does an `argsort` to get ranks. Rust can use `sort_by_cached_key` or compute ranks in-place during merge-sort for slightly less allocation. +- The Apple Vision provider holds Python-side strong refs to factory classes for hot-path lookup avoidance. Rust gets this for free via static linkage. +- `_select_with_fallback` re-decodes candidates in the expanded window without checking for overlap with the native probe set. Rust can dedup by PTS to skip a few decodes per fallback bin. + +## 6. Known Divergence Risk + +| Risk | Mitigation | +|------|------------| +| Numpy float reduction order vs Rust scalar order | Cross-validate at 6 decimals; document any operator that uses tree-reduction in numpy. | +| PyAV resize (swscale) vs Rust resize | Pin both to bilinear; document. | +| Apple Vision saliency mass: pyobjc returns confidence-weighted bounding boxes; `objc2-vision` returns the same observations but the iteration API differs slightly | Test on identical input; tolerate ±1 % saliency mass. | +| `>> 8` truncation in `rgb_to_luma` differs from float division | Both implementations must use integer fixed-point exactly per spec. Test `pure_green` round-trip (33023 >> 8 = 128 → Y = 144) catches this. | + +## 7. Translation Workflow + +1. Read this doc. +2. For each Python module, generate a Rust module skeleton with matching public types and stub function signatures. +3. Port one module at a time, in this order: `types` → `quality` → `decoder` → `sampler` → `saliency` → `cli`. +4. After each module, run the equivalent Rust tests; expectations come from the Python tests verbatim. +5. When a numerical assertion diverges, **first** check this doc's "Known Divergence Risk" table; **then** open an issue. diff --git a/examples/extract_basic.py b/examples/extract_basic.py new file mode 100644 index 0000000..a45126e --- /dev/null +++ b/examples/extract_basic.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python3 +"""End-to-end programmatic example for ``findit-keyframe``. + +Demonstrates the public Python API as an alternative to the +``findit-keyframe extract`` CLI. Run from the repository root:: + + python examples/extract_basic.py path/to/video.mp4 + +The script: + +1. Opens the video with :class:`VideoDecoder`. +2. Builds two synthetic back-to-back shots that cover the whole video. + In real use, shot boundaries come from + `scenesdetect `_'s output. +3. Runs :func:`extract_all` with default :class:`SamplingConfig` plus the + platform's preferred :class:`SaliencyProvider` (Apple Vision on macOS + when ``[macos]`` extras are installed; ``Noop`` elsewhere). +4. Prints the chosen keyframe's bin index, timestamp, confidence, and + Laplacian variance for each shot. + +Use this as a template for wiring ``findit-keyframe`` into a larger +pipeline that already has shot boundaries and wants in-process keyframe +extraction without going through the CLI. +""" + +from __future__ import annotations + +import sys +from pathlib import Path + +from findit_keyframe import ( + SamplingConfig, + ShotRange, + Timebase, + Timestamp, + VideoDecoder, + default_saliency_provider, + extract_all, +) + + +def _ts(seconds: float, timebase: Timebase) -> Timestamp: + """Build a Timestamp for ``seconds`` in the decoder's native timebase.""" + return Timestamp(round(seconds * timebase.den / timebase.num), timebase) + + +def _build_demo_shots(decoder: VideoDecoder) -> list[ShotRange]: + """Two back-to-back shots covering the whole video. + + A real pipeline would replace this with the upstream shot list (for + example, by calling :func:`findit_keyframe.cli._parse_shot_json` on + a scenesdetect output file). + """ + duration = decoder.duration_sec + if duration <= 0.0: + raise ValueError(f"video reports unknown duration ({duration})") + midpoint = duration / 2.0 + return [ + ShotRange(start=_ts(0.0, decoder.timebase), end=_ts(midpoint, decoder.timebase)), + ShotRange(start=_ts(midpoint, decoder.timebase), end=_ts(duration, decoder.timebase)), + ] + + +def main() -> int: + if len(sys.argv) != 2: + print(f"usage: {sys.argv[0]} VIDEO_PATH", file=sys.stderr) + return 1 + + video_path = Path(sys.argv[1]) + if not video_path.is_file(): + print(f"error: {video_path} not found", file=sys.stderr) + return 1 + + config = SamplingConfig() + saliency = default_saliency_provider() + print(f"saliency provider: {type(saliency).__name__}", file=sys.stderr) + + with VideoDecoder.open(video_path, target_size=config.target_size) as decoder: + shots = _build_demo_shots(decoder) + keyframes_per_shot = extract_all(shots, decoder, config, saliency_provider=saliency) + + for shot_id, shot_keyframes in enumerate(keyframes_per_shot): + print(f"shot {shot_id}: {len(shot_keyframes)} keyframe(s)") + for kf in shot_keyframes: + print( + f" bin={kf.bucket_index} " + f"t={kf.timestamp.seconds:7.3f}s " + f"conf={kf.confidence.value:<8} " + f"laplacian={kf.quality.laplacian_var:9.1f} " + f"saliency={kf.quality.saliency_mass:.3f}" + ) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..32b4278 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,149 @@ +[build-system] +requires = ["hatchling>=1.24"] +build-backend = "hatchling.build" + +[project] +name = "findit-keyframe" +version = "0.0.0" +description = "Per-shot keyframe extraction with stratified temporal sampling. Python reference implementation; Rust translation target." +readme = "README.md" +requires-python = ">=3.11" +license = "Apache-2.0 OR MIT" +license-files = ["LICENSE-APACHE", "LICENSE-MIT"] +authors = [ + { name = "Findit-AI" }, +] +keywords = ["video", "keyframe", "scene-detection", "vlm", "embedding"] +classifiers = [ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Topic :: Multimedia :: Video", + "Topic :: Scientific/Engineering :: Image Processing", + "Operating System :: OS Independent", +] +dependencies = [ + "av>=13.0,<14.0", + "numpy>=2.0,<3.0", +] + +[project.optional-dependencies] +macos = [ + "pyobjc-framework-Vision>=10.3; sys_platform == 'darwin'", + "pyobjc-framework-Quartz>=10.3; sys_platform == 'darwin'", +] +dev = [ + "pytest>=8.0", + "pytest-benchmark>=4.0", + "ruff>=0.6.0", + "mypy>=1.11", + "types-Pillow", +] + +[project.scripts] +findit-keyframe = "findit_keyframe.cli:main" + +[project.urls] +Repository = "https://github.com/Findit-AI/findit-keyframe" +Issues = "https://github.com/Findit-AI/findit-keyframe/issues" + +[tool.hatch.build.targets.wheel] +packages = ["src/findit_keyframe"] + +[tool.hatch.build.targets.sdist] +include = [ + "src/", + "tests/", + "docs/", + "README.md", + "LICENSE-APACHE", + "LICENSE-MIT", + "CHANGELOG.md", + "TASKS.md", +] + +[tool.ruff] +line-length = 100 +target-version = "py311" +src = ["src", "tests"] + +[tool.ruff.lint] +select = [ + "E", + "W", + "F", + "I", + "N", + "UP", + "B", + "C4", + "SIM", + "RUF", + "PT", + "TCH", + "PIE", + "PERF", + "FBT", + "ANN", + "S", +] +ignore = [ + "ANN401", + "S101", +] + +[tool.ruff.lint.per-file-ignores] +"tests/**/*.py" = ["ANN", "S", "FBT"] +"benchmarks/**/*.py" = ["ANN", "S", "FBT"] +"examples/**/*.py" = ["ANN", "S", "FBT"] + +[tool.ruff.lint.isort] +known-first-party = ["findit_keyframe"] +combine-as-imports = true + +[tool.ruff.format] +quote-style = "double" +indent-style = "space" +line-ending = "lf" + +[tool.pytest.ini_options] +minversion = "8.0" +testpaths = ["tests"] +addopts = [ + "-ra", + "--strict-markers", + "--strict-config", +] +markers = [ + "macos: tests requiring macOS (Apple Vision, etc.)", + "integration: end-to-end tests touching real video files", + "slow: long-running tests (benchmarks, full-video integration)", +] + +[tool.mypy] +python_version = "3.11" +strict = true +warn_unused_configs = true +warn_redundant_casts = true +warn_unused_ignores = true +warn_return_any = true +warn_unreachable = true +disallow_untyped_defs = true +disallow_incomplete_defs = true +check_untyped_defs = true +disallow_untyped_decorators = true +no_implicit_optional = true +strict_equality = true +extra_checks = true +files = ["src/findit_keyframe"] + +[[tool.mypy.overrides]] +module = ["av", "av.*"] +ignore_missing_imports = true + +[[tool.mypy.overrides]] +module = ["Vision", "Quartz", "objc", "Foundation"] +ignore_missing_imports = true diff --git a/src/findit_keyframe/__init__.py b/src/findit_keyframe/__init__.py new file mode 100644 index 0000000..1d434ba --- /dev/null +++ b/src/findit_keyframe/__init__.py @@ -0,0 +1,57 @@ +"""findit-keyframe: per-shot keyframe extraction with stratified temporal sampling. + +Public API surface re-exports the user-facing types and top-level functions. +Internal helpers live in their respective modules and are not part of the +stable API. +""" + +from __future__ import annotations + +from findit_keyframe.decoder import ( + DecodedFrame, + Strategy, + VideoDecoder, + pick_strategy, +) +from findit_keyframe.quality import QualityGate, compute_quality +from findit_keyframe.saliency import ( + AppleVisionSaliencyProvider, + NoopSaliencyProvider, + SaliencyProvider, + default_saliency_provider, +) +from findit_keyframe.sampler import extract_all, extract_for_shot +from findit_keyframe.types import ( + Confidence, + ExtractedKeyframe, + QualityMetrics, + SamplingConfig, + ShotRange, + Timebase, + Timestamp, +) + +__version__ = "0.0.0" + +__all__ = [ + "AppleVisionSaliencyProvider", + "Confidence", + "DecodedFrame", + "ExtractedKeyframe", + "NoopSaliencyProvider", + "QualityGate", + "QualityMetrics", + "SaliencyProvider", + "SamplingConfig", + "ShotRange", + "Strategy", + "Timebase", + "Timestamp", + "VideoDecoder", + "__version__", + "compute_quality", + "default_saliency_provider", + "extract_all", + "extract_for_shot", + "pick_strategy", +] diff --git a/src/findit_keyframe/cli.py b/src/findit_keyframe/cli.py new file mode 100644 index 0000000..da576f9 --- /dev/null +++ b/src/findit_keyframe/cli.py @@ -0,0 +1,308 @@ +"""argparse-based CLI for findit-keyframe. + +Subcommands: + +* ``extract VIDEO_PATH SHOTS_JSON OUTPUT_DIR`` — read the shot list, extract + one keyframe per bin per shot, write each keyframe as a baseline JPEG and + emit ``manifest.json`` describing all outputs. + +Exit codes (per ``TASKS.md`` §3 → Task 7): + +* ``0`` — success. +* ``1`` — input error (bad JSON, unknown config field, unsupported flag). +* ``2`` — extraction error (decoder open failure, JPEG write failure, etc.). +""" + +from __future__ import annotations + +import argparse +import json +import sys +from dataclasses import asdict +from pathlib import Path +from typing import TYPE_CHECKING, Any + +import av +import numpy as np + +from findit_keyframe.decoder import VideoDecoder +from findit_keyframe.sampler import extract_all +from findit_keyframe.types import ( + ExtractedKeyframe, + SamplingConfig, + ShotRange, + Timebase, + Timestamp, +) + +if TYPE_CHECKING: + from findit_keyframe.saliency import SaliencyProvider + +__all__ = ["main"] + +EXIT_OK = 0 +EXIT_INPUT_ERROR = 1 +EXIT_EXTRACTION_ERROR = 2 + + +# --------------------------------------------------------------------------- # +# Parsing # +# --------------------------------------------------------------------------- # + + +def _parse_shot_json(path: Path) -> list[ShotRange]: + """Read a scenesdetect-compatible shot JSON file into a ``ShotRange`` list. + + Args: + path: Filesystem path to a JSON file with a top-level ``shots`` + array. Each entry must have keys ``start_pts``, ``end_pts``, + ``timebase_num``, ``timebase_den``. + + Returns: + A list of :class:`ShotRange`, in input order. + + Raises: + FileNotFoundError: If ``path`` doesn't exist. + json.JSONDecodeError: If the file is not valid JSON. + KeyError: If the top-level ``shots`` key or any required entry + field is missing. + ValueError: If any shot has ``end <= start`` (per :class:`ShotRange`). + """ + data: dict[str, Any] = json.loads(path.read_text()) + shots: list[ShotRange] = [] + for entry in data["shots"]: + timebase = Timebase(num=entry["timebase_num"], den=entry["timebase_den"]) + shots.append( + ShotRange( + start=Timestamp(pts=entry["start_pts"], timebase=timebase), + end=Timestamp(pts=entry["end_pts"], timebase=timebase), + ) + ) + return shots + + +def _parse_config_json(path: Path | None) -> SamplingConfig: + """Apply optional JSON overrides to a default :class:`SamplingConfig`. + + Unknown fields raise :class:`ValueError` so typos surface immediately + rather than silently no-op-ing. + + Args: + path: Optional path to a JSON object whose keys are + :class:`SamplingConfig` field names. ``None`` returns defaults. + + Returns: + A fresh :class:`SamplingConfig` with overrides applied. + + Raises: + FileNotFoundError: If ``path`` is given but does not exist. + json.JSONDecodeError: If the file is not valid JSON. + ValueError: If the JSON contains a key that is not a + :class:`SamplingConfig` field. + """ + config = SamplingConfig() + if path is None: + return config + overrides: dict[str, Any] = json.loads(path.read_text()) + valid_fields = set(vars(config)) + for name, value in overrides.items(): + if name not in valid_fields: + raise ValueError(f"Unknown SamplingConfig field: {name!r}") + setattr(config, name, value) + return config + + +# --------------------------------------------------------------------------- # +# JPEG output # +# --------------------------------------------------------------------------- # + + +def _write_jpeg(path: Path, rgb_bytes: bytes, width: int, height: int) -> None: + """Encode a packed RGB24 byte buffer as a baseline JPEG via PyAV's mjpeg. + + ``yuvj420p`` (full-range YUV) is the standard JPEG sampling; the + ``image2`` muxer writes a single-frame JPEG file rather than an MJPEG + container. + + Args: + path: Output filesystem path. Parent directories must already exist. + rgb_bytes: Packed RGB24 bytes of length ``width * height * 3``. + width: Frame width in pixels. + height: Frame height in pixels. + + Raises: + av.error.FFmpegError: If FFmpeg cannot write the file (permissions, + disk full, unsupported pixel format, ...). + OSError: From the underlying file open. + """ + rgb = np.frombuffer(rgb_bytes, dtype=np.uint8).reshape(height, width, 3) + container = av.open(str(path), mode="w", format="image2") + try: + stream: Any = container.add_stream("mjpeg") + stream.pix_fmt = "yuvj420p" + stream.width = width + stream.height = height + frame = av.VideoFrame.from_ndarray(rgb, format="rgb24") + for packet in stream.encode(frame): + container.mux(packet) + for packet in stream.encode(): + container.mux(packet) + finally: + container.close() + + +def _manifest_entry(kf: ExtractedKeyframe, filename: str) -> dict[str, Any]: + return { + "shot_id": kf.shot_id, + "bucket": kf.bucket_index, + "file": filename, + "timestamp_sec": kf.timestamp.seconds, + "quality": asdict(kf.quality), + "confidence": kf.confidence.value, + } + + +def _write_outputs( + video_path: Path, + output_dir: Path, + keyframes_per_shot: list[list[ExtractedKeyframe]], +) -> Path: + output_dir.mkdir(parents=True, exist_ok=True) + entries: list[dict[str, Any]] = [] + for shot_keyframes in keyframes_per_shot: + for kf in shot_keyframes: + filename = f"kf_{kf.shot_id:03d}_{kf.bucket_index:03d}.jpg" + _write_jpeg(output_dir / filename, kf.rgb, kf.width, kf.height) + entries.append(_manifest_entry(kf, filename)) + manifest_path = output_dir / "manifest.json" + manifest_path.write_text(json.dumps({"video": str(video_path), "keyframes": entries}, indent=2)) + return manifest_path + + +# --------------------------------------------------------------------------- # +# Argument parsing # +# --------------------------------------------------------------------------- # + + +class _Parser(argparse.ArgumentParser): + """Argparse subclass that exits with code 1 on parse errors (vs. argparse's default 2).""" + + def error(self, message: str) -> None: # type: ignore[override] + self.print_usage(sys.stderr) + self.exit(EXIT_INPUT_ERROR, f"error: {message}\n") + + +def _build_parser() -> argparse.ArgumentParser: + parser = _Parser( + prog="findit-keyframe", + description="Per-shot keyframe extraction. Consumes scenesdetect output, " + "writes one JPEG per (shot, bin) plus a manifest.", + ) + sub = parser.add_subparsers(dest="command", required=True) + + extract = sub.add_parser( + "extract", + help="Extract keyframes for a video given a shot list.", + ) + extract.add_argument("video", type=Path, help="Source video file.") + extract.add_argument("shots", type=Path, help="Shot list JSON (scenesdetect-compatible).") + extract.add_argument("output", type=Path, help="Output directory; created if missing.") + extract.add_argument( + "--config", + type=Path, + default=None, + help="Optional SamplingConfig JSON override file.", + ) + extract.add_argument( + "--saliency", + choices=["none", "apple"], + default="none", + help="Saliency provider. 'apple' uses Apple Vision (macOS, requires the .[macos] extra).", + ) + return parser + + +# --------------------------------------------------------------------------- # +# Command dispatch # +# --------------------------------------------------------------------------- # + + +def _build_saliency_provider(name: str) -> SaliencyProvider | None: + """Map a ``--saliency`` CLI choice to a provider instance. + + Args: + name: Choice from the ``--saliency`` flag. Currently ``"none"`` or + ``"apple"``; argparse ``choices`` enforces this upstream. + + Returns: + ``None`` for ``"none"`` (the sampler skips the saliency call + entirely); a fresh :class:`AppleVisionSaliencyProvider` for + ``"apple"``. + + Raises: + RuntimeError: If ``"apple"`` is requested off-Darwin or without + ``pyobjc-framework-Vision`` installed (propagated from the + provider's constructor). + ValueError: If ``name`` is not one of the known choices (defensive + check; argparse normally prevents this). + """ + if name == "none": + return None + if name == "apple": + from findit_keyframe.saliency import AppleVisionSaliencyProvider + + return AppleVisionSaliencyProvider() + raise ValueError(f"unknown saliency provider: {name!r}") + + +def _extract_command(args: argparse.Namespace) -> int: + """Run the ``extract`` subcommand and return its exit code. + + Args: + args: Parsed argparse namespace from the ``extract`` subparser. + + Returns: + ``EXIT_OK`` on success, ``EXIT_INPUT_ERROR`` for bad JSON / unknown + config field / unsupported saliency, ``EXIT_EXTRACTION_ERROR`` for + decode/encode failures. + """ + try: + shots = _parse_shot_json(args.shots) + config = _parse_config_json(args.config) + saliency = _build_saliency_provider(args.saliency) + except (KeyError, ValueError, RuntimeError, json.JSONDecodeError, FileNotFoundError) as exc: + print(f"error: invalid input: {exc}", file=sys.stderr) + return EXIT_INPUT_ERROR + + try: + with VideoDecoder.open(args.video, target_size=config.target_size) as decoder: + keyframes = extract_all(shots, decoder, config, saliency_provider=saliency) + manifest_path = _write_outputs(args.video, args.output, keyframes) + except (av.error.FFmpegError, OSError, ValueError, RuntimeError) as exc: + # Programmer bugs (TypeError, AttributeError, etc.) deliberately + # propagate as crashes so they surface in tracebacks instead of + # being mapped to the extraction-error exit code. + print(f"error: extraction failed: {exc}", file=sys.stderr) + return EXIT_EXTRACTION_ERROR + + n_keyframes = sum(len(s) for s in keyframes) + # Status messages go to stderr so stdout stays available for any future + # machine-readable output (the manifest path is already written to disk + # and discoverable at args.output / "manifest.json"). + print(f"wrote {n_keyframes} keyframes to {args.output}", file=sys.stderr) + print(f"manifest: {manifest_path}", file=sys.stderr) + return EXIT_OK + + +def main() -> int: + """Console-script entry point.""" + parser = _build_parser() + args = parser.parse_args() + if args.command == "extract": + return _extract_command(args) + parser.print_help() + return EXIT_INPUT_ERROR + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/findit_keyframe/decoder.py b/src/findit_keyframe/decoder.py new file mode 100644 index 0000000..15d0fca --- /dev/null +++ b/src/findit_keyframe/decoder.py @@ -0,0 +1,305 @@ +"""PyAV-backed video decoder with two strategies. + +* ``decode_at(time)`` — keyframe seek + forward decode to the target PTS. + Cheap when shots are sparse; pays a seek penalty per call. +* ``decode_sequential(shots)`` — single linear pass through the file, + yielding frames whose PTS lies inside any provided shot range. Cheap when + shots are dense; pays no seek penalty but reads the whole file once. + +``pick_strategy`` chooses between them based on shot density and count. +The thresholds are documented in ``TASKS.md`` §3 (density > 0.3 shots/s +or count > 200 → Sequential). + +The Rust port (``ffmpeg-next``) preserves the same public surface. +See ``docs/rust-porting.md`` §2. +""" + +from __future__ import annotations + +import enum +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Self + +import av + +from findit_keyframe.types import ShotRange, Timebase, Timestamp + +if TYPE_CHECKING: + from collections.abc import Iterator + from pathlib import Path + from types import TracebackType + + import numpy as np + import numpy.typing as npt + + +__all__ = [ + "DecodedFrame", + "Strategy", + "VideoDecoder", + "pick_strategy", +] + + +class Strategy(enum.StrEnum): + """Frame-fetch strategy for a video + shot list.""" + + Sequential = "sequential" + PerShotSeek = "per_shot_seek" + + +@dataclass(frozen=True, slots=False) +class DecodedFrame: + """A decoded RGB24 frame with its presentation timestamp. + + ``rgb`` has shape ``(height, width, 3)`` and dtype ``uint8``. ``width`` + and ``height`` are stored explicitly to mirror the Rust struct (where + ``rgb`` becomes a ``Vec`` carrying no shape info). + """ + + pts: Timestamp + width: int + height: int + rgb: npt.NDArray[np.uint8] = field(repr=False) + + def __post_init__(self) -> None: + expected = (self.height, self.width, 3) + if self.rgb.shape != expected: + raise ValueError( + f"rgb shape {self.rgb.shape} does not match (height, width, 3) = {expected}" + ) + + +def pick_strategy(shots: list[ShotRange], duration_sec: float) -> Strategy: + """Choose ``Sequential`` for dense or numerous shots, ``PerShotSeek`` otherwise. + + Empty shot lists and unknown durations (``duration_sec <= 0``) collapse to + ``PerShotSeek`` so callers don't have to special-case those paths. + + Args: + shots: Shot list whose density is being measured. + duration_sec: Total video duration in seconds. + + Returns: + ``Strategy.Sequential`` if ``len(shots) / duration_sec > 0.3`` or + ``len(shots) > 200`` (the count short-circuit catches very long + videos with many cuts where density alone would be misleading); + ``Strategy.PerShotSeek`` otherwise. + """ + if not shots or duration_sec <= 0.0: + return Strategy.PerShotSeek + density = len(shots) / duration_sec + if density > 0.3 or len(shots) > 200: + return Strategy.Sequential + return Strategy.PerShotSeek + + +class VideoDecoder: + """PyAV-backed video decoder. + + Construct via ``VideoDecoder.open(path, target_size=...)``. Use as a + context manager (or call ``close()`` explicitly) to release the + underlying container. + + ``target_size`` (in pixels) controls the output frame's square edge + length; a value of ``0`` keeps the native resolution. + """ + + def __init__( + self, + container: Any, + stream: Any, + target_size: int = 0, + ) -> None: + self._container = container + self._stream = stream + self._target_size = target_size + + tb = stream.time_base + self._timebase = Timebase(num=tb.numerator, den=tb.denominator) + + if stream.duration is not None: + self._duration_sec = float(stream.duration * tb) + else: + self._duration_sec = 0.0 + + self._fps = float(stream.average_rate) if stream.average_rate else 0.0 + self._native_width = int(stream.codec_context.width) + self._native_height = int(stream.codec_context.height) + + @classmethod + def open(cls, path: Path, target_size: int = 0) -> Self: + """Open a video file for decoding. + + The container is held open until :meth:`close` is called or the + instance leaves a ``with`` block. If stream selection fails the + container is closed before re-raising. + + Args: + path: Filesystem path to a video container readable by FFmpeg. + target_size: Square output edge length in pixels for decoded + frames. Pass ``0`` (the default) to keep native resolution; + any positive value triggers ``swscale`` resize on every + decode. + + Returns: + A ready-to-use :class:`VideoDecoder` positioned at the start + of the first video stream. + + Raises: + av.error.FFmpegError: If FFmpeg cannot open or probe the file + (missing file, unknown format, corrupted header, ...). + IndexError: If the container has no video stream. + """ + container = av.open(str(path)) + try: + stream = container.streams.video[0] + stream.thread_type = "AUTO" + return cls(container, stream, target_size=target_size) + except Exception: + container.close() + raise + + @property + def duration_sec(self) -> float: + """Stream duration in seconds, or ``0.0`` if FFmpeg reports unknown.""" + return self._duration_sec + + @property + def fps(self) -> float: + """Average frame rate in frames per second, or ``0.0`` when unknown.""" + return self._fps + + @property + def timebase(self) -> Timebase: + """The stream's PTS timebase, suitable for building :class:`Timestamp`.""" + return self._timebase + + @property + def width(self) -> int: + """Output frame width in pixels (after ``target_size`` resize, if any).""" + return self._target_size if self._target_size else self._native_width + + @property + def height(self) -> int: + """Output frame height in pixels (after ``target_size`` resize, if any).""" + return self._target_size if self._target_size else self._native_height + + def close(self) -> None: + """Release the underlying FFmpeg container. Idempotent in practice.""" + self._container.close() + + def __enter__(self) -> Self: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, + ) -> None: + self.close() + + # ------------------------------------------------------------------ # + # Frame conversion # + # ------------------------------------------------------------------ # + + def _to_decoded(self, frame: Any) -> DecodedFrame: + if self._target_size: + frame = frame.reformat( + width=self._target_size, + height=self._target_size, + format="rgb24", + ) + else: + frame = frame.reformat(format="rgb24") + rgb: npt.NDArray[np.uint8] = frame.to_ndarray() + return DecodedFrame( + pts=Timestamp(int(frame.pts), self._timebase), + width=int(rgb.shape[1]), + height=int(rgb.shape[0]), + rgb=rgb, + ) + + # ------------------------------------------------------------------ # + # decode_at # + # ------------------------------------------------------------------ # + + def decode_at(self, time_sec: float) -> DecodedFrame: + """Seek + decode the first frame whose PTS is at or after ``time_sec``. + + Implementation: seek backward to the keyframe at-or-before the target, + then decode forward until the first frame with ``pts >= target_pts``. + Resolves to within ±1 frame of the requested time. + + Args: + time_sec: Wall-clock seek target in seconds. May exceed + :attr:`duration_sec`; in that case no frame is found and + ``ValueError`` is raised. + + Returns: + The first decoded frame whose PTS is at or after the target, + already converted to packed RGB24 at the configured size. + + Raises: + ValueError: If decoding consumes the rest of the stream without + finding a frame at or after the target (typically because + ``time_sec`` is past end-of-stream). + av.error.FFmpegError: If the container raises a decode error. + """ + target_pts = round(time_sec * self._timebase.den / self._timebase.num) + self._container.seek(target_pts, stream=self._stream, any_frame=False) + for frame in self._container.decode(self._stream): + if frame.pts is None: + continue + if frame.pts >= target_pts: + return self._to_decoded(frame) + raise ValueError(f"Could not decode any frame at or after {time_sec} s") + + # ------------------------------------------------------------------ # + # decode_sequential # + # ------------------------------------------------------------------ # + + def decode_sequential( + self, + shots: list[ShotRange], + ) -> Iterator[tuple[int, DecodedFrame]]: + """Single-pass scan; yield ``(shot_id, frame)`` for frames inside any shot. + + Internally sorts shots by start time so the cursor only moves + forward, but the yielded ``shot_id`` is the *original* index into + ``shots`` so callers can correlate with their unsorted input. + + Args: + shots: List of shots to cover. Empty input yields nothing. + Shots are assumed non-overlapping; behaviour with + overlapping ranges is unspecified. + + Yields: + ``(shot_id, frame)`` pairs in PTS order. Frames whose PTS + falls in the gap between consecutive shots are skipped. + + Raises: + av.error.FFmpegError: If the container raises a decode error + during the linear pass. + """ + if not shots: + return + + sorted_shots = sorted(enumerate(shots), key=lambda item: item[1].start) + cursor = 0 + + self._container.seek(0, stream=self._stream) + for frame in self._container.decode(self._stream): + if frame.pts is None: + continue + ts = Timestamp(int(frame.pts), self._timebase) + + while cursor < len(sorted_shots) and sorted_shots[cursor][1].end <= ts: + cursor += 1 + if cursor >= len(sorted_shots): + return + + shot_id, shot = sorted_shots[cursor] + if shot.start <= ts < shot.end: + yield shot_id, self._to_decoded(frame) diff --git a/src/findit_keyframe/py.typed b/src/findit_keyframe/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/src/findit_keyframe/quality.py b/src/findit_keyframe/quality.py new file mode 100644 index 0000000..7c3251a --- /dev/null +++ b/src/findit_keyframe/quality.py @@ -0,0 +1,192 @@ +"""Per-frame quality metrics. Pure numpy, no OpenCV / scipy / skimage. + +Every function is shaped for a 1:1 Rust port: scalar arithmetic on slices, +no broadcasting magic, no in-place mutation across function boundaries. +See ``docs/rust-porting.md`` §3 for the numpy-to-ndarray idiom map. +""" + +from __future__ import annotations + +from dataclasses import dataclass + +import numpy as np +import numpy.typing as npt + +from findit_keyframe.types import QualityMetrics + +__all__ = [ + "QualityGate", + "compute_quality", + "entropy", + "laplacian_variance", + "luma_variance", + "mean_luma", + "rgb_to_luma", +] + +LumaArray = npt.NDArray[np.uint8] +RgbArray = npt.NDArray[np.uint8] + + +def rgb_to_luma(rgb: RgbArray) -> LumaArray: + """Convert packed RGB24 to BT.601 limited-range luma (``Y ∈ [16, 235]``). + + Formula: ``((66*R + 129*G + 25*B + 128) >> 8) + 16``. The integer + fixed-point form is identical to scenesdetect's so a Y-plane round-trip + is bit-exact across crates. + + Args: + rgb: Array of shape ``(H, W, 3)`` and dtype ``uint8``. Channel + order is RGB. + + Returns: + Array of shape ``(H, W)`` and dtype ``uint8`` carrying the luma plane. + + Raises: + ValueError: If ``rgb`` is not ``uint8`` or its shape is not + ``(H, W, 3)``. + """ + if rgb.dtype != np.uint8: + raise ValueError(f"rgb must be uint8, got {rgb.dtype}") + if rgb.ndim != 3 or rgb.shape[2] != 3: + raise ValueError(f"rgb must have shape (H, W, 3), got {rgb.shape}") + r = rgb[..., 0].astype(np.uint32) + g = rgb[..., 1].astype(np.uint32) + b = rgb[..., 2].astype(np.uint32) + y = ((66 * r + 129 * g + 25 * b + 128) >> 8) + 16 + return y.astype(np.uint8) + + +def laplacian_variance(luma: LumaArray) -> float: + """Variance of a 3x3 Laplacian-filtered luma image; a sharpness proxy. + + Kernel ``[[0, 1, 0], [1, -4, 1], [0, 1, 0]]``. The one-pixel border is + dropped (no padding). Variance uses the population denominator (``N``) + to match OpenCV's ``meanStdDev(Laplacian(...))``. + + Args: + luma: Single-plane luma array, dtype ``uint8``. + + Returns: + Population variance of the filtered interior pixels, as a Python + ``float``. Higher values mean sharper content. + + Raises: + ValueError: If ``luma`` is not 2-D or smaller than ``3x3``. + """ + if luma.ndim != 2: + raise ValueError(f"luma must be 2D, got shape {luma.shape}") + if luma.shape[0] < 3 or luma.shape[1] < 3: + raise ValueError(f"luma too small for 3x3 Laplacian: {luma.shape}") + f = luma.astype(np.int32) + out = f[:-2, 1:-1] + f[2:, 1:-1] + f[1:-1, :-2] + f[1:-1, 2:] - 4 * f[1:-1, 1:-1] + return float(out.var()) + + +def mean_luma(luma: LumaArray) -> float: + """Compute the arithmetic mean of luma values normalised to ``[0.0, 1.0]``. + + Args: + luma: Single-plane luma array on the 0-255 scale. + + Returns: + ``float(luma.mean()) / 255.0``. + """ + return float(luma.mean()) / 255.0 + + +def luma_variance(luma: LumaArray) -> float: + """Compute the sample variance (``ddof=1``) of luma on the raw 0-255 scale. + + Args: + luma: Single-plane luma array on the 0-255 scale. + + Returns: + Sample variance with Bessel's correction. Pixel-flat frames return + ``0.0``. + """ + return float(luma.var(ddof=1)) + + +def entropy(luma: LumaArray, bins: int = 256) -> float: + """Compute the Shannon entropy in bits of a ``bins``-bin luma histogram. + + Histogram range is fixed to ``[0, 256)`` (i.e. one bin per integer level + when ``bins == 256``), making the function input-only deterministic. + + Args: + luma: Single-plane luma array on the 0-255 scale. + bins: Number of histogram bins. Defaults to ``256``. + + Returns: + ``-Σ p_i * log2(p_i)`` over non-zero probabilities. Range is + ``[0, log2(bins)]``; ``0.0`` for delta distributions and empty input. + """ + counts, _ = np.histogram(luma, bins=bins, range=(0, 256)) + total = counts.sum() + if total == 0: + return 0.0 + p = counts.astype(np.float64) / float(total) + p_nz = p[p > 0] + if p_nz.size <= 1: + # Delta distribution: H = -1 * log2(1) = 0.0 mathematically, but the + # naive expression evaluates to -0.0 in IEEE-754. Short-circuit so + # the returned float is bit-clean and the JSON snapshot stays tidy. + return 0.0 + return float(-np.sum(p_nz * np.log2(p_nz))) + + +@dataclass(frozen=True, slots=False) +class QualityGate: + """Hard pass/fail gate. + + Defaults reject pixel-flat frames (all-black, all-white, single colour) + and frames with too little luma spread to carry useful information for + downstream models. Boundaries are inclusive. + """ + + min_mean_luma: float = 15.0 / 255.0 + max_mean_luma: float = 240.0 / 255.0 + min_luma_variance: float = 5.0 + + def passes(self, metrics: QualityMetrics) -> bool: + """Return ``True`` when ``metrics`` clears every threshold (boundaries inclusive). + + Args: + metrics: Per-frame metrics produced by :func:`compute_quality`. + + Returns: + ``True`` iff ``mean_luma`` lies inclusively in + ``[min_mean_luma, max_mean_luma]`` *and* ``luma_variance ≥ + min_luma_variance``. + """ + return ( + self.min_mean_luma <= metrics.mean_luma <= self.max_mean_luma + and metrics.luma_variance >= self.min_luma_variance + ) + + +def compute_quality(rgb: RgbArray, saliency: float | None = None) -> QualityMetrics: + """Compute all per-frame quality signals from a packed RGB24 array. + + Args: + rgb: Array of shape ``(H, W, 3)`` and dtype ``uint8``; channel order RGB. + saliency: Optional saliency mass in ``[0.0, 1.0]`` for this frame. + ``None`` (the default) records ``0.0`` so frames extracted + without a saliency provider remain comparable. + + Returns: + A :class:`QualityMetrics` populated with all five signals. + + Raises: + ValueError: Propagated from :func:`rgb_to_luma` when the input has + the wrong shape or dtype. + """ + luma = rgb_to_luma(rgb) + return QualityMetrics( + laplacian_var=laplacian_variance(luma), + mean_luma=mean_luma(luma), + luma_variance=luma_variance(luma), + entropy=entropy(luma), + saliency_mass=0.0 if saliency is None else float(saliency), + ) diff --git a/src/findit_keyframe/saliency.py b/src/findit_keyframe/saliency.py new file mode 100644 index 0000000..bf7e08d --- /dev/null +++ b/src/findit_keyframe/saliency.py @@ -0,0 +1,197 @@ +"""Optional saliency providers. + +The sampler can score candidates with a per-frame "saliency mass" in +``[0.0, 1.0]`` reflecting how attention-grabbing the frame is. Two providers +ship today: + +* ``NoopSaliencyProvider`` — always returns ``0.0``. Default and fallback. +* ``AppleVisionSaliencyProvider`` — wraps Apple's ``VNGenerateAttention\ +BasedSaliencyImageRequest`` (macOS only, requires ``pyobjc-framework-Vision`` +from the ``[macos]`` extra). + +The Apple provider derives its scalar from the request's ``salientObjects`` +bounding boxes — ``sum(area * confidence)`` clamped to ``[0, 1]`` — rather +than reading the raw saliency heatmap ``CVPixelBuffer``. The heatmap path +needs awkward ctypes pointer dereferences from pyobjc; bounding boxes carry +the same "is anything attention-grabbing here" signal with a clean API. The +Rust port (``objc2-vision``) can use either path. +""" + +from __future__ import annotations + +import platform +from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable + +if TYPE_CHECKING: + import numpy as np + import numpy.typing as npt + +__all__ = [ + "AppleVisionSaliencyProvider", + "NoopSaliencyProvider", + "SaliencyProvider", + "default_saliency_provider", +] + + +@runtime_checkable +class SaliencyProvider(Protocol): + """Compute a per-frame saliency mass in ``[0.0, 1.0]`` for an RGB array.""" + + def compute(self, rgb: npt.NDArray[np.uint8]) -> float: ... + + +class NoopSaliencyProvider: + """Always returns ``0.0``. Used as the default and on non-macOS systems.""" + + def compute(self, rgb: npt.NDArray[np.uint8]) -> float: + return 0.0 + + +class AppleVisionSaliencyProvider: + """Apple Vision attention-based saliency. macOS only. + + ``pyobjc-framework-Vision`` and ``pyobjc-framework-Quartz`` are imported + lazily inside ``__init__`` so the surrounding module can be imported on + Linux (CI runners) without exploding. + """ + + def __init__(self) -> None: + """Construct an Apple Vision saliency provider. + + Raises: + RuntimeError: If the host is not macOS, or if + ``pyobjc-framework-Vision`` and ``pyobjc-framework-Quartz`` + are not importable (install via the ``[macos]`` extra). + """ + if platform.system() != "Darwin": + raise RuntimeError("AppleVisionSaliencyProvider requires macOS") + try: + from Quartz import ( + CGColorSpaceCreateDeviceRGB, + CGDataProviderCreateWithData, + CGImageCreate, + kCGImageAlphaNoneSkipLast, + kCGRenderingIntentDefault, + ) + from Vision import ( + VNGenerateAttentionBasedSaliencyImageRequest, + VNImageRequestHandler, + ) + except ImportError as exc: + raise RuntimeError( + "AppleVisionSaliencyProvider requires pyobjc-framework-Vision; " + "install with: pip install -e '.[macos]'" + ) from exc + + # Hold strong references on self so each compute() call avoids module + # attribute lookups in the hot path. + self._CGImageCreate = CGImageCreate + self._CGDataProviderCreateWithData = CGDataProviderCreateWithData + self._CGColorSpaceCreateDeviceRGB = CGColorSpaceCreateDeviceRGB + self._kCGImageAlphaNoneSkipLast = kCGImageAlphaNoneSkipLast + self._kCGRenderingIntentDefault = kCGRenderingIntentDefault + self._VNRequest = VNGenerateAttentionBasedSaliencyImageRequest + self._VNHandler = VNImageRequestHandler + + def _rgb_to_cgimage(self, rgb: npt.NDArray[np.uint8]) -> Any: + """Wrap a packed RGB24 ndarray as a Core Graphics ``CGImage``. + + Quartz requires 32-bit aligned bitmap formats, so the buffer is + padded RGB → RGBX with the alpha channel marked as ignored + (``kCGImageAlphaNoneSkipLast``). + + Args: + rgb: Array of shape ``(H, W, 3)`` and dtype ``uint8``. + + Returns: + An opaque ``CGImage`` reference suitable for + ``VNImageRequestHandler.initWithCGImage_options_``. + + Raises: + ValueError: If ``rgb`` is not ``uint8`` or its shape is not + ``(H, W, 3)``. + """ + import numpy as np + + if rgb.dtype != np.uint8 or rgb.ndim != 3 or rgb.shape[2] != 3: + raise ValueError( + f"rgb must be uint8 (H, W, 3); got dtype={rgb.dtype}, shape={rgb.shape}" + ) + height, width = int(rgb.shape[0]), int(rgb.shape[1]) + # Quartz wants 32-bit aligned bitmaps; pad RGB -> RGBX with alpha ignored. + rgbx = np.empty((height, width, 4), dtype=np.uint8) + rgbx[..., :3] = rgb + rgbx[..., 3] = 0 + data = bytes(rgbx) + provider = self._CGDataProviderCreateWithData(None, data, len(data), None) + color_space = self._CGColorSpaceCreateDeviceRGB() + return self._CGImageCreate( + width, + height, + 8, # bitsPerComponent + 32, # bitsPerPixel + width * 4, # bytesPerRow + color_space, + self._kCGImageAlphaNoneSkipLast, + provider, + None, # decode + False, # shouldInterpolate (positional per Objective-C signature) # noqa: FBT003 + self._kCGRenderingIntentDefault, + ) + + def compute(self, rgb: npt.NDArray[np.uint8]) -> float: + """Run an Apple Vision attention saliency request and return its scalar. + + Each ``VNRectangleObservation`` returned by the request carries a + normalised ``boundingBox`` (in ``[0, 1]`` image coordinates) and a + ``confidence`` in ``[0, 1]``. The provider sums ``area * + confidence`` over all such observations and clamps to ``[0, 1]``. + Vision request failure (rare) collapses to ``0.0`` rather than + raising, so a single bad frame doesn't abort a whole shot. + + Args: + rgb: Array of shape ``(H, W, 3)`` and dtype ``uint8``. + + Returns: + Saliency mass in ``[0.0, 1.0]``. Higher means more + attention-grabbing. + + Raises: + ValueError: If ``rgb`` has the wrong dtype or shape (propagated + from :meth:`_rgb_to_cgimage`). + """ + cg_image = self._rgb_to_cgimage(rgb) + request = self._VNRequest.alloc().init() + handler = self._VNHandler.alloc().initWithCGImage_options_(cg_image, {}) + success, _error = handler.performRequests_error_([request], None) + if not success: + return 0.0 + results = request.results() or [] + if not results: + return 0.0 + observation = results[0] + salient_objects = observation.salientObjects() or [] + total = 0.0 + for obj in salient_objects: + box = obj.boundingBox() + area = float(box.size.width) * float(box.size.height) + confidence = float(obj.confidence()) + total += area * confidence + return min(1.0, total) + + +def default_saliency_provider() -> SaliencyProvider: + """Return the best provider available on the current platform. + + Returns: + :class:`AppleVisionSaliencyProvider` on macOS when + ``pyobjc-framework-Vision`` is importable; :class:`NoopSaliencyProvider` + otherwise (including macOS hosts without the ``[macos]`` extra). + """ + if platform.system() == "Darwin": + try: + return AppleVisionSaliencyProvider() + except RuntimeError: + return NoopSaliencyProvider() + return NoopSaliencyProvider() diff --git a/src/findit_keyframe/sampler.py b/src/findit_keyframe/sampler.py new file mode 100644 index 0000000..eb245f9 --- /dev/null +++ b/src/findit_keyframe/sampler.py @@ -0,0 +1,343 @@ +"""Stratified temporal sampler with quality-gated within-bin selection. + +The algorithm is documented in ``docs/algorithm.md``; this module is the +reference implementation. Each function maps 1:1 to a Rust function — see +``docs/rust-porting.md`` §3 for the idiom map. + +High-level flow per shot: + +1. ``compute_bins`` — partition the (boundary-shrunken) shot into N bins. +2. For each bin, probe K cell-centred candidate timestamps. +3. ``select_from_bin`` — apply ``QualityGate``, score survivors, + ``argmax`` -> emit ``Confidence.High``. +4. On bin failure, expand ±``fallback_expand_pct`` of bin width into + neighbours; retry. Surviving pick is tagged ``Confidence.Low``. +5. On expansion failure, force-pick the highest-score candidate even + though it failed the gate; emit ``Confidence.Degraded``. +""" + +from __future__ import annotations + +import math +from typing import TYPE_CHECKING + +import numpy as np + +from findit_keyframe.quality import QualityGate, compute_quality +from findit_keyframe.types import ( + Confidence, + ExtractedKeyframe, + QualityMetrics, + SamplingConfig, + ShotRange, +) + +if TYPE_CHECKING: + from findit_keyframe.decoder import DecodedFrame, VideoDecoder + from findit_keyframe.saliency import SaliencyProvider + + +__all__ = [ + "compute_bins", + "extract_all", + "extract_for_shot", + "score_bin_candidates", + "select_from_bin", +] + + +# --------------------------------------------------------------------------- # +# Binning # +# --------------------------------------------------------------------------- # + + +def compute_bins(shot: ShotRange, config: SamplingConfig) -> list[tuple[float, float]]: + """Partition ``shot`` into N equal bins after symmetric boundary shrinkage. + + ``N = clamp(ceil(D / target_interval_sec), 1, max_frames_per_shot)`` where + ``D = shot.duration_sec``. The shrunken effective range is + ``[start + s, end - s]`` with ``s = boundary_shrink_pct * D``; bins divide + this range evenly and are returned as ``(t0, t1)`` half-open intervals. + + Args: + shot: The shot to partition. ``shot.duration_sec > 0`` is required + by :class:`ShotRange`'s constructor; this function additionally + handles a defensive ``<= 0`` case by returning an empty list. + config: Sampling parameters; ``target_interval_sec``, + ``max_frames_per_shot``, and ``boundary_shrink_pct`` are read. + + Returns: + A list of ``(t0, t1)`` second-valued tuples, all of equal width + and contiguous (``bins[i].t1 == bins[i+1].t0``). + """ + duration = shot.duration_sec + if duration <= 0.0: + return [] + n = max( + 1, + min(config.max_frames_per_shot, math.ceil(duration / config.target_interval_sec)), + ) + shrink = config.boundary_shrink_pct * duration + start = shot.start.seconds + shrink + end = shot.end.seconds - shrink + width = (end - start) / n + return [(start + i * width, start + (i + 1) * width) for i in range(n)] + + +# --------------------------------------------------------------------------- # +# Candidate sampling # +# --------------------------------------------------------------------------- # + + +def _candidate_times(t0: float, t1: float, k: int) -> list[float]: + """``k`` evenly-spaced cell-centre timestamps in ``[t0, t1]``. + + Cell centres (rather than endpoints) avoid sampling exactly on bin + boundaries, which is where the upstream cut detector is least confident. + """ + if k <= 0: + return [] + if k == 1: + return [(t0 + t1) / 2.0] + width = (t1 - t0) / k + return [t0 + (i + 0.5) * width for i in range(k)] + + +# --------------------------------------------------------------------------- # +# Scoring # +# --------------------------------------------------------------------------- # + + +def _ordinal_rank(values: list[float]) -> list[float]: + """Stable ordinal rank in ``[0, 1]`` (lowest -> 0, highest -> 1). + + Ties keep input order: the earlier-seen value gets the lower rank. This + is sufficient because ``argmax`` on the composite score is deterministic + either way and the algorithm does not depend on tie-broken averaging. + """ + n = len(values) + if n == 0: + return [] + if n == 1: + return [1.0] + order = np.asarray(values, dtype=np.float64).argsort(kind="stable") + ranks = np.empty(n, dtype=np.float64) + ranks[order] = np.arange(n) / (n - 1) + return [float(r) for r in ranks] + + +def score_bin_candidates(metrics_list: list[QualityMetrics]) -> list[float]: + """Compute the composite quality score for each candidate in a bin. + + Formula: ``score = 0.6 * rank(laplacian_var) + 0.2 * rank(entropy) + + 0.2 * saliency_mass``. ``rank`` is the stable ordinal rank within the + bin's pool, scaled to ``[0, 1]``. Single-candidate bins use ``rank = 1`` + for both quality terms, so the score collapses to ``0.8 + 0.2 * + saliency_mass``. + + Args: + metrics_list: Per-candidate quality metrics, in candidate order. + + Returns: + A parallel list of float scores in candidate order. Empty input + returns an empty list. + """ + if not metrics_list: + return [] + if len(metrics_list) == 1: + return [0.6 + 0.2 + 0.2 * metrics_list[0].saliency_mass] + lap = _ordinal_rank([m.laplacian_var for m in metrics_list]) + ent = _ordinal_rank([m.entropy for m in metrics_list]) + return [ + 0.6 * lap[i] + 0.2 * ent[i] + 0.2 * metrics_list[i].saliency_mass + for i in range(len(metrics_list)) + ] + + +# --------------------------------------------------------------------------- # +# Selection # +# --------------------------------------------------------------------------- # + + +def _compute_metrics( + candidates: list[DecodedFrame], + saliency_provider: SaliencyProvider | None, +) -> list[QualityMetrics]: + if saliency_provider is None: + return [compute_quality(c.rgb) for c in candidates] + return [compute_quality(c.rgb, saliency=saliency_provider.compute(c.rgb)) for c in candidates] + + +def select_from_bin( + candidates: list[DecodedFrame], + quality_gate: QualityGate, + saliency_provider: SaliencyProvider | None = None, +) -> tuple[DecodedFrame, QualityMetrics, Confidence] | None: + """Apply the gate, score survivors, return the highest-scoring one. + + Args: + candidates: Decoded candidate frames for this bin. Empty input + returns ``None``. + quality_gate: Hard pass/fail predicate applied to each candidate. + saliency_provider: Optional saliency contributor to the composite + score; ``None`` skips the per-frame saliency call entirely. + + Returns: + ``(frame, metrics, Confidence.High)`` for the winning candidate, or + ``None`` when every candidate fails the gate. The fallback path + wraps the confidence to ``Low`` or ``Degraded`` as appropriate. + """ + if not candidates: + return None + metrics = _compute_metrics(candidates, saliency_provider) + survivors = [(f, m) for f, m in zip(candidates, metrics, strict=True) if quality_gate.passes(m)] + if not survivors: + return None + surviving_metrics = [m for _, m in survivors] + scores = score_bin_candidates(surviving_metrics) + best = int(np.argmax(scores)) + frame, mtr = survivors[best] + return frame, mtr, Confidence.High + + +def _select_with_fallback( + bin_idx: int, + bins: list[tuple[float, float]], + shot: ShotRange, + decoder: VideoDecoder, + config: SamplingConfig, + quality_gate: QualityGate, + saliency_provider: SaliencyProvider | None = None, +) -> tuple[DecodedFrame, QualityMetrics, Confidence]: + """Try the native bin, then an expanded window, then force-pick. + + Always returns a frame; the caller never has to handle ``None``. + + Args: + bin_idx: Index of the bin to fill, in ``[0, len(bins))``. + bins: Output of :func:`compute_bins` for the same shot. + shot: The parent shot; used to clamp the expanded window. + decoder: Decoder for fetching candidate frames. + config: Sampling parameters (probe count, fallback expansion, ...). + quality_gate: Hard pass/fail predicate. + saliency_provider: Optional saliency contributor; ``None`` skips it. + + Returns: + ``(frame, metrics, confidence)`` for the bin's selected keyframe. + The confidence ladder is ``High`` (native pool) → ``Low`` (expanded + pool) → ``Degraded`` (force-picked from a gate-failing pool). + """ + t0, t1 = bins[bin_idx] + bin_width = t1 - t0 + + native = [decoder.decode_at(t) for t in _candidate_times(t0, t1, config.candidates_per_bin)] + native_pick = select_from_bin(native, quality_gate, saliency_provider) + if native_pick is not None: + return native_pick + + expand = config.fallback_expand_pct * bin_width + et0 = max(shot.start.seconds, t0 - expand) + et1 = min(shot.end.seconds, t1 + expand) + expanded = [decoder.decode_at(t) for t in _candidate_times(et0, et1, config.candidates_per_bin)] + expanded_pick = select_from_bin(expanded, quality_gate, saliency_provider) + if expanded_pick is not None: + frame, metrics, _ = expanded_pick + return frame, metrics, Confidence.Low + + pool = native + expanded + metrics_pool = _compute_metrics(pool, saliency_provider) + scores = score_bin_candidates(metrics_pool) + best = int(np.argmax(scores)) + return pool[best], metrics_pool[best], Confidence.Degraded + + +# --------------------------------------------------------------------------- # +# Top-level entry points # +# --------------------------------------------------------------------------- # + + +def extract_for_shot( + shot: ShotRange, + shot_id: int, + decoder: VideoDecoder, + config: SamplingConfig, + quality_gate: QualityGate | None = None, + saliency_provider: SaliencyProvider | None = None, +) -> list[ExtractedKeyframe]: + """Extract one keyframe per bin for a single shot. + + Args: + shot: The shot to process. + shot_id: Identifier copied verbatim into every emitted + :class:`ExtractedKeyframe`. The caller chooses the numbering + scheme; :func:`extract_all` uses input-list index. + decoder: An open :class:`VideoDecoder` covering ``shot``. + config: Sampling parameters. + quality_gate: Optional override; defaults to :class:`QualityGate()`. + saliency_provider: Optional saliency contributor; ``None`` means + ``saliency_mass = 0``. + + Returns: + Exactly ``len(compute_bins(shot, config))`` keyframes (one per bin), + each tagged with the bin's ``bucket_index`` and a + :class:`Confidence`. The ``rgb`` buffer is materialised as ``bytes`` + so the result is safely serialisable. + + Raises: + ValueError: Propagated from :meth:`VideoDecoder.decode_at` when a + probe falls past the end of the stream (typically a malformed + shot whose end exceeds the decoder's duration). + """ + bins = compute_bins(shot, config) + gate = quality_gate or QualityGate() + out: list[ExtractedKeyframe] = [] + for i in range(len(bins)): + frame, metrics, confidence = _select_with_fallback( + i, bins, shot, decoder, config, gate, saliency_provider + ) + out.append( + ExtractedKeyframe( + shot_id=shot_id, + timestamp=frame.pts, + bucket_index=i, + rgb=frame.rgb.tobytes(), + width=frame.width, + height=frame.height, + quality=metrics, + confidence=confidence, + ) + ) + return out + + +def extract_all( + shots: list[ShotRange], + decoder: VideoDecoder, + config: SamplingConfig, + quality_gate: QualityGate | None = None, + saliency_provider: SaliencyProvider | None = None, +) -> list[list[ExtractedKeyframe]]: + """Extract keyframes for every shot in ``shots``. + + Per-shot extraction goes through :func:`extract_for_shot`; this function + is only an iteration shell. The decoder's per-shot-seek path is the only + one currently implemented — see :class:`findit_keyframe.decoder.Strategy` + and :func:`findit_keyframe.decoder.pick_strategy` for the dense-shot + Sequential optimisation tracked for the Rust port. + + Args: + shots: Input shot list, in any order. The output preserves input order. + decoder: An open :class:`VideoDecoder` covering the same video. + config: Sampling parameters applied to every shot. + quality_gate: Optional override; defaults to :class:`QualityGate()`. + saliency_provider: Optional saliency contributor; ``None`` means + ``saliency_mass = 0`` for every keyframe. + + Returns: + A list with one entry per input shot, each entry a list of + :class:`ExtractedKeyframe` (one per bin in that shot). + """ + return [ + extract_for_shot(shot, shot_id, decoder, config, quality_gate, saliency_provider) + for shot_id, shot in enumerate(shots) + ] diff --git a/src/findit_keyframe/types.py b/src/findit_keyframe/types.py new file mode 100644 index 0000000..a4a9b08 --- /dev/null +++ b/src/findit_keyframe/types.py @@ -0,0 +1,280 @@ +"""Core value types for findit-keyframe. + +These types are written for a 1:1 Rust translation. See ``docs/rust-porting.md`` +for the Python ↔ Rust field map. + +Design rules: + +* ``Timebase``, ``Timestamp``, ``ShotRange`` and ``QualityMetrics`` are + ``frozen=True`` — they are value types whose identity is their content. +* ``Timebase`` and ``Timestamp`` use *semantic* equality (1/2 == 2/4, + ``1000 @ 1/1000 == 90000 @ 1/90000``) to mirror the upstream + ``scenesdetect`` Rust crate. +* ``SamplingConfig`` and ``ExtractedKeyframe`` are intentionally mutable so + callers can tweak knobs and attach downstream metadata before serialising. +""" + +from __future__ import annotations + +import enum +import math +from dataclasses import dataclass + +__all__ = [ + "Confidence", + "ExtractedKeyframe", + "QualityMetrics", + "SamplingConfig", + "ShotRange", + "Timebase", + "Timestamp", +] + + +# --------------------------------------------------------------------------- # +# Timebase # +# --------------------------------------------------------------------------- # + + +@dataclass(frozen=True, eq=False, slots=False) +class Timebase: + """Rational timebase ``num / den`` measured in seconds-per-tick. + + Mirrors ``scenesdetect::frame::Timebase``: the denominator is strictly + positive (``> 0``) and equality is *value-based* — ``Timebase(1, 2)`` and + ``Timebase(2, 4)`` compare equal and hash identically. + + Args: + num: Numerator. May be zero (degenerate "always now" timebase) but + cannot be negative under normal use. + den: Denominator. Must be strictly positive. + + Raises: + ValueError: If ``den <= 0``. + """ + + num: int + den: int + + def __post_init__(self) -> None: + if self.den <= 0: + raise ValueError(f"Timebase den must be > 0, got {self.den}") + + # Equality and hashing are reduced-form: we collapse via gcd so that + # 1/2, 2/4, 5/10 all hash and compare the same way. + def _reduced(self) -> tuple[int, int]: + g = math.gcd(abs(self.num), self.den) + if g == 0: + return (0, self.den) + return (self.num // g, self.den // g) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Timebase): + return NotImplemented + return self._reduced() == other._reduced() + + def __hash__(self) -> int: + return hash(self._reduced()) + + def __repr__(self) -> str: + return f"Timebase({self.num}/{self.den})" + + +# --------------------------------------------------------------------------- # +# Timestamp # +# --------------------------------------------------------------------------- # + + +@dataclass(frozen=True, eq=False, slots=False) +class Timestamp: + """A point in time at a given timebase. + + The wall-clock value in seconds is ``pts * timebase.num / timebase.den``. + Equality and ordering are *semantic* — comparisons across different + timebases use exact integer cross-multiplication, so + ``Timestamp(1000, 1/1000) == Timestamp(90000, 1/90000)``. + """ + + pts: int + timebase: Timebase + + @property + def seconds(self) -> float: + """Wall-clock value as a 64-bit float.""" + return self.pts * self.timebase.num / self.timebase.den + + # ----- semantic comparison via cross-multiply (no float loss) ---------- # + + def _key(self) -> tuple[int, int]: + """Return ``(numerator, denominator)`` of ``pts * num / den`` reduced.""" + n = self.pts * self.timebase.num + d = self.timebase.den + g = math.gcd(abs(n), d) + if g == 0: + return (0, d) + return (n // g, d // g) + + def _cross(self, other: Timestamp) -> tuple[int, int]: + """Return ``(self_scaled, other_scaled)`` over a common denominator.""" + return ( + self.pts * self.timebase.num * other.timebase.den, + other.pts * other.timebase.num * self.timebase.den, + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Timestamp): + return NotImplemented + a, b = self._cross(other) + return a == b + + def __lt__(self, other: Timestamp) -> bool: + a, b = self._cross(other) + return a < b + + def __le__(self, other: Timestamp) -> bool: + a, b = self._cross(other) + return a <= b + + def __gt__(self, other: Timestamp) -> bool: + a, b = self._cross(other) + return a > b + + def __ge__(self, other: Timestamp) -> bool: + a, b = self._cross(other) + return a >= b + + def __hash__(self) -> int: + return hash(self._key()) + + def __repr__(self) -> str: + return f"Timestamp(pts={self.pts}, {self.timebase!r}, seconds={self.seconds:g})" + + +# --------------------------------------------------------------------------- # +# ShotRange # +# --------------------------------------------------------------------------- # + + +@dataclass(frozen=True, slots=False) +class ShotRange: + """Half-open shot interval ``[start, end)``. + + ``start`` and ``end`` may use different timebases — the + ``Timestamp`` semantic comparison handles the mixed case correctly. + + Raises: + ValueError: If ``end <= start``. Zero-duration shots are a sign of + upstream input corruption and are rejected at construction. + """ + + start: Timestamp + end: Timestamp + + def __post_init__(self) -> None: + if not (self.start < self.end): + raise ValueError( + f"ShotRange end must be strictly after start: start={self.start!r}, " + f"end={self.end!r}" + ) + + @property + def duration_sec(self) -> float: + """Shot length in seconds (``end - start``).""" + return self.end.seconds - self.start.seconds + + +# --------------------------------------------------------------------------- # +# Confidence # +# --------------------------------------------------------------------------- # + + +class Confidence(enum.StrEnum): + """Per-keyframe confidence tag, surfaced in the manifest output. + + * ``High`` — selected from its native bin's quality-gated pool. + * ``Low`` — selected from an expanded fallback window (adjacent bins). + * ``Degraded`` — all candidates failed the hard gate; force-picked the + best of a bad lot. + """ + + High = "high" + Low = "low" + Degraded = "degraded" + + +# --------------------------------------------------------------------------- # +# SamplingConfig # +# --------------------------------------------------------------------------- # + + +@dataclass(slots=False) +class SamplingConfig: + """User-tunable knobs controlling stratified temporal sampling. + + Defaults are documented in ``docs/algorithm.md`` §7 ("Parameter + Rationale") and chosen for 24-60 fps source video where each shot is + between roughly 0.5 s and several minutes long. + """ + + target_interval_sec: float = 4.0 + candidates_per_bin: int = 6 + max_frames_per_shot: int = 16 + boundary_shrink_pct: float = 0.02 + fallback_expand_pct: float = 0.20 + target_size: int = 384 + + +# --------------------------------------------------------------------------- # +# QualityMetrics # +# --------------------------------------------------------------------------- # + + +@dataclass(frozen=True, slots=False) +class QualityMetrics: + """Per-frame quality signals computed by ``findit_keyframe.quality``. + + All fields are floats: + + * ``laplacian_var`` — variance of a 3x3 Laplacian-filtered luma image, + a sharpness/blur proxy. Higher is sharper. + * ``mean_luma`` — mean luminance normalised to ``[0.0, 1.0]``. + * ``luma_variance`` — sample variance of luma values (raw, on the 0-255 + integer scale before normalisation). + * ``entropy`` — Shannon entropy in bits of the 256-bin luma histogram. + * ``saliency_mass`` — Apple Vision attention score in ``[0.0, 1.0]``; + ``0.0`` when no saliency provider is configured. The bundled + :class:`findit_keyframe.saliency.AppleVisionSaliencyProvider` derives + this as ``clamp(sum(area * confidence), 0, 1)`` over the request's + ``salientObjects`` bounding boxes, *not* from the heatmap + ``CVPixelBuffer`` — see that module's docstring for rationale. + """ + + laplacian_var: float + mean_luma: float + luma_variance: float + entropy: float + saliency_mass: float + + +# --------------------------------------------------------------------------- # +# ExtractedKeyframe # +# --------------------------------------------------------------------------- # + + +@dataclass(slots=False) +class ExtractedKeyframe: + """One keyframe selected for a shot, with raw RGB pixels and metadata. + + ``rgb`` is a packed RGB24 byte buffer of length ``width * height * 3``. + The buffer is intentionally ``bytes`` (not ``np.ndarray``) so the + contract is portable and so the Rust port can use ``Vec`` directly. + """ + + shot_id: int + timestamp: Timestamp + bucket_index: int + rgb: bytes + width: int + height: int + quality: QualityMetrics + confidence: Confidence diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..0bb1b72 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,162 @@ +"""Shared pytest fixtures.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import av +import numpy as np +import pytest + +if TYPE_CHECKING: + from pathlib import Path + + +def _encode_ramp_video( + path: Path, + *, + n_frames: int, + fps: int, + width: int, + height: int, +) -> None: + """Encode a deterministic linear-gray-ramp video at ``path``. + + Frame ``i`` has nominal gray value ``round(i * 255 / (n_frames - 1))``. + libx264 + yuv420p is lossy enough to perturb individual pixel values, so + tests should not assert exact gray equality; the ramp is for ordinal + checks ("frame 15 is brighter than frame 0"). + """ + container = av.open(str(path), mode="w") + try: + stream = container.add_stream("libx264", rate=fps) + stream.width = width + stream.height = height + stream.pix_fmt = "yuv420p" + denom = max(n_frames - 1, 1) + for i in range(n_frames): + gray = round(i * 255 / denom) + rgb = np.full((height, width, 3), gray, dtype=np.uint8) + frame = av.VideoFrame.from_ndarray(rgb, format="rgb24") + for packet in stream.encode(frame): + container.mux(packet) + for packet in stream.encode(): + container.mux(packet) + finally: + container.close() + + +@pytest.fixture(scope="session") +def tiny_video(tmp_path_factory: pytest.TempPathFactory) -> Path: + """A 1-second 30-fps 64x64 ramp video; 30 frames total.""" + path = tmp_path_factory.mktemp("videos") / "tiny.mp4" + _encode_ramp_video(path, n_frames=30, fps=30, width=64, height=64) + return path + + +def _encode_textured_video( + path: Path, + *, + n_frames: int, + fps: int, + width: int, + height: int, +) -> None: + """Encode a video where every frame is mid-tone deterministic noise. + + The luma variance survives the libx264/yuv420p roundtrip cleanly, so the + sampler tests can exercise the quality gate's *pass* path. Each frame's + seed is its index, giving frame-to-frame independence and per-frame + reproducibility. + """ + container = av.open(str(path), mode="w") + try: + stream = container.add_stream("libx264", rate=fps) + stream.width = width + stream.height = height + stream.pix_fmt = "yuv420p" + stream.options = {"crf": "18", "preset": "ultrafast"} + for i in range(n_frames): + rng = np.random.default_rng(seed=i) + rgb = rng.integers(50, 201, size=(height, width, 3), dtype=np.uint8) + frame = av.VideoFrame.from_ndarray(rgb, format="rgb24") + for packet in stream.encode(frame): + container.mux(packet) + for packet in stream.encode(): + container.mux(packet) + finally: + container.close() + + +@pytest.fixture(scope="session") +def varied_video(tmp_path_factory: pytest.TempPathFactory) -> Path: + """A 1.5-second 30-fps 64x64 noise video; high luma variance per frame.""" + path = tmp_path_factory.mktemp("videos") / "varied.mp4" + _encode_textured_video(path, n_frames=45, fps=30, width=64, height=64) + return path + + +def _box_blur(rgb: np.ndarray, k: int) -> np.ndarray: + """Box-blur each channel of an ``(H, W, 3)`` RGB array with a ``k x k`` kernel. + + Edges are extended via ``mode='edge'`` padding. Pure numpy so the test + fixture does not need scipy. Output preserves dtype ``uint8``. + """ + pad = k // 2 + padded = np.pad(rgb, ((pad, pad), (pad, pad), (0, 0)), mode="edge").astype(np.uint32) + height, width = rgb.shape[:2] + accum = np.zeros((height, width, 3), dtype=np.uint32) + for dy in range(k): + for dx in range(k): + accum += padded[dy : dy + height, dx : dx + width] + return (accum // (k * k)).astype(np.uint8) + + +def _encode_quality_gradient_video( + path: Path, + *, + n_frames: int, + fps: int, + width: int, + height: int, + blur_kernel: int = 5, +) -> None: + """Encode a 3-thirds sharp / blur / sharp video for sampler quality tests. + + Frame ``i`` is mid-tone deterministic noise (seed = ``i``); the middle + third (``[N/3, 2N/3)``) is additionally smoothed with a ``blur_kernel`` + box filter so it survives the libx264 round-trip with measurably lower + Laplacian variance than the sharp thirds. + """ + container = av.open(str(path), mode="w") + try: + stream = container.add_stream("libx264", rate=fps) + stream.width = width + stream.height = height + stream.pix_fmt = "yuv420p" + stream.options = {"crf": "18", "preset": "ultrafast"} + third = n_frames // 3 + for i in range(n_frames): + rng = np.random.default_rng(seed=i) + rgb = rng.integers(50, 201, size=(height, width, 3), dtype=np.uint8) + if third <= i < 2 * third: + rgb = _box_blur(rgb, blur_kernel) + frame = av.VideoFrame.from_ndarray(rgb, format="rgb24") + for packet in stream.encode(frame): + container.mux(packet) + for packet in stream.encode(): + container.mux(packet) + finally: + container.close() + + +@pytest.fixture(scope="session") +def quality_gradient_video(tmp_path_factory: pytest.TempPathFactory) -> Path: + """A 20-second 15-fps 96x96 sharp/blur/sharp video; 300 frames total. + + Used by sampler quality-gradient tests to verify the within-bin scorer + prefers sharp candidates over blurred ones in mixed-content bins. + """ + path = tmp_path_factory.mktemp("videos") / "quality_gradient.mp4" + _encode_quality_gradient_video(path, n_frames=300, fps=15, width=96, height=96) + return path diff --git a/tests/fixtures/quality/canonical.json b/tests/fixtures/quality/canonical.json new file mode 100644 index 0000000..42a1099 --- /dev/null +++ b/tests/fixtures/quality/canonical.json @@ -0,0 +1,219 @@ +{ + "version": 1, + "purpose": "Bit-precision regression contract for compute_quality. Generator semantics live in tests/test_quality_golden.py::_build_frame and the Rust port replays this fixture against its own implementation.", + "tolerance": 1e-06, + "frames": [ + { + "id": "solid_black", + "spec": { + "kind": "solid", + "rgb": [ + 0, + 0, + 0 + ], + "size": [ + 32, + 32 + ] + }, + "expected": { + "laplacian_var": 0.0, + "mean_luma": 0.062745, + "luma_variance": 0.0, + "entropy": 0.0, + "saliency_mass": 0.0 + } + }, + { + "id": "solid_white", + "spec": { + "kind": "solid", + "rgb": [ + 255, + 255, + 255 + ], + "size": [ + 32, + 32 + ] + }, + "expected": { + "laplacian_var": 0.0, + "mean_luma": 0.921569, + "luma_variance": 0.0, + "entropy": 0.0, + "saliency_mass": 0.0 + } + }, + { + "id": "solid_gray_128", + "spec": { + "kind": "solid", + "rgb": [ + 128, + 128, + 128 + ], + "size": [ + 32, + 32 + ] + }, + "expected": { + "laplacian_var": 0.0, + "mean_luma": 0.494118, + "luma_variance": 0.0, + "entropy": 0.0, + "saliency_mass": 0.0 + } + }, + { + "id": "channel_red", + "spec": { + "kind": "channel", + "channel": "r", + "value": 255, + "size": [ + 32, + 32 + ] + }, + "expected": { + "laplacian_var": 0.0, + "mean_luma": 0.321569, + "luma_variance": 0.0, + "entropy": 0.0, + "saliency_mass": 0.0 + } + }, + { + "id": "channel_green", + "spec": { + "kind": "channel", + "channel": "g", + "value": 255, + "size": [ + 32, + 32 + ] + }, + "expected": { + "laplacian_var": 0.0, + "mean_luma": 0.564706, + "luma_variance": 0.0, + "entropy": 0.0, + "saliency_mass": 0.0 + } + }, + { + "id": "channel_blue", + "spec": { + "kind": "channel", + "channel": "b", + "value": 255, + "size": [ + 32, + 32 + ] + }, + "expected": { + "laplacian_var": 0.0, + "mean_luma": 0.160784, + "luma_variance": 0.0, + "entropy": 0.0, + "saliency_mass": 0.0 + } + }, + { + "id": "h_gradient_0_255_64w", + "spec": { + "kind": "h_gradient", + "low": 0, + "high": 255, + "size": [ + 32, + 64 + ] + }, + "expected": { + "laplacian_var": 0.82232, + "mean_luma": 0.490809, + "luma_variance": 4122.300928, + "entropy": 6.0, + "saliency_mass": 0.0 + } + }, + { + "id": "v_gradient_0_255_64h", + "spec": { + "kind": "v_gradient", + "low": 0, + "high": 255, + "size": [ + 64, + 32 + ] + }, + "expected": { + "laplacian_var": 0.82232, + "mean_luma": 0.490809, + "luma_variance": 4122.300928, + "entropy": 6.0, + "saliency_mass": 0.0 + } + }, + { + "id": "checker_8px_0_255", + "spec": { + "kind": "checker", + "cell": 8, + "low": 0, + "high": 255, + "size": [ + 32, + 32 + ] + }, + "expected": { + "laplacian_var": 23021.28, + "mean_luma": 0.492157, + "luma_variance": 12001.970674, + "entropy": 1.0, + "saliency_mass": 0.0 + } + }, + { + "id": "single_pixel_100_in_5x5", + "spec": { + "kind": "single_pixel", + "bg": [ + 0, + 0, + 0 + ], + "fg": [ + 100, + 100, + 100 + ], + "position": [ + 2, + 2 + ], + "size": [ + 5, + 5 + ] + }, + "expected": { + "laplacian_var": 16435.555556, + "mean_luma": 0.076235, + "luma_variance": 295.84, + "entropy": 0.242292, + "saliency_mass": 0.0 + } + } + ] +} diff --git a/tests/fixtures/regression/extract_all_synthetic.json b/tests/fixtures/regression/extract_all_synthetic.json new file mode 100644 index 0000000..a195266 --- /dev/null +++ b/tests/fixtures/regression/extract_all_synthetic.json @@ -0,0 +1,79 @@ +{ + "version": 1, + "purpose": "Deterministic regression snapshot for extract_all over a synthetic in-memory decoder. Stand-in for the Kino Demo regression fixture in TASKS.md T5; replace when the real asset becomes available.", + "input": { + "decoder": "in-memory _FakeDecoder, 5.0 s @ 30 fps, 64x64 noise frames", + "shots": [ + "[0.0, 0.4) s", + "[0.5, 4.5) s" + ], + "config": "SamplingConfig(target_interval_sec=1.0, target_size=64)" + }, + "entries": [ + { + "shot_id": 0, + "bucket": 0, + "timestamp_sec": 0.033333, + "confidence": "high", + "quality": { + "laplacian_var": 13288.401914, + "mean_luma": 0.485635, + "luma_variance": 633.283287, + "entropy": 6.63957, + "saliency_mass": 0.0 + } + }, + { + "shot_id": 1, + "bucket": 0, + "timestamp_sec": 1.3, + "confidence": "high", + "quality": { + "laplacian_var": 13002.335297, + "mean_luma": 0.481751, + "luma_variance": 650.178854, + "entropy": 6.654378, + "saliency_mass": 0.0 + } + }, + { + "shot_id": 1, + "bucket": 1, + "timestamp_sec": 1.766667, + "confidence": "high", + "quality": { + "laplacian_var": 12850.368781, + "mean_luma": 0.484072, + "luma_variance": 639.367643, + "entropy": 6.645861, + "saliency_mass": 0.0 + } + }, + { + "shot_id": 1, + "bucket": 2, + "timestamp_sec": 3.066667, + "confidence": "high", + "quality": { + "laplacian_var": 13034.782737, + "mean_luma": 0.484341, + "luma_variance": 643.519364, + "entropy": 6.641683, + "saliency_mass": 0.0 + } + }, + { + "shot_id": 1, + "bucket": 3, + "timestamp_sec": 3.866667, + "confidence": "high", + "quality": { + "laplacian_var": 12932.11908, + "mean_luma": 0.485531, + "luma_variance": 627.445081, + "entropy": 6.625787, + "saliency_mass": 0.0 + } + } + ] +} diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py new file mode 100644 index 0000000..839dd39 --- /dev/null +++ b/tests/test_benchmark.py @@ -0,0 +1,74 @@ +"""Smoke test for ``benchmarks/bench_e2e.py``. + +The script is meant to be run as a CLI; we exercise it via ``subprocess`` to +verify it boots, writes a row to ``results.md``, and exits 0 on a tiny +fixture video. Numerical thresholds belong in the script's output, not in +this test — we only check structure. +""" + +from __future__ import annotations + +import json +import subprocess +import sys +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from pathlib import Path + + +_BENCH_SCRIPT = "benchmarks/bench_e2e.py" + + +def test_bench_runs_and_writes_results_md(tmp_path: Path, varied_video: Path): + results_md = tmp_path / "results.md" + proc = subprocess.run( + [ + sys.executable, + _BENCH_SCRIPT, + "--video", + str(varied_video), + "--results-md", + str(results_md), + "--target-size", + "64", + ], + capture_output=True, + text=True, + check=False, + ) + assert proc.returncode == 0, f"stderr: {proc.stderr}" + + summary = json.loads(proc.stdout) + assert summary["video"] == str(varied_video) + assert summary["wall_sec"] > 0 + assert summary["keyframes"] >= 1 + assert summary["target_size"] == 64 + + assert results_md.is_file() + content = results_md.read_text() + assert "findit-keyframe benchmarks" in content + assert varied_video.name in content + + +def test_bench_quiet_suppresses_stdout(tmp_path: Path, varied_video: Path): + results_md = tmp_path / "results.md" + proc = subprocess.run( + [ + sys.executable, + _BENCH_SCRIPT, + "--video", + str(varied_video), + "--results-md", + str(results_md), + "--target-size", + "64", + "--quiet", + ], + capture_output=True, + text=True, + check=False, + ) + assert proc.returncode == 0, f"stderr: {proc.stderr}" + assert proc.stdout.strip() == "" + assert results_md.is_file() # row still appended diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 0000000..00603f4 --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,369 @@ +"""Tests for ``findit_keyframe.cli``. + +Covers: + +* JSON parsers for shot lists and config overrides. +* The ``extract`` subcommand end-to-end on the ``varied_video`` fixture + (manifest contents, JPEG validity via PyAV round-trip). +* Exit codes per ``TASKS.md`` §7: 0 success, 1 input error, 2 extraction error. +""" + +from __future__ import annotations + +import json +import platform +import sys +from typing import TYPE_CHECKING + +import av +import pytest + +from findit_keyframe.cli import _parse_config_json, _parse_shot_json, main + +if TYPE_CHECKING: + from pathlib import Path + +_IS_DARWIN = platform.system() == "Darwin" + + +def _write_shots_json(path: Path, shots: list[dict]) -> None: + path.write_text(json.dumps({"shots": shots})) + + +# --------------------------------------------------------------------------- # +# _parse_shot_json # +# --------------------------------------------------------------------------- # + + +class TestParseShotJson: + def test_basic(self, tmp_path: Path): + p = tmp_path / "shots.json" + _write_shots_json( + p, + [ + {"id": 0, "start_pts": 0, "end_pts": 1000, "timebase_num": 1, "timebase_den": 1000}, + { + "id": 1, + "start_pts": 1000, + "end_pts": 2000, + "timebase_num": 1, + "timebase_den": 1000, + }, + ], + ) + shots = _parse_shot_json(p) + assert len(shots) == 2 + assert shots[0].start.seconds == 0.0 + assert shots[0].end.seconds == 1.0 + assert shots[1].start.seconds == 1.0 + + def test_supports_video_timebase(self, tmp_path: Path): + p = tmp_path / "shots.json" + _write_shots_json( + p, + [{"id": 0, "start_pts": 0, "end_pts": 90000, "timebase_num": 1, "timebase_den": 90000}], + ) + shots = _parse_shot_json(p) + assert shots[0].end.seconds == 1.0 + + def test_missing_shots_key_raises(self, tmp_path: Path): + p = tmp_path / "shots.json" + p.write_text(json.dumps({"other": []})) + with pytest.raises(KeyError): + _parse_shot_json(p) + + def test_invalid_shot_range_raises(self, tmp_path: Path): + # end <= start violates ShotRange invariant. + p = tmp_path / "shots.json" + _write_shots_json( + p, + [{"id": 0, "start_pts": 1000, "end_pts": 500, "timebase_num": 1, "timebase_den": 1000}], + ) + with pytest.raises(ValueError, match="end"): + _parse_shot_json(p) + + +# --------------------------------------------------------------------------- # +# _parse_config_json # +# --------------------------------------------------------------------------- # + + +class TestParseConfigJson: + def test_none_returns_defaults(self): + c = _parse_config_json(None) + assert c.target_size == 384 + assert c.candidates_per_bin == 6 + + def test_overrides_applied(self, tmp_path: Path): + p = tmp_path / "cfg.json" + p.write_text(json.dumps({"target_size": 256, "candidates_per_bin": 8})) + c = _parse_config_json(p) + assert c.target_size == 256 + assert c.candidates_per_bin == 8 + # Unspecified fields keep defaults. + assert c.max_frames_per_shot == 16 + + def test_unknown_field_raises(self, tmp_path: Path): + p = tmp_path / "cfg.json" + p.write_text(json.dumps({"foo": 1})) + with pytest.raises(ValueError, match="Unknown"): + _parse_config_json(p) + + +# --------------------------------------------------------------------------- # +# main: --help # +# --------------------------------------------------------------------------- # + + +class TestCliHelp: + def test_help_exits_zero( + self, monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str] + ): + monkeypatch.setattr(sys, "argv", ["findit-keyframe", "--help"]) + with pytest.raises(SystemExit) as excinfo: + main() + assert excinfo.value.code == 0 + out = capsys.readouterr().out + assert "extract" in out + + +# --------------------------------------------------------------------------- # +# main: extract — end-to-end # +# --------------------------------------------------------------------------- # + + +class TestCliExtract: + @staticmethod + def _setup(tmp_path: Path) -> tuple[Path, Path]: + shots_path = tmp_path / "shots.json" + _write_shots_json( + shots_path, + [ + {"id": 0, "start_pts": 0, "end_pts": 500, "timebase_num": 1, "timebase_den": 1000}, + { + "id": 1, + "start_pts": 600, + "end_pts": 1200, + "timebase_num": 1, + "timebase_den": 1000, + }, + ], + ) + return shots_path, tmp_path / "out" + + def test_writes_jpegs_and_manifest( + self, tmp_path: Path, varied_video: Path, monkeypatch: pytest.MonkeyPatch + ): + shots_path, out_dir = self._setup(tmp_path) + monkeypatch.setattr( + sys, + "argv", + [ + "findit-keyframe", + "extract", + str(varied_video), + str(shots_path), + str(out_dir), + ], + ) + rc = main() + assert rc == 0 + + manifest_path = out_dir / "manifest.json" + assert manifest_path.is_file() + manifest = json.loads(manifest_path.read_text()) + assert manifest["video"] == str(varied_video) + assert len(manifest["keyframes"]) >= 2 + + for entry in manifest["keyframes"]: + assert { + "shot_id", + "bucket", + "file", + "timestamp_sec", + "quality", + "confidence", + } <= entry.keys() + jpeg = out_dir / entry["file"] + assert jpeg.is_file() + assert jpeg.stat().st_size > 100 + + def test_manifest_quality_fields_present( + self, tmp_path: Path, varied_video: Path, monkeypatch: pytest.MonkeyPatch + ): + shots_path, out_dir = self._setup(tmp_path) + monkeypatch.setattr( + sys, + "argv", + ["findit-keyframe", "extract", str(varied_video), str(shots_path), str(out_dir)], + ) + assert main() == 0 + manifest = json.loads((out_dir / "manifest.json").read_text()) + q = manifest["keyframes"][0]["quality"] + assert { + "laplacian_var", + "mean_luma", + "luma_variance", + "entropy", + "saliency_mass", + } <= q.keys() + + def test_jpeg_round_trips( + self, tmp_path: Path, varied_video: Path, monkeypatch: pytest.MonkeyPatch + ): + shots_path, out_dir = self._setup(tmp_path) + monkeypatch.setattr( + sys, + "argv", + ["findit-keyframe", "extract", str(varied_video), str(shots_path), str(out_dir)], + ) + assert main() == 0 + manifest = json.loads((out_dir / "manifest.json").read_text()) + # Decode the first JPEG back through PyAV; succeeds iff the file is valid. + jpeg = out_dir / manifest["keyframes"][0]["file"] + with av.open(str(jpeg)) as container: + frame = next(container.decode(video=0)) + assert frame.width > 0 + assert frame.height > 0 + + def test_filename_pattern_uses_shot_and_bucket( + self, tmp_path: Path, varied_video: Path, monkeypatch: pytest.MonkeyPatch + ): + shots_path, out_dir = self._setup(tmp_path) + monkeypatch.setattr( + sys, + "argv", + ["findit-keyframe", "extract", str(varied_video), str(shots_path), str(out_dir)], + ) + assert main() == 0 + manifest = json.loads((out_dir / "manifest.json").read_text()) + for entry in manifest["keyframes"]: + expected = f"kf_{entry['shot_id']:03d}_{entry['bucket']:03d}.jpg" + assert entry["file"] == expected + + +# --------------------------------------------------------------------------- # +# Exit codes # +# --------------------------------------------------------------------------- # + + +class TestExitCodes: + def test_bad_shots_json_returns_input_error( + self, tmp_path: Path, varied_video: Path, monkeypatch: pytest.MonkeyPatch + ): + bad = tmp_path / "bad.json" + bad.write_text("not json") + monkeypatch.setattr( + sys, + "argv", + ["findit-keyframe", "extract", str(varied_video), str(bad), str(tmp_path / "out")], + ) + assert main() == 1 + + def test_unknown_config_field_returns_input_error( + self, tmp_path: Path, varied_video: Path, monkeypatch: pytest.MonkeyPatch + ): + shots = tmp_path / "shots.json" + _write_shots_json( + shots, + [{"id": 0, "start_pts": 0, "end_pts": 1000, "timebase_num": 1, "timebase_den": 1000}], + ) + cfg = tmp_path / "cfg.json" + cfg.write_text(json.dumps({"unknown_field": 42})) + monkeypatch.setattr( + sys, + "argv", + [ + "findit-keyframe", + "extract", + str(varied_video), + str(shots), + str(tmp_path / "out"), + "--config", + str(cfg), + ], + ) + assert main() == 1 + + def test_missing_video_returns_extraction_error( + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch + ): + shots = tmp_path / "shots.json" + _write_shots_json( + shots, + [{"id": 0, "start_pts": 0, "end_pts": 1000, "timebase_num": 1, "timebase_den": 1000}], + ) + monkeypatch.setattr( + sys, + "argv", + [ + "findit-keyframe", + "extract", + str(tmp_path / "nope.mp4"), + str(shots), + str(tmp_path / "out"), + ], + ) + assert main() == 2 + + @pytest.mark.skipif(_IS_DARWIN, reason="Apple Vision is available on macOS") + def test_saliency_apple_off_macos_returns_input_error( + self, tmp_path: Path, varied_video: Path, monkeypatch: pytest.MonkeyPatch + ): + shots = tmp_path / "shots.json" + _write_shots_json( + shots, + [{"id": 0, "start_pts": 0, "end_pts": 1000, "timebase_num": 1, "timebase_den": 1000}], + ) + monkeypatch.setattr( + sys, + "argv", + [ + "findit-keyframe", + "extract", + str(varied_video), + str(shots), + str(tmp_path / "out"), + "--saliency", + "apple", + ], + ) + # AppleVisionSaliencyProvider.__init__ raises RuntimeError off-Darwin; + # the CLI maps that to the input-error exit code. + assert main() == 1 + + +@pytest.mark.skipif(not _IS_DARWIN, reason="Apple Vision is macOS-only") +@pytest.mark.macos +class TestCliSaliencyApple: + def test_extract_with_apple_saliency_succeeds( + self, tmp_path: Path, varied_video: Path, monkeypatch: pytest.MonkeyPatch + ): + shots = tmp_path / "shots.json" + _write_shots_json( + shots, + [ + {"id": 0, "start_pts": 0, "end_pts": 500, "timebase_num": 1, "timebase_den": 1000}, + ], + ) + out_dir = tmp_path / "out" + monkeypatch.setattr( + sys, + "argv", + [ + "findit-keyframe", + "extract", + str(varied_video), + str(shots), + str(out_dir), + "--saliency", + "apple", + ], + ) + rc = main() + assert rc == 0 + + manifest = json.loads((out_dir / "manifest.json").read_text()) + # Saliency is in [0, 1]; non-negative is the meaningful contract. + for entry in manifest["keyframes"]: + assert 0.0 <= entry["quality"]["saliency_mass"] <= 1.0 diff --git a/tests/test_decoder.py b/tests/test_decoder.py new file mode 100644 index 0000000..d0cfcce --- /dev/null +++ b/tests/test_decoder.py @@ -0,0 +1,191 @@ +"""Tests for ``findit_keyframe.decoder``. + +The pure-logic pieces (``Strategy`` / ``pick_strategy``) get bare numeric +tests. The PyAV-backed pieces use a session-scoped fixture (``tiny_video``) +encoding a 1-second 30-fps ramp so we exercise real seek/decode paths. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np +import pytest + +from findit_keyframe.decoder import ( + DecodedFrame, + Strategy, + VideoDecoder, + pick_strategy, +) +from findit_keyframe.types import ShotRange, Timebase, Timestamp + +if TYPE_CHECKING: + from pathlib import Path + + +def _ts(seconds: float, tb: Timebase) -> Timestamp: + """Build a Timestamp for ``seconds`` in ``tb``.""" + return Timestamp(round(seconds * tb.den / tb.num), tb) + + +def _make_back_to_back_shots(n: int, duration: float) -> list[ShotRange]: + tb = Timebase(1, 1000) + span = duration / n + return [ + ShotRange( + start=Timestamp(round(i * span * 1000), tb), + end=Timestamp(round((i + 1) * span * 1000), tb), + ) + for i in range(n) + ] + + +# --------------------------------------------------------------------------- # +# pick_strategy # +# --------------------------------------------------------------------------- # + + +class TestPickStrategy: + def test_empty_shots_returns_per_shot_seek(self): + assert pick_strategy([], 60.0) is Strategy.PerShotSeek + + def test_zero_duration_returns_per_shot_seek(self): + assert pick_strategy(_make_back_to_back_shots(10, 60.0), 0.0) is Strategy.PerShotSeek + + def test_low_density_returns_per_shot_seek(self): + # 10 shots / 60 s = 0.166 shots/s, well below 0.3 threshold. + assert pick_strategy(_make_back_to_back_shots(10, 60.0), 60.0) is Strategy.PerShotSeek + + def test_high_density_returns_sequential(self): + # 30 shots / 60 s = 0.5 shots/s, above threshold. + assert pick_strategy(_make_back_to_back_shots(30, 60.0), 60.0) is Strategy.Sequential + + def test_huge_count_returns_sequential(self): + # 250 shots / 1000 s = 0.25 shots/s (below density), but count > 200. + assert pick_strategy(_make_back_to_back_shots(250, 1000.0), 1000.0) is Strategy.Sequential + + def test_density_threshold_is_strict(self): + # Threshold is `> 0.3`, exclusive. Exactly 0.3 stays at PerShotSeek. + assert pick_strategy(_make_back_to_back_shots(30, 100.0), 100.0) is Strategy.PerShotSeek + + +# --------------------------------------------------------------------------- # +# VideoDecoder — open & metadata # +# --------------------------------------------------------------------------- # + + +class TestVideoDecoderOpen: + def test_metadata(self, tiny_video: Path): + with VideoDecoder.open(tiny_video) as dec: + assert dec.fps == pytest.approx(30.0, abs=0.5) + assert dec.duration_sec == pytest.approx(1.0, abs=0.1) + assert dec.width == 64 + assert dec.height == 64 + + def test_target_size_resizes(self, tiny_video: Path): + with VideoDecoder.open(tiny_video, target_size=32) as dec: + assert dec.width == 32 + assert dec.height == 32 + f = dec.decode_at(0.0) + assert f.rgb.shape == (32, 32, 3) + assert f.rgb.dtype == np.uint8 + + def test_close_releases_container(self, tiny_video: Path): + dec = VideoDecoder.open(tiny_video) + dec.close() + # Subsequent decode must fail because the container is closed. + with pytest.raises(Exception): # noqa: B017, PT011 — PyAV raises various + dec.decode_at(0.0) + + +# --------------------------------------------------------------------------- # +# VideoDecoder — decode_at # +# --------------------------------------------------------------------------- # + + +class TestDecodeAt: + def test_first_frame_returns_decoded_frame(self, tiny_video: Path): + with VideoDecoder.open(tiny_video) as dec: + f = dec.decode_at(0.0) + assert isinstance(f, DecodedFrame) + assert f.rgb.shape == (64, 64, 3) + assert f.rgb.dtype == np.uint8 + + def test_first_frame_pts_near_zero(self, tiny_video: Path): + with VideoDecoder.open(tiny_video) as dec: + f = dec.decode_at(0.0) + assert f.pts.seconds < 1.0 / 30.0 + + def test_mid_frame_pts_within_one_frame(self, tiny_video: Path): + with VideoDecoder.open(tiny_video) as dec: + target = 0.5 + f = dec.decode_at(target) + assert abs(f.pts.seconds - target) < 1.0 / 30.0 + 1e-3 + + def test_decoded_frame_dimensions_match_target_size(self, tiny_video: Path): + with VideoDecoder.open(tiny_video, target_size=32) as dec: + f = dec.decode_at(0.5) + assert f.width == 32 + assert f.height == 32 + assert f.rgb.shape == (32, 32, 3) + + def test_seek_past_end_raises(self, tiny_video: Path): + with VideoDecoder.open(tiny_video) as dec, pytest.raises(ValueError, match="Could not"): + dec.decode_at(10.0) + + +# --------------------------------------------------------------------------- # +# VideoDecoder — decode_sequential # +# --------------------------------------------------------------------------- # + + +class TestDecodeSequential: + def test_empty_shots_yields_nothing(self, tiny_video: Path): + with VideoDecoder.open(tiny_video) as dec: + assert list(dec.decode_sequential([])) == [] + + def test_full_coverage_yields_every_frame(self, tiny_video: Path): + with VideoDecoder.open(tiny_video) as dec: + shots = [ShotRange(start=_ts(0.0, dec.timebase), end=_ts(2.0, dec.timebase))] + frames = list(dec.decode_sequential(shots)) + # Allow ±2 frames slack for VFR-ish PTS quantisation at the boundary. + assert 28 <= len(frames) <= 32 + assert all(shot_id == 0 for shot_id, _ in frames) + # Frames are emitted in PTS order. + ptses = [f.pts.seconds for _, f in frames] + assert ptses == sorted(ptses) + + def test_disjoint_shots_yield_subsets_with_correct_ids(self, tiny_video: Path): + with VideoDecoder.open(tiny_video) as dec: + tb = dec.timebase + shots = [ + ShotRange(start=_ts(0.0, tb), end=_ts(0.3, tb)), # ~9 frames + ShotRange(start=_ts(0.6, tb), end=_ts(0.8, tb)), # ~6 frames + ] + frames = list(dec.decode_sequential(shots)) + ids = {shot_id for shot_id, _ in frames} + assert ids == {0, 1} + # No frame should fall in the [0.3, 0.6) gap. + for shot_id, f in frames: + t = f.pts.seconds + if shot_id == 0: + assert 0.0 <= t < 0.3 + 1e-3 + else: + assert 0.6 - 1e-3 <= t < 0.8 + 1e-3 + + def test_unsorted_shot_input_handled(self, tiny_video: Path): + # Original shot indices must be preserved when shots are passed out of order. + with VideoDecoder.open(tiny_video) as dec: + tb = dec.timebase + shots = [ + ShotRange(start=_ts(0.6, tb), end=_ts(0.8, tb)), # id 0, later in time + ShotRange(start=_ts(0.0, tb), end=_ts(0.3, tb)), # id 1, earlier + ] + frames = list(dec.decode_sequential(shots)) + # Frames at t < 0.3 must be tagged with shot id 1 (the earlier one). + for shot_id, f in frames: + if f.pts.seconds < 0.3: + assert shot_id == 1 + elif 0.6 <= f.pts.seconds < 0.8: + assert shot_id == 0 diff --git a/tests/test_quality.py b/tests/test_quality.py new file mode 100644 index 0000000..63d925d --- /dev/null +++ b/tests/test_quality.py @@ -0,0 +1,293 @@ +"""Tests for ``findit_keyframe.quality``. + +Every numeric expectation is computed by hand from the spec in +``TASKS.md`` §4 and ``docs/algorithm.md`` so the Rust port can replay the +same assertions against the same inputs. +""" + +from __future__ import annotations + +import math +import time + +import numpy as np +import pytest + +from findit_keyframe.quality import ( + QualityGate, + compute_quality, + entropy, + laplacian_variance, + luma_variance, + mean_luma, + rgb_to_luma, +) +from findit_keyframe.types import QualityMetrics + + +def _solid_rgb(value: int, h: int = 32, w: int = 32) -> np.ndarray: + return np.full((h, w, 3), value, dtype=np.uint8) + + +# --------------------------------------------------------------------------- # +# rgb_to_luma # +# --------------------------------------------------------------------------- # + + +class TestRgbToLuma: + def test_rejects_non_uint8(self): + with pytest.raises(ValueError, match="uint8"): + rgb_to_luma(np.zeros((4, 4, 3), dtype=np.float32)) + + def test_rejects_2d(self): + with pytest.raises(ValueError, match="shape"): + rgb_to_luma(np.zeros((4, 4), dtype=np.uint8)) + + def test_rejects_4_channel(self): + with pytest.raises(ValueError, match="shape"): + rgb_to_luma(np.zeros((4, 4, 4), dtype=np.uint8)) + + def test_pure_black_yields_16(self): + # BT.601 limited range: black -> Y = ((0+0+0+128) >> 8) + 16 = 16. + y = rgb_to_luma(_solid_rgb(0)) + assert y.dtype == np.uint8 + assert y.shape == (32, 32) + assert (y == 16).all() + + def test_pure_white_yields_235(self): + # Y = ((66+129+25)*255 + 128) >> 8 + 16 = 56228 >> 8 + 16 = 219 + 16 = 235. + y = rgb_to_luma(_solid_rgb(255)) + assert (y == 235).all() + + def test_pure_red(self): + # Y = ((66*255 + 128) >> 8) + 16 = 16958 >> 8 + 16 = 66 + 16 = 82. + rgb = np.zeros((4, 4, 3), dtype=np.uint8) + rgb[..., 0] = 255 + assert (rgb_to_luma(rgb) == 82).all() + + def test_pure_green(self): + # Y = ((129*255 + 128) >> 8) + 16 = 33023 >> 8 + 16 = 128 + 16 = 144. + # (33023 = 128 * 256 + 255, so the shift truncates to 128.) + rgb = np.zeros((4, 4, 3), dtype=np.uint8) + rgb[..., 1] = 255 + assert (rgb_to_luma(rgb) == 144).all() + + def test_pure_blue(self): + # Y = ((25*255 + 128) >> 8) + 16 = 6503 >> 8 + 16 = 25 + 16 = 41. + rgb = np.zeros((4, 4, 3), dtype=np.uint8) + rgb[..., 2] = 255 + assert (rgb_to_luma(rgb) == 41).all() + + +# --------------------------------------------------------------------------- # +# laplacian_variance # +# --------------------------------------------------------------------------- # + + +class TestLaplacianVariance: + def test_rejects_non_2d(self): + with pytest.raises(ValueError, match="2D"): + laplacian_variance(np.zeros((4, 4, 3), dtype=np.uint8)) + + def test_rejects_too_small(self): + with pytest.raises(ValueError, match="3x3"): + laplacian_variance(np.zeros((2, 2), dtype=np.uint8)) + + def test_uniform_image_zero(self): + luma = np.full((16, 16), 128, dtype=np.uint8) + assert laplacian_variance(luma) == 0.0 + + def test_random_noise_is_high(self): + rng = np.random.default_rng(seed=42) + luma = rng.integers(0, 256, size=(64, 64), dtype=np.uint8) + assert laplacian_variance(luma) > 1000.0 + + def test_smooth_gradient_is_low(self): + # Linear horizontal ramp: the discrete second derivative is exactly 0 + # in the interior. Laplacian variance therefore collapses to ~0. + luma = np.tile(np.arange(64, dtype=np.uint8), (64, 1)) + assert laplacian_variance(luma) < 1.0 + + def test_isolated_spike_known(self): + # Single bright pixel at the centre of a 5x5 zero field. + # Filtered values across the 3x3 interior: + # corners = 0, edges = +100, centre = -400 + # mean = 0, variance (ddof=0) = (4*10000 + 160000) / 9 = 200000/9. + luma = np.zeros((5, 5), dtype=np.uint8) + luma[2, 2] = 100 + assert laplacian_variance(luma) == pytest.approx(200000.0 / 9.0) + + def test_returns_finite(self): + rng = np.random.default_rng(seed=0) + luma = rng.integers(0, 256, size=(16, 16), dtype=np.uint8) + assert math.isfinite(laplacian_variance(luma)) + + +# --------------------------------------------------------------------------- # +# mean_luma # +# --------------------------------------------------------------------------- # + + +class TestMeanLuma: + def test_uniform_zero(self): + assert mean_luma(np.zeros((4, 4), dtype=np.uint8)) == 0.0 + + def test_uniform_255(self): + assert mean_luma(np.full((4, 4), 255, dtype=np.uint8)) == 1.0 + + def test_uniform_128(self): + assert mean_luma(np.full((4, 4), 128, dtype=np.uint8)) == pytest.approx(128 / 255) + + +# --------------------------------------------------------------------------- # +# luma_variance # +# --------------------------------------------------------------------------- # + + +class TestLumaVariance: + def test_uniform_zero(self): + assert luma_variance(np.full((8, 8), 128, dtype=np.uint8)) == 0.0 + + def test_two_value_known(self): + # Sample variance (ddof=1) of [0, 255, 0, 255]: + # mean = 127.5 + # sum((x - mean)^2) = 4 * 127.5^2 = 65025 + # variance = 65025 / (4 - 1) = 21675 + luma = np.array([[0, 255], [0, 255]], dtype=np.uint8) + assert luma_variance(luma) == pytest.approx(21675.0) + + +# --------------------------------------------------------------------------- # +# entropy # +# --------------------------------------------------------------------------- # + + +class TestEntropy: + def test_uniform_distribution_max(self): + # Each value 0..255 appears once -> H = log2(256) = 8. + luma = np.arange(256, dtype=np.uint8).reshape(16, 16) + assert entropy(luma) == pytest.approx(8.0) + + def test_constant_zero(self): + # Delta distribution -> H = 0. + luma = np.full((16, 16), 100, dtype=np.uint8) + assert entropy(luma) == 0.0 + + def test_two_value_one_bit(self): + # Half 0, half 255 with equal counts -> H = 1. + luma = np.array([[0] * 8, [255] * 8] * 8, dtype=np.uint8) + assert entropy(luma) == pytest.approx(1.0) + + def test_custom_bins(self): + # 4 bins, uniform across them -> H = log2(4) = 2. + luma = np.array([0, 64, 128, 192] * 4, dtype=np.uint8).reshape(4, 4) + assert entropy(luma, bins=4) == pytest.approx(2.0) + + +# --------------------------------------------------------------------------- # +# QualityGate # +# --------------------------------------------------------------------------- # + + +class TestQualityGate: + def _metrics(self, **overrides): + defaults = { + "laplacian_var": 100.0, + "mean_luma": 0.5, + "luma_variance": 1000.0, + "entropy": 7.0, + "saliency_mass": 0.0, + } + return QualityMetrics(**(defaults | overrides)) + + def test_default_thresholds_match_spec(self): + gate = QualityGate() + assert gate.min_mean_luma == pytest.approx(15.0 / 255.0) + assert gate.max_mean_luma == pytest.approx(240.0 / 255.0) + assert gate.min_luma_variance == 5.0 + + def test_normal_frame_passes(self): + assert QualityGate().passes(self._metrics()) is True + + def test_too_dark_rejected(self): + assert QualityGate().passes(self._metrics(mean_luma=0.05)) is False + + def test_too_bright_rejected(self): + assert QualityGate().passes(self._metrics(mean_luma=0.99)) is False + + def test_flat_rejected(self): + assert QualityGate().passes(self._metrics(luma_variance=4.99)) is False + + def test_lower_mean_boundary_inclusive(self): + assert QualityGate().passes(self._metrics(mean_luma=15.0 / 255.0)) is True + + def test_upper_mean_boundary_inclusive(self): + assert QualityGate().passes(self._metrics(mean_luma=240.0 / 255.0)) is True + + def test_variance_boundary_inclusive(self): + assert QualityGate().passes(self._metrics(luma_variance=5.0)) is True + + +# --------------------------------------------------------------------------- # +# compute_quality # +# --------------------------------------------------------------------------- # + + +class TestComputeQuality: + def test_returns_quality_metrics(self): + rgb = np.full((16, 16, 3), 128, dtype=np.uint8) + assert isinstance(compute_quality(rgb), QualityMetrics) + + def test_default_saliency_is_zero(self): + rgb = np.full((16, 16, 3), 128, dtype=np.uint8) + assert compute_quality(rgb).saliency_mass == 0.0 + + def test_saliency_passed_through(self): + rgb = np.full((16, 16, 3), 128, dtype=np.uint8) + assert compute_quality(rgb, saliency=0.42).saliency_mass == pytest.approx(0.42) + + def test_black_frame_metrics(self): + # All-black RGB -> Y = 16 everywhere; everything spreading to zero. + rgb = np.zeros((32, 32, 3), dtype=np.uint8) + m = compute_quality(rgb) + assert m.mean_luma == pytest.approx(16 / 255) + assert m.luma_variance == 0.0 + assert m.laplacian_var == 0.0 + assert m.entropy == 0.0 + # Gate rejects on luma_variance < 5. + assert QualityGate().passes(m) is False + + def test_random_noise_passes_gate(self): + rng = np.random.default_rng(seed=0) + rgb = rng.integers(0, 256, size=(64, 64, 3), dtype=np.uint8) + m = compute_quality(rgb) + assert m.luma_variance > 5.0 + assert m.laplacian_var > 100.0 + assert QualityGate().passes(m) is True + + def test_smooth_gradient_passes_gate(self): + ramp = np.linspace(0, 255, 64, dtype=np.uint8) + rgb = np.stack([np.tile(ramp, (64, 1))] * 3, axis=-1) + m = compute_quality(rgb) + assert m.luma_variance > 100.0 + assert m.laplacian_var < 5.0 + assert QualityGate().passes(m) is True + + +# --------------------------------------------------------------------------- # +# Performance # +# --------------------------------------------------------------------------- # + + +@pytest.mark.slow +def test_compute_quality_performance_budget(): + """TASKS.md T4: <5 ms on M-series Mac. CI runners can be ~3x slower.""" + rgb = np.random.default_rng(0).integers(0, 256, size=(384, 384, 3), dtype=np.uint8) + for _ in range(3): + compute_quality(rgb) + n = 20 + t0 = time.perf_counter() + for _ in range(n): + compute_quality(rgb) + avg_ms = (time.perf_counter() - t0) / n * 1000 + assert avg_ms < 15.0, f"compute_quality avg {avg_ms:.2f} ms exceeds 15 ms budget" diff --git a/tests/test_quality_golden.py b/tests/test_quality_golden.py new file mode 100644 index 0000000..1d426f5 --- /dev/null +++ b/tests/test_quality_golden.py @@ -0,0 +1,218 @@ +"""Bit-precision regression contract for :func:`findit_keyframe.quality.compute_quality`. + +The JSON fixture at ``tests/fixtures/quality/canonical.json`` encodes both +the input frame specifications (purely deterministic, integer-arithmetic +generators — no PRNG, so a Rust port can reproduce them byte-for-byte) and +the expected :class:`QualityMetrics` rounded to 6 decimal places. + +The Rust translation must replay this fixture: load the JSON, materialise +each frame using the same generator semantics (see ``_build_frame``), run +its own ``compute_quality`` and assert each field matches expected to +``1e-6``. + +To regenerate the fixture after an intentional algorithm change, run:: + + python tests/test_quality_golden.py + +Review the diff in ``canonical.json`` carefully before committing. +""" + +from __future__ import annotations + +import json +from dataclasses import asdict +from pathlib import Path + +import numpy as np + +from findit_keyframe.quality import compute_quality + +FIXTURE_PATH = Path(__file__).parent / "fixtures" / "quality" / "canonical.json" +TOLERANCE = 1e-6 + + +# --------------------------------------------------------------------------- # +# Frame generators (pure integer arithmetic, no PRNG) # +# --------------------------------------------------------------------------- # + + +def _build_frame(spec: dict) -> np.ndarray: + """Materialise a frame from its declarative ``spec`` dict. + + Generator catalogue (exhaustive — the Rust port re-implements these + same six kinds against the same JSON inputs): + + ``solid`` — ``{rgb: [R, G, B], size: [H, W]}`` + Every pixel equals ``rgb``. + + ``channel`` — ``{channel: "r" | "g" | "b", value: V, size: [H, W]}`` + Selected channel = ``V``; the other two channels = 0. + + ``h_gradient`` — ``{low: L, high: H, size: [H, W]}`` + ``pixel[y, x, c] = L + (H - L) * x // (W - 1)``. Integer division; + identical R, G, B per pixel. + + ``v_gradient`` — same as ``h_gradient`` but along the ``y`` axis. + + ``checker`` — ``{cell: K, low: L, high: H, size: [H, W]}`` + ``pixel[y, x, c] = H if ((y // K) + (x // K)) % 2 == 0 else L``; + identical R, G, B per pixel. + + ``single_pixel`` — + ``{bg: [B, B, B], fg: [F, F, F], position: [py, px], size: [H, W]}`` + Background ``bg`` everywhere, foreground ``fg`` at ``(py, px)``. + """ + height, width = spec["size"] + kind = spec["kind"] + + if kind == "solid": + r, g, b = spec["rgb"] + frame = np.empty((height, width, 3), dtype=np.uint8) + frame[..., 0] = r + frame[..., 1] = g + frame[..., 2] = b + return frame + + if kind == "channel": + channel_idx = {"r": 0, "g": 1, "b": 2}[spec["channel"]] + frame = np.zeros((height, width, 3), dtype=np.uint8) + frame[..., channel_idx] = spec["value"] + return frame + + if kind == "h_gradient": + low, high = spec["low"], spec["high"] + ramp = np.array( + [low + (high - low) * x // (width - 1) for x in range(width)], + dtype=np.uint8, + ) + frame = np.empty((height, width, 3), dtype=np.uint8) + for c in range(3): + frame[..., c] = ramp[None, :] + return frame + + if kind == "v_gradient": + low, high = spec["low"], spec["high"] + ramp = np.array( + [low + (high - low) * y // (height - 1) for y in range(height)], + dtype=np.uint8, + ) + frame = np.empty((height, width, 3), dtype=np.uint8) + for c in range(3): + frame[..., c] = ramp[:, None] + return frame + + if kind == "checker": + cell = spec["cell"] + low, high = spec["low"], spec["high"] + frame = np.full((height, width, 3), low, dtype=np.uint8) + for y in range(height): + for x in range(width): + if ((y // cell) + (x // cell)) % 2 == 0: + frame[y, x] = high + return frame + + if kind == "single_pixel": + bg = spec["bg"] + fg = spec["fg"] + py, px = spec["position"] + frame = np.empty((height, width, 3), dtype=np.uint8) + for c in range(3): + frame[..., c] = bg[c] + for c in range(3): + frame[py, px, c] = fg[c] + return frame + + raise ValueError(f"unknown spec kind: {kind!r}") + + +# --------------------------------------------------------------------------- # +# Test # +# --------------------------------------------------------------------------- # + + +def test_canonical_quality_metrics_match_golden_fixture(): + """Every field of every canonical frame must match expected to 6 decimals. + + A failure means ``compute_quality``'s output has shifted at the bit + level. If the shift is intentional, regenerate the fixture + (``python tests/test_quality_golden.py``) and audit the JSON diff + before committing. + """ + data = json.loads(FIXTURE_PATH.read_text()) + assert data["version"] == 1, f"unsupported fixture version: {data['version']}" + assert len(data["frames"]) >= 10, "fixture must contain at least 10 canonical frames" + + failures: list[str] = [] + for entry in data["frames"]: + frame = _build_frame(entry["spec"]) + actual = compute_quality(frame) + actual_dict = asdict(actual) + for field, expected in entry["expected"].items(): + got = actual_dict[field] + if abs(got - expected) > TOLERANCE: + failures.append(f" {entry['id']}.{field}: expected {expected:.6f}, got {got:.10f}") + if failures: + raise AssertionError("Golden fixture mismatch:\n" + "\n".join(failures)) + + +# --------------------------------------------------------------------------- # +# Regenerator (run directly: ``python tests/test_quality_golden.py``) # +# --------------------------------------------------------------------------- # + + +_CANONICAL_SPECS: list[tuple[str, dict]] = [ + ("solid_black", {"kind": "solid", "rgb": [0, 0, 0], "size": [32, 32]}), + ("solid_white", {"kind": "solid", "rgb": [255, 255, 255], "size": [32, 32]}), + ("solid_gray_128", {"kind": "solid", "rgb": [128, 128, 128], "size": [32, 32]}), + ("channel_red", {"kind": "channel", "channel": "r", "value": 255, "size": [32, 32]}), + ("channel_green", {"kind": "channel", "channel": "g", "value": 255, "size": [32, 32]}), + ("channel_blue", {"kind": "channel", "channel": "b", "value": 255, "size": [32, 32]}), + ("h_gradient_0_255_64w", {"kind": "h_gradient", "low": 0, "high": 255, "size": [32, 64]}), + ("v_gradient_0_255_64h", {"kind": "v_gradient", "low": 0, "high": 255, "size": [64, 32]}), + ("checker_8px_0_255", {"kind": "checker", "cell": 8, "low": 0, "high": 255, "size": [32, 32]}), + ( + "single_pixel_100_in_5x5", + { + "kind": "single_pixel", + "bg": [0, 0, 0], + "fg": [100, 100, 100], + "position": [2, 2], + "size": [5, 5], + }, + ), +] + + +def _regenerate_fixture() -> None: + """Recompute every canonical frame's metrics and rewrite the JSON file. + + Intended for human-driven runs after an intentional algorithm change. + """ + frames = [] + for frame_id, spec in _CANONICAL_SPECS: + frame = _build_frame(spec) + metrics = compute_quality(frame) + frames.append( + { + "id": frame_id, + "spec": spec, + "expected": {k: round(float(v), 6) for k, v in asdict(metrics).items()}, + } + ) + data = { + "version": 1, + "purpose": ( + "Bit-precision regression contract for compute_quality. " + "Generator semantics live in tests/test_quality_golden.py::_build_frame " + "and the Rust port replays this fixture against its own implementation." + ), + "tolerance": TOLERANCE, + "frames": frames, + } + FIXTURE_PATH.parent.mkdir(parents=True, exist_ok=True) + FIXTURE_PATH.write_text(json.dumps(data, indent=2) + "\n") + print(f"Wrote {len(frames)} frames to {FIXTURE_PATH}") + + +if __name__ == "__main__": + _regenerate_fixture() diff --git a/tests/test_saliency.py b/tests/test_saliency.py new file mode 100644 index 0000000..2f7df71 --- /dev/null +++ b/tests/test_saliency.py @@ -0,0 +1,108 @@ +"""Tests for ``findit_keyframe.saliency``. + +* The module must import cleanly on systems without ``pyobjc-framework-Vision``; + proven by collecting and running this file at all. +* ``NoopSaliencyProvider`` always returns ``0.0`` and is the default fallback. +* ``default_saliency_provider`` returns ``Noop`` on non-Darwin systems. +* ``AppleVisionSaliencyProvider`` is exercised end-to-end on macOS only; + skipped on Linux (CI Linux runner). +""" + +from __future__ import annotations + +import platform + +import numpy as np +import pytest + +from findit_keyframe.saliency import ( + AppleVisionSaliencyProvider, + NoopSaliencyProvider, + SaliencyProvider, + default_saliency_provider, +) + +_IS_DARWIN = platform.system() == "Darwin" + + +# --------------------------------------------------------------------------- # +# NoopSaliencyProvider # +# --------------------------------------------------------------------------- # + + +class TestNoopSaliencyProvider: + def test_satisfies_protocol(self): + # SaliencyProvider is runtime-checkable so the duck-type contract + # is asserted, not just structural shape. + assert isinstance(NoopSaliencyProvider(), SaliencyProvider) + + def test_returns_zero_for_any_input(self): + provider = NoopSaliencyProvider() + assert provider.compute(np.zeros((4, 4, 3), dtype=np.uint8)) == 0.0 + assert provider.compute(np.full((16, 16, 3), 255, dtype=np.uint8)) == 0.0 + rng = np.random.default_rng(seed=0) + noise = rng.integers(0, 256, size=(64, 64, 3), dtype=np.uint8) + assert provider.compute(noise) == 0.0 + + +# --------------------------------------------------------------------------- # +# default_saliency_provider # +# --------------------------------------------------------------------------- # + + +class TestDefaultProvider: + def test_returns_noop_on_non_darwin(self, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(platform, "system", lambda: "Linux") + provider = default_saliency_provider() + assert isinstance(provider, NoopSaliencyProvider) + + @pytest.mark.skipif(not _IS_DARWIN, reason="Apple Vision is macOS-only") + @pytest.mark.macos + def test_returns_apple_provider_on_macos(self): + provider = default_saliency_provider() + assert isinstance(provider, AppleVisionSaliencyProvider) + + +# --------------------------------------------------------------------------- # +# AppleVisionSaliencyProvider — instantiation guards # +# --------------------------------------------------------------------------- # + + +class TestAppleProviderGuards: + def test_non_darwin_instantiation_raises(self, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(platform, "system", lambda: "Linux") + with pytest.raises(RuntimeError, match="macOS"): + AppleVisionSaliencyProvider() + + +# --------------------------------------------------------------------------- # +# AppleVisionSaliencyProvider — actual Vision call (macOS only) # +# --------------------------------------------------------------------------- # + + +@pytest.mark.skipif(not _IS_DARWIN, reason="Apple Vision is macOS-only") +@pytest.mark.macos +class TestAppleVisionCompute: + def _provider(self) -> AppleVisionSaliencyProvider: + return AppleVisionSaliencyProvider() + + def test_centre_white_frame_has_attention(self): + # 128x128 black with a centred 64x64 white patch. + rgb = np.zeros((128, 128, 3), dtype=np.uint8) + rgb[32:96, 32:96] = 255 + score = self._provider().compute(rgb) + assert 0.0 <= score <= 1.0 + # The patch is highly attention-grabbing relative to a flat field. + assert score > 0.0 + + def test_uniform_frame_has_low_attention(self): + # A perfectly flat field has nothing salient. + rgb = np.full((128, 128, 3), 128, dtype=np.uint8) + score = self._provider().compute(rgb) + assert 0.0 <= score <= 1.0 + + def test_returns_float(self): + rgb = np.zeros((64, 64, 3), dtype=np.uint8) + rgb[16:48, 16:48] = 200 + score = self._provider().compute(rgb) + assert isinstance(score, float) diff --git a/tests/test_sampler.py b/tests/test_sampler.py new file mode 100644 index 0000000..7c2a30e --- /dev/null +++ b/tests/test_sampler.py @@ -0,0 +1,396 @@ +"""Tests for ``findit_keyframe.sampler``. + +Pure-function tests (binning, scoring, selection) get bare numeric checks. +Integration tests use the ``varied_video`` and ``tiny_video`` fixtures so +the fallback path (low/degraded confidence) is exercised on a real decoder. +""" + +from __future__ import annotations + +from itertools import pairwise +from typing import TYPE_CHECKING + +import numpy as np +import pytest + +from findit_keyframe.decoder import DecodedFrame, VideoDecoder +from findit_keyframe.quality import QualityGate +from findit_keyframe.sampler import ( + _candidate_times, + _ordinal_rank, + compute_bins, + extract_all, + extract_for_shot, + score_bin_candidates, + select_from_bin, +) +from findit_keyframe.types import ( + Confidence, + ExtractedKeyframe, + QualityMetrics, + SamplingConfig, + ShotRange, + Timebase, + Timestamp, +) + +if TYPE_CHECKING: + from pathlib import Path + + +def _shot(start_sec: float, end_sec: float) -> ShotRange: + tb = Timebase(1, 1000) + return ShotRange( + start=Timestamp(round(start_sec * 1000), tb), + end=Timestamp(round(end_sec * 1000), tb), + ) + + +def _ts(seconds: float, tb: Timebase) -> Timestamp: + """Build a Timestamp for ``seconds`` in the given decoder timebase.""" + return Timestamp(round(seconds * tb.den / tb.num), tb) + + +def _qm(**overrides: float) -> QualityMetrics: + defaults = { + "laplacian_var": 100.0, + "mean_luma": 0.5, + "luma_variance": 1000.0, + "entropy": 7.0, + "saliency_mass": 0.0, + } + defaults.update(overrides) + return QualityMetrics(**defaults) + + +def _solid_frame(pts_sec: float, gray: int = 128, size: int = 32) -> DecodedFrame: + rgb = np.full((size, size, 3), gray, dtype=np.uint8) + return DecodedFrame( + pts=Timestamp(round(pts_sec * 1000), Timebase(1, 1000)), + width=size, + height=size, + rgb=rgb, + ) + + +def _noise_frame(pts_sec: float, seed: int, size: int = 32) -> DecodedFrame: + rng = np.random.default_rng(seed=seed) + rgb = rng.integers(50, 201, size=(size, size, 3), dtype=np.uint8) + return DecodedFrame( + pts=Timestamp(round(pts_sec * 1000), Timebase(1, 1000)), + width=size, + height=size, + rgb=rgb, + ) + + +# --------------------------------------------------------------------------- # +# compute_bins # +# --------------------------------------------------------------------------- # + + +class TestComputeBins: + def test_short_shot_yields_one_bin(self): + # D = 1s, I = 4s -> N = ceil(1/4) = 1. + bins = compute_bins(_shot(0.0, 1.0), SamplingConfig()) + assert len(bins) == 1 + + def test_typical_shot_yields_n_bins(self): + # D = 60s, I = 4s -> N = 15. + bins = compute_bins(_shot(0.0, 60.0), SamplingConfig()) + assert len(bins) == 15 + + def test_long_shot_capped_at_max(self): + # D = 120s, I = 4s -> N = 30, capped at max_frames_per_shot = 16. + bins = compute_bins(_shot(0.0, 120.0), SamplingConfig()) + assert len(bins) == 16 + + def test_two_bin_case_documented(self): + # D = 5s, I = 4s -> N = ceil(5/4) = 2 (per TASKS.md verification). + bins = compute_bins(_shot(0.0, 5.0), SamplingConfig()) + assert len(bins) == 2 + + def test_bins_are_contiguous_after_shrink(self): + bins = compute_bins(_shot(10.0, 70.0), SamplingConfig()) + for (_, end), (start, _) in pairwise(bins): + assert end == pytest.approx(start) + + def test_bins_are_equal_width(self): + bins = compute_bins(_shot(0.0, 60.0), SamplingConfig()) + widths = [b - a for a, b in bins] + assert all(w == pytest.approx(widths[0]) for w in widths) + + def test_first_bin_starts_after_shrink(self): + cfg = SamplingConfig() + bins = compute_bins(_shot(0.0, 60.0), cfg) + expected_first_start = 0.0 + cfg.boundary_shrink_pct * 60.0 + assert bins[0][0] == pytest.approx(expected_first_start) + + def test_last_bin_ends_before_shrink(self): + cfg = SamplingConfig() + bins = compute_bins(_shot(0.0, 60.0), cfg) + expected_last_end = 60.0 - cfg.boundary_shrink_pct * 60.0 + assert bins[-1][1] == pytest.approx(expected_last_end) + + +# --------------------------------------------------------------------------- # +# _candidate_times # +# --------------------------------------------------------------------------- # + + +class TestCandidateTimes: + def test_returns_centred_points(self): + # K = 4 in [0, 1] -> [0.125, 0.375, 0.625, 0.875] (cell midpoints). + ts = _candidate_times(0.0, 1.0, 4) + assert ts == pytest.approx([0.125, 0.375, 0.625, 0.875]) + + def test_k_one_returns_midpoint(self): + assert _candidate_times(2.0, 4.0, 1) == [3.0] + + def test_k_zero_returns_empty(self): + assert _candidate_times(0.0, 1.0, 0) == [] + + +# --------------------------------------------------------------------------- # +# _ordinal_rank # +# --------------------------------------------------------------------------- # + + +class TestOrdinalRank: + def test_empty(self): + assert _ordinal_rank([]) == [] + + def test_single_returns_top(self): + assert _ordinal_rank([42.0]) == [1.0] + + def test_sorted_ascending(self): + # 4 elements, ranks evenly spaced 0, 1/3, 2/3, 1. + assert _ordinal_rank([1.0, 2.0, 3.0, 4.0]) == pytest.approx([0.0, 1 / 3, 2 / 3, 1.0]) + + def test_reversed(self): + assert _ordinal_rank([4.0, 3.0, 2.0, 1.0]) == pytest.approx([1.0, 2 / 3, 1 / 3, 0.0]) + + def test_stable_for_ties(self): + # Stable sort: equal values keep their input order, so the first + # occurrence gets the lower rank. + ranks = _ordinal_rank([5.0, 5.0, 5.0]) + assert ranks[0] < ranks[1] < ranks[2] + + +# --------------------------------------------------------------------------- # +# score_bin_candidates # +# --------------------------------------------------------------------------- # + + +class TestScoreBinCandidates: + def test_empty(self): + assert score_bin_candidates([]) == [] + + def test_single_full_score_no_saliency(self): + scores = score_bin_candidates([_qm(saliency_mass=0.0)]) + # 0.6 (rank=1) + 0.2 (rank=1) + 0 = 0.8. + assert scores == [pytest.approx(0.8)] + + def test_single_with_saliency(self): + scores = score_bin_candidates([_qm(saliency_mass=0.5)]) + # 0.6 + 0.2 + 0.2 * 0.5 = 0.9. + assert scores == [pytest.approx(0.9)] + + def test_higher_laplacian_higher_score(self): + scores = score_bin_candidates( + [ + _qm(laplacian_var=10.0, entropy=7.0, saliency_mass=0.0), + _qm(laplacian_var=100.0, entropy=7.0, saliency_mass=0.0), + ] + ) + assert scores[1] > scores[0] + + def test_saliency_breaks_a_tie(self): + scores = score_bin_candidates( + [ + _qm(laplacian_var=50.0, entropy=7.0, saliency_mass=0.0), + _qm(laplacian_var=50.0, entropy=7.0, saliency_mass=0.5), + ] + ) + assert scores[1] > scores[0] + + +# --------------------------------------------------------------------------- # +# select_from_bin # +# --------------------------------------------------------------------------- # + + +class TestSelectFromBin: + def test_empty_returns_none(self): + assert select_from_bin([], QualityGate()) is None + + def test_all_uniform_returns_none(self): + # Solid colour frames have luma_variance = 0 < 5 -> all rejected. + cands = [_solid_frame(0.1, gray=128), _solid_frame(0.2, gray=64)] + assert select_from_bin(cands, QualityGate()) is None + + def test_mixed_picks_a_survivor(self): + cands = [_solid_frame(0.1, gray=128), _noise_frame(0.2, seed=0)] + result = select_from_bin(cands, QualityGate()) + assert result is not None + chosen, metrics, conf = result + assert chosen.pts.seconds == pytest.approx(0.2) + assert conf is Confidence.High + assert metrics.luma_variance > 5.0 + + +# --------------------------------------------------------------------------- # +# extract_for_shot — happy path on varied (high-quality) frames # +# --------------------------------------------------------------------------- # + + +class TestExtractForShot: + def test_noise_video_one_keyframe_per_bin(self, varied_video: Path): + with VideoDecoder.open(varied_video, target_size=64) as dec: + shot = ShotRange( + start=Timestamp(0, dec.timebase), + end=Timestamp(round(1.4 * dec.timebase.den), dec.timebase), + ) + keyframes = extract_for_shot(shot, 7, dec, SamplingConfig()) + # D = 1.4s, I = 4s -> N = ceil(1.4/4) = 1. + assert len(keyframes) == 1 + kf = keyframes[0] + assert isinstance(kf, ExtractedKeyframe) + assert kf.shot_id == 7 + assert kf.bucket_index == 0 + assert kf.confidence is Confidence.High + assert kf.width == 64 + assert kf.height == 64 + assert len(kf.rgb) == 64 * 64 * 3 + + def test_returns_n_bins_for_long_shot(self, varied_video: Path): + cfg = SamplingConfig(target_interval_sec=0.3) + with VideoDecoder.open(varied_video, target_size=64) as dec: + shot = ShotRange( + start=Timestamp(0, dec.timebase), + end=Timestamp(round(1.4 * dec.timebase.den), dec.timebase), + ) + keyframes = extract_for_shot(shot, 0, dec, cfg) + # D = 1.4s, I = 0.3s -> N = ceil(1.4/0.3) = 5. + assert len(keyframes) == 5 + assert [kf.bucket_index for kf in keyframes] == [0, 1, 2, 3, 4] + # Selected timestamps strictly increase across bins. + ts = [kf.timestamp.seconds for kf in keyframes] + assert ts == sorted(ts) + assert all(kf.confidence is Confidence.High for kf in keyframes) + + +# --------------------------------------------------------------------------- # +# extract_for_shot — fallback path on uniform (gate-failing) frames # +# --------------------------------------------------------------------------- # + + +class TestExtractForShotFallback: + def test_uniform_video_yields_degraded(self, tiny_video: Path): + # Every frame in tiny_video is a single-colour ramp, luma_variance = 0 + # for each, so the gate fails on every probe. Fallback force-picks. + with VideoDecoder.open(tiny_video) as dec: + shot = ShotRange( + start=Timestamp(0, dec.timebase), + end=Timestamp(round(0.8 * dec.timebase.den), dec.timebase), + ) + keyframes = extract_for_shot(shot, 0, dec, SamplingConfig()) + assert len(keyframes) == 1 + assert keyframes[0].confidence is Confidence.Degraded + + +# --------------------------------------------------------------------------- # +# extract_all # +# --------------------------------------------------------------------------- # + + +class TestQualityGradient: + """T5 verification: sharp regions must win in mixed-content bins. + + The 20-second ``quality_gradient_video`` fixture is structured as + sharp / blur / sharp thirds at 15 fps, so the boundaries between + regions are at ``t = 100/15`` s and ``t = 200/15`` s. With + ``SamplingConfig`` defaults (``target_interval_sec = 4``, + ``boundary_shrink_pct = 0.02``), the shot is split into 5 bins of + equal width over the shrunken range ``[0.4, 19.6]``: + + * Bin 0 ``[0.40, 4.24)`` — entirely sharp + * Bin 1 ``[4.24, 8.08)`` — mostly sharp + tail blur + * Bin 2 ``[8.08, 11.92)`` — entirely blur + * Bin 3 ``[11.92, 15.76)`` — head blur + tail sharp + * Bin 4 ``[15.76, 19.60]`` — entirely sharp + + Bins 1 and 3 are the load-bearing tests: the within-bin scorer must + pick a sharp candidate timestamp, not a blurred one. + """ + + SHARP_END_1 = 100 / 15 # ~6.667 s + BLUR_END = 200 / 15 # ~13.333 s + + def test_sampler_prefers_sharp_in_mixed_bins(self, quality_gradient_video: Path): + with VideoDecoder.open(quality_gradient_video, target_size=64) as dec: + shot = ShotRange( + start=Timestamp(0, dec.timebase), + end=_ts(20.0, dec.timebase), + ) + keyframes = extract_for_shot(shot, 0, dec, SamplingConfig()) + + # ceil(20 / 4) = 5 bins, exactly one keyframe each. + assert len(keyframes) == 5 + assert [kf.bucket_index for kf in keyframes] == [0, 1, 2, 3, 4] + + ts = [kf.timestamp.seconds for kf in keyframes] + # Bin 0: entirely sharp; selected ts must be in the first sharp third. + assert ts[0] < self.SHARP_END_1, f"bin 0 ts={ts[0]:.3f}s outside sharp third" + # Bin 1: mostly sharp; the scorer must reject the tail blur candidates. + assert ts[1] < self.SHARP_END_1, f"bin 1 ts={ts[1]:.3f}s in blur region" + # Bin 2: entirely blur — sampler picks something here, no preference assertion. + assert self.SHARP_END_1 <= ts[2] < self.BLUR_END + # Bin 3: head blur + tail sharp; the scorer must reach into the sharp tail. + assert ts[3] > self.BLUR_END, f"bin 3 ts={ts[3]:.3f}s in blur region" + # Bin 4: entirely sharp. + assert ts[4] > self.BLUR_END, f"bin 4 ts={ts[4]:.3f}s outside sharp third" + + def test_sharp_bins_have_higher_laplacian_than_blur_bin(self, quality_gradient_video: Path): + with VideoDecoder.open(quality_gradient_video, target_size=64) as dec: + shot = ShotRange( + start=Timestamp(0, dec.timebase), + end=_ts(20.0, dec.timebase), + ) + keyframes = extract_for_shot(shot, 0, dec, SamplingConfig()) + # The all-blur bin's Laplacian variance is bounded above by both + # sharp-only bins (0 and 4). This corroborates that the box-blur + # roundtrip survives libx264 encoding measurably. + blur_lap = keyframes[2].quality.laplacian_var + sharp0_lap = keyframes[0].quality.laplacian_var + sharp4_lap = keyframes[4].quality.laplacian_var + assert blur_lap < sharp0_lap, ( + f"blur bin laplacian {blur_lap:.1f} not below sharp bin 0 {sharp0_lap:.1f}" + ) + assert blur_lap < sharp4_lap, ( + f"blur bin laplacian {blur_lap:.1f} not below sharp bin 4 {sharp4_lap:.1f}" + ) + + +class TestExtractAll: + def test_shape_matches_input(self, varied_video: Path): + with VideoDecoder.open(varied_video, target_size=64) as dec: + tb = dec.timebase + shots = [ + ShotRange(start=Timestamp(0, tb), end=Timestamp(round(0.5 * tb.den), tb)), + ShotRange( + start=Timestamp(round(0.6 * tb.den), tb), + end=Timestamp(round(1.2 * tb.den), tb), + ), + ] + result = extract_all(shots, dec, SamplingConfig()) + assert len(result) == 2 + # Both shots are < I, so each yields one bin. + assert len(result[0]) == 1 + assert len(result[1]) == 1 + assert result[0][0].shot_id == 0 + assert result[1][0].shot_id == 1 + + def test_empty_shots_returns_empty(self, varied_video: Path): + with VideoDecoder.open(varied_video) as dec: + assert extract_all([], dec, SamplingConfig()) == [] diff --git a/tests/test_sampler_regression.py b/tests/test_sampler_regression.py new file mode 100644 index 0000000..bc705fa --- /dev/null +++ b/tests/test_sampler_regression.py @@ -0,0 +1,169 @@ +"""Deterministic regression snapshot for :func:`extract_all`. + +This file stands in for the *Kino Demo regression fixture* called out in +``TASKS.md`` T5: a JSON snapshot of ``(shot_id, bin_index, timestamp, +quality)`` tuples that fails the moment the algorithm's selection or +scoring drifts. Until the real asset is available, we use an **in-memory +fake decoder** that returns deterministic noise frames so the snapshot is +bit-exact across machines and PyAV / FFmpeg versions. + +To regenerate after an intentional algorithm change, run:: + + python tests/test_sampler_regression.py + +Review the diff in the JSON file before committing. +""" + +from __future__ import annotations + +import json +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import TYPE_CHECKING + +import numpy as np + +from findit_keyframe.decoder import DecodedFrame +from findit_keyframe.sampler import extract_all +from findit_keyframe.types import SamplingConfig, ShotRange, Timebase, Timestamp + +if TYPE_CHECKING: + from collections.abc import Callable + +SNAPSHOT_PATH = Path(__file__).parent / "fixtures" / "regression" / "extract_all_synthetic.json" + + +# --------------------------------------------------------------------------- # +# In-memory deterministic decoder # +# --------------------------------------------------------------------------- # + + +@dataclass +class _FakeDecoder: + """Duck-typed VideoDecoder for bit-exact regression testing. + + The sampler's only contact with the decoder is :attr:`duration_sec`, + :attr:`timebase`, and :meth:`decode_at`. We satisfy that surface + without touching PyAV so the snapshot is unaffected by codec build + differences across machines. + """ + + duration_sec: float + fps: int + frame_fn: Callable[[int], np.ndarray] + + @property + def timebase(self) -> Timebase: + return Timebase(1, self.fps) + + def decode_at(self, time_sec: float) -> DecodedFrame: + max_idx = int(self.duration_sec * self.fps) - 1 + idx = max(0, min(max_idx, round(time_sec * self.fps))) + rgb = self.frame_fn(idx) + height, width = rgb.shape[:2] + return DecodedFrame( + pts=Timestamp(idx, self.timebase), + width=width, + height=height, + rgb=rgb, + ) + + +def _synthetic_noise_frame(idx: int) -> np.ndarray: + """Deterministic mid-tone noise frame keyed by frame index.""" + rng = np.random.default_rng(seed=idx) + return rng.integers(50, 201, size=(64, 64, 3), dtype=np.uint8) + + +def _build_synthetic_decoder() -> _FakeDecoder: + """A 5-second 30 fps fake decoder yielding deterministic noise.""" + return _FakeDecoder(duration_sec=5.0, fps=30, frame_fn=_synthetic_noise_frame) + + +def _shot(start_sec: float, end_sec: float, tb: Timebase) -> ShotRange: + return ShotRange( + start=Timestamp(round(start_sec * tb.den / tb.num), tb), + end=Timestamp(round(end_sec * tb.den / tb.num), tb), + ) + + +def _run_extract_all() -> list[dict]: + """Run the algorithm on a fixed input and flatten to a JSON-serialisable list. + + Inputs are chosen to exercise both code paths simultaneously: + + * shot 0 — short, single-bin (no within-bin contention). + * shot 1 — longer, multi-bin (exercises scoring + cross-bin ordering). + """ + decoder = _build_synthetic_decoder() + tb = decoder.timebase + shots = [ + _shot(0.0, 0.4, tb), + _shot(0.5, 4.5, tb), + ] + config = SamplingConfig(target_interval_sec=1.0, target_size=64) + results = extract_all(shots, decoder, config) + return [ + { + "shot_id": kf.shot_id, + "bucket": kf.bucket_index, + "timestamp_sec": round(kf.timestamp.seconds, 6), + "confidence": kf.confidence.value, + "quality": {k: round(float(v), 6) for k, v in asdict(kf.quality).items()}, + } + for shot_keyframes in results + for kf in shot_keyframes + ] + + +# --------------------------------------------------------------------------- # +# Test # +# --------------------------------------------------------------------------- # + + +def test_extract_all_matches_regression_snapshot(): + """Every field of every emitted keyframe must match the JSON snapshot. + + This catches: bin partitioning changes, candidate timestamp shifts, + scoring weight or normalisation changes, fallback ordering changes. + PyAV / FFmpeg cannot be blamed because the input pipeline is in-memory. + """ + actual = _run_extract_all() + snapshot = json.loads(SNAPSHOT_PATH.read_text()) + assert snapshot["version"] == 1, f"unsupported snapshot version: {snapshot['version']}" + assert actual == snapshot["entries"], ( + "extract_all output diverged from snapshot. If intentional, " + "regenerate via `python tests/test_sampler_regression.py` and " + "audit the diff before committing." + ) + + +# --------------------------------------------------------------------------- # +# Regenerator (run directly: `python tests/test_sampler_regression.py`) # +# --------------------------------------------------------------------------- # + + +def _regenerate_snapshot() -> None: + entries = _run_extract_all() + data = { + "version": 1, + "purpose": ( + "Deterministic regression snapshot for extract_all over a " + "synthetic in-memory decoder. Stand-in for the Kino Demo " + "regression fixture in TASKS.md T5; replace when the real " + "asset becomes available." + ), + "input": { + "decoder": "in-memory _FakeDecoder, 5.0 s @ 30 fps, 64x64 noise frames", + "shots": ["[0.0, 0.4) s", "[0.5, 4.5) s"], + "config": "SamplingConfig(target_interval_sec=1.0, target_size=64)", + }, + "entries": entries, + } + SNAPSHOT_PATH.parent.mkdir(parents=True, exist_ok=True) + SNAPSHOT_PATH.write_text(json.dumps(data, indent=2) + "\n") + print(f"Wrote {len(entries)} entries to {SNAPSHOT_PATH}") + + +if __name__ == "__main__": + _regenerate_snapshot() diff --git a/tests/test_types.py b/tests/test_types.py new file mode 100644 index 0000000..c5f9497 --- /dev/null +++ b/tests/test_types.py @@ -0,0 +1,261 @@ +"""Tests for ``findit_keyframe.types``. + +These tests pin down invariants documented in ``TASKS.md`` §2 and the Type Map +in ``docs/rust-porting.md``. They are the contract the Rust port must replay. +""" + +from __future__ import annotations + +import dataclasses +from typing import get_type_hints + +import pytest + +from findit_keyframe.types import ( + Confidence, + ExtractedKeyframe, + QualityMetrics, + SamplingConfig, + ShotRange, + Timebase, + Timestamp, +) + +# --------------------------------------------------------------------------- # +# Timebase # +# --------------------------------------------------------------------------- # + + +class TestTimebase: + def test_zero_denominator_rejected(self): + with pytest.raises(ValueError, match="den"): + Timebase(num=1, den=0) + + def test_negative_denominator_rejected(self): + with pytest.raises(ValueError, match="den"): + Timebase(num=1, den=-1) + + def test_zero_numerator_accepted(self): + # scenesdetect allows num == 0; mirror that. + tb = Timebase(num=0, den=1) + assert tb.num == 0 + assert tb.den == 1 + + def test_value_equality_across_reduced_forms(self): + # 1/2 == 2/4 == 5/10 — semantic equality, not field equality. + assert Timebase(1, 2) == Timebase(2, 4) + assert Timebase(1, 2) == Timebase(5, 10) + assert Timebase(1, 1000) == Timebase(90, 90000) + + def test_inequality_across_reduced_forms(self): + assert Timebase(1, 2) != Timebase(1, 3) + assert Timebase(2, 5) != Timebase(3, 5) + + def test_hash_matches_equality(self): + assert hash(Timebase(1, 2)) == hash(Timebase(2, 4)) + assert hash(Timebase(1, 1000)) == hash(Timebase(90, 90000)) + # And distinct values usually have distinct hashes (sanity). + assert hash(Timebase(1, 2)) != hash(Timebase(1, 3)) + + def test_is_frozen(self): + tb = Timebase(1, 1000) + with pytest.raises(dataclasses.FrozenInstanceError): + tb.num = 2 # type: ignore[misc] + + def test_repr_readable(self): + assert "1/1000" in repr(Timebase(1, 1000)) + + +# --------------------------------------------------------------------------- # +# Timestamp # +# --------------------------------------------------------------------------- # + + +class TestTimestamp: + def test_seconds_simple_case(self): + # 1000 ticks at 1/1000 timebase == 1.0 second. + assert Timestamp(pts=1000, timebase=Timebase(1, 1000)).seconds == pytest.approx(1.0) + + def test_seconds_video_timebase(self): + # 90000 ticks at 1/90000 (MPEG-TS) == 1.0 second. + assert Timestamp(pts=90000, timebase=Timebase(1, 90000)).seconds == pytest.approx(1.0) + + def test_seconds_zero(self): + assert Timestamp(pts=0, timebase=Timebase(1, 1000)).seconds == 0.0 + + def test_cross_timebase_equality(self): + a = Timestamp(pts=1000, timebase=Timebase(1, 1000)) + b = Timestamp(pts=90000, timebase=Timebase(1, 90000)) + assert a == b + assert hash(a) == hash(b) + + def test_ordering_same_timebase(self): + tb = Timebase(1, 1000) + assert Timestamp(500, tb) < Timestamp(1000, tb) + assert Timestamp(1000, tb) > Timestamp(500, tb) + assert Timestamp(500, tb) <= Timestamp(500, tb) + + def test_ordering_cross_timebase(self): + a = Timestamp(pts=500, timebase=Timebase(1, 1000)) # 0.5s + b = Timestamp(pts=90000, timebase=Timebase(1, 90000)) # 1.0s + assert a < b + assert b > a + + def test_is_frozen(self): + ts = Timestamp(0, Timebase(1, 1000)) + with pytest.raises(dataclasses.FrozenInstanceError): + ts.pts = 1 # type: ignore[misc] + + +# --------------------------------------------------------------------------- # +# ShotRange # +# --------------------------------------------------------------------------- # + + +class TestShotRange: + def test_duration_sec_basic(self): + tb = Timebase(1, 1000) + sr = ShotRange(start=Timestamp(0, tb), end=Timestamp(5000, tb)) + assert sr.duration_sec == pytest.approx(5.0) + + def test_duration_sec_cross_timebase(self): + sr = ShotRange( + start=Timestamp(0, Timebase(1, 1000)), + end=Timestamp(90000, Timebase(1, 90000)), + ) + assert sr.duration_sec == pytest.approx(1.0) + + def test_end_before_start_rejected(self): + tb = Timebase(1, 1000) + with pytest.raises(ValueError, match="end"): + ShotRange(start=Timestamp(1000, tb), end=Timestamp(500, tb)) + + def test_zero_duration_rejected(self): + tb = Timebase(1, 1000) + with pytest.raises(ValueError, match="end"): + ShotRange(start=Timestamp(1000, tb), end=Timestamp(1000, tb)) + + def test_is_frozen(self): + tb = Timebase(1, 1000) + sr = ShotRange(Timestamp(0, tb), Timestamp(1000, tb)) + with pytest.raises(dataclasses.FrozenInstanceError): + sr.start = Timestamp(1, tb) # type: ignore[misc] + + +# --------------------------------------------------------------------------- # +# SamplingConfig # +# --------------------------------------------------------------------------- # + + +class TestSamplingConfig: + def test_defaults_match_spec(self): + c = SamplingConfig() + assert c.target_interval_sec == pytest.approx(4.0) + assert c.candidates_per_bin == 6 + assert c.max_frames_per_shot == 16 + assert c.boundary_shrink_pct == pytest.approx(0.02) + assert c.fallback_expand_pct == pytest.approx(0.20) + assert c.target_size == 384 + + def test_is_mutable(self): + # SamplingConfig is one of two intentionally non-frozen dataclasses. + c = SamplingConfig() + c.target_size = 256 + assert c.target_size == 256 + + def test_replace_returns_new_instance(self): + c1 = SamplingConfig() + c2 = dataclasses.replace(c1, target_size=256) + assert c2.target_size == 256 + assert c1.target_size == 384 + + +# --------------------------------------------------------------------------- # +# Confidence # +# --------------------------------------------------------------------------- # + + +class TestConfidence: + def test_three_levels(self): + assert {Confidence.High, Confidence.Low, Confidence.Degraded} == set(Confidence) + + def test_string_value_lowercase(self): + # The CLI manifest uses lowercase string form. + assert Confidence.High.value == "high" + assert Confidence.Low.value == "low" + assert Confidence.Degraded.value == "degraded" + + +# --------------------------------------------------------------------------- # +# QualityMetrics # +# --------------------------------------------------------------------------- # + + +class TestQualityMetrics: + def test_construction_and_fields(self): + q = QualityMetrics( + laplacian_var=215.4, + mean_luma=0.41, + luma_variance=1820.7, + entropy=7.31, + saliency_mass=0.62, + ) + assert q.laplacian_var == pytest.approx(215.4) + assert q.mean_luma == pytest.approx(0.41) + assert q.luma_variance == pytest.approx(1820.7) + assert q.entropy == pytest.approx(7.31) + assert q.saliency_mass == pytest.approx(0.62) + + def test_is_frozen(self): + q = QualityMetrics(0.0, 0.0, 0.0, 0.0, 0.0) + with pytest.raises(dataclasses.FrozenInstanceError): + q.entropy = 1.0 # type: ignore[misc] + + def test_all_fields_are_floats(self): + hints = get_type_hints(QualityMetrics) + assert hints["laplacian_var"] is float + assert hints["mean_luma"] is float + assert hints["luma_variance"] is float + assert hints["entropy"] is float + assert hints["saliency_mass"] is float + + +# --------------------------------------------------------------------------- # +# ExtractedKeyframe # +# --------------------------------------------------------------------------- # + + +class TestExtractedKeyframe: + def _make(self, **overrides): + defaults = { + "shot_id": 0, + "timestamp": Timestamp(1000, Timebase(1, 1000)), + "bucket_index": 0, + "rgb": b"\x00" * (4 * 4 * 3), + "width": 4, + "height": 4, + "quality": QualityMetrics(0.0, 0.5, 1.0, 7.0, 0.0), + "confidence": Confidence.High, + } + return ExtractedKeyframe(**(defaults | overrides)) + + def test_construction(self): + kf = self._make() + assert kf.shot_id == 0 + assert kf.bucket_index == 0 + assert kf.width == 4 + assert kf.height == 4 + assert kf.confidence is Confidence.High + assert kf.quality.entropy == pytest.approx(7.0) + + def test_is_mutable(self): + # Per TASKS.md §2 — ExtractedKeyframe is intentionally mutable so callers + # can attach downstream metadata before serialising. + kf = self._make() + kf.confidence = Confidence.Low + assert kf.confidence is Confidence.Low + + def test_rgb_is_bytes(self): + kf = self._make() + assert isinstance(kf.rgb, bytes) + assert len(kf.rgb) == kf.width * kf.height * 3