From a5bf501ebd86912c5e404b4b7c7a4d7e0b85a242 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20K=C3=A4nner?= Date: Mon, 19 Aug 2024 22:26:43 +0200 Subject: [PATCH] implement enum variant generation --- src/lib.rs | 29 ++++++ src/parser.rs | 255 +++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 283 insertions(+), 1 deletion(-) diff --git a/src/lib.rs b/src/lib.rs index 80037f2..544814a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,12 +2,41 @@ //! extra information about the field. use darling::{util::PathList, FromMeta}; +use parser::Enum; use proc_macro::TokenStream; use quote::quote; use syn::{parse_macro_input, parse_quote, DataStruct, DeriveInput}; mod parser; +/// Implements `miniconf::Tree` for the given enum. +/// +/// This implementation of `miniconf::Tree` adds a value and a variants path. +/// The value path serializes the enum. +/// The variants path serializes all possible variants as an array. +/// Optionally default and description paths can be generated. +/// The description is generated from the docstring describing the enum. +/// +/// # Example +/// ``` +/// use macroconf::ConfigEnum; +/// use serde::{Serialize, Deserialize}; +/// +/// /// Description +/// #[derive(Default, ConfigEnum, Serialize, Deserialize)] +/// enum Test { +/// #[default] +/// Variant1, +/// Variant2, +/// Variant3, +/// } +/// ``` +#[proc_macro_derive(ConfigEnum, attributes(default))] +pub fn config_enum(item: TokenStream) -> TokenStream { + let input = parse_macro_input!(item as Enum); + input.generate_tree().into() +} + /// Creates structs for the values to extend them with extra metadata. /// /// supported metadata is `min`, `max` and `default`. Doc comments are parsed as `description` diff --git a/src/parser.rs b/src/parser.rs index 3ba7d67..dabd9e1 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -10,7 +10,7 @@ use darling::{ }; use proc_macro2::{Span, TokenStream}; use quote::{format_ident, quote}; -use syn::{parse_quote, spanned::Spanned}; +use syn::{braced, parse::Parse, parse_quote, spanned::Spanned, Token}; #[derive(Debug, FromField)] #[darling(attributes(config))] @@ -515,3 +515,256 @@ impl Config { &mut fields.fields } } + +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct Enum { + doc: Option, + ident: syn::Ident, + variants: Vec, + default: Option, +} + +impl Enum { + pub fn generate_tree(&self) -> TokenStream { + let mut tokens = self.generate_tree_key(); + tokens.extend(self.generate_tree_serialize()); + tokens.extend(self.generate_tree_deserialize()); + tokens.extend(self.generate_tree_any()); + tokens + } + + fn keys(&self) -> Vec<(&'static str, syn::Expr)> { + let variants = &self.variants; + let mut keys = vec![ + ("value", parse_quote!(self)), + ("variants", parse_quote!([#(Self::#variants,)*])), + ]; + if let Some(ref doc) = self.doc { + keys.push(("description", parse_quote!(#doc))); + } + if let Some(ref default) = self.default { + keys.push(("default", parse_quote!(Self::#default))); + } + keys + } + + fn generate_tree_key(&self) -> TokenStream { + let ident = &self.ident; + let keys = self.keys(); + let num_keys = keys.len(); + let max_length = keys.iter().map(|(path, _)| path.len()).max(); + let keys = keys.iter().map(|(path, _)| path); + quote! { + impl ::miniconf::KeyLookup for #ident { + const LEN: usize = #num_keys; + const NAMES: &'static [&'static str] = &[#(#keys,)*]; + + fn name_to_index(value: &str) -> Option { + Self::NAMES.iter().position(|name| *name == value) + } + } + + impl ::miniconf::TreeKey<1> for #ident { + fn metadata() -> ::miniconf::Metadata { + let mut metadata = ::miniconf::Metadata::default(); + metadata.max_depth = 1; + metadata.count = #num_keys; + metadata.max_length = #max_length; + metadata + } + + fn traverse_by_key(mut keys: K, mut func: F) -> ::core::result::Result> + where + K: ::miniconf::Keys, + // Writing this to return an iterator instead of using a callback + // would have worse performance (O(n^2) instead of O(n) for matching) + F: FnMut(usize, Option<&'static str>, usize) -> ::core::result::Result<(), E>, + { + let ::core::result::Result::Ok(key) = keys.next::() else { return ::core::result::Result::Ok(0) }; + let index = ::miniconf::Key::find::(&key).ok_or(::miniconf::Traversal::NotFound(1))?; + let name = ::NAMES + .get(index) + .ok_or(::miniconf::Traversal::NotFound(1))?; + func(index, Some(name), #num_keys).map_err(|err| ::miniconf::Error::Inner(1, err))?; + ::core::result::Result::Ok(1) + } + } + } + } + + fn generate_tree_serialize(&self) -> TokenStream { + let ident = &self.ident; + let matches = self + .keys() + .iter() + .enumerate() + .map(|(i, (_, expr))| { + quote! { + #i => ::serde::Serialize::serialize(&#expr, ser).map_err(|err| ::miniconf::Error::Inner(0, err)), + } + }) + .collect::>(); + quote! { + impl ::miniconf::TreeSerialize<1> for #ident { + fn serialize_by_key( + &self, + mut keys: K, + ser: S, + ) -> ::core::result::Result> + where + K: ::miniconf::Keys, + S: ::serde::Serializer, + { + let ::core::result::Result::Ok(key) = keys.next::() else { + return ::serde::Serialize::serialize(self, ser).map_err(|err| ::miniconf::Error::Inner(0, err)).map(|_| 0); + }; + let index = ::miniconf::Key::find::(&key).ok_or(::miniconf::Traversal::NotFound(0))?; + if !keys.finalize() { + return ::core::result::Result::Err(::miniconf::Traversal::TooLong(0).into()); + } + match index { + #(#matches)* + _ => unreachable!(), + }?; + Ok(0) + } + } + } + } + + fn generate_tree_deserialize(&self) -> TokenStream { + let ident = &self.ident; + let num_keys = self.keys().len(); + quote! { + impl<'de> ::miniconf::TreeDeserialize<'de, 1> for #ident { + fn deserialize_by_key( + &mut self, + mut keys: K, + de: D, + ) -> ::core::result::Result> + where + K: ::miniconf::Keys, + D: ::serde::Deserializer<'de>, + { + let ::core::result::Result::Ok(key) = keys.next::() else { + ::deserialize(de).map_err(|err| ::miniconf::Error::Inner(0, err))?; + return ::core::result::Result::Ok(0); + }; + let index = ::miniconf::Key::find::(&key).ok_or(::miniconf::Traversal::NotFound(1))?; + if !keys.finalize() { + return ::core::result::Result::Err(::miniconf::Traversal::TooLong(1).into()); + } + match index { + 0 => { + ::deserialize(de).map_err(|err| ::miniconf::Error::Inner(0, err))?; + Ok(0) + } + 1..=#num_keys => ::core::result::Result::Err(::miniconf::Traversal::Access(0, "Cannot write limits").into()), + _ => unreachable!(), + } + } + } + } + } + + fn generate_tree_any(&self) -> TokenStream { + let ident = &self.ident; + let num_keys = self.keys().len(); + quote! { + impl ::miniconf::TreeAny<1> for #ident { + fn ref_any_by_key(&self, mut keys: K) -> ::core::result::Result<&dyn ::core::any::Any, ::miniconf::Traversal> + where + K: ::miniconf::Keys, + { + let ::core::result::Result::Ok(key) = keys.next::() else { + return ::core::result::Result::Ok(self); + }; + let index = ::miniconf::Key::find::(&key).ok_or(::miniconf::Traversal::NotFound(1))?; + if !keys.finalize() { + return ::core::result::Result::Err(::miniconf::Traversal::TooLong(1)); + } + match index { + 0 => ::core::result::Result::Ok(self), + 1..#num_keys => ::core::result::Result::Err(::miniconf::Traversal::Access(1, "cannot return reference to local variable")), + _ => unreachable!(), + } + } + + fn mut_any_by_key(&mut self, mut keys: K) -> ::core::result::Result<&mut dyn ::core::any::Any, ::miniconf::Traversal> + where + K: ::miniconf::Keys, + { + let ::core::result::Result::Ok(key) = keys.next::() else { + return Ok(self); + }; + let index = ::miniconf::Key::find::(&key).ok_or(::miniconf::Traversal::NotFound(1))?; + if !keys.finalize() { + return ::core::result::Result::Err(::miniconf::Traversal::TooLong(1)); + } + match index { + 0 => Ok(self), + 1..#num_keys => ::core::result::Result::Err(::miniconf::Traversal::Access(1, "cannot return reference to local variable")), + _ => unreachable!(), + } + } + } + } + } +} + +impl Parse for Enum { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let attrs = input.call(syn::Attribute::parse_outer)?; + let doc = attrs + .iter() + .find(|attr| attr.path().is_ident("doc")) + .map(|attr| { + let meta = attr.meta.require_name_value()?; + if let syn::Expr::Lit(syn::ExprLit { + attrs: _, + lit: syn::Lit::Str(ref lit), + }) = meta.value + { + Ok(lit.value().trim().to_owned()) + } else { + Err(syn::Error::new_spanned( + &meta.value, + "Expected string literal for doc comment", + )) + } + }) + .transpose()?; + + let _vis = input.parse::()?; + let _enum_token = input.parse::()?; + + let ident = input.parse::()?; + + let content; + let _brace = braced!(content in input); + let variants = content.parse_terminated(syn::Variant::parse, Token![,])?; + if let Some(variant) = variants.iter().find(|variant| !variant.fields.is_empty()) { + return Err(syn::Error::new_spanned( + &variant.fields, + "only unit variants are supported for now", + )); + } + let default = variants + .iter() + .find(|variant| { + variant + .attrs + .iter() + .any(|attr| attr.path().is_ident("default")) + }) + .map(|variant| variant.ident.clone()); + let variants = variants.into_iter().map(|variant| variant.ident).collect(); + + Ok(Self { + doc, + ident, + variants, + default, + }) + } +}