Skip to content

Commit

Permalink
Merge pull request #82 from EnzymeAD/morett
Browse files Browse the repository at this point in the history
add tt to mem calls
  • Loading branch information
jedbrown authored Mar 29, 2024
2 parents 518390f + ccb9fab commit 5c5de7c
Show file tree
Hide file tree
Showing 13 changed files with 246 additions and 42 deletions.
1 change: 1 addition & 0 deletions compiler/rustc_codegen_llvm/src/abi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ impl<'ll, 'tcx> ArgAbiExt<'ll, 'tcx> for ArgAbi<'tcx, Ty<'tcx>> {
scratch_align,
bx.const_usize(self.layout.size.bytes()),
MemFlags::empty(),
None,
);

bx.lifetime_end(llscratch, scratch_size);
Expand Down
56 changes: 51 additions & 5 deletions compiler/rustc_codegen_llvm/src/back/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -703,10 +703,56 @@ unsafe fn create_call<'a>(tgt: &'a Value, src: &'a Value, rev_mode: bool,
let f_ty = LLVMRustGetFunctionType(src);

let inner_param_num = LLVMCountParams(src);
let mut outer_args: Vec<&Value> = get_params(tgt);
let outer_param_num = LLVMCountParams(tgt);
let outer_args: Vec<&Value> = get_params(tgt);
let inner_args: Vec<&Value> = get_params(src);
let mut call_args: Vec<&Value> = vec![];

if inner_param_num as usize != outer_args.len() {
panic!("Args len shouldn't differ. Please report this. {} : {}", inner_param_num, outer_args.len());
if inner_param_num == outer_param_num {
call_args = outer_args;
} else {
dbg!("Different number of args, adjusting");
let mut outer_pos: usize = 0;
let mut inner_pos: usize = 0;
// copy over if they are identical.
// If not, skip the outer arg (and assert it's int).
while outer_pos < outer_param_num as usize {
let inner_arg = inner_args[inner_pos];
let outer_arg = outer_args[outer_pos];
let inner_arg_ty = llvm::LLVMTypeOf(inner_arg);
let outer_arg_ty = llvm::LLVMTypeOf(outer_arg);
if inner_arg_ty == outer_arg_ty {
call_args.push(outer_arg);
inner_pos += 1;
outer_pos += 1;
} else {
// out: (ptr, <>int1, ptr, int2)
// inner: (ptr, <>ptr, int)
// goal: (ptr, ptr, int1), skipping int2
// we are here: <>
assert!(llvm::LLVMRustGetTypeKind(outer_arg_ty) == llvm::TypeKind::Integer);
assert!(llvm::LLVMRustGetTypeKind(inner_arg_ty) == llvm::TypeKind::Pointer);
let next_outer_arg = outer_args[outer_pos + 1];
let next_inner_arg = inner_args[inner_pos + 1];
let next_outer_arg_ty = llvm::LLVMTypeOf(next_outer_arg);
let next_inner_arg_ty = llvm::LLVMTypeOf(next_inner_arg);
assert!(llvm::LLVMRustGetTypeKind(next_outer_arg_ty) == llvm::TypeKind::Pointer);
assert!(llvm::LLVMRustGetTypeKind(next_inner_arg_ty) == llvm::TypeKind::Integer);
let next2_outer_arg = outer_args[outer_pos + 2];
let next2_outer_arg_ty = llvm::LLVMTypeOf(next2_outer_arg);
assert!(llvm::LLVMRustGetTypeKind(next2_outer_arg_ty) == llvm::TypeKind::Integer);
call_args.push(next_outer_arg);
call_args.push(outer_arg);

outer_pos += 3;
inner_pos += 2;
}
}
}


if inner_param_num as usize != call_args.len() {
panic!("Args len shouldn't differ. Please report this. {} : {}", inner_param_num, call_args.len());
}

let inner_fnc_name = llvm::get_value_name(src);
Expand All @@ -719,8 +765,8 @@ unsafe fn create_call<'a>(tgt: &'a Value, src: &'a Value, rev_mode: bool,
builder,
f_ty,
src,
outer_args.as_mut_ptr(),
outer_args.len(),
call_args.as_mut_ptr(),
call_args.len(),
c_inner_fnc_name.as_ptr(),
);

