Skip to content

Commit

Permalink
Rudimentary support for dynamic trait objects (#664)
Browse files Browse the repository at this point in the history
* add stuff to generic_const test

* next: onto lower type

* next: into constraint_gen

* yay, dyn00 works (doesn't keel over)

* yay, dyn01 checks/fails as expected

* allow region vars in dyn

* Update crates/flux-middle/src/fhir/lift.rs

* use Dynamic instead of PolyTraitRef

* remove DynKind (only support Dyn) else panic

* implement to_rustc for Dynamic/ExistentialPredicates

* Update crates/flux-middle/src/rustc/lowering.rs

* check dyn exi-preds are equal during subtyping

* use span_bug

* dont fill lifetimes in fill_generic_args

Co-authored-by: Nico Lehmann <[email protected]>
  • Loading branch information
ranjitjhala and nilehmann authored Jul 29, 2024
1 parent cf6ff67 commit aed06a7
Show file tree
Hide file tree
Showing 21 changed files with 400 additions and 27 deletions.
7 changes: 7 additions & 0 deletions crates/flux-fhir-analysis/src/conv/fill_holes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,13 @@ impl<'genv, 'tcx> Zipper<'genv, 'tcx> {
(rty::BaseTy::Param(pty_a), ty::TyKind::Param(pty_b)) => {
debug_assert_eq!(pty_a, pty_b);
}
(
rty::BaseTy::Dynamic(poly_trait_refs, re_a),
ty::TyKind::Dynamic(poly_trait_refs_b, re_b),
) => {
self.zip_region(re_a, re_b);
debug_assert_eq!(poly_trait_refs.len(), poly_trait_refs_b.len());
}
(rty::BaseTy::Closure(..), _) => {
bug!("unexpected closure {a:?}");
}
Expand Down
41 changes: 37 additions & 4 deletions crates/flux-fhir-analysis/src/conv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
//! 3. Refinements are well-sorted.
mod fill_holes;

use std::{borrow::Borrow, iter};

use flux_common::{bug, iter::IterExt, span_bug};
Expand All @@ -24,7 +23,7 @@ use flux_middle::{
refining::{self, Refiner},
AdtSortDef, ESpan, WfckResults, INNERMOST,
},
rustc,
rustc::{self},
};
use itertools::Itertools;
use rustc_data_structures::fx::FxIndexMap;
Expand Down Expand Up @@ -435,6 +434,29 @@ impl<'a, 'genv, 'tcx> ConvCtxt<'a, 'genv, 'tcx> {
Ok(clauses)
}

fn conv_poly_trait_ref_dyn(
&mut self,
env: &mut Env,
poly_trait_ref: &fhir::PolyTraitRef,
) -> QueryResult<rty::Binder<rty::ExistentialPredicate>> {
let trait_segment = poly_trait_ref.trait_ref.last_segment();

if !poly_trait_ref.bound_generic_params.is_empty() {
bug!(
"unexpected! conv_poly_dyn_trait_ref bound_generic_params={:?}",
poly_trait_ref.bound_generic_params
);
}

let def_id = poly_trait_ref.trait_def_id();
let mut into = vec![];
self.conv_generic_args_into(env, def_id, trait_segment.args, &mut into)?;

let exi_trait_ref = rty::ExistentialTraitRef { def_id, args: into.into() };
let exi_pred = rty::ExistentialPredicate::Trait(exi_trait_ref);
Ok(rty::Binder::new(exi_pred, List::empty()))
}

/// Converts a `T: Trait<T0, ..., A0 = S0, ...>` bound
fn conv_poly_trait_ref(
&mut self,
Expand All @@ -453,8 +475,8 @@ impl<'a, 'genv, 'tcx> ConvCtxt<'a, 'genv, 'tcx> {
vec![self.ty_to_generic_arg(self_param.kind, bounded_ty_span, bounded_ty)?];
self.conv_generic_args_into(env, trait_id, trait_segment.args, &mut args)?;
self.fill_generic_args_defaults(trait_id, &mut args)?;

let trait_ref = rty::TraitRef { def_id: trait_id, args: args.into() };

let pred = rty::TraitPredicate { trait_ref: trait_ref.clone() };
let vars = poly_trait_ref
.bound_generic_params
Expand Down Expand Up @@ -788,6 +810,18 @@ impl<'a, 'genv, 'tcx> ConvCtxt<'a, 'genv, 'tcx> {
let alias_ty = rty::AliasTy::new(def_id, args, refine_args);
Ok(rty::Ty::alias(rty::AliasKind::Opaque, alias_ty))
}
fhir::TyKind::TraitObject(poly_traits, lft, syn) => {
let exi_preds: List<_> = poly_traits
.iter()
.map(|poly_trait| self.conv_poly_trait_ref_dyn(env, poly_trait))
.try_collect()?;
let region = self.conv_lifetime(env, *lft);
if matches!(syn, rustc_ast::TraitObjectSyntax::Dyn) {
Ok(rty::Ty::dynamic(exi_preds, region))
} else {
span_bug!(ty.span, "dyn* traits not supported yet")
}
}
}
}

Expand Down Expand Up @@ -1161,7 +1195,6 @@ impl<'a, 'genv, 'tcx> ConvCtxt<'a, 'genv, 'tcx> {
bug!("unexpected generic param: {param:?}");
}
}

Ok(())
}

