Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 48 additions & 3 deletions src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ const TEMPLATE_NAME: &str = "main.html";
static TEMPLATE_ENV: OnceLock<Environment<'static>> = OnceLock::new();
const MERMAID_JS: &str = include_str!("../static/js/mermaid.min.js");
const MERMAID_ETAG: &str = concat!("\"", env!("CARGO_PKG_VERSION"), "\"");
const MAX_PORT_ATTEMPTS: u16 = 10;

type SharedMarkdownState = Arc<Mutex<MarkdownState>>;

Expand Down Expand Up @@ -321,6 +322,29 @@ fn new_router(
Ok(router)
}

async fn bind_with_retry(hostname: &str, port: u16) -> Result<(TcpListener, u16)> {
let mut last_err = None;
for offset in 0..MAX_PORT_ATTEMPTS {
let try_port = match port.checked_add(offset) {
Some(p) => p,
None => break,
};
match TcpListener::bind((hostname, try_port)).await {
Ok(listener) => return Ok((listener, try_port)),
Err(e) if e.kind() == std::io::ErrorKind::AddrInUse => last_err = Some(e),
Err(e) => return Err(e.into()),
}
}
Err(last_err
.map(|e| anyhow::anyhow!(e))
.unwrap_or_else(|| anyhow::anyhow!("no valid port in range"))
.context(format!(
"could not bind to ports {}--{}",
port,
port.saturating_add(MAX_PORT_ATTEMPTS - 1)
)))
}

pub(crate) async fn serve_markdown(
base_dir: PathBuf,
tracked_files: Vec<PathBuf>,
Expand All @@ -334,9 +358,13 @@ pub(crate) async fn serve_markdown(
let first_file = tracked_files.first().cloned();
let router = new_router(base_dir.clone(), tracked_files, is_directory_mode)?;

let listener = TcpListener::bind((hostname, port)).await?;
let (listener, actual_port) = bind_with_retry(hostname, port).await?;

if actual_port != port {
println!("⚠ Port {port} in use, using {actual_port} instead");
}

let listen_addr = format_host(hostname, port);
let listen_addr = format_host(hostname, actual_port);

if is_directory_mode {
println!("📁 Serving markdown files from: {}", base_dir.display());
Expand All @@ -349,7 +377,7 @@ pub(crate) async fn serve_markdown(
println!("\nPress Ctrl+C to stop the server");

if open {
let browse_addr = format_host(&browsable_host(hostname), port);
let browse_addr = format_host(&browsable_host(hostname), actual_port);
open_browser(&format!("http://{browse_addr}"))?;
}

Expand Down Expand Up @@ -837,6 +865,23 @@ mod tests {
assert_eq!(browsable_host("example.com"), "example.com");
}

#[tokio::test]
async fn test_bind_retries_on_addr_in_use() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let blocked_port = listener.local_addr().unwrap().port();

let (retry_listener, actual_port) =
bind_with_retry("127.0.0.1", blocked_port).await.unwrap();

assert!(
actual_port > blocked_port,
"Should bind to a higher port when requested port is in use"
);

drop(retry_listener);
drop(listener);
}

use axum_test::TestServer;
use std::time::Duration;
use tempfile::{Builder, NamedTempFile, TempDir};
Expand Down