From b0a5b8e9d0ba3dff61c72068a90c3e9e18242ea6 Mon Sep 17 00:00:00 2001 From: 0xrusowsky <0xrusowsky@proton.me> Date: Wed, 3 Dec 2025 19:12:46 +0100 Subject: [PATCH] fix: track shared variables in modifier's "before" and "after" blocks --- .../sol/codesize/unwrapped_modifier_logic.rs | 270 ++++++++++++++---- .../lint/testdata/UnwrappedModifierLogic.sol | 40 +++ .../testdata/UnwrappedModifierLogic.stderr | 86 ++++++ 3 files changed, 337 insertions(+), 59 deletions(-) diff --git a/crates/lint/src/sol/codesize/unwrapped_modifier_logic.rs b/crates/lint/src/sol/codesize/unwrapped_modifier_logic.rs index 64f7b27134732..df545a17685de 100644 --- a/crates/lint/src/sol/codesize/unwrapped_modifier_logic.rs +++ b/crates/lint/src/sol/codesize/unwrapped_modifier_logic.rs @@ -5,8 +5,10 @@ use crate::{ }; use solar::{ ast, - sema::hir::{self, Res}, + data_structures::{Never, map::FxHashSet}, + sema::hir::{self, Res, Visit}, }; +use std::ops::ControlFlow; declare_forge_lint!( UNWRAPPED_MODIFIER_LOGIC, @@ -46,66 +48,159 @@ impl<'hir> LateLintPass<'hir> for UnwrappedModifierLogic { } } -impl UnwrappedModifierLogic { - /// Returns `true` if an expr is not a built-in ('require' or 'assert') call or a lib function. - fn is_valid_expr(&self, hir: &hir::Hir<'_>, expr: &hir::Expr<'_>) -> bool { - if let hir::ExprKind::Call(func_expr, _, _) = &expr.kind { - if let hir::ExprKind::Ident(resolutions) = &func_expr.kind { - return !resolutions.iter().any(|r| matches!(r, Res::Builtin(_))); - } +/// Visitor that collects used variable IDs from expressions. +struct UsedVarCollector<'hir> { + hir: &'hir hir::Hir<'hir>, + vars: FxHashSet, +} + +impl<'hir> hir::Visit<'hir> for UsedVarCollector<'hir> { + type BreakValue = Never; - if let hir::ExprKind::Member(base, _) = &func_expr.kind - && let hir::ExprKind::Ident(resolutions) = &base.kind - { - return resolutions.iter().any(|r| { - matches!(r, Res::Item(hir::ItemId::Contract(id)) if hir.contract(*id).kind == ast::ContractKind::Library) - }); + fn hir(&self) -> &'hir hir::Hir<'hir> { + self.hir + } + + fn visit_expr(&mut self, expr: &'hir hir::Expr<'hir>) -> ControlFlow { + if let hir::ExprKind::Ident(resolutions) = &expr.kind { + for res in *resolutions { + if let Res::Item(hir::ItemId::Variable(var_id)) = res { + self.vars.insert(*var_id); + } } } - - false + self.walk_expr(expr) } +} - /// Checks if a block of statements is complex and should be wrapped in a helper function. - /// - /// This always is 'false' the modifier contains assembly. We assume that if devs know how to - /// use assembly, they will also know how to reduce the codesize of their contracts and they - /// have a good reason to use it on their modifiers. - /// - /// This is 'true' if the block contains: - /// 1. Any statement that is not a placeholder or a valid expression. - /// 2. More than one simple call expression. - fn stmts_require_wrapping(&self, hir: &hir::Hir<'_>, stmts: &[hir::Stmt<'_>]) -> bool { - let (mut res, mut has_valid_stmt) = (false, false); +impl UnwrappedModifierLogic { + /// Checks if statements require wrapping into a helper function. + /// Returns `false` if assembly is detected (HIR represents it as `Err`). + fn requires_wrapping( + &self, + hir: &hir::Hir<'_>, + stmts: &[hir::Stmt<'_>], + allow_one_decl: bool, + ) -> bool { + let (mut has_trivial_call, mut has_decl) = (false, false); for stmt in stmts { match &stmt.kind { - hir::StmtKind::Placeholder => continue, + hir::StmtKind::Placeholder => {} hir::StmtKind::Expr(expr) => { - if !self.is_valid_expr(hir, expr) || has_valid_stmt { - res = true; + if !self.is_trivial_call(hir, expr) || has_trivial_call || has_decl { + return true; } - has_valid_stmt = true; + has_trivial_call = true; } // HIR doesn't support assembly yet: // hir::StmtKind::Err(_) => return false, - _ => res = true, + hir::StmtKind::DeclSingle(_) | hir::StmtKind::DeclMulti(_, _) if allow_one_decl => { + if has_trivial_call || has_decl { + return true; + } + has_decl = true; + } + _ => return true, } } + false + } - res + /// Collects top-level declared variable IDs from statements. + fn collect_declared_vars(hir: &hir::Hir<'_>, stmts: &[hir::Stmt<'_>]) -> Vec { + let is_stmt_var = + |id: &hir::VariableId| matches!(hir.variable(*id).kind, hir::VarKind::Statement); + let mut vars = Vec::new(); + for stmt in stmts { + match &stmt.kind { + hir::StmtKind::DeclSingle(id) if is_stmt_var(id) => vars.push(*id), + hir::StmtKind::DeclMulti(ids, _) => { + vars.extend(ids.iter().flatten().filter(|id| is_stmt_var(id)).copied()) + } + _ => {} + } + } + vars } - fn get_snippet<'a>( - &self, + /// Collects all variables referenced in a statement block. + fn collect_used_vars( + hir: &hir::Hir<'_>, + stmts: &[hir::Stmt<'_>], + ) -> FxHashSet { + let mut collector = UsedVarCollector { hir, vars: FxHashSet::default() }; + for stmt in stmts { + let _ = collector.visit_stmt(stmt); + } + collector.vars + } + + /// Finds variables declared in "before" that are used in "after". + fn collect_shared_locals( + hir: &hir::Hir<'_>, + before: &[hir::Stmt<'_>], + after: &[hir::Stmt<'_>], + ) -> Vec { + if after.is_empty() || before.is_empty() { + return Vec::new(); + } + let declared_before = Self::collect_declared_vars(hir, before); + if declared_before.is_empty() { + return Vec::new(); + } + let used_after = Self::collect_used_vars(hir, after); + declared_before.into_iter().filter(|id| used_after.contains(id)).collect() + } + + /// Returns `true` if the expression is a "trivial" call that doesn't require wrapping. + fn is_trivial_call(&self, hir: &hir::Hir<'_>, expr: &hir::Expr<'_>) -> bool { + let hir::ExprKind::Call(func_expr, _, _) = &expr.kind else { + return false; + }; + + match &func_expr.kind { + // Direct function call: trivial if not a builtin + hir::ExprKind::Ident(resolutions) => { + !resolutions.iter().any(|r| matches!(r, Res::Builtin(_))) + } + // Member call: trivial if calling a library function + hir::ExprKind::Member(base, _) => { + if let hir::ExprKind::Ident(resolutions) = &base.kind { + resolutions.iter().any(|r| { + matches!(r, Res::Item(hir::ItemId::Contract(id)) + if hir.contract(*id).kind == ast::ContractKind::Library) + }) + } else { + false + } + } + _ => false, + } + } + + /// Extracts (type, name, decl) strings for a variable. + fn extract_var_info( ctx: &LintContext, hir: &hir::Hir<'_>, + var_id: hir::VariableId, + ) -> Option<(String, String, String)> { + let var = hir.variable(var_id); + let ty = ctx.span_to_snippet(var.ty.span)?; + let name = var.name?.to_string(); + Some((ty.clone(), name.clone(), format!("{ty} {name}"))) + } + + fn get_snippet<'hir>( + &self, + ctx: &LintContext, + hir: &'hir hir::Hir<'hir>, func: &hir::Function<'_>, - before: &'a [hir::Stmt<'a>], - after: &'a [hir::Stmt<'a>], + before: &'hir [hir::Stmt<'hir>], + after: &'hir [hir::Stmt<'hir>], ) -> Option { - let wrap_before = !before.is_empty() && self.stmts_require_wrapping(hir, before); - let wrap_after = !after.is_empty() && self.stmts_require_wrapping(hir, after); + let wrap_before = !before.is_empty() && self.requires_wrapping(hir, before, true); + let wrap_after = !after.is_empty() && self.requires_wrapping(hir, after, false); if !(wrap_before || wrap_after) { return None; @@ -113,37 +208,81 @@ impl UnwrappedModifierLogic { let binding = func.name.unwrap(); let modifier_name = binding.name.as_str(); - let mut param_list = vec![]; + let mut param_names = vec![]; let mut param_decls = vec![]; - for var_id in func.parameters { - let var = hir.variable(*var_id); - let ty = ctx - .span_to_snippet(var.ty.span) - .unwrap_or_else(|| "/* unknown type */".to_string()); - - // solidity functions should always have named parameters - if let Some(ident) = var.name { - param_list.push(ident.to_string()); - param_decls.push(format!("{ty} {}", ident.to_string())); + if let Some((_, name, decl)) = Self::extract_var_info(ctx, hir, *var_id) { + param_names.push(name); + param_decls.push(decl); } } - let param_list = param_list.join(", "); - let param_decls = param_decls.join(", "); + // Extract type and name info for shared locals + let shared_locals = Self::collect_shared_locals(hir, before, after); + let (mut shared_types, mut shared_names, mut shared_decls) = (vec![], vec![], vec![]); + for var_id in &shared_locals { + if let Some((ty, name, decl)) = Self::extract_var_info(ctx, hir, *var_id) { + shared_types.push(ty); + shared_names.push(name); + shared_decls.push(decl); + } + } let body_indent = " ".repeat(ctx.get_span_indentation( before.first().or(after.first()).map(|stmt| stmt.span).unwrap_or(func.span), )); + + // Build format strings for different shared variable counts + let (assignment, returns_decl, return_stmt) = match shared_locals.len() { + 0 => (String::new(), String::new(), String::new()), + 1 => ( + format!("{} {} = ", shared_types[0], shared_names[0]), + format!(" returns ({})", shared_types[0]), + format!("\n{body_indent}return {};", shared_names[0]), + ), + _ => ( + format!("({}) = ", shared_decls.join(", ")), + format!(" returns ({})", shared_types.join(", ")), + format!("\n{body_indent}return ({});", shared_names.join(", ")), + ), + }; + + let param_names = param_names.join(", "); + let param_decls = param_decls.join(", "); + + let after_args = if shared_locals.is_empty() { + param_names.clone() + } else if param_names.is_empty() { + shared_names.join(", ") + } else { + format!("{}, {}", param_names, shared_names.join(", ")) + }; + let body = match (wrap_before, wrap_after) { (true, true) => format!( - "{body_indent}_{modifier_name}Before({param_list});\n{body_indent}_;\n{body_indent}_{modifier_name}After({param_list});" + "{body_indent}{assignment}_{modifier_name}Before({param_names});\n{body_indent}_;\n{body_indent}_{modifier_name}After({after_args});" ), (true, false) => { - format!("{body_indent}_{modifier_name}({param_list});\n{body_indent}_;") + // Before is wrapped, after isn't complex enough to wrap - keep after inline + let after_stmts = after + .iter() + .filter_map(|s| ctx.span_to_snippet(s.span)) + .map(|code| format!("\n{body_indent}{code}")) + .collect::(); + format!( + "{body_indent}{assignment}_{modifier_name}({param_names});\n{body_indent}_;{after_stmts}" + ) } (false, true) => { - format!("{body_indent}_;\n{body_indent}_{modifier_name}({param_list});") + // Before isn't wrapped, so include its statements inline + let before_stmts = before + .iter() + .filter_map(|s| ctx.span_to_snippet(s.span)) + .map(|code| format!("{body_indent}{code}\n")) + .collect::(); + format!( + "{before_stmts}{body_indent}_;\n{body_indent}_{modifier_name}({after_args});" + ) } _ => unreachable!(), }; @@ -152,22 +291,35 @@ impl UnwrappedModifierLogic { let mut replacement = format!("modifier {modifier_name}({param_decls}) {{\n{body}\n{mod_indent}}}"); - let build_func = |stmts: &[hir::Stmt<'_>], suffix: &str| { + let build_func = |stmts: &[hir::Stmt<'_>], suffix: &str, is_before: bool| { let body_stmts = stmts .iter() .filter_map(|s| ctx.span_to_snippet(s.span)) .map(|code| format!("\n{body_indent}{code}")) .collect::(); + + let extra_params = if !is_before && !shared_decls.is_empty() { + if param_decls.is_empty() { + shared_decls.join(", ") + } else { + format!("{}, {}", param_decls, shared_decls.join(", ")) + } + } else { + param_decls.clone() + }; + + let returns = if is_before && !returns_decl.is_empty() { &returns_decl } else { "" }; + let ret_stmt = if is_before && !return_stmt.is_empty() { &return_stmt } else { "" }; format!( - "\n\n{mod_indent}function _{modifier_name}{suffix}({param_decls}) internal {{{body_stmts}\n{mod_indent}}}" + "\n\n{mod_indent}function _{modifier_name}{suffix}({extra_params}) internal{returns} {{{body_stmts}{ret_stmt}\n{mod_indent}}}" ) }; if wrap_before { - replacement.push_str(&build_func(before, if wrap_after { "Before" } else { "" })); + replacement.push_str(&build_func(before, if wrap_after { "Before" } else { "" }, true)); } if wrap_after { - replacement.push_str(&build_func(after, if wrap_before { "After" } else { "" })); + replacement.push_str(&build_func(after, if wrap_before { "After" } else { "" }, false)); } Some( diff --git a/crates/lint/testdata/UnwrappedModifierLogic.sol b/crates/lint/testdata/UnwrappedModifierLogic.sol index 2a1f541baa7e9..cdc0f66d467c3 100644 --- a/crates/lint/testdata/UnwrappedModifierLogic.sol +++ b/crates/lint/testdata/UnwrappedModifierLogic.sol @@ -81,6 +81,46 @@ contract UnwrappedModifierLogicTest { checkPrivate(msg.sender); } + /// ----------------------------------------------------------------------- + /// Shared local variables + /// ----------------------------------------------------------------------- + + function gasLeft() internal returns (uint256) { return 1; } + function gasLeftMulti() internal returns (uint256, bool) { return (1, true); } + function _payMeSubsidizedGasAfter(uint256, uint256) internal {} + function _refund(uint256) internal {} + + // Single shared variable: declared before, used after + modifier payMeSubsidizedGas(uint256 amount) { + uint256 pre = gasLeft(); + _; + _payMeSubsidizedGasAfter(pre, amount); + } + + // Multiple shared variables + modifier payMeFixedGasAmount() { //~NOTE: wrap modifier logic to reduce code size + uint256 pre = gasLeft(); + uint256 amount = 12345; + _; + _payMeSubsidizedGasAfter(pre, amount); + } + + modifier payMeSubsidizedGasAndRefund(uint256 amount) { //~NOTE: wrap modifier logic to reduce code size + (uint256 pre, bool success) = gasLeftMulti(); + _; + _payMeSubsidizedGasAfter(pre, amount); + _refund(pre); + } + + // Multiple shared variables + modifier payMeFixedGasAmountAndRefund() { //~NOTE: wrap modifier logic to reduce code size + uint256 pre = gasLeft(); + uint256 amount = 12345; + _; + _payMeSubsidizedGasAfter(pre, amount); + _refund(pre); + } + /// ----------------------------------------------------------------------- /// Bad patterns (multiple valid statements before or after placeholder) /// ----------------------------------------------------------------------- diff --git a/crates/lint/testdata/UnwrappedModifierLogic.stderr b/crates/lint/testdata/UnwrappedModifierLogic.stderr index d5d02817d0c04..5a2f88b892872 100644 --- a/crates/lint/testdata/UnwrappedModifierLogic.stderr +++ b/crates/lint/testdata/UnwrappedModifierLogic.stderr @@ -1,3 +1,89 @@ +note[unwrapped-modifier-logic]: wrap modifier logic to reduce code size + --> ROOT/testdata/UnwrappedModifierLogic.sol:LL:CC + | +LL | / modifier payMeFixedGasAmount() { +LL | | uint256 pre = gasLeft(); +LL | | uint256 amount = 12345; +LL | | _; +LL | | _payMeSubsidizedGasAfter(pre, amount); +LL | | } + | |_____^ + | +help: wrap modifier logic to reduce code size + | +LL ~ modifier payMeFixedGasAmount() { +LL + (uint256 pre, uint256 amount) = _payMeFixedGasAmount(); +LL + _; +LL + _payMeSubsidizedGasAfter(pre, amount); +LL + } +LL + +LL + function _payMeFixedGasAmount() internal returns (uint256, uint256) { +LL + uint256 pre = gasLeft(); +LL + uint256 amount = 12345; +LL + return (pre, amount); +LL + } + | + = help: https://book.getfoundry.sh/reference/forge/forge-lint#unwrapped-modifier-logic + +note[unwrapped-modifier-logic]: wrap modifier logic to reduce code size + --> ROOT/testdata/UnwrappedModifierLogic.sol:LL:CC + | +LL | / modifier payMeSubsidizedGasAndRefund(uint256 amount) { +LL | | (uint256 pre, bool success) = gasLeftMulti(); +LL | | _; +LL | | _payMeSubsidizedGasAfter(pre, amount); +LL | | _refund(pre); +LL | | } + | |_____^ + | +help: wrap modifier logic to reduce code size + | +LL ~ modifier payMeSubsidizedGasAndRefund(uint256 amount) { +LL + (uint256 pre, bool success) = gasLeftMulti(); +LL + _; +LL + _payMeSubsidizedGasAndRefund(amount, pre); +LL + } +LL + +LL + function _payMeSubsidizedGasAndRefund(uint256 amount, uint256 pre) internal { +LL + _payMeSubsidizedGasAfter(pre, amount); +LL + _refund(pre); +LL + } + | + = help: https://book.getfoundry.sh/reference/forge/forge-lint#unwrapped-modifier-logic + +note[unwrapped-modifier-logic]: wrap modifier logic to reduce code size + --> ROOT/testdata/UnwrappedModifierLogic.sol:LL:CC + | +LL | / modifier payMeFixedGasAmountAndRefund() { +LL | | uint256 pre = gasLeft(); +LL | | uint256 amount = 12345; +LL | | _; +LL | | _payMeSubsidizedGasAfter(pre, amount); +LL | | _refund(pre); +LL | | } + | |_____^ + | +help: wrap modifier logic to reduce code size + | +LL ~ modifier payMeFixedGasAmountAndRefund() { +LL + (uint256 pre, uint256 amount) = _payMeFixedGasAmountAndRefundBefore(); +LL + _; +LL + _payMeFixedGasAmountAndRefundAfter(pre, amount); +LL + } +LL + +LL + function _payMeFixedGasAmountAndRefundBefore() internal returns (uint256, uint256) { +LL + uint256 pre = gasLeft(); +LL + uint256 amount = 12345; +LL + return (pre, amount); +LL + } +LL + +LL + function _payMeFixedGasAmountAndRefundAfter(uint256 pre, uint256 amount) internal { +LL + _payMeSubsidizedGasAfter(pre, amount); +LL + _refund(pre); +LL + } + | + = help: https://book.getfoundry.sh/reference/forge/forge-lint#unwrapped-modifier-logic + note[unwrapped-modifier-logic]: wrap modifier logic to reduce code size --> ROOT/testdata/UnwrappedModifierLogic.sol:LL:CC |