diff --git a/src/app.rs b/src/app.rs index 04f21d8..2ea8072 100644 --- a/src/app.rs +++ b/src/app.rs @@ -30,6 +30,7 @@ const TEMPLATE_NAME: &str = "main.html"; static TEMPLATE_ENV: OnceLock> = OnceLock::new(); const MERMAID_JS: &str = include_str!("../static/js/mermaid.min.js"); const MERMAID_ETAG: &str = concat!("\"", env!("CARGO_PKG_VERSION"), "\""); +const MAX_PORT_ATTEMPTS: u16 = 10; type SharedMarkdownState = Arc>; @@ -321,6 +322,29 @@ fn new_router( Ok(router) } +async fn bind_with_retry(hostname: &str, port: u16) -> Result<(TcpListener, u16)> { + let mut last_err = None; + for offset in 0..MAX_PORT_ATTEMPTS { + let try_port = match port.checked_add(offset) { + Some(p) => p, + None => break, + }; + match TcpListener::bind((hostname, try_port)).await { + Ok(listener) => return Ok((listener, try_port)), + Err(e) if e.kind() == std::io::ErrorKind::AddrInUse => last_err = Some(e), + Err(e) => return Err(e.into()), + } + } + Err(last_err + .map(|e| anyhow::anyhow!(e)) + .unwrap_or_else(|| anyhow::anyhow!("no valid port in range")) + .context(format!( + "could not bind to ports {}--{}", + port, + port.saturating_add(MAX_PORT_ATTEMPTS - 1) + ))) +} + pub(crate) async fn serve_markdown( base_dir: PathBuf, tracked_files: Vec, @@ -334,9 +358,13 @@ pub(crate) async fn serve_markdown( let first_file = tracked_files.first().cloned(); let router = new_router(base_dir.clone(), tracked_files, is_directory_mode)?; - let listener = TcpListener::bind((hostname, port)).await?; + let (listener, actual_port) = bind_with_retry(hostname, port).await?; + + if actual_port != port { + println!("⚠ Port {port} in use, using {actual_port} instead"); + } - let listen_addr = format_host(hostname, port); + let listen_addr = format_host(hostname, actual_port); if is_directory_mode { println!("📁 Serving markdown files from: {}", base_dir.display()); @@ -349,7 +377,7 @@ pub(crate) async fn serve_markdown( println!("\nPress Ctrl+C to stop the server"); if open { - let browse_addr = format_host(&browsable_host(hostname), port); + let browse_addr = format_host(&browsable_host(hostname), actual_port); open_browser(&format!("http://{browse_addr}"))?; } @@ -837,6 +865,23 @@ mod tests { assert_eq!(browsable_host("example.com"), "example.com"); } + #[tokio::test] + async fn test_bind_retries_on_addr_in_use() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let blocked_port = listener.local_addr().unwrap().port(); + + let (retry_listener, actual_port) = + bind_with_retry("127.0.0.1", blocked_port).await.unwrap(); + + assert!( + actual_port > blocked_port, + "Should bind to a higher port when requested port is in use" + ); + + drop(retry_listener); + drop(listener); + } + use axum_test::TestServer; use std::time::Duration; use tempfile::{Builder, NamedTempFile, TempDir};