Skip to content

Commit

Permalink
Refactor fnc tt (#133)
Browse files Browse the repository at this point in the history
* fix wrapper

* add tt writes into module

* don't pass tt anymore through Enzyme API
  • Loading branch information
ZuseZ4 authored Jul 15, 2024
1 parent eba256f commit 269d384
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 19 deletions.
19 changes: 19 additions & 0 deletions compiler/rustc_codegen_llvm/src/back/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ use llvm::{
LLVMRustLLVMHasZlibCompressionForDebugSymbols, LLVMRustLLVMHasZstdCompressionForDebugSymbols, LLVMGetNextBasicBlock,
};
use rustc_ast::expand::autodiff_attrs::{AutoDiffItem, DiffActivity, DiffMode};
use rustc_ast::expand::typetree::FncTree;
use rustc_codegen_ssa::back::link::ensure_removed;
use rustc_codegen_ssa::back::write::{
BitcodeSection, CodegenContext, EmitObj, ModuleConfig, TargetMachineFactoryConfig,
Expand Down Expand Up @@ -1091,6 +1092,24 @@ pub(crate) unsafe fn differentiate(
llvm::set_loose_types(true);
}

// Before dumping the module, we want all the tt to become part of the module.
for item in &diff_items {
let llvm_data_layout = unsafe { llvm::LLVMGetDataLayoutStr(&*llmod) };
let llvm_data_layout =
std::str::from_utf8(unsafe { CStr::from_ptr(llvm_data_layout) }.to_bytes())
.expect("got a non-UTF8 data-layout from LLVM");
//let input_tts: Vec<TypeTree> =
// item.inputs.iter().map(|x| to_enzyme_typetree(x.clone(), llvm_data_layout, llcx)).collect();
//let output_tt = to_enzyme_typetree(item.output, llvm_data_layout, llcx);
let tt: FncTree = FncTree {
args: item.inputs.clone(),
ret: item.output.clone(),
};
let name = CString::new(item.source.clone()).unwrap();
let fn_def: &llvm::Value = llvm::LLVMGetNamedFunction(llmod, name.as_ptr()).unwrap();
crate::builder::add_tt2(llmod, llcx, fn_def, tt);
}

if std::env::var("ENZYME_PRINT_MOD_BEFORE").is_ok() {
unsafe {
LLVMDumpModule(llmod);
Expand Down
43 changes: 43 additions & 0 deletions compiler/rustc_codegen_llvm/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,49 @@ macro_rules! builder_methods_for_value_instructions {
})+
}
}
pub fn add_tt2<'ll>(llmod: &'ll llvm::Module, llcx: &'ll llvm::Context, fn_def: &'ll Value, tt: FncTree) {
let inputs = tt.args;
let ret_tt: TypeTree = tt.ret;
let llvm_data_layout: *const c_char = unsafe { llvm::LLVMGetDataLayoutStr(&*llmod) };
let llvm_data_layout =
std::str::from_utf8(unsafe { std::ffi::CStr::from_ptr(llvm_data_layout) }.to_bytes())
.expect("got a non-UTF8 data-layout from LLVM");
let attr_name = "enzyme_type";
let c_attr_name = std::ffi::CString::new(attr_name).unwrap();
for (i, &ref input) in inputs.iter().enumerate() {
let c_tt = to_enzyme_typetree(input.clone(), llvm_data_layout, llcx);
let c_str = unsafe { llvm::EnzymeTypeTreeToString(c_tt.inner) };
let c_str = unsafe { std::ffi::CStr::from_ptr(c_str) };
unsafe {
let attr = llvm::LLVMCreateStringAttribute(
llcx,
c_attr_name.as_ptr(),
c_attr_name.as_bytes().len() as c_uint,
c_str.as_ptr(),
c_str.to_bytes().len() as c_uint,
);
llvm::LLVMRustAddFncParamAttr(fn_def, i as u32, attr);
}
unsafe { llvm::EnzymeTypeTreeToStringFree(c_str.as_ptr()) };
}
let ret_attr = unsafe {
let c_tt = to_enzyme_typetree(ret_tt, llvm_data_layout, llcx);
let c_str = llvm::EnzymeTypeTreeToString(c_tt.inner);
let c_str = std::ffi::CStr::from_ptr(c_str);
let attr = llvm::LLVMCreateStringAttribute(
llcx,
c_attr_name.as_ptr(),
c_attr_name.as_bytes().len() as c_uint,
c_str.as_ptr(),
c_str.to_bytes().len() as c_uint,
);
llvm::EnzymeTypeTreeToStringFree(c_str.as_ptr());
attr
};
unsafe {
llvm::LLVMRustAddRetFncAttr(fn_def, ret_attr);
}
}