Expand Down
13 changes: 12 additions & 1 deletion compiler/rustc_codegen_llvm/src/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,13 @@ use rustc_data_structures::small_c_str::SmallCStr;
use rustc_middle::dep_graph;
use rustc_middle::middle::codegen_fn_attrs::CodegenFnAttrs;
use rustc_middle::mir::mono::{Linkage, Visibility};
use rustc_middle::ty::TyCtxt;
use rustc_session::config::DebugInfo;
use rustc_span::symbol::Symbol;
use rustc_target::spec::SanitizerSet;

use rustc_middle::mir::mono::MonoItem;
use rustc_middle::ty::{ParamEnv, TyCtxt, fnc_typetrees};

use std::time::Instant;

pub struct ValueIter<'ll> {
Expand Down Expand Up @@ -86,6 +88,15 @@ pub fn compile_codegen_unit(tcx: TyCtxt<'_>, cgu_name: Symbol) -> (ModuleCodegen
let mono_items = cx.codegen_unit.items_in_deterministic_order(cx.tcx);
for &(mono_item, data) in &mono_items {
mono_item.predefine::<Builder<'_, '_, '_>>(&cx, data.linkage, data.visibility);
let inst = match mono_item {
MonoItem::Fn(instance) => instance,
_ => continue,
};
let fn_ty = inst.ty(tcx, ParamEnv::empty());
let _fnc_tree = fnc_typetrees(tcx, fn_ty, &mut vec![]);
//trace!("codegen_module: predefine fn {}", inst);
//trace!("{} \n {:?} \n {:?}", inst, fn_ty, _fnc_tree);
// Manuel: TODO
}

// ... and now that we have everything pre-defined, fill out those definitions.
Expand Down
64 changes: 58 additions & 6 deletions compiler/rustc_codegen_llvm/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ use std::iter;
use std::ops::Deref;
use std::ptr;

use crate::typetree::to_enzyme_typetree;
use rustc_ast::expand::typetree::{TypeTree, FncTree};

// All Builders must have an llfn associated with them
#[must_use]
pub struct Builder<'a, 'll, 'tcx> {
Expand Down Expand Up @@ -134,6 +137,35 @@ macro_rules! builder_methods_for_value_instructions {
}
}

fn add_tt<'ll>(llmod: &'ll llvm::Module, llcx: &'ll llvm::Context,val: &'ll Value, tt: FncTree) {
let inputs = tt.args;
let _ret: TypeTree = tt.ret;
let llvm_data_layout: *const c_char = unsafe { llvm::LLVMGetDataLayoutStr(&*llmod) };
let llvm_data_layout =
std::str::from_utf8(unsafe { std::ffi::CStr::from_ptr(llvm_data_layout) }.to_bytes())
.expect("got a non-UTF8 data-layout from LLVM");
let attr_name = "enzyme_type";
let c_attr_name = std::ffi::CString::new(attr_name).unwrap();
for (i, &ref input) in inputs.iter().enumerate() {
let c_tt = to_enzyme_typetree(input.clone(), llvm_data_layout, llcx);
let c_str = unsafe { llvm::EnzymeTypeTreeToString(c_tt.inner) };
let c_str = unsafe { std::ffi::CStr::from_ptr(c_str) };
unsafe {
let attr = llvm::LLVMCreateStringAttribute(
llcx,
c_attr_name.as_ptr(),
c_attr_name.as_bytes().len() as c_uint,
c_str.as_ptr(),
c_str.to_bytes().len() as c_uint,
);
llvm::LLVMRustAddParamAttr(val, i as u32, attr);
}
unsafe { llvm::EnzymeTypeTreeToStringFree(c_str.as_ptr()) };
}
dbg!(&val);
}


impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
fn build(cx: &'a CodegenCx<'ll, 'tcx>, llbb: &'ll BasicBlock) -> Self {
let bx = Builder::with_cx(cx);
Expand Down Expand Up @@ -874,11 +906,12 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
src_align: Align,
size: &'ll Value,
flags: MemFlags,
tt: Option<FncTree>,
) {
assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memcpy not supported");
let size = self.intcast(size, self.type_isize(), false);
let is_volatile = flags.contains(MemFlags::VOLATILE);
unsafe {
let val = unsafe {
llvm::LLVMRustBuildMemCpy(
self.llbuilder,
dst,
Expand All @@ -887,7 +920,14 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
src_align.bytes() as c_uint,
size,
is_volatile,
);
)
};
if let Some(tt) = tt {
let llmod = self.cx.llmod;
let llcx = self.cx.llcx;
add_tt(llmod, llcx, val, tt);
} else {
trace!("builder: no tt");
}
}

