Skip to content

Commit

Permalink
Remove some allocations in argument detection (#5481)
Browse files Browse the repository at this point in the history
## Summary

Drive-by PR to remove some allocations around argument name matching.
  • Loading branch information
charliermarsh authored Jul 3, 2023
1 parent d2450c2 commit dadad0e
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 97 deletions.
2 changes: 1 addition & 1 deletion crates/ruff/src/checkers/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ pub(crate) struct Checker<'a> {
deferred: Deferred<'a>,
pub(crate) diagnostics: Vec<Diagnostic>,
// Check-specific state.
pub(crate) flake8_bugbear_seen: Vec<&'a Expr>,
pub(crate) flake8_bugbear_seen: Vec<&'a ast::ExprName>,
}

impl<'a> Checker<'a> {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,31 +1,28 @@
use rustpython_parser::ast::{self, Constant, Expr, Keyword, Ranged};
use rustpython_parser::ast::{Expr, Keyword, Ranged};

use ruff_diagnostics::{Diagnostic, Violation};
use ruff_macros::{derive_message_formats, violation};
use ruff_python_ast::helpers::SimpleCallArgs;
use ruff_python_ast::helpers::{is_const_none, SimpleCallArgs};

use crate::checkers::ast::Checker;

#[violation]
pub struct RequestWithoutTimeout {
pub timeout: Option<String>,
implicit: bool,
}

impl Violation for RequestWithoutTimeout {
#[derive_message_formats]
fn message(&self) -> String {
let RequestWithoutTimeout { timeout } = self;
match timeout {
Some(value) => {
format!("Probable use of requests call with timeout set to `{value}`")
}
None => format!("Probable use of requests call without timeout"),
let RequestWithoutTimeout { implicit } = self;
if *implicit {
format!("Probable use of requests call without timeout")
} else {
format!("Probable use of requests call with timeout set to `None`")
}
}
}

const HTTP_VERBS: [&str; 7] = ["get", "options", "head", "post", "put", "patch", "delete"];

/// S113
pub(crate) fn request_without_timeout(
checker: &mut Checker,
Expand All @@ -37,30 +34,26 @@ pub(crate) fn request_without_timeout(
.semantic()
.resolve_call_path(func)
.map_or(false, |call_path| {
HTTP_VERBS
.iter()
.any(|func_name| call_path.as_slice() == ["requests", func_name])
matches!(
call_path.as_slice(),
[
"requests",
"get" | "options" | "head" | "post" | "put" | "patch" | "delete"
]
)
})
{
let call_args = SimpleCallArgs::new(args, keywords);
if let Some(timeout_arg) = call_args.keyword_argument("timeout") {
if let Some(timeout) = match timeout_arg {
Expr::Constant(ast::ExprConstant {
value: value @ Constant::None,
..
}) => Some(checker.generator().constant(value)),
_ => None,
} {
if let Some(timeout) = call_args.keyword_argument("timeout") {
if is_const_none(timeout) {
checker.diagnostics.push(Diagnostic::new(
RequestWithoutTimeout {
timeout: Some(timeout),
},
timeout_arg.range(),
RequestWithoutTimeout { implicit: false },
timeout.range(),
));
}
} else {
checker.diagnostics.push(Diagnostic::new(
RequestWithoutTimeout { timeout: None },
RequestWithoutTimeout { implicit: true },
func.range(),
));
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
use rustc_hash::FxHashSet;
use rustpython_parser::ast::{self, Comprehension, Expr, ExprContext, Ranged, Stmt};

use ruff_diagnostics::{Diagnostic, Violation};
use ruff_macros::{derive_message_formats, violation};
use ruff_python_ast::helpers::collect_arg_names;
use ruff_python_ast::helpers::includes_arg_name;
use ruff_python_ast::types::Node;
use ruff_python_ast::visitor;
use ruff_python_ast::visitor::Visitor;
Expand Down Expand Up @@ -58,19 +57,17 @@ impl Violation for FunctionUsesLoopVariable {

#[derive(Default)]
struct LoadedNamesVisitor<'a> {
// Tuple of: name, defining expression, and defining range.
loaded: Vec<(&'a str, &'a Expr)>,
// Tuple of: name, defining expression, and defining range.
stored: Vec<(&'a str, &'a Expr)>,
loaded: Vec<&'a ast::ExprName>,
stored: Vec<&'a ast::ExprName>,
}

/// `Visitor` to collect all used identifiers in a statement.
impl<'a> Visitor<'a> for LoadedNamesVisitor<'a> {
fn visit_expr(&mut self, expr: &'a Expr) {
match expr {
Expr::Name(ast::ExprName { id, ctx, range: _ }) => match ctx {
ExprContext::Load => self.loaded.push((id, expr)),
ExprContext::Store => self.stored.push((id, expr)),
Expr::Name(name) => match &name.ctx {
ExprContext::Load => self.loaded.push(name),
ExprContext::Store => self.stored.push(name),
ExprContext::Del => {}
},
_ => visitor::walk_expr(self, expr),
Expand All @@ -80,7 +77,7 @@ impl<'a> Visitor<'a> for LoadedNamesVisitor<'a> {

#[derive(Default)]
struct SuspiciousVariablesVisitor<'a> {
names: Vec<(&'a str, &'a Expr)>,
names: Vec<&'a ast::ExprName>,
safe_functions: Vec<&'a Expr>,
}

Expand All @@ -95,17 +92,20 @@ impl<'a> Visitor<'a> for SuspiciousVariablesVisitor<'a> {
let mut visitor = LoadedNamesVisitor::default();
visitor.visit_body(body);

// Collect all argument names.
let mut arg_names = collect_arg_names(args);
arg_names.extend(visitor.stored.iter().map(|(id, ..)| id));

// Treat any non-arguments as "suspicious".
self.names.extend(
visitor
.loaded
.into_iter()
.filter(|(id, ..)| !arg_names.contains(id)),
);
self.names
.extend(visitor.loaded.into_iter().filter(|loaded| {
if visitor.stored.iter().any(|stored| stored.id == loaded.id) {
return false;
}

if includes_arg_name(&loaded.id, args) {
return false;
}

true
}));

return;
}
Stmt::Return(ast::StmtReturn {
Expand All @@ -132,10 +132,9 @@ impl<'a> Visitor<'a> for SuspiciousVariablesVisitor<'a> {
}) => {
match func.as_ref() {
Expr::Name(ast::ExprName { id, .. }) => {
let id = id.as_str();
if id == "filter" || id == "reduce" || id == "map" {
if matches!(id.as_str(), "filter" | "reduce" | "map") {
for arg in args {
if matches!(arg, Expr::Lambda(_)) {
if arg.is_lambda_expr() {
self.safe_functions.push(arg);
}
}
Expand All @@ -159,7 +158,7 @@ impl<'a> Visitor<'a> for SuspiciousVariablesVisitor<'a> {

for keyword in keywords {
if keyword.arg.as_ref().map_or(false, |arg| arg == "key")
&& matches!(keyword.value, Expr::Lambda(_))
&& keyword.value.is_lambda_expr()
{
self.safe_functions.push(&keyword.value);
}
Expand All @@ -175,17 +174,19 @@ impl<'a> Visitor<'a> for SuspiciousVariablesVisitor<'a> {
let mut visitor = LoadedNamesVisitor::default();
visitor.visit_expr(body);

// Collect all argument names.
let mut arg_names = collect_arg_names(args);
arg_names.extend(visitor.stored.iter().map(|(id, ..)| id));

// Treat any non-arguments as "suspicious".
self.names.extend(
visitor
.loaded
.iter()
.filter(|(id, ..)| !arg_names.contains(id)),
);
self.names
.extend(visitor.loaded.into_iter().filter(|loaded| {
if visitor.stored.iter().any(|stored| stored.id == loaded.id) {
return false;
}

if includes_arg_name(&loaded.id, args) {
return false;
}

true
}));

return;
}
Expand All @@ -198,15 +199,15 @@ impl<'a> Visitor<'a> for SuspiciousVariablesVisitor<'a> {

#[derive(Default)]
struct NamesFromAssignmentsVisitor<'a> {
names: FxHashSet<&'a str>,
names: Vec<&'a str>,
}

/// `Visitor` to collect all names used in an assignment expression.
impl<'a> Visitor<'a> for NamesFromAssignmentsVisitor<'a> {
fn visit_expr(&mut self, expr: &'a Expr) {
match expr {
Expr::Name(ast::ExprName { id, .. }) => {
self.names.insert(id.as_str());
self.names.push(id.as_str());
}
Expr::Starred(ast::ExprStarred { value, .. }) => {
self.visit_expr(value);
Expand All @@ -223,7 +224,7 @@ impl<'a> Visitor<'a> for NamesFromAssignmentsVisitor<'a> {

#[derive(Default)]
struct AssignedNamesVisitor<'a> {
names: FxHashSet<&'a str>,
names: Vec<&'a str>,
}

/// `Visitor` to collect all used identifiers in a statement.
Expand Down Expand Up @@ -257,7 +258,7 @@ impl<'a> Visitor<'a> for AssignedNamesVisitor<'a> {
}

fn visit_expr(&mut self, expr: &'a Expr) {
if matches!(expr, Expr::Lambda(_)) {
if expr.is_lambda_expr() {
// Don't recurse.
return;
}
Expand Down Expand Up @@ -300,15 +301,15 @@ pub(crate) fn function_uses_loop_variable<'a>(checker: &mut Checker<'a>, node: &

// If a variable was used in a function or lambda body, and assigned in the
// loop, flag it.
for (name, expr) in suspicious_variables {
if reassigned_in_loop.contains(name) {
if !checker.flake8_bugbear_seen.contains(&expr) {
checker.flake8_bugbear_seen.push(expr);
for name in suspicious_variables {
if reassigned_in_loop.contains(&name.id.as_str()) {
if !checker.flake8_bugbear_seen.contains(&name) {
checker.flake8_bugbear_seen.push(name);
checker.diagnostics.push(Diagnostic::new(
FunctionUsesLoopVariable {
name: name.to_string(),
name: name.id.to_string(),
},
expr.range(),
name.range(),
));
}
}
Expand Down
4 changes: 2 additions & 2 deletions crates/ruff/src/rules/flake8_pytest_style/rules/fixture.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use ruff_diagnostics::{AlwaysAutofixableViolation, Violation};
use ruff_diagnostics::{Diagnostic, Edit, Fix};
use ruff_macros::{derive_message_formats, violation};
use ruff_python_ast::call_path::collect_call_path;
use ruff_python_ast::helpers::collect_arg_names;
use ruff_python_ast::helpers::includes_arg_name;
use ruff_python_ast::identifier::Identifier;
use ruff_python_ast::visitor;
use ruff_python_ast::visitor::Visitor;
Expand Down Expand Up @@ -446,7 +446,7 @@ fn check_fixture_decorator_name(checker: &mut Checker, decorator: &Decorator) {

/// PT021
fn check_fixture_addfinalizer(checker: &mut Checker, args: &Arguments, body: &[Stmt]) {
if !collect_arg_names(args).contains(&"request") {
if !includes_arg_name("request", args) {
return;
}

Expand Down
19 changes: 11 additions & 8 deletions crates/ruff/src/rules/flake8_pytest_style/rules/patch.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
use rustc_hash::FxHashSet;
use rustpython_parser::ast::{self, Expr, Keyword, Ranged};
use rustpython_parser::ast::{self, Arguments, Expr, Keyword, Ranged};

use ruff_diagnostics::{Diagnostic, Violation};
use ruff_macros::{derive_message_formats, violation};
use ruff_python_ast::call_path::collect_call_path;
use ruff_python_ast::helpers::{collect_arg_names, SimpleCallArgs};
use ruff_python_ast::helpers::{includes_arg_name, SimpleCallArgs};
use ruff_python_ast::visitor;
use ruff_python_ast::visitor::Visitor;

Expand All @@ -18,10 +17,10 @@ impl Violation for PytestPatchWithLambda {
}
}

#[derive(Default)]
/// Visitor that checks references the argument names in the lambda body.
#[derive(Debug)]
struct LambdaBodyVisitor<'a> {
names: FxHashSet<&'a str>,
arguments: &'a Arguments,
uses_args: bool,
}

Expand All @@ -32,11 +31,15 @@ where
fn visit_expr(&mut self, expr: &'b Expr) {
match expr {
Expr::Name(ast::ExprName { id, .. }) => {
if self.names.contains(&id.as_str()) {
if includes_arg_name(id, self.arguments) {
self.uses_args = true;
}
}
_ => visitor::walk_expr(self, expr),
_ => {
if !self.uses_args {
visitor::walk_expr(self, expr);
}
}
}
}
}
Expand All @@ -60,7 +63,7 @@ fn check_patch_call(
{
// Walk the lambda body.
let mut visitor = LambdaBodyVisitor {
names: collect_arg_names(args),
arguments: args,
uses_args: false,
};
visitor.visit_expr(body);
Expand Down
Loading

0 comments on commit dadad0e

Please sign in to comment.