Skip to content

Commit

Permalink
Single commit with all changes
Browse files Browse the repository at this point in the history
  • Loading branch information
ZuseZ4 committed Aug 16, 2024
1 parent 730d5d4 commit 808353c
Show file tree
Hide file tree
Showing 80 changed files with 4,195 additions and 56 deletions.
4 changes: 4 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,7 @@
path = src/tools/rustc-perf
url = https://github.com/rust-lang/rustc-perf.git
shallow = true
[submodule "src/tools/enzyme"]
path = src/tools/enzyme
url = [email protected]:EnzymeAD/Enzyme.git
shallow = true
2 changes: 2 additions & 0 deletions Cargo.lock
Original file line number Diff line number Diff line change
Expand Up @@ -4148,6 +4148,7 @@ dependencies = [
name = "rustc_monomorphize"
version = "0.0.0"
dependencies = [
"rustc_ast",
"rustc_data_structures",
"rustc_errors",
"rustc_fluent_macro",
Expand All @@ -4156,6 +4157,7 @@ dependencies = [
"rustc_middle",
"rustc_session",
"rustc_span",
"rustc_symbol_mangling",
"rustc_target",
"serde",
"serde_json",
Expand Down
7 changes: 7 additions & 0 deletions compiler/rustc_ast/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2729,6 +2729,13 @@ impl FnRetTy {
FnRetTy::Ty(ty) => ty.span,
}
}

pub fn has_ret(&self) -> bool {
match self {
FnRetTy::Default(_) => false,
FnRetTy::Ty(_) => true,
}
}
}

#[derive(Clone, Copy, PartialEq, Encodable, Decodable, Debug)]
Expand Down
269 changes: 269 additions & 0 deletions compiler/rustc_ast/src/expand/autodiff_attrs.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,269 @@
use std::fmt::{self, Display, Formatter};
use std::str::FromStr;

use crate::expand::typetree::TypeTree;
use crate::expand::{Decodable, Encodable, HashStable_Generic};
use crate::ptr::P;
use crate::{Ty, TyKind};

#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub enum DiffMode {
Inactive,
Source,
Forward,
Reverse,
ForwardFirst,
ReverseFirst,
}

pub fn is_rev(mode: DiffMode) -> bool {
match mode {
DiffMode::Reverse | DiffMode::ReverseFirst => true,
_ => false,
}
}
pub fn is_fwd(mode: DiffMode) -> bool {
match mode {
DiffMode::Forward | DiffMode::ForwardFirst => true,
_ => false,
}
}

impl Display for DiffMode {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
DiffMode::Inactive => write!(f, "Inactive"),
DiffMode::Source => write!(f, "Source"),
DiffMode::Forward => write!(f, "Forward"),
DiffMode::Reverse => write!(f, "Reverse"),
DiffMode::ForwardFirst => write!(f, "ForwardFirst"),
DiffMode::ReverseFirst => write!(f, "ReverseFirst"),
}
}
}

pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool {
if activity == DiffActivity::None {
// Only valid if primal returns (), but we can't check that here.
return true;
}
match mode {
DiffMode::Inactive => false,
DiffMode::Source => false,
DiffMode::Forward | DiffMode::ForwardFirst => {
activity == DiffActivity::Dual
|| activity == DiffActivity::DualOnly
|| activity == DiffActivity::Const
}
DiffMode::Reverse | DiffMode::ReverseFirst => {
activity == DiffActivity::Const
|| activity == DiffActivity::Active
|| activity == DiffActivity::ActiveOnly
}
}
}
fn is_ptr_or_ref(ty: &Ty) -> bool {
match ty.kind {
TyKind::Ptr(_) | TyKind::Ref(_, _) => true,
_ => false,
}
}
// TODO We should make this more robust to also
// accept aliases of f32 and f64
//fn is_float(ty: &Ty) -> bool {
// false
//}
pub fn valid_ty_for_activity(ty: &P<Ty>, activity: DiffActivity) -> bool {
if is_ptr_or_ref(ty) {
return activity == DiffActivity::Dual
|| activity == DiffActivity::DualOnly
|| activity == DiffActivity::Duplicated
|| activity == DiffActivity::DuplicatedOnly
|| activity == DiffActivity::Const;
}
true
//if is_scalar_ty(&ty) {
// return activity == DiffActivity::Active || activity == DiffActivity::ActiveOnly ||
// activity == DiffActivity::Const;
//}
}
pub fn valid_input_activity(mode: DiffMode, activity: DiffActivity) -> bool {
return match mode {
DiffMode::Inactive => false,
DiffMode::Source => false,
DiffMode::Forward | DiffMode::ForwardFirst => {
// These are the only valid cases
activity == DiffActivity::Dual
|| activity == DiffActivity::DualOnly
|| activity == DiffActivity::Const
}
DiffMode::Reverse | DiffMode::ReverseFirst => {
// These are the only valid cases
activity == DiffActivity::Active
|| activity == DiffActivity::ActiveOnly
|| activity == DiffActivity::Const
|| activity == DiffActivity::Duplicated
|| activity == DiffActivity::DuplicatedOnly
}
};
}
pub fn invalid_input_activities(mode: DiffMode, activity_vec: &[DiffActivity]) -> Option<usize> {
for i in 0..activity_vec.len() {
if !valid_input_activity(mode, activity_vec[i]) {
return Some(i);
}
}
None
}

#[allow(dead_code)]
#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub enum DiffActivity {
None,
Const,
Active,
ActiveOnly,
Dual,
DualOnly,
Duplicated,
DuplicatedOnly,
FakeActivitySize,
}

impl Display for DiffActivity {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
DiffActivity::None => write!(f, "None"),
DiffActivity::Const => write!(f, "Const"),
DiffActivity::Active => write!(f, "Active"),
DiffActivity::ActiveOnly => write!(f, "ActiveOnly"),
DiffActivity::Dual => write!(f, "Dual"),
DiffActivity::DualOnly => write!(f, "DualOnly"),
DiffActivity::Duplicated => write!(f, "Duplicated"),
DiffActivity::DuplicatedOnly => write!(f, "DuplicatedOnly"),
DiffActivity::FakeActivitySize => write!(f, "FakeActivitySize"),
}
}
}

impl FromStr for DiffMode {
type Err = ();

fn from_str(s: &str) -> Result<DiffMode, ()> {
match s {
"Inactive" => Ok(DiffMode::Inactive),
"Source" => Ok(DiffMode::Source),
"Forward" => Ok(DiffMode::Forward),
"Reverse" => Ok(DiffMode::Reverse),
"ForwardFirst" => Ok(DiffMode::ForwardFirst),
"ReverseFirst" => Ok(DiffMode::ReverseFirst),
_ => Err(()),
}
}
}
impl FromStr for DiffActivity {
type Err = ();

fn from_str(s: &str) -> Result<DiffActivity, ()> {
match s {
"None" => Ok(DiffActivity::None),
"Active" => Ok(DiffActivity::Active),
"ActiveOnly" => Ok(DiffActivity::ActiveOnly),
"Const" => Ok(DiffActivity::Const),
"Dual" => Ok(DiffActivity::Dual),
"DualOnly" => Ok(DiffActivity::DualOnly),
"Duplicated" => Ok(DiffActivity::Duplicated),
"DuplicatedOnly" => Ok(DiffActivity::DuplicatedOnly),
_ => Err(()),
}
}
}

#[allow(dead_code)]
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub struct AutoDiffAttrs {
pub mode: DiffMode,
pub ret_activity: DiffActivity,
pub input_activity: Vec<DiffActivity>,
}

