From 6283abf4785cecb1b174e30c8d047f0484dbcbe9 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Sat, 13 Jul 2024 16:00:52 -0400 Subject: [PATCH 1/7] fix wrapper --- compiler/rustc_codegen_llvm/src/llvm/ffi.rs | 2 +- compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index 260e66d2ce92a..27d9489bbfe3d 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -1028,7 +1028,7 @@ extern "C" { Attr: &'a Attribute ); - pub fn LLVMRustAddRetAttr(V: &Value, attr: AttributeKind); + pub fn LLVMRustAddRetFncAttr(V: &Value, attr: Attribute); pub fn LLVMRustRemoveFncAttr(V: &Value, attr: AttributeKind); pub fn LLVMRustHasDbgMetadata(I: &Value) -> bool; pub fn LLVMRustHasMetadata(I: &Value, KindID: c_uint) -> bool; diff --git a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp index 8f17c26f10177..43ed4c50613a3 100644 --- a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp +++ b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp @@ -865,9 +865,9 @@ extern "C" void LLVMRustAddFncParamAttr(LLVMValueRef F, unsigned i, } extern "C" void LLVMRustAddRetFncAttr(LLVMValueRef F, - LLVMRustAttribute RustAttr) { + LLVMAttributeRef RustAttr) { if (auto *Fn = dyn_cast(unwrap(F))) { - Fn->addRetAttr(fromRust(RustAttr)); + Fn->addRetAttr(unwrap(RustAttr)); } } From 94c6241403d27be2635ea4c04db7bac9ab39decd Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Sat, 13 Jul 2024 16:04:02 -0400 Subject: [PATCH 2/7] wip add_tt2 variant --- compiler/rustc_codegen_llvm/src/builder.rs | 43 ++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/compiler/rustc_codegen_llvm/src/builder.rs b/compiler/rustc_codegen_llvm/src/builder.rs index f7afe9cbefb7a..91c5df99e7b8e 100644 --- a/compiler/rustc_codegen_llvm/src/builder.rs +++ b/compiler/rustc_codegen_llvm/src/builder.rs @@ -136,6 +136,49 @@ macro_rules! builder_methods_for_value_instructions { })+ } } +fn add_tt2<'ll>(llmod: &'ll llvm::Module, llcx: &'ll llvm::Context, fn_def: &'ll Value, tt: FncTree) { + let inputs = tt.args; + let ret_tt: TypeTree = tt.ret; + let llvm_data_layout: *const c_char = unsafe { llvm::LLVMGetDataLayoutStr(&*llmod) }; + let llvm_data_layout = + std::str::from_utf8(unsafe { std::ffi::CStr::from_ptr(llvm_data_layout) }.to_bytes()) + .expect("got a non-UTF8 data-layout from LLVM"); + let attr_name = "enzyme_type"; + let c_attr_name = std::ffi::CString::new(attr_name).unwrap(); + for (i, &ref input) in inputs.iter().enumerate() { + let c_tt = to_enzyme_typetree(input.clone(), llvm_data_layout, llcx); + let c_str = unsafe { llvm::EnzymeTypeTreeToString(c_tt.inner) }; + let c_str = unsafe { std::ffi::CStr::from_ptr(c_str) }; + unsafe { + let attr = llvm::LLVMCreateStringAttribute( + llcx, + c_attr_name.as_ptr(), + c_attr_name.as_bytes().len() as c_uint, + c_str.as_ptr(), + c_str.to_bytes().len() as c_uint, + ); + llvm::LLVMRustAddFncParamAttr(fn_def, i as u32, attr); + } + unsafe { llvm::EnzymeTypeTreeToStringFree(c_str.as_ptr()) }; + } + let ret_attr = unsafe { + let c_tt = to_enzyme_typetree(ret_tt, llvm_data_layout, llcx); + let c_str = llvm::EnzymeTypeTreeToString(c_tt.inner); + let c_str = std::ffi::CStr::from_ptr(c_str); + let attr = llvm::LLVMCreateStringAttribute( + llcx, + c_attr_name.as_ptr(), + c_attr_name.as_bytes().len() as c_uint, + c_str.as_ptr(), + c_str.to_bytes().len() as c_uint, + ); + llvm::EnzymeTypeTreeToStringFree(c_str.as_ptr()); + attr + }; + unsafe { + llvm::LLVMRustAddRetAttr(fn_def, ret_attr); + } +} fn add_tt<'ll>(llmod: &'ll llvm::Module, llcx: &'ll llvm::Context,val: &'ll Value, tt: FncTree) { let inputs = tt.args; From 9783e02eb13950ee9cd3c01fdc4f73727755bf56 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Sat, 13 Jul 2024 16:24:09 -0400 Subject: [PATCH 3/7] add tt writes into module --- compiler/rustc_codegen_llvm/src/back/write.rs | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs index 657db58831c1b..ca4bc70c96bbb 100644 --- a/compiler/rustc_codegen_llvm/src/back/write.rs +++ b/compiler/rustc_codegen_llvm/src/back/write.rs @@ -55,6 +55,7 @@ use llvm::{ LLVMRustLLVMHasZlibCompressionForDebugSymbols, LLVMRustLLVMHasZstdCompressionForDebugSymbols, LLVMGetNextBasicBlock, }; use rustc_ast::expand::autodiff_attrs::{AutoDiffItem, DiffActivity, DiffMode}; +use rustc_ast::expand::typetree::FncTree; use rustc_codegen_ssa::back::link::ensure_removed; use rustc_codegen_ssa::back::write::{ BitcodeSection, CodegenContext, EmitObj, ModuleConfig, TargetMachineFactoryConfig, @@ -1091,6 +1092,24 @@ pub(crate) unsafe fn differentiate( llvm::set_loose_types(true); } + // Before dumping the module, we want all the tt to become part of the module. + for item in &diff_items { + let llvm_data_layout = unsafe { llvm::LLVMGetDataLayoutStr(&*llmod) }; + let llvm_data_layout = + std::str::from_utf8(unsafe { CStr::from_ptr(llvm_data_layout) }.to_bytes()) + .expect("got a non-UTF8 data-layout from LLVM"); + //let input_tts: Vec = + // item.inputs.iter().map(|x| to_enzyme_typetree(x.clone(), llvm_data_layout, llcx)).collect(); + //let output_tt = to_enzyme_typetree(item.output, llvm_data_layout, llcx); + let tt: FncTree = FncTree { + args: item.inputs, + ret: item.output, + }; + let name = CString::new(item.source.clone()).unwrap(); + let fn_def: llvm::Value = llvm::LLVMGetNamedFunction(llmod, name.as_ptr()).unwrap(); + add_tt2(llmod, llcx, fn_def, tt); + } + if std::env::var("ENZYME_PRINT_MOD_BEFORE").is_ok() { unsafe { LLVMDumpModule(llmod); From fd12dadf2e30a8c1b716ba278e128fc1d5767b02 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Sat, 13 Jul 2024 16:28:53 -0400 Subject: [PATCH 4/7] don't pass tt anymore through Enzyme API --- compiler/rustc_codegen_llvm/src/llvm/ffi.rs | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index 27d9489bbfe3d..b7cf37041b3c0 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -895,8 +895,10 @@ pub(crate) unsafe fn enzyme_rust_forward_diff( let mut known_values = vec![kv_tmp; input_activity.len()]; let dummy_type = CFnTypeInfo { - Arguments: args_tree.as_mut_ptr(), - Return: output_tt.inner.clone(), + Arguments: std::ptr::null_mut(), + Return: std::ptr::null_mut(), + //Arguments: args_tree.as_mut_ptr(), + //Return: output_tt.inner.clone(), KnownValues: known_values.as_mut_ptr(), }; @@ -980,8 +982,10 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff( let mut known_values = vec![kv_tmp; input_tts.len()]; let dummy_type = CFnTypeInfo { - Arguments: args_tree.as_mut_ptr(), - Return: output_tt.inner.clone(), + Arguments: std::ptr::null_mut(), + Return: std::ptr::null_mut(), + //Arguments: args_tree.as_mut_ptr(), + //Return: output_tt.inner.clone(), KnownValues: known_values.as_mut_ptr(), }; @@ -1023,12 +1027,12 @@ extern "C" { //pub fn LLVMEraseFromParent(BB: &BasicBlock) -> &Value; // Enzyme pub fn LLVMRustAddFncParamAttr<'a>( - Instr: &'a Value, + F: &'a Value, index: c_uint, Attr: &'a Attribute ); - pub fn LLVMRustAddRetFncAttr(V: &Value, attr: Attribute); + pub fn LLVMRustAddRetFncAttr(F: &Value, attr: Attribute); pub fn LLVMRustRemoveFncAttr(V: &Value, attr: AttributeKind); pub fn LLVMRustHasDbgMetadata(I: &Value) -> bool; pub fn LLVMRustHasMetadata(I: &Value, KindID: c_uint) -> bool; From 0d61b1bb4fe572cb7ba5ff2c3942fcc2fb6bcfe0 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Sat, 13 Jul 2024 17:20:47 -0400 Subject: [PATCH 5/7] fixup --- compiler/rustc_codegen_llvm/src/back/write.rs | 4 ++-- compiler/rustc_codegen_llvm/src/builder.rs | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs index ca4bc70c96bbb..1d88425c906ae 100644 --- a/compiler/rustc_codegen_llvm/src/back/write.rs +++ b/compiler/rustc_codegen_llvm/src/back/write.rs @@ -1106,8 +1106,8 @@ pub(crate) unsafe fn differentiate( ret: item.output, }; let name = CString::new(item.source.clone()).unwrap(); - let fn_def: llvm::Value = llvm::LLVMGetNamedFunction(llmod, name.as_ptr()).unwrap(); - add_tt2(llmod, llcx, fn_def, tt); + let fn_def: &llvm::Value = llvm::LLVMGetNamedFunction(llmod, name.as_ptr()).unwrap(); + crate::builder::add_tt2(llmod, llcx, fn_def, tt); } if std::env::var("ENZYME_PRINT_MOD_BEFORE").is_ok() { diff --git a/compiler/rustc_codegen_llvm/src/builder.rs b/compiler/rustc_codegen_llvm/src/builder.rs index 91c5df99e7b8e..733dd76ed105c 100644 --- a/compiler/rustc_codegen_llvm/src/builder.rs +++ b/compiler/rustc_codegen_llvm/src/builder.rs @@ -136,7 +136,7 @@ macro_rules! builder_methods_for_value_instructions { })+ } } -fn add_tt2<'ll>(llmod: &'ll llvm::Module, llcx: &'ll llvm::Context, fn_def: &'ll Value, tt: FncTree) { +pub fn add_tt2<'ll>(llmod: &'ll llvm::Module, llcx: &'ll llvm::Context, fn_def: &'ll Value, tt: FncTree) { let inputs = tt.args; let ret_tt: TypeTree = tt.ret; let llvm_data_layout: *const c_char = unsafe { llvm::LLVMGetDataLayoutStr(&*llmod) }; @@ -176,7 +176,7 @@ fn add_tt2<'ll>(llmod: &'ll llvm::Module, llcx: &'ll llvm::Context, fn_def: &'ll attr }; unsafe { - llvm::LLVMRustAddRetAttr(fn_def, ret_attr); + llvm::LLVMRustAddRetFncAttr(fn_def, ret_attr); } } From 9bf71926cd12b098e726805a22f8b9eb5cf2bbf8 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Sat, 13 Jul 2024 18:17:19 -0400 Subject: [PATCH 6/7] make it compile --- compiler/rustc_codegen_llvm/src/back/write.rs | 4 ++-- compiler/rustc_codegen_llvm/src/llvm/ffi.rs | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs index 1d88425c906ae..f437e52355ef0 100644 --- a/compiler/rustc_codegen_llvm/src/back/write.rs +++ b/compiler/rustc_codegen_llvm/src/back/write.rs @@ -1102,8 +1102,8 @@ pub(crate) unsafe fn differentiate( // item.inputs.iter().map(|x| to_enzyme_typetree(x.clone(), llvm_data_layout, llcx)).collect(); //let output_tt = to_enzyme_typetree(item.output, llvm_data_layout, llcx); let tt: FncTree = FncTree { - args: item.inputs, - ret: item.output, + args: item.inputs.clone(), + ret: item.output.clone(), }; let name = CString::new(item.source.clone()).unwrap(); let fn_def: &llvm::Value = llvm::LLVMGetNamedFunction(llmod, name.as_ptr()).unwrap(); diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index b7cf37041b3c0..c5efa5d5b1da5 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -849,7 +849,7 @@ pub(crate) unsafe fn enzyme_rust_forward_diff( input_diffactivity: Vec, ret_diffactivity: DiffActivity, input_tts: Vec, - output_tt: TypeTree, + _output_tt: TypeTree, void_ret: bool, ) -> (&Value, Vec) { let ret_activity = cdiffe_from(ret_diffactivity); @@ -878,7 +878,7 @@ pub(crate) unsafe fn enzyme_rust_forward_diff( }; trace!("ret_primary_ret: {}", &ret_primary_ret); - let mut args_tree = input_tts.iter().map(|x| x.inner).collect::>(); + //let mut args_tree = input_tts.iter().map(|x| x.inner).collect::>(); //let mut args_tree = vec![TypeTree::new().inner; typetree.input_tt.len()]; // We don't support volatile / extern / (global?) values. @@ -937,7 +937,7 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff( rust_input_activity: Vec, ret_activity: DiffActivity, input_tts: Vec, - output_tt: TypeTree, + _output_tt: TypeTree, ) -> (&Value, Vec) { let (primary_ret, ret_activity) = match ret_activity { DiffActivity::Const => (true, CDIFFE_TYPE::DFT_CONSTANT), @@ -963,7 +963,7 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff( input_activity.push(cdiffe_from(x)); } - let mut args_tree = input_tts.iter().map(|x| x.inner).collect::>(); + //let args_tree = input_tts.iter().map(|x| x.inner).collect::>(); // We don't support volatile / extern / (global?) values. // Just because I didn't had time to test them, and it seems less urgent. @@ -1032,7 +1032,7 @@ extern "C" { Attr: &'a Attribute ); - pub fn LLVMRustAddRetFncAttr(F: &Value, attr: Attribute); + pub fn LLVMRustAddRetFncAttr(F: &Value, attr: &Attribute); pub fn LLVMRustRemoveFncAttr(V: &Value, attr: AttributeKind); pub fn LLVMRustHasDbgMetadata(I: &Value) -> bool; pub fn LLVMRustHasMetadata(I: &Value, KindID: c_uint) -> bool; From 6f71350b282a9221c3ccd80a68a36ddb9a89f53c Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Mon, 15 Jul 2024 19:24:12 -0400 Subject: [PATCH 7/7] now working againg --- compiler/rustc_codegen_llvm/src/llvm/ffi.rs | 37 +++++++++++---------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index c5efa5d5b1da5..0bf2eb5fb3a2f 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -848,7 +848,7 @@ pub(crate) unsafe fn enzyme_rust_forward_diff( fnc: &Value, input_diffactivity: Vec, ret_diffactivity: DiffActivity, - input_tts: Vec, + _input_tts: Vec, _output_tt: TypeTree, void_ret: bool, ) -> (&Value, Vec) { @@ -883,8 +883,7 @@ pub(crate) unsafe fn enzyme_rust_forward_diff( // We don't support volatile / extern / (global?) values. // Just because I didn't had time to test them, and it seems less urgent. - let args_uncacheable = vec![0; input_tts.len()]; - assert!(args_uncacheable.len() == input_activity.len()); + let args_uncacheable = vec![0; input_activity.len()]; let num_fnc_args = LLVMCountParams(fnc); trace!("num_fnc_args: {}", num_fnc_args); trace!("input_activity.len(): {}", input_activity.len()); @@ -894,11 +893,16 @@ pub(crate) unsafe fn enzyme_rust_forward_diff( let mut known_values = vec![kv_tmp; input_activity.len()]; + let tree_tmp = TypeTree::new(); + let mut args_tree = vec![tree_tmp.inner; input_activity.len()]; + + //let mut args_tree = vec![std::ptr::null_mut(); input_activity.len()]; + //let ret_tt = std::ptr::null_mut(); + //let mut args_tree = vec![TypeTree::new().inner; input_tts.len()]; + let ret_tt = TypeTree::new(); let dummy_type = CFnTypeInfo { - Arguments: std::ptr::null_mut(), - Return: std::ptr::null_mut(), - //Arguments: args_tree.as_mut_ptr(), - //Return: output_tt.inner.clone(), + Arguments: args_tree.as_mut_ptr(), + Return: ret_tt.inner, KnownValues: known_values.as_mut_ptr(), }; @@ -967,12 +971,7 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff( // We don't support volatile / extern / (global?) values. // Just because I didn't had time to test them, and it seems less urgent. - let args_uncacheable = vec![0; input_tts.len()]; - if args_uncacheable.len() != input_activity.len() { - dbg!("args_uncacheable.len(): {}", args_uncacheable.len()); - dbg!("input_activity.len(): {}", input_activity.len()); - } - assert!(args_uncacheable.len() == input_activity.len()); + let args_uncacheable = vec![0; input_activity.len()]; let num_fnc_args = LLVMCountParams(fnc); println!("num_fnc_args: {}", num_fnc_args); println!("input_activity.len(): {}", input_activity.len()); @@ -981,11 +980,15 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff( let mut known_values = vec![kv_tmp; input_tts.len()]; + let tree_tmp = TypeTree::new(); + let mut args_tree = vec![tree_tmp.inner; input_tts.len()]; + //let mut args_tree = vec![TypeTree::new().inner; input_tts.len()]; + let ret_tt = TypeTree::new(); + //let mut args_tree = vec![std::ptr::null_mut(); input_tts.len()]; + //let ret_tt = std::ptr::null_mut(); let dummy_type = CFnTypeInfo { - Arguments: std::ptr::null_mut(), - Return: std::ptr::null_mut(), - //Arguments: args_tree.as_mut_ptr(), - //Return: output_tt.inner.clone(), + Arguments: args_tree.as_mut_ptr(), + Return: ret_tt.inner, KnownValues: known_values.as_mut_ptr(), };