diff --git a/crates/flux-driver/src/callbacks.rs b/crates/flux-driver/src/callbacks.rs index fee101ebb1..85e9d7d743 100644 --- a/crates/flux-driver/src/callbacks.rs +++ b/crates/flux-driver/src/callbacks.rs @@ -228,7 +228,8 @@ impl<'genv, 'tcx> CrateChecker<'genv, 'tcx> { } DefKind::Impl { of_trait } => { if of_trait { - refineck::compare_impl_item::check_impl_against_trait(self.genv, def_id)?; + refineck::compare_impl_item::check_impl_against_trait(self.genv, def_id) + .emit(&self.genv)?; } Ok(()) } diff --git a/crates/flux-infer/src/infer.rs b/crates/flux-infer/src/infer.rs index 7a1bb9cffe..51119cca54 100644 --- a/crates/flux-infer/src/infer.rs +++ b/crates/flux-infer/src/infer.rs @@ -453,16 +453,11 @@ impl<'genv, 'tcx> InferCtxtAt<'_, '_, 'genv, 'tcx> { if let rty::ClauseKind::Projection(projection_pred) = clause.kind_skipping_binder() { let impl_elem = BaseTy::projection(projection_pred.projection_ty) .to_ty() - .normalize_projections( - self.infcx.genv, - self.infcx.region_infcx, - self.infcx.def_id, - )?; - let term = projection_pred.term.to_ty().normalize_projections( - self.infcx.genv, - self.infcx.region_infcx, - self.infcx.def_id, - )?; + .normalize_projections(self.infcx)?; + let term = projection_pred + .term + .to_ty() + .normalize_projections(self.infcx)?; // TODO: does this really need to be invariant? https://github.com/flux-rs/flux/pull/478#issuecomment-1654035374 self.subtyping(&impl_elem, &term, reason)?; @@ -964,7 +959,7 @@ impl<'a, E: LocEnv> Sub<'a, E> { let alias_ty = pred.projection_ty.with_self_ty(bty.to_subset_ty_ctor()); let ty1 = BaseTy::Alias(AliasKind::Projection, alias_ty) .to_ty() - .normalize_projections(infcx.genv, infcx.region_infcx, infcx.def_id)?; + .normalize_projections(infcx)?; let ty2 = pred.term.to_ty(); self.tys(infcx, &ty1, &ty2)?; } diff --git a/crates/flux-infer/src/projections.rs b/crates/flux-infer/src/projections.rs index 9bae4ce9ad..c0f552c9b2 100644 --- a/crates/flux-infer/src/projections.rs +++ b/crates/flux-infer/src/projections.rs @@ -15,59 +15,47 @@ use flux_middle::{ }; use flux_rustc_bridge::{lowering::Lower, ToRustc}; use rustc_hir::def_id::DefId; -use rustc_infer::{infer::InferCtxt, traits::Obligation}; +use rustc_infer::traits::Obligation; use rustc_middle::{ traits::{ImplSource, ObligationCause}, ty::TyCtxt, }; use rustc_trait_selection::traits::SelectionContext; +use crate::infer::InferCtxt; + pub trait NormalizeExt: TypeFoldable { - fn normalize_projections<'tcx>( - &self, - genv: GlobalEnv<'_, 'tcx>, - infcx: &rustc_infer::infer::InferCtxt<'tcx>, - callsite_def_id: DefId, - ) -> QueryResult; + fn normalize_projections<'tcx>(&self, infcx: &mut InferCtxt) -> QueryResult; } impl NormalizeExt for T { - fn normalize_projections<'tcx>( - &self, - genv: GlobalEnv<'_, 'tcx>, - infcx: &rustc_infer::infer::InferCtxt<'tcx>, - callsite_def_id: DefId, - ) -> QueryResult { - let mut normalizer = Normalizer::new(genv, infcx, callsite_def_id)?; + fn normalize_projections<'tcx>(&self, infcx: &mut InferCtxt) -> QueryResult { + let mut normalizer = Normalizer::new(infcx.branch())?; self.erase_regions().try_fold_with(&mut normalizer) } } -struct Normalizer<'genv, 'tcx, 'cx> { - genv: GlobalEnv<'genv, 'tcx>, - selcx: SelectionContext<'cx, 'tcx>, - def_id: DefId, +struct Normalizer<'infcx, 'genv, 'tcx> { + infcx: InferCtxt<'infcx, 'genv, 'tcx>, + selcx: SelectionContext<'infcx, 'tcx>, param_env: List, } -impl<'genv, 'tcx, 'cx> Normalizer<'genv, 'tcx, 'cx> { - fn new( - genv: GlobalEnv<'genv, 'tcx>, - infcx: &'cx InferCtxt<'tcx>, - callsite_def_id: DefId, - ) -> QueryResult { - let param_env = genv - .predicates_of(callsite_def_id)? +impl<'infcx, 'genv, 'tcx> Normalizer<'infcx, 'genv, 'tcx> { + fn new(infcx: InferCtxt<'infcx, 'genv, 'tcx>) -> QueryResult { + let param_env = infcx + .genv + .predicates_of(infcx.def_id)? .instantiate_identity() .predicates .clone(); - let selcx = SelectionContext::new(infcx); - Ok(Normalizer { genv, selcx, def_id: callsite_def_id, param_env }) + let selcx = SelectionContext::new(infcx.region_infcx); + Ok(Normalizer { infcx, selcx, param_env }) } fn get_impl_id_of_alias_reft(&mut self, alias_reft: &AliasReft) -> QueryResult> { let tcx = self.tcx(); - let def_id = self.def_id; + let def_id = self.def_id(); let selcx = &mut self.selcx; let trait_pred = Obligation::new( @@ -90,7 +78,7 @@ impl<'genv, 'tcx, 'cx> Normalizer<'genv, 'tcx, 'cx> { ) -> QueryResult { if let Some(impl_def_id) = self.get_impl_id_of_alias_reft(alias_reft)? { let impl_trait_ref = self - .genv + .genv() .impl_trait_ref(impl_def_id)? .unwrap() .skip_binder(); @@ -105,7 +93,7 @@ impl<'genv, 'tcx, 'cx> Normalizer<'genv, 'tcx, 'cx> { let tcx = self.tcx(); let pred = self - .genv + .genv() .assoc_refinement_def(impl_def_id, alias_reft.name)? .instantiate(tcx, &args, &[]); @@ -125,7 +113,7 @@ impl<'genv, 'tcx, 'cx> Normalizer<'genv, 'tcx, 'cx> { ) -> QueryResult { let projection_ty = obligation.to_rustc(self.tcx()); let cause = ObligationCause::dummy(); - let param_env = self.tcx().param_env(self.def_id); + let param_env = self.rustc_param_env(); let ty = rustc_trait_selection::traits::normalize_projection_ty( &mut self.selcx, @@ -137,7 +125,7 @@ impl<'genv, 'tcx, 'cx> Normalizer<'genv, 'tcx, 'cx> { ) .expect_type(); let rustc_ty = ty.lower(self.tcx()).unwrap(); - Ok(Refiner::default_for_item(self.genv, self.def_id)? + Ok(Refiner::default_for_item(self.genv(), self.def_id())? .refine_ty_or_base(&rustc_ty)? .expect_base()) } @@ -158,7 +146,7 @@ impl<'genv, 'tcx, 'cx> Normalizer<'genv, 'tcx, 'cx> { return Ok((ty != orig_ty, ty)); } if candidates.len() > 1 { - bug!("ambiguity when resolving `{obligation:?}` in {:?}", self.def_id); + bug!("ambiguity when resolving `{obligation:?}` in {:?}", self.def_id()); } let ctor = self.confirm_candidate(candidates.pop().unwrap(), obligation)?; Ok((true, ctor)) @@ -192,7 +180,7 @@ impl<'genv, 'tcx, 'cx> Normalizer<'genv, 'tcx, 'cx> { impl_def_id: DefId, ) -> QueryResult { let mut projection_preds: Vec<_> = self - .genv + .genv() .predicates_of(impl_def_id)? .skip_binder() .predicates @@ -241,7 +229,7 @@ impl<'genv, 'tcx, 'cx> Normalizer<'genv, 'tcx, 'cx> { // => {T -> {v. i32[v] | v > 0}, A -> Global} let impl_trait_ref = self - .genv + .genv() .impl_trait_ref(impl_def_id)? .unwrap() .skip_binder(); @@ -269,7 +257,7 @@ impl<'genv, 'tcx, 'cx> Normalizer<'genv, 'tcx, 'cx> { let tcx = self.tcx(); Ok(self - .genv + .genv() .type_of(assoc_type_id)? .instantiate(tcx, &args, &[]) .expect_subset_ty_ctor()) @@ -299,7 +287,7 @@ impl<'genv, 'tcx, 'cx> Normalizer<'genv, 'tcx, 'cx> { && let BaseTy::Alias(AliasKind::Opaque, alias_ty) = ctor.as_bty_skipping_binder() { debug_assert!(!alias_ty.has_escaping_bvars()); - let bounds = self.genv.item_bounds(alias_ty.def_id)?.instantiate( + let bounds = self.genv().item_bounds(alias_ty.def_id)?.instantiate( self.tcx(), &alias_ty.args, &alias_ty.refine_args, @@ -335,12 +323,20 @@ impl<'genv, 'tcx, 'cx> Normalizer<'genv, 'tcx, 'cx> { Ok(()) } + fn def_id(&self) -> DefId { + self.infcx.def_id + } + + fn genv(&self) -> GlobalEnv<'genv, 'tcx> { + self.infcx.genv + } + fn tcx(&self) -> TyCtxt<'tcx> { self.selcx.tcx() } fn rustc_param_env(&self) -> rustc_middle::ty::ParamEnv<'tcx> { - self.selcx.tcx().param_env(self.def_id) + self.selcx.tcx().param_env(self.def_id()) } } @@ -367,7 +363,7 @@ impl FallibleTypeFolder for Normalizer<'_, '_, '_> { fn try_fold_sort(&mut self, sort: &Sort) -> Result { match sort { Sort::Alias(AliasKind::Weak, alias_ty) => { - self.genv + self.genv() .normalize_weak_alias_sort(alias_ty)? .try_fold_with(self) } @@ -393,9 +389,9 @@ impl FallibleTypeFolder for Normalizer<'_, '_, '_> { match ty.kind() { TyKind::Indexed(BaseTy::Alias(AliasKind::Weak, alias_ty), idx) => { Ok(self - .genv + .genv() .type_of(alias_ty.def_id)? - .instantiate(self.genv.tcx(), &alias_ty.args, &alias_ty.refine_args) + .instantiate(self.tcx(), &alias_ty.args, &alias_ty.refine_args) .expect_ctor() .replace_bound_reft(idx)) } @@ -446,7 +442,7 @@ impl FallibleTypeFolder for Normalizer<'_, '_, '_> { c.to_rustc(self.tcx()) .normalize_internal(self.tcx(), self.rustc_param_env()) .lower(self.tcx()) - .map_err(|e| QueryErr::unsupported(self.def_id, e.into_err())) + .map_err(|e| QueryErr::unsupported(self.def_id(), e.into_err())) } } diff --git a/crates/flux-refineck/src/checker.rs b/crates/flux-refineck/src/checker.rs index a24808c757..46e33373e0 100644 --- a/crates/flux-refineck/src/checker.rs +++ b/crates/flux-refineck/src/checker.rs @@ -162,12 +162,12 @@ impl<'ck, 'genv, 'tcx> Checker<'ck, 'genv, 'tcx, ShapeMode> { let inherited = Inherited::new(&mut mode, ghost_stmts)?; let body = genv.mir(local_id).with_span(span)?; - let infcx = root_ctxt.infcx(def_id, &body.infcx); + let mut infcx = root_ctxt.infcx(def_id, &body.infcx); let poly_sig = genv .fn_sig(local_id) .with_span(span)? .instantiate_identity() - .normalize_projections(infcx.genv, infcx.region_infcx, infcx.def_id) + .normalize_projections(&mut infcx) .with_span(span)?; Checker::run(infcx, local_id, inherited, poly_sig)?; @@ -194,12 +194,12 @@ impl<'ck, 'genv, 'tcx> Checker<'ck, 'genv, 'tcx, RefineMode> { let mut mode = RefineMode { bb_envs }; let inherited = Inherited::new(&mut mode, ghost_stmts)?; let body = genv.mir(local_id).with_span(span)?; - let infcx = root_ctxt.infcx(def_id, &body.infcx); + let mut infcx = root_ctxt.infcx(def_id, &body.infcx); let poly_sig = genv .fn_sig(def_id) .with_span(span)? .instantiate_identity() - .normalize_projections(infcx.genv, infcx.region_infcx, infcx.def_id) + .normalize_projections(&mut infcx) .with_span(span)?; Checker::run(infcx, local_id, inherited, poly_sig)?; @@ -234,7 +234,7 @@ fn check_fn_subtyping( let super_sig = super_sig .replace_bound_vars(|_| rty::ReErased, |sort, _| infcx.define_vars(sort)) - .normalize_projections(infcx.genv, infcx.region_infcx, *def_id)?; + .normalize_projections(&mut infcx)?; // 1. Unpack `T_g` input types let actuals = super_sig @@ -253,7 +253,7 @@ fn check_fn_subtyping( let sub_sig = sub_sig.instantiate(tcx, sub_args, &refine_args); let sub_sig = sub_sig .replace_bound_vars(|_| rty::ReErased, |sort, mode| infcx.fresh_infer_var(sort, mode)) - .normalize_projections(infcx.genv, infcx.region_infcx, *def_id)?; + .normalize_projections(infcx)?; // 3. INPUT subtyping (g-input <: f-input) for requires in super_sig.requires() { @@ -310,8 +310,9 @@ pub(crate) fn trait_impl_subtyping<'genv, 'tcx>( let Some((impl_trait_ref, trait_method_id)) = find_trait_item(genv, def_id)? else { return Ok(None); }; + let impl_method_id = def_id.to_def_id(); // Skip the check if either the trait-method or the impl-method are marked as `trusted_impl` - if genv.has_trusted_impl(trait_method_id) || genv.has_trusted_impl(def_id.to_def_id()) { + if genv.has_trusted_impl(trait_method_id) || genv.has_trusted_impl(impl_method_id) { return Ok(None); } @@ -328,13 +329,13 @@ pub(crate) fn trait_impl_subtyping<'genv, 'tcx>( .tcx() .infer_ctxt() .build(TypingMode::non_body_analysis()); - let mut infcx = root_ctxt.infcx(trait_method_id, &rustc_infcx); + let mut infcx = root_ctxt.infcx(impl_method_id, &rustc_infcx); let trait_fn_sig = genv.fn_sig(trait_method_id)?; - let impl_sig = genv.fn_sig(def_id)?; + let impl_sig = genv.fn_sig(impl_method_id)?; check_fn_subtyping( &mut infcx, - &def_id.to_def_id(), + &impl_method_id, impl_sig, &impl_args, &trait_fn_sig.instantiate(tcx, &trait_args, &trait_refine_args), @@ -422,7 +423,7 @@ impl<'ck, 'genv, 'tcx, M: Mode> Checker<'ck, 'genv, 'tcx, M> { let fn_sig = poly_sig .replace_bound_vars(|_| rty::ReErased, |sort, _| infcx.define_vars(sort)) - .normalize_projections(infcx.genv, infcx.region_infcx, infcx.def_id) + .normalize_projections(&mut infcx) .with_span(span)?; let mut env = TypeEnv::new(&mut infcx, &body, &fn_sig); @@ -782,7 +783,7 @@ impl<'ck, 'genv, 'tcx, M: Mode> Checker<'ck, 'genv, 'tcx, M> { let fn_sig = fn_sig .instantiate(tcx, &generic_args, &refine_args) .replace_bound_vars(|_| rty::ReErased, |sort, mode| infcx.fresh_infer_var(sort, mode)) - .normalize_projections(genv, infcx.region_infcx, infcx.def_id) + .normalize_projections(infcx) .with_span(span)?; let mut at = infcx.at(span); diff --git a/crates/flux-refineck/src/compare_impl_item.rs b/crates/flux-refineck/src/compare_impl_item.rs index 90086326b9..1923704973 100644 --- a/crates/flux-refineck/src/compare_impl_item.rs +++ b/crates/flux-refineck/src/compare_impl_item.rs @@ -1,112 +1,111 @@ -use flux_common::result::ResultExt; -use flux_infer::projections::NormalizeExt as _; -use flux_middle::{def_id_to_string, global_env::GlobalEnv, MaybeExternId}; +use flux_common::result::ErrorEmitter; +use flux_infer::{ + infer::{GlobalEnvExt as _, InferCtxt}, + projections::NormalizeExt as _, +}; +use flux_middle::{ + def_id_to_string, global_env::GlobalEnv, queries::QueryResult, rty::TraitRef, MaybeExternId, +}; use rustc_hash::FxHashSet; use rustc_infer::infer::TyCtxtInferExt; use rustc_middle::ty::TypingMode; -use rustc_span::{def_id::DefId, ErrorGuaranteed, Symbol}; -type Result = std::result::Result; +use rustc_span::{def_id::DefId, Symbol}; -pub fn check_impl_against_trait(genv: GlobalEnv, impl_id: MaybeExternId) -> Result { +pub fn check_impl_against_trait(genv: GlobalEnv, impl_id: MaybeExternId) -> QueryResult { let trait_id = genv.tcx().trait_id_of_impl(impl_id.resolved_id()).unwrap(); - let impl_assoc_refts = genv.assoc_refinements_of(impl_id).emit(&genv)?; - let trait_assoc_refts = genv.assoc_refinements_of(trait_id).emit(&genv)?; + let impl_assoc_refts = genv.assoc_refinements_of(impl_id)?; + let trait_assoc_refts = genv.assoc_refinements_of(trait_id)?; let impl_names: FxHashSet<_> = impl_assoc_refts.items.iter().map(|x| x.name).collect(); for trait_assoc_reft in &trait_assoc_refts.items { let name = trait_assoc_reft.name; - let has_default = genv - .default_assoc_refinement_def(trait_id, name) - .emit(&genv)? - .is_some(); + let has_default = genv.default_assoc_refinement_def(trait_id, name)?.is_some(); if !impl_names.contains(&name) && !has_default { let span = genv.tcx().def_span(impl_id); - return Err(genv.sess().emit_err(errors::MissingAssocReft::new( - span, - name, - def_id_to_string(trait_id), - ))); + Err(genv.emit(errors::MissingAssocReft::new(span, name, def_id_to_string(trait_id))))?; } } + let impl_trait_ref = genv + .impl_trait_ref(impl_id.resolved_id())? + .unwrap() + .instantiate_identity(); + + let mut root_ctxt = genv + .infcx_root(trait_id, genv.infer_opts(impl_id.local_id())) + .with_generic_args(&impl_trait_ref.args) + .build()?; + let rustc_infcx = genv + .tcx() + .infer_ctxt() + .build(TypingMode::non_body_analysis()); + let mut infcx = root_ctxt.infcx(trait_id, &rustc_infcx); + for impl_assoc_reft in &impl_assoc_refts.items { let name = impl_assoc_reft.name; if trait_assoc_refts.find(name).is_none() { let fhir_impl_assoc_reft = genv .map() - .expect_item(impl_id.local_id()) - .emit(&genv)? + .expect_item(impl_id.local_id())? .expect_impl() .find_assoc_reft(name) .unwrap(); - return Err(genv.sess().emit_err(errors::InvalidAssocReft::new( + Err(genv.emit(errors::InvalidAssocReft::new( fhir_impl_assoc_reft.span, name, def_id_to_string(trait_id), - ))); + )))?; } - check_assoc_reft(genv, impl_id, trait_id, impl_assoc_reft.name)?; + check_assoc_reft(&mut infcx, impl_id, &impl_trait_ref, trait_id, impl_assoc_reft.name)?; } Ok(()) } fn check_assoc_reft( - genv: GlobalEnv, + infcx: &mut InferCtxt, impl_id: MaybeExternId, + impl_trait_ref: &TraitRef, trait_id: DefId, name: Symbol, -) -> Result { - let infcx = genv - .tcx() - .infer_ctxt() - .build(TypingMode::non_body_analysis()); - - let impl_span = genv +) -> QueryResult { + let impl_span = infcx + .genv .map() - .expect_item(impl_id.local_id()) - .emit(&genv)? + .expect_item(impl_id.local_id())? .expect_impl() .find_assoc_reft(name) .unwrap() .span; - let impl_trait_ref = genv - .impl_trait_ref(impl_id.resolved_id()) - .emit(&genv)? - .unwrap() - .instantiate_identity(); - - let Some(impl_sort) = genv.sort_of_assoc_reft(impl_id, name).emit(genv.sess())? else { - return Err(genv.sess().emit_err(errors::InvalidAssocReft::new( + let Some(impl_sort) = infcx.genv.sort_of_assoc_reft(impl_id, name)? else { + return Err(infcx.genv.emit(errors::InvalidAssocReft::new( impl_span, name, def_id_to_string(trait_id), - ))); + )))?; }; let impl_sort = impl_sort .instantiate_identity() - .normalize_projections(genv, &infcx, impl_id.resolved_id()) - .emit(&genv)?; + .normalize_projections(infcx)?; - let Some(trait_sort) = genv.sort_of_assoc_reft(trait_id, name).emit(genv.sess())? else { - return Err(genv.sess().emit_err(errors::InvalidAssocReft::new( + let Some(trait_sort) = infcx.genv.sort_of_assoc_reft(trait_id, name)? else { + return Err(infcx.genv.emit(errors::InvalidAssocReft::new( impl_span, name, def_id_to_string(trait_id), - ))); + )))?; }; let trait_sort = trait_sort - .instantiate(genv.tcx(), &impl_trait_ref.args, &[]) - .normalize_projections(genv, &infcx, impl_id.resolved_id()) - .emit(&genv)?; + .instantiate(infcx.tcx(), &impl_trait_ref.args, &[]) + .normalize_projections(infcx)?; if impl_sort != trait_sort { - return Err(genv - .sess() - .emit_err(errors::IncompatibleSort::new(impl_span, name, trait_sort, impl_sort))); + return Err(infcx + .genv + .emit(errors::IncompatibleSort::new(impl_span, name, trait_sort, impl_sort)))?; } Ok(()) diff --git a/crates/flux-refineck/src/type_env/place_ty.rs b/crates/flux-refineck/src/type_env/place_ty.rs index 87e6139977..1d9133957d 100644 --- a/crates/flux-refineck/src/type_env/place_ty.rs +++ b/crates/flux-refineck/src/type_env/place_ty.rs @@ -781,7 +781,7 @@ fn downcast( /// * `x.fld : T[A := t ..][i := e...]` /// i.e. by substituting the type and value indices using the types and values from `x`. fn downcast_struct( - infcx: &InferCtxt, + infcx: &mut InferCtxt, adt: &AdtDef, args: &[GenericArg], idx: &Expr, @@ -795,7 +795,7 @@ fn downcast_struct( Ok(struct_variant(infcx.genv, adt.did())? .instantiate(tcx, args, &[]) .replace_bound_refts(&flds) - .normalize_projections(infcx.genv, infcx.region_infcx, infcx.def_id)? + .normalize_projections(infcx)? .fields .to_vec()) } @@ -829,7 +829,7 @@ fn downcast_enum( .expect("enums cannot be opaque") .instantiate(tcx, args, &[]) .replace_bound_refts_with(|sort, _, _| infcx.define_vars(sort)) - .normalize_projections(infcx.genv, infcx.region_infcx, infcx.def_id)?; + .normalize_projections(infcx)?; // FIXME(nilehmann) We could assert idx1 == variant_def.idx directly, but for aggregate sorts there // are currently two problems.