Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor fnc tt #133

Merged
merged 7 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions compiler/rustc_codegen_llvm/src/back/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ use llvm::{
LLVMRustLLVMHasZlibCompressionForDebugSymbols, LLVMRustLLVMHasZstdCompressionForDebugSymbols, LLVMGetNextBasicBlock,
};
use rustc_ast::expand::autodiff_attrs::{AutoDiffItem, DiffActivity, DiffMode};
use rustc_ast::expand::typetree::FncTree;
use rustc_codegen_ssa::back::link::ensure_removed;
use rustc_codegen_ssa::back::write::{
BitcodeSection, CodegenContext, EmitObj, ModuleConfig, TargetMachineFactoryConfig,
Expand Down Expand Up @@ -1091,6 +1092,24 @@ pub(crate) unsafe fn differentiate(
llvm::set_loose_types(true);
}

// Before dumping the module, we want all the tt to become part of the module.
for item in &diff_items {
let llvm_data_layout = unsafe { llvm::LLVMGetDataLayoutStr(&*llmod) };
let llvm_data_layout =
std::str::from_utf8(unsafe { CStr::from_ptr(llvm_data_layout) }.to_bytes())
.expect("got a non-UTF8 data-layout from LLVM");
//let input_tts: Vec<TypeTree> =
// item.inputs.iter().map(|x| to_enzyme_typetree(x.clone(), llvm_data_layout, llcx)).collect();
//let output_tt = to_enzyme_typetree(item.output, llvm_data_layout, llcx);
let tt: FncTree = FncTree {
args: item.inputs.clone(),
ret: item.output.clone(),
};
let name = CString::new(item.source.clone()).unwrap();
let fn_def: &llvm::Value = llvm::LLVMGetNamedFunction(llmod, name.as_ptr()).unwrap();
crate::builder::add_tt2(llmod, llcx, fn_def, tt);
}

if std::env::var("ENZYME_PRINT_MOD_BEFORE").is_ok() {
unsafe {
LLVMDumpModule(llmod);
Expand Down
43 changes: 43 additions & 0 deletions compiler/rustc_codegen_llvm/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,49 @@ macro_rules! builder_methods_for_value_instructions {
})+
}
}
pub fn add_tt2<'ll>(llmod: &'ll llvm::Module, llcx: &'ll llvm::Context, fn_def: &'ll Value, tt: FncTree) {
let inputs = tt.args;
let ret_tt: TypeTree = tt.ret;
let llvm_data_layout: *const c_char = unsafe { llvm::LLVMGetDataLayoutStr(&*llmod) };
let llvm_data_layout =
std::str::from_utf8(unsafe { std::ffi::CStr::from_ptr(llvm_data_layout) }.to_bytes())
.expect("got a non-UTF8 data-layout from LLVM");
let attr_name = "enzyme_type";
let c_attr_name = std::ffi::CString::new(attr_name).unwrap();
for (i, &ref input) in inputs.iter().enumerate() {
let c_tt = to_enzyme_typetree(input.clone(), llvm_data_layout, llcx);
let c_str = unsafe { llvm::EnzymeTypeTreeToString(c_tt.inner) };
let c_str = unsafe { std::ffi::CStr::from_ptr(c_str) };
unsafe {
let attr = llvm::LLVMCreateStringAttribute(
llcx,
c_attr_name.as_ptr(),
c_attr_name.as_bytes().len() as c_uint,
c_str.as_ptr(),
c_str.to_bytes().len() as c_uint,
);
llvm::LLVMRustAddFncParamAttr(fn_def, i as u32, attr);
}
unsafe { llvm::EnzymeTypeTreeToStringFree(c_str.as_ptr()) };
}
let ret_attr = unsafe {
let c_tt = to_enzyme_typetree(ret_tt, llvm_data_layout, llcx);
let c_str = llvm::EnzymeTypeTreeToString(c_tt.inner);
let c_str = std::ffi::CStr::from_ptr(c_str);
let attr = llvm::LLVMCreateStringAttribute(
llcx,
c_attr_name.as_ptr(),
c_attr_name.as_bytes().len() as c_uint,
c_str.as_ptr(),
c_str.to_bytes().len() as c_uint,
);
llvm::EnzymeTypeTreeToStringFree(c_str.as_ptr());
attr
};
unsafe {
llvm::LLVMRustAddRetFncAttr(fn_def, ret_attr);
}
}

