From 0fa571cbe301d29f8250cac1d53188bb21383a1d Mon Sep 17 00:00:00 2001 From: jif-oai Date: Fri, 27 Feb 2026 13:51:18 +0100 Subject: [PATCH 1/4] feat: polluted memories --- codex-rs/core/config.schema.json | 4 + codex-rs/core/src/config/mod.rs | 3 + codex-rs/core/src/config/types.rs | 7 + codex-rs/core/src/mcp_tool_call.rs | 19 ++ codex-rs/core/src/personality_migration.rs | 1 + codex-rs/core/src/rollout/metadata.rs | 83 +++++ codex-rs/core/src/rollout/recorder.rs | 2 + codex-rs/core/src/rollout/tests.rs | 1 + codex-rs/core/src/state_db.rs | 24 ++ codex-rs/core/src/stream_events_utils.rs | 22 ++ .../core/tests/suite/personality_migration.rs | 2 + codex-rs/core/tests/suite/sqlite_state.rs | 156 +++++++++ codex-rs/protocol/src/protocol.rs | 3 + codex-rs/state/src/extract.rs | 1 + codex-rs/state/src/model/thread_metadata.rs | 2 + codex-rs/state/src/runtime/memories.rs | 313 +++++++++++++++++- codex-rs/state/src/runtime/threads.rs | 97 ++++++ 17 files changed, 736 insertions(+), 4 deletions(-) diff --git a/codex-rs/core/config.schema.json b/codex-rs/core/config.schema.json index 59140d9d160..dd4327934e1 100644 --- a/codex-rs/core/config.schema.json +++ b/codex-rs/core/config.schema.json @@ -644,6 +644,10 @@ "format": "int64", "type": "integer" }, + "no_memories_if_mcp_or_web_search": { + "description": "When `true`, web searches and MCP tool calls mark the thread `memory_mode` as `\"polluted\"`.", + "type": "boolean" + }, "phase_1_model": { "description": "Model used for thread summarisation.", "type": "string" diff --git a/codex-rs/core/src/config/mod.rs b/codex-rs/core/src/config/mod.rs index 5eb1ad4c30c..5e3246ae228 100644 --- a/codex-rs/core/src/config/mod.rs +++ b/codex-rs/core/src/config/mod.rs @@ -2490,6 +2490,7 @@ persistence = "none" let memories = r#" [memories] +no_memories_if_mcp_or_web_search = true generate_memories = false use_memories = false max_raw_memories_for_global = 512 @@ -2504,6 +2505,7 @@ phase_2_model = "gpt-5" toml::from_str::(memories).expect("TOML deserialization should succeed"); assert_eq!( Some(MemoriesToml { + no_memories_if_mcp_or_web_search: Some(true), generate_memories: Some(false), use_memories: Some(false), max_raw_memories_for_global: Some(512), @@ -2526,6 +2528,7 @@ phase_2_model = "gpt-5" assert_eq!( config.memories, MemoriesConfig { + no_memories_if_mcp_or_web_search: true, generate_memories: false, use_memories: false, max_raw_memories_for_global: 512, diff --git a/codex-rs/core/src/config/types.rs b/codex-rs/core/src/config/types.rs index 3ee85226e1a..9c4756238be 100644 --- a/codex-rs/core/src/config/types.rs +++ b/codex-rs/core/src/config/types.rs @@ -371,6 +371,8 @@ pub struct FeedbackConfigToml { #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, JsonSchema)] #[schemars(deny_unknown_fields)] pub struct MemoriesToml { + /// When `true`, web searches and MCP tool calls mark the thread `memory_mode` as `"polluted"`. + pub no_memories_if_mcp_or_web_search: Option, /// When `false`, newly created threads are stored with `memory_mode = "disabled"` in the state DB. pub generate_memories: Option, /// When `false`, skip injecting memory usage instructions into developer prompts. @@ -394,6 +396,7 @@ pub struct MemoriesToml { /// Effective memories settings after defaults are applied. #[derive(Debug, Clone, PartialEq, Eq)] pub struct MemoriesConfig { + pub no_memories_if_mcp_or_web_search: bool, pub generate_memories: bool, pub use_memories: bool, pub max_raw_memories_for_global: usize, @@ -408,6 +411,7 @@ pub struct MemoriesConfig { impl Default for MemoriesConfig { fn default() -> Self { Self { + no_memories_if_mcp_or_web_search: false, generate_memories: true, use_memories: true, max_raw_memories_for_global: DEFAULT_MEMORIES_MAX_RAW_MEMORIES_FOR_GLOBAL, @@ -425,6 +429,9 @@ impl From for MemoriesConfig { fn from(toml: MemoriesToml) -> Self { let defaults = Self::default(); Self { + no_memories_if_mcp_or_web_search: toml + .no_memories_if_mcp_or_web_search + .unwrap_or(defaults.no_memories_if_mcp_or_web_search), generate_memories: toml.generate_memories.unwrap_or(defaults.generate_memories), use_memories: toml.use_memories.unwrap_or(defaults.use_memories), max_raw_memories_for_global: toml diff --git a/codex-rs/core/src/mcp_tool_call.rs b/codex-rs/core/src/mcp_tool_call.rs index 63f83bb3f84..f76eddec6a2 100644 --- a/codex-rs/core/src/mcp_tool_call.rs +++ b/codex-rs/core/src/mcp_tool_call.rs @@ -15,6 +15,7 @@ use crate::protocol::EventMsg; use crate::protocol::McpInvocation; use crate::protocol::McpToolCallBeginEvent; use crate::protocol::McpToolCallEndEvent; +use crate::state_db; use codex_protocol::mcp::CallToolResult; use codex_protocol::models::FunctionCallOutputBody; use codex_protocol::models::FunctionCallOutputPayload; @@ -121,6 +122,7 @@ pub(crate) async fn handle_mcp_tool_call( }); notify_mcp_tool_call_event(sess.as_ref(), turn_context, tool_call_begin_event) .await; + maybe_mark_thread_memory_mode_polluted(sess.as_ref(), turn_context).await; let start = Instant::now(); let result = sess @@ -189,6 +191,7 @@ pub(crate) async fn handle_mcp_tool_call( invocation: invocation.clone(), }); notify_mcp_tool_call_event(sess.as_ref(), turn_context, tool_call_begin_event).await; + maybe_mark_thread_memory_mode_polluted(sess.as_ref(), turn_context).await; let start = Instant::now(); // Perform the tool call. @@ -224,6 +227,22 @@ pub(crate) async fn handle_mcp_tool_call( ResponseInputItem::McpToolCallOutput { call_id, result } } +async fn maybe_mark_thread_memory_mode_polluted(sess: &Session, turn_context: &TurnContext) { + if !turn_context + .config + .memories + .no_memories_if_mcp_or_web_search + { + return; + } + state_db::mark_thread_memory_mode_polluted( + sess.services.state_db.as_deref(), + sess.conversation_id, + "mcp_tool_call", + ) + .await; +} + fn sanitize_mcp_tool_result_for_model( supports_image_input: bool, result: Result, diff --git a/codex-rs/core/src/personality_migration.rs b/codex-rs/core/src/personality_migration.rs index c207295e96f..934ff89379b 100644 --- a/codex-rs/core/src/personality_migration.rs +++ b/codex-rs/core/src/personality_migration.rs @@ -177,6 +177,7 @@ mod tests { model_provider: None, base_instructions: None, dynamic_tools: None, + memory_mode: None, }, git: None, }; diff --git a/codex-rs/core/src/rollout/metadata.rs b/codex-rs/core/src/rollout/metadata.rs index 9196931dafe..e42c67d1256 100644 --- a/codex-rs/core/src/rollout/metadata.rs +++ b/codex-rs/core/src/rollout/metadata.rs @@ -129,6 +129,13 @@ pub(crate) async fn extract_metadata_from_rollout( } Ok(ExtractionOutcome { metadata, + memory_mode: items.iter().rev().find_map(|item| match item { + RolloutItem::SessionMeta(meta_line) => meta_line.meta.memory_mode.clone(), + RolloutItem::ResponseItem(_) + | RolloutItem::Compacted(_) + | RolloutItem::TurnContext(_) + | RolloutItem::EventMsg(_) => None, + }), parse_errors, }) } @@ -272,6 +279,7 @@ pub(crate) async fn backfill_sessions( ); } let mut metadata = outcome.metadata; + let memory_mode = outcome.memory_mode.unwrap_or_else(|| "enabled".to_string()); if rollout.archived && metadata.archived_at.is_none() { let fallback_archived_at = metadata.updated_at; metadata.archived_at = file_modified_time_utc(&rollout.path) @@ -282,6 +290,17 @@ pub(crate) async fn backfill_sessions( stats.failed = stats.failed.saturating_add(1); warn!("failed to upsert rollout {}: {err}", rollout.path.display()); } else { + if let Err(err) = runtime + .set_thread_memory_mode(metadata.id, memory_mode.as_str()) + .await + { + stats.failed = stats.failed.saturating_add(1); + warn!( + "failed to restore memory mode for {}: {err}", + rollout.path.display() + ); + continue; + } stats.upserted = stats.upserted.saturating_add(1); if let Ok(meta_line) = rollout::list::read_session_meta_line(&rollout.path).await @@ -519,6 +538,7 @@ mod tests { model_provider: Some("openai".to_string()), base_instructions: None, dynamic_tools: None, + memory_mode: None, }; let session_meta_line = SessionMetaLine { meta: session_meta, @@ -543,9 +563,71 @@ mod tests { expected.updated_at = file_modified_time_utc(&path).await.expect("mtime"); assert_eq!(outcome.metadata, expected); + assert_eq!(outcome.memory_mode, None); assert_eq!(outcome.parse_errors, 0); } + #[tokio::test] + async fn extract_metadata_from_rollout_returns_latest_memory_mode() { + let dir = tempdir().expect("tempdir"); + let uuid = Uuid::new_v4(); + let id = ThreadId::from_string(&uuid.to_string()).expect("thread id"); + let path = dir + .path() + .join(format!("rollout-2026-01-27T12-34-56-{uuid}.jsonl")); + + let session_meta = SessionMeta { + id, + forked_from_id: None, + timestamp: "2026-01-27T12:34:56Z".to_string(), + cwd: dir.path().to_path_buf(), + originator: "cli".to_string(), + cli_version: "0.0.0".to_string(), + source: SessionSource::default(), + agent_nickname: None, + agent_role: None, + model_provider: Some("openai".to_string()), + base_instructions: None, + dynamic_tools: None, + memory_mode: None, + }; + let polluted_meta = SessionMeta { + memory_mode: Some("polluted".to_string()), + ..session_meta.clone() + }; + let lines = vec![ + RolloutLine { + timestamp: "2026-01-27T12:34:56Z".to_string(), + item: RolloutItem::SessionMeta(SessionMetaLine { + meta: session_meta, + git: None, + }), + }, + RolloutLine { + timestamp: "2026-01-27T12:35:00Z".to_string(), + item: RolloutItem::SessionMeta(SessionMetaLine { + meta: polluted_meta, + git: None, + }), + }, + ]; + let mut file = File::create(&path).expect("create rollout"); + for line in lines { + writeln!( + file, + "{}", + serde_json::to_string(&line).expect("serialize rollout line") + ) + .expect("write rollout line"); + } + + let outcome = extract_metadata_from_rollout(&path, "openai", None) + .await + .expect("extract"); + + assert_eq!(outcome.memory_mode.as_deref(), Some("polluted")); + } + #[test] fn builder_from_items_falls_back_to_filename() { let dir = tempdir().expect("tempdir"); @@ -669,6 +751,7 @@ mod tests { model_provider: Some("test-provider".to_string()), base_instructions: None, dynamic_tools: None, + memory_mode: None, }; let session_meta_line = SessionMetaLine { meta: session_meta, diff --git a/codex-rs/core/src/rollout/recorder.rs b/codex-rs/core/src/rollout/recorder.rs index d7404fd2c69..0497277e865 100644 --- a/codex-rs/core/src/rollout/recorder.rs +++ b/codex-rs/core/src/rollout/recorder.rs @@ -412,6 +412,8 @@ impl RolloutRecorder { } else { Some(dynamic_tools) }, + memory_mode: (!config.memories.generate_memories) + .then_some("disabled".to_string()), }; ( diff --git a/codex-rs/core/src/rollout/tests.rs b/codex-rs/core/src/rollout/tests.rs index 0eabaee745d..14266285f43 100644 --- a/codex-rs/core/src/rollout/tests.rs +++ b/codex-rs/core/src/rollout/tests.rs @@ -1109,6 +1109,7 @@ async fn test_updated_at_uses_file_mtime() -> Result<()> { model_provider: Some("test-provider".into()), base_instructions: None, dynamic_tools: None, + memory_mode: None, }, git: None, }), diff --git a/codex-rs/core/src/state_db.rs b/codex-rs/core/src/state_db.rs index 93c7e75f978..fb73961f877 100644 --- a/codex-rs/core/src/state_db.rs +++ b/codex-rs/core/src/state_db.rs @@ -337,6 +337,19 @@ pub async fn persist_dynamic_tools( } } +pub async fn mark_thread_memory_mode_polluted( + context: Option<&codex_state::StateRuntime>, + thread_id: ThreadId, + stage: &str, +) { + let Some(ctx) = context else { + return; + }; + if let Err(err) = ctx.mark_thread_memory_mode_polluted(thread_id).await { + warn!("state db mark_thread_memory_mode_polluted failed during {stage}: {err}"); + } +} + /// Reconcile rollout items into SQLite, falling back to scanning the rollout file. pub async fn reconcile_rollout( context: Option<&codex_state::StateRuntime>, @@ -375,6 +388,7 @@ pub async fn reconcile_rollout( } }; let mut metadata = outcome.metadata; + let memory_mode = outcome.memory_mode.unwrap_or_else(|| "enabled".to_string()); metadata.cwd = normalize_cwd_for_state_db(&metadata.cwd); match archived_only { Some(true) if metadata.archived_at.is_none() => { @@ -392,6 +406,16 @@ pub async fn reconcile_rollout( ); return; } + if let Err(err) = ctx + .set_thread_memory_mode(metadata.id, memory_mode.as_str()) + .await + { + warn!( + "state db reconcile_rollout memory_mode update failed {}: {err}", + rollout_path.display() + ); + return; + } if let Ok(meta_line) = crate::rollout::list::read_session_meta_line(rollout_path).await { persist_dynamic_tools( Some(ctx), diff --git a/codex-rs/core/src/stream_events_utils.rs b/codex-rs/core/src/stream_events_utils.rs index fe72db09f11..abb4aa86cd6 100644 --- a/codex-rs/core/src/stream_events_utils.rs +++ b/codex-rs/core/src/stream_events_utils.rs @@ -58,9 +58,31 @@ pub(crate) async fn record_completed_response_item( ) { sess.record_conversation_items(turn_context, std::slice::from_ref(item)) .await; + maybe_mark_thread_memory_mode_polluted_from_web_search(sess, turn_context, item).await; record_stage1_output_usage_for_completed_item(turn_context, item).await; } +async fn maybe_mark_thread_memory_mode_polluted_from_web_search( + sess: &Session, + turn_context: &TurnContext, + item: &ResponseItem, +) { + if !turn_context + .config + .memories + .no_memories_if_mcp_or_web_search + || !matches!(item, ResponseItem::WebSearchCall { .. }) + { + return; + } + state_db::mark_thread_memory_mode_polluted( + sess.services.state_db.as_deref(), + sess.conversation_id, + "record_completed_response_item", + ) + .await; +} + async fn record_stage1_output_usage_for_completed_item( turn_context: &TurnContext, item: &ResponseItem, diff --git a/codex-rs/core/tests/suite/personality_migration.rs b/codex-rs/core/tests/suite/personality_migration.rs index dfff16ea03a..adbd86cb23d 100644 --- a/codex-rs/core/tests/suite/personality_migration.rs +++ b/codex-rs/core/tests/suite/personality_migration.rs @@ -71,6 +71,7 @@ async fn write_rollout_with_user_event(dir: &Path, thread_id: ThreadId) -> io::R model_provider: None, base_instructions: None, dynamic_tools: None, + memory_mode: None, }, git: None, }; @@ -114,6 +115,7 @@ async fn write_rollout_with_meta_only(dir: &Path, thread_id: ThreadId) -> io::Re model_provider: None, base_instructions: None, dynamic_tools: None, + memory_mode: None, }, git: None, }; diff --git a/codex-rs/core/tests/suite/sqlite_state.rs b/codex-rs/core/tests/suite/sqlite_state.rs index 113e837b300..0b1a76aa48e 100644 --- a/codex-rs/core/tests/suite/sqlite_state.rs +++ b/codex-rs/core/tests/suite/sqlite_state.rs @@ -1,23 +1,36 @@ use anyhow::Result; +use codex_core::config::types::McpServerConfig; +use codex_core::config::types::McpServerTransportConfig; use codex_core::features::Feature; use codex_protocol::ThreadId; use codex_protocol::dynamic_tools::DynamicToolSpec; +use codex_protocol::protocol::AskForApproval; use codex_protocol::protocol::EventMsg; +use codex_protocol::protocol::Op; use codex_protocol::protocol::RolloutItem; use codex_protocol::protocol::RolloutLine; +use codex_protocol::protocol::SandboxPolicy; use codex_protocol::protocol::SessionMeta; use codex_protocol::protocol::SessionMetaLine; use codex_protocol::protocol::SessionSource; use codex_protocol::protocol::UserMessageEvent; +use codex_protocol::user_input::UserInput; use core_test_support::responses; use core_test_support::responses::ev_completed; use core_test_support::responses::ev_function_call; use core_test_support::responses::ev_response_created; +use core_test_support::responses::ev_web_search_call_done; +use core_test_support::responses::mount_sse_once; use core_test_support::responses::mount_sse_sequence; use core_test_support::responses::start_mock_server; +use core_test_support::skip_if_no_network; +use core_test_support::stdio_server_bin; use core_test_support::test_codex::test_codex; +use core_test_support::wait_for_event; +use core_test_support::wait_for_event_match; use pretty_assertions::assert_eq; use serde_json::json; +use std::collections::HashMap; use std::fs; use tokio::time::Duration; use tracing_subscriber::prelude::*; @@ -128,6 +141,7 @@ async fn backfill_scans_existing_rollouts() -> Result<()> { model_provider: None, base_instructions: None, dynamic_tools: Some(dynamic_tools_for_hook), + memory_mode: None, }, git: None, }; @@ -253,6 +267,148 @@ async fn user_messages_persist_in_state_db() -> Result<()> { Ok(()) } +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn web_search_marks_thread_memory_mode_polluted_when_configured() -> Result<()> { + let server = start_mock_server().await; + mount_sse_sequence( + &server, + vec![responses::sse(vec![ + ev_response_created("resp-1"), + ev_web_search_call_done("ws-1", "completed", "weather seattle"), + ev_completed("resp-1"), + ])], + ) + .await; + + let mut builder = test_codex().with_config(|config| { + config.features.enable(Feature::Sqlite); + config.memories.no_memories_if_mcp_or_web_search = true; + }); + let test = builder.build(&server).await?; + let db = test.codex.state_db().expect("state db enabled"); + let thread_id = test.session_configured.session_id; + + test.submit_turn("search the web").await?; + + let mut memory_mode = None; + for _ in 0..100 { + memory_mode = db.get_thread_memory_mode(thread_id).await?; + if memory_mode.as_deref() == Some("polluted") { + break; + } + tokio::time::sleep(Duration::from_millis(25)).await; + } + + assert_eq!(memory_mode.as_deref(), Some("polluted")); + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn mcp_call_marks_thread_memory_mode_polluted_when_configured() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = start_mock_server().await; + let call_id = "call-123"; + let server_name = "rmcp"; + let tool_name = format!("mcp__{server_name}__echo"); + mount_sse_once( + &server, + responses::sse(vec![ + ev_response_created("resp-1"), + ev_function_call(call_id, &tool_name, "{\"message\":\"ping\"}"), + ev_completed("resp-1"), + ]), + ) + .await; + mount_sse_once( + &server, + responses::sse(vec![ + responses::ev_assistant_message("msg-1", "rmcp echo tool completed."), + ev_completed("resp-2"), + ]), + ) + .await; + + let rmcp_test_server_bin = stdio_server_bin()?; + let mut builder = test_codex().with_config(move |config| { + config.features.enable(Feature::Sqlite); + config.memories.no_memories_if_mcp_or_web_search = true; + + let mut servers = config.mcp_servers.get().clone(); + servers.insert( + server_name.to_string(), + McpServerConfig { + transport: McpServerTransportConfig::Stdio { + command: rmcp_test_server_bin, + args: Vec::new(), + env: Some(HashMap::from([( + "MCP_TEST_VALUE".to_string(), + "propagated-env".to_string(), + )])), + env_vars: Vec::new(), + cwd: None, + }, + enabled: true, + required: false, + disabled_reason: None, + startup_timeout_sec: Some(Duration::from_secs(10)), + tool_timeout_sec: None, + enabled_tools: None, + disabled_tools: None, + scopes: None, + oauth_resource: None, + }, + ); + config + .mcp_servers + .set(servers) + .expect("test mcp servers should accept any configuration"); + }); + let test = builder.build(&server).await?; + let db = test.codex.state_db().expect("state db enabled"); + let thread_id = test.session_configured.session_id; + + test.codex + .submit(Op::UserTurn { + items: vec![UserInput::Text { + text: "call the rmcp echo tool".to_string(), + text_elements: Vec::new(), + }], + final_output_json_schema: None, + cwd: test.cwd_path().to_path_buf(), + approval_policy: AskForApproval::Never, + sandbox_policy: SandboxPolicy::new_read_only_policy(), + model: test.session_configured.model.clone(), + effort: None, + summary: None, + collaboration_mode: None, + personality: None, + }) + .await?; + wait_for_event(&test.codex, |event| { + matches!(event, EventMsg::McpToolCallEnd(_)) + }) + .await; + wait_for_event_match(&test.codex, |event| match event { + EventMsg::Error(err) => Some(Err(anyhow::anyhow!(err.message.clone()))), + EventMsg::TurnComplete(_) => Some(Ok(())), + _ => None, + }) + .await?; + + let mut memory_mode = None; + for _ in 0..100 { + memory_mode = db.get_thread_memory_mode(thread_id).await?; + if memory_mode.as_deref() == Some("polluted") { + break; + } + tokio::time::sleep(Duration::from_millis(25)).await; + } + + assert_eq!(memory_mode.as_deref(), Some("polluted")); + Ok(()) +} + #[tokio::test(flavor = "current_thread")] async fn tool_call_logs_include_thread_id() -> Result<()> { let server = start_mock_server().await; diff --git a/codex-rs/protocol/src/protocol.rs b/codex-rs/protocol/src/protocol.rs index b5aaf47ea65..cec797b178c 100644 --- a/codex-rs/protocol/src/protocol.rs +++ b/codex-rs/protocol/src/protocol.rs @@ -2058,6 +2058,8 @@ pub struct SessionMeta { pub base_instructions: Option, #[serde(skip_serializing_if = "Option::is_none")] pub dynamic_tools: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub memory_mode: Option, } impl Default for SessionMeta { @@ -2075,6 +2077,7 @@ impl Default for SessionMeta { model_provider: None, base_instructions: None, dynamic_tools: None, + memory_mode: None, } } } diff --git a/codex-rs/state/src/extract.rs b/codex-rs/state/src/extract.rs index f9782389790..4c920845594 100644 --- a/codex-rs/state/src/extract.rs +++ b/codex-rs/state/src/extract.rs @@ -242,6 +242,7 @@ mod tests { model_provider: Some("openai".to_string()), base_instructions: None, dynamic_tools: None, + memory_mode: None, }, git: None, }), diff --git a/codex-rs/state/src/model/thread_metadata.rs b/codex-rs/state/src/model/thread_metadata.rs index 9533b8fedeb..96162680126 100644 --- a/codex-rs/state/src/model/thread_metadata.rs +++ b/codex-rs/state/src/model/thread_metadata.rs @@ -45,6 +45,8 @@ pub struct ThreadsPage { pub struct ExtractionOutcome { /// The extracted thread metadata. pub metadata: ThreadMetadata, + /// The explicit thread memory mode from rollout metadata, if present. + pub memory_mode: Option, /// The number of rollout lines that failed to parse. pub parse_errors: usize, } diff --git a/codex-rs/state/src/runtime/memories.rs b/codex-rs/state/src/runtime/memories.rs index 908919d0aa0..4984e5a3e0b 100644 --- a/codex-rs/state/src/runtime/memories.rs +++ b/codex-rs/state/src/runtime/memories.rs @@ -252,7 +252,8 @@ SELECT FROM stage1_outputs AS so LEFT JOIN threads AS t ON t.id = so.thread_id -WHERE length(trim(so.raw_memory)) > 0 OR length(trim(so.rollout_summary)) > 0 +WHERE t.memory_mode = 'enabled' + AND (length(trim(so.raw_memory)) > 0 OR length(trim(so.rollout_summary)) > 0) ORDER BY so.source_updated_at DESC, so.thread_id DESC LIMIT ? "#, @@ -279,11 +280,13 @@ LIMIT ? /// `thread_id DESC` /// - previously selected rows are identified by `selected_for_phase2 = 1` /// - `previous_selected` contains the current persisted rows that belonged - /// to the last successful phase-2 baseline + /// to the last successful phase-2 baseline, even if those threads are no + /// longer memory-eligible /// - `retained_thread_ids` records which current rows still match the exact /// snapshot selected in the last successful phase-2 run /// - removed rows are previously selected rows that are still present in - /// `stage1_outputs` but fall outside the current top-`n` selection + /// `stage1_outputs` but are no longer in the current selection, including + /// threads that are no longer memory-eligible pub async fn get_phase2_input_selection( &self, n: usize, @@ -311,7 +314,8 @@ SELECT FROM stage1_outputs AS so LEFT JOIN threads AS t ON t.id = so.thread_id -WHERE (length(trim(so.raw_memory)) > 0 OR length(trim(so.rollout_summary)) > 0) +WHERE t.memory_mode = 'enabled' + AND (length(trim(so.raw_memory)) > 0 OR length(trim(so.rollout_summary)) > 0) AND ( (so.last_usage IS NOT NULL AND so.last_usage >= ?) OR (so.last_usage IS NULL AND so.source_updated_at >= ?) @@ -396,6 +400,51 @@ ORDER BY so.source_updated_at DESC, so.thread_id DESC }) } + /// Marks a thread as polluted and enqueues phase-2 forgetting when the + /// thread participated in the last successful phase-2 baseline. + pub async fn mark_thread_memory_mode_polluted( + &self, + thread_id: ThreadId, + ) -> anyhow::Result { + let now = Utc::now().timestamp(); + let thread_id = thread_id.to_string(); + let mut tx = self.pool.begin().await?; + let rows_affected = sqlx::query( + r#" +UPDATE threads +SET memory_mode = 'polluted' +WHERE id = ? AND memory_mode != 'polluted' + "#, + ) + .bind(thread_id.as_str()) + .execute(&mut *tx) + .await? + .rows_affected(); + + if rows_affected == 0 { + tx.commit().await?; + return Ok(false); + } + + let selected_for_phase2 = sqlx::query_scalar::<_, i64>( + r#" +SELECT selected_for_phase2 +FROM stage1_outputs +WHERE thread_id = ? + "#, + ) + .bind(thread_id.as_str()) + .fetch_optional(&mut *tx) + .await? + .unwrap_or(0); + if selected_for_phase2 != 0 { + enqueue_global_consolidation_with_executor(&mut *tx, now).await?; + } + + tx.commit().await?; + Ok(true) + } + /// Attempts to claim a stage-1 job for a thread at `source_updated_at`. /// /// Claim semantics: @@ -2433,6 +2482,71 @@ VALUES (?, ?, ?, ?, ?) let _ = tokio::fs::remove_dir_all(codex_home).await; } + #[tokio::test] + async fn list_stage1_outputs_for_global_skips_polluted_threads() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) + .await + .expect("initialize runtime"); + + let thread_id_enabled = + ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); + let thread_id_polluted = + ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); + let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); + + for (thread_id, workspace) in [ + (thread_id_enabled, "workspace-enabled"), + (thread_id_polluted, "workspace-polluted"), + ] { + runtime + .upsert_thread(&test_thread_metadata( + &codex_home, + thread_id, + codex_home.join(workspace), + )) + .await + .expect("upsert thread"); + + let claim = runtime + .try_claim_stage1_job(thread_id, owner, 100, 3600, 64) + .await + .expect("claim stage1"); + let ownership_token = match claim { + Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, + other => panic!("unexpected stage1 claim outcome: {other:?}"), + }; + assert!( + runtime + .mark_stage1_job_succeeded( + thread_id, + ownership_token.as_str(), + 100, + "raw memory", + "summary", + None, + ) + .await + .expect("mark stage1 succeeded"), + "stage1 success should persist output" + ); + } + + runtime + .set_thread_memory_mode(thread_id_polluted, "polluted") + .await + .expect("mark thread polluted"); + + let outputs = runtime + .list_stage1_outputs_for_global(10) + .await + .expect("list stage1 outputs for global"); + assert_eq!(outputs.len(), 1); + assert_eq!(outputs[0].thread_id, thread_id_enabled); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } + #[tokio::test] async fn get_phase2_input_selection_reports_added_retained_and_removed_rows() { let codex_home = unique_temp_dir(); @@ -2545,6 +2659,197 @@ VALUES (?, ?, ?, ?, ?) let _ = tokio::fs::remove_dir_all(codex_home).await; } + #[tokio::test] + async fn get_phase2_input_selection_marks_polluted_previous_selection_as_removed() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) + .await + .expect("initialize runtime"); + + let thread_id_enabled = + ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); + let thread_id_polluted = + ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); + let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); + + for (thread_id, updated_at) in [(thread_id_enabled, 100), (thread_id_polluted, 101)] { + runtime + .upsert_thread(&test_thread_metadata( + &codex_home, + thread_id, + codex_home.join(thread_id.to_string()), + )) + .await + .expect("upsert thread"); + + let claim = runtime + .try_claim_stage1_job(thread_id, owner, updated_at, 3600, 64) + .await + .expect("claim stage1"); + let ownership_token = match claim { + Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, + other => panic!("unexpected stage1 claim outcome: {other:?}"), + }; + assert!( + runtime + .mark_stage1_job_succeeded( + thread_id, + ownership_token.as_str(), + updated_at, + &format!("raw-{updated_at}"), + &format!("summary-{updated_at}"), + None, + ) + .await + .expect("mark stage1 succeeded"), + "stage1 success should persist output" + ); + } + + let claim = runtime + .try_claim_global_phase2_job(owner, 3600) + .await + .expect("claim phase2"); + let (ownership_token, input_watermark) = match claim { + Phase2JobClaimOutcome::Claimed { + ownership_token, + input_watermark, + } => (ownership_token, input_watermark), + other => panic!("unexpected phase2 claim outcome: {other:?}"), + }; + let selected_outputs = runtime + .list_stage1_outputs_for_global(10) + .await + .expect("list stage1 outputs for global"); + assert!( + runtime + .mark_global_phase2_job_succeeded( + ownership_token.as_str(), + input_watermark, + &selected_outputs, + ) + .await + .expect("mark phase2 success"), + "phase2 success should persist selected rows" + ); + + runtime + .set_thread_memory_mode(thread_id_polluted, "polluted") + .await + .expect("mark thread polluted"); + + let selection = runtime + .get_phase2_input_selection(2, 36_500) + .await + .expect("load phase2 input selection"); + + assert_eq!(selection.selected.len(), 1); + assert_eq!(selection.selected[0].thread_id, thread_id_enabled); + assert_eq!(selection.previous_selected.len(), 2); + assert!( + selection + .previous_selected + .iter() + .any(|item| item.thread_id == thread_id_enabled) + ); + assert!( + selection + .previous_selected + .iter() + .any(|item| item.thread_id == thread_id_polluted) + ); + assert_eq!(selection.retained_thread_ids, vec![thread_id_enabled]); + assert_eq!(selection.removed.len(), 1); + assert_eq!(selection.removed[0].thread_id, thread_id_polluted); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } + + #[tokio::test] + async fn mark_thread_memory_mode_polluted_enqueues_phase2_for_selected_threads() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) + .await + .expect("initialize runtime"); + + let thread_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); + let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); + runtime + .upsert_thread(&test_thread_metadata( + &codex_home, + thread_id, + codex_home.join("workspace"), + )) + .await + .expect("upsert thread"); + + let claim = runtime + .try_claim_stage1_job(thread_id, owner, 100, 3600, 64) + .await + .expect("claim stage1"); + let ownership_token = match claim { + Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, + other => panic!("unexpected stage1 claim outcome: {other:?}"), + }; + assert!( + runtime + .mark_stage1_job_succeeded( + thread_id, + ownership_token.as_str(), + 100, + "raw", + "summary", + None, + ) + .await + .expect("mark stage1 succeeded"), + "stage1 success should persist output" + ); + + let phase2_claim = runtime + .try_claim_global_phase2_job(owner, 3600) + .await + .expect("claim phase2"); + let (phase2_token, input_watermark) = match phase2_claim { + Phase2JobClaimOutcome::Claimed { + ownership_token, + input_watermark, + } => (ownership_token, input_watermark), + other => panic!("unexpected phase2 claim outcome: {other:?}"), + }; + let selected_outputs = runtime + .list_stage1_outputs_for_global(10) + .await + .expect("list stage1 outputs"); + assert!( + runtime + .mark_global_phase2_job_succeeded( + phase2_token.as_str(), + input_watermark, + &selected_outputs, + ) + .await + .expect("mark phase2 success"), + "phase2 success should persist selected rows" + ); + + assert!( + runtime + .mark_thread_memory_mode_polluted(thread_id) + .await + .expect("mark thread polluted"), + "thread should transition to polluted" + ); + + let next_claim = runtime + .try_claim_global_phase2_job(owner, 3600) + .await + .expect("claim phase2 after pollution"); + assert!(matches!(next_claim, Phase2JobClaimOutcome::Claimed { .. })); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } + #[tokio::test] async fn get_phase2_input_selection_treats_regenerated_selected_rows_as_added() { let codex_home = unique_temp_dir(); diff --git a/codex-rs/state/src/runtime/threads.rs b/codex-rs/state/src/runtime/threads.rs index e02cfb3e4a2..f9e939a4a15 100644 --- a/codex-rs/state/src/runtime/threads.rs +++ b/codex-rs/state/src/runtime/threads.rs @@ -35,6 +35,14 @@ WHERE id = ? .transpose() } + pub async fn get_thread_memory_mode(&self, id: ThreadId) -> anyhow::Result> { + let row = sqlx::query("SELECT memory_mode FROM threads WHERE id = ?") + .bind(id.to_string()) + .fetch_optional(self.pool.as_ref()) + .await?; + Ok(row.and_then(|row| row.try_get("memory_mode").ok())) + } + /// Get dynamic tools for a thread, if present. pub async fn get_dynamic_tools( &self, @@ -199,6 +207,19 @@ FROM threads .await } + pub async fn set_thread_memory_mode( + &self, + thread_id: ThreadId, + memory_mode: &str, + ) -> anyhow::Result { + let result = sqlx::query("UPDATE threads SET memory_mode = ? WHERE id = ?") + .bind(memory_mode) + .bind(thread_id.to_string()) + .execute(self.pool.as_ref()) + .await?; + Ok(result.rows_affected() > 0) + } + async fn upsert_thread_with_creation_memory_mode( &self, metadata: &crate::ThreadMetadata, @@ -357,6 +378,16 @@ ON CONFLICT(thread_id, position) DO NOTHING } return Err(err); } + if let Some(memory_mode) = extract_memory_mode(items) + && let Err(err) = self + .set_thread_memory_mode(builder.id, memory_mode.as_str()) + .await + { + if let Some(otel) = otel { + otel.counter(DB_ERROR_METRIC, 1, &[("stage", "set_thread_memory_mode")]); + } + return Err(err); + } let dynamic_tools = extract_dynamic_tools(items); if let Some(dynamic_tools) = dynamic_tools && let Err(err) = self @@ -438,6 +469,16 @@ pub(super) fn extract_dynamic_tools(items: &[RolloutItem]) -> Option Option { + items.iter().rev().find_map(|item| match item { + RolloutItem::SessionMeta(meta_line) => meta_line.meta.memory_mode.clone(), + RolloutItem::ResponseItem(_) + | RolloutItem::Compacted(_) + | RolloutItem::TurnContext(_) + | RolloutItem::EventMsg(_) => None, + }) +} + pub(super) fn push_thread_filters<'a>( builder: &mut QueryBuilder<'a, Sqlite>, archived_only: bool, @@ -518,7 +559,11 @@ mod tests { use super::*; use crate::runtime::test_support::test_thread_metadata; use crate::runtime::test_support::unique_temp_dir; + use codex_protocol::protocol::SessionMeta; + use codex_protocol::protocol::SessionMetaLine; + use codex_protocol::protocol::SessionSource; use pretty_assertions::assert_eq; + use std::path::PathBuf; #[tokio::test] async fn upsert_thread_keeps_creation_memory_mode_for_existing_rows() { @@ -557,4 +602,56 @@ mod tests { .expect("memory mode should remain readable"); assert_eq!(memory_mode, "disabled"); } + + #[tokio::test] + async fn apply_rollout_items_restores_memory_mode_from_session_meta() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) + .await + .expect("state db should initialize"); + let thread_id = + ThreadId::from_string("00000000-0000-0000-0000-000000000456").expect("valid thread id"); + let metadata = test_thread_metadata(&codex_home, thread_id, codex_home.clone()); + + runtime + .upsert_thread(&metadata) + .await + .expect("initial upsert should succeed"); + + let builder = ThreadMetadataBuilder::new( + thread_id, + metadata.rollout_path.clone(), + metadata.created_at, + SessionSource::Cli, + ); + let items = vec![RolloutItem::SessionMeta(SessionMetaLine { + meta: SessionMeta { + id: thread_id, + forked_from_id: None, + timestamp: metadata.created_at.to_rfc3339(), + cwd: PathBuf::new(), + originator: String::new(), + cli_version: String::new(), + source: SessionSource::Cli, + agent_nickname: None, + agent_role: None, + model_provider: None, + base_instructions: None, + dynamic_tools: None, + memory_mode: Some("polluted".to_string()), + }, + git: None, + })]; + + runtime + .apply_rollout_items(&builder, &items, None, None) + .await + .expect("apply_rollout_items should succeed"); + + let memory_mode = runtime + .get_thread_memory_mode(thread_id) + .await + .expect("memory mode should load"); + assert_eq!(memory_mode.as_deref(), Some("polluted")); + } } From 5f857c2540985006ea1d3f22b288cd4bcbff7c8b Mon Sep 17 00:00:00 2001 From: jif-oai Date: Fri, 27 Feb 2026 14:15:51 +0100 Subject: [PATCH 2/4] nit fix 2 --- codex-rs/app-server/tests/common/rollout.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/codex-rs/app-server/tests/common/rollout.rs b/codex-rs/app-server/tests/common/rollout.rs index 8122e7cd012..8146f7ae93b 100644 --- a/codex-rs/app-server/tests/common/rollout.rs +++ b/codex-rs/app-server/tests/common/rollout.rs @@ -84,6 +84,7 @@ pub fn create_fake_rollout_with_source( model_provider: model_provider.map(str::to_string), base_instructions: None, dynamic_tools: None, + memory_mode: None, }; let payload = serde_json::to_value(SessionMetaLine { meta, @@ -165,6 +166,7 @@ pub fn create_fake_rollout_with_text_elements( model_provider: model_provider.map(str::to_string), base_instructions: None, dynamic_tools: None, + memory_mode: None, }; let payload = serde_json::to_value(SessionMetaLine { meta, From 44738dc36bd92f7861b1f082f0bd0a955fe49b83 Mon Sep 17 00:00:00 2001 From: jif-oai Date: Fri, 27 Feb 2026 14:40:39 +0100 Subject: [PATCH 3/4] add integration tests --- codex-rs/core/tests/suite/memories.rs | 219 ++++++++++++++++++++++---- 1 file changed, 190 insertions(+), 29 deletions(-) diff --git a/codex-rs/core/tests/suite/memories.rs b/codex-rs/core/tests/suite/memories.rs index 8f96178e867..fa14ec6ab8b 100644 --- a/codex-rs/core/tests/suite/memories.rs +++ b/codex-rs/core/tests/suite/memories.rs @@ -11,7 +11,9 @@ use core_test_support::responses::ResponsesRequest; use core_test_support::responses::ev_assistant_message; use core_test_support::responses::ev_completed; use core_test_support::responses::ev_response_created; +use core_test_support::responses::ev_web_search_call_done; use core_test_support::responses::mount_sse_once; +use core_test_support::responses::mount_sse_sequence; use core_test_support::responses::sse; use core_test_support::responses::start_mock_server; use core_test_support::test_codex::TestCodex; @@ -157,6 +159,145 @@ async fn memories_startup_phase2_tracks_added_and_removed_inputs_across_runs() - Ok(()) } +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn web_search_pollution_removes_selected_thread_from_phase2_inputs() -> Result<()> { + let server = start_mock_server().await; + let home = Arc::new(TempDir::new()?); + let db = init_state_db(&home).await?; + + let mut initial_builder = test_codex().with_home(home.clone()).with_config(|config| { + config.features.enable(Feature::Sqlite); + config.features.enable(Feature::MemoryTool); + config.memories.max_raw_memories_for_global = 1; + config.memories.no_memories_if_mcp_or_web_search = true; + }); + let initial = initial_builder.build(&server).await?; + let rollout_path = initial + .session_configured + .rollout_path + .clone() + .expect("rollout path"); + let thread_id = initial.session_configured.session_id; + let updated_at = { + let deadline = Instant::now() + Duration::from_secs(10); + loop { + if let Some(metadata) = db.get_thread(thread_id).await? { + break metadata.updated_at; + } + assert!( + Instant::now() < deadline, + "timed out waiting for thread metadata for {thread_id}" + ); + tokio::time::sleep(Duration::from_millis(50)).await; + } + }; + + seed_stage1_output_for_existing_thread( + db.as_ref(), + thread_id, + updated_at.timestamp(), + "raw memory seeded for web search pollution", + "rollout summary seeded for web search pollution", + Some("pollution-rollout"), + ) + .await?; + + shutdown_test_codex(&initial).await?; + + let responses = mount_sse_sequence( + &server, + vec![ + sse(vec![ + ev_response_created("resp-phase2-1"), + ev_assistant_message("msg-phase2-1", "phase2 complete"), + ev_completed("resp-phase2-1"), + ]), + sse(vec![ + ev_response_created("resp-web-1"), + ev_web_search_call_done("ws-1", "completed", "weather seattle"), + ev_completed("resp-web-1"), + ]), + sse(vec![ + ev_response_created("resp-phase2-2"), + ev_assistant_message("msg-phase2-2", "phase2 after pollution complete"), + ev_completed("resp-phase2-2"), + ]), + ], + ) + .await; + + let mut resumed_builder = test_codex().with_home(home.clone()).with_config(|config| { + config.features.enable(Feature::Sqlite); + config.features.enable(Feature::MemoryTool); + config.memories.max_raw_memories_for_global = 1; + config.memories.no_memories_if_mcp_or_web_search = true; + }); + let resumed = resumed_builder + .resume(&server, home.clone(), rollout_path) + .await?; + + let first_phase2_request = wait_for_request(&responses, 1).await.remove(0); + let first_phase2_prompt = phase2_prompt_text(&first_phase2_request); + assert!( + first_phase2_prompt.contains("- selected inputs this run: 1"), + "expected seeded thread to be selected before pollution: {first_phase2_prompt}" + ); + assert!( + first_phase2_prompt.contains("- newly added since the last successful Phase 2 run: 1"), + "expected seeded thread to be added before pollution: {first_phase2_prompt}" + ); + assert!( + first_phase2_prompt.contains(&format!("- [added] thread_id={thread_id},")), + "expected selected thread in first phase2 prompt: {first_phase2_prompt}" + ); + + wait_for_phase2_success(db.as_ref(), thread_id).await?; + + resumed + .submit_turn("search the web for weather seattle") + .await?; + assert_eq!( + { + let deadline = Instant::now() + Duration::from_secs(10); + loop { + let memory_mode = db.get_thread_memory_mode(thread_id).await?; + if memory_mode.as_deref() == Some("polluted") { + break memory_mode; + } + assert!( + Instant::now() < deadline, + "timed out waiting for polluted memory mode for {thread_id}" + ); + tokio::time::sleep(Duration::from_millis(50)).await; + } + } + .as_deref(), + Some("polluted") + ); + + let requests = wait_for_request(&responses, 3).await; + let second_phase2_prompt = phase2_prompt_text(&requests[2]); + assert!( + second_phase2_prompt.contains("- selected inputs this run: 0"), + "expected polluted thread to be excluded from selected inputs: {second_phase2_prompt}" + ); + assert!( + second_phase2_prompt.contains("- newly added since the last successful Phase 2 run: 0"), + "expected no added inputs after pollution: {second_phase2_prompt}" + ); + assert!( + second_phase2_prompt.contains("- removed from the last successful Phase 2 run: 1"), + "expected polluted thread to show up as removed: {second_phase2_prompt}" + ); + assert!( + second_phase2_prompt.contains(&format!("- thread_id={thread_id},")), + "expected polluted thread in removed section: {second_phase2_prompt}" + ); + + shutdown_test_codex(&resumed).await?; + Ok(()) +} + async fn build_test_codex(server: &wiremock::MockServer, home: Arc) -> Result { let mut builder = test_codex().with_home(home).with_config(|config| { config.features.enable(Feature::Sqlite); @@ -195,46 +336,33 @@ async fn seed_stage1_output( let metadata = metadata_builder.build("test-provider"); db.upsert_thread(&metadata).await?; - let claim = db - .try_claim_stage1_job( - thread_id, - ThreadId::new(), - updated_at.timestamp(), - 3_600, - 64, - ) - .await?; - let ownership_token = match claim { - codex_state::Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, - other => panic!("unexpected stage-1 claim outcome: {other:?}"), - }; - - assert!( - db.mark_stage1_job_succeeded( - thread_id, - &ownership_token, - updated_at.timestamp(), - raw_memory, - rollout_summary, - Some(rollout_slug), - ) - .await?, - "stage-1 success should enqueue global consolidation" - ); + seed_stage1_output_for_existing_thread( + db, + thread_id, + updated_at.timestamp(), + raw_memory, + rollout_summary, + Some(rollout_slug), + ) + .await?; Ok(thread_id) } async fn wait_for_single_request(mock: &ResponseMock) -> ResponsesRequest { + wait_for_request(mock, 1).await.remove(0) +} + +async fn wait_for_request(mock: &ResponseMock, expected_count: usize) -> Vec { let deadline = Instant::now() + Duration::from_secs(10); loop { let requests = mock.requests(); - if let Some(request) = requests.into_iter().next() { - return request; + if requests.len() >= expected_count { + return requests; } assert!( Instant::now() < deadline, - "timed out waiting for phase2 request" + "timed out waiting for {expected_count} phase2 requests" ); tokio::time::sleep(Duration::from_millis(50)).await; } @@ -272,6 +400,39 @@ async fn wait_for_phase2_success( } } +async fn seed_stage1_output_for_existing_thread( + db: &codex_state::StateRuntime, + thread_id: ThreadId, + updated_at: i64, + raw_memory: &str, + rollout_summary: &str, + rollout_slug: Option<&str>, +) -> Result<()> { + let owner = ThreadId::new(); + let claim = db + .try_claim_stage1_job(thread_id, owner, updated_at, 3_600, 64) + .await?; + let ownership_token = match claim { + codex_state::Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, + other => panic!("unexpected stage-1 claim outcome: {other:?}"), + }; + + assert!( + db.mark_stage1_job_succeeded( + thread_id, + &ownership_token, + updated_at, + raw_memory, + rollout_summary, + rollout_slug, + ) + .await?, + "stage-1 success should enqueue global consolidation" + ); + + Ok(()) +} + async fn read_rollout_summary_bodies(memory_root: &Path) -> Result> { let mut dir = tokio::fs::read_dir(memory_root.join("rollout_summaries")).await?; let mut summaries = Vec::new(); From 5d4165caa36e2443125968ecbf7017311b87e12f Mon Sep 17 00:00:00 2001 From: jif-oai Date: Fri, 27 Feb 2026 14:45:13 +0100 Subject: [PATCH 4/4] fix test --- codex-rs/core/tests/suite/memories.rs | 61 ++++++++++++++++----------- 1 file changed, 36 insertions(+), 25 deletions(-) diff --git a/codex-rs/core/tests/suite/memories.rs b/codex-rs/core/tests/suite/memories.rs index fa14ec6ab8b..fce663cf93c 100644 --- a/codex-rs/core/tests/suite/memories.rs +++ b/codex-rs/core/tests/suite/memories.rs @@ -160,7 +160,7 @@ async fn memories_startup_phase2_tracks_added_and_removed_inputs_across_runs() - } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn web_search_pollution_removes_selected_thread_from_phase2_inputs() -> Result<()> { +async fn web_search_pollution_moves_selected_thread_into_removed_phase2_inputs() -> Result<()> { let server = start_mock_server().await; let home = Arc::new(TempDir::new()?); let db = init_state_db(&home).await?; @@ -172,6 +172,16 @@ async fn web_search_pollution_removes_selected_thread_from_phase2_inputs() -> Re config.memories.no_memories_if_mcp_or_web_search = true; }); let initial = initial_builder.build(&server).await?; + mount_sse_once( + &server, + sse(vec![ + ev_response_created("resp-initial-1"), + ev_assistant_message("msg-initial-1", "initial turn complete"), + ev_completed("resp-initial-1"), + ]), + ) + .await; + initial.submit_turn("hello before memories").await?; let rollout_path = initial .session_configured .rollout_path @@ -217,11 +227,6 @@ async fn web_search_pollution_removes_selected_thread_from_phase2_inputs() -> Re ev_web_search_call_done("ws-1", "completed", "weather seattle"), ev_completed("resp-web-1"), ]), - sse(vec![ - ev_response_created("resp-phase2-2"), - ev_assistant_message("msg-phase2-2", "phase2 after pollution complete"), - ev_completed("resp-phase2-2"), - ]), ], ) .await; @@ -233,7 +238,7 @@ async fn web_search_pollution_removes_selected_thread_from_phase2_inputs() -> Re config.memories.no_memories_if_mcp_or_web_search = true; }); let resumed = resumed_builder - .resume(&server, home.clone(), rollout_path) + .resume(&server, home.clone(), rollout_path.clone()) .await?; let first_phase2_request = wait_for_request(&responses, 1).await.remove(0); @@ -275,24 +280,30 @@ async fn web_search_pollution_removes_selected_thread_from_phase2_inputs() -> Re Some("polluted") ); - let requests = wait_for_request(&responses, 3).await; - let second_phase2_prompt = phase2_prompt_text(&requests[2]); - assert!( - second_phase2_prompt.contains("- selected inputs this run: 0"), - "expected polluted thread to be excluded from selected inputs: {second_phase2_prompt}" - ); - assert!( - second_phase2_prompt.contains("- newly added since the last successful Phase 2 run: 0"), - "expected no added inputs after pollution: {second_phase2_prompt}" - ); - assert!( - second_phase2_prompt.contains("- removed from the last successful Phase 2 run: 1"), - "expected polluted thread to show up as removed: {second_phase2_prompt}" - ); - assert!( - second_phase2_prompt.contains(&format!("- thread_id={thread_id},")), - "expected polluted thread in removed section: {second_phase2_prompt}" - ); + let selection = { + let deadline = Instant::now() + Duration::from_secs(10); + loop { + let selection = db.get_phase2_input_selection(1, 30).await?; + if selection.selected.is_empty() + && selection.retained_thread_ids.is_empty() + && selection.removed.len() == 1 + && selection.removed[0].thread_id == thread_id + { + break selection; + } + assert!( + Instant::now() < deadline, + "timed out waiting for polluted thread to move into removed phase2 inputs: \ + {selection:?}" + ); + tokio::time::sleep(Duration::from_millis(50)).await; + } + }; + assert_eq!(responses.requests().len(), 2); + assert!(selection.selected.is_empty()); + assert_eq!(selection.retained_thread_ids, Vec::::new()); + assert_eq!(selection.removed.len(), 1); + assert_eq!(selection.removed[0].thread_id, thread_id); shutdown_test_codex(&resumed).await?; Ok(())