Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimization for expressions #37

Draft
wants to merge 17 commits into
base: rewrite-2023
Choose a base branch
from
17 changes: 15 additions & 2 deletions prusti-encoder/src/encoders/const.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,19 @@ use vir::{CallableIdent, Arity};

pub struct ConstEnc;


#[derive(Clone)]
pub struct ConstEncOutput<'vir>(pub vir::Expr<'vir>);


impl<'vir> task_encoder::Optimizable for ConstEncOutput<'vir> {}

impl<'vir> From<vir::Expr<'vir>> for ConstEncOutput<'vir> {
fn from(value: vir::Expr<'vir>) -> Self {
Self(value)
}
}

#[derive(Clone, Debug)]
pub struct ConstEncOutputRef<'vir> {
pub base_name: String,
Expand All @@ -28,7 +41,7 @@ impl TaskEncoder for ConstEnc {
usize, // current encoding depth
DefId, // DefId of the current function
);
type OutputFullLocal<'vir> = vir::Expr<'vir>;
type OutputFullLocal<'vir> = ConstEncOutput<'vir>;
type EncodingError = ();

fn task_to_key<'vir>(task: &Self::TaskDescription<'vir>) -> Self::TaskKey<'vir> {
Expand Down Expand Up @@ -94,6 +107,6 @@ impl TaskEncoder for ConstEnc {
}),
mir::ConstantKind::Ty(_) => todo!(),
};
Ok((res, ()))
Ok((res.into(), ()))
}
}
3 changes: 3 additions & 0 deletions prusti-encoder/src/encoders/generic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ pub struct GenericEncOutput<'vir> {
pub domain_type: vir::Domain<'vir>,
}

impl<'vir> task_encoder::Optimizable for GenericEncOutput<'vir> {}


impl TaskEncoder for GenericEnc {
task_encoder::encoder_cache!(GenericEnc);

Expand Down
3 changes: 3 additions & 0 deletions prusti-encoder/src/encoders/local_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ pub struct MirLocalDefEncOutput<'vir> {
}
pub type MirLocalDefEncError = ();


impl<'vir> task_encoder::Optimizable for MirLocalDefEncOutput<'vir> {}

#[derive(Clone, Copy)]
pub struct LocalDef<'vir> {
pub local: vir::Local<'vir>,
Expand Down
3 changes: 3 additions & 0 deletions prusti-encoder/src/encoders/mir_builtin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ pub struct MirBuiltinEncOutput<'vir> {
pub function: vir::Function<'vir>,
}

impl<'vir> task_encoder::Optimizable for MirBuiltinEncOutput<'vir> {}


use crate::encoders::SnapshotEnc;

impl TaskEncoder for MirBuiltinEnc {
Expand Down
15 changes: 12 additions & 3 deletions prusti-encoder/src/encoders/mir_impure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use task_encoder::{
TaskEncoder,
TaskEncoderDependencies,
};
use vir::{MethodIdent, UnknownArity, CallableIdent};
use vir::{with_vcx, CallableIdent, MethodIdent, Optimizable, UnknownArity};

pub struct MirImpureEnc;

Expand All @@ -32,6 +32,15 @@ pub struct MirImpureEncOutput<'vir> {
pub method: vir::Method<'vir>,
}

impl<'vir> task_encoder::Optimizable for MirImpureEncOutput<'vir> {
fn optimize(self) -> Self {
let method = self.method.optimize();
let method = with_vcx(|vcx| vcx.alloc(method));
MirImpureEncOutput { method }
}
}


use crate::encoders::{PredicateEnc, ConstEnc, MirBuiltinEnc, MirFunctionEnc, MirLocalDefEnc, MirSpecEnc};

const ENCODE_REACH_BB: bool = false;
Expand Down Expand Up @@ -391,7 +400,7 @@ impl<'tcx, 'vir, 'enc> EncVisitor<'tcx, 'vir, 'enc> {
ty_out.ref_to_snap.apply(self.vcx, [self.encode_place(Place::from(source))])
}
mir::Operand::Constant(box constant) =>
self.deps.require_local::<ConstEnc>((constant.literal, 0, self.def_id)).unwrap()
self.deps.require_local::<ConstEnc>((constant.literal, 0, self.def_id)).unwrap().0
}
}