Expand All @@ -899,11 +939,12 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
src_align: Align,
size: &'ll Value,
flags: MemFlags,
tt: Option<FncTree>,
) {
assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memmove not supported");
let size = self.intcast(size, self.type_isize(), false);
let is_volatile = flags.contains(MemFlags::VOLATILE);
unsafe {
let val = unsafe {
llvm::LLVMRustBuildMemMove(
self.llbuilder,
dst,
Expand All @@ -912,7 +953,12 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
src_align.bytes() as c_uint,
size,
is_volatile,
);
)
};
if let Some(tt) = tt {
let llmod = self.cx.llmod;
let llcx = self.cx.llcx;
add_tt(llmod, llcx, val, tt);
}
}

Expand All @@ -923,17 +969,23 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
size: &'ll Value,
align: Align,
flags: MemFlags,
tt: Option<FncTree>,
) {
let is_volatile = flags.contains(MemFlags::VOLATILE);
unsafe {
let val = unsafe {
llvm::LLVMRustBuildMemSet(
self.llbuilder,
ptr,
align.bytes() as c_uint,
fill_byte,
size,
is_volatile,
);
)
};
if let Some(tt) = tt {
let llmod = self.cx.llmod;
let llcx = self.cx.llcx;
add_tt(llmod, llcx, val, tt);
}
}

