Skip to content

Commit

Permalink
Simplify IntVarValue/FloatVarValue
Browse files Browse the repository at this point in the history
  • Loading branch information
compiler-errors committed Apr 17, 2024
1 parent 3fba278 commit 6e5acfb
Show file tree
Hide file tree
Showing 9 changed files with 127 additions and 171 deletions.
26 changes: 15 additions & 11 deletions compiler/rustc_infer/src/infer/freshen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
//! inferencer knows "so far".
use super::InferCtxt;
use rustc_data_structures::fx::FxHashMap;
use rustc_middle::infer::unify_key::ToType;
use rustc_middle::ty::fold::TypeFolder;
use rustc_middle::ty::{self, Ty, TyCtxt, TypeFoldable, TypeSuperFoldable, TypeVisitableExt};
use std::collections::hash_map::Entry;
Expand Down Expand Up @@ -203,22 +202,27 @@ impl<'a, 'tcx> TypeFreshener<'a, 'tcx> {

ty::IntVar(v) => {
let mut inner = self.infcx.inner.borrow_mut();
let input = inner
.int_unification_table()
.probe_value(v)
.map(|v| v.to_type(self.infcx.tcx))
.ok_or_else(|| ty::IntVar(inner.int_unification_table().find(v)));
let value = inner.int_unification_table().probe_value(v);
let input = match value {
ty::IntVarValue::IntType(ty) => Ok(Ty::new_int(self.infcx.tcx, ty)),
ty::IntVarValue::UintType(ty) => Ok(Ty::new_uint(self.infcx.tcx, ty)),
ty::IntVarValue::Unknown => {
Err(ty::IntVar(inner.int_unification_table().find(v)))
}
};
drop(inner);
Some(self.freshen_ty(input, |n| Ty::new_fresh_int(self.infcx.tcx, n)))
}

ty::FloatVar(v) => {
let mut inner = self.infcx.inner.borrow_mut();
let input = inner
.float_unification_table()
.probe_value(v)
.map(|v| v.to_type(self.infcx.tcx))
.ok_or_else(|| ty::FloatVar(inner.float_unification_table().find(v)));
let value = inner.float_unification_table().probe_value(v);
let input = match value {
ty::FloatVarValue::Known(ty) => Ok(Ty::new_float(self.infcx.tcx, ty)),
ty::FloatVarValue::Unknown => {
Err(ty::FloatVar(inner.float_unification_table().find(v)))
}
};
drop(inner);
Some(self.freshen_ty(input, |n| Ty::new_fresh_float(self.infcx.tcx, n)))
}
Expand Down
59 changes: 31 additions & 28 deletions compiler/rustc_infer/src/infer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ pub use lexical_region_resolve::RegionResolutionError;
pub use relate::combine::CombineFields;
pub use relate::combine::ObligationEmittingRelation;
pub use relate::StructurallyRelateAliases;
pub use rustc_middle::ty::IntVarValue;
pub use BoundRegionConversionTime::*;
pub use RegionVariableOrigin::*;
pub use SubregionOrigin::*;
Expand All @@ -28,9 +27,9 @@ use rustc_data_structures::unify as ut;
use rustc_errors::{Diag, DiagCtxt, ErrorGuaranteed};
use rustc_hir::def_id::{DefId, LocalDefId};
use rustc_middle::infer::canonical::{Canonical, CanonicalVarValues};
use rustc_middle::infer::unify_key::ConstVariableOrigin;
use rustc_middle::infer::unify_key::ConstVariableValue;
use rustc_middle::infer::unify_key::EffectVarValue;
use rustc_middle::infer::unify_key::{ConstVariableOrigin, ToType};
use rustc_middle::infer::unify_key::{ConstVidKey, EffectVidKey};
use rustc_middle::mir::interpret::{ErrorHandled, EvalToValTreeResult};
use rustc_middle::mir::ConstraintCategory;
Expand Down Expand Up @@ -799,13 +798,13 @@ impl<'tcx> InferCtxt<'tcx> {
vars.extend(
(0..inner.int_unification_table().len())
.map(|i| ty::IntVid::from_u32(i as u32))
.filter(|&vid| inner.int_unification_table().probe_value(vid).is_none())
.filter(|&vid| inner.int_unification_table().probe_value(vid).is_unknown())
.map(|v| Ty::new_int_var(self.tcx, v)),
);
vars.extend(
(0..inner.float_unification_table().len())
.map(|i| ty::FloatVid::from_u32(i as u32))
.filter(|&vid| inner.float_unification_table().probe_value(vid).is_none())
.filter(|&vid| inner.float_unification_table().probe_value(vid).is_unknown())
.map(|v| Ty::new_float_var(self.tcx, v)),
);
vars
Expand Down Expand Up @@ -1041,15 +1040,15 @@ impl<'tcx> InferCtxt<'tcx> {
}

fn next_int_var_id(&self) -> IntVid {
self.inner.borrow_mut().int_unification_table().new_key(None)
self.inner.borrow_mut().int_unification_table().new_key(ty::IntVarValue::Unknown)
}

pub fn next_int_var(&self) -> Ty<'tcx> {
Ty::new_int_var(self.tcx, self.next_int_var_id())
}

fn next_float_var_id(&self) -> FloatVid {
self.inner.borrow_mut().float_unification_table().new_key(None)
self.inner.borrow_mut().float_unification_table().new_key(ty::FloatVarValue::Unknown)
}

pub fn next_float_var(&self) -> Ty<'tcx> {
Expand Down Expand Up @@ -1279,19 +1278,18 @@ impl<'tcx> InferCtxt<'tcx> {
known.map(|t| self.shallow_resolve(t))
}

ty::IntVar(v) => self
.inner
.borrow_mut()
.int_unification_table()
.probe_value(v)
.map(|v| v.to_type(self.tcx)),
ty::IntVar(v) => match self.inner.borrow_mut().int_unification_table().probe_value(v) {
ty::IntVarValue::Unknown => None,
ty::IntVarValue::IntType(ty) => Some(Ty::new_int(self.tcx, ty)),
ty::IntVarValue::UintType(ty) => Some(Ty::new_uint(self.tcx, ty)),
},

ty::FloatVar(v) => self
.inner
.borrow_mut()
.float_unification_table()
.probe_value(v)
.map(|v| v.to_type(self.tcx)),
ty::FloatVar(v) => {
match self.inner.borrow_mut().float_unification_table().probe_value(v) {
ty::FloatVarValue::Unknown => None,
ty::FloatVarValue::Known(ty) => Some(Ty::new_float(self.tcx, ty)),
}
}

ty::FreshTy(_) | ty::FreshIntTy(_) | ty::FreshFloatTy(_) => None,
}
Expand Down Expand Up @@ -1342,21 +1340,26 @@ impl<'tcx> InferCtxt<'tcx> {
/// or else the root int var in the unification table.
pub fn opportunistic_resolve_int_var(&self, vid: ty::IntVid) -> Ty<'tcx> {
let mut inner = self.inner.borrow_mut();
if let Some(value) = inner.int_unification_table().probe_value(vid) {
value.to_type(self.tcx)
} else {
Ty::new_int_var(self.tcx, inner.int_unification_table().find(vid))
let value = inner.int_unification_table().probe_value(vid);
match value {
ty::IntVarValue::IntType(ty) => Ty::new_int(self.tcx, ty),
ty::IntVarValue::UintType(ty) => Ty::new_uint(self.tcx, ty),
ty::IntVarValue::Unknown => {
Ty::new_int_var(self.tcx, inner.int_unification_table().find(vid))
}
}
}

/// Resolves a float var to a rigid int type, if it was constrained to one,
/// or else the root float var in the unification table.
pub fn opportunistic_resolve_float_var(&self, vid: ty::FloatVid) -> Ty<'tcx> {
let mut inner = self.inner.borrow_mut();
if let Some(value) = inner.float_unification_table().probe_value(vid) {
value.to_type(self.tcx)
} else {
Ty::new_float_var(self.tcx, inner.float_unification_table().find(vid))
let value = inner.float_unification_table().probe_value(vid);
match value {
ty::FloatVarValue::Known(ty) => Ty::new_float(self.tcx, ty),
ty::FloatVarValue::Unknown => {
Ty::new_float_var(self.tcx, inner.float_unification_table().find(vid))
}
}
}

Expand Down Expand Up @@ -1667,15 +1670,15 @@ impl<'tcx> InferCtxt<'tcx> {
// If `inlined_probe_value` returns a value it's always a
// `ty::Int(_)` or `ty::UInt(_)`, which never matches a
// `ty::Infer(_)`.
self.inner.borrow_mut().int_unification_table().inlined_probe_value(v).is_some()
!self.inner.borrow_mut().int_unification_table().inlined_probe_value(v).is_unknown()
}

TyOrConstInferVar::TyFloat(v) => {
// If `probe_value` returns a value it's always a
// `ty::Float(_)`, which never matches a `ty::Infer(_)`.
//
// Not `inlined_probe_value(v)` because this call site is colder.
self.inner.borrow_mut().float_unification_table().probe_value(v).is_some()
!self.inner.borrow_mut().float_unification_table().probe_value(v).is_unknown()
}

TyOrConstInferVar::Const(v) => {
Expand Down
79 changes: 19 additions & 60 deletions compiler/rustc_infer/src/infer/relate/combine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use crate::infer::{DefineOpaqueTypes, InferCtxt, TypeTrace};
use crate::traits::{Obligation, PredicateObligations};
use rustc_middle::infer::canonical::OriginalQueryValues;
use rustc_middle::infer::unify_key::EffectVarValue;
use rustc_middle::ty::error::{ExpectedFound, TypeError};
use rustc_middle::ty::error::TypeError;
use rustc_middle::ty::relate::{RelateResult, TypeRelation};
use rustc_middle::ty::{self, InferConst, ToPredicate, Ty, TyCtxt, TypeVisitableExt};
use rustc_middle::ty::{IntType, UintType};
Expand Down Expand Up @@ -57,40 +57,38 @@ impl<'tcx> InferCtxt<'tcx> {
match (a.kind(), b.kind()) {
// Relate integral variables to other types
(&ty::Infer(ty::IntVar(a_id)), &ty::Infer(ty::IntVar(b_id))) => {
self.inner
.borrow_mut()
.int_unification_table()
.unify_var_var(a_id, b_id)
.map_err(|e| int_unification_error(true, e))?;
self.inner.borrow_mut().int_unification_table().union(a_id, b_id);
Ok(a)
}
(&ty::Infer(ty::IntVar(v_id)), &ty::Int(v)) => {
self.unify_integral_variable(true, v_id, IntType(v))
self.unify_integral_variable(v_id, IntType(v));
Ok(b)
}
(&ty::Int(v), &ty::Infer(ty::IntVar(v_id))) => {
self.unify_integral_variable(false, v_id, IntType(v))
self.unify_integral_variable(v_id, IntType(v));
Ok(a)
}
(&ty::Infer(ty::IntVar(v_id)), &ty::Uint(v)) => {
self.unify_integral_variable(true, v_id, UintType(v))
self.unify_integral_variable(v_id, UintType(v));
Ok(b)
}
(&ty::Uint(v), &ty::Infer(ty::IntVar(v_id))) => {
self.unify_integral_variable(false, v_id, UintType(v))
self.unify_integral_variable(v_id, UintType(v));
Ok(a)
}

// Relate floating-point variables to other types
(&ty::Infer(ty::FloatVar(a_id)), &ty::Infer(ty::FloatVar(b_id))) => {
self.inner
.borrow_mut()
.float_unification_table()
.unify_var_var(a_id, b_id)
.map_err(|e| float_unification_error(true, e))?;
self.inner.borrow_mut().float_unification_table().union(a_id, b_id);
Ok(a)
}
(&ty::Infer(ty::FloatVar(v_id)), &ty::Float(v)) => {
self.unify_float_variable(true, v_id, v)
self.unify_float_variable(v_id, ty::FloatVarValue::Known(v));
Ok(b)
}
(&ty::Float(v), &ty::Infer(ty::FloatVar(v_id))) => {
self.unify_float_variable(false, v_id, v)
self.unify_float_variable(v_id, ty::FloatVarValue::Known(v));
Ok(a)
}

// We don't expect `TyVar` or `Fresh*` vars at this point with lazy norm.
Expand Down Expand Up @@ -264,35 +262,12 @@ impl<'tcx> InferCtxt<'tcx> {
}
}

fn unify_integral_variable(
&self,
vid_is_expected: bool,
vid: ty::IntVid,
val: ty::IntVarValue,
) -> RelateResult<'tcx, Ty<'tcx>> {
self.inner
.borrow_mut()
.int_unification_table()
.unify_var_value(vid, Some(val))
.map_err(|e| int_unification_error(vid_is_expected, e))?;
match val {
IntType(v) => Ok(Ty::new_int(self.tcx, v)),
UintType(v) => Ok(Ty::new_uint(self.tcx, v)),
}
fn unify_integral_variable(&self, vid: ty::IntVid, val: ty::IntVarValue) {
self.inner.borrow_mut().int_unification_table().union_value(vid, val);
}

fn unify_float_variable(
&self,
vid_is_expected: bool,
vid: ty::FloatVid,
val: ty::FloatTy,
) -> RelateResult<'tcx, Ty<'tcx>> {
self.inner
.borrow_mut()
.float_unification_table()
.unify_var_value(vid, Some(ty::FloatVarValue(val)))
.map_err(|e| float_unification_error(vid_is_expected, e))?;
Ok(Ty::new_float(self.tcx, val))
fn unify_float_variable(&self, vid: ty::FloatVid, val: ty::FloatVarValue) {
self.inner.borrow_mut().float_unification_table().union_value(vid, val);
}

fn unify_effect_variable(&self, vid: ty::EffectVid, val: ty::Const<'tcx>) -> ty::Const<'tcx> {
Expand Down Expand Up @@ -364,19 +339,3 @@ pub trait ObligationEmittingRelation<'tcx>: TypeRelation<'tcx> {
/// Register `AliasRelate` obligation(s) that both types must be related to each other.
fn register_type_relate_obligation(&mut self, a: Ty<'tcx>, b: Ty<'tcx>);
}

fn int_unification_error<'tcx>(
a_is_expected: bool,
v: (ty::IntVarValue, ty::IntVarValue),
) -> TypeError<'tcx> {
let (a, b) = v;
TypeError::IntMismatch(ExpectedFound::new(a_is_expected, a, b))
}

fn float_unification_error<'tcx>(
a_is_expected: bool,
v: (ty::FloatVarValue, ty::FloatVarValue),
) -> TypeError<'tcx> {
let (ty::FloatVarValue(a), ty::FloatVarValue(b)) = v;
TypeError::FloatMismatch(ExpectedFound::new(a_is_expected, a, b))
}
4 changes: 2 additions & 2 deletions compiler/rustc_infer/src/infer/relate/lattice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ where

let infcx = this.infcx();

let a = infcx.inner.borrow_mut().type_variables().replace_if_possible(a);
let b = infcx.inner.borrow_mut().type_variables().replace_if_possible(b);
let a = infcx.shallow_resolve(a);
let b = infcx.shallow_resolve(b);

match (a.kind(), b.kind()) {
// If one side is known to be a variable and one is not,
Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_infer/src/infer/relate/type_relating.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ impl<'tcx> TypeRelation<'tcx> for TypeRelating<'_, '_, 'tcx> {
}

let infcx = self.fields.infcx;
let a = infcx.inner.borrow_mut().type_variables().replace_if_possible(a);
let b = infcx.inner.borrow_mut().type_variables().replace_if_possible(b);
let a = infcx.shallow_resolve(a);
let b = infcx.shallow_resolve(b);

match (a.kind(), b.kind()) {
(&ty::Infer(TyVar(a_id)), &ty::Infer(TyVar(b_id))) => {
Expand Down
16 changes: 1 addition & 15 deletions compiler/rustc_middle/src/infer/unify_key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,21 +86,6 @@ impl<'tcx> UnifyValue for RegionVariableValue<'tcx> {
}
}

impl ToType for ty::IntVarValue {
fn to_type<'tcx>(&self, tcx: TyCtxt<'tcx>) -> Ty<'tcx> {
match *self {
ty::IntType(i) => Ty::new_int(tcx, i),
ty::UintType(i) => Ty::new_uint(tcx, i),
}
}
}

impl ToType for ty::FloatVarValue {
fn to_type<'tcx>(&self, tcx: TyCtxt<'tcx>) -> Ty<'tcx> {
Ty::new_float(tcx, self.0)
}
}

// Generic consts.

#[derive(Copy, Clone, Debug)]
Expand Down Expand Up @@ -211,6 +196,7 @@ impl<'tcx> EffectVarValue<'tcx> {

impl<'tcx> UnifyValue for EffectVarValue<'tcx> {
type Error = NoError;

fn unify_values(value1: &Self, value2: &Self) -> Result<Self, Self::Error> {
match (*value1, *value2) {
(EffectVarValue::Unknown, EffectVarValue::Unknown) => Ok(EffectVarValue::Unknown),
Expand Down
Loading

0 comments on commit 6e5acfb

Please sign in to comment.