swc_ecma_minifier/compress/optimize/
if_return.rs

1use swc_common::{util::take::Take, DUMMY_SP};
2use swc_ecma_ast::*;
3use swc_ecma_transforms_optimization::debug_assert_valid;
4use swc_ecma_utils::{ExprCtx, StmtLike};
5use swc_ecma_visit::{noop_visit_type, Visit, VisitWith};
6
7use super::Optimizer;
8#[cfg(feature = "debug")]
9use crate::debug::dump;
10use crate::{
11    compress::{
12        optimize::BitCtx,
13        util::{eval_to_undefined, is_pure_undefined},
14    },
15    util::ExprOptExt,
16};
17
18/// Methods related to the option `if_return`. All methods are noop if
19/// `if_return` is false.
20impl Optimizer<'_> {
21    pub(super) fn merge_if_returns(
22        &mut self,
23        stmts: &mut Vec<Stmt>,
24        terminates: bool,
25        is_fn_body: bool,
26    ) {
27        if !self.options.if_return {
28            return;
29        }
30
31        for stmt in stmts.iter_mut() {
32            self.merge_nested_if_returns(stmt, terminates);
33
34            debug_assert_valid(&*stmt);
35        }
36
37        if terminates || is_fn_body {
38            self.merge_if_returns_inner(stmts, !is_fn_body);
39        }
40    }
41
42    #[allow(clippy::only_used_in_recursion)]
43    fn merge_nested_if_returns(&mut self, s: &mut Stmt, can_work: bool) {
44        let terminate = can_merge_as_if_return(&*s);
45
46        match s {
47            Stmt::Block(s) => {
48                self.merge_if_returns(&mut s.stmts, terminate, false);
49
50                debug_assert_valid(&*s);
51            }
52            Stmt::If(s) => {
53                self.merge_nested_if_returns(&mut s.cons, can_work);
54
55                debug_assert_valid(&s.cons);
56
57                if let Some(alt) = s.alt.as_deref_mut() {
58                    self.merge_nested_if_returns(alt, can_work);
59
60                    debug_assert_valid(&*alt);
61                }
62            }
63            _ => {}
64        }
65    }
66
67    /// Merge simple return statements in if statements.
68    ///
69    /// # Example
70    ///
71    /// ## Input
72    ///
73    /// ```js
74    /// function foo() {
75    ///     if (a) return foo();
76    ///     return bar()
77    /// }
78    /// ```
79    ///
80    /// ## Output
81    ///
82    /// ```js
83    /// function foo() {
84    ///     return a ? foo() : bar();
85    /// }
86    /// ```
87    fn merge_if_returns_inner(&mut self, stmts: &mut Vec<Stmt>, should_preserve_last_return: bool) {
88        if !self.options.if_return {
89            return;
90        }
91
92        // for stmt in stmts.iter_mut() {
93        //     let ctx = Ctx {
94        //         is_nested_if_return_merging: true,
95        //         ..self.ctx.clone()
96        //     };
97        //     self.with_ctx(ctx).merge_nested_if_returns(stmt, terminate);
98        // }
99
100        if stmts.len() <= 1 {
101            return;
102        }
103
104        let idx_of_not_mergable =
105            stmts
106                .iter()
107                .enumerate()
108                .rposition(|(idx, stmt)| match stmt.as_stmt() {
109                    Some(v) => !self.can_merge_stmt_as_if_return(v, stmts.len() - 1 == idx),
110                    None => true,
111                });
112        let skip = idx_of_not_mergable.map(|v| v + 1).unwrap_or(0);
113        trace_op!("if_return: Skip = {}", skip);
114        let mut last_idx = stmts.len() - 1;
115
116        {
117            loop {
118                let s = stmts.get(last_idx);
119                let s = match s {
120                    Some(s) => s,
121                    _ => break,
122                };
123
124                if let Stmt::Decl(Decl::Var(v)) = s {
125                    if v.decls.iter().all(|v| v.init.is_none()) {
126                        if last_idx == 0 {
127                            break;
128                        }
129                        last_idx -= 1;
130                        continue;
131                    }
132                }
133
134                break;
135            }
136        }
137
138        if last_idx <= skip {
139            log_abort!("if_return: [x] Aborting because of skip");
140            return;
141        }
142
143        {
144            let stmts = &stmts[skip..=last_idx];
145            let return_count: usize = stmts.iter().map(count_leaping_returns).sum();
146
147            // There's no return statement so merging requires injecting unnecessary `void
148            // 0`
149            if return_count == 0 {
150                log_abort!("if_return: [x] Aborting because we failed to find return");
151                return;
152            }
153
154            // If the last statement is a return statement and last - 1 is an if statement
155            // is without return, we don't need to fold it as `void 0` is too much for such
156            // cases.
157
158            let if_return_count = stmts
159                .iter()
160                .filter(|s| match s {
161                    Stmt::If(IfStmt {
162                        cons, alt: None, ..
163                    }) => always_terminates_with_return_arg(cons),
164                    _ => false,
165                })
166                .count();
167
168            fn is_return_undefined(expr_ctx: ExprCtx, s: &Stmt) -> bool {
169                let Stmt::Return(s) = s else {
170                    return false;
171                };
172
173                match &s.arg {
174                    None => true,
175                    Some(e) => eval_to_undefined(expr_ctx, e),
176                }
177            }
178
179            if stmts.len() >= 2 {
180                match (
181                    &stmts[stmts.len() - 2].as_stmt(),
182                    &stmts[stmts.len() - 1].as_stmt(),
183                ) {
184                    (
185                        Some(Stmt::If(IfStmt {
186                            alt: None, cons, ..
187                        })),
188                        Some(Stmt::Expr(_)),
189                    ) if is_return_undefined(self.ctx.expr_ctx, cons) => {}
190                    (_, Some(Stmt::If(IfStmt { alt: None, .. }) | Stmt::Expr(..)))
191                        if if_return_count <= 1 =>
192                    {
193                        log_abort!(
194                            "if_return: [x] Aborting because last stmt is a not return stmt"
195                        );
196                        return;
197                    }
198
199                    (
200                        Some(Stmt::If(IfStmt {
201                            cons, alt: None, ..
202                        })),
203                        Some(Stmt::Return(..)),
204                    ) => match &**cons {
205                        Stmt::Return(ReturnStmt { arg: Some(..), .. }) => {}
206                        _ => {
207                            log_abort!(
208                                "if_return: [x] Aborting because stmt before last is an if stmt \
209                                 and cons of it is not a return stmt"
210                            );
211                            return;
212                        }
213                    },
214
215                    (
216                        Some(Stmt::Block(BlockStmt { stmts: s1, .. })),
217                        Some(Stmt::Block(BlockStmt { stmts: s2, .. })),
218                    ) if s1.iter().any(|s| matches!(s, Stmt::Return(..)))
219                        && s2.iter().any(|s| matches!(s, Stmt::Return(..))) =>
220                    {
221                        log_abort!("if_return: [x] Aborting because early return is observed");
222                        return;
223                    }
224
225                    _ => {}
226                }
227            }
228        }
229
230        {
231            let stmts = &stmts[..=last_idx];
232            let start = stmts
233                .iter()
234                .enumerate()
235                .skip(skip)
236                .position(|(idx, stmt)| match stmt.as_stmt() {
237                    Some(v) => self.can_merge_stmt_as_if_return(v, stmts.len() - 1 == idx),
238                    None => false,
239                })
240                .unwrap_or(0);
241
242            let ends_with_mergable = stmts
243                .last()
244                .map(|stmt| match stmt.as_stmt() {
245                    Some(Stmt::If(IfStmt { alt: None, .. }))
246                        if self.ctx.bit_ctx.contains(BitCtx::IsNestedIfReturnMerging) =>
247                    {
248                        false
249                    }
250                    Some(s) => self.can_merge_stmt_as_if_return(s, true),
251                    _ => false,
252                })
253                .unwrap();
254
255            if stmts.len() == start + skip + 1 || !ends_with_mergable {
256                return;
257            }
258
259            let can_merge =
260                stmts
261                    .iter()
262                    .enumerate()
263                    .skip(skip)
264                    .all(|(idx, stmt)| match stmt.as_stmt() {
265                        Some(s) => self.can_merge_stmt_as_if_return(s, stmts.len() - 1 == idx),
266                        _ => false,
267                    });
268            if !can_merge {
269                return;
270            }
271        }
272
273        report_change!("if_return: Merging returns");
274
275        self.changed = true;
276
277        let mut cur: Option<Box<Expr>> = None;
278        let mut new = Vec::with_capacity(stmts.len());
279
280        let len = stmts.len();
281
282        for (idx, stmt) in stmts.take().into_iter().enumerate() {
283            if let Some(not_mergable) = idx_of_not_mergable {
284                if idx < not_mergable {
285                    new.push(stmt);
286                    continue;
287                }
288            }
289            if idx > last_idx {
290                new.push(stmt);
291                continue;
292            }
293
294            let stmt = if !self.can_merge_stmt_as_if_return(&stmt, len - 1 == idx) {
295                debug_assert_eq!(cur, None);
296                new.push(stmt);
297                continue;
298            } else {
299                stmt
300            };
301            let is_nonconditional_return = matches!(stmt, Stmt::Return(..));
302            let new_expr = self.merge_if_returns_to(stmt, Vec::new());
303            match new_expr {
304                Expr::Seq(v) => match &mut cur {
305                    Some(cur) => match &mut **cur {
306                        Expr::Cond(cur) => {
307                            let seq = get_rightmost_alt_of_cond(cur).force_seq();
308                            seq.exprs.extend(v.exprs);
309                        }
310                        Expr::Seq(cur) => {
311                            cur.exprs.extend(v.exprs);
312                        }
313                        _ => {
314                            unreachable!(
315                                "if_return: cur must be one of None, Expr::Seq or Expr::Cond(with \
316                                 alt Expr::Seq)"
317                            )
318                        }
319                    },
320                    None => cur = Some(v.into()),
321                },
322                Expr::Cond(v) => match &mut cur {
323                    Some(cur) => match &mut **cur {
324                        Expr::Cond(cur) => {
325                            let alt = get_rightmost_alt_of_cond(cur);
326
327                            let (span, exprs) = {
328                                let prev_seq = alt.force_seq();
329                                prev_seq.exprs.push(v.test);
330                                let exprs = prev_seq.exprs.take();
331
332                                (prev_seq.span, exprs)
333                            };
334
335                            *alt = CondExpr {
336                                span: DUMMY_SP,
337                                test: SeqExpr { span, exprs }.into(),
338                                cons: v.cons,
339                                alt: v.alt,
340                            }
341                            .into();
342                        }
343                        Expr::Seq(prev_seq) => {
344                            prev_seq.exprs.push(v.test);
345                            let exprs = prev_seq.exprs.take();
346
347                            *cur = CondExpr {
348                                span: DUMMY_SP,
349                                test: Box::new(
350                                    SeqExpr {
351                                        span: prev_seq.span,
352                                        exprs,
353                                    }
354                                    .into(),
355                                ),
356                                cons: v.cons,
357                                alt: v.alt,
358                            }
359                            .into();
360                        }
361                        _ => {
362                            unreachable!(
363                                "if_return: cur must be one of None, Expr::Seq or Expr::Cond(with \
364                                 alt Expr::Seq)"
365                            )
366                        }
367                    },
368                    None => cur = Some(v.into()),
369                },
370                _ => {
371                    unreachable!(
372                        "if_return: merge_if_returns_to should return one of None, Expr::Seq or \
373                         Expr::Cond"
374                    )
375                }
376            }
377
378            if is_nonconditional_return {
379                break;
380            }
381        }
382
383        if let Some(mut cur) = cur {
384            self.normalize_expr(&mut cur);
385
386            match &*cur {
387                Expr::Seq(seq)
388                    if !should_preserve_last_return
389                        && seq
390                            .exprs
391                            .last()
392                            .map(|v| is_pure_undefined(self.ctx.expr_ctx, v))
393                            .unwrap_or(true) =>
394                {
395                    let expr = self.ignore_return_value(&mut cur);
396
397                    if let Some(cur) = expr {
398                        new.push(
399                            ExprStmt {
400                                span: DUMMY_SP,
401                                expr: Box::new(cur),
402                            }
403                            .into(),
404                        )
405                    } else {
406                        trace_op!("if_return: Ignoring return value");
407                    }
408                }
409                Expr::Cond(cond)
410                    if !should_preserve_last_return
411                        && eval_to_undefined(self.ctx.expr_ctx, &cond.cons)
412                        && eval_to_undefined(self.ctx.expr_ctx, &cond.alt) =>
413                {
414                    let expr = self.ignore_return_value(&mut cur);
415
416                    if let Some(cur) = expr {
417                        new.push(
418                            ExprStmt {
419                                span: DUMMY_SP,
420                                expr: Box::new(cur),
421                            }
422                            .into(),
423                        )
424                    } else {
425                        trace_op!("if_return: Ignoring return value");
426                    }
427                }
428                _ => {
429                    new.push(
430                        ReturnStmt {
431                            span: DUMMY_SP,
432                            arg: Some(cur),
433                        }
434                        .into(),
435                    );
436                }
437            }
438        }
439
440        *stmts = new;
441    }
442
443    /// This method returns [Expr::Seq] or [Expr::Cond].
444    ///
445    /// `exprs` is a simple optimization.
446    fn merge_if_returns_to(&mut self, stmt: Stmt, mut exprs: Vec<Box<Expr>>) -> Expr {
447        //
448        match stmt {
449            Stmt::Block(s) => {
450                assert_eq!(s.stmts.len(), 1);
451                self.merge_if_returns_to(s.stmts.into_iter().next().unwrap(), exprs)
452            }
453
454            Stmt::If(IfStmt {
455                span,
456                test,
457                cons,
458                alt,
459                ..
460            }) => {
461                let cons = Box::new(self.merge_if_returns_to(*cons, Vec::new()));
462                let alt = match alt {
463                    Some(alt) => Box::new(self.merge_if_returns_to(*alt, Vec::new())),
464                    None => Expr::undefined(DUMMY_SP),
465                };
466
467                exprs.push(test);
468
469                CondExpr {
470                    span,
471                    test: SeqExpr {
472                        span: DUMMY_SP,
473                        exprs,
474                    }
475                    .into(),
476                    cons,
477                    alt,
478                }
479                .into()
480            }
481            Stmt::Expr(stmt) => {
482                exprs.push(
483                    UnaryExpr {
484                        span: DUMMY_SP,
485                        op: op!("void"),
486                        arg: stmt.expr,
487                    }
488                    .into(),
489                );
490                SeqExpr {
491                    span: DUMMY_SP,
492                    exprs,
493                }
494                .into()
495            }
496            Stmt::Return(stmt) => {
497                let span = stmt.span;
498                exprs.push(stmt.arg.unwrap_or_else(|| Expr::undefined(span)));
499                SeqExpr {
500                    span: DUMMY_SP,
501                    exprs,
502                }
503                .into()
504            }
505            _ => unreachable!(),
506        }
507    }
508
509    fn can_merge_stmt_as_if_return(&self, s: &Stmt, is_last: bool) -> bool {
510        // if !res {
511        //     trace!("Cannot merge: {}", dump(s));
512        // }
513
514        match s {
515            Stmt::Expr(..) => true,
516            Stmt::Return(..) => is_last,
517            Stmt::Block(s) => {
518                s.stmts.len() == 1 && self.can_merge_stmt_as_if_return(&s.stmts[0], is_last)
519            }
520            Stmt::If(stmt) => {
521                matches!(&*stmt.cons, Stmt::Return(..))
522                    && matches!(
523                        stmt.alt.as_deref(),
524                        None | Some(Stmt::Return(..) | Stmt::Expr(..))
525                    )
526            }
527            _ => false,
528        }
529    }
530}
531
532fn get_rightmost_alt_of_cond(e: &mut CondExpr) -> &mut Expr {
533    match &mut *e.alt {
534        Expr::Cond(alt) => get_rightmost_alt_of_cond(alt),
535        alt => alt,
536    }
537}
538
539fn count_leaping_returns<N>(n: &N) -> usize
540where
541    N: VisitWith<ReturnFinder>,
542{
543    let mut v = ReturnFinder::default();
544    n.visit_with(&mut v);
545    v.count
546}
547
548#[derive(Default)]
549pub(super) struct ReturnFinder {
550    count: usize,
551}
552
553impl Visit for ReturnFinder {
554    noop_visit_type!(fail);
555
556    fn visit_return_stmt(&mut self, n: &ReturnStmt) {
557        n.visit_children_with(self);
558        self.count += 1;
559    }
560
561    fn visit_function(&mut self, _: &Function) {}
562
563    fn visit_arrow_expr(&mut self, _: &ArrowExpr) {}
564}
565
566fn always_terminates_with_return_arg(s: &Stmt) -> bool {
567    match s {
568        Stmt::Return(ReturnStmt { arg: Some(..), .. }) => true,
569        Stmt::If(IfStmt { cons, alt, .. }) => {
570            always_terminates_with_return_arg(cons)
571                && alt
572                    .as_deref()
573                    .map(always_terminates_with_return_arg)
574                    .unwrap_or(false)
575        }
576        Stmt::Block(s) => s.stmts.iter().any(always_terminates_with_return_arg),
577
578        _ => false,
579    }
580}
581
582fn can_merge_as_if_return(s: &Stmt) -> bool {
583    fn cost(s: &Stmt) -> Option<isize> {
584        if let Stmt::Block(..) = s {
585            if !swc_ecma_utils::StmtExt::terminates(s) {
586                return None;
587            }
588        }
589
590        match s {
591            Stmt::Return(ReturnStmt { arg: Some(..), .. }) => Some(-1),
592
593            Stmt::Return(ReturnStmt { arg: None, .. }) => Some(0),
594
595            Stmt::Throw(..) | Stmt::Break(..) | Stmt::Continue(..) => Some(0),
596
597            Stmt::If(IfStmt { cons, alt, .. }) => {
598                Some(cost(cons)? + alt.as_deref().and_then(cost).unwrap_or(0))
599            }
600            Stmt::Block(s) => {
601                let mut sum = 0;
602                let mut found = false;
603                for s in s.stmts.iter().rev() {
604                    let c = cost(s);
605                    if let Some(c) = c {
606                        found = true;
607                        sum += c;
608                    }
609                }
610                if found {
611                    Some(sum)
612                } else {
613                    None
614                }
615            }
616
617            _ => None,
618        }
619    }
620
621    let c = cost(s);
622
623    trace_op!("merging cost of `{}` = {:?}", dump(s, false), c);
624
625    c.unwrap_or(0) < 0
626}