fix(actor): train GSP predictor via direct MSE instead of DDPG actor-critic#24
Merged
fix(actor): train GSP predictor via direct MSE instead of DDPG actor-critic#24
Conversation
…critic Replaces the DDPG/RDDPG/TD3 actor-critic training path for GSP prediction networks with direct supervised MSE regression against the ground-truth delta-theta label. This is Option A (minimal fix) from the analysis at Stelaris docs/research/2026-04-13-gsp-information-collapse-analysis.md. Root cause: training the GSP predictor as a DDPG agent on a clipped negative-MSE reward r = clip(-|pred - label|^2, -2, 0) is a category error. DDPG's deterministic policy gradient flows through a Q-critic whose value landscape becomes flat when the reward is clipped or when the policy converges to any constant output — the DPG update then vanishes and the predictor freezes in place. Live evidence from the diagnostic batch (50-150 eps): - DDQN+GSP predictor MSE 0.0601 vs trivial-mean MSE 0.0521 (worse than predicting a constant) - DDPG+R-GSP-N pred std collapsed to 0.00019 (near-constant zero) - Correlation(pred, target) ~= 0 across all DDPG-trained variants - A-GSP-N (the only variant trained by direct MSE in learn_attention) did NOT collapse — the controlled experiment that isolates the training mechanism as the cause. This commit adds learn_gsp_mse(networks, recurrent) in learning_aids.py and rewires actor.learn_gsp() to dispatch to it for DDPG/RDDPG/TD3 schemes. The attention scheme is unchanged. The method samples (state, label) pairs from the replay buffer (label in the action field by the RL-CT call-site convention) and minimizes MSE directly against the label with a non-vanishing gradient. Tests: 2 new cases in tests/test_actor/test_gsp_direct_mse.py: - test_learn_gsp_mse_beats_trivial_mean_baseline_on_linear_task verifies that after 200 learn steps on a linear state→label mapping, the predictor's MSE is lower than predicting the mean. Under the previous DDPG path this test fails by ~0.5%, matching the live collapse signature. - test_learn_gsp_populates_last_gsp_loss_for_ddpg_variant confirms the diagnostic field still populates under the new dispatch. Full actor + learning_aids suite: 72/72 pass. Companion change required in RL-CollectiveTransport: Main.py's store_gsp_transition call sites must pass the ground-truth label (currently passing the previous prediction for non-attention variants). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Replaces the DDPG/RDDPG/TD3 actor-critic training path for GSP prediction networks with direct supervised MSE regression against the ground-truth delta-theta label. This is Option A (minimal fix) from the analysis at Stelaris `docs/research/2026-04-13-gsp-information-collapse-analysis.md`.
Root cause
Training the GSP predictor as a DDPG agent on a clipped negative-MSE reward `r = clip(-|pred - label|², -2, 0)` is a category error. DDPG's deterministic policy gradient flows through a Q-critic whose value landscape becomes flat when the reward is clipped or when the policy converges to any constant output — the DPG update then vanishes and the predictor freezes in place.
Literature consensus (UNREAL, ICM, RND, Predictron) is that auxiliary predictors are always trained by direct gradient descent on a supervised loss, never by RL on an aux reward.
Live evidence from the diagnostic batch
Measured on the 12-config GSP diagnostic batch at 50-150 episodes into training:
The attention variant (the only one that uses direct MSE via `learn_attention`) did NOT collapse — the controlled experiment that isolates the training mechanism as the cause.
Changes
Test plan
Companion PR
Required in RL-CollectiveTransport: Main.py must pass the ground-truth `label` as the 2nd arg to `store_gsp_transition` for all variants (currently passing the previous prediction for non-attention variants). The two PRs are runtime-coupled — merging only this one and not the companion leaves the GSP training path reading the old prediction field as if it were a label.
🤖 Generated with Claude Code