impl AutoDiffAttrs {
pub fn has_ret_activity(&self) -> bool {
match self.ret_activity {
DiffActivity::None => false,
_ => true,
}
}
pub fn has_active_only_ret(&self) -> bool {
match self.ret_activity {
DiffActivity::ActiveOnly => true,
_ => false,
}
}
}

impl AutoDiffAttrs {
pub fn inactive() -> Self {
AutoDiffAttrs {
mode: DiffMode::Inactive,
ret_activity: DiffActivity::None,
input_activity: Vec::new(),
}
}
pub fn source() -> Self {
AutoDiffAttrs {
mode: DiffMode::Source,
ret_activity: DiffActivity::None,
input_activity: Vec::new(),
}
}

pub fn is_active(&self) -> bool {
match self.mode {
DiffMode::Inactive => false,
_ => true,
}
}

pub fn is_source(&self) -> bool {
match self.mode {
DiffMode::Source => true,
_ => false,
}
}
pub fn apply_autodiff(&self) -> bool {
match self.mode {
DiffMode::Inactive => false,
DiffMode::Source => false,
_ => true,
}
}

pub fn into_item(
self,
source: String,
target: String,
inputs: Vec<TypeTree>,
output: TypeTree,
) -> AutoDiffItem {
AutoDiffItem { source, target, inputs, output, attrs: self }
}
}

#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub struct AutoDiffItem {
pub source: String,
pub target: String,
pub attrs: AutoDiffAttrs,
pub inputs: Vec<TypeTree>,
pub output: TypeTree,
}

impl fmt::Display for AutoDiffItem {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Differentiating {} -> {}", self.source, self.target)?;
write!(f, " with attributes: {:?}", self.attrs)?;
write!(f, " with inputs: {:?}", self.inputs)?;
write!(f, " with output: {:?}", self.output)
}
}
2 changes: 2 additions & 0 deletions compiler/rustc_ast/src/expand/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ use rustc_span::symbol::Ident;
use crate::MetaItem;

pub mod allocator;
pub mod autodiff_attrs;
pub mod typetree;

#[derive(Debug, Clone, Encodable, Decodable, HashStable_Generic)]
pub struct StrippedCfgItem<ModId = DefId> {
Expand Down
69 changes: 69 additions & 0 deletions compiler/rustc_ast/src/expand/typetree.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
use std::fmt;

use crate::expand::{Decodable, Encodable, HashStable_Generic};

#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub enum Kind {
Anything,
Integer,
Pointer,
Half,
Float,
Double,
Unknown,
}

#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub struct TypeTree(pub Vec<Type>);

impl TypeTree {
pub fn new() -> Self {
Self(Vec::new())
}
pub fn all_ints() -> Self {
Self(vec![Type { offset: -1, size: 1, kind: Kind::Integer, child: TypeTree::new() }])
}
pub fn int(size: usize) -> Self {
let mut ints = Vec::with_capacity(size);
for i in 0..size {
ints.push(Type {
offset: i as isize,
size: 1,
kind: Kind::Integer,
child: TypeTree::new(),
});
}
Self(ints)
}
}

#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub struct FncTree {
pub args: Vec<TypeTree>,
pub ret: TypeTree,
}

#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub struct Type {
pub offset: isize,
pub size: usize,
pub kind: Kind,
pub child: TypeTree,
}

impl Type {
pub fn add_offset(self, add: isize) -> Self {
let offset = match self.offset {
-1 => add,
x => add + x,
};

Self { size: self.size, kind: self.kind, child: self.child, offset }
}
}

impl fmt::Display for Type {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
<Self as fmt::Debug>::fmt(self, f)
}
}
4 changes: 4 additions & 0 deletions compiler/rustc_builtin_macros/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@ name = "rustc_builtin_macros"
version = "0.0.0"
edition = "2021"


[lints.rust]
unexpected_cfgs = { level = "warn", check-cfg = ['cfg(llvm_enzyme)'] }

[lib]
doctest = false

Expand Down
Loading

0 comments on commit 808353c

Please sign in to comment.