Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
259 changes: 242 additions & 17 deletions dropshot/src/router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,15 @@ impl<Context: ServerContext> HttpRouter<Context> {
self.has_versioned_routes
}

/// Merge all routes from `other` into `self`.
///
/// This handles conflicts identically to `insert` (ie it panics).
pub fn extend(&mut self, other: HttpRouter<Context>) {
for endpoint in other.into_endpoints() {
self.insert(endpoint);
}
}

#[cfg(test)]
pub fn lookup_route_unversioned(
&self,
Expand Down Expand Up @@ -573,6 +582,12 @@ impl<Context: ServerContext> HttpRouter<Context> {
) -> HttpRouterIter<'a, Context> {
HttpRouterIter::new(self, version)
}

/// Consumes this router and return an iterator over all registered
/// endpoints.
pub fn into_endpoints(self) -> IntoEndpoints<Context> {
IntoEndpoints::new(self)
}
}

/// Given a list of handlers, return the first one matching the given semver
Expand Down Expand Up @@ -607,6 +622,73 @@ fn insert_var(
varnames.insert(new_varname.clone());
}

/// A consuming iterator over all endpoints in an `HttpRouter`.
pub struct IntoEndpoints<Context: ServerContext> {
/// Remaining nodes to visit.
nodes: Vec<HttpRouterNode<Context>>,
/// Remaining method groups from the current node.
methods: std::collections::btree_map::IntoValues<
String,
Vec<ApiEndpoint<Context>>,
>,
/// Remaining endpoints from the current group.
current: std::vec::IntoIter<ApiEndpoint<Context>>,
}

impl<C: ServerContext> IntoEndpoints<C> {
/// Create a new `IntoEndpoints` iterator starting at the root of the tree.
fn new(router: HttpRouter<C>) -> Self {
let nodes = match router.root.edges {
None => vec![],
Some(edges) => match edges {
HttpRouterEdges::Literals(map) => {
map.into_values().map(|v| *v).collect()
}
HttpRouterEdges::VariableSingle(_, child)
| HttpRouterEdges::VariableRest(_, child) => {
vec![*child]
}
},
};

let mut methods = router.root.methods.into_values();
let current = methods.next().unwrap_or_default().into_iter();

Self { nodes, methods, current }
}
}

impl<Context: ServerContext> Iterator for IntoEndpoints<Context> {
type Item = ApiEndpoint<Context>;

fn next(&mut self) -> Option<Self::Item> {
loop {
if let Some(endpoint) = self.current.next() {
return Some(endpoint);
}

if let Some(handlers) = self.methods.next() {
self.current = handlers.into_iter();
continue;
}

let node = self.nodes.pop()?;
match node.edges {
None => (),
Some(HttpRouterEdges::Literals(map)) => {
self.nodes.extend(map.into_values().map(|n| *n))
}
Some(HttpRouterEdges::VariableSingle(_, child))
| Some(HttpRouterEdges::VariableRest(_, child)) => {
self.nodes.push(*child);
}
}

self.methods = node.methods.into_values();
}
}
}

/// Route Interator implementation. We perform a preorder, depth first traversal
/// of the tree starting from the root node. For each node, we enumerate the
/// methods and then descend into its children (or single child in the case of
Expand Down Expand Up @@ -1636,8 +1718,17 @@ mod test {
#[test]
fn test_iter_null() {
let router = HttpRouter::<()>::new();
let ret: Vec<_> = router.endpoints(None).map(|x| (x.0, x.1)).collect();
assert_eq!(ret, vec![]);
assert_eq!(
router.endpoints(None).map(|x| (x.0, x.1)).collect::<Vec<_>>(),
vec![]
);
assert_eq!(
router
.into_endpoints()
.map(|x| (x.path, x.method.to_string()))
.collect::<Vec<_>>(),
vec![]
);
}

#[test]
Expand All @@ -1653,17 +1744,24 @@ mod test {
Method::GET,
"/projects/{project_id}/instances",
));
let ret: Vec<_> = router.endpoints(None).map(|x| (x.0, x.1)).collect();

let expected = vec![
("/".to_string(), "GET".to_string()),
("/projects/{project_id}/instances".to_string(), "GET".to_string()),
];

assert_eq!(
ret,
vec![
("/".to_string(), "GET".to_string(),),
(
"/projects/{project_id}/instances".to_string(),
"GET".to_string(),
),
]
router.endpoints(None).map(|x| (x.0, x.1)).collect::<Vec<_>>(),
expected,
);

assert_eq!(
router
.into_endpoints()
.map(|x| (x.path, x.method.to_string()))
.collect::<Vec<_>>(),
expected,
)
}

#[test]
Expand All @@ -1679,13 +1777,22 @@ mod test {
Method::POST,
"/",
));
let ret: Vec<_> = router.endpoints(None).map(|x| (x.0, x.1)).collect();

let expected = vec![
("/".to_string(), "GET".to_string()),
("/".to_string(), "POST".to_string()),
];

assert_eq!(
router.endpoints(None).map(|x| (x.0, x.1)).collect::<Vec<_>>(),
expected
);
assert_eq!(
ret,
vec![
("/".to_string(), "GET".to_string(),),
("/".to_string(), "POST".to_string(),),
]
router
.into_endpoints()
.map(|x| (x.path, x.method.to_string()))
.collect::<Vec<_>>(),
expected
);
}