Expand Down
5 changes: 5 additions & 0 deletions crates/flux-middle/src/fhir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use flux_common::{bug, span_bug};
use flux_syntax::surface::ParamMode;
pub use flux_syntax::surface::{BinOp, UnOp};
use itertools::Itertools;
use rustc_ast::TraitObjectSyntax;
use rustc_data_structures::fx::{FxIndexMap, FxIndexSet};
use rustc_hash::FxHashMap;
pub use rustc_hir::PrimTy;
Expand Down Expand Up @@ -538,6 +539,7 @@ pub enum TyKind<'fhir> {
Array(&'fhir Ty<'fhir>, ConstArg),
RawPtr(&'fhir Ty<'fhir>, Mutability),
OpaqueDef(ItemId, &'fhir [GenericArg<'fhir>], &'fhir [RefineArg<'fhir>], bool),
TraitObject(&'fhir [PolyTraitRef<'fhir>], Lifetime, TraitObjectSyntax),
Never,
Hole(FhirId),
}
Expand Down Expand Up @@ -1243,6 +1245,9 @@ impl fmt::Debug for Ty<'_> {
"impl trait <def_id = {def_id:?}, args = {args:?}, refine = {refine_args:?}>"
)
}
TyKind::TraitObject(poly_traits, _lft, _syntax) => {
write!(f, "dyn {poly_traits:?}")
}
}
}
}
Expand Down
8 changes: 8 additions & 0 deletions crates/flux-middle/src/fhir/lift.rs
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,14 @@ impl<'a, 'genv, 'tcx> LiftCtxt<'a, 'genv, 'tcx> {
let args = self.lift_generic_args(args)?;
fhir::TyKind::OpaqueDef(item_id, args, &[], in_trait_def)
}
hir::TyKind::TraitObject(poly_traits, lft, syntax) => {
let poly_traits = try_alloc_slice!(self.genv, poly_traits, |poly_trait| {
self.lift_poly_trait_ref(*poly_trait)
})?;

let lft = self.lift_lifetime(lft)?;
fhir::TyKind::TraitObject(poly_traits, lft, syntax)
}
_ => {
return self.emit_unsupported(&format!(
"unsupported type: `{}`",
Expand Down
4 changes: 4 additions & 0 deletions crates/flux-middle/src/fhir/visit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,10 @@ pub fn walk_ty<'v, V: Visitor<'v>>(vis: &mut V, ty: &Ty<'v>) {
}
TyKind::Never => {}
TyKind::Hole(_) => {}
TyKind::TraitObject(poly_traits, lft, _) => {
walk_list!(vis, visit_poly_trait_ref, poly_traits);
vis.visit_lifetime(&lft);
}
}
}

Expand Down
43 changes: 39 additions & 4 deletions crates/flux-middle/src/rty/fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ use super::{
projections,
subst::EVarSubstFolder,
AliasReft, AliasTy, BaseTy, BinOp, Binder, BoundVariableKind, Clause, ClauseKind, Const,
CoroutineObligPredicate, Ensures, Expr, ExprKind, FnOutput, FnSig, FnTraitPredicate, FuncSort,
GenericArg, Invariant, KVar, Lambda, Name, Opaqueness, OutlivesPredicate, PolyFuncSort,
ProjectionPredicate, PtrKind, Qualifier, ReBound, Region, Sort, SubsetTy, TraitPredicate,
TraitRef, Ty, TyKind,
CoroutineObligPredicate, Ensures, ExistentialPredicate, ExistentialTraitRef, Expr, ExprKind,
FnOutput, FnSig, FnTraitPredicate, FuncSort, GenericArg, Invariant, KVar, Lambda, Name,
Opaqueness, OutlivesPredicate, PolyFuncSort, ProjectionPredicate, PtrKind, Qualifier, ReBound,
Region, Sort, SubsetTy, TraitPredicate, TraitRef, Ty, TyKind,
};
use crate::{
global_env::GlobalEnv,
Expand Down Expand Up @@ -488,6 +488,37 @@ impl TypeFoldable for TraitRef {
}
}

impl TypeVisitable for ExistentialTraitRef {
fn visit_with<V: TypeVisitor>(&self, visitor: &mut V) -> ControlFlow<V::BreakTy> {
self.args.visit_with(visitor)
}
}

impl TypeFoldable for ExistentialTraitRef {
fn try_fold_with<F: FallibleTypeFolder>(&self, folder: &mut F) -> Result<Self, F::Error> {
Ok(ExistentialTraitRef { def_id: self.def_id, args: self.args.try_fold_with(folder)? })
}
}

impl TypeVisitable for ExistentialPredicate {
fn visit_with<V: TypeVisitor>(&self, visitor: &mut V) -> ControlFlow<V::BreakTy> {
match self {
ExistentialPredicate::Trait(exi_trait_ref) => exi_trait_ref.visit_with(visitor),
}
}
}

impl TypeFoldable for ExistentialPredicate {
fn try_fold_with<F: FallibleTypeFolder>(&self, folder: &mut F) -> Result<Self, F::Error> {
match self {
ExistentialPredicate::Trait(exi_trait_ref) => {
let exi_trait_ref = exi_trait_ref.try_fold_with(folder)?;
Ok(ExistentialPredicate::Trait(exi_trait_ref))
}
}
}
}

impl TypeVisitable for CoroutineObligPredicate {
fn visit_with<V: TypeVisitor>(&self, visitor: &mut V) -> ControlFlow<V::BreakTy> {
self.upvar_tys.visit_with(visitor)?;
Expand Down Expand Up @@ -940,6 +971,7 @@ impl TypeSuperVisitable for BaseTy {
| BaseTy::Closure(_, _)
| BaseTy::Never
| BaseTy::Param(_) => ControlFlow::Continue(()),
BaseTy::Dynamic(exi_preds, _) => exi_preds.visit_with(visitor),
}
}
}
Expand Down Expand Up @@ -979,6 +1011,9 @@ impl TypeSuperFoldable for BaseTy {
args.try_fold_with(folder)?,
)
}
BaseTy::Dynamic(exi_preds, region) => {
BaseTy::Dynamic(exi_preds.try_fold_with(folder)?, region.try_fold_with(folder)?)
}
};
Ok(bty)
}
Expand Down
47 changes: 46 additions & 1 deletion crates/flux-middle/src/rty/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ mod pretty;
pub mod projections;
pub mod refining;
pub mod subst;