Expand All @@ -409,7 +418,7 @@ impl<'tcx, 'vir, 'enc> EncVisitor<'tcx, 'vir, 'enc> {
}
mir::Operand::Constant(box constant) => {
let ty_out = self.deps.require_ref::<PredicateEnc>(ty).unwrap();
let constant = self.deps.require_local::<ConstEnc>((constant.literal, 0, self.def_id)).unwrap();
let constant = self.deps.require_local::<ConstEnc>((constant.literal, 0, self.def_id)).unwrap().0;
(constant, ty_out)
}
};
Expand Down
6 changes: 5 additions & 1 deletion prusti-encoder/src/encoders/mir_pure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ pub struct MirPureEncOutput<'vir> {
pub expr: ExprRet<'vir>,
}

impl<'vir> task_encoder::Optimizable for MirPureEncOutput<'vir> {}



#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum PureKind {
Closure,
Expand Down Expand Up @@ -624,7 +628,7 @@ impl<'tcx, 'vir: 'enc, 'enc> Enc<'tcx, 'vir, 'enc>
mir::Operand::Copy(place)
| mir::Operand::Move(place) => self.encode_place(curr_ver, place),
mir::Operand::Constant(box constant) =>
self.deps.require_local::<ConstEnc>((constant.literal, self.encoding_depth, self.def_id)).unwrap().lift(),
self.deps.require_local::<ConstEnc>((constant.literal, self.encoding_depth, self.def_id)).unwrap().0.lift(),
}
}

Expand Down
11 changes: 10 additions & 1 deletion prusti-encoder/src/encoders/mir_pure_function.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use prusti_rustc_interface::{middle::{mir, ty}, span::def_id::DefId};

use task_encoder::{TaskEncoder, TaskEncoderDependencies};
use vir::{Reify, FunctionIdent, UnknownArity, CallableIdent};
use vir::{CallableIdent, FunctionIdent, Optimizable, Reify, UnknownArity};

use crate::encoders::{
MirPureEnc, MirPureEncTask, mir_pure::PureKind, MirSpecEnc, MirLocalDefEnc,
Expand All @@ -28,6 +28,15 @@ pub struct MirFunctionEncOutput<'vir> {
pub function: vir::Function<'vir>,
}

impl<'vir> task_encoder::Optimizable for MirFunctionEncOutput<'vir> {
fn optimize(self) -> Self {
let function = self.function.optimize();
let function = vir::with_vcx(|vcx| vcx.alloc(function));

MirFunctionEncOutput { function }
}
}

impl TaskEncoder for MirFunctionEnc {
task_encoder::encoder_cache!(MirFunctionEnc);

Expand Down
3 changes: 3 additions & 0 deletions prusti-encoder/src/encoders/pure/spec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ pub struct MirSpecEncOutput<'vir> {
pub post_args: &'vir [vir::Expr<'vir>],
}

impl<'vir> task_encoder::Optimizable for MirSpecEncOutput<'vir> {}


impl TaskEncoder for MirSpecEnc {
task_encoder::encoder_cache!(MirSpecEnc);

Expand Down
5 changes: 3 additions & 2 deletions prusti-encoder/src/encoders/spec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@ use prusti_rustc_interface::{
};
use prusti_interface::specs::typed::{DefSpecificationMap, ProcedureSpecification};
use task_encoder::{
TaskEncoder,
TaskEncoderDependencies,
Optimizable, TaskEncoder, TaskEncoderDependencies
};

pub struct SpecEnc;
Expand All @@ -19,6 +18,8 @@ pub struct SpecEncOutput<'vir> {
pub posts: &'vir [DefId],
}

impl<'vir> Optimizable for SpecEncOutput<'vir> {}

use std::cell::RefCell;
thread_local! {
static DEF_SPEC_MAP: RefCell<Option<DefSpecificationMap>> = RefCell::new(Default::default());
Expand Down
24 changes: 18 additions & 6 deletions prusti-encoder/src/encoders/type/domain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,21 @@ pub struct DomainEncOutputRef<'vir> {
}
impl<'vir> task_encoder::OutputRefAny for DomainEncOutputRef<'vir> {}


