string_enum/
lib.rs

1#![recursion_limit = "1024"]
2
3extern crate proc_macro;
4
5use quote::quote_spanned;
6use swc_macros_common::prelude::*;
7use syn::{parse::Parse, *};
8
9/// Creates `.as_str()` and then implements `Debug` and `Display` using it.
10///
11///# Input
12/// Enum with \`str_value\`-style **doc** comment for each variant.
13///
14/// e.g.
15///
16///```no_run
17/// pub enum BinOp {
18///     /// `+`
19///     Add,
20///     /// `-`
21///     Minus,
22/// }
23/// ```
24///
25/// Currently, \`str_value\` must be live in it's own line.
26///
27///# Output
28///
29///  - `pub fn as_str(&self) -> &'static str`
30///  - `impl serde::Serialize` with `cfg(feature = "serde")`
31///  - `impl serde::Deserialize` with `cfg(feature = "serde")`
32///  - `impl FromStr`
33///  - `impl Debug`
34///  - `impl Display`
35///
36///# Example
37///
38///
39///```
40/// #[macro_use]
41/// extern crate string_enum;
42/// extern crate serde;
43///
44/// #[derive(StringEnum)]
45/// pub enum Tokens {
46///     /// `a`
47///     A,
48///     /// `bar`
49///     B,
50/// }
51/// # fn main() {
52///
53/// assert_eq!(Tokens::A.as_str(), "a");
54/// assert_eq!(Tokens::B.as_str(), "bar");
55///
56/// assert_eq!(Tokens::A.to_string(), "a");
57/// assert_eq!(format!("{:?}", Tokens::A), format!("{:?}", "a"));
58///
59/// # }
60/// ```
61///
62///
63/// All formatting flags are handled correctly.
64#[proc_macro_derive(StringEnum, attributes(string_enum))]
65pub fn derive_string_enum(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
66    let input = syn::parse::<syn::DeriveInput>(input).expect("failed to parse derive input");
67    let mut tts = TokenStream::new();
68
69    make_as_str(&input).to_tokens(&mut tts);
70    make_from_str(&input).to_tokens(&mut tts);
71
72    make_serialize(&input).to_tokens(&mut tts);
73    make_deserialize(&input).to_tokens(&mut tts);
74
75    derive_fmt(&input, quote_spanned!(Span::call_site() => std::fmt::Debug)).to_tokens(&mut tts);
76    derive_fmt(
77        &input,
78        quote_spanned!(Span::call_site() => std::fmt::Display),
79    )
80    .to_tokens(&mut tts);
81
82    print("derive(StringEnum)", tts)
83}
84
85fn derive_fmt(i: &DeriveInput, trait_path: TokenStream) -> ItemImpl {
86    let ty = &i.ident;
87
88    let item: ItemImpl = parse_quote!(
89        impl #trait_path for #ty {
90            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
91                let s = self.as_str();
92                #trait_path::fmt(s, f)
93            }
94        }
95    );
96
97    item.with_generics(i.generics.clone())
98}
99
100fn get_str_value(attrs: &[Attribute]) -> String {
101    // TODO: Accept multiline string
102    let docs: Vec<_> = attrs.iter().filter_map(doc_str).collect();
103    for raw_line in docs {
104        let line = raw_line.trim();
105        if line.starts_with('`') && line.ends_with('`') {
106            let mut s: String = line.split_at(1).1.into();
107            let new_len = s.len() - 1;
108            s.truncate(new_len);
109            return s;
110        }
111    }
112
113    panic!("StringEnum: Cannot determine string value of this variant")
114}
115
116fn make_from_str(i: &DeriveInput) -> ItemImpl {
117    let arms = Binder::new_from(i)
118        .variants()
119        .into_iter()
120        .map(|v| {
121            // Qualified path of variant.
122            let qual_name = v.qual_path();
123
124            let str_value = get_str_value(v.attrs());
125
126            let mut pat: Pat = Pat::Lit(ExprLit {
127                attrs: Default::default(),
128                lit: Lit::Str(LitStr::new(&str_value, Span::call_site())),
129            });
130
131            // Handle `string_enum(alias("foo"))`
132            for attr in v
133                .attrs()
134                .iter()
135                .filter(|attr| is_attr_name(attr, "string_enum"))
136            {
137                if let Meta::List(meta) = &attr.meta {
138                    let mut cases = Punctuated::default();
139
140                    cases.push(pat);
141
142                    for item in parse2::<FieldAttr>(meta.tokens.clone())
143                        .expect("failed to parse `#[string_enum]`")
144                        .aliases
145                    {
146                        cases.push(Pat::Lit(PatLit {
147                            attrs: Default::default(),
148                            lit: Lit::Str(item.alias),
149                        }));
150                    }
151
152                    pat = Pat::Or(PatOr {
153                        attrs: Default::default(),
154                        leading_vert: None,
155                        cases,
156                    });
157                    continue;
158                }
159
160                panic!("Unsupported meta: {:#?}", attr.meta);
161            }
162
163            let body = match *v.data() {
164                Fields::Unit => Box::new(parse_quote!(return Ok(#qual_name))),
165                _ => unreachable!("StringEnum requires all variants not to have fields"),
166            };
167
168            Arm {
169                body,
170                attrs: v
171                    .attrs()
172                    .iter()
173                    .filter(|attr| is_attr_name(attr, "cfg"))
174                    .cloned()
175                    .collect(),
176                pat,
177                guard: None,
178                fat_arrow_token: Default::default(),
179                comma: Some(Token![,](def_site())),
180            }
181        })
182        .chain(::std::iter::once(parse_quote!(_ => Err(()))))
183        .collect();
184
185    let body = Expr::Match(ExprMatch {
186        attrs: Default::default(),
187        match_token: Default::default(),
188        brace_token: Default::default(),
189        expr: Box::new(parse_quote!(s)),
190        arms,
191    });
192
193    let ty = &i.ident;
194    let item: ItemImpl = parse_quote!(
195        impl ::std::str::FromStr for #ty {
196            type Err = ();
197
198            fn from_str(s: &str) -> Result<Self, ()> {
199                #body
200            }
201        }
202    );
203    item.with_generics(i.generics.clone())
204}
205
206fn make_as_str(i: &DeriveInput) -> ItemImpl {
207    let arms = Binder::new_from(i)
208        .variants()
209        .into_iter()
210        .map(|v| {
211            // Qualified path of variant.
212            let qual_name = v.qual_path();
213
214            let str_value = get_str_value(v.attrs());
215
216            let body = Box::new(parse_quote!(return #str_value));
217
218            let pat = match *v.data() {
219                Fields::Unit => Box::new(Pat::Path(PatPath {
220                    qself: None,
221                    path: qual_name,
222                    attrs: Default::default(),
223                })),
224                _ => Box::new(Pat::Struct(PatStruct {
225                    attrs: Default::default(),
226                    qself: None,
227                    path: qual_name,
228                    brace_token: Default::default(),
229                    fields: Default::default(),
230                    rest: Some(PatRest {
231                        attrs: Default::default(),
232                        dot2_token: Default::default(),
233                    }),
234                })),
235            };
236
237            Arm {
238                body,
239                attrs: v
240                    .attrs()
241                    .iter()
242                    .filter(|attr| is_attr_name(attr, "cfg"))
243                    .cloned()
244                    .collect(),
245                pat: Pat::Reference(PatReference {
246                    and_token: Default::default(),
247                    mutability: None,
248                    pat,
249                    attrs: Default::default(),
250                }),
251                guard: None,
252                fat_arrow_token: Default::default(),
253                comma: Some(Token![,](def_site())),
254            }
255        })
256        .collect();
257
258    let body = Expr::Match(ExprMatch {
259        attrs: Default::default(),
260        match_token: Default::default(),
261        brace_token: Default::default(),
262        expr: Box::new(parse_quote!(self)),
263        arms,
264    });
265
266    let ty = &i.ident;
267    let as_str = make_as_str_ident();
268    let item: ItemImpl = parse_quote!(
269        impl #ty {
270            pub fn #as_str(&self) -> &'static str {
271                #body
272            }
273        }
274    );
275
276    item.with_generics(i.generics.clone())
277}
278
279fn make_as_str_ident() -> Ident {
280    Ident::new("as_str", call_site())
281}
282
283fn make_serialize(i: &DeriveInput) -> ItemImpl {
284    let ty = &i.ident;
285    let item: ItemImpl = parse_quote!(
286        #[cfg(feature = "serde")]
287        impl ::serde::Serialize for #ty {
288            fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
289            where
290                S: ::serde::Serializer,
291            {
292                serializer.serialize_str(self.as_str())
293            }
294        }
295    );
296
297    item.with_generics(i.generics.clone())
298}
299
300fn make_deserialize(i: &DeriveInput) -> ItemImpl {
301    let ty = &i.ident;
302    let item: ItemImpl = parse_quote!(
303        #[cfg(feature = "serde")]
304        impl<'de> ::serde::Deserialize<'de> for #ty {
305            fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
306            where
307                D: ::serde::Deserializer<'de>,
308            {
309                struct StrVisitor;
310
311                impl<'de> ::serde::de::Visitor<'de> for StrVisitor {
312                    type Value = #ty;
313
314                    fn expecting(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
315                        // TODO: List strings
316                        write!(f, "one of (TODO)")
317                    }
318
319                    fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
320                    where
321                        E: ::serde::de::Error,
322                    {
323                        // TODO
324                        value.parse().map_err(|()| E::unknown_variant(value, &[]))
325                    }
326                }
327
328                deserializer.deserialize_str(StrVisitor)
329            }
330        }
331    );
332
333    item.with_generics(i.generics.clone())
334}
335
336struct FieldAttr {
337    aliases: Punctuated<FieldAttrItem, Token![,]>,
338}
339
340impl Parse for FieldAttr {
341    fn parse(input: parse::ParseStream) -> Result<Self> {
342        Ok(Self {
343            aliases: input.call(Punctuated::parse_terminated)?,
344        })
345    }
346}
347
348/// `alias("text")` in `#[string_enum(alias("text"))]`.
349struct FieldAttrItem {
350    alias: LitStr,
351}
352
353impl Parse for FieldAttrItem {
354    fn parse(input: parse::ParseStream) -> Result<Self> {
355        let name: Ident = input.parse()?;
356
357        assert!(
358            name == "alias",
359            "#[derive(StringEnum) only supports `#[string_enum(alias(\"text\"))]]"
360        );
361
362        let alias;
363        parenthesized!(alias in input);
364
365        Ok(Self {
366            alias: alias.parse()?,
367        })
368    }
369}