Expand Down Expand Up @@ -1808,4 +1915,122 @@ mod test {
}
}
}

#[test]
fn test_extend() {
let mut router1 = HttpRouter::new();
router1.insert(new_endpoint(
new_handler_named("root_get"),
Method::GET,
"/",
));

let mut router2 = HttpRouter::new();
router2.insert(new_endpoint(
new_handler_named("root_post"),
Method::POST,
"/",
));

router2.insert(new_endpoint(
new_handler_named("i"),
Method::GET,
"/projects/{project_id}/instances",
));

router1.extend(router2);

let result =
router1.lookup_route_unversioned(&Method::GET, "/".into()).unwrap();
assert_eq!(result.handler.label(), "root_get");

let result = router1
.lookup_route_unversioned(&Method::POST, "/".into())
.unwrap();
assert_eq!(result.handler.label(), "root_post");

let result = router1
.lookup_route_unversioned(
&Method::GET,
"/projects/p1/instances".into(),
)
.unwrap();
assert_eq!(result.handler.label(), "i");
}

#[test]
fn test_extend_variables() {
// Extend where the other router uses path variables.
let mut router1 = HttpRouter::new();
router1.insert(new_endpoint(
new_handler_named("h1"),
Method::GET,
"/projects",
));

let mut router2 = HttpRouter::new();
router2.insert(new_endpoint(
new_handler_named("h2"),
Method::GET,
"/projects/{project_id}/instances",
));

router1.extend(router2);

let result = router1
.lookup_route_unversioned(&Method::GET, "/projects".into())
.unwrap();
assert_eq!(result.handler.label(), "h1");

let result = router1
.lookup_route_unversioned(
&Method::GET,
"/projects/p1/instances".into(),
)
.unwrap();
assert_eq!(result.handler.label(), "h2");
assert_eq!(
*result.endpoint.variables.get("project_id").unwrap(),
VariableValue::String("p1".to_string())
);
}

#[test]
fn test_extend_empty() {
let mut router1 = HttpRouter::new();
let mut router2 = HttpRouter::new();
router2.insert(new_endpoint(
new_handler_named("h1"),
Method::GET,
"/foo",
));

router1.extend(router2);

let result = router1
.lookup_route_unversioned(&Method::GET, "/foo".into())
.unwrap();
assert_eq!(result.handler.label(), "h1");
}

#[test]
#[should_panic(expected = "URI path \"/foo\": attempted to create \
duplicate route for method \"GET\"")]
fn test_extend_conflict() {
let mut router1 = HttpRouter::new();
router1.insert(new_endpoint(
new_handler_named("h1"),
Method::GET,
"/foo",
));

let mut router2 = HttpRouter::new();
router2.insert(new_endpoint(
new_handler_named("h2"),
Method::GET,
"/foo",
));

router1.extend(router2);
}
}
34 changes: 29 additions & 5 deletions dropshot/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ impl<C: ServerContext> HttpServerStarter<C> {

fn new_internal(
config: &ConfigDropshot,
api: ApiDescription<C>,
router: HttpRouter<C>,
private: C,
log: &Logger,
tls: Option<ConfigTls>,
Expand Down Expand Up @@ -207,7 +207,6 @@ impl<C: ServerContext> HttpServerStarter<C> {
.transpose()?;
let handler_waitgroup = WaitGroup::new();

let router = api.into_router();
if let VersionPolicy::Unversioned = version_policy {
if router.has_versioned_routes() {
return Err(BuildError::UnversionedServerHasVersionedRoutes);
Expand Down Expand Up @@ -1141,7 +1140,7 @@ pub struct ServerBuilder<C: ServerContext> {
// required caller-provided values
private: C,
log: Logger,
api: DebugIgnore<ApiDescription<C>>,
router: HttpRouter<C>,

// optional caller-provided values
config: ConfigDropshot,
Expand All @@ -1164,13 +1163,38 @@ impl<C: ServerContext> ServerBuilder<C> {
ServerBuilder {
private,
log,
api: DebugIgnore(api),
router: api.into_router(),
config: Default::default(),
version_policy: VersionPolicy::Unversioned,
tls: Default::default(),
}
}

/// Create an empty Dropshot server
///
/// The server will have no routes. You can add routes with
/// [`ServerBuilder::routes`].
pub fn empty(private: C, log: Logger) -> ServerBuilder<C> {
ServerBuilder {
private,
log,
router: HttpRouter::new(),
config: Default::default(),
version_policy: VersionPolicy::Unversioned,
tls: Default::default(),
}
}

/// Add routes from an [ApiDescription]
///
/// # Panics
///
/// Panics if the added routes conflict with previously-added routes.
pub fn description(mut self, api: ApiDescription<C>) -> Self {
self.router.extend(api.into_router());
self
}

/// Specify the server configuration
pub fn config(mut self, config: ConfigDropshot) -> Self {
self.config = config;
Expand Down Expand Up @@ -1223,7 +1247,7 @@ impl<C: ServerContext> ServerBuilder<C> {
pub fn build_starter(self) -> Result<HttpServerStarter<C>, BuildError> {
HttpServerStarter::new_internal(
&self.config,
self.api.0,
self.router,
self.private,
&self.log,
self.tls,
Expand Down