fn add_tt<'ll>(llmod: &'ll llvm::Module, llcx: &'ll llvm::Context,val: &'ll Value, tt: FncTree) {
let inputs = tt.args;
Expand Down
24 changes: 14 additions & 10 deletions compiler/rustc_codegen_llvm/src/llvm/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -849,7 +849,7 @@ pub(crate) unsafe fn enzyme_rust_forward_diff(
input_diffactivity: Vec<DiffActivity>,
ret_diffactivity: DiffActivity,
input_tts: Vec<TypeTree>,
output_tt: TypeTree,
_output_tt: TypeTree,
void_ret: bool,
) -> (&Value, Vec<usize>) {
let ret_activity = cdiffe_from(ret_diffactivity);
Expand Down Expand Up @@ -878,7 +878,7 @@ pub(crate) unsafe fn enzyme_rust_forward_diff(
};
trace!("ret_primary_ret: {}", &ret_primary_ret);

let mut args_tree = input_tts.iter().map(|x| x.inner).collect::<Vec<_>>();
//let mut args_tree = input_tts.iter().map(|x| x.inner).collect::<Vec<_>>();
//let mut args_tree = vec![TypeTree::new().inner; typetree.input_tt.len()];

// We don't support volatile / extern / (global?) values.
Expand All @@ -895,8 +895,10 @@ pub(crate) unsafe fn enzyme_rust_forward_diff(
let mut known_values = vec![kv_tmp; input_activity.len()];

let dummy_type = CFnTypeInfo {
Arguments: args_tree.as_mut_ptr(),
Return: output_tt.inner.clone(),
Arguments: std::ptr::null_mut(),
Return: std::ptr::null_mut(),
//Arguments: args_tree.as_mut_ptr(),
//Return: output_tt.inner.clone(),
KnownValues: known_values.as_mut_ptr(),
};

Expand Down Expand Up @@ -935,7 +937,7 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff(
rust_input_activity: Vec<DiffActivity>,
ret_activity: DiffActivity,
input_tts: Vec<TypeTree>,
output_tt: TypeTree,
_output_tt: TypeTree,
) -> (&Value, Vec<usize>) {
let (primary_ret, ret_activity) = match ret_activity {
DiffActivity::Const => (true, CDIFFE_TYPE::DFT_CONSTANT),
Expand All @@ -961,7 +963,7 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff(
input_activity.push(cdiffe_from(x));
}

let mut args_tree = input_tts.iter().map(|x| x.inner).collect::<Vec<_>>();
//let args_tree = input_tts.iter().map(|x| x.inner).collect::<Vec<_>>();

// We don't support volatile / extern / (global?) values.
// Just because I didn't had time to test them, and it seems less urgent.
Expand All @@ -980,8 +982,10 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff(
let mut known_values = vec![kv_tmp; input_tts.len()];

let dummy_type = CFnTypeInfo {
Arguments: args_tree.as_mut_ptr(),
Return: output_tt.inner.clone(),
Arguments: std::ptr::null_mut(),
ZuseZ4 marked this conversation as resolved.
Show resolved Hide resolved
Return: std::ptr::null_mut(),
//Arguments: args_tree.as_mut_ptr(),
//Return: output_tt.inner.clone(),
KnownValues: known_values.as_mut_ptr(),
};

Expand Down Expand Up @@ -1023,12 +1027,12 @@ extern "C" {
//pub fn LLVMEraseFromParent(BB: &BasicBlock) -> &Value;
// Enzyme
pub fn LLVMRustAddFncParamAttr<'a>(
Instr: &'a Value,
F: &'a Value,
index: c_uint,
Attr: &'a Attribute
);

pub fn LLVMRustAddRetAttr(V: &Value, attr: AttributeKind);
pub fn LLVMRustAddRetFncAttr(F: &Value, attr: &Attribute);
pub fn LLVMRustRemoveFncAttr(V: &Value, attr: AttributeKind);
pub fn LLVMRustHasDbgMetadata(I: &Value) -> bool;
pub fn LLVMRustHasMetadata(I: &Value, KindID: c_uint) -> bool;
Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -865,9 +865,9 @@ extern "C" void LLVMRustAddFncParamAttr(LLVMValueRef F, unsigned i,
}

extern "C" void LLVMRustAddRetFncAttr(LLVMValueRef F,
LLVMRustAttribute RustAttr) {
LLVMAttributeRef RustAttr) {
if (auto *Fn = dyn_cast<Function>(unwrap<Value>(F))) {
Fn->addRetAttr(fromRust(RustAttr));
Fn->addRetAttr(unwrap(RustAttr));
}
}

Expand Down
Loading