swc_typescript/fast_dts/
inferrer.rs

1use swc_common::{Spanned, DUMMY_SP};
2use swc_ecma_ast::{
3    ArrowExpr, BindingIdent, BlockStmtOrExpr, Class, Expr, Function, Ident, Lit, ReturnStmt, Stmt,
4    TsKeywordTypeKind, TsParenthesizedType, TsType, TsTypeAliasDecl, TsTypeAnn,
5    TsUnionOrIntersectionType, TsUnionType, UnaryExpr, UnaryOp,
6};
7use swc_ecma_visit::{Visit, VisitWith};
8
9use super::{
10    util::types::{ts_keyword_type, type_ann},
11    FastDts,
12};
13
14impl FastDts {
15    pub(crate) fn infer_type_from_expr(&mut self, e: &Expr) -> Option<Box<TsType>> {
16        match e {
17            Expr::Ident(ident) if ident.sym.as_str() == "undefined" => {
18                Some(ts_keyword_type(TsKeywordTypeKind::TsUndefinedKeyword))
19            }
20            Expr::Lit(lit) => match lit {
21                Lit::Str(_) => Some(ts_keyword_type(TsKeywordTypeKind::TsStringKeyword)),
22                Lit::Bool(_) => Some(ts_keyword_type(TsKeywordTypeKind::TsBooleanKeyword)),
23                Lit::Num(_) => Some(ts_keyword_type(TsKeywordTypeKind::TsNumberKeyword)),
24                Lit::BigInt(_) => Some(ts_keyword_type(TsKeywordTypeKind::TsBigIntKeyword)),
25                Lit::Null(_) => Some(ts_keyword_type(TsKeywordTypeKind::TsNullKeyword)),
26                Lit::Regex(_) | Lit::JSXText(_) => None,
27                #[cfg(swc_ast_unknown)]
28                _ => panic!("unable to access unknown nodes"),
29            },
30            Expr::Tpl(_) => Some(ts_keyword_type(TsKeywordTypeKind::TsStringKeyword)),
31            Expr::Fn(fn_expr) => self.transform_fn_to_ts_type(
32                &fn_expr.function,
33                fn_expr.ident.as_ref().map(|ident| ident.span),
34            ),
35            Expr::Arrow(arrow_expr) => self.transform_arrow_expr_to_ts_type(arrow_expr),
36            Expr::Array(arr) => {
37                self.array_inferred(arr.span);
38                Some(ts_keyword_type(TsKeywordTypeKind::TsUnknownKeyword))
39            }
40            Expr::Object(obj) => self.transform_object_to_ts_type(obj, false),
41            Expr::Class(class) => {
42                self.inferred_type_of_class_expression(class.span());
43                Some(ts_keyword_type(TsKeywordTypeKind::TsUnknownKeyword))
44            }
45            Expr::Paren(expr) => self.infer_type_from_expr(&expr.expr),
46            Expr::TsNonNull(non_null) => self.infer_type_from_expr(&non_null.expr),
47            Expr::TsSatisfies(satisifies) => self.infer_type_from_expr(&satisifies.expr),
48            Expr::TsConstAssertion(assertion) => self.transform_expr_to_ts_type(&assertion.expr),
49            Expr::TsAs(ts_as) => Some(ts_as.type_ann.clone()),
50            Expr::TsTypeAssertion(type_assertion) => Some(type_assertion.type_ann.clone()),
51            Expr::Unary(unary) if Self::can_infer_unary_expr(unary) => {
52                self.infer_type_from_expr(&unary.arg)
53            }
54            _ => None,
55        }
56    }
57
58    pub(crate) fn infer_function_return_type(
59        &mut self,
60        function: &Function,
61    ) -> Option<Box<TsTypeAnn>> {
62        if function.return_type.is_some() {
63            return function.return_type.clone();
64        }
65
66        if function.is_async || function.is_generator {
67            return None;
68        }
69
70        function
71            .body
72            .as_ref()
73            .and_then(|body| ReturnTypeInferrer::infer(self, &body.stmts))
74            .map(type_ann)
75    }
76
77    pub(crate) fn infer_arrow_return_type(&mut self, arrow: &ArrowExpr) -> Option<Box<TsTypeAnn>> {
78        if arrow.return_type.is_some() {
79            return arrow.return_type.clone();
80        }
81
82        if arrow.is_async || arrow.is_generator {
83            return None;
84        }
85
86        match arrow.body.as_ref() {
87            BlockStmtOrExpr::BlockStmt(block_stmt) => {
88                ReturnTypeInferrer::infer(self, &block_stmt.stmts)
89            }
90            BlockStmtOrExpr::Expr(expr) => self.infer_type_from_expr(expr),
91            #[cfg(swc_ast_unknown)]
92            _ => panic!("unable to access unknown nodes"),
93        }
94        .map(type_ann)
95    }
96
97    pub(crate) fn need_to_infer_type_from_expression(expr: &Expr) -> bool {
98        match expr {
99            Expr::Lit(lit) => !(lit.is_str() || lit.is_num() || lit.is_big_int() || lit.is_bool()),
100            Expr::Tpl(tpl) => !tpl.exprs.is_empty(),
101            Expr::Unary(unary) => !Self::can_infer_unary_expr(unary),
102            _ => true,
103        }
104    }
105
106    pub(crate) fn can_infer_unary_expr(unary: &UnaryExpr) -> bool {
107        let is_arithmetic = matches!(unary.op, UnaryOp::Plus | UnaryOp::Minus);
108        let is_number_lit = match unary.arg.as_ref() {
109            Expr::Lit(lit) => lit.is_num() || lit.is_big_int(),
110            _ => false,
111        };
112        is_arithmetic && is_number_lit
113    }
114}
115
116#[derive(Default)]
117pub struct ReturnTypeInferrer {
118    value_bindings: Vec<Ident>,
119    type_bindings: Vec<Ident>,
120    return_expr: Option<Option<Box<Expr>>>,
121    return_expr_count: u8,
122}
123
124impl ReturnTypeInferrer {
125    pub fn infer(transformer: &mut FastDts, stmts: &[Stmt]) -> Option<Box<TsType>> {
126        let mut visitor = ReturnTypeInferrer::default();
127        stmts.visit_children_with(&mut visitor);
128
129        let expr = visitor.return_expr??;
130        let Some(mut expr_type) = transformer.infer_type_from_expr(&expr) else {
131            return if expr.is_fn_expr() || expr.is_arrow() {
132                Some(ts_keyword_type(TsKeywordTypeKind::TsUnknownKeyword))
133            } else {
134                None
135            };
136        };
137
138        if let Some((ref_name, is_value)) = match expr_type.as_ref() {
139            TsType::TsTypeRef(type_ref) => Some((type_ref.type_name.as_ident()?, false)),
140            TsType::TsTypeQuery(type_query) => {
141                Some((type_query.expr_name.as_ts_entity_name()?.as_ident()?, true))
142            }
143            _ => None,
144        } {
145            let is_defined = if is_value {
146                visitor.value_bindings.contains(ref_name)
147            } else {
148                visitor.type_bindings.contains(ref_name)
149            };
150
151            if is_defined {
152                transformer.type_containing_private_name(ref_name.sym.as_str(), ref_name.span);
153            }
154        }
155
156        if visitor.return_expr_count > 1 {
157            if expr_type.is_ts_fn_or_constructor_type() {
158                expr_type = Box::new(TsType::TsParenthesizedType(TsParenthesizedType {
159                    span: DUMMY_SP,
160                    type_ann: expr_type,
161                }));
162            }
163
164            expr_type = Box::new(TsType::TsUnionOrIntersectionType(
165                TsUnionOrIntersectionType::TsUnionType(TsUnionType {
166                    span: DUMMY_SP,
167                    types: vec![
168                        expr_type,
169                        ts_keyword_type(TsKeywordTypeKind::TsUndefinedKeyword),
170                    ],
171                }),
172            ))
173        }
174
175        Some(expr_type)
176    }
177}
178
179impl Visit for ReturnTypeInferrer {
180    fn visit_arrow_expr(&mut self, _node: &ArrowExpr) {}
181
182    fn visit_function(&mut self, _node: &Function) {}
183
184    fn visit_class(&mut self, _node: &Class) {}
185
186    fn visit_binding_ident(&mut self, node: &BindingIdent) {
187        self.value_bindings.push(node.id.clone());
188    }
189
190    fn visit_ts_type_alias_decl(&mut self, node: &TsTypeAliasDecl) {
191        self.type_bindings.push(node.id.clone());
192    }
193
194    fn visit_return_stmt(&mut self, node: &ReturnStmt) {
195        self.return_expr_count += 1;
196        if self.return_expr_count > 1 {
197            if let Some(return_expr) = &self.return_expr {
198                if return_expr.is_some() {
199                    self.return_expr = None;
200                    return;
201                }
202            } else {
203                return;
204            }
205        }
206        self.return_expr = Some(node.arg.clone());
207    }
208}