Skip to content

Commit

Permalink
Rollup merge of rust-lang#95179 - b-naber:eval-in-try-unify, r=lcnr
Browse files Browse the repository at this point in the history
Try to evaluate in try unify and postpone resolution of constants that contain inference variables

We want code like that in [`ui/const-generics/generic_const_exprs/eval-try-unify.rs`](https://github.com/rust-lang/rust/compare/master...b-naber:eval-in-try-unify?expand=1#diff-8027038201cf07a6c96abf3cbf0b0f4fdd8a64ce6292435f01c8ed995b87fe9b) to compile. To do that we need to try to evaluate constants in `try_unify_abstract_consts`, this requires us to be more careful about what constants we try to resolve, specifically we cannot try to resolve constants that still contain inference variables.

r? `@lcnr`
  • Loading branch information
Dylan-DPC authored Mar 24, 2022
2 parents 5dd844f + 11a70db commit bc1d9df
Show file tree
Hide file tree
Showing 13 changed files with 277 additions and 281 deletions.
28 changes: 22 additions & 6 deletions compiler/rustc_infer/src/infer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ use rustc_hir::def_id::{DefId, LocalDefId};
use rustc_middle::infer::canonical::{Canonical, CanonicalVarValues};
use rustc_middle::infer::unify_key::{ConstVarValue, ConstVariableValue};
use rustc_middle::infer::unify_key::{ConstVariableOrigin, ConstVariableOriginKind, ToType};
use rustc_middle::mir::interpret::ErrorHandled;
use rustc_middle::mir::interpret::EvalToConstValueResult;
use rustc_middle::mir::interpret::{ErrorHandled, EvalToConstValueResult};
use rustc_middle::traits::select;
use rustc_middle::ty::error::{ExpectedFound, TypeError};
use rustc_middle::ty::fold::{TypeFoldable, TypeFolder};
Expand Down Expand Up @@ -71,7 +70,6 @@ mod sub;
pub mod type_variable;
mod undo_log;

use crate::infer::canonical::OriginalQueryValues;
pub use rustc_middle::infer::unify_key;

#[must_use]
Expand Down Expand Up @@ -687,15 +685,28 @@ pub struct CombinedSnapshot<'a, 'tcx> {
impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
/// calls `tcx.try_unify_abstract_consts` after
/// canonicalizing the consts.
#[instrument(skip(self), level = "debug")]
pub fn try_unify_abstract_consts(
&self,
a: ty::Unevaluated<'tcx, ()>,
b: ty::Unevaluated<'tcx, ()>,
param_env: ty::ParamEnv<'tcx>,
) -> bool {
let canonical = self.canonicalize_query((a, b), &mut OriginalQueryValues::default());
debug!("canonical consts: {:?}", &canonical.value);
// Reject any attempt to unify two unevaluated constants that contain inference
// variables, since inference variables in queries lead to ICEs.
if a.substs.has_infer_types_or_consts()
|| b.substs.has_infer_types_or_consts()
|| param_env.has_infer_types_or_consts()
{
debug!("a or b or param_env contain infer vars in its substs -> cannot unify");
return false;
}

let param_env_and = param_env.and((a, b));
let erased = self.tcx.erase_regions(param_env_and);
debug!("after erase_regions: {:?}", erased);

self.tcx.try_unify_abstract_consts(canonical.value)
self.tcx.try_unify_abstract_consts(erased)
}

pub fn is_in_snapshot(&self) -> bool {
Expand Down Expand Up @@ -1598,22 +1609,27 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
///
/// This handles inferences variables within both `param_env` and `substs` by
/// performing the operation on their respective canonical forms.
#[instrument(skip(self), level = "debug")]
pub fn const_eval_resolve(
&self,
param_env: ty::ParamEnv<'tcx>,
unevaluated: ty::Unevaluated<'tcx>,
span: Option<Span>,
) -> EvalToConstValueResult<'tcx> {
let substs = self.resolve_vars_if_possible(unevaluated.substs);
debug!(?substs);

// Postpone the evaluation of constants whose substs depend on inference
// variables
if substs.has_infer_types_or_consts() {
debug!("substs have infer types or consts: {:?}", substs);
return Err(ErrorHandled::TooGeneric);
}

let param_env_erased = self.tcx.erase_regions(param_env);
let substs_erased = self.tcx.erase_regions(substs);
debug!(?param_env_erased);
debug!(?substs_erased);

let unevaluated = ty::Unevaluated {
def: unevaluated.def,
Expand Down
11 changes: 11 additions & 0 deletions compiler/rustc_middle/src/mir/interpret/queries.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use super::{ErrorHandled, EvalToConstValueResult, GlobalId};

use crate::mir;
use crate::ty::fold::TypeFoldable;
use crate::ty::subst::InternalSubsts;
use crate::ty::{self, TyCtxt};
use rustc_hir::def_id::DefId;
Expand Down Expand Up @@ -38,6 +39,16 @@ impl<'tcx> TyCtxt<'tcx> {
ct: ty::Unevaluated<'tcx>,
span: Option<Span>,
) -> EvalToConstValueResult<'tcx> {
// Cannot resolve `Unevaluated` constants that contain inference
// variables. We reject those here since `resolve_opt_const_arg`
// would fail otherwise.
//
// When trying to evaluate constants containing inference variables,
// use `Infcx::const_eval_resolve` instead.
if ct.substs.has_infer_types_or_consts() {
bug!("did not expect inference variables here");
}

match ty::Instance::resolve_opt_const_arg(self, param_env, ct.def, ct.substs) {
Ok(Some(instance)) => {
let cid = GlobalId { instance, promoted: ct.promoted };
Expand Down
8 changes: 4 additions & 4 deletions compiler/rustc_middle/src/query/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -331,12 +331,12 @@ rustc_queries! {
}
}

query try_unify_abstract_consts(key: (
ty::Unevaluated<'tcx, ()>, ty::Unevaluated<'tcx, ()>
)) -> bool {
query try_unify_abstract_consts(key:
ty::ParamEnvAnd<'tcx, (ty::Unevaluated<'tcx, ()>, ty::Unevaluated<'tcx, ()>
)>) -> bool {
desc {
|tcx| "trying to unify the generic constants {} and {}",
tcx.def_path_str(key.0.def.did), tcx.def_path_str(key.1.def.did)
tcx.def_path_str(key.value.0.def.did), tcx.def_path_str(key.value.1.def.did)
}
}

Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_middle/src/ty/relate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,7 @@ pub fn super_relate_consts<'tcx, R: TypeRelation<'tcx>>(
(ty::ConstKind::Unevaluated(au), ty::ConstKind::Unevaluated(bu))
if tcx.features().generic_const_exprs =>
{
tcx.try_unify_abstract_consts((au.shrink(), bu.shrink()))
tcx.try_unify_abstract_consts(relation.param_env().and((au.shrink(), bu.shrink())))
}

// While this is slightly incorrect, it shouldn't matter for `min_const_generics`
Expand Down
187 changes: 110 additions & 77 deletions compiler/rustc_trait_selection/src/traits/const_evaluatable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ pub fn is_const_evaluatable<'cx, 'tcx>(
}
}

#[instrument(skip(tcx), level = "debug")]
fn satisfied_from_param_env<'tcx>(
tcx: TyCtxt<'tcx>,
ct: AbstractConst<'tcx>,
Expand All @@ -197,14 +198,17 @@ fn satisfied_from_param_env<'tcx>(
match pred.kind().skip_binder() {
ty::PredicateKind::ConstEvaluatable(uv) => {
if let Some(b_ct) = AbstractConst::new(tcx, uv)? {
let const_unify_ctxt = ConstUnifyCtxt { tcx, param_env };

// Try to unify with each subtree in the AbstractConst to allow for
// `N + 1` being const evaluatable even if theres only a `ConstEvaluatable`
// predicate for `(N + 1) * 2`
let result =
walk_abstract_const(tcx, b_ct, |b_ct| match try_unify(tcx, ct, b_ct) {
let result = walk_abstract_const(tcx, b_ct, |b_ct| {
match const_unify_ctxt.try_unify(ct, b_ct) {
true => ControlFlow::BREAK,
false => ControlFlow::CONTINUE,
});
}
});

if let ControlFlow::Break(()) = result {
debug!("is_const_evaluatable: abstract_const ~~> ok");
Expand Down Expand Up @@ -637,11 +641,13 @@ pub(super) fn thir_abstract_const<'tcx>(
pub(super) fn try_unify_abstract_consts<'tcx>(
tcx: TyCtxt<'tcx>,
(a, b): (ty::Unevaluated<'tcx, ()>, ty::Unevaluated<'tcx, ()>),
param_env: ty::ParamEnv<'tcx>,
) -> bool {
(|| {
if let Some(a) = AbstractConst::new(tcx, a)? {
if let Some(b) = AbstractConst::new(tcx, b)? {
return Ok(try_unify(tcx, a, b));
let const_unify_ctxt = ConstUnifyCtxt { tcx, param_env };
return Ok(const_unify_ctxt.try_unify(a, b));
}
}

Expand Down Expand Up @@ -689,88 +695,115 @@ where
recurse(tcx, ct, &mut f)
}

/// Tries to unify two abstract constants using structural equality.
pub(super) fn try_unify<'tcx>(
struct ConstUnifyCtxt<'tcx> {
tcx: TyCtxt<'tcx>,
mut a: AbstractConst<'tcx>,
mut b: AbstractConst<'tcx>,
) -> bool {
// We substitute generics repeatedly to allow AbstractConsts to unify where a
param_env: ty::ParamEnv<'tcx>,
}

impl<'tcx> ConstUnifyCtxt<'tcx> {
// Substitutes generics repeatedly to allow AbstractConsts to unify where a
// ConstKind::Unevalated could be turned into an AbstractConst that would unify e.g.
// Param(N) should unify with Param(T), substs: [Unevaluated("T2", [Unevaluated("T3", [Param(N)])])]
while let Node::Leaf(a_ct) = a.root(tcx) {
match AbstractConst::from_const(tcx, a_ct) {
Ok(Some(a_act)) => a = a_act,
Ok(None) => break,
Err(_) => return true,
}
}
while let Node::Leaf(b_ct) = b.root(tcx) {
match AbstractConst::from_const(tcx, b_ct) {
Ok(Some(b_act)) => b = b_act,
Ok(None) => break,
Err(_) => return true,
#[inline]
#[instrument(skip(self), level = "debug")]
fn try_replace_substs_in_root(
&self,
mut abstr_const: AbstractConst<'tcx>,
) -> Option<AbstractConst<'tcx>> {
while let Node::Leaf(ct) = abstr_const.root(self.tcx) {
match AbstractConst::from_const(self.tcx, ct) {
Ok(Some(act)) => abstr_const = act,
Ok(None) => break,
Err(_) => return None,
}
}
}

match (a.root(tcx), b.root(tcx)) {
(Node::Leaf(a_ct), Node::Leaf(b_ct)) => {
if a_ct.ty() != b_ct.ty() {
return false;
}
Some(abstr_const)
}

match (a_ct.val(), b_ct.val()) {
// We can just unify errors with everything to reduce the amount of
// emitted errors here.
(ty::ConstKind::Error(_), _) | (_, ty::ConstKind::Error(_)) => true,
(ty::ConstKind::Param(a_param), ty::ConstKind::Param(b_param)) => {
a_param == b_param
/// Tries to unify two abstract constants using structural equality.
#[instrument(skip(self), level = "debug")]
fn try_unify(&self, a: AbstractConst<'tcx>, b: AbstractConst<'tcx>) -> bool {
let a = if let Some(a) = self.try_replace_substs_in_root(a) {
a
} else {
return true;
};

let b = if let Some(b) = self.try_replace_substs_in_root(b) {
b
} else {
return true;
};

let a_root = a.root(self.tcx);
let b_root = b.root(self.tcx);
debug!(?a_root, ?b_root);

match (a_root, b_root) {
(Node::Leaf(a_ct), Node::Leaf(b_ct)) => {
let a_ct = a_ct.eval(self.tcx, self.param_env);
debug!("a_ct evaluated: {:?}", a_ct);
let b_ct = b_ct.eval(self.tcx, self.param_env);
debug!("b_ct evaluated: {:?}", b_ct);

if a_ct.ty() != b_ct.ty() {
return false;
}
(ty::ConstKind::Value(a_val), ty::ConstKind::Value(b_val)) => a_val == b_val,
// If we have `fn a<const N: usize>() -> [u8; N + 1]` and `fn b<const M: usize>() -> [u8; 1 + M]`
// we do not want to use `assert_eq!(a(), b())` to infer that `N` and `M` have to be `1`. This
// means that we only allow inference variables if they are equal.
(ty::ConstKind::Infer(a_val), ty::ConstKind::Infer(b_val)) => a_val == b_val,
// We expand generic anonymous constants at the start of this function, so this
// branch should only be taking when dealing with associated constants, at
// which point directly comparing them seems like the desired behavior.
//
// FIXME(generic_const_exprs): This isn't actually the case.
// We also take this branch for concrete anonymous constants and
// expand generic anonymous constants with concrete substs.
(ty::ConstKind::Unevaluated(a_uv), ty::ConstKind::Unevaluated(b_uv)) => {
a_uv == b_uv

match (a_ct.val(), b_ct.val()) {
// We can just unify errors with everything to reduce the amount of
// emitted errors here.
(ty::ConstKind::Error(_), _) | (_, ty::ConstKind::Error(_)) => true,
(ty::ConstKind::Param(a_param), ty::ConstKind::Param(b_param)) => {
a_param == b_param
}
(ty::ConstKind::Value(a_val), ty::ConstKind::Value(b_val)) => a_val == b_val,
// If we have `fn a<const N: usize>() -> [u8; N + 1]` and `fn b<const M: usize>() -> [u8; 1 + M]`
// we do not want to use `assert_eq!(a(), b())` to infer that `N` and `M` have to be `1`. This
// means that we only allow inference variables if they are equal.
(ty::ConstKind::Infer(a_val), ty::ConstKind::Infer(b_val)) => a_val == b_val,
// We expand generic anonymous constants at the start of this function, so this
// branch should only be taking when dealing with associated constants, at
// which point directly comparing them seems like the desired behavior.
//
// FIXME(generic_const_exprs): This isn't actually the case.
// We also take this branch for concrete anonymous constants and
// expand generic anonymous constants with concrete substs.
(ty::ConstKind::Unevaluated(a_uv), ty::ConstKind::Unevaluated(b_uv)) => {
a_uv == b_uv
}
// FIXME(generic_const_exprs): We may want to either actually try
// to evaluate `a_ct` and `b_ct` if they are are fully concrete or something like
// this, for now we just return false here.
_ => false,
}
// FIXME(generic_const_exprs): We may want to either actually try
// to evaluate `a_ct` and `b_ct` if they are are fully concrete or something like
// this, for now we just return false here.
_ => false,
}
(Node::Binop(a_op, al, ar), Node::Binop(b_op, bl, br)) if a_op == b_op => {
self.try_unify(a.subtree(al), b.subtree(bl))
&& self.try_unify(a.subtree(ar), b.subtree(br))
}
(Node::UnaryOp(a_op, av), Node::UnaryOp(b_op, bv)) if a_op == b_op => {
self.try_unify(a.subtree(av), b.subtree(bv))
}
(Node::FunctionCall(a_f, a_args), Node::FunctionCall(b_f, b_args))
if a_args.len() == b_args.len() =>
{
self.try_unify(a.subtree(a_f), b.subtree(b_f))
&& iter::zip(a_args, b_args)
.all(|(&an, &bn)| self.try_unify(a.subtree(an), b.subtree(bn)))
}
(Node::Cast(a_kind, a_operand, a_ty), Node::Cast(b_kind, b_operand, b_ty))
if (a_ty == b_ty) && (a_kind == b_kind) =>
{
self.try_unify(a.subtree(a_operand), b.subtree(b_operand))
}
// use this over `_ => false` to make adding variants to `Node` less error prone
(Node::Cast(..), _)
| (Node::FunctionCall(..), _)
| (Node::UnaryOp(..), _)
| (Node::Binop(..), _)
| (Node::Leaf(..), _) => false,
}
(Node::Binop(a_op, al, ar), Node::Binop(b_op, bl, br)) if a_op == b_op => {
try_unify(tcx, a.subtree(al), b.subtree(bl))
&& try_unify(tcx, a.subtree(ar), b.subtree(br))
}
(Node::UnaryOp(a_op, av), Node::UnaryOp(b_op, bv)) if a_op == b_op => {
try_unify(tcx, a.subtree(av), b.subtree(bv))
}
(Node::FunctionCall(a_f, a_args), Node::FunctionCall(b_f, b_args))
if a_args.len() == b_args.len() =>
{
try_unify(tcx, a.subtree(a_f), b.subtree(b_f))
&& iter::zip(a_args, b_args)
.all(|(&an, &bn)| try_unify(tcx, a.subtree(an), b.subtree(bn)))
}
(Node::Cast(a_kind, a_operand, a_ty), Node::Cast(b_kind, b_operand, b_ty))
if (a_ty == b_ty) && (a_kind == b_kind) =>
{
try_unify(tcx, a.subtree(a_operand), b.subtree(b_operand))
}
// use this over `_ => false` to make adding variants to `Node` less error prone
(Node::Cast(..), _)
| (Node::FunctionCall(..), _)
| (Node::UnaryOp(..), _)
| (Node::Binop(..), _)
| (Node::Leaf(..), _) => false,
}
}
6 changes: 5 additions & 1 deletion compiler/rustc_trait_selection/src/traits/fulfill.rs
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,11 @@ impl<'a, 'b, 'tcx> FulfillProcessor<'a, 'b, 'tcx> {
if let (ty::ConstKind::Unevaluated(a), ty::ConstKind::Unevaluated(b)) =
(c1.val(), c2.val())
{
if infcx.try_unify_abstract_consts(a.shrink(), b.shrink()) {
if infcx.try_unify_abstract_consts(
a.shrink(),
b.shrink(),
obligation.param_env,
) {
return ProcessResult::Changed(vec![]);
}
}
Expand Down
5 changes: 4 additions & 1 deletion compiler/rustc_trait_selection/src/traits/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -862,7 +862,10 @@ pub fn provide(providers: &mut ty::query::Providers) {
ty::WithOptConstParam { did, const_param_did: Some(param_did) },
)
},
try_unify_abstract_consts: const_evaluatable::try_unify_abstract_consts,
try_unify_abstract_consts: |tcx, param_env_and| {
let (param_env, (a, b)) = param_env_and.into_parts();
const_evaluatable::try_unify_abstract_consts(tcx, (a, b), param_env)
},
..*providers
};
}
6 changes: 5 additions & 1 deletion compiler/rustc_trait_selection/src/traits/select/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -639,7 +639,11 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
if let (ty::ConstKind::Unevaluated(a), ty::ConstKind::Unevaluated(b)) =
(c1.val(), c2.val())
{
if self.infcx.try_unify_abstract_consts(a.shrink(), b.shrink()) {
if self.infcx.try_unify_abstract_consts(
a.shrink(),
b.shrink(),
obligation.param_env,
) {
return Ok(EvaluatedToOk);
}
}
Expand Down
Loading

0 comments on commit bc1d9df

Please sign in to comment.