Skip to content

Commit

Permalink
refactor(transformer/react-refresh): using SemanticInjector to insert…
Browse files Browse the repository at this point in the history
… statements
  • Loading branch information
Dunqing authored and overlookmotel committed Oct 25, 2024
1 parent b2a17ae commit c057449
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 82 deletions.
7 changes: 7 additions & 0 deletions crates/oxc_ast/src/ast_impl/js.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1007,6 +1007,13 @@ impl<'a> Function<'a> {
}
}

impl GetAddress for Function<'_> {
#[inline]
fn address(&self) -> Address {
Address::from_ptr(self)
}
}

impl<'a> FormalParameters<'a> {
/// Number of parameters bound in this parameter list.
pub fn parameters_count(&self) -> usize {
Expand Down
126 changes: 44 additions & 82 deletions crates/oxc_transformer/src/jsx/refresh.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
use std::iter::once;

use base64::prelude::{Engine, BASE64_STANDARD};
use rustc_hash::FxHashMap;
use sha1::{Digest, Sha1};

use oxc_allocator::CloneIn;
use oxc_ast::{ast::*, match_expression, AstBuilder, NONE};
use oxc_semantic::{Reference, ReferenceFlags, ScopeFlags, ScopeId, SymbolFlags, SymbolId};
use oxc_semantic::{Reference, ReferenceFlags, ScopeFlags, ScopeId, SymbolFlags};
use oxc_span::{Atom, GetSpan, SPAN};
use oxc_syntax::operator::AssignmentOperator;
use oxc_traverse::{Ancestor, BoundIdentifier, Traverse, TraverseCtx};
Expand Down Expand Up @@ -107,7 +105,6 @@ pub struct ReactRefresh<'a, 'ctx> {
/// Used to wrap call expression with signature.
/// (eg: hoc(() => {}) -> _s1(hoc(_s1(() => {}))))
last_signature: Option<(BindingIdentifier<'a>, oxc_allocator::Vec<'a, Argument<'a>>)>,
extra_statements: FxHashMap<SymbolId, oxc_allocator::Vec<'a, Statement<'a>>>,
// (function_scope_id, (hook_name, hook_key, custom_hook_callee)
hook_calls: FxHashMap<ScopeId, Vec<(Atom<'a>, Atom<'a>)>>,
non_builtin_hooks_callee: FxHashMap<ScopeId, Vec<Option<Expression<'a>>>>,
Expand All @@ -127,7 +124,6 @@ impl<'a, 'ctx> ReactRefresh<'a, 'ctx> {
registrations: Vec::default(),
ctx,
last_signature: None,
extra_statements: FxHashMap::default(),
hook_calls: FxHashMap::default(),
non_builtin_hooks_callee: FxHashMap::default(),
}
Expand Down Expand Up @@ -196,30 +192,18 @@ impl<'a, 'ctx> Traverse<'a> for ReactRefresh<'a, 'ctx> {
stmts: &mut oxc_allocator::Vec<'a, Statement<'a>>,
ctx: &mut TraverseCtx<'a>,
) {
// TODO: check is there any function declaration

let mut new_stmts = ctx.ast.vec_with_capacity(stmts.len() + 1);

let declarations = self.signature_declarator_items.pop().unwrap();
if !declarations.is_empty() {
new_stmts.push(Statement::from(ctx.ast.declaration_variable(
SPAN,
VariableDeclarationKind::Var,
declarations,
false,
)));
stmts.insert(
0,
Statement::from(ctx.ast.declaration_variable(
SPAN,
VariableDeclarationKind::Var,
declarations,
false,
)),
);
}
new_stmts.extend(stmts.drain(..).flat_map(move |stmt| {
let symbol_ids = get_symbol_id_from_function_and_declarator(&stmt);
let extra_stmts = symbol_ids
.into_iter()
.filter_map(|symbol_id| self.extra_statements.remove(&symbol_id))
.flatten()
.collect::<Vec<_>>();
once(stmt).chain(extra_stmts)
}));

*stmts = new_stmts;
}

fn exit_expression(&mut self, expr: &mut Expression<'a>, ctx: &mut TraverseCtx<'a>) {
Expand Down Expand Up @@ -268,7 +252,6 @@ impl<'a, 'ctx> Traverse<'a> for ReactRefresh<'a, 'ctx> {

let first_argument = Argument::from(id_binding.create_read_expression(ctx));
arguments.insert(0, first_argument);

let statement = ctx.ast.statement_expression(
SPAN,
ctx.ast.expression_call(
Expand All @@ -279,10 +262,19 @@ impl<'a, 'ctx> Traverse<'a> for ReactRefresh<'a, 'ctx> {
false,
),
);
self.extra_statements
.entry(id_binding.symbol_id)
.or_insert(ctx.ast.vec())
.push(statement);

let mut target_ancestor = ctx.ancestor(1);
for ancestor in ctx.ancestors().skip(2) {
if !matches!(
ancestor,
Ancestor::VariableDeclarationDeclarations(_)
| Ancestor::ExportNamedDeclarationDeclaration(_)
) {
break;
}
target_ancestor = ancestor;
}
self.ctx.statement_injector.insert_after(&target_ancestor, statement);
return;
}
}
Expand Down Expand Up @@ -334,18 +326,27 @@ impl<'a, 'ctx> Traverse<'a> for ReactRefresh<'a, 'ctx> {
arguments.insert(0, Argument::from(id_binding.create_read_expression(ctx)));

let binding = BoundIdentifier::from_binding_ident(&binding_identifier);
self.extra_statements.entry(id_binding.symbol_id).or_insert(ctx.ast.vec()).push(
ctx.ast.statement_expression(
SPAN,
ctx.ast.expression_call(
SPAN,
binding.create_read_expression(ctx),
NONE,
arguments,
false,
),
),
);
let callee = binding.create_read_expression(ctx);
let expr = ctx.ast.expression_call(SPAN, callee, NONE, arguments, false);
let statement = ctx.ast.statement_expression(SPAN, expr);

let mut target_ancestor = Ancestor::None;
for ancestor in ctx.ancestors() {
if !matches!(
ancestor,
Ancestor::ExportNamedDeclarationDeclaration(_)
| Ancestor::ExportDefaultDeclarationDeclaration(_)
) {
break;
}
target_ancestor = ancestor;
}

if matches!(target_ancestor, Ancestor::None) {
self.ctx.statement_injector.insert_after(func, statement);
} else {
self.ctx.statement_injector.insert_after(&target_ancestor, statement);
}
}

fn enter_call_expression(
Expand Down Expand Up @@ -898,42 +899,3 @@ fn is_builtin_hook(hook_name: &str) -> bool {
"useOptimistic"
)
}

fn get_symbol_id_from_function_and_declarator(stmt: &Statement<'_>) -> Vec<SymbolId> {
let mut symbol_ids = vec![];
match stmt {
Statement::FunctionDeclaration(ref func) => {
if !func.is_typescript_syntax() {
symbol_ids.push(func.symbol_id().unwrap());
}
}
Statement::VariableDeclaration(ref decl) => {
symbol_ids.extend(decl.declarations.iter().filter_map(|decl| {
decl.id.get_binding_identifier().and_then(|id| id.symbol_id.get())
}));
}
Statement::ExportNamedDeclaration(ref export_decl) => {
if let Some(Declaration::FunctionDeclaration(func)) = &export_decl.declaration {
if !func.is_typescript_syntax() {
symbol_ids.push(func.symbol_id().unwrap());
}
} else if let Some(Declaration::VariableDeclaration(decl)) = &export_decl.declaration {
symbol_ids.extend(decl.declarations.iter().filter_map(|decl| {
decl.id.get_binding_identifier().and_then(|id| id.symbol_id.get())
}));
}
}
Statement::ExportDefaultDeclaration(ref export_decl) => {
if let ExportDefaultDeclarationKind::FunctionDeclaration(func) =
&export_decl.declaration
{
if let Some(id) = func.symbol_id() {
symbol_ids.push(id);
}
}
}
_ => {}
};

symbol_ids
}

0 comments on commit c057449

Please sign in to comment.