From a96446aec29ddc4e09cf0e8046d5499c77844f16 Mon Sep 17 00:00:00 2001 From: Colin Marc Date: Mon, 23 Mar 2026 19:32:22 +0100 Subject: [PATCH] Allow creating servers with multiple ApiDescriptions This adds the ability to merge multiple `ApiDescription` instances by combining their routes. The routes are not allowed to overlap. Fixes #1069 --- dropshot/src/router.rs | 259 ++++++++++++++++++++++++++++++++++++++--- dropshot/src/server.rs | 34 +++++- 2 files changed, 271 insertions(+), 22 deletions(-) diff --git a/dropshot/src/router.rs b/dropshot/src/router.rs index e624a2219..acabad0a1 100644 --- a/dropshot/src/router.rs +++ b/dropshot/src/router.rs @@ -426,6 +426,15 @@ impl HttpRouter { self.has_versioned_routes } + /// Merge all routes from `other` into `self`. + /// + /// This handles conflicts identically to `insert` (ie it panics). + pub fn extend(&mut self, other: HttpRouter) { + for endpoint in other.into_endpoints() { + self.insert(endpoint); + } + } + #[cfg(test)] pub fn lookup_route_unversioned( &self, @@ -573,6 +582,12 @@ impl HttpRouter { ) -> HttpRouterIter<'a, Context> { HttpRouterIter::new(self, version) } + + /// Consumes this router and return an iterator over all registered + /// endpoints. + pub fn into_endpoints(self) -> IntoEndpoints { + IntoEndpoints::new(self) + } } /// Given a list of handlers, return the first one matching the given semver @@ -607,6 +622,73 @@ fn insert_var( varnames.insert(new_varname.clone()); } +/// A consuming iterator over all endpoints in an `HttpRouter`. +pub struct IntoEndpoints { + /// Remaining nodes to visit. + nodes: Vec>, + /// Remaining method groups from the current node. + methods: std::collections::btree_map::IntoValues< + String, + Vec>, + >, + /// Remaining endpoints from the current group. + current: std::vec::IntoIter>, +} + +impl IntoEndpoints { + /// Create a new `IntoEndpoints` iterator starting at the root of the tree. + fn new(router: HttpRouter) -> Self { + let nodes = match router.root.edges { + None => vec![], + Some(edges) => match edges { + HttpRouterEdges::Literals(map) => { + map.into_values().map(|v| *v).collect() + } + HttpRouterEdges::VariableSingle(_, child) + | HttpRouterEdges::VariableRest(_, child) => { + vec![*child] + } + }, + }; + + let mut methods = router.root.methods.into_values(); + let current = methods.next().unwrap_or_default().into_iter(); + + Self { nodes, methods, current } + } +} + +impl Iterator for IntoEndpoints { + type Item = ApiEndpoint; + + fn next(&mut self) -> Option { + loop { + if let Some(endpoint) = self.current.next() { + return Some(endpoint); + } + + if let Some(handlers) = self.methods.next() { + self.current = handlers.into_iter(); + continue; + } + + let node = self.nodes.pop()?; + match node.edges { + None => (), + Some(HttpRouterEdges::Literals(map)) => { + self.nodes.extend(map.into_values().map(|n| *n)) + } + Some(HttpRouterEdges::VariableSingle(_, child)) + | Some(HttpRouterEdges::VariableRest(_, child)) => { + self.nodes.push(*child); + } + } + + self.methods = node.methods.into_values(); + } + } +} + /// Route Interator implementation. We perform a preorder, depth first traversal /// of the tree starting from the root node. For each node, we enumerate the /// methods and then descend into its children (or single child in the case of @@ -1636,8 +1718,17 @@ mod test { #[test] fn test_iter_null() { let router = HttpRouter::<()>::new(); - let ret: Vec<_> = router.endpoints(None).map(|x| (x.0, x.1)).collect(); - assert_eq!(ret, vec![]); + assert_eq!( + router.endpoints(None).map(|x| (x.0, x.1)).collect::>(), + vec![] + ); + assert_eq!( + router + .into_endpoints() + .map(|x| (x.path, x.method.to_string())) + .collect::>(), + vec![] + ); } #[test] @@ -1653,17 +1744,24 @@ mod test { Method::GET, "/projects/{project_id}/instances", )); - let ret: Vec<_> = router.endpoints(None).map(|x| (x.0, x.1)).collect(); + + let expected = vec![ + ("/".to_string(), "GET".to_string()), + ("/projects/{project_id}/instances".to_string(), "GET".to_string()), + ]; + assert_eq!( - ret, - vec![ - ("/".to_string(), "GET".to_string(),), - ( - "/projects/{project_id}/instances".to_string(), - "GET".to_string(), - ), - ] + router.endpoints(None).map(|x| (x.0, x.1)).collect::>(), + expected, ); + + assert_eq!( + router + .into_endpoints() + .map(|x| (x.path, x.method.to_string())) + .collect::>(), + expected, + ) } #[test] @@ -1679,13 +1777,22 @@ mod test { Method::POST, "/", )); - let ret: Vec<_> = router.endpoints(None).map(|x| (x.0, x.1)).collect(); + + let expected = vec![ + ("/".to_string(), "GET".to_string()), + ("/".to_string(), "POST".to_string()), + ]; + + assert_eq!( + router.endpoints(None).map(|x| (x.0, x.1)).collect::>(), + expected + ); assert_eq!( - ret, - vec![ - ("/".to_string(), "GET".to_string(),), - ("/".to_string(), "POST".to_string(),), - ] + router + .into_endpoints() + .map(|x| (x.path, x.method.to_string())) + .collect::>(), + expected ); } @@ -1808,4 +1915,122 @@ mod test { } } } + + #[test] + fn test_extend() { + let mut router1 = HttpRouter::new(); + router1.insert(new_endpoint( + new_handler_named("root_get"), + Method::GET, + "/", + )); + + let mut router2 = HttpRouter::new(); + router2.insert(new_endpoint( + new_handler_named("root_post"), + Method::POST, + "/", + )); + + router2.insert(new_endpoint( + new_handler_named("i"), + Method::GET, + "/projects/{project_id}/instances", + )); + + router1.extend(router2); + + let result = + router1.lookup_route_unversioned(&Method::GET, "/".into()).unwrap(); + assert_eq!(result.handler.label(), "root_get"); + + let result = router1 + .lookup_route_unversioned(&Method::POST, "/".into()) + .unwrap(); + assert_eq!(result.handler.label(), "root_post"); + + let result = router1 + .lookup_route_unversioned( + &Method::GET, + "/projects/p1/instances".into(), + ) + .unwrap(); + assert_eq!(result.handler.label(), "i"); + } + + #[test] + fn test_extend_variables() { + // Extend where the other router uses path variables. + let mut router1 = HttpRouter::new(); + router1.insert(new_endpoint( + new_handler_named("h1"), + Method::GET, + "/projects", + )); + + let mut router2 = HttpRouter::new(); + router2.insert(new_endpoint( + new_handler_named("h2"), + Method::GET, + "/projects/{project_id}/instances", + )); + + router1.extend(router2); + + let result = router1 + .lookup_route_unversioned(&Method::GET, "/projects".into()) + .unwrap(); + assert_eq!(result.handler.label(), "h1"); + + let result = router1 + .lookup_route_unversioned( + &Method::GET, + "/projects/p1/instances".into(), + ) + .unwrap(); + assert_eq!(result.handler.label(), "h2"); + assert_eq!( + *result.endpoint.variables.get("project_id").unwrap(), + VariableValue::String("p1".to_string()) + ); + } + + #[test] + fn test_extend_empty() { + let mut router1 = HttpRouter::new(); + let mut router2 = HttpRouter::new(); + router2.insert(new_endpoint( + new_handler_named("h1"), + Method::GET, + "/foo", + )); + + router1.extend(router2); + + let result = router1 + .lookup_route_unversioned(&Method::GET, "/foo".into()) + .unwrap(); + assert_eq!(result.handler.label(), "h1"); + } + + #[test] + #[should_panic(expected = "URI path \"/foo\": attempted to create \ + duplicate route for method \"GET\"")] + fn test_extend_conflict() { + let mut router1 = HttpRouter::new(); + router1.insert(new_endpoint( + new_handler_named("h1"), + Method::GET, + "/foo", + )); + + let mut router2 = HttpRouter::new(); + router2.insert(new_endpoint( + new_handler_named("h2"), + Method::GET, + "/foo", + )); + + router1.extend(router2); + } } diff --git a/dropshot/src/server.rs b/dropshot/src/server.rs index 0bda75fb7..57a251557 100644 --- a/dropshot/src/server.rs +++ b/dropshot/src/server.rs @@ -159,7 +159,7 @@ impl HttpServerStarter { fn new_internal( config: &ConfigDropshot, - api: ApiDescription, + router: HttpRouter, private: C, log: &Logger, tls: Option, @@ -207,7 +207,6 @@ impl HttpServerStarter { .transpose()?; let handler_waitgroup = WaitGroup::new(); - let router = api.into_router(); if let VersionPolicy::Unversioned = version_policy { if router.has_versioned_routes() { return Err(BuildError::UnversionedServerHasVersionedRoutes); @@ -1141,7 +1140,7 @@ pub struct ServerBuilder { // required caller-provided values private: C, log: Logger, - api: DebugIgnore>, + router: HttpRouter, // optional caller-provided values config: ConfigDropshot, @@ -1164,13 +1163,38 @@ impl ServerBuilder { ServerBuilder { private, log, - api: DebugIgnore(api), + router: api.into_router(), + config: Default::default(), + version_policy: VersionPolicy::Unversioned, + tls: Default::default(), + } + } + + /// Create an empty Dropshot server + /// + /// The server will have no routes. You can add routes with + /// [`ServerBuilder::routes`]. + pub fn empty(private: C, log: Logger) -> ServerBuilder { + ServerBuilder { + private, + log, + router: HttpRouter::new(), config: Default::default(), version_policy: VersionPolicy::Unversioned, tls: Default::default(), } } + /// Add routes from an [ApiDescription] + /// + /// # Panics + /// + /// Panics if the added routes conflict with previously-added routes. + pub fn description(mut self, api: ApiDescription) -> Self { + self.router.extend(api.into_router()); + self + } + /// Specify the server configuration pub fn config(mut self, config: ConfigDropshot) -> Self { self.config = config; @@ -1223,7 +1247,7 @@ impl ServerBuilder { pub fn build_starter(self) -> Result, BuildError> { HttpServerStarter::new_internal( &self.config, - self.api.0, + self.router, self.private, &self.log, self.tls,