1use swc_macros_common::prelude::*;
2use syn::{
3 parse::{Parse, ParseStream},
4 *,
5};
6
7struct VariantAttr {
8 tags: Punctuated<Lit, Token![,]>,
9}
10
11impl Parse for VariantAttr {
12 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
13 Ok(VariantAttr {
14 tags: input.call(Punctuated::parse_terminated)?,
15 })
16 }
17}
18
19pub fn expand(
20 DeriveInput {
21 generics,
22 ident,
23 data,
24 ..
25 }: DeriveInput,
26) -> ItemImpl {
27 let data = match data {
28 Data::Enum(data) => data,
29 _ => unreachable!("expand_enum is called with none-enum item"),
30 };
31
32 let mut has_wildcard = false;
33
34 let deserialize = {
35 let mut all_tags: Punctuated<_, token::Comma> = Default::default();
36 let tag_match_arms = data
37 .variants
38 .iter()
39 .filter(|v| !crate::encoding::is_unknown(&v.attrs))
40 .map(|variant| {
41 let field_type = match variant.fields {
42 Fields::Unnamed(ref fields) => {
43 assert_eq!(
44 fields.unnamed.len(),
45 1,
46 "#[ast_node] enum cannot contain variant with multiple fields"
47 );
48
49 fields.unnamed.last().unwrap().ty.clone()
50 }
51 _ => {
52 unreachable!("#[ast_node] enum cannot contain named fields or unit variant")
53 }
54 };
55 let tags = variant
56 .attrs
57 .iter()
58 .filter_map(|attr| -> Option<VariantAttr> {
59 if !is_attr_name(attr, "tag") {
60 return None;
61 }
62 let tokens = match &attr.meta {
63 Meta::List(meta) => meta.tokens.clone(),
64 _ => {
65 panic!("#[tag] attribute must be in form of #[tag(..)]")
66 }
67 };
68 let tags = parse2(tokens).expect("failed to parse #[tag] attribute");
69
70 Some(tags)
71 })
72 .flat_map(|v| v.tags)
73 .collect::<Punctuated<_, token::Comma>>();
74
75 assert!(
76 !tags.is_empty(),
77 "All #[ast_node] enum variants have one or more tag"
78 );
79
80 if tags.len() == 1
82 && match tags.first() {
83 Some(Lit::Str(s)) => &*s.value() == "*",
84 _ => false,
85 }
86 {
87 has_wildcard = true;
88 } else {
89 for tag in tags.iter() {
90 all_tags.push(tag.clone());
91 }
92 }
93
94 let vi = &variant.ident;
95
96 Arm {
97 attrs: Default::default(),
98 pat: Pat::Path(parse_quote!(__TypeVariant::#vi)),
99 guard: Default::default(),
100 fat_arrow_token: Token),
101 body: parse_quote!(
102 swc_common::private::serde::Result::map(
103 <#field_type as serde::Deserialize>::deserialize(
104 swc_common::private::serde::de::ContentDeserializer::<
105 __D::Error,
106 >::new(__content),
107 ),
108 Self::#vi,
109 )
110 ),
111 comma: Some(Token)),
112 }
113 })
114 .collect::<Vec<Arm>>();
115
116 let tag_expr: Expr = {
117 let mut visit_str_arms = Vec::new();
118 let mut visit_bytes_arms = Vec::new();
119
120 for variant in &data.variants {
121 if crate::encoding::is_unknown(&variant.attrs) {
122 continue;
123 }
124
125 let tags = variant
126 .attrs
127 .iter()
128 .filter_map(|attr| -> Option<VariantAttr> {
129 if !is_attr_name(attr, "tag") {
130 return None;
131 }
132 let tokens = match &attr.meta {
133 Meta::List(meta) => meta.tokens.clone(),
134 _ => {
135 panic!("#[tag] attribute must be in form of #[tag(..)]")
136 }
137 };
138 let tags = parse2(tokens).expect("failed to parse #[tag] attribute");
139
140 Some(tags)
141 })
142 .flat_map(|v| v.tags)
143 .collect::<Punctuated<_, token::Comma>>();
144
145 assert!(
146 !tags.is_empty(),
147 "All #[ast_node] enum variants have one or more tag"
148 );
149 let (str_pat, bytes_pat) = {
150 if tags.len() == 1
151 && match tags.first() {
152 Some(Lit::Str(s)) => &*s.value() == "*",
153 _ => false,
154 }
155 {
156 (
157 Pat::Wild(PatWild {
158 attrs: Default::default(),
159 underscore_token: Token),
160 }),
161 Pat::Wild(PatWild {
162 attrs: Default::default(),
163 underscore_token: Token),
164 }),
165 )
166 } else {
167 fn make_pat(lit: Lit) -> (Pat, Pat) {
168 let s = match lit.clone() {
169 Lit::Str(s) => s.value(),
170 _ => {
171 unreachable!()
172 }
173 };
174 (
175 Pat::Lit(PatLit {
176 attrs: Default::default(),
177 lit,
178 }),
179 Pat::Lit(PatLit {
180 attrs: Default::default(),
181 lit: Lit::ByteStr(LitByteStr::new(s.as_bytes(), call_site())),
182 }),
183 )
184 }
185 if tags.len() == 1 {
186 make_pat(tags.into_iter().next().unwrap())
187 } else {
188 let mut str_cases = Punctuated::new();
189 let mut bytes_cases = Punctuated::new();
190
191 for tag in tags {
192 let (str_pat, bytes_pat) = make_pat(tag);
193 str_cases.push(str_pat);
194 bytes_cases.push(bytes_pat);
195 }
196
197 (
198 Pat::Or(PatOr {
199 attrs: Default::default(),
200 leading_vert: Default::default(),
201 cases: str_cases,
202 }),
203 Pat::Or(PatOr {
204 attrs: Default::default(),
205 leading_vert: Default::default(),
206 cases: bytes_cases,
207 }),
208 )
209 }
210 }
211 };
212 visit_str_arms.push(Arm {
213 attrs: Default::default(),
214 pat: str_pat,
215 guard: None,
216 fat_arrow_token: Token),
217 body: {
218 let vi = &variant.ident;
219
220 parse_quote!(Ok(__TypeVariant::#vi))
221 },
222 comma: Some(Token)),
223 });
224 visit_bytes_arms.push(Arm {
225 attrs: Default::default(),
226 pat: bytes_pat,
227 guard: None,
228 fat_arrow_token: Token),
229 body: {
230 let vi = &variant.ident;
231
232 parse_quote!(Ok(__TypeVariant::#vi))
233 },
234 comma: Some(Token)),
235 });
236 }
237
238 if !has_wildcard {
239 visit_str_arms.push(Arm {
240 attrs: Default::default(),
241 pat: Pat::Wild(PatWild {
242 attrs: Default::default(),
243 underscore_token: Token),
244 }),
245 guard: None,
246 fat_arrow_token: Token),
247 body: parse_quote!(swc_common::private::serde::Err(
248 serde::de::Error::unknown_variant(__value, VARIANTS,)
249 )),
250 comma: Some(Token)),
251 });
252 visit_bytes_arms.push(Arm {
253 attrs: Default::default(),
254 pat: Pat::Wild(PatWild {
255 attrs: Default::default(),
256 underscore_token: Token!(_)(ident.span()),
257 }),
258 guard: None,
259 fat_arrow_token: Token),
260 body: parse_quote!({
261 let __value = &swc_common::private::serde::from_utf8_lossy(__value);
262 swc_common::private::serde::Err(serde::de::Error::unknown_variant(
263 __value, VARIANTS,
264 ))
265 }),
266 comma: Some(Token)),
267 });
268 }
269
270 let visit_str_body = Expr::Match(ExprMatch {
271 attrs: Default::default(),
272 match_token: Default::default(),
273 expr: parse_quote!(__value),
274 brace_token: Default::default(),
275 arms: visit_str_arms,
276 });
277 let visit_bytes_body = Expr::Match(ExprMatch {
278 attrs: Default::default(),
279 match_token: Default::default(),
280 expr: parse_quote!(__value),
281 brace_token: Default::default(),
282 arms: visit_bytes_arms,
283 });
284
285 parse_quote!({
286 static VARIANTS: &[&str] = &[#all_tags];
287
288 struct __TypeVariantVisitor;
289
290 impl<'de> serde::de::Visitor<'de> for __TypeVariantVisitor {
291 type Value = __TypeVariant;
292
293 fn expecting(
294 &self,
295 __formatter: &mut swc_common::private::serde::Formatter,
296 ) -> swc_common::private::serde::fmt::Result {
297 swc_common::private::serde::Formatter::write_str(
298 __formatter,
299 "variant identifier",
300 )
301 }
302
303 fn visit_str<__E>(
304 self,
305 __value: &str,
306 ) -> swc_common::private::serde::Result<Self::Value, __E>
307 where
308 __E: serde::de::Error,
309 {
310 #visit_str_body
311 }
312
313 fn visit_bytes<__E>(
314 self,
315 __value: &[u8],
316 ) -> swc_common::private::serde::Result<Self::Value, __E>
317 where
318 __E: serde::de::Error,
319 {
320 #visit_bytes_body
321 }
322 }
323
324 impl<'de> serde::Deserialize<'de> for __TypeVariant {
325 #[inline]
326 fn deserialize<__D>(
327 __deserializer: __D,
328 ) -> swc_common::private::serde::Result<Self, __D::Error>
329 where
330 __D: serde::Deserializer<'de>,
331 {
332 serde::Deserializer::deserialize_identifier(
333 __deserializer,
334 __TypeVariantVisitor,
335 )
336 }
337 }
338
339 let ty = swc_common::serializer::Type::deserialize(
340 swc_common::private::serde::de::ContentRefDeserializer::<__D::Error>::new(
341 &__content,
342 ),
343 )?;
344
345 let __tagged = __TypeVariant::deserialize(
346 swc_common::private::serde::de::ContentDeserializer::<__D::Error>::new(
347 swc_common::private::serde::de::Content::Str(&ty.ty),
348 ),
349 )?;
350
351 __tagged
352 })
353 };
354
355 let match_type_expr = Expr::Match(ExprMatch {
356 attrs: Default::default(),
357 match_token: Default::default(),
358 expr: parse_quote!(__tagged),
359 brace_token: Default::default(),
360 arms: tag_match_arms,
361 });
362
363 let variants: Punctuated<Variant, Token![,]> = {
364 data.variants
365 .iter()
366 .filter(|v| !crate::encoding::is_unknown(&v.attrs))
367 .cloned()
368 .map(|variant| Variant {
369 attrs: Default::default(),
370 fields: Fields::Unit,
371 ..variant
372 })
373 .collect()
374 };
375 let item: ItemImpl = parse_quote!(
376 #[cfg(feature = "serde-impl")]
377 impl<'de> serde::Deserialize<'de> for #ident {
378 #[allow(unreachable_code)]
379 fn deserialize<__D>(__deserializer: __D) -> ::std::result::Result<Self, __D::Error>
380 where
381 __D: serde::Deserializer<'de>,
382 {
383 enum __TypeVariant {
384 #variants,
385 }
386
387 let __content = swc_common::private::content::deserialize_content(__deserializer)?;
388
389 let __tagged = #tag_expr;
390
391 #match_type_expr
392 }
393 }
394 );
395
396 item.with_generics(generics)
397 };
398
399 deserialize
400}