Skip to content

Commit

Permalink
[flake8-pyi] Improve autofix for nested and mixed type unions `unne…
Browse files Browse the repository at this point in the history
…cessary-type-union` (`PYI055`) (#14272)

## Summary

This PR improves the fix for `PYI055` to be able to handle nested and
mixed type unions.

It also marks the fix as unsafe when comments are present. 
 
<!-- What's the purpose of the change? What does it do, and why? -->

## Test Plan

<!-- How was it tested? -->
  • Loading branch information
sbrugman authored Nov 12, 2024
1 parent 2b6d66b commit bd30701
Show file tree
Hide file tree
Showing 5 changed files with 237 additions and 136 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ def func():
# PYI055
x: type[requests_mock.Mocker] | type[httpretty] | type[str] = requests_mock.Mocker
y: Union[type[requests_mock.Mocker], type[httpretty], type[str]] = requests_mock.Mocker
z: Union[ # comment
type[requests_mock.Mocker], # another comment
type[httpretty], type[str]] = requests_mock.Mocker


def func():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@ z: Union[float, complex]

def func(arg: type[int, float] | str) -> None: ...

# OK
# PYI055
item: type[requests_mock.Mocker] | type[httpretty] = requests_mock.Mocker

def func():
# PYI055
item: type[requests_mock.Mocker] | type[httpretty] | type[str] = requests_mock.Mocker
item2: Union[type[requests_mock.Mocker], type[httpretty], type[str]] = requests_mock.Mocker
item3: Union[ # comment
type[requests_mock.Mocker], # another comment
type[httpretty], type[str]] = requests_mock.Mocker
165 changes: 104 additions & 61 deletions crates/ruff_linter/src/rules/flake8_pyi/rules/unnecessary_type_union.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use ast::ExprContext;
use ruff_diagnostics::{Diagnostic, Edit, Fix, FixAvailability, Violation};
use ruff_diagnostics::{Applicability, Diagnostic, Edit, Fix, FixAvailability, Violation};
use ruff_macros::{derive_message_formats, violation};
use ruff_python_ast::helpers::pep_604_union;
use ruff_python_ast::name::Name;
Expand All @@ -25,21 +25,28 @@ use crate::checkers::ast::Checker;
/// ```pyi
/// field: type[int | float] | str
/// ```
///
/// ## Fix safety
///
/// This rule's fix is marked as safe in most cases; however, the fix will
/// flatten nested unions type expressions into a single top-level union.
///
/// The fix is marked as unsafe when comments are present within the type
/// expression.
#[violation]
pub struct UnnecessaryTypeUnion {
members: Vec<Name>,
is_pep604_union: bool,
union_kind: UnionKind,
}

impl Violation for UnnecessaryTypeUnion {
const FIX_AVAILABILITY: FixAvailability = FixAvailability::Sometimes;

#[derive_message_formats]
fn message(&self) -> String {
let union_str = if self.is_pep604_union {
self.members.join(" | ")
} else {
format!("Union[{}]", self.members.join(", "))
let union_str = match self.union_kind {
UnionKind::PEP604 => self.members.join(" | "),
UnionKind::TypingUnion => format!("Union[{}]", self.members.join(", ")),
};

format!(
Expand All @@ -63,43 +70,85 @@ pub(crate) fn unnecessary_type_union<'a>(checker: &mut Checker, union: &'a Expr)

// Check if `union` is a PEP604 union (e.g. `float | int`) or a `typing.Union[float, int]`
let subscript = union.as_subscript_expr();
if subscript.is_some_and(|subscript| !semantic.match_typing_expr(&subscript.value, "Union")) {
return;
}
let mut union_kind = match subscript {
Some(subscript) => {
if !semantic.match_typing_expr(&subscript.value, "Union") {
return;
}
UnionKind::TypingUnion
}
None => UnionKind::PEP604,
};

let mut type_exprs: Vec<&Expr> = Vec::new();
let mut other_exprs: Vec<&Expr> = Vec::new();

let mut collect_type_exprs = |expr: &'a Expr, _parent: &'a Expr| match expr {
Expr::Subscript(ast::ExprSubscript { slice, value, .. }) => {
if semantic.match_builtin_expr(value, "type") {
type_exprs.push(slice);
} else {
other_exprs.push(expr);
let mut collect_type_exprs = |expr: &'a Expr, parent: &'a Expr| {
// If a PEP604-style union is used within a `typing.Union`, then the fix can
// use PEP604-style unions.
if matches!(parent, Expr::BinOp(_)) {
union_kind = UnionKind::PEP604;
}
match expr {
Expr::Subscript(ast::ExprSubscript { slice, value, .. }) => {
if semantic.match_builtin_expr(value, "type") {
type_exprs.push(slice);
} else {
other_exprs.push(expr);
}
}
_ => other_exprs.push(expr),
}
_ => other_exprs.push(expr),
};

traverse_union(&mut collect_type_exprs, semantic, union);

if type_exprs.len() > 1 {
let type_members: Vec<Name> = type_exprs
.clone()
.into_iter()
.map(|type_expr| Name::new(checker.locator().slice(type_expr)))
.collect();

let mut diagnostic = Diagnostic::new(
UnnecessaryTypeUnion {
members: type_members.clone(),
is_pep604_union: subscript.is_none(),
},
union.range(),
);

if semantic.has_builtin_binding("type") {
let content = if let Some(subscript) = subscript {
// Return if zero or one `type` expressions are found.
if type_exprs.len() <= 1 {
return;
}

let type_members: Vec<Name> = type_exprs
.iter()
.map(|type_expr| Name::new(checker.locator().slice(type_expr)))
.collect();

let mut diagnostic = Diagnostic::new(
UnnecessaryTypeUnion {
members: type_members.clone(),
union_kind,
},
union.range(),
);

if semantic.has_builtin_binding("type") {
// Construct the content for the [`Fix`] based on if we encountered a PEP604 union.
let content = match union_kind {
UnionKind::PEP604 => {
let elts: Vec<Expr> = type_exprs.into_iter().cloned().collect();
let types = Expr::Subscript(ast::ExprSubscript {
value: Box::new(Expr::Name(ast::ExprName {
id: Name::new_static("type"),
ctx: ExprContext::Load,
range: TextRange::default(),
})),
slice: Box::new(pep_604_union(&elts)),
ctx: ExprContext::Load,
range: TextRange::default(),
});

if other_exprs.is_empty() {
checker.generator().expr(&types)
} else {
let elts: Vec<Expr> = std::iter::once(types)
.chain(other_exprs.into_iter().cloned())
.collect();
checker.generator().expr(&pep_604_union(&elts))
}
}
UnionKind::TypingUnion => {
// When subscript is None, it uses the pervious match case.
let subscript = subscript.unwrap();
let types = &Expr::Subscript(ast::ExprSubscript {
value: Box::new(Expr::Name(ast::ExprName {
id: Name::new_static("type"),
Expand Down Expand Up @@ -151,35 +200,29 @@ pub(crate) fn unnecessary_type_union<'a>(checker: &mut Checker, union: &'a Expr)

checker.generator().expr(&union)
}
} else {
let elts: Vec<Expr> = type_exprs.into_iter().cloned().collect();
let types = Expr::Subscript(ast::ExprSubscript {
value: Box::new(Expr::Name(ast::ExprName {
id: Name::new_static("type"),
ctx: ExprContext::Load,
range: TextRange::default(),
})),
slice: Box::new(pep_604_union(&elts)),
ctx: ExprContext::Load,
range: TextRange::default(),
});

if other_exprs.is_empty() {
checker.generator().expr(&types)
} else {
let elts: Vec<Expr> = std::iter::once(types)
.chain(other_exprs.into_iter().cloned())
.collect();
checker.generator().expr(&pep_604_union(&elts))
}
};
}
};

diagnostic.set_fix(Fix::safe_edit(Edit::range_replacement(
content,
union.range(),
)));
}
// Mark [`Fix`] as unsafe when comments are in range.
let applicability = if checker.comment_ranges().intersects(union.range()) {
Applicability::Unsafe
} else {
Applicability::Safe
};

checker.diagnostics.push(diagnostic);
diagnostic.set_fix(Fix::applicable_edit(
Edit::range_replacement(content, union.range()),
applicability,
));
}

checker.diagnostics.push(diagnostic);
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum UnionKind {
/// E.g., `typing.Union[int, str]`
TypingUnion,
/// E.g., `int | str`
PEP604,
}
Loading

0 comments on commit bd30701

Please sign in to comment.