Expand Down
4 changes: 4 additions & 0 deletions compiler/rustc_codegen_llvm/src/llvm/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -945,6 +945,10 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff(
// We don't support volatile / extern / (global?) values.
// Just because I didn't had time to test them, and it seems less urgent.
let args_uncacheable = vec![0; input_tts.len()];
if args_uncacheable.len() != input_activity.len() {
dbg!("args_uncacheable.len(): {}", args_uncacheable.len());
dbg!("input_activity.len(): {}", input_activity.len());
}
assert!(args_uncacheable.len() == input_activity.len());
let num_fnc_args = LLVMCountParams(fnc);
println!("num_fnc_args: {}", num_fnc_args);
Expand Down
16 changes: 15 additions & 1 deletion compiler/rustc_codegen_ssa/src/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ use std::time::{Duration, Instant};

use itertools::Itertools;

use rustc_middle::ty::typetree_from;
use rustc_ast::expand::typetree::{TypeTree, FncTree};

pub fn bin_op_to_icmp_predicate(op: hir::BinOpKind, signed: bool) -> IntPredicate {
match op {
hir::BinOpKind::Eq => IntPredicate::IntEQ,
Expand Down Expand Up @@ -357,6 +360,7 @@ pub fn wants_new_eh_instructions(sess: &Session) -> bool {
wants_wasm_eh(sess) || wants_msvc_seh(sess)
}

// Manuel TODO
pub fn memcpy_ty<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>(
bx: &mut Bx,
dst: Bx::Value,
Expand All @@ -370,15 +374,25 @@ pub fn memcpy_ty<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>(
if size == 0 {
return;
}
let my_ty = layout.ty;
let tcx: TyCtxt<'_> = bx.cx().tcx();
let fnc_tree: TypeTree = typetree_from(tcx, my_ty);
let fnc_tree: FncTree = FncTree {
args: vec![fnc_tree.clone(), fnc_tree.clone()],
ret: TypeTree::new(),
};

if flags == MemFlags::empty()
&& let Some(bty) = bx.cx().scalar_copy_backend_type(layout)
{
let temp = bx.load(bty, src, src_align);
bx.store(temp, dst, dst_align);
} else {
bx.memcpy(dst, dst_align, src, src_align, bx.cx().const_usize(size), flags);
trace!("my_ty: {:?}, enzyme tt: {:?}", my_ty, fnc_tree);
trace!("memcpy_ty: {:?} -> {:?} (size={}, align={:?})", src, dst, size, dst_align);
bx.memcpy(dst, dst_align, src, src_align, bx.cx().const_usize(size), flags, Some(fnc_tree));
}
//let (_args, _ret): (Vec<TypeTree>, TypeTree) = (fnc_tree.args, fnc_tree.ret);
}

pub fn codegen_instance<'a, 'tcx: 'a, Bx: BuilderMethods<'a, 'tcx>>(
Expand Down
25 changes: 22 additions & 3 deletions compiler/rustc_codegen_ssa/src/mir/intrinsic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ use rustc_target::abi::{
WrappingRange,
};

use rustc_middle::ty::typetree_from;
use rustc_ast::expand::typetree::{TypeTree, FncTree};
use crate::rustc_middle::ty::layout::HasTyCtxt;

fn copy_intrinsic<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>(
bx: &mut Bx,
allow_overlap: bool,
Expand All @@ -25,15 +29,23 @@ fn copy_intrinsic<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>(
src: Bx::Value,
count: Bx::Value,
) {
let tcx: TyCtxt<'_> = bx.cx().tcx();
let fnc_tree: TypeTree = typetree_from(tcx, ty);
let fnc_tree: FncTree = FncTree {
args: vec![fnc_tree.clone(), fnc_tree.clone()],
ret: TypeTree::new(),
};

let layout = bx.layout_of(ty);
let size = layout.size;
let align = layout.align.abi;
let size = bx.mul(bx.const_usize(size.bytes()), count);
let flags = if volatile { MemFlags::VOLATILE } else { MemFlags::empty() };
trace!("copy: mir ty: {:?}, enzyme tt: {:?}", ty, fnc_tree);
if allow_overlap {
bx.memmove(dst, align, src, align, size, flags);
bx.memmove(dst, align, src, align, size, flags, Some(fnc_tree));
} else {
bx.memcpy(dst, align, src, align, size, flags);
bx.memcpy(dst, align, src, align, size, flags, Some(fnc_tree));
}
}

Expand All @@ -45,12 +57,19 @@ fn memset_intrinsic<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>(
val: Bx::Value,
count: Bx::Value,
) {
let tcx: TyCtxt<'_> = bx.cx().tcx();
let fnc_tree: TypeTree = typetree_from(tcx, ty);
let fnc_tree: FncTree = FncTree {
args: vec![fnc_tree.clone(), fnc_tree.clone()],
ret: TypeTree::new(),
};

let layout = bx.layout_of(ty);
let size = layout.size;
let align = layout.align.abi;
let size = bx.mul(bx.const_usize(size.bytes()), count);
let flags = if volatile { MemFlags::VOLATILE } else { MemFlags::empty() };
bx.memset(dst, val, size, align, flags);
bx.memset(dst, val, size, align, flags, Some(fnc_tree));
}

impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_codegen_ssa/src/mir/operand.rs
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,7 @@ impl<'a, 'tcx, V: CodegenObject> OperandValue<V> {
let neg_address = bx.neg(address);
let offset = bx.and(neg_address, align_minus_1);
let dst = bx.inbounds_gep(bx.type_i8(), alloca, &[offset]);
bx.memcpy(dst, min_align, llptr, min_align, size, MemFlags::empty());
bx.memcpy(dst, min_align, llptr, min_align, size, MemFlags::empty(), None);

// Store the allocated region and the extra to the indirect place.
let indirect_operand = OperandValue::Pair(dst, llextra);
Expand Down
6 changes: 4 additions & 2 deletions compiler/rustc_codegen_ssa/src/mir/rvalue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,15 +100,17 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {

// Use llvm.memset.p0i8.* to initialize all zero arrays
if bx.cx().const_to_opt_u128(v, false) == Some(0) {
//let ty = bx.cx().val_ty(v);
let fill = bx.cx().const_u8(0);
bx.memset(start, fill, size, dest.align, MemFlags::empty());
bx.memset(start, fill, size, dest.align, MemFlags::empty(), None);
return;
}

// Use llvm.memset.p0i8.* to initialize byte arrays
let v = bx.from_immediate(v);
if bx.cx().val_ty(v) == bx.cx().type_i8() {
bx.memset(start, v, size, dest.align, MemFlags::empty());
//let ty = bx.cx().type_i8();
bx.memset(start, v, size, dest.align, MemFlags::empty(), None);
return;
}
}
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_codegen_ssa/src/mir/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
let align = pointee_layout.align;
let dst = dst_val.immediate();
let src = src_val.immediate();
bx.memcpy(dst, align, src, align, bytes, crate::MemFlags::empty());
bx.memcpy(dst, align, src, align, bytes, crate::MemFlags::empty(), None);
}
mir::StatementKind::FakeRead(..)
| mir::StatementKind::Retag { .. }
Expand Down
Loading

0 comments on commit 5c5de7c

Please sign in to comment.