From f7a9c50d151f40b75cb02f85289efefbdbf5c165 Mon Sep 17 00:00:00 2001 From: Mikhail Babenko Date: Tue, 7 Apr 2020 18:45:12 +0300 Subject: [PATCH] rework VisitResult design, add helper method in wf --- chalk-derive/src/lib.rs | 27 ++++++++++++---- chalk-ir/src/visit.rs | 17 ++++++++-- chalk-ir/src/visit/boring_impls.rs | 51 ++++++++++++++++++------------ chalk-ir/src/visit/visitors.rs | 12 ++++--- chalk-solve/src/wf.rs | 35 ++++++++++---------- 5 files changed, 87 insertions(+), 55 deletions(-) diff --git a/chalk-derive/src/lib.rs b/chalk-derive/src/lib.rs index 61ccd1c1941..e637938cf46 100644 --- a/chalk-derive/src/lib.rs +++ b/chalk-derive/src/lib.rs @@ -494,11 +494,16 @@ fn derive_visit_body(type_name: &Ident, data: Data) -> proc_macro2::TokenStream Data::Struct(s) => { let fields = s.fields.into_iter().map(|f| { let name = f.ident.as_ref().expect("Unnamed field in a struct"); - quote! { .and_then(|| self.#name.visit_with(visitor, outer_binder)) } + quote! { + result = result.combine(self.#name.visit_with(visitor, outer_binder)); + if result.return_early() { return result; } + } }); quote! { - R::new() - #(#fields)* + let mut result = R::new(); + #(#fields)* + + result } } Data::Enum(e) => { @@ -509,8 +514,12 @@ fn derive_visit_body(type_name: &Ident, data: Data) -> proc_macro2::TokenStream let fnames: &Vec<_> = &fields.named.iter().map(|f| &f.ident).collect(); quote! { #type_name :: #variant { #(#fnames),* } => { - R::new() - #(.and_then(|| #fnames.visit_with(visitor, outer_binder)))* + let mut result = R::new(); + #( + result = result.combine(#fnames.visit_with(visitor, outer_binder)); + if result.return_early() { return result; } + )* + result } } } @@ -521,8 +530,12 @@ fn derive_visit_body(type_name: &Ident, data: Data) -> proc_macro2::TokenStream .collect(); quote! { #type_name::#variant( #(ref #names),* ) => { - R::new() - #(.and_then(|| #names.visit_with(visitor, outer_binder)))* + let mut result = R::new(); + #( + result = result.combine(#names.visit_with(visitor, outer_binder)); + if result.return_early() { return result; } + )* + result } } } diff --git a/chalk-ir/src/visit.rs b/chalk-ir/src/visit.rs index 3f87571334a..6449ce2a8aa 100644 --- a/chalk-ir/src/visit.rs +++ b/chalk-ir/src/visit.rs @@ -15,14 +15,25 @@ pub use visitors::VisitExt; pub trait VisitResult: Sized { fn new() -> Self; - fn and_then(self, op: impl FnOnce() -> Self) -> Self; + fn return_early(&self) -> bool; + fn combine(self, other: Self) -> Self; + + fn and_then(self, op: impl FnOnce() -> Self) -> Self { + if self.return_early() { + self + } else { + self.combine(op()) + } + } } impl VisitResult for () { fn new() -> () {} - fn and_then(self, op: impl FnOnce() -> ()) -> () { - op() + + fn return_early(&self) -> bool { + false } + fn combine(self, _other: Self) {} } /// A "visitor" recursively folds some term -- that is, some bit of IR, diff --git a/chalk-ir/src/visit/boring_impls.rs b/chalk-ir/src/visit/boring_impls.rs index 2a623af7a81..c24f1b2b765 100644 --- a/chalk-ir/src/visit/boring_impls.rs +++ b/chalk-ir/src/visit/boring_impls.rs @@ -12,6 +12,27 @@ use crate::{ use chalk_engine::{context::Context, ExClause, FlounderedSubgoal, Literal}; use std::{marker::PhantomData, sync::Arc}; +pub fn visit_iter<'i, T, I, IT, R>( + it: IT, + visitor: &mut dyn Visitor<'i, I, Result = R>, + outer_binder: DebruijnIndex, +) -> R +where + T: Visit, + I: 'i + Interner, + IT: Iterator, + R: VisitResult, +{ + let mut result = R::new(); + for e in it { + result = result.combine(e.visit_with(visitor, outer_binder)); + if result.return_early() { + return result; + } + } + result +} + impl, I: Interner> Visit for &T { fn visit_with<'i, R: VisitResult>( &self, @@ -34,11 +55,7 @@ impl, I: Interner> Visit for Vec { where I: 'i, { - let mut result = R::new(); - for e in self { - result = result.and_then(|| e.visit_with(visitor, outer_binder)) - } - result + visit_iter(self.iter(), visitor, outer_binder) } } @@ -75,8 +92,12 @@ macro_rules! tuple_visit { { #[allow(non_snake_case)] let &($(ref $n),*) = self; - R::new() - $(.and_then(|| $n.visit_with(visitor, outer_binder)))* + let mut result = R::new(); + $( + result = result.combine($n.visit_with(visitor, outer_binder)); + if result.return_early() { return result; } + )* + result } } } @@ -127,13 +148,7 @@ impl Visit for Substitution { I: 'i, { let interner = visitor.interner(); - let mut result = R::new(); - - for p in self.iter(interner) { - result = result.and_then(|| p.visit_with(visitor, outer_binder)); - } - - result + visit_iter(self.iter(interner), visitor, outer_binder) } } @@ -147,13 +162,7 @@ impl Visit for Goals { I: 'i, { let interner = visitor.interner(); - let mut result = R::new(); - - for p in self.iter(interner) { - result = result.and_then(|| p.visit_with(visitor, outer_binder)); - } - - result + visit_iter(self.iter(interner), visitor, outer_binder) } } diff --git a/chalk-ir/src/visit/visitors.rs b/chalk-ir/src/visit/visitors.rs index fa3f858f4d3..c4175175189 100644 --- a/chalk-ir/src/visit/visitors.rs +++ b/chalk-ir/src/visit/visitors.rs @@ -27,11 +27,13 @@ impl VisitResult for FindAny { fn new() -> Self { FindAny { found: false } } - fn and_then(self, op: impl FnOnce() -> Self) -> Self { - if self.found { - self - } else { - op() + + fn return_early(&self) -> bool { + self.found + } + fn combine(self, other: Self) -> Self { + FindAny { + found: self.found || other.found, } } } diff --git a/chalk-solve/src/wf.rs b/chalk-solve/src/wf.rs index dc773b71dd8..3c988e3edf4 100644 --- a/chalk-solve/src/wf.rs +++ b/chalk-solve/src/wf.rs @@ -54,6 +54,12 @@ impl<'i, I: Interner> InputTypeCollector<'i, I> { interner, } } + + fn types_in(interner: &'i I, value: impl Visit) -> Vec> { + let mut collector = Self::new(interner); + value.visit_with(&mut collector, DebruijnIndex::INNERMOST); + collector.types + } } impl<'i, 't, I: Interner> Visitor<'i, I> for InputTypeCollector<'i, I> { @@ -165,13 +171,10 @@ where .map(|wc| wc.into_from_env_goal(interner)), |gb| { // WellFormed(Vec), for each field type `Vec` or type that appears in the where clauses - let mut type_collector = InputTypeCollector::new(gb.interner()); + let types = + InputTypeCollector::types_in(gb.interner(), (&fields, &where_clauses)); - // ...in a field type... - fields.visit_with(&mut type_collector, DebruijnIndex::INNERMOST); - // ...in a where clause. - where_clauses.visit_with(&mut type_collector, DebruijnIndex::INNERMOST); - gb.all(type_collector.types.into_iter().map(|ty| ty.well_formed())) + gb.all(types.into_iter().map(|ty| ty.well_formed())) }, ) }); @@ -255,14 +258,12 @@ fn impl_header_wf_goal( // we would retrieve `HashSet`, `Box`, `Vec>`, `(HashSet, Vec>)`. // We will have to prove that these types are well-formed (e.g. an additional `K: Hash` // bound would be needed here). - let mut type_collector = InputTypeCollector::new(interner); - where_clauses.visit_with(&mut type_collector, DebruijnIndex::INNERMOST); + let types = InputTypeCollector::types_in(gb.interner(), &where_clauses); // Things to prove well-formed: input types of the where-clauses, projection types // appearing in the header, associated type values, and of course the trait ref. - debug!("verify_trait_impl: input_types={:?}", type_collector.types); - let goals = type_collector - .types + debug!("verify_trait_impl: input_types={:?}", types); + let goals = types .into_iter() .map(|ty| ty.well_formed().cast(interner)) .chain(Some((*trait_ref).clone().well_formed().cast(interner))); @@ -298,11 +299,9 @@ fn impl_wf_environment<'i, I: Interner>( // // Inside here, we can rely on the fact that `K: Hash` holds // } // ``` - let mut type_collector = InputTypeCollector::new(interner); - trait_ref.visit_with(&mut type_collector, DebruijnIndex::INNERMOST); + let types = InputTypeCollector::types_in(interner, trait_ref); - let types_wf = type_collector - .types + let types_wf = types .into_iter() .map(move |ty| ty.into_from_env_goal(interner).cast(interner)); @@ -403,12 +402,10 @@ fn compute_assoc_ty_goal( .cloned() .map(|qwc| qwc.into_from_env_goal(interner)), |gb| { - let mut type_collector = InputTypeCollector::new(interner); - value_ty.visit_with(&mut type_collector, DebruijnIndex::INNERMOST); + let types = InputTypeCollector::types_in(gb.interner(), value_ty); // We require that `WellFormed(T)` for each type that appears in the value - let wf_goals = type_collector - .types + let wf_goals = types .into_iter() .map(|ty| ty.well_formed()) .casted(interner);