Skip to content

fix(actor): train GSP predictor via direct MSE instead of DDPG actor-critic#24

Merged
jdbloom merged 1 commit intomainfrom
fix/gsp-direct-mse-training
Apr 13, 2026
Merged

fix(actor): train GSP predictor via direct MSE instead of DDPG actor-critic#24
jdbloom merged 1 commit intomainfrom
fix/gsp-direct-mse-training

Conversation

@jdbloom
Copy link
Copy Markdown
Contributor

@jdbloom jdbloom commented Apr 13, 2026

Summary

Replaces the DDPG/RDDPG/TD3 actor-critic training path for GSP prediction networks with direct supervised MSE regression against the ground-truth delta-theta label. This is Option A (minimal fix) from the analysis at Stelaris `docs/research/2026-04-13-gsp-information-collapse-analysis.md`.

Root cause

Training the GSP predictor as a DDPG agent on a clipped negative-MSE reward `r = clip(-|pred - label|², -2, 0)` is a category error. DDPG's deterministic policy gradient flows through a Q-critic whose value landscape becomes flat when the reward is clipped or when the policy converges to any constant output — the DPG update then vanishes and the predictor freezes in place.

Literature consensus (UNREAL, ICM, RND, Predictron) is that auxiliary predictors are always trained by direct gradient descent on a supervised loss, never by RL on an aux reward.

Live evidence from the diagnostic batch

Measured on the 12-config GSP diagnostic batch at 50-150 episodes into training:

Variant pred std corr(pred, target) vs constant-mean baseline
DDQN+GSP 0.071 +0.005 worse (MSE 0.060 vs 0.052)
DDPG+GSP 0.072 +0.005 worse
DDPG+R-GSP-N 0.00019 noisy collapsed to constant zero
DDQN+A-GSP-N 0.016 -0.01 less pathological

The attention variant (the only one that uses direct MSE via `learn_attention`) did NOT collapse — the controlled experiment that isolates the training mechanism as the cause.

Changes

  1. Add `learn_gsp_mse(networks, recurrent)` in `gsp_rl/src/actors/learning_aids.py`. Samples (state, label) from `networks['replay']` (label in the action field by the RL-CT call-site convention), forwards the state through `networks['actor']`, and minimizes `F.mse_loss(pred, label)`. Handles both the 5-tuple and 7-tuple sample shapes from `sample_memory`.
  2. Rewire `Actor.learn_gsp()` in `actor.py` to dispatch to `learn_gsp_mse` for DDPG/RDDPG/TD3 schemes. Attention scheme unchanged.
  3. Preserve `last_gsp_loss` diagnostic field — `learn_gsp_mse` returns a plain float that gets stored exactly like before.

Test plan

  • 2 new tests in `tests/test_actor/test_gsp_direct_mse.py`:
    • `test_learn_gsp_mse_beats_trivial_mean_baseline_on_linear_task` — RED under the old DDPG path (pred MSE 0.0525 vs trivial 0.0550), GREEN under direct MSE. This test is the unit-level reproduction of the live collapse signature.
    • `test_learn_gsp_populates_last_gsp_loss_for_ddpg_variant` — confirms the diagnostic field still populates.
  • Full actor + learning_aids suite: 72/72 pass

Companion PR

Required in RL-CollectiveTransport: Main.py must pass the ground-truth `label` as the 2nd arg to `store_gsp_transition` for all variants (currently passing the previous prediction for non-attention variants). The two PRs are runtime-coupled — merging only this one and not the companion leaves the GSP training path reading the old prediction field as if it were a label.

🤖 Generated with Claude Code

…critic

Replaces the DDPG/RDDPG/TD3 actor-critic training path for GSP prediction
networks with direct supervised MSE regression against the ground-truth
delta-theta label. This is Option A (minimal fix) from the analysis at
Stelaris docs/research/2026-04-13-gsp-information-collapse-analysis.md.

Root cause: training the GSP predictor as a DDPG agent on a clipped
negative-MSE reward r = clip(-|pred - label|^2, -2, 0) is a category
error. DDPG's deterministic policy gradient flows through a Q-critic
whose value landscape becomes flat when the reward is clipped or when
the policy converges to any constant output — the DPG update then
vanishes and the predictor freezes in place.

Live evidence from the diagnostic batch (50-150 eps):
- DDQN+GSP predictor MSE 0.0601 vs trivial-mean MSE 0.0521 (worse
  than predicting a constant)
- DDPG+R-GSP-N pred std collapsed to 0.00019 (near-constant zero)
- Correlation(pred, target) ~= 0 across all DDPG-trained variants
- A-GSP-N (the only variant trained by direct MSE in learn_attention)
  did NOT collapse — the controlled experiment that isolates the
  training mechanism as the cause.

This commit adds learn_gsp_mse(networks, recurrent) in learning_aids.py
and rewires actor.learn_gsp() to dispatch to it for DDPG/RDDPG/TD3
schemes. The attention scheme is unchanged. The method samples
(state, label) pairs from the replay buffer (label in the action field
by the RL-CT call-site convention) and minimizes MSE directly against
the label with a non-vanishing gradient.

Tests: 2 new cases in tests/test_actor/test_gsp_direct_mse.py:
- test_learn_gsp_mse_beats_trivial_mean_baseline_on_linear_task
  verifies that after 200 learn steps on a linear state→label mapping,
  the predictor's MSE is lower than predicting the mean. Under the
  previous DDPG path this test fails by ~0.5%, matching the live
  collapse signature.
- test_learn_gsp_populates_last_gsp_loss_for_ddpg_variant confirms
  the diagnostic field still populates under the new dispatch.

Full actor + learning_aids suite: 72/72 pass.

Companion change required in RL-CollectiveTransport: Main.py's
store_gsp_transition call sites must pass the ground-truth label
(currently passing the previous prediction for non-attention variants).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@jdbloom jdbloom merged commit 0be62c6 into main Apr 13, 2026
4 checks passed
@jdbloom jdbloom deleted the fix/gsp-direct-mse-training branch April 13, 2026 13:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant