ast_node/
enum_deserialize.rs

1use swc_macros_common::prelude::*;
2use syn::{
3    parse::{Parse, ParseStream},
4    *,
5};
6
7struct VariantAttr {
8    tags: Punctuated<Lit, Token![,]>,
9}
10
11impl Parse for VariantAttr {
12    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
13        Ok(VariantAttr {
14            tags: input.call(Punctuated::parse_terminated)?,
15        })
16    }
17}
18
19pub fn expand(
20    DeriveInput {
21        generics,
22        ident,
23        data,
24        ..
25    }: DeriveInput,
26) -> ItemImpl {
27    let data = match data {
28        Data::Enum(data) => data,
29        _ => unreachable!("expand_enum is called with none-enum item"),
30    };
31
32    let mut has_wildcard = false;
33
34    let deserialize = {
35        let mut all_tags: Punctuated<_, token::Comma> = Default::default();
36        let tag_match_arms = data
37            .variants
38            .iter()
39            .filter(|v| !crate::encoding::is_unknown(&v.attrs))
40            .map(|variant| {
41                let field_type = match variant.fields {
42                    Fields::Unnamed(ref fields) => {
43                        assert_eq!(
44                            fields.unnamed.len(),
45                            1,
46                            "#[ast_node] enum cannot contain variant with multiple fields"
47                        );
48
49                        fields.unnamed.last().unwrap().ty.clone()
50                    }
51                    _ => {
52                        unreachable!("#[ast_node] enum cannot contain named fields or unit variant")
53                    }
54                };
55                let tags = variant
56                    .attrs
57                    .iter()
58                    .filter_map(|attr| -> Option<VariantAttr> {
59                        if !is_attr_name(attr, "tag") {
60                            return None;
61                        }
62                        let tokens = match &attr.meta {
63                            Meta::List(meta) => meta.tokens.clone(),
64                            _ => {
65                                panic!("#[tag] attribute must be in form of #[tag(..)]")
66                            }
67                        };
68                        let tags = parse2(tokens).expect("failed to parse #[tag] attribute");
69
70                        Some(tags)
71                    })
72                    .flat_map(|v| v.tags)
73                    .collect::<Punctuated<_, token::Comma>>();
74
75                assert!(
76                    !tags.is_empty(),
77                    "All #[ast_node] enum variants have one or more tag"
78                );
79
80                // TODO: Clean up this code
81                if tags.len() == 1
82                    && match tags.first() {
83                        Some(Lit::Str(s)) => &*s.value() == "*",
84                        _ => false,
85                    }
86                {
87                    has_wildcard = true;
88                } else {
89                    for tag in tags.iter() {
90                        all_tags.push(tag.clone());
91                    }
92                }
93
94                let vi = &variant.ident;
95
96                Arm {
97                    attrs: Default::default(),
98                    pat: Pat::Path(parse_quote!(__TypeVariant::#vi)),
99                    guard: Default::default(),
100                    fat_arrow_token: Token![=>](variant.ident.span()),
101                    body: parse_quote!(
102                        swc_common::private::serde::Result::map(
103                            <#field_type as serde::Deserialize>::deserialize(
104                                swc_common::private::serde::de::ContentDeserializer::<
105                                    __D::Error,
106                                >::new(__content),
107                            ),
108                            Self::#vi,
109                        )
110                    ),
111                    comma: Some(Token![,](variant.ident.span())),
112                }
113            })
114            .collect::<Vec<Arm>>();
115
116        let tag_expr: Expr = {
117            let mut visit_str_arms = Vec::new();
118            let mut visit_bytes_arms = Vec::new();
119
120            for variant in &data.variants {
121                if crate::encoding::is_unknown(&variant.attrs) {
122                    continue;
123                }
124
125                let tags = variant
126                    .attrs
127                    .iter()
128                    .filter_map(|attr| -> Option<VariantAttr> {
129                        if !is_attr_name(attr, "tag") {
130                            return None;
131                        }
132                        let tokens = match &attr.meta {
133                            Meta::List(meta) => meta.tokens.clone(),
134                            _ => {
135                                panic!("#[tag] attribute must be in form of #[tag(..)]")
136                            }
137                        };
138                        let tags = parse2(tokens).expect("failed to parse #[tag] attribute");
139
140                        Some(tags)
141                    })
142                    .flat_map(|v| v.tags)
143                    .collect::<Punctuated<_, token::Comma>>();
144
145                assert!(
146                    !tags.is_empty(),
147                    "All #[ast_node] enum variants have one or more tag"
148                );
149                let (str_pat, bytes_pat) = {
150                    if tags.len() == 1
151                        && match tags.first() {
152                            Some(Lit::Str(s)) => &*s.value() == "*",
153                            _ => false,
154                        }
155                    {
156                        (
157                            Pat::Wild(PatWild {
158                                attrs: Default::default(),
159                                underscore_token: Token![_](variant.ident.span()),
160                            }),
161                            Pat::Wild(PatWild {
162                                attrs: Default::default(),
163                                underscore_token: Token![_](variant.ident.span()),
164                            }),
165                        )
166                    } else {
167                        fn make_pat(lit: Lit) -> (Pat, Pat) {
168                            let s = match lit.clone() {
169                                Lit::Str(s) => s.value(),
170                                _ => {
171                                    unreachable!()
172                                }
173                            };
174                            (
175                                Pat::Lit(PatLit {
176                                    attrs: Default::default(),
177                                    lit,
178                                }),
179                                Pat::Lit(PatLit {
180                                    attrs: Default::default(),
181                                    lit: Lit::ByteStr(LitByteStr::new(s.as_bytes(), call_site())),
182                                }),
183                            )
184                        }
185                        if tags.len() == 1 {
186                            make_pat(tags.into_iter().next().unwrap())
187                        } else {
188                            let mut str_cases = Punctuated::new();
189                            let mut bytes_cases = Punctuated::new();
190
191                            for tag in tags {
192                                let (str_pat, bytes_pat) = make_pat(tag);
193                                str_cases.push(str_pat);
194                                bytes_cases.push(bytes_pat);
195                            }
196
197                            (
198                                Pat::Or(PatOr {
199                                    attrs: Default::default(),
200                                    leading_vert: Default::default(),
201                                    cases: str_cases,
202                                }),
203                                Pat::Or(PatOr {
204                                    attrs: Default::default(),
205                                    leading_vert: Default::default(),
206                                    cases: bytes_cases,
207                                }),
208                            )
209                        }
210                    }
211                };
212                visit_str_arms.push(Arm {
213                    attrs: Default::default(),
214                    pat: str_pat,
215                    guard: None,
216                    fat_arrow_token: Token![=>](variant.ident.span()),
217                    body: {
218                        let vi = &variant.ident;
219
220                        parse_quote!(Ok(__TypeVariant::#vi))
221                    },
222                    comma: Some(Token![,](variant.ident.span())),
223                });
224                visit_bytes_arms.push(Arm {
225                    attrs: Default::default(),
226                    pat: bytes_pat,
227                    guard: None,
228                    fat_arrow_token: Token![=>](variant.ident.span()),
229                    body: {
230                        let vi = &variant.ident;
231
232                        parse_quote!(Ok(__TypeVariant::#vi))
233                    },
234                    comma: Some(Token![,](variant.ident.span())),
235                });
236            }
237
238            if !has_wildcard {
239                visit_str_arms.push(Arm {
240                    attrs: Default::default(),
241                    pat: Pat::Wild(PatWild {
242                        attrs: Default::default(),
243                        underscore_token: Token![_](ident.span()),
244                    }),
245                    guard: None,
246                    fat_arrow_token: Token![=>](ident.span()),
247                    body: parse_quote!(swc_common::private::serde::Err(
248                        serde::de::Error::unknown_variant(__value, VARIANTS,)
249                    )),
250                    comma: Some(Token![,](ident.span())),
251                });
252                visit_bytes_arms.push(Arm {
253                    attrs: Default::default(),
254                    pat: Pat::Wild(PatWild {
255                        attrs: Default::default(),
256                        underscore_token: Token!(_)(ident.span()),
257                    }),
258                    guard: None,
259                    fat_arrow_token: Token![=>](ident.span()),
260                    body: parse_quote!({
261                        let __value = &swc_common::private::serde::from_utf8_lossy(__value);
262                        swc_common::private::serde::Err(serde::de::Error::unknown_variant(
263                            __value, VARIANTS,
264                        ))
265                    }),
266                    comma: Some(Token![,](ident.span())),
267                });
268            }
269
270            let visit_str_body = Expr::Match(ExprMatch {
271                attrs: Default::default(),
272                match_token: Default::default(),
273                expr: parse_quote!(__value),
274                brace_token: Default::default(),
275                arms: visit_str_arms,
276            });
277            let visit_bytes_body = Expr::Match(ExprMatch {
278                attrs: Default::default(),
279                match_token: Default::default(),
280                expr: parse_quote!(__value),
281                brace_token: Default::default(),
282                arms: visit_bytes_arms,
283            });
284
285            parse_quote!({
286                static VARIANTS: &[&str] = &[#all_tags];
287
288                struct __TypeVariantVisitor;
289
290                impl<'de> serde::de::Visitor<'de> for __TypeVariantVisitor {
291                    type Value = __TypeVariant;
292
293                    fn expecting(
294                        &self,
295                        __formatter: &mut swc_common::private::serde::Formatter,
296                    ) -> swc_common::private::serde::fmt::Result {
297                        swc_common::private::serde::Formatter::write_str(
298                            __formatter,
299                            "variant identifier",
300                        )
301                    }
302
303                    fn visit_str<__E>(
304                        self,
305                        __value: &str,
306                    ) -> swc_common::private::serde::Result<Self::Value, __E>
307                    where
308                        __E: serde::de::Error,
309                    {
310                        #visit_str_body
311                    }
312
313                    fn visit_bytes<__E>(
314                        self,
315                        __value: &[u8],
316                    ) -> swc_common::private::serde::Result<Self::Value, __E>
317                    where
318                        __E: serde::de::Error,
319                    {
320                        #visit_bytes_body
321                    }
322                }
323
324                impl<'de> serde::Deserialize<'de> for __TypeVariant {
325                    #[inline]
326                    fn deserialize<__D>(
327                        __deserializer: __D,
328                    ) -> swc_common::private::serde::Result<Self, __D::Error>
329                    where
330                        __D: serde::Deserializer<'de>,
331                    {
332                        serde::Deserializer::deserialize_identifier(
333                            __deserializer,
334                            __TypeVariantVisitor,
335                        )
336                    }
337                }
338
339                let ty = swc_common::serializer::Type::deserialize(
340                    swc_common::private::serde::de::ContentRefDeserializer::<__D::Error>::new(
341                        &__content,
342                    ),
343                )?;
344
345                let __tagged = __TypeVariant::deserialize(
346                    swc_common::private::serde::de::ContentDeserializer::<__D::Error>::new(
347                        swc_common::private::serde::de::Content::Str(&ty.ty),
348                    ),
349                )?;
350
351                __tagged
352            })
353        };
354
355        let match_type_expr = Expr::Match(ExprMatch {
356            attrs: Default::default(),
357            match_token: Default::default(),
358            expr: parse_quote!(__tagged),
359            brace_token: Default::default(),
360            arms: tag_match_arms,
361        });
362
363        let variants: Punctuated<Variant, Token![,]> = {
364            data.variants
365                .iter()
366                .filter(|v| !crate::encoding::is_unknown(&v.attrs))
367                .cloned()
368                .map(|variant| Variant {
369                    attrs: Default::default(),
370                    fields: Fields::Unit,
371                    ..variant
372                })
373                .collect()
374        };
375        let item: ItemImpl = parse_quote!(
376            #[cfg(feature = "serde-impl")]
377            impl<'de> serde::Deserialize<'de> for #ident {
378                #[allow(unreachable_code)]
379                fn deserialize<__D>(__deserializer: __D) -> ::std::result::Result<Self, __D::Error>
380                where
381                    __D: serde::Deserializer<'de>,
382                {
383                    enum __TypeVariant {
384                        #variants,
385                    }
386
387                    let __content = swc_common::private::content::deserialize_content(__deserializer)?;
388
389                    let __tagged = #tag_expr;
390
391                    #match_type_expr
392                }
393            }
394        );
395
396        item.with_generics(generics)
397    };
398
399    deserialize
400}