use std::{borrow::Cow, hash::Hash, iter, slice, sync::LazyLock};

pub use evars::{EVar, EVarGen};
Expand Down Expand Up @@ -230,6 +229,36 @@ impl TraitRef {
}
}

#[derive(Debug, Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable)]
pub enum ExistentialPredicate {
Trait(ExistentialTraitRef),
}

impl Binder<ExistentialPredicate> {
fn to_rustc<'tcx>(
&self,
tcx: TyCtxt<'tcx>,
) -> rustc_middle::ty::Binder<'tcx, rustc_middle::ty::ExistentialPredicate<'tcx>> {
assert!(self.vars.is_empty());
match self.value {
ExistentialPredicate::Trait(ref exi_trait_ref) => {
let exi_trait_ref = rustc_middle::ty::ExistentialTraitRef {
def_id: exi_trait_ref.def_id,
args: exi_trait_ref.args.to_rustc(tcx),
};
let exi_pred = rustc_middle::ty::ExistentialPredicate::Trait(exi_trait_ref);
rustc_middle::ty::Binder::bind_with_vars(exi_pred, rustc_middle::ty::List::empty())
}
}
}
}

#[derive(Debug, Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable)]
pub struct ExistentialTraitRef {
pub def_id: DefId,
pub args: GenericArgs,
}