#[derive(Clone)]
pub struct DomainEncOutput<'vir>(pub vir::Domain<'vir>);

impl<'vir> task_encoder::Optimizable for DomainEncOutput<'vir> {}

impl<'vir> From<vir::Domain<'vir>> for DomainEncOutput<'vir> {
fn from(value: vir::Domain<'vir>) -> Self {
DomainEncOutput(value)
}
}

use crate::encoders::SnapshotEnc;

pub fn all_outputs<'vir>() -> Vec<vir::Domain<'vir>> {
pub fn all_outputs<'vir>() -> Vec<DomainEncOutput<'vir>> {
DomainEnc::all_outputs()
}

Expand All @@ -90,7 +102,7 @@ impl TaskEncoder for DomainEnc {

type OutputRef<'vir> = DomainEncOutputRef<'vir>;
type OutputFullDependency<'vir> = DomainEncSpecifics<'vir>;
type OutputFullLocal<'vir> = vir::Domain<'vir>;
type OutputFullLocal<'vir> = DomainEncOutput<'vir>;
//type OutputFullDependency<'vir> = DomainEncOutputDep<'vir>;

type EncodingError = ();
Expand All @@ -109,7 +121,7 @@ impl TaskEncoder for DomainEnc {
Self::EncodingError,
Option<Self::OutputFullDependency<'vir>>,
)> {
vir::with_vcx(|vcx| match task_key.kind() {
(vir::with_vcx(|vcx| match task_key.kind() {
TyKind::Bool | TyKind::Char | TyKind::Int(_) | TyKind::Uint(_) | TyKind::Float(_) => {
let (base_name, prim_type) = match task_key.kind() {
TyKind::Bool => (String::from("Bool"), &vir::TypeData::Bool),
Expand Down Expand Up @@ -197,7 +209,7 @@ impl TaskEncoder for DomainEnc {
Ok((enc.finalize(), specifics))
}
kind => todo!("{kind:?}"),
})
}))
}
}

Expand Down Expand Up @@ -531,13 +543,13 @@ impl<'vir, 'tcx> DomainEncData<'vir, 'tcx> {
domain: self.domain,
}
}
fn finalize(self) -> vir::Domain<'vir> {
fn finalize(self) -> DomainEncOutput<'vir> {
self.vcx.mk_domain(
self.domain.name(),
self.domain.arity().args(),
self.vcx.alloc_slice(&self.axioms),
self.vcx.alloc_slice(&self.functions),
)
).into()
}
}

Expand Down
3 changes: 3 additions & 0 deletions prusti-encoder/src/encoders/type/predicate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,9 @@ pub struct PredicateEncOutput<'vir> {
pub method_assign: vir::Method<'vir>,
}

impl<'vir> task_encoder::Optimizable for PredicateEncOutput<'vir> {}


use super::{snapshot::SnapshotEnc, domain::{DomainDataPrim, DomainDataStruct, DomainDataEnum, DiscrBounds}};

impl TaskEncoder for PredicateEnc {
Expand Down
3 changes: 3 additions & 0 deletions prusti-encoder/src/encoders/type/snapshot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ pub struct SnapshotEncOutput<'vir> {
pub specifics: DomainEncSpecifics<'vir>,
}


impl<'vir> task_encoder::Optimizable for SnapshotEncOutput<'vir> {}

use super::domain::{DomainEnc, DomainEncSpecifics};

impl TaskEncoder for SnapshotEnc {
Expand Down
3 changes: 3 additions & 0 deletions prusti-encoder/src/encoders/type/viper_tuple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ pub struct ViperTupleEncOutput<'vir> {
tuple: Option<DomainDataStruct<'vir>>,
}

impl<'vir> task_encoder::Optimizable for ViperTupleEncOutput<'vir> {}