fn add_tt<'ll>(llmod: &'ll llvm::Module, llcx: &'ll llvm::Context,val: &'ll Value, tt: FncTree) {
let inputs = tt.args;
Expand Down
41 changes: 24 additions & 17 deletions compiler/rustc_codegen_llvm/src/llvm/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -848,8 +848,8 @@ pub(crate) unsafe fn enzyme_rust_forward_diff(
fnc: &Value,
input_diffactivity: Vec<DiffActivity>,
ret_diffactivity: DiffActivity,
input_tts: Vec<TypeTree>,
output_tt: TypeTree,
_input_tts: Vec<TypeTree>,
_output_tt: TypeTree,
void_ret: bool,
) -> (&Value, Vec<usize>) {
let ret_activity = cdiffe_from(ret_diffactivity);
Expand Down Expand Up @@ -878,13 +878,12 @@ pub(crate) unsafe fn enzyme_rust_forward_diff(
};
trace!("ret_primary_ret: {}", &ret_primary_ret);

let mut args_tree = input_tts.iter().map(|x| x.inner).collect::<Vec<_>>();
//let mut args_tree = input_tts.iter().map(|x| x.inner).collect::<Vec<_>>();
//let mut args_tree = vec![TypeTree::new().inner; typetree.input_tt.len()];

// We don't support volatile / extern / (global?) values.
// Just because I didn't had time to test them, and it seems less urgent.
let args_uncacheable = vec![0; input_tts.len()];
assert!(args_uncacheable.len() == input_activity.len());
let args_uncacheable = vec![0; input_activity.len()];
let num_fnc_args = LLVMCountParams(fnc);
trace!("num_fnc_args: {}", num_fnc_args);
trace!("input_activity.len(): {}", input_activity.len());
Expand All @@ -894,9 +893,16 @@ pub(crate) unsafe fn enzyme_rust_forward_diff(

let mut known_values = vec![kv_tmp; input_activity.len()];

let tree_tmp = TypeTree::new();
let mut args_tree = vec![tree_tmp.inner; input_activity.len()];

//let mut args_tree = vec![std::ptr::null_mut(); input_activity.len()];
//let ret_tt = std::ptr::null_mut();
//let mut args_tree = vec![TypeTree::new().inner; input_tts.len()];
let ret_tt = TypeTree::new();
let dummy_type = CFnTypeInfo {
Arguments: args_tree.as_mut_ptr(),
Return: output_tt.inner.clone(),
Return: ret_tt.inner,
KnownValues: known_values.as_mut_ptr(),
};

Expand Down Expand Up @@ -935,7 +941,7 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff(
rust_input_activity: Vec<DiffActivity>,
ret_activity: DiffActivity,
input_tts: Vec<TypeTree>,
output_tt: TypeTree,
_output_tt: TypeTree,
) -> (&Value, Vec<usize>) {
let (primary_ret, ret_activity) = match ret_activity {
DiffActivity::Const => (true, CDIFFE_TYPE::DFT_CONSTANT),
Expand All @@ -961,16 +967,11 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff(
input_activity.push(cdiffe_from(x));
}

let mut args_tree = input_tts.iter().map(|x| x.inner).collect::<Vec<_>>();
//let args_tree = input_tts.iter().map(|x| x.inner).collect::<Vec<_>>();

// We don't support volatile / extern / (global?) values.
// Just because I didn't had time to test them, and it seems less urgent.
let args_uncacheable = vec![0; input_tts.len()];
if args_uncacheable.len() != input_activity.len() {
dbg!("args_uncacheable.len(): {}", args_uncacheable.len());
dbg!("input_activity.len(): {}", input_activity.len());
}
assert!(args_uncacheable.len() == input_activity.len());
let args_uncacheable = vec![0; input_activity.len()];
let num_fnc_args = LLVMCountParams(fnc);
println!("num_fnc_args: {}", num_fnc_args);
println!("input_activity.len(): {}", input_activity.len());
Expand All @@ -979,9 +980,15 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff(

let mut known_values = vec![kv_tmp; input_tts.len()];

let tree_tmp = TypeTree::new();
let mut args_tree = vec![tree_tmp.inner; input_tts.len()];
//let mut args_tree = vec![TypeTree::new().inner; input_tts.len()];
let ret_tt = TypeTree::new();
//let mut args_tree = vec![std::ptr::null_mut(); input_tts.len()];
//let ret_tt = std::ptr::null_mut();
let dummy_type = CFnTypeInfo {
Arguments: args_tree.as_mut_ptr(),
Return: output_tt.inner.clone(),
Return: ret_tt.inner,
KnownValues: known_values.as_mut_ptr(),
};

Expand Down Expand Up @@ -1023,12 +1030,12 @@ extern "C" {
//pub fn LLVMEraseFromParent(BB: &BasicBlock) -> &Value;
// Enzyme
pub fn LLVMRustAddFncParamAttr<'a>(
Instr: &'a Value,
F: &'a Value,
index: c_uint,
Attr: &'a Attribute
);

pub fn LLVMRustAddRetAttr(V: &Value, attr: AttributeKind);
pub fn LLVMRustAddRetFncAttr(F: &Value, attr: &Attribute);
pub fn LLVMRustRemoveFncAttr(V: &Value, attr: AttributeKind);
pub fn LLVMRustHasDbgMetadata(I: &Value) -> bool;
pub fn LLVMRustHasMetadata(I: &Value, KindID: c_uint) -> bool;
Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -865,9 +865,9 @@ extern "C" void LLVMRustAddFncParamAttr(LLVMValueRef F, unsigned i,
}

extern "C" void LLVMRustAddRetFncAttr(LLVMValueRef F,
LLVMRustAttribute RustAttr) {
LLVMAttributeRef RustAttr) {
if (auto *Fn = dyn_cast<Function>(unwrap<Value>(F))) {
Fn->addRetAttr(fromRust(RustAttr));
Fn->addRetAttr(unwrap(RustAttr));
}
}

Expand Down

0 comments on commit 269d384

Please sign in to comment.