#[derive(PartialEq, Eq, Hash, Debug, Clone, TyEncodable, TyDecodable)]
pub struct ProjectionPredicate {
pub projection_ty: AliasTy,
Expand Down Expand Up @@ -622,6 +651,10 @@ impl Ty {
Self::alias(AliasKind::Projection, alias_ty)
}

pub fn dynamic(preds: impl Into<List<Binder<ExistentialPredicate>>>, region: Region) -> Ty {
BaseTy::Dynamic(preds.into(), region).to_ty()
}

pub fn strg_ref(re: Region, path: Path, ty: Ty) -> Ty {
TyKind::StrgRef(re, path, ty).intern()
}
Expand Down Expand Up @@ -951,6 +984,7 @@ pub enum BaseTy {
Never,
Closure(DefId, /* upvar_tys */ List<Ty>),
Coroutine(DefId, /*resume_ty: */ Ty, /* upvar_tys: */ List<Ty>),
Dynamic(List<Binder<ExistentialPredicate>>, Region),
Param(ParamTy),
}

Expand Down Expand Up @@ -1099,6 +1133,7 @@ impl BaseTy {
| BaseTy::Array(_, _)
| BaseTy::Closure(_, _)
| BaseTy::Coroutine(..)
| BaseTy::Dynamic(_, _)
| BaseTy::Never => Sort::unit(),
}
}
Expand Down Expand Up @@ -1131,6 +1166,14 @@ impl BaseTy {
BaseTy::Array(_, _) => todo!(),
BaseTy::Never => tcx.types.never,
BaseTy::Closure(_, _) => todo!(),
BaseTy::Dynamic(exi_preds, re) => {
let preds: Vec<_> = exi_preds
.iter()
.map(|pred| pred.to_rustc(tcx))
.collect_vec();
let preds = tcx.mk_poly_existential_predicates(&preds);
ty::Ty::new_dynamic(tcx, preds, re.to_rustc(tcx), rustc_middle::ty::DynKind::Dyn)
}
BaseTy::Coroutine(def_id, resume_ty, upvars) => {
todo!("Generator {def_id:?} {resume_ty:?} {upvars:?}")
// let args = args.iter().map(|arg| into_rustc_generic_arg(tcx, arg));
Expand Down Expand Up @@ -2055,6 +2098,8 @@ impl_slice_internable!(
InferMode,
Sort,
GenericParamDef,
TraitRef,
Binder<ExistentialPredicate>,
Clause,
PolyVariant,
Invariant,
Expand Down
14 changes: 14 additions & 0 deletions crates/flux-middle/src/rty/pretty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,17 @@ impl Pretty for List<Ty> {
}
}

impl Pretty for ExistentialPredicate {
fn fmt(&self, _cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
define_scoped!(_cx, f);
match self {
ExistentialPredicate::Trait(exi_trait_ref) => {
w!("{exi_trait_ref:?}")
}
}
}
}

impl Pretty for BaseTy {
fn fmt(&self, cx: &PrettyCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
define_scoped!(cx, f);
Expand Down Expand Up @@ -388,6 +399,9 @@ impl Pretty for BaseTy {
}
Ok(())
}
BaseTy::Dynamic(exi_preds, _) => {
w!("dyn {:?}", join!(", ", exi_preds))
}
}
}
}
Expand Down
31 changes: 31 additions & 0 deletions crates/flux-middle/src/rty/refining.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,30 @@ impl<'genv, 'tcx> Refiner<'genv, 'tcx> {
Ok(rty::ClauseKind::FnTrait(pred))
}

pub fn refine_existential_predicate(
&self,
exi_pred: &rustc::ty::Binder<rustc::ty::ExistentialPredicate>,
) -> QueryResult<rty::Binder<rty::ExistentialPredicate>> {
assert!(exi_pred.vars().is_empty());
let exi_pred = match exi_pred.as_ref().skip_binder() {
rustc::ty::ExistentialPredicate::Trait(exi_trait_ref) => {
rty::ExistentialPredicate::Trait(self.refine_exi_trait_ref(exi_trait_ref)?)
}
};
Ok(rty::Binder::new(exi_pred, List::empty()))
}

pub fn refine_exi_trait_ref(
&self,
exi_trait_ref: &rustc::ty::ExistentialTraitRef,
) -> QueryResult<rty::ExistentialTraitRef> {
let exi_trait_ref = rty::ExistentialTraitRef {
def_id: exi_trait_ref.def_id,
args: self.refine_generic_args(exi_trait_ref.def_id, &exi_trait_ref.args)?,
};
Ok(exi_trait_ref)
}

pub fn refine_trait_ref(&self, trait_ref: &rustc::ty::TraitRef) -> QueryResult<rty::TraitRef> {
let trait_ref = rty::TraitRef {
def_id: trait_ref.def_id,
Expand Down Expand Up @@ -358,6 +382,13 @@ impl<'genv, 'tcx> Refiner<'genv, 'tcx> {
rustc::ty::TyKind::RawPtr(ty, mu) => {
rty::BaseTy::RawPtr(self.as_default().refine_ty(ty)?, *mu)
}
rustc::ty::TyKind::Dynamic(exi_preds, r) => {
let exi_preds = exi_preds
.iter()
.map(|ty| self.refine_existential_predicate(ty))
.try_collect()?;
rty::BaseTy::Dynamic(exi_preds, *r)
}
};
Ok(TyOrBase::Base((self.refine)(bty)))
}
Expand Down
Loading

0 comments on commit aed06a7

Please sign in to comment.