diff --git a/Cargo.lock b/Cargo.lock index fdb72f2..a83c23c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -553,6 +553,7 @@ name = "embers-cli" version = "0.1.0" dependencies = [ "assert_cmd", + "base64", "clap", "embers-client", "embers-core", @@ -613,6 +614,7 @@ name = "embers-server" version = "0.1.0" dependencies = [ "alacritty_terminal", + "base64", "embers-core", "embers-protocol", "portable-pty", @@ -632,6 +634,7 @@ dependencies = [ "embers-core", "embers-protocol", "embers-server", + "libc", "portable-pty", "tempfile", "tokio", diff --git a/crates/embers-cli/Cargo.toml b/crates/embers-cli/Cargo.toml index 6ea3cbc..815c323 100644 --- a/crates/embers-cli/Cargo.toml +++ b/crates/embers-cli/Cargo.toml @@ -18,6 +18,7 @@ name = "embers-cli" path = "src/bin/embers-cli.rs" [dependencies] +base64.workspace = true clap.workspace = true embers-client = { path = "../embers-client" } embers-core = { path = "../embers-core" } diff --git a/crates/embers-cli/src/lib.rs b/crates/embers-cli/src/lib.rs index 0392faf..a8cea15 100644 --- a/crates/embers-cli/src/lib.rs +++ b/crates/embers-cli/src/lib.rs @@ -1,13 +1,19 @@ mod interactive; +use std::ffi::OsString; use std::fs::{self, OpenOptions}; use std::io::Write; use std::num::NonZeroU64; #[cfg(unix)] +use std::os::unix::ffi::OsStringExt; +#[cfg(unix)] use std::os::unix::fs::{MetadataExt, OpenOptionsExt, PermissionsExt}; +#[cfg(windows)] +use std::os::windows::ffi::OsStringExt; use std::path::{Path, PathBuf}; use std::process::{Command as ProcessCommand, Stdio}; +use base64::Engine as _; use clap::{Parser, Subcommand}; use embers_core::{ BufferId, FloatGeometry, FloatingId, MuxError, NodeId, Result, SessionId, SplitDirection, @@ -71,6 +77,21 @@ pub enum Command { }, #[command(name = "__serve", hide = true)] Serve, + #[command(name = "__runtime-keeper", hide = true)] + RuntimeKeeper { + #[arg(long = "keeper-socket")] + keeper_socket: PathBuf, + #[arg(long)] + cols: u16, + #[arg(long)] + rows: u16, + #[arg(long)] + cwd: Option, + #[arg(long = "env", value_parser = parse_env_arg)] + env: Vec<(String, OsString)>, + #[arg(last = true)] + command: Vec, + }, Ping { #[arg(default_value = "phase0")] payload: String, @@ -226,9 +247,9 @@ async fn execute(socket: &Path, command: Command) -> Result { let mut connection = CliConnection::connect(socket).await?; match command { - Command::Attach { .. } | Command::Serve => Err(MuxError::internal( - "interactive commands must be dispatched through run()", - )), + Command::Attach { .. } | Command::Serve | Command::RuntimeKeeper { .. } => Err( + MuxError::internal("interactive commands must be dispatched through run()"), + ), Command::Ping { payload } => { let response = connection .request(ClientMessage::Ping(PingRequest { @@ -621,31 +642,56 @@ async fn execute(socket: &Path, command: Command) -> Result { } pub async fn run(cli: Cli) -> Result<()> { - let socket = resolve_socket_path(cli.socket.as_deref()); - validate_runtime_socket_parent(&socket)?; + let Cli { + socket, + config, + command, + .. + } = cli; - match cli.command { - None => { - ensure_server_process(&socket).await?; - interactive::run(socket, None, cli.config).await - } - Some(Command::Attach { target }) => { - if !server_is_available(&socket).await { - return Err(MuxError::not_found(format!( - "no embers server is listening on {}", - socket.display() - ))); - } - interactive::run(socket, target, cli.config).await - } - Some(Command::Serve) => run_server(socket).await, - Some(command) => { - ensure_server_process(&socket).await?; - let output = execute(&socket, command).await?; - if !output.is_empty() { - println!("{output}"); + match command { + Some(Command::RuntimeKeeper { + keeper_socket, + cols, + rows, + cwd, + env, + command, + }) => embers_server::run_runtime_keeper(embers_server::RuntimeKeeperCli { + socket_path: keeper_socket, + command, + cwd, + env: env.into_iter().collect(), + size: embers_core::PtySize::new(cols, rows), + }), + command => { + let socket = resolve_socket_path(socket.as_deref()); + validate_runtime_socket_parent(&socket)?; + + match command { + None => { + ensure_server_process(&socket).await?; + interactive::run(socket, None, config).await + } + Some(Command::Attach { target }) => { + if !server_is_available(&socket).await { + return Err(MuxError::not_found(format!( + "no embers server is listening on {}", + socket.display() + ))); + } + interactive::run(socket, target, config).await + } + Some(Command::Serve) => run_server(socket).await, + Some(command) => { + ensure_server_process(&socket).await?; + let output = execute(&socket, command).await?; + if !output.is_empty() { + println!("{output}"); + } + Ok(()) + } } - Ok(()) } } } @@ -676,6 +722,46 @@ fn default_runtime_dir() -> PathBuf { PathBuf::from("/tmp").join(format!("embers-{}", effective_uid())) } +fn parse_env_arg(value: &str) -> std::result::Result<(String, OsString), String> { + let Some((key, env_value)) = value.split_once('=') else { + return Err("expected KEY=VALUE".to_owned()); + }; + if key.is_empty() { + return Err("environment key must not be empty".to_owned()); + } + Ok((key.to_owned(), decode_runtime_keeper_env_value(env_value)?)) +} + +fn decode_runtime_keeper_env_value(value: &str) -> std::result::Result { + let Some(encoded) = value.strip_prefix("base64:") else { + return Ok(OsString::from(value)); + }; + let decoded = base64::engine::general_purpose::STANDARD + .decode(encoded) + .map_err(|error| format!("invalid base64 environment value: {error}"))?; + #[cfg(unix)] + { + Ok(OsString::from_vec(decoded)) + } + #[cfg(windows)] + { + if decoded.len() % 2 != 0 { + return Err("invalid UTF-16LE environment value: odd-length byte sequence".to_owned()); + } + let wide = decoded + .chunks_exact(2) + .map(|chunk| u16::from_le_bytes([chunk[0], chunk[1]])) + .collect::>(); + Ok(OsString::from_wide(&wide)) + } + #[cfg(all(not(unix), not(windows)))] + { + String::from_utf8(decoded) + .map(OsString::from) + .map_err(|error| format!("invalid UTF-8 environment value: {error}")) + } +} + #[cfg(unix)] fn effective_uid() -> u32 { unsafe { libc::geteuid() } @@ -1585,9 +1671,20 @@ fn default_title(command: &[String], fallback: &str) -> String { #[cfg(test)] mod tests { + #[cfg(windows)] + use base64::Engine as _; use clap::Parser; use embers_core::NodeId; use embers_protocol::{TabRecord, TabsRecord}; + #[cfg(windows)] + use std::ffi::OsString; + #[cfg(unix)] + use std::ffi::OsString; + #[cfg(unix)] + use std::os::unix::ffi::OsStringExt; + #[cfg(windows)] + use std::os::windows::ffi::OsStringExt; + use std::path::Path; use super::{Cli, resolve_window_index, split_scoped_required, split_scoped_target}; @@ -1614,6 +1711,59 @@ mod tests { } } + #[test] + fn runtime_keeper_uses_distinct_keeper_socket_flag() { + let cli = Cli::try_parse_from([ + "embers", + "__runtime-keeper", + "--socket", + "/tmp/global.sock", + "--keeper-socket", + "/tmp/keeper.sock", + "--cols", + "80", + "--rows", + "24", + "--", + "/bin/sh", + ]) + .expect("cli parses"); + + assert_eq!(cli.socket.as_deref(), Some(Path::new("/tmp/global.sock"))); + match cli.command { + Some(super::Command::RuntimeKeeper { + keeper_socket, + cols, + rows, + command, + .. + }) => { + assert_eq!(keeper_socket, Path::new("/tmp/keeper.sock")); + assert_eq!((cols, rows), (80, 24)); + assert_eq!(command, vec!["/bin/sh"]); + } + other => panic!("expected runtime keeper command, got {other:?}"), + } + } + + #[cfg(unix)] + #[test] + fn runtime_keeper_env_values_decode_base64_losslessly() { + let (key, value) = super::parse_env_arg("KEY=base64:AP8=").expect("env parses"); + assert_eq!(key, "KEY"); + assert_eq!(value, OsString::from_vec(vec![0, 255])); + } + + #[cfg(windows)] + #[test] + fn runtime_keeper_env_values_decode_utf16le_losslessly() { + let encoded = base64::engine::general_purpose::STANDARD.encode([0x00, 0xD8, 0x61, 0x00]); + let (key, value) = + super::parse_env_arg(&format!("KEY=base64:{encoded}")).expect("env parses"); + assert_eq!(key, "KEY"); + assert_eq!(value, OsString::from_wide(&[0xD800, 0x0061])); + } + #[test] fn scoped_targets_split_session_prefix() { assert_eq!( diff --git a/crates/embers-cli/tests/interactive.rs b/crates/embers-cli/tests/interactive.rs index 3d93919..0577d3b 100644 --- a/crates/embers-cli/tests/interactive.rs +++ b/crates/embers-cli/tests/interactive.rs @@ -3,7 +3,7 @@ use std::path::Path; use std::time::Duration; use embers_core::PtySize; -use embers_test_support::{PtyHarness, TestServer, cargo_bin, cargo_bin_path}; +use embers_test_support::{PtyHarness, TestServer, acquire_test_lock, cargo_bin, cargo_bin_path}; use tempfile::tempdir; use crate::support::{run_cli, stdout}; @@ -165,6 +165,7 @@ fn first_client_id_finds_attached_row() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn embers_without_subcommand_starts_server_and_client() { + let _guard = acquire_test_lock().await.expect("acquire test lock"); let tempdir = tempdir().expect("tempdir"); let socket_path = tempdir.path().join("embers.sock"); let socket_arg = socket_path.to_string_lossy().into_owned(); @@ -204,6 +205,7 @@ async fn embers_without_subcommand_starts_server_and_client() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn attach_subcommand_connects_to_running_server() { + let _guard = acquire_test_lock().await.expect("acquire test lock"); let server = TestServer::start().await.expect("start server"); let binary = cargo_bin_path("embers"); let binary_dir = binary.parent().expect("binary dir"); @@ -248,6 +250,7 @@ async fn attach_subcommand_connects_to_running_server() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn client_commands_can_switch_and_detach_a_live_attached_client() { + let _guard = acquire_test_lock().await.expect("acquire test lock"); let server = TestServer::start().await.expect("start server"); run_cli(&server, ["new-session", "main"]); @@ -302,6 +305,7 @@ async fn client_commands_can_switch_and_detach_a_live_attached_client() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn page_up_enters_local_scrollback_and_shows_indicator() { + let _guard = acquire_test_lock().await.expect("acquire test lock"); let tempdir = tempdir().expect("tempdir"); let socket_path = tempdir.path().join("embers.sock"); let socket_arg = socket_path.to_string_lossy().into_owned(); @@ -321,6 +325,7 @@ async fn page_up_enters_local_scrollback_and_shows_indicator() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn local_selection_yank_emits_osc52_clipboard_sequence() { + let _guard = acquire_test_lock().await.expect("acquire test lock"); let tempdir = tempdir().expect("tempdir"); let socket_path = tempdir.path().join("embers.sock"); let socket_arg = socket_path.to_string_lossy().into_owned(); diff --git a/crates/embers-cli/tests/panes.rs b/crates/embers-cli/tests/panes.rs index 4e06071..8f935e0 100644 --- a/crates/embers-cli/tests/panes.rs +++ b/crates/embers-cli/tests/panes.rs @@ -2,13 +2,14 @@ use std::time::Duration; use embers_core::RequestId; use embers_protocol::{BufferRequest, ClientMessage, ServerResponse}; -use embers_test_support::{TestConnection, TestServer}; +use embers_test_support::{TestConnection, TestServer, acquire_test_lock}; use tokio::time::sleep; use crate::support::{run_cli, session_snapshot_by_name, stdout}; #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn pane_commands_round_trip_through_cli() { + let _guard = acquire_test_lock().await.expect("acquire test lock"); let server = TestServer::start().await.expect("start server"); run_cli(&server, ["new-session", "alpha"]); @@ -129,6 +130,7 @@ async fn pane_commands_round_trip_through_cli() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn detached_buffers_can_be_listed_and_attached_via_cli() { + let _guard = acquire_test_lock().await.expect("acquire test lock"); let server = TestServer::start().await.expect("start server"); run_cli(&server, ["new-session", "alpha"]); diff --git a/crates/embers-core/src/metadata.rs b/crates/embers-core/src/metadata.rs index 5523186..718d6fb 100644 --- a/crates/embers-core/src/metadata.rs +++ b/crates/embers-core/src/metadata.rs @@ -1,7 +1,9 @@ use std::path::PathBuf; use std::time::SystemTime; -#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] pub struct Timestamp(pub SystemTime); impl Timestamp { @@ -16,7 +18,7 @@ impl Default for Timestamp { } } -#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)] +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Serialize, Deserialize)] pub enum ActivityState { #[default] Idle, @@ -24,7 +26,7 @@ pub enum ActivityState { Bell, } -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct EntityMetadata { pub title: Option, pub cwd: Option, diff --git a/crates/embers-core/src/snapshot.rs b/crates/embers-core/src/snapshot.rs index 345667f..1e40994 100644 --- a/crates/embers-core/src/snapshot.rs +++ b/crates/embers-core/src/snapshot.rs @@ -1,14 +1,16 @@ use std::path::PathBuf; +use serde::{Deserialize, Serialize}; + use crate::geometry::PtySize; -#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)] +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Serialize, Deserialize)] pub struct CursorPosition { pub row: u16, pub col: u16, } -#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)] +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Serialize, Deserialize)] pub enum CursorShape { #[default] Block, @@ -16,13 +18,13 @@ pub enum CursorShape { Beam, } -#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct CursorState { pub position: CursorPosition, pub shape: CursorShape, } -#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)] +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Serialize, Deserialize)] pub struct TerminalModes { pub alternate_screen: bool, pub mouse_reporting: bool, @@ -30,7 +32,7 @@ pub struct TerminalModes { pub bracketed_paste: bool, } -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct SnapshotLine { pub text: String, } @@ -43,7 +45,7 @@ impl From<&str> for SnapshotLine { } } -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct TerminalSnapshot { pub sequence: u64, pub size: PtySize, diff --git a/crates/embers-server/Cargo.toml b/crates/embers-server/Cargo.toml index e54dfb3..6960cf3 100644 --- a/crates/embers-server/Cargo.toml +++ b/crates/embers-server/Cargo.toml @@ -8,6 +8,7 @@ version.workspace = true [dependencies] alacritty_terminal = "0.25.1" +base64.workspace = true embers-core = { path = "../embers-core" } embers-protocol = { path = "../embers-protocol" } portable-pty.workspace = true diff --git a/crates/embers-server/src/buffer_runtime.rs b/crates/embers-server/src/buffer_runtime.rs index 97ec825..4acda30 100644 --- a/crates/embers-server/src/buffer_runtime.rs +++ b/crates/embers-server/src/buffer_runtime.rs @@ -1,18 +1,55 @@ -use std::any::Any; use std::collections::BTreeMap; -use std::ffi::OsString; +use std::env; +use std::ffi::{OsStr, OsString}; +use std::fs; use std::io::{Read, Write}; -use std::path::Path; +#[cfg(unix)] +use std::os::unix::ffi::OsStrExt; +#[cfg(unix)] +use std::os::unix::fs::PermissionsExt; +use std::os::unix::net::{UnixListener, UnixStream}; +#[cfg(windows)] +use std::os::windows::ffi::OsStrExt; +use std::path::{Path, PathBuf}; +use std::process::{Command as ProcessCommand, Stdio}; +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; use std::sync::{Arc, Mutex}; use std::thread; +use std::time::Duration; -use embers_core::{BufferId, MuxError, PtySize, Result}; +use base64::Engine as _; +use embers_core::{ActivityState, BufferId, MuxError, PtySize, Result, TerminalSnapshot}; use portable_pty::{ Child, ChildKiller, CommandBuilder, MasterPty, NativePtySystem, PtySize as PortablePtySize, PtySystem, }; +use serde::{Deserialize, Serialize}; use tracing::error; +use crate::{AlacrittyTerminalBackend, RawByteRouter, TerminalBackend}; + +const CONNECT_RETRY_DELAY: Duration = Duration::from_millis(25); +const CONNECT_RETRY_ATTEMPTS: usize = 1200; +const STATUS_POLL_INTERVAL: Duration = Duration::from_millis(50); +const MAX_FRAME_SIZE: usize = 16 * 1024 * 1024; + +#[derive(Clone, Debug)] +pub struct BufferRuntimeUpdate { + pub sequence: u64, + pub activity: ActivityState, + pub title: Option>, +} + +#[derive(Clone, Debug)] +pub struct BufferRuntimeStatus { + pub pid: Option, + pub sequence: u64, + pub activity: ActivityState, + pub title: Option, + pub running: bool, + pub exit_code: Option, +} + #[derive(Clone)] pub struct BufferRuntimeHandle { inner: Arc, @@ -21,128 +58,273 @@ pub struct BufferRuntimeHandle { struct BufferRuntimeInner { buffer_id: BufferId, pid: Option, - master: Mutex>, - writer: Mutex>, - killer: Mutex>, + socket_path: PathBuf, + connection: Mutex, + stop: AtomicBool, threads: Mutex, } #[derive(Default)] struct RuntimeThreads { - reader: Option>, - wait: Option>, + poller: Option>, } #[derive(Clone)] pub struct BufferRuntimeCallbacks { - pub on_output: Arc) + Send + Sync>, + pub on_output: Arc, pub on_exit: Arc) + Send + Sync>, } +#[derive(Clone)] +pub struct RuntimeKeeperCli { + pub socket_path: PathBuf, + pub command: Vec, + pub cwd: Option, + pub env: BTreeMap, + pub size: PtySize, +} + +struct KeeperConnection { + stream: UnixStream, +} + +#[derive(Serialize, Deserialize)] +enum KeeperRequest { + Status, + Write { bytes: Vec }, + Resize { size: PtySize }, + Snapshot { cwd: Option }, + VisibleSnapshot { cwd: Option }, + ScrollbackSlice { start_line: u64, line_count: u32 }, + Kill, +} + +#[derive(Serialize, Deserialize)] +enum KeeperResponse { + Status(KeeperStatus), + Snapshot(KeeperSnapshot), + VisibleSnapshot(TerminalSnapshot), + ScrollbackSlice(KeeperScrollbackSlice), + Ok, + Error { message: String }, +} + +#[derive(Clone, Serialize, Deserialize)] +struct KeeperStatus { + pid: Option, + sequence: u64, + activity: ActivityState, + title: Option, + running: bool, + exit_code: Option, +} + +#[derive(Clone, Serialize, Deserialize)] +pub struct KeeperSnapshot { + pub sequence: u64, + pub size: PtySize, + pub lines: Vec, + pub title: Option, + pub cwd: Option, +} + +#[derive(Clone, Serialize, Deserialize)] +pub struct KeeperScrollbackSlice { + pub start_line: u64, + pub total_lines: u64, + pub lines: Vec, +} + +struct KeeperRuntime { + surface: Mutex, + master: Mutex>, + writer: Mutex>, + killer: Mutex>, + sequence: AtomicU64, + activity: Mutex, + exit_code: Mutex>>, + pid: Option, +} + +struct KeeperSurface { + router: RawByteRouter, + backend: Box, + size: PtySize, +} + impl std::fmt::Debug for BufferRuntimeHandle { fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { formatter .debug_struct("BufferRuntimeHandle") .field("buffer_id", &self.inner.buffer_id) .field("pid", &self.inner.pid) + .field("socket_path", &self.inner.socket_path) .finish() } } impl BufferRuntimeHandle { - pub fn spawn( + pub async fn spawn( buffer_id: BufferId, + socket_path: PathBuf, command: &[String], cwd: Option<&Path>, env: &BTreeMap, size: PtySize, callbacks: BufferRuntimeCallbacks, ) -> Result { - let Some(program) = command.first() else { - return Err(MuxError::invalid_input("buffer command must not be empty")); - }; - - let pty_system = NativePtySystem::default(); - let pair = pty_system - .openpty(to_portable_size(size)) - .map_err(|error| MuxError::pty(error.to_string()))?; + let command = command.to_vec(); + let cwd = cwd.map(Path::to_path_buf); + let env = env.clone(); + tokio::task::spawn_blocking(move || { + Self::spawn_blocking(buffer_id, socket_path, command, cwd, env, size, callbacks) + }) + .await + .map_err(|error| MuxError::internal(error.to_string()))? + } - let mut command_builder = CommandBuilder::new(program); - command_builder.args(&command[1..]); - if let Some(cwd) = cwd { - command_builder.cwd(cwd); + fn spawn_blocking( + buffer_id: BufferId, + socket_path: PathBuf, + command: Vec, + cwd: Option, + env: BTreeMap, + size: PtySize, + callbacks: BufferRuntimeCallbacks, + ) -> Result { + if command.is_empty() { + return Err(MuxError::invalid_input("buffer command must not be empty")); } - for (key, value) in env { - command_builder.env(key, value); + if let Some(parent) = socket_path.parent() { + fs::create_dir_all(parent)?; + } + if socket_path.exists() { + let _ = fs::remove_file(&socket_path); } - let mut child = pair - .slave - .spawn_command(command_builder) - .map_err(|error| MuxError::pty(error.to_string()))?; - let pid = child.process_id(); - let mut killer = child.clone_killer(); - let reader = pair - .master - .try_clone_reader() - .map_err(|error| MuxError::pty(error.to_string()))?; - let writer = pair - .master - .take_writer() - .map_err(|error| MuxError::pty(error.to_string()))?; - - let on_output = callbacks.on_output.clone(); - let reader_handle = thread::Builder::new() - .name(format!("buffer-{buffer_id}-reader")) - .spawn(move || read_loop(buffer_id, reader, on_output)) - .map_err(|error| { - let _ = killer.kill(); - let _ = child.wait(); - MuxError::internal(error.to_string()) - })?; - - let on_exit = callbacks.on_exit.clone(); - let wait_handle = match thread::Builder::new() - .name(format!("buffer-{buffer_id}-wait")) - .spawn(move || wait_loop(buffer_id, child, on_exit)) - { - Ok(handle) => handle, - Err(error) => { - let _ = killer.kill(); - join_thread(buffer_id, "reader", reader_handle); - return Err(MuxError::internal(error.to_string())); - } + let cli = RuntimeKeeperCli { + socket_path: socket_path.clone(), + command, + cwd, + env, + size, }; + spawn_runtime_keeper(cli)?; - Ok(Self { - inner: Arc::new(BufferRuntimeInner { - buffer_id, - pid, - master: Mutex::new(pair.master), - writer: Mutex::new(writer), - killer: Mutex::new(killer), - threads: Mutex::new(RuntimeThreads { - reader: Some(reader_handle), - wait: Some(wait_handle), - }), - }), + Self::attach_blocking(buffer_id, socket_path, callbacks) + } + + pub async fn attach( + buffer_id: BufferId, + socket_path: PathBuf, + callbacks: BufferRuntimeCallbacks, + ) -> Result { + tokio::task::spawn_blocking(move || { + Self::attach_blocking(buffer_id, socket_path, callbacks) }) + .await + .map_err(|error| MuxError::internal(error.to_string()))? + } + + fn attach_blocking( + buffer_id: BufferId, + socket_path: PathBuf, + callbacks: BufferRuntimeCallbacks, + ) -> Result { + let stream = connect_to_keeper(&socket_path)?; + let mut connection = KeeperConnection { stream }; + let initial = connection.status()?; + let inner = Arc::new(BufferRuntimeInner { + buffer_id, + pid: initial.pid, + socket_path, + connection: Mutex::new(connection), + stop: AtomicBool::new(false), + threads: Mutex::new(RuntimeThreads::default()), + }); + + let poller = spawn_status_poller(inner.clone(), callbacks, initial)?; + inner + .threads + .lock() + .map_err(|_| MuxError::internal("buffer runtime thread registry lock poisoned"))? + .poller = Some(poller); + + Ok(Self { inner }) } pub fn pid(&self) -> Option { self.inner.pid } + pub fn socket_path(&self) -> &Path { + &self.inner.socket_path + } + + pub async fn status(&self) -> Result { + let inner = self.inner.clone(); + tokio::task::spawn_blocking(move || { + let mut connection = inner + .connection + .lock() + .map_err(|_| MuxError::internal("buffer runtime connection lock poisoned"))?; + connection.status() + }) + .await + .map_err(|error| MuxError::internal(error.to_string()))? + } + + pub async fn capture_snapshot(&self, cwd: Option) -> Result { + let inner = self.inner.clone(); + tokio::task::spawn_blocking(move || { + let mut connection = inner + .connection + .lock() + .map_err(|_| MuxError::internal("buffer runtime connection lock poisoned"))?; + connection.snapshot(cwd) + }) + .await + .map_err(|error| MuxError::internal(error.to_string()))? + } + + pub async fn capture_visible_snapshot(&self, cwd: Option) -> Result { + let inner = self.inner.clone(); + tokio::task::spawn_blocking(move || { + let mut connection = inner + .connection + .lock() + .map_err(|_| MuxError::internal("buffer runtime connection lock poisoned"))?; + connection.visible_snapshot(cwd) + }) + .await + .map_err(|error| MuxError::internal(error.to_string()))? + } + + pub async fn capture_scrollback_slice( + &self, + start_line: u64, + line_count: u32, + ) -> Result { + let inner = self.inner.clone(); + tokio::task::spawn_blocking(move || { + let mut connection = inner + .connection + .lock() + .map_err(|_| MuxError::internal("buffer runtime connection lock poisoned"))?; + connection.scrollback_slice(start_line, line_count) + }) + .await + .map_err(|error| MuxError::internal(error.to_string()))? + } + pub async fn write(&self, bytes: Vec) -> Result<()> { let inner = self.inner.clone(); tokio::task::spawn_blocking(move || { - let mut writer = inner - .writer + let mut connection = inner + .connection .lock() - .map_err(|_| MuxError::internal("buffer runtime writer lock poisoned"))?; - writer.write_all(&bytes)?; - writer.flush()?; - Ok(()) + .map_err(|_| MuxError::internal("buffer runtime connection lock poisoned"))?; + connection.write(bytes) }) .await .map_err(|error| MuxError::internal(error.to_string()))? @@ -151,13 +333,11 @@ impl BufferRuntimeHandle { pub async fn resize(&self, size: PtySize) -> Result<()> { let inner = self.inner.clone(); tokio::task::spawn_blocking(move || { - let master = inner - .master + let mut connection = inner + .connection .lock() - .map_err(|_| MuxError::internal("buffer runtime master lock poisoned"))?; - master - .resize(to_portable_size(size)) - .map_err(|error| MuxError::pty(error.to_string())) + .map_err(|_| MuxError::internal("buffer runtime connection lock poisoned"))?; + connection.resize(size) }) .await .map_err(|error| MuxError::internal(error.to_string()))? @@ -166,13 +346,11 @@ impl BufferRuntimeHandle { pub async fn kill(&self) -> Result<()> { let inner = self.inner.clone(); tokio::task::spawn_blocking(move || { - let mut killer = inner - .killer + let mut connection = inner + .connection .lock() - .map_err(|_| MuxError::internal("buffer runtime killer lock poisoned"))?; - killer - .kill() - .map_err(|error| MuxError::pty(error.to_string())) + .map_err(|_| MuxError::internal("buffer runtime connection lock poisoned"))?; + connection.kill() }) .await .map_err(|error| MuxError::internal(error.to_string()))? @@ -188,6 +366,7 @@ impl BufferRuntimeHandle { impl BufferRuntimeInner { fn join_threads_blocking(&self) { + self.stop.store(true, Ordering::Relaxed); let mut threads = match self.threads.lock() { Ok(threads) => threads, Err(poisoned) => { @@ -198,14 +377,13 @@ impl BufferRuntimeInner { poisoned.into_inner() } }; - let RuntimeThreads { reader, wait } = std::mem::take(&mut *threads); + let poller = threads.poller.take(); drop(threads); - if let Some(handle) = reader { - join_thread(self.buffer_id, "reader", handle); - } - if let Some(handle) = wait { - join_thread(self.buffer_id, "wait", handle); + if let Some(poller) = poller + && poller.thread().id() != thread::current().id() + { + let _ = poller.join(); } } } @@ -216,29 +394,677 @@ impl Drop for BufferRuntimeInner { } } -fn read_loop( - buffer_id: BufferId, - mut reader: Box, - on_output: Arc) + Send + Sync>, -) { +impl KeeperConnection { + fn request(&mut self, request: KeeperRequest) -> Result { + write_message(&mut self.stream, &request)?; + match read_message(&mut self.stream)? { + Some(KeeperResponse::Error { message }) => Err(MuxError::transport(message)), + Some(response) => Ok(response), + None => Err(MuxError::transport("runtime keeper disconnected")), + } + } + + fn status(&mut self) -> Result { + match self.request(KeeperRequest::Status)? { + KeeperResponse::Status(status) => Ok(BufferRuntimeStatus { + pid: status.pid, + sequence: status.sequence, + activity: status.activity, + title: status.title, + running: status.running, + exit_code: status.exit_code, + }), + other => Err(MuxError::protocol(format!( + "unexpected runtime keeper status response: {other_kind}", + other_kind = keeper_response_kind(&other) + ))), + } + } + + fn write(&mut self, bytes: Vec) -> Result<()> { + match self.request(KeeperRequest::Write { bytes })? { + KeeperResponse::Ok => Ok(()), + other => Err(MuxError::protocol(format!( + "unexpected runtime keeper write response: {other_kind}", + other_kind = keeper_response_kind(&other) + ))), + } + } + + fn resize(&mut self, size: PtySize) -> Result<()> { + match self.request(KeeperRequest::Resize { size })? { + KeeperResponse::Ok => Ok(()), + other => Err(MuxError::protocol(format!( + "unexpected runtime keeper resize response: {other_kind}", + other_kind = keeper_response_kind(&other) + ))), + } + } + + fn snapshot(&mut self, cwd: Option) -> Result { + match self.request(KeeperRequest::Snapshot { cwd })? { + KeeperResponse::Snapshot(snapshot) => Ok(snapshot), + other => Err(MuxError::protocol(format!( + "unexpected runtime keeper snapshot response: {other_kind}", + other_kind = keeper_response_kind(&other) + ))), + } + } + + fn visible_snapshot(&mut self, cwd: Option) -> Result { + match self.request(KeeperRequest::VisibleSnapshot { cwd })? { + KeeperResponse::VisibleSnapshot(snapshot) => Ok(snapshot), + other => Err(MuxError::protocol(format!( + "unexpected runtime keeper visible snapshot response: {other_kind}", + other_kind = keeper_response_kind(&other) + ))), + } + } + + fn scrollback_slice( + &mut self, + start_line: u64, + line_count: u32, + ) -> Result { + match self.request(KeeperRequest::ScrollbackSlice { + start_line, + line_count, + })? { + KeeperResponse::ScrollbackSlice(slice) => Ok(slice), + other => Err(MuxError::protocol(format!( + "unexpected runtime keeper scrollback response: {other_kind}", + other_kind = keeper_response_kind(&other) + ))), + } + } + + fn kill(&mut self) -> Result<()> { + match self.request(KeeperRequest::Kill)? { + KeeperResponse::Ok => Ok(()), + other => Err(MuxError::protocol(format!( + "unexpected runtime keeper kill response: {other_kind}", + other_kind = keeper_response_kind(&other) + ))), + } + } +} + +impl KeeperSurface { + fn new(size: PtySize) -> Self { + Self { + router: RawByteRouter, + backend: Box::new(AlacrittyTerminalBackend::new(size)), + size, + } + } + + fn route_output(&mut self, bytes: &[u8]) -> ActivityState { + self.router.route_output(self.backend.as_mut(), bytes); + self.backend.take_activity() + } + + fn resize(&mut self, size: PtySize) { + self.size = size; + self.backend.resize(size); + } + + fn capture_lines(&self) -> Vec { + self.backend.capture_scrollback() + } + + fn capture_visible_snapshot(&self, sequence: u64, cwd: Option) -> TerminalSnapshot { + self.backend.visible_snapshot(sequence, self.size, cwd) + } + + fn capture_scrollback_slice(&self, start_line: u64, line_count: u32) -> KeeperScrollbackSlice { + let slice = self + .backend + .capture_scrollback_slice(start_line, line_count); + KeeperScrollbackSlice { + start_line: slice.start_line, + total_lines: slice.total_lines, + lines: slice.lines, + } + } +} + +pub fn run_runtime_keeper(cli: RuntimeKeeperCli) -> Result<()> { + let Some(program) = cli.command.first() else { + return Err(MuxError::invalid_input( + "runtime keeper command must not be empty", + )); + }; + + if let Some(parent) = cli.socket_path.parent() { + fs::create_dir_all(parent)?; + } + if cli.socket_path.exists() { + let _ = fs::remove_file(&cli.socket_path); + } + let listener = UnixListener::bind(&cli.socket_path)?; + let _cleanup = SocketCleanup::new(cli.socket_path.clone()); + + let pty_system = NativePtySystem::default(); + let pair = pty_system + .openpty(to_portable_size(cli.size)) + .map_err(|error| MuxError::pty(error.to_string()))?; + + let mut command_builder = CommandBuilder::new(program); + command_builder.args(&cli.command[1..]); + if let Some(cwd) = &cli.cwd { + command_builder.cwd(cwd); + } + for (key, value) in &cli.env { + command_builder.env(key, value); + } + + let child = pair + .slave + .spawn_command(command_builder) + .map_err(|error| MuxError::pty(error.to_string()))?; + let pid = child.process_id(); + let killer = child.clone_killer(); + let reader = pair + .master + .try_clone_reader() + .map_err(|error| MuxError::pty(error.to_string()))?; + let writer = pair + .master + .take_writer() + .map_err(|error| MuxError::pty(error.to_string()))?; + + let runtime = Arc::new(KeeperRuntime { + surface: Mutex::new(KeeperSurface::new(cli.size)), + master: Mutex::new(pair.master), + writer: Mutex::new(writer), + killer: Mutex::new(killer), + sequence: AtomicU64::new(0), + activity: Mutex::new(ActivityState::Idle), + exit_code: Mutex::new(None), + pid, + }); + + let reader_runtime = runtime.clone(); + let reader_join = thread::Builder::new() + .name(format!("keeper-reader-{}", cli.socket_path.display())) + .spawn(move || keeper_read_loop(reader_runtime, reader)) + .map_err(|error| MuxError::internal(error.to_string()))?; + let wait_runtime = runtime.clone(); + let wait_join = thread::Builder::new() + .name(format!("keeper-wait-{}", cli.socket_path.display())) + .spawn(move || keeper_wait_loop(wait_runtime, child)) + .map_err(|error| MuxError::internal(error.to_string()))?; + let mut terminate = false; + while !terminate { + let (mut stream, _) = listener.accept()?; + terminate = handle_keeper_client(runtime.clone(), &mut stream)?; + } + + let _ = reader_join.join(); + let _ = wait_join.join(); + Ok(()) +} + +fn handle_keeper_client(runtime: Arc, stream: &mut UnixStream) -> Result { + loop { + let request = match read_message::(stream) { + Ok(Some(request)) => request, + Ok(None) => return Ok(false), + Err(error) => { + let response = KeeperResponse::Error { + message: error.to_string(), + }; + if write_message(stream, &response).is_err() { + return Ok(false); + } + continue; + } + }; + let (response, terminate) = match handle_keeper_request(&runtime, request) { + Ok(result) => result, + Err(error) => { + let response = KeeperResponse::Error { + message: error.to_string(), + }; + if write_message(stream, &response).is_err() { + return Ok(false); + } + continue; + } + }; + if write_message(stream, &response).is_err() { + return Ok(false); + } + if terminate { + return Ok(true); + } + } +} + +fn handle_keeper_request( + runtime: &Arc, + request: KeeperRequest, +) -> Result<(KeeperResponse, bool)> { + match request { + KeeperRequest::Status => Ok((KeeperResponse::Status(runtime.status()?), false)), + KeeperRequest::Write { bytes } => { + runtime.write(bytes)?; + Ok((KeeperResponse::Ok, false)) + } + KeeperRequest::Resize { size } => { + runtime.resize(size)?; + Ok((KeeperResponse::Ok, false)) + } + KeeperRequest::Snapshot { cwd } => { + Ok((KeeperResponse::Snapshot(runtime.snapshot(cwd)?), false)) + } + KeeperRequest::VisibleSnapshot { cwd } => Ok(( + KeeperResponse::VisibleSnapshot(runtime.visible_snapshot(cwd)?), + false, + )), + KeeperRequest::ScrollbackSlice { + start_line, + line_count, + } => Ok(( + KeeperResponse::ScrollbackSlice(runtime.scrollback_slice(start_line, line_count)?), + false, + )), + KeeperRequest::Kill => { + runtime.kill()?; + Ok((KeeperResponse::Ok, false)) + } + } +} + +impl KeeperRuntime { + fn status(&self) -> Result { + let exit_code = *self + .exit_code + .lock() + .map_err(|_| MuxError::internal("runtime keeper exit lock poisoned"))?; + let surface = self + .surface + .lock() + .map_err(|_| MuxError::internal("runtime keeper surface lock poisoned"))?; + let activity = *self + .activity + .lock() + .map_err(|_| MuxError::internal("runtime keeper activity lock poisoned"))?; + let sequence = self.sequence.load(Ordering::Relaxed); + let title = surface.backend.metadata().title.clone(); + Ok(KeeperStatus { + pid: self.pid, + sequence, + activity, + title, + running: exit_code.is_none(), + exit_code: exit_code.flatten(), + }) + } + + fn write(&self, bytes: Vec) -> Result<()> { + if self + .exit_code + .lock() + .map_err(|_| MuxError::internal("runtime keeper exit lock poisoned"))? + .is_some() + { + return Err(MuxError::conflict("buffer runtime has already exited")); + } + let mut writer = self + .writer + .lock() + .map_err(|_| MuxError::internal("runtime keeper writer lock poisoned"))?; + writer.write_all(&bytes)?; + writer.flush()?; + Ok(()) + } + + fn resize(&self, size: PtySize) -> Result<()> { + let master = self + .master + .lock() + .map_err(|_| MuxError::internal("runtime keeper master lock poisoned"))?; + master + .resize(to_portable_size(size)) + .map_err(|error| MuxError::pty(error.to_string()))?; + self.surface + .lock() + .map_err(|_| MuxError::internal("runtime keeper surface lock poisoned"))? + .resize(size); + Ok(()) + } + + fn snapshot(&self, cwd: Option) -> Result { + let surface = self + .surface + .lock() + .map_err(|_| MuxError::internal("runtime keeper surface lock poisoned"))?; + Ok(KeeperSnapshot { + sequence: self.sequence.load(Ordering::Relaxed), + size: surface.size, + lines: surface.capture_lines(), + title: surface.backend.metadata().title, + cwd, + }) + } + + fn visible_snapshot(&self, cwd: Option) -> Result { + let surface = self + .surface + .lock() + .map_err(|_| MuxError::internal("runtime keeper surface lock poisoned"))?; + Ok(surface.capture_visible_snapshot(self.sequence.load(Ordering::Relaxed), cwd)) + } + + fn scrollback_slice(&self, start_line: u64, line_count: u32) -> Result { + let surface = self + .surface + .lock() + .map_err(|_| MuxError::internal("runtime keeper surface lock poisoned"))?; + Ok(surface.capture_scrollback_slice(start_line, line_count)) + } + + fn kill(&self) -> Result<()> { + let mut killer = self + .killer + .lock() + .map_err(|_| MuxError::internal("runtime keeper killer lock poisoned"))?; + killer + .kill() + .map_err(|error| MuxError::pty(error.to_string())) + } +} + +fn keeper_read_loop(runtime: Arc, mut reader: Box) { let mut buffer = [0_u8; 4096]; loop { match reader.read(&mut buffer) { Ok(0) => break, - Ok(read) => on_output(buffer_id, buffer[..read].to_vec()), + Ok(read) => { + let mut surface = match runtime.surface.lock() { + Ok(surface) => surface, + Err(_) => break, + }; + let activity = surface.route_output(&buffer[..read]); + runtime.sequence.fetch_add(1, Ordering::Relaxed); + if let Ok(mut state) = runtime.activity.lock() { + *state = activity; + } + } Err(error) if error.kind() == std::io::ErrorKind::Interrupted => continue, Err(_) => break, } } } -fn wait_loop( - buffer_id: BufferId, - mut child: Box, - on_exit: Arc) + Send + Sync>, -) { +fn keeper_wait_loop(runtime: Arc, mut child: Box) { let exit_code = child.wait().ok().and_then(exit_status_code); - on_exit(buffer_id, exit_code); + if let Ok(mut state) = runtime.exit_code.lock() { + *state = Some(exit_code); + } +} + +fn spawn_status_poller( + inner: Arc, + callbacks: BufferRuntimeCallbacks, + initial: BufferRuntimeStatus, +) -> Result> { + thread::Builder::new() + .name(format!("buffer-{}-poller", inner.buffer_id)) + .spawn(move || { + let mut last_sequence = initial.sequence; + let mut last_title = initial.title.clone(); + let mut last_activity = initial.activity; + let mut saw_exit = !initial.running; + + while !inner.stop.load(Ordering::Relaxed) { + let status = { + let mut connection = match inner.connection.lock() { + Ok(connection) => connection, + Err(_) => break, + }; + match connection.status() { + Ok(status) => status, + Err(error) => { + error!(%error, %inner.buffer_id, "status poll failed"); + (callbacks.on_exit)(inner.buffer_id, None); + break; + } + } + }; + + if status.sequence != last_sequence + || status.title != last_title + || status.activity != last_activity + { + let title = (status.title != last_title).then(|| status.title.clone()); + (callbacks.on_output)( + inner.buffer_id, + BufferRuntimeUpdate { + sequence: status.sequence, + activity: status.activity, + title, + }, + ); + last_sequence = status.sequence; + last_title = status.title.clone(); + last_activity = status.activity; + } + + if !saw_exit && !status.running { + saw_exit = true; + (callbacks.on_exit)(inner.buffer_id, status.exit_code); + } + + thread::sleep(STATUS_POLL_INTERVAL); + } + }) + .map_err(|error| MuxError::internal(error.to_string())) +} + +fn connect_to_keeper(socket_path: &Path) -> Result { + for _ in 0..CONNECT_RETRY_ATTEMPTS { + match UnixStream::connect(socket_path) { + Ok(stream) => return Ok(stream), + Err(error) if error.kind() == std::io::ErrorKind::NotFound => { + thread::sleep(CONNECT_RETRY_DELAY); + } + Err(error) if error.kind() == std::io::ErrorKind::ConnectionRefused => { + return Err(error.into()); + } + Err(error) => return Err(error.into()), + } + } + Err(MuxError::timeout(format!( + "timed out connecting to runtime keeper {}", + socket_path.display() + ))) +} + +fn spawn_runtime_keeper(cli: RuntimeKeeperCli) -> Result<()> { + if let Some(keeper_exe) = resolve_runtime_keeper_executable() { + let mut keeper = ProcessCommand::new(keeper_exe); + keeper + .arg("__runtime-keeper") + .arg("--keeper-socket") + .arg(&cli.socket_path) + .arg("--cols") + .arg(cli.size.cols.to_string()) + .arg("--rows") + .arg(cli.size.rows.to_string()); + if let Some(cwd) = &cli.cwd { + keeper.arg("--cwd").arg(cwd); + } + for (key, value) in &cli.env { + keeper.arg("--env").arg(format!( + "{}=base64:{}", + key, + encode_runtime_keeper_env_value(value.as_os_str()) + )); + } + keeper.arg("--"); + keeper.args(&cli.command); + keeper.stdin(Stdio::null()); + keeper.stdout(Stdio::null()); + keeper.stderr(Stdio::null()); + keeper.spawn()?; + return Ok(()); + } + + thread::Builder::new() + .name(format!("runtime-keeper-{}", cli.socket_path.display())) + .spawn(move || { + if let Err(error) = run_runtime_keeper(cli) { + error!(%error, "runtime keeper thread failed"); + } + }) + .map_err(|error| MuxError::internal(error.to_string()))?; + Ok(()) +} + +fn resolve_runtime_keeper_executable() -> Option { + if let Some(path) = env::var_os("EMBERS_RUNTIME_KEEPER_BIN").map(PathBuf::from) + && is_executable_file(&path) + { + return Some(path); + } + if let Some(path) = env::var_os("CARGO_BIN_EXE_embers").map(PathBuf::from) + && is_executable_file(&path) + { + return Some(path); + } + let current_exe = env::current_exe().ok(); + if let Some(current_exe) = current_exe.as_ref() { + if current_exe + .file_stem() + .and_then(|name| name.to_str()) + .is_some_and(|name| name == "embers" || name == "embers-cli") + && is_executable_file(current_exe) + { + return Some(current_exe.clone()); + } + + if let Some(parent) = current_exe.parent() { + if parent.file_name().is_some_and(|name| name == "deps") { + let candidate = parent.parent()?.join(binary_name("embers")); + if is_executable_file(&candidate) { + return Some(candidate); + } + } + + for stem in ["embers", "embers-cli", "embers-runtime-keeper"] { + let candidate = parent.join(binary_name(stem)); + if is_executable_file(&candidate) { + return Some(candidate); + } + } + } + } + + for stem in ["embers", "embers-cli", "embers-runtime-keeper"] { + if let Some(path) = resolve_binary_on_path(stem) { + return Some(path); + } + } + + None +} + +fn binary_name(stem: &str) -> String { + if cfg!(windows) { + format!("{stem}.exe") + } else { + stem.to_owned() + } +} + +fn is_executable_file(path: &Path) -> bool { + let Ok(metadata) = path.metadata() else { + return false; + }; + if !metadata.is_file() { + return false; + } + #[cfg(unix)] + if metadata.permissions().mode() & 0o111 == 0 { + return false; + } + true +} + +fn resolve_binary_on_path(stem: &str) -> Option { + let path = env::var_os("PATH")?; + let binary_name = binary_name(stem); + for entry in env::split_paths(&path) { + let candidate = entry.join(&binary_name); + if is_executable_file(&candidate) { + return Some(candidate); + } + } + None +} + +fn encode_runtime_keeper_env_value(value: &OsStr) -> String { + #[cfg(unix)] + { + base64::engine::general_purpose::STANDARD.encode(value.as_bytes()) + } + #[cfg(windows)] + { + let encoded = value + .encode_wide() + .flat_map(|unit| unit.to_le_bytes()) + .collect::>(); + base64::engine::general_purpose::STANDARD.encode(encoded) + } + #[cfg(all(not(unix), not(windows)))] + { + base64::engine::general_purpose::STANDARD.encode(value.to_string_lossy().as_bytes()) + } +} + +fn write_message(stream: &mut UnixStream, value: &T) -> Result<()> { + let payload = + serde_json::to_vec(value).map_err(|error| MuxError::internal(error.to_string()))?; + let len = u32::try_from(payload.len()) + .map_err(|_| MuxError::internal("runtime keeper payload exceeded u32 length"))?; + stream.write_all(&len.to_le_bytes())?; + stream.write_all(&payload)?; + stream.flush()?; + Ok(()) +} + +fn read_message Deserialize<'de>>(stream: &mut UnixStream) -> Result> { + let mut len_bytes = [0_u8; 4]; + match stream.read_exact(&mut len_bytes) { + Ok(()) => {} + Err(error) if error.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None), + Err(error) => return Err(error.into()), + } + let len = usize::try_from(u32::from_le_bytes(len_bytes)) + .map_err(|_| MuxError::protocol("runtime keeper frame length exceeds platform limits"))?; + if len == 0 || len > MAX_FRAME_SIZE { + return Err(MuxError::protocol(format!( + "runtime keeper frame length {len} is out of range" + ))); + } + let mut payload = vec![0_u8; len]; + stream.read_exact(&mut payload)?; + let value = + serde_json::from_slice(&payload).map_err(|error| MuxError::internal(error.to_string()))?; + Ok(Some(value)) +} + +fn keeper_response_kind(response: &KeeperResponse) -> &'static str { + match response { + KeeperResponse::Status(_) => "status", + KeeperResponse::Snapshot(_) => "snapshot", + KeeperResponse::VisibleSnapshot(_) => "visible_snapshot", + KeeperResponse::ScrollbackSlice(_) => "scrollback_slice", + KeeperResponse::Ok => "ok", + KeeperResponse::Error { .. } => "error", + } } fn exit_status_code(status: portable_pty::ExitStatus) -> Option { @@ -258,23 +1084,152 @@ fn to_portable_size(size: PtySize) -> PortablePtySize { } } -fn join_thread(buffer_id: BufferId, role: &str, handle: thread::JoinHandle<()>) { - if let Err(payload) = handle.join() { - error!( - %buffer_id, - thread = role, - panic = %panic_payload_message(payload), - "buffer runtime thread panicked" - ); +struct SocketCleanup { + socket_path: PathBuf, +} + +impl SocketCleanup { + fn new(socket_path: PathBuf) -> Self { + Self { socket_path } } } -fn panic_payload_message(payload: Box) -> String { - match payload.downcast::() { - Ok(message) => *message, - Err(payload) => match payload.downcast::<&'static str>() { - Ok(message) => (*message).to_owned(), - Err(_) => "non-string panic payload".to_owned(), - }, +impl Drop for SocketCleanup { + fn drop(&mut self) { + let _ = fs::remove_file(&self.socket_path); + } +} + +#[cfg(test)] +mod tests { + use std::io::Write; + use std::os::unix::net::UnixStream; + use std::sync::Arc; + use std::sync::mpsc; + use std::thread; + use std::time::{Duration, Instant}; + + use embers_core::{ActivityState, BufferId, MuxError}; + + use super::{ + BufferRuntimeCallbacks, BufferRuntimeInner, BufferRuntimeStatus, KeeperConnection, + MAX_FRAME_SIZE, RuntimeThreads, read_message, spawn_status_poller, + }; + + #[test] + fn join_threads_waits_for_poller_shutdown() { + let (stream, _peer) = UnixStream::pair().expect("create socket pair"); + let inner = Arc::new(BufferRuntimeInner { + buffer_id: BufferId(1), + pid: None, + socket_path: "/tmp/test-buffer.sock".into(), + connection: std::sync::Mutex::new(KeeperConnection { stream }), + stop: std::sync::atomic::AtomicBool::new(false), + threads: std::sync::Mutex::new(RuntimeThreads::default()), + }); + let (tx, rx) = mpsc::channel(); + let poller_inner = inner.clone(); + let poller = thread::spawn(move || { + while !poller_inner.stop.load(std::sync::atomic::Ordering::Relaxed) { + thread::sleep(Duration::from_millis(5)); + } + thread::sleep(Duration::from_millis(40)); + tx.send(()).expect("send shutdown notification"); + }); + inner.threads.lock().expect("lock thread registry").poller = Some(poller); + + let started = Instant::now(); + inner.join_threads_blocking(); + + assert!( + started.elapsed() >= Duration::from_millis(40), + "join should wait for the poller to finish" + ); + rx.try_recv() + .expect("poller should finish before join returns"); + } + + #[test] + fn read_message_rejects_empty_frame() { + let (mut stream, mut peer) = UnixStream::pair().expect("create socket pair"); + peer.write_all(&0_u32.to_le_bytes()) + .expect("write frame length"); + drop(peer); + + let error = match read_message::(&mut stream) { + Err(error) => error, + Ok(_) => panic!("expected frame error"), + }; + + assert!(matches!(error, MuxError::Protocol(_))); + assert!(error.to_string().contains("out of range")); + } + + #[test] + fn read_message_rejects_oversized_frame() { + let (mut stream, mut peer) = UnixStream::pair().expect("create socket pair"); + peer.write_all( + &(u32::try_from(MAX_FRAME_SIZE).expect("frame size fits in u32") + 1).to_le_bytes(), + ) + .expect("write frame length"); + drop(peer); + + let error = match read_message::(&mut stream) { + Err(error) => error, + Ok(_) => panic!("expected frame error"), + }; + + assert!(matches!(error, MuxError::Protocol(_))); + assert!(error.to_string().contains("out of range")); + } + + #[test] + fn status_poller_exits_on_status_error() { + let (stream, peer) = UnixStream::pair().expect("create socket pair"); + drop(peer); + let inner = Arc::new(BufferRuntimeInner { + buffer_id: BufferId(1), + pid: None, + socket_path: "/tmp/test-buffer.sock".into(), + connection: std::sync::Mutex::new(KeeperConnection { stream }), + stop: std::sync::atomic::AtomicBool::new(false), + threads: std::sync::Mutex::new(RuntimeThreads::default()), + }); + let (exit_tx, exit_rx) = mpsc::channel(); + let (output_tx, output_rx) = mpsc::channel(); + let poller = spawn_status_poller( + inner, + BufferRuntimeCallbacks { + on_output: Arc::new(move |buffer_id, _| { + output_tx + .send(buffer_id) + .expect("send unexpected output notification"); + }), + on_exit: Arc::new(move |buffer_id, exit_code| { + exit_tx + .send((buffer_id, exit_code)) + .expect("send exit notification"); + }), + }, + BufferRuntimeStatus { + pid: None, + sequence: 0, + activity: ActivityState::Idle, + title: None, + running: true, + exit_code: None, + }, + ) + .expect("spawn poller"); + + poller.join().expect("poller exits cleanly"); + + assert_eq!( + exit_rx + .recv_timeout(Duration::from_secs(1)) + .expect("poller should report exit"), + (BufferId(1), None) + ); + assert!(output_rx.try_recv().is_err()); } } diff --git a/crates/embers-server/src/config.rs b/crates/embers-server/src/config.rs index 782a70a..551d7e2 100644 --- a/crates/embers-server/src/config.rs +++ b/crates/embers-server/src/config.rs @@ -8,6 +8,7 @@ pub const SOCKET_ENV_VAR: &str = "EMBERS_SOCKET"; pub struct ServerConfig { pub socket_path: PathBuf, pub workspace_path: PathBuf, + pub runtime_dir: PathBuf, pub buffer_env: BTreeMap, } @@ -19,9 +20,11 @@ impl ServerConfig { socket_path.as_os_str().to_owned(), ); let workspace_path = socket_path.with_extension("workspace.json"); + let runtime_dir = socket_path.with_extension("runtimes"); Self { socket_path, workspace_path, + runtime_dir, buffer_env, } } diff --git a/crates/embers-server/src/lib.rs b/crates/embers-server/src/lib.rs index 6d2867d..b467d7c 100644 --- a/crates/embers-server/src/lib.rs +++ b/crates/embers-server/src/lib.rs @@ -8,7 +8,10 @@ mod protocol; mod server; mod terminal_backend; -pub use buffer_runtime::{BufferRuntimeCallbacks, BufferRuntimeHandle}; +pub use buffer_runtime::{ + BufferRuntimeCallbacks, BufferRuntimeHandle, BufferRuntimeStatus, BufferRuntimeUpdate, + RuntimeKeeperCli, run_runtime_keeper, +}; pub use config::{SOCKET_ENV_VAR, ServerConfig}; pub use model::{ Buffer, BufferAttachment, BufferState, BufferViewNode, BufferViewState, ExitedBuffer, diff --git a/crates/embers-server/src/model.rs b/crates/embers-server/src/model.rs index f8dc1de..dffa4d5 100644 --- a/crates/embers-server/src/model.rs +++ b/crates/embers-server/src/model.rs @@ -1,4 +1,5 @@ use std::collections::BTreeMap; +use std::fmt; use std::path::PathBuf; use embers_core::{ @@ -17,13 +18,14 @@ pub struct Session { pub created_at: Timestamp, } -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone)] pub struct Buffer { pub id: BufferId, pub title: String, pub command: Vec, pub cwd: Option, pub env: BTreeMap, + runtime_socket_path: Option, pub state: BufferState, pub attachment: BufferAttachment, pub pty_size: PtySize, @@ -32,6 +34,76 @@ pub struct Buffer { pub created_at: Timestamp, } +impl Buffer { + pub(crate) fn new( + id: BufferId, + title: impl Into, + command: Vec, + cwd: Option, + env: BTreeMap, + ) -> Self { + Self { + id, + title: title.into(), + command, + cwd, + env, + runtime_socket_path: None, + state: BufferState::Created, + attachment: BufferAttachment::Detached, + pty_size: PtySize::new(80, 24), + activity: ActivityState::Idle, + last_snapshot_seq: 0, + created_at: Timestamp::now(), + } + } + + pub(crate) fn runtime_socket_path(&self) -> Option<&PathBuf> { + self.runtime_socket_path.as_ref() + } + + pub(crate) fn set_runtime_socket_path(&mut self, runtime_socket_path: Option) { + self.runtime_socket_path = runtime_socket_path; + } +} + +impl fmt::Debug for Buffer { + fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + formatter + .debug_struct("Buffer") + .field("id", &self.id) + .field("title", &self.title) + .field("command", &self.command) + .field("cwd", &self.cwd) + .field("env", &self.env) + .field("state", &self.state) + .field("attachment", &self.attachment) + .field("pty_size", &self.pty_size) + .field("activity", &self.activity) + .field("last_snapshot_seq", &self.last_snapshot_seq) + .field("created_at", &self.created_at) + .finish() + } +} + +impl PartialEq for Buffer { + fn eq(&self, other: &Self) -> bool { + self.id == other.id + && self.title == other.title + && self.command == other.command + && self.cwd == other.cwd + && self.env == other.env + && self.state == other.state + && self.attachment == other.attachment + && self.pty_size == other.pty_size + && self.activity == other.activity + && self.last_snapshot_seq == other.last_snapshot_seq + && self.created_at == other.created_at + } +} + +impl Eq for Buffer {} + #[derive(Clone, Debug, Default, PartialEq, Eq)] pub struct RunningBuffer { pub pid: Option, diff --git a/crates/embers-server/src/persist.rs b/crates/embers-server/src/persist.rs index 510a678..302abb7 100644 --- a/crates/embers-server/src/persist.rs +++ b/crates/embers-server/src/persist.rs @@ -54,6 +54,8 @@ pub struct PersistedBuffer { pub command: Vec, pub cwd: Option, pub env: BTreeMap, + #[serde(default)] + pub runtime_socket_path: Option, pub state: PersistedBufferState, pub attachment: PersistedBufferAttachment, pub pty_size: PtySize, @@ -295,6 +297,7 @@ pub fn persisted_buffer(buffer: &Buffer) -> PersistedBuffer { command: buffer.command.clone(), cwd: buffer.cwd.clone(), env: buffer.env.clone(), + runtime_socket_path: buffer.runtime_socket_path().cloned(), state: persisted_buffer_state(&buffer.state), attachment: persisted_buffer_attachment(&buffer.attachment), pty_size: buffer.pty_size, @@ -305,19 +308,21 @@ pub fn persisted_buffer(buffer: &Buffer) -> PersistedBuffer { } pub fn restored_buffer(buffer: PersistedBuffer) -> Result { - Ok(Buffer { - id: BufferId(buffer.id), - title: buffer.title, - command: buffer.command, - cwd: buffer.cwd, - env: buffer.env, - state: restored_buffer_state(buffer.state)?, - attachment: restored_buffer_attachment(buffer.attachment), - pty_size: buffer.pty_size, - activity: restored_activity(buffer.activity), - last_snapshot_seq: buffer.last_snapshot_seq, - created_at: timestamp_from_millis(buffer.created_at_ms)?, - }) + let mut restored = Buffer::new( + BufferId(buffer.id), + buffer.title, + buffer.command, + buffer.cwd, + buffer.env, + ); + restored.set_runtime_socket_path(buffer.runtime_socket_path); + restored.state = restored_buffer_state(buffer.state)?; + restored.attachment = restored_buffer_attachment(buffer.attachment); + restored.pty_size = buffer.pty_size; + restored.activity = restored_activity(buffer.activity); + restored.last_snapshot_seq = buffer.last_snapshot_seq; + restored.created_at = timestamp_from_millis(buffer.created_at_ms)?; + Ok(restored) } pub fn persisted_node(node: &Node) -> PersistedNode { diff --git a/crates/embers-server/src/server.rs b/crates/embers-server/src/server.rs index 9a5cd0d..04e9c4e 100644 --- a/crates/embers-server/src/server.rs +++ b/crates/embers-server/src/server.rs @@ -2,6 +2,8 @@ use std::collections::{BTreeMap, BTreeSet}; use std::ffi::OsString; use std::fs; #[cfg(unix)] +use std::os::unix::ffi::OsStrExt; +#[cfg(unix)] use std::os::unix::fs::PermissionsExt; use std::path::{Path, PathBuf}; use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering}; @@ -30,9 +32,8 @@ use tracing::{debug, error, info}; use crate::persist::{load_workspace, save_workspace}; use crate::protocol::{buffer_record, floating_record, session_record, session_snapshot}; use crate::{ - AlacrittyTerminalBackend, BackendDamage, BufferAttachment, BufferRuntimeCallbacks, - BufferRuntimeHandle, BufferState, RawByteRouter, ServerConfig, ServerState, TabEntry, - TerminalBackend, + BufferAttachment, BufferRuntimeCallbacks, BufferRuntimeHandle, BufferRuntimeStatus, + BufferRuntimeUpdate, BufferState, ServerConfig, ServerState, TabEntry, }; #[derive(Debug)] @@ -51,14 +52,17 @@ impl Server { } let restored_state = load_workspace(&self.config.workspace_path)?; - let listener = UnixListener::bind(&self.config.socket_path)?; - set_socket_permissions(&self.config.socket_path)?; let socket_path = self.config.socket_path.clone(); let runtime = Arc::new(Runtime::new( restored_state.unwrap_or_default(), + self.config.socket_path.clone(), self.config.workspace_path.clone(), + self.config.runtime_dir.clone(), self.config.buffer_env.clone(), )); + runtime.restore_buffer_runtimes().await?; + let listener = UnixListener::bind(&self.config.socket_path)?; + set_socket_permissions(&self.config.socket_path)?; let shutdown_signal = runtime.shutdown.clone(); let (shutdown_tx, mut shutdown_rx) = oneshot::channel(); @@ -276,8 +280,9 @@ struct Runtime { state: Mutex, buffer_runtimes: Mutex>, buffer_shutdown_intents: StdMutex>, - buffer_surfaces: Mutex>, + socket_path: PathBuf, workspace_path: PathBuf, + runtime_dir: PathBuf, buffer_env: BTreeMap, subscriptions: Mutex>, clients: Mutex>, @@ -288,21 +293,6 @@ struct Runtime { state_tasks: TaskCounter, } -struct BufferSurface { - router: RawByteRouter, - backend: Box, - size: PtySize, -} - -impl std::fmt::Debug for BufferSurface { - fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter - .debug_struct("BufferSurface") - .field("size", &self.size) - .finish() - } -} - #[derive(Clone, Default)] struct TaskCounter { inner: Arc, @@ -363,74 +353,21 @@ impl Drop for TaskTicket { } } -impl BufferSurface { - fn new(size: PtySize) -> Self { - Self { - router: RawByteRouter, - backend: Box::new(AlacrittyTerminalBackend::new(size)), - size, - } - } - - fn route_input(&mut self, bytes: Vec) -> Vec { - self.router.route_input(bytes) - } - - fn route_output(&mut self, bytes: &[u8]) { - self.router.route_output(self.backend.as_mut(), bytes); - } - - fn resize(&mut self, size: PtySize) { - self.size = size; - self.backend.resize(size); - } - - fn capture_lines(&self) -> Vec { - self.backend.capture_scrollback() - } - - fn capture_visible_snapshot( - &self, - sequence: u64, - cwd: Option, - ) -> embers_core::TerminalSnapshot { - self.backend.visible_snapshot(sequence, self.size, cwd) - } - - fn capture_scrollback_slice( - &self, - start_line: u64, - line_count: u32, - ) -> crate::BackendScrollbackSlice { - self.backend - .capture_scrollback_slice(start_line, line_count) - } - - fn metadata(&self) -> crate::BackendMetadata { - self.backend.metadata() - } - - fn take_activity(&mut self) -> embers_core::ActivityState { - self.backend.take_activity() - } - - fn damage(&mut self) -> BackendDamage { - self.backend.take_damage() - } -} - impl Runtime { fn new( state: ServerState, + socket_path: PathBuf, workspace_path: PathBuf, + runtime_dir: PathBuf, buffer_env: BTreeMap, ) -> Self { Self { state: Mutex::new(state), buffer_runtimes: Mutex::new(BTreeMap::new()), buffer_shutdown_intents: StdMutex::new(BTreeSet::new()), - buffer_surfaces: Mutex::new(BTreeMap::new()), + socket_path, workspace_path, + runtime_dir, buffer_env, subscriptions: Mutex::new(BTreeMap::new()), clients: Mutex::new(BTreeMap::new()), @@ -461,6 +398,99 @@ impl Runtime { .remove(&buffer_id) } + fn runtime_socket_path(&self, buffer_id: BufferId) -> Result { + let path = self + .runtime_dir + .join(format!("buffer-{}.sock", buffer_id.0)); + validate_keeper_socket_path(&self.socket_path, &path)?; + Ok(path) + } + + fn buffer_runtime_callbacks(self: &Arc) -> BufferRuntimeCallbacks { + let output_handle = tokio::runtime::Handle::current(); + let exit_handle = output_handle.clone(); + let output_runtime = self.clone(); + let exit_runtime = self.clone(); + let output_tasks = self.state_tasks.clone(); + let exit_tasks = self.state_tasks.clone(); + + BufferRuntimeCallbacks { + on_output: Arc::new(move |buffer_id, update| { + let runtime = output_runtime.clone(); + let task = output_tasks.enter(); + std::mem::drop(output_handle.spawn(async move { + let _task = task; + runtime.record_buffer_update(buffer_id, update).await; + })); + }), + on_exit: Arc::new(move |buffer_id, exit_code| { + let runtime = exit_runtime.clone(); + let task = exit_tasks.enter(); + std::mem::drop(exit_handle.spawn(async move { + let _task = task; + runtime.record_buffer_exit(buffer_id, exit_code).await; + })); + }), + } + } + + async fn restore_buffer_runtimes(self: &Arc) -> Result<()> { + let buffers = { + let state = self.state.lock().await; + state.buffers.values().cloned().collect::>() + }; + + for buffer in buffers { + let Some(socket_path) = buffer.runtime_socket_path().cloned() else { + if matches!(buffer.state, BufferState::Running(_) | BufferState::Created) { + let mut state = self.state.lock().await; + let _ = + state.mark_buffer_interrupted(buffer.id, buffer_pid_hint(&buffer.state)); + } + continue; + }; + if !socket_path.exists() { + debug!( + %buffer.id, + socket_path = %socket_path.display(), + "skipping runtime restore because keeper socket is missing" + ); + let mut state = self.state.lock().await; + let _ = state.set_buffer_runtime_socket_path(buffer.id, None); + let _ = state.mark_buffer_interrupted(buffer.id, buffer_pid_hint(&buffer.state)); + continue; + } + + match self + .attach_buffer_runtime(buffer.id, socket_path.clone()) + .await + { + Ok((runtime, status)) => { + let mut state = self.state.lock().await; + let _ = + state.set_buffer_runtime_socket_path(buffer.id, Some(socket_path.clone())); + apply_runtime_status(&mut state, buffer.id, &status); + drop(state); + self.buffer_runtimes.lock().await.insert(buffer.id, runtime); + } + Err(error) => { + debug!( + %buffer.id, + socket_path = %socket_path.display(), + %error, + "failed to restore buffer runtime" + ); + let mut state = self.state.lock().await; + let _ = state.set_buffer_runtime_socket_path(buffer.id, None); + let _ = + state.mark_buffer_interrupted(buffer.id, buffer_pid_hint(&buffer.state)); + } + } + } + + Ok(()) + } + async fn register_client( &self, connection_id: u64, @@ -859,7 +889,6 @@ impl Runtime { if let Err(error) = self.spawn_buffer_runtime(buffer_id).await { let mut state = self.state.lock().await; let _ = state.remove_buffer(buffer_id); - self.buffer_surfaces.lock().await.remove(&buffer_id); return (mux_error_response(Some(request_id), error), Vec::new()); } @@ -986,7 +1015,7 @@ impl Runtime { request_id, buffer_id, force: _, - } => match self.running_buffer_runtime(buffer_id).await { + } => match self.buffer_runtime(buffer_id).await { Ok(runtime) => match runtime.kill().await { Ok(()) => (ServerResponse::Ok(OkResponse { request_id }), Vec::new()), Err(error) => (mux_error_response(Some(request_id), error), Vec::new()), @@ -1031,14 +1060,11 @@ impl Runtime { request_id, buffer_id, bytes, - } => match self.running_buffer_runtime(buffer_id).await { - Ok(runtime) => { - let bytes = self.route_input_bytes(buffer_id, bytes).await; - match runtime.write(bytes).await { - Ok(()) => (ServerResponse::Ok(OkResponse { request_id }), Vec::new()), - Err(error) => (mux_error_response(Some(request_id), error), Vec::new()), - } - } + } => match self.buffer_runtime(buffer_id).await { + Ok(runtime) => match runtime.write(bytes).await { + Ok(()) => (ServerResponse::Ok(OkResponse { request_id }), Vec::new()), + Err(error) => (mux_error_response(Some(request_id), error), Vec::new()), + }, Err(error) => (mux_error_response(Some(request_id), error), Vec::new()), }, InputRequest::Resize { @@ -1047,7 +1073,7 @@ impl Runtime { cols, rows, } => { - let runtime = match self.running_buffer_runtime(buffer_id).await { + let runtime = match self.buffer_runtime(buffer_id).await { Ok(runtime) => runtime, Err(error) => return (mux_error_response(Some(request_id), error), Vec::new()), }; @@ -1076,11 +1102,12 @@ impl Runtime { return (mux_error_response(Some(request_id), error), Vec::new()); } } - let damage = self.resize_surface(buffer_id, size).await; ( ServerResponse::Ok(OkResponse { request_id }), - render_events(buffer_id, damage), + vec![ServerEvent::RenderInvalidated(RenderInvalidatedEvent { + buffer_id, + })], ) } } @@ -1548,61 +1575,53 @@ impl Runtime { (buffer.command, buffer.cwd, buffer.pty_size, buffer.env) }; - let output_handle = tokio::runtime::Handle::current(); - let exit_handle = output_handle.clone(); - let output_runtime = self.clone(); - let exit_runtime = self.clone(); - let output_tasks = self.state_tasks.clone(); - let exit_tasks = self.state_tasks.clone(); let mut buffer_env = self.buffer_env.clone(); for (key, value) in env_hints { buffer_env.insert(key, OsString::from(value)); } let runtime = BufferRuntimeHandle::spawn( buffer_id, + self.runtime_socket_path(buffer_id)?, &command, cwd.as_deref(), &buffer_env, size, - BufferRuntimeCallbacks { - on_output: Arc::new(move |buffer_id, bytes| { - let runtime = output_runtime.clone(); - let _task = output_tasks.enter(); - std::mem::drop(output_handle.spawn(async move { - let _task = _task; - runtime.record_buffer_output(buffer_id, bytes).await; - })); - }), - on_exit: Arc::new(move |buffer_id, exit_code| { - let runtime = exit_runtime.clone(); - let _task = exit_tasks.enter(); - std::mem::drop(exit_handle.spawn(async move { - let _task = _task; - runtime.record_buffer_exit(buffer_id, exit_code).await; - })); - }), - }, - )?; + self.buffer_runtime_callbacks(), + ) + .await?; + let status = runtime.status().await?; { let mut state = self.state.lock().await; - if let Err(error) = state.mark_buffer_running(buffer_id, runtime.pid()) { + if let Err(error) = state.mark_buffer_running(buffer_id, status.pid) { let _ = runtime.kill().await; let _ = runtime.join_threads().await; return Err(error); } + state.set_buffer_runtime_socket_path( + buffer_id, + Some(runtime.socket_path().to_path_buf()), + )?; + apply_runtime_status(&mut state, buffer_id, &status); } - self.buffer_surfaces - .lock() - .await - .entry(buffer_id) - .or_insert_with(|| BufferSurface::new(size)); self.buffer_runtimes.lock().await.insert(buffer_id, runtime); Ok(()) } - async fn running_buffer_runtime(&self, buffer_id: BufferId) -> Result { + async fn attach_buffer_runtime( + self: &Arc, + buffer_id: BufferId, + socket_path: PathBuf, + ) -> Result<(BufferRuntimeHandle, BufferRuntimeStatus)> { + let runtime = + BufferRuntimeHandle::attach(buffer_id, socket_path, self.buffer_runtime_callbacks()) + .await?; + let status = runtime.status().await?; + Ok((runtime, status)) + } + + async fn buffer_runtime(&self, buffer_id: BufferId) -> Result { if let Some(runtime) = self.buffer_runtimes.lock().await.get(&buffer_id).cloned() { return Ok(runtime); } @@ -1634,21 +1653,17 @@ impl Runtime { let state = self.state.lock().await; state.buffer(buffer_id)?.clone() }; - let lines = self - .buffer_surfaces - .lock() - .await - .get(&buffer_id) - .map(BufferSurface::capture_lines) - .unwrap_or_default(); + let runtime = self.buffer_runtime(buffer_id).await?; + let snapshot = runtime.capture_snapshot(buffer.cwd.clone()).await?; + self.sync_buffer_runtime_status(buffer_id, &runtime).await?; Ok(SnapshotResponse { request_id, buffer_id, - sequence: buffer.last_snapshot_seq, - size: buffer.pty_size, - lines, - title: Some(buffer.title), + sequence: snapshot.sequence, + size: snapshot.size, + lines: snapshot.lines, + title: snapshot.title.or(Some(buffer.title)), cwd: buffer.cwd.map(|path| path.display().to_string()), }) } @@ -1662,13 +1677,9 @@ impl Runtime { let state = self.state.lock().await; state.buffer(buffer_id)?.clone() }; - let snapshot = { - let mut surfaces = self.buffer_surfaces.lock().await; - surfaces - .entry(buffer_id) - .or_insert_with(|| BufferSurface::new(buffer.pty_size)) - .capture_visible_snapshot(buffer.last_snapshot_seq, buffer.cwd.clone()) - }; + let runtime = self.buffer_runtime(buffer_id).await?; + let snapshot = runtime.capture_visible_snapshot(buffer.cwd.clone()).await?; + self.sync_buffer_runtime_status(buffer_id, &runtime).await?; Ok(VisibleSnapshotResponse { request_id, @@ -1695,17 +1706,11 @@ impl Runtime { start_line: u64, line_count: u32, ) -> Result { - let buffer = { - let state = self.state.lock().await; - state.buffer(buffer_id)?.clone() - }; - let slice = { - let mut surfaces = self.buffer_surfaces.lock().await; - surfaces - .entry(buffer_id) - .or_insert_with(|| BufferSurface::new(buffer.pty_size)) - .capture_scrollback_slice(start_line, line_count) - }; + let runtime = self.buffer_runtime(buffer_id).await?; + let slice = runtime + .capture_scrollback_slice(start_line, line_count) + .await?; + self.sync_buffer_runtime_status(buffer_id, &runtime).await?; Ok(ScrollbackSliceResponse { request_id, @@ -1716,74 +1721,52 @@ impl Runtime { }) } - async fn route_input_bytes(&self, buffer_id: BufferId, bytes: Vec) -> Vec { - match self.buffer_surfaces.lock().await.get_mut(&buffer_id) { - Some(surface) => surface.route_input(bytes), - None => bytes, - } - } - - async fn resize_surface(&self, buffer_id: BufferId, size: PtySize) -> BackendDamage { - let mut surfaces = self.buffer_surfaces.lock().await; - let surface = surfaces - .entry(buffer_id) - .or_insert_with(|| BufferSurface::new(size)); - surface.resize(size); - surface.damage() - } - - async fn record_buffer_output(&self, buffer_id: BufferId, bytes: Vec) { - let size = { + async fn record_buffer_update(&self, buffer_id: BufferId, update: BufferRuntimeUpdate) { + let updated = { let mut state = self.state.lock().await; - if let Err(error) = state.note_buffer_output(buffer_id) { - debug!(%buffer_id, %error, "dropping PTY output for unknown buffer"); + let Some(buffer) = state.buffers.get_mut(&buffer_id) else { return; - } - match state.buffer(buffer_id) { - Ok(buffer) => buffer.pty_size, - Err(error) => { - debug!(%buffer_id, %error, "buffer disappeared while recording output"); - return; + }; + if update.sequence <= buffer.last_snapshot_seq { + false + } else { + buffer.last_snapshot_seq = update.sequence; + buffer.activity = update.activity; + if let Some(title) = update.title { + match title { + Some(title) => buffer.title = title, + None => buffer.title.clear(), + } } + true } }; - let (metadata, activity, damage) = { - let mut surfaces = self.buffer_surfaces.lock().await; - let surface = surfaces - .entry(buffer_id) - .or_insert_with(|| BufferSurface::new(size)); - surface.resize(size); - surface.route_output(&bytes); - ( - surface.metadata(), - surface.take_activity(), - surface.damage(), + if updated { + self.broadcast( + vec![ServerEvent::RenderInvalidated(RenderInvalidatedEvent { + buffer_id, + })], + &[], ) - }; - - { - let mut state = self.state.lock().await; - if let Some(title) = metadata.title - && let Err(error) = state.set_buffer_title(buffer_id, title) - { - debug!(%buffer_id, %error, "failed to apply terminal title update"); - } - if let Err(error) = state.set_buffer_activity(buffer_id, activity) { - debug!(%buffer_id, %error, "failed to apply buffer activity update"); - } + .await; } - - self.broadcast(render_events(buffer_id, damage), &[]).await; } async fn record_buffer_exit(&self, buffer_id: BufferId, exit_code: Option) { - let runtime = self.buffer_runtimes.lock().await.remove(&buffer_id); let should_interrupt = self.take_buffer_shutdown_intent(buffer_id); + if should_interrupt { + let runtime = self.buffer_runtimes.lock().await.remove(&buffer_id); + drop(runtime); + } let updated = { let mut state = self.state.lock().await; let result = if should_interrupt { - state.mark_buffer_interrupted(buffer_id) + let pid = state + .buffers + .get(&buffer_id) + .and_then(|buffer| buffer_pid_hint(&buffer.state)); + state.mark_buffer_interrupted(buffer_id, pid) } else { state.mark_buffer_exited(buffer_id, exit_code) }; @@ -1796,12 +1779,6 @@ impl Runtime { } }; - if let Some(runtime) = runtime - && let Err(error) = runtime.join_threads().await - { - debug!(%buffer_id, %error, "failed to join buffer runtime threads"); - } - if updated { self.broadcast( vec![ServerEvent::RenderInvalidated(RenderInvalidatedEvent { @@ -1813,6 +1790,27 @@ impl Runtime { } } + async fn sync_buffer_runtime_status( + &self, + buffer_id: BufferId, + runtime: &BufferRuntimeHandle, + ) -> Result<()> { + let status = runtime.status().await?; + self.record_buffer_update( + buffer_id, + BufferRuntimeUpdate { + sequence: status.sequence, + activity: status.activity, + title: Some(status.title.clone()), + }, + ) + .await; + if !status.running { + self.record_buffer_exit(buffer_id, status.exit_code).await; + } + Ok(()) + } + async fn shutdown_runtimes(&self) { let runtimes: Vec<_> = { let runtimes = self.buffer_runtimes.lock().await; @@ -1829,13 +1827,11 @@ impl Runtime { .collect() }; for runtime in runtimes { - if let Err(error) = runtime.kill().await { - debug!(%error, "failed to kill buffer runtime during shutdown"); - } if let Err(error) = runtime.join_threads().await { debug!(%error, "failed to join buffer runtime threads during shutdown"); } } + self.buffer_runtimes.lock().await.clear(); } async fn broadcast( @@ -2175,6 +2171,36 @@ fn set_socket_permissions(socket_path: &Path) -> Result<()> { Ok(()) } +/// Maximum Unix-domain socket path length in bytes for runtime keeper sockets. +/// These values come from `sockaddr_un.sun_path`: macOS exposes 104 bytes per +/// `unix(4)`, while other Unix/Linux platforms expose 108 bytes per `unix(7)`. +/// `validate_keeper_socket_path` uses this limit to validate keeper socket +/// paths derived from the server socket path before binding. +#[cfg(target_os = "macos")] +const UNIX_SOCKET_PATH_LIMIT: usize = 104; +/// Maximum Unix-domain socket path length in bytes for runtime keeper sockets. +/// These values come from `sockaddr_un.sun_path`: macOS exposes 104 bytes per +/// `unix(4)`, while other Unix/Linux platforms expose 108 bytes per `unix(7)`. +/// `validate_keeper_socket_path` uses this limit to validate keeper socket +/// paths derived from the server socket path before binding. +#[cfg(all(unix, not(target_os = "macos")))] +const UNIX_SOCKET_PATH_LIMIT: usize = 108; + +fn validate_keeper_socket_path(server_socket_path: &Path, keeper_socket_path: &Path) -> Result<()> { + #[cfg(unix)] + { + let len = keeper_socket_path.as_os_str().as_bytes().len(); + if len > UNIX_SOCKET_PATH_LIMIT { + return Err(MuxError::invalid_input(format!( + "runtime keeper socket path is too long ({len} bytes, max {UNIX_SOCKET_PATH_LIMIT}): {} (runtime_dir derived from server socket {}). Use a shorter server socket path.", + keeper_socket_path.display(), + server_socket_path.display(), + ))); + } + } + Ok(()) +} + fn protocol_tab_index(index: u32) -> Result { usize::try_from(index) .map_err(|_| MuxError::invalid_input(format!("tab index {index} exceeds platform limits"))) @@ -2254,24 +2280,48 @@ fn protocol_error_to_mux(error: ProtocolError) -> MuxError { MuxError::protocol(error.to_string()) } -fn render_events(buffer_id: BufferId, damage: BackendDamage) -> Vec { - match damage { - BackendDamage::None => Vec::new(), - BackendDamage::Full | BackendDamage::Partial(_) => { - vec![ServerEvent::RenderInvalidated(RenderInvalidatedEvent { - buffer_id, - })] - } +fn apply_runtime_status( + state: &mut ServerState, + buffer_id: BufferId, + status: &BufferRuntimeStatus, +) { + if let Some(buffer) = state.buffers.get_mut(&buffer_id) { + buffer.last_snapshot_seq = status.sequence; + } + if let Some(title) = &status.title { + let _ = state.set_buffer_title(buffer_id, title.clone()); + } + let _ = state.set_buffer_activity(buffer_id, status.activity); + if status.running { + let _ = state.mark_buffer_running(buffer_id, status.pid); + } else { + let _ = state.mark_buffer_exited(buffer_id, status.exit_code); + } +} + +fn buffer_pid_hint(state: &BufferState) -> Option { + match state { + BufferState::Running(running) => running.pid, + BufferState::Interrupted(interrupted) => interrupted.last_known_pid, + BufferState::Created | BufferState::Exited(_) => None, } } #[cfg(test)] mod tests { use std::collections::BTreeMap; + #[cfg(unix)] + use std::os::unix::net::UnixListener as StdUnixListener; use std::path::PathBuf; + use std::sync::Arc; - use super::{Runtime, ShutdownSignal, wait_for_shutdown}; - use crate::ServerState; + use embers_core::ActivityState; + use embers_protocol::{ServerEnvelope, ServerEvent}; + use tempfile::tempdir; + use tokio::sync::mpsc; + + use super::{Runtime, ShutdownSignal, Subscription, wait_for_shutdown}; + use crate::{BufferRuntimeUpdate, BufferState, ServerState}; use tokio::time::{Duration, timeout}; @@ -2290,7 +2340,9 @@ mod tests { fn buffer_shutdown_intents_are_consumed_per_buffer() { let runtime = Runtime::new( ServerState::new(), + PathBuf::from("server.sock"), PathBuf::from("workspace"), + PathBuf::from("runtime"), BTreeMap::new(), ); runtime @@ -2303,4 +2355,197 @@ mod tests { assert!(!runtime.take_buffer_shutdown_intent(embers_core::BufferId(1))); assert!(!runtime.take_buffer_shutdown_intent(embers_core::BufferId(2))); } + + #[tokio::test] + async fn record_buffer_update_ignores_stale_sequences() { + let runtime = Runtime::new( + ServerState::new(), + PathBuf::from("server.sock"), + PathBuf::from("workspace"), + PathBuf::from("runtime"), + BTreeMap::new(), + ); + let buffer_id = { + let mut state = runtime.state.lock().await; + let buffer_id = state.create_buffer("current-title", vec!["/bin/sh".to_owned()], None); + let buffer = state + .buffers + .get_mut(&buffer_id) + .expect("buffer is created"); + buffer.last_snapshot_seq = 5; + buffer.activity = ActivityState::Activity; + buffer_id + }; + let (sender, mut receiver) = mpsc::unbounded_channel(); + runtime.subscriptions.lock().await.insert( + 1, + Subscription { + connection_id: 1, + session_id: None, + sender, + }, + ); + + runtime + .record_buffer_update( + buffer_id, + BufferRuntimeUpdate { + sequence: 5, + activity: ActivityState::Bell, + title: Some(Some("stale-title".to_owned())), + }, + ) + .await; + + let buffer = runtime + .state + .lock() + .await + .buffer(buffer_id) + .expect("buffer exists") + .clone(); + assert_eq!(buffer.last_snapshot_seq, 5); + assert_eq!(buffer.activity, ActivityState::Activity); + assert_eq!(buffer.title, "current-title"); + assert!(receiver.try_recv().is_err()); + + runtime + .record_buffer_update( + buffer_id, + BufferRuntimeUpdate { + sequence: 6, + activity: ActivityState::Bell, + title: Some(Some("fresh-title".to_owned())), + }, + ) + .await; + + let buffer = runtime + .state + .lock() + .await + .buffer(buffer_id) + .expect("buffer exists") + .clone(); + assert_eq!(buffer.last_snapshot_seq, 6); + assert_eq!(buffer.activity, ActivityState::Bell); + assert_eq!(buffer.title, "fresh-title"); + assert!(matches!( + receiver.try_recv(), + Ok(ServerEnvelope::Event(ServerEvent::RenderInvalidated(event))) + if event.buffer_id == buffer_id + )); + } + + #[tokio::test] + async fn record_buffer_update_clears_title() { + let runtime = Runtime::new( + ServerState::new(), + PathBuf::from("server.sock"), + PathBuf::from("workspace"), + PathBuf::from("runtime"), + BTreeMap::new(), + ); + let buffer_id = { + let mut state = runtime.state.lock().await; + let buffer_id = state.create_buffer("current-title", vec!["/bin/sh".to_owned()], None); + let buffer = state + .buffers + .get_mut(&buffer_id) + .expect("buffer is created"); + buffer.last_snapshot_seq = 5; + buffer_id + }; + + runtime + .record_buffer_update( + buffer_id, + BufferRuntimeUpdate { + sequence: 6, + activity: ActivityState::Idle, + title: Some(None), + }, + ) + .await; + + let buffer = runtime + .state + .lock() + .await + .buffer(buffer_id) + .expect("buffer exists") + .clone(); + assert_eq!(buffer.last_snapshot_seq, 6); + assert_eq!(buffer.title, ""); + } + + #[tokio::test] + async fn restore_buffer_runtimes_clears_missing_socket_paths() { + let tempdir = tempdir().expect("tempdir"); + let mut state = ServerState::new(); + let buffer_id = state.create_buffer("buffer", vec!["/bin/sh".to_owned()], None); + state + .mark_buffer_running(buffer_id, Some(42)) + .expect("mark running"); + state + .set_buffer_runtime_socket_path( + buffer_id, + Some(tempdir.path().join("missing-runtime.sock")), + ) + .expect("set runtime socket path"); + + let runtime = Arc::new(Runtime::new( + state, + tempdir.path().join("server.sock"), + tempdir.path().join("workspace.json"), + tempdir.path().join("runtime"), + BTreeMap::new(), + )); + + runtime + .restore_buffer_runtimes() + .await + .expect("restore succeeds"); + + let state = runtime.state.lock().await; + let buffer = state.buffer(buffer_id).expect("buffer exists"); + assert!(matches!(buffer.state, BufferState::Interrupted(_))); + assert_eq!(buffer.runtime_socket_path(), None); + } + + #[cfg(unix)] + #[tokio::test] + async fn restore_buffer_runtimes_clears_unreachable_socket_paths() { + let tempdir = tempdir().expect("tempdir"); + let socket_path = tempdir.path().join("stale-runtime.sock"); + let listener = StdUnixListener::bind(&socket_path).expect("bind stale socket"); + drop(listener); + + let mut state = ServerState::new(); + let buffer_id = state.create_buffer("buffer", vec!["/bin/sh".to_owned()], None); + state + .mark_buffer_running(buffer_id, Some(42)) + .expect("mark running"); + state + .set_buffer_runtime_socket_path(buffer_id, Some(socket_path.clone())) + .expect("set runtime socket path"); + + let runtime = Arc::new(Runtime::new( + state, + tempdir.path().join("server.sock"), + tempdir.path().join("workspace.json"), + tempdir.path().join("runtime"), + BTreeMap::new(), + )); + + runtime + .restore_buffer_runtimes() + .await + .expect("restore succeeds"); + + let state = runtime.state.lock().await; + let buffer = state.buffer(buffer_id).expect("buffer exists"); + assert!(matches!(buffer.state, BufferState::Interrupted(_))); + assert_eq!(buffer.runtime_socket_path(), None); + } } diff --git a/crates/embers-server/src/state.rs b/crates/embers-server/src/state.rs index 35cfbdb..416d58d 100644 --- a/crates/embers-server/src/state.rs +++ b/crates/embers-server/src/state.rs @@ -122,7 +122,7 @@ impl ServerState { let safe_next_node_id = next_id_after_max(nodes.keys().map(|id| id.0)); let safe_next_floating_id = next_id_after_max(floating.keys().map(|id| id.0)); - let mut state = Self { + let state = Self { sessions, buffers, nodes, @@ -132,7 +132,6 @@ impl ServerState { node_ids: IdAllocator::new(next_node_id.max(safe_next_node_id)), floating_ids: IdAllocator::new(next_floating_id.max(safe_next_floating_id)), }; - state.interrupt_unrecoverable_buffers(); state.validate()?; Ok(state) } @@ -346,22 +345,8 @@ impl ServerState { env: BTreeMap, ) -> BufferId { let buffer_id = self.buffer_ids.next(); - self.buffers.insert( - buffer_id, - Buffer { - id: buffer_id, - title: title.into(), - command, - cwd, - env, - state: BufferState::Created, - attachment: BufferAttachment::Detached, - pty_size: PtySize::new(80, 24), - activity: ActivityState::Idle, - last_snapshot_seq: 0, - created_at: Timestamp::now(), - }, - ); + self.buffers + .insert(buffer_id, Buffer::new(buffer_id, title, command, cwd, env)); buffer_id } @@ -388,6 +373,27 @@ impl ServerState { Ok(()) } + pub fn set_buffer_runtime_socket_path( + &mut self, + buffer_id: BufferId, + runtime_socket_path: Option, + ) -> Result<()> { + self.buffer_mut(buffer_id)? + .set_runtime_socket_path(runtime_socket_path); + Ok(()) + } + + pub fn mark_buffer_interrupted(&mut self, buffer_id: BufferId, pid: Option) -> Result<()> { + let buffer = self.buffer_mut(buffer_id)?; + if matches!(buffer.state, BufferState::Exited(_)) { + return Ok(()); + } + buffer.state = BufferState::Interrupted(InterruptedBuffer { + last_known_pid: pid, + }); + Ok(()) + } + pub fn mark_buffer_exited( &mut self, buffer_id: BufferId, @@ -401,17 +407,6 @@ impl ServerState { Ok(()) } - pub fn mark_buffer_interrupted(&mut self, buffer_id: BufferId) -> Result<()> { - let buffer = self.buffer_mut(buffer_id)?; - let last_known_pid = match &buffer.state { - BufferState::Running(running) => running.pid, - BufferState::Interrupted(interrupted) => interrupted.last_known_pid, - BufferState::Created | BufferState::Exited(_) => None, - }; - buffer.state = BufferState::Interrupted(InterruptedBuffer { last_known_pid }); - Ok(()) - } - pub fn interrupt_unrecoverable_buffers(&mut self) { for buffer in self.buffers.values_mut() { buffer.state = match &buffer.state { diff --git a/crates/embers-server/tests/persistence.rs b/crates/embers-server/tests/persistence.rs index ecb4a84..dc1d7ff 100644 --- a/crates/embers-server/tests/persistence.rs +++ b/crates/embers-server/tests/persistence.rs @@ -1,7 +1,7 @@ use embers_core::{BufferId, RequestId, init_test_tracing}; use embers_protocol::{ - BufferRecordState, BufferRequest, BufferResponse, BuffersResponse, ClientMessage, - ProtocolClient, ServerResponse, SessionRequest, SessionSnapshotResponse, + BufferRecordState, BufferRequest, BufferResponse, BuffersResponse, ClientMessage, InputRequest, + ProtocolClient, ServerResponse, SessionRequest, SessionSnapshotResponse, SnapshotResponse, }; use embers_server::{Server, ServerConfig}; use tempfile::tempdir; @@ -43,8 +43,38 @@ async fn request_buffers(client: &mut ProtocolClient, request: BufferRequest) -> } } +async fn wait_for_snapshot_line( + client: &mut ProtocolClient, + request_id: RequestId, + buffer_id: BufferId, + expected: &str, +) -> SnapshotResponse { + let deadline = Instant::now() + Duration::from_secs(2); + loop { + if let Ok(ServerResponse::Snapshot(snapshot)) = client + .request(&ClientMessage::Buffer(BufferRequest::Capture { + request_id, + buffer_id, + })) + .await + && snapshot.lines.iter().any(|line| line.contains(expected)) + { + return snapshot; + } + + if Instant::now() >= deadline { + break; + } + sleep(Duration::from_millis(25)).await; + } + + panic!( + "capture for buffer {buffer_id} did not contain expected line '{expected}' before timeout" + ); +} + #[tokio::test] -async fn clean_restart_restores_workspace_and_marks_live_buffers_interrupted() { +async fn clean_restart_restores_workspace_and_keeps_live_buffers_running() { init_test_tracing(); let tempdir = tempdir().expect("tempdir"); @@ -158,7 +188,7 @@ async fn clean_restart_restores_workspace_and_marks_live_buffers_interrupted() { .iter() .find(|buffer| buffer.id == attached_id) .expect("attached buffer restored"); - assert_eq!(attached_buffer.state, BufferRecordState::Interrupted); + assert_eq!(attached_buffer.state, BufferRecordState::Running); assert!(attached_buffer.attachment_node_id.is_some()); let buffers = request_buffers( @@ -176,22 +206,45 @@ async fn clean_restart_restores_workspace_and_marks_live_buffers_interrupted() { .iter() .find(|buffer| buffer.id == detached_id) .expect("detached buffer restored"); - assert_eq!(detached_buffer.state, BufferRecordState::Interrupted); + assert_eq!(detached_buffer.state, BufferRecordState::Running); assert_eq!(detached_buffer.attachment_node_id, None); - let send_err = client - .request(&ClientMessage::Buffer(BufferRequest::Get { + match client + .request(&ClientMessage::Input(InputRequest::Send { request_id: RequestId(7), - buffer_id: BufferId(detached_id.0), + buffer_id: attached_id, + bytes: b"printf restarted-attached\\n\r".to_vec(), })) .await - .expect("buffer get succeeds"); - match send_err { - ServerResponse::Buffer(response) => { - assert_eq!(response.buffer.state, BufferRecordState::Interrupted); - } - other => panic!("expected restored interrupted buffer, got {other:?}"), + .expect("send to attached buffer succeeds") + { + ServerResponse::Ok(_) => {} + other => panic!("expected ok response, got {other:?}"), } + match client + .request(&ClientMessage::Input(InputRequest::Send { + request_id: RequestId(8), + buffer_id: detached_id, + bytes: b"printf restarted-detached\\n\r".to_vec(), + })) + .await + .expect("send to detached buffer succeeds") + { + ServerResponse::Ok(_) => {} + other => panic!("expected ok response, got {other:?}"), + } + + let _attached_capture = + wait_for_snapshot_line(&mut client, RequestId(9), attached_id, "restarted-attached").await; + + let _detached_capture = wait_for_snapshot_line( + &mut client, + RequestId(10), + detached_id, + "restarted-detached", + ) + .await; + handle.shutdown().await.expect("shutdown restarted server"); } diff --git a/crates/embers-test-support/Cargo.toml b/crates/embers-test-support/Cargo.toml index 50a66f5..b3961c0 100644 --- a/crates/embers-test-support/Cargo.toml +++ b/crates/embers-test-support/Cargo.toml @@ -11,6 +11,7 @@ assert_cmd.workspace = true embers-core = { path = "../embers-core" } embers-protocol = { path = "../embers-protocol" } embers-server = { path = "../embers-server" } +libc.workspace = true portable-pty.workspace = true tempfile.workspace = true tokio.workspace = true diff --git a/crates/embers-test-support/src/lib.rs b/crates/embers-test-support/src/lib.rs index dab8150..0338336 100644 --- a/crates/embers-test-support/src/lib.rs +++ b/crates/embers-test-support/src/lib.rs @@ -2,8 +2,10 @@ mod cli; mod protocol; mod pty; mod server; +mod test_lock; pub use cli::{cargo_bin, cargo_bin_path}; pub use protocol::TestConnection; pub use pty::PtyHarness; pub use server::TestServer; +pub use test_lock::{InterprocessTestLock, acquire_test_lock}; diff --git a/crates/embers-test-support/src/test_lock.rs b/crates/embers-test-support/src/test_lock.rs new file mode 100644 index 0000000..de3809b --- /dev/null +++ b/crates/embers-test-support/src/test_lock.rs @@ -0,0 +1,115 @@ +use std::fs::{File, OpenOptions}; +use std::io; +#[cfg(unix)] +use std::os::fd::AsRawFd; +#[cfg(not(unix))] +use std::path::PathBuf; +use std::sync::{Arc, OnceLock}; + +#[cfg(not(unix))] +use std::thread; +#[cfg(not(unix))] +use std::time::{Duration, Instant}; +use tokio::sync::{Mutex, OwnedMutexGuard}; + +const TEST_LOCK_FILE_NAME: &str = "embers-integration-tests.lock"; +#[cfg(not(unix))] +const FILE_LOCK_RETRY_DELAY: Duration = Duration::from_millis(10); +#[cfg(not(unix))] +const FILE_LOCK_TIMEOUT: Duration = Duration::from_secs(10); + +fn process_lock() -> Arc> { + static LOCK: OnceLock>> = OnceLock::new(); + LOCK.get_or_init(|| Arc::new(Mutex::new(()))).clone() +} + +pub struct InterprocessTestLock { + _process_guard: OwnedMutexGuard<()>, + file: File, + #[cfg(not(unix))] + path: PathBuf, +} + +pub async fn acquire_test_lock() -> io::Result { + let process_guard = process_lock().lock_owned().await; + let path = std::env::temp_dir().join(TEST_LOCK_FILE_NAME); + + #[cfg(unix)] + let file = tokio::task::spawn_blocking(move || acquire_file_lock(path)) + .await + .map_err(|error| io::Error::other(error.to_string()))??; + + #[cfg(not(unix))] + let (file, path) = tokio::task::spawn_blocking(move || acquire_file_lock(path)) + .await + .map_err(|error| io::Error::other(error.to_string()))??; + + Ok(InterprocessTestLock { + _process_guard: process_guard, + file, + #[cfg(not(unix))] + path, + }) +} + +#[cfg(unix)] +fn acquire_file_lock(path: std::path::PathBuf) -> io::Result { + if let Some(parent) = path.parent() { + std::fs::create_dir_all(parent)?; + } + let file = OpenOptions::new() + .read(true) + .write(true) + .create(true) + .truncate(false) + .open(path)?; + let result = unsafe { libc::flock(file.as_raw_fd(), libc::LOCK_EX) }; + if result != 0 { + return Err(io::Error::last_os_error()); + } + Ok(file) +} + +#[cfg(not(unix))] +fn acquire_file_lock(path: PathBuf) -> io::Result<(File, PathBuf)> { + if let Some(parent) = path.parent() { + std::fs::create_dir_all(parent)?; + } + let started = Instant::now(); + loop { + match OpenOptions::new() + .read(true) + .write(true) + .create_new(true) + .open(&path) + { + Ok(file) => return Ok((file, path)), + Err(error) if error.kind() == io::ErrorKind::AlreadyExists => { + thread::sleep(FILE_LOCK_RETRY_DELAY); + if started.elapsed() >= FILE_LOCK_TIMEOUT { + return Err(io::Error::new( + io::ErrorKind::TimedOut, + format!( + "timed out acquiring integration test lock at {}; remove the orphaned lock file if no other test process is using it", + path.display() + ), + )); + } + } + Err(error) => return Err(error), + } + } +} + +impl Drop for InterprocessTestLock { + fn drop(&mut self) { + #[cfg(unix)] + { + let _ = unsafe { libc::flock(self.file.as_raw_fd(), libc::LOCK_UN) }; + } + #[cfg(not(unix))] + { + let _ = std::fs::remove_file(&self.path); + } + } +} diff --git a/crates/embers-test-support/tests/buffer_runtime.rs b/crates/embers-test-support/tests/buffer_runtime.rs index 0b2a688..6700541 100644 --- a/crates/embers-test-support/tests/buffer_runtime.rs +++ b/crates/embers-test-support/tests/buffer_runtime.rs @@ -5,7 +5,7 @@ use embers_protocol::{ BufferRecord, BufferRecordState, BufferRequest, ClientMessage, InputRequest, OkResponse, ServerResponse, SnapshotResponse, }; -use embers_test_support::{TestConnection, TestServer}; +use embers_test_support::{TestConnection, TestServer, acquire_test_lock}; use tokio::time::sleep; async fn create_buffer(connection: &mut TestConnection, command: &[&str]) -> BufferRecord { @@ -178,6 +178,7 @@ async fn wait_for_exit( #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn detached_buffers_accept_input_and_keep_running_after_detach_requests() { + let _guard = acquire_test_lock().await.expect("acquire test lock"); let server = TestServer::start().await.expect("start server"); let mut connection = TestConnection::connect(server.socket_path()) .await @@ -215,6 +216,7 @@ async fn detached_buffers_accept_input_and_keep_running_after_detach_requests() #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn resize_and_kill_requests_update_buffer_state_and_preserve_capture() { + let _guard = acquire_test_lock().await.expect("acquire test lock"); let server = TestServer::start().await.expect("start server"); let mut connection = TestConnection::connect(server.socket_path()) .await @@ -245,6 +247,7 @@ async fn resize_and_kill_requests_update_buffer_state_and_preserve_capture() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn capture_preserves_scrollback_for_long_output() { + let _guard = acquire_test_lock().await.expect("acquire test lock"); let server = TestServer::start().await.expect("start server"); let mut connection = TestConnection::connect(server.socket_path()) .await @@ -270,6 +273,7 @@ async fn capture_preserves_scrollback_for_long_output() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn visible_snapshot_surfaces_terminal_modes_and_cursor_metadata() { + let _guard = acquire_test_lock().await.expect("acquire test lock"); let server = TestServer::start().await.expect("start server"); let mut connection = TestConnection::connect(server.socket_path()) .await @@ -305,6 +309,7 @@ async fn visible_snapshot_surfaces_terminal_modes_and_cursor_metadata() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn scrollback_slice_returns_history_while_full_capture_stays_available() { + let _guard = acquire_test_lock().await.expect("acquire test lock"); let server = TestServer::start().await.expect("start server"); let mut connection = TestConnection::connect(server.socket_path()) .await