impl<'vir> ViperTupleEncOutput<'vir> {
pub fn mk_cons<'tcx, Curr, Next>(
&self,
Expand Down
34 changes: 27 additions & 7 deletions prusti-encoder/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,26 @@ use prusti_rustc_interface::{
hir,
};


const ENABLE_OPTIMIZATION : bool = true;

// Wrapper Trait for task_encoder::Optimizable to allow toggling of optimization
// TODO: replace with config
trait MaybeOptimize {
fn optimize(self) -> Self;
}

impl<T> MaybeOptimize for T where T : task_encoder::Optimizable {
fn optimize(self) -> Self {
if ENABLE_OPTIMIZATION {
task_encoder::Optimizable::optimize(self)
}
else {
self
}
}
}

pub fn test_entrypoint<'tcx>(
tcx: ty::TyCtxt<'tcx>,
body: EnvBody<'tcx>,
Expand Down Expand Up @@ -63,34 +83,34 @@ pub fn test_entrypoint<'tcx>(
let mut viper_code = String::new();

header(&mut viper_code, "methods");
for output in crate::encoders::MirImpureEnc::all_outputs() {
for output in crate::encoders::MirImpureEnc::all_outputs().optimize() {
viper_code.push_str(&format!("{:?}\n", output.method));
}

header(&mut viper_code, "functions");
for output in crate::encoders::MirFunctionEnc::all_outputs() {
for output in crate::encoders::MirFunctionEnc::all_outputs().optimize() {
viper_code.push_str(&format!("{:?}\n", output.function));
}

header(&mut viper_code, "MIR builtins");
for output in crate::encoders::MirBuiltinEnc::all_outputs() {
for output in crate::encoders::MirBuiltinEnc::all_outputs().optimize() {
viper_code.push_str(&format!("{:?}\n", output.function));
}

header(&mut viper_code, "generics");
for output in crate::encoders::GenericEnc::all_outputs() {
for output in crate::encoders::GenericEnc::all_outputs().optimize() {
viper_code.push_str(&format!("{:?}\n", output.snapshot_param));
viper_code.push_str(&format!("{:?}\n", output.predicate_param));
viper_code.push_str(&format!("{:?}\n", output.domain_type));
}

header(&mut viper_code, "snapshots");
for output in crate::encoders::DomainEnc_all_outputs() {
viper_code.push_str(&format!("{:?}\n", output));
for output in crate::encoders::DomainEnc_all_outputs().optimize() {
viper_code.push_str(&format!("{:?}\n", output.0));
}

header(&mut viper_code, "types");
for output in crate::encoders::PredicateEnc::all_outputs() {
for output in crate::encoders::PredicateEnc::all_outputs().optimize() {
for field in output.fields {
viper_code.push_str(&format!("{:?}", field));
}
Expand Down
14 changes: 13 additions & 1 deletion task-encoder/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,18 @@ use std::cell::RefCell;
pub trait OutputRefAny {}
impl OutputRefAny for () {}


pub trait Optimizable: Sized {
fn optimize(self) -> Self {
self
}
}

impl<T> Optimizable for Vec<T> where T: Optimizable {
fn optimize(self) -> Self {
self.into_iter().map(|e|e.optimize()).collect()
}
}
pub enum TaskEncoderCacheState<'vir, E: TaskEncoder + 'vir + ?Sized> {
// None, // indicated by absence in the cache

Expand Down Expand Up @@ -177,7 +189,7 @@ pub trait TaskEncoder {
/// Fully encoded output for this task. When encoding items which can be
/// dependencies (such as methods), this output should only be emitted in
/// one Viper program.
type OutputFullLocal<'vir>: Clone;
type OutputFullLocal<'vir>: Clone + Optimizable;

/// Fully encoded output for this task for dependents. When encoding items
/// which can be dependencies (such as methods), this output should be
Expand Down
5 changes: 3 additions & 2 deletions vir/src/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ pub enum ConstData {
Null,
}

#[derive(PartialEq, Eq)]
pub enum TypeData<'vir> {
Int {
bit_width: u8,
Expand All @@ -102,12 +103,12 @@ pub enum TypeData<'vir> {
Unsupported(UnsupportedType<'vir>)
}

#[derive(Clone)]
#[derive(Clone, PartialEq, Eq)]
pub struct UnsupportedType<'vir> {
pub name: &'vir str,
}

#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct DomainParamData<'vir> {
pub name: &'vir str, // TODO: identifiers
}
Expand Down
Loading