1#![recursion_limit = "1024"]
2
3extern crate proc_macro;
4
5use quote::quote_spanned;
6use swc_macros_common::prelude::*;
7use syn::{parse::Parse, *};
8
9#[proc_macro_derive(StringEnum, attributes(string_enum))]
65pub fn derive_string_enum(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
66 let input = syn::parse::<syn::DeriveInput>(input).expect("failed to parse derive input");
67 let mut tts = TokenStream::new();
68
69 make_as_str(&input).to_tokens(&mut tts);
70 make_from_str(&input).to_tokens(&mut tts);
71
72 make_serialize(&input).to_tokens(&mut tts);
73 make_deserialize(&input).to_tokens(&mut tts);
74
75 derive_fmt(&input, quote_spanned!(Span::call_site() => std::fmt::Debug)).to_tokens(&mut tts);
76 derive_fmt(
77 &input,
78 quote_spanned!(Span::call_site() => std::fmt::Display),
79 )
80 .to_tokens(&mut tts);
81
82 print("derive(StringEnum)", tts)
83}
84
85fn derive_fmt(i: &DeriveInput, trait_path: TokenStream) -> ItemImpl {
86 let ty = &i.ident;
87
88 let item: ItemImpl = parse_quote!(
89 impl #trait_path for #ty {
90 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
91 let s = self.as_str();
92 #trait_path::fmt(s, f)
93 }
94 }
95 );
96
97 item.with_generics(i.generics.clone())
98}
99
100fn get_str_value(attrs: &[Attribute]) -> String {
101 let docs: Vec<_> = attrs.iter().filter_map(doc_str).collect();
103 for raw_line in docs {
104 let line = raw_line.trim();
105 if line.starts_with('`') && line.ends_with('`') {
106 let mut s: String = line.split_at(1).1.into();
107 let new_len = s.len() - 1;
108 s.truncate(new_len);
109 return s;
110 }
111 }
112
113 panic!("StringEnum: Cannot determine string value of this variant")
114}
115
116fn make_from_str(i: &DeriveInput) -> ItemImpl {
117 let arms = Binder::new_from(i)
118 .variants()
119 .into_iter()
120 .map(|v| {
121 let qual_name = v.qual_path();
123
124 let str_value = get_str_value(v.attrs());
125
126 let mut pat: Pat = Pat::Lit(ExprLit {
127 attrs: Default::default(),
128 lit: Lit::Str(LitStr::new(&str_value, Span::call_site())),
129 });
130
131 for attr in v
133 .attrs()
134 .iter()
135 .filter(|attr| is_attr_name(attr, "string_enum"))
136 {
137 if let Meta::List(meta) = &attr.meta {
138 let mut cases = Punctuated::default();
139
140 cases.push(pat);
141
142 for item in parse2::<FieldAttr>(meta.tokens.clone())
143 .expect("failed to parse `#[string_enum]`")
144 .aliases
145 {
146 cases.push(Pat::Lit(PatLit {
147 attrs: Default::default(),
148 lit: Lit::Str(item.alias),
149 }));
150 }
151
152 pat = Pat::Or(PatOr {
153 attrs: Default::default(),
154 leading_vert: None,
155 cases,
156 });
157 continue;
158 }
159
160 panic!("Unsupported meta: {:#?}", attr.meta);
161 }
162
163 let body = match *v.data() {
164 Fields::Unit => Box::new(parse_quote!(return Ok(#qual_name))),
165 _ => unreachable!("StringEnum requires all variants not to have fields"),
166 };
167
168 Arm {
169 body,
170 attrs: v
171 .attrs()
172 .iter()
173 .filter(|attr| is_attr_name(attr, "cfg"))
174 .cloned()
175 .collect(),
176 pat,
177 guard: None,
178 fat_arrow_token: Default::default(),
179 comma: Some(Token)),
180 }
181 })
182 .chain(::std::iter::once(parse_quote!(_ => Err(()))))
183 .collect();
184
185 let body = Expr::Match(ExprMatch {
186 attrs: Default::default(),
187 match_token: Default::default(),
188 brace_token: Default::default(),
189 expr: Box::new(parse_quote!(s)),
190 arms,
191 });
192
193 let ty = &i.ident;
194 let item: ItemImpl = parse_quote!(
195 impl ::std::str::FromStr for #ty {
196 type Err = ();
197
198 fn from_str(s: &str) -> Result<Self, ()> {
199 #body
200 }
201 }
202 );
203 item.with_generics(i.generics.clone())
204}
205
206fn make_as_str(i: &DeriveInput) -> ItemImpl {
207 let arms = Binder::new_from(i)
208 .variants()
209 .into_iter()
210 .map(|v| {
211 let qual_name = v.qual_path();
213
214 let str_value = get_str_value(v.attrs());
215
216 let body = Box::new(parse_quote!(return #str_value));
217
218 let pat = match *v.data() {
219 Fields::Unit => Box::new(Pat::Path(PatPath {
220 qself: None,
221 path: qual_name,
222 attrs: Default::default(),
223 })),
224 _ => Box::new(Pat::Struct(PatStruct {
225 attrs: Default::default(),
226 qself: None,
227 path: qual_name,
228 brace_token: Default::default(),
229 fields: Default::default(),
230 rest: Some(PatRest {
231 attrs: Default::default(),
232 dot2_token: Default::default(),
233 }),
234 })),
235 };
236
237 Arm {
238 body,
239 attrs: v
240 .attrs()
241 .iter()
242 .filter(|attr| is_attr_name(attr, "cfg"))
243 .cloned()
244 .collect(),
245 pat: Pat::Reference(PatReference {
246 and_token: Default::default(),
247 mutability: None,
248 pat,
249 attrs: Default::default(),
250 }),
251 guard: None,
252 fat_arrow_token: Default::default(),
253 comma: Some(Token)),
254 }
255 })
256 .collect();
257
258 let body = Expr::Match(ExprMatch {
259 attrs: Default::default(),
260 match_token: Default::default(),
261 brace_token: Default::default(),
262 expr: Box::new(parse_quote!(self)),
263 arms,
264 });
265
266 let ty = &i.ident;
267 let as_str = make_as_str_ident();
268 let item: ItemImpl = parse_quote!(
269 impl #ty {
270 pub fn #as_str(&self) -> &'static str {
271 #body
272 }
273 }
274 );
275
276 item.with_generics(i.generics.clone())
277}
278
279fn make_as_str_ident() -> Ident {
280 Ident::new("as_str", call_site())
281}
282
283fn make_serialize(i: &DeriveInput) -> ItemImpl {
284 let ty = &i.ident;
285 let item: ItemImpl = parse_quote!(
286 #[cfg(feature = "serde")]
287 impl ::serde::Serialize for #ty {
288 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
289 where
290 S: ::serde::Serializer,
291 {
292 serializer.serialize_str(self.as_str())
293 }
294 }
295 );
296
297 item.with_generics(i.generics.clone())
298}
299
300fn make_deserialize(i: &DeriveInput) -> ItemImpl {
301 let ty = &i.ident;
302 let item: ItemImpl = parse_quote!(
303 #[cfg(feature = "serde")]
304 impl<'de> ::serde::Deserialize<'de> for #ty {
305 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
306 where
307 D: ::serde::Deserializer<'de>,
308 {
309 struct StrVisitor;
310
311 impl<'de> ::serde::de::Visitor<'de> for StrVisitor {
312 type Value = #ty;
313
314 fn expecting(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
315 write!(f, "one of (TODO)")
317 }
318
319 fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
320 where
321 E: ::serde::de::Error,
322 {
323 value.parse().map_err(|()| E::unknown_variant(value, &[]))
325 }
326 }
327
328 deserializer.deserialize_str(StrVisitor)
329 }
330 }
331 );
332
333 item.with_generics(i.generics.clone())
334}
335
336struct FieldAttr {
337 aliases: Punctuated<FieldAttrItem, Token![,]>,
338}
339
340impl Parse for FieldAttr {
341 fn parse(input: parse::ParseStream) -> Result<Self> {
342 Ok(Self {
343 aliases: input.call(Punctuated::parse_terminated)?,
344 })
345 }
346}
347
348struct FieldAttrItem {
350 alias: LitStr,
351}
352
353impl Parse for FieldAttrItem {
354 fn parse(input: parse::ParseStream) -> Result<Self> {
355 let name: Ident = input.parse()?;
356
357 assert!(
358 name == "alias",
359 "#[derive(StringEnum) only supports `#[string_enum(alias(\"text\"))]]"
360 );
361
362 let alias;
363 parenthesized!(alias in input);
364
365 Ok(Self {
366 alias: alias.parse()?,
367 })
368 }
369}