From 79d1834cdbbda470ba5bb8bcb084ff7c9895e176 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Mon, 25 Dec 2023 22:32:34 -0500 Subject: [PATCH 001/100] wip --- compiler/rustc_builtin_macros/src/autodiff.rs | 97 +++++++++++++++++++ compiler/rustc_builtin_macros/src/lib.rs | 3 + compiler/rustc_expand/src/base.rs | 2 + compiler/rustc_resolve/src/macros.rs | 2 + compiler/rustc_span/src/symbol.rs | 2 + library/core/src/macros/mod.rs | 13 +++ library/core/src/prelude/v1.rs | 5 + library/std/src/lib.rs | 13 +++ library/std/src/prelude/v1.rs | 5 + 9 files changed, 142 insertions(+) create mode 100644 compiler/rustc_builtin_macros/src/autodiff.rs diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs new file mode 100644 index 0000000000000..27f321930fa35 --- /dev/null +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -0,0 +1,97 @@ +use crate::errors; +use crate::util::check_builtin_macro_attribute; + +use rustc_ast::ptr::P; +use rustc_ast::{self as ast, FnHeader, FnSig, Generics, StmtKind}; +use rustc_ast::{Fn, ItemKind, Stmt, TyKind, Unsafe}; +use rustc_expand::base::{Annotatable, ExtCtxt}; +use rustc_span::symbol::{kw, sym, Ident}; +use rustc_span::Span; +use thin_vec::{thin_vec, ThinVec}; + +pub fn expand( + ecx: &mut ExtCtxt<'_>, + _span: Span, + meta_item: &ast::MetaItem, + item: Annotatable, +) -> Vec { + //check_builtin_macro_attribute(ecx, meta_item, sym::alloc_error_handler); + //check_builtin_macro_attribute(ecx, meta_item, sym::autodiff); + + let orig_item = item.clone(); + + // Allow using `#[alloc_error_handler]` on an item statement + // FIXME - if we get deref patterns, use them to reduce duplication here + let (item, is_stmt, sig_span) = if let Annotatable::Item(item) = &item + && let ItemKind::Fn(fn_kind) = &item.kind + { + (item, false, ecx.with_def_site_ctxt(fn_kind.sig.span)) + } else if let Annotatable::Stmt(stmt) = &item + && let StmtKind::Item(item) = &stmt.kind + && let ItemKind::Fn(fn_kind) = &item.kind + { + (item, true, ecx.with_def_site_ctxt(fn_kind.sig.span)) + } else { + ecx.sess.dcx().emit_err(errors::AllocErrorMustBeFn { span: item.span() }); + return vec![orig_item]; + }; + + // Generate a bunch of new items using the AllocFnFactory + let span = ecx.with_def_site_ctxt(item.span); + + // Generate item statements for the allocator methods. + let stmts = thin_vec![generate_handler(ecx, item.ident, span, sig_span)]; + + // Generate anonymous constant serving as container for the allocator methods. + let const_ty = ecx.ty(sig_span, TyKind::Tup(ThinVec::new())); + let const_body = ecx.expr_block(ecx.block(span, stmts)); + let const_item = ecx.item_const(span, Ident::new(kw::Underscore, span), const_ty, const_body); + let const_item = if is_stmt { + Annotatable::Stmt(P(ecx.stmt_item(span, const_item))) + } else { + Annotatable::Item(const_item) + }; + + // Return the original item and the new methods. + vec![orig_item, const_item] +} + +// #[rustc_std_internal_symbol] +// unsafe fn __rg_oom(size: usize, align: usize) -> ! { +// handler(core::alloc::Layout::from_size_align_unchecked(size, align)) +// } +fn generate_handler(cx: &ExtCtxt<'_>, handler: Ident, span: Span, sig_span: Span) -> Stmt { + let usize = cx.path_ident(span, Ident::new(sym::usize, span)); + let ty_usize = cx.ty_path(usize); + let size = Ident::from_str_and_span("size", span); + let align = Ident::from_str_and_span("align", span); + + let layout_new = cx.std_path(&[sym::alloc, sym::Layout, sym::from_size_align_unchecked]); + let layout_new = cx.expr_path(cx.path(span, layout_new)); + let layout = cx.expr_call( + span, + layout_new, + thin_vec![cx.expr_ident(span, size), cx.expr_ident(span, align)], + ); + + let call = cx.expr_call_ident(sig_span, handler, thin_vec![layout]); + + let never = ast::FnRetTy::Ty(cx.ty(span, TyKind::Never)); + let params = thin_vec![cx.param(span, size, ty_usize.clone()), cx.param(span, align, ty_usize)]; + let decl = cx.fn_decl(params, never); + let header = FnHeader { unsafety: Unsafe::Yes(span), ..FnHeader::default() }; + let sig = FnSig { decl, header, span: span }; + + let body = Some(cx.block_expr(call)); + let kind = ItemKind::Fn(Box::new(Fn { + defaultness: ast::Defaultness::Final, + sig, + generics: Generics::default(), + body, + })); + + let attrs = thin_vec![cx.attr_word(sym::rustc_std_internal_symbol, span)]; + + let item = cx.item(span, Ident::from_str_and_span("__rg_oom", span), attrs, kind); + cx.stmt_item(sig_span, item) +} diff --git a/compiler/rustc_builtin_macros/src/lib.rs b/compiler/rustc_builtin_macros/src/lib.rs index f60b73fbe9b13..7c64c148b8fd2 100644 --- a/compiler/rustc_builtin_macros/src/lib.rs +++ b/compiler/rustc_builtin_macros/src/lib.rs @@ -14,6 +14,7 @@ #![feature(lint_reasons)] #![feature(proc_macro_internals)] #![feature(proc_macro_quote)] +#![cfg_attr(not(bootstrap), feature(autodiff))] #![recursion_limit = "256"] extern crate proc_macro; @@ -29,6 +30,7 @@ use rustc_span::symbol::sym; mod alloc_error_handler; mod assert; +mod autodiff; mod cfg; mod cfg_accessible; mod cfg_eval; @@ -105,6 +107,7 @@ pub fn register_builtin_macros(resolver: &mut dyn ResolverExpand) { register_attr! { alloc_error_handler: alloc_error_handler::expand, + autodiff: autodiff::expand, bench: test::expand_bench, cfg_accessible: cfg_accessible::Expander, cfg_eval: cfg_eval::expand, diff --git a/compiler/rustc_expand/src/base.rs b/compiler/rustc_expand/src/base.rs index 1fd4d2d55dd81..fff7810ea76a2 100644 --- a/compiler/rustc_expand/src/base.rs +++ b/compiler/rustc_expand/src/base.rs @@ -917,6 +917,8 @@ pub trait ResolverExpand { fn check_unused_macros(&mut self); + //fn autodiff() -> bool; + // Resolver interfaces for specific built-in macros. /// Does `#[derive(...)]` attribute with the given `ExpnId` have built-in `Copy` inside it? fn has_derive_copy(&self, expn_id: LocalExpnId) -> bool; diff --git a/compiler/rustc_resolve/src/macros.rs b/compiler/rustc_resolve/src/macros.rs index 1001286b6c2d7..7ea8bf31d095b 100644 --- a/compiler/rustc_resolve/src/macros.rs +++ b/compiler/rustc_resolve/src/macros.rs @@ -344,6 +344,8 @@ impl<'a, 'tcx> ResolverExpand for Resolver<'a, 'tcx> { self.containers_deriving_copy.contains(&expn_id) } + // TODO: add autodiff? + fn resolve_derives( &mut self, expn_id: LocalExpnId, diff --git a/compiler/rustc_span/src/symbol.rs b/compiler/rustc_span/src/symbol.rs index 95106cc64c129..60eae2c3db2d6 100644 --- a/compiler/rustc_span/src/symbol.rs +++ b/compiler/rustc_span/src/symbol.rs @@ -69,6 +69,7 @@ symbols! { // Keywords that are used in unstable Rust or reserved for future use. Abstract: "abstract", + //Autodiff: "autodiff", Become: "become", Box: "box", Do: "do", @@ -438,6 +439,7 @@ symbols! { attributes, augmented_assignments, auto_traits, + autodiff, automatically_derived, avx, avx512_target_feature, diff --git a/library/core/src/macros/mod.rs b/library/core/src/macros/mod.rs index 7f5908e477cfd..510c53f1fafcf 100644 --- a/library/core/src/macros/mod.rs +++ b/library/core/src/macros/mod.rs @@ -1486,6 +1486,19 @@ pub(crate) mod builtin { ($file:expr $(,)?) => {{ /* compiler built-in */ }}; } + /// Attribute macro used to apply derive macros for implementing traits + /// in a const context. Autodiff + /// + /// See [the reference] for more info. + /// + /// [the reference]: ../../../reference/attributes/derive.html + #[unstable(feature = "autodiff", issue = "none")] + #[rustc_builtin_macro] + #[cfg(not(bootstrap))] + pub macro autodiff($item:item) { + /* compiler built-in */ + } + /// Asserts that a boolean expression is `true` at runtime. /// /// This will invoke the [`panic!`] macro if the provided expression cannot be diff --git a/library/core/src/prelude/v1.rs b/library/core/src/prelude/v1.rs index 10525a16f3a66..a934e62086c8e 100644 --- a/library/core/src/prelude/v1.rs +++ b/library/core/src/prelude/v1.rs @@ -83,6 +83,11 @@ pub use crate::macros::builtin::{ #[unstable(feature = "derive_const", issue = "none")] pub use crate::macros::builtin::derive_const; +#[cfg(not(bootstrap))] +#[unstable(feature = "autodiff", issue = "none")] +#[rustc_builtin_macro] +pub use crate::macros::builtin::autodiff; + #[unstable( feature = "cfg_accessible", issue = "64797", diff --git a/library/std/src/lib.rs b/library/std/src/lib.rs index 6365366297c43..2b7d685e8d0a2 100644 --- a/library/std/src/lib.rs +++ b/library/std/src/lib.rs @@ -255,6 +255,8 @@ #![allow(unused_features)] // // Features: + +#![cfg_attr(not(bootstrap), feature(autodiff))] #![cfg_attr(test, feature(internal_output_capture, print_internals, update_panic_count, rt))] #![cfg_attr( all(target_vendor = "fortanix", target_env = "sgx"), @@ -662,6 +664,17 @@ pub use core::{ module_path, option_env, stringify, trace_macros, }; + +// #[unstable( +// feature = "autodiff", +// issue = "87555", +// reason = "`autodiff` is not stable enough for use and is subject to change" +// )] +// #[cfg(not(bootstrap))] +// pub use core::autodiff; + + + #[unstable( feature = "concat_bytes", issue = "87555", diff --git a/library/std/src/prelude/v1.rs b/library/std/src/prelude/v1.rs index 7a7a773763559..0be8fa35e100d 100644 --- a/library/std/src/prelude/v1.rs +++ b/library/std/src/prelude/v1.rs @@ -67,6 +67,11 @@ pub use core::prelude::v1::{ #[unstable(feature = "derive_const", issue = "none")] pub use core::prelude::v1::derive_const; +#[unstable(feature = "autodiff", issue = "none")] +#[cfg(not(bootstrap))] +#[rustc_builtin_macro] +pub use core::prelude::v1::autodiff; + // Do not `doc(no_inline)` either. #[unstable( feature = "cfg_accessible", From 12292caa4326447a262a836b9f6ae2d1112e7f2c Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Tue, 26 Dec 2023 01:04:32 -0500 Subject: [PATCH 002/100] got autodiff macro to do something --- compiler/rustc_builtin_macros/messages.ftl | 2 + compiler/rustc_builtin_macros/src/autodiff.rs | 126 ++++++++---------- compiler/rustc_builtin_macros/src/errors.rs | 7 + 3 files changed, 67 insertions(+), 68 deletions(-) diff --git a/compiler/rustc_builtin_macros/messages.ftl b/compiler/rustc_builtin_macros/messages.ftl index dda466b026d91..07c9b588e1faf 100644 --- a/compiler/rustc_builtin_macros/messages.ftl +++ b/compiler/rustc_builtin_macros/messages.ftl @@ -1,6 +1,8 @@ builtin_macros_alloc_error_must_be_fn = alloc_error_handler must be a function builtin_macros_alloc_must_statics = allocators must be statics +builtin_macros_autodiff = autodiff must be applied to function + builtin_macros_asm_clobber_abi = clobber_abi builtin_macros_asm_clobber_no_reg = asm with `clobber_abi` must specify explicit registers for outputs builtin_macros_asm_clobber_outputs = generic outputs diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index 27f321930fa35..5a737c74c5ae1 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -1,5 +1,7 @@ +#![allow(unused)] + use crate::errors; -use crate::util::check_builtin_macro_attribute; +//use crate::util::check_builtin_macro_attribute; use rustc_ast::ptr::P; use rustc_ast::{self as ast, FnHeader, FnSig, Generics, StmtKind}; @@ -8,6 +10,7 @@ use rustc_expand::base::{Annotatable, ExtCtxt}; use rustc_span::symbol::{kw, sym, Ident}; use rustc_span::Span; use thin_vec::{thin_vec, ThinVec}; +use rustc_span::Symbol; pub fn expand( ecx: &mut ExtCtxt<'_>, @@ -18,80 +21,67 @@ pub fn expand( //check_builtin_macro_attribute(ecx, meta_item, sym::alloc_error_handler); //check_builtin_macro_attribute(ecx, meta_item, sym::autodiff); - let orig_item = item.clone(); + dbg!(&meta_item); + let input = item.clone(); + let orig_item: P = item.clone().expect_item(); + let mut d_item: P = item.clone().expect_item(); - // Allow using `#[alloc_error_handler]` on an item statement - // FIXME - if we get deref patterns, use them to reduce duplication here - let (item, is_stmt, sig_span) = if let Annotatable::Item(item) = &item - && let ItemKind::Fn(fn_kind) = &item.kind - { - (item, false, ecx.with_def_site_ctxt(fn_kind.sig.span)) - } else if let Annotatable::Stmt(stmt) = &item - && let StmtKind::Item(item) = &stmt.kind - && let ItemKind::Fn(fn_kind) = &item.kind + // Allow using `#[autodiff(...)]` on a Fn + let (fn_item, _ty_span) = if let Annotatable::Item(item) = &item + && let ItemKind::Fn(box ast::Fn { sig, .. }) = &item.kind { - (item, true, ecx.with_def_site_ctxt(fn_kind.sig.span)) - } else { - ecx.sess.dcx().emit_err(errors::AllocErrorMustBeFn { span: item.span() }); - return vec![orig_item]; - }; - - // Generate a bunch of new items using the AllocFnFactory - let span = ecx.with_def_site_ctxt(item.span); - - // Generate item statements for the allocator methods. - let stmts = thin_vec![generate_handler(ecx, item.ident, span, sig_span)]; - - // Generate anonymous constant serving as container for the allocator methods. - let const_ty = ecx.ty(sig_span, TyKind::Tup(ThinVec::new())); - let const_body = ecx.expr_block(ecx.block(span, stmts)); - let const_item = ecx.item_const(span, Ident::new(kw::Underscore, span), const_ty, const_body); - let const_item = if is_stmt { - Annotatable::Stmt(P(ecx.stmt_item(span, const_item))) + dbg!(&item); + (item, ecx.with_def_site_ctxt(sig.span)) } else { - Annotatable::Item(const_item) + ecx.sess + .dcx() + .emit_err(errors::AutoDiffInvalidApplication { span: item.span() }); + return vec![input]; }; - - // Return the original item and the new methods. - vec![orig_item, const_item] + let _x: &ItemKind = &fn_item.kind; + d_item.ident.name = + Symbol::intern(format!("d_{}", fn_item.ident.name).as_str()); + let orig_annotatable = Annotatable::Item(orig_item.clone()); + let d_annotatable = Annotatable::Item(d_item.clone()); + return vec![orig_annotatable, d_annotatable]; } // #[rustc_std_internal_symbol] // unsafe fn __rg_oom(size: usize, align: usize) -> ! { // handler(core::alloc::Layout::from_size_align_unchecked(size, align)) // } -fn generate_handler(cx: &ExtCtxt<'_>, handler: Ident, span: Span, sig_span: Span) -> Stmt { - let usize = cx.path_ident(span, Ident::new(sym::usize, span)); - let ty_usize = cx.ty_path(usize); - let size = Ident::from_str_and_span("size", span); - let align = Ident::from_str_and_span("align", span); - - let layout_new = cx.std_path(&[sym::alloc, sym::Layout, sym::from_size_align_unchecked]); - let layout_new = cx.expr_path(cx.path(span, layout_new)); - let layout = cx.expr_call( - span, - layout_new, - thin_vec![cx.expr_ident(span, size), cx.expr_ident(span, align)], - ); - - let call = cx.expr_call_ident(sig_span, handler, thin_vec![layout]); - - let never = ast::FnRetTy::Ty(cx.ty(span, TyKind::Never)); - let params = thin_vec![cx.param(span, size, ty_usize.clone()), cx.param(span, align, ty_usize)]; - let decl = cx.fn_decl(params, never); - let header = FnHeader { unsafety: Unsafe::Yes(span), ..FnHeader::default() }; - let sig = FnSig { decl, header, span: span }; - - let body = Some(cx.block_expr(call)); - let kind = ItemKind::Fn(Box::new(Fn { - defaultness: ast::Defaultness::Final, - sig, - generics: Generics::default(), - body, - })); - - let attrs = thin_vec![cx.attr_word(sym::rustc_std_internal_symbol, span)]; - - let item = cx.item(span, Ident::from_str_and_span("__rg_oom", span), attrs, kind); - cx.stmt_item(sig_span, item) -} +//fn generate_handler(cx: &ExtCtxt<'_>, handler: Ident, span: Span, sig_span: Span) -> Stmt { +// let usize = cx.path_ident(span, Ident::new(sym::usize, span)); +// let ty_usize = cx.ty_path(usize); +// let size = Ident::from_str_and_span("size", span); +// let align = Ident::from_str_and_span("align", span); +// +// let layout_new = cx.std_path(&[sym::alloc, sym::Layout, sym::from_size_align_unchecked]); +// let layout_new = cx.expr_path(cx.path(span, layout_new)); +// let layout = cx.expr_call( +// span, +// layout_new, +// thin_vec![cx.expr_ident(span, size), cx.expr_ident(span, align)], +// ); +// +// let call = cx.expr_call_ident(sig_span, handler, thin_vec![layout]); +// +// let never = ast::FnRetTy::Ty(cx.ty(span, TyKind::Never)); +// let params = thin_vec![cx.param(span, size, ty_usize.clone()), cx.param(span, align, ty_usize)]; +// let decl = cx.fn_decl(params, never); +// let header = FnHeader { unsafety: Unsafe::Yes(span), ..FnHeader::default() }; +// let sig = FnSig { decl, header, span: span }; +// +// let body = Some(cx.block_expr(call)); +// let kind = ItemKind::Fn(Box::new(Fn { +// defaultness: ast::Defaultness::Final, +// sig, +// generics: Generics::default(), +// body, +// })); +// +// let attrs = thin_vec![cx.attr_word(sym::rustc_std_internal_symbol, span)]; +// +// let item = cx.item(span, Ident::from_str_and_span("__rg_oom", span), attrs, kind); +// cx.stmt_item(sig_span, item) +//} diff --git a/compiler/rustc_builtin_macros/src/errors.rs b/compiler/rustc_builtin_macros/src/errors.rs index e07eb2e490b71..7c032d98d5d8d 100644 --- a/compiler/rustc_builtin_macros/src/errors.rs +++ b/compiler/rustc_builtin_macros/src/errors.rs @@ -164,6 +164,13 @@ pub(crate) struct AllocMustStatics { pub(crate) span: Span, } +#[derive(Diagnostic)] +#[diag(builtin_macros_autodiff)] +pub(crate) struct AutoDiffInvalidApplication { + #[primary_span] + pub(crate) span: Span, +} + #[derive(Diagnostic)] #[diag(builtin_macros_concat_bytes_invalid)] pub(crate) struct ConcatBytesInvalid { From 175d236df0137b8bcf3d272218434d18990f0084 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Tue, 26 Dec 2023 01:30:22 -0500 Subject: [PATCH 003/100] cleanup, first helper --- .../rustc_ast/src/expand/autodiff_attrs.rs | 156 ++++++++++++++++++ compiler/rustc_ast/src/expand/typetree.rs | 68 ++++++++ compiler/rustc_builtin_macros/src/autodiff.rs | 1 + compiler/rustc_passes/messages.ftl | 4 + compiler/rustc_passes/src/check_attr.rs | 13 ++ compiler/rustc_passes/src/errors.rs | 8 + 6 files changed, 250 insertions(+) create mode 100644 compiler/rustc_ast/src/expand/autodiff_attrs.rs create mode 100644 compiler/rustc_ast/src/expand/typetree.rs diff --git a/compiler/rustc_ast/src/expand/autodiff_attrs.rs b/compiler/rustc_ast/src/expand/autodiff_attrs.rs new file mode 100644 index 0000000000000..2126868751b21 --- /dev/null +++ b/compiler/rustc_ast/src/expand/autodiff_attrs.rs @@ -0,0 +1,156 @@ +use super::typetree::TypeTree; +use std::str::FromStr; +use rustc_data_structures::stable_hasher::{HashStable, StableHasher};//, StableOrd}; +use crate::HashStableContext; + +#[allow(dead_code)] +#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug)] +pub enum DiffMode { + Inactive, + Source, + Forward, + Reverse, +} + +#[allow(dead_code)] +#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug)] +pub enum DiffActivity { + None, + Active, + Const, + Duplicated, + DuplicatedNoNeed, +} +fn clause_diffactivity_discriminant(value: &DiffActivity) -> usize { + match value { + DiffActivity::None => 0, + DiffActivity::Active => 1, + DiffActivity::Const => 2, + DiffActivity::Duplicated => 3, + DiffActivity::DuplicatedNoNeed => 4, + } +} +fn clause_diffmode_discriminant(value: &DiffMode) -> usize { + match value { + DiffMode::Inactive => 0, + DiffMode::Source => 1, + DiffMode::Forward => 2, + DiffMode::Reverse => 3, + } +} + + +impl HashStable for DiffMode { + fn hash_stable(&self, hcx: &mut CTX, hasher: &mut StableHasher) { + clause_diffmode_discriminant(self).hash_stable(hcx, hasher); + } +} + +impl HashStable for DiffActivity { + fn hash_stable(&self, hcx: &mut CTX, hasher: &mut StableHasher) { + clause_diffactivity_discriminant(self).hash_stable(hcx, hasher); + } +} + + +impl FromStr for DiffActivity { + type Err = (); + + fn from_str(s: &str) -> Result { + match s { + "None" => Ok(DiffActivity::None), + "Active" => Ok(DiffActivity::Active), + "Const" => Ok(DiffActivity::Const), + "Duplicated" => Ok(DiffActivity::Duplicated), + "DuplicatedNoNeed" => Ok(DiffActivity::DuplicatedNoNeed), + _ => Err(()), + } + } +} + +#[allow(dead_code)] +#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug)] +pub struct AutoDiffAttrs { + pub mode: DiffMode, + pub ret_activity: DiffActivity, + pub input_activity: Vec, +} + +impl HashStable for AutoDiffAttrs { + fn hash_stable(&self, hcx: &mut CTX, hasher: &mut StableHasher) { + self.mode.hash_stable(hcx, hasher); + self.ret_activity.hash_stable(hcx, hasher); + self.input_activity.hash_stable(hcx, hasher); + } +} + +impl AutoDiffAttrs { + pub fn inactive() -> Self { + AutoDiffAttrs { + mode: DiffMode::Inactive, + ret_activity: DiffActivity::None, + input_activity: Vec::new(), + } + } + + pub fn is_active(&self) -> bool { + match self.mode { + DiffMode::Inactive => false, + _ => { + dbg!(&self); + true + }, + } + } + + pub fn is_source(&self) -> bool { + dbg!(&self); + match self.mode { + DiffMode::Source => true, + _ => false, + } + } + pub fn apply_autodiff(&self) -> bool { + match self.mode { + DiffMode::Inactive => false, + DiffMode::Source => false, + _ => { + dbg!(&self); + true + }, + } + } + + pub fn into_item( + self, + source: String, + target: String, + inputs: Vec, + output: TypeTree, + ) -> AutoDiffItem { + dbg!(&self); + AutoDiffItem { source, target, inputs, output, attrs: self } + } +} + +#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] +pub struct AutoDiffItem { + pub source: String, + pub target: String, + pub attrs: AutoDiffAttrs, + pub inputs: Vec, + pub output: TypeTree, +} + +//impl HashStable for AutoDiffItem { +// fn hash_stable(&self, hcx: &mut CTX, hasher: &mut StableHasher) { +// self.source.hash_stable(hcx, hasher); +// self.target.hash_stable(hcx, hasher); +// self.attrs.hash_stable(hcx, hasher); +// for tt in &self.inputs { +// tt.0.hash_stable(hcx, hasher); +// } +// //self.inputs.hash_stable(hcx, hasher); +// self.output.0.hash_stable(hcx, hasher); +// } +//} diff --git a/compiler/rustc_ast/src/expand/typetree.rs b/compiler/rustc_ast/src/expand/typetree.rs new file mode 100644 index 0000000000000..4b154650a4a0e --- /dev/null +++ b/compiler/rustc_ast/src/expand/typetree.rs @@ -0,0 +1,68 @@ +use std::fmt; +//use rustc_data_structures::stable_hasher::{HashStable};//, StableHasher}; +//use crate::HashStableContext; + + +#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] +pub enum Kind { + Anything, + Integer, + Pointer, + Half, + Float, + Double, + Unknown, +} +//impl HashStable for Kind { +// fn hash_stable(&self, hcx: &mut CTX, hasher: &mut StableHasher) { +// clause_kind_discriminant(self).hash_stable(hcx, hasher); +// } +//} +//fn clause_kind_discriminant(value: &Kind) -> usize { +// match value { +// Kind::Anything => 0, +// Kind::Integer => 1, +// Kind::Pointer => 2, +// Kind::Half => 3, +// Kind::Float => 4, +// Kind::Double => 5, +// Kind::Unknown => 6, +// } +//} + +#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] +pub struct TypeTree(pub Vec); + +#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] +pub struct Type { + pub offset: isize, + pub size: usize, + pub kind: Kind, + pub child: TypeTree, +} + +//impl HashStable for Type { +// fn hash_stable(&self, hcx: &mut CTX, hasher: &mut StableHasher) { +// self.offset.hash_stable(hcx, hasher); +// self.size.hash_stable(hcx, hasher); +// self.kind.hash_stable(hcx, hasher); +// self.child.0.hash_stable(hcx, hasher); +// } +//} + +impl Type { + pub fn add_offset(self, add: isize) -> Self { + let offset = match self.offset { + -1 => add, + x => add + x, + }; + + Self { size: self.size, kind: self.kind, child: self.child, offset } + } +} + +impl fmt::Display for Type { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + ::fmt(self, f) + } +} diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index 5a737c74c5ae1..e5f4ee160c3bf 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -2,6 +2,7 @@ use crate::errors; //use crate::util::check_builtin_macro_attribute; +//use crate::util::check_autodiff; use rustc_ast::ptr::P; use rustc_ast::{self as ast, FnHeader, FnSig, Generics, StmtKind}; diff --git a/compiler/rustc_passes/messages.ftl b/compiler/rustc_passes/messages.ftl index be50aad13032f..5a8fdc8c9e301 100644 --- a/compiler/rustc_passes/messages.ftl +++ b/compiler/rustc_passes/messages.ftl @@ -13,6 +13,10 @@ passes_abi_ne = passes_abi_of = fn_abi_of({$fn_name}) = {$fn_abi} +passes_autodiff_attr = + `#[autodiff]` should be applied to a function + .label = not a function + passes_allow_incoherent_impl = `rustc_allow_incoherent_impl` attribute should be applied to impl items. .label = the only currently supported targets are inherent methods diff --git a/compiler/rustc_passes/src/check_attr.rs b/compiler/rustc_passes/src/check_attr.rs index c5073048be3e8..5e6c3ca9963d4 100644 --- a/compiler/rustc_passes/src/check_attr.rs +++ b/compiler/rustc_passes/src/check_attr.rs @@ -232,6 +232,7 @@ impl CheckAttrVisitor<'_> { self.check_generic_attr(hir_id, attr, target, Target::Fn); self.check_proc_macro(hir_id, target, ProcMacroKind::Derive) } + sym::autodiff => self.check_autodiff(hir_id, attr, span, target), _ => {} } @@ -2382,6 +2383,18 @@ impl CheckAttrVisitor<'_> { self.abort.set(true); } } + + /// Checks if `#[autodiff]` is applied to an item other than a function item. + fn check_autodiff(&self, _hir_id: HirId, _attr: &Attribute, span: Span, target: Target) { + dbg!("check_autodiff"); + match target { + Target::Fn => {} + _ => { + self.tcx.sess.emit_err(errors::AutoDiffAttr { attr_span: span }); + self.abort.set(true); + } + } + } } impl<'tcx> Visitor<'tcx> for CheckAttrVisitor<'tcx> { diff --git a/compiler/rustc_passes/src/errors.rs b/compiler/rustc_passes/src/errors.rs index 856256a064149..fb8cb98346e14 100644 --- a/compiler/rustc_passes/src/errors.rs +++ b/compiler/rustc_passes/src/errors.rs @@ -24,6 +24,14 @@ pub struct IncorrectDoNotRecommendLocation { pub span: Span, } +#[derive(Diagnostic)] +#[diag(passes_autodiff_attr)] +pub struct AutoDiffAttr { + #[primary_span] + #[label] + pub attr_span: Span, +} + #[derive(LintDiagnostic)] #[diag(passes_outer_crate_level_attr)] pub struct OuterCrateLevelAttr; From 095eabd4c3cd3e115bc0440dba2ef1a1766e48eb Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Thu, 4 Jan 2024 11:06:18 -0500 Subject: [PATCH 004/100] adding Enzyme submodule --- .gitmodules | 4 ++++ src/tools/enzyme | 1 + 2 files changed, 5 insertions(+) create mode 160000 src/tools/enzyme diff --git a/.gitmodules b/.gitmodules index 9bb68b37081f5..52517890060b8 100644 --- a/.gitmodules +++ b/.gitmodules @@ -43,3 +43,7 @@ path = library/backtrace url = https://github.com/rust-lang/backtrace-rs.git shallow = true +[submodule "src/tools/enzyme"] + path = src/tools/enzyme + url = https://github.com/enzymeAD/enzyme + shallow = true diff --git a/src/tools/enzyme b/src/tools/enzyme new file mode 160000 index 0000000000000..5422797090b89 --- /dev/null +++ b/src/tools/enzyme @@ -0,0 +1 @@ +Subproject commit 5422797090b89c1e22e836eb74a852de544febc1 From c3b73ce5dd1d45293f6043684280ff041989b046 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Thu, 4 Jan 2024 11:25:10 -0500 Subject: [PATCH 005/100] adding extra error message --- compiler/rustc_ast/src/expand/mod.rs | 2 ++ compiler/rustc_codegen_llvm/messages.ftl | 3 +++ compiler/rustc_codegen_llvm/src/errors.rs | 9 +++++++++ 3 files changed, 14 insertions(+) diff --git a/compiler/rustc_ast/src/expand/mod.rs b/compiler/rustc_ast/src/expand/mod.rs index 942347383ce31..b8434374a3594 100644 --- a/compiler/rustc_ast/src/expand/mod.rs +++ b/compiler/rustc_ast/src/expand/mod.rs @@ -5,6 +5,8 @@ use rustc_span::{def_id::DefId, symbol::Ident}; use crate::MetaItem; pub mod allocator; +pub mod typetree; +pub mod autodiff_attrs; #[derive(Debug, Clone, Encodable, Decodable, HashStable_Generic)] pub struct StrippedCfgItem { diff --git a/compiler/rustc_codegen_llvm/messages.ftl b/compiler/rustc_codegen_llvm/messages.ftl index 7a86ddc7556a0..bcf4d0096d5f3 100644 --- a/compiler/rustc_codegen_llvm/messages.ftl +++ b/compiler/rustc_codegen_llvm/messages.ftl @@ -60,6 +60,9 @@ codegen_llvm_prepare_thin_lto_module_with_llvm_err = failed to prepare thin LTO codegen_llvm_run_passes = failed to run LLVM passes codegen_llvm_run_passes_with_llvm_err = failed to run LLVM passes: {$llvm_err} +codegen_llvm_prepare_autodiff = failed to prepare AutoDiff: src: {$src}, target: {$target}, {$error} +codegen_llvm_prepare_autodiff_with_llvm_err = failed to prepare AutoDiff: {$llvm_err}, src: {$src}, target: {$target}, {$error} + codegen_llvm_sanitizer_memtag_requires_mte = `-Zsanitizer=memtag` requires `-Ctarget-feature=+mte` diff --git a/compiler/rustc_codegen_llvm/src/errors.rs b/compiler/rustc_codegen_llvm/src/errors.rs index 8db97d577ca76..ea0c55f92250e 100644 --- a/compiler/rustc_codegen_llvm/src/errors.rs +++ b/compiler/rustc_codegen_llvm/src/errors.rs @@ -182,6 +182,12 @@ pub enum LlvmError<'a> { PrepareThinLtoModule, #[diag(codegen_llvm_parse_bitcode)] ParseBitcode, + #[diag(codegen_llvm_prepare_autodiff)] + PrepareAutoDiff { + src: String, + target: String, + error: String, + }, } pub(crate) struct WithLlvmError<'a>(pub LlvmError<'a>, pub String); @@ -203,6 +209,9 @@ impl IntoDiagnostic<'_, G> for WithLlvmError<'_> { } PrepareThinLtoModule => fluent::codegen_llvm_prepare_thin_lto_module_with_llvm_err, ParseBitcode => fluent::codegen_llvm_parse_bitcode_with_llvm_err, + PrepareAutoDiff { .. } => { + fluent::codegen_llvm_prepare_autodiff_with_llvm_err + } }; let mut diag = self.0.into_diagnostic(dcx, level); diag.set_primary_message(msg_with_llvm_err); From e4e3430ea3fde744ed94ace4f949a81bbbe22481 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Fri, 5 Jan 2024 06:08:38 -0500 Subject: [PATCH 006/100] adding ffi, not building enzyme yet --- compiler/rustc_codegen_llvm/src/lib.rs | 4 + compiler/rustc_codegen_llvm/src/llvm/ffi.rs | 492 ++++++++++++++++++++ compiler/rustc_codegen_llvm/src/typetree.rs | 33 ++ 3 files changed, 529 insertions(+) create mode 100644 compiler/rustc_codegen_llvm/src/typetree.rs diff --git a/compiler/rustc_codegen_llvm/src/lib.rs b/compiler/rustc_codegen_llvm/src/lib.rs index a81056ed3ad62..9726140f12b50 100644 --- a/compiler/rustc_codegen_llvm/src/lib.rs +++ b/compiler/rustc_codegen_llvm/src/lib.rs @@ -29,6 +29,10 @@ use back::owned_target_machine::OwnedTargetMachine; use back::write::{create_informational_target_machine, create_target_machine}; use errors::ParseTargetMachineConfig; + +#[allow(unused_imports)] +use llvm::TypeTree; + pub use llvm_util::target_features; use rustc_ast::expand::allocator::AllocatorKind; use rustc_codegen_ssa::back::lto::{LtoModuleCodegen, SerializedModule, ThinModule}; diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index 81702baa8c053..e479229271d95 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -1,6 +1,9 @@ #![allow(non_camel_case_types)] #![allow(non_upper_case_globals)] +#[allow(unused_imports)] +use rustc_ast::expand::autodiff_attrs::DiffActivity; + use super::debuginfo::{ DIArray, DIBasicType, DIBuilder, DICompositeType, DIDerivedType, DIDescriptor, DIEnumerator, DIFile, DIFlags, DIGlobalVariableExpression, DILexicalBlock, DILocation, DINameSpace, @@ -11,6 +14,8 @@ use super::debuginfo::{ use libc::{c_char, c_int, c_uint, size_t}; use libc::{c_ulonglong, c_void}; +use core::fmt; +use std::ffi::{CStr, CString}; use std::marker::PhantomData; use super::RustString; @@ -832,7 +837,184 @@ pub type SelfProfileAfterPassCallback = unsafe extern "C" fn(*mut c_void); pub type GetSymbolsCallback = unsafe extern "C" fn(*mut c_void, *const c_char) -> *mut c_void; pub type GetSymbolsErrorCallback = unsafe extern "C" fn(*const c_char) -> *mut c_void; +#[repr(C)] +pub enum LLVMVerifierFailureAction { + LLVMAbortProcessAction, + LLVMPrintMessageAction, + LLVMReturnStatusAction, +} + +#[allow(dead_code)] +pub(crate) unsafe fn enzyme_rust_forward_diff( + logic_ref: EnzymeLogicRef, + type_analysis: EnzymeTypeAnalysisRef, + fnc: &Value, + input_diffactivity: Vec, + ret_diffactivity: DiffActivity, + mut ret_primary_ret: bool, + input_tts: Vec, + output_tt: TypeTree, +) -> &Value { + let ret_activity = cdiffe_from(ret_diffactivity); + assert!(ret_activity != CDIFFE_TYPE::DFT_OUT_DIFF); + let mut input_activity: Vec = vec![]; + for input in input_diffactivity { + let act = cdiffe_from(input); + assert!(act == CDIFFE_TYPE::DFT_CONSTANT || act == CDIFFE_TYPE::DFT_DUP_ARG || act == CDIFFE_TYPE::DFT_DUP_NONEED); + input_activity.push(act); + } + + if ret_activity == CDIFFE_TYPE::DFT_DUP_ARG { + if ret_primary_ret != true { + dbg!("overwriting ret_primary_ret!"); + } + ret_primary_ret = true; + } else if ret_activity == CDIFFE_TYPE::DFT_DUP_NONEED { + if ret_primary_ret != false { + dbg!("overwriting ret_primary_ret!"); + } + ret_primary_ret = false; + } + + let mut args_tree = input_tts.iter().map(|x| x.inner).collect::>(); + //let mut args_tree = vec![TypeTree::new().inner; typetree.input_tt.len()]; + + // We don't support volatile / extern / (global?) values. + // Just because I didn't had time to test them, and it seems less urgent. + let args_uncacheable = vec![0; input_activity.len()]; + + let kv_tmp = IntList { data: std::ptr::null_mut(), size: 0 }; + + let mut known_values = vec![kv_tmp; input_activity.len()]; + + let dummy_type = CFnTypeInfo { + Arguments: args_tree.as_mut_ptr(), + Return: output_tt.inner.clone(), + KnownValues: known_values.as_mut_ptr(), + }; + + EnzymeCreateForwardDiff( + logic_ref, // Logic + std::ptr::null(), + std::ptr::null(), + fnc, + ret_activity, // LLVM function, return type + input_activity.as_ptr(), + input_activity.len(), // constant arguments + type_analysis, // type analysis struct + ret_primary_ret as u8, + CDerivativeMode::DEM_ForwardMode, // return value, dret_used, top_level which was 1 + 1, // free memory + 1, // vector mode width + Option::None, + dummy_type, // additional_arg, type info (return + args) + args_uncacheable.as_ptr(), + args_uncacheable.len(), // uncacheable arguments + std::ptr::null_mut(), // write augmented function to this + ) +} + +#[allow(dead_code)] +pub(crate) unsafe fn enzyme_rust_reverse_diff( + logic_ref: EnzymeLogicRef, + type_analysis: EnzymeTypeAnalysisRef, + fnc: &Value, + input_activity: Vec, + ret_activity: DiffActivity, + mut ret_primary_ret: bool, + diff_primary_ret: bool, + input_tts: Vec, + output_tt: TypeTree, +) -> &Value { + let ret_activity = cdiffe_from(ret_activity); + assert!(ret_activity == CDIFFE_TYPE::DFT_CONSTANT || ret_activity == CDIFFE_TYPE::DFT_OUT_DIFF); + let input_activity: Vec = input_activity.iter().map(|&x| cdiffe_from(x)).collect(); + + dbg!(&fnc); + + if ret_activity == CDIFFE_TYPE::DFT_DUP_ARG { + if ret_primary_ret != true { + dbg!("overwriting ret_primary_ret!"); + } + ret_primary_ret = true; + } else if ret_activity == CDIFFE_TYPE::DFT_DUP_NONEED { + if ret_primary_ret != false { + dbg!("overwriting ret_primary_ret!"); + } + ret_primary_ret = false; + } + + let mut args_tree = input_tts.iter().map(|x| x.inner).collect::>(); + + // We don't support volatile / extern / (global?) values. + // Just because I didn't had time to test them, and it seems less urgent. + let args_uncacheable = vec![0; input_tts.len()]; + assert!(args_uncacheable.len() == input_activity.len()); + let num_fnc_args = LLVMCountParams(fnc); + println!("num_fnc_args: {}", num_fnc_args); + println!("input_activity.len(): {}", input_activity.len()); + assert!(num_fnc_args == input_activity.len() as u32); + let kv_tmp = IntList { data: std::ptr::null_mut(), size: 0 }; + + + let mut known_values = vec![kv_tmp; input_tts.len()]; + + let dummy_type = CFnTypeInfo { + Arguments: args_tree.as_mut_ptr(), + Return: output_tt.inner.clone(), + KnownValues: known_values.as_mut_ptr(), + }; + + let res = EnzymeCreatePrimalAndGradient( + logic_ref, // Logic + std::ptr::null(), + std::ptr::null(), + fnc, + ret_activity, // LLVM function, return type + input_activity.as_ptr(), + input_activity.len(), // constant arguments + type_analysis, // type analysis struct + ret_primary_ret as u8, + diff_primary_ret as u8, //0 + CDerivativeMode::DEM_ReverseModeCombined, // return value, dret_used, top_level which was 1 + 1, // vector mode width + 1, // free memory + Option::None, + 0, // do not force anonymous tape + dummy_type, // additional_arg, type info (return + args) + args_uncacheable.as_ptr(), + args_uncacheable.len(), // uncacheable arguments + std::ptr::null_mut(), // write augmented function to this + 0, + ); + dbg!(&res); + res +} + extern "C" { + // Enzyme + pub fn LLVMGetReturnType(T: &Type) -> &Type; + pub fn LLVMDumpModule(M: &Module); + pub fn LLVMCountStructElementTypes(T: &Type) -> c_uint; + pub fn LLVMDeleteFunction(V: &Value); + pub fn LLVMVerifyFunction(V: &Value, action: LLVMVerifierFailureAction) -> bool; + pub fn LLVMGetParams(Fnc: &Value, parms: *mut &Value); + pub fn LLVMBuildCall2<'a>( + arg1: &Builder<'a>, + ty: &Type, + func: &Value, + args: *mut &Value, + num_args: size_t, + name: *const c_char, + ) -> &'a Value; + pub fn LLVMGetBasicBlockTerminator(B: &BasicBlock) -> &Value; + pub fn LLVMAddFunction<'a>(M: &Module, Name: *const c_char, Ty: &Type) -> &'a Value; + pub fn LLVMGetFirstFunction(M: &Module) -> Option<&Value>; + pub fn LLVMGetNextFunction(V: &Value) -> Option<&Value>; + pub fn LLVMGetNamedFunction(M: &Module, Name: *const c_char) -> Option<&Value>; + pub fn LLVMGlobalGetValueType(val: &Value) -> &Type; + pub fn LLVMRustGetFunctionType(fnc: &Value) -> &Type; + // Create and destroy contexts. pub fn LLVMContextDispose(C: &'static mut Context); pub fn LLVMGetMDKindIDInContext(C: &Context, Name: *const c_char, SLen: c_uint) -> c_uint; @@ -996,6 +1178,16 @@ extern "C" { Value: *const c_char, ValueLen: c_uint, ) -> &Attribute; + pub fn LLVMRemoveStringAttributeAtIndex(F : &Value, Idx: c_uint, K: *const c_char, KLen : c_uint); + pub fn LLVMGetStringAttributeAtIndex(F : &Value, Idx: c_uint, K: *const c_char, KLen : c_uint) -> &Attribute; + pub fn LLVMAddAttributeAtIndex(F : &Value, Idx: c_uint, K: &Attribute); + pub fn LLVMRemoveEnumAttributeAtIndex(F : &Value, Idx: c_uint, K: Attribute); + pub fn LLVMGetEnumAttributeAtIndex(F : &Value, Idx: c_uint, K: Attribute) -> &Attribute; + pub fn LLVMIsEnumAttribute(A : &Attribute) -> bool; + pub fn LLVMIsStringAttribute(A : &Attribute) -> bool; + pub fn LLVMRustAddEnumAttributeAtIndex(C: &Context, V: &Value, index: c_uint, attr: AttributeKind); + pub fn LLVMRustRemoveEnumAttributeAtIndex(V: &Value, index: c_uint, attr: AttributeKind); + pub fn LLVMRustGetEnumAttributeAtIndex(V: &Value, index: c_uint, attr: AttributeKind) ->&Attribute; // Operations on functions pub fn LLVMSetFunctionCallConv(Fn: &Value, CC: c_uint); @@ -2405,3 +2597,303 @@ extern "C" { error_callback: GetSymbolsErrorCallback, ) -> *mut c_void; } + +// Enzyme +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct EnzymeOpaqueTypeAnalysis { + _unused: [u8; 0], +} +pub type EnzymeTypeAnalysisRef = *mut EnzymeOpaqueTypeAnalysis; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct EnzymeOpaqueLogic { + _unused: [u8; 0], +} +pub type EnzymeLogicRef = *mut EnzymeOpaqueLogic; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct EnzymeOpaqueAugmentedReturn { + _unused: [u8; 0], +} +pub type EnzymeAugmentedReturnPtr = *mut EnzymeOpaqueAugmentedReturn; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct IntList { + pub data: *mut i64, + pub size: size_t, +} +#[repr(u32)] +#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] +pub enum CConcreteType { + DT_Anything = 0, + DT_Integer = 1, + DT_Pointer = 2, + DT_Half = 3, + DT_Float = 4, + DT_Double = 5, + DT_Unknown = 6, +} +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct EnzymeTypeTree { + _unused: [u8; 0], +} +pub type CTypeTreeRef = *mut EnzymeTypeTree; +extern "C" { + fn EnzymeNewTypeTree() -> CTypeTreeRef; +} +extern "C" { + fn EnzymeFreeTypeTree(CTT: CTypeTreeRef); +} +extern "C" { + pub fn EnzymeSetCLBool(arg1: *mut ::std::os::raw::c_void, arg2: u8); +} +extern "C" { + pub fn EnzymeSetCLInteger(arg1: *mut ::std::os::raw::c_void, arg2: i64); +} + +extern "C" { + pub static mut MaxIntOffset: c_void; + pub static mut MaxTypeOffset: c_void; + pub static mut EnzymeMaxTypeDepth: c_void; + + pub static mut EnzymePrintPerf: c_void; + pub static mut EnzymePrintActivity: c_void; + pub static mut EnzymePrintType: c_void; + pub static mut EnzymePrint: c_void; + pub static mut EnzymeStrictAliasing: c_void; +} + +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct CFnTypeInfo { + #[doc = " Types of arguments, assumed of size len(Arguments)"] + pub Arguments: *mut CTypeTreeRef, + #[doc = " Type of return"] + pub Return: CTypeTreeRef, + #[doc = " The specific constant(s) known to represented by an argument, if constant"] + pub KnownValues: *mut IntList, +} +#[repr(u32)] +#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] +pub enum CDIFFE_TYPE { + DFT_OUT_DIFF = 0, + DFT_DUP_ARG = 1, + DFT_CONSTANT = 2, + DFT_DUP_NONEED = 3, +} + +fn cdiffe_from(act: DiffActivity) -> CDIFFE_TYPE { + return match act { + DiffActivity::None => CDIFFE_TYPE::DFT_CONSTANT, + DiffActivity::Active => CDIFFE_TYPE::DFT_OUT_DIFF, + DiffActivity::Const => CDIFFE_TYPE::DFT_CONSTANT, + DiffActivity::Duplicated => CDIFFE_TYPE::DFT_DUP_ARG, + DiffActivity::DuplicatedNoNeed => CDIFFE_TYPE::DFT_DUP_NONEED, + }; +} + +#[repr(u32)] +#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] +pub enum CDerivativeMode { + DEM_ForwardMode = 0, + DEM_ReverseModePrimal = 1, + DEM_ReverseModeGradient = 2, + DEM_ReverseModeCombined = 3, + DEM_ForwardModeSplit = 4, +} +extern "C" { + #[allow(dead_code)] + fn EnzymeCreatePrimalAndGradient<'a>( + arg1: EnzymeLogicRef, + _builderCtx: *const u8, // &'a Builder<'_>, + _callerCtx: *const u8,// &'a Value, + todiff: &'a Value, + retType: CDIFFE_TYPE, + constant_args: *const CDIFFE_TYPE, + constant_args_size: size_t, + TA: EnzymeTypeAnalysisRef, + returnValue: u8, + dretUsed: u8, + mode: CDerivativeMode, + width: ::std::os::raw::c_uint, + freeMemory: u8, + additionalArg: Option<&Type>, + forceAnonymousTape: u8, + typeInfo: CFnTypeInfo, + _uncacheable_args: *const u8, + uncacheable_args_size: size_t, + augmented: EnzymeAugmentedReturnPtr, + AtomicAdd: u8, + ) -> &'a Value; + //) -> LLVMValueRef; +} +extern "C" { + fn EnzymeCreateForwardDiff<'a>( + arg1: EnzymeLogicRef, + _builderCtx: *const u8,// &'a Builder<'_>, + _callerCtx: *const u8,// &'a Value, + todiff: &'a Value, + retType: CDIFFE_TYPE, + constant_args: *const CDIFFE_TYPE, + constant_args_size: size_t, + TA: EnzymeTypeAnalysisRef, + returnValue: u8, + mode: CDerivativeMode, + freeMemory: u8, + width: ::std::os::raw::c_uint, + additionalArg: Option<&Type>, + typeInfo: CFnTypeInfo, + _uncacheable_args: *const u8, + uncacheable_args_size: size_t, + augmented: EnzymeAugmentedReturnPtr, + ) -> &'a Value; +} +pub type CustomRuleType = ::std::option::Option< + unsafe extern "C" fn( + direction: ::std::os::raw::c_int, + ret: CTypeTreeRef, + args: *mut CTypeTreeRef, + known_values: *mut IntList, + num_args: size_t, + fnc: &Value, + ta: *const ::std::os::raw::c_void, + ) -> u8, +>; +extern "C" { + pub fn CreateTypeAnalysis( + Log: EnzymeLogicRef, + customRuleNames: *mut *mut ::std::os::raw::c_char, + customRules: *mut CustomRuleType, + numRules: size_t, + ) -> EnzymeTypeAnalysisRef; +} +extern "C" { + pub fn ClearTypeAnalysis(arg1: EnzymeTypeAnalysisRef); +} +extern "C" { + pub fn FreeTypeAnalysis(arg1: EnzymeTypeAnalysisRef); +} +extern "C" { + pub fn CreateEnzymeLogic(PostOpt: u8) -> EnzymeLogicRef; +} +extern "C" { + pub fn ClearEnzymeLogic(arg1: EnzymeLogicRef); +} +extern "C" { + pub fn FreeEnzymeLogic(arg1: EnzymeLogicRef); +} + +extern "C" { + fn EnzymeNewTypeTreeCT(arg1: CConcreteType, ctx: &Context) -> CTypeTreeRef; + fn EnzymeNewTypeTreeTR(arg1: CTypeTreeRef) -> CTypeTreeRef; + fn EnzymeMergeTypeTree(arg1: CTypeTreeRef, arg2: CTypeTreeRef) -> bool; + fn EnzymeTypeTreeOnlyEq(arg1: CTypeTreeRef, pos: i64); + fn EnzymeTypeTreeData0Eq(arg1: CTypeTreeRef); + fn EnzymeTypeTreeShiftIndiciesEq( + arg1: CTypeTreeRef, + data_layout: *const c_char, + offset: i64, + max_size: i64, + add_offset: u64, + ); + fn EnzymeTypeTreeToStringFree(arg1: *const c_char); + fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char; +} + +pub struct TypeTree { + pub inner: CTypeTreeRef, +} + +impl TypeTree { + pub fn new() -> TypeTree { + let inner = unsafe { EnzymeNewTypeTree() }; + + TypeTree { inner } + } + + #[must_use] + pub fn from_type(t: CConcreteType, ctx: &Context) -> TypeTree { + let inner = unsafe { EnzymeNewTypeTreeCT(t, ctx) }; + + TypeTree { inner } + } + + #[must_use] + pub fn only(self, idx: isize) -> TypeTree { + unsafe { + EnzymeTypeTreeOnlyEq(self.inner, idx as i64); + } + self + } + + #[must_use] + pub fn data0(self) -> TypeTree { + unsafe { + EnzymeTypeTreeData0Eq(self.inner); + } + self + } + + pub fn merge(self, other: Self) -> Self { + unsafe { + EnzymeMergeTypeTree(self.inner, other.inner); + } + drop(other); + + self + } + + #[must_use] + pub fn shift(self, layout: &str, offset: isize, max_size: isize, add_offset: usize) -> Self { + let layout = CString::new(layout).unwrap(); + + unsafe { + EnzymeTypeTreeShiftIndiciesEq( + self.inner, + layout.as_ptr(), + offset as i64, + max_size as i64, + add_offset as u64, + ) + } + + self + } +} + +impl Clone for TypeTree { + fn clone(&self) -> Self { + let inner = unsafe { EnzymeNewTypeTreeTR(self.inner) }; + TypeTree { inner } + } +} + +impl fmt::Display for TypeTree { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let ptr = unsafe { EnzymeTypeTreeToString(self.inner) }; + let cstr = unsafe { CStr::from_ptr(ptr) }; + match cstr.to_str() { + Ok(x) => write!(f, "{}", x)?, + Err(err) => write!(f, "could not parse: {}", err)?, + } + + // delete C string pointer + unsafe { EnzymeTypeTreeToStringFree(ptr) } + + Ok(()) + } +} + +impl fmt::Debug for TypeTree { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + ::fmt(self, f) + } +} + +impl Drop for TypeTree { + fn drop(&mut self) { + unsafe { EnzymeFreeTypeTree(self.inner) } + } +} diff --git a/compiler/rustc_codegen_llvm/src/typetree.rs b/compiler/rustc_codegen_llvm/src/typetree.rs new file mode 100644 index 0000000000000..c45f5ec5e7005 --- /dev/null +++ b/compiler/rustc_codegen_llvm/src/typetree.rs @@ -0,0 +1,33 @@ +use crate::llvm; +use rustc_ast::expand::typetree::{Kind, TypeTree}; + +pub fn to_enzyme_typetree( + tree: TypeTree, + llvm_data_layout: &str, + llcx: &llvm::Context, +) -> llvm::TypeTree { + tree.0.iter().fold(llvm::TypeTree::new(), |obj, x| { + let scalar = match x.kind { + Kind::Integer => llvm::CConcreteType::DT_Integer, + Kind::Float => llvm::CConcreteType::DT_Float, + Kind::Double => llvm::CConcreteType::DT_Double, + Kind::Pointer => llvm::CConcreteType::DT_Pointer, + _ => panic!("Unknown kind {:?}", x.kind), + }; + + let tt = llvm::TypeTree::from_type(scalar, llcx).only(-1); + + let tt = if !x.child.0.is_empty() { + let inner_tt = to_enzyme_typetree(x.child.clone(), llvm_data_layout, llcx); + tt.merge(inner_tt.only(-1)) + } else { + tt + }; + + if x.offset != -1 { + obj.merge(tt.shift(llvm_data_layout, 0, x.size as isize, x.offset as usize)) + } else { + obj.merge(tt) + } + }) +} From 51e0b65840806509cfb38b9b4c4c2589444f1cdc Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Fri, 5 Jan 2024 11:03:13 -0500 Subject: [PATCH 007/100] Build and link Enzyme durin bootstrap --- src/bootstrap/src/core/build_steps/compile.rs | 21 ++++++ src/bootstrap/src/core/build_steps/llvm.rs | 66 +++++++++++++++++++ src/bootstrap/src/core/builder.rs | 4 ++ src/bootstrap/src/lib.rs | 4 ++ 4 files changed, 95 insertions(+) diff --git a/src/bootstrap/src/core/build_steps/compile.rs b/src/bootstrap/src/core/build_steps/compile.rs index df4d1a43dabc7..dbef8cb1eb3a5 100644 --- a/src/bootstrap/src/core/build_steps/compile.rs +++ b/src/bootstrap/src/core/build_steps/compile.rs @@ -1576,6 +1576,7 @@ pub struct Assemble { pub target_compiler: Compiler, } +#[allow(unreachable_code)] impl Step for Assemble { type Output = Compiler; const ONLY_HOSTS: bool = true; @@ -1636,6 +1637,26 @@ impl Step for Assemble { return target_compiler; } + // Build enzyme + let enzyme_install = + Some(builder.ensure(llvm::Enzyme { target: build_compiler.host })); + //let enzyme_install = if builder.config.llvm_enzyme { + // Some(builder.ensure(llvm::Enzyme { target: build_compiler.host })) + //} else { + // None + //}; + + if let Some(enzyme_install) = enzyme_install { + let src_lib = enzyme_install.join("build/Enzyme/LLVMEnzyme-17.so"); + + let libdir = builder.sysroot_libdir(build_compiler, build_compiler.host); + let target_libdir = builder.sysroot_libdir(target_compiler, target_compiler.host); + let dst_lib = libdir.join("libLLVMEnzyme-17.so"); + let target_dst_lib = target_libdir.join("libLLVMEnzyme-17.so"); + builder.copy(&src_lib, &dst_lib); + builder.copy(&src_lib, &target_dst_lib); + } + // Build the libraries for this compiler to link to (i.e., the libraries // it uses at runtime). NOTE: Crates the target compiler compiles don't // link to these. (FIXME: Is that correct? It seems to be correct most diff --git a/src/bootstrap/src/core/build_steps/llvm.rs b/src/bootstrap/src/core/build_steps/llvm.rs index 4b2d3e9ab4b75..ed21a092cd7a1 100644 --- a/src/bootstrap/src/core/build_steps/llvm.rs +++ b/src/bootstrap/src/core/build_steps/llvm.rs @@ -817,6 +817,72 @@ fn get_var(var_base: &str, host: &str, target: &str) -> Option { .or_else(|| env::var_os(var_base)) } +#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] +pub struct Enzyme { + pub target: TargetSelection, +} + +impl Step for Enzyme { + type Output = PathBuf; + const ONLY_HOSTS: bool = true; + + fn should_run(run: ShouldRun<'_>) -> ShouldRun<'_> { + run.path("src/tools/enzyme/enzyme") + } + + fn make_run(run: RunConfig<'_>) { + run.builder.ensure(Enzyme { target: run.target }); + } + + /// Compile Enzyme for `target`. + fn run(self, builder: &Builder<'_>) -> PathBuf { + if builder.config.dry_run() { + let out_dir = builder.enzyme_out(self.target); + return out_dir; + } + let target = self.target; + + let LlvmResult { llvm_config, .. } = builder.ensure(Llvm { target: self.target }); + + let out_dir = builder.enzyme_out(target); + let done_stamp = out_dir.join("enzyme-finished-building"); + if done_stamp.exists() { + return out_dir; + } + + builder.info(&format!("Building Enzyme for {}", target)); + let _time = helpers::timeit(&builder); + t!(fs::create_dir_all(&out_dir)); + + builder.update_submodule(&Path::new("src").join("tools").join("enzyme")); + let mut cfg = cmake::Config::new(builder.src.join("src/tools/enzyme/enzyme/")); + // TODO: Find a nicer way to use Enzyme Debug builds + //cfg.profile("Debug"); + //cfg.define("CMAKE_BUILD_TYPE", "Debug"); + configure_cmake(builder, target, &mut cfg, true, LdFlags::default(), &[]); + + // Re-use the same flags as llvm to control the level of debug information + // generated for lld. + let profile = match (builder.config.llvm_optimize, builder.config.llvm_release_debuginfo) { + (false, _) => "Debug", + (true, false) => "Release", + (true, true) => "RelWithDebInfo", + }; + + cfg.out_dir(&out_dir) + .profile(profile) + .env("LLVM_CONFIG_REAL", &llvm_config) + .define("LLVM_ENABLE_ASSERTIONS", "ON") + .define("ENZYME_EXTERNAL_SHARED_LIB", "OFF") + .define("LLVM_DIR", builder.llvm_out(target)); + + cfg.build(); + + t!(File::create(&done_stamp)); + out_dir + } +} + #[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] pub struct Lld { pub target: TargetSelection, diff --git a/src/bootstrap/src/core/builder.rs b/src/bootstrap/src/core/builder.rs index 753b41abaf489..b8e10a1f9aa85 100644 --- a/src/bootstrap/src/core/builder.rs +++ b/src/bootstrap/src/core/builder.rs @@ -1416,6 +1416,10 @@ impl<'a> Builder<'a> { rustflags.arg(sysroot_str); } + // https://rust-lang.zulipchat.com/#narrow/stream/182449-t-compiler.2Fhelp/topic/.E2.9C.94.20link.20new.20library.20into.20stage1.2Frustc + rustflags.arg("-l"); + rustflags.arg("LLVMEnzyme-17"); + let use_new_symbol_mangling = match self.config.rust_new_symbol_mangling { Some(setting) => { // If an explicit setting is given, use that diff --git a/src/bootstrap/src/lib.rs b/src/bootstrap/src/lib.rs index 871318de5955e..3909115140bca 100644 --- a/src/bootstrap/src/lib.rs +++ b/src/bootstrap/src/lib.rs @@ -799,6 +799,10 @@ impl Build { self.out.join(&*target.triple).join("lld") } + fn enzyme_out(&self, target: TargetSelection) -> PathBuf { + self.out.join(&*target.triple).join("enzyme") + } + /// Output directory for all documentation for a target fn doc_out(&self, target: TargetSelection) -> PathBuf { self.out.join(&*target.triple).join("doc") From f253e672783877ff2bd5f0dd6128f93565a41a6d Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Sun, 7 Jan 2024 20:21:44 -0500 Subject: [PATCH 008/100] update rustc_middle, rustc_monomorphize, and (partly) cg_ssa --- .../src/coverageinfo/mapgen.rs | 2 +- .../src/assert_module_sources.rs | 2 +- .../src/back/symbol_export.rs | 2 +- compiler/rustc_codegen_ssa/src/base.rs | 4 +- .../rustc_codegen_ssa/src/codegen_attrs.rs | 175 ++++++++++++- compiler/rustc_middle/src/arena.rs | 1 + compiler/rustc_middle/src/query/erase.rs | 4 + compiler/rustc_middle/src/query/mod.rs | 10 +- compiler/rustc_middle/src/ty/mod.rs | 2 + compiler/rustc_monomorphize/Cargo.toml | 2 + compiler/rustc_monomorphize/src/collector.rs | 2 +- .../rustc_monomorphize/src/partitioning.rs | 233 +++++++++++++++++- compiler/rustc_resolve/src/lib.rs | 1 + 13 files changed, 425 insertions(+), 15 deletions(-) diff --git a/compiler/rustc_codegen_llvm/src/coverageinfo/mapgen.rs b/compiler/rustc_codegen_llvm/src/coverageinfo/mapgen.rs index 33bfde03a31c3..3fb4eaea16b01 100644 --- a/compiler/rustc_codegen_llvm/src/coverageinfo/mapgen.rs +++ b/compiler/rustc_codegen_llvm/src/coverageinfo/mapgen.rs @@ -409,7 +409,7 @@ fn add_unused_functions(cx: &CodegenCx<'_, '_>) { /// All items participating in code generation together with (instrumented) /// items inlined into them. fn codegenned_and_inlined_items(tcx: TyCtxt<'_>) -> DefIdSet { - let (items, cgus) = tcx.collect_and_partition_mono_items(()); + let (items, _, cgus) = tcx.collect_and_partition_mono_items(()); let mut visited = DefIdSet::default(); let mut result = items.clone(); diff --git a/compiler/rustc_codegen_ssa/src/assert_module_sources.rs b/compiler/rustc_codegen_ssa/src/assert_module_sources.rs index a5bd10ecb34b5..d182ed88e4d10 100644 --- a/compiler/rustc_codegen_ssa/src/assert_module_sources.rs +++ b/compiler/rustc_codegen_ssa/src/assert_module_sources.rs @@ -46,7 +46,7 @@ pub fn assert_module_sources(tcx: TyCtxt<'_>, set_reuse: &dyn Fn(&mut CguReuseTr } let available_cgus = - tcx.collect_and_partition_mono_items(()).1.iter().map(|cgu| cgu.name()).collect(); + tcx.collect_and_partition_mono_items(()).2.iter().map(|cgu| cgu.name()).collect(); let mut ams = AssertModuleSource { tcx, diff --git a/compiler/rustc_codegen_ssa/src/back/symbol_export.rs b/compiler/rustc_codegen_ssa/src/back/symbol_export.rs index 54b523cb6bd54..defbff19c7778 100644 --- a/compiler/rustc_codegen_ssa/src/back/symbol_export.rs +++ b/compiler/rustc_codegen_ssa/src/back/symbol_export.rs @@ -331,7 +331,7 @@ fn exported_symbols_provider_local( // external linkage is enough for monomorphization to be linked to. let need_visibility = tcx.sess.target.dynamic_linking && !tcx.sess.target.only_cdylib; - let (_, cgus) = tcx.collect_and_partition_mono_items(()); + let (_, _, cgus) = tcx.collect_and_partition_mono_items(()); for (mono_item, data) in cgus.iter().flat_map(|cgu| cgu.items().iter()) { if data.linkage != Linkage::External { diff --git a/compiler/rustc_codegen_ssa/src/base.rs b/compiler/rustc_codegen_ssa/src/base.rs index b87f4b6bf8907..776467c73174e 100644 --- a/compiler/rustc_codegen_ssa/src/base.rs +++ b/compiler/rustc_codegen_ssa/src/base.rs @@ -594,7 +594,7 @@ pub fn codegen_crate( // Run the monomorphization collector and partition the collected items into // codegen units. - let codegen_units = tcx.collect_and_partition_mono_items(()).1; + let codegen_units = tcx.collect_and_partition_mono_items(()).2; // Force all codegen_unit queries so they are already either red or green // when compile_codegen_unit accesses them. We are not able to re-execute @@ -980,7 +980,7 @@ pub fn provide(providers: &mut Providers) { config::OptLevel::SizeMin => config::OptLevel::Default, }; - let (defids, _) = tcx.collect_and_partition_mono_items(cratenum); + let (defids, _, _) = tcx.collect_and_partition_mono_items(cratenum); let any_for_speed = defids.items().any(|id| { let CodegenFnAttrs { optimize, .. } = tcx.codegen_fn_attrs(*id); diff --git a/compiler/rustc_codegen_ssa/src/codegen_attrs.rs b/compiler/rustc_codegen_ssa/src/codegen_attrs.rs index e529956b1baf6..4a86909ce83b3 100644 --- a/compiler/rustc_codegen_ssa/src/codegen_attrs.rs +++ b/compiler/rustc_codegen_ssa/src/codegen_attrs.rs @@ -1,4 +1,5 @@ -use rustc_ast::{ast, attr, MetaItemKind, NestedMetaItem}; +use rustc_ast::{ast, attr, MetaItem, MetaItemKind, NestedMetaItem}; +use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode}; use rustc_attr::{list_contains_name, InlineAttr, InstructionSetAttr, OptimizeAttr}; use rustc_errors::struct_span_err; use rustc_hir as hir; @@ -14,6 +15,8 @@ use rustc_span::symbol::Ident; use rustc_span::{sym, Span}; use rustc_target::spec::{abi, SanitizerSet}; +use std::str::FromStr; + use crate::errors; use crate::target_features::from_target_feature; use crate::{ @@ -689,6 +692,174 @@ fn check_link_name_xor_ordinal( } } +fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs { + //let attrs = tcx.get_attrs(id, sym::autodiff_into); + let attrs = tcx.get_attrs(id, sym::autodiff); + + let attrs = attrs + .into_iter() + .filter(|attr| attr.name_or_empty() == sym::autodiff) + //.filter(|attr| attr.name_or_empty() == sym::autodiff_into) + .collect::>(); + if attrs.len() > 0 { + dbg!("autodiff_attrs len = > 0: {}", attrs.len()); + } + + // check for exactly one autodiff attribute on extern block + let msg_once = "autodiff attribute can only be applied once"; + let attr = match &attrs[..] { + &[] => return AutoDiffAttrs::inactive(), + &[elm] => elm, + x => { + tcx.sess + .struct_span_err(x[1].span, msg_once) + .span_label(x[1].span, "more than one") + .emit(); + + return AutoDiffAttrs::inactive(); + } + }; + + let list = attr.meta_item_list().unwrap_or_default(); + + // empty autodiff attribute macros (i.e. `#[autodiff]`) are used to mark source functions + if list.len() == 0 { + return AutoDiffAttrs { + mode: DiffMode::Source, + ret_activity: DiffActivity::None, + input_activity: Vec::new(), + }; + } + + let msg_ad_mode = "autodiff attribute must contain autodiff mode"; + let mode = match &list[0] { + NestedMetaItem::MetaItem(MetaItem { path: ref p2, kind: MetaItemKind::Word, .. }) => { + p2.segments.first().unwrap().ident + } + _ => { + tcx.sess + .struct_span_err(attr.span, msg_ad_mode) + .span_label(attr.span, "empty argument list") + .emit(); + + return AutoDiffAttrs::inactive(); + } + }; + + // parse mode + let msg_mode = "mode should be either forward or reverse"; + let mode = match mode.as_str() { + //map(|x| x.as_str()) { + "Forward" => DiffMode::Forward, + "Reverse" => DiffMode::Reverse, + _ => { + tcx.sess + .struct_span_err(attr.span, msg_mode) + .span_label(attr.span, "invalid mode") + .emit(); + + return AutoDiffAttrs::inactive(); + } + }; + + let msg_ret_activity = "autodiff attribute must contain the return activity"; + let ret_symbol = match &list[1] { + NestedMetaItem::MetaItem(MetaItem { path: ref p2, kind: MetaItemKind::Word, .. }) => { + p2.segments.first().unwrap().ident + } + _ => { + tcx.sess + .struct_span_err(attr.span, msg_ret_activity) + .span_label(attr.span, "missing return activity") + .emit(); + + return AutoDiffAttrs::inactive(); + } + }; + + let msg_unknown_ret_activity = "unknown return activity"; + let ret_activity = match DiffActivity::from_str(ret_symbol.as_str()) { + Ok(x) => x, + Err(_) => { + tcx.sess + .struct_span_err(attr.span, msg_unknown_ret_activity) + .span_label(attr.span, "invalid return activity") + .emit(); + + return AutoDiffAttrs::inactive(); + } + }; + + let msg_arg_activity = "autodiff attribute must contain the return activity"; + let mut arg_activities: Vec = vec![]; + for arg in &list[2..] { + let arg_symbol = match arg { + NestedMetaItem::MetaItem(MetaItem { + path: ref p2, kind: MetaItemKind::Word, .. + }) => p2.segments.first().unwrap().ident, + _ => { + tcx.sess + .struct_span_err( + attr.span, msg_arg_activity, + ) + .span_label(attr.span, "missing return activity") + .emit(); + + return AutoDiffAttrs::inactive(); + } + }; + + match DiffActivity::from_str(arg_symbol.as_str()) { + Ok(arg_activity) => arg_activities.push(arg_activity), + Err(_) => { + tcx.sess + .struct_span_err(attr.span, msg_unknown_ret_activity) + .span_label(attr.span, "invalid input activity") + .emit(); + + return AutoDiffAttrs::inactive(); + } + } + } + + let msg_fwd_incompatible_ret = "Forward Mode is incompatible with Active ret"; + let msg_fwd_incompatible_arg = "Forward Mode is incompatible with Active ret"; + let msg_rev_incompatible_arg = "Reverse Mode is only compatible with Active, None, or Const ret"; + if mode == DiffMode::Forward { + if ret_activity == DiffActivity::Active { + tcx.sess + .struct_span_err(attr.span, msg_fwd_incompatible_ret) + .span_label(attr.span, "invalid return activity") + .emit(); + return AutoDiffAttrs::inactive(); + } + if arg_activities.iter().filter(|&x| *x == DiffActivity::Active).count() > 0 { + tcx.sess + .struct_span_err(attr.span, msg_fwd_incompatible_arg) + .span_label(attr.span, "invalid input activity") + .emit(); + return AutoDiffAttrs::inactive(); + } + } + + if mode == DiffMode::Reverse { + if ret_activity == DiffActivity::Duplicated + || ret_activity == DiffActivity::DuplicatedNoNeed + { + tcx.sess + .struct_span_err( + attr.span, msg_rev_incompatible_arg, + ) + .span_label(attr.span, "invalid return activity") + .emit(); + return AutoDiffAttrs::inactive(); + } + } + + AutoDiffAttrs { mode, ret_activity, input_activity: arg_activities } +} + pub fn provide(providers: &mut Providers) { - *providers = Providers { codegen_fn_attrs, should_inherit_track_caller, ..*providers }; + *providers = + Providers { codegen_fn_attrs, should_inherit_track_caller, autodiff_attrs, ..*providers }; } diff --git a/compiler/rustc_middle/src/arena.rs b/compiler/rustc_middle/src/arena.rs index 52fd494a10db0..92cc1adcc7245 100644 --- a/compiler/rustc_middle/src/arena.rs +++ b/compiler/rustc_middle/src/arena.rs @@ -88,6 +88,7 @@ macro_rules! arena_types { [] upvars_mentioned: rustc_data_structures::fx::FxIndexMap, [] object_safety_violations: rustc_middle::traits::ObjectSafetyViolation, [] codegen_unit: rustc_middle::mir::mono::CodegenUnit<'tcx>, + [] autodiff_item: rustc_ast::expand::autodiff_attrs::AutoDiffItem, [decode] attribute: rustc_ast::Attribute, [] name_set: rustc_data_structures::unord::UnordSet, [] ordered_name_set: rustc_data_structures::fx::FxIndexSet, diff --git a/compiler/rustc_middle/src/query/erase.rs b/compiler/rustc_middle/src/query/erase.rs index b9200f1abf161..b9476a7ee92d6 100644 --- a/compiler/rustc_middle/src/query/erase.rs +++ b/compiler/rustc_middle/src/query/erase.rs @@ -200,6 +200,10 @@ impl EraseType for (&'_ T0, &'_ [T1]) { type Result = [u8; size_of::<(&'static (), &'static [()])>()]; } +impl EraseType for (&'_ T0, &'_ [T1], &'_ [T2]) { + type Result = [u8; size_of::<(&'static (), &'static [()], &'static [()])>()]; +} + macro_rules! trivial { ($($ty:ty),+ $(,)?) => { $( diff --git a/compiler/rustc_middle/src/query/mod.rs b/compiler/rustc_middle/src/query/mod.rs index 3a54f5f6b3d01..6524c6ea00e8a 100644 --- a/compiler/rustc_middle/src/query/mod.rs +++ b/compiler/rustc_middle/src/query/mod.rs @@ -55,6 +55,7 @@ use crate::ty::{GenericArg, GenericArgsRef}; use rustc_arena::TypedArena; use rustc_ast as ast; use rustc_ast::expand::{allocator::AllocatorKind, StrippedCfgItem}; +use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem}; use rustc_attr as attr; use rustc_data_structures::fingerprint::Fingerprint; use rustc_data_structures::fx::{FxHashMap, FxIndexMap, FxIndexSet}; @@ -1227,6 +1228,13 @@ rustc_queries! { separate_provide_extern } + /// The list autodiff extern functions in current crate + query autodiff_attrs(def_id: DefId) -> &'tcx AutoDiffAttrs { + desc { |tcx| "computing autodiff attributes of `{}`", tcx.def_path_str(def_id) } + arena_cache + cache_on_disk_if { def_id.is_local() } + } + query asm_target_features(def_id: DefId) -> &'tcx FxIndexSet { desc { |tcx| "computing target features for inline asm of `{}`", tcx.def_path_str(def_id) } } @@ -1880,7 +1888,7 @@ rustc_queries! { separate_provide_extern } - query collect_and_partition_mono_items(_: ()) -> (&'tcx DefIdSet, &'tcx [CodegenUnit<'tcx>]) { + query collect_and_partition_mono_items(_: ()) -> (&'tcx DefIdSet, &'tcx [AutoDiffItem], &'tcx [CodegenUnit<'tcx>]) { eval_always desc { "collect_and_partition_mono_items" } } diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs index 36b8d387d6966..a4abdf84fc0ce 100644 --- a/compiler/rustc_middle/src/ty/mod.rs +++ b/compiler/rustc_middle/src/ty/mod.rs @@ -195,6 +195,8 @@ pub struct ResolverAstLowering { pub node_id_to_def_id: NodeMap, pub def_id_to_node_id: IndexVec, + /// Mapping of autodiff function IDs + pub autodiff_map: FxHashMap, pub trait_map: NodeMap>, /// List functions and methods for which lifetime elision was successful. pub lifetime_elision_allowed: FxHashSet, diff --git a/compiler/rustc_monomorphize/Cargo.toml b/compiler/rustc_monomorphize/Cargo.toml index c7f1b9fa78454..bce8d9c4a9878 100644 --- a/compiler/rustc_monomorphize/Cargo.toml +++ b/compiler/rustc_monomorphize/Cargo.toml @@ -5,6 +5,7 @@ edition = "2021" [dependencies] # tidy-alphabetical-start +rustc_ast = { path = "../rustc_ast" } rustc_data_structures = { path = "../rustc_data_structures" } rustc_errors = { path = "../rustc_errors" } rustc_fluent_macro = { path = "../rustc_fluent_macro" } @@ -13,6 +14,7 @@ rustc_macros = { path = "../rustc_macros" } rustc_middle = { path = "../rustc_middle" } rustc_session = { path = "../rustc_session" } rustc_span = { path = "../rustc_span" } +rustc_symbol_mangling = { path = "../rustc_symbol_mangling" } rustc_target = { path = "../rustc_target" } serde = "1" serde_json = "1" diff --git a/compiler/rustc_monomorphize/src/collector.rs b/compiler/rustc_monomorphize/src/collector.rs index a68bfcd06d505..644afb0251407 100644 --- a/compiler/rustc_monomorphize/src/collector.rs +++ b/compiler/rustc_monomorphize/src/collector.rs @@ -206,7 +206,7 @@ pub enum MonoItemCollectionMode { pub struct UsageMap<'tcx> { // Maps every mono item to the mono items used by it. - used_map: FxHashMap, Vec>>, + pub used_map: FxHashMap, Vec>>, // Maps every mono item to the mono items that use it. user_map: FxHashMap, Vec>>, diff --git a/compiler/rustc_monomorphize/src/partitioning.rs b/compiler/rustc_monomorphize/src/partitioning.rs index d47d3e5e7d3b5..fac9a215eb674 100644 --- a/compiler/rustc_monomorphize/src/partitioning.rs +++ b/compiler/rustc_monomorphize/src/partitioning.rs @@ -92,12 +92,15 @@ //! source-level module, functions from the same module will be available for //! inlining, even when they are not marked `#[inline]`. + use std::cmp; use std::collections::hash_map::Entry; use std::fs::{self, File}; use std::io::{BufWriter, Write}; use std::path::{Path, PathBuf}; +use rustc_ast::expand::typetree::{Kind, Type, TypeTree}; +use rustc_ast::expand::autodiff_attrs::AutoDiffItem; use rustc_data_structures::fx::{FxHashMap, FxHashSet}; use rustc_data_structures::sync; use rustc_hir::def::DefKind; @@ -111,10 +114,12 @@ use rustc_middle::mir::mono::{ }; use rustc_middle::query::Providers; use rustc_middle::ty::print::{characteristic_def_id_of_type, with_no_trimmed_paths}; -use rustc_middle::ty::{self, visit::TypeVisitableExt, InstanceDef, TyCtxt}; +use rustc_middle::ty::{self, visit::TypeVisitableExt, InstanceDef, TyCtxt, ParamEnv, ParamEnvAnd, Adt, Ty}; use rustc_session::config::{DumpMonoStatsFormat, SwitchWithOptPath}; use rustc_session::CodegenUnits; use rustc_span::symbol::Symbol; +use rustc_symbol_mangling::symbol_name_for_instance_in_crate; +use rustc_target::abi::FieldsShape; use crate::collector::UsageMap; use crate::collector::{self, MonoItemCollectionMode}; @@ -248,7 +253,17 @@ where &mut can_be_internalized, export_generics, ); - if visibility == Visibility::Hidden && can_be_internalized { + + //if visibility == Visibility::Hidden && can_be_internalized { + let autodiff_active = characteristic_def_id + .map(|x| cx.tcx.autodiff_attrs(x).is_active()) + .unwrap_or(false); + if autodiff_active { + dbg!("place_mono_items: autodiff_active"); + dbg!(&mono_item); + } + + if !autodiff_active && visibility == Visibility::Hidden && can_be_internalized { internalization_candidates.insert(mono_item); } let size_estimate = mono_item.size_estimate(cx.tcx); @@ -1084,7 +1099,7 @@ where } } -fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> (&DefIdSet, &[CodegenUnit<'_>]) { +fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> (&DefIdSet, &[AutoDiffItem], &[CodegenUnit<'_>]) { let collection_mode = match tcx.sess.opts.unstable_opts.print_mono_items { Some(ref s) => { let mode = s.to_lowercase(); @@ -1143,6 +1158,68 @@ fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> (&DefIdSet, &[Co }) .collect(); + let autodiff_items = items + .iter() + .filter_map(|item| match *item { + MonoItem::Fn(ref instance) => Some((item, instance)), + _ => None, + }) + .filter_map(|(item, instance)| { + let target_id = instance.def_id(); + let target_attrs = tcx.autodiff_attrs(target_id); + if !target_attrs.apply_autodiff() { + return None; + } + //println!("target_id: {:?}", target_id); + + let target_symbol = + symbol_name_for_instance_in_crate(tcx, instance.clone(), LOCAL_CRATE); + //let range = usage_map.used_map.get(&item).unwrap(); + //TODO: check if last and next line are correct after rebasing + + println!("target_symbol: {:?}", target_symbol); + println!("target_attrs: {:?}", target_attrs); + println!("target_id: {:?}", target_id); + //print item + println!("item: {:?}", item); + let source = usage_map.used_map.get(&item).unwrap() + .into_iter() + .filter_map(|item| match *item { + MonoItem::Fn(ref instance_s) => { + let source_id = instance_s.def_id(); + println!("source_id_inner: {:?}", source_id); + println!("instance_s: {:?}", instance_s); + + if tcx.autodiff_attrs(source_id).is_active() { + println!("source_id is active"); + return Some(instance_s); + } + //target_symbol: "_ZN14rosenbrock_rev12d_rosenbrock17h3352c4f00c3082daE" + //target_attrs: AutoDiffAttrs { mode: Reverse, ret_activity: Active, input_activity: [Duplicated] } + //target_id: DefId(0:8 ~ rosenbrock_rev[2708]::d_rosenbrock) + //item: Fn(Instance { def: Item(DefId(0:8 ~ rosenbrock_rev[2708]::d_rosenbrock)), args: [] }) + //source_id_inner: DefId(0:4 ~ rosenbrock_rev[2708]::main) + //instance_s: Instance { def: Item(DefId(0:4 ~ rosenbrock_rev[2708]::main)), args: [] } + + + None + } + _ => None, + }) + .next(); + println!("source: {:?}", source); + + source.map(|inst| { + println!("source_id: {:?}", inst.def_id()); + let (inputs, output) = fnc_typetrees(inst.ty(tcx, ParamEnv::empty()), tcx); + let symb = symbol_name_for_instance_in_crate(tcx, inst.clone(), LOCAL_CRATE); + + target_attrs.clone().into_item(symb, target_symbol, inputs, output) + }) + }); + + let autodiff_items = tcx.arena.alloc_from_iter(autodiff_items); + // Output monomorphization stats per def_id if let SwitchWithOptPath::Enabled(ref path) = tcx.sess.opts.unstable_opts.dump_mono_stats { if let Err(err) = @@ -1202,8 +1279,152 @@ fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> (&DefIdSet, &[Co println!("MONO_ITEM {item}"); } } + if autodiff_items.len() > 0 { + println!("AUTODIFF ITEMS EXIST"); + for item in &mut *autodiff_items { + dbg!(&item); + } + } + + (tcx.arena.alloc(mono_items), autodiff_items, codegen_units) +} + +pub fn typetree_empty() -> TypeTree { + TypeTree(vec![]) +} + +pub fn typetree_from_ty<'a>(ty: Ty<'a>, tcx: TyCtxt<'a>, depth: usize) -> TypeTree { + if ty.is_unsafe_ptr() || ty.is_ref() || ty.is_box() { + if ty.is_fn_ptr() { + unimplemented!("what to do whith fn ptr?"); + } + let inner_ty = ty.builtin_deref(true).unwrap().ty; + let child = typetree_from_ty(inner_ty, tcx, depth + 1); + let tt = Type { offset: -1, kind: Kind::Pointer, size: 8, child }; + return TypeTree(vec![tt]); + } + + if ty.is_scalar() { + assert!(!ty.is_any_ptr()); + let (kind, size) = if ty.is_integral() { + (Kind::Integer, 8) + } else { + assert!(ty.is_floating_point()); + match ty { + x if x == tcx.types.f32 => (Kind::Float, 4), + x if x == tcx.types.f64 => (Kind::Double, 8), + _ => panic!("floatTy scalar that is neither f32 nor f64"), + } + }; + return TypeTree(vec![Type { offset: -1, child: typetree_empty(), kind, size }]); + } + + let param_env_and = ParamEnvAnd { param_env: ParamEnv::empty(), value: ty }; + + let layout = tcx.layout_of(param_env_and); + assert!(layout.is_ok()); + + let layout = layout.unwrap().layout; + let fields = layout.fields(); + let max_size = layout.size(); + + if ty.is_adt() { + let adt_def = ty.ty_adt_def().unwrap(); + let substs = match ty.kind() { + Adt(_, subst_ref) => subst_ref, + _ => panic!(""), + }; + + if adt_def.is_struct() { + let (offsets, _memory_index) = match fields { + FieldsShape::Arbitrary { offsets: o, memory_index: m } => (o, m), + _ => panic!(""), + }; + + let fields = adt_def.all_fields(); + let fields = fields + .into_iter() + .zip(offsets.into_iter()) + .filter_map(|(field, offset)| { + let field_ty: Ty<'_> = field.ty(tcx, substs); + let field_ty: Ty<'_> = + tcx.normalize_erasing_regions(ParamEnv::empty(), field_ty); + + if field_ty.is_phantom_data() { + return None; + } + + let mut child = typetree_from_ty(field_ty, tcx, depth + 1).0; + + for c in &mut child { + if c.offset == -1 { + c.offset = offset.bytes() as isize + } else { + c.offset += offset.bytes() as isize; + } + } + + Some(child) + }) + .flatten() + .collect::>(); + + let ret_tt = TypeTree(fields); + return ret_tt; + } else { + unimplemented!("adt that isn't a struct"); + } + } + + if ty.is_array() { + let (stride, count) = match fields { + FieldsShape::Array { stride: s, count: c } => (s, c), + _ => panic!(""), + }; + let byte_stride = stride.bytes_usize(); + let byte_max_size = max_size.bytes_usize(); + + assert!(byte_stride * *count as usize == byte_max_size); + assert!(*count > 0); // return empty TT for empty? + let sub_ty = ty.builtin_index().unwrap(); + let subtt = typetree_from_ty(sub_ty, tcx, depth + 1); + + // calculate size of subtree + let param_env_and = ParamEnvAnd { param_env: ParamEnv::empty(), value: sub_ty }; + let size = tcx.layout_of(param_env_and).unwrap().size.bytes() as usize; + let tt = TypeTree( + std::iter::repeat(subtt) + .take(*count as usize) + .enumerate() + .map(|(idx, x)| x.0.into_iter().map(move |x| x.add_offset((idx * size) as isize))) + .flatten() + .collect(), + ); + + return tt; + } + + if ty.is_slice() { + let sub_ty = ty.builtin_index().unwrap(); + let subtt = typetree_from_ty(sub_ty, tcx, depth + 1); + + return subtt; + } + + typetree_empty() +} + +pub fn fnc_typetrees<'tcx>(fn_ty: Ty<'tcx>, tcx: TyCtxt<'tcx>) -> (Vec, TypeTree) { + let fnc_binder: ty::Binder<'_, ty::FnSig<'_>> = fn_ty.fn_sig(tcx); + + // TODO: verify. + let x: ty::FnSig<'_> = fnc_binder.skip_binder(); + + let inputs = x.inputs().into_iter().map(|x| typetree_from_ty(*x, tcx, 0)).collect(); + + let output = typetree_from_ty(x.output(), tcx, 0); - (tcx.arena.alloc(mono_items), codegen_units) + (inputs, output) } /// Outputs stats about instantiation counts and estimated size, per `MonoItem`'s @@ -1288,12 +1509,12 @@ pub fn provide(providers: &mut Providers) { providers.collect_and_partition_mono_items = collect_and_partition_mono_items; providers.is_codegened_item = |tcx, def_id| { - let (all_mono_items, _) = tcx.collect_and_partition_mono_items(()); + let (all_mono_items, _, _) = tcx.collect_and_partition_mono_items(()); all_mono_items.contains(&def_id) }; providers.codegen_unit = |tcx, name| { - let (_, all) = tcx.collect_and_partition_mono_items(()); + let (_, _, all) = tcx.collect_and_partition_mono_items(()); all.iter() .find(|cgu| cgu.name() == name) .unwrap_or_else(|| panic!("failed to find cgu with name {name:?}")) diff --git a/compiler/rustc_resolve/src/lib.rs b/compiler/rustc_resolve/src/lib.rs index a14f3d494fb4d..13c495b7a4368 100644 --- a/compiler/rustc_resolve/src/lib.rs +++ b/compiler/rustc_resolve/src/lib.rs @@ -1539,6 +1539,7 @@ impl<'a, 'tcx> Resolver<'a, 'tcx> { next_node_id: self.next_node_id, node_id_to_def_id: self.node_id_to_def_id, def_id_to_node_id: self.def_id_to_node_id, + autodiff_map: Default::default(), trait_map: self.trait_map, lifetime_elision_allowed: self.lifetime_elision_allowed, lint_buffer: Steal::new(self.lint_buffer), From b39406603951a8e4e355fc6d5a57ebe74529c1f5 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Sun, 7 Jan 2024 20:36:14 -0500 Subject: [PATCH 009/100] update the backend, cg_ssa and cg_llvm --- Cargo.lock | 2 + compiler/rustc_codegen_llvm/src/attributes.rs | 3 + compiler/rustc_codegen_llvm/src/back/lto.rs | 2 +- compiler/rustc_codegen_llvm/src/back/write.rs | 315 +++++++++++++++++- compiler/rustc_codegen_llvm/src/context.rs | 4 + compiler/rustc_codegen_llvm/src/lib.rs | 35 +- compiler/rustc_codegen_ssa/src/back/lto.rs | 24 +- compiler/rustc_codegen_ssa/src/back/write.rs | 46 ++- compiler/rustc_codegen_ssa/src/base.rs | 7 +- compiler/rustc_codegen_ssa/src/traits/misc.rs | 1 + .../rustc_codegen_ssa/src/traits/write.rs | 12 + .../rustc_llvm/llvm-wrapper/RustWrapper.cpp | 32 +- 12 files changed, 464 insertions(+), 19 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5d78e29de0e08..847f11573b811 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4261,6 +4261,7 @@ dependencies = [ name = "rustc_monomorphize" version = "0.0.0" dependencies = [ + "rustc_ast", "rustc_data_structures", "rustc_errors", "rustc_fluent_macro", @@ -4269,6 +4270,7 @@ dependencies = [ "rustc_middle", "rustc_session", "rustc_span", + "rustc_symbol_mangling", "rustc_target", "serde", "serde_json", diff --git a/compiler/rustc_codegen_llvm/src/attributes.rs b/compiler/rustc_codegen_llvm/src/attributes.rs index 481741bb1277a..348551ea4e4b0 100644 --- a/compiler/rustc_codegen_llvm/src/attributes.rs +++ b/compiler/rustc_codegen_llvm/src/attributes.rs @@ -294,6 +294,7 @@ pub fn from_fn_attrs<'ll, 'tcx>( instance: ty::Instance<'tcx>, ) { let codegen_fn_attrs = cx.tcx.codegen_fn_attrs(instance.def_id()); + let autodiff_attrs = cx.tcx.autodiff_attrs(instance.def_id()); let mut to_add = SmallVec::<[_; 16]>::new(); @@ -311,6 +312,8 @@ pub fn from_fn_attrs<'ll, 'tcx>( let inline = if codegen_fn_attrs.inline == InlineAttr::None && instance.def.requires_inline(cx.tcx) { InlineAttr::Hint + } else if autodiff_attrs.is_active() { + InlineAttr::Never } else { codegen_fn_attrs.inline }; diff --git a/compiler/rustc_codegen_llvm/src/back/lto.rs b/compiler/rustc_codegen_llvm/src/back/lto.rs index e9e8ade09b77b..55bf6b295d27b 100644 --- a/compiler/rustc_codegen_llvm/src/back/lto.rs +++ b/compiler/rustc_codegen_llvm/src/back/lto.rs @@ -714,7 +714,7 @@ pub unsafe fn optimize_thin_module( let llcx = llvm::LLVMRustContextCreate(cgcx.fewer_names); let llmod_raw = parse_module(llcx, module_name, thin_module.data(), &dcx)? as *const _; let mut module = ModuleCodegen { - module_llvm: ModuleLlvm { llmod_raw, llcx, tm: ManuallyDrop::new(tm) }, + module_llvm: ModuleLlvm { llmod_raw, llcx, tm: ManuallyDrop::new(tm), typetrees: Default::default() }, name: thin_module.name().to_string(), kind: ModuleKind::Regular, }; diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs index 75f99f964d0a9..c726b62912ef1 100644 --- a/compiler/rustc_codegen_llvm/src/back/write.rs +++ b/compiler/rustc_codegen_llvm/src/back/write.rs @@ -5,18 +5,32 @@ use crate::back::profiling::{ }; use crate::base; use crate::common; +use crate::DiffTypeTree; use crate::errors::{ CopyBitcode, FromLlvmDiag, FromLlvmOptimizationDiag, LlvmError, UnknownCompression, WithLlvmError, WriteBytecode, }; use crate::llvm::{self, DiagnosticInfo, PassManager}; +use crate::llvm::{LLVMReplaceAllUsesWith, LLVMVerifyFunction, Value, LLVMRustGetEnumAttributeAtIndex, LLVMRustAddEnumAttributeAtIndex, LLVMRustRemoveEnumAttributeAtIndex, + enzyme_rust_forward_diff, enzyme_rust_reverse_diff, BasicBlock, CreateEnzymeLogic, + CreateTypeAnalysis, EnzymeLogicRef, EnzymeTypeAnalysisRef, LLVMAddFunction, + LLVMAppendBasicBlockInContext, LLVMBuildCall2, LLVMBuildExtractValue, LLVMBuildRet, + LLVMCountParams, LLVMCountStructElementTypes, LLVMCreateBuilderInContext, LLVMDeleteFunction, + LLVMDisposeBuilder, LLVMGetBasicBlockTerminator, LLVMGetElementType, LLVMGetModuleContext, + LLVMGetParams, LLVMGetReturnType, LLVMPositionBuilderAtEnd, LLVMSetValueName2, LLVMTypeOf, + LLVMVoidTypeInContext, LLVMGlobalGetValueType, LLVMGetStringAttributeAtIndex, + LLVMIsStringAttribute, LLVMRemoveStringAttributeAtIndex, AttributeKind, + LLVMGetFirstFunction, LLVMGetNextFunction, LLVMIsEnumAttribute, + LLVMCreateStringAttribute, LLVMRustAddFunctionAttributes, LLVMDumpModule}; use crate::llvm_util; use crate::type_::Type; +use crate::typetree::to_enzyme_typetree; use crate::LlvmCodegenBackend; use crate::ModuleLlvm; use llvm::{ LLVMRustLLVMHasZlibCompressionForDebugSymbols, LLVMRustLLVMHasZstdCompressionForDebugSymbols, }; +use rustc_ast::expand::autodiff_attrs::{AutoDiffItem, DiffActivity, DiffMode}; use rustc_codegen_ssa::back::link::ensure_removed; use rustc_codegen_ssa::back::write::{ BitcodeSection, CodegenContext, EmitObj, ModuleConfig, TargetMachineFactoryConfig, @@ -26,6 +40,7 @@ use rustc_codegen_ssa::traits::*; use rustc_codegen_ssa::{CompiledModule, ModuleCodegen}; use rustc_data_structures::profiling::SelfProfilerRef; use rustc_data_structures::small_c_str::SmallCStr; +use rustc_data_structures::fx::FxHashMap; use rustc_errors::{DiagCtxt, FatalError, Level}; use rustc_fs_util::{link_or_copy, path_to_c_string}; use rustc_middle::ty::TyCtxt; @@ -37,7 +52,7 @@ use rustc_target::spec::{CodeModel, RelocModel, SanitizerSet, SplitDebuginfo, Tl use crate::llvm::diagnostic::OptimizationDiagnosticKind; use libc::{c_char, c_int, c_uint, c_void, size_t}; -use std::ffi::CString; +use std::ffi::{CStr, CString}; use std::fs; use std::io::{self, Write}; use std::path::{Path, PathBuf}; @@ -511,8 +526,19 @@ pub(crate) unsafe fn llvm_optimize( opt_level: config::OptLevel, opt_stage: llvm::OptStage, ) -> Result<(), FatalError> { - let unroll_loops = + // Enzyme: + // We want to simplify / optimize functions before AD. + // However, benchmarks show that optimizations increasing the code size + // tend to reduce AD performance. Therefore activate them first, then differentiate the code + // and finally re-optimize the module, now with all optimizations available. + // RIP compile time. + // let unroll_loops = + // opt_level != config::OptLevel::Size && opt_level != config::OptLevel::SizeMin; + let _unroll_loops = opt_level != config::OptLevel::Size && opt_level != config::OptLevel::SizeMin; + let unroll_loops = false; + let vectorize_slp = false; + let vectorize_loop = false; let using_thin_buffers = opt_stage == llvm::OptStage::PreLinkThinLTO || config.bitcode_needed(); let pgo_gen_path = get_pgo_gen_path(config); let pgo_use_path = get_pgo_use_path(config); @@ -567,8 +593,8 @@ pub(crate) unsafe fn llvm_optimize( using_thin_buffers, config.merge_functions, unroll_loops, - config.vectorize_slp, - config.vectorize_loop, + vectorize_slp, + vectorize_loop, config.emit_lifetime_markers, sanitizer_options.as_ref(), pgo_gen_path.as_ref().map_or(std::ptr::null(), |s| s.as_ptr()), @@ -589,6 +615,261 @@ pub(crate) unsafe fn llvm_optimize( result.into_result().map_err(|()| llvm_err(dcx, LlvmError::RunLlvmPasses)) } +fn get_params(fnc: &Value) -> Vec<&Value> { + unsafe { + let param_num = LLVMCountParams(fnc) as usize; + let mut fnc_args: Vec<&Value> = vec![]; + fnc_args.reserve(param_num); + LLVMGetParams(fnc, fnc_args.as_mut_ptr()); + fnc_args.set_len(param_num); + fnc_args + } +} + +// TODO: Here we could start adding length checks for the shaddow args. +unsafe fn create_wrapper<'a>( + llmod: &'a llvm::Module, + fnc: &'a Value, + u_type: &Type, + fnc_name: String, +) -> (&'a Value, &'a BasicBlock, Vec<&'a Value>, Vec<&'a Value>, CString) { + let context = LLVMGetModuleContext(llmod); + let inner_fnc_name = "inner_".to_string() + &fnc_name; + let c_inner_fnc_name = CString::new(inner_fnc_name.clone()).unwrap(); + LLVMSetValueName2(fnc, c_inner_fnc_name.as_ptr(), inner_fnc_name.len() as usize); + + let c_outer_fnc_name = CString::new(fnc_name).unwrap(); + let outer_fnc: &Value = + LLVMAddFunction(llmod, c_outer_fnc_name.as_ptr(), LLVMGetElementType(u_type) as &Type); + + let entry = "fnc_entry".to_string(); + let c_entry = CString::new(entry).unwrap(); + let basic_block = LLVMAppendBasicBlockInContext(context, outer_fnc, c_entry.as_ptr()); + + let outer_params: Vec<&Value> = get_params(outer_fnc); + let inner_params: Vec<&Value> = get_params(fnc); + + (outer_fnc, basic_block, outer_params, inner_params, c_inner_fnc_name) +} + +pub(crate) unsafe fn extract_return_type<'a>( + llmod: &'a llvm::Module, + fnc: &'a Value, + u_type: &Type, + fnc_name: String, +) -> &'a Value { + let context = llvm::LLVMGetModuleContext(llmod); + + let inner_param_num = LLVMCountParams(fnc); + let (outer_fnc, outer_bb, mut outer_args, _inner_args, c_inner_fnc_name) = + create_wrapper(llmod, fnc, u_type, fnc_name); + + if inner_param_num as usize != outer_args.len() { + panic!("Args len shouldn't differ. Please report this."); + } + + let builder = LLVMCreateBuilderInContext(context); + LLVMPositionBuilderAtEnd(builder, outer_bb); + let struct_ret = LLVMBuildCall2( + builder, + u_type, + fnc, + outer_args.as_mut_ptr(), + outer_args.len(), + c_inner_fnc_name.as_ptr(), + ); + // We can use an arbitrary name here, since it will be used to store a tmp value. + let inner_grad_name = "foo".to_string(); + let c_inner_grad_name = CString::new(inner_grad_name).unwrap(); + let struct_ret = LLVMBuildExtractValue(builder, struct_ret, 0, c_inner_grad_name.as_ptr()); + let _ret = LLVMBuildRet(builder, struct_ret); + let _terminator = LLVMGetBasicBlockTerminator(outer_bb); + LLVMDisposeBuilder(builder); + let _fnc_ok = + LLVMVerifyFunction(outer_fnc, llvm::LLVMVerifierFailureAction::LLVMAbortProcessAction); + outer_fnc +} + +// As unsafe as it can be. +#[allow(unused_variables)] +#[allow(unused)] +pub(crate) unsafe fn enzyme_ad( + llmod: &llvm::Module, + llcx: &llvm::Context, + diag_handler: &DiagCtxt, + item: AutoDiffItem, +) -> Result<(), FatalError> { + dbg!("cg_llvm enzyme_ad"); + let autodiff_mode = item.attrs.mode; + let rust_name = item.source; + let rust_name2 = &item.target; + + let args_activity = item.attrs.input_activity.clone(); + let ret_activity: DiffActivity = item.attrs.ret_activity; + + // get target and source function + let name = CString::new(rust_name.to_owned()).unwrap(); + let name2 = CString::new(rust_name2.clone()).unwrap(); + let src_fnc_opt = llvm::LLVMGetNamedFunction(llmod, name.as_c_str().as_ptr()); + let src_fnc = match src_fnc_opt { + Some(x) => x, + None => { + return Err(llvm_err(diag_handler, LlvmError::PrepareAutoDiff{ + src: rust_name.to_owned(), + target: rust_name2.to_owned(), + error: "could not find src function".to_owned(), + })); + } + }; + let target_fnc_opt = llvm::LLVMGetNamedFunction(llmod, name2.as_ptr()); + let target_fnc = match target_fnc_opt { + Some(x) => x, + None => { + return Err(llvm_err(diag_handler, LlvmError::PrepareAutoDiff{ + src: rust_name.to_owned(), + target: rust_name2.to_owned(), + error: "could not find target function".to_owned(), + })); + } + }; + let src_num_args = llvm::LLVMCountParams(src_fnc); + let target_num_args = llvm::LLVMCountParams(target_fnc); + assert!(src_num_args <= target_num_args); + + // create enzyme typetrees + let llvm_data_layout = unsafe { llvm::LLVMGetDataLayoutStr(&*llmod) }; + let llvm_data_layout = + std::str::from_utf8(unsafe { CStr::from_ptr(llvm_data_layout) }.to_bytes()) + .expect("got a non-UTF8 data-layout from LLVM"); + + let input_tts = + item.inputs.into_iter().map(|x| to_enzyme_typetree(x, llvm_data_layout, llcx)).collect(); + let output_tt = to_enzyme_typetree(item.output, llvm_data_layout, llcx); + + let opt = 1; + let ret_primary_ret = false; + let diff_primary_ret = false; + let logic_ref: EnzymeLogicRef = CreateEnzymeLogic(opt as u8); + let type_analysis: EnzymeTypeAnalysisRef = + CreateTypeAnalysis(logic_ref, std::ptr::null_mut(), std::ptr::null_mut(), 0); + + llvm::EnzymeSetCLBool(std::ptr::addr_of_mut!(llvm::EnzymeStrictAliasing), 0); + + if std::env::var("ENZYME_PRINT_TA").is_ok() { + llvm::EnzymeSetCLBool(std::ptr::addr_of_mut!(llvm::EnzymePrintType), 1); + } + if std::env::var("ENZYME_PRINT_AA").is_ok() { + llvm::EnzymeSetCLBool(std::ptr::addr_of_mut!(llvm::EnzymePrintActivity), 1); + } + if std::env::var("ENZYME_PRINT_PERF").is_ok() { + llvm::EnzymeSetCLBool(std::ptr::addr_of_mut!(llvm::EnzymePrintPerf), 1); + } + if std::env::var("ENZYME_PRINT").is_ok() { + llvm::EnzymeSetCLBool(std::ptr::addr_of_mut!(llvm::EnzymePrint), 1); + } + + let mut res: &Value = match item.attrs.mode { + DiffMode::Forward => enzyme_rust_forward_diff( + logic_ref, + type_analysis, + src_fnc, + args_activity, + ret_activity, + ret_primary_ret, + input_tts, + output_tt, + ), + DiffMode::Reverse => enzyme_rust_reverse_diff( + logic_ref, + type_analysis, + src_fnc, + args_activity, + ret_activity, + ret_primary_ret, + diff_primary_ret, + input_tts, + output_tt, + ), + _ => unreachable!(), + }; + let f_return_type = LLVMGetReturnType(LLVMGlobalGetValueType(res)); + + let void_type = LLVMVoidTypeInContext(llcx); + if item.attrs.mode == DiffMode::Reverse && f_return_type != void_type { + let num_elem_in_ret_struct = LLVMCountStructElementTypes(f_return_type); + if num_elem_in_ret_struct == 1 { + let u_type = LLVMTypeOf(target_fnc); + res = extract_return_type(llmod, res, u_type, rust_name2.clone()); // TODO: check if name or name2 + } + } + LLVMSetValueName2(res, name2.as_ptr(), rust_name2.len()); + LLVMReplaceAllUsesWith(target_fnc, res); + LLVMDeleteFunction(target_fnc); + + Ok(()) +} + +pub(crate) unsafe fn differentiate( + module: &ModuleCodegen, + cgcx: &CodegenContext, + diff_items: Vec, + _typetrees: FxHashMap, + _config: &ModuleConfig, +) -> Result<(), FatalError> { + dbg!("cg_llvm differentiate"); + dbg!(&diff_items); + + let llmod = module.module_llvm.llmod(); + let llcx = &module.module_llvm.llcx; + let diag_handler = cgcx.create_dcx(); + + llvm::EnzymeSetCLBool(std::ptr::addr_of_mut!(llvm::EnzymeStrictAliasing), 0); + + if std::env::var("ENZYME_PRINT_MOD").is_ok() { + unsafe {LLVMDumpModule(llmod);} + } + if std::env::var("ENZYME_TT_DEPTH").is_ok() { + let depth = std::env::var("ENZYME_TT_DEPTH").unwrap(); + let depth = depth.parse::().unwrap(); + assert!(depth >= 1); + llvm::EnzymeSetCLInteger(std::ptr::addr_of_mut!(llvm::EnzymeMaxTypeDepth), depth); + } + if std::env::var("ENZYME_TT_WIDTH").is_ok() { + let width = std::env::var("ENZYME_TT_WIDTH").unwrap(); + let width = width.parse::().unwrap(); + assert!(width >= 1); + llvm::EnzymeSetCLInteger(std::ptr::addr_of_mut!(llvm::MaxTypeOffset), width); + } + + for item in diff_items { + let res = enzyme_ad(llmod, llcx, &diag_handler, item); + assert!(res.is_ok()); + } + + let mut f = LLVMGetFirstFunction(llmod); + loop { + if let Some(lf) = f { + f = LLVMGetNextFunction(lf); + let myhwattr = "enzyme_hw"; + let attr = LLVMGetStringAttributeAtIndex(lf, c_uint::MAX, myhwattr.as_ptr() as *const c_char, myhwattr.as_bytes().len() as c_uint); + if LLVMIsStringAttribute(attr) { + LLVMRemoveStringAttributeAtIndex(lf, c_uint::MAX, myhwattr.as_ptr() as *const c_char, myhwattr.as_bytes().len() as c_uint); + } else { + LLVMRustRemoveEnumAttributeAtIndex(lf, c_uint::MAX, AttributeKind::SanitizeHWAddress); + } + + + } else { + break; + } + } + if std::env::var("ENZYME_PRINT_MOD_AFTER").is_ok() { + unsafe {LLVMDumpModule(llmod);} + } + + Ok(()) +} + // Unsafe due to LLVM calls. pub(crate) unsafe fn optimize( cgcx: &CodegenContext, @@ -611,6 +892,32 @@ pub(crate) unsafe fn optimize( llvm::LLVMWriteBitcodeToFile(llmod, out.as_ptr()); } + // This code enables Enzyme to differentiate code containing Rust enums. + // By adding the SanitizeHWAddress attribute we prevent LLVM from Optimizing + // away the enums and allows Enzyme to understand why a value can be of different types in + // different code sections. We remove this attribute after Enzyme is done, to not affect the + // rest of the compilation. + { + let mut f = LLVMGetFirstFunction(llmod); + loop { + if let Some(lf) = f { + f = LLVMGetNextFunction(lf); + let myhwattr = "enzyme_hw"; + let myhwv = ""; + let prevattr = LLVMRustGetEnumAttributeAtIndex(lf, c_uint::MAX, AttributeKind::SanitizeHWAddress); + if LLVMIsEnumAttribute(prevattr) { + let attr = LLVMCreateStringAttribute(llcx, myhwattr.as_ptr() as *const c_char, myhwattr.as_bytes().len() as c_uint, myhwv.as_ptr() as *const c_char, myhwv.as_bytes().len() as c_uint); + LLVMRustAddFunctionAttributes(lf, c_uint::MAX, &attr, 1); + } else { + LLVMRustAddEnumAttributeAtIndex(llcx, lf, c_uint::MAX, AttributeKind::SanitizeHWAddress); + } + + } else { + break; + } + } + } + if let Some(opt_level) = config.opt_level { let opt_stage = match cgcx.lto { Lto::Fat => llvm::OptStage::PreLinkFatLTO, diff --git a/compiler/rustc_codegen_llvm/src/context.rs b/compiler/rustc_codegen_llvm/src/context.rs index 3053c4e0daaa9..c58373bf168f6 100644 --- a/compiler/rustc_codegen_llvm/src/context.rs +++ b/compiler/rustc_codegen_llvm/src/context.rs @@ -639,6 +639,10 @@ impl<'ll, 'tcx> MiscMethods<'tcx> for CodegenCx<'ll, 'tcx> { None } } + + fn create_autodiff(&self) -> Vec { + return vec![]; + } } impl<'ll> CodegenCx<'ll, '_> { diff --git a/compiler/rustc_codegen_llvm/src/lib.rs b/compiler/rustc_codegen_llvm/src/lib.rs index 9726140f12b50..b52042f55e71d 100644 --- a/compiler/rustc_codegen_llvm/src/lib.rs +++ b/compiler/rustc_codegen_llvm/src/lib.rs @@ -35,6 +35,7 @@ use llvm::TypeTree; pub use llvm_util::target_features; use rustc_ast::expand::allocator::AllocatorKind; +use rustc_ast::expand::autodiff_attrs::AutoDiffItem; use rustc_codegen_ssa::back::lto::{LtoModuleCodegen, SerializedModule, ThinModule}; use rustc_codegen_ssa::back::write::{ CodegenContext, FatLtoInput, ModuleConfig, TargetMachineFactoryConfig, TargetMachineFactoryFn, @@ -42,7 +43,7 @@ use rustc_codegen_ssa::back::write::{ use rustc_codegen_ssa::traits::*; use rustc_codegen_ssa::ModuleCodegen; use rustc_codegen_ssa::{CodegenResults, CompiledModule}; -use rustc_data_structures::fx::FxIndexMap; +use rustc_data_structures::fx::{FxHashMap, FxIndexMap}; use rustc_errors::{DiagCtxt, ErrorGuaranteed, FatalError}; use rustc_metadata::EncodedMetadata; use rustc_middle::dep_graph::{WorkProduct, WorkProductId}; @@ -80,6 +81,7 @@ mod debuginfo; mod declare; mod errors; mod intrinsic; +mod typetree; // The following is a workaround that replaces `pub mod llvm;` and that fixes issue 53912. #[path = "llvm/mod.rs"] @@ -175,6 +177,7 @@ impl WriteBackendMethods for LlvmCodegenBackend { type TargetMachineError = crate::errors::LlvmError<'static>; type ThinData = back::lto::ThinData; type ThinBuffer = back::lto::ThinBuffer; + type TypeTree = DiffTypeTree; fn print_pass_timings(&self) { unsafe { let mut size = 0; @@ -257,6 +260,23 @@ impl WriteBackendMethods for LlvmCodegenBackend { fn serialize_module(module: ModuleCodegen) -> (String, Self::ModuleBuffer) { (module.name, back::lto::ModuleBuffer::new(module.module_llvm.llmod())) } + /// Generate autodiff rules + fn autodiff( + cgcx: &CodegenContext, + module: &ModuleCodegen, + diff_fncs: Vec, + typetrees: FxHashMap, + config: &ModuleConfig, + ) -> Result<(), FatalError> { + dbg!("cg_llvm autodiff"); + unsafe { back::write::differentiate(module, cgcx, diff_fncs, typetrees, config) } + } + + // The typetrees contain all information, their order therefore is irrelevant. + #[allow(rustc::potential_query_instability)] + fn typetrees(module: &mut Self::Module) -> FxHashMap { + module.typetrees.drain().collect() + } } unsafe impl Send for LlvmCodegenBackend {} // Llvm is on a per-thread basis @@ -409,6 +429,13 @@ impl CodegenBackend for LlvmCodegenBackend { } } +#[derive(Clone, Debug)] +pub struct DiffTypeTree { + pub ret_tt: TypeTree, + pub input_tt: Vec, +} + +#[allow(dead_code)] pub struct ModuleLlvm { llcx: &'static mut llvm::Context, llmod_raw: *const llvm::Module, @@ -416,6 +443,7 @@ pub struct ModuleLlvm { // This field is `ManuallyDrop` because it is important that the `TargetMachine` // is disposed prior to the `Context` being disposed otherwise UAFs can occur. tm: ManuallyDrop, + typetrees: FxHashMap, } unsafe impl Send for ModuleLlvm {} @@ -430,6 +458,7 @@ impl ModuleLlvm { llmod_raw, llcx, tm: ManuallyDrop::new(create_target_machine(tcx, mod_name)), + typetrees: Default::default(), } } } @@ -442,6 +471,7 @@ impl ModuleLlvm { llmod_raw, llcx, tm: ManuallyDrop::new(create_informational_target_machine(tcx.sess)), + typetrees: Default::default(), } } } @@ -463,7 +493,8 @@ impl ModuleLlvm { } }; - Ok(ModuleLlvm { llmod_raw, llcx, tm: ManuallyDrop::new(tm) }) + Ok(ModuleLlvm { llmod_raw, llcx, tm: ManuallyDrop::new(tm), + typetrees: Default::default() }) } } diff --git a/compiler/rustc_codegen_ssa/src/back/lto.rs b/compiler/rustc_codegen_ssa/src/back/lto.rs index cb6244050df24..cf03fb78cf1d2 100644 --- a/compiler/rustc_codegen_ssa/src/back/lto.rs +++ b/compiler/rustc_codegen_ssa/src/back/lto.rs @@ -1,8 +1,10 @@ use super::write::CodegenContext; +use crate::back::write::ModuleConfig; use crate::traits::*; use crate::ModuleCodegen; -use rustc_data_structures::memmap::Mmap; +use rustc_ast::expand::autodiff_attrs::AutoDiffItem; +use rustc_data_structures::{fx::FxHashMap, memmap::Mmap}; use rustc_errors::FatalError; use std::ffi::CString; @@ -76,6 +78,26 @@ impl LtoModuleCodegen { } } + /// Run autodiff on Fat LTO module + pub unsafe fn autodiff( + self, + cgcx: &CodegenContext, + diff_fncs: Vec, + typetrees: FxHashMap, + config: &ModuleConfig, + ) -> Result, FatalError> { + match &self { + LtoModuleCodegen::Fat { ref module, .. } => { + { + B::autodiff(cgcx, &module, diff_fncs, typetrees, config)?; + } + }, + _ => {}, + } + + Ok(self) + } + /// A "gauge" of how costly it is to optimize this module, used to sort /// biggest modules first. pub fn cost(&self) -> u64 { diff --git a/compiler/rustc_codegen_ssa/src/back/write.rs b/compiler/rustc_codegen_ssa/src/back/write.rs index 53ae085a72160..166a0852803e4 100644 --- a/compiler/rustc_codegen_ssa/src/back/write.rs +++ b/compiler/rustc_codegen_ssa/src/back/write.rs @@ -9,6 +9,7 @@ use crate::{ }; use jobserver::{Acquired, Client}; use rustc_ast::attr; +use rustc_ast::expand::autodiff_attrs::AutoDiffItem; use rustc_data_structures::fx::{FxHashMap, FxIndexMap}; use rustc_data_structures::memmap::Mmap; use rustc_data_structures::profiling::{SelfProfilerRef, VerboseTimingGuard}; @@ -374,6 +375,8 @@ impl CodegenContext { fn generate_lto_work( cgcx: &CodegenContext, + autodiff: Vec, + typetrees: FxHashMap, needs_fat_lto: Vec>, needs_thin_lto: Vec<(String, B::ThinBuffer)>, import_only_modules: Vec<(SerializedModule, WorkProduct)>, @@ -382,8 +385,12 @@ fn generate_lto_work( if !needs_fat_lto.is_empty() { assert!(needs_thin_lto.is_empty()); - let module = + let mut module = B::run_fat_lto(cgcx, needs_fat_lto, import_only_modules).unwrap_or_else(|e| e.raise()); + if cgcx.lto == Lto::Fat { + let config = cgcx.config(ModuleKind::Regular); + module = unsafe { module.autodiff(cgcx, autodiff, typetrees, config).unwrap() }; + } // We are adding a single work item, so the cost doesn't matter. vec![(WorkItem::LTO(module), 0)] } else { @@ -970,6 +977,8 @@ pub(crate) enum Message { work_product: WorkProduct, }, + AddAutoDiffItems(Vec), + /// The frontend has finished generating everything for all codegen units. /// Sent from the main thread. CodegenComplete, @@ -1264,6 +1273,7 @@ fn start_executing_work( // This is where we collect codegen units that have gone all the way // through codegen and LLVM. + let mut autodiff_items = Vec::new(); let mut compiled_modules = vec![]; let mut compiled_allocator_module = None; let mut needs_link = Vec::new(); @@ -1271,6 +1281,7 @@ fn start_executing_work( let mut needs_thin_lto = Vec::new(); let mut lto_import_only_modules = Vec::new(); let mut started_lto = false; + let mut typetrees = FxHashMap::::default(); /// Possible state transitions: /// - Ongoing -> Completed @@ -1375,9 +1386,14 @@ fn start_executing_work( let needs_thin_lto = mem::take(&mut needs_thin_lto); let import_only_modules = mem::take(&mut lto_import_only_modules); - for (work, cost) in - generate_lto_work(&cgcx, needs_fat_lto, needs_thin_lto, import_only_modules) - { + for (work, cost) in generate_lto_work( + &cgcx, + autodiff_items.clone(), + typetrees.clone(), + needs_fat_lto, + needs_thin_lto, + import_only_modules, + ) { let insertion_index = work_items .binary_search_by_key(&cost, |&(_, cost)| cost) .unwrap_or_else(|e| e); @@ -1490,7 +1506,16 @@ fn start_executing_work( } } - Message::CodegenDone { llvm_work_item, cost } => { + Message::CodegenDone { mut llvm_work_item, cost } => { + //// extract build typetrees + match &mut llvm_work_item { + WorkItem::Optimize(module) => { + let tt = B::typetrees(&mut module.module_llvm); + typetrees.extend(tt); + } + _ => {}, + } + // We keep the queue sorted by estimated processing cost, // so that more expensive items are processed earlier. This // is good for throughput as it gives the main thread more @@ -1512,6 +1537,12 @@ fn start_executing_work( main_thread_state = MainThreadState::Idle; } + Message::AddAutoDiffItems(mut items) => { + dbg!("AddAutoDiffItems"); + autodiff_items.append(&mut items); + } + + Message::CodegenComplete => { if codegen_state != Aborted { codegen_state = Completed; @@ -1981,6 +2012,11 @@ impl OngoingCodegen { drop(self.coordinator.sender.send(Box::new(Message::CodegenComplete::))); } + pub fn submit_autodiff_items(&self, items: Vec) { + dbg!("submit_autodiff_items"); + drop(self.coordinator.sender.send(Box::new(Message::::AddAutoDiffItems(items)))); + } + pub fn check_for_errors(&self, sess: &Session) { self.shared_emitter_main.check(sess, false); } diff --git a/compiler/rustc_codegen_ssa/src/base.rs b/compiler/rustc_codegen_ssa/src/base.rs index 776467c73174e..94a8d2cc4d4c9 100644 --- a/compiler/rustc_codegen_ssa/src/base.rs +++ b/compiler/rustc_codegen_ssa/src/base.rs @@ -594,7 +594,8 @@ pub fn codegen_crate( // Run the monomorphization collector and partition the collected items into // codegen units. - let codegen_units = tcx.collect_and_partition_mono_items(()).2; + let (_, autodiff_fncs, codegen_units) = tcx.collect_and_partition_mono_items(()); + let autodiff_fncs = autodiff_fncs.to_vec(); // Force all codegen_unit queries so they are already either red or green // when compile_codegen_unit accesses them. We are not able to re-execute @@ -663,6 +664,10 @@ pub fn codegen_crate( ); } + if !autodiff_fncs.is_empty() { + ongoing_codegen.submit_autodiff_items(autodiff_fncs); + } + // For better throughput during parallel processing by LLVM, we used to sort // CGUs largest to smallest. This would lead to better thread utilization // by, for example, preventing a large CGU from being processed last and diff --git a/compiler/rustc_codegen_ssa/src/traits/misc.rs b/compiler/rustc_codegen_ssa/src/traits/misc.rs index 04e2b8796c46a..5f64dd3367661 100644 --- a/compiler/rustc_codegen_ssa/src/traits/misc.rs +++ b/compiler/rustc_codegen_ssa/src/traits/misc.rs @@ -19,4 +19,5 @@ pub trait MiscMethods<'tcx>: BackendTypes { fn apply_target_cpu_attr(&self, llfn: Self::Function); /// Declares the extern "C" main function for the entry point. Returns None if the symbol already exists. fn declare_c_main(&self, fn_type: Self::Type) -> Option; + fn create_autodiff(&self) -> Vec; } diff --git a/compiler/rustc_codegen_ssa/src/traits/write.rs b/compiler/rustc_codegen_ssa/src/traits/write.rs index 048540894ac9b..18f3855c5d5f9 100644 --- a/compiler/rustc_codegen_ssa/src/traits/write.rs +++ b/compiler/rustc_codegen_ssa/src/traits/write.rs @@ -2,6 +2,8 @@ use crate::back::lto::{LtoModuleCodegen, SerializedModule, ThinModule}; use crate::back::write::{CodegenContext, FatLtoInput, ModuleConfig}; use crate::{CompiledModule, ModuleCodegen}; +use rustc_ast::expand::autodiff_attrs::AutoDiffItem; +use rustc_data_structures::fx::FxHashMap; use rustc_errors::{DiagCtxt, FatalError}; use rustc_middle::dep_graph::WorkProduct; @@ -12,6 +14,7 @@ pub trait WriteBackendMethods: 'static + Sized + Clone { type ModuleBuffer: ModuleBufferMethods; type ThinData: Send + Sync; type ThinBuffer: ThinBufferMethods; + type TypeTree: Clone; /// Merge all modules into main_module and returning it fn run_link( @@ -58,6 +61,15 @@ pub trait WriteBackendMethods: 'static + Sized + Clone { ) -> Result; fn prepare_thin(module: ModuleCodegen) -> (String, Self::ThinBuffer); fn serialize_module(module: ModuleCodegen) -> (String, Self::ModuleBuffer); + /// Generate autodiff rules + fn autodiff( + cgcx: &CodegenContext, + module: &ModuleCodegen, + diff_fncs: Vec, + typetrees: FxHashMap, + config: &ModuleConfig, + ) -> Result<(), FatalError>; + fn typetrees(module: &mut Self::Module) -> FxHashMap; } pub trait ThinBufferMethods: Send + Sync { diff --git a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp index 0df7b7eed11f9..49873e5e02cd4 100644 --- a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp +++ b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp @@ -300,17 +300,39 @@ extern "C" void LLVMRustAddFunctionAttributes(LLVMValueRef Fn, unsigned Index, AddAttributes(F, Index, Attrs, AttrsLen); } +extern "C" LLVMAttributeRef +LLVMRustCreateAttrNoValue(LLVMContextRef C, LLVMRustAttribute RustAttr) { + return wrap(Attribute::get(*unwrap(C), fromRust(RustAttr))); +} + +extern "C" LLVMTypeRef LLVMRustGetFunctionType(LLVMValueRef Fn) { + auto Ftype = unwrap(Fn)->getFunctionType(); + return wrap(Ftype); +} + +extern "C" void LLVMRustRemoveEnumAttributeAtIndex(LLVMValueRef F, size_t index, + LLVMRustAttribute RustAttr) { + LLVMRemoveEnumAttributeAtIndex(F, index, fromRust(RustAttr)); +} + +extern "C" void LLVMRustAddEnumAttributeAtIndex(LLVMContextRef C, + LLVMValueRef F, size_t index, + LLVMRustAttribute RustAttr) { + LLVMAddAttributeAtIndex(F, index, LLVMRustCreateAttrNoValue(C, RustAttr)); +} + +extern "C" LLVMAttributeRef +LLVMRustGetEnumAttributeAtIndex(LLVMValueRef F, size_t index, + LLVMRustAttribute RustAttr) { + return LLVMGetEnumAttributeAtIndex(F, index, fromRust(RustAttr)); +} + extern "C" void LLVMRustAddCallSiteAttributes(LLVMValueRef Instr, unsigned Index, LLVMAttributeRef *Attrs, size_t AttrsLen) { CallBase *Call = unwrap(Instr); AddAttributes(Call, Index, Attrs, AttrsLen); } -extern "C" LLVMAttributeRef LLVMRustCreateAttrNoValue(LLVMContextRef C, - LLVMRustAttribute RustAttr) { - return wrap(Attribute::get(*unwrap(C), fromRust(RustAttr))); -} - extern "C" LLVMAttributeRef LLVMRustCreateAlignmentAttr(LLVMContextRef C, uint64_t Bytes) { return wrap(Attribute::getWithAlignment(*unwrap(C), llvm::Align(Bytes))); From 670054202ad12fe6926edafb75dda45c9f00b9d2 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Mon, 15 Jan 2024 18:02:02 -0500 Subject: [PATCH 010/100] add minimal ast based macro parser --- compiler/rustc_ast/src/ast.rs | 6 + .../rustc_ast/src/expand/autodiff_attrs.rs | 59 ++++- compiler/rustc_ast/src/lib.rs | 1 + compiler/rustc_builtin_macros/src/autodiff.rs | 222 +++++++++++++----- 4 files changed, 231 insertions(+), 57 deletions(-) diff --git a/compiler/rustc_ast/src/ast.rs b/compiler/rustc_ast/src/ast.rs index 3496cfc38c84e..259f76abd0816 100644 --- a/compiler/rustc_ast/src/ast.rs +++ b/compiler/rustc_ast/src/ast.rs @@ -2574,6 +2574,12 @@ impl FnRetTy { FnRetTy::Ty(ty) => ty.span, } } + pub fn has_ret(&self) -> bool { + match self { + FnRetTy::Default(_) => false, + FnRetTy::Ty(_) => true, + } + } } #[derive(Clone, Copy, PartialEq, Encodable, Decodable, Debug)] diff --git a/compiler/rustc_ast/src/expand/autodiff_attrs.rs b/compiler/rustc_ast/src/expand/autodiff_attrs.rs index 2126868751b21..f14dfc6ba49a8 100644 --- a/compiler/rustc_ast/src/expand/autodiff_attrs.rs +++ b/compiler/rustc_ast/src/expand/autodiff_attrs.rs @@ -1,7 +1,11 @@ -use super::typetree::TypeTree; -use std::str::FromStr; use rustc_data_structures::stable_hasher::{HashStable, StableHasher};//, StableOrd}; use crate::HashStableContext; +use crate::expand::typetree::TypeTree; +use thin_vec::ThinVec; +//use rustc_expand::base::{Annotatable, ExtCtxt}; +use std::str::FromStr; + +use crate::NestedMetaItem; #[allow(dead_code)] #[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug)] @@ -52,12 +56,22 @@ impl HashStable for DiffActivity { } } +impl FromStr for DiffMode { + type Err = (); + fn from_str(s: &str) -> Result { match s { + "Inactive" => Ok(DiffMode::Inactive), + "Source" => Ok(DiffMode::Source), + "Forward" => Ok(DiffMode::Forward), + "Reverse" => Ok(DiffMode::Reverse), + _ => Err(()), + } + } +} impl FromStr for DiffActivity { type Err = (); - fn from_str(s: &str) -> Result { - match s { + fn from_str(s: &str) -> Result { match s { "None" => Ok(DiffActivity::None), "Active" => Ok(DiffActivity::Active), "Const" => Ok(DiffActivity::Const), @@ -76,6 +90,43 @@ pub struct AutoDiffAttrs { pub input_activity: Vec, } +fn name(x: &NestedMetaItem) -> String { + let segments = &x.meta_item().unwrap().path.segments; + assert!(segments.len() == 1); + segments[0].ident.name.to_string() +} + +impl AutoDiffAttrs{ + pub fn has_ret_activity(&self) -> bool { + match self.ret_activity { + DiffActivity::None => false, + _ => true, + } + } + pub fn from_ast(meta_item: &ThinVec, has_ret: bool) -> Self { + let mode = name(&meta_item[1]); + let mode = DiffMode::from_str(&mode).unwrap(); + let activities: Vec = meta_item[2..].iter().map(|x| { + let activity_str = name(&x); + DiffActivity::from_str(&activity_str).unwrap() + }).collect(); + + // If a return type exist, we need to split the last activity, + // otherwise we return None as placeholder. + let (ret_activity, input_activity) = if has_ret { + activities.split_last().unwrap() + } else { + (&DiffActivity::None, activities.as_slice()) + }; + + AutoDiffAttrs { + mode, + ret_activity: *ret_activity, + input_activity: input_activity.to_vec(), + } + } +} + impl HashStable for AutoDiffAttrs { fn hash_stable(&self, hcx: &mut CTX, hasher: &mut StableHasher) { self.mode.hash_stable(hcx, hasher); diff --git a/compiler/rustc_ast/src/lib.rs b/compiler/rustc_ast/src/lib.rs index 7e713a49a8cfa..e69d222c418a0 100644 --- a/compiler/rustc_ast/src/lib.rs +++ b/compiler/rustc_ast/src/lib.rs @@ -50,6 +50,7 @@ pub mod ptr; pub mod token; pub mod tokenstream; pub mod visit; +//pub mod autodiff_attrs; pub use self::ast::*; pub use self::ast_traits::{AstDeref, AstNodeWrapper, HasAttrs, HasNodeId, HasSpan, HasTokens}; diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index e5f4ee160c3bf..5ca3cf63565fd 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -1,17 +1,19 @@ -#![allow(unused)] - -use crate::errors; +#![allow(unused_imports)] //use crate::util::check_builtin_macro_attribute; //use crate::util::check_autodiff; +use std::string::String; +use crate::errors; use rustc_ast::ptr::P; -use rustc_ast::{self as ast, FnHeader, FnSig, Generics, StmtKind}; -use rustc_ast::{Fn, ItemKind, Stmt, TyKind, Unsafe}; +use rustc_ast::{BindingAnnotation, ByRef}; +use rustc_ast::{self as ast, FnHeader, FnSig, Generics, StmtKind, NestedMetaItem, MetaItemKind}; +use rustc_ast::{Fn, ItemKind, Stmt, TyKind, Unsafe, PatKind}; use rustc_expand::base::{Annotatable, ExtCtxt}; use rustc_span::symbol::{kw, sym, Ident}; use rustc_span::Span; use thin_vec::{thin_vec, ThinVec}; use rustc_span::Symbol; +use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode}; pub fn expand( ecx: &mut ExtCtxt<'_>, @@ -20,69 +22,183 @@ pub fn expand( item: Annotatable, ) -> Vec { //check_builtin_macro_attribute(ecx, meta_item, sym::alloc_error_handler); - //check_builtin_macro_attribute(ecx, meta_item, sym::autodiff); - dbg!(&meta_item); + let meta_item_vec: ThinVec = match meta_item.kind { + ast::MetaItemKind::List(ref vec) => vec.clone(), + _ => { + ecx.sess + .dcx() + .emit_err(errors::AutoDiffInvalidApplication { span: item.span() }); + return vec![item]; + } + }; let input = item.clone(); let orig_item: P = item.clone().expect_item(); let mut d_item: P = item.clone().expect_item(); + let primal = orig_item.ident.clone(); - // Allow using `#[autodiff(...)]` on a Fn - let (fn_item, _ty_span) = if let Annotatable::Item(item) = &item + // Allow using `#[autodiff(...)]` only on a Fn + let (fn_item, has_ret, sig, sig_span) = if let Annotatable::Item(item) = &item && let ItemKind::Fn(box ast::Fn { sig, .. }) = &item.kind { - dbg!(&item); - (item, ecx.with_def_site_ctxt(sig.span)) + (item, sig.decl.output.has_ret(), sig, ecx.with_def_site_ctxt(sig.span)) } else { ecx.sess .dcx() .emit_err(errors::AutoDiffInvalidApplication { span: item.span() }); return vec![input]; }; - let _x: &ItemKind = &fn_item.kind; - d_item.ident.name = - Symbol::intern(format!("d_{}", fn_item.ident.name).as_str()); + let x: AutoDiffAttrs = AutoDiffAttrs::from_ast(&meta_item_vec, has_ret); + dbg!(&x); + let span = ecx.with_def_site_ctxt(fn_item.span); + + let (d_decl, old_names, new_args) = gen_enzyme_decl(ecx, &sig.decl, &x, span, sig_span); + let d_body = gen_enzyme_body(ecx, primal, &old_names, &new_args, span, sig_span); + let meta_item_name = meta_item_vec[0].meta_item().unwrap(); + d_item.ident = meta_item_name.path.segments[0].ident; + // update d_item + if let ItemKind::Fn(box ast::Fn { sig, body, .. }) = &mut d_item.kind { + *sig.decl = d_decl; + *body = Some(d_body); + } else { + ecx.sess + .dcx() + .emit_err(errors::AutoDiffInvalidApplication { span: item.span() }); + return vec![input]; + } + let orig_annotatable = Annotatable::Item(orig_item.clone()); let d_annotatable = Annotatable::Item(d_item.clone()); return vec![orig_annotatable, d_annotatable]; } -// #[rustc_std_internal_symbol] -// unsafe fn __rg_oom(size: usize, align: usize) -> ! { -// handler(core::alloc::Layout::from_size_align_unchecked(size, align)) -// } -//fn generate_handler(cx: &ExtCtxt<'_>, handler: Ident, span: Span, sig_span: Span) -> Stmt { -// let usize = cx.path_ident(span, Ident::new(sym::usize, span)); -// let ty_usize = cx.ty_path(usize); -// let size = Ident::from_str_and_span("size", span); -// let align = Ident::from_str_and_span("align", span); -// -// let layout_new = cx.std_path(&[sym::alloc, sym::Layout, sym::from_size_align_unchecked]); -// let layout_new = cx.expr_path(cx.path(span, layout_new)); -// let layout = cx.expr_call( -// span, -// layout_new, -// thin_vec![cx.expr_ident(span, size), cx.expr_ident(span, align)], -// ); -// -// let call = cx.expr_call_ident(sig_span, handler, thin_vec![layout]); -// -// let never = ast::FnRetTy::Ty(cx.ty(span, TyKind::Never)); -// let params = thin_vec![cx.param(span, size, ty_usize.clone()), cx.param(span, align, ty_usize)]; -// let decl = cx.fn_decl(params, never); -// let header = FnHeader { unsafety: Unsafe::Yes(span), ..FnHeader::default() }; -// let sig = FnSig { decl, header, span: span }; -// -// let body = Some(cx.block_expr(call)); -// let kind = ItemKind::Fn(Box::new(Fn { -// defaultness: ast::Defaultness::Final, -// sig, -// generics: Generics::default(), -// body, -// })); -// -// let attrs = thin_vec![cx.attr_word(sym::rustc_std_internal_symbol, span)]; -// -// let item = cx.item(span, Ident::from_str_and_span("__rg_oom", span), attrs, kind); -// cx.stmt_item(sig_span, item) -//} +// shadow arguments must be mutable references or ptrs, because Enzyme will write into them. +fn assure_mut_ref(ty: &ast::Ty) -> ast::Ty { + let mut ty = ty.clone(); + match ty.kind { + TyKind::Ptr(ref mut mut_ty) => { + mut_ty.mutbl = ast::Mutability::Mut; + } + TyKind::Ref(_, ref mut mut_ty) => { + mut_ty.mutbl = ast::Mutability::Mut; + } + _ => { + panic!("unsupported type: {:?}", ty); + } + } + ty +} + + +// The body of our generated functions will consist of three black_Box calls. +// The first will call the primal function with the original arguments. +// The second will just take the shadow arguments. +// The third will (unsafely) call std::mem::zeroed(), to match the return type of the new function +// (whatever that might be). This way we surpress rustc from optimizing anyt argument away. +fn gen_enzyme_body(ecx: &ExtCtxt<'_>, primal: Ident, old_names: &[String], new_names: &[String], span: Span, sig_span: Span) -> P { + let blackbox_path = ecx.std_path(&[Symbol::intern("hint"), Symbol::intern("black_box")]); + let zeroed_path = ecx.std_path(&[Symbol::intern("mem"), Symbol::intern("zeroed")]); + + let primal_call_expr = ecx.expr_path(ecx.path_ident(span, primal)); + let blackbox_call_expr = ecx.expr_path(ecx.path(span, blackbox_path)); + let zeroed_call_expr = ecx.expr_path(ecx.path(span, zeroed_path)); + + let mem_zeroed_call: Stmt = ecx.stmt_expr(ecx.expr_call( + span, + zeroed_call_expr.clone(), + thin_vec![], + )); + let unsafe_block_with_zeroed_call: P = ecx.expr_block(P(ast::Block { + stmts: thin_vec![mem_zeroed_call], + id: ast::DUMMY_NODE_ID, + rules: ast::BlockCheckMode::Unsafe(ast::UserProvided), + span: sig_span, + tokens: None, + could_be_bare_literal: false, + })); + // create ::core::hint::black_box(array(arr)); + let _primal_call = ecx.expr_call( + span, + primal_call_expr.clone(), + old_names.iter().map(|name| { + ecx.expr_path(ecx.path_ident(span, Ident::from_str(name))) + }).collect(), + ); + + // create ::core::hint::black_box(grad_arr, tang_y)); + let black_box1 = ecx.expr_call( + sig_span, + blackbox_call_expr.clone(), + new_names.iter().map(|arg| { + ecx.expr_path(ecx.path_ident(span, Ident::from_str(arg))) + }).collect(), + ); + + // create ::core::hint::black_box(unsafe { ::core::mem::zeroed() }) + let black_box2 = ecx.expr_call( + sig_span, + blackbox_call_expr.clone(), + thin_vec![unsafe_block_with_zeroed_call.clone()], + ); + + let mut body = ecx.block(span, ThinVec::new()); + body.stmts.push(ecx.stmt_expr(black_box1)); + body.stmts.push(ecx.stmt_expr(black_box2)); + body +} + +// Generate the new function declaration. Const arguments are kept as is. Duplicated arguments must +// be pointers or references. Those receive a shadow argument, which is a mutable reference/pointer. +// Active arguments must be scalars. Their shadow argument is added to the return type (and will be +// zero-initialized by Enzyme). Active arguments are not handled yet. +// Each argument of the primal function (and the return type if existing) must be annotated with an +// activity. +fn gen_enzyme_decl(_ecx: &ExtCtxt<'_>, decl: &ast::FnDecl, x: &AutoDiffAttrs, _span: Span, _sig_span: Span) + -> (ast::FnDecl, Vec, Vec) { + assert!(decl.inputs.len() == x.input_activity.len()); + assert!(decl.output.has_ret() == x.has_ret_activity()); + let mut d_decl = decl.clone(); + let mut d_inputs = Vec::new(); + let mut new_inputs = Vec::new(); + let mut old_names = Vec::new(); + for (arg, activity) in decl.inputs.iter().zip(x.input_activity.iter()) { + dbg!(&arg); + d_inputs.push(arg.clone()); + match activity { + DiffActivity::Duplicated => { + let mut shadow_arg = arg.clone(); + shadow_arg.ty = P(assure_mut_ref(&arg.ty)); + // adjust name depending on mode + let old_name = if let PatKind::Ident(_, ident, _) = shadow_arg.pat.kind { + ident.name + } else { + dbg!(&shadow_arg.pat); + panic!("not an ident?"); + }; + old_names.push(old_name.to_string()); + let name: String = match x.mode { + DiffMode::Reverse => format!("d{}", old_name), + DiffMode::Forward => format!("b{}", old_name), + _ => panic!("unsupported mode: {}", old_name), + }; + dbg!(&name); + new_inputs.push(name.clone()); + shadow_arg.pat = P(ast::Pat { + // TODO: Check id + id: ast::DUMMY_NODE_ID, + kind: PatKind::Ident(BindingAnnotation::NONE, + Ident::from_str_and_span(&name, shadow_arg.pat.span), + None, + ), + span: shadow_arg.pat.span, + tokens: shadow_arg.pat.tokens.clone(), + }); + + d_inputs.push(shadow_arg); + } + _ => {}, + } + } + d_decl.inputs = d_inputs.into(); + (d_decl, old_names, new_inputs) +} From bf72f16b4054f1d5e5d8b759cd83ca823b334efd Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Tue, 16 Jan 2024 20:37:02 -0500 Subject: [PATCH 011/100] yeet --- compiler/rustc_builtin_macros/src/autodiff.rs | 64 +++++++++++-------- compiler/rustc_expand/src/build.rs | 3 + 2 files changed, 42 insertions(+), 25 deletions(-) diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index 5ca3cf63565fd..274107613d1a3 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -1,4 +1,5 @@ #![allow(unused_imports)] +#![allow(unused_variables)] //use crate::util::check_builtin_macro_attribute; //use crate::util::check_autodiff; @@ -32,9 +33,7 @@ pub fn expand( return vec![item]; } }; - let input = item.clone(); let orig_item: P = item.clone().expect_item(); - let mut d_item: P = item.clone().expect_item(); let primal = orig_item.ident.clone(); // Allow using `#[autodiff(...)]` only on a Fn @@ -46,29 +45,27 @@ pub fn expand( ecx.sess .dcx() .emit_err(errors::AutoDiffInvalidApplication { span: item.span() }); - return vec![input]; + return vec![item]; }; let x: AutoDiffAttrs = AutoDiffAttrs::from_ast(&meta_item_vec, has_ret); dbg!(&x); let span = ecx.with_def_site_ctxt(fn_item.span); - let (d_decl, old_names, new_args) = gen_enzyme_decl(ecx, &sig.decl, &x, span, sig_span); + let (d_sig, old_names, new_args) = gen_enzyme_decl(ecx, &sig, &x, span, sig_span); let d_body = gen_enzyme_body(ecx, primal, &old_names, &new_args, span, sig_span); - let meta_item_name = meta_item_vec[0].meta_item().unwrap(); - d_item.ident = meta_item_name.path.segments[0].ident; - // update d_item - if let ItemKind::Fn(box ast::Fn { sig, body, .. }) = &mut d_item.kind { - *sig.decl = d_decl; - *body = Some(d_body); - } else { - ecx.sess - .dcx() - .emit_err(errors::AutoDiffInvalidApplication { span: item.span() }); - return vec![input]; - } + let d_ident = meta_item_vec[0].meta_item().unwrap().path.segments[0].ident; + + // The first element of it is the name of the function to be generated + let asdf = ItemKind::Fn(Box::new(ast::Fn { + defaultness: ast::Defaultness::Final, + sig: d_sig, + generics: Generics::default(), + body: Some(d_body), + })); + let d_fn = ecx.item(span, d_ident, rustc_ast::AttrVec::default(), asdf); let orig_annotatable = Annotatable::Item(orig_item.clone()); - let d_annotatable = Annotatable::Item(d_item.clone()); + let d_annotatable = Annotatable::Item(d_fn); return vec![orig_annotatable, d_annotatable]; } @@ -98,6 +95,9 @@ fn assure_mut_ref(ty: &ast::Ty) -> ast::Ty { fn gen_enzyme_body(ecx: &ExtCtxt<'_>, primal: Ident, old_names: &[String], new_names: &[String], span: Span, sig_span: Span) -> P { let blackbox_path = ecx.std_path(&[Symbol::intern("hint"), Symbol::intern("black_box")]); let zeroed_path = ecx.std_path(&[Symbol::intern("mem"), Symbol::intern("zeroed")]); + let empty_loop_block = ecx.block(span, ThinVec::new()); + let loop_expr = ecx.expr_loop(span, empty_loop_block); + let primal_call_expr = ecx.expr_path(ecx.path_ident(span, primal)); let blackbox_call_expr = ecx.expr_path(ecx.path(span, blackbox_path)); @@ -117,13 +117,18 @@ fn gen_enzyme_body(ecx: &ExtCtxt<'_>, primal: Ident, old_names: &[String], new_n could_be_bare_literal: false, })); // create ::core::hint::black_box(array(arr)); - let _primal_call = ecx.expr_call( + let primal_call = ecx.expr_call( span, primal_call_expr.clone(), old_names.iter().map(|name| { ecx.expr_path(ecx.path_ident(span, Ident::from_str(name))) }).collect(), ); + let black_box0 = ecx.expr_call( + sig_span, + blackbox_call_expr.clone(), + thin_vec![primal_call.clone()], + ); // create ::core::hint::black_box(grad_arr, tang_y)); let black_box1 = ecx.expr_call( @@ -135,15 +140,18 @@ fn gen_enzyme_body(ecx: &ExtCtxt<'_>, primal: Ident, old_names: &[String], new_n ); // create ::core::hint::black_box(unsafe { ::core::mem::zeroed() }) - let black_box2 = ecx.expr_call( + let _black_box2 = ecx.expr_call( sig_span, blackbox_call_expr.clone(), thin_vec![unsafe_block_with_zeroed_call.clone()], ); let mut body = ecx.block(span, ThinVec::new()); - body.stmts.push(ecx.stmt_expr(black_box1)); - body.stmts.push(ecx.stmt_expr(black_box2)); + body.stmts.push(ecx.stmt_expr(primal_call)); + //body.stmts.push(ecx.stmt_expr(black_box0)); + //body.stmts.push(ecx.stmt_expr(black_box1)); + //body.stmts.push(ecx.stmt_expr(black_box2)); + body.stmts.push(ecx.stmt_expr(loop_expr)); body } @@ -153,8 +161,9 @@ fn gen_enzyme_body(ecx: &ExtCtxt<'_>, primal: Ident, old_names: &[String], new_n // zero-initialized by Enzyme). Active arguments are not handled yet. // Each argument of the primal function (and the return type if existing) must be annotated with an // activity. -fn gen_enzyme_decl(_ecx: &ExtCtxt<'_>, decl: &ast::FnDecl, x: &AutoDiffAttrs, _span: Span, _sig_span: Span) - -> (ast::FnDecl, Vec, Vec) { +fn gen_enzyme_decl(_ecx: &ExtCtxt<'_>, sig: &ast::FnSig, x: &AutoDiffAttrs, span: Span, _sig_span: Span) + -> (ast::FnSig, Vec, Vec) { + let decl: P = sig.decl.clone(); assert!(decl.inputs.len() == x.input_activity.len()); assert!(decl.output.has_ret() == x.has_ret_activity()); let mut d_decl = decl.clone(); @@ -162,7 +171,7 @@ fn gen_enzyme_decl(_ecx: &ExtCtxt<'_>, decl: &ast::FnDecl, x: &AutoDiffAttrs, _s let mut new_inputs = Vec::new(); let mut old_names = Vec::new(); for (arg, activity) in decl.inputs.iter().zip(x.input_activity.iter()) { - dbg!(&arg); + //dbg!(&arg); d_inputs.push(arg.clone()); match activity { DiffActivity::Duplicated => { @@ -200,5 +209,10 @@ fn gen_enzyme_decl(_ecx: &ExtCtxt<'_>, decl: &ast::FnDecl, x: &AutoDiffAttrs, _s } } d_decl.inputs = d_inputs.into(); - (d_decl, old_names, new_inputs) + let d_sig = FnSig { + header: sig.header.clone(), + decl: d_decl, + span, + }; + (d_sig, old_names, new_inputs) } diff --git a/compiler/rustc_expand/src/build.rs b/compiler/rustc_expand/src/build.rs index f9bfebee12e92..51d87ac03487e 100644 --- a/compiler/rustc_expand/src/build.rs +++ b/compiler/rustc_expand/src/build.rs @@ -405,6 +405,9 @@ impl<'a> ExtCtxt<'a> { pub fn expr_tuple(&self, sp: Span, exprs: ThinVec>) -> P { self.expr(sp, ast::ExprKind::Tup(exprs)) } + pub fn expr_loop(&self, sp: Span, block: P) -> P { + self.expr(sp, ast::ExprKind::Loop(block, None, sp)) + } pub fn expr_fail(&self, span: Span, msg: Symbol) -> P { self.expr_call_global( From 001bc8ac600d16a0ce4000f9121433a68e36e7d1 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Tue, 23 Jan 2024 20:50:10 -0500 Subject: [PATCH 012/100] compiles and runs binaries, but without Enzyme yet --- .../rustc_ast/src/expand/autodiff_attrs.rs | 16 +++++------ compiler/rustc_builtin_macros/src/autodiff.rs | 28 +++++++++++-------- 2 files changed, 24 insertions(+), 20 deletions(-) diff --git a/compiler/rustc_ast/src/expand/autodiff_attrs.rs b/compiler/rustc_ast/src/expand/autodiff_attrs.rs index f14dfc6ba49a8..a8e7c7eca5611 100644 --- a/compiler/rustc_ast/src/expand/autodiff_attrs.rs +++ b/compiler/rustc_ast/src/expand/autodiff_attrs.rs @@ -83,7 +83,7 @@ impl FromStr for DiffActivity { } #[allow(dead_code)] -#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug)] +#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] pub struct AutoDiffAttrs { pub mode: DiffMode, pub ret_activity: DiffActivity, @@ -127,13 +127,13 @@ impl AutoDiffAttrs{ } } -impl HashStable for AutoDiffAttrs { - fn hash_stable(&self, hcx: &mut CTX, hasher: &mut StableHasher) { - self.mode.hash_stable(hcx, hasher); - self.ret_activity.hash_stable(hcx, hasher); - self.input_activity.hash_stable(hcx, hasher); - } -} +//impl HashStable for AutoDiffAttrs { +// fn hash_stable(&self, hcx: &mut CTX, hasher: &mut StableHasher) { +// self.mode.hash_stable(hcx, hasher); +// self.ret_activity.hash_stable(hcx, hasher); +// self.input_activity.hash_stable(hcx, hasher); +// } +//} impl AutoDiffAttrs { pub fn inactive() -> Self { diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index 274107613d1a3..fa1dc8e330436 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -18,7 +18,7 @@ use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode}; pub fn expand( ecx: &mut ExtCtxt<'_>, - _span: Span, + expand_span: Span, meta_item: &ast::MetaItem, item: Annotatable, ) -> Vec { @@ -40,7 +40,7 @@ pub fn expand( let (fn_item, has_ret, sig, sig_span) = if let Annotatable::Item(item) = &item && let ItemKind::Fn(box ast::Fn { sig, .. }) = &item.kind { - (item, sig.decl.output.has_ret(), sig, ecx.with_def_site_ctxt(sig.span)) + (item, sig.decl.output.has_ret(), sig, ecx.with_call_site_ctxt(sig.span)) } else { ecx.sess .dcx() @@ -49,10 +49,14 @@ pub fn expand( }; let x: AutoDiffAttrs = AutoDiffAttrs::from_ast(&meta_item_vec, has_ret); dbg!(&x); - let span = ecx.with_def_site_ctxt(fn_item.span); + //let span = ecx.with_def_site_ctxt(sig_span); + let span = ecx.with_def_site_ctxt(expand_span); + //let span = ecx.with_def_site_ctxt(fn_item.span); let (d_sig, old_names, new_args) = gen_enzyme_decl(ecx, &sig, &x, span, sig_span); - let d_body = gen_enzyme_body(ecx, primal, &old_names, &new_args, span, sig_span); + let new_decl_span = d_sig.span; + //let d_body = gen_enzyme_body(ecx, primal, &old_names, &new_args, span, span); + let d_body = gen_enzyme_body(ecx, primal, &old_names, &new_args, span, sig_span, new_decl_span); let d_ident = meta_item_vec[0].meta_item().unwrap().path.segments[0].ident; // The first element of it is the name of the function to be generated @@ -92,7 +96,7 @@ fn assure_mut_ref(ty: &ast::Ty) -> ast::Ty { // The second will just take the shadow arguments. // The third will (unsafely) call std::mem::zeroed(), to match the return type of the new function // (whatever that might be). This way we surpress rustc from optimizing anyt argument away. -fn gen_enzyme_body(ecx: &ExtCtxt<'_>, primal: Ident, old_names: &[String], new_names: &[String], span: Span, sig_span: Span) -> P { +fn gen_enzyme_body(ecx: &ExtCtxt<'_>, primal: Ident, old_names: &[String], new_names: &[String], span: Span, sig_span: Span, new_decl_span: Span) -> P { let blackbox_path = ecx.std_path(&[Symbol::intern("hint"), Symbol::intern("black_box")]); let zeroed_path = ecx.std_path(&[Symbol::intern("mem"), Symbol::intern("zeroed")]); let empty_loop_block = ecx.block(span, ThinVec::new()); @@ -118,14 +122,14 @@ fn gen_enzyme_body(ecx: &ExtCtxt<'_>, primal: Ident, old_names: &[String], new_n })); // create ::core::hint::black_box(array(arr)); let primal_call = ecx.expr_call( - span, - primal_call_expr.clone(), + new_decl_span, + primal_call_expr, old_names.iter().map(|name| { - ecx.expr_path(ecx.path_ident(span, Ident::from_str(name))) + ecx.expr_path(ecx.path_ident(new_decl_span, Ident::from_str(name))) }).collect(), ); let black_box0 = ecx.expr_call( - sig_span, + new_decl_span, blackbox_call_expr.clone(), thin_vec![primal_call.clone()], ); @@ -140,17 +144,17 @@ fn gen_enzyme_body(ecx: &ExtCtxt<'_>, primal: Ident, old_names: &[String], new_n ); // create ::core::hint::black_box(unsafe { ::core::mem::zeroed() }) - let _black_box2 = ecx.expr_call( + let black_box2 = ecx.expr_call( sig_span, blackbox_call_expr.clone(), thin_vec![unsafe_block_with_zeroed_call.clone()], ); let mut body = ecx.block(span, ThinVec::new()); - body.stmts.push(ecx.stmt_expr(primal_call)); + //body.stmts.push(ecx.stmt_expr(primal_call)); //body.stmts.push(ecx.stmt_expr(black_box0)); //body.stmts.push(ecx.stmt_expr(black_box1)); - //body.stmts.push(ecx.stmt_expr(black_box2)); + body.stmts.push(ecx.stmt_expr(black_box2)); body.stmts.push(ecx.stmt_expr(loop_expr)); body } From e3ba2d2e329b96aa1ea0514e21c4848cbed0a09d Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Wed, 24 Jan 2024 22:39:19 -0500 Subject: [PATCH 013/100] I will clean this up, I promise! --- .../rustc_ast/src/expand/autodiff_attrs.rs | 14 ++++- compiler/rustc_builtin_macros/src/autodiff.rs | 55 ++++++++++++++++++- compiler/rustc_codegen_llvm/src/lib.rs | 1 + compiler/rustc_codegen_ssa/src/back/write.rs | 1 + .../rustc_codegen_ssa/src/codegen_attrs.rs | 35 +++++------- compiler/rustc_feature/src/builtin_attrs.rs | 7 +++ .../rustc_monomorphize/src/partitioning.rs | 6 +- compiler/rustc_span/src/symbol.rs | 1 + 8 files changed, 92 insertions(+), 28 deletions(-) diff --git a/compiler/rustc_ast/src/expand/autodiff_attrs.rs b/compiler/rustc_ast/src/expand/autodiff_attrs.rs index a8e7c7eca5611..3b77d1c347b7a 100644 --- a/compiler/rustc_ast/src/expand/autodiff_attrs.rs +++ b/compiler/rustc_ast/src/expand/autodiff_attrs.rs @@ -90,10 +90,13 @@ pub struct AutoDiffAttrs { pub input_activity: Vec, } -fn name(x: &NestedMetaItem) -> String { +fn first_ident(x: &NestedMetaItem) -> rustc_span::symbol::Ident { let segments = &x.meta_item().unwrap().path.segments; assert!(segments.len() == 1); - segments[0].ident.name.to_string() + segments[0].ident +} +fn name(x: &NestedMetaItem) -> String { + first_ident(x).name.to_string() } impl AutoDiffAttrs{ @@ -143,6 +146,13 @@ impl AutoDiffAttrs { input_activity: Vec::new(), } } + pub fn source() -> Self { + AutoDiffAttrs { + mode: DiffMode::Source, + ret_activity: DiffActivity::None, + input_activity: Vec::new(), + } + } pub fn is_active(&self) -> bool { match self.mode { diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index fa1dc8e330436..8b414930fdc76 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -1,5 +1,6 @@ #![allow(unused_imports)] #![allow(unused_variables)] +#![allow(unused_mut)] //use crate::util::check_builtin_macro_attribute; //use crate::util::check_autodiff; @@ -9,12 +10,20 @@ use rustc_ast::ptr::P; use rustc_ast::{BindingAnnotation, ByRef}; use rustc_ast::{self as ast, FnHeader, FnSig, Generics, StmtKind, NestedMetaItem, MetaItemKind}; use rustc_ast::{Fn, ItemKind, Stmt, TyKind, Unsafe, PatKind}; +use rustc_ast::tokenstream::*; use rustc_expand::base::{Annotatable, ExtCtxt}; use rustc_span::symbol::{kw, sym, Ident}; use rustc_span::Span; use thin_vec::{thin_vec, ThinVec}; use rustc_span::Symbol; use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode}; +use rustc_ast::token::{Token, TokenKind}; + +fn first_ident(x: &NestedMetaItem) -> rustc_span::symbol::Ident { + let segments = &x.meta_item().unwrap().path.segments; + assert!(segments.len() == 1); + segments[0].ident +} pub fn expand( ecx: &mut ExtCtxt<'_>, @@ -33,7 +42,8 @@ pub fn expand( return vec![item]; } }; - let orig_item: P = item.clone().expect_item(); + let mut orig_item: P = item.clone().expect_item(); + //dbg!(&orig_item.tokens); let primal = orig_item.ident.clone(); // Allow using `#[autodiff(...)]` only on a Fn @@ -47,6 +57,25 @@ pub fn expand( .emit_err(errors::AutoDiffInvalidApplication { span: item.span() }); return vec![item]; }; + // create TokenStream from vec elemtents: + // meta_item doesn't have a .tokens field + let ts: Vec = meta_item_vec.clone()[1..].iter().map(|x| { + let val = first_ident(x); + let t = Token::from_ast_ident(val); + t + }).collect(); + let comma: Token = Token::new(TokenKind::Comma, Span::default()); + let mut ts: Vec = vec![]; + for t in meta_item_vec.clone()[1..].iter() { + let val = first_ident(t); + let t = Token::from_ast_ident(val); + ts.push(TokenTree::Token(t, Spacing::Joint)); + ts.push(TokenTree::Token(comma.clone(), Spacing::Alone)); + } + dbg!(&ts); + let ts: TokenStream = TokenStream::from_iter(ts); + dbg!(&ts); + let x: AutoDiffAttrs = AutoDiffAttrs::from_ast(&meta_item_vec, has_ret); dbg!(&x); //let span = ecx.with_def_site_ctxt(sig_span); @@ -66,7 +95,29 @@ pub fn expand( generics: Generics::default(), body: Some(d_body), })); - let d_fn = ecx.item(span, d_ident, rustc_ast::AttrVec::default(), asdf); + let mut tmp = P(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::autodiff_into))); + let mut attr: ast::Attribute = ast::Attribute { + kind: ast::AttrKind::Normal(tmp.clone()), + id: ast::AttrId::from_u32(0), + style: ast::AttrStyle::Outer, + span: span, + }; + orig_item.attrs.push(attr); + + // Now update for d_fn + tmp.item.args = rustc_ast::AttrArgs::Delimited(rustc_ast::DelimArgs { + dspan: DelimSpan::dummy(), + delim: rustc_ast::token::Delimiter::Parenthesis, + tokens: ts, + }); + let mut attr2: ast::Attribute = ast::Attribute { + kind: ast::AttrKind::Normal(tmp), + id: ast::AttrId::from_u32(0), + style: ast::AttrStyle::Outer, + span: span, + }; + let attr_vec: rustc_ast::AttrVec = thin_vec![attr2]; + let d_fn = ecx.item(span, d_ident, attr_vec, asdf); let orig_annotatable = Annotatable::Item(orig_item.clone()); let d_annotatable = Annotatable::Item(d_fn); diff --git a/compiler/rustc_codegen_llvm/src/lib.rs b/compiler/rustc_codegen_llvm/src/lib.rs index b52042f55e71d..6f08211a926c8 100644 --- a/compiler/rustc_codegen_llvm/src/lib.rs +++ b/compiler/rustc_codegen_llvm/src/lib.rs @@ -269,6 +269,7 @@ impl WriteBackendMethods for LlvmCodegenBackend { config: &ModuleConfig, ) -> Result<(), FatalError> { dbg!("cg_llvm autodiff"); + dbg!("Differentiating {} functions", diff_fncs.len()); unsafe { back::write::differentiate(module, cgcx, diff_fncs, typetrees, config) } } diff --git a/compiler/rustc_codegen_ssa/src/back/write.rs b/compiler/rustc_codegen_ssa/src/back/write.rs index 166a0852803e4..4328509d84a39 100644 --- a/compiler/rustc_codegen_ssa/src/back/write.rs +++ b/compiler/rustc_codegen_ssa/src/back/write.rs @@ -382,6 +382,7 @@ fn generate_lto_work( import_only_modules: Vec<(SerializedModule, WorkProduct)>, ) -> Vec<(WorkItem, u64)> { let _prof_timer = cgcx.prof.generic_activity("codegen_generate_lto_work"); + dbg!("Differentiating {} functions", autodiff.len()); if !needs_fat_lto.is_empty() { assert!(needs_thin_lto.is_empty()); diff --git a/compiler/rustc_codegen_ssa/src/codegen_attrs.rs b/compiler/rustc_codegen_ssa/src/codegen_attrs.rs index 4a86909ce83b3..a60dbacba66fc 100644 --- a/compiler/rustc_codegen_ssa/src/codegen_attrs.rs +++ b/compiler/rustc_codegen_ssa/src/codegen_attrs.rs @@ -693,16 +693,15 @@ fn check_link_name_xor_ordinal( } fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs { - //let attrs = tcx.get_attrs(id, sym::autodiff_into); - let attrs = tcx.get_attrs(id, sym::autodiff); + let attrs = tcx.get_attrs(id, sym::autodiff_into); let attrs = attrs .into_iter() - .filter(|attr| attr.name_or_empty() == sym::autodiff) - //.filter(|attr| attr.name_or_empty() == sym::autodiff_into) + .filter(|attr| attr.name_or_empty() == sym::autodiff_into) .collect::>(); - if attrs.len() > 0 { - dbg!("autodiff_attrs len = > 0: {}", attrs.len()); + + if !attrs.is_empty() { + dbg!("autodiff_attrs amount = {}", attrs.len()); } // check for exactly one autodiff attribute on extern block @@ -723,18 +722,12 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs { let list = attr.meta_item_list().unwrap_or_default(); // empty autodiff attribute macros (i.e. `#[autodiff]`) are used to mark source functions - if list.len() == 0 { - return AutoDiffAttrs { - mode: DiffMode::Source, - ret_activity: DiffActivity::None, - input_activity: Vec::new(), - }; - } + if list.len() == 0 { return AutoDiffAttrs::source(); } let msg_ad_mode = "autodiff attribute must contain autodiff mode"; - let mode = match &list[0] { - NestedMetaItem::MetaItem(MetaItem { path: ref p2, kind: MetaItemKind::Word, .. }) => { - p2.segments.first().unwrap().ident + let (mode, list) = match list.split_first() { + Some((NestedMetaItem::MetaItem(MetaItem { path: ref p1, kind: MetaItemKind::Word, .. }), list)) => { + (p1.segments.first().unwrap().ident, list) } _ => { tcx.sess @@ -749,7 +742,6 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs { // parse mode let msg_mode = "mode should be either forward or reverse"; let mode = match mode.as_str() { - //map(|x| x.as_str()) { "Forward" => DiffMode::Forward, "Reverse" => DiffMode::Reverse, _ => { @@ -763,9 +755,9 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs { }; let msg_ret_activity = "autodiff attribute must contain the return activity"; - let ret_symbol = match &list[1] { - NestedMetaItem::MetaItem(MetaItem { path: ref p2, kind: MetaItemKind::Word, .. }) => { - p2.segments.first().unwrap().ident + let (ret_symbol, list) = match list.split_last() { + Some((NestedMetaItem::MetaItem(MetaItem { path: ref p1, kind: MetaItemKind::Word, .. }), list)) => { + (p1.segments.first().unwrap().ident, list) } _ => { tcx.sess @@ -792,7 +784,7 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs { let msg_arg_activity = "autodiff attribute must contain the return activity"; let mut arg_activities: Vec = vec![]; - for arg in &list[2..] { + for arg in list { let arg_symbol = match arg { NestedMetaItem::MetaItem(MetaItem { path: ref p2, kind: MetaItemKind::Word, .. @@ -846,6 +838,7 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs { if ret_activity == DiffActivity::Duplicated || ret_activity == DiffActivity::DuplicatedNoNeed { + dbg!("ret_activity = {:?}", ret_activity); tcx.sess .struct_span_err( attr.span, msg_rev_incompatible_arg, diff --git a/compiler/rustc_feature/src/builtin_attrs.rs b/compiler/rustc_feature/src/builtin_attrs.rs index 5523543cd4fb9..d84bcd07e3d1b 100644 --- a/compiler/rustc_feature/src/builtin_attrs.rs +++ b/compiler/rustc_feature/src/builtin_attrs.rs @@ -361,6 +361,13 @@ pub const BUILTIN_ATTRIBUTES: &[BuiltinAttribute] = &[ large_assignments, experimental!(move_size_limit) ), + // Autodiff + ungated!( + autodiff_into, Normal, + template!(Word, List: r#""...""#), + DuplicatesOk, + ), + // Entry point: gated!(unix_sigpipe, Normal, template!(Word, NameValueStr: "inherit|sig_ign|sig_dfl"), ErrorFollowing, experimental!(unix_sigpipe)), ungated!(start, Normal, template!(Word), WarnFollowing), diff --git a/compiler/rustc_monomorphize/src/partitioning.rs b/compiler/rustc_monomorphize/src/partitioning.rs index fac9a215eb674..bbe78b527de7f 100644 --- a/compiler/rustc_monomorphize/src/partitioning.rs +++ b/compiler/rustc_monomorphize/src/partitioning.rs @@ -1184,7 +1184,7 @@ fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> (&DefIdSet, &[Au println!("item: {:?}", item); let source = usage_map.used_map.get(&item).unwrap() .into_iter() - .filter_map(|item| match *item { + .find_map(|item| match *item { MonoItem::Fn(ref instance_s) => { let source_id = instance_s.def_id(); println!("source_id_inner: {:?}", source_id); @@ -1205,8 +1205,8 @@ fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> (&DefIdSet, &[Au None } _ => None, - }) - .next(); + }); + //.next(); println!("source: {:?}", source); source.map(|inst| { diff --git a/compiler/rustc_span/src/symbol.rs b/compiler/rustc_span/src/symbol.rs index 60eae2c3db2d6..a77bc0860cb54 100644 --- a/compiler/rustc_span/src/symbol.rs +++ b/compiler/rustc_span/src/symbol.rs @@ -440,6 +440,7 @@ symbols! { augmented_assignments, auto_traits, autodiff, + autodiff_into, automatically_derived, avx, avx512_target_feature, From 3e895756ab0ea911ba3f4cb6d62aeb3cf95fe27c Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Wed, 24 Jan 2024 23:12:14 -0500 Subject: [PATCH 014/100] remove leftovers --- .../rustc_ast/src/expand/autodiff_attrs.rs | 22 --------------- compiler/rustc_ast/src/expand/typetree.rs | 28 ------------------- 2 files changed, 50 deletions(-) diff --git a/compiler/rustc_ast/src/expand/autodiff_attrs.rs b/compiler/rustc_ast/src/expand/autodiff_attrs.rs index 3b77d1c347b7a..9acc068f76936 100644 --- a/compiler/rustc_ast/src/expand/autodiff_attrs.rs +++ b/compiler/rustc_ast/src/expand/autodiff_attrs.rs @@ -2,7 +2,6 @@ use rustc_data_structures::stable_hasher::{HashStable, StableHasher};//, StableO use crate::HashStableContext; use crate::expand::typetree::TypeTree; use thin_vec::ThinVec; -//use rustc_expand::base::{Annotatable, ExtCtxt}; use std::str::FromStr; use crate::NestedMetaItem; @@ -130,14 +129,6 @@ impl AutoDiffAttrs{ } } -//impl HashStable for AutoDiffAttrs { -// fn hash_stable(&self, hcx: &mut CTX, hasher: &mut StableHasher) { -// self.mode.hash_stable(hcx, hasher); -// self.ret_activity.hash_stable(hcx, hasher); -// self.input_activity.hash_stable(hcx, hasher); -// } -//} - impl AutoDiffAttrs { pub fn inactive() -> Self { AutoDiffAttrs { @@ -202,16 +193,3 @@ pub struct AutoDiffItem { pub inputs: Vec, pub output: TypeTree, } - -//impl HashStable for AutoDiffItem { -// fn hash_stable(&self, hcx: &mut CTX, hasher: &mut StableHasher) { -// self.source.hash_stable(hcx, hasher); -// self.target.hash_stable(hcx, hasher); -// self.attrs.hash_stable(hcx, hasher); -// for tt in &self.inputs { -// tt.0.hash_stable(hcx, hasher); -// } -// //self.inputs.hash_stable(hcx, hasher); -// self.output.0.hash_stable(hcx, hasher); -// } -//} diff --git a/compiler/rustc_ast/src/expand/typetree.rs b/compiler/rustc_ast/src/expand/typetree.rs index 4b154650a4a0e..bab35d55e26b8 100644 --- a/compiler/rustc_ast/src/expand/typetree.rs +++ b/compiler/rustc_ast/src/expand/typetree.rs @@ -1,7 +1,4 @@ use std::fmt; -//use rustc_data_structures::stable_hasher::{HashStable};//, StableHasher}; -//use crate::HashStableContext; - #[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] pub enum Kind { @@ -13,22 +10,6 @@ pub enum Kind { Double, Unknown, } -//impl HashStable for Kind { -// fn hash_stable(&self, hcx: &mut CTX, hasher: &mut StableHasher) { -// clause_kind_discriminant(self).hash_stable(hcx, hasher); -// } -//} -//fn clause_kind_discriminant(value: &Kind) -> usize { -// match value { -// Kind::Anything => 0, -// Kind::Integer => 1, -// Kind::Pointer => 2, -// Kind::Half => 3, -// Kind::Float => 4, -// Kind::Double => 5, -// Kind::Unknown => 6, -// } -//} #[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] pub struct TypeTree(pub Vec); @@ -41,15 +22,6 @@ pub struct Type { pub child: TypeTree, } -//impl HashStable for Type { -// fn hash_stable(&self, hcx: &mut CTX, hasher: &mut StableHasher) { -// self.offset.hash_stable(hcx, hasher); -// self.size.hash_stable(hcx, hasher); -// self.kind.hash_stable(hcx, hasher); -// self.child.0.hash_stable(hcx, hasher); -// } -//} - impl Type { pub fn add_offset(self, add: isize) -> Self { let offset = match self.offset { From f17e81d8f0d4cceea1e427b5044b52d45f8e6511 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Wed, 24 Jan 2024 23:22:40 -0500 Subject: [PATCH 015/100] mark autodiff_into as builtin and permanently unstable --- compiler/rustc_feature/src/builtin_attrs.rs | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/compiler/rustc_feature/src/builtin_attrs.rs b/compiler/rustc_feature/src/builtin_attrs.rs index d84bcd07e3d1b..115907ec17ca1 100644 --- a/compiler/rustc_feature/src/builtin_attrs.rs +++ b/compiler/rustc_feature/src/builtin_attrs.rs @@ -361,13 +361,6 @@ pub const BUILTIN_ATTRIBUTES: &[BuiltinAttribute] = &[ large_assignments, experimental!(move_size_limit) ), - // Autodiff - ungated!( - autodiff_into, Normal, - template!(Word, List: r#""...""#), - DuplicatesOk, - ), - // Entry point: gated!(unix_sigpipe, Normal, template!(Word, NameValueStr: "inherit|sig_ign|sig_dfl"), ErrorFollowing, experimental!(unix_sigpipe)), ungated!(start, Normal, template!(Word), WarnFollowing), @@ -595,6 +588,9 @@ pub const BUILTIN_ATTRIBUTES: &[BuiltinAttribute] = &[ IMPL_DETAIL, ), rustc_attr!(rustc_proc_macro_decls, Normal, template!(Word), WarnFollowing, INTERNAL_UNSTABLE), + // Autodiff + rustc_attr!(autodiff_into, Normal, template!(Word, List: r#""...""#), DuplicatesOk, INTERNAL_UNSTABLE), + rustc_attr!( rustc_macro_transparency, Normal, template!(NameValueStr: "transparent|semitransparent|opaque"), ErrorFollowing, From c97f625cf1338025529d6e27d1b8a702a225db28 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Thu, 25 Jan 2024 14:43:35 -0500 Subject: [PATCH 016/100] follow rustc_ naming convention --- compiler/rustc_builtin_macros/src/autodiff.rs | 2 +- compiler/rustc_codegen_ssa/src/codegen_attrs.rs | 4 ++-- compiler/rustc_feature/src/builtin_attrs.rs | 3 +-- compiler/rustc_span/src/symbol.rs | 2 +- 4 files changed, 5 insertions(+), 6 deletions(-) diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index 8b414930fdc76..27960d64092fc 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -95,7 +95,7 @@ pub fn expand( generics: Generics::default(), body: Some(d_body), })); - let mut tmp = P(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::autodiff_into))); + let mut tmp = P(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::rustc_autodiff))); let mut attr: ast::Attribute = ast::Attribute { kind: ast::AttrKind::Normal(tmp.clone()), id: ast::AttrId::from_u32(0), diff --git a/compiler/rustc_codegen_ssa/src/codegen_attrs.rs b/compiler/rustc_codegen_ssa/src/codegen_attrs.rs index a60dbacba66fc..9184f2ce3e13a 100644 --- a/compiler/rustc_codegen_ssa/src/codegen_attrs.rs +++ b/compiler/rustc_codegen_ssa/src/codegen_attrs.rs @@ -693,11 +693,11 @@ fn check_link_name_xor_ordinal( } fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs { - let attrs = tcx.get_attrs(id, sym::autodiff_into); + let attrs = tcx.get_attrs(id, sym::rustc_autodiff); let attrs = attrs .into_iter() - .filter(|attr| attr.name_or_empty() == sym::autodiff_into) + .filter(|attr| attr.name_or_empty() == sym::rustc_autodiff) .collect::>(); if !attrs.is_empty() { diff --git a/compiler/rustc_feature/src/builtin_attrs.rs b/compiler/rustc_feature/src/builtin_attrs.rs index 115907ec17ca1..1725ed887ec6e 100644 --- a/compiler/rustc_feature/src/builtin_attrs.rs +++ b/compiler/rustc_feature/src/builtin_attrs.rs @@ -588,8 +588,7 @@ pub const BUILTIN_ATTRIBUTES: &[BuiltinAttribute] = &[ IMPL_DETAIL, ), rustc_attr!(rustc_proc_macro_decls, Normal, template!(Word), WarnFollowing, INTERNAL_UNSTABLE), - // Autodiff - rustc_attr!(autodiff_into, Normal, template!(Word, List: r#""...""#), DuplicatesOk, INTERNAL_UNSTABLE), + rustc_attr!(rustc_autodiff, Normal, template!(Word, List: r#""...""#), DuplicatesOk, INTERNAL_UNSTABLE), rustc_attr!( rustc_macro_transparency, Normal, diff --git a/compiler/rustc_span/src/symbol.rs b/compiler/rustc_span/src/symbol.rs index a77bc0860cb54..b1c1f0feef8ad 100644 --- a/compiler/rustc_span/src/symbol.rs +++ b/compiler/rustc_span/src/symbol.rs @@ -440,7 +440,6 @@ symbols! { augmented_assignments, auto_traits, autodiff, - autodiff_into, automatically_derived, avx, avx512_target_feature, @@ -1375,6 +1374,7 @@ symbols! { rustc_allow_incoherent_impl, rustc_allowed_through_unstable_modules, rustc_attrs, + rustc_autodiff, rustc_box, rustc_builtin_macro, rustc_capture_analysis, From 494b102f648cc98cdf30299fdfa4e59ac04d94a4 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Thu, 25 Jan 2024 18:02:09 -0500 Subject: [PATCH 017/100] It works (somewhat) --- compiler/rustc_builtin_macros/src/autodiff.rs | 67 +++++++++++-------- compiler/rustc_expand/src/build.rs | 4 ++ 2 files changed, 44 insertions(+), 27 deletions(-) diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index 27960d64092fc..deaea847beb9d 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -72,20 +72,15 @@ pub fn expand( ts.push(TokenTree::Token(t, Spacing::Joint)); ts.push(TokenTree::Token(comma.clone(), Spacing::Alone)); } - dbg!(&ts); let ts: TokenStream = TokenStream::from_iter(ts); - dbg!(&ts); let x: AutoDiffAttrs = AutoDiffAttrs::from_ast(&meta_item_vec, has_ret); dbg!(&x); - //let span = ecx.with_def_site_ctxt(sig_span); let span = ecx.with_def_site_ctxt(expand_span); - //let span = ecx.with_def_site_ctxt(fn_item.span); - let (d_sig, old_names, new_args) = gen_enzyme_decl(ecx, &sig, &x, span, sig_span); + let (d_sig, old_names, new_args, idents) = gen_enzyme_decl(ecx, &sig, &x, span, sig_span); let new_decl_span = d_sig.span; - //let d_body = gen_enzyme_body(ecx, primal, &old_names, &new_args, span, span); - let d_body = gen_enzyme_body(ecx, primal, &old_names, &new_args, span, sig_span, new_decl_span); + let d_body = gen_enzyme_body(ecx, primal, &old_names, &new_args, span, sig_span, new_decl_span, &sig, &d_sig, idents); let d_ident = meta_item_vec[0].meta_item().unwrap().path.segments[0].ident; // The first element of it is the name of the function to be generated @@ -147,14 +142,13 @@ fn assure_mut_ref(ty: &ast::Ty) -> ast::Ty { // The second will just take the shadow arguments. // The third will (unsafely) call std::mem::zeroed(), to match the return type of the new function // (whatever that might be). This way we surpress rustc from optimizing anyt argument away. -fn gen_enzyme_body(ecx: &ExtCtxt<'_>, primal: Ident, old_names: &[String], new_names: &[String], span: Span, sig_span: Span, new_decl_span: Span) -> P { +fn gen_enzyme_body(ecx: &ExtCtxt<'_>, primal: Ident, old_names: &[String], new_names: &[String], span: Span, sig_span: Span, new_decl_span: Span, sig: &ast::FnSig, d_sig: &ast::FnSig, idents: Vec) -> P { let blackbox_path = ecx.std_path(&[Symbol::intern("hint"), Symbol::intern("black_box")]); let zeroed_path = ecx.std_path(&[Symbol::intern("mem"), Symbol::intern("zeroed")]); let empty_loop_block = ecx.block(span, ThinVec::new()); let loop_expr = ecx.expr_loop(span, empty_loop_block); - let primal_call_expr = ecx.expr_path(ecx.path_ident(span, primal)); let blackbox_call_expr = ecx.expr_path(ecx.path(span, blackbox_path)); let zeroed_call_expr = ecx.expr_path(ecx.path(span, zeroed_path)); @@ -172,18 +166,11 @@ fn gen_enzyme_body(ecx: &ExtCtxt<'_>, primal: Ident, old_names: &[String], new_n could_be_bare_literal: false, })); // create ::core::hint::black_box(array(arr)); - let primal_call = ecx.expr_call( - new_decl_span, - primal_call_expr, - old_names.iter().map(|name| { - ecx.expr_path(ecx.path_ident(new_decl_span, Ident::from_str(name))) - }).collect(), - ); - let black_box0 = ecx.expr_call( - new_decl_span, - blackbox_call_expr.clone(), - thin_vec![primal_call.clone()], - ); + //let black_box0 = ecx.expr_call( + // new_decl_span, + // blackbox_call_expr.clone(), + // thin_vec![primal_call.clone()], + //); // create ::core::hint::black_box(grad_arr, tang_y)); let black_box1 = ecx.expr_call( @@ -201,15 +188,38 @@ fn gen_enzyme_body(ecx: &ExtCtxt<'_>, primal: Ident, old_names: &[String], new_n thin_vec![unsafe_block_with_zeroed_call.clone()], ); + let primal_call = gen_primal_call(ecx, span, primal, sig, idents); + let mut body = ecx.block(span, ThinVec::new()); - //body.stmts.push(ecx.stmt_expr(primal_call)); + body.stmts.push(ecx.stmt_semi(primal_call)); //body.stmts.push(ecx.stmt_expr(black_box0)); //body.stmts.push(ecx.stmt_expr(black_box1)); - body.stmts.push(ecx.stmt_expr(black_box2)); + //body.stmts.push(ecx.stmt_expr(black_box2)); body.stmts.push(ecx.stmt_expr(loop_expr)); body } +fn gen_primal_call(ecx: &ExtCtxt<'_>, span: Span, primal: Ident, sig: &ast::FnSig, idents: Vec) -> P{ +//pub struct Param { +// pub attrs: AttrVec, +// pub ty: P, +// pub pat: P, +// pub id: NodeId, +// pub span: Span, +// pub is_placeholder: bool, +//} + let primal_call_expr = ecx.expr_path(ecx.path_ident(span, primal)); + let args = idents.iter().map(|arg| { + ecx.expr_path(ecx.path_ident(span, *arg)) + }).collect(); + let primal_call = ecx.expr_call( + span, + primal_call_expr, + args, + ); + primal_call +} + // Generate the new function declaration. Const arguments are kept as is. Duplicated arguments must // be pointers or references. Those receive a shadow argument, which is a mutable reference/pointer. // Active arguments must be scalars. Their shadow argument is added to the return type (and will be @@ -217,7 +227,7 @@ fn gen_enzyme_body(ecx: &ExtCtxt<'_>, primal: Ident, old_names: &[String], new_n // Each argument of the primal function (and the return type if existing) must be annotated with an // activity. fn gen_enzyme_decl(_ecx: &ExtCtxt<'_>, sig: &ast::FnSig, x: &AutoDiffAttrs, span: Span, _sig_span: Span) - -> (ast::FnSig, Vec, Vec) { + -> (ast::FnSig, Vec, Vec, Vec) { let decl: P = sig.decl.clone(); assert!(decl.inputs.len() == x.input_activity.len()); assert!(decl.output.has_ret() == x.has_ret_activity()); @@ -225,6 +235,7 @@ fn gen_enzyme_decl(_ecx: &ExtCtxt<'_>, sig: &ast::FnSig, x: &AutoDiffAttrs, span let mut d_inputs = Vec::new(); let mut new_inputs = Vec::new(); let mut old_names = Vec::new(); + let mut idents = Vec::new(); for (arg, activity) in decl.inputs.iter().zip(x.input_activity.iter()) { //dbg!(&arg); d_inputs.push(arg.clone()); @@ -234,6 +245,7 @@ fn gen_enzyme_decl(_ecx: &ExtCtxt<'_>, sig: &ast::FnSig, x: &AutoDiffAttrs, span shadow_arg.ty = P(assure_mut_ref(&arg.ty)); // adjust name depending on mode let old_name = if let PatKind::Ident(_, ident, _) = shadow_arg.pat.kind { + idents.push(ident.clone()); ident.name } else { dbg!(&shadow_arg.pat); @@ -247,17 +259,18 @@ fn gen_enzyme_decl(_ecx: &ExtCtxt<'_>, sig: &ast::FnSig, x: &AutoDiffAttrs, span }; dbg!(&name); new_inputs.push(name.clone()); + let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span); shadow_arg.pat = P(ast::Pat { // TODO: Check id id: ast::DUMMY_NODE_ID, kind: PatKind::Ident(BindingAnnotation::NONE, - Ident::from_str_and_span(&name, shadow_arg.pat.span), + ident, None, ), span: shadow_arg.pat.span, tokens: shadow_arg.pat.tokens.clone(), }); - + //idents.push(ident); d_inputs.push(shadow_arg); } _ => {}, @@ -269,5 +282,5 @@ fn gen_enzyme_decl(_ecx: &ExtCtxt<'_>, sig: &ast::FnSig, x: &AutoDiffAttrs, span decl: d_decl, span, }; - (d_sig, old_names, new_inputs) + (d_sig, old_names, new_inputs, idents) } diff --git a/compiler/rustc_expand/src/build.rs b/compiler/rustc_expand/src/build.rs index 51d87ac03487e..1ee77d4a6429f 100644 --- a/compiler/rustc_expand/src/build.rs +++ b/compiler/rustc_expand/src/build.rs @@ -157,6 +157,10 @@ impl<'a> ExtCtxt<'a> { ast::Stmt { id: ast::DUMMY_NODE_ID, span: expr.span, kind: ast::StmtKind::Expr(expr) } } + pub fn stmt_semi(&self, expr: P) -> ast::Stmt { + ast::Stmt { id: ast::DUMMY_NODE_ID, span: expr.span, kind: ast::StmtKind::Semi(expr) } + } + pub fn stmt_let_pat(&self, sp: Span, pat: P, ex: P) -> ast::Stmt { let local = P(ast::Local { pat, From 66f86a636eb70806af6563760f7f03542d286936 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Fri, 2 Feb 2024 16:45:24 -0500 Subject: [PATCH 018/100] Fwd: Rename Duplicated -> Dual --- .../rustc_ast/src/expand/autodiff_attrs.rs | 40 +++---------------- compiler/rustc_builtin_macros/src/autodiff.rs | 2 +- compiler/rustc_codegen_llvm/src/llvm/ffi.rs | 2 + 3 files changed, 9 insertions(+), 35 deletions(-) diff --git a/compiler/rustc_ast/src/expand/autodiff_attrs.rs b/compiler/rustc_ast/src/expand/autodiff_attrs.rs index 9acc068f76936..20070f20d3092 100644 --- a/compiler/rustc_ast/src/expand/autodiff_attrs.rs +++ b/compiler/rustc_ast/src/expand/autodiff_attrs.rs @@ -1,5 +1,3 @@ -use rustc_data_structures::stable_hasher::{HashStable, StableHasher};//, StableOrd}; -use crate::HashStableContext; use crate::expand::typetree::TypeTree; use thin_vec::ThinVec; use std::str::FromStr; @@ -7,7 +5,7 @@ use std::str::FromStr; use crate::NestedMetaItem; #[allow(dead_code)] -#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug)] +#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] pub enum DiffMode { Inactive, Source, @@ -16,44 +14,16 @@ pub enum DiffMode { } #[allow(dead_code)] -#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug)] +#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] pub enum DiffActivity { None, Active, Const, + Dual, + DualNoNeed, Duplicated, DuplicatedNoNeed, } -fn clause_diffactivity_discriminant(value: &DiffActivity) -> usize { - match value { - DiffActivity::None => 0, - DiffActivity::Active => 1, - DiffActivity::Const => 2, - DiffActivity::Duplicated => 3, - DiffActivity::DuplicatedNoNeed => 4, - } -} -fn clause_diffmode_discriminant(value: &DiffMode) -> usize { - match value { - DiffMode::Inactive => 0, - DiffMode::Source => 1, - DiffMode::Forward => 2, - DiffMode::Reverse => 3, - } -} - - -impl HashStable for DiffMode { - fn hash_stable(&self, hcx: &mut CTX, hasher: &mut StableHasher) { - clause_diffmode_discriminant(self).hash_stable(hcx, hasher); - } -} - -impl HashStable for DiffActivity { - fn hash_stable(&self, hcx: &mut CTX, hasher: &mut StableHasher) { - clause_diffactivity_discriminant(self).hash_stable(hcx, hasher); - } -} impl FromStr for DiffMode { type Err = (); @@ -74,6 +44,8 @@ impl FromStr for DiffActivity { "None" => Ok(DiffActivity::None), "Active" => Ok(DiffActivity::Active), "Const" => Ok(DiffActivity::Const), + "Dual" => Ok(DiffActivity::Dual), + "DualNoNeed" => Ok(DiffActivity::DualNoNeed), "Duplicated" => Ok(DiffActivity::Duplicated), "DuplicatedNoNeed" => Ok(DiffActivity::DuplicatedNoNeed), _ => Err(()), diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index deaea847beb9d..ad0a398141f45 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -240,7 +240,7 @@ fn gen_enzyme_decl(_ecx: &ExtCtxt<'_>, sig: &ast::FnSig, x: &AutoDiffAttrs, span //dbg!(&arg); d_inputs.push(arg.clone()); match activity { - DiffActivity::Duplicated => { + DiffActivity::Duplicated | DiffActivity::Dual => { let mut shadow_arg = arg.clone(); shadow_arg.ty = P(assure_mut_ref(&arg.ty)); // adjust name depending on mode diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index e479229271d95..2c3f0df1e5e2c 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -2689,6 +2689,8 @@ fn cdiffe_from(act: DiffActivity) -> CDIFFE_TYPE { DiffActivity::None => CDIFFE_TYPE::DFT_CONSTANT, DiffActivity::Active => CDIFFE_TYPE::DFT_OUT_DIFF, DiffActivity::Const => CDIFFE_TYPE::DFT_CONSTANT, + DiffActivity::Dual => CDIFFE_TYPE::DFT_DUP_ARG, + DiffActivity::DualNoNeed => CDIFFE_TYPE::DFT_DUP_NONEED, DiffActivity::Duplicated => CDIFFE_TYPE::DFT_DUP_ARG, DiffActivity::DuplicatedNoNeed => CDIFFE_TYPE::DFT_DUP_NONEED, }; From 59bc30fe9a23041257325bdd902678b81aefff06 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Sat, 3 Feb 2024 19:51:07 -0500 Subject: [PATCH 019/100] properly allow interior instability --- library/core/src/macros/mod.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/library/core/src/macros/mod.rs b/library/core/src/macros/mod.rs index 510c53f1fafcf..041d9990354ca 100644 --- a/library/core/src/macros/mod.rs +++ b/library/core/src/macros/mod.rs @@ -1493,6 +1493,7 @@ pub(crate) mod builtin { /// /// [the reference]: ../../../reference/attributes/derive.html #[unstable(feature = "autodiff", issue = "none")] + #[allow_internal_unstable(rustc_attrs)] #[rustc_builtin_macro] #[cfg(not(bootstrap))] pub macro autodiff($item:item) { From 12b180e756db7281161e5d5f72ba3625252249d3 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Sat, 3 Feb 2024 19:52:02 -0500 Subject: [PATCH 020/100] make d_fnc decl a bit more precise --- compiler/rustc_builtin_macros/src/autodiff.rs | 83 ++++++++++++------- 1 file changed, 54 insertions(+), 29 deletions(-) diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index ad0a398141f45..f394899d6cba2 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -90,23 +90,23 @@ pub fn expand( generics: Generics::default(), body: Some(d_body), })); - let mut tmp = P(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::rustc_autodiff))); + let mut rustc_ad_attr = P(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::rustc_autodiff))); let mut attr: ast::Attribute = ast::Attribute { - kind: ast::AttrKind::Normal(tmp.clone()), + kind: ast::AttrKind::Normal(rustc_ad_attr.clone()), id: ast::AttrId::from_u32(0), style: ast::AttrStyle::Outer, span: span, }; - orig_item.attrs.push(attr); + orig_item.attrs.push(attr.clone()); // Now update for d_fn - tmp.item.args = rustc_ast::AttrArgs::Delimited(rustc_ast::DelimArgs { + rustc_ad_attr.item.args = rustc_ast::AttrArgs::Delimited(rustc_ast::DelimArgs { dspan: DelimSpan::dummy(), delim: rustc_ast::token::Delimiter::Parenthesis, tokens: ts, }); let mut attr2: ast::Attribute = ast::Attribute { - kind: ast::AttrKind::Normal(tmp), + kind: ast::AttrKind::Normal(rustc_ad_attr), id: ast::AttrId::from_u32(0), style: ast::AttrStyle::Outer, span: span, @@ -165,12 +165,13 @@ fn gen_enzyme_body(ecx: &ExtCtxt<'_>, primal: Ident, old_names: &[String], new_n tokens: None, could_be_bare_literal: false, })); + let primal_call = gen_primal_call(ecx, span, primal, sig, idents); // create ::core::hint::black_box(array(arr)); - //let black_box0 = ecx.expr_call( - // new_decl_span, - // blackbox_call_expr.clone(), - // thin_vec![primal_call.clone()], - //); + let black_box0 = ecx.expr_call( + new_decl_span, + blackbox_call_expr.clone(), + thin_vec![primal_call.clone()], + ); // create ::core::hint::black_box(grad_arr, tang_y)); let black_box1 = ecx.expr_call( @@ -188,26 +189,17 @@ fn gen_enzyme_body(ecx: &ExtCtxt<'_>, primal: Ident, old_names: &[String], new_n thin_vec![unsafe_block_with_zeroed_call.clone()], ); - let primal_call = gen_primal_call(ecx, span, primal, sig, idents); let mut body = ecx.block(span, ThinVec::new()); body.stmts.push(ecx.stmt_semi(primal_call)); - //body.stmts.push(ecx.stmt_expr(black_box0)); - //body.stmts.push(ecx.stmt_expr(black_box1)); - //body.stmts.push(ecx.stmt_expr(black_box2)); + body.stmts.push(ecx.stmt_semi(black_box0)); + body.stmts.push(ecx.stmt_semi(black_box1)); + //body.stmts.push(ecx.stmt_semi(black_box2)); body.stmts.push(ecx.stmt_expr(loop_expr)); body } fn gen_primal_call(ecx: &ExtCtxt<'_>, span: Span, primal: Ident, sig: &ast::FnSig, idents: Vec) -> P{ -//pub struct Param { -// pub attrs: AttrVec, -// pub ty: P, -// pub pat: P, -// pub id: NodeId, -// pub span: Span, -// pub is_placeholder: bool, -//} let primal_call_expr = ecx.expr_path(ecx.path_ident(span, primal)); let args = idents.iter().map(|arg| { ecx.expr_path(ecx.path_ident(span, *arg)) @@ -228,16 +220,14 @@ fn gen_primal_call(ecx: &ExtCtxt<'_>, span: Span, primal: Ident, sig: &ast::FnSi // activity. fn gen_enzyme_decl(_ecx: &ExtCtxt<'_>, sig: &ast::FnSig, x: &AutoDiffAttrs, span: Span, _sig_span: Span) -> (ast::FnSig, Vec, Vec, Vec) { - let decl: P = sig.decl.clone(); - assert!(decl.inputs.len() == x.input_activity.len()); - assert!(decl.output.has_ret() == x.has_ret_activity()); - let mut d_decl = decl.clone(); + assert!(sig.decl.inputs.len() == x.input_activity.len()); + assert!(sig.decl.output.has_ret() == x.has_ret_activity()); + let mut d_decl = sig.decl.clone(); let mut d_inputs = Vec::new(); let mut new_inputs = Vec::new(); let mut old_names = Vec::new(); let mut idents = Vec::new(); - for (arg, activity) in decl.inputs.iter().zip(x.input_activity.iter()) { - //dbg!(&arg); + for (arg, activity) in sig.decl.inputs.iter().zip(x.input_activity.iter()) { d_inputs.push(arg.clone()); match activity { DiffActivity::Duplicated | DiffActivity::Dual => { @@ -273,7 +263,42 @@ fn gen_enzyme_decl(_ecx: &ExtCtxt<'_>, sig: &ast::FnSig, x: &AutoDiffAttrs, span //idents.push(ident); d_inputs.push(shadow_arg); } - _ => {}, + _ => {dbg!(&activity);}, + } + } + + // If we return a scalar in the primal and the scalar is active, + // then add it as last arg to the inputs. + if x.mode == DiffMode::Reverse { + match x.ret_activity { + DiffActivity::Active => { + let ty = match d_decl.output { + rustc_ast::FnRetTy::Ty(ref ty) => ty.clone(), + rustc_ast::FnRetTy::Default(span) => { + panic!("Did not expect Default ret ty: {:?}", span); + } + }; + let name = "dret".to_string(); + let ident = Ident::from_str_and_span(&name, ty.span); + let shadow_arg = ast::Param { + attrs: ThinVec::new(), + ty: ty.clone(), + pat: P(ast::Pat { + id: ast::DUMMY_NODE_ID, + kind: PatKind::Ident(BindingAnnotation::NONE, + ident, + None, + ), + span: ty.span, + tokens: None, + }), + id: ast::DUMMY_NODE_ID, + span: ty.span, + is_placeholder: false, + }; + d_inputs.push(shadow_arg); + } + _ => {} } } d_decl.inputs = d_inputs.into(); From 47d6d3c0bdfb689a48ea7f1796cf0aca18baac0a Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Sun, 4 Feb 2024 23:19:11 -0500 Subject: [PATCH 021/100] running x.py fmt --- .../rustc_ast/src/expand/autodiff_attrs.rs | 31 ++-- compiler/rustc_ast/src/expand/mod.rs | 2 +- compiler/rustc_builtin_macros/src/autodiff.rs | 135 ++++++++++-------- compiler/rustc_codegen_llvm/src/back/lto.rs | 7 +- compiler/rustc_codegen_llvm/src/back/write.rs | 133 +++++++++++------ compiler/rustc_codegen_llvm/src/errors.rs | 10 +- compiler/rustc_codegen_llvm/src/lib.rs | 8 +- compiler/rustc_codegen_llvm/src/llvm/ffi.rs | 47 ++++-- compiler/rustc_codegen_ssa/src/back/lto.rs | 8 +- compiler/rustc_codegen_ssa/src/back/write.rs | 13 +- .../rustc_codegen_ssa/src/codegen_attrs.rs | 31 ++-- compiler/rustc_middle/src/query/mod.rs | 2 +- .../rustc_monomorphize/src/partitioning.rs | 47 ++---- library/std/src/lib.rs | 4 - src/bootstrap/src/core/build_steps/compile.rs | 3 +- 15 files changed, 266 insertions(+), 215 deletions(-) diff --git a/compiler/rustc_ast/src/expand/autodiff_attrs.rs b/compiler/rustc_ast/src/expand/autodiff_attrs.rs index 20070f20d3092..a0f632e1ba105 100644 --- a/compiler/rustc_ast/src/expand/autodiff_attrs.rs +++ b/compiler/rustc_ast/src/expand/autodiff_attrs.rs @@ -1,6 +1,6 @@ use crate::expand::typetree::TypeTree; -use thin_vec::ThinVec; use std::str::FromStr; +use thin_vec::ThinVec; use crate::NestedMetaItem; @@ -28,7 +28,8 @@ pub enum DiffActivity { impl FromStr for DiffMode { type Err = (); - fn from_str(s: &str) -> Result { match s { + fn from_str(s: &str) -> Result { + match s { "Inactive" => Ok(DiffMode::Inactive), "Source" => Ok(DiffMode::Source), "Forward" => Ok(DiffMode::Forward), @@ -40,7 +41,8 @@ impl FromStr for DiffMode { impl FromStr for DiffActivity { type Err = (); - fn from_str(s: &str) -> Result { match s { + fn from_str(s: &str) -> Result { + match s { "None" => Ok(DiffActivity::None), "Active" => Ok(DiffActivity::Active), "Const" => Ok(DiffActivity::Const), @@ -70,7 +72,7 @@ fn name(x: &NestedMetaItem) -> String { first_ident(x).name.to_string() } -impl AutoDiffAttrs{ +impl AutoDiffAttrs { pub fn has_ret_activity(&self) -> bool { match self.ret_activity { DiffActivity::None => false, @@ -80,10 +82,13 @@ impl AutoDiffAttrs{ pub fn from_ast(meta_item: &ThinVec, has_ret: bool) -> Self { let mode = name(&meta_item[1]); let mode = DiffMode::from_str(&mode).unwrap(); - let activities: Vec = meta_item[2..].iter().map(|x| { - let activity_str = name(&x); - DiffActivity::from_str(&activity_str).unwrap() - }).collect(); + let activities: Vec = meta_item[2..] + .iter() + .map(|x| { + let activity_str = name(&x); + DiffActivity::from_str(&activity_str).unwrap() + }) + .collect(); // If a return type exist, we need to split the last activity, // otherwise we return None as placeholder. @@ -93,11 +98,7 @@ impl AutoDiffAttrs{ (&DiffActivity::None, activities.as_slice()) }; - AutoDiffAttrs { - mode, - ret_activity: *ret_activity, - input_activity: input_activity.to_vec(), - } + AutoDiffAttrs { mode, ret_activity: *ret_activity, input_activity: input_activity.to_vec() } } } @@ -123,7 +124,7 @@ impl AutoDiffAttrs { _ => { dbg!(&self); true - }, + } } } @@ -141,7 +142,7 @@ impl AutoDiffAttrs { _ => { dbg!(&self); true - }, + } } } diff --git a/compiler/rustc_ast/src/expand/mod.rs b/compiler/rustc_ast/src/expand/mod.rs index b8434374a3594..29cc9121aebbd 100644 --- a/compiler/rustc_ast/src/expand/mod.rs +++ b/compiler/rustc_ast/src/expand/mod.rs @@ -5,8 +5,8 @@ use rustc_span::{def_id::DefId, symbol::Ident}; use crate::MetaItem; pub mod allocator; -pub mod typetree; pub mod autodiff_attrs; +pub mod typetree; #[derive(Debug, Clone, Encodable, Decodable, HashStable_Generic)] pub struct StrippedCfgItem { diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index f394899d6cba2..a8c2d30472597 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -4,20 +4,20 @@ //use crate::util::check_builtin_macro_attribute; //use crate::util::check_autodiff; -use std::string::String; use crate::errors; +use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode}; use rustc_ast::ptr::P; -use rustc_ast::{BindingAnnotation, ByRef}; -use rustc_ast::{self as ast, FnHeader, FnSig, Generics, StmtKind, NestedMetaItem, MetaItemKind}; -use rustc_ast::{Fn, ItemKind, Stmt, TyKind, Unsafe, PatKind}; +use rustc_ast::token::{Token, TokenKind}; use rustc_ast::tokenstream::*; +use rustc_ast::{self as ast, FnHeader, FnSig, Generics, MetaItemKind, NestedMetaItem, StmtKind}; +use rustc_ast::{BindingAnnotation, ByRef}; +use rustc_ast::{Fn, ItemKind, PatKind, Stmt, TyKind, Unsafe}; use rustc_expand::base::{Annotatable, ExtCtxt}; use rustc_span::symbol::{kw, sym, Ident}; use rustc_span::Span; -use thin_vec::{thin_vec, ThinVec}; use rustc_span::Symbol; -use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode}; -use rustc_ast::token::{Token, TokenKind}; +use std::string::String; +use thin_vec::{thin_vec, ThinVec}; fn first_ident(x: &NestedMetaItem) -> rustc_span::symbol::Ident { let segments = &x.meta_item().unwrap().path.segments; @@ -36,9 +36,7 @@ pub fn expand( let meta_item_vec: ThinVec = match meta_item.kind { ast::MetaItemKind::List(ref vec) => vec.clone(), _ => { - ecx.sess - .dcx() - .emit_err(errors::AutoDiffInvalidApplication { span: item.span() }); + ecx.sess.dcx().emit_err(errors::AutoDiffInvalidApplication { span: item.span() }); return vec![item]; } }; @@ -52,18 +50,19 @@ pub fn expand( { (item, sig.decl.output.has_ret(), sig, ecx.with_call_site_ctxt(sig.span)) } else { - ecx.sess - .dcx() - .emit_err(errors::AutoDiffInvalidApplication { span: item.span() }); + ecx.sess.dcx().emit_err(errors::AutoDiffInvalidApplication { span: item.span() }); return vec![item]; }; // create TokenStream from vec elemtents: // meta_item doesn't have a .tokens field - let ts: Vec = meta_item_vec.clone()[1..].iter().map(|x| { - let val = first_ident(x); - let t = Token::from_ast_ident(val); - t - }).collect(); + let ts: Vec = meta_item_vec.clone()[1..] + .iter() + .map(|x| { + let val = first_ident(x); + let t = Token::from_ast_ident(val); + t + }) + .collect(); let comma: Token = Token::new(TokenKind::Comma, Span::default()); let mut ts: Vec = vec![]; for t in meta_item_vec.clone()[1..].iter() { @@ -80,7 +79,18 @@ pub fn expand( let (d_sig, old_names, new_args, idents) = gen_enzyme_decl(ecx, &sig, &x, span, sig_span); let new_decl_span = d_sig.span; - let d_body = gen_enzyme_body(ecx, primal, &old_names, &new_args, span, sig_span, new_decl_span, &sig, &d_sig, idents); + let d_body = gen_enzyme_body( + ecx, + primal, + &old_names, + &new_args, + span, + sig_span, + new_decl_span, + &sig, + &d_sig, + idents, + ); let d_ident = meta_item_vec[0].meta_item().unwrap().path.segments[0].ident; // The first element of it is the name of the function to be generated @@ -90,7 +100,8 @@ pub fn expand( generics: Generics::default(), body: Some(d_body), })); - let mut rustc_ad_attr = P(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::rustc_autodiff))); + let mut rustc_ad_attr = + P(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::rustc_autodiff))); let mut attr: ast::Attribute = ast::Attribute { kind: ast::AttrKind::Normal(rustc_ad_attr.clone()), id: ast::AttrId::from_u32(0), @@ -136,27 +147,33 @@ fn assure_mut_ref(ty: &ast::Ty) -> ast::Ty { ty } - // The body of our generated functions will consist of three black_Box calls. // The first will call the primal function with the original arguments. // The second will just take the shadow arguments. // The third will (unsafely) call std::mem::zeroed(), to match the return type of the new function // (whatever that might be). This way we surpress rustc from optimizing anyt argument away. -fn gen_enzyme_body(ecx: &ExtCtxt<'_>, primal: Ident, old_names: &[String], new_names: &[String], span: Span, sig_span: Span, new_decl_span: Span, sig: &ast::FnSig, d_sig: &ast::FnSig, idents: Vec) -> P { +fn gen_enzyme_body( + ecx: &ExtCtxt<'_>, + primal: Ident, + old_names: &[String], + new_names: &[String], + span: Span, + sig_span: Span, + new_decl_span: Span, + sig: &ast::FnSig, + d_sig: &ast::FnSig, + idents: Vec, +) -> P { let blackbox_path = ecx.std_path(&[Symbol::intern("hint"), Symbol::intern("black_box")]); let zeroed_path = ecx.std_path(&[Symbol::intern("mem"), Symbol::intern("zeroed")]); let empty_loop_block = ecx.block(span, ThinVec::new()); let loop_expr = ecx.expr_loop(span, empty_loop_block); - let blackbox_call_expr = ecx.expr_path(ecx.path(span, blackbox_path)); let zeroed_call_expr = ecx.expr_path(ecx.path(span, zeroed_path)); - let mem_zeroed_call: Stmt = ecx.stmt_expr(ecx.expr_call( - span, - zeroed_call_expr.clone(), - thin_vec![], - )); + let mem_zeroed_call: Stmt = + ecx.stmt_expr(ecx.expr_call(span, zeroed_call_expr.clone(), thin_vec![])); let unsafe_block_with_zeroed_call: P = ecx.expr_block(P(ast::Block { stmts: thin_vec![mem_zeroed_call], id: ast::DUMMY_NODE_ID, @@ -167,19 +184,17 @@ fn gen_enzyme_body(ecx: &ExtCtxt<'_>, primal: Ident, old_names: &[String], new_n })); let primal_call = gen_primal_call(ecx, span, primal, sig, idents); // create ::core::hint::black_box(array(arr)); - let black_box0 = ecx.expr_call( - new_decl_span, - blackbox_call_expr.clone(), - thin_vec![primal_call.clone()], - ); + let black_box0 = + ecx.expr_call(new_decl_span, blackbox_call_expr.clone(), thin_vec![primal_call.clone()]); // create ::core::hint::black_box(grad_arr, tang_y)); let black_box1 = ecx.expr_call( sig_span, blackbox_call_expr.clone(), - new_names.iter().map(|arg| { - ecx.expr_path(ecx.path_ident(span, Ident::from_str(arg))) - }).collect(), + new_names + .iter() + .map(|arg| ecx.expr_path(ecx.path_ident(span, Ident::from_str(arg)))) + .collect(), ); // create ::core::hint::black_box(unsafe { ::core::mem::zeroed() }) @@ -189,7 +204,6 @@ fn gen_enzyme_body(ecx: &ExtCtxt<'_>, primal: Ident, old_names: &[String], new_n thin_vec![unsafe_block_with_zeroed_call.clone()], ); - let mut body = ecx.block(span, ThinVec::new()); body.stmts.push(ecx.stmt_semi(primal_call)); body.stmts.push(ecx.stmt_semi(black_box0)); @@ -199,16 +213,16 @@ fn gen_enzyme_body(ecx: &ExtCtxt<'_>, primal: Ident, old_names: &[String], new_n body } -fn gen_primal_call(ecx: &ExtCtxt<'_>, span: Span, primal: Ident, sig: &ast::FnSig, idents: Vec) -> P{ +fn gen_primal_call( + ecx: &ExtCtxt<'_>, + span: Span, + primal: Ident, + sig: &ast::FnSig, + idents: Vec, +) -> P { let primal_call_expr = ecx.expr_path(ecx.path_ident(span, primal)); - let args = idents.iter().map(|arg| { - ecx.expr_path(ecx.path_ident(span, *arg)) - }).collect(); - let primal_call = ecx.expr_call( - span, - primal_call_expr, - args, - ); + let args = idents.iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect(); + let primal_call = ecx.expr_call(span, primal_call_expr, args); primal_call } @@ -218,8 +232,13 @@ fn gen_primal_call(ecx: &ExtCtxt<'_>, span: Span, primal: Ident, sig: &ast::FnSi // zero-initialized by Enzyme). Active arguments are not handled yet. // Each argument of the primal function (and the return type if existing) must be annotated with an // activity. -fn gen_enzyme_decl(_ecx: &ExtCtxt<'_>, sig: &ast::FnSig, x: &AutoDiffAttrs, span: Span, _sig_span: Span) - -> (ast::FnSig, Vec, Vec, Vec) { +fn gen_enzyme_decl( + _ecx: &ExtCtxt<'_>, + sig: &ast::FnSig, + x: &AutoDiffAttrs, + span: Span, + _sig_span: Span, +) -> (ast::FnSig, Vec, Vec, Vec) { assert!(sig.decl.inputs.len() == x.input_activity.len()); assert!(sig.decl.output.has_ret() == x.has_ret_activity()); let mut d_decl = sig.decl.clone(); @@ -253,17 +272,16 @@ fn gen_enzyme_decl(_ecx: &ExtCtxt<'_>, sig: &ast::FnSig, x: &AutoDiffAttrs, span shadow_arg.pat = P(ast::Pat { // TODO: Check id id: ast::DUMMY_NODE_ID, - kind: PatKind::Ident(BindingAnnotation::NONE, - ident, - None, - ), + kind: PatKind::Ident(BindingAnnotation::NONE, ident, None), span: shadow_arg.pat.span, tokens: shadow_arg.pat.tokens.clone(), }); //idents.push(ident); d_inputs.push(shadow_arg); } - _ => {dbg!(&activity);}, + _ => { + dbg!(&activity); + } } } @@ -285,10 +303,7 @@ fn gen_enzyme_decl(_ecx: &ExtCtxt<'_>, sig: &ast::FnSig, x: &AutoDiffAttrs, span ty: ty.clone(), pat: P(ast::Pat { id: ast::DUMMY_NODE_ID, - kind: PatKind::Ident(BindingAnnotation::NONE, - ident, - None, - ), + kind: PatKind::Ident(BindingAnnotation::NONE, ident, None), span: ty.span, tokens: None, }), @@ -302,10 +317,6 @@ fn gen_enzyme_decl(_ecx: &ExtCtxt<'_>, sig: &ast::FnSig, x: &AutoDiffAttrs, span } } d_decl.inputs = d_inputs.into(); - let d_sig = FnSig { - header: sig.header.clone(), - decl: d_decl, - span, - }; + let d_sig = FnSig { header: sig.header.clone(), decl: d_decl, span }; (d_sig, old_names, new_inputs, idents) } diff --git a/compiler/rustc_codegen_llvm/src/back/lto.rs b/compiler/rustc_codegen_llvm/src/back/lto.rs index 55bf6b295d27b..cf5badf7b99c1 100644 --- a/compiler/rustc_codegen_llvm/src/back/lto.rs +++ b/compiler/rustc_codegen_llvm/src/back/lto.rs @@ -714,7 +714,12 @@ pub unsafe fn optimize_thin_module( let llcx = llvm::LLVMRustContextCreate(cgcx.fewer_names); let llmod_raw = parse_module(llcx, module_name, thin_module.data(), &dcx)? as *const _; let mut module = ModuleCodegen { - module_llvm: ModuleLlvm { llmod_raw, llcx, tm: ManuallyDrop::new(tm), typetrees: Default::default() }, + module_llvm: ModuleLlvm { + llmod_raw, + llcx, + tm: ManuallyDrop::new(tm), + typetrees: Default::default(), + }, name: thin_module.name().to_string(), kind: ModuleKind::Regular, }; diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs index c726b62912ef1..dc1f988617377 100644 --- a/compiler/rustc_codegen_llvm/src/back/write.rs +++ b/compiler/rustc_codegen_llvm/src/back/write.rs @@ -5,26 +5,29 @@ use crate::back::profiling::{ }; use crate::base; use crate::common; -use crate::DiffTypeTree; use crate::errors::{ CopyBitcode, FromLlvmDiag, FromLlvmOptimizationDiag, LlvmError, UnknownCompression, WithLlvmError, WriteBytecode, }; use crate::llvm::{self, DiagnosticInfo, PassManager}; -use crate::llvm::{LLVMReplaceAllUsesWith, LLVMVerifyFunction, Value, LLVMRustGetEnumAttributeAtIndex, LLVMRustAddEnumAttributeAtIndex, LLVMRustRemoveEnumAttributeAtIndex, - enzyme_rust_forward_diff, enzyme_rust_reverse_diff, BasicBlock, CreateEnzymeLogic, - CreateTypeAnalysis, EnzymeLogicRef, EnzymeTypeAnalysisRef, LLVMAddFunction, +use crate::llvm::{ + enzyme_rust_forward_diff, enzyme_rust_reverse_diff, AttributeKind, BasicBlock, + CreateEnzymeLogic, CreateTypeAnalysis, EnzymeLogicRef, EnzymeTypeAnalysisRef, LLVMAddFunction, LLVMAppendBasicBlockInContext, LLVMBuildCall2, LLVMBuildExtractValue, LLVMBuildRet, - LLVMCountParams, LLVMCountStructElementTypes, LLVMCreateBuilderInContext, LLVMDeleteFunction, - LLVMDisposeBuilder, LLVMGetBasicBlockTerminator, LLVMGetElementType, LLVMGetModuleContext, - LLVMGetParams, LLVMGetReturnType, LLVMPositionBuilderAtEnd, LLVMSetValueName2, LLVMTypeOf, - LLVMVoidTypeInContext, LLVMGlobalGetValueType, LLVMGetStringAttributeAtIndex, - LLVMIsStringAttribute, LLVMRemoveStringAttributeAtIndex, AttributeKind, - LLVMGetFirstFunction, LLVMGetNextFunction, LLVMIsEnumAttribute, - LLVMCreateStringAttribute, LLVMRustAddFunctionAttributes, LLVMDumpModule}; + LLVMCountParams, LLVMCountStructElementTypes, LLVMCreateBuilderInContext, + LLVMCreateStringAttribute, LLVMDeleteFunction, LLVMDisposeBuilder, LLVMDumpModule, + LLVMGetBasicBlockTerminator, LLVMGetElementType, LLVMGetFirstFunction, LLVMGetModuleContext, + LLVMGetNextFunction, LLVMGetParams, LLVMGetReturnType, LLVMGetStringAttributeAtIndex, + LLVMGlobalGetValueType, LLVMIsEnumAttribute, LLVMIsStringAttribute, LLVMPositionBuilderAtEnd, + LLVMRemoveStringAttributeAtIndex, LLVMReplaceAllUsesWith, LLVMRustAddEnumAttributeAtIndex, + LLVMRustAddFunctionAttributes, LLVMRustGetEnumAttributeAtIndex, + LLVMRustRemoveEnumAttributeAtIndex, LLVMSetValueName2, LLVMTypeOf, LLVMVerifyFunction, + LLVMVoidTypeInContext, Value, +}; use crate::llvm_util; use crate::type_::Type; use crate::typetree::to_enzyme_typetree; +use crate::DiffTypeTree; use crate::LlvmCodegenBackend; use crate::ModuleLlvm; use llvm::{ @@ -38,9 +41,9 @@ use rustc_codegen_ssa::back::write::{ }; use rustc_codegen_ssa::traits::*; use rustc_codegen_ssa::{CompiledModule, ModuleCodegen}; +use rustc_data_structures::fx::FxHashMap; use rustc_data_structures::profiling::SelfProfilerRef; use rustc_data_structures::small_c_str::SmallCStr; -use rustc_data_structures::fx::FxHashMap; use rustc_errors::{DiagCtxt, FatalError, Level}; use rustc_fs_util::{link_or_copy, path_to_c_string}; use rustc_middle::ty::TyCtxt; @@ -714,22 +717,28 @@ pub(crate) unsafe fn enzyme_ad( let src_fnc = match src_fnc_opt { Some(x) => x, None => { - return Err(llvm_err(diag_handler, LlvmError::PrepareAutoDiff{ - src: rust_name.to_owned(), - target: rust_name2.to_owned(), - error: "could not find src function".to_owned(), - })); + return Err(llvm_err( + diag_handler, + LlvmError::PrepareAutoDiff { + src: rust_name.to_owned(), + target: rust_name2.to_owned(), + error: "could not find src function".to_owned(), + }, + )); } }; let target_fnc_opt = llvm::LLVMGetNamedFunction(llmod, name2.as_ptr()); let target_fnc = match target_fnc_opt { Some(x) => x, None => { - return Err(llvm_err(diag_handler, LlvmError::PrepareAutoDiff{ - src: rust_name.to_owned(), - target: rust_name2.to_owned(), - error: "could not find target function".to_owned(), - })); + return Err(llvm_err( + diag_handler, + LlvmError::PrepareAutoDiff { + src: rust_name.to_owned(), + target: rust_name2.to_owned(), + error: "could not find target function".to_owned(), + }, + )); } }; let src_num_args = llvm::LLVMCountParams(src_fnc); @@ -756,13 +765,13 @@ pub(crate) unsafe fn enzyme_ad( llvm::EnzymeSetCLBool(std::ptr::addr_of_mut!(llvm::EnzymeStrictAliasing), 0); if std::env::var("ENZYME_PRINT_TA").is_ok() { - llvm::EnzymeSetCLBool(std::ptr::addr_of_mut!(llvm::EnzymePrintType), 1); + llvm::EnzymeSetCLBool(std::ptr::addr_of_mut!(llvm::EnzymePrintType), 1); } if std::env::var("ENZYME_PRINT_AA").is_ok() { llvm::EnzymeSetCLBool(std::ptr::addr_of_mut!(llvm::EnzymePrintActivity), 1); } if std::env::var("ENZYME_PRINT_PERF").is_ok() { - llvm::EnzymeSetCLBool(std::ptr::addr_of_mut!(llvm::EnzymePrintPerf), 1); + llvm::EnzymeSetCLBool(std::ptr::addr_of_mut!(llvm::EnzymePrintPerf), 1); } if std::env::var("ENZYME_PRINT").is_ok() { llvm::EnzymeSetCLBool(std::ptr::addr_of_mut!(llvm::EnzymePrint), 1); @@ -826,7 +835,9 @@ pub(crate) unsafe fn differentiate( llvm::EnzymeSetCLBool(std::ptr::addr_of_mut!(llvm::EnzymeStrictAliasing), 0); if std::env::var("ENZYME_PRINT_MOD").is_ok() { - unsafe {LLVMDumpModule(llmod);} + unsafe { + LLVMDumpModule(llmod); + } } if std::env::var("ENZYME_TT_DEPTH").is_ok() { let depth = std::env::var("ENZYME_TT_DEPTH").unwrap(); @@ -849,22 +860,36 @@ pub(crate) unsafe fn differentiate( let mut f = LLVMGetFirstFunction(llmod); loop { if let Some(lf) = f { - f = LLVMGetNextFunction(lf); - let myhwattr = "enzyme_hw"; - let attr = LLVMGetStringAttributeAtIndex(lf, c_uint::MAX, myhwattr.as_ptr() as *const c_char, myhwattr.as_bytes().len() as c_uint); - if LLVMIsStringAttribute(attr) { - LLVMRemoveStringAttributeAtIndex(lf, c_uint::MAX, myhwattr.as_ptr() as *const c_char, myhwattr.as_bytes().len() as c_uint); - } else { - LLVMRustRemoveEnumAttributeAtIndex(lf, c_uint::MAX, AttributeKind::SanitizeHWAddress); - } - - + f = LLVMGetNextFunction(lf); + let myhwattr = "enzyme_hw"; + let attr = LLVMGetStringAttributeAtIndex( + lf, + c_uint::MAX, + myhwattr.as_ptr() as *const c_char, + myhwattr.as_bytes().len() as c_uint, + ); + if LLVMIsStringAttribute(attr) { + LLVMRemoveStringAttributeAtIndex( + lf, + c_uint::MAX, + myhwattr.as_ptr() as *const c_char, + myhwattr.as_bytes().len() as c_uint, + ); + } else { + LLVMRustRemoveEnumAttributeAtIndex( + lf, + c_uint::MAX, + AttributeKind::SanitizeHWAddress, + ); + } } else { break; } } if std::env::var("ENZYME_PRINT_MOD_AFTER").is_ok() { - unsafe {LLVMDumpModule(llmod);} + unsafe { + LLVMDumpModule(llmod); + } } Ok(()) @@ -901,17 +926,31 @@ pub(crate) unsafe fn optimize( let mut f = LLVMGetFirstFunction(llmod); loop { if let Some(lf) = f { - f = LLVMGetNextFunction(lf); - let myhwattr = "enzyme_hw"; - let myhwv = ""; - let prevattr = LLVMRustGetEnumAttributeAtIndex(lf, c_uint::MAX, AttributeKind::SanitizeHWAddress); - if LLVMIsEnumAttribute(prevattr) { - let attr = LLVMCreateStringAttribute(llcx, myhwattr.as_ptr() as *const c_char, myhwattr.as_bytes().len() as c_uint, myhwv.as_ptr() as *const c_char, myhwv.as_bytes().len() as c_uint); - LLVMRustAddFunctionAttributes(lf, c_uint::MAX, &attr, 1); - } else { - LLVMRustAddEnumAttributeAtIndex(llcx, lf, c_uint::MAX, AttributeKind::SanitizeHWAddress); - } - + f = LLVMGetNextFunction(lf); + let myhwattr = "enzyme_hw"; + let myhwv = ""; + let prevattr = LLVMRustGetEnumAttributeAtIndex( + lf, + c_uint::MAX, + AttributeKind::SanitizeHWAddress, + ); + if LLVMIsEnumAttribute(prevattr) { + let attr = LLVMCreateStringAttribute( + llcx, + myhwattr.as_ptr() as *const c_char, + myhwattr.as_bytes().len() as c_uint, + myhwv.as_ptr() as *const c_char, + myhwv.as_bytes().len() as c_uint, + ); + LLVMRustAddFunctionAttributes(lf, c_uint::MAX, &attr, 1); + } else { + LLVMRustAddEnumAttributeAtIndex( + llcx, + lf, + c_uint::MAX, + AttributeKind::SanitizeHWAddress, + ); + } } else { break; } diff --git a/compiler/rustc_codegen_llvm/src/errors.rs b/compiler/rustc_codegen_llvm/src/errors.rs index ea0c55f92250e..129e692fd7fa4 100644 --- a/compiler/rustc_codegen_llvm/src/errors.rs +++ b/compiler/rustc_codegen_llvm/src/errors.rs @@ -183,11 +183,7 @@ pub enum LlvmError<'a> { #[diag(codegen_llvm_parse_bitcode)] ParseBitcode, #[diag(codegen_llvm_prepare_autodiff)] - PrepareAutoDiff { - src: String, - target: String, - error: String, - }, + PrepareAutoDiff { src: String, target: String, error: String }, } pub(crate) struct WithLlvmError<'a>(pub LlvmError<'a>, pub String); @@ -209,9 +205,7 @@ impl IntoDiagnostic<'_, G> for WithLlvmError<'_> { } PrepareThinLtoModule => fluent::codegen_llvm_prepare_thin_lto_module_with_llvm_err, ParseBitcode => fluent::codegen_llvm_parse_bitcode_with_llvm_err, - PrepareAutoDiff { .. } => { - fluent::codegen_llvm_prepare_autodiff_with_llvm_err - } + PrepareAutoDiff { .. } => fluent::codegen_llvm_prepare_autodiff_with_llvm_err, }; let mut diag = self.0.into_diagnostic(dcx, level); diag.set_primary_message(msg_with_llvm_err); diff --git a/compiler/rustc_codegen_llvm/src/lib.rs b/compiler/rustc_codegen_llvm/src/lib.rs index 6f08211a926c8..d47b1bff0f4e8 100644 --- a/compiler/rustc_codegen_llvm/src/lib.rs +++ b/compiler/rustc_codegen_llvm/src/lib.rs @@ -494,8 +494,12 @@ impl ModuleLlvm { } }; - Ok(ModuleLlvm { llmod_raw, llcx, tm: ManuallyDrop::new(tm), - typetrees: Default::default() }) + Ok(ModuleLlvm { + llmod_raw, + llcx, + tm: ManuallyDrop::new(tm), + typetrees: Default::default(), + }) } } diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index 2c3f0df1e5e2c..5bd6d8c1cc1e8 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -860,7 +860,11 @@ pub(crate) unsafe fn enzyme_rust_forward_diff( let mut input_activity: Vec = vec![]; for input in input_diffactivity { let act = cdiffe_from(input); - assert!(act == CDIFFE_TYPE::DFT_CONSTANT || act == CDIFFE_TYPE::DFT_DUP_ARG || act == CDIFFE_TYPE::DFT_DUP_NONEED); + assert!( + act == CDIFFE_TYPE::DFT_CONSTANT + || act == CDIFFE_TYPE::DFT_DUP_ARG + || act == CDIFFE_TYPE::DFT_DUP_NONEED + ); input_activity.push(act); } @@ -956,7 +960,6 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff( assert!(num_fnc_args == input_activity.len() as u32); let kv_tmp = IntList { data: std::ptr::null_mut(), size: 0 }; - let mut known_values = vec![kv_tmp; input_tts.len()]; let dummy_type = CFnTypeInfo { @@ -980,7 +983,7 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff( 1, // vector mode width 1, // free memory Option::None, - 0, // do not force anonymous tape + 0, // do not force anonymous tape dummy_type, // additional_arg, type info (return + args) args_uncacheable.as_ptr(), args_uncacheable.len(), // uncacheable arguments @@ -1178,16 +1181,30 @@ extern "C" { Value: *const c_char, ValueLen: c_uint, ) -> &Attribute; - pub fn LLVMRemoveStringAttributeAtIndex(F : &Value, Idx: c_uint, K: *const c_char, KLen : c_uint); - pub fn LLVMGetStringAttributeAtIndex(F : &Value, Idx: c_uint, K: *const c_char, KLen : c_uint) -> &Attribute; - pub fn LLVMAddAttributeAtIndex(F : &Value, Idx: c_uint, K: &Attribute); - pub fn LLVMRemoveEnumAttributeAtIndex(F : &Value, Idx: c_uint, K: Attribute); - pub fn LLVMGetEnumAttributeAtIndex(F : &Value, Idx: c_uint, K: Attribute) -> &Attribute; - pub fn LLVMIsEnumAttribute(A : &Attribute) -> bool; - pub fn LLVMIsStringAttribute(A : &Attribute) -> bool; - pub fn LLVMRustAddEnumAttributeAtIndex(C: &Context, V: &Value, index: c_uint, attr: AttributeKind); + pub fn LLVMRemoveStringAttributeAtIndex(F: &Value, Idx: c_uint, K: *const c_char, KLen: c_uint); + pub fn LLVMGetStringAttributeAtIndex( + F: &Value, + Idx: c_uint, + K: *const c_char, + KLen: c_uint, + ) -> &Attribute; + pub fn LLVMAddAttributeAtIndex(F: &Value, Idx: c_uint, K: &Attribute); + pub fn LLVMRemoveEnumAttributeAtIndex(F: &Value, Idx: c_uint, K: Attribute); + pub fn LLVMGetEnumAttributeAtIndex(F: &Value, Idx: c_uint, K: Attribute) -> &Attribute; + pub fn LLVMIsEnumAttribute(A: &Attribute) -> bool; + pub fn LLVMIsStringAttribute(A: &Attribute) -> bool; + pub fn LLVMRustAddEnumAttributeAtIndex( + C: &Context, + V: &Value, + index: c_uint, + attr: AttributeKind, + ); pub fn LLVMRustRemoveEnumAttributeAtIndex(V: &Value, index: c_uint, attr: AttributeKind); - pub fn LLVMRustGetEnumAttributeAtIndex(V: &Value, index: c_uint, attr: AttributeKind) ->&Attribute; + pub fn LLVMRustGetEnumAttributeAtIndex( + V: &Value, + index: c_uint, + attr: AttributeKind, + ) -> &Attribute; // Operations on functions pub fn LLVMSetFunctionCallConv(Fn: &Value, CC: c_uint); @@ -2710,7 +2727,7 @@ extern "C" { fn EnzymeCreatePrimalAndGradient<'a>( arg1: EnzymeLogicRef, _builderCtx: *const u8, // &'a Builder<'_>, - _callerCtx: *const u8,// &'a Value, + _callerCtx: *const u8, // &'a Value, todiff: &'a Value, retType: CDIFFE_TYPE, constant_args: *const CDIFFE_TYPE, @@ -2734,8 +2751,8 @@ extern "C" { extern "C" { fn EnzymeCreateForwardDiff<'a>( arg1: EnzymeLogicRef, - _builderCtx: *const u8,// &'a Builder<'_>, - _callerCtx: *const u8,// &'a Value, + _builderCtx: *const u8, // &'a Builder<'_>, + _callerCtx: *const u8, // &'a Value, todiff: &'a Value, retType: CDIFFE_TYPE, constant_args: *const CDIFFE_TYPE, diff --git a/compiler/rustc_codegen_ssa/src/back/lto.rs b/compiler/rustc_codegen_ssa/src/back/lto.rs index cf03fb78cf1d2..1e749f570ac51 100644 --- a/compiler/rustc_codegen_ssa/src/back/lto.rs +++ b/compiler/rustc_codegen_ssa/src/back/lto.rs @@ -88,11 +88,9 @@ impl LtoModuleCodegen { ) -> Result, FatalError> { match &self { LtoModuleCodegen::Fat { ref module, .. } => { - { - B::autodiff(cgcx, &module, diff_fncs, typetrees, config)?; - } - }, - _ => {}, + B::autodiff(cgcx, &module, diff_fncs, typetrees, config)?; + } + _ => {} } Ok(self) diff --git a/compiler/rustc_codegen_ssa/src/back/write.rs b/compiler/rustc_codegen_ssa/src/back/write.rs index 4328509d84a39..8e2598e3de7f4 100644 --- a/compiler/rustc_codegen_ssa/src/back/write.rs +++ b/compiler/rustc_codegen_ssa/src/back/write.rs @@ -964,12 +964,18 @@ pub(crate) enum Message { /// The backend has finished processing a work item for a codegen unit. /// Sent from a backend worker thread. - WorkItem { result: Result, Option>, worker_id: usize }, + WorkItem { + result: Result, Option>, + worker_id: usize, + }, /// The frontend has finished generating something (backend IR or a /// post-LTO artifact) for a codegen unit, and it should be passed to the /// backend. Sent from the main thread. - CodegenDone { llvm_work_item: WorkItem, cost: u64 }, + CodegenDone { + llvm_work_item: WorkItem, + cost: u64, + }, /// Similar to `CodegenDone`, but for reusing a pre-LTO artifact /// Sent from the main thread. @@ -1514,7 +1520,7 @@ fn start_executing_work( let tt = B::typetrees(&mut module.module_llvm); typetrees.extend(tt); } - _ => {}, + _ => {} } // We keep the queue sorted by estimated processing cost, @@ -1543,7 +1549,6 @@ fn start_executing_work( autodiff_items.append(&mut items); } - Message::CodegenComplete => { if codegen_state != Aborted { codegen_state = Completed; diff --git a/compiler/rustc_codegen_ssa/src/codegen_attrs.rs b/compiler/rustc_codegen_ssa/src/codegen_attrs.rs index 9184f2ce3e13a..33dedfd77c685 100644 --- a/compiler/rustc_codegen_ssa/src/codegen_attrs.rs +++ b/compiler/rustc_codegen_ssa/src/codegen_attrs.rs @@ -1,5 +1,5 @@ -use rustc_ast::{ast, attr, MetaItem, MetaItemKind, NestedMetaItem}; use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode}; +use rustc_ast::{ast, attr, MetaItem, MetaItemKind, NestedMetaItem}; use rustc_attr::{list_contains_name, InlineAttr, InstructionSetAttr, OptimizeAttr}; use rustc_errors::struct_span_err; use rustc_hir as hir; @@ -722,13 +722,16 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs { let list = attr.meta_item_list().unwrap_or_default(); // empty autodiff attribute macros (i.e. `#[autodiff]`) are used to mark source functions - if list.len() == 0 { return AutoDiffAttrs::source(); } + if list.len() == 0 { + return AutoDiffAttrs::source(); + } let msg_ad_mode = "autodiff attribute must contain autodiff mode"; let (mode, list) = match list.split_first() { - Some((NestedMetaItem::MetaItem(MetaItem { path: ref p1, kind: MetaItemKind::Word, .. }), list)) => { - (p1.segments.first().unwrap().ident, list) - } + Some(( + NestedMetaItem::MetaItem(MetaItem { path: ref p1, kind: MetaItemKind::Word, .. }), + list, + )) => (p1.segments.first().unwrap().ident, list), _ => { tcx.sess .struct_span_err(attr.span, msg_ad_mode) @@ -756,9 +759,10 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs { let msg_ret_activity = "autodiff attribute must contain the return activity"; let (ret_symbol, list) = match list.split_last() { - Some((NestedMetaItem::MetaItem(MetaItem { path: ref p1, kind: MetaItemKind::Word, .. }), list)) => { - (p1.segments.first().unwrap().ident, list) - } + Some(( + NestedMetaItem::MetaItem(MetaItem { path: ref p1, kind: MetaItemKind::Word, .. }), + list, + )) => (p1.segments.first().unwrap().ident, list), _ => { tcx.sess .struct_span_err(attr.span, msg_ret_activity) @@ -791,9 +795,7 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs { }) => p2.segments.first().unwrap().ident, _ => { tcx.sess - .struct_span_err( - attr.span, msg_arg_activity, - ) + .struct_span_err(attr.span, msg_arg_activity) .span_label(attr.span, "missing return activity") .emit(); @@ -816,7 +818,8 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs { let msg_fwd_incompatible_ret = "Forward Mode is incompatible with Active ret"; let msg_fwd_incompatible_arg = "Forward Mode is incompatible with Active ret"; - let msg_rev_incompatible_arg = "Reverse Mode is only compatible with Active, None, or Const ret"; + let msg_rev_incompatible_arg = + "Reverse Mode is only compatible with Active, None, or Const ret"; if mode == DiffMode::Forward { if ret_activity == DiffActivity::Active { tcx.sess @@ -840,9 +843,7 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs { { dbg!("ret_activity = {:?}", ret_activity); tcx.sess - .struct_span_err( - attr.span, msg_rev_incompatible_arg, - ) + .struct_span_err(attr.span, msg_rev_incompatible_arg) .span_label(attr.span, "invalid return activity") .emit(); return AutoDiffAttrs::inactive(); diff --git a/compiler/rustc_middle/src/query/mod.rs b/compiler/rustc_middle/src/query/mod.rs index 6524c6ea00e8a..836387a82442f 100644 --- a/compiler/rustc_middle/src/query/mod.rs +++ b/compiler/rustc_middle/src/query/mod.rs @@ -54,8 +54,8 @@ use crate::ty::{ use crate::ty::{GenericArg, GenericArgsRef}; use rustc_arena::TypedArena; use rustc_ast as ast; -use rustc_ast::expand::{allocator::AllocatorKind, StrippedCfgItem}; use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem}; +use rustc_ast::expand::{allocator::AllocatorKind, StrippedCfgItem}; use rustc_attr as attr; use rustc_data_structures::fingerprint::Fingerprint; use rustc_data_structures::fx::{FxHashMap, FxIndexMap, FxIndexSet}; diff --git a/compiler/rustc_monomorphize/src/partitioning.rs b/compiler/rustc_monomorphize/src/partitioning.rs index bbe78b527de7f..6342840e64f04 100644 --- a/compiler/rustc_monomorphize/src/partitioning.rs +++ b/compiler/rustc_monomorphize/src/partitioning.rs @@ -92,15 +92,14 @@ //! source-level module, functions from the same module will be available for //! inlining, even when they are not marked `#[inline]`. - use std::cmp; use std::collections::hash_map::Entry; use std::fs::{self, File}; use std::io::{BufWriter, Write}; use std::path::{Path, PathBuf}; -use rustc_ast::expand::typetree::{Kind, Type, TypeTree}; use rustc_ast::expand::autodiff_attrs::AutoDiffItem; +use rustc_ast::expand::typetree::{Kind, Type, TypeTree}; use rustc_data_structures::fx::{FxHashMap, FxHashSet}; use rustc_data_structures::sync; use rustc_hir::def::DefKind; @@ -114,7 +113,9 @@ use rustc_middle::mir::mono::{ }; use rustc_middle::query::Providers; use rustc_middle::ty::print::{characteristic_def_id_of_type, with_no_trimmed_paths}; -use rustc_middle::ty::{self, visit::TypeVisitableExt, InstanceDef, TyCtxt, ParamEnv, ParamEnvAnd, Adt, Ty}; +use rustc_middle::ty::{ + self, visit::TypeVisitableExt, Adt, InstanceDef, ParamEnv, ParamEnvAnd, Ty, TyCtxt, +}; use rustc_session::config::{DumpMonoStatsFormat, SwitchWithOptPath}; use rustc_session::CodegenUnits; use rustc_span::symbol::Symbol; @@ -255,9 +256,8 @@ where ); //if visibility == Visibility::Hidden && can_be_internalized { - let autodiff_active = characteristic_def_id - .map(|x| cx.tcx.autodiff_attrs(x).is_active()) - .unwrap_or(false); + let autodiff_active = + characteristic_def_id.map(|x| cx.tcx.autodiff_attrs(x).is_active()).unwrap_or(false); if autodiff_active { dbg!("place_mono_items: autodiff_active"); dbg!(&mono_item); @@ -1099,7 +1099,10 @@ where } } -fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> (&DefIdSet, &[AutoDiffItem], &[CodegenUnit<'_>]) { +fn collect_and_partition_mono_items( + tcx: TyCtxt<'_>, + (): (), +) -> (&DefIdSet, &[AutoDiffItem], &[CodegenUnit<'_>]) { let collection_mode = match tcx.sess.opts.unstable_opts.print_mono_items { Some(ref s) => { let mode = s.to_lowercase(); @@ -1158,7 +1161,7 @@ fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> (&DefIdSet, &[Au }) .collect(); - let autodiff_items = items + let autodiff_items = items .iter() .filter_map(|item| match *item { MonoItem::Fn(ref instance) => Some((item, instance)), @@ -1170,43 +1173,21 @@ fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> (&DefIdSet, &[Au if !target_attrs.apply_autodiff() { return None; } - //println!("target_id: {:?}", target_id); let target_symbol = symbol_name_for_instance_in_crate(tcx, instance.clone(), LOCAL_CRATE); - //let range = usage_map.used_map.get(&item).unwrap(); - //TODO: check if last and next line are correct after rebasing - - println!("target_symbol: {:?}", target_symbol); - println!("target_attrs: {:?}", target_attrs); - println!("target_id: {:?}", target_id); - //print item - println!("item: {:?}", item); - let source = usage_map.used_map.get(&item).unwrap() - .into_iter() - .find_map(|item| match *item { + + let source = + usage_map.used_map.get(&item).unwrap().into_iter().find_map(|item| match *item { MonoItem::Fn(ref instance_s) => { let source_id = instance_s.def_id(); - println!("source_id_inner: {:?}", source_id); - println!("instance_s: {:?}", instance_s); - if tcx.autodiff_attrs(source_id).is_active() { - println!("source_id is active"); return Some(instance_s); } - //target_symbol: "_ZN14rosenbrock_rev12d_rosenbrock17h3352c4f00c3082daE" - //target_attrs: AutoDiffAttrs { mode: Reverse, ret_activity: Active, input_activity: [Duplicated] } - //target_id: DefId(0:8 ~ rosenbrock_rev[2708]::d_rosenbrock) - //item: Fn(Instance { def: Item(DefId(0:8 ~ rosenbrock_rev[2708]::d_rosenbrock)), args: [] }) - //source_id_inner: DefId(0:4 ~ rosenbrock_rev[2708]::main) - //instance_s: Instance { def: Item(DefId(0:4 ~ rosenbrock_rev[2708]::main)), args: [] } - - None } _ => None, }); - //.next(); println!("source: {:?}", source); source.map(|inst| { diff --git a/library/std/src/lib.rs b/library/std/src/lib.rs index 2b7d685e8d0a2..11e21c577dd70 100644 --- a/library/std/src/lib.rs +++ b/library/std/src/lib.rs @@ -255,7 +255,6 @@ #![allow(unused_features)] // // Features: - #![cfg_attr(not(bootstrap), feature(autodiff))] #![cfg_attr(test, feature(internal_output_capture, print_internals, update_panic_count, rt))] #![cfg_attr( @@ -664,7 +663,6 @@ pub use core::{ module_path, option_env, stringify, trace_macros, }; - // #[unstable( // feature = "autodiff", // issue = "87555", @@ -673,8 +671,6 @@ pub use core::{ // #[cfg(not(bootstrap))] // pub use core::autodiff; - - #[unstable( feature = "concat_bytes", issue = "87555", diff --git a/src/bootstrap/src/core/build_steps/compile.rs b/src/bootstrap/src/core/build_steps/compile.rs index dbef8cb1eb3a5..561b04bc3f41a 100644 --- a/src/bootstrap/src/core/build_steps/compile.rs +++ b/src/bootstrap/src/core/build_steps/compile.rs @@ -1638,8 +1638,7 @@ impl Step for Assemble { } // Build enzyme - let enzyme_install = - Some(builder.ensure(llvm::Enzyme { target: build_compiler.host })); + let enzyme_install = Some(builder.ensure(llvm::Enzyme { target: build_compiler.host })); //let enzyme_install = if builder.config.llvm_enzyme { // Some(builder.ensure(llvm::Enzyme { target: build_compiler.host })) //} else { From e4b7ef1ecfc61842b545a8857db77ec96113fd57 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Mon, 5 Feb 2024 01:46:19 -0500 Subject: [PATCH 022/100] various fn_decl fixes --- compiler/rustc_builtin_macros/src/autodiff.rs | 125 ++++++++++-------- 1 file changed, 73 insertions(+), 52 deletions(-) diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index a8c2d30472597..de337e21134d0 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -5,6 +5,7 @@ //use crate::util::check_autodiff; use crate::errors; +use rustc_ast::FnRetTy; use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode}; use rustc_ast::ptr::P; use rustc_ast::token::{Token, TokenKind}; @@ -41,7 +42,6 @@ pub fn expand( } }; let mut orig_item: P = item.clone().expect_item(); - //dbg!(&orig_item.tokens); let primal = orig_item.ident.clone(); // Allow using `#[autodiff(...)]` only on a Fn @@ -77,7 +77,7 @@ pub fn expand( dbg!(&x); let span = ecx.with_def_site_ctxt(expand_span); - let (d_sig, old_names, new_args, idents) = gen_enzyme_decl(ecx, &sig, &x, span, sig_span); + let (d_sig, old_names, new_args, idents) = gen_enzyme_decl(&sig, &x, span); let new_decl_span = d_sig.span; let d_body = gen_enzyme_body( ecx, @@ -147,11 +147,11 @@ fn assure_mut_ref(ty: &ast::Ty) -> ast::Ty { ty } -// The body of our generated functions will consist of three black_Box calls. +// The body of our generated functions will consist of two black_Box calls. // The first will call the primal function with the original arguments. -// The second will just take the shadow arguments. -// The third will (unsafely) call std::mem::zeroed(), to match the return type of the new function -// (whatever that might be). This way we surpress rustc from optimizing anyt argument away. +// The second will just take a tuple containing the new arguments. +// This way we surpress rustc from optimizing any argument away. +// The last line will 'loop {}', to match the return type of the new function fn gen_enzyme_body( ecx: &ExtCtxt<'_>, primal: Ident, @@ -184,31 +184,25 @@ fn gen_enzyme_body( })); let primal_call = gen_primal_call(ecx, span, primal, sig, idents); // create ::core::hint::black_box(array(arr)); - let black_box0 = + let black_box_primal_call = ecx.expr_call(new_decl_span, blackbox_call_expr.clone(), thin_vec![primal_call.clone()]); - // create ::core::hint::black_box(grad_arr, tang_y)); - let black_box1 = ecx.expr_call( - sig_span, - blackbox_call_expr.clone(), - new_names - .iter() - .map(|arg| ecx.expr_path(ecx.path_ident(span, Ident::from_str(arg)))) - .collect(), - ); + // create ::core::hint::black_box((grad_arr, tang_y)); + let tup_args = new_names + .iter() + .map(|arg| ecx.expr_path(ecx.path_ident(span, Ident::from_str(arg)))) + .collect(); - // create ::core::hint::black_box(unsafe { ::core::mem::zeroed() }) - let black_box2 = ecx.expr_call( + let black_box_remaining_args = ecx.expr_call( sig_span, blackbox_call_expr.clone(), - thin_vec![unsafe_block_with_zeroed_call.clone()], + thin_vec![ecx.expr_tuple(sig_span, tup_args)], ); let mut body = ecx.block(span, ThinVec::new()); body.stmts.push(ecx.stmt_semi(primal_call)); - body.stmts.push(ecx.stmt_semi(black_box0)); - body.stmts.push(ecx.stmt_semi(black_box1)); - //body.stmts.push(ecx.stmt_semi(black_box2)); + body.stmts.push(ecx.stmt_semi(black_box_primal_call)); + body.stmts.push(ecx.stmt_semi(black_box_remaining_args)); body.stmts.push(ecx.stmt_expr(loop_expr)); body } @@ -233,11 +227,9 @@ fn gen_primal_call( // Each argument of the primal function (and the return type if existing) must be annotated with an // activity. fn gen_enzyme_decl( - _ecx: &ExtCtxt<'_>, sig: &ast::FnSig, x: &AutoDiffAttrs, span: Span, - _sig_span: Span, ) -> (ast::FnSig, Vec, Vec, Vec) { assert!(sig.decl.inputs.len() == x.input_activity.len()); assert!(sig.decl.output.has_ret() == x.has_ret_activity()); @@ -246,15 +238,19 @@ fn gen_enzyme_decl( let mut new_inputs = Vec::new(); let mut old_names = Vec::new(); let mut idents = Vec::new(); + let mut act_ret = ThinVec::new(); for (arg, activity) in sig.decl.inputs.iter().zip(x.input_activity.iter()) { d_inputs.push(arg.clone()); match activity { + DiffActivity::Active => { + assert!(x.mode == DiffMode::Reverse); + act_ret.push(arg.ty.clone()); + } DiffActivity::Duplicated | DiffActivity::Dual => { let mut shadow_arg = arg.clone(); shadow_arg.ty = P(assure_mut_ref(&arg.ty)); // adjust name depending on mode - let old_name = if let PatKind::Ident(_, ident, _) = shadow_arg.pat.kind { - idents.push(ident.clone()); + let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind { ident.name } else { dbg!(&shadow_arg.pat); @@ -276,47 +272,72 @@ fn gen_enzyme_decl( span: shadow_arg.pat.span, tokens: shadow_arg.pat.tokens.clone(), }); - //idents.push(ident); d_inputs.push(shadow_arg); } _ => { dbg!(&activity); } } + if let PatKind::Ident(_, ident, _) = arg.pat.kind { + idents.push(ident.clone()); + } else { + panic!("not an ident?"); + } } // If we return a scalar in the primal and the scalar is active, // then add it as last arg to the inputs. - if x.mode == DiffMode::Reverse { - match x.ret_activity { - DiffActivity::Active => { - let ty = match d_decl.output { - rustc_ast::FnRetTy::Ty(ref ty) => ty.clone(), - rustc_ast::FnRetTy::Default(span) => { - panic!("Did not expect Default ret ty: {:?}", span); - } - }; - let name = "dret".to_string(); - let ident = Ident::from_str_and_span(&name, ty.span); - let shadow_arg = ast::Param { - attrs: ThinVec::new(), - ty: ty.clone(), - pat: P(ast::Pat { - id: ast::DUMMY_NODE_ID, - kind: PatKind::Ident(BindingAnnotation::NONE, ident, None), - span: ty.span, - tokens: None, - }), + if let DiffMode::Reverse = x.mode { + if let DiffActivity::Active = x.ret_activity { + let ty = match d_decl.output { + FnRetTy::Ty(ref ty) => ty.clone(), + FnRetTy::Default(span) => { + panic!("Did not expect Default ret ty: {:?}", span); + } + }; + let name = "dret".to_string(); + let ident = Ident::from_str_and_span(&name, ty.span); + let shadow_arg = ast::Param { + attrs: ThinVec::new(), + ty: ty.clone(), + pat: P(ast::Pat { id: ast::DUMMY_NODE_ID, + kind: PatKind::Ident(BindingAnnotation::NONE, ident, None), span: ty.span, - is_placeholder: false, - }; - d_inputs.push(shadow_arg); - } - _ => {} + tokens: None, + }), + id: ast::DUMMY_NODE_ID, + span: ty.span, + is_placeholder: false, + }; + d_inputs.push(shadow_arg); + new_inputs.push(name); } } d_decl.inputs = d_inputs.into(); + + // If we have an active input scalar, add it's gradient to the + // return type. This might require changing the return type to a + // tuple. + if act_ret.len() > 0 { + let mut ret_ty = match d_decl.output { + FnRetTy::Ty(ref ty) => { + act_ret.insert(0, ty.clone()); + let kind = TyKind::Tup(act_ret); + P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None }) + } + FnRetTy::Default(span) => { + if act_ret.len() == 1 { + act_ret[0].clone() + } else { + let kind = TyKind::Tup(act_ret.iter().map(|arg| arg.clone()).collect()); + P(rustc_ast::Ty { kind, id: ast::DUMMY_NODE_ID, span, tokens: None }) + } + } + }; + d_decl.output = FnRetTy::Ty(ret_ty); + } + let d_sig = FnSig { header: sig.header.clone(), decl: d_decl, span }; (d_sig, old_names, new_inputs, idents) } From 9cbdf680c0632b5719265803eb1625535c4eac69 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Mon, 5 Feb 2024 02:40:51 -0500 Subject: [PATCH 023/100] More precise Enzyme settings --- .../rustc_ast/src/expand/autodiff_attrs.rs | 11 +++--- compiler/rustc_codegen_llvm/src/back/write.rs | 5 --- compiler/rustc_codegen_llvm/src/llvm/ffi.rs | 38 ++++++++----------- .../rustc_codegen_ssa/src/codegen_attrs.rs | 2 +- 4 files changed, 23 insertions(+), 33 deletions(-) diff --git a/compiler/rustc_ast/src/expand/autodiff_attrs.rs b/compiler/rustc_ast/src/expand/autodiff_attrs.rs index a0f632e1ba105..c6febf6f9467a 100644 --- a/compiler/rustc_ast/src/expand/autodiff_attrs.rs +++ b/compiler/rustc_ast/src/expand/autodiff_attrs.rs @@ -17,12 +17,13 @@ pub enum DiffMode { #[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] pub enum DiffActivity { None, - Active, Const, + Active, + ActiveOnly, Dual, - DualNoNeed, + DualOnly, Duplicated, - DuplicatedNoNeed, + DuplicatedOnly, } impl FromStr for DiffMode { @@ -47,9 +48,9 @@ impl FromStr for DiffActivity { "Active" => Ok(DiffActivity::Active), "Const" => Ok(DiffActivity::Const), "Dual" => Ok(DiffActivity::Dual), - "DualNoNeed" => Ok(DiffActivity::DualNoNeed), + "DualOnly" => Ok(DiffActivity::DualOnly), "Duplicated" => Ok(DiffActivity::Duplicated), - "DuplicatedNoNeed" => Ok(DiffActivity::DuplicatedNoNeed), + "DuplicatedOnly" => Ok(DiffActivity::DuplicatedOnly), _ => Err(()), } } diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs index dc1f988617377..132a68a1b254f 100644 --- a/compiler/rustc_codegen_llvm/src/back/write.rs +++ b/compiler/rustc_codegen_llvm/src/back/write.rs @@ -756,8 +756,6 @@ pub(crate) unsafe fn enzyme_ad( let output_tt = to_enzyme_typetree(item.output, llvm_data_layout, llcx); let opt = 1; - let ret_primary_ret = false; - let diff_primary_ret = false; let logic_ref: EnzymeLogicRef = CreateEnzymeLogic(opt as u8); let type_analysis: EnzymeTypeAnalysisRef = CreateTypeAnalysis(logic_ref, std::ptr::null_mut(), std::ptr::null_mut(), 0); @@ -784,7 +782,6 @@ pub(crate) unsafe fn enzyme_ad( src_fnc, args_activity, ret_activity, - ret_primary_ret, input_tts, output_tt, ), @@ -794,8 +791,6 @@ pub(crate) unsafe fn enzyme_ad( src_fnc, args_activity, ret_activity, - ret_primary_ret, - diff_primary_ret, input_tts, output_tt, ), diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index 5bd6d8c1cc1e8..e7aca0ab77689 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -851,10 +851,10 @@ pub(crate) unsafe fn enzyme_rust_forward_diff( fnc: &Value, input_diffactivity: Vec, ret_diffactivity: DiffActivity, - mut ret_primary_ret: bool, input_tts: Vec, output_tt: TypeTree, ) -> &Value { + let mut ret_primary_ret = false; let ret_activity = cdiffe_from(ret_diffactivity); assert!(ret_activity != CDIFFE_TYPE::DFT_OUT_DIFF); let mut input_activity: Vec = vec![]; @@ -925,29 +925,22 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff( fnc: &Value, input_activity: Vec, ret_activity: DiffActivity, - mut ret_primary_ret: bool, - diff_primary_ret: bool, input_tts: Vec, output_tt: TypeTree, ) -> &Value { - let ret_activity = cdiffe_from(ret_activity); - assert!(ret_activity == CDIFFE_TYPE::DFT_CONSTANT || ret_activity == CDIFFE_TYPE::DFT_OUT_DIFF); + let (primary_ret, diff_ret, ret_activity) = match ret_activity { + DiffActivity::Const => (true, false, CDIFFE_TYPE::DFT_CONSTANT), + DiffActivity::Active => (true, true, CDIFFE_TYPE::DFT_DUP_ARG), + DiffActivity::ActiveOnly => (false, true, CDIFFE_TYPE::DFT_DUP_NONEED), + DiffActivity::None => (false, false, CDIFFE_TYPE::DFT_CONSTANT), + _ => panic!("Invalid return activity"), + }; + + //assert!(ret_activity == CDIFFE_TYPE::DFT_CONSTANT || ret_activity == CDIFFE_TYPE::DFT_OUT_DIFF); let input_activity: Vec = input_activity.iter().map(|&x| cdiffe_from(x)).collect(); dbg!(&fnc); - if ret_activity == CDIFFE_TYPE::DFT_DUP_ARG { - if ret_primary_ret != true { - dbg!("overwriting ret_primary_ret!"); - } - ret_primary_ret = true; - } else if ret_activity == CDIFFE_TYPE::DFT_DUP_NONEED { - if ret_primary_ret != false { - dbg!("overwriting ret_primary_ret!"); - } - ret_primary_ret = false; - } - let mut args_tree = input_tts.iter().map(|x| x.inner).collect::>(); // We don't support volatile / extern / (global?) values. @@ -977,8 +970,8 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff( input_activity.as_ptr(), input_activity.len(), // constant arguments type_analysis, // type analysis struct - ret_primary_ret as u8, - diff_primary_ret as u8, //0 + primary_ret as u8, + diff_ret as u8, //0 CDerivativeMode::DEM_ReverseModeCombined, // return value, dret_used, top_level which was 1 1, // vector mode width 1, // free memory @@ -2704,12 +2697,13 @@ pub enum CDIFFE_TYPE { fn cdiffe_from(act: DiffActivity) -> CDIFFE_TYPE { return match act { DiffActivity::None => CDIFFE_TYPE::DFT_CONSTANT, - DiffActivity::Active => CDIFFE_TYPE::DFT_OUT_DIFF, DiffActivity::Const => CDIFFE_TYPE::DFT_CONSTANT, + DiffActivity::Active => CDIFFE_TYPE::DFT_OUT_DIFF, + DiffActivity::ActiveOnly => CDIFFE_TYPE::DFT_OUT_DIFF, DiffActivity::Dual => CDIFFE_TYPE::DFT_DUP_ARG, - DiffActivity::DualNoNeed => CDIFFE_TYPE::DFT_DUP_NONEED, + DiffActivity::DualOnly => CDIFFE_TYPE::DFT_DUP_NONEED, DiffActivity::Duplicated => CDIFFE_TYPE::DFT_DUP_ARG, - DiffActivity::DuplicatedNoNeed => CDIFFE_TYPE::DFT_DUP_NONEED, + DiffActivity::DuplicatedOnly => CDIFFE_TYPE::DFT_DUP_NONEED, }; } diff --git a/compiler/rustc_codegen_ssa/src/codegen_attrs.rs b/compiler/rustc_codegen_ssa/src/codegen_attrs.rs index 33dedfd77c685..42e390d66d1ba 100644 --- a/compiler/rustc_codegen_ssa/src/codegen_attrs.rs +++ b/compiler/rustc_codegen_ssa/src/codegen_attrs.rs @@ -839,7 +839,7 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs { if mode == DiffMode::Reverse { if ret_activity == DiffActivity::Duplicated - || ret_activity == DiffActivity::DuplicatedNoNeed + || ret_activity == DiffActivity::DuplicatedOnly { dbg!("ret_activity = {:?}", ret_activity); tcx.sess From bdc8e5d1a1fa89a88bac33b4bb7a6a9e7f74d33c Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Mon, 5 Feb 2024 14:54:46 -0500 Subject: [PATCH 024/100] cleanup, reduce diff --- compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp index 49873e5e02cd4..a2976e1f510c2 100644 --- a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp +++ b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp @@ -300,6 +300,14 @@ extern "C" void LLVMRustAddFunctionAttributes(LLVMValueRef Fn, unsigned Index, AddAttributes(F, Index, Attrs, AttrsLen); } +extern "C" void LLVMRustAddCallSiteAttributes(LLVMValueRef Instr, + unsigned Index, + LLVMAttributeRef *Attrs, + size_t AttrsLen) { + CallBase *Call = unwrap(Instr); + AddAttributes(Call, Index, Attrs, AttrsLen); +} + extern "C" LLVMAttributeRef LLVMRustCreateAttrNoValue(LLVMContextRef C, LLVMRustAttribute RustAttr) { return wrap(Attribute::get(*unwrap(C), fromRust(RustAttr))); @@ -327,12 +335,6 @@ LLVMRustGetEnumAttributeAtIndex(LLVMValueRef F, size_t index, return LLVMGetEnumAttributeAtIndex(F, index, fromRust(RustAttr)); } -extern "C" void LLVMRustAddCallSiteAttributes(LLVMValueRef Instr, unsigned Index, - LLVMAttributeRef *Attrs, size_t AttrsLen) { - CallBase *Call = unwrap(Instr); - AddAttributes(Call, Index, Attrs, AttrsLen); -} - extern "C" LLVMAttributeRef LLVMRustCreateAlignmentAttr(LLVMContextRef C, uint64_t Bytes) { return wrap(Attribute::getWithAlignment(*unwrap(C), llvm::Align(Bytes))); From c8c4ea3ac90051268a45faf9d9787ebde11d1744 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Thu, 8 Feb 2024 22:09:33 -0500 Subject: [PATCH 025/100] cleanups --- compiler/rustc_builtin_macros/src/autodiff.rs | 2 +- compiler/rustc_codegen_llvm/src/back/write.rs | 3 +-- compiler/rustc_codegen_llvm/src/lib.rs | 2 -- compiler/rustc_codegen_llvm/src/llvm/ffi.rs | 3 +-- compiler/rustc_codegen_ssa/src/back/write.rs | 2 -- compiler/rustc_monomorphize/src/partitioning.rs | 5 ----- 6 files changed, 3 insertions(+), 14 deletions(-) diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index de337e21134d0..5978ad05b0dfa 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -5,11 +5,11 @@ //use crate::util::check_autodiff; use crate::errors; -use rustc_ast::FnRetTy; use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode}; use rustc_ast::ptr::P; use rustc_ast::token::{Token, TokenKind}; use rustc_ast::tokenstream::*; +use rustc_ast::FnRetTy; use rustc_ast::{self as ast, FnHeader, FnSig, Generics, MetaItemKind, NestedMetaItem, StmtKind}; use rustc_ast::{BindingAnnotation, ByRef}; use rustc_ast::{Fn, ItemKind, PatKind, Stmt, TyKind, Unsafe}; diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs index 132a68a1b254f..ce77678696da7 100644 --- a/compiler/rustc_codegen_llvm/src/back/write.rs +++ b/compiler/rustc_codegen_llvm/src/back/write.rs @@ -702,7 +702,6 @@ pub(crate) unsafe fn enzyme_ad( diag_handler: &DiagCtxt, item: AutoDiffItem, ) -> Result<(), FatalError> { - dbg!("cg_llvm enzyme_ad"); let autodiff_mode = item.attrs.mode; let rust_name = item.source; let rust_name2 = &item.target; @@ -743,6 +742,7 @@ pub(crate) unsafe fn enzyme_ad( }; let src_num_args = llvm::LLVMCountParams(src_fnc); let target_num_args = llvm::LLVMCountParams(target_fnc); + // A really simple check assert!(src_num_args <= target_num_args); // create enzyme typetrees @@ -820,7 +820,6 @@ pub(crate) unsafe fn differentiate( _typetrees: FxHashMap, _config: &ModuleConfig, ) -> Result<(), FatalError> { - dbg!("cg_llvm differentiate"); dbg!(&diff_items); let llmod = module.module_llvm.llmod(); diff --git a/compiler/rustc_codegen_llvm/src/lib.rs b/compiler/rustc_codegen_llvm/src/lib.rs index d47b1bff0f4e8..a6b625139ef22 100644 --- a/compiler/rustc_codegen_llvm/src/lib.rs +++ b/compiler/rustc_codegen_llvm/src/lib.rs @@ -268,8 +268,6 @@ impl WriteBackendMethods for LlvmCodegenBackend { typetrees: FxHashMap, config: &ModuleConfig, ) -> Result<(), FatalError> { - dbg!("cg_llvm autodiff"); - dbg!("Differentiating {} functions", diff_fncs.len()); unsafe { back::write::differentiate(module, cgcx, diff_fncs, typetrees, config) } } diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index e7aca0ab77689..2ef24a42cfda0 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -936,7 +936,6 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff( _ => panic!("Invalid return activity"), }; - //assert!(ret_activity == CDIFFE_TYPE::DFT_CONSTANT || ret_activity == CDIFFE_TYPE::DFT_OUT_DIFF); let input_activity: Vec = input_activity.iter().map(|&x| cdiffe_from(x)).collect(); dbg!(&fnc); @@ -971,7 +970,7 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff( input_activity.len(), // constant arguments type_analysis, // type analysis struct primary_ret as u8, - diff_ret as u8, //0 + diff_ret as u8, //0 CDerivativeMode::DEM_ReverseModeCombined, // return value, dret_used, top_level which was 1 1, // vector mode width 1, // free memory diff --git a/compiler/rustc_codegen_ssa/src/back/write.rs b/compiler/rustc_codegen_ssa/src/back/write.rs index 8e2598e3de7f4..04d4843911087 100644 --- a/compiler/rustc_codegen_ssa/src/back/write.rs +++ b/compiler/rustc_codegen_ssa/src/back/write.rs @@ -1545,7 +1545,6 @@ fn start_executing_work( } Message::AddAutoDiffItems(mut items) => { - dbg!("AddAutoDiffItems"); autodiff_items.append(&mut items); } @@ -2019,7 +2018,6 @@ impl OngoingCodegen { } pub fn submit_autodiff_items(&self, items: Vec) { - dbg!("submit_autodiff_items"); drop(self.coordinator.sender.send(Box::new(Message::::AddAutoDiffItems(items)))); } diff --git a/compiler/rustc_monomorphize/src/partitioning.rs b/compiler/rustc_monomorphize/src/partitioning.rs index 6342840e64f04..ad1141726d970 100644 --- a/compiler/rustc_monomorphize/src/partitioning.rs +++ b/compiler/rustc_monomorphize/src/partitioning.rs @@ -258,10 +258,6 @@ where //if visibility == Visibility::Hidden && can_be_internalized { let autodiff_active = characteristic_def_id.map(|x| cx.tcx.autodiff_attrs(x).is_active()).unwrap_or(false); - if autodiff_active { - dbg!("place_mono_items: autodiff_active"); - dbg!(&mono_item); - } if !autodiff_active && visibility == Visibility::Hidden && can_be_internalized { internalization_candidates.insert(mono_item); @@ -1188,7 +1184,6 @@ fn collect_and_partition_mono_items( } _ => None, }); - println!("source: {:?}", source); source.map(|inst| { println!("source_id: {:?}", inst.def_id()); From 7b0d0f112f7b6ecb0d4477e22fd4682d66b7d855 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Thu, 8 Feb 2024 22:10:36 -0500 Subject: [PATCH 026/100] cleanups2 --- .../rustc_ast/src/expand/autodiff_attrs.rs | 39 ++++++ .../rustc_codegen_ssa/src/codegen_attrs.rs | 128 +++++++----------- 2 files changed, 91 insertions(+), 76 deletions(-) diff --git a/compiler/rustc_ast/src/expand/autodiff_attrs.rs b/compiler/rustc_ast/src/expand/autodiff_attrs.rs index c6febf6f9467a..5db90e94c6f8c 100644 --- a/compiler/rustc_ast/src/expand/autodiff_attrs.rs +++ b/compiler/rustc_ast/src/expand/autodiff_attrs.rs @@ -13,6 +13,45 @@ pub enum DiffMode { Reverse, } +pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool { + match mode { + DiffMode::Inactive => false, + DiffMode::Source => false, + DiffMode::Forward => { + // Doesn't recognize all illegal cases (insufficient information) + activity != DiffActivity::Active && activity != DiffActivity::ActiveOnly + && activity != DiffActivity::Duplicated && activity != DiffActivity::DuplicatedOnly + } + DiffMode::Reverse => { + // Doesn't recognize all illegal cases (insufficient information) + activity != DiffActivity::Duplicated && activity != DiffActivity::DuplicatedOnly + && activity != DiffActivity::Dual && activity != DiffActivity::DualOnly + } + } +} + +pub fn valid_input_activities(mode: DiffMode, activity_vec: &[DiffActivity]) -> bool { + for &activity in activity_vec { + let valid = match mode { + DiffMode::Inactive => false, + DiffMode::Source => false, + DiffMode::Forward => { + // These are the only valid cases + activity == DiffActivity::Dual || activity == DiffActivity::DualOnly || activity == DiffActivity::Const + } + DiffMode::Reverse => { + // These are the only valid cases + activity == DiffActivity::Active || activity == DiffActivity::ActiveOnly || activity == DiffActivity::Const + || activity == DiffActivity::Duplicated || activity == DiffActivity::DuplicatedOnly + } + }; + if !valid { + return false; + } + } + true +} + #[allow(dead_code)] #[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] pub enum DiffActivity { diff --git a/compiler/rustc_codegen_ssa/src/codegen_attrs.rs b/compiler/rustc_codegen_ssa/src/codegen_attrs.rs index 42e390d66d1ba..08cd614cc8eb5 100644 --- a/compiler/rustc_codegen_ssa/src/codegen_attrs.rs +++ b/compiler/rustc_codegen_ssa/src/codegen_attrs.rs @@ -1,4 +1,4 @@ -use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode}; +use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode, valid_ret_activity, valid_input_activities}; use rustc_ast::{ast, attr, MetaItem, MetaItemKind, NestedMetaItem}; use rustc_attr::{list_contains_name, InlineAttr, InstructionSetAttr, OptimizeAttr}; use rustc_errors::struct_span_err; @@ -692,6 +692,11 @@ fn check_link_name_xor_ordinal( } } +/// We now check the #[rustc_autodiff] attributes which we generated from the #[autodiff(...)] +/// macros. There are two forms. The pure one without args to mark primal functions (the functions +/// being differentiated). The other form is #[rustc_autodiff(Mode, ActivityList)] on top of the +/// placeholder functions. We wrote the rustc_autodiff attributes ourself, so this should never +/// panic, unless we introduced a bug when parsing the autodiff macro. fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs { let attrs = tcx.get_attrs(id, sym::rustc_autodiff); @@ -726,20 +731,22 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs { return AutoDiffAttrs::source(); } - let msg_ad_mode = "autodiff attribute must contain autodiff mode"; - let (mode, list) = match list.split_first() { - Some(( - NestedMetaItem::MetaItem(MetaItem { path: ref p1, kind: MetaItemKind::Word, .. }), - list, - )) => (p1.segments.first().unwrap().ident, list), - _ => { - tcx.sess - .struct_span_err(attr.span, msg_ad_mode) - .span_label(attr.span, "empty argument list") - .emit(); - - return AutoDiffAttrs::inactive(); - } + let [mode, input_activities @ .., ret_activity] = &list[..] else { + tcx.sess + .struct_span_err(attr.span, msg_once) + .span_label(attr.span, "Implementation bug in autodiff_attrs. Please report this!") + .emit(); + return AutoDiffAttrs::inactive(); + }; + let mode = if let NestedMetaItem::MetaItem(MetaItem { path: ref p1, .. }) = mode { + p1.segments.first().unwrap().ident + } else { + let msg = "autodiff attribute must contain autodiff mode"; + tcx.sess + .struct_span_err(attr.span, msg) + .span_label(attr.span, "empty argument list") + .emit(); + return AutoDiffAttrs::inactive(); }; // parse mode @@ -752,27 +759,23 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs { .struct_span_err(attr.span, msg_mode) .span_label(attr.span, "invalid mode") .emit(); - return AutoDiffAttrs::inactive(); } }; - let msg_ret_activity = "autodiff attribute must contain the return activity"; - let (ret_symbol, list) = match list.split_last() { - Some(( - NestedMetaItem::MetaItem(MetaItem { path: ref p1, kind: MetaItemKind::Word, .. }), - list, - )) => (p1.segments.first().unwrap().ident, list), - _ => { - tcx.sess - .struct_span_err(attr.span, msg_ret_activity) - .span_label(attr.span, "missing return activity") - .emit(); - - return AutoDiffAttrs::inactive(); - } + // First read the ret symbol from the attribute + let ret_symbol = if let NestedMetaItem::MetaItem(MetaItem { path: ref p1, .. }) = ret_activity { + p1.segments.first().unwrap().ident + } else { + let msg = "autodiff attribute must contain the return activity"; + tcx.sess + .struct_span_err(attr.span, msg) + .span_label(attr.span, "missing return activity") + .emit(); + return AutoDiffAttrs::inactive(); }; + // Then parse it into an actual DiffActivity let msg_unknown_ret_activity = "unknown return activity"; let ret_activity = match DiffActivity::from_str(ret_symbol.as_str()) { Ok(x) => x, @@ -781,26 +784,22 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs { .struct_span_err(attr.span, msg_unknown_ret_activity) .span_label(attr.span, "invalid return activity") .emit(); - return AutoDiffAttrs::inactive(); } }; + // Now parse all the intermediate (inptut) activities let msg_arg_activity = "autodiff attribute must contain the return activity"; let mut arg_activities: Vec = vec![]; - for arg in list { - let arg_symbol = match arg { - NestedMetaItem::MetaItem(MetaItem { - path: ref p2, kind: MetaItemKind::Word, .. - }) => p2.segments.first().unwrap().ident, - _ => { - tcx.sess - .struct_span_err(attr.span, msg_arg_activity) - .span_label(attr.span, "missing return activity") - .emit(); - - return AutoDiffAttrs::inactive(); - } + for arg in input_activities { + let arg_symbol = if let NestedMetaItem::MetaItem(MetaItem { path: ref p2, .. }) = arg { + p2.segments.first().unwrap().ident + } else { + tcx.sess + .struct_span_err(attr.span, msg_arg_activity) + .span_label(attr.span, "Implementation bug, please report this!") + .emit(); + return AutoDiffAttrs::inactive(); }; match DiffActivity::from_str(arg_symbol.as_str()) { @@ -810,45 +809,22 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs { .struct_span_err(attr.span, msg_unknown_ret_activity) .span_label(attr.span, "invalid input activity") .emit(); - return AutoDiffAttrs::inactive(); } } } - let msg_fwd_incompatible_ret = "Forward Mode is incompatible with Active ret"; - let msg_fwd_incompatible_arg = "Forward Mode is incompatible with Active ret"; - let msg_rev_incompatible_arg = - "Reverse Mode is only compatible with Active, None, or Const ret"; - if mode == DiffMode::Forward { - if ret_activity == DiffActivity::Active { - tcx.sess - .struct_span_err(attr.span, msg_fwd_incompatible_ret) - .span_label(attr.span, "invalid return activity") - .emit(); - return AutoDiffAttrs::inactive(); - } - if arg_activities.iter().filter(|&x| *x == DiffActivity::Active).count() > 0 { - tcx.sess - .struct_span_err(attr.span, msg_fwd_incompatible_arg) - .span_label(attr.span, "invalid input activity") - .emit(); - return AutoDiffAttrs::inactive(); - } + let msg = "Invalid activity for mode"; + let valid_input = valid_input_activities(mode, &arg_activities); + let valid_ret = valid_ret_activity(mode, ret_activity); + if !valid_input || !valid_ret { + tcx.sess + .struct_span_err(attr.span, msg) + .span_label(attr.span, "invalid activity") + .emit(); + return AutoDiffAttrs::inactive(); } - if mode == DiffMode::Reverse { - if ret_activity == DiffActivity::Duplicated - || ret_activity == DiffActivity::DuplicatedOnly - { - dbg!("ret_activity = {:?}", ret_activity); - tcx.sess - .struct_span_err(attr.span, msg_rev_incompatible_arg) - .span_label(attr.span, "invalid return activity") - .emit(); - return AutoDiffAttrs::inactive(); - } - } AutoDiffAttrs { mode, ret_activity, input_activity: arg_activities } } From a06caa7a98ee8a38cebdb6a7568e501bb5ad3e9e Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Fri, 9 Feb 2024 00:25:34 -0500 Subject: [PATCH 027/100] make Enzyme build optional --- src/bootstrap/configure.py | 1 + src/bootstrap/src/core/build_steps/compile.rs | 11 +++++------ src/bootstrap/src/core/builder.rs | 6 ++++-- src/bootstrap/src/core/config/config.rs | 6 ++++++ 4 files changed, 16 insertions(+), 8 deletions(-) diff --git a/src/bootstrap/configure.py b/src/bootstrap/configure.py index 544a42d9ada1a..dfc07a4b7ddae 100755 --- a/src/bootstrap/configure.py +++ b/src/bootstrap/configure.py @@ -72,6 +72,7 @@ def v(*args): o("optimize-llvm", "llvm.optimize", "build optimized LLVM") o("llvm-assertions", "llvm.assertions", "build LLVM with assertions") o("llvm-plugins", "llvm.plugins", "build LLVM with plugin interface") +o("llvm-enzyme", "llvm.enzyme", "build LLVM with enzyme") o("debug-assertions", "rust.debug-assertions", "build with debugging assertions") o("debug-assertions-std", "rust.debug-assertions-std", "build the standard library with debugging assertions") o("overflow-checks", "rust.overflow-checks", "build with overflow checks") diff --git a/src/bootstrap/src/core/build_steps/compile.rs b/src/bootstrap/src/core/build_steps/compile.rs index 561b04bc3f41a..4f5844ae51c80 100644 --- a/src/bootstrap/src/core/build_steps/compile.rs +++ b/src/bootstrap/src/core/build_steps/compile.rs @@ -1638,12 +1638,11 @@ impl Step for Assemble { } // Build enzyme - let enzyme_install = Some(builder.ensure(llvm::Enzyme { target: build_compiler.host })); - //let enzyme_install = if builder.config.llvm_enzyme { - // Some(builder.ensure(llvm::Enzyme { target: build_compiler.host })) - //} else { - // None - //}; + let enzyme_install = if builder.config.llvm_enzyme { + Some(builder.ensure(llvm::Enzyme { target: build_compiler.host })) + } else { + None + }; if let Some(enzyme_install) = enzyme_install { let src_lib = enzyme_install.join("build/Enzyme/LLVMEnzyme-17.so"); diff --git a/src/bootstrap/src/core/builder.rs b/src/bootstrap/src/core/builder.rs index b8e10a1f9aa85..eb028c367d797 100644 --- a/src/bootstrap/src/core/builder.rs +++ b/src/bootstrap/src/core/builder.rs @@ -1417,8 +1417,10 @@ impl<'a> Builder<'a> { } // https://rust-lang.zulipchat.com/#narrow/stream/182449-t-compiler.2Fhelp/topic/.E2.9C.94.20link.20new.20library.20into.20stage1.2Frustc - rustflags.arg("-l"); - rustflags.arg("LLVMEnzyme-17"); + if self.config.llvm_enzyme { + rustflags.arg("-l"); + rustflags.arg("LLVMEnzyme-17"); + } let use_new_symbol_mangling = match self.config.rust_new_symbol_mangling { Some(setting) => { diff --git a/src/bootstrap/src/core/config/config.rs b/src/bootstrap/src/core/config/config.rs index f1e1b89d9ba71..c7ab615ac4fe0 100644 --- a/src/bootstrap/src/core/config/config.rs +++ b/src/bootstrap/src/core/config/config.rs @@ -211,6 +211,7 @@ pub struct Config { pub llvm_assertions: bool, pub llvm_tests: bool, pub llvm_plugins: bool, + pub llvm_enzyme: bool, pub llvm_optimize: bool, pub llvm_thin_lto: bool, pub llvm_release_debuginfo: bool, @@ -872,6 +873,7 @@ define_config! { release_debuginfo: Option = "release-debuginfo", assertions: Option = "assertions", tests: Option = "tests", + enzyme: Option = "enzyme", plugins: Option = "plugins", ccache: Option = "ccache", static_libstdcpp: Option = "static-libstdcpp", @@ -1494,6 +1496,7 @@ impl Config { // we'll infer default values for them later let mut llvm_assertions = None; let mut llvm_tests = None; + let mut llvm_enzyme = None; let mut llvm_plugins = None; let mut debug = None; let mut debug_assertions = None; @@ -1697,6 +1700,7 @@ impl Config { release_debuginfo, assertions, tests, + enzyme, plugins, ccache, static_libstdcpp, @@ -1729,6 +1733,7 @@ impl Config { set(&mut config.ninja_in_file, ninja); llvm_assertions = assertions; llvm_tests = tests; + llvm_enzyme = enzyme; llvm_plugins = plugins; set(&mut config.llvm_optimize, optimize_toml); set(&mut config.llvm_thin_lto, thin_lto); @@ -1885,6 +1890,7 @@ impl Config { config.llvm_assertions = llvm_assertions.unwrap_or(false); config.llvm_tests = llvm_tests.unwrap_or(false); + config.llvm_enzyme = llvm_enzyme.unwrap_or(true); config.llvm_plugins = llvm_plugins.unwrap_or(false); config.rust_optimize = optimize.unwrap_or(RustOptimize::Bool(true)); From 2de98a22831bffd0084ea827b9668d9570b4f264 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Fri, 9 Feb 2024 18:07:51 -0500 Subject: [PATCH 028/100] Implement fallback to optionally not link and build Enzyme --- compiler/rustc_codegen_llvm/src/back/write.rs | 20 +- compiler/rustc_codegen_llvm/src/llvm/ffi.rs | 602 ++++++++++++------ config.example.toml | 3 + 3 files changed, 409 insertions(+), 216 deletions(-) diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs index ce77678696da7..feec822f38f9e 100644 --- a/compiler/rustc_codegen_llvm/src/back/write.rs +++ b/compiler/rustc_codegen_llvm/src/back/write.rs @@ -760,19 +760,19 @@ pub(crate) unsafe fn enzyme_ad( let type_analysis: EnzymeTypeAnalysisRef = CreateTypeAnalysis(logic_ref, std::ptr::null_mut(), std::ptr::null_mut(), 0); - llvm::EnzymeSetCLBool(std::ptr::addr_of_mut!(llvm::EnzymeStrictAliasing), 0); + llvm::set_strict_aliasing(false); if std::env::var("ENZYME_PRINT_TA").is_ok() { - llvm::EnzymeSetCLBool(std::ptr::addr_of_mut!(llvm::EnzymePrintType), 1); + llvm::set_print_type(true); } if std::env::var("ENZYME_PRINT_AA").is_ok() { - llvm::EnzymeSetCLBool(std::ptr::addr_of_mut!(llvm::EnzymePrintActivity), 1); + llvm::set_print_activity(true); } if std::env::var("ENZYME_PRINT_PERF").is_ok() { - llvm::EnzymeSetCLBool(std::ptr::addr_of_mut!(llvm::EnzymePrintPerf), 1); + llvm::set_print_perf(true); } if std::env::var("ENZYME_PRINT").is_ok() { - llvm::EnzymeSetCLBool(std::ptr::addr_of_mut!(llvm::EnzymePrint), 1); + llvm::set_print(true); } let mut res: &Value = match item.attrs.mode { @@ -826,7 +826,7 @@ pub(crate) unsafe fn differentiate( let llcx = &module.module_llvm.llcx; let diag_handler = cgcx.create_dcx(); - llvm::EnzymeSetCLBool(std::ptr::addr_of_mut!(llvm::EnzymeStrictAliasing), 0); + llvm::set_strict_aliasing(false); if std::env::var("ENZYME_PRINT_MOD").is_ok() { unsafe { @@ -835,15 +835,15 @@ pub(crate) unsafe fn differentiate( } if std::env::var("ENZYME_TT_DEPTH").is_ok() { let depth = std::env::var("ENZYME_TT_DEPTH").unwrap(); - let depth = depth.parse::().unwrap(); + let depth = depth.parse::().unwrap(); assert!(depth >= 1); - llvm::EnzymeSetCLInteger(std::ptr::addr_of_mut!(llvm::EnzymeMaxTypeDepth), depth); + llvm::set_max_int_offset(depth); } if std::env::var("ENZYME_TT_WIDTH").is_ok() { let width = std::env::var("ENZYME_TT_WIDTH").unwrap(); - let width = width.parse::().unwrap(); + let width = width.parse::().unwrap(); assert!(width >= 1); - llvm::EnzymeSetCLInteger(std::ptr::addr_of_mut!(llvm::MaxTypeOffset), width); + llvm::set_max_type_offset(width); } for item in diff_items { diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index 2ef24a42cfda0..8621f0424b047 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -1,7 +1,6 @@ #![allow(non_camel_case_types)] #![allow(non_upper_case_globals)] -#[allow(unused_imports)] use rustc_ast::expand::autodiff_attrs::DiffActivity; use super::debuginfo::{ @@ -14,8 +13,6 @@ use super::debuginfo::{ use libc::{c_char, c_int, c_uint, size_t}; use libc::{c_ulonglong, c_void}; -use core::fmt; -use std::ffi::{CStr, CString}; use std::marker::PhantomData; use super::RustString; @@ -2607,117 +2604,29 @@ extern "C" { ) -> *mut c_void; } -// Enzyme -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct EnzymeOpaqueTypeAnalysis { - _unused: [u8; 0], -} -pub type EnzymeTypeAnalysisRef = *mut EnzymeOpaqueTypeAnalysis; -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct EnzymeOpaqueLogic { - _unused: [u8; 0], -} -pub type EnzymeLogicRef = *mut EnzymeOpaqueLogic; -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct EnzymeOpaqueAugmentedReturn { - _unused: [u8; 0], -} -pub type EnzymeAugmentedReturnPtr = *mut EnzymeOpaqueAugmentedReturn; -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct IntList { - pub data: *mut i64, - pub size: size_t, -} -#[repr(u32)] -#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] -pub enum CConcreteType { - DT_Anything = 0, - DT_Integer = 1, - DT_Pointer = 2, - DT_Half = 3, - DT_Float = 4, - DT_Double = 5, - DT_Unknown = 6, -} -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct EnzymeTypeTree { - _unused: [u8; 0], -} -pub type CTypeTreeRef = *mut EnzymeTypeTree; -extern "C" { - fn EnzymeNewTypeTree() -> CTypeTreeRef; -} -extern "C" { - fn EnzymeFreeTypeTree(CTT: CTypeTreeRef); -} -extern "C" { - pub fn EnzymeSetCLBool(arg1: *mut ::std::os::raw::c_void, arg2: u8); -} -extern "C" { - pub fn EnzymeSetCLInteger(arg1: *mut ::std::os::raw::c_void, arg2: i64); -} +#[cfg(AD_FALLBACK)] +pub use self::Fallback_AD::*; -extern "C" { - pub static mut MaxIntOffset: c_void; - pub static mut MaxTypeOffset: c_void; - pub static mut EnzymeMaxTypeDepth: c_void; +#[cfg(AD_FALLBACK)] +pub mod Fallback_AD { + #![allow(unused_variables)] + use super::*; - pub static mut EnzymePrintPerf: c_void; - pub static mut EnzymePrintActivity: c_void; - pub static mut EnzymePrintType: c_void; - pub static mut EnzymePrint: c_void; - pub static mut EnzymeStrictAliasing: c_void; -} + pub fn EnzymeNewTypeTree() -> CTypeTreeRef { unimplemented!() } + pub fn EnzymeFreeTypeTree(CTT: CTypeTreeRef) { unimplemented!() } + pub fn EnzymeSetCLBool(arg1: *mut ::std::os::raw::c_void, arg2: u8) { unimplemented!() } + pub fn EnzymeSetCLInteger(arg1: *mut ::std::os::raw::c_void, arg2: i64) { unimplemented!() } -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct CFnTypeInfo { - #[doc = " Types of arguments, assumed of size len(Arguments)"] - pub Arguments: *mut CTypeTreeRef, - #[doc = " Type of return"] - pub Return: CTypeTreeRef, - #[doc = " The specific constant(s) known to represented by an argument, if constant"] - pub KnownValues: *mut IntList, -} -#[repr(u32)] -#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] -pub enum CDIFFE_TYPE { - DFT_OUT_DIFF = 0, - DFT_DUP_ARG = 1, - DFT_CONSTANT = 2, - DFT_DUP_NONEED = 3, -} - -fn cdiffe_from(act: DiffActivity) -> CDIFFE_TYPE { - return match act { - DiffActivity::None => CDIFFE_TYPE::DFT_CONSTANT, - DiffActivity::Const => CDIFFE_TYPE::DFT_CONSTANT, - DiffActivity::Active => CDIFFE_TYPE::DFT_OUT_DIFF, - DiffActivity::ActiveOnly => CDIFFE_TYPE::DFT_OUT_DIFF, - DiffActivity::Dual => CDIFFE_TYPE::DFT_DUP_ARG, - DiffActivity::DualOnly => CDIFFE_TYPE::DFT_DUP_NONEED, - DiffActivity::Duplicated => CDIFFE_TYPE::DFT_DUP_ARG, - DiffActivity::DuplicatedOnly => CDIFFE_TYPE::DFT_DUP_NONEED, - }; -} + pub fn set_max_int_offset(offset: u64) { unimplemented!() } + pub fn set_max_type_offset(offset: u64) { unimplemented!() } + pub fn set_max_type_depth(depth: u64) { unimplemented!() } + pub fn set_print_perf(print: bool) { unimplemented!() } + pub fn set_print_activity(print: bool) { unimplemented!() } + pub fn set_print_type(print: bool) { unimplemented!() } + pub fn set_print(print: bool) { unimplemented!() } + pub fn set_strict_aliasing(strict: bool) { unimplemented!() } -#[repr(u32)] -#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] -pub enum CDerivativeMode { - DEM_ForwardMode = 0, - DEM_ReverseModePrimal = 1, - DEM_ReverseModeGradient = 2, - DEM_ReverseModeCombined = 3, - DEM_ForwardModeSplit = 4, -} -extern "C" { - #[allow(dead_code)] - fn EnzymeCreatePrimalAndGradient<'a>( + pub fn EnzymeCreatePrimalAndGradient<'a>( arg1: EnzymeLogicRef, _builderCtx: *const u8, // &'a Builder<'_>, _callerCtx: *const u8, // &'a Value, @@ -2738,11 +2647,10 @@ extern "C" { uncacheable_args_size: size_t, augmented: EnzymeAugmentedReturnPtr, AtomicAdd: u8, - ) -> &'a Value; - //) -> LLVMValueRef; -} -extern "C" { - fn EnzymeCreateForwardDiff<'a>( + ) -> &'a Value { + unimplemented!() + } + pub fn EnzymeCreateForwardDiff<'a>( arg1: EnzymeLogicRef, _builderCtx: *const u8, // &'a Builder<'_>, _callerCtx: *const u8, // &'a Value, @@ -2760,8 +2668,9 @@ extern "C" { _uncacheable_args: *const u8, uncacheable_args_size: size_t, augmented: EnzymeAugmentedReturnPtr, - ) -> &'a Value; -} + ) -> &'a Value { + unimplemented!() + } pub type CustomRuleType = ::std::option::Option< unsafe extern "C" fn( direction: ::std::os::raw::c_int, @@ -2781,131 +2690,412 @@ extern "C" { numRules: size_t, ) -> EnzymeTypeAnalysisRef; } -extern "C" { - pub fn ClearTypeAnalysis(arg1: EnzymeTypeAnalysisRef); -} -extern "C" { - pub fn FreeTypeAnalysis(arg1: EnzymeTypeAnalysisRef); -} -extern "C" { - pub fn CreateEnzymeLogic(PostOpt: u8) -> EnzymeLogicRef; -} -extern "C" { - pub fn ClearEnzymeLogic(arg1: EnzymeLogicRef); -} -extern "C" { - pub fn FreeEnzymeLogic(arg1: EnzymeLogicRef); -} + pub fn ClearTypeAnalysis(arg1: EnzymeTypeAnalysisRef) { unimplemented!() } + pub fn FreeTypeAnalysis(arg1: EnzymeTypeAnalysisRef) { unimplemented!() } + pub fn CreateEnzymeLogic(PostOpt: u8) -> EnzymeLogicRef { unimplemented!() } + pub fn ClearEnzymeLogic(arg1: EnzymeLogicRef) { unimplemented!() } + pub fn FreeEnzymeLogic(arg1: EnzymeLogicRef) { unimplemented!() } -extern "C" { - fn EnzymeNewTypeTreeCT(arg1: CConcreteType, ctx: &Context) -> CTypeTreeRef; - fn EnzymeNewTypeTreeTR(arg1: CTypeTreeRef) -> CTypeTreeRef; - fn EnzymeMergeTypeTree(arg1: CTypeTreeRef, arg2: CTypeTreeRef) -> bool; - fn EnzymeTypeTreeOnlyEq(arg1: CTypeTreeRef, pos: i64); - fn EnzymeTypeTreeData0Eq(arg1: CTypeTreeRef); - fn EnzymeTypeTreeShiftIndiciesEq( + pub fn EnzymeNewTypeTreeCT(arg1: CConcreteType, ctx: &Context) -> CTypeTreeRef { + unimplemented!() + } + pub fn EnzymeNewTypeTreeTR(arg1: CTypeTreeRef) -> CTypeTreeRef { + unimplemented!() + } + pub fn EnzymeMergeTypeTree(arg1: CTypeTreeRef, arg2: CTypeTreeRef) -> bool { + unimplemented!() + } + pub fn EnzymeTypeTreeOnlyEq(arg1: CTypeTreeRef, pos: i64) { + unimplemented!() + } + pub fn EnzymeTypeTreeData0Eq(arg1: CTypeTreeRef) { + unimplemented!() + } + pub fn EnzymeTypeTreeShiftIndiciesEq( arg1: CTypeTreeRef, data_layout: *const c_char, offset: i64, max_size: i64, add_offset: u64, - ); - fn EnzymeTypeTreeToStringFree(arg1: *const c_char); - fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char; + ) { + unimplemented!() + } + pub fn EnzymeTypeTreeToStringFree(arg1: *const c_char) { + unimplemented!() + } + pub fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char { + unimplemented!() + } } -pub struct TypeTree { - pub inner: CTypeTreeRef, -} -impl TypeTree { - pub fn new() -> TypeTree { - let inner = unsafe { EnzymeNewTypeTree() }; +// Enzyme specific, but doesn't require Enzyme to be build +pub use self::Shared_AD::*; +pub mod Shared_AD { + // Depending on the AD backend (Enzyme or Fallback), some functions might or might not be + // unsafe. So we just allways call them in an unsafe context. + #![allow(unused_unsafe)] + #![allow(unused_variables)] + + use libc::size_t; + use super::Context; + + #[cfg(AD_FALLBACK)] + use super::Fallback_AD::*; + #[cfg(not(AD_FALLBACK))] + use super::Enzyme_AD::*; + + use core::fmt; + use std::ffi::{CStr, CString}; + use rustc_ast::expand::autodiff_attrs::DiffActivity; + #[repr(u32)] + #[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] + pub enum CDIFFE_TYPE { + DFT_OUT_DIFF = 0, + DFT_DUP_ARG = 1, + DFT_CONSTANT = 2, + DFT_DUP_NONEED = 3, + } - TypeTree { inner } + pub fn cdiffe_from(act: DiffActivity) -> CDIFFE_TYPE { + return match act { + DiffActivity::None => CDIFFE_TYPE::DFT_CONSTANT, + DiffActivity::Const => CDIFFE_TYPE::DFT_CONSTANT, + DiffActivity::Active => CDIFFE_TYPE::DFT_OUT_DIFF, + DiffActivity::ActiveOnly => CDIFFE_TYPE::DFT_OUT_DIFF, + DiffActivity::Dual => CDIFFE_TYPE::DFT_DUP_ARG, + DiffActivity::DualOnly => CDIFFE_TYPE::DFT_DUP_NONEED, + DiffActivity::Duplicated => CDIFFE_TYPE::DFT_DUP_ARG, + DiffActivity::DuplicatedOnly => CDIFFE_TYPE::DFT_DUP_NONEED, + }; } - #[must_use] - pub fn from_type(t: CConcreteType, ctx: &Context) -> TypeTree { - let inner = unsafe { EnzymeNewTypeTreeCT(t, ctx) }; + #[repr(u32)] + #[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] + pub enum CDerivativeMode { + DEM_ForwardMode = 0, + DEM_ReverseModePrimal = 1, + DEM_ReverseModeGradient = 2, + DEM_ReverseModeCombined = 3, + DEM_ForwardModeSplit = 4, + } + #[repr(C)] + #[derive(Debug, Copy, Clone)] + pub struct EnzymeOpaqueTypeAnalysis { + _unused: [u8; 0], + } + pub type EnzymeTypeAnalysisRef = *mut EnzymeOpaqueTypeAnalysis; + #[repr(C)] + #[derive(Debug, Copy, Clone)] + pub struct EnzymeOpaqueLogic { + _unused: [u8; 0], + } + pub type EnzymeLogicRef = *mut EnzymeOpaqueLogic; + #[repr(C)] + #[derive(Debug, Copy, Clone)] + pub struct EnzymeOpaqueAugmentedReturn { + _unused: [u8; 0], + } + pub type EnzymeAugmentedReturnPtr = *mut EnzymeOpaqueAugmentedReturn; + #[repr(C)] + #[derive(Debug, Copy, Clone)] + pub struct IntList { + pub data: *mut i64, + pub size: size_t, + } + #[repr(u32)] + #[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] + pub enum CConcreteType { + DT_Anything = 0, + DT_Integer = 1, + DT_Pointer = 2, + DT_Half = 3, + DT_Float = 4, + DT_Double = 5, + DT_Unknown = 6, + } - TypeTree { inner } + pub type CTypeTreeRef = *mut EnzymeTypeTree; + #[repr(C)] + #[derive(Debug, Copy, Clone)] + pub struct EnzymeTypeTree { + _unused: [u8; 0], + } + pub struct TypeTree { + pub inner: CTypeTreeRef, } - #[must_use] - pub fn only(self, idx: isize) -> TypeTree { - unsafe { - EnzymeTypeTreeOnlyEq(self.inner, idx as i64); + impl TypeTree { + pub fn new() -> TypeTree { + let inner = unsafe { EnzymeNewTypeTree() }; + TypeTree { inner } } - self - } - #[must_use] - pub fn data0(self) -> TypeTree { - unsafe { - EnzymeTypeTreeData0Eq(self.inner); + #[must_use] + pub fn from_type(t: CConcreteType, ctx: &Context) -> TypeTree { + let inner = unsafe { EnzymeNewTypeTreeCT(t, ctx) }; + TypeTree { inner } + } + + #[must_use] + pub fn only(self, idx: isize) -> TypeTree { + unsafe { + EnzymeTypeTreeOnlyEq(self.inner, idx as i64); + } + self + } + + #[must_use] + pub fn data0(self) -> TypeTree { + unsafe { + EnzymeTypeTreeData0Eq(self.inner); + } + self + } + + pub fn merge(self, other: Self) -> Self { + unsafe { + EnzymeMergeTypeTree(self.inner, other.inner); + } + drop(other); + self + } + + #[must_use] + pub fn shift(self, layout: &str, offset: isize, max_size: isize, add_offset: usize) -> Self { + let layout = CString::new(layout).unwrap(); + + unsafe { + EnzymeTypeTreeShiftIndiciesEq( + self.inner, + layout.as_ptr(), + offset as i64, + max_size as i64, + add_offset as u64, + ) + } + + self } - self } - pub fn merge(self, other: Self) -> Self { - unsafe { - EnzymeMergeTypeTree(self.inner, other.inner); + impl Clone for TypeTree { + fn clone(&self) -> Self { + let inner = unsafe { EnzymeNewTypeTreeTR(self.inner) }; + TypeTree { inner } } - drop(other); + } + + impl fmt::Display for TypeTree { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let ptr = unsafe { EnzymeTypeTreeToString(self.inner) }; + let cstr = unsafe { CStr::from_ptr(ptr) }; + match cstr.to_str() { + Ok(x) => write!(f, "{}", x)?, + Err(err) => write!(f, "could not parse: {}", err)?, + } - self + // delete C string pointer + unsafe { EnzymeTypeTreeToStringFree(ptr) } + + Ok(()) + } } - #[must_use] - pub fn shift(self, layout: &str, offset: isize, max_size: isize, add_offset: usize) -> Self { - let layout = CString::new(layout).unwrap(); - - unsafe { - EnzymeTypeTreeShiftIndiciesEq( - self.inner, - layout.as_ptr(), - offset as i64, - max_size as i64, - add_offset as u64, - ) + impl fmt::Debug for TypeTree { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + ::fmt(self, f) } + } - self + impl Drop for TypeTree { + fn drop(&mut self) { + unsafe { EnzymeFreeTypeTree(self.inner) } + } } -} -impl Clone for TypeTree { - fn clone(&self) -> Self { - let inner = unsafe { EnzymeNewTypeTreeTR(self.inner) }; - TypeTree { inner } + #[repr(C)] + #[derive(Debug, Copy, Clone)] + pub struct CFnTypeInfo { + #[doc = " Types of arguments, assumed of size len(Arguments)"] + pub Arguments: *mut CTypeTreeRef, + #[doc = " Type of return"] + pub Return: CTypeTreeRef, + #[doc = " The specific constant(s) known to represented by an argument, if constant"] + pub KnownValues: *mut IntList, } } -impl fmt::Display for TypeTree { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let ptr = unsafe { EnzymeTypeTreeToString(self.inner) }; - let cstr = unsafe { CStr::from_ptr(ptr) }; - match cstr.to_str() { - Ok(x) => write!(f, "{}", x)?, - Err(err) => write!(f, "could not parse: {}", err)?, - } +#[cfg(not(AD_FALLBACK))] +pub use self::Enzyme_AD::*; - // delete C string pointer - unsafe { EnzymeTypeTreeToStringFree(ptr) } +// Enzyme is an optional component, so we do need to provide a fallback when it is ont getting +// compiled. We deny the usage of #[autodiff(..)] on a higher level, so a placeholder implementation +// here is completely fine. +#[cfg(not(AD_FALLBACK))] +pub mod Enzyme_AD { +use super::*; - Ok(()) - } +use super::debuginfo::{ + DIArray, DIBasicType, DIBuilder, DICompositeType, DIDerivedType, DIDescriptor, DIEnumerator, + DIFile, DIFlags, DIGlobalVariableExpression, DILexicalBlock, DILocation, DINameSpace, + DISPFlags, DIScope, DISubprogram, DISubrange, DITemplateTypeParameter, DIType, DIVariable, + DebugEmissionKind, DebugNameTableKind, +}; + +use libc::{c_char, c_int, c_uint, size_t}; +use libc::{c_ulonglong, c_void}; + +use std::marker::PhantomData; + +use super::RustString; +use core::fmt; +use std::ffi::{CStr, CString}; + +extern "C" { + fn EnzymeNewTypeTree() -> CTypeTreeRef; + fn EnzymeFreeTypeTree(CTT: CTypeTreeRef); + pub fn EnzymeSetCLBool(arg1: *mut ::std::os::raw::c_void, arg2: u8); + pub fn EnzymeSetCLInteger(arg1: *mut ::std::os::raw::c_void, arg2: i64); } -impl fmt::Debug for TypeTree { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - ::fmt(self, f) +extern "C" { + static mut MaxIntOffset: c_void; + static mut MaxTypeOffset: c_void; + static mut EnzymeMaxTypeDepth: c_void; + + static mut EnzymePrintPerf: c_void; + static mut EnzymePrintActivity: c_void; + static mut EnzymePrintType: c_void; + static mut EnzymePrint: c_void; + static mut EnzymeStrictAliasing: c_void; +} +pub fn set_max_int_offset(offset: u64) { + unsafe { + EnzymeSetCLInteger(std::ptr::addr_of_mut!(llvm::MaxIntOffset), offset); } } - -impl Drop for TypeTree { - fn drop(&mut self) { - unsafe { EnzymeFreeTypeTree(self.inner) } +pub fn set_max_type_offset(offset: u64) { + unsafe { + EnzymeSetCLInteger(std::ptr::addr_of_mut!(llvm::MaxTypeOffset), offset); + } +} +pub fn set_max_type_depth(depth: u64) { + unsafe { + EnzymeSetCLInteger(std::ptr::addr_of_mut!(llvm::EnzymeMaxTypeDepth), depth); + } +} +pub fn set_print_perf(print: bool) { + unsafe { + EnzymeSetCLBool(std::ptr::addr_of_mut!(llvm::EnzymePrintPerf), print as u8); + } +} +pub fn set_print_activity(print: bool) { + unsafe { + EnzymeSetCLBool(std::ptr::addr_of_mut!(llvm::EnzymePrintActivity), print as u8); + } +} +pub fn set_print_type(print: bool) { + unsafe { + EnzymeSetCLBool(std::ptr::addr_of_mut!(llvm::EnzymePrintType), print as u8); + } +} +pub fn set_print(print: bool) { + unsafe { + EnzymeSetCLBool(std::ptr::addr_of_mut!(llvm::EnzymePrint), print as u8); } } +pub fn set_strict_aliasing(strict: bool) { + unsafe { + EnzymeSetCLBool(std::ptr::addr_of_mut!(llvm::EnzymeStrictAliasing), strict as u8); + } +} +extern "C" { + pub fn EnzymeCreatePrimalAndGradient<'a>( + arg1: EnzymeLogicRef, + _builderCtx: *const u8, // &'a Builder<'_>, + _callerCtx: *const u8, // &'a Value, + todiff: &'a Value, + retType: CDIFFE_TYPE, + constant_args: *const CDIFFE_TYPE, + constant_args_size: size_t, + TA: EnzymeTypeAnalysisRef, + returnValue: u8, + dretUsed: u8, + mode: CDerivativeMode, + width: ::std::os::raw::c_uint, + freeMemory: u8, + additionalArg: Option<&Type>, + forceAnonymousTape: u8, + typeInfo: CFnTypeInfo, + _uncacheable_args: *const u8, + uncacheable_args_size: size_t, + augmented: EnzymeAugmentedReturnPtr, + AtomicAdd: u8, + ) -> &'a Value; +} +extern "C" { + pub fn EnzymeCreateForwardDiff<'a>( + arg1: EnzymeLogicRef, + _builderCtx: *const u8, // &'a Builder<'_>, + _callerCtx: *const u8, // &'a Value, + todiff: &'a Value, + retType: CDIFFE_TYPE, + constant_args: *const CDIFFE_TYPE, + constant_args_size: size_t, + TA: EnzymeTypeAnalysisRef, + returnValue: u8, + mode: CDerivativeMode, + freeMemory: u8, + width: ::std::os::raw::c_uint, + additionalArg: Option<&Type>, + typeInfo: CFnTypeInfo, + _uncacheable_args: *const u8, + uncacheable_args_size: size_t, + augmented: EnzymeAugmentedReturnPtr, + ) -> &'a Value; +} +pub type CustomRuleType = ::std::option::Option< + unsafe extern "C" fn( + direction: ::std::os::raw::c_int, + ret: CTypeTreeRef, + args: *mut CTypeTreeRef, + known_values: *mut IntList, + num_args: size_t, + fnc: &Value, + ta: *const ::std::os::raw::c_void, + ) -> u8, +>; +extern "C" { + pub fn CreateTypeAnalysis( + Log: EnzymeLogicRef, + customRuleNames: *mut *mut ::std::os::raw::c_char, + customRules: *mut CustomRuleType, + numRules: size_t, + ) -> EnzymeTypeAnalysisRef; +} +//extern "C" { +// pub fn ClearTypeAnalysis(arg1: EnzymeTypeAnalysisRef); +// pub fn FreeTypeAnalysis(arg1: EnzymeTypeAnalysisRef); +// pub fn CreateEnzymeLogic(PostOpt: u8) -> EnzymeLogicRef; +// pub fn ClearEnzymeLogic(arg1: EnzymeLogicRef); +// pub fn FreeEnzymeLogic(arg1: EnzymeLogicRef); +//} + +extern "C" { + fn EnzymeNewTypeTreeCT(arg1: CConcreteType, ctx: &Context) -> CTypeTreeRef; + fn EnzymeNewTypeTreeTR(arg1: CTypeTreeRef) -> CTypeTreeRef; + fn EnzymeMergeTypeTree(arg1: CTypeTreeRef, arg2: CTypeTreeRef) -> bool; + fn EnzymeTypeTreeOnlyEq(arg1: CTypeTreeRef, pos: i64); + fn EnzymeTypeTreeData0Eq(arg1: CTypeTreeRef); + fn EnzymeTypeTreeShiftIndiciesEq( + arg1: CTypeTreeRef, + data_layout: *const c_char, + offset: i64, + max_size: i64, + add_offset: u64, + ); + fn EnzymeTypeTreeToStringFree(arg1: *const c_char); + fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char; +} +} diff --git a/config.example.toml b/config.example.toml index 4cf7c1e81990c..53995ac7d7202 100644 --- a/config.example.toml +++ b/config.example.toml @@ -149,6 +149,9 @@ # Whether to build the clang compiler. #clang = false +# Wheter to build Enzyme as AutoDiff backend. +#enzyme = true + # Whether to enable llvm compilation warnings. #enable-warnings = false From 293ab97828f5f2ed59aaf274974f39a4a6da8ad3 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Fri, 9 Feb 2024 19:43:10 -0500 Subject: [PATCH 029/100] adding CI and updating README --- README.md | 63 +++++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 61 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 5d5beaf1b7a20..6a37cac337ce0 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,68 @@ -# The Rust Programming Language +# The Rust Programming Language + AutoDiff [![Rust Community](https://img.shields.io/badge/Rust_Community%20-Join_us-brightgreen?style=plastic&logo=rust)](https://www.rust-lang.org/community) This is the main source code repository for [Rust]. It contains the compiler, -standard library, and documentation. +standard library, and documentation. It is modified to use Enzyme for AutoDiff. + + +Please configure this fork using the following command: + +``` +mkdir build +cd build +../configure --enable-llvm-link-shared --enable-llvm-plugins --enable-llvm-enzyme --release-channel=nightly --enable-llvm-assertions --enable-clang --enable-lld --enable-option-checking --enable-ninja --disable-docs +``` + +Afterwards you can build rustc using: +``` +../x.py build --stage 1 library/std library/proc_macro library/test tools/rustdoc +``` + +Afterwards rustc toolchain link will allow you to use it through cargo: +``` +rustup toolchain link enzyme `pwd`/build/`rustup target list --installed`/stage1 +rustup toolchain install nightly # enables -Z unstable-options +``` + +You can then look at examples in the `library/autodiff/examples/*` folder and run them with + +```bash +# rosenbrock forward iteration +cargo +enzyme run --example rosenbrock_fwd_iter --release + +# or all of them +cargo +enzyme test --examples +``` + +## Enzyme Config +To help with debugging, Enzyme can be configured using environment variables. +```bash +export ENZYME_PRINT_TA=1 +export ENZYME_PRINT_AA=1 +export ENZYME_PRINT=1 +export ENZYME_PRINT_MOD=1 +export ENZYME_PRINT_MOD_AFTER=1 +``` +The first three will print TypeAnalysis, ActivityAnalysis and the llvm-ir on a function basis, respectively. +The last two variables will print the whole module directly before and after Enzyme differented the functions. + +When experimenting with flags please make sure that EnzymeStrictAliasing=0 +is not changed, since it is required for Enzyme to handle enums correctly. + +## Bug reporting +Bugs are pretty much expected at this point of the development process. +In order to help us please minimize the Rust code as far as possible. +This tool might be a nicer helper: https://github.com/Nilstrieb/cargo-minimize +If you have some knowledge of LLVM-IR we also greatly appreciate it if you could help +us by compiling your minimized Rust code to LLVM-IR and reducing it further. + +The only exception to this strategy is error based on "Can not deduce type of X", +where reducing your example will make it harder for us to understand the origin of the bug. +In this case please just try to inline all dependencies into a single crate or even file, +without deleting used code. + + [Rust]: https://www.rust-lang.org/ From 4889ad3565a1cb9be06b7c4332591a092cac9c15 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Sun, 11 Feb 2024 02:12:41 -0500 Subject: [PATCH 030/100] fix mem leak, fix logic bug, cleanup --- compiler/rustc_codegen_llvm/src/back/write.rs | 4 +- compiler/rustc_codegen_llvm/src/llvm/ffi.rs | 86 +++++++++---------- 2 files changed, 43 insertions(+), 47 deletions(-) diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs index feec822f38f9e..1b6cdd7ec3db5 100644 --- a/compiler/rustc_codegen_llvm/src/back/write.rs +++ b/compiler/rustc_codegen_llvm/src/back/write.rs @@ -11,7 +11,7 @@ use crate::errors::{ }; use crate::llvm::{self, DiagnosticInfo, PassManager}; use crate::llvm::{ - enzyme_rust_forward_diff, enzyme_rust_reverse_diff, AttributeKind, BasicBlock, + enzyme_rust_forward_diff, enzyme_rust_reverse_diff, AttributeKind, BasicBlock, FreeTypeAnalysis, CreateEnzymeLogic, CreateTypeAnalysis, EnzymeLogicRef, EnzymeTypeAnalysisRef, LLVMAddFunction, LLVMAppendBasicBlockInContext, LLVMBuildCall2, LLVMBuildExtractValue, LLVMBuildRet, LLVMCountParams, LLVMCountStructElementTypes, LLVMCreateBuilderInContext, @@ -809,6 +809,8 @@ pub(crate) unsafe fn enzyme_ad( LLVMSetValueName2(res, name2.as_ptr(), rust_name2.len()); LLVMReplaceAllUsesWith(target_fnc, res); LLVMDeleteFunction(target_fnc); + // TODO: implement drop for wrapper type? + FreeTypeAnalysis(type_analysis); Ok(()) } diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index 8621f0424b047..bdb0778415310 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -925,13 +925,17 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff( input_tts: Vec, output_tt: TypeTree, ) -> &Value { - let (primary_ret, diff_ret, ret_activity) = match ret_activity { - DiffActivity::Const => (true, false, CDIFFE_TYPE::DFT_CONSTANT), - DiffActivity::Active => (true, true, CDIFFE_TYPE::DFT_DUP_ARG), - DiffActivity::ActiveOnly => (false, true, CDIFFE_TYPE::DFT_DUP_NONEED), - DiffActivity::None => (false, false, CDIFFE_TYPE::DFT_CONSTANT), + let (primary_ret, ret_activity) = match ret_activity { + DiffActivity::Const => (true, CDIFFE_TYPE::DFT_CONSTANT), + DiffActivity::Active => (true, CDIFFE_TYPE::DFT_DUP_ARG), + DiffActivity::None => (false, CDIFFE_TYPE::DFT_CONSTANT), _ => panic!("Invalid return activity"), }; + // This only is needed for split-mode AD, which we don't support. + // See Julia: + // https://github.com/EnzymeAD/Enzyme.jl/blob/a511e4e6979d6161699f5c9919d49801c0764a09/src/compiler.jl#L3132 + // https://github.com/EnzymeAD/Enzyme.jl/blob/a511e4e6979d6161699f5c9919d49801c0764a09/src/compiler.jl#L3092 + let diff_ret = false; let input_activity: Vec = input_activity.iter().map(|&x| cdiffe_from(x)).collect(); @@ -2690,7 +2694,7 @@ extern "C" { numRules: size_t, ) -> EnzymeTypeAnalysisRef; } - pub fn ClearTypeAnalysis(arg1: EnzymeTypeAnalysisRef) { unimplemented!() } + //pub fn ClearTypeAnalysis(arg1: EnzymeTypeAnalysisRef) { unimplemented!() } pub fn FreeTypeAnalysis(arg1: EnzymeTypeAnalysisRef) { unimplemented!() } pub fn CreateEnzymeLogic(PostOpt: u8) -> EnzymeLogicRef { unimplemented!() } pub fn ClearEnzymeLogic(arg1: EnzymeLogicRef) { unimplemented!() } @@ -2936,25 +2940,12 @@ pub use self::Enzyme_AD::*; pub mod Enzyme_AD { use super::*; -use super::debuginfo::{ - DIArray, DIBasicType, DIBuilder, DICompositeType, DIDerivedType, DIDescriptor, DIEnumerator, - DIFile, DIFlags, DIGlobalVariableExpression, DILexicalBlock, DILocation, DINameSpace, - DISPFlags, DIScope, DISubprogram, DISubrange, DITemplateTypeParameter, DIType, DIVariable, - DebugEmissionKind, DebugNameTableKind, -}; - -use libc::{c_char, c_int, c_uint, size_t}; -use libc::{c_ulonglong, c_void}; - -use std::marker::PhantomData; - -use super::RustString; -use core::fmt; -use std::ffi::{CStr, CString}; +use libc::{c_char, size_t}; +use libc::c_void; extern "C" { - fn EnzymeNewTypeTree() -> CTypeTreeRef; - fn EnzymeFreeTypeTree(CTT: CTypeTreeRef); + pub fn EnzymeNewTypeTree() -> CTypeTreeRef; + pub fn EnzymeFreeTypeTree(CTT: CTypeTreeRef); pub fn EnzymeSetCLBool(arg1: *mut ::std::os::raw::c_void, arg2: u8); pub fn EnzymeSetCLInteger(arg1: *mut ::std::os::raw::c_void, arg2: i64); } @@ -2971,43 +2962,46 @@ extern "C" { static mut EnzymeStrictAliasing: c_void; } pub fn set_max_int_offset(offset: u64) { + let offset = offset.try_into().unwrap(); unsafe { - EnzymeSetCLInteger(std::ptr::addr_of_mut!(llvm::MaxIntOffset), offset); + EnzymeSetCLInteger(std::ptr::addr_of_mut!(MaxIntOffset), offset); } } pub fn set_max_type_offset(offset: u64) { + let offset = offset.try_into().unwrap(); unsafe { - EnzymeSetCLInteger(std::ptr::addr_of_mut!(llvm::MaxTypeOffset), offset); + EnzymeSetCLInteger(std::ptr::addr_of_mut!(MaxTypeOffset), offset); } } pub fn set_max_type_depth(depth: u64) { + let depth = depth.try_into().unwrap(); unsafe { - EnzymeSetCLInteger(std::ptr::addr_of_mut!(llvm::EnzymeMaxTypeDepth), depth); + EnzymeSetCLInteger(std::ptr::addr_of_mut!(EnzymeMaxTypeDepth), depth); } } pub fn set_print_perf(print: bool) { unsafe { - EnzymeSetCLBool(std::ptr::addr_of_mut!(llvm::EnzymePrintPerf), print as u8); + EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrintPerf), print as u8); } } pub fn set_print_activity(print: bool) { unsafe { - EnzymeSetCLBool(std::ptr::addr_of_mut!(llvm::EnzymePrintActivity), print as u8); + EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrintActivity), print as u8); } } pub fn set_print_type(print: bool) { unsafe { - EnzymeSetCLBool(std::ptr::addr_of_mut!(llvm::EnzymePrintType), print as u8); + EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrintType), print as u8); } } pub fn set_print(print: bool) { unsafe { - EnzymeSetCLBool(std::ptr::addr_of_mut!(llvm::EnzymePrint), print as u8); + EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrint), print as u8); } } pub fn set_strict_aliasing(strict: bool) { unsafe { - EnzymeSetCLBool(std::ptr::addr_of_mut!(llvm::EnzymeStrictAliasing), strict as u8); + EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymeStrictAliasing), strict as u8); } } extern "C" { @@ -3074,28 +3068,28 @@ extern "C" { numRules: size_t, ) -> EnzymeTypeAnalysisRef; } -//extern "C" { -// pub fn ClearTypeAnalysis(arg1: EnzymeTypeAnalysisRef); -// pub fn FreeTypeAnalysis(arg1: EnzymeTypeAnalysisRef); -// pub fn CreateEnzymeLogic(PostOpt: u8) -> EnzymeLogicRef; -// pub fn ClearEnzymeLogic(arg1: EnzymeLogicRef); -// pub fn FreeEnzymeLogic(arg1: EnzymeLogicRef); -//} +extern "C" { + //pub(super) fn ClearTypeAnalysis(arg1: EnzymeTypeAnalysisRef); + pub fn FreeTypeAnalysis(arg1: EnzymeTypeAnalysisRef); + pub fn CreateEnzymeLogic(PostOpt: u8) -> EnzymeLogicRef; + pub fn ClearEnzymeLogic(arg1: EnzymeLogicRef); + pub fn FreeEnzymeLogic(arg1: EnzymeLogicRef); +} extern "C" { - fn EnzymeNewTypeTreeCT(arg1: CConcreteType, ctx: &Context) -> CTypeTreeRef; - fn EnzymeNewTypeTreeTR(arg1: CTypeTreeRef) -> CTypeTreeRef; - fn EnzymeMergeTypeTree(arg1: CTypeTreeRef, arg2: CTypeTreeRef) -> bool; - fn EnzymeTypeTreeOnlyEq(arg1: CTypeTreeRef, pos: i64); - fn EnzymeTypeTreeData0Eq(arg1: CTypeTreeRef); - fn EnzymeTypeTreeShiftIndiciesEq( + pub(super) fn EnzymeNewTypeTreeCT(arg1: CConcreteType, ctx: &Context) -> CTypeTreeRef; + pub(super) fn EnzymeNewTypeTreeTR(arg1: CTypeTreeRef) -> CTypeTreeRef; + pub(super) fn EnzymeMergeTypeTree(arg1: CTypeTreeRef, arg2: CTypeTreeRef) -> bool; + pub(super) fn EnzymeTypeTreeOnlyEq(arg1: CTypeTreeRef, pos: i64); + pub(super) fn EnzymeTypeTreeData0Eq(arg1: CTypeTreeRef); + pub(super) fn EnzymeTypeTreeShiftIndiciesEq( arg1: CTypeTreeRef, data_layout: *const c_char, offset: i64, max_size: i64, add_offset: u64, ); - fn EnzymeTypeTreeToStringFree(arg1: *const c_char); - fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char; + pub(super) fn EnzymeTypeTreeToStringFree(arg1: *const c_char); + pub(super) fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char; } } From 769622dcf6475ab5610b788976a94cd322366232 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Sun, 11 Feb 2024 03:53:48 -0500 Subject: [PATCH 031/100] try to make enzyme configurable, not working yet --- compiler/rustc_codegen_llvm/src/llvm/ffi.rs | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index bdb0778415310..4c7705e3081ef 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -1,5 +1,6 @@ #![allow(non_camel_case_types)] #![allow(non_upper_case_globals)] +#![allow(unexpected_cfgs)] use rustc_ast::expand::autodiff_attrs::DiffActivity; @@ -2608,10 +2609,10 @@ extern "C" { ) -> *mut c_void; } -#[cfg(AD_FALLBACK)] +#[cfg(autodiff_fallback)] pub use self::Fallback_AD::*; -#[cfg(AD_FALLBACK)] +#[cfg(autodiff_fallback)] pub mod Fallback_AD { #![allow(unused_variables)] use super::*; @@ -2744,9 +2745,9 @@ pub mod Shared_AD { use libc::size_t; use super::Context; - #[cfg(AD_FALLBACK)] + #[cfg(autodiff_fallback)] use super::Fallback_AD::*; - #[cfg(not(AD_FALLBACK))] + #[cfg(not(autodiff_fallback))] use super::Enzyme_AD::*; use core::fmt; @@ -2930,13 +2931,13 @@ pub mod Shared_AD { } } -#[cfg(not(AD_FALLBACK))] +#[cfg(not(autodiff_fallback))] pub use self::Enzyme_AD::*; // Enzyme is an optional component, so we do need to provide a fallback when it is ont getting // compiled. We deny the usage of #[autodiff(..)] on a higher level, so a placeholder implementation // here is completely fine. -#[cfg(not(AD_FALLBACK))] +#[cfg(not(autodiff_fallback))] pub mod Enzyme_AD { use super::*; From e3945466a940ae7b7b271af9746efa4ee5666f01 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Sun, 11 Feb 2024 21:17:41 -0500 Subject: [PATCH 032/100] small cleanups, make Enzyme config nicer --- compiler/rustc_builtin_macros/messages.ftl | 1 + compiler/rustc_builtin_macros/src/autodiff.rs | 70 ++++++++----------- compiler/rustc_builtin_macros/src/errors.rs | 7 ++ compiler/rustc_codegen_llvm/src/llvm/ffi.rs | 14 ++-- compiler/rustc_session/src/config.rs | 2 + compiler/rustc_span/src/symbol.rs | 2 + src/bootstrap/src/core/build_steps/compile.rs | 3 + src/bootstrap/src/core/config/config.rs | 4 ++ src/bootstrap/src/lib.rs | 3 + 9 files changed, 57 insertions(+), 49 deletions(-) diff --git a/compiler/rustc_builtin_macros/messages.ftl b/compiler/rustc_builtin_macros/messages.ftl index 07c9b588e1faf..e1739fe52cc14 100644 --- a/compiler/rustc_builtin_macros/messages.ftl +++ b/compiler/rustc_builtin_macros/messages.ftl @@ -2,6 +2,7 @@ builtin_macros_alloc_error_must_be_fn = alloc_error_handler must be a function builtin_macros_alloc_must_statics = allocators must be statics builtin_macros_autodiff = autodiff must be applied to function +builtin_macros_autodiff_not_build = this rustc version does not support autodiff builtin_macros_asm_clobber_abi = clobber_abi builtin_macros_asm_clobber_no_reg = asm with `clobber_abi` must specify explicit registers for outputs diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index 5978ad05b0dfa..eeaefbaaf1a56 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -1,6 +1,4 @@ #![allow(unused_imports)] -#![allow(unused_variables)] -#![allow(unused_mut)] //use crate::util::check_builtin_macro_attribute; //use crate::util::check_autodiff; @@ -20,12 +18,25 @@ use rustc_span::Symbol; use std::string::String; use thin_vec::{thin_vec, ThinVec}; +#[cfg(llvm_enzyme)] fn first_ident(x: &NestedMetaItem) -> rustc_span::symbol::Ident { let segments = &x.meta_item().unwrap().path.segments; assert!(segments.len() == 1); segments[0].ident } +#[cfg(not(llvm_enzyme))] +pub fn expand( + ecx: &mut ExtCtxt<'_>, + _expand_span: Span, + meta_item: &ast::MetaItem, + item: Annotatable, +) -> Vec { + ecx.sess.dcx().emit_err(errors::AutoDiffSupportNotBuild { span: meta_item.span }); + return vec![item]; +} + +#[cfg(llvm_enzyme)] pub fn expand( ecx: &mut ExtCtxt<'_>, expand_span: Span, @@ -45,24 +56,16 @@ pub fn expand( let primal = orig_item.ident.clone(); // Allow using `#[autodiff(...)]` only on a Fn - let (fn_item, has_ret, sig, sig_span) = if let Annotatable::Item(item) = &item + let (has_ret, sig, sig_span) = if let Annotatable::Item(item) = &item && let ItemKind::Fn(box ast::Fn { sig, .. }) = &item.kind { - (item, sig.decl.output.has_ret(), sig, ecx.with_call_site_ctxt(sig.span)) + (sig.decl.output.has_ret(), sig, ecx.with_call_site_ctxt(sig.span)) } else { ecx.sess.dcx().emit_err(errors::AutoDiffInvalidApplication { span: item.span() }); return vec![item]; }; // create TokenStream from vec elemtents: // meta_item doesn't have a .tokens field - let ts: Vec = meta_item_vec.clone()[1..] - .iter() - .map(|x| { - let val = first_ident(x); - let t = Token::from_ast_ident(val); - t - }) - .collect(); let comma: Token = Token::new(TokenKind::Comma, Span::default()); let mut ts: Vec = vec![]; for t in meta_item_vec.clone()[1..].iter() { @@ -77,18 +80,15 @@ pub fn expand( dbg!(&x); let span = ecx.with_def_site_ctxt(expand_span); - let (d_sig, old_names, new_args, idents) = gen_enzyme_decl(&sig, &x, span); + let (d_sig, new_args, idents) = gen_enzyme_decl(&sig, &x, span); let new_decl_span = d_sig.span; let d_body = gen_enzyme_body( ecx, primal, - &old_names, &new_args, span, sig_span, new_decl_span, - &sig, - &d_sig, idents, ); let d_ident = meta_item_vec[0].meta_item().unwrap().path.segments[0].ident; @@ -102,7 +102,7 @@ pub fn expand( })); let mut rustc_ad_attr = P(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::rustc_autodiff))); - let mut attr: ast::Attribute = ast::Attribute { + let attr: ast::Attribute = ast::Attribute { kind: ast::AttrKind::Normal(rustc_ad_attr.clone()), id: ast::AttrId::from_u32(0), style: ast::AttrStyle::Outer, @@ -116,7 +116,7 @@ pub fn expand( delim: rustc_ast::token::Delimiter::Parenthesis, tokens: ts, }); - let mut attr2: ast::Attribute = ast::Attribute { + let attr2: ast::Attribute = ast::Attribute { kind: ast::AttrKind::Normal(rustc_ad_attr), id: ast::AttrId::from_u32(0), style: ast::AttrStyle::Outer, @@ -131,6 +131,7 @@ pub fn expand( } // shadow arguments must be mutable references or ptrs, because Enzyme will write into them. +#[cfg(llvm_enzyme)] fn assure_mut_ref(ty: &ast::Ty) -> ast::Ty { let mut ty = ty.clone(); match ty.kind { @@ -152,37 +153,21 @@ fn assure_mut_ref(ty: &ast::Ty) -> ast::Ty { // The second will just take a tuple containing the new arguments. // This way we surpress rustc from optimizing any argument away. // The last line will 'loop {}', to match the return type of the new function +#[cfg(llvm_enzyme)] fn gen_enzyme_body( ecx: &ExtCtxt<'_>, primal: Ident, - old_names: &[String], new_names: &[String], span: Span, sig_span: Span, new_decl_span: Span, - sig: &ast::FnSig, - d_sig: &ast::FnSig, idents: Vec, ) -> P { let blackbox_path = ecx.std_path(&[Symbol::intern("hint"), Symbol::intern("black_box")]); - let zeroed_path = ecx.std_path(&[Symbol::intern("mem"), Symbol::intern("zeroed")]); let empty_loop_block = ecx.block(span, ThinVec::new()); let loop_expr = ecx.expr_loop(span, empty_loop_block); - let blackbox_call_expr = ecx.expr_path(ecx.path(span, blackbox_path)); - let zeroed_call_expr = ecx.expr_path(ecx.path(span, zeroed_path)); - - let mem_zeroed_call: Stmt = - ecx.stmt_expr(ecx.expr_call(span, zeroed_call_expr.clone(), thin_vec![])); - let unsafe_block_with_zeroed_call: P = ecx.expr_block(P(ast::Block { - stmts: thin_vec![mem_zeroed_call], - id: ast::DUMMY_NODE_ID, - rules: ast::BlockCheckMode::Unsafe(ast::UserProvided), - span: sig_span, - tokens: None, - could_be_bare_literal: false, - })); - let primal_call = gen_primal_call(ecx, span, primal, sig, idents); + let primal_call = gen_primal_call(ecx, span, primal, idents); // create ::core::hint::black_box(array(arr)); let black_box_primal_call = ecx.expr_call(new_decl_span, blackbox_call_expr.clone(), thin_vec![primal_call.clone()]); @@ -207,11 +192,11 @@ fn gen_enzyme_body( body } +#[cfg(llvm_enzyme)] fn gen_primal_call( ecx: &ExtCtxt<'_>, span: Span, primal: Ident, - sig: &ast::FnSig, idents: Vec, ) -> P { let primal_call_expr = ecx.expr_path(ecx.path_ident(span, primal)); @@ -226,17 +211,18 @@ fn gen_primal_call( // zero-initialized by Enzyme). Active arguments are not handled yet. // Each argument of the primal function (and the return type if existing) must be annotated with an // activity. +#[cfg(llvm_enzyme)] fn gen_enzyme_decl( sig: &ast::FnSig, x: &AutoDiffAttrs, span: Span, -) -> (ast::FnSig, Vec, Vec, Vec) { +) -> (ast::FnSig, Vec, Vec) { assert!(sig.decl.inputs.len() == x.input_activity.len()); assert!(sig.decl.output.has_ret() == x.has_ret_activity()); let mut d_decl = sig.decl.clone(); let mut d_inputs = Vec::new(); let mut new_inputs = Vec::new(); - let mut old_names = Vec::new(); + //let mut old_names = Vec::new(); let mut idents = Vec::new(); let mut act_ret = ThinVec::new(); for (arg, activity) in sig.decl.inputs.iter().zip(x.input_activity.iter()) { @@ -256,7 +242,7 @@ fn gen_enzyme_decl( dbg!(&shadow_arg.pat); panic!("not an ident?"); }; - old_names.push(old_name.to_string()); + //old_names.push(old_name.to_string()); let name: String = match x.mode { DiffMode::Reverse => format!("d{}", old_name), DiffMode::Forward => format!("b{}", old_name), @@ -320,7 +306,7 @@ fn gen_enzyme_decl( // return type. This might require changing the return type to a // tuple. if act_ret.len() > 0 { - let mut ret_ty = match d_decl.output { + let ret_ty = match d_decl.output { FnRetTy::Ty(ref ty) => { act_ret.insert(0, ty.clone()); let kind = TyKind::Tup(act_ret); @@ -339,5 +325,5 @@ fn gen_enzyme_decl( } let d_sig = FnSig { header: sig.header.clone(), decl: d_decl, span }; - (d_sig, old_names, new_inputs, idents) + (d_sig, new_inputs, idents) } diff --git a/compiler/rustc_builtin_macros/src/errors.rs b/compiler/rustc_builtin_macros/src/errors.rs index 7c032d98d5d8d..a3dccaf11d50d 100644 --- a/compiler/rustc_builtin_macros/src/errors.rs +++ b/compiler/rustc_builtin_macros/src/errors.rs @@ -171,6 +171,13 @@ pub(crate) struct AutoDiffInvalidApplication { pub(crate) span: Span, } +#[derive(Diagnostic)] +#[diag(builtin_macros_autodiff_not_build)] +pub(crate) struct AutoDiffSupportNotBuild { + #[primary_span] + pub(crate) span: Span, +} + #[derive(Diagnostic)] #[diag(builtin_macros_concat_bytes_invalid)] pub(crate) struct ConcatBytesInvalid { diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index 4c7705e3081ef..ea74f1859dcca 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -1,6 +1,6 @@ #![allow(non_camel_case_types)] #![allow(non_upper_case_globals)] -#![allow(unexpected_cfgs)] +//#![allow(unexpected_cfgs)] use rustc_ast::expand::autodiff_attrs::DiffActivity; @@ -2609,10 +2609,10 @@ extern "C" { ) -> *mut c_void; } -#[cfg(autodiff_fallback)] +#[cfg(not(llvm_enzyme))] pub use self::Fallback_AD::*; -#[cfg(autodiff_fallback)] +#[cfg(not(llvm_enzyme))] pub mod Fallback_AD { #![allow(unused_variables)] use super::*; @@ -2745,9 +2745,9 @@ pub mod Shared_AD { use libc::size_t; use super::Context; - #[cfg(autodiff_fallback)] + #[cfg(not(llvm_enzyme))] use super::Fallback_AD::*; - #[cfg(not(autodiff_fallback))] + #[cfg(llvm_enzyme)] use super::Enzyme_AD::*; use core::fmt; @@ -2931,13 +2931,13 @@ pub mod Shared_AD { } } -#[cfg(not(autodiff_fallback))] +#[cfg(llvm_enzyme)] pub use self::Enzyme_AD::*; // Enzyme is an optional component, so we do need to provide a fallback when it is ont getting // compiled. We deny the usage of #[autodiff(..)] on a higher level, so a placeholder implementation // here is completely fine. -#[cfg(not(autodiff_fallback))] +#[cfg(llvm_enzyme)] pub mod Enzyme_AD { use super::*; diff --git a/compiler/rustc_session/src/config.rs b/compiler/rustc_session/src/config.rs index 0d731e330f3d9..2219fd5e951a8 100644 --- a/compiler/rustc_session/src/config.rs +++ b/compiler/rustc_session/src/config.rs @@ -1273,6 +1273,7 @@ fn default_configuration(sess: &Session) -> Cfg { // // NOTE: These insertions should be kept in sync with // `CheckCfg::fill_well_known` below. + ins_none!(sym::autodiff_fallback); if sess.opts.debug_assertions { ins_none!(sym::debug_assertions); @@ -1460,6 +1461,7 @@ impl CheckCfg { // // When adding a new config here you should also update // `tests/ui/check-cfg/well-known-values.rs`. + ins!(sym::autodiff_fallback, no_values); ins!(sym::debug_assertions, no_values); diff --git a/compiler/rustc_span/src/symbol.rs b/compiler/rustc_span/src/symbol.rs index b1c1f0feef8ad..c1783c56a0125 100644 --- a/compiler/rustc_span/src/symbol.rs +++ b/compiler/rustc_span/src/symbol.rs @@ -440,6 +440,7 @@ symbols! { augmented_assignments, auto_traits, autodiff, + autodiff_fallback, automatically_derived, avx, avx512_target_feature, @@ -493,6 +494,7 @@ symbols! { cfg_accessible, cfg_attr, cfg_attr_multi, + cfg_autodiff_fallback, cfg_doctest, cfg_eval, cfg_hide, diff --git a/src/bootstrap/src/core/build_steps/compile.rs b/src/bootstrap/src/core/build_steps/compile.rs index 4f5844ae51c80..cdb798c47d571 100644 --- a/src/bootstrap/src/core/build_steps/compile.rs +++ b/src/bootstrap/src/core/build_steps/compile.rs @@ -1048,6 +1048,9 @@ pub fn rustc_cargo_env( if builder.config.rust_verify_llvm_ir { cargo.env("RUSTC_VERIFY_LLVM_IR", "1"); } + if builder.config.llvm_enzyme { + cargo.rustflag("--cfg=llvm_enzyme"); + } // Note that this is disabled if LLVM itself is disabled or we're in a check // build. If we are in a check build we still go ahead here presuming we've diff --git a/src/bootstrap/src/core/config/config.rs b/src/bootstrap/src/core/config/config.rs index c7ab615ac4fe0..1023d3f0a584c 100644 --- a/src/bootstrap/src/core/config/config.rs +++ b/src/bootstrap/src/core/config/config.rs @@ -1095,6 +1095,7 @@ define_config! { codegen_backends: Option> = "codegen-backends", lld: Option = "lld", lld_mode: Option = "use-lld", + llvm_enzyme: Option = "llvm-enzyme", llvm_tools: Option = "llvm-tools", deny_warnings: Option = "deny-warnings", backtrace_on_ice: Option = "backtrace-on-ice", @@ -1545,6 +1546,7 @@ impl Config { save_toolstates, codegen_backends, lld, + llvm_enzyme, llvm_tools, deny_warnings, backtrace_on_ice, @@ -1634,6 +1636,8 @@ impl Config { } set(&mut config.llvm_tools_enabled, llvm_tools); + config.llvm_enzyme = + llvm_enzyme.unwrap_or(config.channel == "dev" || config.channel == "nightly"); config.rustc_parallel = parallel_compiler.unwrap_or(config.channel == "dev" || config.channel == "nightly"); config.rustc_default_linker = default_linker; diff --git a/src/bootstrap/src/lib.rs b/src/bootstrap/src/lib.rs index 3909115140bca..5038c888bca53 100644 --- a/src/bootstrap/src/lib.rs +++ b/src/bootstrap/src/lib.rs @@ -76,6 +76,9 @@ const LLD_FILE_NAMES: &[&str] = &["ld.lld", "ld64.lld", "lld-link", "wasm-ld"]; /// (Mode restriction, config name, config values (if any)) const EXTRA_CHECK_CFGS: &[(Option, &str, Option<&[&'static str]>)] = &[ (None, "bootstrap", None), + (Some(Mode::Rustc), "llvm_enzyme", None), + (Some(Mode::Codegen), "llvm_enzyme", None), + (Some(Mode::ToolRustc), "llvm_enzyme", None), (Some(Mode::Rustc), "parallel_compiler", None), (Some(Mode::ToolRustc), "parallel_compiler", None), (Some(Mode::ToolRustc), "rust_analyzer", None), From 84c418c17918f481fa67335ec0445874f575c9e9 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Sun, 11 Feb 2024 21:43:26 -0500 Subject: [PATCH 033/100] Update submodule 'enzyme' --- src/tools/enzyme | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tools/enzyme b/src/tools/enzyme index 5422797090b89..a309cc083f64f 160000 --- a/src/tools/enzyme +++ b/src/tools/enzyme @@ -1 +1 @@ -Subproject commit 5422797090b89c1e22e836eb74a852de544febc1 +Subproject commit a309cc083f64fb6c17689a81feffa804bc7fcb3d From 4c2a59f4f4a708d6b4f12b74544894335a7c7266 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Sat, 17 Feb 2024 07:39:51 -0500 Subject: [PATCH 034/100] fix reverse_mode setting, various cleanups --- compiler/rustc_builtin_macros/src/autodiff.rs | 51 ++++++++++++------- compiler/rustc_codegen_llvm/src/attributes.rs | 5 +- compiler/rustc_codegen_llvm/src/back/write.rs | 1 + compiler/rustc_codegen_llvm/src/llvm/ffi.rs | 24 ++++----- compiler/rustc_codegen_ssa/src/back/lto.rs | 2 +- compiler/rustc_codegen_ssa/src/back/write.rs | 4 +- .../rustc_codegen_ssa/src/codegen_attrs.rs | 23 ++++----- compiler/rustc_expand/src/build.rs | 3 ++ .../rustc_monomorphize/src/partitioning.rs | 12 +++-- 9 files changed, 75 insertions(+), 50 deletions(-) diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index eeaefbaaf1a56..de9f0a517fbd1 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -91,7 +91,7 @@ pub fn expand( new_decl_span, idents, ); - let d_ident = meta_item_vec[0].meta_item().unwrap().path.segments[0].ident; + let d_ident = first_ident(&meta_item_vec[0]); // The first element of it is the name of the function to be generated let asdf = ItemKind::Fn(Box::new(ast::Fn { @@ -102,11 +102,12 @@ pub fn expand( })); let mut rustc_ad_attr = P(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::rustc_autodiff))); - let attr: ast::Attribute = ast::Attribute { + let mut attr: ast::Attribute = ast::Attribute { kind: ast::AttrKind::Normal(rustc_ad_attr.clone()), - id: ast::AttrId::from_u32(0), + //id: ast::DUMMY_TR_ID, + id: ast::AttrId::from_u32(12341), // TODO: fix style: ast::AttrStyle::Outer, - span: span, + span, }; orig_item.attrs.push(attr.clone()); @@ -116,21 +117,15 @@ pub fn expand( delim: rustc_ast::token::Delimiter::Parenthesis, tokens: ts, }); - let attr2: ast::Attribute = ast::Attribute { - kind: ast::AttrKind::Normal(rustc_ad_attr), - id: ast::AttrId::from_u32(0), - style: ast::AttrStyle::Outer, - span: span, - }; - let attr_vec: rustc_ast::AttrVec = thin_vec![attr2]; - let d_fn = ecx.item(span, d_ident, attr_vec, asdf); + attr.kind = ast::AttrKind::Normal(rustc_ad_attr); + let d_fn = ecx.item(span, d_ident, thin_vec![attr], asdf); - let orig_annotatable = Annotatable::Item(orig_item.clone()); + let orig_annotatable = Annotatable::Item(orig_item); let d_annotatable = Annotatable::Item(d_fn); return vec![orig_annotatable, d_annotatable]; } -// shadow arguments must be mutable references or ptrs, because Enzyme will write into them. +// shadow arguments in reverse mode must be mutable references or ptrs, because Enzyme will write into them. #[cfg(llvm_enzyme)] fn assure_mut_ref(ty: &ast::Ty) -> ast::Ty { let mut ty = ty.clone(); @@ -165,6 +160,25 @@ fn gen_enzyme_body( ) -> P { let blackbox_path = ecx.std_path(&[Symbol::intern("hint"), Symbol::intern("black_box")]); let empty_loop_block = ecx.block(span, ThinVec::new()); + let noop = ast::InlineAsm { + template: vec![ast::InlineAsmTemplatePiece::String("NOP".to_string())], + template_strs: Box::new([]), + operands: vec![], + clobber_abis: vec![], + options: ast::InlineAsmOptions::PURE & ast::InlineAsmOptions::NOMEM, + line_spans: vec![], + }; + let noop_expr = ecx.expr_asm(span, P(noop)); + let unsf = ast::BlockCheckMode::Unsafe(ast::UnsafeSource::CompilerGenerated); + let unsf_block = ast::Block { + stmts: thin_vec![ecx.stmt_semi(noop_expr)], + id: ast::DUMMY_NODE_ID, + tokens: None, + rules: unsf, + span, + could_be_bare_literal: false, + }; + let unsf_expr = ecx.expr_block(P(unsf_block)); let loop_expr = ecx.expr_loop(span, empty_loop_block); let blackbox_call_expr = ecx.expr_path(ecx.path(span, blackbox_path)); let primal_call = gen_primal_call(ecx, span, primal, idents); @@ -185,7 +199,7 @@ fn gen_enzyme_body( ); let mut body = ecx.block(span, ThinVec::new()); - body.stmts.push(ecx.stmt_semi(primal_call)); + body.stmts.push(ecx.stmt_semi(unsf_expr)); body.stmts.push(ecx.stmt_semi(black_box_primal_call)); body.stmts.push(ecx.stmt_semi(black_box_remaining_args)); body.stmts.push(ecx.stmt_expr(loop_expr)); @@ -234,7 +248,11 @@ fn gen_enzyme_decl( } DiffActivity::Duplicated | DiffActivity::Dual => { let mut shadow_arg = arg.clone(); - shadow_arg.ty = P(assure_mut_ref(&arg.ty)); + // We += into the shadow in reverse mode. + // Otherwise copy mutability of the original argument. + if activity == &DiffActivity::Duplicated { + shadow_arg.ty = P(assure_mut_ref(&arg.ty)); + } // adjust name depending on mode let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind { ident.name @@ -242,7 +260,6 @@ fn gen_enzyme_decl( dbg!(&shadow_arg.pat); panic!("not an ident?"); }; - //old_names.push(old_name.to_string()); let name: String = match x.mode { DiffMode::Reverse => format!("d{}", old_name), DiffMode::Forward => format!("b{}", old_name), diff --git a/compiler/rustc_codegen_llvm/src/attributes.rs b/compiler/rustc_codegen_llvm/src/attributes.rs index 348551ea4e4b0..032a89c2affa7 100644 --- a/compiler/rustc_codegen_llvm/src/attributes.rs +++ b/compiler/rustc_codegen_llvm/src/attributes.rs @@ -1,5 +1,6 @@ //! Set and unset common attributes on LLVM values. +use rustc_ast::expand::autodiff_attrs::AutoDiffAttrs; use rustc_codegen_ssa::traits::*; use rustc_hir::def_id::DefId; use rustc_middle::middle::codegen_fn_attrs::CodegenFnAttrFlags; @@ -294,7 +295,7 @@ pub fn from_fn_attrs<'ll, 'tcx>( instance: ty::Instance<'tcx>, ) { let codegen_fn_attrs = cx.tcx.codegen_fn_attrs(instance.def_id()); - let autodiff_attrs = cx.tcx.autodiff_attrs(instance.def_id()); + let autodiff_attrs: &AutoDiffAttrs = cx.tcx.autodiff_attrs(instance.def_id()); let mut to_add = SmallVec::<[_; 16]>::new(); @@ -313,6 +314,8 @@ pub fn from_fn_attrs<'ll, 'tcx>( if codegen_fn_attrs.inline == InlineAttr::None && instance.def.requires_inline(cx.tcx) { InlineAttr::Hint } else if autodiff_attrs.is_active() { + dbg!("autodiff_attrs.is_active()"); + dbg!(&autodiff_attrs); InlineAttr::Never } else { codegen_fn_attrs.inline diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs index 1b6cdd7ec3db5..4f3e628bda8e5 100644 --- a/compiler/rustc_codegen_llvm/src/back/write.rs +++ b/compiler/rustc_codegen_llvm/src/back/write.rs @@ -702,6 +702,7 @@ pub(crate) unsafe fn enzyme_ad( diag_handler: &DiagCtxt, item: AutoDiffItem, ) -> Result<(), FatalError> { + dbg!("\n\n\n\n\n\n AUTO DIFF \n"); let autodiff_mode = item.attrs.mode; let rust_name = item.source; let rust_name2 = &item.target; diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index ea74f1859dcca..8414baf046160 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -852,7 +852,6 @@ pub(crate) unsafe fn enzyme_rust_forward_diff( input_tts: Vec, output_tt: TypeTree, ) -> &Value { - let mut ret_primary_ret = false; let ret_activity = cdiffe_from(ret_diffactivity); assert!(ret_activity != CDIFFE_TYPE::DFT_OUT_DIFF); let mut input_activity: Vec = vec![]; @@ -866,17 +865,12 @@ pub(crate) unsafe fn enzyme_rust_forward_diff( input_activity.push(act); } - if ret_activity == CDIFFE_TYPE::DFT_DUP_ARG { - if ret_primary_ret != true { - dbg!("overwriting ret_primary_ret!"); - } - ret_primary_ret = true; - } else if ret_activity == CDIFFE_TYPE::DFT_DUP_NONEED { - if ret_primary_ret != false { - dbg!("overwriting ret_primary_ret!"); - } - ret_primary_ret = false; - } + let ret_primary_ret = match ret_activity { + CDIFFE_TYPE::DFT_CONSTANT => true, + CDIFFE_TYPE::DFT_DUP_ARG => true, + CDIFFE_TYPE::DFT_DUP_NONEED => false, + _ => panic!("Implementation error in enzyme_rust_forward_diff."), + }; let mut args_tree = input_tts.iter().map(|x| x.inner).collect::>(); //let mut args_tree = vec![TypeTree::new().inner; typetree.input_tt.len()]; @@ -916,7 +910,6 @@ pub(crate) unsafe fn enzyme_rust_forward_diff( ) } -#[allow(dead_code)] pub(crate) unsafe fn enzyme_rust_reverse_diff( logic_ref: EnzymeLogicRef, type_analysis: EnzymeTypeAnalysisRef, @@ -928,7 +921,7 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff( ) -> &Value { let (primary_ret, ret_activity) = match ret_activity { DiffActivity::Const => (true, CDIFFE_TYPE::DFT_CONSTANT), - DiffActivity::Active => (true, CDIFFE_TYPE::DFT_DUP_ARG), + DiffActivity::Active => (true, CDIFFE_TYPE::DFT_OUT_DIFF), DiffActivity::None => (false, CDIFFE_TYPE::DFT_CONSTANT), _ => panic!("Invalid return activity"), }; @@ -962,6 +955,9 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff( KnownValues: known_values.as_mut_ptr(), }; + dbg!(&primary_ret); + dbg!(&ret_activity); + dbg!(&input_activity); let res = EnzymeCreatePrimalAndGradient( logic_ref, // Logic std::ptr::null(), diff --git a/compiler/rustc_codegen_ssa/src/back/lto.rs b/compiler/rustc_codegen_ssa/src/back/lto.rs index 1e749f570ac51..f78b74c9545be 100644 --- a/compiler/rustc_codegen_ssa/src/back/lto.rs +++ b/compiler/rustc_codegen_ssa/src/back/lto.rs @@ -90,7 +90,7 @@ impl LtoModuleCodegen { LtoModuleCodegen::Fat { ref module, .. } => { B::autodiff(cgcx, &module, diff_fncs, typetrees, config)?; } - _ => {} + _ => panic!("autodiff called with non-fat LTO module"), } Ok(self) diff --git a/compiler/rustc_codegen_ssa/src/back/write.rs b/compiler/rustc_codegen_ssa/src/back/write.rs index 04d4843911087..47da3254eb39e 100644 --- a/compiler/rustc_codegen_ssa/src/back/write.rs +++ b/compiler/rustc_codegen_ssa/src/back/write.rs @@ -382,7 +382,9 @@ fn generate_lto_work( import_only_modules: Vec<(SerializedModule, WorkProduct)>, ) -> Vec<(WorkItem, u64)> { let _prof_timer = cgcx.prof.generic_activity("codegen_generate_lto_work"); - dbg!("Differentiating {} functions", autodiff.len()); + //let error_msg = format!("Found {} Functions, but {} TypeTrees", autodiff.len(), typetrees.len()); + // Don't assert yet, bc. apparently we add them later. + //assert!(autodiff.len() == typetrees.len(), "{}", error_msg); if !needs_fat_lto.is_empty() { assert!(needs_thin_lto.is_empty()); diff --git a/compiler/rustc_codegen_ssa/src/codegen_attrs.rs b/compiler/rustc_codegen_ssa/src/codegen_attrs.rs index 08cd614cc8eb5..7edc9cf7641f7 100644 --- a/compiler/rustc_codegen_ssa/src/codegen_attrs.rs +++ b/compiler/rustc_codegen_ssa/src/codegen_attrs.rs @@ -701,33 +701,30 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs { let attrs = tcx.get_attrs(id, sym::rustc_autodiff); let attrs = attrs - .into_iter() .filter(|attr| attr.name_or_empty() == sym::rustc_autodiff) .collect::>(); - if !attrs.is_empty() { - dbg!("autodiff_attrs amount = {}", attrs.len()); - } - // check for exactly one autodiff attribute on extern block - let msg_once = "autodiff attribute can only be applied once"; - let attr = match &attrs[..] { - &[] => return AutoDiffAttrs::inactive(), - &[elm] => elm, - x => { + let msg_once = "cg_ssa: autodiff attribute can only be applied once"; + let attr = match attrs.len() { + 0 => return AutoDiffAttrs::inactive(), + 1 => attrs.get(0).unwrap(), + _ => { tcx.sess - .struct_span_err(x[1].span, msg_once) - .span_label(x[1].span, "more than one") + .struct_span_err(attrs[1].span, msg_once) + .span_label(attrs[1].span, "more than one") .emit(); - return AutoDiffAttrs::inactive(); } }; + dbg!("autodiff_attr = {:?}", &attr); let list = attr.meta_item_list().unwrap_or_default(); + dbg!("autodiff_attrs list = {:?}", &list); // empty autodiff attribute macros (i.e. `#[autodiff]`) are used to mark source functions if list.len() == 0 { + dbg!("autodiff_attrs: source"); return AutoDiffAttrs::source(); } diff --git a/compiler/rustc_expand/src/build.rs b/compiler/rustc_expand/src/build.rs index 1ee77d4a6429f..0037267f30be9 100644 --- a/compiler/rustc_expand/src/build.rs +++ b/compiler/rustc_expand/src/build.rs @@ -412,6 +412,9 @@ impl<'a> ExtCtxt<'a> { pub fn expr_loop(&self, sp: Span, block: P) -> P { self.expr(sp, ast::ExprKind::Loop(block, None, sp)) } + pub fn expr_asm(&self, sp: Span, expr: P) -> P { + self.expr(sp, ast::ExprKind::InlineAsm(expr)) + } pub fn expr_fail(&self, span: Span, msg: Symbol) -> P { self.expr_call_global( diff --git a/compiler/rustc_monomorphize/src/partitioning.rs b/compiler/rustc_monomorphize/src/partitioning.rs index ad1141726d970..8046e4392f475 100644 --- a/compiler/rustc_monomorphize/src/partitioning.rs +++ b/compiler/rustc_monomorphize/src/partitioning.rs @@ -255,9 +255,11 @@ where export_generics, ); - //if visibility == Visibility::Hidden && can_be_internalized { - let autodiff_active = - characteristic_def_id.map(|x| cx.tcx.autodiff_attrs(x).is_active()).unwrap_or(false); + // We can't differentiate something that got inlined. + let autodiff_active = match characteristic_def_id { + Some(def_id) => cx.tcx.autodiff_attrs(def_id).is_active(), + None => false, + }; if !autodiff_active && visibility == Visibility::Hidden && can_be_internalized { internalization_candidates.insert(mono_item); @@ -1166,6 +1168,10 @@ fn collect_and_partition_mono_items( .filter_map(|(item, instance)| { let target_id = instance.def_id(); let target_attrs = tcx.autodiff_attrs(target_id); + if target_attrs.is_source() { + dbg!("source"); + dbg!(&target_attrs); + } if !target_attrs.apply_autodiff() { return None; } From e54da7b4cd09cbcd1d9bf72b6a78b3a85912a44f Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Sat, 17 Feb 2024 07:56:20 -0500 Subject: [PATCH 035/100] allow differentiating a fnc multiple times --- compiler/rustc_builtin_macros/src/autodiff.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index de9f0a517fbd1..9662c0cc0e53b 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -109,7 +109,10 @@ pub fn expand( style: ast::AttrStyle::Outer, span, }; - orig_item.attrs.push(attr.clone()); + // don't add it multiple times: + if !orig_item.iter().any(|a| a.id == attr.id) { + orig_item.attrs.push(attr.clone()); + } // Now update for d_fn rustc_ad_attr.item.args = rustc_ast::AttrArgs::Delimited(rustc_ast::DelimArgs { From c90ab5ad9459bb64b79414f8b0054c102cf1e421 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Sat, 17 Feb 2024 08:45:29 -0500 Subject: [PATCH 036/100] improve inlining prevention --- compiler/rustc_builtin_macros/src/autodiff.rs | 31 ++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index 9662c0cc0e53b..d51ef51872249 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -102,6 +102,25 @@ pub fn expand( })); let mut rustc_ad_attr = P(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::rustc_autodiff))); + let ts2: Vec = vec![ + TokenTree::Token( + Token::new(TokenKind::Ident(sym::never, false), span), + Spacing::Joint, + )]; + let never_arg = ast::DelimArgs { + dspan: ast::tokenstream::DelimSpan::from_single(span), + delim: ast::token::Delimiter::Parenthesis, + tokens: ast::tokenstream::TokenStream::from_iter(ts2), + }; + let inline_item = ast::AttrItem { + path: ast::Path::from_ident(Ident::with_dummy_span(sym::inline)), + args: ast::AttrArgs::Delimited(never_arg), + tokens: None, + }; + let inline_never_attr = P(ast::NormalAttr { + item: inline_item, + tokens: None, + }); let mut attr: ast::Attribute = ast::Attribute { kind: ast::AttrKind::Normal(rustc_ad_attr.clone()), //id: ast::DUMMY_TR_ID, @@ -109,10 +128,20 @@ pub fn expand( style: ast::AttrStyle::Outer, span, }; + let inline_never : ast::Attribute = ast::Attribute { + kind: ast::AttrKind::Normal(inline_never_attr), + //id: ast::DUMMY_TR_ID, + id: ast::AttrId::from_u32(12342), // TODO: fix + style: ast::AttrStyle::Outer, + span, + }; // don't add it multiple times: - if !orig_item.iter().any(|a| a.id == attr.id) { + if !orig_item.attrs.iter().any(|a| a.id == attr.id) { orig_item.attrs.push(attr.clone()); } + if !orig_item.attrs.iter().any(|a| a.id == inline_never.id) { + orig_item.attrs.push(inline_never); + } // Now update for d_fn rustc_ad_attr.item.args = rustc_ast::AttrArgs::Delimited(rustc_ast::DelimArgs { From f8c263bce116d9a0dcb6fe236c8ab0b7932639e9 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Sun, 18 Feb 2024 02:10:25 -0500 Subject: [PATCH 037/100] use propper rustc error handler for mode/act check --- .../rustc_ast/src/expand/autodiff_attrs.rs | 52 ++++++++++---- compiler/rustc_builtin_macros/messages.ftl | 1 + compiler/rustc_builtin_macros/src/autodiff.rs | 69 +++++++++++++++---- compiler/rustc_builtin_macros/src/errors.rs | 9 +++ 4 files changed, 105 insertions(+), 26 deletions(-) diff --git a/compiler/rustc_ast/src/expand/autodiff_attrs.rs b/compiler/rustc_ast/src/expand/autodiff_attrs.rs index 5db90e94c6f8c..10958ec1bd434 100644 --- a/compiler/rustc_ast/src/expand/autodiff_attrs.rs +++ b/compiler/rustc_ast/src/expand/autodiff_attrs.rs @@ -1,7 +1,7 @@ use crate::expand::typetree::TypeTree; use std::str::FromStr; use thin_vec::ThinVec; - +use std::fmt::{Display, Formatter}; use crate::NestedMetaItem; #[allow(dead_code)] @@ -13,6 +13,17 @@ pub enum DiffMode { Reverse, } +impl Display for DiffMode { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + DiffMode::Inactive => write!(f, "Inactive"), + DiffMode::Source => write!(f, "Source"), + DiffMode::Forward => write!(f, "Forward"), + DiffMode::Reverse => write!(f, "Reverse"), + } + } +} + pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool { match mode { DiffMode::Inactive => false, @@ -30,26 +41,28 @@ pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool { } } -pub fn valid_input_activities(mode: DiffMode, activity_vec: &[DiffActivity]) -> bool { - for &activity in activity_vec { - let valid = match mode { +pub fn valid_input_activity(mode: DiffMode, activity: DiffActivity) -> bool { + return match mode { DiffMode::Inactive => false, DiffMode::Source => false, DiffMode::Forward => { // These are the only valid cases - activity == DiffActivity::Dual || activity == DiffActivity::DualOnly || activity == DiffActivity::Const + activity == DiffActivity::Dual || + activity == DiffActivity::DualOnly || + activity == DiffActivity::Const } DiffMode::Reverse => { // These are the only valid cases - activity == DiffActivity::Active || activity == DiffActivity::ActiveOnly || activity == DiffActivity::Const - || activity == DiffActivity::Duplicated || activity == DiffActivity::DuplicatedOnly + activity == DiffActivity::Active || + activity == DiffActivity::ActiveOnly || + activity == DiffActivity::Const || + activity == DiffActivity::Duplicated || + activity == DiffActivity::DuplicatedOnly } }; - if !valid { - return false; - } - } - true +} +pub fn valid_input_activities(mode: DiffMode, activity_vec: &[DiffActivity]) -> bool { + return activity_vec.iter().any(|&x| !valid_input_activity(mode, x)); } #[allow(dead_code)] @@ -65,6 +78,21 @@ pub enum DiffActivity { DuplicatedOnly, } +impl Display for DiffActivity { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + DiffActivity::None => write!(f, "None"), + DiffActivity::Const => write!(f, "Const"), + DiffActivity::Active => write!(f, "Active"), + DiffActivity::ActiveOnly => write!(f, "ActiveOnly"), + DiffActivity::Dual => write!(f, "Dual"), + DiffActivity::DualOnly => write!(f, "DualOnly"), + DiffActivity::Duplicated => write!(f, "Duplicated"), + DiffActivity::DuplicatedOnly => write!(f, "DuplicatedOnly"), + } + } +} + impl FromStr for DiffMode { type Err = (); diff --git a/compiler/rustc_builtin_macros/messages.ftl b/compiler/rustc_builtin_macros/messages.ftl index e1739fe52cc14..f0ab49b0c3fe9 100644 --- a/compiler/rustc_builtin_macros/messages.ftl +++ b/compiler/rustc_builtin_macros/messages.ftl @@ -3,6 +3,7 @@ builtin_macros_alloc_must_statics = allocators must be statics builtin_macros_autodiff = autodiff must be applied to function builtin_macros_autodiff_not_build = this rustc version does not support autodiff +builtin_macros_autodiff_mode_activity = {$act} can not be used in {$mode} Mode builtin_macros_asm_clobber_abi = clobber_abi builtin_macros_asm_clobber_no_reg = asm with `clobber_abi` must specify explicit registers for outputs diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index d51ef51872249..56d827b620cf7 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -3,7 +3,7 @@ //use crate::util::check_autodiff; use crate::errors; -use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode}; +use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode, valid_input_activity}; use rustc_ast::ptr::P; use rustc_ast::token::{Token, TokenKind}; use rustc_ast::tokenstream::*; @@ -80,7 +80,7 @@ pub fn expand( dbg!(&x); let span = ecx.with_def_site_ctxt(expand_span); - let (d_sig, new_args, idents) = gen_enzyme_decl(&sig, &x, span); + let (d_sig, new_args, idents) = gen_enzyme_decl(ecx, &sig, &x, span); let new_decl_span = d_sig.span; let d_body = gen_enzyme_body( ecx, @@ -175,6 +175,26 @@ fn assure_mut_ref(ty: &ast::Ty) -> ast::Ty { ty } +// TODO We should make this more robust to also +// accept aliases of f32 and f64 +#[cfg(llvm_enzyme)] +fn is_float(ty: &ast::Ty) -> bool { + match ty.kind { + TyKind::Path(_, ref path) => { + let last = path.segments.last().unwrap(); + last.ident.name == sym::f32 || last.ident.name == sym::f64 + } + _ => false, + } +} +#[cfg(llvm_enzyme)] +fn is_ptr_or_ref(ty: &ast::Ty) -> bool { + match ty.kind { + TyKind::Ptr(_) | TyKind::Ref(_, _) => true, + _ => false, + } +} + // The body of our generated functions will consist of two black_Box calls. // The first will call the primal function with the original arguments. // The second will just take a tuple containing the new arguments. @@ -259,6 +279,7 @@ fn gen_primal_call( // activity. #[cfg(llvm_enzyme)] fn gen_enzyme_decl( + ecx: &ExtCtxt<'_>, sig: &ast::FnSig, x: &AutoDiffAttrs, span: Span, @@ -273,31 +294,50 @@ fn gen_enzyme_decl( let mut act_ret = ThinVec::new(); for (arg, activity) in sig.decl.inputs.iter().zip(x.input_activity.iter()) { d_inputs.push(arg.clone()); + if !valid_input_activity(x.mode, *activity) { + ecx.sess.dcx().emit_err(errors::AutoDiffInvalidApplicationModeAct { + span, + mode: x.mode.to_string(), + act: activity.to_string() + }); + } match activity { DiffActivity::Active => { - assert!(x.mode == DiffMode::Reverse); + assert!(is_float(&arg.ty)); act_ret.push(arg.ty.clone()); } - DiffActivity::Duplicated | DiffActivity::Dual => { + DiffActivity::Duplicated => { + assert!(is_ptr_or_ref(&arg.ty)); let mut shadow_arg = arg.clone(); // We += into the shadow in reverse mode. - // Otherwise copy mutability of the original argument. - if activity == &DiffActivity::Duplicated { - shadow_arg.ty = P(assure_mut_ref(&arg.ty)); - } - // adjust name depending on mode + shadow_arg.ty = P(assure_mut_ref(&arg.ty)); let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind { ident.name } else { dbg!(&shadow_arg.pat); panic!("not an ident?"); }; - let name: String = match x.mode { - DiffMode::Reverse => format!("d{}", old_name), - DiffMode::Forward => format!("b{}", old_name), - _ => panic!("unsupported mode: {}", old_name), + let name: String = format!("d{}", old_name); + new_inputs.push(name.clone()); + let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span); + shadow_arg.pat = P(ast::Pat { + // TODO: Check id + id: ast::DUMMY_NODE_ID, + kind: PatKind::Ident(BindingAnnotation::NONE, ident, None), + span: shadow_arg.pat.span, + tokens: shadow_arg.pat.tokens.clone(), + }); + d_inputs.push(shadow_arg); + } + DiffActivity::Dual => { + let mut shadow_arg = arg.clone(); + let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind { + ident.name + } else { + dbg!(&shadow_arg.pat); + panic!("not an ident?"); }; - dbg!(&name); + let name: String = format!("b{}", old_name); new_inputs.push(name.clone()); let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span); shadow_arg.pat = P(ast::Pat { @@ -311,6 +351,7 @@ fn gen_enzyme_decl( } _ => { dbg!(&activity); + panic!("Not implemented"); } } if let PatKind::Ident(_, ident, _) = arg.pat.kind { diff --git a/compiler/rustc_builtin_macros/src/errors.rs b/compiler/rustc_builtin_macros/src/errors.rs index a3dccaf11d50d..fcd1fa44c0061 100644 --- a/compiler/rustc_builtin_macros/src/errors.rs +++ b/compiler/rustc_builtin_macros/src/errors.rs @@ -164,6 +164,15 @@ pub(crate) struct AllocMustStatics { pub(crate) span: Span, } +#[derive(Diagnostic)] +#[diag(builtin_macros_autodiff_mode_activity)] +pub(crate) struct AutoDiffInvalidApplicationModeAct { + #[primary_span] + pub(crate) span: Span, + pub(crate) mode: String, + pub(crate) act: String, +} + #[derive(Diagnostic)] #[diag(builtin_macros_autodiff)] pub(crate) struct AutoDiffInvalidApplication { From 273c2b5c3c0329a82ef9e5475d78aa0f5fd5e542 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Sun, 18 Feb 2024 02:39:08 -0500 Subject: [PATCH 038/100] use propper rustc error handler for type/act check --- .../rustc_ast/src/expand/autodiff_attrs.rs | 26 +++++++++++++++- compiler/rustc_builtin_macros/messages.ftl | 1 + compiler/rustc_builtin_macros/src/autodiff.rs | 31 +++++-------------- compiler/rustc_builtin_macros/src/errors.rs | 7 +++++ 4 files changed, 41 insertions(+), 24 deletions(-) diff --git a/compiler/rustc_ast/src/expand/autodiff_attrs.rs b/compiler/rustc_ast/src/expand/autodiff_attrs.rs index 10958ec1bd434..0d3af81d935da 100644 --- a/compiler/rustc_ast/src/expand/autodiff_attrs.rs +++ b/compiler/rustc_ast/src/expand/autodiff_attrs.rs @@ -3,6 +3,8 @@ use std::str::FromStr; use thin_vec::ThinVec; use std::fmt::{Display, Formatter}; use crate::NestedMetaItem; +use crate::ptr::P; +use crate::{Ty, TyKind}; #[allow(dead_code)] #[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] @@ -40,7 +42,29 @@ pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool { } } } - +fn is_ptr_or_ref(ty: &Ty) -> bool { + match ty.kind { + TyKind::Ptr(_) | TyKind::Ref(_, _) => true, + _ => false, + } +} +// TODO We should make this more robust to also +// accept aliases of f32 and f64 +//fn is_float(ty: &Ty) -> bool { +// false +//} +pub fn valid_ty_for_activity(ty: &P, activity: DiffActivity) -> bool { + if is_ptr_or_ref(ty) { + return activity == DiffActivity::Dual || activity == DiffActivity::DualOnly || + activity == DiffActivity::Duplicated || activity == DiffActivity::DuplicatedOnly || + activity == DiffActivity::Const; + } + true + //if is_scalar_ty(&ty) { + // return activity == DiffActivity::Active || activity == DiffActivity::ActiveOnly || + // activity == DiffActivity::Const; + //} +} pub fn valid_input_activity(mode: DiffMode, activity: DiffActivity) -> bool { return match mode { DiffMode::Inactive => false, diff --git a/compiler/rustc_builtin_macros/messages.ftl b/compiler/rustc_builtin_macros/messages.ftl index f0ab49b0c3fe9..e0bb8fef688cb 100644 --- a/compiler/rustc_builtin_macros/messages.ftl +++ b/compiler/rustc_builtin_macros/messages.ftl @@ -4,6 +4,7 @@ builtin_macros_alloc_must_statics = allocators must be statics builtin_macros_autodiff = autodiff must be applied to function builtin_macros_autodiff_not_build = this rustc version does not support autodiff builtin_macros_autodiff_mode_activity = {$act} can not be used in {$mode} Mode +builtin_macros_autodiff_ty_activity = {$act} can not be used for this type builtin_macros_asm_clobber_abi = clobber_abi builtin_macros_asm_clobber_no_reg = asm with `clobber_abi` must specify explicit registers for outputs diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index 56d827b620cf7..1a6ec45578c11 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -3,7 +3,7 @@ //use crate::util::check_autodiff; use crate::errors; -use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode, valid_input_activity}; +use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode, valid_input_activity, valid_ty_for_activity}; use rustc_ast::ptr::P; use rustc_ast::token::{Token, TokenKind}; use rustc_ast::tokenstream::*; @@ -175,25 +175,6 @@ fn assure_mut_ref(ty: &ast::Ty) -> ast::Ty { ty } -// TODO We should make this more robust to also -// accept aliases of f32 and f64 -#[cfg(llvm_enzyme)] -fn is_float(ty: &ast::Ty) -> bool { - match ty.kind { - TyKind::Path(_, ref path) => { - let last = path.segments.last().unwrap(); - last.ident.name == sym::f32 || last.ident.name == sym::f64 - } - _ => false, - } -} -#[cfg(llvm_enzyme)] -fn is_ptr_or_ref(ty: &ast::Ty) -> bool { - match ty.kind { - TyKind::Ptr(_) | TyKind::Ref(_, _) => true, - _ => false, - } -} // The body of our generated functions will consist of two black_Box calls. // The first will call the primal function with the original arguments. @@ -277,7 +258,7 @@ fn gen_primal_call( // zero-initialized by Enzyme). Active arguments are not handled yet. // Each argument of the primal function (and the return type if existing) must be annotated with an // activity. -#[cfg(llvm_enzyme)] +//#[cfg(llvm_enzyme)] fn gen_enzyme_decl( ecx: &ExtCtxt<'_>, sig: &ast::FnSig, @@ -301,13 +282,17 @@ fn gen_enzyme_decl( act: activity.to_string() }); } + if !valid_ty_for_activity(&arg.ty, *activity) { + ecx.sess.dcx().emit_err(errors::AutoDiffInvalidTypeForActivity { + span: arg.ty.span, + act: activity.to_string() + }); + } match activity { DiffActivity::Active => { - assert!(is_float(&arg.ty)); act_ret.push(arg.ty.clone()); } DiffActivity::Duplicated => { - assert!(is_ptr_or_ref(&arg.ty)); let mut shadow_arg = arg.clone(); // We += into the shadow in reverse mode. shadow_arg.ty = P(assure_mut_ref(&arg.ty)); diff --git a/compiler/rustc_builtin_macros/src/errors.rs b/compiler/rustc_builtin_macros/src/errors.rs index fcd1fa44c0061..d7824cb83e68b 100644 --- a/compiler/rustc_builtin_macros/src/errors.rs +++ b/compiler/rustc_builtin_macros/src/errors.rs @@ -163,6 +163,13 @@ pub(crate) struct AllocMustStatics { #[primary_span] pub(crate) span: Span, } +#[derive(Diagnostic)] +#[diag(builtin_macros_autodiff_ty_activity)] +pub(crate) struct AutoDiffInvalidTypeForActivity { + #[primary_span] + pub(crate) span: Span, + pub(crate) act: String, +} #[derive(Diagnostic)] #[diag(builtin_macros_autodiff_mode_activity)] From bdb944978d937fc15fdfa177465f937795f225f6 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Sun, 18 Feb 2024 16:53:56 -0500 Subject: [PATCH 039/100] add CI back --- .github/workflows/enzyme-ci.yml | 38 +++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 .github/workflows/enzyme-ci.yml diff --git a/.github/workflows/enzyme-ci.yml b/.github/workflows/enzyme-ci.yml new file mode 100644 index 0000000000000..edc48879479e8 --- /dev/null +++ b/.github/workflows/enzyme-ci.yml @@ -0,0 +1,38 @@ +name: Rust CI + +on: + push: + branches: + - master + pull_request: + branches: + - master + merge_group: + +jobs: + build: + name: Rust Integration CI LLVM ${{ matrix.llvm }} ${{ matrix.build }} ${{ matrix.os }} + runs-on: ${{ matrix.os }} + + strategy: + fail-fast: false + matrix: + os: [openstack22] + + timeout-minutes: 600 + steps: + - name: checkout the source code + uses: actions/checkout@v3 + with: + fetch-depth: 2 + - name: build + run: | + mkdir build + cd build + ../configure --enable-llvm-link-shared --enable-llvm-plugins --enable-llvm-enzyme --release-channel=nightly --enable-llvm-assertions --enable-lld --enable-option-checking --enable-ninja --disable-docs + ../x.py build --stage 1 library/std library/proc_macro library/test tools/rustdoc + rustup toolchain link enzyme `pwd`/build/`rustup target list --installed`/stage1 + rustup toolchain install nightly # enables -Z unstable-options + - name: test + run: | + cargo +enzyme test --examples From 9445d7627b9a570ef7b9177f6f642720021bcd4d Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Mon, 19 Feb 2024 00:56:28 -0500 Subject: [PATCH 040/100] better error, fix activity check --- .../rustc_ast/src/expand/autodiff_attrs.rs | 27 ++++++++++++------- .../rustc_codegen_ssa/src/codegen_attrs.rs | 15 ++++++----- 2 files changed, 26 insertions(+), 16 deletions(-) diff --git a/compiler/rustc_ast/src/expand/autodiff_attrs.rs b/compiler/rustc_ast/src/expand/autodiff_attrs.rs index 0d3af81d935da..8120b7d3e1f8e 100644 --- a/compiler/rustc_ast/src/expand/autodiff_attrs.rs +++ b/compiler/rustc_ast/src/expand/autodiff_attrs.rs @@ -31,14 +31,14 @@ pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool { DiffMode::Inactive => false, DiffMode::Source => false, DiffMode::Forward => { - // Doesn't recognize all illegal cases (insufficient information) - activity != DiffActivity::Active && activity != DiffActivity::ActiveOnly - && activity != DiffActivity::Duplicated && activity != DiffActivity::DuplicatedOnly + activity == DiffActivity::Dual || + activity == DiffActivity::DualOnly || + activity == DiffActivity::Const } DiffMode::Reverse => { - // Doesn't recognize all illegal cases (insufficient information) - activity != DiffActivity::Duplicated && activity != DiffActivity::DuplicatedOnly - && activity != DiffActivity::Dual && activity != DiffActivity::DualOnly + activity == DiffActivity::Const || + activity == DiffActivity::Active || + activity == DiffActivity::ActiveOnly } } } @@ -55,8 +55,10 @@ fn is_ptr_or_ref(ty: &Ty) -> bool { //} pub fn valid_ty_for_activity(ty: &P, activity: DiffActivity) -> bool { if is_ptr_or_ref(ty) { - return activity == DiffActivity::Dual || activity == DiffActivity::DualOnly || - activity == DiffActivity::Duplicated || activity == DiffActivity::DuplicatedOnly || + return activity == DiffActivity::Dual || + activity == DiffActivity::DualOnly || + activity == DiffActivity::Duplicated || + activity == DiffActivity::DuplicatedOnly || activity == DiffActivity::Const; } true @@ -85,8 +87,13 @@ pub fn valid_input_activity(mode: DiffMode, activity: DiffActivity) -> bool { } }; } -pub fn valid_input_activities(mode: DiffMode, activity_vec: &[DiffActivity]) -> bool { - return activity_vec.iter().any(|&x| !valid_input_activity(mode, x)); +pub fn invalid_input_activities(mode: DiffMode, activity_vec: &[DiffActivity]) -> Option { + for i in 0..activity_vec.len() { + if !valid_input_activity(mode, activity_vec[i]) { + return Some(i); + } + } + None } #[allow(dead_code)] diff --git a/compiler/rustc_codegen_ssa/src/codegen_attrs.rs b/compiler/rustc_codegen_ssa/src/codegen_attrs.rs index 7edc9cf7641f7..3e231c80bfd01 100644 --- a/compiler/rustc_codegen_ssa/src/codegen_attrs.rs +++ b/compiler/rustc_codegen_ssa/src/codegen_attrs.rs @@ -1,4 +1,4 @@ -use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode, valid_ret_activity, valid_input_activities}; +use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode, valid_ret_activity, invalid_input_activities}; use rustc_ast::{ast, attr, MetaItem, MetaItemKind, NestedMetaItem}; use rustc_attr::{list_contains_name, InlineAttr, InstructionSetAttr, OptimizeAttr}; use rustc_errors::struct_span_err; @@ -811,10 +811,14 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs { } } - let msg = "Invalid activity for mode"; - let valid_input = valid_input_activities(mode, &arg_activities); - let valid_ret = valid_ret_activity(mode, ret_activity); - if !valid_input || !valid_ret { + let mut msg = "".to_string(); + if let Some(i) = invalid_input_activities(mode, &arg_activities) { + msg = format!("Invalid input activity {} for {} mode", arg_activities[i], mode); + } + if !valid_ret_activity(mode, ret_activity) { + msg = format!("Invalid return activity {} for {} mode", ret_activity, mode); + } + if msg != "".to_string() { tcx.sess .struct_span_err(attr.span, msg) .span_label(attr.span, "invalid activity") @@ -822,7 +826,6 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs { return AutoDiffAttrs::inactive(); } - AutoDiffAttrs { mode, ret_activity, input_activity: arg_activities } } From 3e0ba90b46c39658e6d50671231bf4a912232e63 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Mon, 19 Feb 2024 03:42:08 -0500 Subject: [PATCH 041/100] add parser support for const --- compiler/rustc_builtin_macros/src/autodiff.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index 1a6ec45578c11..36095cfd4dcd1 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -334,6 +334,9 @@ fn gen_enzyme_decl( }); d_inputs.push(shadow_arg); } + DiffActivity::Const => { + // Nothing to do here. + } _ => { dbg!(&activity); panic!("Not implemented"); From a74948f17b18b20866863956827e214b0d24a48d Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Tue, 27 Feb 2024 12:07:32 -0500 Subject: [PATCH 042/100] fix for opaque ptrs --- compiler/rustc_codegen_llvm/src/back/write.rs | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs index 4f3e628bda8e5..ff814f94e7481 100644 --- a/compiler/rustc_codegen_llvm/src/back/write.rs +++ b/compiler/rustc_codegen_llvm/src/back/write.rs @@ -16,12 +16,12 @@ use crate::llvm::{ LLVMAppendBasicBlockInContext, LLVMBuildCall2, LLVMBuildExtractValue, LLVMBuildRet, LLVMCountParams, LLVMCountStructElementTypes, LLVMCreateBuilderInContext, LLVMCreateStringAttribute, LLVMDeleteFunction, LLVMDisposeBuilder, LLVMDumpModule, - LLVMGetBasicBlockTerminator, LLVMGetElementType, LLVMGetFirstFunction, LLVMGetModuleContext, - LLVMGetNextFunction, LLVMGetParams, LLVMGetReturnType, LLVMGetStringAttributeAtIndex, + LLVMGetBasicBlockTerminator, LLVMGetFirstFunction, LLVMGetModuleContext, + LLVMGetNextFunction, LLVMGetParams, LLVMGetReturnType, LLVMRustGetFunctionType, LLVMGetStringAttributeAtIndex, LLVMGlobalGetValueType, LLVMIsEnumAttribute, LLVMIsStringAttribute, LLVMPositionBuilderAtEnd, LLVMRemoveStringAttributeAtIndex, LLVMReplaceAllUsesWith, LLVMRustAddEnumAttributeAtIndex, LLVMRustAddFunctionAttributes, LLVMRustGetEnumAttributeAtIndex, - LLVMRustRemoveEnumAttributeAtIndex, LLVMSetValueName2, LLVMTypeOf, LLVMVerifyFunction, + LLVMRustRemoveEnumAttributeAtIndex, LLVMSetValueName2, LLVMVerifyFunction, LLVMVoidTypeInContext, Value, }; use crate::llvm_util; @@ -642,8 +642,9 @@ unsafe fn create_wrapper<'a>( LLVMSetValueName2(fnc, c_inner_fnc_name.as_ptr(), inner_fnc_name.len() as usize); let c_outer_fnc_name = CString::new(fnc_name).unwrap(); + //let u_type = LLVMGetReturnType(u_type); let outer_fnc: &Value = - LLVMAddFunction(llmod, c_outer_fnc_name.as_ptr(), LLVMGetElementType(u_type) as &Type); + LLVMAddFunction(llmod, c_outer_fnc_name.as_ptr(), u_type); let entry = "fnc_entry".to_string(); let c_entry = CString::new(entry).unwrap(); @@ -661,6 +662,7 @@ pub(crate) unsafe fn extract_return_type<'a>( u_type: &Type, fnc_name: String, ) -> &'a Value { + let f_ty = LLVMRustGetFunctionType(fnc); let context = llvm::LLVMGetModuleContext(llmod); let inner_param_num = LLVMCountParams(fnc); @@ -675,12 +677,13 @@ pub(crate) unsafe fn extract_return_type<'a>( LLVMPositionBuilderAtEnd(builder, outer_bb); let struct_ret = LLVMBuildCall2( builder, - u_type, + f_ty, fnc, outer_args.as_mut_ptr(), outer_args.len(), c_inner_fnc_name.as_ptr(), ); + // We can use an arbitrary name here, since it will be used to store a tmp value. let inner_grad_name = "foo".to_string(); let c_inner_grad_name = CString::new(inner_grad_name).unwrap(); @@ -690,6 +693,7 @@ pub(crate) unsafe fn extract_return_type<'a>( LLVMDisposeBuilder(builder); let _fnc_ok = LLVMVerifyFunction(outer_fnc, llvm::LLVMVerifierFailureAction::LLVMAbortProcessAction); + dbg!(&outer_fnc); outer_fnc } @@ -702,7 +706,6 @@ pub(crate) unsafe fn enzyme_ad( diag_handler: &DiagCtxt, item: AutoDiffItem, ) -> Result<(), FatalError> { - dbg!("\n\n\n\n\n\n AUTO DIFF \n"); let autodiff_mode = item.attrs.mode; let rust_name = item.source; let rust_name2 = &item.target; @@ -803,7 +806,8 @@ pub(crate) unsafe fn enzyme_ad( if item.attrs.mode == DiffMode::Reverse && f_return_type != void_type { let num_elem_in_ret_struct = LLVMCountStructElementTypes(f_return_type); if num_elem_in_ret_struct == 1 { - let u_type = LLVMTypeOf(target_fnc); + + let u_type = LLVMRustGetFunctionType(target_fnc); res = extract_return_type(llmod, res, u_type, rust_name2.clone()); // TODO: check if name or name2 } } From 0ec88ee76af32f94635ebd7b219f3f0c538faa1d Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Fri, 8 Mar 2024 11:22:30 -0500 Subject: [PATCH 043/100] imprv placeholder towards removing indirection --- compiler/rustc_codegen_llvm/src/back/write.rs | 65 ++++++++++++++++++- compiler/rustc_codegen_llvm/src/llvm/ffi.rs | 7 ++ .../rustc_llvm/llvm-wrapper/RustWrapper.cpp | 15 +++++ 3 files changed, 86 insertions(+), 1 deletion(-) diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs index ff814f94e7481..4bf4b7c7a5d4c 100644 --- a/compiler/rustc_codegen_llvm/src/back/write.rs +++ b/compiler/rustc_codegen_llvm/src/back/write.rs @@ -1,3 +1,8 @@ +use crate::llvm::LLVMGetFirstBasicBlock; +use crate::llvm::LLVMRustGetTerminator; +use crate::llvm::LLVMRustEraseInstFromParent; +use crate::llvm::LLVMRustEraseBBFromParent; +//use crate::llvm::LLVMEraseFromParent; use crate::back::lto::ThinBuffer; use crate::back::owned_target_machine::OwnedTargetMachine; use crate::back::profiling::{ @@ -31,7 +36,7 @@ use crate::DiffTypeTree; use crate::LlvmCodegenBackend; use crate::ModuleLlvm; use llvm::{ - LLVMRustLLVMHasZlibCompressionForDebugSymbols, LLVMRustLLVMHasZstdCompressionForDebugSymbols, + LLVMRustLLVMHasZlibCompressionForDebugSymbols, LLVMRustLLVMHasZstdCompressionForDebugSymbols, LLVMGetNextBasicBlock, }; use rustc_ast::expand::autodiff_attrs::{AutoDiffItem, DiffActivity, DiffMode}; use rustc_codegen_ssa::back::link::ensure_removed; @@ -629,6 +634,62 @@ fn get_params(fnc: &Value) -> Vec<&Value> { } } +// DESIGN: +// Today we have our placeholder function, and our Enzyme generated one. +// We create a wrapper function and delete the placeholder body. +// We then call the wrapper from the placeholder. +// +// Soon, we won't delete the whole placeholder, but just the loop, +// and the two inline asm sections. For now we can still call the wrapper. +// In the future we call our Enzyme generated function directly and unwrap the return +// struct in our original placeholder. +// +// define internal double @_ZN2ad3bar17ha38374e821680177E(ptr align 8 %0, ptr align 8 %1, double %2) unnamed_addr #17 !dbg !13678 { +// %4 = alloca double, align 8 +// %5 = alloca ptr, align 8 +// %6 = alloca ptr, align 8 +// %7 = alloca { ptr, double }, align 8 +// store ptr %0, ptr %6, align 8 +// call void @llvm.dbg.declare(metadata ptr %6, metadata !13682, metadata !DIExpression()), !dbg !13685 +// store ptr %1, ptr %5, align 8 +// call void @llvm.dbg.declare(metadata ptr %5, metadata !13683, metadata !DIExpression()), !dbg !13685 +// store double %2, ptr %4, align 8 +// call void @llvm.dbg.declare(metadata ptr %4, metadata !13684, metadata !DIExpression()), !dbg !13686 +// call void asm sideeffect alignstack inteldialect "NOP", "~{dirflag},~{fpsr},~{flags},~{memory}"(), !dbg !13687, !srcloc !23 +// %8 = call double @_ZN2ad3foo17h95b548a9411653b2E(ptr align 8 %0), !dbg !13687 +// %9 = call double @_ZN4core4hint9black_box17h7bd67a41b0f12bdfE(double %8), !dbg !13687 +// store ptr %1, ptr %7, align 8, !dbg !13687 +// %10 = getelementptr inbounds { ptr, double }, ptr %7, i32 0, i32 1, !dbg !13687 +// store double %2, ptr %10, align 8, !dbg !13687 +// %11 = getelementptr inbounds { ptr, double }, ptr %7, i32 0, i32 0, !dbg !13687 +// %12 = load ptr, ptr %11, align 8, !dbg !13687, !nonnull !23, !align !1047, !noundef !23 +// %13 = getelementptr inbounds { ptr, double }, ptr %7, i32 0, i32 1, !dbg !13687 +// %14 = load double, ptr %13, align 8, !dbg !13687, !noundef !23 +// %15 = call { ptr, double } @_ZN4core4hint9black_box17h669f3b22afdcb487E(ptr align 8 %12, double %14), !dbg !13687 +// %16 = extractvalue { ptr, double } %15, 0, !dbg !13687 +// %17 = extractvalue { ptr, double } %15, 1, !dbg !13687 +// br label %18, !dbg !13687 +// +//18: ; preds = %18, %3 +// br label %18, !dbg !13687 + +#[allow(unused_variables)] +#[allow(unused)] +unsafe fn cleanup<'a>(fnc: &'a Value) { + // first, remove all calls from fnc + let bb = LLVMGetFirstBasicBlock(fnc); + let bb2 = LLVMGetNextBasicBlock(bb); + let br = LLVMRustGetTerminator(bb); + LLVMRustEraseInstFromParent(br); + + LLVMRustEraseBBFromParent(bb2); + //LLVMEraseFromParent(bb); + dbg!(&fnc); + //let bb2 = LLVMGet + //LLVMEraseFromParent + +} + // TODO: Here we could start adding length checks for the shaddow args. unsafe fn create_wrapper<'a>( llmod: &'a llvm::Module, @@ -744,6 +805,8 @@ pub(crate) unsafe fn enzyme_ad( )); } }; + dbg!(&target_fnc); + cleanup(target_fnc); let src_num_args = llvm::LLVMCountParams(src_fnc); let target_num_args = llvm::LLVMCountParams(target_fnc); // A really simple check diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index 8414baf046160..453817c2c40ac 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -985,7 +985,13 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff( } extern "C" { + // TODO: can I just ignore the non void return + // EraseFromParent doesn't exist :( + //pub fn LLVMEraseFromParent(BB: &BasicBlock) -> &Value; // Enzyme + pub fn LLVMRustEraseBBFromParent(B: &BasicBlock); + pub fn LLVMRustEraseInstFromParent(V: &Value); + pub fn LLVMRustGetTerminator<'a>(B: &BasicBlock) -> &'a Value; pub fn LLVMGetReturnType(T: &Type) -> &Type; pub fn LLVMDumpModule(M: &Module); pub fn LLVMCountStructElementTypes(T: &Type) -> c_uint; @@ -1215,6 +1221,7 @@ extern "C" { // Operations on instructions pub fn LLVMIsAInstruction(Val: &Value) -> Option<&Value>; pub fn LLVMGetFirstBasicBlock(Fn: &Value) -> &BasicBlock; + pub fn LLVMGetNextBasicBlock(Fn: &BasicBlock) -> &BasicBlock; // Operations on call sites pub fn LLVMSetInstructionCallConv(Instr: &Value, CC: c_uint); diff --git a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp index a2976e1f510c2..a5e8257b7f18c 100644 --- a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp +++ b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp @@ -313,6 +313,21 @@ LLVMRustCreateAttrNoValue(LLVMContextRef C, LLVMRustAttribute RustAttr) { return wrap(Attribute::get(*unwrap(C), fromRust(RustAttr))); } +extern "C" LLVMValueRef LLVMRustGetTerminator(LLVMBasicBlockRef BB) { + Instruction *ret = unwrap(BB)->getTerminator(); + return wrap(ret); +} + +extern "C" void LLVMRustEraseInstFromParent(LLVMValueRef Instr) { + if (auto I = dyn_cast(unwrap(Instr))) { + I->eraseFromParent(); + } +} + +extern "C" void LLVMRustEraseBBFromParent(LLVMBasicBlockRef BB) { + unwrap(BB)->eraseFromParent(); +} + extern "C" LLVMTypeRef LLVMRustGetFunctionType(LLVMValueRef Fn) { auto Ftype = unwrap(Fn)->getFunctionType(); return wrap(Ftype); From 691d2f9ef5cf5485246747f4e365cb87450080db Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Sun, 10 Mar 2024 18:30:19 -0400 Subject: [PATCH 044/100] simplify wrapper --- .../rustc_ast/src/expand/autodiff_attrs.rs | 1 - compiler/rustc_codegen_llvm/src/back/write.rs | 153 +++++++++--------- compiler/rustc_codegen_llvm/src/llvm/ffi.rs | 9 ++ .../rustc_llvm/llvm-wrapper/RustWrapper.cpp | 78 +++++++++ 4 files changed, 165 insertions(+), 76 deletions(-) diff --git a/compiler/rustc_ast/src/expand/autodiff_attrs.rs b/compiler/rustc_ast/src/expand/autodiff_attrs.rs index 8120b7d3e1f8e..1d5648ffb2f56 100644 --- a/compiler/rustc_ast/src/expand/autodiff_attrs.rs +++ b/compiler/rustc_ast/src/expand/autodiff_attrs.rs @@ -228,7 +228,6 @@ impl AutoDiffAttrs { } pub fn is_source(&self) -> bool { - dbg!(&self); match self.mode { DiffMode::Source => true, _ => false, diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs index 4bf4b7c7a5d4c..833abe56bcdb4 100644 --- a/compiler/rustc_codegen_llvm/src/back/write.rs +++ b/compiler/rustc_codegen_llvm/src/back/write.rs @@ -1,4 +1,15 @@ +#![allow(unused_imports)] +#![allow(unused_variables)] use crate::llvm::LLVMGetFirstBasicBlock; +use crate::llvm::LLVMRustEraseInstBefore; +use crate::llvm::LLVMRustHasDbgMetadata; +use crate::llvm::LLVMRustHasMetadata; +use crate::llvm::LLVMRustRemoveFncAttr; +use crate::llvm::LLVMMetadataAsValue; +use crate::llvm::LLVMRustGetLastInstruction; +use crate::llvm::LLVMRustDIGetInstMetadata; +use crate::llvm::LLVMRustDIGetInstMetadataOfTy; +use crate::llvm::LLVMRustgetFirstNonPHIOrDbgOrLifetime; use crate::llvm::LLVMRustGetTerminator; use crate::llvm::LLVMRustEraseInstFromParent; use crate::llvm::LLVMRustEraseBBFromParent; @@ -35,6 +46,8 @@ use crate::typetree::to_enzyme_typetree; use crate::DiffTypeTree; use crate::LlvmCodegenBackend; use crate::ModuleLlvm; +use llvm::LLVMRustDISetInstMetadata; +//use llvm::LLVMGetValueName2; use llvm::{ LLVMRustLLVMHasZlibCompressionForDebugSymbols, LLVMRustLLVMHasZstdCompressionForDebugSymbols, LLVMGetNextBasicBlock, }; @@ -542,6 +555,9 @@ pub(crate) unsafe fn llvm_optimize( // RIP compile time. // let unroll_loops = // opt_level != config::OptLevel::Size && opt_level != config::OptLevel::SizeMin; + + + let _unroll_loops = opt_level != config::OptLevel::Size && opt_level != config::OptLevel::SizeMin; let unroll_loops = false; @@ -673,89 +689,86 @@ fn get_params(fnc: &Value) -> Vec<&Value> { //18: ; preds = %18, %3 // br label %18, !dbg !13687 -#[allow(unused_variables)] -#[allow(unused)] -unsafe fn cleanup<'a>(fnc: &'a Value) { + +unsafe fn create_call<'a>(tgt: &'a Value, src: &'a Value, rev_mode: bool, + llmod: &'a llvm::Module, llcx: &llvm::Context) { // first, remove all calls from fnc - let bb = LLVMGetFirstBasicBlock(fnc); + let bb = LLVMGetFirstBasicBlock(tgt); let bb2 = LLVMGetNextBasicBlock(bb); let br = LLVMRustGetTerminator(bb); LLVMRustEraseInstFromParent(br); - LLVMRustEraseBBFromParent(bb2); - //LLVMEraseFromParent(bb); - dbg!(&fnc); - //let bb2 = LLVMGet - //LLVMEraseFromParent -} + // now add a call to inner. + // append call to src at end of bb. + let f_ty = LLVMRustGetFunctionType(src); -// TODO: Here we could start adding length checks for the shaddow args. -unsafe fn create_wrapper<'a>( - llmod: &'a llvm::Module, - fnc: &'a Value, - u_type: &Type, - fnc_name: String, -) -> (&'a Value, &'a BasicBlock, Vec<&'a Value>, Vec<&'a Value>, CString) { - let context = LLVMGetModuleContext(llmod); - let inner_fnc_name = "inner_".to_string() + &fnc_name; - let c_inner_fnc_name = CString::new(inner_fnc_name.clone()).unwrap(); - LLVMSetValueName2(fnc, c_inner_fnc_name.as_ptr(), inner_fnc_name.len() as usize); - - let c_outer_fnc_name = CString::new(fnc_name).unwrap(); - //let u_type = LLVMGetReturnType(u_type); - let outer_fnc: &Value = - LLVMAddFunction(llmod, c_outer_fnc_name.as_ptr(), u_type); - - let entry = "fnc_entry".to_string(); - let c_entry = CString::new(entry).unwrap(); - let basic_block = LLVMAppendBasicBlockInContext(context, outer_fnc, c_entry.as_ptr()); - - let outer_params: Vec<&Value> = get_params(outer_fnc); - let inner_params: Vec<&Value> = get_params(fnc); - - (outer_fnc, basic_block, outer_params, inner_params, c_inner_fnc_name) -} - -pub(crate) unsafe fn extract_return_type<'a>( - llmod: &'a llvm::Module, - fnc: &'a Value, - u_type: &Type, - fnc_name: String, -) -> &'a Value { - let f_ty = LLVMRustGetFunctionType(fnc); - let context = llvm::LLVMGetModuleContext(llmod); - - let inner_param_num = LLVMCountParams(fnc); - let (outer_fnc, outer_bb, mut outer_args, _inner_args, c_inner_fnc_name) = - create_wrapper(llmod, fnc, u_type, fnc_name); + let inner_param_num = LLVMCountParams(src); + let mut outer_args: Vec<&Value> = get_params(tgt); if inner_param_num as usize != outer_args.len() { - panic!("Args len shouldn't differ. Please report this."); + panic!("Args len shouldn't differ. Please report this. {} : {}", inner_param_num, outer_args.len()); } - let builder = LLVMCreateBuilderInContext(context); - LLVMPositionBuilderAtEnd(builder, outer_bb); - let struct_ret = LLVMBuildCall2( + let inner_fnc_name = llvm::get_value_name(src); + let c_inner_fnc_name = CString::new(inner_fnc_name).unwrap(); + + let builder = LLVMCreateBuilderInContext(llcx); + let last_inst = LLVMRustGetLastInstruction(bb).unwrap(); + LLVMPositionBuilderAtEnd(builder, bb); + let mut struct_ret = LLVMBuildCall2( builder, f_ty, - fnc, + src, outer_args.as_mut_ptr(), outer_args.len(), c_inner_fnc_name.as_ptr(), ); - // We can use an arbitrary name here, since it will be used to store a tmp value. - let inner_grad_name = "foo".to_string(); - let c_inner_grad_name = CString::new(inner_grad_name).unwrap(); - let struct_ret = LLVMBuildExtractValue(builder, struct_ret, 0, c_inner_grad_name.as_ptr()); + + // Add dummy dbg info to our newly generated call, if we have any. + let inst = LLVMRustgetFirstNonPHIOrDbgOrLifetime(bb).unwrap(); + let md_ty = llvm::LLVMGetMDKindIDInContext( + llcx, + "dbg".as_ptr() as *const c_char, + "dbg".len() as c_uint, + ); + + if LLVMRustHasMetadata(last_inst, md_ty) { + let md = LLVMRustDIGetInstMetadata(last_inst); + let md_val = LLVMMetadataAsValue(llcx, md); + let md2 = llvm::LLVMSetMetadata(struct_ret, md_ty, md_val); + } else { + dbg!("No dbg info"); + dbg!(&inst); + } + + // Our placeholder originally ended with `loop {}`, and therefore got the noreturn fnc attr. + // This is not true anymore, so we remove it. + LLVMRustRemoveFncAttr(tgt, AttributeKind::NoReturn); + + dbg!(&tgt); + + // Now clean up placeholder code. + LLVMRustEraseInstBefore(bb, last_inst); + + let f_return_type = LLVMGetReturnType(LLVMGlobalGetValueType(src)); + let void_type = LLVMVoidTypeInContext(llcx); + // Now unwrap the struct_ret if it's actually a struct + if rev_mode && f_return_type != void_type { + let num_elem_in_ret_struct = LLVMCountStructElementTypes(f_return_type); + if num_elem_in_ret_struct == 1 { + let inner_grad_name = "foo".to_string(); + let c_inner_grad_name = CString::new(inner_grad_name).unwrap(); + struct_ret = LLVMBuildExtractValue(builder, struct_ret, 0, c_inner_grad_name.as_ptr()); + } + } let _ret = LLVMBuildRet(builder, struct_ret); - let _terminator = LLVMGetBasicBlockTerminator(outer_bb); LLVMDisposeBuilder(builder); + + dbg!(&tgt); let _fnc_ok = - LLVMVerifyFunction(outer_fnc, llvm::LLVMVerifierFailureAction::LLVMAbortProcessAction); - dbg!(&outer_fnc); - outer_fnc + LLVMVerifyFunction(tgt, llvm::LLVMVerifierFailureAction::LLVMAbortProcessAction); } // As unsafe as it can be. @@ -805,8 +818,6 @@ pub(crate) unsafe fn enzyme_ad( )); } }; - dbg!(&target_fnc); - cleanup(target_fnc); let src_num_args = llvm::LLVMCountParams(src_fnc); let target_num_args = llvm::LLVMCountParams(target_fnc); // A really simple check @@ -866,23 +877,15 @@ pub(crate) unsafe fn enzyme_ad( let f_return_type = LLVMGetReturnType(LLVMGlobalGetValueType(res)); let void_type = LLVMVoidTypeInContext(llcx); - if item.attrs.mode == DiffMode::Reverse && f_return_type != void_type { - let num_elem_in_ret_struct = LLVMCountStructElementTypes(f_return_type); - if num_elem_in_ret_struct == 1 { - - let u_type = LLVMRustGetFunctionType(target_fnc); - res = extract_return_type(llmod, res, u_type, rust_name2.clone()); // TODO: check if name or name2 - } - } - LLVMSetValueName2(res, name2.as_ptr(), rust_name2.len()); - LLVMReplaceAllUsesWith(target_fnc, res); - LLVMDeleteFunction(target_fnc); + let rev_mode = item.attrs.mode == DiffMode::Reverse; + create_call(target_fnc, res, rev_mode, llmod, llcx); // TODO: implement drop for wrapper type? FreeTypeAnalysis(type_analysis); Ok(()) } + pub(crate) unsafe fn differentiate( module: &ModuleCodegen, cgcx: &CodegenContext, diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index 453817c2c40ac..7554149d6e235 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -989,6 +989,15 @@ extern "C" { // EraseFromParent doesn't exist :( //pub fn LLVMEraseFromParent(BB: &BasicBlock) -> &Value; // Enzyme + pub fn LLVMRustRemoveFncAttr(V: &Value, attr: AttributeKind); + pub fn LLVMRustHasDbgMetadata(I: &Value) -> bool; + pub fn LLVMRustHasMetadata(I: &Value, KindID: c_uint) -> bool; + pub fn LLVMRustEraseInstBefore(BB: &BasicBlock, I: &Value); + pub fn LLVMRustgetFirstNonPHIOrDbgOrLifetime<'a>(BB: &BasicBlock) -> Option<&'a Value>; + pub fn LLVMRustGetLastInstruction<'a>(BB: &BasicBlock) -> Option<&'a Value>; + pub fn LLVMRustDIGetInstMetadataOfTy(I: &Value, KindID: c_uint) -> &Metadata; + pub fn LLVMRustDIGetInstMetadata(I: &Value) -> &Metadata; + pub fn LLVMRustDISetInstMetadata<'a>(I: &Value, MD: &'a Metadata); pub fn LLVMRustEraseBBFromParent(B: &BasicBlock); pub fn LLVMRustEraseInstFromParent(V: &Value); pub fn LLVMRustGetTerminator<'a>(B: &BasicBlock) -> &'a Value; diff --git a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp index a5e8257b7f18c..8ee8098b9bebb 100644 --- a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp +++ b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp @@ -800,6 +800,84 @@ extern "C" bool LLVMRustHasModuleFlag(LLVMModuleRef M, const char *Name, return unwrap(M)->getModuleFlag(StringRef(Name, Len)) != nullptr; } +extern "C" LLVMValueRef +LLVMRustgetFirstNonPHIOrDbgOrLifetime(LLVMBasicBlockRef BB) { + if (auto *I = unwrap(BB)->getFirstNonPHIOrDbgOrLifetime()) + return wrap(I); + return nullptr; +} + +// pub fn LLVMRustGetLastInstruction<'a>(BB: &BasicBlock) -> Option<&'a Value>; +extern "C" LLVMValueRef LLVMRustGetLastInstruction(LLVMBasicBlockRef BB) { + auto Point = unwrap(BB)->rbegin(); + if (Point != unwrap(BB)->rend()) + return wrap(&*Point); + return nullptr; +} + +extern "C" void LLVMRustEraseInstBefore(LLVMBasicBlockRef bb, LLVMValueRef I) { + auto &BB = *unwrap(bb); + auto &Inst = *unwrap(I); + auto It = BB.begin(); + while (&*It != &Inst) + ++It; + assert(It != BB.end()); + // Delete in rev order to ensure no dangling references. + while (It != BB.begin()) { + auto Prev = std::prev(It); + It->eraseFromParent(); + It = Prev; + } + It->eraseFromParent(); +} + +extern "C" bool LLVMRustHasMetadata(LLVMValueRef inst, unsigned kindID) { + if (auto *I = dyn_cast(unwrap(inst))) { + return I->hasMetadata(kindID); + } + return false; +} + +extern "C" bool LLVMRustHasDbgMetadata(LLVMValueRef inst) { + if (auto *I = dyn_cast(unwrap(inst))) { + return false; + // return I->hasDbgValues(); + } + return false; +} + +extern "C" void LLVMRustRemoveFncAttr(LLVMValueRef F, + LLVMRustAttribute RustAttr) { + if (auto *Fn = dyn_cast(unwrap(F))) { + Fn->removeFnAttr(fromRust(RustAttr)); + } +} + +extern "C" LLVMMetadataRef LLVMRustDIGetInstMetadata(LLVMValueRef x) { + if (auto *I = dyn_cast(unwrap(x))) { + // auto *MD = I->getMetadata(LLVMContext::MD_dbg); + auto *MD = I->getDebugLoc().getAsMDNode(); + return wrap(MD); + } + return nullptr; +} + +extern "C" LLVMMetadataRef LLVMRustDIGetInstMetadataOfTy(LLVMValueRef x, + unsigned kindID) { + if (auto *I = dyn_cast(unwrap(x))) { + auto *MD = I->getMetadata(kindID); + return wrap(MD); + } + return nullptr; +} + +extern "C" void LLVMRustDISetInstMetadata(LLVMValueRef Inst, + LLVMMetadataRef Desc) { + if (auto *I = dyn_cast(unwrap(Inst))) { + I->setMetadata(LLVMContext::MD_dbg, unwrap(Desc)); + } +} + extern "C" void LLVMRustGlobalAddMetadata( LLVMValueRef Global, unsigned Kind, LLVMMetadataRef MD) { unwrap(Global)->addMetadata(Kind, *unwrap(MD)); From 1cc021666ad0f48d3840ef55ee4467cb43cff905 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Sun, 10 Mar 2024 19:37:23 -0400 Subject: [PATCH 045/100] cmake improvement --- src/bootstrap/src/core/build_steps/llvm.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/bootstrap/src/core/build_steps/llvm.rs b/src/bootstrap/src/core/build_steps/llvm.rs index ed21a092cd7a1..4e18806493e3e 100644 --- a/src/bootstrap/src/core/build_steps/llvm.rs +++ b/src/bootstrap/src/core/build_steps/llvm.rs @@ -842,7 +842,7 @@ impl Step for Enzyme { } let target = self.target; - let LlvmResult { llvm_config, .. } = builder.ensure(Llvm { target: self.target }); + let LlvmResult { llvm_config, llvm_cmake_dir } = builder.ensure(Llvm { target }); let out_dir = builder.enzyme_out(target); let done_stamp = out_dir.join("enzyme-finished-building"); @@ -874,7 +874,7 @@ impl Step for Enzyme { .env("LLVM_CONFIG_REAL", &llvm_config) .define("LLVM_ENABLE_ASSERTIONS", "ON") .define("ENZYME_EXTERNAL_SHARED_LIB", "OFF") - .define("LLVM_DIR", builder.llvm_out(target)); + .define("LLVM_DIR", &llvm_cmake_dir); cfg.build(); From 618bcef2184c0d04a977fdbe69db5d4209ef89ea Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Sun, 10 Mar 2024 22:40:13 -0400 Subject: [PATCH 046/100] replace loop {} by actually correct return, use default for floats --- compiler/rustc_builtin_macros/src/autodiff.rs | 94 ++++++++++++++++++- 1 file changed, 89 insertions(+), 5 deletions(-) diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index 36095cfd4dcd1..bab4d3a87bc8a 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -80,10 +80,16 @@ pub fn expand( dbg!(&x); let span = ecx.with_def_site_ctxt(expand_span); + let n_active: u32 = x.input_activity.iter() + .filter(|a| **a == DiffActivity::Active || **a == DiffActivity::ActiveOnly) + .count() as u32; let (d_sig, new_args, idents) = gen_enzyme_decl(ecx, &sig, &x, span); let new_decl_span = d_sig.span; let d_body = gen_enzyme_body( ecx, + n_active, + &sig, + &d_sig, primal, &new_args, span, @@ -184,6 +190,9 @@ fn assure_mut_ref(ty: &ast::Ty) -> ast::Ty { #[cfg(llvm_enzyme)] fn gen_enzyme_body( ecx: &ExtCtxt<'_>, + n_active: u32, + sig: &ast::FnSig, + d_sig: &ast::FnSig, primal: Ident, new_names: &[String], span: Span, @@ -192,6 +201,7 @@ fn gen_enzyme_body( idents: Vec, ) -> P { let blackbox_path = ecx.std_path(&[Symbol::intern("hint"), Symbol::intern("black_box")]); + //let default_path = ecx.def_site_path(&[Symbol::intern("f32"), Symbol::intern("default")]); let empty_loop_block = ecx.block(span, ThinVec::new()); let noop = ast::InlineAsm { template: vec![ast::InlineAsmTemplatePiece::String("NOP".to_string())], @@ -212,13 +222,13 @@ fn gen_enzyme_body( could_be_bare_literal: false, }; let unsf_expr = ecx.expr_block(P(unsf_block)); - let loop_expr = ecx.expr_loop(span, empty_loop_block); + let _loop_expr = ecx.expr_loop(span, empty_loop_block); let blackbox_call_expr = ecx.expr_path(ecx.path(span, blackbox_path)); + //let default_call_expr = ecx.expr_path(ecx.path(span, default_path)); let primal_call = gen_primal_call(ecx, span, primal, idents); // create ::core::hint::black_box(array(arr)); let black_box_primal_call = ecx.expr_call(new_decl_span, blackbox_call_expr.clone(), thin_vec![primal_call.clone()]); - // create ::core::hint::black_box((grad_arr, tang_y)); let tup_args = new_names .iter() @@ -233,9 +243,83 @@ fn gen_enzyme_body( let mut body = ecx.block(span, ThinVec::new()); body.stmts.push(ecx.stmt_semi(unsf_expr)); - body.stmts.push(ecx.stmt_semi(black_box_primal_call)); + body.stmts.push(ecx.stmt_semi(black_box_primal_call.clone())); body.stmts.push(ecx.stmt_semi(black_box_remaining_args)); - body.stmts.push(ecx.stmt_expr(loop_expr)); + + + if !d_sig.decl.output.has_ret() { + // there is no return type that we have to match, () works fine. + return body; + } + + let primal_ret = sig.decl.output.has_ret(); + + if primal_ret && n_active == 0 { + // We only have the primal ret. + body.stmts.push(ecx.stmt_expr(black_box_primal_call.clone())); + return body; + } + + if !primal_ret && n_active == 1 { + // Again no tuple return, so return default float val. + let ty = match d_sig.decl.output { + FnRetTy::Ty(ref ty) => ty.clone(), + FnRetTy::Default(span) => { + panic!("Did not expect Default ret ty: {:?}", span); + } + }; + let arg = ty.kind.is_simple_path().unwrap(); + let sl: Vec = vec![arg, Symbol::intern("default")]; + let tmp = ecx.def_site_path(&sl); + let default_call_expr = ecx.expr_path(ecx.path(span, tmp)); + let default_call_expr = ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]); + body.stmts.push(ecx.stmt_expr(default_call_expr)); + return body; + } + + let mut exprs = ThinVec::>::new(); + if primal_ret { + // We have both primal ret and active floats. + // primal ret is first, by construction. + exprs.push(primal_call.clone()); + } + + // Now construct default placeholder for each active float. + // Is there something nicer than f32::default() and f64::default()? + let mut d_ret_ty = match d_sig.decl.output { + FnRetTy::Ty(ref ty) => ty.clone(), + FnRetTy::Default(span) => { + panic!("Did not expect Default ret ty: {:?}", span); + } + }; + let mut d_ret_ty = match d_ret_ty.kind { + TyKind::Tup(ref mut tys) => { + tys.clone() + } + _ => { + // We messed up construction of d_sig + panic!("Did not expect non-tuple ret ty: {:?}", d_ret_ty); + } + }; + if primal_ret { + // We have extra handling above for the primal ret + d_ret_ty = d_ret_ty[1..].to_vec().into(); + } + + for arg in d_ret_ty.iter() { + let arg = arg.kind.is_simple_path().unwrap(); + let sl: Vec = vec![arg, Symbol::intern("default")]; + let tmp = ecx.def_site_path(&sl); + let default_call_expr = ecx.expr_path(ecx.path(span, tmp)); + let default_call_expr = ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]); + exprs.push(default_call_expr); + }; + + let ret_tuple: P = ecx.expr_tuple(span, exprs); + let ret = ecx.expr_call(new_decl_span, blackbox_call_expr.clone(), thin_vec![ret_tuple]); + body.stmts.push(ecx.stmt_expr(ret)); + //body.stmts.push(ecx.stmt_expr(ret_tuple)); + body } @@ -258,7 +342,7 @@ fn gen_primal_call( // zero-initialized by Enzyme). Active arguments are not handled yet. // Each argument of the primal function (and the return type if existing) must be annotated with an // activity. -//#[cfg(llvm_enzyme)] +#[cfg(llvm_enzyme)] fn gen_enzyme_decl( ecx: &ExtCtxt<'_>, sig: &ast::FnSig, From a46e6678ba0590e0eb651045dccf900f934587ec Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Sun, 10 Mar 2024 22:50:18 -0400 Subject: [PATCH 047/100] adjust cleanup to new return handling --- compiler/rustc_codegen_llvm/src/back/write.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs index 833abe56bcdb4..35e1db9455c8b 100644 --- a/compiler/rustc_codegen_llvm/src/back/write.rs +++ b/compiler/rustc_codegen_llvm/src/back/write.rs @@ -692,12 +692,11 @@ fn get_params(fnc: &Value) -> Vec<&Value> { unsafe fn create_call<'a>(tgt: &'a Value, src: &'a Value, rev_mode: bool, llmod: &'a llvm::Module, llcx: &llvm::Context) { + dbg!(&tgt); // first, remove all calls from fnc let bb = LLVMGetFirstBasicBlock(tgt); - let bb2 = LLVMGetNextBasicBlock(bb); let br = LLVMRustGetTerminator(bb); LLVMRustEraseInstFromParent(br); - LLVMRustEraseBBFromParent(bb2); // now add a call to inner. // append call to src at end of bb. @@ -745,7 +744,7 @@ unsafe fn create_call<'a>(tgt: &'a Value, src: &'a Value, rev_mode: bool, // Our placeholder originally ended with `loop {}`, and therefore got the noreturn fnc attr. // This is not true anymore, so we remove it. - LLVMRustRemoveFncAttr(tgt, AttributeKind::NoReturn); + //LLVMRustRemoveFncAttr(tgt, AttributeKind::NoReturn); dbg!(&tgt); From 4698879c6deaa1de4413cfb4966954d2bedc8b54 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Wed, 13 Mar 2024 19:42:54 -0400 Subject: [PATCH 048/100] improve error handling for unrecognized mode --- .../rustc_ast/src/expand/autodiff_attrs.rs | 33 ----------- compiler/rustc_builtin_macros/messages.ftl | 1 + compiler/rustc_builtin_macros/src/autodiff.rs | 57 ++++++++++++++++--- compiler/rustc_builtin_macros/src/errors.rs | 8 +++ 4 files changed, 58 insertions(+), 41 deletions(-) diff --git a/compiler/rustc_ast/src/expand/autodiff_attrs.rs b/compiler/rustc_ast/src/expand/autodiff_attrs.rs index 1d5648ffb2f56..43f45fc5758b2 100644 --- a/compiler/rustc_ast/src/expand/autodiff_attrs.rs +++ b/compiler/rustc_ast/src/expand/autodiff_attrs.rs @@ -1,8 +1,6 @@ use crate::expand::typetree::TypeTree; use std::str::FromStr; -use thin_vec::ThinVec; use std::fmt::{Display, Formatter}; -use crate::NestedMetaItem; use crate::ptr::P; use crate::{Ty, TyKind}; @@ -162,15 +160,6 @@ pub struct AutoDiffAttrs { pub input_activity: Vec, } -fn first_ident(x: &NestedMetaItem) -> rustc_span::symbol::Ident { - let segments = &x.meta_item().unwrap().path.segments; - assert!(segments.len() == 1); - segments[0].ident -} -fn name(x: &NestedMetaItem) -> String { - first_ident(x).name.to_string() -} - impl AutoDiffAttrs { pub fn has_ret_activity(&self) -> bool { match self.ret_activity { @@ -178,27 +167,6 @@ impl AutoDiffAttrs { _ => true, } } - pub fn from_ast(meta_item: &ThinVec, has_ret: bool) -> Self { - let mode = name(&meta_item[1]); - let mode = DiffMode::from_str(&mode).unwrap(); - let activities: Vec = meta_item[2..] - .iter() - .map(|x| { - let activity_str = name(&x); - DiffActivity::from_str(&activity_str).unwrap() - }) - .collect(); - - // If a return type exist, we need to split the last activity, - // otherwise we return None as placeholder. - let (ret_activity, input_activity) = if has_ret { - activities.split_last().unwrap() - } else { - (&DiffActivity::None, activities.as_slice()) - }; - - AutoDiffAttrs { mode, ret_activity: *ret_activity, input_activity: input_activity.to_vec() } - } } impl AutoDiffAttrs { @@ -221,7 +189,6 @@ impl AutoDiffAttrs { match self.mode { DiffMode::Inactive => false, _ => { - dbg!(&self); true } } diff --git a/compiler/rustc_builtin_macros/messages.ftl b/compiler/rustc_builtin_macros/messages.ftl index e0bb8fef688cb..6f67f9effe090 100644 --- a/compiler/rustc_builtin_macros/messages.ftl +++ b/compiler/rustc_builtin_macros/messages.ftl @@ -4,6 +4,7 @@ builtin_macros_alloc_must_statics = allocators must be statics builtin_macros_autodiff = autodiff must be applied to function builtin_macros_autodiff_not_build = this rustc version does not support autodiff builtin_macros_autodiff_mode_activity = {$act} can not be used in {$mode} Mode +builtin_macros_autodiff_mode = unknown Mode: `{$mode}`. Use `Forward` or `Reverse` builtin_macros_autodiff_ty_activity = {$act} can not be used for this type builtin_macros_asm_clobber_abi = clobber_abi diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index bab4d3a87bc8a..7a7aab3b1c7b4 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -17,13 +17,7 @@ use rustc_span::Span; use rustc_span::Symbol; use std::string::String; use thin_vec::{thin_vec, ThinVec}; - -#[cfg(llvm_enzyme)] -fn first_ident(x: &NestedMetaItem) -> rustc_span::symbol::Ident { - let segments = &x.meta_item().unwrap().path.segments; - assert!(segments.len() == 1); - segments[0].ident -} +use std::str::FromStr; #[cfg(not(llvm_enzyme))] pub fn expand( @@ -36,6 +30,48 @@ pub fn expand( return vec![item]; } +#[cfg(llvm_enzyme)] +fn first_ident(x: &NestedMetaItem) -> rustc_span::symbol::Ident { + let segments = &x.meta_item().unwrap().path.segments; + assert!(segments.len() == 1); + segments[0].ident +} + +#[cfg(llvm_enzyme)] +fn name(x: &NestedMetaItem) -> String { + first_ident(x).name.to_string() +} + +#[cfg(llvm_enzyme)] +pub fn from_ast(ecx: &mut ExtCtxt<'_>, meta_item: &ThinVec, has_ret: bool) -> AutoDiffAttrs { + + let mode = name(&meta_item[1]); + let mode = match DiffMode::from_str(&mode) { + Ok(x) => x, + Err(_) => { + ecx.sess.dcx().emit_err(errors::AutoDiffInvalidMode { span: meta_item[1].span(), mode}); + return AutoDiffAttrs::inactive(); + }, + }; + let activities: Vec = meta_item[2..] + .iter() + .map(|x| { + let activity_str = name(&x); + DiffActivity::from_str(&activity_str).unwrap() + }) + .collect(); + + // If a return type exist, we need to split the last activity, + // otherwise we return None as placeholder. + let (ret_activity, input_activity) = if has_ret { + activities.split_last().unwrap() + } else { + (&DiffActivity::None, activities.as_slice()) + }; + + AutoDiffAttrs { mode, ret_activity: *ret_activity, input_activity: input_activity.to_vec() } +} + #[cfg(llvm_enzyme)] pub fn expand( ecx: &mut ExtCtxt<'_>, @@ -76,7 +112,12 @@ pub fn expand( } let ts: TokenStream = TokenStream::from_iter(ts); - let x: AutoDiffAttrs = AutoDiffAttrs::from_ast(&meta_item_vec, has_ret); + let x: AutoDiffAttrs = from_ast(ecx, &meta_item_vec, has_ret); + if !x.is_active() { + // We encountered an error, so we return the original item. + // This allows us to potentially parse other attributes. + return vec![item]; + } dbg!(&x); let span = ecx.with_def_site_ctxt(expand_span); diff --git a/compiler/rustc_builtin_macros/src/errors.rs b/compiler/rustc_builtin_macros/src/errors.rs index d7824cb83e68b..763372f569f91 100644 --- a/compiler/rustc_builtin_macros/src/errors.rs +++ b/compiler/rustc_builtin_macros/src/errors.rs @@ -180,6 +180,14 @@ pub(crate) struct AutoDiffInvalidApplicationModeAct { pub(crate) act: String, } +#[derive(Diagnostic)] +#[diag(builtin_macros_autodiff_mode)] +pub(crate) struct AutoDiffInvalidMode { + #[primary_span] + pub(crate) span: Span, + pub(crate) mode: String, +} + #[derive(Diagnostic)] #[diag(builtin_macros_autodiff)] pub(crate) struct AutoDiffInvalidApplication { From 72429704e47d201c9838ea027e534b8d5811ddb5 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Thu, 14 Mar 2024 13:11:07 -0400 Subject: [PATCH 049/100] fixing forward mode with Dual return --- compiler/rustc_builtin_macros/src/autodiff.rs | 66 ++++++++++++++----- compiler/rustc_codegen_llvm/src/attributes.rs | 2 - compiler/rustc_codegen_llvm/src/llvm/ffi.rs | 8 ++- .../rustc_codegen_ssa/src/codegen_attrs.rs | 5 +- 4 files changed, 57 insertions(+), 24 deletions(-) diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index 7a7aab3b1c7b4..3251bf753ebab 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -128,6 +128,7 @@ pub fn expand( let new_decl_span = d_sig.span; let d_body = gen_enzyme_body( ecx, + &x, n_active, &sig, &d_sig, @@ -231,6 +232,7 @@ fn assure_mut_ref(ty: &ast::Ty) -> ast::Ty { #[cfg(llvm_enzyme)] fn gen_enzyme_body( ecx: &ExtCtxt<'_>, + x: &AutoDiffAttrs, n_active: u32, sig: &ast::FnSig, d_sig: &ast::FnSig, @@ -242,7 +244,6 @@ fn gen_enzyme_body( idents: Vec, ) -> P { let blackbox_path = ecx.std_path(&[Symbol::intern("hint"), Symbol::intern("black_box")]); - //let default_path = ecx.def_site_path(&[Symbol::intern("f32"), Symbol::intern("default")]); let empty_loop_block = ecx.block(span, ThinVec::new()); let noop = ast::InlineAsm { template: vec![ast::InlineAsmTemplatePiece::String("NOP".to_string())], @@ -265,12 +266,9 @@ fn gen_enzyme_body( let unsf_expr = ecx.expr_block(P(unsf_block)); let _loop_expr = ecx.expr_loop(span, empty_loop_block); let blackbox_call_expr = ecx.expr_path(ecx.path(span, blackbox_path)); - //let default_call_expr = ecx.expr_path(ecx.path(span, default_path)); let primal_call = gen_primal_call(ecx, span, primal, idents); - // create ::core::hint::black_box(array(arr)); let black_box_primal_call = ecx.expr_call(new_decl_span, blackbox_call_expr.clone(), thin_vec![primal_call.clone()]); - // create ::core::hint::black_box((grad_arr, tang_y)); let tup_args = new_names .iter() .map(|arg| ecx.expr_path(ecx.path_ident(span, Ident::from_str(arg)))) @@ -295,7 +293,7 @@ fn gen_enzyme_body( let primal_ret = sig.decl.output.has_ret(); - if primal_ret && n_active == 0 { + if primal_ret && n_active == 0 && x.mode == DiffMode::Reverse { // We only have the primal ret. body.stmts.push(ecx.stmt_expr(black_box_primal_call.clone())); return body; @@ -342,24 +340,40 @@ fn gen_enzyme_body( panic!("Did not expect non-tuple ret ty: {:?}", d_ret_ty); } }; - if primal_ret { - // We have extra handling above for the primal ret - d_ret_ty = d_ret_ty[1..].to_vec().into(); - } + if x.mode == DiffMode::Forward { + if x.ret_activity == DiffActivity::Dual { + assert!(d_ret_ty.len() == 2); + // both should be identical, by construction + let arg = d_ret_ty[0].kind.is_simple_path().unwrap(); + let arg2 = d_ret_ty[1].kind.is_simple_path().unwrap(); + assert!(arg == arg2); + let sl: Vec = vec![arg, Symbol::intern("default")]; + let tmp = ecx.def_site_path(&sl); + let default_call_expr = ecx.expr_path(ecx.path(span, tmp)); + let default_call_expr = ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]); + exprs.push(default_call_expr); + } + } else { + assert!(x.mode == DiffMode::Reverse); - for arg in d_ret_ty.iter() { - let arg = arg.kind.is_simple_path().unwrap(); - let sl: Vec = vec![arg, Symbol::intern("default")]; - let tmp = ecx.def_site_path(&sl); - let default_call_expr = ecx.expr_path(ecx.path(span, tmp)); - let default_call_expr = ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]); - exprs.push(default_call_expr); - }; + if primal_ret { + // We have extra handling above for the primal ret + d_ret_ty = d_ret_ty[1..].to_vec().into(); + } + + for arg in d_ret_ty.iter() { + let arg = arg.kind.is_simple_path().unwrap(); + let sl: Vec = vec![arg, Symbol::intern("default")]; + let tmp = ecx.def_site_path(&sl); + let default_call_expr = ecx.expr_path(ecx.path(span, tmp)); + let default_call_expr = ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]); + exprs.push(default_call_expr); + }; + } let ret_tuple: P = ecx.expr_tuple(span, exprs); let ret = ecx.expr_call(new_decl_span, blackbox_call_expr.clone(), thin_vec![ret_tuple]); body.stmts.push(ecx.stmt_expr(ret)); - //body.stmts.push(ecx.stmt_expr(ret_tuple)); body } @@ -505,6 +519,22 @@ fn gen_enzyme_decl( } d_decl.inputs = d_inputs.into(); + if let DiffMode::Forward = x.mode { + if let DiffActivity::Dual = x.ret_activity { + let ty = match d_decl.output { + FnRetTy::Ty(ref ty) => ty.clone(), + FnRetTy::Default(span) => { + panic!("Did not expect Default ret ty: {:?}", span); + } + }; + // Dual can only be used for f32/f64 ret. + // In that case we return now a tuple with two floats. + let kind = TyKind::Tup(thin_vec![ty.clone(), ty.clone()]); + let ty = P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None }); + d_decl.output = FnRetTy::Ty(ty); + } + } + // If we have an active input scalar, add it's gradient to the // return type. This might require changing the return type to a // tuple. diff --git a/compiler/rustc_codegen_llvm/src/attributes.rs b/compiler/rustc_codegen_llvm/src/attributes.rs index 032a89c2affa7..7be288886d039 100644 --- a/compiler/rustc_codegen_llvm/src/attributes.rs +++ b/compiler/rustc_codegen_llvm/src/attributes.rs @@ -314,8 +314,6 @@ pub fn from_fn_attrs<'ll, 'tcx>( if codegen_fn_attrs.inline == InlineAttr::None && instance.def.requires_inline(cx.tcx) { InlineAttr::Hint } else if autodiff_attrs.is_active() { - dbg!("autodiff_attrs.is_active()"); - dbg!(&autodiff_attrs); InlineAttr::Never } else { codegen_fn_attrs.inline diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index 7554149d6e235..ac00042c677b7 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -877,7 +877,13 @@ pub(crate) unsafe fn enzyme_rust_forward_diff( // We don't support volatile / extern / (global?) values. // Just because I didn't had time to test them, and it seems less urgent. - let args_uncacheable = vec![0; input_activity.len()]; + dbg!(&fnc); + let args_uncacheable = vec![0; input_tts.len()]; + assert!(args_uncacheable.len() == input_activity.len()); + let num_fnc_args = LLVMCountParams(fnc); + println!("num_fnc_args: {}", num_fnc_args); + println!("input_activity.len(): {}", input_activity.len()); + assert!(num_fnc_args == input_activity.len() as u32); let kv_tmp = IntList { data: std::ptr::null_mut(), size: 0 }; diff --git a/compiler/rustc_codegen_ssa/src/codegen_attrs.rs b/compiler/rustc_codegen_ssa/src/codegen_attrs.rs index 3e231c80bfd01..76b1f220c9c4d 100644 --- a/compiler/rustc_codegen_ssa/src/codegen_attrs.rs +++ b/compiler/rustc_codegen_ssa/src/codegen_attrs.rs @@ -717,14 +717,13 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs { return AutoDiffAttrs::inactive(); } }; - dbg!("autodiff_attr = {:?}", &attr); let list = attr.meta_item_list().unwrap_or_default(); - dbg!("autodiff_attrs list = {:?}", &list); + //dbg!("autodiff_attrs list = {:?}", &list); // empty autodiff attribute macros (i.e. `#[autodiff]`) are used to mark source functions if list.len() == 0 { - dbg!("autodiff_attrs: source"); + //dbg!("autodiff_attrs: source"); return AutoDiffAttrs::source(); } From 78bd09aac0a7dea64c783ecf4d77b2890b4470e3 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Mon, 18 Mar 2024 19:35:57 -0400 Subject: [PATCH 050/100] add handling for ->() --- compiler/rustc_ast/src/expand/autodiff_attrs.rs | 4 ++++ compiler/rustc_builtin_macros/src/autodiff.rs | 6 ++++++ 2 files changed, 10 insertions(+) diff --git a/compiler/rustc_ast/src/expand/autodiff_attrs.rs b/compiler/rustc_ast/src/expand/autodiff_attrs.rs index 43f45fc5758b2..9cee364543af7 100644 --- a/compiler/rustc_ast/src/expand/autodiff_attrs.rs +++ b/compiler/rustc_ast/src/expand/autodiff_attrs.rs @@ -25,6 +25,10 @@ impl Display for DiffMode { } pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool { + if activity == DiffActivity::None { + // Only valid if primal returns (), but we can't check that here. + return true; + } match mode { DiffMode::Inactive => false, DiffMode::Source => false, diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index 3251bf753ebab..d77bc5c2ce5a3 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -110,6 +110,12 @@ pub fn expand( ts.push(TokenTree::Token(t, Spacing::Joint)); ts.push(TokenTree::Token(comma.clone(), Spacing::Alone)); } + if !sig.decl.output.has_ret() { + // We don't want users to provide a return activity if the function doesn't return anything. + // For simplicity, we just add a dummy token to the end of the list. + let t = Token::new(TokenKind::Ident(sym::None, false), Span::default()); + ts.push(TokenTree::Token(t, Spacing::Joint)); + } let ts: TokenStream = TokenStream::from_iter(ts); let x: AutoDiffAttrs = from_ast(ecx, &meta_item_vec, has_ret); From 6933f08f7ba38f6a1b6751ee7873ad59e3d59f6a Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Mon, 18 Mar 2024 23:38:27 -0400 Subject: [PATCH 051/100] enforce fat-lto when using autodiff --- compiler/rustc_codegen_ssa/src/back/write.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/compiler/rustc_codegen_ssa/src/back/write.rs b/compiler/rustc_codegen_ssa/src/back/write.rs index 47da3254eb39e..0a9117072b7ac 100644 --- a/compiler/rustc_codegen_ssa/src/back/write.rs +++ b/compiler/rustc_codegen_ssa/src/back/write.rs @@ -397,6 +397,7 @@ fn generate_lto_work( // We are adding a single work item, so the cost doesn't matter. vec![(WorkItem::LTO(module), 0)] } else { + assert!(autodiff.is_empty()); assert!(needs_fat_lto.is_empty()); let (lto_modules, copy_jobs) = B::run_thin_lto(cgcx, needs_thin_lto, import_only_modules) .unwrap_or_else(|e| e.raise()); From 7352472d6f64e5309aaa924ae6b862ea22c65a6a Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Tue, 19 Mar 2024 01:00:10 -0400 Subject: [PATCH 052/100] handle more () ret cases --- compiler/rustc_builtin_macros/src/autodiff.rs | 5 ++++- compiler/rustc_codegen_llvm/src/back/write.rs | 9 ++++++++- compiler/rustc_codegen_ssa/src/codegen_attrs.rs | 9 ++++----- 3 files changed, 16 insertions(+), 7 deletions(-) diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index d77bc5c2ce5a3..1443d1dfd5447 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -379,7 +379,10 @@ fn gen_enzyme_body( let ret_tuple: P = ecx.expr_tuple(span, exprs); let ret = ecx.expr_call(new_decl_span, blackbox_call_expr.clone(), thin_vec![ret_tuple]); - body.stmts.push(ecx.stmt_expr(ret)); + if d_sig.decl.output.has_ret() { + // If we return (), we don't have to match the return type. + body.stmts.push(ecx.stmt_expr(ret)); + } body } diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs index 35e1db9455c8b..c0a9d4e9d23d6 100644 --- a/compiler/rustc_codegen_llvm/src/back/write.rs +++ b/compiler/rustc_codegen_llvm/src/back/write.rs @@ -1,6 +1,7 @@ #![allow(unused_imports)] #![allow(unused_variables)] use crate::llvm::LLVMGetFirstBasicBlock; +use crate::llvm::LLVMBuildRetVoid; use crate::llvm::LLVMRustEraseInstBefore; use crate::llvm::LLVMRustHasDbgMetadata; use crate::llvm::LLVMRustHasMetadata; @@ -762,7 +763,13 @@ unsafe fn create_call<'a>(tgt: &'a Value, src: &'a Value, rev_mode: bool, struct_ret = LLVMBuildExtractValue(builder, struct_ret, 0, c_inner_grad_name.as_ptr()); } } - let _ret = LLVMBuildRet(builder, struct_ret); + if f_return_type != void_type { + dbg!("Returning struct"); + let _ret = LLVMBuildRet(builder, struct_ret); + } else { + dbg!("Returning void"); + let _ret = LLVMBuildRetVoid(builder); + } LLVMDisposeBuilder(builder); dbg!(&tgt); diff --git a/compiler/rustc_codegen_ssa/src/codegen_attrs.rs b/compiler/rustc_codegen_ssa/src/codegen_attrs.rs index 76b1f220c9c4d..4d1dd44031669 100644 --- a/compiler/rustc_codegen_ssa/src/codegen_attrs.rs +++ b/compiler/rustc_codegen_ssa/src/codegen_attrs.rs @@ -704,8 +704,9 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs { .filter(|attr| attr.name_or_empty() == sym::rustc_autodiff) .collect::>(); - // check for exactly one autodiff attribute on extern block - let msg_once = "cg_ssa: autodiff attribute can only be applied once"; + // check for exactly one autodiff attribute on placeholder functions. + // There should only be one, since we generate a new placeholder per ad macro. + let msg_once = "cg_ssa: implementation bug. Autodiff attribute can only be applied once"; let attr = match attrs.len() { 0 => return AutoDiffAttrs::inactive(), 1 => attrs.get(0).unwrap(), @@ -719,11 +720,9 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs { }; let list = attr.meta_item_list().unwrap_or_default(); - //dbg!("autodiff_attrs list = {:?}", &list); // empty autodiff attribute macros (i.e. `#[autodiff]`) are used to mark source functions if list.len() == 0 { - //dbg!("autodiff_attrs: source"); return AutoDiffAttrs::source(); } @@ -784,7 +783,7 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs { } }; - // Now parse all the intermediate (inptut) activities + // Now parse all the intermediate (input) activities let msg_arg_activity = "autodiff attribute must contain the return activity"; let mut arg_activities: Vec = vec![]; for arg in input_activities { From daad666ad8ea3299785dbc6bad36371cf4e17d44 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Tue, 19 Mar 2024 01:01:17 -0400 Subject: [PATCH 053/100] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 6a37cac337ce0..425ed87f3718c 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ cd build Afterwards you can build rustc using: ``` -../x.py build --stage 1 library/std library/proc_macro library/test tools/rustdoc +../x.py build --stage 1 library ``` Afterwards rustc toolchain link will allow you to use it through cargo: From 294f1373a644ff258e0af386bac619210ab762fc Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Wed, 20 Mar 2024 01:34:38 -0400 Subject: [PATCH 054/100] updating enzyme to latest master --- src/tools/enzyme | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tools/enzyme b/src/tools/enzyme index a309cc083f64f..880fb3272e4e3 160000 --- a/src/tools/enzyme +++ b/src/tools/enzyme @@ -1 +1 @@ -Subproject commit a309cc083f64fb6c17689a81feffa804bc7fcb3d +Subproject commit 880fb3272e4e34a8535a36f0d0a0fb232630a2b9 From 215089d5970edd8ac511f1ec333aa029763a2677 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Wed, 20 Mar 2024 01:46:39 -0400 Subject: [PATCH 055/100] rust-alloc testing --- src/tools/enzyme | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tools/enzyme b/src/tools/enzyme index 880fb3272e4e3..7e3e90f428706 160000 --- a/src/tools/enzyme +++ b/src/tools/enzyme @@ -1 +1 @@ -Subproject commit 880fb3272e4e34a8535a36f0d0a0fb232630a2b9 +Subproject commit 7e3e90f4287068a41d3fb5a99127ee2857353b04 From 1eff53cb545c6b0b012b955b065786c032565a34 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Tue, 26 Mar 2024 15:59:37 -0400 Subject: [PATCH 056/100] div bugfixes --- compiler/rustc_ast/src/expand/typetree.rs | 12 + compiler/rustc_builtin_macros/src/autodiff.rs | 5 +- compiler/rustc_codegen_llvm/src/back/write.rs | 6 +- compiler/rustc_codegen_llvm/src/llvm/ffi.rs | 19 +- .../rustc_llvm/llvm-wrapper/RustWrapper.cpp | 7 + compiler/rustc_middle/messages.ftl | 2 + compiler/rustc_middle/src/error.rs | 6 + compiler/rustc_middle/src/ty/mod.rs | 219 +++++++++++++++++- .../rustc_monomorphize/src/partitioning.rs | 209 ++++++----------- 9 files changed, 338 insertions(+), 147 deletions(-) diff --git a/compiler/rustc_ast/src/expand/typetree.rs b/compiler/rustc_ast/src/expand/typetree.rs index bab35d55e26b8..dfb9372f92f9c 100644 --- a/compiler/rustc_ast/src/expand/typetree.rs +++ b/compiler/rustc_ast/src/expand/typetree.rs @@ -14,6 +14,18 @@ pub enum Kind { #[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] pub struct TypeTree(pub Vec); +impl TypeTree { + pub fn new() -> Self { + Self(Vec::new()) + } +} + +#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] +pub struct FncTree { + pub args: Vec, + pub ret: TypeTree, +} + #[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] pub struct Type { pub offset: isize, diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index 1443d1dfd5447..674f20577934c 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -204,7 +204,10 @@ pub fn expand( tokens: ts, }); attr.kind = ast::AttrKind::Normal(rustc_ad_attr); - let d_fn = ecx.item(span, d_ident, thin_vec![attr], asdf); + let mut d_fn = ecx.item(span, d_ident, thin_vec![attr], asdf); + + // Copy visibility from original function + d_fn.vis = orig_item.vis.clone(); let orig_annotatable = Annotatable::Item(orig_item); let d_annotatable = Annotatable::Item(d_fn); diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs index c0a9d4e9d23d6..4f1b1fbe35845 100644 --- a/compiler/rustc_codegen_llvm/src/back/write.rs +++ b/compiler/rustc_codegen_llvm/src/back/write.rs @@ -48,7 +48,6 @@ use crate::DiffTypeTree; use crate::LlvmCodegenBackend; use crate::ModuleLlvm; use llvm::LLVMRustDISetInstMetadata; -//use llvm::LLVMGetValueName2; use llvm::{ LLVMRustLLVMHasZlibCompressionForDebugSymbols, LLVMRustLLVMHasZstdCompressionForDebugSymbols, LLVMGetNextBasicBlock, }; @@ -907,6 +906,11 @@ pub(crate) unsafe fn differentiate( llvm::set_strict_aliasing(false); + if std::env::var("ENZYME_LOOSE_TYPES").is_ok() { + dbg!("Setting loose types to true"); + llvm::set_loose_types(true); + } + if std::env::var("ENZYME_PRINT_MOD").is_ok() { unsafe { LLVMDumpModule(llmod); diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index ac00042c677b7..bb2b01f09227e 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -1,6 +1,5 @@ #![allow(non_camel_case_types)] #![allow(non_upper_case_globals)] -//#![allow(unexpected_cfgs)] use rustc_ast::expand::autodiff_attrs::DiffActivity; @@ -1827,6 +1826,12 @@ extern "C" { AttrsLen: size_t, ); + pub fn LLVMRustAddParamAttr<'a>( + Instr: &'a Value, + index: c_uint, + Attr: &'a Attribute + ); + pub fn LLVMRustBuildInvoke<'a>( B: &Builder<'a>, Ty: &'a Type, @@ -2648,6 +2653,7 @@ pub mod Fallback_AD { pub fn set_print_type(print: bool) { unimplemented!() } pub fn set_print(print: bool) { unimplemented!() } pub fn set_strict_aliasing(strict: bool) { unimplemented!() } + pub fn set_loose_types(loose: bool) { unimplemented!() } pub fn EnzymeCreatePrimalAndGradient<'a>( arg1: EnzymeLogicRef, @@ -2979,6 +2985,7 @@ extern "C" { static mut EnzymePrintType: c_void; static mut EnzymePrint: c_void; static mut EnzymeStrictAliasing: c_void; + static mut looseTypeAnalysis: c_void; } pub fn set_max_int_offset(offset: u64) { let offset = offset.try_into().unwrap(); @@ -3023,6 +3030,12 @@ pub fn set_strict_aliasing(strict: bool) { EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymeStrictAliasing), strict as u8); } } +pub fn set_loose_types(loose: bool) { + unsafe { + EnzymeSetCLBool(std::ptr::addr_of_mut!(looseTypeAnalysis), loose as u8); + } +} + extern "C" { pub fn EnzymeCreatePrimalAndGradient<'a>( arg1: EnzymeLogicRef, @@ -3108,7 +3121,7 @@ extern "C" { max_size: i64, add_offset: u64, ); - pub(super) fn EnzymeTypeTreeToStringFree(arg1: *const c_char); - pub(super) fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char; + pub fn EnzymeTypeTreeToStringFree(arg1: *const c_char); + pub fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char; } } diff --git a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp index 8ee8098b9bebb..548040579b392 100644 --- a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp +++ b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp @@ -871,6 +871,13 @@ extern "C" LLVMMetadataRef LLVMRustDIGetInstMetadataOfTy(LLVMValueRef x, return nullptr; } +extern "C" void LLVMRustAddParamAttr(LLVMValueRef call, unsigned i, + LLVMAttributeRef RustAttr) { + if (auto *CI = dyn_cast(unwrap(call))) { + CI->addParamAttr(i, unwrap(RustAttr)); + } +} + extern "C" void LLVMRustDISetInstMetadata(LLVMValueRef Inst, LLVMMetadataRef Desc) { if (auto *I = dyn_cast(unwrap(Inst))) { diff --git a/compiler/rustc_middle/messages.ftl b/compiler/rustc_middle/messages.ftl index 27d555d7e26c7..31e3594115665 100644 --- a/compiler/rustc_middle/messages.ftl +++ b/compiler/rustc_middle/messages.ftl @@ -85,3 +85,5 @@ middle_unknown_layout = middle_values_too_big = values of the type `{$ty}` are too big for the current architecture + +middle_unsupported_union = we don't support unions yet: '{$ty_name}' diff --git a/compiler/rustc_middle/src/error.rs b/compiler/rustc_middle/src/error.rs index 3c5536570872a..4f763663a50fe 100644 --- a/compiler/rustc_middle/src/error.rs +++ b/compiler/rustc_middle/src/error.rs @@ -151,5 +151,11 @@ pub struct ErroneousConstant { pub span: Span, } +#[derive(Diagnostic)] +#[diag(middle_unsupported_union)] +pub struct UnsupportedUnion { + pub ty_name: String, +} + /// Used by `rustc_const_eval` pub use crate::fluent_generated::middle_adjust_for_foreign_abi_error; diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs index a4abdf84fc0ce..d16da3fc46dbc 100644 --- a/compiler/rustc_middle/src/ty/mod.rs +++ b/compiler/rustc_middle/src/ty/mod.rs @@ -10,6 +10,10 @@ //! ["The `ty` module: representing types"]: https://rustc-dev-guide.rust-lang.org/ty.html #![allow(rustc::usage_of_ty_tykind)] +#![allow(unused_imports)] + +use rustc_ast::expand::typetree::{Type, Kind, TypeTree, FncTree}; +use rustc_target::abi::FieldsShape; pub use self::fold::{FallibleTypeFolder, TypeFoldable, TypeFolder, TypeSuperFoldable}; pub use self::visit::{TypeSuperVisitable, TypeVisitable, TypeVisitableExt, TypeVisitor}; @@ -17,7 +21,7 @@ pub use self::AssocItemContainer::*; pub use self::BorrowKind::*; pub use self::IntVarValue::*; pub use self::Variance::*; -use crate::error::{OpaqueHiddenTypeMismatch, TypeMismatchReason}; +use crate::error::{OpaqueHiddenTypeMismatch, TypeMismatchReason, UnsupportedUnion}; use crate::metadata::ModChild; use crate::middle::privacy::EffectiveVisibilities; use crate::mir::{Body, CoroutineLayout}; @@ -2715,3 +2719,216 @@ mod size_asserts { static_assert_size!(WithCachedTypeInfo>, 56); // tidy-alphabetical-end } + +pub fn typetree_from<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree { + let mut visited = vec![]; + let ty = typetree_from_ty(ty, tcx, 0, false, &mut visited); + let tt = Type { offset: -1, kind: Kind::Pointer, size: 8, child: ty }; + return TypeTree(vec![tt]); +} + +pub fn fnc_typetrees<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>) -> FncTree { + if !fn_ty.is_fn() { + return FncTree { args: vec![], ret: TypeTree::new() }; + } + let fnc_binder: ty::Binder<'_, ty::FnSig<'_>> = fn_ty.fn_sig(tcx); + + // TODO: verify. + let x: ty::FnSig<'_> = match fnc_binder.no_bound_vars() { + Some(x) => x, + None => return FncTree { args: vec![], ret: TypeTree::new() }, + }; + + let mut visited = vec![]; + let mut args = vec![]; + for arg in x.inputs() { + visited.clear(); + let arg_tt = typetree_from_ty(*arg, tcx, 0, false, &mut visited); + args.push(arg_tt); + } + + visited.clear(); + let ret = typetree_from_ty(x.output(), tcx, 0, false, &mut visited); + + FncTree { args, ret } +} + +pub fn typetree_from_ty<'a>(ty: Ty<'a>, tcx: TyCtxt<'a>, depth: usize, safety: bool, visited: &mut Vec>) -> TypeTree { + if depth > 20 { + dbg!(&ty); + } + if visited.contains(&ty) { + // recursive type + dbg!("recursive type"); + dbg!(&ty); + return TypeTree::new(); + } + visited.push(ty); + + if ty.is_unsafe_ptr() || ty.is_ref() || ty.is_box() { + if ty.is_fn_ptr() { + unimplemented!("what to do whith fn ptr?"); + } + let inner_ty = ty.builtin_deref(true).unwrap().ty; + //visited.push(inner_ty); + let child = typetree_from_ty(inner_ty, tcx, depth + 1, safety, visited); + let tt = Type { offset: -1, kind: Kind::Pointer, size: 8, child }; + visited.pop(); + return TypeTree(vec![tt]); + } + + + if ty.is_closure() || ty.is_coroutine() || ty.is_fresh() || ty.is_fn() { + visited.pop(); + return TypeTree::new(); + } + + if ty.is_scalar() { + let (kind, size) = if ty.is_integral() || ty.is_char() || ty.is_bool() { + (Kind::Integer, ty.primitive_size(tcx).bytes_usize()) + } else if ty.is_floating_point() { + match ty { + x if x == tcx.types.f32 => (Kind::Float, 4), + x if x == tcx.types.f64 => (Kind::Double, 8), + _ => panic!("floatTy scalar that is neither f32 nor f64"), + } + } else { + panic!("scalar that is neither integral nor floating point"); + }; + visited.pop(); + return TypeTree(vec![Type { offset: -1, child: TypeTree::new(), kind, size }]); + } + + let param_env_and = ParamEnvAnd { param_env: ParamEnv::empty(), value: ty }; + + let layout = tcx.layout_of(param_env_and); + assert!(layout.is_ok()); + + let layout = layout.unwrap().layout; + let fields = layout.fields(); + let max_size = layout.size(); + + + + if ty.is_adt() && !ty.is_simd() { + let adt_def = ty.ty_adt_def().unwrap(); + + if adt_def.is_struct() { + let (offsets, _memory_index) = match fields { + // Manuel TODO: + FieldsShape::Arbitrary { offsets: o, memory_index: m } => (o, m), + FieldsShape::Array { .. } => {return TypeTree::new();}, //e.g. core::arch::x86_64::__m128i, TODO: later + FieldsShape::Union(_) => {return TypeTree::new();}, + FieldsShape::Primitive => {return TypeTree::new();}, + //_ => {dbg!(&adt_def); panic!("")}, + }; + + let substs = match ty.kind() { + Adt(_, subst_ref) => subst_ref, + _ => panic!(""), + }; + + let fields = adt_def.all_fields(); + let fields = fields + .into_iter() + .zip(offsets.into_iter()) + .filter_map(|(field, offset)| { + let field_ty: Ty<'_> = field.ty(tcx, substs); + let field_ty: Ty<'_> = + tcx.normalize_erasing_regions(ParamEnv::empty(), field_ty); + + if field_ty.is_phantom_data() { + return None; + } + + //visited.push(field_ty); + let mut child = typetree_from_ty(field_ty, tcx, depth + 1, safety, visited).0; + + for c in &mut child { + if c.offset == -1 { + c.offset = offset.bytes() as isize + } else { + c.offset += offset.bytes() as isize; + } + } + + Some(child) + }) + .flatten() + .collect::>(); + + visited.pop(); + let ret_tt = TypeTree(fields); + return ret_tt; + } else if adt_def.is_enum() { + // Enzyme can't represent enums, so let it figure it out itself, without seeeding + // typetree + //unimplemented!("adt that is an enum"); + } else { + //let ty_name = tcx.def_path_debug_str(adt_def.did()); + //tcx.sess.emit_fatal(UnsupportedUnion { ty_name }); + } + } + + if ty.is_simd() { + trace!("simd"); + let (_size, inner_ty) = ty.simd_size_and_type(tcx); + //visited.push(inner_ty); + let _sub_tt = typetree_from_ty(inner_ty, tcx, depth + 1, safety, visited); + //let tt = TypeTree( + // std::iter::repeat(subtt) + // .take(*count as usize) + // .enumerate() + // .map(|(idx, x)| x.0.into_iter().map(move |x| x.add_offset((idx * size) as isize))) + // .flatten() + // .collect(), + //); + // TODO + visited.pop(); + return TypeTree::new(); + } + + if ty.is_array() { + let (stride, count) = match fields { + FieldsShape::Array { stride: s, count: c } => (s, c), + _ => panic!(""), + }; + let byte_stride = stride.bytes_usize(); + let byte_max_size = max_size.bytes_usize(); + + assert!(byte_stride * *count as usize == byte_max_size); + if (*count as usize) == 0 { + return TypeTree::new(); + } + let sub_ty = ty.builtin_index().unwrap(); + //visited.push(sub_ty); + let subtt = typetree_from_ty(sub_ty, tcx, depth + 1, safety, visited); + + // calculate size of subtree + let param_env_and = ParamEnvAnd { param_env: ParamEnv::empty(), value: sub_ty }; + let size = tcx.layout_of(param_env_and).unwrap().size.bytes() as usize; + let tt = TypeTree( + std::iter::repeat(subtt) + .take(*count as usize) + .enumerate() + .map(|(idx, x)| x.0.into_iter().map(move |x| x.add_offset((idx * size) as isize))) + .flatten() + .collect(), + ); + + visited.pop(); + return tt; + } + + if ty.is_slice() { + let sub_ty = ty.builtin_index().unwrap(); + //visited.push(sub_ty); + let subtt = typetree_from_ty(sub_ty, tcx, depth + 1, safety, visited); + + visited.pop(); + return subtt; + } + + visited.pop(); + TypeTree::new() +} diff --git a/compiler/rustc_monomorphize/src/partitioning.rs b/compiler/rustc_monomorphize/src/partitioning.rs index 8046e4392f475..e15dbe9a7c9d6 100644 --- a/compiler/rustc_monomorphize/src/partitioning.rs +++ b/compiler/rustc_monomorphize/src/partitioning.rs @@ -98,8 +98,7 @@ use std::fs::{self, File}; use std::io::{BufWriter, Write}; use std::path::{Path, PathBuf}; -use rustc_ast::expand::autodiff_attrs::AutoDiffItem; -use rustc_ast::expand::typetree::{Kind, Type, TypeTree}; +use rustc_ast::expand::autodiff_attrs::{AutoDiffItem, AutoDiffAttrs}; use rustc_data_structures::fx::{FxHashMap, FxHashSet}; use rustc_data_structures::sync; use rustc_hir::def::DefKind; @@ -114,13 +113,13 @@ use rustc_middle::mir::mono::{ use rustc_middle::query::Providers; use rustc_middle::ty::print::{characteristic_def_id_of_type, with_no_trimmed_paths}; use rustc_middle::ty::{ - self, visit::TypeVisitableExt, Adt, InstanceDef, ParamEnv, ParamEnvAnd, Ty, TyCtxt, + self, visit::TypeVisitableExt, InstanceDef, ParamEnv, TyCtxt, + fnc_typetrees }; use rustc_session::config::{DumpMonoStatsFormat, SwitchWithOptPath}; use rustc_session::CodegenUnits; use rustc_span::symbol::Symbol; use rustc_symbol_mangling::symbol_name_for_instance_in_crate; -use rustc_target::abi::FieldsShape; use crate::collector::UsageMap; use crate::collector::{self, MonoItemCollectionMode}; @@ -1167,7 +1166,7 @@ fn collect_and_partition_mono_items( }) .filter_map(|(item, instance)| { let target_id = instance.def_id(); - let target_attrs = tcx.autodiff_attrs(target_id); + let target_attrs: &AutoDiffAttrs = tcx.autodiff_attrs(target_id); if target_attrs.is_source() { dbg!("source"); dbg!(&target_attrs); @@ -1193,7 +1192,9 @@ fn collect_and_partition_mono_items( source.map(|inst| { println!("source_id: {:?}", inst.def_id()); - let (inputs, output) = fnc_typetrees(inst.ty(tcx, ParamEnv::empty()), tcx); + let fnc_tree = fnc_typetrees(tcx, inst.ty(tcx, ParamEnv::empty())); + let (inputs, output) = (fnc_tree.args, fnc_tree.ret); + //check_types(inst.ty(tcx, ParamEnv::empty()), tcx, &target_attrs.input_activity); let symb = symbol_name_for_instance_in_crate(tcx, inst.clone(), LOCAL_CRATE); target_attrs.clone().into_item(symb, target_symbol, inputs, output) @@ -1271,143 +1272,69 @@ fn collect_and_partition_mono_items( (tcx.arena.alloc(mono_items), autodiff_items, codegen_units) } -pub fn typetree_empty() -> TypeTree { - TypeTree(vec![]) +#[allow(dead_code)] +enum Checks { + Slice(usize), + Enum(usize), } -pub fn typetree_from_ty<'a>(ty: Ty<'a>, tcx: TyCtxt<'a>, depth: usize) -> TypeTree { - if ty.is_unsafe_ptr() || ty.is_ref() || ty.is_box() { - if ty.is_fn_ptr() { - unimplemented!("what to do whith fn ptr?"); - } - let inner_ty = ty.builtin_deref(true).unwrap().ty; - let child = typetree_from_ty(inner_ty, tcx, depth + 1); - let tt = Type { offset: -1, kind: Kind::Pointer, size: 8, child }; - return TypeTree(vec![tt]); - } - - if ty.is_scalar() { - assert!(!ty.is_any_ptr()); - let (kind, size) = if ty.is_integral() { - (Kind::Integer, 8) - } else { - assert!(ty.is_floating_point()); - match ty { - x if x == tcx.types.f32 => (Kind::Float, 4), - x if x == tcx.types.f64 => (Kind::Double, 8), - _ => panic!("floatTy scalar that is neither f32 nor f64"), - } - }; - return TypeTree(vec![Type { offset: -1, child: typetree_empty(), kind, size }]); - } - - let param_env_and = ParamEnvAnd { param_env: ParamEnv::empty(), value: ty }; - - let layout = tcx.layout_of(param_env_and); - assert!(layout.is_ok()); - - let layout = layout.unwrap().layout; - let fields = layout.fields(); - let max_size = layout.size(); - - if ty.is_adt() { - let adt_def = ty.ty_adt_def().unwrap(); - let substs = match ty.kind() { - Adt(_, subst_ref) => subst_ref, - _ => panic!(""), - }; - - if adt_def.is_struct() { - let (offsets, _memory_index) = match fields { - FieldsShape::Arbitrary { offsets: o, memory_index: m } => (o, m), - _ => panic!(""), - }; - - let fields = adt_def.all_fields(); - let fields = fields - .into_iter() - .zip(offsets.into_iter()) - .filter_map(|(field, offset)| { - let field_ty: Ty<'_> = field.ty(tcx, substs); - let field_ty: Ty<'_> = - tcx.normalize_erasing_regions(ParamEnv::empty(), field_ty); - - if field_ty.is_phantom_data() { - return None; - } - - let mut child = typetree_from_ty(field_ty, tcx, depth + 1).0; - - for c in &mut child { - if c.offset == -1 { - c.offset = offset.bytes() as isize - } else { - c.offset += offset.bytes() as isize; - } - } - - Some(child) - }) - .flatten() - .collect::>(); - - let ret_tt = TypeTree(fields); - return ret_tt; - } else { - unimplemented!("adt that isn't a struct"); - } - } - - if ty.is_array() { - let (stride, count) = match fields { - FieldsShape::Array { stride: s, count: c } => (s, c), - _ => panic!(""), - }; - let byte_stride = stride.bytes_usize(); - let byte_max_size = max_size.bytes_usize(); - - assert!(byte_stride * *count as usize == byte_max_size); - assert!(*count > 0); // return empty TT for empty? - let sub_ty = ty.builtin_index().unwrap(); - let subtt = typetree_from_ty(sub_ty, tcx, depth + 1); - - // calculate size of subtree - let param_env_and = ParamEnvAnd { param_env: ParamEnv::empty(), value: sub_ty }; - let size = tcx.layout_of(param_env_and).unwrap().size.bytes() as usize; - let tt = TypeTree( - std::iter::repeat(subtt) - .take(*count as usize) - .enumerate() - .map(|(idx, x)| x.0.into_iter().map(move |x| x.add_offset((idx * size) as isize))) - .flatten() - .collect(), - ); - - return tt; - } - - if ty.is_slice() { - let sub_ty = ty.builtin_index().unwrap(); - let subtt = typetree_from_ty(sub_ty, tcx, depth + 1); - - return subtt; - } - - typetree_empty() -} - -pub fn fnc_typetrees<'tcx>(fn_ty: Ty<'tcx>, tcx: TyCtxt<'tcx>) -> (Vec, TypeTree) { - let fnc_binder: ty::Binder<'_, ty::FnSig<'_>> = fn_ty.fn_sig(tcx); - - // TODO: verify. - let x: ty::FnSig<'_> = fnc_binder.skip_binder(); - - let inputs = x.inputs().into_iter().map(|x| typetree_from_ty(*x, tcx, 0)).collect(); - - let output = typetree_from_ty(x.output(), tcx, 0); - - (inputs, output) -} +// ty is going to get duplicated. So we need to find DST +// inside to later make sure that it's shadow is at least +// equally large. +//pub fn check_for_types<'a>(ty: Ty<'a>, tcx: TyCtxt<'a>, pos: usize) { +// // Find all DST inside of this type. +// let johnnie = ty.walk(); +// let mut check_positions = 0; +// let mut check_vec = Vec::new(); +// while let Some(ty) = johnnie.next() { +// if ty.is_trivially_sized(tcx) { +// ty.skip_current_subtree() +// } +// if ty.is_slice() { +// check_vec.push(Checks::Slice(check_positions)); +// } +// if ty.is_enum() { +// check_vec.push(Checks::Enum(check_positions)); +// } +// if ty.is_adt() { +// let adt_def = ty.ty_adt_def().unwrap(); +// if adt_def.is_struct() { +// let fields = adt_def.all_fields(); +// for field in fields { +// let field_ty: Ty<'_> = field.ty(tcx, ty.substs); +// if field_ty.is_phantom_data() { +// continue; +// } +// if field_ty.is_trivially_sized(tcx) { +// continue; +// } +// check_vec.push(Checks::Enum(check_positions)); +// } +// } +// } +// +// //ty::Str | ty::Slice(_) | ty::Dynamic(..) | ty::Foreign(..) => false, +// //ty::Tuple(tys) => tys.iter().all(|ty| ty.is_trivially_sized(tcx)), +// //ty::Adt(def, _args) => def.sized_constraint(tcx).skip_binder().is_empty(), +// check_positions += 1; +// } +//} +//pub fn check_types<'tcx>(fn_ty: Ty<'tcx>, tcx: TyCtxt<'tcx>, activities: &[DiffActivity]) { +// let fnc_binder: ty::Binder<'_, ty::FnSig<'_>> = fn_ty.fn_sig(tcx); +// +// // TODO: verify. +// let x: ty::FnSig<'_> = fnc_binder.skip_binder(); +// let _inputs = x.inputs(); +// for i in 0..activities.len() { +// match activities[i] { +// DiffActivity::Const => continue, +// DiffActivity::Active => continue, +// DiffActivity::ActiveOnly => continue, +// _ => {}, +// } +// check_for_types(inputs[i], tcx, i); +// } +//} /// Outputs stats about instantiation counts and estimated size, per `MonoItem`'s /// def, to a file in the given output directory. From 15f738971cf7b2403fcb4f55db3918ba07f91be9 Mon Sep 17 00:00:00 2001 From: Jed Brown Date: Thu, 28 Mar 2024 09:33:22 -0600 Subject: [PATCH 057/100] ci: ensure nightly and rustup are available --- .github/workflows/enzyme-ci.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/enzyme-ci.yml b/.github/workflows/enzyme-ci.yml index edc48879479e8..ad3cedb23e75d 100644 --- a/.github/workflows/enzyme-ci.yml +++ b/.github/workflows/enzyme-ci.yml @@ -22,9 +22,10 @@ jobs: timeout-minutes: 600 steps: - name: checkout the source code - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: fetch-depth: 2 + - uses: dtolnay/rust-toolchain@nightly - name: build run: | mkdir build @@ -32,7 +33,6 @@ jobs: ../configure --enable-llvm-link-shared --enable-llvm-plugins --enable-llvm-enzyme --release-channel=nightly --enable-llvm-assertions --enable-lld --enable-option-checking --enable-ninja --disable-docs ../x.py build --stage 1 library/std library/proc_macro library/test tools/rustdoc rustup toolchain link enzyme `pwd`/build/`rustup target list --installed`/stage1 - rustup toolchain install nightly # enables -Z unstable-options - name: test run: | cargo +enzyme test --examples From cf3516313d8f917ac4a16f1c3790d0b7a7c4ffe2 Mon Sep 17 00:00:00 2001 From: Jed Brown Date: Thu, 28 Mar 2024 18:04:11 -0600 Subject: [PATCH 058/100] ci: cache LLVM build and use rustbook for testing --- .github/workflows/enzyme-ci.yml | 27 ++++++++++++++++++++++----- README.md | 2 +- 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/.github/workflows/enzyme-ci.yml b/.github/workflows/enzyme-ci.yml index ad3cedb23e75d..8e7e358bc16c5 100644 --- a/.github/workflows/enzyme-ci.yml +++ b/.github/workflows/enzyme-ci.yml @@ -21,18 +21,35 @@ jobs: timeout-minutes: 600 steps: - - name: checkout the source code + - name: Checkout Rust source uses: actions/checkout@v4 with: fetch-depth: 2 - uses: dtolnay/rust-toolchain@nightly - - name: build + - name: Get LLVM commit hash + id: llvm-commit + run: echo "HEAD=$(git -C src/llvm-project rev-parse HEAD)" >> $GITHUB_OUTPUT + - name: Cache LLVM + id: cache-llvm + uses: actions/cache@v4 + with: + path: build/build/x86_64-unknown-linux-gnu/llvm + key: ${{ matrix.os }}-llvm-${{ steps.llvm-commit.outputs.HEAD }} + - name: Build run: | - mkdir build + mkdir -p build cd build + rm -f config.toml ../configure --enable-llvm-link-shared --enable-llvm-plugins --enable-llvm-enzyme --release-channel=nightly --enable-llvm-assertions --enable-lld --enable-option-checking --enable-ninja --disable-docs ../x.py build --stage 1 library/std library/proc_macro library/test tools/rustdoc - rustup toolchain link enzyme `pwd`/build/`rustup target list --installed`/stage1 - - name: test + rustup toolchain link enzyme build/host/stage1 + - name: checkout Enzyme/rustbook + uses: actions/checkout@v4 + with: + repository: EnzymeAD/rustbook + ref: mdbook-test + path: rustbook + - name: test Enzyme/rustbook + working-directory: rustbook run: | cargo +enzyme test --examples diff --git a/README.md b/README.md index 425ed87f3718c..594c69d0453be 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ Afterwards you can build rustc using: Afterwards rustc toolchain link will allow you to use it through cargo: ``` -rustup toolchain link enzyme `pwd`/build/`rustup target list --installed`/stage1 +rustup toolchain link enzyme build/host/stage1 rustup toolchain install nightly # enables -Z unstable-options ``` From fceee05791e2b5f25e82dadb3ef7694179c488e7 Mon Sep 17 00:00:00 2001 From: Jed Brown Date: Thu, 28 Mar 2024 18:37:28 -0600 Subject: [PATCH 059/100] build: fix stamp logic for rebuilding enzyme --- src/bootstrap/src/core/build_steps/llvm.rs | 28 +++++++++++++++++++--- src/bootstrap/src/lib.rs | 3 +++ 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/src/bootstrap/src/core/build_steps/llvm.rs b/src/bootstrap/src/core/build_steps/llvm.rs index 4e18806493e3e..6736b985ebeba 100644 --- a/src/bootstrap/src/core/build_steps/llvm.rs +++ b/src/bootstrap/src/core/build_steps/llvm.rs @@ -844,13 +844,34 @@ impl Step for Enzyme { let LlvmResult { llvm_config, llvm_cmake_dir } = builder.ensure(Llvm { target }); + static STAMP_HASH_MEMO: OnceLock = OnceLock::new(); + let smart_stamp_hash = STAMP_HASH_MEMO.get_or_init(|| { + generate_smart_stamp_hash( + &builder.config.src.join("src/tools/enzyme"), + &builder.enzyme_info.sha().unwrap_or_default(), + ) + }); + let out_dir = builder.enzyme_out(target); - let done_stamp = out_dir.join("enzyme-finished-building"); - if done_stamp.exists() { + let stamp = out_dir.join("enzyme-finished-building"); + let stamp = HashStamp::new(stamp, Some(smart_stamp_hash)); + + if stamp.is_done() { + if stamp.hash.is_none() { + builder.info( + "Could not determine the Enzyme submodule commit hash. \ + Assuming that an Enzyme rebuild is not necessary.", + ); + builder.info(&format!( + "To force Enzyme to rebuild, remove the file `{}`", + stamp.path.display() + )); + } return out_dir; } builder.info(&format!("Building Enzyme for {}", target)); + t!(stamp.remove()); let _time = helpers::timeit(&builder); t!(fs::create_dir_all(&out_dir)); @@ -878,7 +899,8 @@ impl Step for Enzyme { cfg.build(); - t!(File::create(&done_stamp)); + t!(stamp.write()); + out_dir } } diff --git a/src/bootstrap/src/lib.rs b/src/bootstrap/src/lib.rs index 5038c888bca53..a733e55c8f9d6 100644 --- a/src/bootstrap/src/lib.rs +++ b/src/bootstrap/src/lib.rs @@ -168,6 +168,7 @@ pub struct Build { clippy_info: GitInfo, miri_info: GitInfo, rustfmt_info: GitInfo, + enzyme_info: GitInfo, in_tree_llvm_info: GitInfo, local_rebuild: bool, fail_fast: bool, @@ -331,6 +332,7 @@ impl Build { let clippy_info = GitInfo::new(omit_git_hash, &src.join("src/tools/clippy")); let miri_info = GitInfo::new(omit_git_hash, &src.join("src/tools/miri")); let rustfmt_info = GitInfo::new(omit_git_hash, &src.join("src/tools/rustfmt")); + let enzyme_info = GitInfo::new(omit_git_hash, &src.join("src/tools/enzyme")); // we always try to use git for LLVM builds let in_tree_llvm_info = GitInfo::new(false, &src.join("src/llvm-project")); @@ -413,6 +415,7 @@ impl Build { clippy_info, miri_info, rustfmt_info, + enzyme_info, in_tree_llvm_info, cc: RefCell::new(HashMap::new()), cxx: RefCell::new(HashMap::new()), From ea3d213777039dc4b72523081c410d8365f7181a Mon Sep 17 00:00:00 2001 From: Jed Brown Date: Thu, 28 Mar 2024 21:32:07 -0600 Subject: [PATCH 060/100] ci: rustbook tests workspace, not examples --- .github/workflows/enzyme-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/enzyme-ci.yml b/.github/workflows/enzyme-ci.yml index 8e7e358bc16c5..85ad9e62066fb 100644 --- a/.github/workflows/enzyme-ci.yml +++ b/.github/workflows/enzyme-ci.yml @@ -52,4 +52,4 @@ jobs: - name: test Enzyme/rustbook working-directory: rustbook run: | - cargo +enzyme test --examples + cargo +enzyme test --workspace From 518390f8e5023a4ec92913ff2a9afe49ffcecef8 Mon Sep 17 00:00:00 2001 From: Jed Brown Date: Thu, 28 Mar 2024 22:46:31 -0600 Subject: [PATCH 061/100] ci: cache Enzyme and use rustbook@main --- .github/workflows/enzyme-ci.yml | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/.github/workflows/enzyme-ci.yml b/.github/workflows/enzyme-ci.yml index 85ad9e62066fb..c233b049e1d6d 100644 --- a/.github/workflows/enzyme-ci.yml +++ b/.github/workflows/enzyme-ci.yml @@ -35,6 +35,15 @@ jobs: with: path: build/build/x86_64-unknown-linux-gnu/llvm key: ${{ matrix.os }}-llvm-${{ steps.llvm-commit.outputs.HEAD }} + - name: Get Enzyme commit hash + id: enzyme-commit + run: echo "HEAD=$(git -C src/tools/enzyme rev-parse HEAD)" >> $GITHUB_OUTPUT + - name: Cache Enzyme + id: cache-enzyme + uses: actions/cache@v4 + with: + path: build/build/x86_64-unknown-linux-gnu/enzyme + key: ${{ matrix.os }}-enzyme-${{ steps.enzyme-commit.outputs.HEAD }} - name: Build run: | mkdir -p build @@ -47,7 +56,7 @@ jobs: uses: actions/checkout@v4 with: repository: EnzymeAD/rustbook - ref: mdbook-test + ref: main path: rustbook - name: test Enzyme/rustbook working-directory: rustbook From e5f35f4b208a91da00010e28dd22ae124c8447e4 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Fri, 29 Mar 2024 14:32:39 -0400 Subject: [PATCH 062/100] add tt to mem calls --- compiler/rustc_codegen_llvm/src/abi.rs | 1 + compiler/rustc_codegen_llvm/src/base.rs | 13 +++- compiler/rustc_codegen_llvm/src/builder.rs | 64 +++++++++++++++++-- compiler/rustc_codegen_llvm/src/llvm/ffi.rs | 4 ++ compiler/rustc_codegen_ssa/src/base.rs | 16 ++++- .../rustc_codegen_ssa/src/mir/intrinsic.rs | 25 +++++++- compiler/rustc_codegen_ssa/src/mir/operand.rs | 2 +- compiler/rustc_codegen_ssa/src/mir/rvalue.rs | 6 +- .../rustc_codegen_ssa/src/mir/statement.rs | 2 +- .../rustc_codegen_ssa/src/traits/builder.rs | 5 ++ 10 files changed, 123 insertions(+), 15 deletions(-) diff --git a/compiler/rustc_codegen_llvm/src/abi.rs b/compiler/rustc_codegen_llvm/src/abi.rs index 97dc401251cf0..4f36f6c825373 100644 --- a/compiler/rustc_codegen_llvm/src/abi.rs +++ b/compiler/rustc_codegen_llvm/src/abi.rs @@ -248,6 +248,7 @@ impl<'ll, 'tcx> ArgAbiExt<'ll, 'tcx> for ArgAbi<'tcx, Ty<'tcx>> { scratch_align, bx.const_usize(self.layout.size.bytes()), MemFlags::empty(), + None, ); bx.lifetime_end(llscratch, scratch_size); diff --git a/compiler/rustc_codegen_llvm/src/base.rs b/compiler/rustc_codegen_llvm/src/base.rs index 5dc271ccddb7c..60b33eb72db2e 100644 --- a/compiler/rustc_codegen_llvm/src/base.rs +++ b/compiler/rustc_codegen_llvm/src/base.rs @@ -27,11 +27,13 @@ use rustc_data_structures::small_c_str::SmallCStr; use rustc_middle::dep_graph; use rustc_middle::middle::codegen_fn_attrs::CodegenFnAttrs; use rustc_middle::mir::mono::{Linkage, Visibility}; -use rustc_middle::ty::TyCtxt; use rustc_session::config::DebugInfo; use rustc_span::symbol::Symbol; use rustc_target::spec::SanitizerSet; +use rustc_middle::mir::mono::MonoItem; +use rustc_middle::ty::{ParamEnv, TyCtxt, fnc_typetrees}; + use std::time::Instant; pub struct ValueIter<'ll> { @@ -86,6 +88,15 @@ pub fn compile_codegen_unit(tcx: TyCtxt<'_>, cgu_name: Symbol) -> (ModuleCodegen let mono_items = cx.codegen_unit.items_in_deterministic_order(cx.tcx); for &(mono_item, data) in &mono_items { mono_item.predefine::>(&cx, data.linkage, data.visibility); + let inst = match mono_item { + MonoItem::Fn(instance) => instance, + _ => continue, + }; + let fn_ty = inst.ty(tcx, ParamEnv::empty()); + let _fnc_tree = fnc_typetrees(tcx, fn_ty); + //trace!("codegen_module: predefine fn {}", inst); + //trace!("{} \n {:?} \n {:?}", inst, fn_ty, _fnc_tree); + // Manuel: TODO } // ... and now that we have everything pre-defined, fill out those definitions. diff --git a/compiler/rustc_codegen_llvm/src/builder.rs b/compiler/rustc_codegen_llvm/src/builder.rs index 8f60175a6031c..3a82a40fa1035 100644 --- a/compiler/rustc_codegen_llvm/src/builder.rs +++ b/compiler/rustc_codegen_llvm/src/builder.rs @@ -30,6 +30,9 @@ use std::iter; use std::ops::Deref; use std::ptr; +use crate::typetree::to_enzyme_typetree; +use rustc_ast::expand::typetree::{TypeTree, FncTree}; + // All Builders must have an llfn associated with them #[must_use] pub struct Builder<'a, 'll, 'tcx> { @@ -134,6 +137,35 @@ macro_rules! builder_methods_for_value_instructions { } } +fn add_tt<'ll>(llmod: &'ll llvm::Module, llcx: &'ll llvm::Context,val: &'ll Value, tt: FncTree) { + let inputs = tt.args; + let _ret: TypeTree = tt.ret; + let llvm_data_layout: *const c_char = unsafe { llvm::LLVMGetDataLayoutStr(&*llmod) }; + let llvm_data_layout = + std::str::from_utf8(unsafe { std::ffi::CStr::from_ptr(llvm_data_layout) }.to_bytes()) + .expect("got a non-UTF8 data-layout from LLVM"); + let attr_name = "enzyme_type"; + let c_attr_name = std::ffi::CString::new(attr_name).unwrap(); + for (i, &ref input) in inputs.iter().enumerate() { + let c_tt = to_enzyme_typetree(input.clone(), llvm_data_layout, llcx); + let c_str = unsafe { llvm::EnzymeTypeTreeToString(c_tt.inner) }; + let c_str = unsafe { std::ffi::CStr::from_ptr(c_str) }; + unsafe { + let attr = llvm::LLVMCreateStringAttribute( + llcx, + c_attr_name.as_ptr(), + c_attr_name.as_bytes().len() as c_uint, + c_str.as_ptr(), + c_str.to_bytes().len() as c_uint, + ); + llvm::LLVMRustAddParamAttr(val, i as u32, attr); + } + unsafe { llvm::EnzymeTypeTreeToStringFree(c_str.as_ptr()) }; + } + dbg!(&val); +} + + impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { fn build(cx: &'a CodegenCx<'ll, 'tcx>, llbb: &'ll BasicBlock) -> Self { let bx = Builder::with_cx(cx); @@ -874,11 +906,12 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { src_align: Align, size: &'ll Value, flags: MemFlags, + tt: Option, ) { assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memcpy not supported"); let size = self.intcast(size, self.type_isize(), false); let is_volatile = flags.contains(MemFlags::VOLATILE); - unsafe { + let val = unsafe { llvm::LLVMRustBuildMemCpy( self.llbuilder, dst, @@ -887,7 +920,14 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { src_align.bytes() as c_uint, size, is_volatile, - ); + ) + }; + if let Some(tt) = tt { + let llmod = self.cx.llmod; + let llcx = self.cx.llcx; + add_tt(llmod, llcx, val, tt); + } else { + trace!("builder: no tt"); } } @@ -899,11 +939,12 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { src_align: Align, size: &'ll Value, flags: MemFlags, + tt: Option, ) { assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memmove not supported"); let size = self.intcast(size, self.type_isize(), false); let is_volatile = flags.contains(MemFlags::VOLATILE); - unsafe { + let val = unsafe { llvm::LLVMRustBuildMemMove( self.llbuilder, dst, @@ -912,7 +953,12 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { src_align.bytes() as c_uint, size, is_volatile, - ); + ) + }; + if let Some(tt) = tt { + let llmod = self.cx.llmod; + let llcx = self.cx.llcx; + add_tt(llmod, llcx, val, tt); } } @@ -923,9 +969,10 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { size: &'ll Value, align: Align, flags: MemFlags, + tt: Option, ) { let is_volatile = flags.contains(MemFlags::VOLATILE); - unsafe { + let val = unsafe { llvm::LLVMRustBuildMemSet( self.llbuilder, ptr, @@ -933,7 +980,12 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { fill_byte, size, is_volatile, - ); + ) + }; + if let Some(tt) = tt { + let llmod = self.cx.llmod; + let llcx = self.cx.llcx; + add_tt(llmod, llcx, val, tt); } } diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index bb2b01f09227e..b490dd40302df 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -945,6 +945,10 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff( // We don't support volatile / extern / (global?) values. // Just because I didn't had time to test them, and it seems less urgent. let args_uncacheable = vec![0; input_tts.len()]; + if args_uncacheable.len() != input_activity.len() { + dbg!("args_uncacheable.len(): {}", args_uncacheable.len()); + dbg!("input_activity.len(): {}", input_activity.len()); + } assert!(args_uncacheable.len() == input_activity.len()); let num_fnc_args = LLVMCountParams(fnc); println!("num_fnc_args: {}", num_fnc_args); diff --git a/compiler/rustc_codegen_ssa/src/base.rs b/compiler/rustc_codegen_ssa/src/base.rs index 94a8d2cc4d4c9..8fc840f48f5e1 100644 --- a/compiler/rustc_codegen_ssa/src/base.rs +++ b/compiler/rustc_codegen_ssa/src/base.rs @@ -44,6 +44,9 @@ use std::time::{Duration, Instant}; use itertools::Itertools; +use rustc_middle::ty::typetree_from; +use rustc_ast::expand::typetree::{TypeTree, FncTree}; + pub fn bin_op_to_icmp_predicate(op: hir::BinOpKind, signed: bool) -> IntPredicate { match op { hir::BinOpKind::Eq => IntPredicate::IntEQ, @@ -357,6 +360,7 @@ pub fn wants_new_eh_instructions(sess: &Session) -> bool { wants_wasm_eh(sess) || wants_msvc_seh(sess) } +// Manuel TODO pub fn memcpy_ty<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>( bx: &mut Bx, dst: Bx::Value, @@ -370,6 +374,13 @@ pub fn memcpy_ty<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>( if size == 0 { return; } + let my_ty = layout.ty; + let tcx: TyCtxt<'_> = bx.cx().tcx(); + let fnc_tree: TypeTree = typetree_from(tcx, my_ty); + let fnc_tree: FncTree = FncTree { + args: vec![fnc_tree.clone(), fnc_tree.clone()], + ret: TypeTree::new(), + }; if flags == MemFlags::empty() && let Some(bty) = bx.cx().scalar_copy_backend_type(layout) @@ -377,8 +388,11 @@ pub fn memcpy_ty<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>( let temp = bx.load(bty, src, src_align); bx.store(temp, dst, dst_align); } else { - bx.memcpy(dst, dst_align, src, src_align, bx.cx().const_usize(size), flags); + trace!("my_ty: {:?}, enzyme tt: {:?}", my_ty, fnc_tree); + trace!("memcpy_ty: {:?} -> {:?} (size={}, align={:?})", src, dst, size, dst_align); + bx.memcpy(dst, dst_align, src, src_align, bx.cx().const_usize(size), flags, Some(fnc_tree)); } + //let (_args, _ret): (Vec, TypeTree) = (fnc_tree.args, fnc_tree.ret); } pub fn codegen_instance<'a, 'tcx: 'a, Bx: BuilderMethods<'a, 'tcx>>( diff --git a/compiler/rustc_codegen_ssa/src/mir/intrinsic.rs b/compiler/rustc_codegen_ssa/src/mir/intrinsic.rs index a5bffc33d393c..e8364ca0545cd 100644 --- a/compiler/rustc_codegen_ssa/src/mir/intrinsic.rs +++ b/compiler/rustc_codegen_ssa/src/mir/intrinsic.rs @@ -16,6 +16,10 @@ use rustc_target::abi::{ WrappingRange, }; +use rustc_middle::ty::typetree_from; +use rustc_ast::expand::typetree::{TypeTree, FncTree}; +use crate::rustc_middle::ty::layout::HasTyCtxt; + fn copy_intrinsic<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>( bx: &mut Bx, allow_overlap: bool, @@ -25,15 +29,23 @@ fn copy_intrinsic<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>( src: Bx::Value, count: Bx::Value, ) { + let tcx: TyCtxt<'_> = bx.cx().tcx(); + let fnc_tree: TypeTree = typetree_from(tcx, ty); + let fnc_tree: FncTree = FncTree { + args: vec![fnc_tree.clone(), fnc_tree.clone()], + ret: TypeTree::new(), + }; + let layout = bx.layout_of(ty); let size = layout.size; let align = layout.align.abi; let size = bx.mul(bx.const_usize(size.bytes()), count); let flags = if volatile { MemFlags::VOLATILE } else { MemFlags::empty() }; + trace!("copy: mir ty: {:?}, enzyme tt: {:?}", ty, fnc_tree); if allow_overlap { - bx.memmove(dst, align, src, align, size, flags); + bx.memmove(dst, align, src, align, size, flags, Some(fnc_tree)); } else { - bx.memcpy(dst, align, src, align, size, flags); + bx.memcpy(dst, align, src, align, size, flags, Some(fnc_tree)); } } @@ -45,12 +57,19 @@ fn memset_intrinsic<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>( val: Bx::Value, count: Bx::Value, ) { + let tcx: TyCtxt<'_> = bx.cx().tcx(); + let fnc_tree: TypeTree = typetree_from(tcx, ty); + let fnc_tree: FncTree = FncTree { + args: vec![fnc_tree.clone(), fnc_tree.clone()], + ret: TypeTree::new(), + }; + let layout = bx.layout_of(ty); let size = layout.size; let align = layout.align.abi; let size = bx.mul(bx.const_usize(size.bytes()), count); let flags = if volatile { MemFlags::VOLATILE } else { MemFlags::empty() }; - bx.memset(dst, val, size, align, flags); + bx.memset(dst, val, size, align, flags, Some(fnc_tree)); } impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { diff --git a/compiler/rustc_codegen_ssa/src/mir/operand.rs b/compiler/rustc_codegen_ssa/src/mir/operand.rs index 794cbd315b795..7a6686de99a29 100644 --- a/compiler/rustc_codegen_ssa/src/mir/operand.rs +++ b/compiler/rustc_codegen_ssa/src/mir/operand.rs @@ -482,7 +482,7 @@ impl<'a, 'tcx, V: CodegenObject> OperandValue { let neg_address = bx.neg(address); let offset = bx.and(neg_address, align_minus_1); let dst = bx.inbounds_gep(bx.type_i8(), alloca, &[offset]); - bx.memcpy(dst, min_align, llptr, min_align, size, MemFlags::empty()); + bx.memcpy(dst, min_align, llptr, min_align, size, MemFlags::empty(), None); // Store the allocated region and the extra to the indirect place. let indirect_operand = OperandValue::Pair(dst, llextra); diff --git a/compiler/rustc_codegen_ssa/src/mir/rvalue.rs b/compiler/rustc_codegen_ssa/src/mir/rvalue.rs index 02b51dfe5bf7f..2ba03eb247d82 100644 --- a/compiler/rustc_codegen_ssa/src/mir/rvalue.rs +++ b/compiler/rustc_codegen_ssa/src/mir/rvalue.rs @@ -100,15 +100,17 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { // Use llvm.memset.p0i8.* to initialize all zero arrays if bx.cx().const_to_opt_u128(v, false) == Some(0) { + //let ty = bx.cx().val_ty(v); let fill = bx.cx().const_u8(0); - bx.memset(start, fill, size, dest.align, MemFlags::empty()); + bx.memset(start, fill, size, dest.align, MemFlags::empty(), None); return; } // Use llvm.memset.p0i8.* to initialize byte arrays let v = bx.from_immediate(v); if bx.cx().val_ty(v) == bx.cx().type_i8() { - bx.memset(start, v, size, dest.align, MemFlags::empty()); + //let ty = bx.cx().type_i8(); + bx.memset(start, v, size, dest.align, MemFlags::empty(), None); return; } } diff --git a/compiler/rustc_codegen_ssa/src/mir/statement.rs b/compiler/rustc_codegen_ssa/src/mir/statement.rs index a158fc6e26074..90cfd430f2dfd 100644 --- a/compiler/rustc_codegen_ssa/src/mir/statement.rs +++ b/compiler/rustc_codegen_ssa/src/mir/statement.rs @@ -85,7 +85,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { let align = pointee_layout.align; let dst = dst_val.immediate(); let src = src_val.immediate(); - bx.memcpy(dst, align, src, align, bytes, crate::MemFlags::empty()); + bx.memcpy(dst, align, src, align, bytes, crate::MemFlags::empty(), None); } mir::StatementKind::FakeRead(..) | mir::StatementKind::Retag { .. } diff --git a/compiler/rustc_codegen_ssa/src/traits/builder.rs b/compiler/rustc_codegen_ssa/src/traits/builder.rs index aa411f002a0c6..af2e85ab87caf 100644 --- a/compiler/rustc_codegen_ssa/src/traits/builder.rs +++ b/compiler/rustc_codegen_ssa/src/traits/builder.rs @@ -22,6 +22,8 @@ use rustc_target::abi::call::FnAbi; use rustc_target::abi::{Abi, Align, Scalar, Size, WrappingRange}; use rustc_target::spec::HasTargetSpec; +use rustc_ast::expand::typetree::FncTree; + #[derive(Copy, Clone)] pub enum OverflowOp { Add, @@ -238,6 +240,7 @@ pub trait BuilderMethods<'a, 'tcx>: src_align: Align, size: Self::Value, flags: MemFlags, + tt: Option, ); fn memmove( &mut self, @@ -247,6 +250,7 @@ pub trait BuilderMethods<'a, 'tcx>: src_align: Align, size: Self::Value, flags: MemFlags, + tt: Option, ); fn memset( &mut self, @@ -255,6 +259,7 @@ pub trait BuilderMethods<'a, 'tcx>: size: Self::Value, align: Align, flags: MemFlags, + tt: Option, ); fn select( From ccb9fab864674371b183c574f9d9b32166885336 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Fri, 29 Mar 2024 18:54:11 -0400 Subject: [PATCH 063/100] handle slice abi. Wohooo. Also other fixes --- compiler/rustc_codegen_llvm/src/back/write.rs | 56 +++++++++++++++++-- compiler/rustc_codegen_llvm/src/base.rs | 2 +- compiler/rustc_middle/src/ty/mod.rs | 52 ++++++++++++++--- .../rustc_monomorphize/src/partitioning.rs | 42 +++++++++----- 4 files changed, 124 insertions(+), 28 deletions(-) diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs index 4f1b1fbe35845..d916cfb7d8291 100644 --- a/compiler/rustc_codegen_llvm/src/back/write.rs +++ b/compiler/rustc_codegen_llvm/src/back/write.rs @@ -703,10 +703,56 @@ unsafe fn create_call<'a>(tgt: &'a Value, src: &'a Value, rev_mode: bool, let f_ty = LLVMRustGetFunctionType(src); let inner_param_num = LLVMCountParams(src); - let mut outer_args: Vec<&Value> = get_params(tgt); + let outer_param_num = LLVMCountParams(tgt); + let outer_args: Vec<&Value> = get_params(tgt); + let inner_args: Vec<&Value> = get_params(src); + let mut call_args: Vec<&Value> = vec![]; - if inner_param_num as usize != outer_args.len() { - panic!("Args len shouldn't differ. Please report this. {} : {}", inner_param_num, outer_args.len()); + if inner_param_num == outer_param_num { + call_args = outer_args; + } else { + dbg!("Different number of args, adjusting"); + let mut outer_pos: usize = 0; + let mut inner_pos: usize = 0; + // copy over if they are identical. + // If not, skip the outer arg (and assert it's int). + while outer_pos < outer_param_num as usize { + let inner_arg = inner_args[inner_pos]; + let outer_arg = outer_args[outer_pos]; + let inner_arg_ty = llvm::LLVMTypeOf(inner_arg); + let outer_arg_ty = llvm::LLVMTypeOf(outer_arg); + if inner_arg_ty == outer_arg_ty { + call_args.push(outer_arg); + inner_pos += 1; + outer_pos += 1; + } else { + // out: (ptr, <>int1, ptr, int2) + // inner: (ptr, <>ptr, int) + // goal: (ptr, ptr, int1), skipping int2 + // we are here: <> + assert!(llvm::LLVMRustGetTypeKind(outer_arg_ty) == llvm::TypeKind::Integer); + assert!(llvm::LLVMRustGetTypeKind(inner_arg_ty) == llvm::TypeKind::Pointer); + let next_outer_arg = outer_args[outer_pos + 1]; + let next_inner_arg = inner_args[inner_pos + 1]; + let next_outer_arg_ty = llvm::LLVMTypeOf(next_outer_arg); + let next_inner_arg_ty = llvm::LLVMTypeOf(next_inner_arg); + assert!(llvm::LLVMRustGetTypeKind(next_outer_arg_ty) == llvm::TypeKind::Pointer); + assert!(llvm::LLVMRustGetTypeKind(next_inner_arg_ty) == llvm::TypeKind::Integer); + let next2_outer_arg = outer_args[outer_pos + 2]; + let next2_outer_arg_ty = llvm::LLVMTypeOf(next2_outer_arg); + assert!(llvm::LLVMRustGetTypeKind(next2_outer_arg_ty) == llvm::TypeKind::Integer); + call_args.push(next_outer_arg); + call_args.push(outer_arg); + + outer_pos += 3; + inner_pos += 2; + } + } + } + + + if inner_param_num as usize != call_args.len() { + panic!("Args len shouldn't differ. Please report this. {} : {}", inner_param_num, call_args.len()); } let inner_fnc_name = llvm::get_value_name(src); @@ -719,8 +765,8 @@ unsafe fn create_call<'a>(tgt: &'a Value, src: &'a Value, rev_mode: bool, builder, f_ty, src, - outer_args.as_mut_ptr(), - outer_args.len(), + call_args.as_mut_ptr(), + call_args.len(), c_inner_fnc_name.as_ptr(), ); diff --git a/compiler/rustc_codegen_llvm/src/base.rs b/compiler/rustc_codegen_llvm/src/base.rs index 60b33eb72db2e..5d76510ab91ae 100644 --- a/compiler/rustc_codegen_llvm/src/base.rs +++ b/compiler/rustc_codegen_llvm/src/base.rs @@ -93,7 +93,7 @@ pub fn compile_codegen_unit(tcx: TyCtxt<'_>, cgu_name: Symbol) -> (ModuleCodegen _ => continue, }; let fn_ty = inst.ty(tcx, ParamEnv::empty()); - let _fnc_tree = fnc_typetrees(tcx, fn_ty); + let _fnc_tree = fnc_typetrees(tcx, fn_ty, &mut vec![]); //trace!("codegen_module: predefine fn {}", inst); //trace!("{} \n {:?} \n {:?}", inst, fn_ty, _fnc_tree); // Manuel: TODO diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs index d16da3fc46dbc..1239048155c18 100644 --- a/compiler/rustc_middle/src/ty/mod.rs +++ b/compiler/rustc_middle/src/ty/mod.rs @@ -2727,23 +2727,59 @@ pub fn typetree_from<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree { return TypeTree(vec![tt]); } -pub fn fnc_typetrees<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>) -> FncTree { +use rustc_ast::expand::autodiff_attrs::DiffActivity; + +pub fn fnc_typetrees<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec) -> FncTree { if !fn_ty.is_fn() { return FncTree { args: vec![], ret: TypeTree::new() }; } let fnc_binder: ty::Binder<'_, ty::FnSig<'_>> = fn_ty.fn_sig(tcx); - // TODO: verify. - let x: ty::FnSig<'_> = match fnc_binder.no_bound_vars() { - Some(x) => x, - None => return FncTree { args: vec![], ret: TypeTree::new() }, - }; + // TODO: cleanup + // Ok, sorry whoever reviews this. + // If we call this on arbitrary rust functions which we don't differentiate directly, + // then we have no da vec. We might encounter complex types, so do it properly. + // If we have a da vec, we know we are difrectly differentiating that fnc, + // so can assume it's something simpler, where skip_binder is ok (for now). + let x: ty::FnSig<'_>; + if da.is_empty() { + x = match fnc_binder.no_bound_vars() { + Some(x) => x, + None => return FncTree { args: vec![], ret: TypeTree::new() }, + } + } else { + x = fnc_binder.skip_binder(); + } + dbg!("creating fncTree"); + let mut offset = 0; let mut visited = vec![]; let mut args = vec![]; - for arg in x.inputs() { + for (i, ty) in x.inputs().iter().enumerate() { visited.clear(); - let arg_tt = typetree_from_ty(*arg, tcx, 0, false, &mut visited); + if ty.is_unsafe_ptr() || ty.is_ref() || ty.is_box() { + if ty.is_fn_ptr() { + unimplemented!("what to do whith fn ptr?"); + } + let inner_ty = ty.builtin_deref(true).unwrap().ty; + if inner_ty.is_slice() { + // We know that the lenght will be passed as extra arg. + let child = typetree_from_ty(inner_ty, tcx, 0, false, &mut visited); + let tt = Type { offset: -1, kind: Kind::Pointer, size: 8, child }; + args.push(TypeTree(vec![tt])); + let i64_tt = Type { offset: -1, kind: Kind::Integer, size: 8, child: TypeTree::new() }; + args.push(TypeTree(vec![i64_tt])); + if !da.is_empty() { + // We are looking at a slice. The length of that slice will become an + // extra integer on llvm level. Integers are always const. + da.insert(i + 1 + offset, DiffActivity::Const); + offset += 1; + } + dbg!("ABI MATCHING\n\n\n"); + continue; + } + } + let arg_tt = typetree_from_ty(*ty, tcx, 0, false, &mut visited); args.push(arg_tt); } diff --git a/compiler/rustc_monomorphize/src/partitioning.rs b/compiler/rustc_monomorphize/src/partitioning.rs index e15dbe9a7c9d6..6ab79c969291a 100644 --- a/compiler/rustc_monomorphize/src/partitioning.rs +++ b/compiler/rustc_monomorphize/src/partitioning.rs @@ -125,6 +125,9 @@ use crate::collector::UsageMap; use crate::collector::{self, MonoItemCollectionMode}; use crate::errors::{CouldntDumpMonoStats, SymbolAlreadyDefined, UnknownCguCollectionMode}; +use rustc_ast::expand::autodiff_attrs::DiffActivity; + + struct PartitioningCx<'a, 'tcx> { tcx: TyCtxt<'tcx>, usage_map: &'a UsageMap<'tcx>, @@ -1158,21 +1161,24 @@ fn collect_and_partition_mono_items( }) .collect(); - let autodiff_items = items + let autodiff_items2: Vec<_> = items .iter() .filter_map(|item| match *item { MonoItem::Fn(ref instance) => Some((item, instance)), _ => None, - }) - .filter_map(|(item, instance)| { + }).collect(); + let mut autodiff_items: Vec = vec![]; + + for (item, instance) in autodiff_items2 { let target_id = instance.def_id(); let target_attrs: &AutoDiffAttrs = tcx.autodiff_attrs(target_id); + let mut input_activities: Vec = target_attrs.input_activity.clone(); if target_attrs.is_source() { dbg!("source"); dbg!(&target_attrs); } if !target_attrs.apply_autodiff() { - return None; + continue; } let target_symbol = @@ -1189,17 +1195,25 @@ fn collect_and_partition_mono_items( } _ => None, }); + let inst = match source { + Some(source) => source, + None => continue, + }; - source.map(|inst| { - println!("source_id: {:?}", inst.def_id()); - let fnc_tree = fnc_typetrees(tcx, inst.ty(tcx, ParamEnv::empty())); - let (inputs, output) = (fnc_tree.args, fnc_tree.ret); - //check_types(inst.ty(tcx, ParamEnv::empty()), tcx, &target_attrs.input_activity); - let symb = symbol_name_for_instance_in_crate(tcx, inst.clone(), LOCAL_CRATE); - - target_attrs.clone().into_item(symb, target_symbol, inputs, output) - }) - }); + println!("source_id: {:?}", inst.def_id()); + let fn_ty = inst.ty(tcx, ParamEnv::empty()); + assert!(fn_ty.is_fn()); + let fnc_tree = fnc_typetrees(tcx, fn_ty, &mut input_activities); + let (inputs, output) = (fnc_tree.args, fnc_tree.ret); + //check_types(inst.ty(tcx, ParamEnv::empty()), tcx, &target_attrs.input_activity); + let symb = symbol_name_for_instance_in_crate(tcx, inst.clone(), LOCAL_CRATE); + + let mut new_target_attrs = target_attrs.clone(); + new_target_attrs.input_activity = input_activities; + let itm = new_target_attrs.into_item(symb, target_symbol, inputs, output); + dbg!(&itm); + autodiff_items.push(itm); + }; let autodiff_items = tcx.arena.alloc_from_iter(autodiff_items); From 7ac8be73aa56a13e091d0d6ed5308ac940cd40d8 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Fri, 29 Mar 2024 23:00:09 -0400 Subject: [PATCH 064/100] simplify binder handling (#83) * simplify binder handling --- compiler/rustc_middle/src/ty/mod.rs | 21 ++++++--------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs index 1239048155c18..3cfb7550b5a62 100644 --- a/compiler/rustc_middle/src/ty/mod.rs +++ b/compiler/rustc_middle/src/ty/mod.rs @@ -2735,21 +2735,12 @@ pub fn fnc_typetrees<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec> = fn_ty.fn_sig(tcx); - // TODO: cleanup - // Ok, sorry whoever reviews this. - // If we call this on arbitrary rust functions which we don't differentiate directly, - // then we have no da vec. We might encounter complex types, so do it properly. - // If we have a da vec, we know we are difrectly differentiating that fnc, - // so can assume it's something simpler, where skip_binder is ok (for now). - let x: ty::FnSig<'_>; - if da.is_empty() { - x = match fnc_binder.no_bound_vars() { - Some(x) => x, - None => return FncTree { args: vec![], ret: TypeTree::new() }, - } - } else { - x = fnc_binder.skip_binder(); - } + // If rustc compiles the unmodified primal, we know that this copy of the function + // also has correct lifetimes. We know that Enzyme won't free the shadow too early + // (or actually at all), so let's strip lifetimes when computing the layout. + // Recommended by compiler-errors: + // https://discord.com/channels/273534239310479360/957720175619215380/1223454360676208751 + let x = tcx.instantiate_bound_regions_with_erased(fnc_binder); dbg!("creating fncTree"); let mut offset = 0; From 3e116e51bf75fe31d3cfcf40d1318f6e11604828 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Sun, 31 Mar 2024 00:02:14 -0400 Subject: [PATCH 065/100] update enzyme (#85) --- src/tools/enzyme | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tools/enzyme b/src/tools/enzyme index 7e3e90f428706..328332fe1c3c9 160000 --- a/src/tools/enzyme +++ b/src/tools/enzyme @@ -1 +1 @@ -Subproject commit 7e3e90f4287068a41d3fb5a99127ee2857353b04 +Subproject commit 328332fe1c3c99a60445c13feee566c61b724999 From 1f65049c72dadad6872fd4e8e52a0c69f6dc54d2 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Sun, 31 Mar 2024 01:08:56 -0400 Subject: [PATCH 066/100] Upgr enzyme2 (#87) * update enzyme --- src/tools/enzyme | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tools/enzyme b/src/tools/enzyme index 328332fe1c3c9..0bf10624447d0 160000 --- a/src/tools/enzyme +++ b/src/tools/enzyme @@ -1 +1 @@ -Subproject commit 328332fe1c3c99a60445c13feee566c61b724999 +Subproject commit 0bf10624447d0f68eb537cc111a13f48f464f2e1 From ca1aa97aaa5ba3c34a1cc32889513ebcf94073e1 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Sun, 31 Mar 2024 11:25:11 -0400 Subject: [PATCH 067/100] No dbg call (#88) * move from dbg to trace --- .../rustc_ast/src/expand/autodiff_attrs.rs | 17 +++++++++---- compiler/rustc_codegen_llvm/src/back/write.rs | 20 ++++------------ compiler/rustc_codegen_llvm/src/builder.rs | 1 - compiler/rustc_codegen_llvm/src/llvm/ffi.rs | 24 +++++++++++++------ compiler/rustc_middle/src/ty/mod.rs | 9 +++---- .../rustc_monomorphize/src/partitioning.rs | 8 +++---- compiler/rustc_passes/src/check_attr.rs | 1 - 7 files changed, 41 insertions(+), 39 deletions(-) diff --git a/compiler/rustc_ast/src/expand/autodiff_attrs.rs b/compiler/rustc_ast/src/expand/autodiff_attrs.rs index 9cee364543af7..5855fc845520e 100644 --- a/compiler/rustc_ast/src/expand/autodiff_attrs.rs +++ b/compiler/rustc_ast/src/expand/autodiff_attrs.rs @@ -1,6 +1,6 @@ use crate::expand::typetree::TypeTree; use std::str::FromStr; -use std::fmt::{Display, Formatter}; +use std::fmt::{self, Display, Formatter}; use crate::ptr::P; use crate::{Ty, TyKind}; @@ -14,7 +14,7 @@ pub enum DiffMode { } impl Display for DiffMode { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { match self { DiffMode::Inactive => write!(f, "Inactive"), DiffMode::Source => write!(f, "Source"), @@ -209,7 +209,6 @@ impl AutoDiffAttrs { DiffMode::Inactive => false, DiffMode::Source => false, _ => { - dbg!(&self); true } } @@ -222,7 +221,6 @@ impl AutoDiffAttrs { inputs: Vec, output: TypeTree, ) -> AutoDiffItem { - dbg!(&self); AutoDiffItem { source, target, inputs, output, attrs: self } } } @@ -235,3 +233,14 @@ pub struct AutoDiffItem { pub inputs: Vec, pub output: TypeTree, } + +impl fmt::Display for AutoDiffItem { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Differentiating {} -> {}", self.source, self.target)?; + write!(f, " with attributes: {:?}", self.attrs)?; + write!(f, " with inputs: {:?}", self.inputs)?; + write!(f, " with output: {:?}", self.output) + } +} + + diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs index d916cfb7d8291..ec31d9dc0b7cd 100644 --- a/compiler/rustc_codegen_llvm/src/back/write.rs +++ b/compiler/rustc_codegen_llvm/src/back/write.rs @@ -692,7 +692,6 @@ fn get_params(fnc: &Value) -> Vec<&Value> { unsafe fn create_call<'a>(tgt: &'a Value, src: &'a Value, rev_mode: bool, llmod: &'a llvm::Module, llcx: &llvm::Context) { - dbg!(&tgt); // first, remove all calls from fnc let bb = LLVMGetFirstBasicBlock(tgt); let br = LLVMRustGetTerminator(bb); @@ -711,7 +710,7 @@ unsafe fn create_call<'a>(tgt: &'a Value, src: &'a Value, rev_mode: bool, if inner_param_num == outer_param_num { call_args = outer_args; } else { - dbg!("Different number of args, adjusting"); + trace!("Different number of args, adjusting"); let mut outer_pos: usize = 0; let mut inner_pos: usize = 0; // copy over if they are identical. @@ -784,16 +783,9 @@ unsafe fn create_call<'a>(tgt: &'a Value, src: &'a Value, rev_mode: bool, let md_val = LLVMMetadataAsValue(llcx, md); let md2 = llvm::LLVMSetMetadata(struct_ret, md_ty, md_val); } else { - dbg!("No dbg info"); - dbg!(&inst); + trace!("No dbg info"); } - // Our placeholder originally ended with `loop {}`, and therefore got the noreturn fnc attr. - // This is not true anymore, so we remove it. - //LLVMRustRemoveFncAttr(tgt, AttributeKind::NoReturn); - - dbg!(&tgt); - // Now clean up placeholder code. LLVMRustEraseInstBefore(bb, last_inst); @@ -809,15 +801,11 @@ unsafe fn create_call<'a>(tgt: &'a Value, src: &'a Value, rev_mode: bool, } } if f_return_type != void_type { - dbg!("Returning struct"); let _ret = LLVMBuildRet(builder, struct_ret); } else { - dbg!("Returning void"); let _ret = LLVMBuildRetVoid(builder); } LLVMDisposeBuilder(builder); - - dbg!(&tgt); let _fnc_ok = LLVMVerifyFunction(tgt, llvm::LLVMVerifierFailureAction::LLVMAbortProcessAction); } @@ -944,7 +932,9 @@ pub(crate) unsafe fn differentiate( _typetrees: FxHashMap, _config: &ModuleConfig, ) -> Result<(), FatalError> { - dbg!(&diff_items); + for item in &diff_items { + trace!("{}", item); + } let llmod = module.module_llvm.llmod(); let llcx = &module.module_llvm.llcx; diff --git a/compiler/rustc_codegen_llvm/src/builder.rs b/compiler/rustc_codegen_llvm/src/builder.rs index 3a82a40fa1035..f7afe9cbefb7a 100644 --- a/compiler/rustc_codegen_llvm/src/builder.rs +++ b/compiler/rustc_codegen_llvm/src/builder.rs @@ -162,7 +162,6 @@ fn add_tt<'ll>(llmod: &'ll llvm::Module, llcx: &'ll llvm::Context,val: &'ll Valu } unsafe { llvm::EnzymeTypeTreeToStringFree(c_str.as_ptr()) }; } - dbg!(&val); } diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index b490dd40302df..a3a89f7854f36 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -876,7 +876,6 @@ pub(crate) unsafe fn enzyme_rust_forward_diff( // We don't support volatile / extern / (global?) values. // Just because I didn't had time to test them, and it seems less urgent. - dbg!(&fnc); let args_uncacheable = vec![0; input_tts.len()]; assert!(args_uncacheable.len() == input_activity.len()); let num_fnc_args = LLVMCountParams(fnc); @@ -938,8 +937,6 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff( let input_activity: Vec = input_activity.iter().map(|&x| cdiffe_from(x)).collect(); - dbg!(&fnc); - let mut args_tree = input_tts.iter().map(|x| x.inner).collect::>(); // We don't support volatile / extern / (global?) values. @@ -964,9 +961,11 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff( KnownValues: known_values.as_mut_ptr(), }; - dbg!(&primary_ret); - dbg!(&ret_activity); - dbg!(&input_activity); + trace!("{}", &primary_ret); + trace!("{}", &ret_activity); + for i in &input_activity { + trace!("{}", &i); + } let res = EnzymeCreatePrimalAndGradient( logic_ref, // Logic std::ptr::null(), @@ -989,7 +988,6 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff( std::ptr::null_mut(), // write augmented function to this 0, ); - dbg!(&res); res } @@ -2790,6 +2788,18 @@ pub mod Shared_AD { DFT_DUP_NONEED = 3, } + impl fmt::Display for CDIFFE_TYPE { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let value = match self { + CDIFFE_TYPE::DFT_OUT_DIFF => "DFT_OUT_DIFF", + CDIFFE_TYPE::DFT_DUP_ARG => "DFT_DUP_ARG", + CDIFFE_TYPE::DFT_CONSTANT => "DFT_CONSTANT", + CDIFFE_TYPE::DFT_DUP_NONEED => "DFT_DUP_NONEED", + }; + write!(f, "{}", value) + } + } + pub fn cdiffe_from(act: DiffActivity) -> CDIFFE_TYPE { return match act { DiffActivity::None => CDIFFE_TYPE::DFT_CONSTANT, diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs index 3cfb7550b5a62..91f9b6846cff7 100644 --- a/compiler/rustc_middle/src/ty/mod.rs +++ b/compiler/rustc_middle/src/ty/mod.rs @@ -2741,7 +2741,6 @@ pub fn fnc_typetrees<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec(ty: Ty<'a>, tcx: TyCtxt<'a>, depth: usize, safety: bool, visited: &mut Vec>) -> TypeTree { if depth > 20 { - dbg!(&ty); + trace!("depth > 20 for ty: {}", &ty); } if visited.contains(&ty) { // recursive type - dbg!("recursive type"); - dbg!(&ty); + trace!("recursive type: {}", &ty); return TypeTree::new(); } visited.push(ty); @@ -2847,7 +2845,6 @@ pub fn typetree_from_ty<'a>(ty: Ty<'a>, tcx: TyCtxt<'a>, depth: usize, safety: b FieldsShape::Array { .. } => {return TypeTree::new();}, //e.g. core::arch::x86_64::__m128i, TODO: later FieldsShape::Union(_) => {return TypeTree::new();}, FieldsShape::Primitive => {return TypeTree::new();}, - //_ => {dbg!(&adt_def); panic!("")}, }; let substs = match ty.kind() { diff --git a/compiler/rustc_monomorphize/src/partitioning.rs b/compiler/rustc_monomorphize/src/partitioning.rs index 6ab79c969291a..0c3d8a64ea317 100644 --- a/compiler/rustc_monomorphize/src/partitioning.rs +++ b/compiler/rustc_monomorphize/src/partitioning.rs @@ -1174,8 +1174,7 @@ fn collect_and_partition_mono_items( let target_attrs: &AutoDiffAttrs = tcx.autodiff_attrs(target_id); let mut input_activities: Vec = target_attrs.input_activity.clone(); if target_attrs.is_source() { - dbg!("source"); - dbg!(&target_attrs); + trace!("source found: {:?}", target_id); } if !target_attrs.apply_autodiff() { continue; @@ -1211,7 +1210,6 @@ fn collect_and_partition_mono_items( let mut new_target_attrs = target_attrs.clone(); new_target_attrs.input_activity = input_activities; let itm = new_target_attrs.into_item(symb, target_symbol, inputs, output); - dbg!(&itm); autodiff_items.push(itm); }; @@ -1277,9 +1275,9 @@ fn collect_and_partition_mono_items( } } if autodiff_items.len() > 0 { - println!("AUTODIFF ITEMS EXIST"); + trace!("AUTODIFF ITEMS EXIST"); for item in &mut *autodiff_items { - dbg!(&item); + trace!("{}", &item); } } diff --git a/compiler/rustc_passes/src/check_attr.rs b/compiler/rustc_passes/src/check_attr.rs index 5e6c3ca9963d4..64243a4c438eb 100644 --- a/compiler/rustc_passes/src/check_attr.rs +++ b/compiler/rustc_passes/src/check_attr.rs @@ -2386,7 +2386,6 @@ impl CheckAttrVisitor<'_> { /// Checks if `#[autodiff]` is applied to an item other than a function item. fn check_autodiff(&self, _hir_id: HirId, _attr: &Attribute, span: Span, target: Target) { - dbg!("check_autodiff"); match target { Target::Fn => {} _ => { From 3a582e4b87888a8b6b336406327fae9da0583020 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Sun, 31 Mar 2024 12:14:06 -0400 Subject: [PATCH 068/100] Update enzyme (#89) * updating enzyme * unrelatedly remove last dbg --- compiler/rustc_builtin_macros/src/autodiff.rs | 1 - src/tools/enzyme | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index 674f20577934c..66dd7f746b789 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -124,7 +124,6 @@ pub fn expand( // This allows us to potentially parse other attributes. return vec![item]; } - dbg!(&x); let span = ecx.with_def_site_ctxt(expand_span); let n_active: u32 = x.input_activity.iter() diff --git a/src/tools/enzyme b/src/tools/enzyme index 0bf10624447d0..cad0ac27dfcfc 160000 --- a/src/tools/enzyme +++ b/src/tools/enzyme @@ -1 +1 @@ -Subproject commit 0bf10624447d0f68eb537cc111a13f48f464f2e1 +Subproject commit cad0ac27dfcfc5be075a2cdc8b751018f25197af From 415c7146f1a5b2243e139b585a362e37e7da6d51 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Sun, 31 Mar 2024 12:31:47 -0400 Subject: [PATCH 069/100] make post opts optional (#90) --- compiler/rustc_codegen_llvm/src/back/write.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs index ec31d9dc0b7cd..68f6b4c4aa4a8 100644 --- a/compiler/rustc_codegen_llvm/src/back/write.rs +++ b/compiler/rustc_codegen_llvm/src/back/write.rs @@ -872,7 +872,10 @@ pub(crate) unsafe fn enzyme_ad( item.inputs.into_iter().map(|x| to_enzyme_typetree(x, llvm_data_layout, llcx)).collect(); let output_tt = to_enzyme_typetree(item.output, llvm_data_layout, llcx); - let opt = 1; + let mut opt = 1; + if std::env::var("ENZYME_DISABLE_OPTS").is_ok() { + opt = 0; + } let logic_ref: EnzymeLogicRef = CreateEnzymeLogic(opt as u8); let type_analysis: EnzymeTypeAnalysisRef = CreateTypeAnalysis(logic_ref, std::ptr::null_mut(), std::ptr::null_mut(), 0); From fca0b585ffaa8f4a5d1a464cb0a7215428fb4fd7 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Sun, 31 Mar 2024 15:49:57 -0400 Subject: [PATCH 070/100] prevent ice by first checking that macro is applicable (#91) --- compiler/rustc_builtin_macros/src/autodiff.rs | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index 66dd7f746b789..b901b89cb41ff 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -88,9 +88,6 @@ pub fn expand( return vec![item]; } }; - let mut orig_item: P = item.clone().expect_item(); - let primal = orig_item.ident.clone(); - // Allow using `#[autodiff(...)]` only on a Fn let (has_ret, sig, sig_span) = if let Annotatable::Item(item) = &item && let ItemKind::Fn(box ast::Fn { sig, .. }) = &item.kind @@ -100,6 +97,11 @@ pub fn expand( ecx.sess.dcx().emit_err(errors::AutoDiffInvalidApplication { span: item.span() }); return vec![item]; }; + + // Now we know that item is a Item::Fn + let mut orig_item: P = item.clone().expect_item(); + let primal = orig_item.ident.clone(); + // create TokenStream from vec elemtents: // meta_item doesn't have a .tokens field let comma: Token = Token::new(TokenKind::Comma, Span::default()); From e0e8180fae84d7ab1f9076b062086a706f71fcc7 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Sun, 31 Mar 2024 19:45:11 -0400 Subject: [PATCH 071/100] catch wrong spelled activities (#92) --- compiler/rustc_builtin_macros/messages.ftl | 1 + compiler/rustc_builtin_macros/src/autodiff.rs | 19 ++++++++++++------- compiler/rustc_builtin_macros/src/errors.rs | 7 +++++++ 3 files changed, 20 insertions(+), 7 deletions(-) diff --git a/compiler/rustc_builtin_macros/messages.ftl b/compiler/rustc_builtin_macros/messages.ftl index 6f67f9effe090..5a261700dfe90 100644 --- a/compiler/rustc_builtin_macros/messages.ftl +++ b/compiler/rustc_builtin_macros/messages.ftl @@ -1,6 +1,7 @@ builtin_macros_alloc_error_must_be_fn = alloc_error_handler must be a function builtin_macros_alloc_must_statics = allocators must be statics +builtin_macros_autodiff_unknown_activity = did not recognize activity {$act} builtin_macros_autodiff = autodiff must be applied to function builtin_macros_autodiff_not_build = this rustc version does not support autodiff builtin_macros_autodiff_mode_activity = {$act} can not be used in {$mode} Mode diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index b901b89cb41ff..49f1b743d1824 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -53,13 +53,18 @@ pub fn from_ast(ecx: &mut ExtCtxt<'_>, meta_item: &ThinVec, has_ return AutoDiffAttrs::inactive(); }, }; - let activities: Vec = meta_item[2..] - .iter() - .map(|x| { - let activity_str = name(&x); - DiffActivity::from_str(&activity_str).unwrap() - }) - .collect(); + let mut activities: Vec = vec![]; + for x in &meta_item[2..] { + let activity_str = name(&x); + let res = DiffActivity::from_str(&activity_str); + match res { + Ok(x) => activities.push(x), + Err(_) => { + ecx.sess.dcx().emit_err(errors::AutoDiffUnknownActivity { span: x.span(), act: activity_str }); + return AutoDiffAttrs::inactive(); + } + }; + } // If a return type exist, we need to split the last activity, // otherwise we return None as placeholder. diff --git a/compiler/rustc_builtin_macros/src/errors.rs b/compiler/rustc_builtin_macros/src/errors.rs index 763372f569f91..03321e889749a 100644 --- a/compiler/rustc_builtin_macros/src/errors.rs +++ b/compiler/rustc_builtin_macros/src/errors.rs @@ -164,6 +164,13 @@ pub(crate) struct AllocMustStatics { pub(crate) span: Span, } #[derive(Diagnostic)] +#[diag(builtin_macros_autodiff_unknown_activity)] +pub(crate) struct AutoDiffUnknownActivity { + #[primary_span] + pub(crate) span: Span, + pub(crate) act: String, +} +#[derive(Diagnostic)] #[diag(builtin_macros_autodiff_ty_activity)] pub(crate) struct AutoDiffInvalidTypeForActivity { #[primary_span] From eafa14285bf2ce588880b0f6d7448b588a534e43 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Mon, 1 Apr 2024 02:36:12 -0400 Subject: [PATCH 072/100] safety check for slices! (#93) --- .../rustc_ast/src/expand/autodiff_attrs.rs | 2 + compiler/rustc_codegen_llvm/src/back/write.rs | 132 +++++++++++++++++- compiler/rustc_codegen_llvm/src/llvm/ffi.rs | 29 +++- compiler/rustc_middle/src/ty/mod.rs | 11 +- 4 files changed, 159 insertions(+), 15 deletions(-) diff --git a/compiler/rustc_ast/src/expand/autodiff_attrs.rs b/compiler/rustc_ast/src/expand/autodiff_attrs.rs index 5855fc845520e..4da71ad6650a9 100644 --- a/compiler/rustc_ast/src/expand/autodiff_attrs.rs +++ b/compiler/rustc_ast/src/expand/autodiff_attrs.rs @@ -109,6 +109,7 @@ pub enum DiffActivity { DualOnly, Duplicated, DuplicatedOnly, + FakeActivitySize } impl Display for DiffActivity { @@ -122,6 +123,7 @@ impl Display for DiffActivity { DiffActivity::DualOnly => write!(f, "DualOnly"), DiffActivity::Duplicated => write!(f, "Duplicated"), DiffActivity::DuplicatedOnly => write!(f, "DuplicatedOnly"), + DiffActivity::FakeActivitySize => write!(f, "FakeActivitySize"), } } } diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs index 68f6b4c4aa4a8..2bbe1f03f91d8 100644 --- a/compiler/rustc_codegen_llvm/src/back/write.rs +++ b/compiler/rustc_codegen_llvm/src/back/write.rs @@ -1,6 +1,8 @@ #![allow(unused_imports)] #![allow(unused_variables)] use crate::llvm::LLVMGetFirstBasicBlock; +use crate::llvm::LLVMBuildCondBr; +use crate::llvm::LLVMBuildICmp; use crate::llvm::LLVMBuildRetVoid; use crate::llvm::LLVMRustEraseInstBefore; use crate::llvm::LLVMRustHasDbgMetadata; @@ -47,6 +49,7 @@ use crate::typetree::to_enzyme_typetree; use crate::DiffTypeTree; use crate::LlvmCodegenBackend; use crate::ModuleLlvm; +use llvm::IntPredicate; use llvm::LLVMRustDISetInstMetadata; use llvm::{ LLVMRustLLVMHasZlibCompressionForDebugSymbols, LLVMRustLLVMHasZstdCompressionForDebugSymbols, LLVMGetNextBasicBlock, @@ -691,7 +694,8 @@ fn get_params(fnc: &Value) -> Vec<&Value> { unsafe fn create_call<'a>(tgt: &'a Value, src: &'a Value, rev_mode: bool, - llmod: &'a llvm::Module, llcx: &llvm::Context) { + llmod: &'a llvm::Module, llcx: &llvm::Context, size_positions: &[usize]) { + dbg!("size_positions: {:?}", size_positions); // first, remove all calls from fnc let bb = LLVMGetFirstBasicBlock(tgt); let br = LLVMRustGetTerminator(bb); @@ -707,6 +711,11 @@ unsafe fn create_call<'a>(tgt: &'a Value, src: &'a Value, rev_mode: bool, let inner_args: Vec<&Value> = get_params(src); let mut call_args: Vec<&Value> = vec![]; + let mut safety_vals = vec![]; + let builder = LLVMCreateBuilderInContext(llcx); + let last_inst = LLVMRustGetLastInstruction(bb).unwrap(); + LLVMPositionBuilderAtEnd(builder, bb); + if inner_param_num == outer_param_num { call_args = outer_args; } else { @@ -745,21 +754,71 @@ unsafe fn create_call<'a>(tgt: &'a Value, src: &'a Value, rev_mode: bool, outer_pos += 3; inner_pos += 2; + + // Now we assert if int1 <= int2 + let res = LLVMBuildICmp( + builder, + IntPredicate::IntULE as u32, + outer_arg, + next2_outer_arg, + "safety_check".as_ptr() as *const c_char); + safety_vals.push(res); } } } - if inner_param_num as usize != call_args.len() { panic!("Args len shouldn't differ. Please report this. {} : {}", inner_param_num, call_args.len()); } + // Now add the safety checks. + if !safety_vals.is_empty() { + dbg!("Adding safety checks"); + // first we create one bb per check and two more for the fail and success case. + let fail_bb = LLVMAppendBasicBlockInContext(llcx, tgt, "ad_safety_fail".as_ptr() as *const c_char); + let success_bb = LLVMAppendBasicBlockInContext(llcx, tgt, "ad_safety_success".as_ptr() as *const c_char); + let mut err_bb = vec![]; + for i in 0..safety_vals.len() { + let name: String = format!("ad_safety_err_{}", i); + err_bb.push(LLVMAppendBasicBlockInContext(llcx, tgt, name.as_ptr() as *const c_char)); + } + for (i, &val) in safety_vals.iter().enumerate() { + LLVMBuildCondBr(builder, val, err_bb[i], fail_bb); + LLVMPositionBuilderAtEnd(builder, err_bb[i]); + } + LLVMBuildCondBr(builder, safety_vals.last().unwrap(), success_bb, fail_bb); + LLVMPositionBuilderAtEnd(builder, fail_bb); + + + + let mut arg_vec = vec![add_panic_msg_to_global(llmod, llcx)]; + let name1 = "_ZN4core9panicking14panic_explicit17h8607a79b2acfb83bE"; + let name2 = "_RN4core9panicking14panic_explicit17h8607a79b2acfb83bE"; + let cname1 = CString::new(name1).unwrap(); + let cname2 = CString::new(name2).unwrap(); + + let fnc1 = llvm::LLVMGetNamedFunction(llmod, cname1.as_ptr() as *const c_char); + let call; + if fnc1.is_none() { + let fnc2 = llvm::LLVMGetNamedFunction(llmod, cname1.as_ptr() as *const c_char); + assert!(fnc2.is_some()); + let fnc2 = fnc2.unwrap(); + let ty = LLVMRustGetFunctionType(fnc2); + // now call with msg + call = LLVMBuildCall2(builder, ty, fnc2, arg_vec.as_mut_ptr(), arg_vec.len(), name2.as_ptr() as *const c_char); + } else { + let fnc1 = fnc1.unwrap(); + let ty = LLVMRustGetFunctionType(fnc1); + call = LLVMBuildCall2(builder, ty, fnc1, arg_vec.as_mut_ptr(), arg_vec.len(), name1.as_ptr() as *const c_char); + } + llvm::LLVMSetTailCall(call, 1); + llvm::LLVMBuildUnreachable(builder); + LLVMPositionBuilderAtEnd(builder, success_bb); + } + let inner_fnc_name = llvm::get_value_name(src); let c_inner_fnc_name = CString::new(inner_fnc_name).unwrap(); - let builder = LLVMCreateBuilderInContext(llcx); - let last_inst = LLVMRustGetLastInstruction(bb).unwrap(); - LLVMPositionBuilderAtEnd(builder, bb); let mut struct_ret = LLVMBuildCall2( builder, f_ty, @@ -788,6 +847,7 @@ unsafe fn create_call<'a>(tgt: &'a Value, src: &'a Value, rev_mode: bool, // Now clean up placeholder code. LLVMRustEraseInstBefore(bb, last_inst); + //dbg!(&tgt); let f_return_type = LLVMGetReturnType(LLVMGlobalGetValueType(src)); let void_type = LLVMVoidTypeInContext(llcx); @@ -810,6 +870,61 @@ unsafe fn create_call<'a>(tgt: &'a Value, src: &'a Value, rev_mode: bool, LLVMVerifyFunction(tgt, llvm::LLVMVerifierFailureAction::LLVMAbortProcessAction); } +unsafe fn add_panic_msg_to_global<'a>(llmod: &'a llvm::Module, llcx: &'a llvm::Context) -> &'a llvm::Value { + use llvm::*; + + // Convert the message to a CString + let msg = "autodiff safety check failed!"; + let cmsg = CString::new(msg).unwrap(); + + let msg_global_name = "ad_safety_msg".to_string(); + let cmsg_global_name = CString::new(msg_global_name).unwrap(); + + // Get the length of the message + let msg_len = msg.len(); + + // Create the array type + let i8_array_type = LLVMRustArrayType(LLVMInt8TypeInContext(llcx), msg_len as u64); + + // Create the string constant + let string_const_val = LLVMConstStringInContext(llcx, cmsg.as_ptr() as *const i8, msg_len as u32, 0); + + // Create the array initializer + let mut array_elems: Vec<_> = Vec::with_capacity(msg_len); + for i in 0..msg_len { + let char_value = LLVMConstInt(LLVMInt8TypeInContext(llcx), cmsg.as_bytes()[i] as u64, 0); + array_elems.push(char_value); + } + let array_initializer = LLVMConstArray(LLVMInt8TypeInContext(llcx), array_elems.as_mut_ptr(), msg_len as u32); + + // Create the struct type + let global_type = LLVMStructTypeInContext(llcx, [i8_array_type].as_mut_ptr(), 1, 0); + + // Create the struct initializer + let struct_initializer = LLVMConstStructInContext(llcx, [array_initializer].as_mut_ptr(), 1, 0); + + // Add the global variable to the module + let global_var = LLVMAddGlobal(llmod, global_type, cmsg_global_name.as_ptr() as *const i8); + LLVMRustSetLinkage(global_var, Linkage::PrivateLinkage); + LLVMSetInitializer(global_var, struct_initializer); + + //let msg_global_name = "ad_safety_msg".to_string(); + //let cmsg_global_name = CString::new(msg_global_name).unwrap(); + //let msg = "autodiff safety check failed!"; + //let cmsg = CString::new(msg).unwrap(); + //let msg_len = msg.len(); + //let i8_array_type = llvm::LLVMRustArrayType(llvm::LLVMInt8TypeInContext(llcx), msg_len as u64); + //let global_type = llvm::LLVMStructTypeInContext(llcx, [i8_array_type].as_mut_ptr(), 1, 0); + //let string_const_val = llvm::LLVMConstStringInContext(llcx, cmsg.as_ptr() as *const c_char, msg_len as u32, 0); + //let initializer = llvm::LLVMConstStructInContext(llcx, [string_const_val].as_mut_ptr(), 1, 0); + //let global = llvm::LLVMAddGlobal(llmod, global_type, cmsg_global_name.as_ptr() as *const c_char); + //llvm::LLVMRustSetLinkage(global, llvm::Linkage::PrivateLinkage); + //llvm::LLVMSetInitializer(global, initializer); + //llvm::LLVMSetUnnamedAddress(global, llvm::UnnamedAddr::Global); + + global_var +} + // As unsafe as it can be. #[allow(unused_variables)] #[allow(unused)] @@ -895,7 +1010,7 @@ pub(crate) unsafe fn enzyme_ad( llvm::set_print(true); } - let mut res: &Value = match item.attrs.mode { + let mut tmp = match item.attrs.mode { DiffMode::Forward => enzyme_rust_forward_diff( logic_ref, type_analysis, @@ -916,11 +1031,14 @@ pub(crate) unsafe fn enzyme_ad( ), _ => unreachable!(), }; + let mut res: &Value = tmp.0; + let size_positions: Vec = tmp.1; + let f_return_type = LLVMGetReturnType(LLVMGlobalGetValueType(res)); let void_type = LLVMVoidTypeInContext(llcx); let rev_mode = item.attrs.mode == DiffMode::Reverse; - create_call(target_fnc, res, rev_mode, llmod, llcx); + create_call(target_fnc, res, rev_mode, llmod, llcx, &size_positions); // TODO: implement drop for wrapper type? FreeTypeAnalysis(type_analysis); diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index a3a89f7854f36..b86cdeb09bdcb 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -850,7 +850,7 @@ pub(crate) unsafe fn enzyme_rust_forward_diff( ret_diffactivity: DiffActivity, input_tts: Vec, output_tt: TypeTree, -) -> &Value { +) -> (&Value, Vec) { let ret_activity = cdiffe_from(ret_diffactivity); assert!(ret_activity != CDIFFE_TYPE::DFT_OUT_DIFF); let mut input_activity: Vec = vec![]; @@ -893,7 +893,7 @@ pub(crate) unsafe fn enzyme_rust_forward_diff( KnownValues: known_values.as_mut_ptr(), }; - EnzymeCreateForwardDiff( + let res = EnzymeCreateForwardDiff( logic_ref, // Logic std::ptr::null(), std::ptr::null(), @@ -911,18 +911,19 @@ pub(crate) unsafe fn enzyme_rust_forward_diff( args_uncacheable.as_ptr(), args_uncacheable.len(), // uncacheable arguments std::ptr::null_mut(), // write augmented function to this - ) + ); + (res, vec![]) } pub(crate) unsafe fn enzyme_rust_reverse_diff( logic_ref: EnzymeLogicRef, type_analysis: EnzymeTypeAnalysisRef, fnc: &Value, - input_activity: Vec, + rust_input_activity: Vec, ret_activity: DiffActivity, input_tts: Vec, output_tt: TypeTree, -) -> &Value { +) -> (&Value, Vec) { let (primary_ret, ret_activity) = match ret_activity { DiffActivity::Const => (true, CDIFFE_TYPE::DFT_CONSTANT), DiffActivity::Active => (true, CDIFFE_TYPE::DFT_OUT_DIFF), @@ -935,7 +936,16 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff( // https://github.com/EnzymeAD/Enzyme.jl/blob/a511e4e6979d6161699f5c9919d49801c0764a09/src/compiler.jl#L3092 let diff_ret = false; - let input_activity: Vec = input_activity.iter().map(|&x| cdiffe_from(x)).collect(); + let mut primal_sizes = vec![]; + let mut input_activity: Vec = vec![]; + for (i, &x) in rust_input_activity.iter().enumerate() { + if is_size(x) { + primal_sizes.push(i); + input_activity.push(CDIFFE_TYPE::DFT_CONSTANT); + continue; + } + input_activity.push(cdiffe_from(x)); + } let mut args_tree = input_tts.iter().map(|x| x.inner).collect::>(); @@ -988,7 +998,7 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff( std::ptr::null_mut(), // write augmented function to this 0, ); - res + (res, primal_sizes) } extern "C" { @@ -2810,9 +2820,14 @@ pub mod Shared_AD { DiffActivity::DualOnly => CDIFFE_TYPE::DFT_DUP_NONEED, DiffActivity::Duplicated => CDIFFE_TYPE::DFT_DUP_ARG, DiffActivity::DuplicatedOnly => CDIFFE_TYPE::DFT_DUP_NONEED, + DiffActivity::FakeActivitySize => panic!("Implementation error"), }; } + pub fn is_size(act: DiffActivity) -> bool { + return act == DiffActivity::FakeActivitySize; + } + #[repr(u32)] #[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] pub enum CDerivativeMode { diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs index 91f9b6846cff7..0e63a2dcf48e3 100644 --- a/compiler/rustc_middle/src/ty/mod.rs +++ b/compiler/rustc_middle/src/ty/mod.rs @@ -2762,7 +2762,16 @@ pub fn fnc_typetrees<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec DiffActivity::FakeActivitySize, + DiffActivity::Const => DiffActivity::Const, + _ => panic!("unexpected activity for ptr/ref"), + }; + da.insert(i + 1 + offset, activity); offset += 1; } trace!("ABI MATCHING!"); From b288111287eafbfb5e16dccfef3deb9910d9344f Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Mon, 1 Apr 2024 14:59:15 -0400 Subject: [PATCH 073/100] update enzyme (#94) --- src/tools/enzyme | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tools/enzyme b/src/tools/enzyme index cad0ac27dfcfc..b1f676c07cd45 160000 --- a/src/tools/enzyme +++ b/src/tools/enzyme @@ -1 +1 @@ -Subproject commit cad0ac27dfcfc5be075a2cdc8b751018f25197af +Subproject commit b1f676c07cd45fa3fe711b7bd2297178d55061d2 From 672f84ad659cb70503d8d5550d90e356aa2de62a Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Mon, 1 Apr 2024 18:59:33 -0400 Subject: [PATCH 074/100] default to C++-Enzyme opt pipeline (#95) --- compiler/rustc_codegen_llvm/src/back/lto.rs | 4 +- compiler/rustc_codegen_llvm/src/back/write.rs | 56 ++++++++++++++----- 2 files changed, 46 insertions(+), 14 deletions(-) diff --git a/compiler/rustc_codegen_llvm/src/back/lto.rs b/compiler/rustc_codegen_llvm/src/back/lto.rs index cf5badf7b99c1..f28af4ba49bad 100644 --- a/compiler/rustc_codegen_llvm/src/back/lto.rs +++ b/compiler/rustc_codegen_llvm/src/back/lto.rs @@ -616,7 +616,9 @@ pub(crate) fn run_pass_manager( } let opt_stage = if thin { llvm::OptStage::ThinLTO } else { llvm::OptStage::FatLTO }; let opt_level = config.opt_level.unwrap_or(config::OptLevel::No); - write::llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage)?; + // We will run this again with different values in the context of automatic differentiation. + let first_run = true; + write::llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage, first_run)?; } debug!("lto done"); Ok(()) diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs index 2bbe1f03f91d8..e09aec69ec44c 100644 --- a/compiler/rustc_codegen_llvm/src/back/write.rs +++ b/compiler/rustc_codegen_llvm/src/back/write.rs @@ -549,6 +549,7 @@ pub(crate) unsafe fn llvm_optimize( config: &ModuleConfig, opt_level: config::OptLevel, opt_stage: llvm::OptStage, + first_run: bool, ) -> Result<(), FatalError> { // Enzyme: // We want to simplify / optimize functions before AD. @@ -556,16 +557,23 @@ pub(crate) unsafe fn llvm_optimize( // tend to reduce AD performance. Therefore activate them first, then differentiate the code // and finally re-optimize the module, now with all optimizations available. // RIP compile time. - // let unroll_loops = - // opt_level != config::OptLevel::Size && opt_level != config::OptLevel::SizeMin; + let unroll_loops; + let vectorize_slp; + let vectorize_loop; - - let _unroll_loops = + if first_run { + unroll_loops = false; + vectorize_slp = false; + vectorize_loop = false; + } else { + unroll_loops = opt_level != config::OptLevel::Size && opt_level != config::OptLevel::SizeMin; - let unroll_loops = false; - let vectorize_slp = false; - let vectorize_loop = false; + vectorize_slp = config.vectorize_slp; + vectorize_loop = config.vectorize_loop; + dbg!("Enzyme: Running with unroll_loops: {}, vectorize_slp: {}, vectorize_loop: {}", unroll_loops, vectorize_slp, vectorize_loop); + } + let using_thin_buffers = opt_stage == llvm::OptStage::PreLinkThinLTO || config.bitcode_needed(); let pgo_gen_path = get_pgo_gen_path(config); let pgo_use_path = get_pgo_use_path(config); @@ -987,11 +995,13 @@ pub(crate) unsafe fn enzyme_ad( item.inputs.into_iter().map(|x| to_enzyme_typetree(x, llvm_data_layout, llcx)).collect(); let output_tt = to_enzyme_typetree(item.output, llvm_data_layout, llcx); - let mut opt = 1; - if std::env::var("ENZYME_DISABLE_OPTS").is_ok() { - opt = 0; + let mut fnc_opt = false; + if std::env::var("ENZYME_ENABLE_FNC_OPT").is_ok() { + dbg!("Disabling optimizations for Enzyme"); + fnc_opt = true; } - let logic_ref: EnzymeLogicRef = CreateEnzymeLogic(opt as u8); + + let logic_ref: EnzymeLogicRef = CreateEnzymeLogic(fnc_opt as u8); let type_analysis: EnzymeTypeAnalysisRef = CreateTypeAnalysis(logic_ref, std::ptr::null_mut(), std::ptr::null_mut(), 0); @@ -1051,7 +1061,7 @@ pub(crate) unsafe fn differentiate( cgcx: &CodegenContext, diff_items: Vec, _typetrees: FxHashMap, - _config: &ModuleConfig, + config: &ModuleConfig, ) -> Result<(), FatalError> { for item in &diff_items { trace!("{}", item); @@ -1086,6 +1096,7 @@ pub(crate) unsafe fn differentiate( llvm::set_max_type_offset(width); } + let differentiate = !diff_items.is_empty(); for item in diff_items { let res = enzyme_ad(llmod, llcx, &diag_handler, item); assert!(res.is_ok()); @@ -1126,6 +1137,23 @@ pub(crate) unsafe fn differentiate( } } + + if std::env::var("ENZYME_NO_MOD_OPT_AFTER").is_ok() || !differentiate { + trace!("Skipping module optimization after automatic differentiation"); + } else { + if let Some(opt_level) = config.opt_level { + let opt_stage = match cgcx.lto { + Lto::Fat => llvm::OptStage::PreLinkFatLTO, + Lto::Thin | Lto::ThinLocal => llvm::OptStage::PreLinkThinLTO, + _ if cgcx.opts.cg.linker_plugin_lto.enabled() => llvm::OptStage::PreLinkThinLTO, + _ => llvm::OptStage::PreLinkNoLTO, + }; + let first_run = false; + dbg!("Running Module Optimization after differentiation"); + llvm_optimize(cgcx, &diag_handler, module, config, opt_level, opt_stage, first_run)?; + } + } + Ok(()) } @@ -1198,7 +1226,9 @@ pub(crate) unsafe fn optimize( _ if cgcx.opts.cg.linker_plugin_lto.enabled() => llvm::OptStage::PreLinkThinLTO, _ => llvm::OptStage::PreLinkNoLTO, }; - return llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage); + // Second run only relevant for AD + let first_run = true; + return llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage, first_run); } Ok(()) } From af8a8a610d99226659818d32528c6050d19244d2 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Mon, 1 Apr 2024 19:09:14 -0400 Subject: [PATCH 075/100] fix wrong variable usage --- compiler/rustc_codegen_llvm/src/back/write.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs index e09aec69ec44c..fca72d269a380 100644 --- a/compiler/rustc_codegen_llvm/src/back/write.rs +++ b/compiler/rustc_codegen_llvm/src/back/write.rs @@ -808,7 +808,7 @@ unsafe fn create_call<'a>(tgt: &'a Value, src: &'a Value, rev_mode: bool, let fnc1 = llvm::LLVMGetNamedFunction(llmod, cname1.as_ptr() as *const c_char); let call; if fnc1.is_none() { - let fnc2 = llvm::LLVMGetNamedFunction(llmod, cname1.as_ptr() as *const c_char); + let fnc2 = llvm::LLVMGetNamedFunction(llmod, cname2.as_ptr() as *const c_char); assert!(fnc2.is_some()); let fnc2 = fnc2.unwrap(); let ty = LLVMRustGetFunctionType(fnc2); From 7468cf8a108956acbdf406faccbe27874acad1ef Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Mon, 1 Apr 2024 21:16:06 -0400 Subject: [PATCH 076/100] fix panic calling code and iteration order (#96) * fix panic calling code and iteration order * handle rust name mangling * fix iteration order --- compiler/rustc_codegen_llvm/src/back/write.rs | 46 +++++++++++-------- compiler/rustc_middle/src/ty/mod.rs | 15 ++++-- 2 files changed, 38 insertions(+), 23 deletions(-) diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs index fca72d269a380..52f657e7c46d6 100644 --- a/compiler/rustc_codegen_llvm/src/back/write.rs +++ b/compiler/rustc_codegen_llvm/src/back/write.rs @@ -798,27 +798,15 @@ unsafe fn create_call<'a>(tgt: &'a Value, src: &'a Value, rev_mode: bool, LLVMPositionBuilderAtEnd(builder, fail_bb); + let panic_name: CString = get_panic_name(llmod); let mut arg_vec = vec![add_panic_msg_to_global(llmod, llcx)]; - let name1 = "_ZN4core9panicking14panic_explicit17h8607a79b2acfb83bE"; - let name2 = "_RN4core9panicking14panic_explicit17h8607a79b2acfb83bE"; - let cname1 = CString::new(name1).unwrap(); - let cname2 = CString::new(name2).unwrap(); - - let fnc1 = llvm::LLVMGetNamedFunction(llmod, cname1.as_ptr() as *const c_char); - let call; - if fnc1.is_none() { - let fnc2 = llvm::LLVMGetNamedFunction(llmod, cname2.as_ptr() as *const c_char); - assert!(fnc2.is_some()); - let fnc2 = fnc2.unwrap(); - let ty = LLVMRustGetFunctionType(fnc2); - // now call with msg - call = LLVMBuildCall2(builder, ty, fnc2, arg_vec.as_mut_ptr(), arg_vec.len(), name2.as_ptr() as *const c_char); - } else { - let fnc1 = fnc1.unwrap(); - let ty = LLVMRustGetFunctionType(fnc1); - call = LLVMBuildCall2(builder, ty, fnc1, arg_vec.as_mut_ptr(), arg_vec.len(), name1.as_ptr() as *const c_char); - } + + let fnc1 = llvm::LLVMGetNamedFunction(llmod, panic_name.as_ptr() as *const c_char); + assert!(fnc1.is_some()); + let fnc1 = fnc1.unwrap(); + let ty = LLVMRustGetFunctionType(fnc1); + let call = LLVMBuildCall2(builder, ty, fnc1, arg_vec.as_mut_ptr(), arg_vec.len(), panic_name.as_ptr() as *const c_char); llvm::LLVMSetTailCall(call, 1); llvm::LLVMBuildUnreachable(builder); LLVMPositionBuilderAtEnd(builder, success_bb); @@ -877,7 +865,25 @@ unsafe fn create_call<'a>(tgt: &'a Value, src: &'a Value, rev_mode: bool, let _fnc_ok = LLVMVerifyFunction(tgt, llvm::LLVMVerifierFailureAction::LLVMAbortProcessAction); } - +unsafe fn get_panic_name(llmod: &llvm::Module) -> CString { + // The names are mangled and their ending changes based on a hash, so just take whichever. + let mut f = LLVMGetFirstFunction(llmod); + loop { + if let Some(lf) = f { + f = LLVMGetNextFunction(lf); + let fnc_name = llvm::get_value_name(lf); + let fnc_name: String = String::from_utf8(fnc_name.to_vec()).unwrap(); + if fnc_name.starts_with("_ZN4core9panicking14panic_explicit") { + return CString::new(fnc_name).unwrap(); + } else if fnc_name.starts_with("_RN4core9panicking14panic_explicit") { + return CString::new(fnc_name).unwrap(); + } + } else { + break; + } + } + panic!("Could not find panic function"); +} unsafe fn add_panic_msg_to_global<'a>(llmod: &'a llvm::Module, llcx: &'a llvm::Context) -> &'a llvm::Value { use llvm::*; diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs index 0e63a2dcf48e3..4002ee4ace6ff 100644 --- a/compiler/rustc_middle/src/ty/mod.rs +++ b/compiler/rustc_middle/src/ty/mod.rs @@ -2742,7 +2742,8 @@ pub fn fnc_typetrees<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec DiffActivity::Const, _ => panic!("unexpected activity for ptr/ref"), }; - da.insert(i + 1 + offset, activity); - offset += 1; + new_activities.push(activity); + new_positions.push(i + 1); } trace!("ABI MATCHING!"); continue; @@ -2782,6 +2783,14 @@ pub fn fnc_typetrees<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec Date: Tue, 2 Apr 2024 02:43:28 -0400 Subject: [PATCH 077/100] enable runtime activity (#98) --- compiler/rustc_codegen_llvm/src/back/write.rs | 5 +++++ compiler/rustc_codegen_llvm/src/llvm/ffi.rs | 7 +++++++ 2 files changed, 12 insertions(+) diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs index 52f657e7c46d6..961934d1fa667 100644 --- a/compiler/rustc_codegen_llvm/src/back/write.rs +++ b/compiler/rustc_codegen_llvm/src/back/write.rs @@ -1102,6 +1102,11 @@ pub(crate) unsafe fn differentiate( llvm::set_max_type_offset(width); } + if std::env::var("ENZYME_RUNTIME_ACTIVITY").is_ok() { + dbg!("Setting runtime activity check to true"); + llvm::set_runtime_activity_check(true); + } + let differentiate = !diff_items.is_empty(); for item in diff_items { let res = enzyme_ad(llmod, llcx, &diag_handler, item); diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index b86cdeb09bdcb..77352e94c6d0b 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -2657,6 +2657,7 @@ pub mod Fallback_AD { pub fn EnzymeSetCLBool(arg1: *mut ::std::os::raw::c_void, arg2: u8) { unimplemented!() } pub fn EnzymeSetCLInteger(arg1: *mut ::std::os::raw::c_void, arg2: i64) { unimplemented!() } + pub fn set_runtime_activity_check(check: bool) { unimplemented!() } pub fn set_max_int_offset(offset: u64) { unimplemented!() } pub fn set_max_type_offset(offset: u64) { unimplemented!() } pub fn set_max_type_depth(depth: u64) { unimplemented!() } @@ -3009,6 +3010,7 @@ extern "C" { static mut MaxTypeOffset: c_void; static mut EnzymeMaxTypeDepth: c_void; + static mut EnzymeRuntimeActivityCheck: c_void; static mut EnzymePrintPerf: c_void; static mut EnzymePrintActivity: c_void; static mut EnzymePrintType: c_void; @@ -3016,6 +3018,11 @@ extern "C" { static mut EnzymeStrictAliasing: c_void; static mut looseTypeAnalysis: c_void; } +pub fn set_runtime_activity_check(check: bool) { + unsafe { + EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymeRuntimeActivityCheck), check as u8); + } +} pub fn set_max_int_offset(offset: u64) { let offset = offset.try_into().unwrap(); unsafe { From 683aa31d5b112ec822c1cd09779ab972d5453fdb Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Tue, 2 Apr 2024 02:53:06 -0400 Subject: [PATCH 078/100] printing (#99) --- compiler/rustc_codegen_llvm/src/back/write.rs | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs index 961934d1fa667..dd5168ead45e5 100644 --- a/compiler/rustc_codegen_llvm/src/back/write.rs +++ b/compiler/rustc_codegen_llvm/src/back/write.rs @@ -1084,7 +1084,7 @@ pub(crate) unsafe fn differentiate( llvm::set_loose_types(true); } - if std::env::var("ENZYME_PRINT_MOD").is_ok() { + if std::env::var("ENZYME_PRINT_MOD_BEFORE").is_ok() { unsafe { LLVMDumpModule(llmod); } @@ -1142,7 +1142,7 @@ pub(crate) unsafe fn differentiate( break; } } - if std::env::var("ENZYME_PRINT_MOD_AFTER").is_ok() { + if std::env::var("ENZYME_PRINT_MOD_AFTER_ENZYME").is_ok() { unsafe { LLVMDumpModule(llmod); } @@ -1159,12 +1159,22 @@ pub(crate) unsafe fn differentiate( _ if cgcx.opts.cg.linker_plugin_lto.enabled() => llvm::OptStage::PreLinkThinLTO, _ => llvm::OptStage::PreLinkNoLTO, }; - let first_run = false; + let mut first_run = false; dbg!("Running Module Optimization after differentiation"); + if std::env::var("ENZYME_NO_VEC_UNROLL").is_ok() { + // disables vectorization and loop unrolling + first_run = true; + } llvm_optimize(cgcx, &diag_handler, module, config, opt_level, opt_stage, first_run)?; } } + if std::env::var("ENZYME_PRINT_MOD_AFTER_OPTS").is_ok() { + unsafe { + LLVMDumpModule(llmod); + } + } + Ok(()) } From 8e61557a6acb4c697290b69a8572336feda0f889 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Tue, 2 Apr 2024 14:12:15 -0400 Subject: [PATCH 079/100] new flag (#101) --- compiler/rustc_codegen_llvm/src/back/write.rs | 6 ++++++ compiler/rustc_codegen_llvm/src/llvm/ffi.rs | 7 +++++++ 2 files changed, 13 insertions(+) diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs index dd5168ead45e5..54dca723d8dc6 100644 --- a/compiler/rustc_codegen_llvm/src/back/write.rs +++ b/compiler/rustc_codegen_llvm/src/back/write.rs @@ -1089,6 +1089,12 @@ pub(crate) unsafe fn differentiate( LLVMDumpModule(llmod); } } + + if std::env::var("ENZYME_INLINE").is_ok() { + dbg!("Setting inline to true"); + llvm::set_inline(true); + } + if std::env::var("ENZYME_TT_DEPTH").is_ok() { let depth = std::env::var("ENZYME_TT_DEPTH").unwrap(); let depth = depth.parse::().unwrap(); diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index 77352e94c6d0b..a86e6ce8d65f0 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -2657,6 +2657,7 @@ pub mod Fallback_AD { pub fn EnzymeSetCLBool(arg1: *mut ::std::os::raw::c_void, arg2: u8) { unimplemented!() } pub fn EnzymeSetCLInteger(arg1: *mut ::std::os::raw::c_void, arg2: i64) { unimplemented!() } + pub fn set_inline(val: bool) { unimplemented!() } pub fn set_runtime_activity_check(check: bool) { unimplemented!() } pub fn set_max_int_offset(offset: u64) { unimplemented!() } pub fn set_max_type_offset(offset: u64) { unimplemented!() } @@ -3017,6 +3018,7 @@ extern "C" { static mut EnzymePrint: c_void; static mut EnzymeStrictAliasing: c_void; static mut looseTypeAnalysis: c_void; + static mut EnzymeInline: c_void; } pub fn set_runtime_activity_check(check: bool) { unsafe { @@ -3071,6 +3073,11 @@ pub fn set_loose_types(loose: bool) { EnzymeSetCLBool(std::ptr::addr_of_mut!(looseTypeAnalysis), loose as u8); } } +pub fn set_inline(val: bool) { + unsafe { + EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymeInline), val as u8); + } +} extern "C" { pub fn EnzymeCreatePrimalAndGradient<'a>( From eab555d386d52d18cae89a39c84ca7e11dc41986 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Tue, 2 Apr 2024 16:50:45 -0400 Subject: [PATCH 080/100] updatenzyme (#102) --- src/tools/enzyme | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tools/enzyme b/src/tools/enzyme index b1f676c07cd45..10bf684c27439 160000 --- a/src/tools/enzyme +++ b/src/tools/enzyme @@ -1 +1 @@ -Subproject commit b1f676c07cd45fa3fe711b7bd2297178d55061d2 +Subproject commit 10bf684c27439a9ba47388d6bd14ed66be2dd3d7 From 1218cb24d608c5ab434802637797ba1bb0171cca Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Tue, 2 Apr 2024 20:22:58 -0400 Subject: [PATCH 081/100] Warn if the shadow will have internal immutability (#103) * Warn if the shadow will have internal immutability * improve error msg --- compiler/rustc_codegen_llvm/src/base.rs | 2 +- compiler/rustc_middle/messages.ftl | 2 + compiler/rustc_middle/src/error.rs | 8 ++ compiler/rustc_middle/src/ty/mod.rs | 75 +++++++++++++++---- .../rustc_monomorphize/src/partitioning.rs | 3 +- 5 files changed, 75 insertions(+), 15 deletions(-) diff --git a/compiler/rustc_codegen_llvm/src/base.rs b/compiler/rustc_codegen_llvm/src/base.rs index 5d76510ab91ae..344c5763483c8 100644 --- a/compiler/rustc_codegen_llvm/src/base.rs +++ b/compiler/rustc_codegen_llvm/src/base.rs @@ -93,7 +93,7 @@ pub fn compile_codegen_unit(tcx: TyCtxt<'_>, cgu_name: Symbol) -> (ModuleCodegen _ => continue, }; let fn_ty = inst.ty(tcx, ParamEnv::empty()); - let _fnc_tree = fnc_typetrees(tcx, fn_ty, &mut vec![]); + let _fnc_tree = fnc_typetrees(tcx, fn_ty, &mut vec![], None); //trace!("codegen_module: predefine fn {}", inst); //trace!("{} \n {:?} \n {:?}", inst, fn_ty, _fnc_tree); // Manuel: TODO diff --git a/compiler/rustc_middle/messages.ftl b/compiler/rustc_middle/messages.ftl index 31e3594115665..a1e8e0b1d8f90 100644 --- a/compiler/rustc_middle/messages.ftl +++ b/compiler/rustc_middle/messages.ftl @@ -1,6 +1,8 @@ middle_adjust_for_foreign_abi_error = target architecture {$arch} does not support `extern {$abi}` ABI +middle_autodiff_unsafe_inner_const_ref = reading from a `Duplicated` const {$ty} is unsafe + middle_assert_async_resume_after_panic = `async fn` resumed after panicking middle_assert_async_resume_after_return = `async fn` resumed after completion diff --git a/compiler/rustc_middle/src/error.rs b/compiler/rustc_middle/src/error.rs index 4f763663a50fe..b362b9facc98b 100644 --- a/compiler/rustc_middle/src/error.rs +++ b/compiler/rustc_middle/src/error.rs @@ -17,6 +17,14 @@ pub struct DropCheckOverflow<'tcx> { pub overflow_ty: Ty<'tcx>, } +#[derive(Diagnostic)] +#[diag(middle_autodiff_unsafe_inner_const_ref)] +pub struct AutodiffUnsafeInnerConstRef { + #[primary_span] + pub span: Span, + pub ty: String, +} + #[derive(Diagnostic)] #[diag(middle_opaque_hidden_type_mismatch)] pub struct OpaqueHiddenTypeMismatch<'tcx> { diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs index 4002ee4ace6ff..ed0b08f823367 100644 --- a/compiler/rustc_middle/src/ty/mod.rs +++ b/compiler/rustc_middle/src/ty/mod.rs @@ -21,7 +21,7 @@ pub use self::AssocItemContainer::*; pub use self::BorrowKind::*; pub use self::IntVarValue::*; pub use self::Variance::*; -use crate::error::{OpaqueHiddenTypeMismatch, TypeMismatchReason, UnsupportedUnion}; +use crate::error::{OpaqueHiddenTypeMismatch, TypeMismatchReason, UnsupportedUnion, AutodiffUnsafeInnerConstRef}; use crate::metadata::ModChild; use crate::middle::privacy::EffectiveVisibilities; use crate::mir::{Body, CoroutineLayout}; @@ -105,6 +105,7 @@ pub use self::typeck_results::{ CanonicalUserType, CanonicalUserTypeAnnotation, CanonicalUserTypeAnnotations, IsIdentity, TypeckResults, UserType, UserTypeAnnotationIndex, }; +use crate::query::Key; pub mod _match; pub mod abstract_const; @@ -2722,14 +2723,25 @@ mod size_asserts { pub fn typetree_from<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree { let mut visited = vec![]; - let ty = typetree_from_ty(ty, tcx, 0, false, &mut visited); + let ty = typetree_from_ty(ty, tcx, 0, false, &mut visited, None); let tt = Type { offset: -1, kind: Kind::Pointer, size: 8, child: ty }; return TypeTree(vec![tt]); } use rustc_ast::expand::autodiff_attrs::DiffActivity; -pub fn fnc_typetrees<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec) -> FncTree { +// This function combines three tasks. To avoid traversing each type 3x, we combine them. +// 1. Create a TypeTree from a Ty. This is the main task. +// 2. IFF da is not empty, we also want to adjust DiffActivity to account for future MIR->LLVM +// lowering. E.g. fat ptr are going to introduce an extra int. +// 3. IFF da is not empty, we are creating TT for a function directly differentiated (has an +// autodiff macro on top). Here we want to make sure that shadows are mutable internally. +// We know the outermost ref/ptr indirection is mutability - we generate it like that. +// We now have to make sure that inner ptr/ref are mutable too, or issue a warning. +// Not an error, becaues it only causes issues if they are actually read, which we don't check +// yet. We should add such analysis to relibably either issue an error or accept without warning. +// If there only were some reasearch to do that... +pub fn fnc_typetrees<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec, span: Option) -> FncTree { if !fn_ty.is_fn() { return FncTree { args: vec![], ret: TypeTree::new() }; } @@ -2747,6 +2759,19 @@ pub fn fnc_typetrees<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec true, + _ => false, + } + } else { + false + }; + visited.clear(); if ty.is_unsafe_ptr() || ty.is_ref() || ty.is_box() { if ty.is_fn_ptr() { @@ -2755,7 +2780,7 @@ pub fn fnc_typetrees<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec(ty: Ty<'a>, tcx: TyCtxt<'a>, depth: usize, safety: bool, visited: &mut Vec>) -> TypeTree { +fn typetree_from_ty<'a>(ty: Ty<'a>, tcx: TyCtxt<'a>, depth: usize, safety: bool, visited: &mut Vec>, span: Option) -> TypeTree { if depth > 20 { trace!("depth > 20 for ty: {}", &ty); } @@ -2812,9 +2837,33 @@ pub fn typetree_from_ty<'a>(ty: Ty<'a>, tcx: TyCtxt<'a>, depth: usize, safety: b if ty.is_fn_ptr() { unimplemented!("what to do whith fn ptr?"); } - let inner_ty = ty.builtin_deref(true).unwrap().ty; + + let inner_ty_and_mut = ty.builtin_deref(true).unwrap(); + let is_mut = inner_ty_and_mut.mutbl == hir::Mutability::Mut; + let inner_ty = inner_ty_and_mut.ty; + + // Now account for inner mutability. + if !is_mut && depth > 0 && safety { + let ptr_ty: String = if ty.is_ref() { + "ref" + } else if ty.is_unsafe_ptr() { + "ptr" + } else { + assert!(ty.is_box()); + "box" + }.to_string(); + + // If we have mutability, we also have a span + assert!(span.is_some()); + let span = span.unwrap(); + + tcx.sess + .dcx() + .emit_warning(AutodiffUnsafeInnerConstRef{span, ty: ptr_ty}); + } + //visited.push(inner_ty); - let child = typetree_from_ty(inner_ty, tcx, depth + 1, safety, visited); + let child = typetree_from_ty(inner_ty, tcx, depth + 1, safety, visited, span); let tt = Type { offset: -1, kind: Kind::Pointer, size: 8, child }; visited.pop(); return TypeTree(vec![tt]); @@ -2884,7 +2933,7 @@ pub fn typetree_from_ty<'a>(ty: Ty<'a>, tcx: TyCtxt<'a>, depth: usize, safety: b } //visited.push(field_ty); - let mut child = typetree_from_ty(field_ty, tcx, depth + 1, safety, visited).0; + let mut child = typetree_from_ty(field_ty, tcx, depth + 1, safety, visited, span).0; for c in &mut child { if c.offset == -1 { @@ -2916,7 +2965,7 @@ pub fn typetree_from_ty<'a>(ty: Ty<'a>, tcx: TyCtxt<'a>, depth: usize, safety: b trace!("simd"); let (_size, inner_ty) = ty.simd_size_and_type(tcx); //visited.push(inner_ty); - let _sub_tt = typetree_from_ty(inner_ty, tcx, depth + 1, safety, visited); + let _sub_tt = typetree_from_ty(inner_ty, tcx, depth + 1, safety, visited, span); //let tt = TypeTree( // std::iter::repeat(subtt) // .take(*count as usize) @@ -2944,7 +2993,7 @@ pub fn typetree_from_ty<'a>(ty: Ty<'a>, tcx: TyCtxt<'a>, depth: usize, safety: b } let sub_ty = ty.builtin_index().unwrap(); //visited.push(sub_ty); - let subtt = typetree_from_ty(sub_ty, tcx, depth + 1, safety, visited); + let subtt = typetree_from_ty(sub_ty, tcx, depth + 1, safety, visited, span); // calculate size of subtree let param_env_and = ParamEnvAnd { param_env: ParamEnv::empty(), value: sub_ty }; @@ -2965,7 +3014,7 @@ pub fn typetree_from_ty<'a>(ty: Ty<'a>, tcx: TyCtxt<'a>, depth: usize, safety: b if ty.is_slice() { let sub_ty = ty.builtin_index().unwrap(); //visited.push(sub_ty); - let subtt = typetree_from_ty(sub_ty, tcx, depth + 1, safety, visited); + let subtt = typetree_from_ty(sub_ty, tcx, depth + 1, safety, visited, span); visited.pop(); return subtt; diff --git a/compiler/rustc_monomorphize/src/partitioning.rs b/compiler/rustc_monomorphize/src/partitioning.rs index 0c3d8a64ea317..36a3c4f34a490 100644 --- a/compiler/rustc_monomorphize/src/partitioning.rs +++ b/compiler/rustc_monomorphize/src/partitioning.rs @@ -1202,7 +1202,8 @@ fn collect_and_partition_mono_items( println!("source_id: {:?}", inst.def_id()); let fn_ty = inst.ty(tcx, ParamEnv::empty()); assert!(fn_ty.is_fn()); - let fnc_tree = fnc_typetrees(tcx, fn_ty, &mut input_activities); + let span = tcx.def_span(inst.def_id()); + let fnc_tree = fnc_typetrees(tcx, fn_ty, &mut input_activities, Some(span)); let (inputs, output) = (fnc_tree.args, fnc_tree.ret); //check_types(inst.ty(tcx, ParamEnv::empty()), tcx, &target_attrs.input_activity); let symb = symbol_name_for_instance_in_crate(tcx, inst.clone(), LOCAL_CRATE); From 1e32009f0289a30e798efef49a8db38fb5459e3f Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Wed, 3 Apr 2024 18:32:24 -0400 Subject: [PATCH 082/100] impl new modes for higher order ad (#106) * impl new modes for higher order ad * fix fwd void ret case --- .../rustc_ast/src/expand/autodiff_attrs.rs | 27 ++++++- compiler/rustc_builtin_macros/src/autodiff.rs | 12 +-- compiler/rustc_codegen_llvm/src/back/write.rs | 77 ++++++++++++------- compiler/rustc_codegen_llvm/src/llvm/ffi.rs | 35 ++++++--- .../rustc_codegen_ssa/src/codegen_attrs.rs | 2 + 5 files changed, 106 insertions(+), 47 deletions(-) diff --git a/compiler/rustc_ast/src/expand/autodiff_attrs.rs b/compiler/rustc_ast/src/expand/autodiff_attrs.rs index 4da71ad6650a9..927cbfc50b681 100644 --- a/compiler/rustc_ast/src/expand/autodiff_attrs.rs +++ b/compiler/rustc_ast/src/expand/autodiff_attrs.rs @@ -11,6 +11,21 @@ pub enum DiffMode { Source, Forward, Reverse, + ForwardFirst, + ReverseFirst, +} + +pub fn is_rev(mode: DiffMode) -> bool { + match mode { + DiffMode::Reverse | DiffMode::ReverseFirst => true, + _ => false, + } +} +pub fn is_fwd(mode: DiffMode) -> bool { + match mode { + DiffMode::Forward | DiffMode::ForwardFirst => true, + _ => false, + } } impl Display for DiffMode { @@ -20,6 +35,8 @@ impl Display for DiffMode { DiffMode::Source => write!(f, "Source"), DiffMode::Forward => write!(f, "Forward"), DiffMode::Reverse => write!(f, "Reverse"), + DiffMode::ForwardFirst => write!(f, "ForwardFirst"), + DiffMode::ReverseFirst => write!(f, "ReverseFirst"), } } } @@ -32,12 +49,12 @@ pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool { match mode { DiffMode::Inactive => false, DiffMode::Source => false, - DiffMode::Forward => { + DiffMode::Forward | DiffMode::ForwardFirst => { activity == DiffActivity::Dual || activity == DiffActivity::DualOnly || activity == DiffActivity::Const } - DiffMode::Reverse => { + DiffMode::Reverse | DiffMode::ReverseFirst => { activity == DiffActivity::Const || activity == DiffActivity::Active || activity == DiffActivity::ActiveOnly @@ -73,13 +90,13 @@ pub fn valid_input_activity(mode: DiffMode, activity: DiffActivity) -> bool { return match mode { DiffMode::Inactive => false, DiffMode::Source => false, - DiffMode::Forward => { + DiffMode::Forward | DiffMode::ForwardFirst => { // These are the only valid cases activity == DiffActivity::Dual || activity == DiffActivity::DualOnly || activity == DiffActivity::Const } - DiffMode::Reverse => { + DiffMode::Reverse | DiffMode::ReverseFirst => { // These are the only valid cases activity == DiffActivity::Active || activity == DiffActivity::ActiveOnly || @@ -137,6 +154,8 @@ impl FromStr for DiffMode { "Source" => Ok(DiffMode::Source), "Forward" => Ok(DiffMode::Forward), "Reverse" => Ok(DiffMode::Reverse), + "ForwardFirst" => Ok(DiffMode::ForwardFirst), + "ReverseFirst" => Ok(DiffMode::ReverseFirst), _ => Err(()), } } diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index 49f1b743d1824..f203769e8a4cf 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -3,7 +3,7 @@ //use crate::util::check_autodiff; use crate::errors; -use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode, valid_input_activity, valid_ty_for_activity}; +use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode, is_fwd, is_rev, valid_input_activity, valid_ty_for_activity}; use rustc_ast::ptr::P; use rustc_ast::token::{Token, TokenKind}; use rustc_ast::tokenstream::*; @@ -308,7 +308,7 @@ fn gen_enzyme_body( let primal_ret = sig.decl.output.has_ret(); - if primal_ret && n_active == 0 && x.mode == DiffMode::Reverse { + if primal_ret && n_active == 0 && is_rev(x.mode) { // We only have the primal ret. body.stmts.push(ecx.stmt_expr(black_box_primal_call.clone())); return body; @@ -355,7 +355,7 @@ fn gen_enzyme_body( panic!("Did not expect non-tuple ret ty: {:?}", d_ret_ty); } }; - if x.mode == DiffMode::Forward { + if is_fwd(x.mode) { if x.ret_activity == DiffActivity::Dual { assert!(d_ret_ty.len() == 2); // both should be identical, by construction @@ -369,7 +369,7 @@ fn gen_enzyme_body( exprs.push(default_call_expr); } } else { - assert!(x.mode == DiffMode::Reverse); + assert!(is_rev(x.mode)); if primal_ret { // We have extra handling above for the primal ret @@ -508,7 +508,7 @@ fn gen_enzyme_decl( // If we return a scalar in the primal and the scalar is active, // then add it as last arg to the inputs. - if let DiffMode::Reverse = x.mode { + if is_rev(x.mode) { if let DiffActivity::Active = x.ret_activity { let ty = match d_decl.output { FnRetTy::Ty(ref ty) => ty.clone(), @@ -537,7 +537,7 @@ fn gen_enzyme_decl( } d_decl.inputs = d_inputs.into(); - if let DiffMode::Forward = x.mode { + if is_fwd(x.mode) { if let DiffActivity::Dual = x.ret_activity { let ty = match d_decl.output { FnRetTy::Ty(ref ty) => ty.clone(), diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs index 54dca723d8dc6..615196c41304f 100644 --- a/compiler/rustc_codegen_llvm/src/back/write.rs +++ b/compiler/rustc_codegen_llvm/src/back/write.rs @@ -703,7 +703,7 @@ fn get_params(fnc: &Value) -> Vec<&Value> { unsafe fn create_call<'a>(tgt: &'a Value, src: &'a Value, rev_mode: bool, llmod: &'a llvm::Module, llcx: &llvm::Context, size_positions: &[usize]) { - dbg!("size_positions: {:?}", size_positions); + // first, remove all calls from fnc let bb = LLVMGetFirstBasicBlock(tgt); let br = LLVMRustGetTerminator(bb); @@ -843,7 +843,6 @@ unsafe fn create_call<'a>(tgt: &'a Value, src: &'a Value, rev_mode: bool, // Now clean up placeholder code. LLVMRustEraseInstBefore(bb, last_inst); - //dbg!(&tgt); let f_return_type = LLVMGetReturnType(LLVMGlobalGetValueType(src)); let void_type = LLVMVoidTypeInContext(llcx); @@ -865,6 +864,7 @@ unsafe fn create_call<'a>(tgt: &'a Value, src: &'a Value, rev_mode: bool, let _fnc_ok = LLVMVerifyFunction(tgt, llvm::LLVMVerifierFailureAction::LLVMAbortProcessAction); } + unsafe fn get_panic_name(llmod: &llvm::Module) -> CString { // The names are mangled and their ending changes based on a hash, so just take whichever. let mut f = LLVMGetFirstFunction(llmod); @@ -922,21 +922,7 @@ unsafe fn add_panic_msg_to_global<'a>(llmod: &'a llvm::Module, llcx: &'a llvm::C LLVMRustSetLinkage(global_var, Linkage::PrivateLinkage); LLVMSetInitializer(global_var, struct_initializer); - //let msg_global_name = "ad_safety_msg".to_string(); - //let cmsg_global_name = CString::new(msg_global_name).unwrap(); - //let msg = "autodiff safety check failed!"; - //let cmsg = CString::new(msg).unwrap(); - //let msg_len = msg.len(); - //let i8_array_type = llvm::LLVMRustArrayType(llvm::LLVMInt8TypeInContext(llcx), msg_len as u64); - //let global_type = llvm::LLVMStructTypeInContext(llcx, [i8_array_type].as_mut_ptr(), 1, 0); - //let string_const_val = llvm::LLVMConstStringInContext(llcx, cmsg.as_ptr() as *const c_char, msg_len as u32, 0); - //let initializer = llvm::LLVMConstStructInContext(llcx, [string_const_val].as_mut_ptr(), 1, 0); - //let global = llvm::LLVMAddGlobal(llmod, global_type, cmsg_global_name.as_ptr() as *const c_char); - //llvm::LLVMRustSetLinkage(global, llvm::Linkage::PrivateLinkage); - //llvm::LLVMSetInitializer(global, initializer); - //llvm::LLVMSetUnnamedAddress(global, llvm::UnnamedAddr::Global); - - global_var + global_var } // As unsafe as it can be. @@ -947,6 +933,7 @@ pub(crate) unsafe fn enzyme_ad( llcx: &llvm::Context, diag_handler: &DiagCtxt, item: AutoDiffItem, + logic_ref: EnzymeLogicRef, ) -> Result<(), FatalError> { let autodiff_mode = item.attrs.mode; let rust_name = item.source; @@ -1001,13 +988,6 @@ pub(crate) unsafe fn enzyme_ad( item.inputs.into_iter().map(|x| to_enzyme_typetree(x, llvm_data_layout, llcx)).collect(); let output_tt = to_enzyme_typetree(item.output, llvm_data_layout, llcx); - let mut fnc_opt = false; - if std::env::var("ENZYME_ENABLE_FNC_OPT").is_ok() { - dbg!("Disabling optimizations for Enzyme"); - fnc_opt = true; - } - - let logic_ref: EnzymeLogicRef = CreateEnzymeLogic(fnc_opt as u8); let type_analysis: EnzymeTypeAnalysisRef = CreateTypeAnalysis(logic_ref, std::ptr::null_mut(), std::ptr::null_mut(), 0); @@ -1026,7 +1006,18 @@ pub(crate) unsafe fn enzyme_ad( llvm::set_print(true); } - let mut tmp = match item.attrs.mode { + let mode = match autodiff_mode { + DiffMode::Forward => DiffMode::Forward, + DiffMode::Reverse => DiffMode::Reverse, + DiffMode::ForwardFirst => DiffMode::Forward, + DiffMode::ReverseFirst => DiffMode::Reverse, + _ => unreachable!(), + }; + + let void_type = LLVMVoidTypeInContext(llcx); + let return_type = LLVMGetReturnType(LLVMGlobalGetValueType(src_fnc)); + let void_ret = void_type == return_type; + let mut tmp = match mode { DiffMode::Forward => enzyme_rust_forward_diff( logic_ref, type_analysis, @@ -1035,6 +1026,7 @@ pub(crate) unsafe fn enzyme_ad( ret_activity, input_tts, output_tt, + void_ret, ), DiffMode::Reverse => enzyme_rust_reverse_diff( logic_ref, @@ -1052,7 +1044,6 @@ pub(crate) unsafe fn enzyme_ad( let f_return_type = LLVMGetReturnType(LLVMGlobalGetValueType(res)); - let void_type = LLVMVoidTypeInContext(llcx); let rev_mode = item.attrs.mode == DiffMode::Reverse; create_call(target_fnc, res, rev_mode, llmod, llcx, &size_positions); // TODO: implement drop for wrapper type? @@ -1114,8 +1105,40 @@ pub(crate) unsafe fn differentiate( } let differentiate = !diff_items.is_empty(); + let mut first_order_items: Vec = vec![]; + let mut higher_order_items: Vec = vec![]; for item in diff_items { - let res = enzyme_ad(llmod, llcx, &diag_handler, item); + if item.attrs.mode == DiffMode::ForwardFirst || item.attrs.mode == DiffMode::ReverseFirst{ + first_order_items.push(item); + } else { + // default + higher_order_items.push(item); + } + } + + let mut fnc_opt = false; + if std::env::var("ENZYME_ENABLE_FNC_OPT").is_ok() { + dbg!("Enable extra optimizations for Enzyme"); + fnc_opt = true; + } + + // If a function is a base for some higher order ad, always optimize + let fnc_opt_base = true; + let logic_ref_opt: EnzymeLogicRef = CreateEnzymeLogic(fnc_opt_base as u8); + + for item in first_order_items { + let res = enzyme_ad(llmod, llcx, &diag_handler, item, logic_ref_opt); + assert!(res.is_ok()); + } + + // For the rest, follow the user choice on debug vs release. + // Reuse the opt one if possible for better compile time (Enzyme internal caching). + let logic_ref = match fnc_opt { + true => logic_ref_opt, + false => CreateEnzymeLogic(fnc_opt as u8), + }; + for item in higher_order_items { + let res = enzyme_ad(llmod, llcx, &diag_handler, item, logic_ref); assert!(res.is_ok()); } diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index a86e6ce8d65f0..ed8b4be25dc19 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -850,6 +850,7 @@ pub(crate) unsafe fn enzyme_rust_forward_diff( ret_diffactivity: DiffActivity, input_tts: Vec, output_tt: TypeTree, + void_ret: bool, ) -> (&Value, Vec) { let ret_activity = cdiffe_from(ret_diffactivity); assert!(ret_activity != CDIFFE_TYPE::DFT_OUT_DIFF); @@ -864,12 +865,18 @@ pub(crate) unsafe fn enzyme_rust_forward_diff( input_activity.push(act); } - let ret_primary_ret = match ret_activity { - CDIFFE_TYPE::DFT_CONSTANT => true, - CDIFFE_TYPE::DFT_DUP_ARG => true, - CDIFFE_TYPE::DFT_DUP_NONEED => false, - _ => panic!("Implementation error in enzyme_rust_forward_diff."), + // if we have void ret, this must be false; + let ret_primary_ret = if void_ret { + false + } else { + match ret_activity { + CDIFFE_TYPE::DFT_CONSTANT => true, + CDIFFE_TYPE::DFT_DUP_ARG => true, + CDIFFE_TYPE::DFT_DUP_NONEED => false, + _ => panic!("Implementation error in enzyme_rust_forward_diff."), + } }; + trace!("ret_primary_ret: {}", &ret_primary_ret); let mut args_tree = input_tts.iter().map(|x| x.inner).collect::>(); //let mut args_tree = vec![TypeTree::new().inner; typetree.input_tt.len()]; @@ -879,8 +886,8 @@ pub(crate) unsafe fn enzyme_rust_forward_diff( let args_uncacheable = vec![0; input_tts.len()]; assert!(args_uncacheable.len() == input_activity.len()); let num_fnc_args = LLVMCountParams(fnc); - println!("num_fnc_args: {}", num_fnc_args); - println!("input_activity.len(): {}", input_activity.len()); + trace!("num_fnc_args: {}", num_fnc_args); + trace!("input_activity.len(): {}", input_activity.len()); assert!(num_fnc_args == input_activity.len() as u32); let kv_tmp = IntList { data: std::ptr::null_mut(), size: 0 }; @@ -893,6 +900,11 @@ pub(crate) unsafe fn enzyme_rust_forward_diff( KnownValues: known_values.as_mut_ptr(), }; + trace!("ret_activity: {}", &ret_activity); + for i in &input_activity { + trace!("input_activity i: {}", &i); + } + trace!("before calling Enzyme"); let res = EnzymeCreateForwardDiff( logic_ref, // Logic std::ptr::null(), @@ -912,6 +924,7 @@ pub(crate) unsafe fn enzyme_rust_forward_diff( args_uncacheable.len(), // uncacheable arguments std::ptr::null_mut(), // write augmented function to this ); + trace!("after calling Enzyme"); (res, vec![]) } @@ -971,11 +984,12 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff( KnownValues: known_values.as_mut_ptr(), }; - trace!("{}", &primary_ret); - trace!("{}", &ret_activity); + trace!("primary_ret: {}", &primary_ret); + trace!("ret_activity: {}", &ret_activity); for i in &input_activity { - trace!("{}", &i); + trace!("input_activity i: {}", &i); } + trace!("before calling Enzyme"); let res = EnzymeCreatePrimalAndGradient( logic_ref, // Logic std::ptr::null(), @@ -998,6 +1012,7 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff( std::ptr::null_mut(), // write augmented function to this 0, ); + trace!("after calling Enzyme"); (res, primal_sizes) } diff --git a/compiler/rustc_codegen_ssa/src/codegen_attrs.rs b/compiler/rustc_codegen_ssa/src/codegen_attrs.rs index 4d1dd44031669..f0b993a3844e1 100644 --- a/compiler/rustc_codegen_ssa/src/codegen_attrs.rs +++ b/compiler/rustc_codegen_ssa/src/codegen_attrs.rs @@ -749,6 +749,8 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs { let mode = match mode.as_str() { "Forward" => DiffMode::Forward, "Reverse" => DiffMode::Reverse, + "ForwardFirst" => DiffMode::ForwardFirst, + "ReverseFirst" => DiffMode::ReverseFirst, _ => { tcx.sess .struct_span_err(attr.span, msg_mode) From b5a31307e2f591164d56dd0ee166484c93ac1db0 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Sat, 6 Apr 2024 01:11:43 -0400 Subject: [PATCH 083/100] adding two more flags (#109) --- compiler/rustc_codegen_llvm/src/back/lto.rs | 4 +- compiler/rustc_codegen_llvm/src/back/write.rs | 83 +++++++++++++++---- 2 files changed, 68 insertions(+), 19 deletions(-) diff --git a/compiler/rustc_codegen_llvm/src/back/lto.rs b/compiler/rustc_codegen_llvm/src/back/lto.rs index f28af4ba49bad..d061571e9ded7 100644 --- a/compiler/rustc_codegen_llvm/src/back/lto.rs +++ b/compiler/rustc_codegen_llvm/src/back/lto.rs @@ -618,7 +618,9 @@ pub(crate) fn run_pass_manager( let opt_level = config.opt_level.unwrap_or(config::OptLevel::No); // We will run this again with different values in the context of automatic differentiation. let first_run = true; - write::llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage, first_run)?; + let noop = false; + dbg!("running llvm pm opt pipeline"); + write::llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage, first_run, noop)?; } debug!("lto done"); Ok(()) diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs index 615196c41304f..c08dbac0e7a79 100644 --- a/compiler/rustc_codegen_llvm/src/back/write.rs +++ b/compiler/rustc_codegen_llvm/src/back/write.rs @@ -550,7 +550,11 @@ pub(crate) unsafe fn llvm_optimize( opt_level: config::OptLevel, opt_stage: llvm::OptStage, first_run: bool, + noop: bool, ) -> Result<(), FatalError> { + if noop { + return Ok(()); + } // Enzyme: // We want to simplify / optimize functions before AD. // However, benchmarks show that optimizations increasing the code size @@ -724,6 +728,13 @@ unsafe fn create_call<'a>(tgt: &'a Value, src: &'a Value, rev_mode: bool, let last_inst = LLVMRustGetLastInstruction(bb).unwrap(); LLVMPositionBuilderAtEnd(builder, bb); + let safety_run_checks; + if std::env::var("ENZYME_NO_SAFETY_CHECKS").is_ok() { + safety_run_checks = false; + } else { + safety_run_checks = true; + } + if inner_param_num == outer_param_num { call_args = outer_args; } else { @@ -763,14 +774,18 @@ unsafe fn create_call<'a>(tgt: &'a Value, src: &'a Value, rev_mode: bool, outer_pos += 3; inner_pos += 2; - // Now we assert if int1 <= int2 - let res = LLVMBuildICmp( - builder, - IntPredicate::IntULE as u32, - outer_arg, - next2_outer_arg, - "safety_check".as_ptr() as *const c_char); - safety_vals.push(res); + + if safety_run_checks { + + // Now we assert if int1 <= int2 + let res = LLVMBuildICmp( + builder, + IntPredicate::IntULE as u32, + outer_arg, + next2_outer_arg, + "safety_check".as_ptr() as *const c_char); + safety_vals.push(res); + } } } } @@ -782,17 +797,18 @@ unsafe fn create_call<'a>(tgt: &'a Value, src: &'a Value, rev_mode: bool, // Now add the safety checks. if !safety_vals.is_empty() { dbg!("Adding safety checks"); + assert!(safety_run_checks); // first we create one bb per check and two more for the fail and success case. let fail_bb = LLVMAppendBasicBlockInContext(llcx, tgt, "ad_safety_fail".as_ptr() as *const c_char); let success_bb = LLVMAppendBasicBlockInContext(llcx, tgt, "ad_safety_success".as_ptr() as *const c_char); - let mut err_bb = vec![]; - for i in 0..safety_vals.len() { - let name: String = format!("ad_safety_err_{}", i); - err_bb.push(LLVMAppendBasicBlockInContext(llcx, tgt, name.as_ptr() as *const c_char)); - } - for (i, &val) in safety_vals.iter().enumerate() { - LLVMBuildCondBr(builder, val, err_bb[i], fail_bb); - LLVMPositionBuilderAtEnd(builder, err_bb[i]); + for i in 1..safety_vals.len() { + // 'or' all safety checks together + // Doing some binary tree style or'ing here would be more efficient, + // but I assume LLVM will opt it anyway + let prev = safety_vals[i - 1]; + let curr = safety_vals[i]; + let res = llvm::LLVMBuildOr(builder, prev, curr, "safety_check".as_ptr() as *const c_char); + safety_vals[i] = res; } LLVMBuildCondBr(builder, safety_vals.last().unwrap(), success_bb, fail_bb); LLVMPositionBuilderAtEnd(builder, fail_bb); @@ -1194,7 +1210,31 @@ pub(crate) unsafe fn differentiate( // disables vectorization and loop unrolling first_run = true; } - llvm_optimize(cgcx, &diag_handler, module, config, opt_level, opt_stage, first_run)?; + if std::env::var("ENZYME_ALT_PIPELINE").is_ok() { + dbg!("Running first postAD optimization"); + first_run = true; + } + let noop = false; + llvm_optimize(cgcx, &diag_handler, module, config, opt_level, opt_stage, first_run, noop)?; + } + if std::env::var("ENZYME_ALT_PIPELINE").is_ok() { + dbg!("Running Second postAD optimization"); + if let Some(opt_level) = config.opt_level { + let opt_stage = match cgcx.lto { + Lto::Fat => llvm::OptStage::PreLinkFatLTO, + Lto::Thin | Lto::ThinLocal => llvm::OptStage::PreLinkThinLTO, + _ if cgcx.opts.cg.linker_plugin_lto.enabled() => llvm::OptStage::PreLinkThinLTO, + _ => llvm::OptStage::PreLinkNoLTO, + }; + let mut first_run = false; + dbg!("Running Module Optimization after differentiation"); + if std::env::var("ENZYME_NO_VEC_UNROLL").is_ok() { + // enables vectorization and loop unrolling + first_run = false; + } + let noop = false; + llvm_optimize(cgcx, &diag_handler, module, config, opt_level, opt_stage, first_run, noop)?; + } } } @@ -1278,7 +1318,14 @@ pub(crate) unsafe fn optimize( }; // Second run only relevant for AD let first_run = true; - return llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage, first_run); + let noop; + if std::env::var("ENZYME_ALT_PIPELINE").is_ok() { + noop = true; + dbg!("Skipping PreAD optimization"); + } else { + noop = false; + } + return llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage, first_run, noop); } Ok(()) } From af4d7660895557179b88c845a5cc9776ab069b1f Mon Sep 17 00:00:00 2001 From: I-Al-Istannen Date: Sun, 7 Apr 2024 14:22:16 +0200 Subject: [PATCH 084/100] chore: Use correct commit hash as github cache key if submodule is not checked out (#107) * chore: Use correct commit hash as github cache key * chore: Clone with submodules This will also initialize all the "doc" submodules, which might prove too much. --- .github/workflows/enzyme-ci.yml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/enzyme-ci.yml b/.github/workflows/enzyme-ci.yml index c233b049e1d6d..a08267fcdb725 100644 --- a/.github/workflows/enzyme-ci.yml +++ b/.github/workflows/enzyme-ci.yml @@ -24,11 +24,12 @@ jobs: - name: Checkout Rust source uses: actions/checkout@v4 with: + submodules: true # check out all submodules so the cache can work correctly fetch-depth: 2 - uses: dtolnay/rust-toolchain@nightly - name: Get LLVM commit hash id: llvm-commit - run: echo "HEAD=$(git -C src/llvm-project rev-parse HEAD)" >> $GITHUB_OUTPUT + run: echo "HEAD=$(git rev-parse HEAD:src/llvm-project)" >> $GITHUB_OUTPUT - name: Cache LLVM id: cache-llvm uses: actions/cache@v4 @@ -37,7 +38,7 @@ jobs: key: ${{ matrix.os }}-llvm-${{ steps.llvm-commit.outputs.HEAD }} - name: Get Enzyme commit hash id: enzyme-commit - run: echo "HEAD=$(git -C src/tools/enzyme rev-parse HEAD)" >> $GITHUB_OUTPUT + run: echo "HEAD=$(git rev-parse HEAD:src/tools/enzyme)" >> $GITHUB_OUTPUT - name: Cache Enzyme id: cache-enzyme uses: actions/cache@v4 From 948848d5abcdb128055d3ff779c0ce57fda27215 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Thu, 11 Apr 2024 18:26:30 -0400 Subject: [PATCH 085/100] support self ty (#110) * finish method / trait support --- compiler/rustc_builtin_macros/src/autodiff.rs | 134 +++++++++++++----- compiler/rustc_expand/src/build.rs | 4 + 2 files changed, 106 insertions(+), 32 deletions(-) diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index f203769e8a4cf..9787d8ffe7aa7 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -19,6 +19,8 @@ use std::string::String; use thin_vec::{thin_vec, ThinVec}; use std::str::FromStr; +use rustc_ast::AssocItemKind; + #[cfg(not(llvm_enzyme))] pub fn expand( ecx: &mut ExtCtxt<'_>, @@ -82,10 +84,39 @@ pub fn expand( ecx: &mut ExtCtxt<'_>, expand_span: Span, meta_item: &ast::MetaItem, - item: Annotatable, + mut item: Annotatable, ) -> Vec { //check_builtin_macro_attribute(ecx, meta_item, sym::alloc_error_handler); + // first get the annotable item: + let (sig, is_impl): (FnSig, bool) = match &item { + Annotatable::Item(ref iitem) => { + let sig = match &iitem.kind { + ItemKind::Fn(box ast::Fn { sig, .. }) => sig, + _ => { + ecx.sess.dcx().emit_err(errors::AutoDiffInvalidApplication { span: item.span() }); + return vec![item]; + } + }; + (sig.clone(), false) + }, + Annotatable::ImplItem(ref assoc_item) => { + let sig = match &assoc_item.kind { + ast::AssocItemKind::Fn(box ast::Fn { sig, .. }) => sig, + _ => { + ecx.sess.dcx().emit_err(errors::AutoDiffInvalidApplication { span: item.span() }); + return vec![item]; + } + }; + (sig.clone(), true) + }, + _ => { + dbg!(&item); + ecx.sess.dcx().emit_err(errors::AutoDiffInvalidApplication { span: item.span() }); + return vec![item]; + } + }; + let meta_item_vec: ThinVec = match meta_item.kind { ast::MetaItemKind::List(ref vec) => vec.clone(), _ => { @@ -93,19 +124,22 @@ pub fn expand( return vec![item]; } }; - // Allow using `#[autodiff(...)]` only on a Fn - let (has_ret, sig, sig_span) = if let Annotatable::Item(item) = &item - && let ItemKind::Fn(box ast::Fn { sig, .. }) = &item.kind - { - (sig.decl.output.has_ret(), sig, ecx.with_call_site_ctxt(sig.span)) - } else { - ecx.sess.dcx().emit_err(errors::AutoDiffInvalidApplication { span: item.span() }); - return vec![item]; - }; - // Now we know that item is a Item::Fn - let mut orig_item: P = item.clone().expect_item(); - let primal = orig_item.ident.clone(); + let has_ret = sig.decl.output.has_ret(); + let sig_span = ecx.with_call_site_ctxt(sig.span); + + let (vis, primal) = match &item { + Annotatable::Item(ref iitem) => { + (iitem.vis.clone(), iitem.ident.clone()) + }, + Annotatable::ImplItem(ref assoc_item) => { + (assoc_item.vis.clone(), assoc_item.ident.clone()) + }, + _ => { + ecx.sess.dcx().emit_err(errors::AutoDiffInvalidApplication { span: item.span() }); + return vec![item]; + } + }; // create TokenStream from vec elemtents: // meta_item doesn't have a .tokens field @@ -154,12 +188,12 @@ pub fn expand( let d_ident = first_ident(&meta_item_vec[0]); // The first element of it is the name of the function to be generated - let asdf = ItemKind::Fn(Box::new(ast::Fn { + let asdf = Box::new(ast::Fn { defaultness: ast::Defaultness::Final, sig: d_sig, generics: Generics::default(), body: Some(d_body), - })); + }); let mut rustc_ad_attr = P(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::rustc_autodiff))); let ts2: Vec = vec![ @@ -195,13 +229,6 @@ pub fn expand( style: ast::AttrStyle::Outer, span, }; - // don't add it multiple times: - if !orig_item.attrs.iter().any(|a| a.id == attr.id) { - orig_item.attrs.push(attr.clone()); - } - if !orig_item.attrs.iter().any(|a| a.id == inline_never.id) { - orig_item.attrs.push(inline_never); - } // Now update for d_fn rustc_ad_attr.item.args = rustc_ast::AttrArgs::Delimited(rustc_ast::DelimArgs { @@ -210,13 +237,51 @@ pub fn expand( tokens: ts, }); attr.kind = ast::AttrKind::Normal(rustc_ad_attr); - let mut d_fn = ecx.item(span, d_ident, thin_vec![attr], asdf); - // Copy visibility from original function - d_fn.vis = orig_item.vis.clone(); + // Don't add it multiple times: + let orig_annotatable: Annotatable = match item { + Annotatable::Item(ref mut iitem) => { + if !iitem.attrs.iter().any(|a| a.id == attr.id) { + iitem.attrs.push(attr.clone()); + } + if !iitem.attrs.iter().any(|a| a.id == inline_never.id) { + iitem.attrs.push(inline_never.clone()); + } + Annotatable::Item(iitem.clone()) + }, + Annotatable::ImplItem(ref mut assoc_item) => { + if !assoc_item.attrs.iter().any(|a| a.id == attr.id) { + assoc_item.attrs.push(attr.clone()); + } + if !assoc_item.attrs.iter().any(|a| a.id == inline_never.id) { + assoc_item.attrs.push(inline_never.clone()); + } + Annotatable::ImplItem(assoc_item.clone()) + }, + _ => { + panic!("not supported"); + } + }; + + let d_annotatable = if is_impl { + let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf); + let d_fn = P(ast::AssocItem { + attrs: thin_vec![attr.clone(), inline_never], + id: ast::DUMMY_NODE_ID, + span, + vis, + ident: d_ident, + kind: assoc_item, + tokens: None, + }); + Annotatable::ImplItem(d_fn) + } else { + let mut d_fn = ecx.item(span, d_ident, thin_vec![attr.clone()], ItemKind::Fn(asdf)); + d_fn.vis = vis; + Annotatable::Item(d_fn) + }; + trace!("Generated function: {:?}", d_annotatable); - let orig_annotatable = Annotatable::Item(orig_item); - let d_annotatable = Annotatable::Item(d_fn); return vec![orig_annotatable, d_annotatable]; } @@ -403,10 +468,16 @@ fn gen_primal_call( primal: Ident, idents: Vec, ) -> P { - let primal_call_expr = ecx.expr_path(ecx.path_ident(span, primal)); - let args = idents.iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect(); - let primal_call = ecx.expr_call(span, primal_call_expr, args); - primal_call + let has_self = idents.len() > 0 && idents[0].name == kw::SelfLower; + if has_self { + let args: ThinVec<_> = idents[1..].iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect(); + let self_expr = ecx.expr_self(span); + ecx.expr_method_call(span, self_expr, primal, args.clone()) + } else { + let args: ThinVec<_> = idents.iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect(); + let primal_call_expr = ecx.expr_path(ecx.path_ident(span, primal)); + ecx.expr_call(span, primal_call_expr, args) + } } // Generate the new function declaration. Const arguments are kept as is. Duplicated arguments must @@ -427,7 +498,6 @@ fn gen_enzyme_decl( let mut d_decl = sig.decl.clone(); let mut d_inputs = Vec::new(); let mut new_inputs = Vec::new(); - //let mut old_names = Vec::new(); let mut idents = Vec::new(); let mut act_ret = ThinVec::new(); for (arg, activity) in sig.decl.inputs.iter().zip(x.input_activity.iter()) { diff --git a/compiler/rustc_expand/src/build.rs b/compiler/rustc_expand/src/build.rs index 0037267f30be9..ac75c942b2b3f 100644 --- a/compiler/rustc_expand/src/build.rs +++ b/compiler/rustc_expand/src/build.rs @@ -287,6 +287,10 @@ impl<'a> ExtCtxt<'a> { self.expr(sp, ast::ExprKind::Paren(e)) } + pub fn expr_method_call(&self, span: Span, expr: P, ident: Ident, args: ThinVec>) -> P { + let seg = ast::PathSegment::from_ident(ident); + self.expr(span, ast::ExprKind::MethodCall(Box::new(ast::MethodCall { seg, receiver: expr, args, span }))) + } pub fn expr_call( &self, span: Span, From ffdbffb73b8737d9be949b60beb4b01371ce0bd2 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Thu, 11 Apr 2024 19:14:42 -0400 Subject: [PATCH 086/100] add rust error for activity to arg count missmatch (#111) --- compiler/rustc_builtin_macros/messages.ftl | 1 + compiler/rustc_builtin_macros/src/autodiff.rs | 9 +++++++++ compiler/rustc_builtin_macros/src/errors.rs | 9 ++++++++- 3 files changed, 18 insertions(+), 1 deletion(-) diff --git a/compiler/rustc_builtin_macros/messages.ftl b/compiler/rustc_builtin_macros/messages.ftl index 5a261700dfe90..d5d7562f36a11 100644 --- a/compiler/rustc_builtin_macros/messages.ftl +++ b/compiler/rustc_builtin_macros/messages.ftl @@ -5,6 +5,7 @@ builtin_macros_autodiff_unknown_activity = did not recognize activity {$act} builtin_macros_autodiff = autodiff must be applied to function builtin_macros_autodiff_not_build = this rustc version does not support autodiff builtin_macros_autodiff_mode_activity = {$act} can not be used in {$mode} Mode +builtin_macros_autodiff_number_activities = expected {$expected} activities, but found {$found} builtin_macros_autodiff_mode = unknown Mode: `{$mode}`. Use `Forward` or `Reverse` builtin_macros_autodiff_ty_activity = {$act} can not be used for this type diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index 9787d8ffe7aa7..c0f6763dc3549 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -493,6 +493,15 @@ fn gen_enzyme_decl( x: &AutoDiffAttrs, span: Span, ) -> (ast::FnSig, Vec, Vec) { + let sig_args = sig.decl.inputs.len() + if sig.decl.output.has_ret() { 1 } else { 0 }; + let num_activities = x.input_activity.len() + if x.has_ret_activity() { 1 } else { 0 }; + if sig_args != num_activities { + ecx.sess.dcx().emit_fatal(errors::AutoDiffInvalidNumberActivities { + span, + expected: sig_args, + found: num_activities, + }); + } assert!(sig.decl.inputs.len() == x.input_activity.len()); assert!(sig.decl.output.has_ret() == x.has_ret_activity()); let mut d_decl = sig.decl.clone(); diff --git a/compiler/rustc_builtin_macros/src/errors.rs b/compiler/rustc_builtin_macros/src/errors.rs index 03321e889749a..9d4e64027c14f 100644 --- a/compiler/rustc_builtin_macros/src/errors.rs +++ b/compiler/rustc_builtin_macros/src/errors.rs @@ -177,7 +177,14 @@ pub(crate) struct AutoDiffInvalidTypeForActivity { pub(crate) span: Span, pub(crate) act: String, } - +#[derive(Diagnostic)] +#[diag(builtin_macros_autodiff_number_activities)] +pub(crate) struct AutoDiffInvalidNumberActivities { + #[primary_span] + pub(crate) span: Span, + pub(crate) expected: usize, + pub(crate) found: usize, +} #[derive(Diagnostic)] #[diag(builtin_macros_autodiff_mode_activity)] pub(crate) struct AutoDiffInvalidApplicationModeAct { From 17c772f90c5e8b45725ffdc10e0a00ad221e3895 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Fri, 12 Apr 2024 15:40:04 -0400 Subject: [PATCH 087/100] unbreak ActiveOnly return (#112) --- .../rustc_ast/src/expand/autodiff_attrs.rs | 7 ++ compiler/rustc_builtin_macros/src/autodiff.rs | 93 +++++++++++++------ compiler/rustc_codegen_llvm/src/llvm/ffi.rs | 1 + 3 files changed, 72 insertions(+), 29 deletions(-) diff --git a/compiler/rustc_ast/src/expand/autodiff_attrs.rs b/compiler/rustc_ast/src/expand/autodiff_attrs.rs index 927cbfc50b681..8fb5e41577330 100644 --- a/compiler/rustc_ast/src/expand/autodiff_attrs.rs +++ b/compiler/rustc_ast/src/expand/autodiff_attrs.rs @@ -167,6 +167,7 @@ impl FromStr for DiffActivity { match s { "None" => Ok(DiffActivity::None), "Active" => Ok(DiffActivity::Active), + "ActiveOnly" => Ok(DiffActivity::ActiveOnly), "Const" => Ok(DiffActivity::Const), "Dual" => Ok(DiffActivity::Dual), "DualOnly" => Ok(DiffActivity::DualOnly), @@ -192,6 +193,12 @@ impl AutoDiffAttrs { _ => true, } } + pub fn has_active_only_ret(&self) -> bool { + match self.ret_activity { + DiffActivity::ActiveOnly => true, + _ => false, + } + } } impl AutoDiffAttrs { diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index c0f6763dc3549..83ce5bac25cf9 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -111,7 +111,6 @@ pub fn expand( (sig.clone(), true) }, _ => { - dbg!(&item); ecx.sess.dcx().emit_err(errors::AutoDiffInvalidApplication { span: item.span() }); return vec![item]; } @@ -280,7 +279,6 @@ pub fn expand( d_fn.vis = vis; Annotatable::Item(d_fn) }; - trace!("Generated function: {:?}", d_annotatable); return vec![orig_annotatable, d_annotatable]; } @@ -371,7 +369,9 @@ fn gen_enzyme_body( return body; } - let primal_ret = sig.decl.output.has_ret(); + // having an active-only return means we'll drop the original return type. + // So that can be treated identical to not having one in the first place. + let primal_ret = sig.decl.output.has_ret() && !x.has_active_only_ret(); if primal_ret && n_active == 0 && is_rev(x.mode) { // We only have the primal ret. @@ -405,16 +405,26 @@ fn gen_enzyme_body( // Now construct default placeholder for each active float. // Is there something nicer than f32::default() and f64::default()? - let mut d_ret_ty = match d_sig.decl.output { + let d_ret_ty = match d_sig.decl.output { FnRetTy::Ty(ref ty) => ty.clone(), FnRetTy::Default(span) => { panic!("Did not expect Default ret ty: {:?}", span); } }; - let mut d_ret_ty = match d_ret_ty.kind { - TyKind::Tup(ref mut tys) => { + let mut d_ret_ty = match d_ret_ty.kind.clone() { + TyKind::Tup(ref tys) => { tys.clone() } + TyKind::Path(_, rustc_ast::Path { segments, .. }) => { + if segments.len() == 1 && segments[0].args.is_none() { + let id = vec![segments[0].ident]; + let kind = TyKind::Path(None, ecx.path(span, id)); + let ty = P(rustc_ast::Ty { kind, id: ast::DUMMY_NODE_ID, span, tokens: None }); + thin_vec![ty] + } else { + panic!("Expected tuple or simple path return type"); + } + } _ => { // We messed up construction of d_sig panic!("Did not expect non-tuple ret ty: {:?}", d_ret_ty); @@ -585,33 +595,41 @@ fn gen_enzyme_decl( } } + let active_only_ret = x.ret_activity == DiffActivity::ActiveOnly; + if active_only_ret { + assert!(is_rev(x.mode)); + } + // If we return a scalar in the primal and the scalar is active, // then add it as last arg to the inputs. if is_rev(x.mode) { - if let DiffActivity::Active = x.ret_activity { - let ty = match d_decl.output { - FnRetTy::Ty(ref ty) => ty.clone(), - FnRetTy::Default(span) => { - panic!("Did not expect Default ret ty: {:?}", span); - } - }; - let name = "dret".to_string(); - let ident = Ident::from_str_and_span(&name, ty.span); - let shadow_arg = ast::Param { - attrs: ThinVec::new(), - ty: ty.clone(), - pat: P(ast::Pat { + match x.ret_activity { + DiffActivity::Active | DiffActivity::ActiveOnly => { + let ty = match d_decl.output { + FnRetTy::Ty(ref ty) => ty.clone(), + FnRetTy::Default(span) => { + panic!("Did not expect Default ret ty: {:?}", span); + } + }; + let name = "dret".to_string(); + let ident = Ident::from_str_and_span(&name, ty.span); + let shadow_arg = ast::Param { + attrs: ThinVec::new(), + ty: ty.clone(), + pat: P(ast::Pat { + id: ast::DUMMY_NODE_ID, + kind: PatKind::Ident(BindingAnnotation::NONE, ident, None), + span: ty.span, + tokens: None, + }), id: ast::DUMMY_NODE_ID, - kind: PatKind::Ident(BindingAnnotation::NONE, ident, None), span: ty.span, - tokens: None, - }), - id: ast::DUMMY_NODE_ID, - span: ty.span, - is_placeholder: false, - }; - d_inputs.push(shadow_arg); - new_inputs.push(name); + is_placeholder: false, + }; + d_inputs.push(shadow_arg); + new_inputs.push(name); + } + _ => {} } } d_decl.inputs = d_inputs.into(); @@ -630,15 +648,31 @@ fn gen_enzyme_decl( let ty = P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None }); d_decl.output = FnRetTy::Ty(ty); } + if let DiffActivity::DualOnly = x.ret_activity { + // No need to change the return type, + // we will just return the shadow in place + // of the primal return. + } } + // If we use ActiveOnly, drop the original return value. + d_decl.output = if active_only_ret { + FnRetTy::Default(span) + } else { + d_decl.output.clone() + }; + + trace!("act_ret: {:?}", act_ret); + // If we have an active input scalar, add it's gradient to the // return type. This might require changing the return type to a // tuple. if act_ret.len() > 0 { let ret_ty = match d_decl.output { FnRetTy::Ty(ref ty) => { - act_ret.insert(0, ty.clone()); + if !active_only_ret { + act_ret.insert(0, ty.clone()); + } let kind = TyKind::Tup(act_ret); P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None }) } @@ -655,5 +689,6 @@ fn gen_enzyme_decl( } let d_sig = FnSig { header: sig.header.clone(), decl: d_decl, span }; + trace!("Generated signature: {:?}", d_sig); (d_sig, new_inputs, idents) } diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index ed8b4be25dc19..69049ca752d26 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -940,6 +940,7 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff( let (primary_ret, ret_activity) = match ret_activity { DiffActivity::Const => (true, CDIFFE_TYPE::DFT_CONSTANT), DiffActivity::Active => (true, CDIFFE_TYPE::DFT_OUT_DIFF), + DiffActivity::ActiveOnly => (false, CDIFFE_TYPE::DFT_OUT_DIFF), DiffActivity::None => (false, CDIFFE_TYPE::DFT_CONSTANT), _ => panic!("Invalid return activity"), }; From 9a411dca17d2e04ff87895d72e6616041c4a7052 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Fri, 12 Apr 2024 17:54:11 -0400 Subject: [PATCH 088/100] handle DualOnly ret more reliably (#113) --- compiler/rustc_builtin_macros/src/autodiff.rs | 18 +++++++++++++----- compiler/rustc_middle/src/ty/mod.rs | 1 + 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index 83ce5bac25cf9..2d2ef62e8fba9 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -461,12 +461,20 @@ fn gen_enzyme_body( }; } - let ret_tuple: P = ecx.expr_tuple(span, exprs); - let ret = ecx.expr_call(new_decl_span, blackbox_call_expr.clone(), thin_vec![ret_tuple]); - if d_sig.decl.output.has_ret() { - // If we return (), we don't have to match the return type. - body.stmts.push(ecx.stmt_expr(ret)); + let ret : P; + if exprs.len() > 1 { + let ret_tuple: P = ecx.expr_tuple(span, exprs); + ret = ecx.expr_call(new_decl_span, blackbox_call_expr.clone(), thin_vec![ret_tuple]); + } else if exprs.len() == 1 { + let ret_scal = exprs.pop().unwrap(); + ret = ecx.expr_call(new_decl_span, blackbox_call_expr.clone(), thin_vec![ret_scal]); + } else { + assert!(!d_sig.decl.output.has_ret()); + // We don't have to match the return type. + return body; } + assert!(d_sig.decl.output.has_ret()); + body.stmts.push(ecx.stmt_expr(ret)); body } diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs index ed0b08f823367..e15c35d6b0639 100644 --- a/compiler/rustc_middle/src/ty/mod.rs +++ b/compiler/rustc_middle/src/ty/mod.rs @@ -2762,6 +2762,7 @@ pub fn fnc_typetrees<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec Date: Sat, 13 Apr 2024 16:54:03 -0400 Subject: [PATCH 089/100] last commit broke the right order of updating attributes (#114) --- compiler/rustc_builtin_macros/src/autodiff.rs | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index 2d2ef62e8fba9..7f47b62776e93 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -229,14 +229,6 @@ pub fn expand( span, }; - // Now update for d_fn - rustc_ad_attr.item.args = rustc_ast::AttrArgs::Delimited(rustc_ast::DelimArgs { - dspan: DelimSpan::dummy(), - delim: rustc_ast::token::Delimiter::Parenthesis, - tokens: ts, - }); - attr.kind = ast::AttrKind::Normal(rustc_ad_attr); - // Don't add it multiple times: let orig_annotatable: Annotatable = match item { Annotatable::Item(ref mut iitem) => { @@ -262,6 +254,14 @@ pub fn expand( } }; + // Now update for d_fn + rustc_ad_attr.item.args = rustc_ast::AttrArgs::Delimited(rustc_ast::DelimArgs { + dspan: DelimSpan::dummy(), + delim: rustc_ast::token::Delimiter::Parenthesis, + tokens: ts, + }); + attr.kind = ast::AttrKind::Normal(rustc_ad_attr); + let d_annotatable = if is_impl { let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf); let d_fn = P(ast::AssocItem { @@ -275,7 +275,7 @@ pub fn expand( }); Annotatable::ImplItem(d_fn) } else { - let mut d_fn = ecx.item(span, d_ident, thin_vec![attr.clone()], ItemKind::Fn(asdf)); + let mut d_fn = ecx.item(span, d_ident, thin_vec![attr.clone(), inline_never], ItemKind::Fn(asdf)); d_fn.vis = vis; Annotatable::Item(d_fn) }; From 5747648b7e6cb35c750f72b5d46a1b889e0ca921 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Sun, 14 Apr 2024 13:50:11 -0400 Subject: [PATCH 090/100] fix higher-order { float } ret case (#115) --- compiler/rustc_codegen_llvm/src/back/write.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs index c08dbac0e7a79..a362f1640c2e0 100644 --- a/compiler/rustc_codegen_llvm/src/back/write.rs +++ b/compiler/rustc_codegen_llvm/src/back/write.rs @@ -840,7 +840,6 @@ unsafe fn create_call<'a>(tgt: &'a Value, src: &'a Value, rev_mode: bool, c_inner_fnc_name.as_ptr(), ); - // Add dummy dbg info to our newly generated call, if we have any. let inst = LLVMRustgetFirstNonPHIOrDbgOrLifetime(bb).unwrap(); let md_ty = llvm::LLVMGetMDKindIDInContext( @@ -863,7 +862,7 @@ unsafe fn create_call<'a>(tgt: &'a Value, src: &'a Value, rev_mode: bool, let f_return_type = LLVMGetReturnType(LLVMGlobalGetValueType(src)); let void_type = LLVMVoidTypeInContext(llcx); // Now unwrap the struct_ret if it's actually a struct - if rev_mode && f_return_type != void_type { + if f_return_type != void_type { let num_elem_in_ret_struct = LLVMCountStructElementTypes(f_return_type); if num_elem_in_ret_struct == 1 { let inner_grad_name = "foo".to_string(); From 15bfd28f14ec7114dc40d7a3e69d4db1b1813039 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Sun, 14 Apr 2024 15:20:58 -0400 Subject: [PATCH 091/100] fix fwd test case (#116) * fix fwd test case * simplify --- compiler/rustc_codegen_llvm/src/back/write.rs | 3 ++- compiler/rustc_codegen_llvm/src/llvm/ffi.rs | 1 + compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp | 4 ++++ 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs index a362f1640c2e0..657db58831c1b 100644 --- a/compiler/rustc_codegen_llvm/src/back/write.rs +++ b/compiler/rustc_codegen_llvm/src/back/write.rs @@ -860,9 +860,10 @@ unsafe fn create_call<'a>(tgt: &'a Value, src: &'a Value, rev_mode: bool, LLVMRustEraseInstBefore(bb, last_inst); let f_return_type = LLVMGetReturnType(LLVMGlobalGetValueType(src)); + let f_is_struct = llvm::LLVMRustIsStructType(f_return_type); let void_type = LLVMVoidTypeInContext(llcx); // Now unwrap the struct_ret if it's actually a struct - if f_return_type != void_type { + if f_is_struct { let num_elem_in_ret_struct = LLVMCountStructElementTypes(f_return_type); if num_elem_in_ret_struct == 1 { let inner_grad_name = "foo".to_string(); diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index 69049ca752d26..5b94b80502109 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -1035,6 +1035,7 @@ extern "C" { pub fn LLVMRustEraseInstFromParent(V: &Value); pub fn LLVMRustGetTerminator<'a>(B: &BasicBlock) -> &'a Value; pub fn LLVMGetReturnType(T: &Type) -> &Type; + pub fn LLVMRustIsStructType(T: &Type) -> bool; pub fn LLVMDumpModule(M: &Module); pub fn LLVMCountStructElementTypes(T: &Type) -> c_uint; pub fn LLVMDeleteFunction(V: &Value); diff --git a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp index 548040579b392..078c8918939b0 100644 --- a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp +++ b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp @@ -300,6 +300,10 @@ extern "C" void LLVMRustAddFunctionAttributes(LLVMValueRef Fn, unsigned Index, AddAttributes(F, Index, Attrs, AttrsLen); } +extern "C" bool LLVMRustIsStructType(LLVMTypeRef Ty) { + return unwrap(Ty)->isStructTy(); +} + extern "C" void LLVMRustAddCallSiteAttributes(LLVMValueRef Instr, unsigned Index, LLVMAttributeRef *Attrs, From b445b6c3048b258113003b63a7b6b750334f0f8c Mon Sep 17 00:00:00 2001 From: Jed Brown Date: Fri, 19 Apr 2024 15:29:32 -0600 Subject: [PATCH 092/100] ci: update rustbook tests to use ENYZME_LOOSE_TYPES (#117) * ci: update rustbook tests to use ENYZME_LOOSE_TYPES * Update .github/workflows/enzyme-ci.yml Co-authored-by: Jed Brown --------- Co-authored-by: Manuel Drehwald --- .github/workflows/enzyme-ci.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/enzyme-ci.yml b/.github/workflows/enzyme-ci.yml index a08267fcdb725..fa44fb0eacbf0 100644 --- a/.github/workflows/enzyme-ci.yml +++ b/.github/workflows/enzyme-ci.yml @@ -62,4 +62,5 @@ jobs: - name: test Enzyme/rustbook working-directory: rustbook run: | - cargo +enzyme test --workspace + cargo +enzyme test + ENZYME_LOOSE_TYPES=1 cargo +enzyme test -p samples-loose-types From 4cae48171d92c3e369b55c98574dc9abf5f4177c Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Fri, 19 Apr 2024 18:32:53 -0400 Subject: [PATCH 093/100] Update enzyme-ci.yml (#84) * Update enzyme-ci.yml We don't need to build LLD (takes 2.5 min), we just want to use some LLD for linking. We also don't need llvm plugins. * Update enzyme-ci.yml * Update .github/workflows/enzyme-ci.yml Co-authored-by: Jed Brown * Update enzyme-ci.yml * Update enzyme-ci.yml or will it be --enable-use-lld? stay tuned. * Update enzyme-ci.yml * Update enzyme-ci.yml authored-by: @jedbrown * Update enzyme-ci.yml * Update enzyme-ci.yml * Update .github/workflows/enzyme-ci.yml Co-authored-by: Tim Gymnich * Update enzyme-ci.yml * Use vendored llvm * Update enzyme-ci.yml * Update enzyme-ci.yml * Update enzyme-ci.yml * Update enzyme-ci.yml * Revert "Update enzyme-ci.yml" This reverts commit 9ec11394556c7db454a69183d15fa36524380b78. * Revert "Update enzyme-ci.yml" This reverts commit 84fe2693025b13f34a10f69550bef76270ace172. * Fxing lld rebuild Co-authored-by: I-Al-Istannen --------- Co-authored-by: I-Al-Istannen Co-authored-by: Jed Brown Co-authored-by: Tim Gymnich Co-authored-by: William Moses Co-authored-by: Tim Gymnich --- .github/workflows/enzyme-ci.yml | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/.github/workflows/enzyme-ci.yml b/.github/workflows/enzyme-ci.yml index fa44fb0eacbf0..c9a3b8245851f 100644 --- a/.github/workflows/enzyme-ci.yml +++ b/.github/workflows/enzyme-ci.yml @@ -47,10 +47,20 @@ jobs: key: ${{ matrix.os }}-enzyme-${{ steps.enzyme-commit.outputs.HEAD }} - name: Build run: | + wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key|sudo apt-key add - + sudo apt-add-repository "deb http://apt.llvm.org/`lsb_release -c | cut -f2`/ llvm-toolchain-`lsb_release -c | cut -f2`-17 main" || true + sudo apt-get -y update + sudo apt-get install -y lld-17 + mkdir lld-path-manipulation + ln -s "$(which lld-17)" lld-path-manipulation/lld + ln -s "$(which lld-17)" lld-path-manipulation/ld + ln -s "$(which lld-17)" lld-path-manipulation/ld.lld + ln -s "$(which lld-17)" lld-path-manipulation/lld-17 + export PATH="$PWD/lld-path-manipulation:$PATH" mkdir -p build cd build rm -f config.toml - ../configure --enable-llvm-link-shared --enable-llvm-plugins --enable-llvm-enzyme --release-channel=nightly --enable-llvm-assertions --enable-lld --enable-option-checking --enable-ninja --disable-docs + ../configure --enable-llvm-link-shared --enable-llvm-enzyme --set=rust.use-lld=true --release-channel=nightly --enable-llvm-assertions --enable-option-checking --enable-ninja --disable-docs ../x.py build --stage 1 library/std library/proc_macro library/test tools/rustdoc rustup toolchain link enzyme build/host/stage1 - name: checkout Enzyme/rustbook From c9b28e79153a17630d03a1fc4e672eeff2179e21 Mon Sep 17 00:00:00 2001 From: I-Al-Istannen Date: Mon, 22 Apr 2024 16:56:28 +0200 Subject: [PATCH 094/100] chore: Cache random stuff (#118) * Be a bit more aggressive with caching --- .github/workflows/enzyme-ci.yml | 14 ++++++++++++++ src/bootstrap/src/core/build_steps/llvm.rs | 7 ++++++- src/bootstrap/src/lib.rs | 11 ++++++++++- 3 files changed, 30 insertions(+), 2 deletions(-) diff --git a/.github/workflows/enzyme-ci.yml b/.github/workflows/enzyme-ci.yml index c9a3b8245851f..19aa039478b39 100644 --- a/.github/workflows/enzyme-ci.yml +++ b/.github/workflows/enzyme-ci.yml @@ -45,6 +45,20 @@ jobs: with: path: build/build/x86_64-unknown-linux-gnu/enzyme key: ${{ matrix.os }}-enzyme-${{ steps.enzyme-commit.outputs.HEAD }} + - name: Cache bootstrap/stage0 artifacts for incremental builds + uses: actions/cache@v4 + with: + path: | + build/build/bootstrap/ + build/build/x86_64-unknown-linux-gnu/stage0-rustc/ + build/build/x86_64-unknown-linux-gnu/stage0-std/ + build/build/x86_64-unknown-linux-gnu/stage0-tools/ + build/build/x86_64-unknown-linux-gnu/stage1-std/ + # Approximate stable hash. It doesn't matter too much when this goes out of sync as it just caches + # some stage0/stage1 dependencies and stdlibs which *hopefully* are hash-keyed. + key: enzyme-rust-incremental-${{ runner.os }}-${{ hashFiles('src/**/Cargo.lock', 'Cargo.lock') }} + restore-keys: | + enzyme-rust-incremental-${{ runner.os }} - name: Build run: | wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key|sudo apt-key add - diff --git a/src/bootstrap/src/core/build_steps/llvm.rs b/src/bootstrap/src/core/build_steps/llvm.rs index 6736b985ebeba..57f15faca1315 100644 --- a/src/bootstrap/src/core/build_steps/llvm.rs +++ b/src/bootstrap/src/core/build_steps/llvm.rs @@ -1211,7 +1211,12 @@ impl HashStamp { fn is_done(&self) -> bool { match fs::read(&self.path) { - Ok(h) => self.hash.as_deref().unwrap_or(b"") == h.as_slice(), + Ok(h) => { + let unwrapped = self.hash.as_deref().unwrap_or(b""); + let res = unwrapped == h.as_slice(); + eprintln!("Result for {:?}: {res:?} for expected '{unwrapped:?}' and read '{h:?}'", self.path); + res + }, Err(e) if e.kind() == io::ErrorKind::NotFound => false, Err(e) => { panic!("failed to read stamp file `{}`: {}", self.path.display(), e); diff --git a/src/bootstrap/src/lib.rs b/src/bootstrap/src/lib.rs index a733e55c8f9d6..0726fc6ced40f 100644 --- a/src/bootstrap/src/lib.rs +++ b/src/bootstrap/src/lib.rs @@ -1866,6 +1866,9 @@ pub fn generate_smart_stamp_hash(dir: &Path, additional_input: &str) -> String { .map(|o| String::from_utf8(o.stdout).unwrap_or_default()) .unwrap_or_default(); + eprintln!("Computing stamp for {dir:?}"); + eprintln!("Diff output: {diff:?}"); + let status = Command::new("git") .current_dir(dir) .arg("status") @@ -1876,13 +1879,19 @@ pub fn generate_smart_stamp_hash(dir: &Path, additional_input: &str) -> String { .map(|o| String::from_utf8(o.stdout).unwrap_or_default()) .unwrap_or_default(); + eprintln!("Status output: {status:?}"); + eprintln!("Additional input: {additional_input:?}"); let mut hasher = sha2::Sha256::new(); hasher.update(diff); hasher.update(status); hasher.update(additional_input); - hex_encode(hasher.finalize().as_slice()) + let result = hex_encode(hasher.finalize().as_slice()); + + eprintln!("Final hash: {result:?}"); + + result } /// Ensures that the behavior dump directory is properly initialized. From c00250e65e89e25f89b362b9dbfda3608f0a4b57 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Tue, 23 Apr 2024 16:12:02 -0400 Subject: [PATCH 095/100] fix ci (#119) --- .github/workflows/enzyme-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/enzyme-ci.yml b/.github/workflows/enzyme-ci.yml index 19aa039478b39..2016b83462c2a 100644 --- a/.github/workflows/enzyme-ci.yml +++ b/.github/workflows/enzyme-ci.yml @@ -74,7 +74,7 @@ jobs: mkdir -p build cd build rm -f config.toml - ../configure --enable-llvm-link-shared --enable-llvm-enzyme --set=rust.use-lld=true --release-channel=nightly --enable-llvm-assertions --enable-option-checking --enable-ninja --disable-docs + ../configure --enable-llvm-link-shared --enable-llvm-plugins --enable-llvm-enzyme --set=rust.use-lld=true --release-channel=nightly --enable-llvm-assertions --enable-option-checking --enable-ninja --disable-docs ../x.py build --stage 1 library/std library/proc_macro library/test tools/rustdoc rustup toolchain link enzyme build/host/stage1 - name: checkout Enzyme/rustbook From 75b31f9239243eabe0dba0884999eb4426e2269c Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Wed, 1 May 2024 18:25:02 -0400 Subject: [PATCH 096/100] Update README.md (#120) --- README.md | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 594c69d0453be..c147b0262f2e8 100644 --- a/README.md +++ b/README.md @@ -25,14 +25,10 @@ rustup toolchain link enzyme build/host/stage1 rustup toolchain install nightly # enables -Z unstable-options ``` -You can then look at examples in the `library/autodiff/examples/*` folder and run them with +You can then run an examples from our [docs](https://enzyme.mit.edu/index.fcgi/rust/usage/usage.html) using ```bash -# rosenbrock forward iteration -cargo +enzyme run --example rosenbrock_fwd_iter --release - -# or all of them -cargo +enzyme test --examples +cargo +enzyme run --release ``` ## Enzyme Config From fc4ab22afa28d6034f9070246bc7afcc3bec380d Mon Sep 17 00:00:00 2001 From: I-Al-Istannen Date: Wed, 8 May 2024 15:44:30 +0200 Subject: [PATCH 097/100] ci: Print if stamp is missing --- src/bootstrap/src/core/build_steps/llvm.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/bootstrap/src/core/build_steps/llvm.rs b/src/bootstrap/src/core/build_steps/llvm.rs index 57f15faca1315..bc3cb6550c3f6 100644 --- a/src/bootstrap/src/core/build_steps/llvm.rs +++ b/src/bootstrap/src/core/build_steps/llvm.rs @@ -1217,7 +1217,10 @@ impl HashStamp { eprintln!("Result for {:?}: {res:?} for expected '{unwrapped:?}' and read '{h:?}'", self.path); res }, - Err(e) if e.kind() == io::ErrorKind::NotFound => false, + Err(e) if e.kind() == io::ErrorKind::NotFound => { + eprintln!("No existing stamp found at {:?}", self.path); + false + }, Err(e) => { panic!("failed to read stamp file `{}`: {}", self.path.display(), e); } From 9e4c0009b5c62dfeb11922e8c071820bc582d207 Mon Sep 17 00:00:00 2001 From: I-Al-Istannen Date: Fri, 10 May 2024 11:22:59 +0200 Subject: [PATCH 098/100] ci: Add tmate --- .github/workflows/enzyme-ci.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/enzyme-ci.yml b/.github/workflows/enzyme-ci.yml index 2016b83462c2a..59ac706f5406c 100644 --- a/.github/workflows/enzyme-ci.yml +++ b/.github/workflows/enzyme-ci.yml @@ -59,6 +59,8 @@ jobs: key: enzyme-rust-incremental-${{ runner.os }}-${{ hashFiles('src/**/Cargo.lock', 'Cargo.lock') }} restore-keys: | enzyme-rust-incremental-${{ runner.os }} + - name: Setup tmate session + uses: mxschmitt/action-tmate@v3 - name: Build run: | wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key|sudo apt-key add - From bd475ebc1d49be3db4ad09740b678233e91cc754 Mon Sep 17 00:00:00 2001 From: I-Al-Istannen Date: Fri, 10 May 2024 11:23:44 +0200 Subject: [PATCH 099/100] ci: Do not run further --- .github/workflows/enzyme-ci.yml | 60 +++++++++++++++++---------------- 1 file changed, 31 insertions(+), 29 deletions(-) diff --git a/.github/workflows/enzyme-ci.yml b/.github/workflows/enzyme-ci.yml index 59ac706f5406c..6bc9298f3bcf3 100644 --- a/.github/workflows/enzyme-ci.yml +++ b/.github/workflows/enzyme-ci.yml @@ -61,32 +61,34 @@ jobs: enzyme-rust-incremental-${{ runner.os }} - name: Setup tmate session uses: mxschmitt/action-tmate@v3 - - name: Build - run: | - wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key|sudo apt-key add - - sudo apt-add-repository "deb http://apt.llvm.org/`lsb_release -c | cut -f2`/ llvm-toolchain-`lsb_release -c | cut -f2`-17 main" || true - sudo apt-get -y update - sudo apt-get install -y lld-17 - mkdir lld-path-manipulation - ln -s "$(which lld-17)" lld-path-manipulation/lld - ln -s "$(which lld-17)" lld-path-manipulation/ld - ln -s "$(which lld-17)" lld-path-manipulation/ld.lld - ln -s "$(which lld-17)" lld-path-manipulation/lld-17 - export PATH="$PWD/lld-path-manipulation:$PATH" - mkdir -p build - cd build - rm -f config.toml - ../configure --enable-llvm-link-shared --enable-llvm-plugins --enable-llvm-enzyme --set=rust.use-lld=true --release-channel=nightly --enable-llvm-assertions --enable-option-checking --enable-ninja --disable-docs - ../x.py build --stage 1 library/std library/proc_macro library/test tools/rustdoc - rustup toolchain link enzyme build/host/stage1 - - name: checkout Enzyme/rustbook - uses: actions/checkout@v4 - with: - repository: EnzymeAD/rustbook - ref: main - path: rustbook - - name: test Enzyme/rustbook - working-directory: rustbook - run: | - cargo +enzyme test - ENZYME_LOOSE_TYPES=1 cargo +enzyme test -p samples-loose-types + - name: Fail + run: 'false' + # - name: Build + # run: | + # wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key|sudo apt-key add - + # sudo apt-add-repository "deb http://apt.llvm.org/`lsb_release -c | cut -f2`/ llvm-toolchain-`lsb_release -c | cut -f2`-17 main" || true + # sudo apt-get -y update + # sudo apt-get install -y lld-17 + # mkdir lld-path-manipulation + # ln -s "$(which lld-17)" lld-path-manipulation/lld + # ln -s "$(which lld-17)" lld-path-manipulation/ld + # ln -s "$(which lld-17)" lld-path-manipulation/ld.lld + # ln -s "$(which lld-17)" lld-path-manipulation/lld-17 + # export PATH="$PWD/lld-path-manipulation:$PATH" + # mkdir -p build + # cd build + # rm -f config.toml + # ../configure --enable-llvm-link-shared --enable-llvm-plugins --enable-llvm-enzyme --set=rust.use-lld=true --release-channel=nightly --enable-llvm-assertions --enable-option-checking --enable-ninja --disable-docs + # ../x.py build --stage 1 library/std library/proc_macro library/test tools/rustdoc + # rustup toolchain link enzyme build/host/stage1 + # - name: checkout Enzyme/rustbook + # uses: actions/checkout@v4 + # with: + # repository: EnzymeAD/rustbook + # ref: main + # path: rustbook + # - name: test Enzyme/rustbook + # working-directory: rustbook + # run: | + # cargo +enzyme test + # ENZYME_LOOSE_TYPES=1 cargo +enzyme test -p samples-loose-types From 0ebffb1b8975f89b6d8a165feda2f940a105f756 Mon Sep 17 00:00:00 2001 From: I-Al-Istannen Date: Fri, 10 May 2024 11:34:36 +0200 Subject: [PATCH 100/100] ci: create cache folders manually? --- .github/workflows/enzyme-ci.yml | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/.github/workflows/enzyme-ci.yml b/.github/workflows/enzyme-ci.yml index 6bc9298f3bcf3..4e87085ebc92a 100644 --- a/.github/workflows/enzyme-ci.yml +++ b/.github/workflows/enzyme-ci.yml @@ -27,6 +27,15 @@ jobs: submodules: true # check out all submodules so the cache can work correctly fetch-depth: 2 - uses: dtolnay/rust-toolchain@nightly + - name: Create cache folders + run: | + mkdir -p build/build/x86_64-unknown-linux-gnu/llvm + mkdir -p build/build/x86_64-unknown-linux-gnu/enzyme + mkdir -p build/build/bootstrap/ + mkdir -p build/build/x86_64-unknown-linux-gnu/stage0-rustc/ + mkdir -p build/build/x86_64-unknown-linux-gnu/stage0-std/ + mkdir -p build/build/x86_64-unknown-linux-gnu/stage0-tools/ + mkdir -p build/build/x86_64-unknown-linux-gnu/stage1-std/ - name: Get LLVM commit hash id: llvm-commit run: echo "HEAD=$(git rev-parse HEAD:src/llvm-project)" >> $GITHUB_OUTPUT