diff --git a/compiler/rustc_middle/src/ty/sty.rs b/compiler/rustc_middle/src/ty/sty.rs index 9835211a74865..7c6d6ea1cb6ee 100644 --- a/compiler/rustc_middle/src/ty/sty.rs +++ b/compiler/rustc_middle/src/ty/sty.rs @@ -8,7 +8,9 @@ use crate::infer::canonical::Canonical; use crate::ty::fold::ValidateBoundVars; use crate::ty::subst::{GenericArg, InternalSubsts, Subst, SubstsRef}; use crate::ty::InferTy::{self, *}; -use crate::ty::{self, AdtDef, DefIdTree, Discr, Term, Ty, TyCtxt, TypeFlags, TypeFoldable}; +use crate::ty::{ + self, AdtDef, DefIdTree, Discr, Term, Ty, TyCtxt, TypeFlags, TypeFoldable, TypeVisitor, +}; use crate::ty::{DelaySpanBugEmitted, List, ParamEnv}; use polonius_engine::Atom; use rustc_data_structures::captures::Captures; @@ -24,7 +26,7 @@ use std::borrow::Cow; use std::cmp::Ordering; use std::fmt; use std::marker::PhantomData; -use std::ops::{Deref, Range}; +use std::ops::{ControlFlow, Deref, Range}; use ty::util::IntTypeExt; #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug, TyEncodable, TyDecodable)] @@ -2072,6 +2074,24 @@ impl<'tcx> Ty<'tcx> { !matches!(self.kind(), Param(_) | Infer(_) | Error(_)) } + /// Checks whether a type recursively contains another type + /// + /// Example: `Option<()>` contains `()` + pub fn contains(self, other: Ty<'tcx>) -> bool { + struct ContainsTyVisitor<'tcx>(Ty<'tcx>); + + impl<'tcx> TypeVisitor<'tcx> for ContainsTyVisitor<'tcx> { + type BreakTy = (); + + fn visit_ty(&mut self, t: Ty<'tcx>) -> ControlFlow { + if self.0 == t { ControlFlow::BREAK } else { t.super_visit_with(self) } + } + } + + let cf = self.visit_with(&mut ContainsTyVisitor(other)); + cf.is_break() + } + /// Returns the type and mutability of `*ty`. /// /// The parameter `explicit` indicates if this is an *explicit* dereference. diff --git a/compiler/rustc_typeck/src/check/fn_ctxt/suggestions.rs b/compiler/rustc_typeck/src/check/fn_ctxt/suggestions.rs index 86cf850d72322..f9c482713f1fe 100644 --- a/compiler/rustc_typeck/src/check/fn_ctxt/suggestions.rs +++ b/compiler/rustc_typeck/src/check/fn_ctxt/suggestions.rs @@ -8,8 +8,12 @@ use rustc_errors::{Applicability, DiagnosticBuilder}; use rustc_hir as hir; use rustc_hir::def::{CtorOf, DefKind}; use rustc_hir::lang_items::LangItem; -use rustc_hir::{Expr, ExprKind, ItemKind, Node, Path, QPath, Stmt, StmtKind, TyKind}; +use rustc_hir::{ + Expr, ExprKind, GenericBound, ItemKind, Node, Path, QPath, Stmt, StmtKind, TyKind, + WherePredicate, +}; use rustc_infer::infer::{self, TyCtxtInferExt}; + use rustc_middle::lint::in_external_macro; use rustc_middle::ty::{self, Binder, Ty}; use rustc_span::symbol::{kw, sym}; @@ -559,6 +563,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { let ty = self.tcx.erase_late_bound_regions(ty); if self.can_coerce(expected, ty) { err.span_label(sp, format!("expected `{}` because of return type", expected)); + self.try_suggest_return_impl_trait(err, expected, ty, fn_id); return true; } false @@ -566,6 +571,115 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { } } + /// check whether the return type is a generic type with a trait bound + /// only suggest this if the generic param is not present in the arguments + /// if this is true, hint them towards changing the return type to `impl Trait` + /// ``` + /// fn cant_name_it u32>() -> T { + /// || 3 + /// } + /// ``` + fn try_suggest_return_impl_trait( + &self, + err: &mut DiagnosticBuilder<'_>, + expected: Ty<'tcx>, + found: Ty<'tcx>, + fn_id: hir::HirId, + ) { + // Only apply the suggestion if: + // - the return type is a generic parameter + // - the generic param is not used as a fn param + // - the generic param has at least one bound + // - the generic param doesn't appear in any other bounds where it's not the Self type + // Suggest: + // - Changing the return type to be `impl ` + + debug!("try_suggest_return_impl_trait, expected = {:?}, found = {:?}", expected, found); + + let ty::Param(expected_ty_as_param) = expected.kind() else { return }; + + let fn_node = self.tcx.hir().find(fn_id); + + let Some(hir::Node::Item(hir::Item { + kind: + hir::ItemKind::Fn( + hir::FnSig { decl: hir::FnDecl { inputs: fn_parameters, output: fn_return, .. }, .. }, + hir::Generics { params, where_clause, .. }, + _body_id, + ), + .. + })) = fn_node else { return }; + + let Some(expected_generic_param) = params.get(expected_ty_as_param.index as usize) else { return }; + + // get all where BoundPredicates here, because they are used in to cases below + let where_predicates = where_clause + .predicates + .iter() + .filter_map(|p| match p { + WherePredicate::BoundPredicate(hir::WhereBoundPredicate { + bounds, + bounded_ty, + .. + }) => { + // FIXME: Maybe these calls to `ast_ty_to_ty` can be removed (and the ones below) + let ty = >::ast_ty_to_ty(self, bounded_ty); + Some((ty, bounds)) + } + _ => None, + }) + .map(|(ty, bounds)| match ty.kind() { + ty::Param(param_ty) if param_ty == expected_ty_as_param => Ok(Some(bounds)), + // check whether there is any predicate that contains our `T`, like `Option: Send` + _ => match ty.contains(expected) { + true => Err(()), + false => Ok(None), + }, + }) + .collect::, _>>(); + + let Ok(where_predicates) = where_predicates else { return }; + + // now get all predicates in the same types as the where bounds, so we can chain them + let predicates_from_where = + where_predicates.iter().flatten().map(|bounds| bounds.iter()).flatten(); + + // extract all bounds from the source code using their spans + let all_matching_bounds_strs = expected_generic_param + .bounds + .iter() + .chain(predicates_from_where) + .filter_map(|bound| match bound { + GenericBound::Trait(_, _) => { + self.tcx.sess.source_map().span_to_snippet(bound.span()).ok() + } + _ => None, + }) + .collect::>(); + + if all_matching_bounds_strs.len() == 0 { + return; + } + + let all_bounds_str = all_matching_bounds_strs.join(" + "); + + let ty_param_used_in_fn_params = fn_parameters.iter().any(|param| { + let ty = >::ast_ty_to_ty(self, param); + matches!(ty.kind(), ty::Param(fn_param_ty_param) if expected_ty_as_param == fn_param_ty_param) + }); + + if ty_param_used_in_fn_params { + return; + } + + err.span_suggestion( + fn_return.span(), + "consider using an impl return type", + format!("impl {}", all_bounds_str), + Applicability::MaybeIncorrect, + ); + } + pub(in super::super) fn suggest_missing_break_or_return_expr( &self, err: &mut DiagnosticBuilder<'_>, diff --git a/src/test/ui/return/return-impl-trait-bad.rs b/src/test/ui/return/return-impl-trait-bad.rs new file mode 100644 index 0000000000000..e3f6ddb9a1497 --- /dev/null +++ b/src/test/ui/return/return-impl-trait-bad.rs @@ -0,0 +1,31 @@ +trait Trait {} +impl Trait for () {} + +fn bad_echo(_t: T) -> T { + "this should not suggest impl Trait" //~ ERROR mismatched types +} + +fn bad_echo_2(_t: T) -> T { + "this will not suggest it, because that would probably be wrong" //~ ERROR mismatched types +} + +fn other_bounds_bad() -> T +where + T: Send, + Option: Send, +{ + "don't suggest this, because Option places additional constraints" //~ ERROR mismatched types +} + +// FIXME: implement this check +trait GenericTrait {} + +fn used_in_trait() -> T +where + T: Send, + (): GenericTrait, +{ + "don't suggest this, because the generic param is used in the bound." //~ ERROR mismatched types +} + +fn main() {} diff --git a/src/test/ui/return/return-impl-trait-bad.stderr b/src/test/ui/return/return-impl-trait-bad.stderr new file mode 100644 index 0000000000000..237b85ee66a10 --- /dev/null +++ b/src/test/ui/return/return-impl-trait-bad.stderr @@ -0,0 +1,59 @@ +error[E0308]: mismatched types + --> $DIR/return-impl-trait-bad.rs:5:5 + | +LL | fn bad_echo(_t: T) -> T { + | - - expected `T` because of return type + | | + | this type parameter +LL | "this should not suggest impl Trait" + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ expected type parameter `T`, found `&str` + | + = note: expected type parameter `T` + found reference `&'static str` + +error[E0308]: mismatched types + --> $DIR/return-impl-trait-bad.rs:9:5 + | +LL | fn bad_echo_2(_t: T) -> T { + | - - expected `T` because of return type + | | + | this type parameter +LL | "this will not suggest it, because that would probably be wrong" + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ expected type parameter `T`, found `&str` + | + = note: expected type parameter `T` + found reference `&'static str` + +error[E0308]: mismatched types + --> $DIR/return-impl-trait-bad.rs:17:5 + | +LL | fn other_bounds_bad() -> T + | - - expected `T` because of return type + | | + | this type parameter +... +LL | "don't suggest this, because Option places additional constraints" + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ expected type parameter `T`, found `&str` + | + = note: expected type parameter `T` + found reference `&'static str` + +error[E0308]: mismatched types + --> $DIR/return-impl-trait-bad.rs:28:5 + | +LL | fn used_in_trait() -> T + | - - + | | | + | | expected `T` because of return type + | | help: consider using an impl return type: `impl Send` + | this type parameter +... +LL | "don't suggest this, because the generic param is used in the bound." + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ expected type parameter `T`, found `&str` + | + = note: expected type parameter `T` + found reference `&'static str` + +error: aborting due to 4 previous errors + +For more information about this error, try `rustc --explain E0308`. diff --git a/src/test/ui/return/return-impl-trait.fixed b/src/test/ui/return/return-impl-trait.fixed new file mode 100644 index 0000000000000..ff2b02f73ea65 --- /dev/null +++ b/src/test/ui/return/return-impl-trait.fixed @@ -0,0 +1,30 @@ +// run-rustfix + +trait Trait {} +impl Trait for () {} + +// this works +fn foo() -> impl Trait { + () +} + +fn bar() -> impl Trait + std::marker::Sync + Send +where + T: Send, +{ + () //~ ERROR mismatched types +} + +fn other_bounds() -> impl Trait +where + T: Trait, + Vec: Clone, +{ + () //~ ERROR mismatched types +} + +fn main() { + foo(); + bar::<()>(); + other_bounds::<()>(); +} diff --git a/src/test/ui/return/return-impl-trait.rs b/src/test/ui/return/return-impl-trait.rs new file mode 100644 index 0000000000000..e905d712f622d --- /dev/null +++ b/src/test/ui/return/return-impl-trait.rs @@ -0,0 +1,30 @@ +// run-rustfix + +trait Trait {} +impl Trait for () {} + +// this works +fn foo() -> impl Trait { + () +} + +fn bar() -> T +where + T: Send, +{ + () //~ ERROR mismatched types +} + +fn other_bounds() -> T +where + T: Trait, + Vec: Clone, +{ + () //~ ERROR mismatched types +} + +fn main() { + foo(); + bar::<()>(); + other_bounds::<()>(); +} diff --git a/src/test/ui/return/return-impl-trait.stderr b/src/test/ui/return/return-impl-trait.stderr new file mode 100644 index 0000000000000..43d40972fcac0 --- /dev/null +++ b/src/test/ui/return/return-impl-trait.stderr @@ -0,0 +1,34 @@ +error[E0308]: mismatched types + --> $DIR/return-impl-trait.rs:15:5 + | +LL | fn bar() -> T + | - - + | | | + | | expected `T` because of return type + | this type parameter help: consider using an impl return type: `impl Trait + std::marker::Sync + Send` +... +LL | () + | ^^ expected type parameter `T`, found `()` + | + = note: expected type parameter `T` + found unit type `()` + +error[E0308]: mismatched types + --> $DIR/return-impl-trait.rs:23:5 + | +LL | fn other_bounds() -> T + | - - + | | | + | | expected `T` because of return type + | | help: consider using an impl return type: `impl Trait` + | this type parameter +... +LL | () + | ^^ expected type parameter `T`, found `()` + | + = note: expected type parameter `T` + found unit type `()` + +error: aborting due to 2 previous errors + +For more information about this error, try `rustc --explain E0308`.