1use swc_common::{util::take::Take, DUMMY_SP};
2use swc_ecma_ast::*;
3use swc_ecma_transforms_optimization::debug_assert_valid;
4use swc_ecma_utils::{ExprCtx, StmtLike};
5use swc_ecma_visit::{noop_visit_type, Visit, VisitWith};
6
7use super::Optimizer;
8#[cfg(feature = "debug")]
9use crate::debug::dump;
10use crate::{
11 compress::{
12 optimize::BitCtx,
13 util::{eval_to_undefined, is_pure_undefined},
14 },
15 util::ExprOptExt,
16};
17
18impl Optimizer<'_> {
21 pub(super) fn merge_if_returns(
22 &mut self,
23 stmts: &mut Vec<Stmt>,
24 terminates: bool,
25 is_fn_body: bool,
26 ) {
27 if !self.options.if_return {
28 return;
29 }
30
31 for stmt in stmts.iter_mut() {
32 self.merge_nested_if_returns(stmt, terminates);
33
34 debug_assert_valid(&*stmt);
35 }
36
37 if terminates || is_fn_body {
38 self.merge_if_returns_inner(stmts, !is_fn_body);
39 }
40 }
41
42 #[allow(clippy::only_used_in_recursion)]
43 fn merge_nested_if_returns(&mut self, s: &mut Stmt, can_work: bool) {
44 let terminate = can_merge_as_if_return(&*s);
45
46 match s {
47 Stmt::Block(s) => {
48 self.merge_if_returns(&mut s.stmts, terminate, false);
49
50 debug_assert_valid(&*s);
51 }
52 Stmt::If(s) => {
53 self.merge_nested_if_returns(&mut s.cons, can_work);
54
55 debug_assert_valid(&s.cons);
56
57 if let Some(alt) = s.alt.as_deref_mut() {
58 self.merge_nested_if_returns(alt, can_work);
59
60 debug_assert_valid(&*alt);
61 }
62 }
63 _ => {}
64 }
65 }
66
67 fn merge_if_returns_inner(&mut self, stmts: &mut Vec<Stmt>, should_preserve_last_return: bool) {
88 if !self.options.if_return {
89 return;
90 }
91
92 if stmts.len() <= 1 {
101 return;
102 }
103
104 let idx_of_not_mergable =
105 stmts
106 .iter()
107 .enumerate()
108 .rposition(|(idx, stmt)| match stmt.as_stmt() {
109 Some(v) => !self.can_merge_stmt_as_if_return(v, stmts.len() - 1 == idx),
110 None => true,
111 });
112 let skip = idx_of_not_mergable.map(|v| v + 1).unwrap_or(0);
113 trace_op!("if_return: Skip = {}", skip);
114 let mut last_idx = stmts.len() - 1;
115
116 {
117 loop {
118 let s = stmts.get(last_idx);
119 let s = match s {
120 Some(s) => s,
121 _ => break,
122 };
123
124 if let Stmt::Decl(Decl::Var(v)) = s {
125 if v.decls.iter().all(|v| v.init.is_none()) {
126 if last_idx == 0 {
127 break;
128 }
129 last_idx -= 1;
130 continue;
131 }
132 }
133
134 break;
135 }
136 }
137
138 if last_idx <= skip {
139 log_abort!("if_return: [x] Aborting because of skip");
140 return;
141 }
142
143 {
144 let stmts = &stmts[skip..=last_idx];
145 let return_count: usize = stmts.iter().map(count_leaping_returns).sum();
146
147 if return_count == 0 {
150 log_abort!("if_return: [x] Aborting because we failed to find return");
151 return;
152 }
153
154 let if_return_count = stmts
159 .iter()
160 .filter(|s| match s {
161 Stmt::If(IfStmt {
162 cons, alt: None, ..
163 }) => always_terminates_with_return_arg(cons),
164 _ => false,
165 })
166 .count();
167
168 fn is_return_undefined(expr_ctx: ExprCtx, s: &Stmt) -> bool {
169 let Stmt::Return(s) = s else {
170 return false;
171 };
172
173 match &s.arg {
174 None => true,
175 Some(e) => eval_to_undefined(expr_ctx, e),
176 }
177 }
178
179 if stmts.len() >= 2 {
180 match (
181 &stmts[stmts.len() - 2].as_stmt(),
182 &stmts[stmts.len() - 1].as_stmt(),
183 ) {
184 (
185 Some(Stmt::If(IfStmt {
186 alt: None, cons, ..
187 })),
188 Some(Stmt::Expr(_)),
189 ) if is_return_undefined(self.ctx.expr_ctx, cons) => {}
190 (_, Some(Stmt::If(IfStmt { alt: None, .. }) | Stmt::Expr(..)))
191 if if_return_count <= 1 =>
192 {
193 log_abort!(
194 "if_return: [x] Aborting because last stmt is a not return stmt"
195 );
196 return;
197 }
198
199 (
200 Some(Stmt::If(IfStmt {
201 cons, alt: None, ..
202 })),
203 Some(Stmt::Return(..)),
204 ) => match &**cons {
205 Stmt::Return(ReturnStmt { arg: Some(..), .. }) => {}
206 _ => {
207 log_abort!(
208 "if_return: [x] Aborting because stmt before last is an if stmt \
209 and cons of it is not a return stmt"
210 );
211 return;
212 }
213 },
214
215 (
216 Some(Stmt::Block(BlockStmt { stmts: s1, .. })),
217 Some(Stmt::Block(BlockStmt { stmts: s2, .. })),
218 ) if s1.iter().any(|s| matches!(s, Stmt::Return(..)))
219 && s2.iter().any(|s| matches!(s, Stmt::Return(..))) =>
220 {
221 log_abort!("if_return: [x] Aborting because early return is observed");
222 return;
223 }
224
225 _ => {}
226 }
227 }
228 }
229
230 {
231 let stmts = &stmts[..=last_idx];
232 let start = stmts
233 .iter()
234 .enumerate()
235 .skip(skip)
236 .position(|(idx, stmt)| match stmt.as_stmt() {
237 Some(v) => self.can_merge_stmt_as_if_return(v, stmts.len() - 1 == idx),
238 None => false,
239 })
240 .unwrap_or(0);
241
242 let ends_with_mergable = stmts
243 .last()
244 .map(|stmt| match stmt.as_stmt() {
245 Some(Stmt::If(IfStmt { alt: None, .. }))
246 if self.ctx.bit_ctx.contains(BitCtx::IsNestedIfReturnMerging) =>
247 {
248 false
249 }
250 Some(s) => self.can_merge_stmt_as_if_return(s, true),
251 _ => false,
252 })
253 .unwrap();
254
255 if stmts.len() == start + skip + 1 || !ends_with_mergable {
256 return;
257 }
258
259 let can_merge =
260 stmts
261 .iter()
262 .enumerate()
263 .skip(skip)
264 .all(|(idx, stmt)| match stmt.as_stmt() {
265 Some(s) => self.can_merge_stmt_as_if_return(s, stmts.len() - 1 == idx),
266 _ => false,
267 });
268 if !can_merge {
269 return;
270 }
271 }
272
273 report_change!("if_return: Merging returns");
274
275 self.changed = true;
276
277 let mut cur: Option<Box<Expr>> = None;
278 let mut new = Vec::with_capacity(stmts.len());
279
280 let len = stmts.len();
281
282 for (idx, stmt) in stmts.take().into_iter().enumerate() {
283 if let Some(not_mergable) = idx_of_not_mergable {
284 if idx < not_mergable {
285 new.push(stmt);
286 continue;
287 }
288 }
289 if idx > last_idx {
290 new.push(stmt);
291 continue;
292 }
293
294 let stmt = if !self.can_merge_stmt_as_if_return(&stmt, len - 1 == idx) {
295 debug_assert_eq!(cur, None);
296 new.push(stmt);
297 continue;
298 } else {
299 stmt
300 };
301 let is_nonconditional_return = matches!(stmt, Stmt::Return(..));
302 let new_expr = self.merge_if_returns_to(stmt, Vec::new());
303 match new_expr {
304 Expr::Seq(v) => match &mut cur {
305 Some(cur) => match &mut **cur {
306 Expr::Cond(cur) => {
307 let seq = get_rightmost_alt_of_cond(cur).force_seq();
308 seq.exprs.extend(v.exprs);
309 }
310 Expr::Seq(cur) => {
311 cur.exprs.extend(v.exprs);
312 }
313 _ => {
314 unreachable!(
315 "if_return: cur must be one of None, Expr::Seq or Expr::Cond(with \
316 alt Expr::Seq)"
317 )
318 }
319 },
320 None => cur = Some(v.into()),
321 },
322 Expr::Cond(v) => match &mut cur {
323 Some(cur) => match &mut **cur {
324 Expr::Cond(cur) => {
325 let alt = get_rightmost_alt_of_cond(cur);
326
327 let (span, exprs) = {
328 let prev_seq = alt.force_seq();
329 prev_seq.exprs.push(v.test);
330 let exprs = prev_seq.exprs.take();
331
332 (prev_seq.span, exprs)
333 };
334
335 *alt = CondExpr {
336 span: DUMMY_SP,
337 test: SeqExpr { span, exprs }.into(),
338 cons: v.cons,
339 alt: v.alt,
340 }
341 .into();
342 }
343 Expr::Seq(prev_seq) => {
344 prev_seq.exprs.push(v.test);
345 let exprs = prev_seq.exprs.take();
346
347 *cur = CondExpr {
348 span: DUMMY_SP,
349 test: Box::new(
350 SeqExpr {
351 span: prev_seq.span,
352 exprs,
353 }
354 .into(),
355 ),
356 cons: v.cons,
357 alt: v.alt,
358 }
359 .into();
360 }
361 _ => {
362 unreachable!(
363 "if_return: cur must be one of None, Expr::Seq or Expr::Cond(with \
364 alt Expr::Seq)"
365 )
366 }
367 },
368 None => cur = Some(v.into()),
369 },
370 _ => {
371 unreachable!(
372 "if_return: merge_if_returns_to should return one of None, Expr::Seq or \
373 Expr::Cond"
374 )
375 }
376 }
377
378 if is_nonconditional_return {
379 break;
380 }
381 }
382
383 if let Some(mut cur) = cur {
384 self.normalize_expr(&mut cur);
385
386 match &*cur {
387 Expr::Seq(seq)
388 if !should_preserve_last_return
389 && seq
390 .exprs
391 .last()
392 .map(|v| is_pure_undefined(self.ctx.expr_ctx, v))
393 .unwrap_or(true) =>
394 {
395 let expr = self.ignore_return_value(&mut cur);
396
397 if let Some(cur) = expr {
398 new.push(
399 ExprStmt {
400 span: DUMMY_SP,
401 expr: Box::new(cur),
402 }
403 .into(),
404 )
405 } else {
406 trace_op!("if_return: Ignoring return value");
407 }
408 }
409 Expr::Cond(cond)
410 if !should_preserve_last_return
411 && eval_to_undefined(self.ctx.expr_ctx, &cond.cons)
412 && eval_to_undefined(self.ctx.expr_ctx, &cond.alt) =>
413 {
414 let expr = self.ignore_return_value(&mut cur);
415
416 if let Some(cur) = expr {
417 new.push(
418 ExprStmt {
419 span: DUMMY_SP,
420 expr: Box::new(cur),
421 }
422 .into(),
423 )
424 } else {
425 trace_op!("if_return: Ignoring return value");
426 }
427 }
428 _ => {
429 new.push(
430 ReturnStmt {
431 span: DUMMY_SP,
432 arg: Some(cur),
433 }
434 .into(),
435 );
436 }
437 }
438 }
439
440 *stmts = new;
441 }
442
443 fn merge_if_returns_to(&mut self, stmt: Stmt, mut exprs: Vec<Box<Expr>>) -> Expr {
447 match stmt {
449 Stmt::Block(s) => {
450 assert_eq!(s.stmts.len(), 1);
451 self.merge_if_returns_to(s.stmts.into_iter().next().unwrap(), exprs)
452 }
453
454 Stmt::If(IfStmt {
455 span,
456 test,
457 cons,
458 alt,
459 ..
460 }) => {
461 let cons = Box::new(self.merge_if_returns_to(*cons, Vec::new()));
462 let alt = match alt {
463 Some(alt) => Box::new(self.merge_if_returns_to(*alt, Vec::new())),
464 None => Expr::undefined(DUMMY_SP),
465 };
466
467 exprs.push(test);
468
469 CondExpr {
470 span,
471 test: SeqExpr {
472 span: DUMMY_SP,
473 exprs,
474 }
475 .into(),
476 cons,
477 alt,
478 }
479 .into()
480 }
481 Stmt::Expr(stmt) => {
482 exprs.push(
483 UnaryExpr {
484 span: DUMMY_SP,
485 op: op!("void"),
486 arg: stmt.expr,
487 }
488 .into(),
489 );
490 SeqExpr {
491 span: DUMMY_SP,
492 exprs,
493 }
494 .into()
495 }
496 Stmt::Return(stmt) => {
497 let span = stmt.span;
498 exprs.push(stmt.arg.unwrap_or_else(|| Expr::undefined(span)));
499 SeqExpr {
500 span: DUMMY_SP,
501 exprs,
502 }
503 .into()
504 }
505 _ => unreachable!(),
506 }
507 }
508
509 fn can_merge_stmt_as_if_return(&self, s: &Stmt, is_last: bool) -> bool {
510 match s {
515 Stmt::Expr(..) => true,
516 Stmt::Return(..) => is_last,
517 Stmt::Block(s) => {
518 s.stmts.len() == 1 && self.can_merge_stmt_as_if_return(&s.stmts[0], is_last)
519 }
520 Stmt::If(stmt) => {
521 matches!(&*stmt.cons, Stmt::Return(..))
522 && matches!(
523 stmt.alt.as_deref(),
524 None | Some(Stmt::Return(..) | Stmt::Expr(..))
525 )
526 }
527 _ => false,
528 }
529 }
530}
531
532fn get_rightmost_alt_of_cond(e: &mut CondExpr) -> &mut Expr {
533 match &mut *e.alt {
534 Expr::Cond(alt) => get_rightmost_alt_of_cond(alt),
535 alt => alt,
536 }
537}
538
539fn count_leaping_returns<N>(n: &N) -> usize
540where
541 N: VisitWith<ReturnFinder>,
542{
543 let mut v = ReturnFinder::default();
544 n.visit_with(&mut v);
545 v.count
546}
547
548#[derive(Default)]
549pub(super) struct ReturnFinder {
550 count: usize,
551}
552
553impl Visit for ReturnFinder {
554 noop_visit_type!(fail);
555
556 fn visit_return_stmt(&mut self, n: &ReturnStmt) {
557 n.visit_children_with(self);
558 self.count += 1;
559 }
560
561 fn visit_function(&mut self, _: &Function) {}
562
563 fn visit_arrow_expr(&mut self, _: &ArrowExpr) {}
564}
565
566fn always_terminates_with_return_arg(s: &Stmt) -> bool {
567 match s {
568 Stmt::Return(ReturnStmt { arg: Some(..), .. }) => true,
569 Stmt::If(IfStmt { cons, alt, .. }) => {
570 always_terminates_with_return_arg(cons)
571 && alt
572 .as_deref()
573 .map(always_terminates_with_return_arg)
574 .unwrap_or(false)
575 }
576 Stmt::Block(s) => s.stmts.iter().any(always_terminates_with_return_arg),
577
578 _ => false,
579 }
580}
581
582fn can_merge_as_if_return(s: &Stmt) -> bool {
583 fn cost(s: &Stmt) -> Option<isize> {
584 if let Stmt::Block(..) = s {
585 if !swc_ecma_utils::StmtExt::terminates(s) {
586 return None;
587 }
588 }
589
590 match s {
591 Stmt::Return(ReturnStmt { arg: Some(..), .. }) => Some(-1),
592
593 Stmt::Return(ReturnStmt { arg: None, .. }) => Some(0),
594
595 Stmt::Throw(..) | Stmt::Break(..) | Stmt::Continue(..) => Some(0),
596
597 Stmt::If(IfStmt { cons, alt, .. }) => {
598 Some(cost(cons)? + alt.as_deref().and_then(cost).unwrap_or(0))
599 }
600 Stmt::Block(s) => {
601 let mut sum = 0;
602 let mut found = false;
603 for s in s.stmts.iter().rev() {
604 let c = cost(s);
605 if let Some(c) = c {
606 found = true;
607 sum += c;
608 }
609 }
610 if found {
611 Some(sum)
612 } else {
613 None
614 }
615 }
616
617 _ => None,
618 }
619 }
620
621 let c = cost(s);
622
623 trace_op!("merging cost of `{}` = {:?}", dump(s, false), c);
624
625 c.unwrap_or(0) < 0
626}