diff --git a/nearest-derive/src/emit.rs b/nearest-derive/src/emit.rs index c15a147..7677c35 100644 --- a/nearest-derive/src/emit.rs +++ b/nearest-derive/src/emit.rs @@ -4,7 +4,10 @@ use syn::{Data, DataEnum, DataStruct, DeriveInput, Fields, Variant}; use crate::{ attrs::{parse_field_attrs, parse_variant_attrs}, - util::{capitalize, has_flat_bound, opt_where_clause, to_snake_case}, + util::{ + capitalize, combine_where, flat_bounded_param_names, has_flat_bound, opt_where_clause, + to_snake_case, + }, }; // --------------------------------------------------------------------------- @@ -176,10 +179,25 @@ fn builder_generics( // Structs // --------------------------------------------------------------------------- +/// Generate a where clause that adds `Flat` bounds to unbounded type parameters. +fn outer_where_clause(generics: &syn::Generics) -> TokenStream { + let already_bounded = flat_bounded_param_names(generics); + let preds: Vec<_> = generics + .type_params() + .filter(|tp| !already_bounded.contains(&tp.ident.to_string())) + .map(|tp| { + let ident = &tp.ident; + quote! { #ident: ::nearest::Flat } + }) + .collect(); + combine_where(generics.where_clause.as_ref(), &preds) +} + fn gen_emit_struct(input: &DeriveInput, data: &DataStruct) -> TokenStream { let name = &input.ident; let vis = &input.vis; - let (impl_generics, ty_generics, input_where) = input.generics.split_for_impl(); + let (impl_generics, ty_generics, _) = input.generics.split_for_impl(); + let outer_where = outer_where_clause(&input.generics); match &data.fields { Fields::Named(named) => { @@ -234,7 +252,7 @@ fn gen_emit_struct(input: &DeriveInput, data: &DataStruct) -> TokenStream { }; quote! { - impl #impl_generics #name #ty_generics #input_where { + impl #impl_generics #name #ty_generics #outer_where { #vis fn make #fn_make_generics (#(#fn_params),*) -> impl ::nearest::Emit { struct __Builder #bsg { #bpf @@ -308,7 +326,7 @@ fn gen_emit_struct(input: &DeriveInput, data: &DataStruct) -> TokenStream { }; quote! { - impl #impl_generics #name #ty_generics #input_where { + impl #impl_generics #name #ty_generics #outer_where { #vis fn make #fn_make_generics (#(#fn_params),*) -> impl ::nearest::Emit { struct __Builder #bsg { #bpf @@ -340,7 +358,8 @@ fn gen_emit_struct(input: &DeriveInput, data: &DataStruct) -> TokenStream { fn gen_emit_enum(input: &DeriveInput, data: &DataEnum) -> TokenStream { let name = &input.ident; let vis = &input.vis; - let (impl_generics, ty_generics, input_where) = input.generics.split_for_impl(); + let (impl_generics, ty_generics, _) = input.generics.split_for_impl(); + let outer_where = outer_where_clause(&input.generics); let methods: Vec<_> = data .variants @@ -350,7 +369,7 @@ fn gen_emit_enum(input: &DeriveInput, data: &DataEnum) -> TokenStream { .collect(); quote! { - impl #impl_generics #name #ty_generics #input_where { + impl #impl_generics #name #ty_generics #outer_where { #(#methods)* } }