From a05bee9ccba93fcd4b6e9d6adb864829ba8768c6 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Wed, 14 Aug 2024 11:38:54 +0200 Subject: [PATCH 01/47] Start rewriting PTX parser --- Cargo.toml | 5 + gen/Cargo.toml | 15 + gen/src/lib.rs | 860 +++++++++++++++++++++++++++++++++++++++++ gen_impl/Cargo.toml | 12 + gen_impl/src/lib.rs | 718 ++++++++++++++++++++++++++++++++++ gen_impl/src/parser.rs | 793 +++++++++++++++++++++++++++++++++++++ ptx_parser/Cargo.toml | 9 + ptx_parser/src/main.rs | 437 +++++++++++++++++++++ 8 files changed, 2849 insertions(+) create mode 100644 gen/Cargo.toml create mode 100644 gen/src/lib.rs create mode 100644 gen_impl/Cargo.toml create mode 100644 gen_impl/src/lib.rs create mode 100644 gen_impl/src/parser.rs create mode 100644 ptx_parser/Cargo.toml create mode 100644 ptx_parser/src/main.rs diff --git a/Cargo.toml b/Cargo.toml index 63719813..7f38976a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,7 @@ [workspace] +resolver = "2" + members = [ "cuda_base", "cuda_types", @@ -15,6 +17,9 @@ members = [ "zluda_redirect", "zluda_ml", "ptx", + "gen", + "gen_impl", + "ptx_parser" ] default-members = ["zluda_lib", "zluda_ml", "zluda_inject", "zluda_redirect"] diff --git a/gen/Cargo.toml b/gen/Cargo.toml new file mode 100644 index 00000000..e24be0f4 --- /dev/null +++ b/gen/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "gen" +version = "0.1.0" +edition = "2021" + +[lib] +proc-macro = true + +[dependencies] +gen_impl = { path = "../gen_impl" } +convert_case = "0.6.0" +rustc-hash = "2.0.0" +syn = "2.0.67" +quote = "1.0" +proc-macro2 = "1.0.86" diff --git a/gen/src/lib.rs b/gen/src/lib.rs new file mode 100644 index 00000000..f39150ff --- /dev/null +++ b/gen/src/lib.rs @@ -0,0 +1,860 @@ +use gen_impl::parser; +use proc_macro2::{Span, TokenStream}; +use quote::{format_ident, quote, ToTokens}; +use rustc_hash::{FxHashMap, FxHashSet}; +use std::{collections::hash_map, hash::Hash, rc::Rc}; +use syn::{parse_macro_input, punctuated::Punctuated, Ident, ItemEnum, Token, TypePath, Variant}; + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#vectors +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#fundamental-types +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#alternate-floating-point-data-formats +#[rustfmt::skip] +static POSTFIX_MODIFIERS: &[&str] = &[ + ".v2", ".v4", + ".s8", ".s16", ".s32", ".s64", + ".u8", ".u16", ".u32", ".u64", + ".f16", ".f16x2", ".f32", ".f64", + ".b8", ".b16", ".b32", ".b64", ".b128", + ".pred", + ".bf16", ".e4m3", ".e5m2", ".tf32", +]; + +static POSTFIX_TYPES: &[&str] = &["ScalarType", "VectorPrefix"]; + +struct OpcodeDefinitions { + definitions: Vec, + block_selection: Vec, usize)>>, +} + +impl OpcodeDefinitions { + fn new(opcode: &Ident, definitions: Vec) -> Self { + let mut selections = vec![None; definitions.len()]; + let mut generation = 0usize; + loop { + let mut selected_something = false; + let unselected = selections + .iter() + .enumerate() + .filter_map(|(idx, s)| if s.is_none() { Some(idx) } else { None }) + .collect::>(); + match &*unselected { + [] => break, + [remaining] => { + selections[*remaining] = Some((None, generation)); + break; + } + _ => {} + } + 'check_definitions: for i in unselected.iter().copied() { + // just pick the first alternative and attempt every modifier + 'check_candidates: for candidate in definitions[i] + .unordered_modifiers + .iter() + .chain(definitions[i].ordered_modifiers.iter()) + { + let candidate = if let DotModifierRef::Direct { + optional: false, + value, + } = candidate + { + value + } else { + continue; + }; + // check all other unselected patterns + for j in unselected.iter().copied() { + if i == j { + continue; + } + if definitions[j].possible_modifiers.contains(candidate) { + continue 'check_candidates; + } + } + // it's unique + selections[i] = Some((Some(candidate), generation)); + selected_something = true; + continue 'check_definitions; + } + } + if !selected_something { + panic!( + "Failed to generate pattern selection for `{}`. State: {:?}", + opcode, + selections.into_iter().rev().collect::>() + ); + } + generation += 1; + } + let mut block_selection = Vec::new(); + for current_generation in 0usize.. { + let mut current_generation_definitions = Vec::new(); + for (idx, selection) in selections.iter_mut().enumerate() { + match selection { + Some((modifier, generation)) => { + if *generation == current_generation { + current_generation_definitions.push((modifier.cloned(), idx)); + *selection = None; + } + } + None => {} + } + } + if current_generation_definitions.is_empty() { + break; + } + block_selection.push(current_generation_definitions); + } + #[cfg(debug_assertions)] + { + let selected = block_selection + .iter() + .map(|x| x.len()) + .reduce(|x, y| x + y) + .unwrap(); + if selected != definitions.len() { + panic!( + "Internal error when generating pattern selection for `{}`: {:?}", + opcode, &block_selection + ); + } + } + Self { + definitions, + block_selection, + } + } + + fn get_enum_types( + parse_definitions: &[parser::OpcodeDefinition], + ) -> FxHashMap> { + let mut result = FxHashMap::default(); + for parser::OpcodeDefinition(_, rules) in parse_definitions.iter() { + for rule in rules { + let type_ = match rule.type_ { + Some(ref type_) => type_.clone(), + None => continue, + }; + let insert_values = |set: &mut FxHashSet<_>| { + for value in rule.alternatives.iter().cloned() { + set.insert(value); + } + }; + match result.entry(type_) { + hash_map::Entry::Occupied(mut entry) => insert_values(entry.get_mut()), + hash_map::Entry::Vacant(entry) => { + insert_values(entry.insert(FxHashSet::default())) + } + }; + } + } + result + } +} + +struct SingleOpcodeDefinition { + possible_modifiers: FxHashSet, + unordered_modifiers: Vec, + ordered_modifiers: Vec, + arguments: parser::Arguments, + code_block: parser::CodeBlock, +} + +impl SingleOpcodeDefinition { + fn function_arguments_declarations(&self) -> impl Iterator + '_ { + self.unordered_modifiers + .iter() + .chain(self.ordered_modifiers.iter()) + .filter_map(|modf| { + let type_ = modf.type_of(); + type_.map(|t| { + let name = modf.ident(); + quote! { #name : #t } + }) + }) + .chain(self.arguments.0.iter().map(|arg| { + let name = &arg.ident; + if arg.optional { + quote! { #name : Option<&str> } + } else { + quote! { #name : &str } + } + })) + } + + fn function_arguments(&self) -> impl Iterator + '_ { + self.unordered_modifiers + .iter() + .chain(self.ordered_modifiers.iter()) + .filter_map(|modf| { + let type_ = modf.type_of(); + type_.map(|_| { + let name = modf.ident(); + quote! { #name } + }) + }) + .chain(self.arguments.0.iter().map(|arg| { + let name = &arg.ident; + quote! { #name } + })) + } + + fn extract_and_insert( + output: &mut FxHashMap>, + parser::OpcodeDefinition(pattern_seq, rules): parser::OpcodeDefinition, + ) { + let mut rules = rules + .into_iter() + .map(|r| (r.modifier.clone(), Rc::new(r))) + .collect::>(); + let mut last_opcode = pattern_seq.0.last().unwrap().0 .0.name.clone(); + for (opcode_decl, code_block) in pattern_seq.0.into_iter().rev() { + let current_opcode = opcode_decl.0.name.clone(); + if last_opcode != current_opcode { + rules = FxHashMap::default(); + } + let mut possible_modifiers = FxHashSet::default(); + for (_, options) in rules.iter() { + possible_modifiers.extend(options.alternatives.iter().cloned()); + } + let parser::OpcodeDecl(instruction, arguments) = opcode_decl; + let mut unordered_modifiers = instruction + .modifiers + .into_iter() + .map( + |parser::MaybeDotModifier { optional, modifier }| match rules.get(&modifier) { + Some(alts) => { + if alts.alternatives.len() == 1 && alts.type_.is_none() { + DotModifierRef::Direct { + optional, + value: alts.alternatives[0].clone(), + } + } else { + DotModifierRef::Indirect { + optional, + value: alts.clone(), + } + } + } + None => { + possible_modifiers.insert(modifier.clone()); + DotModifierRef::Direct { + optional, + value: modifier, + } + } + }, + ) + .collect::>(); + let ordered_modifiers = Self::extract_ordered_modifiers(&mut unordered_modifiers); + let entry = Self { + possible_modifiers, + unordered_modifiers, + ordered_modifiers, + arguments, + code_block, + }; + multihash_extend(output, current_opcode.clone(), entry); + last_opcode = current_opcode; + } + } + + fn extract_ordered_modifiers( + unordered_modifiers: &mut Vec, + ) -> Vec { + let mut result = Vec::new(); + loop { + let is_ordered = match unordered_modifiers.last() { + Some(DotModifierRef::Direct { value, .. }) => { + let name = value.to_string(); + POSTFIX_MODIFIERS.contains(&&*name) + } + Some(DotModifierRef::Indirect { value, .. }) => { + let type_ = value.type_.to_token_stream().to_string(); + //panic!("{} {}", type_, POSTFIX_TYPES.contains(&&*type_)); + POSTFIX_TYPES.contains(&&*type_) + } + None => break, + }; + if is_ordered { + result.push(unordered_modifiers.pop().unwrap()); + } else { + break; + } + } + if unordered_modifiers.len() == 1 { + result.push(unordered_modifiers.pop().unwrap()); + } + result.reverse(); + result + } +} + +#[proc_macro] +pub fn derive_parser(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream { + let parse_definitions = parse_macro_input!(tokens as gen_impl::parser::ParseDefinitions); + let mut definitions = FxHashMap::default(); + let types = OpcodeDefinitions::get_enum_types(&parse_definitions.definitions); + let enum_types_tokens = emit_enum_types(types, parse_definitions.additional_enums); + for definition in parse_definitions.definitions.into_iter() { + SingleOpcodeDefinition::extract_and_insert(&mut definitions, definition); + } + let definitions = definitions + .into_iter() + .map(|(k, v)| { + let v = OpcodeDefinitions::new(&k, v); + (k, v) + }) + .collect::>(); + let mut token_enum = parse_definitions.token_type; + let (_, all_modifier) = write_definitions_into_tokens(&definitions, &mut token_enum.variants); + let token_impl = emit_parse_function(&token_enum.ident, &definitions, all_modifier); + let tokens = quote! { + #enum_types_tokens + + #token_enum + + #token_impl + }; + tokens.into() +} + +fn emit_enum_types( + types: FxHashMap>, + mut existing_enums: FxHashMap, +) -> TokenStream { + let token_types = types.into_iter().filter_map(|(type_, variants)| { + match type_ { + syn::Type::Path(TypePath { + qself: None, + ref path, + }) => { + if let Some(ident) = path.get_ident() { + if let Some(enum_) = existing_enums.get_mut(ident) { + enum_.variants.extend(variants.into_iter().map(|modifier| { + let ident = modifier.variant_capitalized(); + let variant: syn::Variant = syn::parse_quote! { + #ident + }; + variant + })); + return None; + } + } + } + _ => {} + } + let variants = variants.iter().map(|v| v.variant_capitalized()); + Some(quote! { + #[derive(Copy, Clone, PartialEq, Eq, Hash)] + enum #type_ { + #(#variants),* + } + }) + }); + let mut result = TokenStream::new(); + for tokens in token_types { + tokens.to_tokens(&mut result); + } + for (_, enum_) in existing_enums { + quote! { #enum_ }.to_tokens(&mut result); + } + result +} + +fn emit_parse_function( + type_name: &Ident, + defs: &FxHashMap, + all_modifier: FxHashSet<&parser::DotModifier>, +) -> TokenStream { + use std::fmt::Write; + let fns_ = defs + .iter() + .map(|(opcode, defs)| { + defs.definitions.iter().enumerate().map(|(idx, def)| { + let mut fn_name = opcode.to_string(); + write!(&mut fn_name, "_{}", idx).ok(); + let fn_name = Ident::new(&fn_name, Span::call_site()); + let code_block = &def.code_block.0; + let args = def.function_arguments_declarations(); + quote! { + fn #fn_name( #(#args),* ) -> Instruction #code_block + } + }) + }) + .flatten(); + let selectors = defs.iter().map(|(opcode, def)| { + let opcode_variant = Ident::new(&capitalize(&opcode.to_string()), opcode.span()); + let mut result = TokenStream::new(); + let mut selectors = TokenStream::new(); + quote! { + if false { + unsafe { std::hint::unreachable_unchecked() } + } + } + .to_tokens(&mut selectors); + let mut has_default_selector = false; + for selection_layer in def.block_selection.iter() { + for (selection_key, selected_definition) in selection_layer { + let def_parser = emit_definition_parser(type_name, (opcode,*selected_definition), &def.definitions[*selected_definition]); + match selection_key { + Some(selection_key) => { + let selection_key = + selection_key.dot_capitalized(); + quote! { + else if modifiers.contains(& #type_name :: #selection_key) { + #def_parser + } + } + .to_tokens(&mut selectors); + } + None => { + has_default_selector = true; + quote! { + else { + #def_parser + } + } + .to_tokens(&mut selectors); + } + } + } + } + if !has_default_selector { + quote! { + else { + return Err(winnow::error::ErrMode::from_error_kind(stream, winnow::error::ErrorKind::Token)) + } + } + .to_tokens(&mut selectors); + } + quote! { + #opcode_variant => { + let modifers_start = stream.checkpoint(); + let modifiers = take_while(0.., Token::modifier).parse_next(stream)?; + #selectors + } + } + .to_tokens(&mut result); + result + }); + let modifier_names = all_modifier.iter().map(|m| m.dot_capitalized()); + quote! { + impl<'input> #type_name<'input> { + fn modifier(self) -> bool { + match self { + #( + #type_name :: #modifier_names => true, + )* + _ => false + } + } + } + + #(#fns_)* + + fn parse_instruction<'input>(stream: &mut (impl winnow::stream::Stream, Slice = &'input [#type_name<'input>]> + winnow::stream::StreamIsPartial)) -> winnow::error::PResult> + { + use winnow::Parser; + use winnow::token::*; + use winnow::combinator::*; + let opcode = any.parse_next(stream)?; + let modifiers_start = stream.checkpoint(); + Ok(match opcode { + #( + #type_name :: #selectors + )* + _ => return Err(winnow::error::ErrMode::from_error_kind(stream, winnow::error::ErrorKind::Token)) + }) + } + } +} + +fn emit_definition_parser( + token_type: &Ident, + (opcode, fn_idx): (&Ident, usize), + definition: &SingleOpcodeDefinition, +) -> TokenStream { + let return_error_ref = quote! { + return Err(winnow::error::ErrMode::from_error_kind(&stream, winnow::error::ErrorKind::Token)) + }; + let return_error = quote! { + return Err(winnow::error::ErrMode::from_error_kind(stream, winnow::error::ErrorKind::Token)) + }; + let ordered_parse_declarations = definition.ordered_modifiers.iter().map(|modifier| { + modifier.type_of().map(|type_| { + let name = modifier.ident(); + quote! { + let #name : #type_; + } + }) + }); + let ordered_parse = definition.ordered_modifiers.iter().rev().map(|modifier| { + let arg_name = modifier.ident(); + let arg_type = modifier.type_of(); + match modifier { + DotModifierRef::Direct { optional, value } => { + let variant = value.dot_capitalized(); + if *optional { + quote! { + #arg_name = opt(any.verify(|t| *t == #token_type :: #variant)).parse_next(&mut stream)?.is_some(); + } + } else { + quote! { + any.verify(|t| *t == #token_type :: #variant).parse_next(&mut stream)?; + } + } + } + DotModifierRef::Indirect { optional, value } => { + let variants = value.alternatives.iter().map(|alt| { + let type_ = value.type_.as_ref().unwrap(); + let token_variant = alt.dot_capitalized(); + let parsed_variant = alt.variant_capitalized(); + quote! { + #token_type :: #token_variant => #type_ :: #parsed_variant, + } + }); + if *optional { + quote! { + #arg_name = opt(any.verify_map(|tok| { + Some(match tok { + #(#variants)* + _ => return None + }) + })).parse_next(&mut stream)?; + } + } else { + quote! { + #arg_name = any.verify_map(|tok| { + Some(match tok { + #(#variants)* + _ => return None + }) + }).parse_next(&mut stream)?; + } + } + } + } + }); + let unordered_parse_declarations = definition.unordered_modifiers.iter().map(|modifier| { + let name = modifier.ident(); + let type_ = modifier.type_of_check(); + quote! { + let mut #name : #type_ = std::default::Default::default(); + } + }); + let unordered_parse = definition + .unordered_modifiers + .iter() + .map(|modifier| match modifier { + DotModifierRef::Direct { value, .. } => { + let name = value.ident(); + let token_variant = value.dot_capitalized(); + quote! { + #token_type :: #token_variant => { + if #name { + #return_error_ref; + } + #name = true; + } + } + } + DotModifierRef::Indirect { value, .. } => { + let variable = value.modifier.ident(); + let type_ = value.type_.as_ref().unwrap(); + let alternatives = value.alternatives.iter().map(|alt| { + let token_variant = alt.dot_capitalized(); + let enum_variant = alt.variant_capitalized(); + quote! { + #token_type :: #token_variant => { + if #variable.is_some() { + #return_error_ref; + } + #variable = Some(#type_ :: #enum_variant); + } + } + }); + quote! { + #(#alternatives)* + } + } + }); + let unordered_parse_validations = + definition + .unordered_modifiers + .iter() + .map(|modifier| match modifier { + DotModifierRef::Direct { + optional: false, + value, + } => { + let variable = value.ident(); + quote! { + if !#variable { + #return_error; + } + } + } + DotModifierRef::Direct { optional: true, .. } => TokenStream::new(), + DotModifierRef::Indirect { + optional: false, + value, + } => { + let variable = value.modifier.ident(); + quote! { + let #variable = match #variable { + Some(x) => x, + None => #return_error + }; + } + } + DotModifierRef::Indirect { optional: true, .. } => TokenStream::new(), + }); + let arguments_parse = definition.arguments.0.iter().enumerate().map(|(idx, arg)| { + let comma = if idx == 0 { + quote! { empty } + } else { + quote! { any.verify(|t| *t == #token_type::Comma) } + }; + let pre_bracket = if arg.pre_bracket { + quote! { + any.verify(|t| *t == #token_type::LBracket).map(|_| ()) + } + } else { + quote! { + empty + } + }; + let pre_pipe = if arg.pre_pipe { + quote! { + any.verify(|t| *t == #token_type::Or).map(|_| ()) + } + } else { + quote! { + empty + } + }; + let can_be_negated = if arg.can_be_negated { + quote! { + opt(any.verify(|t| *t == #token_type::Not)).map(|o| o.is_some()) + } + } else { + quote! { + empty + } + }; + let ident = { + quote! { + any.verify_map(|t| match t { #token_type::Ident(s) => Some(s), _ => None }) + } + }; + let post_bracket = if arg.post_bracket { + quote! { + any.verify(|t| *t == #token_type::RBracket).map(|_| ()) + } + } else { + quote! { + empty + } + }; + let parser = quote! { + (#comma, #pre_bracket, #pre_pipe, #can_be_negated, #ident, #post_bracket) + }; + let arg_name = &arg.ident; + if arg.optional { + quote! { + let #arg_name = opt(#parser.map(|(_, _, _, _, name, _)| name)).parse_next(stream)?; + } + } else { + quote! { + let #arg_name = #parser.map(|(_, _, _, _, name, _)| name).parse_next(stream)?; + } + } + }); + let fn_args = definition.function_arguments(); + let fn_name = format_ident!("{}_{}", opcode, fn_idx); + let fn_call = quote! { + #fn_name( #(#fn_args),* ) + }; + quote! { + #(#unordered_parse_declarations)* + #(#ordered_parse_declarations)* + { + let mut stream = ReverseStream(modifiers); + #(#ordered_parse)* + let mut stream: &[#token_type] = stream.0; + for token in stream.iter().copied() { + match token { + #(#unordered_parse)* + _ => #return_error_ref + } + } + } + #(#unordered_parse_validations)* + #(#arguments_parse)* + #fn_call + } +} + +fn write_definitions_into_tokens<'a>( + defs: &'a FxHashMap, + variants: &mut Punctuated, +) -> (Vec<&'a Ident>, FxHashSet<&'a parser::DotModifier>) { + let mut all_opcodes = Vec::new(); + let mut all_modifiers = FxHashSet::default(); + for (opcode, definitions) in defs.iter() { + all_opcodes.push(opcode); + let opcode_as_string = opcode.to_string(); + let variant_name = Ident::new(&capitalize(&opcode_as_string), opcode.span()); + let arg: Variant = syn::parse_quote! { + #[token(#opcode_as_string)] + #variant_name + }; + variants.push(arg); + for definition in definitions.definitions.iter() { + for modifier in definition.possible_modifiers.iter() { + all_modifiers.insert(modifier); + } + } + } + for modifier in all_modifiers.iter() { + let modifier_as_string = modifier.to_string(); + let variant_name = modifier.dot_capitalized(); + let arg: Variant = syn::parse_quote! { + #[token(#modifier_as_string)] + #variant_name + }; + variants.push(arg); + } + (all_opcodes, all_modifiers) +} + +fn capitalize(s: &str) -> String { + let mut c = s.chars(); + match c.next() { + None => String::new(), + Some(f) => f.to_uppercase().collect::() + c.as_str(), + } +} + +fn multihash_extend(multimap: &mut FxHashMap>, k: K, v: V) { + match multimap.entry(k) { + hash_map::Entry::Occupied(mut entry) => entry.get_mut().push(v), + hash_map::Entry::Vacant(entry) => { + entry.insert(vec![v]); + } + } +} + +enum DotModifierRef { + Direct { + optional: bool, + value: parser::DotModifier, + }, + Indirect { + optional: bool, + value: Rc, + }, +} + +impl DotModifierRef { + fn ident(&self) -> Ident { + match self { + DotModifierRef::Direct { value, .. } => value.ident(), + DotModifierRef::Indirect { value, .. } => value.modifier.ident(), + } + } + + fn type_of(&self) -> Option { + Some(match self { + DotModifierRef::Direct { optional: true, .. } => syn::parse_quote! { bool }, + DotModifierRef::Direct { + optional: false, .. + } => return None, + DotModifierRef::Indirect { optional, value } => { + let type_ = value + .type_ + .as_ref() + .expect("Indirect modifer must have a type"); + if *optional { + syn::parse_quote! { Option<#type_> } + } else { + type_.clone() + } + } + }) + } + + fn type_of_check(&self) -> syn::Type { + match self { + DotModifierRef::Direct { .. } => syn::parse_quote! { bool }, + DotModifierRef::Indirect { value, .. } => { + let type_ = value + .type_ + .as_ref() + .expect("Indirect modifer must have a type"); + syn::parse_quote! { Option<#type_> } + } + } + } +} + +impl Hash for DotModifierRef { + fn hash(&self, state: &mut H) { + match self { + DotModifierRef::Direct { optional, value } => { + optional.hash(state); + value.hash(state); + } + DotModifierRef::Indirect { optional, value } => { + optional.hash(state); + (value.as_ref() as *const parser::Rule).hash(state); + } + } + } +} + +impl PartialEq for DotModifierRef { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + ( + Self::Direct { + optional: l_optional, + value: l_value, + }, + Self::Direct { + optional: r_optional, + value: r_value, + }, + ) => l_optional == r_optional && l_value == r_value, + ( + Self::Indirect { + optional: l_optional, + value: l_value, + }, + Self::Indirect { + optional: r_optional, + value: r_value, + }, + ) => { + l_optional == r_optional + && l_value.as_ref() as *const parser::Rule + == r_value.as_ref() as *const parser::Rule + } + _ => false, + } + } +} + +impl Eq for DotModifierRef {} + +#[proc_macro] +pub fn generate_instruction_type(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream { + let input = parse_macro_input!(tokens as gen_impl::GenerateInstructionType); + let mut result = proc_macro2::TokenStream::new(); + input.emit_arg_types(&mut result); + input.emit_instruction_type(&mut result); + input.emit_visit(&mut result); + input.emit_visit_mut(&mut result); + input.emit_visit_map(&mut result); + result.into() +} diff --git a/gen_impl/Cargo.toml b/gen_impl/Cargo.toml new file mode 100644 index 00000000..ff93f98c --- /dev/null +++ b/gen_impl/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "gen_impl" +version = "0.1.0" +edition = "2021" + +[lib] + +[dependencies] +syn = { version = "2.0.67", features = ["extra-traits", "full"] } +quote = "1.0" +proc-macro2 = "1.0.86" +rustc-hash = "2.0.0" diff --git a/gen_impl/src/lib.rs b/gen_impl/src/lib.rs new file mode 100644 index 00000000..4c7f2ab2 --- /dev/null +++ b/gen_impl/src/lib.rs @@ -0,0 +1,718 @@ +use proc_macro2::TokenStream; +use quote::{format_ident, quote, ToTokens}; +use syn::{ + braced, parse::Parse, punctuated::Punctuated, token, Expr, Ident, Token, Type, TypeParam, +}; + +pub mod parser; + +pub struct GenerateInstructionType { + pub name: Ident, + pub type_parameters: Punctuated, + pub short_parameters: Punctuated, + pub variants: Punctuated, +} + +impl GenerateInstructionType { + pub fn emit_arg_types(&self, tokens: &mut TokenStream) { + for v in self.variants.iter() { + v.emit_type(&self.type_parameters, tokens); + } + } + + pub fn emit_instruction_type(&self, tokens: &mut TokenStream) { + let type_name = &self.name; + let type_parameters = &self.type_parameters; + let variants = self.variants.iter().map(|v| v.emit_variant()); + quote! { + enum #type_name<#type_parameters> { + #(#variants),* + } + } + .to_tokens(tokens); + } + + pub fn emit_visit(&self, tokens: &mut TokenStream) { + self.emit_visit_impl(VisitKind::Ref, tokens, InstructionVariant::emit_visit) + } + + pub fn emit_visit_mut(&self, tokens: &mut TokenStream) { + self.emit_visit_impl( + VisitKind::RefMut, + tokens, + InstructionVariant::emit_visit_mut, + ) + } + + pub fn emit_visit_map(&self, tokens: &mut TokenStream) { + self.emit_visit_impl(VisitKind::Map, tokens, InstructionVariant::emit_visit_map) + } + + fn emit_visit_impl( + &self, + kind: VisitKind, + tokens: &mut TokenStream, + mut fn_: impl FnMut(&InstructionVariant, &Ident, &mut TokenStream), + ) { + let type_name = &self.name; + let type_parameters = &self.type_parameters; + let short_parameters = &self.short_parameters; + let mut inner_tokens = TokenStream::new(); + for v in self.variants.iter() { + fn_(v, type_name, &mut inner_tokens); + } + let visit_ref = kind.reference(); + let visitor_type = format_ident!("Visitor{}", kind.type_suffix()); + let visit_fn = format_ident!("visit{}", kind.fn_suffix()); + let visit_slice_fn = format_ident!("visit{}_slice", kind.fn_suffix()); + let (type_parameters, visitor_parameters, return_type) = if kind == VisitKind::Map { + ( + quote! { <#type_parameters, To: Operand> }, + quote! { <#short_parameters, To> }, + quote! { #type_name }, + ) + } else { + ( + quote! { <#type_parameters> }, + quote! { <#short_parameters> }, + quote! { () }, + ) + }; + quote! { + fn #visit_fn #type_parameters (i: #visit_ref #type_name<#short_parameters>, visitor: &mut impl #visitor_type #visitor_parameters ) -> #return_type { + match i { + #inner_tokens + } + } + }.to_tokens(tokens); + if kind == VisitKind::Map { + return; + } + quote! { + fn #visit_slice_fn #type_parameters (instructions: #visit_ref [#type_name<#short_parameters>], visitor: &mut impl #visitor_type #visitor_parameters) { + for i in instructions { + #visit_fn(i, visitor) + } + } + }.to_tokens(tokens); + } +} + +#[derive(Clone, Copy, PartialEq, Eq)] +enum VisitKind { + Ref, + RefMut, + Map, +} + +impl VisitKind { + fn fn_suffix(self) -> &'static str { + match self { + VisitKind::Ref => "", + VisitKind::RefMut => "_mut", + VisitKind::Map => "_map", + } + } + + fn type_suffix(self) -> &'static str { + match self { + VisitKind::Ref => "", + VisitKind::RefMut => "Mut", + VisitKind::Map => "Map", + } + } + + fn reference(self) -> Option { + match self { + VisitKind::Ref => Some(quote! { & }), + VisitKind::RefMut => Some(quote! { &mut }), + VisitKind::Map => None, + } + } +} + +impl Parse for GenerateInstructionType { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + input.parse::()?; + let name = input.parse::()?; + input.parse::()?; + let type_parameters = Punctuated::parse_separated_nonempty(input)?; + let short_parameters = type_parameters + .iter() + .map(|p: &TypeParam| p.ident.clone()) + .collect(); + input.parse::]>()?; + let variants_buffer; + braced!(variants_buffer in input); + let variants = variants_buffer.parse_terminated(InstructionVariant::parse, Token![,])?; + Ok(Self { + name, + type_parameters, + short_parameters, + variants, + }) + } +} + +pub struct InstructionVariant { + pub name: Ident, + pub type_: Option, + pub space: Option, + pub data: Option, + pub arguments: Option, +} + +impl InstructionVariant { + fn args_name(&self) -> Ident { + format_ident!("{}Args", self.name) + } + + fn emit_variant(&self) -> TokenStream { + let name = &self.name; + let data = match &self.data { + None => { + quote! {} + } + Some(data_type) => { + quote! { + data: #data_type, + } + } + }; + let arguments = match &self.arguments { + None => { + quote! {} + } + Some(args) => { + let args_name = self.args_name(); + match &args.generic { + None => { + quote! { + arguments: #args_name, + } + } + Some(generics) => { + quote! { + arguments: #args_name <#generics>, + } + } + } + } + }; + quote! { + #name { #data #arguments } + } + } + + fn emit_visit(&self, enum_: &Ident, tokens: &mut TokenStream) { + self.emit_visit_impl(enum_, tokens, InstructionArguments::emit_visit) + } + + fn emit_visit_mut(&self, enum_: &Ident, tokens: &mut TokenStream) { + self.emit_visit_impl(enum_, tokens, InstructionArguments::emit_visit_mut) + } + + fn emit_visit_impl( + &self, + enum_: &Ident, + tokens: &mut TokenStream, + mut fn_: impl FnMut(&InstructionArguments, &Option, &Option) -> TokenStream, + ) { + let name = &self.name; + let arguments = match &self.arguments { + None => { + quote! { + #enum_ :: #name { .. } => { } + } + .to_tokens(tokens); + return; + } + Some(args) => args, + }; + let arg_calls = fn_(arguments, &self.type_, &self.space); + quote! { + #enum_ :: #name { arguments, data } => { + #arg_calls + } + } + .to_tokens(tokens); + } + + fn emit_visit_map(&self, enum_: &Ident, tokens: &mut TokenStream) { + let name = &self.name; + let arguments = &self.arguments.as_ref().map(|_| quote! { arguments,}); + let data = &self.data.as_ref().map(|_| quote! { data,}); + let mut arg_calls = None; + let arguments_init = self.arguments.as_ref().map(|arguments| { + let arg_type = self.args_name(); + arg_calls = Some(arguments.emit_visit_map(&self.type_, &self.space)); + let arg_names = arguments.fields.iter().map(|arg| &arg.name); + quote! { + arguments: #arg_type { #(#arg_names),* } + } + }); + quote! { + #enum_ :: #name { #data #arguments } => { + #arg_calls + #enum_ :: #name { #data #arguments_init } + } + } + .to_tokens(tokens); + } + + fn emit_type( + &self, + type_parameters: &Punctuated, + tokens: &mut TokenStream, + ) { + let arguments = match self.arguments { + Some(ref a) => a, + None => return, + }; + let name = self.args_name(); + let type_parameters = if arguments.generic.is_some() { + Some(quote! { <#type_parameters> }) + } else { + None + }; + let fields = arguments.fields.iter().map(ArgumentField::emit_field); + quote! { + struct #name #type_parameters { + #(#fields),* + } + } + .to_tokens(tokens); + } +} + +impl Parse for InstructionVariant { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let name = input.parse::()?; + let properties_buffer; + braced!(properties_buffer in input); + let properties = properties_buffer.parse_terminated(VariantProperty::parse, Token![,])?; + let mut type_ = None; + let mut space = None; + let mut data = None; + let mut arguments = None; + for property in properties { + match property { + VariantProperty::Type(t) => type_ = Some(t), + VariantProperty::Space(s) => space = Some(s), + VariantProperty::Data(d) => data = Some(d), + VariantProperty::Arguments(a) => arguments = Some(a), + } + } + Ok(Self { + name, + type_, + space, + data, + arguments, + }) + } +} + +enum VariantProperty { + Type(Expr), + Space(Expr), + Data(Type), + Arguments(InstructionArguments), +} + +impl VariantProperty { + pub fn parse(input: syn::parse::ParseStream) -> syn::Result { + let lookahead = input.lookahead1(); + Ok(if lookahead.peek(Token![type]) { + input.parse::()?; + input.parse::()?; + VariantProperty::Type(input.parse::()?) + } else if lookahead.peek(Ident) { + let key = input.parse::()?; + match &*key.to_string() { + "data" => { + input.parse::()?; + VariantProperty::Data(input.parse::()?) + } + "space" => { + input.parse::()?; + VariantProperty::Space(input.parse::()?) + } + "arguments" => { + let generics = if input.peek(Token![<]) { + input.parse::()?; + let gen_params = + Punctuated::::parse_separated_nonempty(input)?; + input.parse::]>()?; + Some(gen_params) + } else { + None + }; + input.parse::()?; + let fields; + braced!(fields in input); + VariantProperty::Arguments(InstructionArguments::parse(generics, &fields)?) + } + x => { + return Err(syn::Error::new( + key.span(), + format!( + "Unexpected key `{}`. Expected `type`, `data` or `arguments`.", + x + ), + )) + } + } + } else { + return Err(lookahead.error()); + }) + } +} + +pub struct InstructionArguments { + pub generic: Option>, + pub fields: Punctuated, +} + +impl InstructionArguments { + pub fn parse( + generic: Option>, + input: syn::parse::ParseStream, + ) -> syn::Result { + let fields = Punctuated::::parse_terminated_with( + input, + ArgumentField::parse, + )?; + Ok(Self { generic, fields }) + } + + fn emit_visit(&self, parent_type: &Option, parent_space: &Option) -> TokenStream { + self.emit_visit_impl(parent_type, parent_space, ArgumentField::emit_visit) + } + + fn emit_visit_mut( + &self, + parent_type: &Option, + parent_space: &Option, + ) -> TokenStream { + self.emit_visit_impl(parent_type, parent_space, ArgumentField::emit_visit_mut) + } + + fn emit_visit_map( + &self, + parent_type: &Option, + parent_space: &Option, + ) -> TokenStream { + self.emit_visit_impl(parent_type, parent_space, ArgumentField::emit_visit_map) + } + + fn emit_visit_impl( + &self, + parent_type: &Option, + parent_space: &Option, + mut fn_: impl FnMut(&ArgumentField, &Option, &Option) -> TokenStream, + ) -> TokenStream { + let field_calls = self + .fields + .iter() + .map(|f| fn_(f, parent_type, parent_space)); + quote! { + #(#field_calls)* + } + } +} + +pub struct ArgumentField { + pub name: Ident, + pub is_dst: bool, + pub repr: Type, + pub space: Option, + pub type_: Option, +} + +impl ArgumentField { + fn parse_block( + input: syn::parse::ParseStream, + ) -> syn::Result<(Type, Option, Option)> { + let content; + braced!(content in input); + let all_fields = + Punctuated::::parse_terminated_with(&content, |content| { + let lookahead = content.lookahead1(); + Ok(if lookahead.peek(Token![type]) { + content.parse::()?; + content.parse::()?; + ExprOrPath::Type(content.parse::()?) + } else if lookahead.peek(Ident) { + let name_ident = content.parse::()?; + content.parse::()?; + match &*name_ident.to_string() { + "repr" => ExprOrPath::Repr(content.parse::()?), + "space" => ExprOrPath::Space(content.parse::()?), + name => { + return Err(syn::Error::new( + name_ident.span(), + format!("Unexpected key `{}`, expected `repr` or `space", name), + )) + } + } + } else { + return Err(lookahead.error()); + }) + })?; + let mut repr = None; + let mut type_ = None; + let mut space = None; + for exp_or_path in all_fields { + match exp_or_path { + ExprOrPath::Repr(r) => repr = Some(r), + ExprOrPath::Type(t) => type_ = Some(t), + ExprOrPath::Space(s) => space = Some(s), + } + } + Ok((repr.unwrap(), type_, space)) + } + + fn parse_basic(input: &syn::parse::ParseBuffer) -> syn::Result { + input.parse::() + } + + fn emit_visit(&self, parent_type: &Option, parent_space: &Option) -> TokenStream { + self.emit_visit_impl(parent_type, parent_space, false) + } + + fn emit_visit_mut( + &self, + parent_type: &Option, + parent_space: &Option, + ) -> TokenStream { + self.emit_visit_impl(parent_type, parent_space, true) + } + + fn emit_visit_impl( + &self, + parent_type: &Option, + parent_space: &Option, + is_mut: bool, + ) -> TokenStream { + let type_ = self.type_.as_ref().or(parent_type.as_ref()).unwrap(); + let space = self + .space + .as_ref() + .or(parent_space.as_ref()) + .map(|space| quote! { #space }) + .unwrap_or_else(|| quote! { StateSpace::Reg }); + let is_dst = self.is_dst; + let name = &self.name; + let arguments_name = if is_mut { + quote! { + &mut arguments.#name + } + } else { + quote! { + & arguments.#name + } + }; + quote! {{ + let type_ = #type_; + let space = #space; + visitor.visit(#arguments_name, &type_, space, #is_dst); + }} + } + + fn emit_visit_map( + &self, + parent_type: &Option, + parent_space: &Option, + ) -> TokenStream { + let type_ = self.type_.as_ref().or(parent_type.as_ref()).unwrap(); + let space = self + .space + .as_ref() + .or(parent_space.as_ref()) + .map(|space| quote! { #space }) + .unwrap_or_else(|| quote! { StateSpace::Reg }); + let is_dst = self.is_dst; + let name = &self.name; + quote! { + let #name = { + let type_ = #type_; + let space = #space; + visitor.visit(arguments.#name, &type_, space, #is_dst) + }; + } + } + + fn is_dst(name: &Ident) -> syn::Result { + if name.to_string().starts_with("dst") { + Ok(true) + } else if name.to_string().starts_with("src") { + Ok(false) + } else { + return Err(syn::Error::new( + name.span(), + format!( + "Could not guess if `{}` is a read or write argument. Name should start with `dst` or `src`", + name + ), + )); + } + } + + fn emit_field(&self) -> TokenStream { + let name = &self.name; + let type_ = &self.repr; + quote! { + #name: #type_ + } + } +} + +impl Parse for ArgumentField { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let name = input.parse::()?; + let is_dst = Self::is_dst(&name)?; + input.parse::()?; + let lookahead = input.lookahead1(); + let (repr, type_, space) = if lookahead.peek(token::Brace) { + Self::parse_block(input)? + } else if lookahead.peek(syn::Ident) { + (Self::parse_basic(input)?, None, None) + } else { + return Err(lookahead.error()); + }; + Ok(Self { + name, + is_dst, + repr, + type_, + space, + }) + } +} + +enum ExprOrPath { + Repr(Type), + Type(Expr), + Space(Expr), +} + +#[cfg(test)] +mod tests { + use super::*; + use proc_macro2::Span; + use quote::{quote, ToTokens}; + + fn to_string(x: impl ToTokens) -> String { + quote! { #x }.to_string() + } + + #[test] + fn parse_argument_field_basic() { + let input = quote! { + dst: P::Operand + }; + let arg = syn::parse2::(input).unwrap(); + assert_eq!("dst", arg.name.to_string()); + assert_eq!("P :: Operand", to_string(arg.repr)); + assert!(matches!(arg.type_, None)); + } + + #[test] + fn parse_argument_field_block() { + let input = quote! { + dst: { + type: ScalarType::U32, + space: StateSpace::Global, + repr: P::Operand, + } + }; + let arg = syn::parse2::(input).unwrap(); + assert_eq!("dst", arg.name.to_string()); + assert_eq!("ScalarType :: U32", to_string(arg.type_.unwrap())); + assert_eq!("StateSpace :: Global", to_string(arg.space.unwrap())); + assert_eq!("P :: Operand", to_string(arg.repr)); + } + + #[test] + fn parse_argument_field_block_untyped() { + let input = quote! { + dst: { + repr: P::Operand, + } + }; + let arg = syn::parse2::(input).unwrap(); + assert_eq!("dst", arg.name.to_string()); + assert_eq!("P :: Operand", to_string(arg.repr)); + assert!(matches!(arg.type_, None)); + } + + #[test] + fn parse_variant_complex() { + let input = quote! { + Ld { + type: ScalarType::U32, + space: StateSpace::Global, + data: LdDetails, + arguments

: { + dst: { + repr: P::Operand, + type: ScalarType::U32, + space: StateSpace::Shared, + }, + src: P::Operand, + }, + } + }; + let variant = syn::parse2::(input).unwrap(); + assert_eq!("Ld", variant.name.to_string()); + assert_eq!("ScalarType :: U32", to_string(variant.type_.unwrap())); + assert_eq!("StateSpace :: Global", to_string(variant.space.unwrap())); + assert_eq!("LdDetails", to_string(variant.data.unwrap())); + let arguments = variant.arguments.unwrap(); + assert_eq!("P", to_string(arguments.generic)); + let mut fields = arguments.fields.into_iter(); + let dst = fields.next().unwrap(); + assert_eq!("P :: Operand", to_string(dst.repr)); + assert_eq!("ScalarType :: U32", to_string(dst.type_)); + assert_eq!("StateSpace :: Shared", to_string(dst.space)); + let src = fields.next().unwrap(); + assert_eq!("P :: Operand", to_string(src.repr)); + assert!(matches!(src.type_, None)); + assert!(matches!(src.space, None)); + } + + #[test] + fn visit_variant() { + let input = quote! { + Ld { + type: ScalarType::U32, + data: LdDetails, + arguments

: { + dst: { + repr: P::Operand, + type: ScalarType::U32 + }, + src: P::Operand, + }, + } + }; + let variant = syn::parse2::(input).unwrap(); + let mut output = TokenStream::new(); + variant.emit_visit(&Ident::new("Instruction", Span::call_site()), &mut output); + assert_eq!(output.to_string(), "Instruction :: Ld { arguments , data } => { { let type_ = ScalarType :: U32 ; let space = StateSpace :: Reg ; visitor . visit (& arguments . dst , & type_ , space , true) ; } { let type_ = ScalarType :: U32 ; let space = StateSpace :: Reg ; visitor . visit (& arguments . src , & type_ , space , false) ; } }"); + } + + #[test] + fn visit_variant_empty() { + let input = quote! { + Ret { + data: RetData + } + }; + let variant = syn::parse2::(input).unwrap(); + let mut output = TokenStream::new(); + variant.emit_visit(&Ident::new("Instruction", Span::call_site()), &mut output); + assert_eq!(output.to_string(), "Instruction :: Ret { .. } => { }"); + } +} diff --git a/gen_impl/src/parser.rs b/gen_impl/src/parser.rs new file mode 100644 index 00000000..c8da61d1 --- /dev/null +++ b/gen_impl/src/parser.rs @@ -0,0 +1,793 @@ +use proc_macro2::Span; +use proc_macro2::TokenStream; +use quote::quote; +use quote::ToTokens; +use rustc_hash::FxHashMap; +use std::fmt::Write; +use syn::bracketed; +use syn::parse::Peek; +use syn::punctuated::Punctuated; +use syn::LitInt; +use syn::Type; +use syn::{braced, parse::Parse, token, Ident, ItemEnum, Token}; + +pub struct ParseDefinitions { + pub token_type: ItemEnum, + pub additional_enums: FxHashMap, + pub definitions: Vec, +} + +impl Parse for ParseDefinitions { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let token_type = input.parse::()?; + let mut additional_enums = FxHashMap::default(); + while input.peek(Token![#]) { + let enum_ = input.parse::()?; + additional_enums.insert(enum_.ident.clone(), enum_); + } + let mut definitions = Vec::new(); + while !input.is_empty() { + definitions.push(input.parse::()?); + } + Ok(Self { + token_type, + additional_enums, + definitions, + }) + } +} + +pub struct OpcodeDefinition(pub Patterns, pub Vec); + +impl Parse for OpcodeDefinition { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let patterns = input.parse::()?; + let mut rules = Vec::new(); + while Rule::peek(input) { + rules.push(input.parse::()?); + input.parse::()?; + } + Ok(Self(patterns, rules)) + } +} + +pub struct Patterns(pub Vec<(OpcodeDecl, CodeBlock)>); + +impl Parse for Patterns { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let mut result = Vec::new(); + loop { + if !OpcodeDecl::peek(input) { + break; + } + let decl = input.parse::()?; + let code_block = input.parse::()?; + result.push((decl, code_block)) + } + Ok(Self(result)) + } +} + +pub struct OpcodeDecl(pub Instruction, pub Arguments); + +impl OpcodeDecl { + fn peek(input: syn::parse::ParseStream) -> bool { + Instruction::peek(input) + } +} + +impl Parse for OpcodeDecl { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + Ok(Self( + input.parse::()?, + input.parse::()?, + )) + } +} + +pub struct CodeBlock(pub proc_macro2::Group); + +impl Parse for CodeBlock { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + input.parse::]>()?; + let group = input.parse::()?; + Ok(Self(group)) + } +} + +pub struct Rule { + pub modifier: DotModifier, + pub type_: Option, + pub alternatives: Vec, +} + +impl Rule { + fn peek(input: syn::parse::ParseStream) -> bool { + DotModifier::peek(input) + } + + fn parse_alternatives(input: syn::parse::ParseStream) -> syn::Result> { + let mut result = Vec::new(); + Self::parse_with_alternative(input, &mut result)?; + loop { + if !input.peek(Token![,]) { + break; + } + input.parse::()?; + Self::parse_with_alternative(input, &mut result)?; + } + Ok(result) + } + + fn parse_with_alternative( + input: &syn::parse::ParseBuffer, + result: &mut Vec, + ) -> Result<(), syn::Error> { + input.parse::()?; + let part1 = input.parse::()?; + if input.peek(token::Brace) { + result.push(DotModifier { + part1: part1.clone(), + part2: None, + }); + let suffix_content; + braced!(suffix_content in input); + let suffixes = Punctuated::::parse_separated_nonempty( + &suffix_content, + )?; + for part2 in suffixes { + result.push(DotModifier { + part1: part1.clone(), + part2: Some(part2), + }); + } + } else if IdentOrTypeSuffix::peek(input) { + let part2 = Some(IdentOrTypeSuffix::parse(input)?); + result.push(DotModifier { part1, part2 }); + } else { + result.push(DotModifier { part1, part2: None }); + } + Ok(()) + } +} + +#[derive(PartialEq, Eq, Hash, Clone)] +struct IdentOrTypeSuffix(IdentLike); + +impl IdentOrTypeSuffix { + fn peek(input: syn::parse::ParseStream) -> bool { + input.peek(Token![::]) + } +} + +impl ToTokens for IdentOrTypeSuffix { + fn to_tokens(&self, tokens: &mut TokenStream) { + let ident = &self.0; + quote! { :: #ident }.to_tokens(tokens) + } +} + +impl Parse for IdentOrTypeSuffix { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + input.parse::()?; + Ok(Self(input.parse::()?)) + } +} + +impl Parse for Rule { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let modifier = input.parse::()?; + let type_ = if input.peek(Token![:]) { + input.parse::()?; + Some(input.parse::()?) + } else { + None + }; + input.parse::()?; + let content; + braced!(content in input); + let alternatives = Self::parse_alternatives(&content)?; + Ok(Self { + modifier, + type_, + alternatives, + }) + } +} + +pub struct Instruction { + pub name: Ident, + pub modifiers: Vec, +} +impl Instruction { + fn peek(input: syn::parse::ParseStream) -> bool { + input.peek(Ident) + } +} + +impl Parse for Instruction { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let instruction = input.parse::()?; + let mut modifiers = Vec::new(); + loop { + if !MaybeDotModifier::peek(input) { + break; + } + modifiers.push(MaybeDotModifier::parse(input)?); + } + Ok(Self { + name: instruction, + modifiers, + }) + } +} + +pub struct MaybeDotModifier { + pub optional: bool, + pub modifier: DotModifier, +} + +impl MaybeDotModifier { + fn peek(input: syn::parse::ParseStream) -> bool { + input.peek(token::Brace) || DotModifier::peek(input) + } +} + +impl Parse for MaybeDotModifier { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + Ok(if input.peek(token::Brace) { + let content; + braced!(content in input); + let modifier = DotModifier::parse(&content)?; + Self { + modifier, + optional: true, + } + } else { + let modifier = DotModifier::parse(input)?; + Self { + modifier, + optional: false, + } + }) + } +} + +#[derive(PartialEq, Eq, Hash, Clone)] +pub struct DotModifier { + part1: IdentLike, + part2: Option, +} + +impl std::fmt::Display for DotModifier { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, ".")?; + self.part1.fmt(f)?; + if let Some(ref part2) = self.part2 { + write!(f, "::")?; + part2.0.fmt(f)?; + } + Ok(()) + } +} + +impl std::fmt::Debug for DotModifier { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + std::fmt::Display::fmt(&self, f) + } +} + +impl DotModifier { + pub fn ident(&self) -> Ident { + let mut result = String::new(); + write!(&mut result, "{}", self.part1).unwrap(); + if let Some(ref part2) = self.part2 { + write!(&mut result, "_{}", part2.0).unwrap(); + } else { + match self.part1 { + IdentLike::Type | IdentLike::Const => result.push('_'), + IdentLike::Ident(_) | IdentLike::Integer(_) => {} + } + } + Ident::new(&result.to_ascii_lowercase(), Span::call_site()) + } + + pub fn variant_capitalized(&self) -> Ident { + self.capitalized_impl(String::new()) + } + + pub fn dot_capitalized(&self) -> Ident { + self.capitalized_impl("Dot".to_string()) + } + + fn capitalized_impl(&self, prefix: String) -> Ident { + let mut temp = String::new(); + write!(&mut temp, "{}", &self.part1).unwrap(); + if let Some(IdentOrTypeSuffix(ref part2)) = self.part2 { + write!(&mut temp, "_{}", part2).unwrap(); + } + let mut result = prefix; + let mut capitalize = true; + for c in temp.chars() { + if c == '_' { + capitalize = true; + continue; + } + let c = if capitalize { + capitalize = false; + c.to_ascii_uppercase() + } else { + c + }; + result.push(c); + } + Ident::new(&result, Span::call_site()) + } + + pub fn tokens(&self) -> TokenStream { + let part1 = &self.part1; + let part2 = &self.part2; + match self.part2 { + None => quote! { . #part1 }, + Some(_) => quote! { . #part1 #part2 }, + } + } + + fn peek(input: syn::parse::ParseStream) -> bool { + input.peek(Token![.]) + } +} + +impl Parse for DotModifier { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + input.parse::()?; + let part1 = input.parse::()?; + if IdentOrTypeSuffix::peek(input) { + let part2 = Some(IdentOrTypeSuffix::parse(input)?); + Ok(Self { part1, part2 }) + } else { + Ok(Self { part1, part2: None }) + } + } +} + +#[derive(PartialEq, Eq, Hash, Clone)] +enum IdentLike { + Type, + Const, + Ident(Ident), + Integer(LitInt), +} + +impl std::fmt::Display for IdentLike { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + IdentLike::Type => f.write_str("type"), + IdentLike::Const => f.write_str("const"), + IdentLike::Ident(ident) => write!(f, "{}", ident), + IdentLike::Integer(integer) => write!(f, "{}", integer), + } + } +} + +impl ToTokens for IdentLike { + fn to_tokens(&self, tokens: &mut TokenStream) { + match self { + IdentLike::Type => quote! { type }.to_tokens(tokens), + IdentLike::Const => quote! { const }.to_tokens(tokens), + IdentLike::Ident(ident) => quote! { #ident }.to_tokens(tokens), + IdentLike::Integer(int) => quote! { #int }.to_tokens(tokens), + } + } +} + +impl Parse for IdentLike { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let lookahead = input.lookahead1(); + Ok(if lookahead.peek(Token![const]) { + input.parse::()?; + IdentLike::Const + } else if lookahead.peek(Token![type]) { + input.parse::()?; + IdentLike::Type + } else if lookahead.peek(Ident) { + IdentLike::Ident(input.parse::()?) + } else if lookahead.peek(LitInt) { + IdentLike::Integer(input.parse::()?) + } else { + return Err(lookahead.error()); + }) + } +} + +// Arguments decalaration can loook like this: +// a{, b} +// That's why we don't parse Arguments as Punctuated +#[derive(PartialEq, Eq)] +pub struct Arguments(pub Vec); + +impl Parse for Arguments { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let mut result = Vec::new(); + loop { + if input.peek(Token![,]) { + input.parse::()?; + } + let mut optional = false; + let mut can_be_negated = false; + let mut pre_pipe = false; + let ident; + let lookahead = input.lookahead1(); + if lookahead.peek(token::Brace) { + let content; + braced!(content in input); + let lookahead = content.lookahead1(); + if lookahead.peek(Token![!]) { + content.parse::()?; + can_be_negated = true; + ident = input.parse::()?; + } else if lookahead.peek(Token![,]) { + optional = true; + content.parse::()?; + ident = content.parse::()?; + } else { + return Err(lookahead.error()); + } + } else if lookahead.peek(token::Bracket) { + let bracketed; + bracketed!(bracketed in input); + if bracketed.peek(Token![|]) { + optional = true; + bracketed.parse::()?; + pre_pipe = true; + ident = bracketed.parse::()?; + } else { + let mut sub_args = Self::parse(&bracketed)?; + sub_args.0.first_mut().unwrap().pre_bracket = true; + sub_args.0.last_mut().unwrap().post_bracket = true; + if peek_brace_token(input, Token![.]) { + let optional_suffix; + braced!(optional_suffix in input); + optional_suffix.parse::()?; + let unified_ident = optional_suffix.parse::()?; + if unified_ident.to_string() != "unified" { + return Err(syn::Error::new( + unified_ident.span(), + format!("Exptected `unified`, got `{}`", unified_ident), + )); + } + for a in sub_args.0.iter_mut() { + a.unified = true; + } + } + result.extend(sub_args.0); + continue; + } + } else if lookahead.peek(Ident) { + ident = input.parse::()?; + } else if lookahead.peek(Token![|]) { + input.parse::()?; + pre_pipe = true; + ident = input.parse::()?; + } else { + break; + } + result.push(Argument { + optional, + pre_pipe, + can_be_negated, + pre_bracket: false, + ident, + post_bracket: false, + unified: false, + }); + } + Ok(Self(result)) + } +} + +// This is effectively input.peek(token::Brace) && input.peek2(Token![.]) +// input.peek2 is supposed to skip over next token, but it skips over whole +// braced token group. Not sure if it's a bug +fn peek_brace_token(input: syn::parse::ParseStream, _t: T) -> bool { + use syn::token::Token; + let cursor = input.cursor(); + cursor + .group(proc_macro2::Delimiter::Brace) + .map_or(false, |(content, ..)| T::Token::peek(content)) +} + +#[derive(PartialEq, Eq)] +pub struct Argument { + pub optional: bool, + pub pre_bracket: bool, + pub pre_pipe: bool, + pub can_be_negated: bool, + pub ident: Ident, + pub post_bracket: bool, + pub unified: bool, +} + +#[cfg(test)] +mod tests { + use super::{Arguments, DotModifier, MaybeDotModifier}; + use quote::{quote, ToTokens}; + + #[test] + fn parse_modifier_complex() { + let input = quote! { + .level::eviction_priority + }; + let modifier = syn::parse2::(input).unwrap(); + assert_eq!( + ". level :: eviction_priority", + modifier.tokens().to_string() + ); + } + + #[test] + fn parse_modifier_optional() { + let input = quote! { + { .level::eviction_priority } + }; + let maybe_modifider = syn::parse2::(input).unwrap(); + assert_eq!( + ". level :: eviction_priority", + maybe_modifider.modifier.tokens().to_string() + ); + assert!(maybe_modifider.optional); + } + + #[test] + fn parse_type_token() { + let input = quote! { + . type + }; + let maybe_modifier = syn::parse2::(input).unwrap(); + assert_eq!(". type", maybe_modifier.modifier.tokens().to_string()); + assert!(!maybe_modifier.optional); + } + + #[test] + fn arguments_memory() { + let input = quote! { + [a], b + }; + let arguments = syn::parse2::(input).unwrap(); + let a = &arguments.0[0]; + assert!(!a.optional); + assert_eq!("a", a.ident.to_string()); + assert!(a.pre_bracket); + assert!(!a.pre_pipe); + assert!(a.post_bracket); + assert!(!a.can_be_negated); + let b = &arguments.0[1]; + assert!(!b.optional); + assert_eq!("b", b.ident.to_string()); + assert!(!b.pre_bracket); + assert!(!b.pre_pipe); + assert!(!b.post_bracket); + assert!(!b.can_be_negated); + } + + #[test] + fn arguments_optional() { + let input = quote! { + b{, cache_policy} + }; + let arguments = syn::parse2::(input).unwrap(); + let b = &arguments.0[0]; + assert!(!b.optional); + assert_eq!("b", b.ident.to_string()); + assert!(!b.pre_bracket); + assert!(!b.pre_pipe); + assert!(!b.post_bracket); + assert!(!b.can_be_negated); + let cache_policy = &arguments.0[1]; + assert!(cache_policy.optional); + assert_eq!("cache_policy", cache_policy.ident.to_string()); + assert!(!cache_policy.pre_bracket); + assert!(!cache_policy.pre_pipe); + assert!(!cache_policy.post_bracket); + assert!(!cache_policy.can_be_negated); + } + + #[test] + fn arguments_optional_pred() { + let input = quote! { + p[|q], a + }; + let arguments = syn::parse2::(input).unwrap(); + assert_eq!(arguments.0.len(), 3); + let p = &arguments.0[0]; + assert!(!p.optional); + assert_eq!("p", p.ident.to_string()); + assert!(!p.pre_bracket); + assert!(!p.pre_pipe); + assert!(!p.post_bracket); + assert!(!p.can_be_negated); + let q = &arguments.0[1]; + assert!(q.optional); + assert_eq!("q", q.ident.to_string()); + assert!(!q.pre_bracket); + assert!(q.pre_pipe); + assert!(!q.post_bracket); + assert!(!q.can_be_negated); + let a = &arguments.0[2]; + assert!(!a.optional); + assert_eq!("a", a.ident.to_string()); + assert!(!a.pre_bracket); + assert!(!a.pre_pipe); + assert!(!a.post_bracket); + assert!(!a.can_be_negated); + } + + #[test] + fn arguments_optional_with_negate() { + let input = quote! { + b, {!}c + }; + let arguments = syn::parse2::(input).unwrap(); + assert_eq!(arguments.0.len(), 2); + let b = &arguments.0[0]; + assert!(!b.optional); + assert_eq!("b", b.ident.to_string()); + assert!(!b.pre_bracket); + assert!(!b.pre_pipe); + assert!(!b.post_bracket); + assert!(!b.can_be_negated); + let c = &arguments.0[1]; + assert!(!c.optional); + assert_eq!("c", c.ident.to_string()); + assert!(!c.pre_bracket); + assert!(!c.pre_pipe); + assert!(!c.post_bracket); + assert!(c.can_be_negated); + } + + #[test] + fn arguments_tex() { + let input = quote! { + d[|p], [a{, b}, c], dpdx, dpdy {, e} + }; + let arguments = syn::parse2::(input).unwrap(); + assert_eq!(arguments.0.len(), 8); + { + let d = &arguments.0[0]; + assert!(!d.optional); + assert_eq!("d", d.ident.to_string()); + assert!(!d.pre_bracket); + assert!(!d.pre_pipe); + assert!(!d.post_bracket); + assert!(!d.can_be_negated); + } + { + let p = &arguments.0[1]; + assert!(p.optional); + assert_eq!("p", p.ident.to_string()); + assert!(!p.pre_bracket); + assert!(p.pre_pipe); + assert!(!p.post_bracket); + assert!(!p.can_be_negated); + } + { + let a = &arguments.0[2]; + assert!(!a.optional); + assert_eq!("a", a.ident.to_string()); + assert!(a.pre_bracket); + assert!(!a.pre_pipe); + assert!(!a.post_bracket); + assert!(!a.can_be_negated); + } + { + let b = &arguments.0[3]; + assert!(b.optional); + assert_eq!("b", b.ident.to_string()); + assert!(!b.pre_bracket); + assert!(!b.pre_pipe); + assert!(!b.post_bracket); + assert!(!b.can_be_negated); + } + { + let c = &arguments.0[4]; + assert!(!c.optional); + assert_eq!("c", c.ident.to_string()); + assert!(!c.pre_bracket); + assert!(!c.pre_pipe); + assert!(c.post_bracket); + assert!(!c.can_be_negated); + } + { + let dpdx = &arguments.0[5]; + assert!(!dpdx.optional); + assert_eq!("dpdx", dpdx.ident.to_string()); + assert!(!dpdx.pre_bracket); + assert!(!dpdx.pre_pipe); + assert!(!dpdx.post_bracket); + assert!(!dpdx.can_be_negated); + } + { + let dpdy = &arguments.0[6]; + assert!(!dpdy.optional); + assert_eq!("dpdy", dpdy.ident.to_string()); + assert!(!dpdy.pre_bracket); + assert!(!dpdy.pre_pipe); + assert!(!dpdy.post_bracket); + assert!(!dpdy.can_be_negated); + } + { + let e = &arguments.0[7]; + assert!(e.optional); + assert_eq!("e", e.ident.to_string()); + assert!(!e.pre_bracket); + assert!(!e.pre_pipe); + assert!(!e.post_bracket); + assert!(!e.can_be_negated); + } + } + + #[test] + fn rule_multi() { + let input = quote! { + .ss: StateSpace = { .global, .local, .param{::func}, .shared{::cta, ::cluster} } + }; + let rule = syn::parse2::(input).unwrap(); + assert_eq!(". ss", rule.modifier.tokens().to_string()); + assert_eq!( + "StateSpace", + rule.type_.unwrap().to_token_stream().to_string() + ); + let alts = rule + .alternatives + .iter() + .map(|m| m.tokens().to_string()) + .collect::>(); + assert_eq!( + vec![ + ". global", + ". local", + ". param", + ". param :: func", + ". shared", + ". shared :: cta", + ". shared :: cluster" + ], + alts + ); + } + + #[test] + fn rule_multi2() { + let input = quote! { + .cop: StCacheOperator = { .wb, .cg, .cs, .wt } + }; + let rule = syn::parse2::(input).unwrap(); + assert_eq!(". cop", rule.modifier.tokens().to_string()); + assert_eq!( + "StCacheOperator", + rule.type_.unwrap().to_token_stream().to_string() + ); + let alts = rule + .alternatives + .iter() + .map(|m| m.tokens().to_string()) + .collect::>(); + assert_eq!(vec![". wb", ". cg", ". cs", ". wt",], alts); + } + + #[test] + fn args_unified() { + let input = quote! { + d, [a]{.unified}{, cache_policy} + }; + let args = syn::parse2::(input).unwrap(); + let a = &args.0[1]; + assert!(!a.optional); + assert_eq!("a", a.ident.to_string()); + assert!(a.pre_bracket); + assert!(!a.pre_pipe); + assert!(a.post_bracket); + assert!(!a.can_be_negated); + assert!(a.unified); + } +} diff --git a/ptx_parser/Cargo.toml b/ptx_parser/Cargo.toml new file mode 100644 index 00000000..d5e3d5db --- /dev/null +++ b/ptx_parser/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "ptx_parser" +version = "0.1.0" +edition = "2021" + +[dependencies] +logos = "0.14" +winnow = { version = "0.6.18", features = ["debug"] } +gen = { path = "../gen" } diff --git a/ptx_parser/src/main.rs b/ptx_parser/src/main.rs new file mode 100644 index 00000000..8af8ede1 --- /dev/null +++ b/ptx_parser/src/main.rs @@ -0,0 +1,437 @@ +use gen::derive_parser; +use logos::Logos; +use std::mem; +use winnow::{ + error::{ContextError, ParserError}, + stream::{Offset, Stream, StreamIsPartial}, +}; + +pub trait Operand {} + +pub trait Visitor { + fn visit(&mut self, args: &T, type_: &Type, space: StateSpace, is_dst: bool); +} + +pub trait VisitorMut { + fn visit(&mut self, args: &mut T, type_: &Type, space: StateSpace, is_dst: bool); +} + +pub trait VisitorMap { + fn visit(&mut self, args: From, type_: &Type, space: StateSpace, is_dst: bool) -> To; +} + +gen::generate_instruction_type!( + enum Instruction { + Ld { + type: { &data.typ }, + data: LdDetails, + arguments: { + dst: T, + src: { + repr: T, + space: { data.state_space }, + } + } + }, + Add { + type: { data.type_().into() }, + data: ArithDetails, + arguments: { + dst: T, + src1: T, + src2: T, + } + }, + St { + type: { &data.typ }, + data: StData, + arguments: { + src1: { + repr: T, + space: { data.state_space }, + }, + src2: T, + } + }, + Ret { + data: RetData + }, + Trap { } + } +); + +pub struct LdDetails { + pub qualifier: LdStQualifier, + pub state_space: StateSpace, + pub caching: LdCacheOperator, + pub typ: Type, + pub non_coherent: bool, +} + +#[derive(Copy, Clone)] +pub enum ArithDetails { + Unsigned(ScalarType), + Signed(ArithSInt), + Float(ArithFloat), +} + +impl ArithDetails { + fn type_(&self) -> ScalarType { + match self { + ArithDetails::Unsigned(t) => *t, + ArithDetails::Signed(arith) => arith.typ, + ArithDetails::Float(arith) => arith.typ, + } + } +} + +#[derive(Copy, Clone)] +pub struct ArithSInt { + pub typ: ScalarType, + pub saturate: bool, +} + +#[derive(Copy, Clone)] +pub struct ArithFloat { + pub typ: ScalarType, + pub rounding: Option, + pub flush_to_zero: Option, + pub saturate: bool, +} + +#[derive(PartialEq, Eq, Copy, Clone)] +pub enum RoundingMode { + NearestEven, + Zero, + NegativeInf, + PositiveInf, +} + +#[derive(Copy, Clone, PartialEq, Eq)] +pub enum LdCacheOperator { + Cached, + L2Only, + Streaming, + LastUse, + Uncached, +} + +#[derive(PartialEq, Eq, Clone, Hash)] +pub enum Type { + // .param.b32 foo; + Scalar(ScalarType), + // .param.v2.b32 foo; + Vector(ScalarType, u8), + // .param.b32 foo[4]; + Array(ScalarType, Vec), +} + +impl From for Type { + fn from(value: ScalarType) -> Self { + Type::Scalar(value) + } +} + +#[derive(Copy, Clone, PartialEq, Eq)] +pub enum LdStQualifier { + Weak, + Volatile, + Relaxed(MemScope), + Acquire(MemScope), + Release(MemScope), +} + +pub struct StData { + pub qualifier: LdStQualifier, + pub state_space: StateSpace, + pub caching: StCacheOperator, + pub typ: Type, +} + +#[derive(PartialEq, Eq)] +pub enum StCacheOperator { + Writeback, + L2Only, + Streaming, + Writethrough, +} + +#[derive(Copy, Clone)] +pub struct RetData { + pub uniform: bool, +} + +pub struct ParsedOperand {} + +impl Operand for ParsedOperand {} + +#[derive(Debug)] +struct ReverseStream<'a, T>(pub &'a [T]); + +impl<'i, T> Stream for ReverseStream<'i, T> +where + T: Clone + ::std::fmt::Debug, +{ + type Token = T; + type Slice = &'i [T]; + + type IterOffsets = + std::iter::Enumerate>>>; + + type Checkpoint = &'i [T]; + + #[inline(always)] + fn iter_offsets(&self) -> Self::IterOffsets { + self.0.iter().rev().cloned().enumerate() + } + + #[inline(always)] + fn eof_offset(&self) -> usize { + self.0.len() + } + + #[inline(always)] + fn next_token(&mut self) -> Option { + let (token, next) = self.0.split_last()?; + self.0 = next; + Some(token.clone()) + } + + #[inline(always)] + fn offset_for

(&self, predicate: P) -> Option + where + P: Fn(Self::Token) -> bool, + { + self.0.iter().rev().position(|b| predicate(b.clone())) + } + + #[inline(always)] + fn offset_at(&self, tokens: usize) -> Result { + if let Some(needed) = tokens + .checked_sub(self.0.len()) + .and_then(std::num::NonZeroUsize::new) + { + Err(winnow::error::Needed::Size(needed)) + } else { + Ok(tokens) + } + } + + #[inline(always)] + fn next_slice(&mut self, offset: usize) -> Self::Slice { + let offset = self.0.len() - offset; + let (next, slice) = self.0.split_at(offset); + self.0 = next; + slice + } + + #[inline(always)] + fn checkpoint(&self) -> Self::Checkpoint { + self.0 + } + + #[inline(always)] + fn reset(&mut self, checkpoint: &Self::Checkpoint) { + self.0 = checkpoint; + } + + #[inline(always)] + fn raw(&self) -> &dyn std::fmt::Debug { + self + } +} + +impl<'a, T> Offset<&'a [T]> for ReverseStream<'a, T> { + #[inline] + fn offset_from(&self, start: &&'a [T]) -> usize { + let fst = start.as_ptr(); + let snd = self.0.as_ptr(); + + debug_assert!( + snd <= fst, + "`Offset::offset_from({snd:?}, {fst:?})` only accepts slices of `self`" + ); + (fst as usize - snd as usize) / std::mem::size_of::() + } +} + +impl<'a, T> StreamIsPartial for ReverseStream<'a, T> { + type PartialState = (); + + fn complete(&mut self) -> Self::PartialState {} + + fn restore_partial(&mut self, _state: Self::PartialState) {} + + fn is_partial_supported() -> bool { + false + } +} + +// Modifiers are turned into arguments to the blocks, with type: +// * If it is an alternative: +// * If it is mandatory then its type is Foo (as defined by the relevant rule) +// * If it is optional then its type is Option +// * Otherwise: +// * If it is mandatory then it is skipped +// * If it is optional then its type is `bool` + +derive_parser!( + #[derive(Logos, PartialEq, Eq, Debug, Clone, Copy)] + #[logos(skip r"\s+")] + enum Token<'input> { + #[token(",")] + Comma, + #[regex(r"[a-zA-Z][a-zA-Z0-9_$]*|[_$%][a-zA-Z0-9_$]+", |lex| lex.slice(), priority = 0)] + Ident(&'input str), + #[token("|")] + Or, + #[token("!")] + Not, + #[token(";")] + Semicolon, + #[token("[")] + LBracket, + #[token("]")] + RBracket, + #[regex(r"[0-9]+U?")] + Decimal + } + + #[derive(Copy, Clone, PartialEq, Eq, Hash)] + pub enum StateSpace { + Reg + } + + #[derive(Copy, Clone, PartialEq, Eq, Hash)] + pub enum MemScope { } + + #[derive(Copy, Clone, PartialEq, Eq, Hash)] + pub enum ScalarType { } + + // https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-st + st{.weak}{.ss}{.cop}{.level::eviction_priority}{.level::cache_hint}{.vec}.type [a], b{, cache_policy} => { + todo!() + } + st.volatile{.ss}{.vec}.type [a], b => { + todo!() + } + st.relaxed.scope{.ss}{.level::eviction_priority}{.level::cache_hint}{.vec}.type [a], b{, cache_policy} => { + todo!() + } + st.release.scope{.ss}{.level::eviction_priority}{.level::cache_hint}{.vec}.type [a], b{, cache_policy} => { + todo!() + } + st.mmio.relaxed.sys{.global}.type [a], b => { + todo!() + } + + .ss: StateSpace = { .global, .local, .param{::func}, .shared{::cta, ::cluster} }; + .level::eviction_priority: EvictionPriority = + { .L1::evict_normal, .L1::evict_unchanged, .L1::evict_first, .L1::evict_last, .L1::no_allocate }; + .level::cache_hint = { .L2::cache_hint }; + .cop: RawStCacheOperator = { .wb, .cg, .cs, .wt }; + .scope: MemScope = { .cta, .cluster, .gpu, .sys }; + .vec: VectorPrefix = { .v2, .v4 }; + .type: ScalarType = { .b8, .b16, .b32, .b64, .b128, + .u8, .u16, .u32, .u64, + .s8, .s16, .s32, .s64, + .f32, .f64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-ld + ld{.weak}{.ss}{.cop}{.level::eviction_priority}{.level::cache_hint}{.level::prefetch_size}{.vec}.type d, [a]{.unified}{, cache_policy} => { + todo!() + } + ld.volatile{.ss}{.level::prefetch_size}{.vec}.type d, [a] => { + todo!() + } + ld.relaxed.scope{.ss}{.level::eviction_priority}{.level::cache_hint}{.level::prefetch_size}{.vec}.type d, [a]{, cache_policy} => { + todo!() + } + ld.acquire.scope{.ss}{.level::eviction_priority}{.level::cache_hint}{.level::prefetch_size}{.vec}.type d, [a]{, cache_policy} => { + todo!() + } + ld.mmio.relaxed.sys{.global}.type d, [a] => { + todo!() + } + + .ss: StateSpace = { .const, .global, .local, .param{::entry, ::func}, .shared{::cta, ::cluster} }; + .cop: RawCacheOp = { .ca, .cg, .cs, .lu, .cv }; + .level::eviction_priority: EvictionPriority = + { .L1::evict_normal, .L1::evict_unchanged, .L1::evict_first, .L1::evict_last, .L1::no_allocate }; + .level::cache_hint = { .L2::cache_hint }; + .level::prefetch_size: PrefetchSize = { .L2::64B, .L2::128B, .L2::256B }; + .scope: MemScope = { .cta, .cluster, .gpu, .sys }; + .vec: VectorPrefix = { .v2, .v4 }; + .type: ScalarType = { .b8, .b16, .b32, .b64, .b128, + .u8, .u16, .u32, .u64, + .s8, .s16, .s32, .s64, + .f32, .f64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/#integer-arithmetic-instructions-add + add.type d, a, b => { + todo!() + } + add{.sat}.s32 d, a, b => { + todo!() + } + + .type: ScalarType = { .u16, .u32, .u64, + .s16, .s64, + .u16x2, .s16x2 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/#floating-point-instructions-add + add{.rnd}{.ftz}{.sat}.f32 d, a, b => { + todo!() + } + add{.rnd}.f64 d, a, b => { + todo!() + } + + .rnd: RawFloatRounding = { .rn, .rz, .rm, .rp }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/#half-precision-floating-point-instructions-add + add{.rnd}{.ftz}{.sat}.f16 d, a, b => { + todo!() + } + add{.rnd}{.ftz}{.sat}.f16x2 d, a, b => { + todo!() + } + add{.rnd}.bf16 d, a, b => { + todo!() + } + add{.rnd}.bf16x2 d, a, b => { + todo!() + } + + .rnd: RawFloatRounding = { .rn }; + + ret => { + todo!() + } + +); + +fn main() { + use winnow::combinator::*; + use winnow::token::*; + use winnow::Parser; + + let mut input: &[Token] = &[][..]; + let x = opt(any::<_, ContextError>.verify_map(|t| { println!("MAP");Some(true) })).parse_next(&mut input).unwrap(); + dbg!(x); + let lexer = Token::lexer( + " + ld.u64 temp, [in_addr]; + add.u64 temp2, temp, 1; + st.u64 [out_addr], temp2; + ret; + ", + ); + let tokens = lexer.map(|t| t.unwrap()).collect::>(); + println!("{:?}", &tokens); + let mut stream = &tokens[..]; + parse_instruction(&mut stream).unwrap(); + //parse_prefix(&mut lexer); + let mut parser = &*tokens; + println!("{}", mem::size_of::()); +} From 8d7c88c095a013261cca1c6e5cbfb1acaac05624 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Thu, 15 Aug 2024 03:26:38 +0200 Subject: [PATCH 02/47] Fully parse operands --- gen/src/lib.rs | 14 +- gen_impl/src/lib.rs | 2 +- ptx_parser/Cargo.toml | 1 + ptx_parser/src/ast.rs | 16 +++ ptx_parser/src/main.rs | 318 ++++++++++++++++++++++++++++++++++++++--- 5 files changed, 323 insertions(+), 28 deletions(-) create mode 100644 ptx_parser/src/ast.rs diff --git a/gen/src/lib.rs b/gen/src/lib.rs index f39150ff..30e4595d 100644 --- a/gen/src/lib.rs +++ b/gen/src/lib.rs @@ -174,9 +174,9 @@ impl SingleOpcodeDefinition { .chain(self.arguments.0.iter().map(|arg| { let name = &arg.ident; if arg.optional { - quote! { #name : Option<&str> } + quote! { #name : Option> } } else { - quote! { #name : &str } + quote! { #name : ParsedOperand<'input> } } })) } @@ -377,7 +377,7 @@ fn emit_parse_function( let code_block = &def.code_block.0; let args = def.function_arguments_declarations(); quote! { - fn #fn_name( #(#args),* ) -> Instruction #code_block + fn #fn_name<'input>( #(#args),* ) -> Instruction> #code_block } }) }) @@ -452,7 +452,7 @@ fn emit_parse_function( #(#fns_)* - fn parse_instruction<'input>(stream: &mut (impl winnow::stream::Stream, Slice = &'input [#type_name<'input>]> + winnow::stream::StreamIsPartial)) -> winnow::error::PResult> + fn parse_instruction<'a, 'input>(stream: &mut ParserState<'a, 'input>) -> winnow::error::PResult>> { use winnow::Parser; use winnow::token::*; @@ -642,9 +642,9 @@ fn emit_definition_parser( empty } }; - let ident = { + let operand = { quote! { - any.verify_map(|t| match t { #token_type::Ident(s) => Some(s), _ => None }) + ParsedOperand::parse } }; let post_bracket = if arg.post_bracket { @@ -657,7 +657,7 @@ fn emit_definition_parser( } }; let parser = quote! { - (#comma, #pre_bracket, #pre_pipe, #can_be_negated, #ident, #post_bracket) + (#comma, #pre_bracket, #pre_pipe, #can_be_negated, #operand, #post_bracket) }; let arg_name = &arg.ident; if arg.optional { diff --git a/gen_impl/src/lib.rs b/gen_impl/src/lib.rs index 4c7f2ab2..6b606af2 100644 --- a/gen_impl/src/lib.rs +++ b/gen_impl/src/lib.rs @@ -67,7 +67,7 @@ impl GenerateInstructionType { let visit_slice_fn = format_ident!("visit{}_slice", kind.fn_suffix()); let (type_parameters, visitor_parameters, return_type) = if kind == VisitKind::Map { ( - quote! { <#type_parameters, To: Operand> }, + quote! { <#type_parameters, To> }, quote! { <#short_parameters, To> }, quote! { #type_name }, ) diff --git a/ptx_parser/Cargo.toml b/ptx_parser/Cargo.toml index d5e3d5db..951d508a 100644 --- a/ptx_parser/Cargo.toml +++ b/ptx_parser/Cargo.toml @@ -7,3 +7,4 @@ edition = "2021" logos = "0.14" winnow = { version = "0.6.18", features = ["debug"] } gen = { path = "../gen" } +thiserror = "1.0" diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs new file mode 100644 index 00000000..ae4eaba3 --- /dev/null +++ b/ptx_parser/src/ast.rs @@ -0,0 +1,16 @@ +#[derive(Clone)] +pub enum ParsedOperand { + Reg(Ident), + RegOffset(Ident, i32), + Imm(ImmediateValue), + VecMember(Ident, u8), + VecPack(Vec), +} + +#[derive(Copy, Clone)] +pub enum ImmediateValue { + U64(u64), + S64(i64), + F32(f32), + F64(f64), +} diff --git a/ptx_parser/src/main.rs b/ptx_parser/src/main.rs index 8af8ede1..4f3ed415 100644 --- a/ptx_parser/src/main.rs +++ b/ptx_parser/src/main.rs @@ -1,27 +1,66 @@ use gen::derive_parser; use logos::Logos; use std::mem; +use std::num::{ParseFloatError, ParseIntError}; +use winnow::combinator::{alt, empty, fail, opt}; +use winnow::stream::SliceLen; +use winnow::token::{any, literal}; use winnow::{ error::{ContextError, ParserError}, stream::{Offset, Stream, StreamIsPartial}, + PResult, }; +use winnow::{prelude::*, Stateful}; + +mod ast; pub trait Operand {} -pub trait Visitor { +pub trait Visitor { fn visit(&mut self, args: &T, type_: &Type, space: StateSpace, is_dst: bool); } -pub trait VisitorMut { +pub trait VisitorMut { fn visit(&mut self, args: &mut T, type_: &Type, space: StateSpace, is_dst: bool); } -pub trait VisitorMap { +pub trait VisitorMap { fn visit(&mut self, args: From, type_: &Type, space: StateSpace, is_dst: bool) -> To; } +#[derive(Clone)] +pub struct MovDetails { + pub typ: Type, + pub src_is_address: bool, + // two fields below are in use by member moves + pub dst_width: u8, + pub src_width: u8, + // This is in use by auto-generated movs + pub relaxed_src2_conv: bool, +} + +impl MovDetails { + pub fn new(typ: Type) -> Self { + MovDetails { + typ, + src_is_address: false, + dst_width: 0, + src_width: 0, + relaxed_src2_conv: false, + } + } +} + gen::generate_instruction_type!( - enum Instruction { + enum Instruction { + Mov { + type: { &data.typ }, + data: MovDetails, + arguments: { + dst: T, + src: T + } + }, Ld { type: { &data.typ }, data: LdDetails, @@ -161,9 +200,212 @@ pub struct RetData { pub uniform: bool, } -pub struct ParsedOperand {} +type ParserState<'a, 'input> = Stateful<&'a [Token<'input>], Vec>; + +fn ident<'a, 'input>(stream: &mut ParserState<'a, 'input>) -> PResult<&'input str> { + any.verify_map(|t| { + if let Token::Ident(text) = t { + Some(text) + } else { + None + } + }) + .parse_next(stream) +} + +fn num<'a, 'input>(stream: &mut ParserState<'a, 'input>) -> PResult<(&'input str, u32, bool)> { + any.verify_map(|t| { + Some(match t { + Token::Hex(s) => { + if s.ends_with('U') { + (&s[2..s.len() - 1], 16, true) + } else { + (&s[2..], 16, false) + } + } + Token::Decimal(s) => { + let radix = if s.starts_with('0') { 8 } else { 10 }; + if s.ends_with('U') { + (&s[..s.len() - 1], radix, true) + } else { + (s, radix, false) + } + } + _ => return None, + }) + }) + .parse_next(stream) +} + +fn take_error<'a, 'input: 'a, O, E>( + mut parser: impl Parser, Result, E>, +) -> impl Parser, O, E> { + move |input: &mut ParserState<'a, 'input>| { + Ok(match parser.parse_next(input)? { + Ok(x) => x, + Err((x, err)) => { + input.state.push(err); + x + } + }) + } +} + +fn int_immediate<'a, 'input>(input: &mut ParserState<'a, 'input>) -> PResult { + take_error((opt(Token::Minus), num).map(|(neg, x)| { + let (num, radix, is_unsigned) = x; + if neg.is_some() { + match i64::from_str_radix(num, radix) { + Ok(x) => Ok(ast::ImmediateValue::S64(-x)), + Err(err) => Err((ast::ImmediateValue::S64(0), PtxError::from(err))), + } + } else if is_unsigned { + match u64::from_str_radix(num, radix) { + Ok(x) => Ok(ast::ImmediateValue::U64(x)), + Err(err) => Err((ast::ImmediateValue::U64(0), PtxError::from(err))), + } + } else { + match i64::from_str_radix(num, radix) { + Ok(x) => Ok(ast::ImmediateValue::S64(x)), + Err(_) => match u64::from_str_radix(num, radix) { + Ok(x) => Ok(ast::ImmediateValue::U64(x)), + Err(err) => Err((ast::ImmediateValue::U64(0), PtxError::from(err))), + }, + } + } + })) + .parse_next(input) +} + +fn f32<'a, 'input>(stream: &mut ParserState<'a, 'input>) -> PResult { + take_error(any.verify_map(|t| match t { + Token::F32(f) => Some(match u32::from_str_radix(&f[2..], 16) { + Ok(x) => Ok(f32::from_bits(x)), + Err(err) => Err((0.0, PtxError::from(err))), + }), + _ => None, + })) + .parse_next(stream) +} + +fn f64<'a, 'input>(stream: &mut ParserState<'a, 'input>) -> PResult { + take_error(any.verify_map(|t| match t { + Token::F64(f) => Some(match u64::from_str_radix(&f[2..], 16) { + Ok(x) => Ok(f64::from_bits(x)), + Err(err) => Err((0.0, PtxError::from(err))), + }), + _ => None, + })) + .parse_next(stream) +} + +fn s32<'a, 'input>(stream: &mut ParserState<'a, 'input>) -> PResult { + take_error((opt(Token::Minus), num).map(|(sign, x)| { + let (text, radix, _) = x; + match i32::from_str_radix(text, radix) { + Ok(x) => Ok(if sign.is_some() { -x } else { x }), + Err(err) => Err((0, PtxError::from(err))), + } + })) + .parse_next(stream) +} + +fn immediate_value<'a, 'input>( + stream: &mut ParserState<'a, 'input>, +) -> PResult { + alt(( + int_immediate, + f32.map(ast::ImmediateValue::F32), + f64.map(ast::ImmediateValue::F64), + )) + .parse_next(stream) +} + +impl ast::ParsedOperand { + fn parse<'a, 'input>( + stream: &mut ParserState<'a, 'input>, + ) -> PResult> { + use winnow::combinator::*; + use winnow::token::any; + fn vector_index<'input>(inp: &'input str) -> Result { + match inp { + "x" | "r" => Ok(0), + "y" | "g" => Ok(1), + "z" | "b" => Ok(2), + "w" | "a" => Ok(3), + _ => Err(PtxError::WrongVectorElement), + } + } + fn ident_operands<'a, 'input>( + stream: &mut ParserState<'a, 'input>, + ) -> PResult> { + let main_ident = ident.parse_next(stream)?; + alt(( + preceded(Token::Plus, s32) + .map(move |offset| ast::ParsedOperand::RegOffset(main_ident, offset)), + take_error(preceded(Token::Dot, ident).map(move |suffix| { + let vector_index = vector_index(suffix) + .map_err(move |e| (ast::ParsedOperand::VecMember(main_ident, 0), e))?; + Ok(ast::ParsedOperand::VecMember(main_ident, vector_index)) + })), + empty.value(ast::ParsedOperand::Reg(main_ident)), + )) + .parse_next(stream) + } + fn vector_operand<'a, 'input>( + stream: &mut ParserState<'a, 'input>, + ) -> PResult> { + let (_, r1, _, r2) = + (Token::LBracket, ident, Token::Comma, ident).parse_next(stream)?; + dispatch! {any; + Token::LBracket => empty.map(|_| vec![r1, r2]), + Token::Comma => (ident, Token::Comma, ident, Token::LBracket).map(|(r3, _, r4, _)| vec![r1, r2, r3, r4]), + _ => fail + } + .parse_next(stream) + } + alt(( + ident_operands, + immediate_value.map(ast::ParsedOperand::Imm), + vector_operand.map(ast::ParsedOperand::VecPack), + )) + .parse_next(stream) + } +} -impl Operand for ParsedOperand {} +#[derive(Debug, thiserror::Error)] +pub enum PtxError { + #[error("{source}")] + ParseInt { + #[from] + source: ParseIntError, + }, + #[error("{source}")] + ParseFloat { + #[from] + source: ParseFloatError, + }, + #[error("")] + SyntaxError, + #[error("")] + NonF32Ftz, + #[error("")] + WrongArrayType, + #[error("")] + WrongVectorElement, + #[error("")] + MultiArrayVariable, + #[error("")] + ZeroDimensionArray, + #[error("")] + ArrayInitalizer, + #[error("")] + NonExternPointer, + #[error("{start}:{end}")] + UnrecognizedStatement { start: usize, end: usize }, + #[error("{start}:{end}")] + UnrecognizedDirective { start: usize, end: usize }, +} #[derive(Debug)] struct ReverseStream<'a, T>(pub &'a [T]); @@ -180,24 +422,20 @@ where type Checkpoint = &'i [T]; - #[inline(always)] fn iter_offsets(&self) -> Self::IterOffsets { self.0.iter().rev().cloned().enumerate() } - #[inline(always)] fn eof_offset(&self) -> usize { self.0.len() } - #[inline(always)] fn next_token(&mut self) -> Option { let (token, next) = self.0.split_last()?; self.0 = next; Some(token.clone()) } - #[inline(always)] fn offset_for

(&self, predicate: P) -> Option where P: Fn(Self::Token) -> bool, @@ -205,7 +443,6 @@ where self.0.iter().rev().position(|b| predicate(b.clone())) } - #[inline(always)] fn offset_at(&self, tokens: usize) -> Result { if let Some(needed) = tokens .checked_sub(self.0.len()) @@ -217,7 +454,6 @@ where } } - #[inline(always)] fn next_slice(&mut self, offset: usize) -> Self::Slice { let offset = self.0.len() - offset; let (next, slice) = self.0.split_at(offset); @@ -225,24 +461,20 @@ where slice } - #[inline(always)] fn checkpoint(&self) -> Self::Checkpoint { self.0 } - #[inline(always)] fn reset(&mut self, checkpoint: &Self::Checkpoint) { self.0 = checkpoint; } - #[inline(always)] fn raw(&self) -> &dyn std::fmt::Debug { self } } impl<'a, T> Offset<&'a [T]> for ReverseStream<'a, T> { - #[inline] fn offset_from(&self, start: &&'a [T]) -> usize { let fst = start.as_ptr(); let snd = self.0.as_ptr(); @@ -267,6 +499,14 @@ impl<'a, T> StreamIsPartial for ReverseStream<'a, T> { } } +impl<'input, I: Stream + StreamIsPartial, E: ParserError> Parser + for Token<'input> +{ + fn parse_next(&mut self, input: &mut I) -> PResult { + any.parse_next(input) + } +} + // Modifiers are turned into arguments to the blocks, with type: // * If it is an alternative: // * If it is mandatory then its type is Foo (as defined by the relevant rule) @@ -275,12 +515,16 @@ impl<'a, T> StreamIsPartial for ReverseStream<'a, T> { // * If it is mandatory then it is skipped // * If it is optional then its type is `bool` +type ParsedOperand<'input> = ast::ParsedOperand<&'input str>; + derive_parser!( #[derive(Logos, PartialEq, Eq, Debug, Clone, Copy)] #[logos(skip r"\s+")] enum Token<'input> { #[token(",")] Comma, + #[token(".")] + Dot, #[regex(r"[a-zA-Z][a-zA-Z0-9_$]*|[_$%][a-zA-Z0-9_$]+", |lex| lex.slice(), priority = 0)] Ident(&'input str), #[token("|")] @@ -293,8 +537,18 @@ derive_parser!( LBracket, #[token("]")] RBracket, - #[regex(r"[0-9]+U?")] - Decimal + #[regex(r"0[fF][0-9a-zA-Z]{8}", |lex| lex.slice())] + F32(&'input str), + #[regex(r"0[dD][0-9a-zA-Z]{16}", |lex| lex.slice())] + F64(&'input str), + #[regex(r"0[xX][0-9a-zA-Z]+U?", |lex| lex.slice())] + Hex(&'input str), + #[regex(r"[0-9]+U?", |lex| lex.slice())] + Decimal(&'input str), + #[token("-")] + Minus, + #[token("+")] + Plus, } #[derive(Copy, Clone, PartialEq, Eq, Hash)] @@ -308,6 +562,20 @@ derive_parser!( #[derive(Copy, Clone, PartialEq, Eq, Hash)] pub enum ScalarType { } + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov + + mov.type d, a => { + Instruction::Mov{ + data: MovDetails::new(type_.into()), + arguments: MovArgs { dst: d, src: a } + } + } + .type: ScalarType = { .pred, + .b16, .b32, .b64, + .u16, .u32, .u64, + .s16, .s32, .s64, + .f32, .f64 }; + // https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-st st{.weak}{.ss}{.cop}{.level::eviction_priority}{.level::cache_hint}{.vec}.type [a], b{, cache_policy} => { todo!() @@ -416,8 +684,15 @@ fn main() { use winnow::token::*; use winnow::Parser; + println!("{}", mem::size_of::()); + let mut input: &[Token] = &[][..]; - let x = opt(any::<_, ContextError>.verify_map(|t| { println!("MAP");Some(true) })).parse_next(&mut input).unwrap(); + let x = opt(any::<_, ContextError>.verify_map(|t| { + println!("MAP"); + Some(true) + })) + .parse_next(&mut input) + .unwrap(); dbg!(x); let lexer = Token::lexer( " @@ -429,7 +704,10 @@ fn main() { ); let tokens = lexer.map(|t| t.unwrap()).collect::>(); println!("{:?}", &tokens); - let mut stream = &tokens[..]; + let mut stream = ParserState { + input: &tokens[..], + state: Vec::new(), + }; parse_instruction(&mut stream).unwrap(); //parse_prefix(&mut lexer); let mut parser = &*tokens; From dbd37f97ad75ac06168f4c26ddc884558f857483 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Thu, 15 Aug 2024 18:51:11 +0200 Subject: [PATCH 03/47] Clean up and improve ident parsing --- gen/src/lib.rs | 111 ++++++++++++++++------------------------- gen_impl/src/parser.rs | 49 +++++++++++++----- ptx_parser/src/main.rs | 8 +-- 3 files changed, 84 insertions(+), 84 deletions(-) diff --git a/gen/src/lib.rs b/gen/src/lib.rs index 30e4595d..3ab5e433 100644 --- a/gen/src/lib.rs +++ b/gen/src/lib.rs @@ -55,6 +55,7 @@ impl OpcodeDefinitions { let candidate = if let DotModifierRef::Direct { optional: false, value, + .. } = candidate { value @@ -227,11 +228,13 @@ impl SingleOpcodeDefinition { DotModifierRef::Direct { optional, value: alts.alternatives[0].clone(), + name: modifier, } } else { DotModifierRef::Indirect { optional, value: alts.clone(), + name: modifier, } } } @@ -239,7 +242,8 @@ impl SingleOpcodeDefinition { possible_modifiers.insert(modifier.clone()); DotModifierRef::Direct { optional, - value: modifier, + value: modifier.clone(), + name: modifier, } } }, @@ -306,8 +310,9 @@ pub fn derive_parser(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream }) .collect::>(); let mut token_enum = parse_definitions.token_type; - let (_, all_modifier) = write_definitions_into_tokens(&definitions, &mut token_enum.variants); - let token_impl = emit_parse_function(&token_enum.ident, &definitions, all_modifier); + let (all_opcode, all_modifier) = + write_definitions_into_tokens(&definitions, &mut token_enum.variants); + let token_impl = emit_parse_function(&token_enum.ident, &definitions, all_opcode, all_modifier); let tokens = quote! { #enum_types_tokens @@ -364,6 +369,7 @@ fn emit_enum_types( fn emit_parse_function( type_name: &Ident, defs: &FxHashMap, + all_opcode: Vec<&Ident>, all_modifier: FxHashSet<&parser::DotModifier>, ) -> TokenStream { use std::fmt::Write; @@ -437,9 +443,24 @@ fn emit_parse_function( .to_tokens(&mut result); result }); + let opcodes = all_opcode.into_iter().map(|op_ident| { + let op = op_ident.to_string(); + let variant = Ident::new(&capitalize(&op), op_ident.span()); + let value = op; + quote! { + #type_name :: #variant => Some(#value), + } + }); let modifier_names = all_modifier.iter().map(|m| m.dot_capitalized()); quote! { impl<'input> #type_name<'input> { + fn opcode_text(self) -> Option<&'static str> { + match self { + #(#opcodes)* + _ => None + } + } + fn modifier(self) -> bool { match self { #( @@ -452,7 +473,7 @@ fn emit_parse_function( #(#fns_)* - fn parse_instruction<'a, 'input>(stream: &mut ParserState<'a, 'input>) -> winnow::error::PResult>> + fn parse_instruction<'a, 'input>(stream: &mut ParserState<'a, 'input>) -> winnow::error::PResult>> { use winnow::Parser; use winnow::token::*; @@ -490,9 +511,8 @@ fn emit_definition_parser( }); let ordered_parse = definition.ordered_modifiers.iter().rev().map(|modifier| { let arg_name = modifier.ident(); - let arg_type = modifier.type_of(); match modifier { - DotModifierRef::Direct { optional, value } => { + DotModifierRef::Direct { optional, value, .. } => { let variant = value.dot_capitalized(); if *optional { quote! { @@ -504,7 +524,7 @@ fn emit_definition_parser( } } } - DotModifierRef::Indirect { optional, value } => { + DotModifierRef::Indirect { optional, value, .. } => { let variants = value.alternatives.iter().map(|alt| { let type_ = value.type_.as_ref().unwrap(); let token_variant = alt.dot_capitalized(); @@ -546,8 +566,8 @@ fn emit_definition_parser( .unordered_modifiers .iter() .map(|modifier| match modifier { - DotModifierRef::Direct { value, .. } => { - let name = value.ident(); + DotModifierRef::Direct { name, value, .. } => { + let name = name.ident(); let token_variant = value.dot_capitalized(); quote! { #token_type :: #token_variant => { @@ -558,8 +578,8 @@ fn emit_definition_parser( } } } - DotModifierRef::Indirect { value, .. } => { - let variable = value.modifier.ident(); + DotModifierRef::Indirect { value, name, .. } => { + let variable = name.ident(); let type_ = value.type_.as_ref().unwrap(); let alternatives = value.alternatives.iter().map(|alt| { let token_variant = alt.dot_capitalized(); @@ -585,9 +605,10 @@ fn emit_definition_parser( .map(|modifier| match modifier { DotModifierRef::Direct { optional: false, - value, + name, + .. } => { - let variable = value.ident(); + let variable = name.ident(); quote! { if !#variable { #return_error; @@ -597,9 +618,10 @@ fn emit_definition_parser( DotModifierRef::Direct { optional: true, .. } => TokenStream::new(), DotModifierRef::Indirect { optional: false, - value, + name, + .. } => { - let variable = value.modifier.ident(); + let variable = name.ident(); quote! { let #variable = match #variable { Some(x) => x, @@ -749,9 +771,11 @@ enum DotModifierRef { Direct { optional: bool, value: parser::DotModifier, + name: parser::DotModifier, }, Indirect { optional: bool, + name: parser::DotModifier, value: Rc, }, } @@ -759,8 +783,8 @@ enum DotModifierRef { impl DotModifierRef { fn ident(&self) -> Ident { match self { - DotModifierRef::Direct { value, .. } => value.ident(), - DotModifierRef::Indirect { value, .. } => value.modifier.ident(), + DotModifierRef::Direct { name, .. } => name.ident(), + DotModifierRef::Indirect { name, .. } => name.ident(), } } @@ -770,7 +794,9 @@ impl DotModifierRef { DotModifierRef::Direct { optional: false, .. } => return None, - DotModifierRef::Indirect { optional, value } => { + DotModifierRef::Indirect { + optional, value, .. + } => { let type_ = value .type_ .as_ref() @@ -798,55 +824,6 @@ impl DotModifierRef { } } -impl Hash for DotModifierRef { - fn hash(&self, state: &mut H) { - match self { - DotModifierRef::Direct { optional, value } => { - optional.hash(state); - value.hash(state); - } - DotModifierRef::Indirect { optional, value } => { - optional.hash(state); - (value.as_ref() as *const parser::Rule).hash(state); - } - } - } -} - -impl PartialEq for DotModifierRef { - fn eq(&self, other: &Self) -> bool { - match (self, other) { - ( - Self::Direct { - optional: l_optional, - value: l_value, - }, - Self::Direct { - optional: r_optional, - value: r_value, - }, - ) => l_optional == r_optional && l_value == r_value, - ( - Self::Indirect { - optional: l_optional, - value: l_value, - }, - Self::Indirect { - optional: r_optional, - value: r_value, - }, - ) => { - l_optional == r_optional - && l_value.as_ref() as *const parser::Rule - == r_value.as_ref() as *const parser::Rule - } - _ => false, - } - } -} - -impl Eq for DotModifierRef {} - #[proc_macro] pub fn generate_instruction_type(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream { let input = parse_macro_input!(tokens as gen_impl::GenerateInstructionType); diff --git a/gen_impl/src/parser.rs b/gen_impl/src/parser.rs index c8da61d1..b57d6ece 100644 --- a/gen_impl/src/parser.rs +++ b/gen_impl/src/parser.rs @@ -7,6 +7,7 @@ use std::fmt::Write; use syn::bracketed; use syn::parse::Peek; use syn::punctuated::Punctuated; +use syn::spanned::Spanned; use syn::LitInt; use syn::Type; use syn::{braced, parse::Parse, token, Ident, ItemEnum, Token}; @@ -155,6 +156,10 @@ impl Rule { struct IdentOrTypeSuffix(IdentLike); impl IdentOrTypeSuffix { + fn span(&self) -> Span { + self.0.span() + } + fn peek(input: syn::parse::ParseStream) -> bool { input.peek(Token![::]) } @@ -278,6 +283,15 @@ impl std::fmt::Debug for DotModifier { } impl DotModifier { + pub fn span(&self) -> Span { + let part1 = self.part1.span(); + if let Some(ref part2) = self.part2 { + part1.join(part2.span()).unwrap_or(part1) + } else { + part1 + } + } + pub fn ident(&self) -> Ident { let mut result = String::new(); write!(&mut result, "{}", self.part1).unwrap(); @@ -285,11 +299,11 @@ impl DotModifier { write!(&mut result, "_{}", part2.0).unwrap(); } else { match self.part1 { - IdentLike::Type | IdentLike::Const => result.push('_'), + IdentLike::Type(_) | IdentLike::Const(_) => result.push('_'), IdentLike::Ident(_) | IdentLike::Integer(_) => {} } } - Ident::new(&result.to_ascii_lowercase(), Span::call_site()) + Ident::new(&result.to_ascii_lowercase(), self.span()) } pub fn variant_capitalized(&self) -> Ident { @@ -321,7 +335,7 @@ impl DotModifier { }; result.push(c); } - Ident::new(&result, Span::call_site()) + Ident::new(&result, self.span()) } pub fn tokens(&self) -> TokenStream { @@ -353,17 +367,28 @@ impl Parse for DotModifier { #[derive(PartialEq, Eq, Hash, Clone)] enum IdentLike { - Type, - Const, + Type(Token![type]), + Const(Token![const]), Ident(Ident), Integer(LitInt), } +impl IdentLike { + fn span(&self) -> Span { + match self { + IdentLike::Type(c) => c.span(), + IdentLike::Const(t) => t.span(), + IdentLike::Ident(i) => i.span(), + IdentLike::Integer(l) => l.span(), + } + } +} + impl std::fmt::Display for IdentLike { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - IdentLike::Type => f.write_str("type"), - IdentLike::Const => f.write_str("const"), + IdentLike::Type(_) => f.write_str("type"), + IdentLike::Const(_) => f.write_str("const"), IdentLike::Ident(ident) => write!(f, "{}", ident), IdentLike::Integer(integer) => write!(f, "{}", integer), } @@ -373,8 +398,8 @@ impl std::fmt::Display for IdentLike { impl ToTokens for IdentLike { fn to_tokens(&self, tokens: &mut TokenStream) { match self { - IdentLike::Type => quote! { type }.to_tokens(tokens), - IdentLike::Const => quote! { const }.to_tokens(tokens), + IdentLike::Type(_) => quote! { type }.to_tokens(tokens), + IdentLike::Const(_) => quote! { const }.to_tokens(tokens), IdentLike::Ident(ident) => quote! { #ident }.to_tokens(tokens), IdentLike::Integer(int) => quote! { #int }.to_tokens(tokens), } @@ -385,11 +410,9 @@ impl Parse for IdentLike { fn parse(input: syn::parse::ParseStream) -> syn::Result { let lookahead = input.lookahead1(); Ok(if lookahead.peek(Token![const]) { - input.parse::()?; - IdentLike::Const + IdentLike::Const(input.parse::()?) } else if lookahead.peek(Token![type]) { - input.parse::()?; - IdentLike::Type + IdentLike::Type(input.parse::()?) } else if lookahead.peek(Ident) { IdentLike::Ident(input.parse::()?) } else if lookahead.peek(LitInt) { diff --git a/ptx_parser/src/main.rs b/ptx_parser/src/main.rs index 4f3ed415..96f08b6e 100644 --- a/ptx_parser/src/main.rs +++ b/ptx_parser/src/main.rs @@ -2,9 +2,8 @@ use gen::derive_parser; use logos::Logos; use std::mem; use std::num::{ParseFloatError, ParseIntError}; -use winnow::combinator::{alt, empty, fail, opt}; -use winnow::stream::SliceLen; -use winnow::token::{any, literal}; +use winnow::combinator::*; +use winnow::token::any; use winnow::{ error::{ContextError, ParserError}, stream::{Offset, Stream, StreamIsPartial}, @@ -206,6 +205,8 @@ fn ident<'a, 'input>(stream: &mut ParserState<'a, 'input>) -> PResult<&'input st any.verify_map(|t| { if let Token::Ident(text) = t { Some(text) + } else if let Some(text) = t.opcode_text() { + Some(text) } else { None } @@ -563,7 +564,6 @@ derive_parser!( pub enum ScalarType { } // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov - mov.type d, a => { Instruction::Mov{ data: MovDetails::new(type_.into()), From ba17906de8381482241dc151d4891845a84bc71e Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Thu, 15 Aug 2024 19:30:09 +0200 Subject: [PATCH 04/47] Pass parser state to instruction callbacks --- gen/src/lib.rs | 6 +++--- ptx_parser/src/main.rs | 35 ++++++++++++++++++----------------- 2 files changed, 21 insertions(+), 20 deletions(-) diff --git a/gen/src/lib.rs b/gen/src/lib.rs index 3ab5e433..6bea2df7 100644 --- a/gen/src/lib.rs +++ b/gen/src/lib.rs @@ -383,7 +383,7 @@ fn emit_parse_function( let code_block = &def.code_block.0; let args = def.function_arguments_declarations(); quote! { - fn #fn_name<'input>( #(#args),* ) -> Instruction> #code_block + fn #fn_name<'input>(state: &mut PtxParserState, #(#args),* ) -> Instruction> #code_block } }) }) @@ -473,7 +473,7 @@ fn emit_parse_function( #(#fns_)* - fn parse_instruction<'a, 'input>(stream: &mut ParserState<'a, 'input>) -> winnow::error::PResult>> + fn parse_instruction<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> winnow::error::PResult>> { use winnow::Parser; use winnow::token::*; @@ -695,7 +695,7 @@ fn emit_definition_parser( let fn_args = definition.function_arguments(); let fn_name = format_ident!("{}_{}", opcode, fn_idx); let fn_call = quote! { - #fn_name( #(#fn_args),* ) + #fn_name(&mut stream.state, #(#fn_args),* ) }; quote! { #(#unordered_parse_declarations)* diff --git a/ptx_parser/src/main.rs b/ptx_parser/src/main.rs index 96f08b6e..7786debc 100644 --- a/ptx_parser/src/main.rs +++ b/ptx_parser/src/main.rs @@ -199,9 +199,10 @@ pub struct RetData { pub uniform: bool, } -type ParserState<'a, 'input> = Stateful<&'a [Token<'input>], Vec>; +type PtxParserState = Vec; +type PtxParser<'a, 'input> = Stateful<&'a [Token<'input>], PtxParserState>; -fn ident<'a, 'input>(stream: &mut ParserState<'a, 'input>) -> PResult<&'input str> { +fn ident<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<&'input str> { any.verify_map(|t| { if let Token::Ident(text) = t { Some(text) @@ -214,7 +215,7 @@ fn ident<'a, 'input>(stream: &mut ParserState<'a, 'input>) -> PResult<&'input st .parse_next(stream) } -fn num<'a, 'input>(stream: &mut ParserState<'a, 'input>) -> PResult<(&'input str, u32, bool)> { +fn num<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<(&'input str, u32, bool)> { any.verify_map(|t| { Some(match t { Token::Hex(s) => { @@ -239,9 +240,9 @@ fn num<'a, 'input>(stream: &mut ParserState<'a, 'input>) -> PResult<(&'input str } fn take_error<'a, 'input: 'a, O, E>( - mut parser: impl Parser, Result, E>, -) -> impl Parser, O, E> { - move |input: &mut ParserState<'a, 'input>| { + mut parser: impl Parser, Result, E>, +) -> impl Parser, O, E> { + move |input: &mut PtxParser<'a, 'input>| { Ok(match parser.parse_next(input)? { Ok(x) => x, Err((x, err)) => { @@ -252,7 +253,7 @@ fn take_error<'a, 'input: 'a, O, E>( } } -fn int_immediate<'a, 'input>(input: &mut ParserState<'a, 'input>) -> PResult { +fn int_immediate<'a, 'input>(input: &mut PtxParser<'a, 'input>) -> PResult { take_error((opt(Token::Minus), num).map(|(neg, x)| { let (num, radix, is_unsigned) = x; if neg.is_some() { @@ -278,7 +279,7 @@ fn int_immediate<'a, 'input>(input: &mut ParserState<'a, 'input>) -> PResult(stream: &mut ParserState<'a, 'input>) -> PResult { +fn f32<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult { take_error(any.verify_map(|t| match t { Token::F32(f) => Some(match u32::from_str_radix(&f[2..], 16) { Ok(x) => Ok(f32::from_bits(x)), @@ -289,7 +290,7 @@ fn f32<'a, 'input>(stream: &mut ParserState<'a, 'input>) -> PResult { .parse_next(stream) } -fn f64<'a, 'input>(stream: &mut ParserState<'a, 'input>) -> PResult { +fn f64<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult { take_error(any.verify_map(|t| match t { Token::F64(f) => Some(match u64::from_str_radix(&f[2..], 16) { Ok(x) => Ok(f64::from_bits(x)), @@ -300,7 +301,7 @@ fn f64<'a, 'input>(stream: &mut ParserState<'a, 'input>) -> PResult { .parse_next(stream) } -fn s32<'a, 'input>(stream: &mut ParserState<'a, 'input>) -> PResult { +fn s32<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult { take_error((opt(Token::Minus), num).map(|(sign, x)| { let (text, radix, _) = x; match i32::from_str_radix(text, radix) { @@ -312,7 +313,7 @@ fn s32<'a, 'input>(stream: &mut ParserState<'a, 'input>) -> PResult { } fn immediate_value<'a, 'input>( - stream: &mut ParserState<'a, 'input>, + stream: &mut PtxParser<'a, 'input>, ) -> PResult { alt(( int_immediate, @@ -324,7 +325,7 @@ fn immediate_value<'a, 'input>( impl ast::ParsedOperand { fn parse<'a, 'input>( - stream: &mut ParserState<'a, 'input>, + stream: &mut PtxParser<'a, 'input>, ) -> PResult> { use winnow::combinator::*; use winnow::token::any; @@ -338,7 +339,7 @@ impl ast::ParsedOperand { } } fn ident_operands<'a, 'input>( - stream: &mut ParserState<'a, 'input>, + stream: &mut PtxParser<'a, 'input>, ) -> PResult> { let main_ident = ident.parse_next(stream)?; alt(( @@ -354,7 +355,7 @@ impl ast::ParsedOperand { .parse_next(stream) } fn vector_operand<'a, 'input>( - stream: &mut ParserState<'a, 'input>, + stream: &mut PtxParser<'a, 'input>, ) -> PResult> { let (_, r1, _, r2) = (Token::LBracket, ident, Token::Comma, ident).parse_next(stream)?; @@ -565,9 +566,9 @@ derive_parser!( // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov mov.type d, a => { - Instruction::Mov{ + Instruction::Mov { data: MovDetails::new(type_.into()), - arguments: MovArgs { dst: d, src: a } + arguments: MovArgs { dst: d, src: a }, } } .type: ScalarType = { .pred, @@ -704,7 +705,7 @@ fn main() { ); let tokens = lexer.map(|t| t.unwrap()).collect::>(); println!("{:?}", &tokens); - let mut stream = ParserState { + let mut stream = PtxParser { input: &tokens[..], state: Vec::new(), }; From 0da45ea7d8e3febf681e0036a17627950bac6a7d Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Thu, 15 Aug 2024 22:24:53 +0200 Subject: [PATCH 05/47] Add parsing of st, allow associating type with a non-alternative modifier --- gen/src/lib.rs | 123 +++++++++++++++++++++++++++++------ gen_impl/src/parser.rs | 17 +++-- ptx_parser/src/ast.rs | 19 ++++++ ptx_parser/src/main.rs | 143 ++++++++++++++++++++++++++++++----------- 4 files changed, 239 insertions(+), 63 deletions(-) diff --git a/gen/src/lib.rs b/gen/src/lib.rs index 6bea2df7..6ea01367 100644 --- a/gen/src/lib.rs +++ b/gen/src/lib.rs @@ -3,7 +3,9 @@ use proc_macro2::{Span, TokenStream}; use quote::{format_ident, quote, ToTokens}; use rustc_hash::{FxHashMap, FxHashSet}; use std::{collections::hash_map, hash::Hash, rc::Rc}; -use syn::{parse_macro_input, punctuated::Punctuated, Ident, ItemEnum, Token, TypePath, Variant}; +use syn::{ + parse_macro_input, punctuated::Punctuated, Ident, ItemEnum, Token, Type, TypePath, Variant, +}; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#vectors // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#fundamental-types @@ -46,7 +48,7 @@ impl OpcodeDefinitions { _ => {} } 'check_definitions: for i in unselected.iter().copied() { - // just pick the first alternative and attempt every modifier + // Attempt every modifier 'check_candidates: for candidate in definitions[i] .unordered_modifiers .iter() @@ -203,32 +205,31 @@ impl SingleOpcodeDefinition { output: &mut FxHashMap>, parser::OpcodeDefinition(pattern_seq, rules): parser::OpcodeDefinition, ) { - let mut rules = rules - .into_iter() - .map(|r| (r.modifier.clone(), Rc::new(r))) - .collect::>(); + let (mut named_rules, mut unnamed_rules) = gather_rules(rules); let mut last_opcode = pattern_seq.0.last().unwrap().0 .0.name.clone(); for (opcode_decl, code_block) in pattern_seq.0.into_iter().rev() { let current_opcode = opcode_decl.0.name.clone(); if last_opcode != current_opcode { - rules = FxHashMap::default(); + named_rules = FxHashMap::default(); + unnamed_rules = FxHashMap::default(); } let mut possible_modifiers = FxHashSet::default(); - for (_, options) in rules.iter() { + for (_, options) in named_rules.iter() { possible_modifiers.extend(options.alternatives.iter().cloned()); } let parser::OpcodeDecl(instruction, arguments) = opcode_decl; let mut unordered_modifiers = instruction .modifiers .into_iter() - .map( - |parser::MaybeDotModifier { optional, modifier }| match rules.get(&modifier) { + .map(|parser::MaybeDotModifier { optional, modifier }| { + match named_rules.get(&modifier) { Some(alts) => { if alts.alternatives.len() == 1 && alts.type_.is_none() { DotModifierRef::Direct { optional, value: alts.alternatives[0].clone(), name: modifier, + type_: alts.type_.clone(), } } else { DotModifierRef::Indirect { @@ -239,15 +240,17 @@ impl SingleOpcodeDefinition { } } None => { + let type_ = unnamed_rules.get(&modifier).cloned(); possible_modifiers.insert(modifier.clone()); DotModifierRef::Direct { optional, value: modifier.clone(), name: modifier, + type_, } } - }, - ) + } + }) .collect::>(); let ordered_modifiers = Self::extract_ordered_modifiers(&mut unordered_modifiers); let entry = Self { @@ -293,6 +296,29 @@ impl SingleOpcodeDefinition { } } +fn gather_rules( + rules: Vec, +) -> ( + FxHashMap>, + FxHashMap, +) { + let mut named = FxHashMap::default(); + let mut unnamed = FxHashMap::default(); + for rule in rules { + match rule.modifier { + Some(ref modifier) => { + named.insert(modifier.clone(), Rc::new(rule)); + } + None => unnamed.extend( + rule.alternatives + .into_iter() + .map(|alt| (alt, rule.type_.as_ref().unwrap().clone())), + ), + } + } + (named, unnamed) +} + #[proc_macro] pub fn derive_parser(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream { let parse_definitions = parse_macro_input!(tokens as gen_impl::parser::ParseDefinitions); @@ -512,7 +538,7 @@ fn emit_definition_parser( let ordered_parse = definition.ordered_modifiers.iter().rev().map(|modifier| { let arg_name = modifier.ident(); match modifier { - DotModifierRef::Direct { optional, value, .. } => { + DotModifierRef::Direct { optional, value, type_: None, .. } => { let variant = value.dot_capitalized(); if *optional { quote! { @@ -524,6 +550,7 @@ fn emit_definition_parser( } } } + DotModifierRef::Direct { type_: Some(_), .. } => { todo!() } DotModifierRef::Indirect { optional, value, .. } => { let variants = value.alternatives.iter().map(|alt| { let type_ = value.type_.as_ref().unwrap(); @@ -566,7 +593,12 @@ fn emit_definition_parser( .unordered_modifiers .iter() .map(|modifier| match modifier { - DotModifierRef::Direct { name, value, .. } => { + DotModifierRef::Direct { + name, + value, + type_: None, + .. + } => { let name = name.ident(); let token_variant = value.dot_capitalized(); quote! { @@ -578,6 +610,24 @@ fn emit_definition_parser( } } } + DotModifierRef::Direct { + name, + value, + type_: Some(type_), + .. + } => { + let variable = name.ident(); + let token_variant = value.dot_capitalized(); + let enum_variant = value.variant_capitalized(); + quote! { + #token_type :: #token_variant => { + if #variable.is_some() { + #return_error_ref; + } + #variable = Some(#type_ :: #enum_variant); + } + } + } DotModifierRef::Indirect { value, name, .. } => { let variable = name.ident(); let type_ = value.type_.as_ref().unwrap(); @@ -606,6 +656,7 @@ fn emit_definition_parser( DotModifierRef::Direct { optional: false, name, + type_: None, .. } => { let variable = name.ident(); @@ -615,7 +666,20 @@ fn emit_definition_parser( } } } - DotModifierRef::Direct { optional: true, .. } => TokenStream::new(), + DotModifierRef::Direct { + optional: false, + name, + type_: Some(type_), + .. + } => { + let variable = name.ident(); + quote! { + let #variable = match #variable { + Some(x) => x, + None => #return_error + }; + } + } DotModifierRef::Indirect { optional: false, name, @@ -629,7 +693,8 @@ fn emit_definition_parser( }; } } - DotModifierRef::Indirect { optional: true, .. } => TokenStream::new(), + DotModifierRef::Direct { optional: true, .. } + | DotModifierRef::Indirect { optional: true, .. } => TokenStream::new(), }); let arguments_parse = definition.arguments.0.iter().enumerate().map(|(idx, arg)| { let comma = if idx == 0 { @@ -772,6 +837,7 @@ enum DotModifierRef { optional: bool, value: parser::DotModifier, name: parser::DotModifier, + type_: Option, }, Indirect { optional: bool, @@ -790,10 +856,26 @@ impl DotModifierRef { fn type_of(&self) -> Option { Some(match self { - DotModifierRef::Direct { optional: true, .. } => syn::parse_quote! { bool }, DotModifierRef::Direct { - optional: false, .. + optional: true, + type_: None, + .. + } => syn::parse_quote! { bool }, + DotModifierRef::Direct { + optional: false, + type_: None, + .. } => return None, + DotModifierRef::Direct { + optional: true, + type_: Some(type_), + .. + } => syn::parse_quote! { Option<#type_> }, + DotModifierRef::Direct { + optional: false, + type_: Some(type_), + .. + } => type_.clone(), DotModifierRef::Indirect { optional, value, .. } => { @@ -812,7 +894,10 @@ impl DotModifierRef { fn type_of_check(&self) -> syn::Type { match self { - DotModifierRef::Direct { .. } => syn::parse_quote! { bool }, + DotModifierRef::Direct { type_: None, .. } => syn::parse_quote! { bool }, + DotModifierRef::Direct { + type_: Some(type_), .. + } => syn::parse_quote! { Option<#type_> }, DotModifierRef::Indirect { value, .. } => { let type_ = value .type_ diff --git a/gen_impl/src/parser.rs b/gen_impl/src/parser.rs index b57d6ece..6834cbca 100644 --- a/gen_impl/src/parser.rs +++ b/gen_impl/src/parser.rs @@ -97,7 +97,7 @@ impl Parse for CodeBlock { } pub struct Rule { - pub modifier: DotModifier, + pub modifier: Option, pub type_: Option, pub alternatives: Vec, } @@ -105,6 +105,7 @@ pub struct Rule { impl Rule { fn peek(input: syn::parse::ParseStream) -> bool { DotModifier::peek(input) + || (input.peek(Ident) && input.peek2(Token![=]) && !input.peek3(Token![>])) } fn parse_alternatives(input: syn::parse::ParseStream) -> syn::Result> { @@ -181,12 +182,16 @@ impl Parse for IdentOrTypeSuffix { impl Parse for Rule { fn parse(input: syn::parse::ParseStream) -> syn::Result { - let modifier = input.parse::()?; - let type_ = if input.peek(Token![:]) { - input.parse::()?; - Some(input.parse::()?) + let (modifier, type_) = if DotModifier::peek(input) { + let modifier = Some(input.parse::()?); + if input.peek(Token![:]) { + input.parse::()?; + (modifier, Some(input.parse::()?)) + } else { + (modifier, None) + } } else { - None + (None, Some(input.parse::()?)) }; input.parse::()?; let content; diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index ae4eaba3..c45a241a 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -1,3 +1,5 @@ +use super::MemScope; + #[derive(Clone)] pub enum ParsedOperand { Reg(Ident), @@ -14,3 +16,20 @@ pub enum ImmediateValue { F32(f32), F64(f64), } + +#[derive(Copy, Clone, PartialEq, Eq)] +pub enum StCacheOperator { + Writeback, + L2Only, + Streaming, + Writethrough, +} + +#[derive(Copy, Clone, PartialEq, Eq)] +pub enum LdStQualifier { + Weak, + Volatile, + Relaxed(MemScope), + Acquire(MemScope), + Release(MemScope), +} diff --git a/ptx_parser/src/main.rs b/ptx_parser/src/main.rs index 7786debc..dd9e6d21 100644 --- a/ptx_parser/src/main.rs +++ b/ptx_parser/src/main.rs @@ -39,9 +39,9 @@ pub struct MovDetails { } impl MovDetails { - pub fn new(typ: Type) -> Self { + fn new(vector: Option, scalar: ScalarType) -> Self { MovDetails { - typ, + typ: Type::maybe_vector(vector, scalar), src_is_address: false, dst_width: 0, src_width: 0, @@ -99,7 +99,7 @@ gen::generate_instruction_type!( ); pub struct LdDetails { - pub qualifier: LdStQualifier, + pub qualifier: ast::LdStQualifier, pub state_space: StateSpace, pub caching: LdCacheOperator, pub typ: Type, @@ -164,41 +164,54 @@ pub enum Type { Array(ScalarType, Vec), } +impl Type { + fn maybe_vector(vector: Option, scalar: ScalarType) -> Self { + match vector { + Some(VectorPrefix::V2) => Type::Vector(scalar, 2), + Some(VectorPrefix::V4) => Type::Vector(scalar, 4), + None => Type::Scalar(scalar), + } + } +} + impl From for Type { fn from(value: ScalarType) -> Self { Type::Scalar(value) } } -#[derive(Copy, Clone, PartialEq, Eq)] -pub enum LdStQualifier { - Weak, - Volatile, - Relaxed(MemScope), - Acquire(MemScope), - Release(MemScope), -} - pub struct StData { - pub qualifier: LdStQualifier, + pub qualifier: ast::LdStQualifier, pub state_space: StateSpace, - pub caching: StCacheOperator, + pub caching: ast::StCacheOperator, pub typ: Type, } -#[derive(PartialEq, Eq)] -pub enum StCacheOperator { - Writeback, - L2Only, - Streaming, - Writethrough, -} - #[derive(Copy, Clone)] pub struct RetData { pub uniform: bool, } +impl From for ast::StCacheOperator { + fn from(value: RawStCacheOperator) -> Self { + match value { + RawStCacheOperator::Wb => ast::StCacheOperator::Writeback, + RawStCacheOperator::Cg => ast::StCacheOperator::L2Only, + RawStCacheOperator::Cs => ast::StCacheOperator::Streaming, + RawStCacheOperator::Wt => ast::StCacheOperator::Writethrough, + } + } +} + +impl From for ast::LdStQualifier { + fn from(value: RawLdStQualifier) -> Self { + match value { + RawLdStQualifier::Weak => ast::LdStQualifier::Weak, + RawLdStQualifier::Volatile => ast::LdStQualifier::Volatile, + } + } +} + type PtxParserState = Vec; type PtxParser<'a, 'input> = Stateful<&'a [Token<'input>], PtxParserState>; @@ -312,9 +325,7 @@ fn s32<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult { .parse_next(stream) } -fn immediate_value<'a, 'input>( - stream: &mut PtxParser<'a, 'input>, -) -> PResult { +fn immediate_value<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult { alt(( int_immediate, f32.map(ast::ImmediateValue::F32), @@ -388,6 +399,8 @@ pub enum PtxError { source: ParseFloatError, }, #[error("")] + Todo, + #[error("")] SyntaxError, #[error("")] NonF32Ftz, @@ -555,7 +568,8 @@ derive_parser!( #[derive(Copy, Clone, PartialEq, Eq, Hash)] pub enum StateSpace { - Reg + Reg, + Generic, } #[derive(Copy, Clone, PartialEq, Eq, Hash)] @@ -565,33 +579,84 @@ derive_parser!( pub enum ScalarType { } // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov - mov.type d, a => { + mov{.vec}.type d, a => { Instruction::Mov { - data: MovDetails::new(type_.into()), + data: MovDetails::new(vec, type_), arguments: MovArgs { dst: d, src: a }, } } - .type: ScalarType = { .pred, - .b16, .b32, .b64, - .u16, .u32, .u64, - .s16, .s32, .s64, - .f32, .f64 }; + .vec: VectorPrefix = { .v2, .v4 }; + .type: ScalarType = { .pred, + .b16, .b32, .b64, + .u16, .u32, .u64, + .s16, .s32, .s64, + .f32, .f64 }; // https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-st st{.weak}{.ss}{.cop}{.level::eviction_priority}{.level::cache_hint}{.vec}.type [a], b{, cache_policy} => { - todo!() + if level_eviction_priority.is_some() || level_cache_hint || cache_policy.is_some() { + state.push(PtxError::Todo); + } + Instruction::St { + data: StData { + qualifier: weak.unwrap_or(RawLdStQualifier::Weak).into(), + state_space: ss.unwrap_or(StateSpace::Generic), + caching: cop.unwrap_or(RawStCacheOperator::Wb).into(), + typ: Type::maybe_vector(vec, type_) + }, + arguments: StArgs { src1:a, src2:b } + } } st.volatile{.ss}{.vec}.type [a], b => { - todo!() + Instruction::St { + data: StData { + qualifier: volatile.into(), + state_space: ss.unwrap_or(StateSpace::Generic), + caching: ast::StCacheOperator::Writeback, + typ: Type::maybe_vector(vec, type_) + }, + arguments: StArgs { src1:a, src2:b } + } } st.relaxed.scope{.ss}{.level::eviction_priority}{.level::cache_hint}{.vec}.type [a], b{, cache_policy} => { - todo!() + if level_eviction_priority.is_some() || level_cache_hint || cache_policy.is_some() { + state.push(PtxError::Todo); + } + Instruction::St { + data: StData { + qualifier: ast::LdStQualifier::Relaxed(scope), + state_space: ss.unwrap_or(StateSpace::Generic), + caching: ast::StCacheOperator::Writeback, + typ: Type::maybe_vector(vec, type_) + }, + arguments: StArgs { src1:a, src2:b } + } } st.release.scope{.ss}{.level::eviction_priority}{.level::cache_hint}{.vec}.type [a], b{, cache_policy} => { - todo!() + if level_eviction_priority.is_some() || level_cache_hint || cache_policy.is_some() { + state.push(PtxError::Todo); + } + Instruction::St { + data: StData { + qualifier: ast::LdStQualifier::Release(scope), + state_space: ss.unwrap_or(StateSpace::Generic), + caching: ast::StCacheOperator::Writeback, + typ: Type::maybe_vector(vec, type_) + }, + arguments: StArgs { src1:a, src2:b } + } } st.mmio.relaxed.sys{.global}.type [a], b => { - todo!() + state.push(PtxError::Todo); + Instruction::St { + data: StData { + qualifier: ast::LdStQualifier::Relaxed(MemScope::Sys), + state_space: global.unwrap_or(StateSpace::Generic), + caching: ast::StCacheOperator::Writeback, + typ: type_.into() + }, + arguments: StArgs { src1:a, src2:b } + } } .ss: StateSpace = { .global, .local, .param{::func}, .shared{::cta, ::cluster} }; @@ -605,6 +670,8 @@ derive_parser!( .u8, .u16, .u32, .u64, .s8, .s16, .s32, .s64, .f32, .f64 }; + RawLdStQualifier = { .weak, .volatile }; + StateSpace = { .global }; // https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-ld ld{.weak}{.ss}{.cop}{.level::eviction_priority}{.level::cache_hint}{.level::prefetch_size}{.vec}.type d, [a]{.unified}{, cache_policy} => { From 0112880f2742183558bbfd27022f080a8e8817fd Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Fri, 16 Aug 2024 16:02:26 +0200 Subject: [PATCH 06/47] Parse ld, add, ret --- gen/src/lib.rs | 68 +++++++++--- ptx_parser/src/ast.rs | 48 +++++++++ ptx_parser/src/main.rs | 239 +++++++++++++++++++++++++++++++++++------ 3 files changed, 309 insertions(+), 46 deletions(-) diff --git a/gen/src/lib.rs b/gen/src/lib.rs index 6ea01367..93b31fe9 100644 --- a/gen/src/lib.rs +++ b/gen/src/lib.rs @@ -2,9 +2,10 @@ use gen_impl::parser; use proc_macro2::{Span, TokenStream}; use quote::{format_ident, quote, ToTokens}; use rustc_hash::{FxHashMap, FxHashSet}; -use std::{collections::hash_map, hash::Hash, rc::Rc}; +use std::{collections::hash_map, hash::Hash, iter, rc::Rc}; use syn::{ - parse_macro_input, punctuated::Punctuated, Ident, ItemEnum, Token, Type, TypePath, Variant, + parse_macro_input, parse_quote, punctuated::Punctuated, Ident, ItemEnum, Token, Type, TypePath, + Variant, }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#vectors @@ -176,10 +177,15 @@ impl SingleOpcodeDefinition { }) .chain(self.arguments.0.iter().map(|arg| { let name = &arg.ident; + let arg_type = if arg.unified { + quote! { (ParsedOperand<'input>, bool) } + } else { + quote! { ParsedOperand<'input> } + }; if arg.optional { - quote! { #name : Option> } + quote! { #name : Option<#arg_type> } } else { - quote! { #name : ParsedOperand<'input> } + quote! { #name : #arg_type } } })) } @@ -477,7 +483,8 @@ fn emit_parse_function( #type_name :: #variant => Some(#value), } }); - let modifier_names = all_modifier.iter().map(|m| m.dot_capitalized()); + let modifier_names = iter::once(Ident::new("DotUnified", Span::call_site())) + .chain(all_modifier.iter().map(|m| m.dot_capitalized())); quote! { impl<'input> #type_name<'input> { fn opcode_text(self) -> Option<&'static str> { @@ -550,7 +557,16 @@ fn emit_definition_parser( } } } - DotModifierRef::Direct { type_: Some(_), .. } => { todo!() } + DotModifierRef::Direct { optional: false, type_: Some(type_), name, value } => { + let variable = name.ident(); + let variant = value.dot_capitalized(); + let parsed_variant = value.variant_capitalized(); + quote! { + any.verify(|t| *t == #token_type :: #variant).parse_next(&mut stream)?; + #variable = #type_ :: #parsed_variant; + } + } + DotModifierRef::Direct { optional: true, type_: Some(_), .. } => { todo!() } DotModifierRef::Indirect { optional, value, .. } => { let variants = value.alternatives.iter().map(|alt| { let type_ = value.type_.as_ref().unwrap(); @@ -669,7 +685,7 @@ fn emit_definition_parser( DotModifierRef::Direct { optional: false, name, - type_: Some(type_), + type_: Some(_), .. } => { let variable = name.ident(); @@ -700,11 +716,11 @@ fn emit_definition_parser( let comma = if idx == 0 { quote! { empty } } else { - quote! { any.verify(|t| *t == #token_type::Comma) } + quote! { any.verify(|t| *t == #token_type::Comma).void() } }; let pre_bracket = if arg.pre_bracket { quote! { - any.verify(|t| *t == #token_type::LBracket).map(|_| ()) + any.verify(|t| *t == #token_type::LBracket).void() } } else { quote! { @@ -713,7 +729,7 @@ fn emit_definition_parser( }; let pre_pipe = if arg.pre_pipe { quote! { - any.verify(|t| *t == #token_type::Or).map(|_| ()) + any.verify(|t| *t == #token_type::Or).void() } } else { quote! { @@ -736,24 +752,42 @@ fn emit_definition_parser( }; let post_bracket = if arg.post_bracket { quote! { - any.verify(|t| *t == #token_type::RBracket).map(|_| ()) + any.verify(|t| *t == #token_type::RBracket).void() } } else { quote! { empty } }; - let parser = quote! { - (#comma, #pre_bracket, #pre_pipe, #can_be_negated, #operand, #post_bracket) + let unified = if arg.unified { + quote! { + opt(any.verify(|t| *t == #token_type::DotUnified).void()).map(|u| u.is_some()) + } + } else { + quote! { + empty + } + }; + let pattern = quote! { + (#comma, #pre_bracket, #pre_pipe, #can_be_negated, #operand, #post_bracket, #unified) }; let arg_name = &arg.ident; + let inner_parser = if arg.unified { + quote! { + #pattern.map(|(_, _, _, _, name, _, unified)| (name, unified)) + } + } else { + quote! { + #pattern.map(|(_, _, _, _, name, _, _)| name) + } + }; if arg.optional { quote! { - let #arg_name = opt(#parser.map(|(_, _, _, _, name, _)| name)).parse_next(stream)?; + let #arg_name = opt(#inner_parser).parse_next(stream)?; } } else { quote! { - let #arg_name = #parser.map(|(_, _, _, _, name, _)| name).parse_next(stream)?; + let #arg_name = #inner_parser.parse_next(stream)?; } } }); @@ -812,6 +846,10 @@ fn write_definitions_into_tokens<'a>( }; variants.push(arg); } + variants.push(parse_quote! { + #[token(".unified")] + DotUnified + }); (all_opcodes, all_modifiers) } diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index c45a241a..a471b8ec 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -25,6 +25,46 @@ pub enum StCacheOperator { Writethrough, } +#[derive(Copy, Clone, PartialEq, Eq)] +pub enum LdCacheOperator { + Cached, + L2Only, + Streaming, + LastUse, + Uncached, +} + + + +#[derive(Copy, Clone)] +pub enum ArithDetails { + Integer(ArithInteger), + Float(ArithFloat), +} + +impl ArithDetails { + pub fn type_(&self) -> super::ScalarType { + match self { + ArithDetails::Integer(t) => t.type_, + ArithDetails::Float(arith) => arith.type_, + } + } +} + +#[derive(Copy, Clone)] +pub struct ArithInteger { + pub type_: super::ScalarType, + pub saturate: bool, +} + +#[derive(Copy, Clone)] +pub struct ArithFloat { + pub type_: super::ScalarType, + pub rounding: Option, + pub flush_to_zero: Option, + pub saturate: bool, +} + #[derive(Copy, Clone, PartialEq, Eq)] pub enum LdStQualifier { Weak, @@ -33,3 +73,11 @@ pub enum LdStQualifier { Acquire(MemScope), Release(MemScope), } + +#[derive(PartialEq, Eq, Copy, Clone)] +pub enum RoundingMode { + NearestEven, + Zero, + NegativeInf, + PositiveInf, +} diff --git a/ptx_parser/src/main.rs b/ptx_parser/src/main.rs index dd9e6d21..eb137a5b 100644 --- a/ptx_parser/src/main.rs +++ b/ptx_parser/src/main.rs @@ -73,7 +73,7 @@ gen::generate_instruction_type!( }, Add { type: { data.type_().into() }, - data: ArithDetails, + data: ast::ArithDetails, arguments: { dst: T, src1: T, @@ -101,7 +101,7 @@ gen::generate_instruction_type!( pub struct LdDetails { pub qualifier: ast::LdStQualifier, pub state_space: StateSpace, - pub caching: LdCacheOperator, + pub caching: ast::LdCacheOperator, pub typ: Type, pub non_coherent: bool, } @@ -145,15 +145,6 @@ pub enum RoundingMode { PositiveInf, } -#[derive(Copy, Clone, PartialEq, Eq)] -pub enum LdCacheOperator { - Cached, - L2Only, - Streaming, - LastUse, - Uncached, -} - #[derive(PartialEq, Eq, Clone, Hash)] pub enum Type { // .param.b32 foo; @@ -203,6 +194,18 @@ impl From for ast::StCacheOperator { } } +impl From for ast::LdCacheOperator { + fn from(value: RawLdCacheOperator) -> Self { + match value { + RawLdCacheOperator::Ca => ast::LdCacheOperator::Cached, + RawLdCacheOperator::Cg => ast::LdCacheOperator::L2Only, + RawLdCacheOperator::Cs => ast::LdCacheOperator::Streaming, + RawLdCacheOperator::Lu => ast::LdCacheOperator::LastUse, + RawLdCacheOperator::Cv => ast::LdCacheOperator::Uncached, + } + } +} + impl From for ast::LdStQualifier { fn from(value: RawLdStQualifier) -> Self { match value { @@ -212,6 +215,17 @@ impl From for ast::LdStQualifier { } } +impl From for ast::RoundingMode { + fn from(value: RawFloatRounding) -> Self { + match value { + RawFloatRounding::Rn => ast::RoundingMode::NearestEven, + RawFloatRounding::Rz => ast::RoundingMode::Zero, + RawFloatRounding::Rm => ast::RoundingMode::NegativeInf, + RawFloatRounding::Rp => ast::RoundingMode::PositiveInf, + } + } +} + type PtxParserState = Vec; type PtxParser<'a, 'input> = Stateful<&'a [Token<'input>], PtxParserState>; @@ -334,6 +348,12 @@ fn immediate_value<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult( + stream: &mut PtxParser<'a, 'input>, +) -> PResult>>> { + repeat(3.., terminated(parse_instruction, Token::Semicolon)).parse_next(stream) +} + impl ast::ParsedOperand { fn parse<'a, 'input>( stream: &mut PtxParser<'a, 'input>, @@ -518,7 +538,7 @@ impl<'input, I: Stream + StreamIsPartial, E: ParserError> Parse for Token<'input> { fn parse_next(&mut self, input: &mut I) -> PResult { - any.parse_next(input) + any.verify(|t| t == self).parse_next(input) } } @@ -540,14 +560,14 @@ derive_parser!( Comma, #[token(".")] Dot, + #[token(";")] + Semicolon, #[regex(r"[a-zA-Z][a-zA-Z0-9_$]*|[_$%][a-zA-Z0-9_$]+", |lex| lex.slice(), priority = 0)] Ident(&'input str), #[token("|")] Or, #[token("!")] Not, - #[token(";")] - Semicolon, #[token("[")] LBracket, #[token("]")] @@ -675,23 +695,82 @@ derive_parser!( // https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-ld ld{.weak}{.ss}{.cop}{.level::eviction_priority}{.level::cache_hint}{.level::prefetch_size}{.vec}.type d, [a]{.unified}{, cache_policy} => { - todo!() + let (a, unified) = a; + if level_eviction_priority.is_some() || level_cache_hint || level_prefetch_size.is_some() || unified || cache_policy.is_some() { + state.push(PtxError::Todo); + } + Instruction::Ld { + data: LdDetails { + qualifier: weak.unwrap_or(RawLdStQualifier::Weak).into(), + state_space: ss.unwrap_or(StateSpace::Generic), + caching: cop.unwrap_or(RawLdCacheOperator::Ca).into(), + typ: Type::maybe_vector(vec, type_), + non_coherent: false + }, + arguments: LdArgs { dst:d, src:a } + } } ld.volatile{.ss}{.level::prefetch_size}{.vec}.type d, [a] => { - todo!() + if level_prefetch_size.is_some() { + state.push(PtxError::Todo); + } + Instruction::Ld { + data: LdDetails { + qualifier: volatile.into(), + state_space: ss.unwrap_or(StateSpace::Generic), + caching: ast::LdCacheOperator::Cached, + typ: Type::maybe_vector(vec, type_), + non_coherent: false + }, + arguments: LdArgs { dst:d, src:a } + } } ld.relaxed.scope{.ss}{.level::eviction_priority}{.level::cache_hint}{.level::prefetch_size}{.vec}.type d, [a]{, cache_policy} => { - todo!() + if level_eviction_priority.is_some() || level_cache_hint || level_prefetch_size.is_some() || cache_policy.is_some() { + state.push(PtxError::Todo); + } + Instruction::Ld { + data: LdDetails { + qualifier: ast::LdStQualifier::Relaxed(scope), + state_space: ss.unwrap_or(StateSpace::Generic), + caching: ast::LdCacheOperator::Cached, + typ: Type::maybe_vector(vec, type_), + non_coherent: false + }, + arguments: LdArgs { dst:d, src:a } + } } ld.acquire.scope{.ss}{.level::eviction_priority}{.level::cache_hint}{.level::prefetch_size}{.vec}.type d, [a]{, cache_policy} => { - todo!() + if level_eviction_priority.is_some() || level_cache_hint || level_prefetch_size.is_some() || cache_policy.is_some() { + state.push(PtxError::Todo); + } + Instruction::Ld { + data: LdDetails { + qualifier: ast::LdStQualifier::Acquire(scope), + state_space: ss.unwrap_or(StateSpace::Generic), + caching: ast::LdCacheOperator::Cached, + typ: Type::maybe_vector(vec, type_), + non_coherent: false + }, + arguments: LdArgs { dst:d, src:a } + } } ld.mmio.relaxed.sys{.global}.type d, [a] => { - todo!() + state.push(PtxError::Todo); + Instruction::Ld { + data: LdDetails { + qualifier: ast::LdStQualifier::Relaxed(MemScope::Sys), + state_space: global.unwrap_or(StateSpace::Generic), + caching: ast::LdCacheOperator::Cached, + typ: type_.into(), + non_coherent: false + }, + arguments: LdArgs { dst:d, src:a } + } } .ss: StateSpace = { .const, .global, .local, .param{::entry, ::func}, .shared{::cta, ::cluster} }; - .cop: RawCacheOp = { .ca, .cg, .cs, .lu, .cv }; + .cop: RawLdCacheOperator = { .ca, .cg, .cs, .lu, .cv }; .level::eviction_priority: EvictionPriority = { .L1::evict_normal, .L1::evict_unchanged, .L1::evict_first, .L1::evict_last, .L1::no_allocate }; .level::cache_hint = { .L2::cache_hint }; @@ -702,47 +781,144 @@ derive_parser!( .u8, .u16, .u32, .u64, .s8, .s16, .s32, .s64, .f32, .f64 }; + RawLdStQualifier = { .weak, .volatile }; + StateSpace = { .global }; // https://docs.nvidia.com/cuda/parallel-thread-execution/#integer-arithmetic-instructions-add add.type d, a, b => { - todo!() + Instruction::Add { + data: ast::ArithDetails::Integer( + ast::ArithInteger { + type_, + saturate: false + } + ), + arguments: AddArgs { + dst: d, src1: a, src2: b + } + } } add{.sat}.s32 d, a, b => { - todo!() + Instruction::Add { + data: ast::ArithDetails::Integer( + ast::ArithInteger { + type_: s32, + saturate: sat + } + ), + arguments: AddArgs { + dst: d, src1: a, src2: b + } + } } .type: ScalarType = { .u16, .u32, .u64, .s16, .s64, .u16x2, .s16x2 }; + ScalarType = { .s32 }; // https://docs.nvidia.com/cuda/parallel-thread-execution/#floating-point-instructions-add add{.rnd}{.ftz}{.sat}.f32 d, a, b => { - todo!() + Instruction::Add { + data: ast::ArithDetails::Float( + ast::ArithFloat { + type_: f32, + rounding: rnd.map(Into::into), + flush_to_zero: Some(ftz), + saturate: sat + } + ), + arguments: AddArgs { + dst: d, src1: a, src2: b + } + } } add{.rnd}.f64 d, a, b => { - todo!() + Instruction::Add { + data: ast::ArithDetails::Float( + ast::ArithFloat { + type_: f64, + rounding: rnd.map(Into::into), + flush_to_zero: None, + saturate: false + } + ), + arguments: AddArgs { + dst: d, src1: a, src2: b + } + } } .rnd: RawFloatRounding = { .rn, .rz, .rm, .rp }; + ScalarType = { .f32, .f64 }; // https://docs.nvidia.com/cuda/parallel-thread-execution/#half-precision-floating-point-instructions-add add{.rnd}{.ftz}{.sat}.f16 d, a, b => { - todo!() + Instruction::Add { + data: ast::ArithDetails::Float( + ast::ArithFloat { + type_: f16, + rounding: rnd.map(Into::into), + flush_to_zero: Some(ftz), + saturate: sat + } + ), + arguments: AddArgs { + dst: d, src1: a, src2: b + } + } } add{.rnd}{.ftz}{.sat}.f16x2 d, a, b => { - todo!() + Instruction::Add { + data: ast::ArithDetails::Float( + ast::ArithFloat { + type_: f16x2, + rounding: rnd.map(Into::into), + flush_to_zero: Some(ftz), + saturate: sat + } + ), + arguments: AddArgs { + dst: d, src1: a, src2: b + } + } } add{.rnd}.bf16 d, a, b => { - todo!() + Instruction::Add { + data: ast::ArithDetails::Float( + ast::ArithFloat { + type_: bf16, + rounding: rnd.map(Into::into), + flush_to_zero: None, + saturate: false + } + ), + arguments: AddArgs { + dst: d, src1: a, src2: b + } + } } add{.rnd}.bf16x2 d, a, b => { - todo!() + Instruction::Add { + data: ast::ArithDetails::Float( + ast::ArithFloat { + type_: bf16x2, + rounding: rnd.map(Into::into), + flush_to_zero: None, + saturate: false + } + ), + arguments: AddArgs { + dst: d, src1: a, src2: b + } + } } .rnd: RawFloatRounding = { .rn }; + ScalarType = { .f16, .f16x2, .bf16, .bf16x2 }; - ret => { - todo!() + ret{.uni} => { + Instruction::Ret { data: RetData { uniform: uni } } } ); @@ -776,7 +952,8 @@ fn main() { input: &tokens[..], state: Vec::new(), }; - parse_instruction(&mut stream).unwrap(); + let fn_body = fn_body.parse(stream).unwrap(); + println!("{}", fn_body.len()); //parse_prefix(&mut lexer); let mut parser = &*tokens; println!("{}", mem::size_of::()); From 91dbbb372b04c40e0f0ad60cbeda621fb592ee01 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Fri, 16 Aug 2024 18:29:13 +0200 Subject: [PATCH 07/47] Move all types to a separate module --- gen/src/lib.rs | 10 +- gen_impl/src/lib.rs | 22 +++-- ptx_parser/src/ast.rs | 132 ++++++++++++++++++++++++++- ptx_parser/src/main.rs | 203 ++++------------------------------------- 4 files changed, 168 insertions(+), 199 deletions(-) diff --git a/gen/src/lib.rs b/gen/src/lib.rs index 93b31fe9..67b276e2 100644 --- a/gen/src/lib.rs +++ b/gen/src/lib.rs @@ -178,9 +178,9 @@ impl SingleOpcodeDefinition { .chain(self.arguments.0.iter().map(|arg| { let name = &arg.ident; let arg_type = if arg.unified { - quote! { (ParsedOperand<'input>, bool) } + quote! { (ParsedOperandStr<'input>, bool) } } else { - quote! { ParsedOperand<'input> } + quote! { ParsedOperandStr<'input> } }; if arg.optional { quote! { #name : Option<#arg_type> } @@ -415,7 +415,7 @@ fn emit_parse_function( let code_block = &def.code_block.0; let args = def.function_arguments_declarations(); quote! { - fn #fn_name<'input>(state: &mut PtxParserState, #(#args),* ) -> Instruction> #code_block + fn #fn_name<'input>(state: &mut PtxParserState, #(#args),* ) -> Instruction> #code_block } }) }) @@ -506,7 +506,7 @@ fn emit_parse_function( #(#fns_)* - fn parse_instruction<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> winnow::error::PResult>> + fn parse_instruction<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> winnow::error::PResult>> { use winnow::Parser; use winnow::token::*; @@ -747,7 +747,7 @@ fn emit_definition_parser( }; let operand = { quote! { - ParsedOperand::parse + ParsedOperandStr::parse } }; let post_bracket = if arg.post_bracket { diff --git a/gen_impl/src/lib.rs b/gen_impl/src/lib.rs index 6b606af2..7160603d 100644 --- a/gen_impl/src/lib.rs +++ b/gen_impl/src/lib.rs @@ -2,11 +2,13 @@ use proc_macro2::TokenStream; use quote::{format_ident, quote, ToTokens}; use syn::{ braced, parse::Parse, punctuated::Punctuated, token, Expr, Ident, Token, Type, TypeParam, + Visibility, }; pub mod parser; pub struct GenerateInstructionType { + pub visibility: Option, pub name: Ident, pub type_parameters: Punctuated, pub short_parameters: Punctuated, @@ -16,16 +18,17 @@ pub struct GenerateInstructionType { impl GenerateInstructionType { pub fn emit_arg_types(&self, tokens: &mut TokenStream) { for v in self.variants.iter() { - v.emit_type(&self.type_parameters, tokens); + v.emit_type(&self.visibility, &self.type_parameters, tokens); } } pub fn emit_instruction_type(&self, tokens: &mut TokenStream) { + let vis = &self.visibility; let type_name = &self.name; let type_parameters = &self.type_parameters; let variants = self.variants.iter().map(|v| v.emit_variant()); quote! { - enum #type_name<#type_parameters> { + #vis enum #type_name<#type_parameters> { #(#variants),* } } @@ -133,6 +136,11 @@ impl VisitKind { impl Parse for GenerateInstructionType { fn parse(input: syn::parse::ParseStream) -> syn::Result { + let visibility = if !input.peek(Token![enum]) { + Some(input.parse::()?) + } else { + None + }; input.parse::()?; let name = input.parse::()?; input.parse::()?; @@ -146,6 +154,7 @@ impl Parse for GenerateInstructionType { braced!(variants_buffer in input); let variants = variants_buffer.parse_terminated(InstructionVariant::parse, Token![,])?; Ok(Self { + visibility, name, type_parameters, short_parameters, @@ -262,6 +271,7 @@ impl InstructionVariant { fn emit_type( &self, + vis: &Option, type_parameters: &Punctuated, tokens: &mut TokenStream, ) { @@ -275,9 +285,9 @@ impl InstructionVariant { } else { None }; - let fields = arguments.fields.iter().map(ArgumentField::emit_field); + let fields = arguments.fields.iter().map(|f| f.emit_field(vis)); quote! { - struct #name #type_parameters { + #vis struct #name #type_parameters { #(#fields),* } } @@ -559,11 +569,11 @@ impl ArgumentField { } } - fn emit_field(&self) -> TokenStream { + fn emit_field(&self, vis: &Option) -> TokenStream { let name = &self.name; let type_ = &self.repr; quote! { - #name: #type_ + #vis #name: #type_ } } } diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index a471b8ec..302aef76 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -1,4 +1,113 @@ -use super::MemScope; +use super::{MemScope, ScalarType, VectorPrefix, StateSpace}; + +gen::generate_instruction_type!( + pub enum Instruction { + Mov { + type: { &data.typ }, + data: MovDetails, + arguments: { + dst: T, + src: T + } + }, + Ld { + type: { &data.typ }, + data: LdDetails, + arguments: { + dst: T, + src: { + repr: T, + space: { data.state_space }, + } + } + }, + Add { + type: { data.type_().into() }, + data: ArithDetails, + arguments: { + dst: T, + src1: T, + src2: T, + } + }, + St { + type: { &data.typ }, + data: StData, + arguments: { + src1: { + repr: T, + space: { data.state_space }, + }, + src2: T, + } + }, + Ret { + data: RetData + }, + Trap { } + } +); + +pub trait Visitor { + fn visit(&mut self, args: &T, type_: &Type, space: StateSpace, is_dst: bool); +} + +pub trait VisitorMut { + fn visit(&mut self, args: &mut T, type_: &Type, space: StateSpace, is_dst: bool); +} + +pub trait VisitorMap { + fn visit(&mut self, args: From, type_: &Type, space: StateSpace, is_dst: bool) -> To; +} + +#[derive(PartialEq, Eq, Clone, Hash)] +pub enum Type { + // .param.b32 foo; + Scalar(ScalarType), + // .param.v2.b32 foo; + Vector(ScalarType, u8), + // .param.b32 foo[4]; + Array(ScalarType, Vec), +} + +impl Type { + pub(crate) fn maybe_vector(vector: Option, scalar: ScalarType) -> Self { + match vector { + Some(VectorPrefix::V2) => Type::Vector(scalar, 2), + Some(VectorPrefix::V4) => Type::Vector(scalar, 4), + None => Type::Scalar(scalar), + } + } +} + +impl From for Type { + fn from(value: ScalarType) -> Self { + Type::Scalar(value) + } +} + +#[derive(Clone)] +pub struct MovDetails { + pub typ: super::Type, + pub src_is_address: bool, + // two fields below are in use by member moves + pub dst_width: u8, + pub src_width: u8, + // This is in use by auto-generated movs + pub relaxed_src2_conv: bool, +} + +impl MovDetails { + pub(crate) fn new(vector: Option, scalar: ScalarType) -> Self { + MovDetails { + typ: Type::maybe_vector(vector, scalar), + src_is_address: false, + dst_width: 0, + src_width: 0, + relaxed_src2_conv: false, + } + } +} #[derive(Clone)] pub enum ParsedOperand { @@ -81,3 +190,24 @@ pub enum RoundingMode { NegativeInf, PositiveInf, } + +pub struct LdDetails { + pub qualifier: LdStQualifier, + pub state_space: StateSpace, + pub caching: LdCacheOperator, + pub typ: Type, + pub non_coherent: bool, +} + + +pub struct StData { + pub qualifier: LdStQualifier, + pub state_space: StateSpace, + pub caching: StCacheOperator, + pub typ: Type, +} + +#[derive(Copy, Clone)] +pub struct RetData { + pub uniform: bool, +} diff --git a/ptx_parser/src/main.rs b/ptx_parser/src/main.rs index eb137a5b..34c27dae 100644 --- a/ptx_parser/src/main.rs +++ b/ptx_parser/src/main.rs @@ -12,176 +12,7 @@ use winnow::{ use winnow::{prelude::*, Stateful}; mod ast; - -pub trait Operand {} - -pub trait Visitor { - fn visit(&mut self, args: &T, type_: &Type, space: StateSpace, is_dst: bool); -} - -pub trait VisitorMut { - fn visit(&mut self, args: &mut T, type_: &Type, space: StateSpace, is_dst: bool); -} - -pub trait VisitorMap { - fn visit(&mut self, args: From, type_: &Type, space: StateSpace, is_dst: bool) -> To; -} - -#[derive(Clone)] -pub struct MovDetails { - pub typ: Type, - pub src_is_address: bool, - // two fields below are in use by member moves - pub dst_width: u8, - pub src_width: u8, - // This is in use by auto-generated movs - pub relaxed_src2_conv: bool, -} - -impl MovDetails { - fn new(vector: Option, scalar: ScalarType) -> Self { - MovDetails { - typ: Type::maybe_vector(vector, scalar), - src_is_address: false, - dst_width: 0, - src_width: 0, - relaxed_src2_conv: false, - } - } -} - -gen::generate_instruction_type!( - enum Instruction { - Mov { - type: { &data.typ }, - data: MovDetails, - arguments: { - dst: T, - src: T - } - }, - Ld { - type: { &data.typ }, - data: LdDetails, - arguments: { - dst: T, - src: { - repr: T, - space: { data.state_space }, - } - } - }, - Add { - type: { data.type_().into() }, - data: ast::ArithDetails, - arguments: { - dst: T, - src1: T, - src2: T, - } - }, - St { - type: { &data.typ }, - data: StData, - arguments: { - src1: { - repr: T, - space: { data.state_space }, - }, - src2: T, - } - }, - Ret { - data: RetData - }, - Trap { } - } -); - -pub struct LdDetails { - pub qualifier: ast::LdStQualifier, - pub state_space: StateSpace, - pub caching: ast::LdCacheOperator, - pub typ: Type, - pub non_coherent: bool, -} - -#[derive(Copy, Clone)] -pub enum ArithDetails { - Unsigned(ScalarType), - Signed(ArithSInt), - Float(ArithFloat), -} - -impl ArithDetails { - fn type_(&self) -> ScalarType { - match self { - ArithDetails::Unsigned(t) => *t, - ArithDetails::Signed(arith) => arith.typ, - ArithDetails::Float(arith) => arith.typ, - } - } -} - -#[derive(Copy, Clone)] -pub struct ArithSInt { - pub typ: ScalarType, - pub saturate: bool, -} - -#[derive(Copy, Clone)] -pub struct ArithFloat { - pub typ: ScalarType, - pub rounding: Option, - pub flush_to_zero: Option, - pub saturate: bool, -} - -#[derive(PartialEq, Eq, Copy, Clone)] -pub enum RoundingMode { - NearestEven, - Zero, - NegativeInf, - PositiveInf, -} - -#[derive(PartialEq, Eq, Clone, Hash)] -pub enum Type { - // .param.b32 foo; - Scalar(ScalarType), - // .param.v2.b32 foo; - Vector(ScalarType, u8), - // .param.b32 foo[4]; - Array(ScalarType, Vec), -} - -impl Type { - fn maybe_vector(vector: Option, scalar: ScalarType) -> Self { - match vector { - Some(VectorPrefix::V2) => Type::Vector(scalar, 2), - Some(VectorPrefix::V4) => Type::Vector(scalar, 4), - None => Type::Scalar(scalar), - } - } -} - -impl From for Type { - fn from(value: ScalarType) -> Self { - Type::Scalar(value) - } -} - -pub struct StData { - pub qualifier: ast::LdStQualifier, - pub state_space: StateSpace, - pub caching: ast::StCacheOperator, - pub typ: Type, -} - -#[derive(Copy, Clone)] -pub struct RetData { - pub uniform: bool, -} +pub use ast::*; impl From for ast::StCacheOperator { fn from(value: RawStCacheOperator) -> Self { @@ -350,7 +181,7 @@ fn immediate_value<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult( stream: &mut PtxParser<'a, 'input>, -) -> PResult>>> { +) -> PResult>>> { repeat(3.., terminated(parse_instruction, Token::Semicolon)).parse_next(stream) } @@ -550,7 +381,7 @@ impl<'input, I: Stream + StreamIsPartial, E: ParserError> Parse // * If it is mandatory then it is skipped // * If it is optional then its type is `bool` -type ParsedOperand<'input> = ast::ParsedOperand<&'input str>; +type ParsedOperandStr<'input> = ast::ParsedOperand<&'input str>; derive_parser!( #[derive(Logos, PartialEq, Eq, Debug, Clone, Copy)] @@ -601,7 +432,7 @@ derive_parser!( // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov mov{.vec}.type d, a => { Instruction::Mov { - data: MovDetails::new(vec, type_), + data: ast::MovDetails::new(vec, type_), arguments: MovArgs { dst: d, src: a }, } } @@ -622,7 +453,7 @@ derive_parser!( qualifier: weak.unwrap_or(RawLdStQualifier::Weak).into(), state_space: ss.unwrap_or(StateSpace::Generic), caching: cop.unwrap_or(RawStCacheOperator::Wb).into(), - typ: Type::maybe_vector(vec, type_) + typ: ast::Type::maybe_vector(vec, type_) }, arguments: StArgs { src1:a, src2:b } } @@ -633,7 +464,7 @@ derive_parser!( qualifier: volatile.into(), state_space: ss.unwrap_or(StateSpace::Generic), caching: ast::StCacheOperator::Writeback, - typ: Type::maybe_vector(vec, type_) + typ: ast::Type::maybe_vector(vec, type_) }, arguments: StArgs { src1:a, src2:b } } @@ -647,7 +478,7 @@ derive_parser!( qualifier: ast::LdStQualifier::Relaxed(scope), state_space: ss.unwrap_or(StateSpace::Generic), caching: ast::StCacheOperator::Writeback, - typ: Type::maybe_vector(vec, type_) + typ: ast::Type::maybe_vector(vec, type_) }, arguments: StArgs { src1:a, src2:b } } @@ -661,7 +492,7 @@ derive_parser!( qualifier: ast::LdStQualifier::Release(scope), state_space: ss.unwrap_or(StateSpace::Generic), caching: ast::StCacheOperator::Writeback, - typ: Type::maybe_vector(vec, type_) + typ: ast::Type::maybe_vector(vec, type_) }, arguments: StArgs { src1:a, src2:b } } @@ -669,13 +500,13 @@ derive_parser!( st.mmio.relaxed.sys{.global}.type [a], b => { state.push(PtxError::Todo); Instruction::St { - data: StData { + data: ast::StData { qualifier: ast::LdStQualifier::Relaxed(MemScope::Sys), state_space: global.unwrap_or(StateSpace::Generic), caching: ast::StCacheOperator::Writeback, typ: type_.into() }, - arguments: StArgs { src1:a, src2:b } + arguments: ast::StArgs { src1:a, src2:b } } } @@ -704,7 +535,7 @@ derive_parser!( qualifier: weak.unwrap_or(RawLdStQualifier::Weak).into(), state_space: ss.unwrap_or(StateSpace::Generic), caching: cop.unwrap_or(RawLdCacheOperator::Ca).into(), - typ: Type::maybe_vector(vec, type_), + typ: ast::Type::maybe_vector(vec, type_), non_coherent: false }, arguments: LdArgs { dst:d, src:a } @@ -719,7 +550,7 @@ derive_parser!( qualifier: volatile.into(), state_space: ss.unwrap_or(StateSpace::Generic), caching: ast::LdCacheOperator::Cached, - typ: Type::maybe_vector(vec, type_), + typ: ast::Type::maybe_vector(vec, type_), non_coherent: false }, arguments: LdArgs { dst:d, src:a } @@ -734,7 +565,7 @@ derive_parser!( qualifier: ast::LdStQualifier::Relaxed(scope), state_space: ss.unwrap_or(StateSpace::Generic), caching: ast::LdCacheOperator::Cached, - typ: Type::maybe_vector(vec, type_), + typ: ast::Type::maybe_vector(vec, type_), non_coherent: false }, arguments: LdArgs { dst:d, src:a } @@ -749,7 +580,7 @@ derive_parser!( qualifier: ast::LdStQualifier::Acquire(scope), state_space: ss.unwrap_or(StateSpace::Generic), caching: ast::LdCacheOperator::Cached, - typ: Type::maybe_vector(vec, type_), + typ: ast::Type::maybe_vector(vec, type_), non_coherent: false }, arguments: LdArgs { dst:d, src:a } @@ -931,7 +762,7 @@ fn main() { println!("{}", mem::size_of::()); let mut input: &[Token] = &[][..]; - let x = opt(any::<_, ContextError>.verify_map(|t| { + let x = opt(any::<_, ContextError>.verify_map(|_| { println!("MAP"); Some(true) })) @@ -948,13 +779,11 @@ fn main() { ); let tokens = lexer.map(|t| t.unwrap()).collect::>(); println!("{:?}", &tokens); - let mut stream = PtxParser { + let stream = PtxParser { input: &tokens[..], state: Vec::new(), }; let fn_body = fn_body.parse(stream).unwrap(); println!("{}", fn_body.len()); - //parse_prefix(&mut lexer); - let mut parser = &*tokens; println!("{}", mem::size_of::()); } From 77de5c7a1522ac3608f52540aa405d302199a14c Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Sun, 18 Aug 2024 03:45:09 +0200 Subject: [PATCH 08/47] Parse simplest vector add kernel --- gen/src/lib.rs | 8 +- gen_impl/src/parser.rs | 3 +- ptx_parser/Cargo.toml | 1 + ptx_parser/src/ast.rs | 89 ++++++- ptx_parser/src/main.rs | 545 ++++++++++++++++++++++++++++++++++++++++- 5 files changed, 629 insertions(+), 17 deletions(-) diff --git a/gen/src/lib.rs b/gen/src/lib.rs index 67b276e2..ebddf03f 100644 --- a/gen/src/lib.rs +++ b/gen/src/lib.rs @@ -11,15 +11,17 @@ use syn::{ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#vectors // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#fundamental-types // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#alternate-floating-point-data-formats +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#packed-floating-point-data-types +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#packed-integer-data-types #[rustfmt::skip] static POSTFIX_MODIFIERS: &[&str] = &[ ".v2", ".v4", - ".s8", ".s16", ".s32", ".s64", - ".u8", ".u16", ".u32", ".u64", + ".s8", ".s16", ".s16x2", ".s32", ".s64", + ".u8", ".u16", ".u16x2", ".u32", ".u64", ".f16", ".f16x2", ".f32", ".f64", ".b8", ".b16", ".b32", ".b64", ".b128", ".pred", - ".bf16", ".e4m3", ".e5m2", ".tf32", + ".bf16", ".bf16x2", ".e4m3", ".e5m2", ".tf32", ]; static POSTFIX_TYPES: &[&str] = &["ScalarType", "VectorPrefix"]; diff --git a/gen_impl/src/parser.rs b/gen_impl/src/parser.rs index 6834cbca..519bf129 100644 --- a/gen_impl/src/parser.rs +++ b/gen_impl/src/parser.rs @@ -332,7 +332,8 @@ impl DotModifier { capitalize = true; continue; } - let c = if capitalize { + // Special hack to emit `BF16`` instead of `Bf16`` + let c = if capitalize || c == 'f' && result.ends_with('B') { capitalize = false; c.to_ascii_uppercase() } else { diff --git a/ptx_parser/Cargo.toml b/ptx_parser/Cargo.toml index 951d508a..4f32860c 100644 --- a/ptx_parser/Cargo.toml +++ b/ptx_parser/Cargo.toml @@ -8,3 +8,4 @@ logos = "0.14" winnow = { version = "0.6.18", features = ["debug"] } gen = { path = "../gen" } thiserror = "1.0" +bitflags = "1.2" diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index 302aef76..2dabf3e8 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -1,4 +1,31 @@ -use super::{MemScope, ScalarType, VectorPrefix, StateSpace}; +use super::{MemScope, ScalarType, StateSpace, VectorPrefix}; +use bitflags::bitflags; + +pub enum Statement { + Label(P::Ident), + Variable(MultiVariable), + Instruction(Option>, Instruction

), + Block(Vec>), +} + +pub struct MultiVariable { + pub var: Variable, + pub count: Option, +} + +#[derive(Clone)] +pub struct Variable { + pub align: Option, + pub v_type: Type, + pub state_space: StateSpace, + pub name: ID, + pub array_init: Vec, +} + +pub struct PredAt { + pub not: bool, + pub label: ID, +} gen::generate_instruction_type!( pub enum Instruction { @@ -118,6 +145,14 @@ pub enum ParsedOperand { VecPack(Vec), } +impl Operand for ParsedOperand { + type Ident = Ident; +} + +pub trait Operand { + type Ident; +} + #[derive(Copy, Clone)] pub enum ImmediateValue { U64(u64), @@ -143,8 +178,6 @@ pub enum LdCacheOperator { Uncached, } - - #[derive(Copy, Clone)] pub enum ArithDetails { Integer(ArithInteger), @@ -199,7 +232,6 @@ pub struct LdDetails { pub non_coherent: bool, } - pub struct StData { pub qualifier: LdStQualifier, pub state_space: StateSpace, @@ -211,3 +243,52 @@ pub struct StData { pub struct RetData { pub uniform: bool, } + +#[derive(Copy, Clone, PartialEq, Eq)] +pub enum TuningDirective { + MaxNReg(u32), + MaxNtid(u32, u32, u32), + ReqNtid(u32, u32, u32), + MinNCtaPerSm(u32), +} + +pub struct MethodDeclaration<'input, ID> { + pub return_arguments: Vec>, + pub name: MethodName<'input, ID>, + pub input_arguments: Vec>, + pub shared_mem: Option, +} + +#[derive(Hash, PartialEq, Eq, Copy, Clone)] +pub enum MethodName<'input, ID> { + Kernel(&'input str), + Func(ID), +} + +bitflags! { + pub struct LinkingDirective: u8 { + const NONE = 0b000; + const EXTERN = 0b001; + const VISIBLE = 0b10; + const WEAK = 0b100; + } +} + +pub struct Function<'a, ID, S> { + pub func_directive: MethodDeclaration<'a, ID>, + pub tuning: Vec, + pub body: Option>, +} + +pub enum Directive<'input, O: Operand> { + Variable(LinkingDirective, Variable), + Method( + LinkingDirective, + Function<'input, &'input str, Statement>, + ), +} + +pub struct Module<'input> { + pub version: (u8, u8), + pub directives: Vec>>, +} diff --git a/ptx_parser/src/main.rs b/ptx_parser/src/main.rs index 34c27dae..7a29e635 100644 --- a/ptx_parser/src/main.rs +++ b/ptx_parser/src/main.rs @@ -2,7 +2,10 @@ use gen::derive_parser; use logos::Logos; use std::mem; use std::num::{ParseFloatError, ParseIntError}; +use winnow::ascii::{dec_uint, digit1}; use winnow::combinator::*; +use winnow::error::ErrMode; +use winnow::stream::Accumulate; use winnow::token::any; use winnow::{ error::{ContextError, ParserError}, @@ -170,6 +173,28 @@ fn s32<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult { .parse_next(stream) } +fn u8<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult { + take_error(num.map(|x| { + let (text, radix, _) = x; + match u8::from_str_radix(text, radix) { + Ok(x) => Ok(x), + Err(err) => Err((0, PtxError::from(err))), + } + })) + .parse_next(stream) +} + +fn u32<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult { + take_error(num.map(|x| { + let (text, radix, _) = x; + match u32::from_str_radix(text, radix) { + Ok(x) => Ok(x), + Err(err) => Err((0, PtxError::from(err))), + } + })) + .parse_next(stream) +} + fn immediate_value<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult { alt(( int_immediate, @@ -179,10 +204,402 @@ fn immediate_value<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult( +fn module<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult> { + ( + version, + target, + opt(address_size), + repeat_without_none(directive), + ) + .map(|(version, _, _, directives)| ast::Module { + version, + directives, + }) + .parse_next(stream) +} + +fn address_size<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> { + (Token::DotAddressSize, u8_literal(64)) + .void() + .parse_next(stream) +} + +fn version<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<(u8, u8)> { + (Token::DotVersion, u8, Token::Dot, u8) + .map(|(_, major, _, minor)| (major, minor)) + .parse_next(stream) +} + +fn target<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<(u32, Option)> { + preceded(Token::DotTarget, ident.and_then(shader_model)).parse_next(stream) +} + +fn shader_model<'a>(stream: &mut &str) -> PResult<(u32, Option)> { + ( + "sm_", + dec_uint, + opt(any.verify(|c: &char| c.is_ascii_lowercase())), + eof, + ) + .map(|(_, digits, arch_variant, _)| (digits, arch_variant)) + .parse_next(stream) +} + +fn directive<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult>>> { + (function.map(|f| { + let (linking, func) = f; + Some(ast::Directive::Method(linking, func)) + })) + .parse_next(stream) +} + +fn function<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult<( + ast::LinkingDirective, + ast::Function<'input, &'input str, ast::Statement>>, +)> { + ( + linking_directives, + method_declaration, + repeat(0.., tuning_directive), + function_body, + ) + .map(|(linking, func_directive, tuning, body)| { + ( + linking, + ast::Function { + func_directive, + tuning, + body, + }, + ) + }) + .parse_next(stream) +} + +fn linking_directives<'a, 'input>( stream: &mut PtxParser<'a, 'input>, -) -> PResult>>> { - repeat(3.., terminated(parse_instruction, Token::Semicolon)).parse_next(stream) +) -> PResult { + dispatch! { any; + Token::DotExtern => empty.value(ast::LinkingDirective::EXTERN), + Token::DotVisible => empty.value(ast::LinkingDirective::VISIBLE), + Token::DotWeak => empty.value(ast::LinkingDirective::WEAK), + _ => fail + } + .parse_next(stream) +} + +fn tuning_directive<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult { + dispatch! {any; + Token::DotMaxnreg => u32.map(ast::TuningDirective::MaxNReg), + Token::DotMaxntid => tuple1to3_u32.map(|(nx, ny, nz)| ast::TuningDirective::MaxNtid(nx, ny, nz)), + Token::DotReqntid => tuple1to3_u32.map(|(nx, ny, nz)| ast::TuningDirective::ReqNtid(nx, ny, nz)), + Token::DotMinnctapersm => u32.map(ast::TuningDirective::MinNCtaPerSm), + _ => fail + } + .parse_next(stream) +} + +fn method_declaration<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult> { + dispatch! {any; + Token::DotEntry => (ident, kernel_arguments).map(|(name, input_arguments)| ast::MethodDeclaration{ + return_arguments: Vec::new(), name: ast::MethodName::Kernel(name), input_arguments, shared_mem: None + }), + Token::DotFunc => (opt(fn_arguments), ident, fn_arguments).map(|(return_arguments, name,input_arguments)| { + let return_arguments = return_arguments.unwrap_or_else(|| Vec::new()); + let name = ast::MethodName::Func(name); + ast::MethodDeclaration{ return_arguments, name, input_arguments, shared_mem: None } + }), + _ => fail + } + .parse_next(stream) +} + +fn fn_arguments<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult>> { + delimited( + Token::LParen, + separated(0.., fn_input, Token::Comma), + Token::RParen, + ) + .parse_next(stream) +} + +fn kernel_arguments<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult>> { + delimited( + Token::LParen, + separated(0.., kernel_input, Token::Comma), + Token::RParen, + ) + .parse_next(stream) +} + +fn kernel_input<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult> { + preceded( + Token::DotParam, + variable_scalar_or_vector(StateSpace::Param), + ) + .parse_next(stream) +} + +fn fn_input<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult> { + dispatch! { any; + Token::DotParam => variable_scalar_or_vector(StateSpace::Param), + Token::DotReg => variable_scalar_or_vector(StateSpace::Reg), + _ => fail + } + .parse_next(stream) +} + +fn tuple1to3_u32<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<(u32, u32, u32)> { + struct Tuple3AccumulateU32 { + index: usize, + value: (u32, u32, u32), + } + + impl Accumulate for Tuple3AccumulateU32 { + fn initial(_: Option) -> Self { + Self { + index: 0, + value: (1, 1, 1), + } + } + + fn accumulate(&mut self, value: u32) { + match self.index { + 0 => { + self.value = (value, self.value.1, self.value.2); + self.index = 1; + } + 1 => { + self.value = (self.value.0, value, self.value.2); + self.index = 2; + } + 2 => { + self.value = (self.value.0, self.value.1, value); + self.index = 3; + } + _ => unreachable!(), + } + } + } + + separated::<_, _, Tuple3AccumulateU32, _, _, _, _>(1..3, u32, Token::Comma) + .map(|acc| acc.value) + .parse_next(stream) +} + +fn function_body<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult>>>> { + dispatch! {any; + Token::LBrace => terminated(repeat_without_none(statement), Token::RBrace).map(Some), + Token::Semicolon => empty.map(|_| None), + _ => fail + } + .parse_next(stream) +} + +fn statement<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult>>> { + alt(( + label.map(Some), + debug_directive.map(|_| None), + multi_variable.map(Some), + predicated_instruction.map(Some), + pragma.map(|_| None), + block_statement.map(Some), + )) + .parse_next(stream) +} + +fn pragma<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> { + (Token::DotPragma, Token::String, Token::Semicolon) + .void() + .parse_next(stream) +} + +fn multi_variable<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult>> { + ( + variable, + opt(delimited(Token::Lt, u32, Token::Gt)), + Token::Semicolon, + ) + .map(|(var, count, _)| ast::Statement::Variable(ast::MultiVariable { var, count })) + .parse_next(stream) +} + +fn variable<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult> { + dispatch! {any; + Token::DotReg => variable_scalar_or_vector(StateSpace::Reg), + Token::DotLocal => variable_scalar_or_vector(StateSpace::Local), + Token::DotParam => variable_scalar_or_vector(StateSpace::Param), + Token::DotShared => variable_scalar_or_vector(StateSpace::Shared), + _ => fail + } + .parse_next(stream) +} + +fn variable_scalar_or_vector<'a, 'input: 'a>( + state_space: StateSpace, +) -> impl Parser, ast::Variable<&'input str>, ContextError> { + move |stream: &mut PtxParser<'a, 'input>| { + (opt(align), scalar_vector_type, ident) + .map(|(align, v_type, name)| ast::Variable { + align, + v_type, + state_space, + name, + array_init: Vec::new(), + }) + .parse_next(stream) + } +} + +fn align<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult { + preceded(Token::DotAlign, u32).parse_next(stream) +} + +fn scalar_vector_type<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult { + ( + opt(alt(( + Token::DotV2.value(VectorPrefix::V2), + Token::DotV4.value(VectorPrefix::V4), + ))), + scalar_type, + ) + .map(|(prefix, scalar)| ast::Type::maybe_vector(prefix, scalar)) + .parse_next(stream) +} + +fn scalar_type<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult { + any.verify_map(|t| { + Some(match t { + Token::DotS8 => ScalarType::S8, + Token::DotS16 => ScalarType::S16, + Token::DotS16x2 => ScalarType::S16x2, + Token::DotS32 => ScalarType::S32, + Token::DotS64 => ScalarType::S64, + Token::DotU8 => ScalarType::U8, + Token::DotU16 => ScalarType::U16, + Token::DotU16x2 => ScalarType::U16x2, + Token::DotU32 => ScalarType::U32, + Token::DotU64 => ScalarType::U64, + Token::DotB8 => ScalarType::B8, + Token::DotB16 => ScalarType::B16, + Token::DotB32 => ScalarType::B32, + Token::DotB64 => ScalarType::B64, + Token::DotB128 => ScalarType::B128, + Token::DotPred => ScalarType::Pred, + Token::DotF16 => ScalarType::F16, + Token::DotF16x2 => ScalarType::F16x2, + Token::DotF32 => ScalarType::F32, + Token::DotF64 => ScalarType::F64, + Token::DotBF16 => ScalarType::BF16, + Token::DotBF16x2 => ScalarType::BF16x2, + _ => return None, + }) + }) + .parse_next(stream) +} + +fn predicated_instruction<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult>> { + (opt(pred_at), parse_instruction, Token::Semicolon) + .map(|(p, i, _)| ast::Statement::Instruction(p, i)) + .parse_next(stream) +} + +fn pred_at<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult> { + (Token::At, opt(Token::Not), ident) + .map(|(_, not, label)| ast::PredAt { + not: not.is_some(), + label, + }) + .parse_next(stream) +} + +fn label<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult>> { + terminated(ident, Token::Colon) + .map(|l| ast::Statement::Label(l)) + .parse_next(stream) +} + +fn debug_directive<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> { + ( + Token::DotLoc, + u32, + u32, + u32, + opt(( + Token::Comma, + ident_literal("function_name"), + ident, + dispatch! { any; + Token::Comma => (ident_literal("inlined_at"), u32, u32, u32).void(), + Token::Plus => (u32, Token::Comma, ident_literal("inlined_at"), u32, u32, u32).void(), + _ => fail + }, + )), + ) + .void() + .parse_next(stream) +} + +fn block_statement<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult>> { + delimited(Token::LBrace, repeat_without_none(statement), Token::RBrace) + .map(|s| ast::Statement::Block(s)) + .parse_next(stream) +} + +fn repeat_without_none>( + parser: impl Parser, Error>, +) -> impl Parser, Error> { + repeat(0.., parser).fold(Vec::new, |mut acc: Vec<_>, item| { + if let Some(item) = item { + acc.push(item); + } + acc + }) +} + +fn ident_literal< + 'a, + 'input, + I: Stream> + StreamIsPartial, + E: ParserError, +>( + s: &'input str, +) -> impl Parser + 'input { + move |stream: &mut I| { + any.verify(|t| matches!(t, Token::Ident(text) if *text == s)) + .void() + .parse_next(stream) + } +} + +fn u8_literal<'a, 'input>(x: u8) -> impl Parser, (), ContextError> { + move |stream: &mut PtxParser| u8.verify(|t| *t == x).void().parse_next(stream) } impl ast::ParsedOperand { @@ -391,18 +808,36 @@ derive_parser!( Comma, #[token(".")] Dot, + #[token(":")] + Colon, #[token(";")] Semicolon, + #[token("@")] + At, #[regex(r"[a-zA-Z][a-zA-Z0-9_$]*|[_$%][a-zA-Z0-9_$]+", |lex| lex.slice(), priority = 0)] Ident(&'input str), + #[regex(r#""[^"]*""#)] + String, #[token("|")] Or, #[token("!")] Not, + #[token("(")] + LParen, + #[token(")")] + RParen, #[token("[")] LBracket, #[token("]")] RBracket, + #[token("{")] + LBrace, + #[token("}")] + RBrace, + #[token("<")] + Lt, + #[token(">")] + Gt, #[regex(r"0[fF][0-9a-zA-Z]{8}", |lex| lex.slice())] F32(&'input str), #[regex(r"0[dD][0-9a-zA-Z]{16}", |lex| lex.slice())] @@ -415,6 +850,36 @@ derive_parser!( Minus, #[token("+")] Plus, + #[token(".version")] + DotVersion, + #[token(".loc")] + DotLoc, + #[token(".reg")] + DotReg, + #[token(".align")] + DotAlign, + #[token(".pragma")] + DotPragma, + #[token(".maxnreg")] + DotMaxnreg, + #[token(".maxntid")] + DotMaxntid, + #[token(".reqntid")] + DotReqntid, + #[token(".minnctapersm")] + DotMinnctapersm, + #[token(".entry")] + DotEntry, + #[token(".func")] + DotFunc, + #[token(".extern")] + DotExtern, + #[token(".visible")] + DotVisible, + #[token(".target")] + DotTarget, + #[token(".address_size")] + DotAddressSize } #[derive(Copy, Clone, PartialEq, Eq, Hash)] @@ -771,10 +1236,29 @@ fn main() { dbg!(x); let lexer = Token::lexer( " - ld.u64 temp, [in_addr]; - add.u64 temp2, temp, 1; - st.u64 [out_addr], temp2; - ret; + .version 6.5 + .target sm_30 + .address_size 64 + + .visible .entry add( + .param .u64 input, + .param .u64 output + ) + { + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .u64 temp; + .reg .u64 temp2; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.u64 temp, [in_addr]; + add.u64 temp2, temp, 1; + st.u64 [out_addr], temp2; + ret; + } + ", ); let tokens = lexer.map(|t| t.unwrap()).collect::>(); @@ -783,7 +1267,50 @@ fn main() { input: &tokens[..], state: Vec::new(), }; - let fn_body = fn_body.parse(stream).unwrap(); - println!("{}", fn_body.len()); + let module_ = module.parse(stream).unwrap(); println!("{}", mem::size_of::()); } + +#[cfg(test)] +mod tests { + use super::target; + use super::Token; + use logos::Logos; + use winnow::prelude::*; + + #[test] + fn sm_11() { + let tokens = Token::lexer(".target sm_11") + .collect::, ()>>() + .unwrap(); + let stream = super::PtxParser { + input: &tokens[..], + state: Vec::new(), + }; + assert_eq!(target.parse(stream).unwrap(), (11, None)); + } + + #[test] + fn sm_90a() { + let tokens = Token::lexer(".target sm_90a") + .collect::, ()>>() + .unwrap(); + let stream = super::PtxParser { + input: &tokens[..], + state: Vec::new(), + }; + assert_eq!(target.parse(stream).unwrap(), (90, Some('a'))); + } + + #[test] + fn sm_90ab() { + let tokens = Token::lexer(".target sm_90ab") + .collect::, ()>>() + .unwrap(); + let stream = super::PtxParser { + input: &tokens[..], + state: Vec::new(), + }; + assert!(target.parse(stream).is_err()); + } +} From 522541d5c5156b556acd2f5c8c0fada49940606b Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Sun, 18 Aug 2024 18:28:00 +0200 Subject: [PATCH 09/47] Support simple module variables --- ptx_parser/src/main.rs | 155 ++++++++++++++++++++++++++++++++--------- 1 file changed, 124 insertions(+), 31 deletions(-) diff --git a/ptx_parser/src/main.rs b/ptx_parser/src/main.rs index 7a29e635..0ac1260b 100644 --- a/ptx_parser/src/main.rs +++ b/ptx_parser/src/main.rs @@ -2,9 +2,8 @@ use gen::derive_parser; use logos::Logos; use std::mem; use std::num::{ParseFloatError, ParseIntError}; -use winnow::ascii::{dec_uint, digit1}; +use winnow::ascii::dec_uint; use winnow::combinator::*; -use winnow::error::ErrMode; use winnow::stream::Accumulate; use winnow::token::any; use winnow::{ @@ -76,6 +75,17 @@ fn ident<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<&'input str> .parse_next(stream) } +fn dot_ident<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<&'input str> { + any.verify_map(|t| { + if let Token::DotIdent(text) = t { + Some(text) + } else { + None + } + }) + .parse_next(stream) +} + fn num<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<(&'input str, u32, bool)> { any.verify_map(|t| { Some(match t { @@ -210,8 +220,9 @@ fn module<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult(stream: &mut &str) -> PResult<(u32, Option)> { fn directive<'a, 'input>( stream: &mut PtxParser<'a, 'input>, ) -> PResult>>> { - (function.map(|f| { - let (linking, func) = f; - Some(ast::Directive::Method(linking, func)) - })) + alt(( + function.map(|(linking, func)| Some(ast::Directive::Method(linking, func))), + file.map(|_| None), + section.map(|_| None), + (module_variable, Token::Semicolon) + .map(|((linking, var), _)| Some(ast::Directive::Variable(linking, var))), + )) + .parse_next(stream) +} + +fn module_variable<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult<(ast::LinkingDirective, ast::Variable<&'input str>)> { + ( + linking_directives, + module_variable_state_space.flat_map(variable_scalar_or_vector), + ) + .parse_next(stream) +} + +fn module_variable_state_space<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult { + alt(( + Token::DotConst.value(StateSpace::Const), + Token::DotGlobal.value(StateSpace::Global), + Token::DotShared.value(StateSpace::Shared), + )) + .parse_next(stream) +} + +fn file<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> { + ( + Token::DotFile, + u32, + Token::String, + opt((Token::Comma, u32, Token::Comma, u32)), + ) + .void() + .parse_next(stream) +} + +fn section<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> { + ( + Token::DotSection.void(), + dot_ident.void(), + Token::LBrace.void(), + repeat::<_, _, (), _, _>(0.., section_dwarf_line), + Token::RBrace.void(), + ) + .void() + .parse_next(stream) +} + +fn section_dwarf_line<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> { + alt(( + (section_label, Token::Colon).void(), + (Token::DotB32, section_label, opt((Token::Add, u32))).void(), + (Token::DotB64, section_label, opt((Token::Add, u32))).void(), + ( + any_bit_type, + separated::<_, _, (), _, _, _, _>(1.., u32, Token::Comma), + ) + .void(), + )) .parse_next(stream) } +fn any_bit_type<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> { + alt((Token::DotB8, Token::DotB16, Token::DotB32, Token::DotB64)) + .void() + .parse_next(stream) +} + +fn section_label<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> { + alt((ident, dot_ident)).void().parse_next(stream) +} + fn function<'a, 'input>( stream: &mut PtxParser<'a, 'input>, ) -> PResult<( @@ -283,12 +365,16 @@ fn function<'a, 'input>( fn linking_directives<'a, 'input>( stream: &mut PtxParser<'a, 'input>, ) -> PResult { - dispatch! { any; - Token::DotExtern => empty.value(ast::LinkingDirective::EXTERN), - Token::DotVisible => empty.value(ast::LinkingDirective::VISIBLE), - Token::DotWeak => empty.value(ast::LinkingDirective::WEAK), - _ => fail - } + repeat( + 0.., + dispatch! { any; + Token::DotExtern => empty.value(ast::LinkingDirective::EXTERN), + Token::DotVisible => empty.value(ast::LinkingDirective::VISIBLE), + Token::DotWeak => empty.value(ast::LinkingDirective::WEAK), + _ => fail + }, + ) + .fold(|| ast::LinkingDirective::NONE, |x, y| x | y) .parse_next(stream) } @@ -816,6 +902,8 @@ derive_parser!( At, #[regex(r"[a-zA-Z][a-zA-Z0-9_$]*|[_$%][a-zA-Z0-9_$]+", |lex| lex.slice(), priority = 0)] Ident(&'input str), + #[regex(r"\.[a-zA-Z][a-zA-Z0-9_$]*|\.[_$%][a-zA-Z0-9_$]+", |lex| lex.slice(), priority = 0)] + DotIdent(&'input str), #[regex(r#""[^"]*""#)] String, #[token("|")] @@ -879,7 +967,11 @@ derive_parser!( #[token(".target")] DotTarget, #[token(".address_size")] - DotAddressSize + DotAddressSize, + #[token(".action")] + DotSection, + #[token(".file")] + DotFile } #[derive(Copy, Clone, PartialEq, Eq, Hash)] @@ -1224,50 +1316,51 @@ fn main() { use winnow::token::*; use winnow::Parser; - println!("{}", mem::size_of::()); - - let mut input: &[Token] = &[][..]; - let x = opt(any::<_, ContextError>.verify_map(|_| { - println!("MAP"); - Some(true) - })) - .parse_next(&mut input) - .unwrap(); - dbg!(x); let lexer = Token::lexer( " .version 6.5 .target sm_30 .address_size 64 - .visible .entry add( + .const .align 8 .b32 constparams; + + .visible .entry const( .param .u64 input, .param .u64 output ) { .reg .u64 in_addr; .reg .u64 out_addr; - .reg .u64 temp; - .reg .u64 temp2; + .reg .b16 temp1; + .reg .b16 temp2; + .reg .b16 temp3; + .reg .b16 temp4; ld.param.u64 in_addr, [input]; ld.param.u64 out_addr, [output]; - ld.u64 temp, [in_addr]; - add.u64 temp2, temp, 1; - st.u64 [out_addr], temp2; + ld.const.b16 temp1, [constparams]; + ld.const.b16 temp2, [constparams+2]; + ld.const.b16 temp3, [constparams+4]; + ld.const.b16 temp4, [constparams+6]; + st.u16 [out_addr], temp1; + st.u16 [out_addr+2], temp2; + st.u16 [out_addr+4], temp3; + st.u16 [out_addr+6], temp4; ret; } ", ); + let tokens = lexer.clone().collect::>(); + println!("{:?}", &tokens); let tokens = lexer.map(|t| t.unwrap()).collect::>(); println!("{:?}", &tokens); let stream = PtxParser { input: &tokens[..], state: Vec::new(), }; - let module_ = module.parse(stream).unwrap(); + let _module = module.parse(stream).unwrap(); println!("{}", mem::size_of::()); } From cb64b04f41b39a8b4740fe1c3e9450a05a90d950 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Sun, 18 Aug 2024 23:27:07 +0200 Subject: [PATCH 10/47] Add mul --- ptx_parser/src/ast.rs | 99 ++++++++++++++++++++++++------- ptx_parser/src/main.rs | 130 +++++++++++++++++++++++++++++++++++++++-- 2 files changed, 202 insertions(+), 27 deletions(-) diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index 2dabf3e8..714c9b38 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -1,3 +1,5 @@ +use std::intrinsics::unreachable; + use super::{MemScope, ScalarType, StateSpace, VectorPrefix}; use bitflags::bitflags; @@ -8,25 +10,6 @@ pub enum Statement { Block(Vec>), } -pub struct MultiVariable { - pub var: Variable, - pub count: Option, -} - -#[derive(Clone)] -pub struct Variable { - pub align: Option, - pub v_type: Type, - pub state_space: StateSpace, - pub name: ID, - pub array_init: Vec, -} - -pub struct PredAt { - pub not: bool, - pub label: ID, -} - gen::generate_instruction_type!( pub enum Instruction { Mov { @@ -68,6 +51,18 @@ gen::generate_instruction_type!( src2: T, } }, + Mul { + type: { data.type_().into() }, + data: MulDetails, + arguments: { + dst: { + repr: T, + type: { data.dst_type().into() }, + }, + src1: T, + src2: T, + } + }, Ret { data: RetData }, @@ -75,6 +70,25 @@ gen::generate_instruction_type!( } ); +pub struct MultiVariable { + pub var: Variable, + pub count: Option, +} + +#[derive(Clone)] +pub struct Variable { + pub align: Option, + pub v_type: Type, + pub state_space: StateSpace, + pub name: ID, + pub array_init: Vec, +} + +pub struct PredAt { + pub not: bool, + pub label: ID, +} + pub trait Visitor { fn visit(&mut self, args: &T, type_: &Type, space: StateSpace, is_dst: bool); } @@ -185,7 +199,7 @@ pub enum ArithDetails { } impl ArithDetails { - pub fn type_(&self) -> super::ScalarType { + pub fn type_(&self) -> ScalarType { match self { ArithDetails::Integer(t) => t.type_, ArithDetails::Float(arith) => arith.type_, @@ -195,13 +209,13 @@ impl ArithDetails { #[derive(Copy, Clone)] pub struct ArithInteger { - pub type_: super::ScalarType, + pub type_: ScalarType, pub saturate: bool, } #[derive(Copy, Clone)] pub struct ArithFloat { - pub type_: super::ScalarType, + pub type_: ScalarType, pub rounding: Option, pub flush_to_zero: Option, pub saturate: bool, @@ -292,3 +306,44 @@ pub struct Module<'input> { pub version: (u8, u8), pub directives: Vec>>, } + +#[derive(Copy, Clone)] +pub enum MulDetails { + Integer { + type_: ScalarType, + control: MulIntControl, + }, + Float(ArithFloat), +} + +impl MulDetails { + fn type_(&self) -> ScalarType { + match self { + MulDetails::Integer { type_, .. } => *type_, + MulDetails::Float(arith) => arith.type_, + } + } + + fn dst_type(&self) -> ScalarType { + match self { + MulDetails::Integer { + type_, + control: MulIntControl::Wide, + } => match type_ { + ScalarType::U16 => ScalarType::U32, + ScalarType::S16 => ScalarType::S32, + ScalarType::U32 => ScalarType::U64, + ScalarType::S32 => ScalarType::S64, + _ => unreachable!(), + }, + _ => self.type_(), + } + } +} + +#[derive(Copy, Clone, PartialEq, Eq)] +pub enum MulIntControl { + Low, + High, + Wide, +} diff --git a/ptx_parser/src/main.rs b/ptx_parser/src/main.rs index 0ac1260b..b087fb93 100644 --- a/ptx_parser/src/main.rs +++ b/ptx_parser/src/main.rs @@ -16,6 +16,16 @@ use winnow::{prelude::*, Stateful}; mod ast; pub use ast::*; +impl From for ast::MulIntControl { + fn from(value: RawMulIntControl) -> Self { + match value { + RawMulIntControl::Lo => ast::MulIntControl::Low, + RawMulIntControl::Hi => ast::MulIntControl::High, + RawMulIntControl::Wide => ast::MulIntControl::Wide, + } + } +} + impl From for ast::StCacheOperator { fn from(value: RawStCacheOperator) -> Self { match value { @@ -1066,7 +1076,6 @@ derive_parser!( arguments: ast::StArgs { src1:a, src2:b } } } - .ss: StateSpace = { .global, .local, .param{::func}, .shared{::cta, ::cluster} }; .level::eviction_priority: EvictionPriority = { .L1::evict_normal, .L1::evict_unchanged, .L1::evict_first, .L1::evict_last, .L1::no_allocate }; @@ -1156,7 +1165,6 @@ derive_parser!( arguments: LdArgs { dst:d, src:a } } } - .ss: StateSpace = { .const, .global, .local, .param{::entry, ::func}, .shared{::cta, ::cluster} }; .cop: RawLdCacheOperator = { .ca, .cg, .cs, .lu, .cv }; .level::eviction_priority: EvictionPriority = @@ -1199,7 +1207,6 @@ derive_parser!( } } } - .type: ScalarType = { .u16, .u32, .u64, .s16, .s64, .u16x2, .s16x2 }; @@ -1236,7 +1243,6 @@ derive_parser!( } } } - .rnd: RawFloatRounding = { .rn, .rz, .rm, .rp }; ScalarType = { .f32, .f64 }; @@ -1301,10 +1307,124 @@ derive_parser!( } } } - .rnd: RawFloatRounding = { .rn }; ScalarType = { .f16, .f16x2, .bf16, .bf16x2 }; + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-mul + mul.mode.type d, a, b => { + ast::Instruction::Mul { + data: ast::MulDetails::Integer { + type_, + control: mode.into() + }, + arguments: MulArgs { dst: d, src1: a, src2: b } + } + } + .mode: RawMulIntControl = { .hi, .lo }; + .type: ScalarType = { .u16, .u32, .u64, + .s16, .s32, .s64 }; + // "The .wide suffix is supported only for 16- and 32-bit integer types" + mul.wide.type d, a, b => { + ast::Instruction::Mul { + data: ast::MulDetails::Integer { + type_, + control: wide.into() + }, + arguments: MulArgs { dst: d, src1: a, src2: b } + } + } + .type: ScalarType = { .u16, .u32, + .s16, .s32 }; + RawMulIntControl = { .wide }; + + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-mul + mul{.rnd}{.ftz}{.sat}.f32 d, a, b => { + ast::Instruction::Mul { + data: ast::MulDetails::Float ( + ArithFloat { + type_: f32, + rounding: rnd.map(Into::into), + flush_to_zero: Some(ftz), + saturate: sat, + } + ), + arguments: MulArgs { dst: d, src1: a, src2: b } + } + } + mul{.rnd}.f64 d, a, b => { + ast::Instruction::Mul { + data: ast::MulDetails::Float ( + ArithFloat { + type_: f64, + rounding: rnd.map(Into::into), + flush_to_zero: None, + saturate: false, + } + ), + arguments: MulArgs { dst: d, src1: a, src2: b } + } + } + .rnd: RawFloatRounding = { .rn, .rz, .rm, .rp }; + ScalarType = { .f32, .f64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-mul + mul{.rnd}{.ftz}{.sat}.f16 d, a, b => { + ast::Instruction::Mul { + data: ast::MulDetails::Float ( + ArithFloat { + type_: f16, + rounding: rnd.map(Into::into), + flush_to_zero: Some(ftz), + saturate: sat, + } + ), + arguments: MulArgs { dst: d, src1: a, src2: b } + } + } + mul{.rnd}{.ftz}{.sat}.f16x2 d, a, b => { + ast::Instruction::Mul { + data: ast::MulDetails::Float ( + ArithFloat { + type_: f16x2, + rounding: rnd.map(Into::into), + flush_to_zero: Some(ftz), + saturate: sat, + } + ), + arguments: MulArgs { dst: d, src1: a, src2: b } + } + } + mul{.rnd}.bf16 d, a, b => { + ast::Instruction::Mul { + data: ast::MulDetails::Float ( + ArithFloat { + type_: bf16, + rounding: rnd.map(Into::into), + flush_to_zero: None, + saturate: false, + } + ), + arguments: MulArgs { dst: d, src1: a, src2: b } + } + } + mul{.rnd}.bf16x2 d, a, b => { + ast::Instruction::Mul { + data: ast::MulDetails::Float ( + ArithFloat { + type_: bf16x2, + rounding: rnd.map(Into::into), + flush_to_zero: None, + saturate: false, + } + ), + arguments: MulArgs { dst: d, src1: a, src2: b } + } + } + .rnd: RawFloatRounding = { .rn }; + ScalarType = { .f16, .f16x2, .bf16, .bf16x2 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-ret ret{.uni} => { Instruction::Ret { data: RetData { uniform: uni } } } From c08e6a6772b934e042136055d864898ff065c682 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Mon, 19 Aug 2024 02:23:26 +0200 Subject: [PATCH 11/47] Implement setp --- gen/Cargo.toml | 1 + gen/src/lib.rs | 79 +++++++---- gen_impl/src/lib.rs | 30 +++-- ptx_parser/src/ast.rs | 294 ++++++++++++++++++++++++++++++++++++++--- ptx_parser/src/main.rs | 39 +++++- 5 files changed, 388 insertions(+), 55 deletions(-) diff --git a/gen/Cargo.toml b/gen/Cargo.toml index e24be0f4..e26383da 100644 --- a/gen/Cargo.toml +++ b/gen/Cargo.toml @@ -13,3 +13,4 @@ rustc-hash = "2.0.0" syn = "2.0.67" quote = "1.0" proc-macro2 = "1.0.86" +either = "1.13.0" diff --git a/gen/src/lib.rs b/gen/src/lib.rs index ebddf03f..472d1fc6 100644 --- a/gen/src/lib.rs +++ b/gen/src/lib.rs @@ -1,3 +1,4 @@ +use either::Either; use gen_impl::parser; use proc_macro2::{Span, TokenStream}; use quote::{format_ident, quote, ToTokens}; @@ -28,7 +29,7 @@ static POSTFIX_TYPES: &[&str] = &["ScalarType", "VectorPrefix"]; struct OpcodeDefinitions { definitions: Vec, - block_selection: Vec, usize)>>, + block_selection: Vec>, usize)>>, } impl OpcodeDefinitions { @@ -51,33 +52,51 @@ impl OpcodeDefinitions { _ => {} } 'check_definitions: for i in unselected.iter().copied() { - // Attempt every modifier - 'check_candidates: for candidate in definitions[i] + let mut candidates = definitions[i] .unordered_modifiers .iter() .chain(definitions[i].ordered_modifiers.iter()) - { - let candidate = if let DotModifierRef::Direct { - optional: false, - value, - .. - } = candidate - { - value - } else { - continue; - }; + .filter(|modifier| match modifier { + DotModifierRef::Direct { + optional: false, .. + } + | DotModifierRef::Indirect { + optional: false, .. + } => true, + _ => false, + }) + .collect::>(); + candidates.sort_by_key(|modifier| match modifier { + DotModifierRef::Direct { .. } => 1, + DotModifierRef::Indirect { value, .. } => value.alternatives.len(), + }); + // Attempt every modifier + 'check_candidates: for candidate_modifier in candidates { // check all other unselected patterns for j in unselected.iter().copied() { if i == j { continue; } - if definitions[j].possible_modifiers.contains(candidate) { - continue 'check_candidates; + let candidate_set = match candidate_modifier { + DotModifierRef::Direct { value, .. } => Either::Left(iter::once(value)), + DotModifierRef::Indirect { value, .. } => { + Either::Right(value.alternatives.iter()) + } + }; + for candidate_value in candidate_set { + if definitions[j].possible_modifiers.contains(candidate_value) { + continue 'check_candidates; + } } } // it's unique - selections[i] = Some((Some(candidate), generation)); + let candidate_vec = match candidate_modifier { + DotModifierRef::Direct { value, .. } => vec![value.clone()], + DotModifierRef::Indirect { value, .. } => { + value.alternatives.iter().cloned().collect::>() + } + }; + selections[i] = Some((Some(candidate_vec), generation)); selected_something = true; continue 'check_definitions; } @@ -96,9 +115,9 @@ impl OpcodeDefinitions { let mut current_generation_definitions = Vec::new(); for (idx, selection) in selections.iter_mut().enumerate() { match selection { - Some((modifier, generation)) => { + Some((modifier_set, generation)) => { if *generation == current_generation { - current_generation_definitions.push((modifier.cloned(), idx)); + current_generation_definitions.push((modifier_set.clone(), idx)); *selection = None; } } @@ -181,6 +200,8 @@ impl SingleOpcodeDefinition { let name = &arg.ident; let arg_type = if arg.unified { quote! { (ParsedOperandStr<'input>, bool) } + } else if arg.can_be_negated { + quote! { (bool, ParsedOperandStr<'input>) } } else { quote! { ParsedOperandStr<'input> } }; @@ -222,9 +243,6 @@ impl SingleOpcodeDefinition { unnamed_rules = FxHashMap::default(); } let mut possible_modifiers = FxHashSet::default(); - for (_, options) in named_rules.iter() { - possible_modifiers.extend(options.alternatives.iter().cloned()); - } let parser::OpcodeDecl(instruction, arguments) = opcode_decl; let mut unordered_modifiers = instruction .modifiers @@ -232,6 +250,7 @@ impl SingleOpcodeDefinition { .map(|parser::MaybeDotModifier { optional, modifier }| { match named_rules.get(&modifier) { Some(alts) => { + possible_modifiers.extend(alts.alternatives.iter().cloned()); if alts.alternatives.len() == 1 && alts.type_.is_none() { DotModifierRef::Direct { optional, @@ -437,11 +456,10 @@ fn emit_parse_function( for (selection_key, selected_definition) in selection_layer { let def_parser = emit_definition_parser(type_name, (opcode,*selected_definition), &def.definitions[*selected_definition]); match selection_key { - Some(selection_key) => { - let selection_key = - selection_key.dot_capitalized(); + Some(selection_keys) => { + let selection_keys = selection_keys.iter().map(|k| k.dot_capitalized()); quote! { - else if modifiers.contains(& #type_name :: #selection_key) { + else if false #(|| modifiers.contains(& #type_name :: #selection_keys))* { #def_parser } } @@ -715,7 +733,7 @@ fn emit_definition_parser( | DotModifierRef::Indirect { optional: true, .. } => TokenStream::new(), }); let arguments_parse = definition.arguments.0.iter().enumerate().map(|(idx, arg)| { - let comma = if idx == 0 { + let comma = if idx == 0 || arg.pre_pipe { quote! { empty } } else { quote! { any.verify(|t| *t == #token_type::Comma).void() } @@ -774,10 +792,17 @@ fn emit_definition_parser( (#comma, #pre_bracket, #pre_pipe, #can_be_negated, #operand, #post_bracket, #unified) }; let arg_name = &arg.ident; + if arg.unified && arg.can_be_negated { + panic!("TODO: argument can't be both prefixed by `!` and suffixed by `.unified`") + } let inner_parser = if arg.unified { quote! { #pattern.map(|(_, _, _, _, name, _, unified)| (name, unified)) } + } else if arg.can_be_negated { + quote! { + #pattern.map(|(_, _, _, negated, name, _, _)| (negated, name)) + } } else { quote! { #pattern.map(|(_, _, _, _, name, _, _)| name) diff --git a/gen_impl/src/lib.rs b/gen_impl/src/lib.rs index 7160603d..57660fb0 100644 --- a/gen_impl/src/lib.rs +++ b/gen_impl/src/lib.rs @@ -70,7 +70,7 @@ impl GenerateInstructionType { let visit_slice_fn = format_ident!("visit{}_slice", kind.fn_suffix()); let (type_parameters, visitor_parameters, return_type) = if kind == VisitKind::Map { ( - quote! { <#type_parameters, To> }, + quote! { <#type_parameters, To: Operand> }, quote! { <#short_parameters, To> }, quote! { #type_name }, ) @@ -514,19 +514,29 @@ impl ArgumentField { .unwrap_or_else(|| quote! { StateSpace::Reg }); let is_dst = self.is_dst; let name = &self.name; - let arguments_name = if is_mut { - quote! { - &mut arguments.#name - } + let (operand_fn, arguments_name) = if is_mut { + ( + quote! { + VisitOperand::visit_mut + }, + quote! { + &mut arguments.#name + }, + ) } else { - quote! { - & arguments.#name - } + ( + quote! { + VisitOperand::visit + }, + quote! { + & arguments.#name + }, + ) }; quote! {{ let type_ = #type_; let space = #space; - visitor.visit(#arguments_name, &type_, space, #is_dst); + #operand_fn(#arguments_name, |x| visitor.visit(x, &type_, space, #is_dst)); }} } @@ -548,7 +558,7 @@ impl ArgumentField { let #name = { let type_ = #type_; let space = #space; - visitor.visit(arguments.#name, &type_, space, #is_dst) + MapOperand::map(arguments.#name, |x| visitor.visit(x, &type_, space, #is_dst)) }; } } diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index 714c9b38..e456e03a 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -1,6 +1,5 @@ -use std::intrinsics::unreachable; - -use super::{MemScope, ScalarType, StateSpace, VectorPrefix}; +use super::{MemScope, RawSetpCompareOp, ScalarType, SetpBoolPostOp, StateSpace, VectorPrefix}; +use crate::{PtxError, PtxParserState}; use bitflags::bitflags; pub enum Statement { @@ -11,7 +10,7 @@ pub enum Statement { } gen::generate_instruction_type!( - pub enum Instruction { + pub enum Instruction { Mov { type: { &data.typ }, data: MovDetails, @@ -63,6 +62,52 @@ gen::generate_instruction_type!( src2: T, } }, + Setp { + data: SetpData, + arguments: { + dst1: { + repr: T, + type: ScalarType::Pred.into() + }, + dst2: { + repr: Option, + type: ScalarType::Pred.into() + }, + src1: { + repr: T, + type: data.type_.into(), + }, + src2: { + repr: T, + type: data.type_.into(), + } + } + }, + SetpBool { + data: SetpBoolData, + arguments: { + dst1: { + repr: T, + type: ScalarType::Pred.into() + }, + dst2: { + repr: Option, + type: ScalarType::Pred.into() + }, + src1: { + repr: T, + type: data.base.type_.into(), + }, + src2: { + repr: T, + type: data.base.type_.into(), + }, + src3: { + repr: T, + type: ScalarType::Pred.into() + } + } + }, Ret { data: RetData }, @@ -70,6 +115,66 @@ gen::generate_instruction_type!( } ); +pub trait Visitor { + fn visit(&mut self, args: &T, type_: &Type, space: StateSpace, is_dst: bool); +} + +pub trait VisitorMut { + fn visit(&mut self, args: &mut T, type_: &Type, space: StateSpace, is_dst: bool); +} + +pub trait VisitorMap { + fn visit(&mut self, args: From, type_: &Type, space: StateSpace, is_dst: bool) -> To; +} + +trait VisitOperand { + type Operand; + fn visit(&self, fn_: impl FnOnce(&Self::Operand)); + fn visit_mut(&mut self, fn_: impl FnOnce(&mut Self::Operand)); +} + +impl VisitOperand for T { + type Operand = Self; + fn visit(&self, fn_: impl FnOnce(&Self::Operand)) { + fn_(self) + } + fn visit_mut(&mut self, fn_: impl FnOnce(&mut Self::Operand)) { + fn_(self) + } +} + +impl VisitOperand for Option { + type Operand = T; + fn visit(&self, fn_: impl FnOnce(&Self::Operand)) { + self.as_ref().map(fn_); + } + fn visit_mut(&mut self, fn_: impl FnOnce(&mut Self::Operand)) { + self.as_mut().map(fn_); + } +} + +trait MapOperand: Sized { + type Input; + type Output; + fn map(self, fn_: impl FnOnce(Self::Input) -> U) -> Self::Output; +} + +impl MapOperand for T { + type Input = Self; + type Output = U; + fn map(self, fn_: impl FnOnce(T) -> U) -> U { + fn_(self) + } +} + +impl MapOperand for Option { + type Input = T; + type Output = Option; + fn map(self, fn_: impl FnOnce(T) -> U) -> Option { + self.map(|x| fn_(x)) + } +} + pub struct MultiVariable { pub var: Variable, pub count: Option, @@ -89,18 +194,6 @@ pub struct PredAt { pub label: ID, } -pub trait Visitor { - fn visit(&mut self, args: &T, type_: &Type, space: StateSpace, is_dst: bool); -} - -pub trait VisitorMut { - fn visit(&mut self, args: &mut T, type_: &Type, space: StateSpace, is_dst: bool); -} - -pub trait VisitorMap { - fn visit(&mut self, args: From, type_: &Type, space: StateSpace, is_dst: bool) -> To; -} - #[derive(PartialEq, Eq, Clone, Hash)] pub enum Type { // .param.b32 foo; @@ -121,6 +214,43 @@ impl Type { } } +impl ScalarType { + pub fn kind(self) -> ScalarKind { + match self { + ScalarType::U8 => ScalarKind::Unsigned, + ScalarType::U16 => ScalarKind::Unsigned, + ScalarType::U16x2 => ScalarKind::Unsigned, + ScalarType::U32 => ScalarKind::Unsigned, + ScalarType::U64 => ScalarKind::Unsigned, + ScalarType::S8 => ScalarKind::Signed, + ScalarType::S16 => ScalarKind::Signed, + ScalarType::S16x2 => ScalarKind::Signed, + ScalarType::S32 => ScalarKind::Signed, + ScalarType::S64 => ScalarKind::Signed, + ScalarType::B8 => ScalarKind::Bit, + ScalarType::B16 => ScalarKind::Bit, + ScalarType::B32 => ScalarKind::Bit, + ScalarType::B64 => ScalarKind::Bit, + ScalarType::B128 => ScalarKind::Bit, + ScalarType::F16 => ScalarKind::Float, + ScalarType::F16x2 => ScalarKind::Float, + ScalarType::F32 => ScalarKind::Float, + ScalarType::F64 => ScalarKind::Float, + ScalarType::BF16 => ScalarKind::Float, + ScalarType::BF16x2 => ScalarKind::Float, + ScalarType::Pred => ScalarKind::Pred, + } + } +} + +#[derive(Clone, Copy, PartialEq, Eq)] +pub enum ScalarKind { + Bit, + Unsigned, + Signed, + Float, + Pred, +} impl From for Type { fn from(value: ScalarType) -> Self { Type::Scalar(value) @@ -347,3 +477,135 @@ pub enum MulIntControl { High, Wide, } + +pub struct SetpData { + pub type_: ScalarType, + pub flush_to_zero: Option, + pub cmp_op: SetpCompareOp, +} + +impl SetpData { + pub(crate) fn try_parse( + errors: &mut PtxParserState, + cmp_op: super::RawSetpCompareOp, + ftz: bool, + type_: ScalarType, + ) -> Self { + let flush_to_zero = match (ftz, type_) { + (_, ScalarType::F32) => Some(ftz), + _ => { + errors.push(PtxError::NonF32Ftz); + None + } + }; + let type_kind = type_.kind(); + let cmp_op = if type_kind == ScalarKind::Float { + SetpCompareOp::Float(SetpCompareFloat::from(cmp_op)) + } else { + match SetpCompareInt::try_from(cmp_op) { + Ok(op) => SetpCompareOp::Integer(op), + Err(err) => { + errors.push(err); + SetpCompareOp::Integer(SetpCompareInt::Eq) + } + } + }; + Self { + type_, + flush_to_zero, + cmp_op, + } + } +} + +pub struct SetpBoolData { + pub base: SetpData, + pub bool_op: SetpBoolPostOp, + pub negate_src3: bool +} + +#[derive(PartialEq, Eq, Copy, Clone)] +pub enum SetpCompareOp { + Integer(SetpCompareInt), + Float(SetpCompareFloat), +} + +#[derive(PartialEq, Eq, Copy, Clone)] +pub enum SetpCompareInt { + Eq, + NotEq, + Less, + LessOrEq, + Greater, + GreaterOrEq, +} + +#[derive(PartialEq, Eq, Copy, Clone)] +pub enum SetpCompareFloat { + Eq, + NotEq, + Less, + LessOrEq, + Greater, + GreaterOrEq, + NanEq, + NanNotEq, + NanLess, + NanLessOrEq, + NanGreater, + NanGreaterOrEq, + IsNotNan, + IsAnyNan, +} + +impl TryFrom for SetpCompareInt { + type Error = PtxError; + + fn try_from(value: RawSetpCompareOp) -> Result { + match value { + RawSetpCompareOp::Eq => Ok(SetpCompareInt::Eq), + RawSetpCompareOp::Ne => Ok(SetpCompareInt::NotEq), + RawSetpCompareOp::Lt => Ok(SetpCompareInt::Less), + RawSetpCompareOp::Le => Ok(SetpCompareInt::LessOrEq), + RawSetpCompareOp::Gt => Ok(SetpCompareInt::Greater), + RawSetpCompareOp::Ge => Ok(SetpCompareInt::GreaterOrEq), + RawSetpCompareOp::Lo => Ok(SetpCompareInt::Less), + RawSetpCompareOp::Ls => Ok(SetpCompareInt::LessOrEq), + RawSetpCompareOp::Hi => Ok(SetpCompareInt::Greater), + RawSetpCompareOp::Hs => Ok(SetpCompareInt::GreaterOrEq), + RawSetpCompareOp::Equ => Err(PtxError::WrongType), + RawSetpCompareOp::Neu => Err(PtxError::WrongType), + RawSetpCompareOp::Ltu => Err(PtxError::WrongType), + RawSetpCompareOp::Leu => Err(PtxError::WrongType), + RawSetpCompareOp::Gtu => Err(PtxError::WrongType), + RawSetpCompareOp::Geu => Err(PtxError::WrongType), + RawSetpCompareOp::Num => Err(PtxError::WrongType), + RawSetpCompareOp::Nan => Err(PtxError::WrongType), + } + } +} + +impl From for SetpCompareFloat { + fn from(value: RawSetpCompareOp) -> Self { + match value { + RawSetpCompareOp::Eq => SetpCompareFloat::Eq, + RawSetpCompareOp::Ne => SetpCompareFloat::NotEq, + RawSetpCompareOp::Lt => SetpCompareFloat::Less, + RawSetpCompareOp::Le => SetpCompareFloat::LessOrEq, + RawSetpCompareOp::Gt => SetpCompareFloat::Greater, + RawSetpCompareOp::Ge => SetpCompareFloat::GreaterOrEq, + RawSetpCompareOp::Lo => SetpCompareFloat::Less, + RawSetpCompareOp::Ls => SetpCompareFloat::LessOrEq, + RawSetpCompareOp::Hi => SetpCompareFloat::Greater, + RawSetpCompareOp::Hs => SetpCompareFloat::GreaterOrEq, + RawSetpCompareOp::Equ => SetpCompareFloat::NanEq, + RawSetpCompareOp::Neu => SetpCompareFloat::NanNotEq, + RawSetpCompareOp::Ltu => SetpCompareFloat::NanLess, + RawSetpCompareOp::Leu => SetpCompareFloat::NanLessOrEq, + RawSetpCompareOp::Gtu => SetpCompareFloat::NanGreater, + RawSetpCompareOp::Geu => SetpCompareFloat::NanGreaterOrEq, + RawSetpCompareOp::Num => SetpCompareFloat::IsNotNan, + RawSetpCompareOp::Nan => SetpCompareFloat::IsAnyNan, + } + } +} diff --git a/ptx_parser/src/main.rs b/ptx_parser/src/main.rs index b087fb93..785496d1 100644 --- a/ptx_parser/src/main.rs +++ b/ptx_parser/src/main.rs @@ -769,6 +769,8 @@ pub enum PtxError { #[error("")] NonF32Ftz, #[error("")] + WrongType, + #[error("")] WrongArrayType, #[error("")] WrongVectorElement, @@ -996,6 +998,9 @@ derive_parser!( #[derive(Copy, Clone, PartialEq, Eq, Hash)] pub enum ScalarType { } + #[derive(Copy, Clone, PartialEq, Eq, Hash)] + pub enum SetpBoolPostOp { } + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov mov{.vec}.type d, a => { Instruction::Mov { @@ -1424,6 +1429,38 @@ derive_parser!( .rnd: RawFloatRounding = { .rn }; ScalarType = { .f16, .f16x2, .bf16, .bf16x2 }; + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#comparison-and-selection-instructions-setp + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-comparison-instructions-setp + setp.CmpOp{.ftz}.type p[|q], a, b => { + let data = ast::SetpData::try_parse(state, cmpop, ftz, type_); + ast::Instruction::Setp { + data, + arguments: SetpArgs { dst1: p, dst2: q, src1: a, src2: b } + } + } + setp.CmpOp.BoolOp{.ftz}.type p[|q], a, b, {!}c => { + let (negate_src3, c) = c; + let base = ast::SetpData::try_parse(state, cmpop, ftz, type_); + let data = ast::SetpBoolData { + base, + bool_op: boolop, + negate_src3 + }; + ast::Instruction::SetpBool { + data, + arguments: SetpBoolArgs { dst1: p, dst2: q, src1: a, src2: b, src3: c } + } + } + .CmpOp: RawSetpCompareOp = { .eq, .ne, .lt, .le, .gt, .ge, + .lo, .ls, .hi, .hs, // signed + .equ, .neu, .ltu, .leu, .gtu, .geu, .num, .nan }; // float-only + .BoolOp: SetpBoolPostOp = { .and, .or, .xor }; + .type: ScalarType = { .b16, .b32, .b64, + .u16, .u32, .u64, + .s16, .s32, .s64, + .f32, .f64, + .f16, .f16x2, .bf16, .bf16x2 }; + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-ret ret{.uni} => { Instruction::Ret { data: RetData { uniform: uni } } @@ -1432,8 +1469,6 @@ derive_parser!( ); fn main() { - use winnow::combinator::*; - use winnow::token::*; use winnow::Parser; let lexer = Token::lexer( From 22492ec7f1b42b58b5ca5a61f3cb9fffa1757510 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Mon, 19 Aug 2024 13:37:04 +0200 Subject: [PATCH 12/47] Implement not, or, and, bra --- gen/src/lib.rs | 56 +++++++++-- gen_impl/src/lib.rs | 208 ++++++++++++++++++++++++++--------------- gen_impl/src/parser.rs | 34 +++++-- ptx_parser/src/ast.rs | 84 ++++++++++++----- ptx_parser/src/main.rs | 51 +++++++++- 5 files changed, 315 insertions(+), 118 deletions(-) diff --git a/gen/src/lib.rs b/gen/src/lib.rs index 472d1fc6..a110fdc4 100644 --- a/gen/src/lib.rs +++ b/gen/src/lib.rs @@ -231,7 +231,8 @@ impl SingleOpcodeDefinition { } fn extract_and_insert( - output: &mut FxHashMap>, + definitions: &mut FxHashMap>, + special_definitions: &mut FxHashMap, parser::OpcodeDefinition(pattern_seq, rules): parser::OpcodeDefinition, ) { let (mut named_rules, mut unnamed_rules) = gather_rules(rules); @@ -242,8 +243,18 @@ impl SingleOpcodeDefinition { named_rules = FxHashMap::default(); unnamed_rules = FxHashMap::default(); } - let mut possible_modifiers = FxHashSet::default(); let parser::OpcodeDecl(instruction, arguments) = opcode_decl; + if code_block.special { + if !instruction.modifiers.is_empty() || !arguments.0.is_empty() { + panic!( + "`{}`: no modifiers or arguments are allowed in parser definition.", + instruction.name + ); + } + special_definitions.insert(instruction.name, code_block.code); + continue; + } + let mut possible_modifiers = FxHashSet::default(); let mut unordered_modifiers = instruction .modifiers .into_iter() @@ -287,7 +298,7 @@ impl SingleOpcodeDefinition { arguments, code_block, }; - multihash_extend(output, current_opcode.clone(), entry); + multihash_extend(definitions, current_opcode.clone(), entry); last_opcode = current_opcode; } } @@ -350,10 +361,15 @@ fn gather_rules( pub fn derive_parser(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream { let parse_definitions = parse_macro_input!(tokens as gen_impl::parser::ParseDefinitions); let mut definitions = FxHashMap::default(); + let mut special_definitions = FxHashMap::default(); let types = OpcodeDefinitions::get_enum_types(&parse_definitions.definitions); let enum_types_tokens = emit_enum_types(types, parse_definitions.additional_enums); for definition in parse_definitions.definitions.into_iter() { - SingleOpcodeDefinition::extract_and_insert(&mut definitions, definition); + SingleOpcodeDefinition::extract_and_insert( + &mut definitions, + &mut special_definitions, + definition, + ); } let definitions = definitions .into_iter() @@ -363,9 +379,12 @@ pub fn derive_parser(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream }) .collect::>(); let mut token_enum = parse_definitions.token_type; - let (all_opcode, all_modifier) = - write_definitions_into_tokens(&definitions, &mut token_enum.variants); - let token_impl = emit_parse_function(&token_enum.ident, &definitions, all_opcode, all_modifier); + let (all_opcode, all_modifier) = write_definitions_into_tokens( + &definitions, + special_definitions.keys(), + &mut token_enum.variants, + ); + let token_impl = emit_parse_function(&token_enum.ident, &definitions, &special_definitions, all_opcode, all_modifier); let tokens = quote! { #enum_types_tokens @@ -422,6 +441,7 @@ fn emit_enum_types( fn emit_parse_function( type_name: &Ident, defs: &FxHashMap, + special_defs: &FxHashMap, all_opcode: Vec<&Ident>, all_modifier: FxHashSet<&parser::DotModifier>, ) -> TokenStream { @@ -433,7 +453,7 @@ fn emit_parse_function( let mut fn_name = opcode.to_string(); write!(&mut fn_name, "_{}", idx).ok(); let fn_name = Ident::new(&fn_name, Span::call_site()); - let code_block = &def.code_block.0; + let code_block = &def.code_block.code; let args = def.function_arguments_declarations(); quote! { fn #fn_name<'input>(state: &mut PtxParserState, #(#args),* ) -> Instruction> #code_block @@ -494,7 +514,12 @@ fn emit_parse_function( } .to_tokens(&mut result); result - }); + }).chain(special_defs.iter().map(|(opcode, code)| { + let opcode_variant = Ident::new(&capitalize(&opcode.to_string()), opcode.span()); + quote! { + #opcode_variant => { #code? } + } + })); let opcodes = all_opcode.into_iter().map(|op_ident| { let op = op_ident.to_string(); let variant = Ident::new(&capitalize(&op), op_ident.span()); @@ -749,7 +774,7 @@ fn emit_definition_parser( }; let pre_pipe = if arg.pre_pipe { quote! { - any.verify(|t| *t == #token_type::Or).void() + any.verify(|t| *t == #token_type::Pipe).void() } } else { quote! { @@ -845,6 +870,7 @@ fn emit_definition_parser( fn write_definitions_into_tokens<'a>( defs: &'a FxHashMap, + special_definitions: impl Iterator, variants: &mut Punctuated, ) -> (Vec<&'a Ident>, FxHashSet<&'a parser::DotModifier>) { let mut all_opcodes = Vec::new(); @@ -864,6 +890,16 @@ fn write_definitions_into_tokens<'a>( } } } + for opcode in special_definitions { + all_opcodes.push(opcode); + let opcode_as_string = opcode.to_string(); + let variant_name = Ident::new(&capitalize(&opcode_as_string), opcode.span()); + let arg: Variant = syn::parse_quote! { + #[token(#opcode_as_string)] + #variant_name + }; + variants.push(arg); + } for modifier in all_modifiers.iter() { let modifier_as_string = modifier.to_string(); let variant_name = modifier.dot_capitalized(); diff --git a/gen_impl/src/lib.rs b/gen_impl/src/lib.rs index 57660fb0..39cc30e7 100644 --- a/gen_impl/src/lib.rs +++ b/gen_impl/src/lib.rs @@ -1,8 +1,8 @@ use proc_macro2::TokenStream; use quote::{format_ident, quote, ToTokens}; use syn::{ - braced, parse::Parse, punctuated::Punctuated, token, Expr, Ident, Token, Type, TypeParam, - Visibility, + braced, parse::Parse, punctuated::Punctuated, token, Expr, Ident, PathSegment, Token, Type, + TypeParam, Visibility, }; pub mod parser; @@ -18,7 +18,7 @@ pub struct GenerateInstructionType { impl GenerateInstructionType { pub fn emit_arg_types(&self, tokens: &mut TokenStream) { for v in self.variants.iter() { - v.emit_type(&self.visibility, &self.type_parameters, tokens); + v.emit_type(&self.visibility, tokens); } } @@ -165,7 +165,7 @@ impl Parse for GenerateInstructionType { pub struct InstructionVariant { pub name: Ident, - pub type_: Option, + pub type_: Option>, pub space: Option, pub data: Option, pub arguments: Option, @@ -225,7 +225,7 @@ impl InstructionVariant { &self, enum_: &Ident, tokens: &mut TokenStream, - mut fn_: impl FnMut(&InstructionArguments, &Option, &Option) -> TokenStream, + mut fn_: impl FnMut(&InstructionArguments, &Option>, &Option) -> TokenStream, ) { let name = &self.name; let arguments = match &self.arguments { @@ -238,9 +238,10 @@ impl InstructionVariant { } Some(args) => args, }; + let data = &self.data.as_ref().map(|_| quote! { data,}); let arg_calls = fn_(arguments, &self.type_, &self.space); quote! { - #enum_ :: #name { arguments, data } => { + #enum_ :: #name { #data arguments } => { #arg_calls } } @@ -269,19 +270,14 @@ impl InstructionVariant { .to_tokens(tokens); } - fn emit_type( - &self, - vis: &Option, - type_parameters: &Punctuated, - tokens: &mut TokenStream, - ) { + fn emit_type(&self, vis: &Option, tokens: &mut TokenStream) { let arguments = match self.arguments { Some(ref a) => a, None => return, }; let name = self.args_name(); let type_parameters = if arguments.generic.is_some() { - Some(quote! { <#type_parameters> }) + Some(quote! { }) } else { None }; @@ -324,7 +320,7 @@ impl Parse for InstructionVariant { } enum VariantProperty { - Type(Expr), + Type(Option), Space(Expr), Data(Type), Arguments(InstructionArguments), @@ -336,7 +332,12 @@ impl VariantProperty { Ok(if lookahead.peek(Token![type]) { input.parse::()?; input.parse::()?; - VariantProperty::Type(input.parse::()?) + VariantProperty::Type(if input.peek(Token![!]) { + input.parse::()?; + None + } else { + Some(input.parse::()?) + }) } else if lookahead.peek(Ident) { let key = input.parse::()?; match &*key.to_string() { @@ -352,7 +353,7 @@ impl VariantProperty { let generics = if input.peek(Token![<]) { input.parse::()?; let gen_params = - Punctuated::::parse_separated_nonempty(input)?; + Punctuated::::parse_separated_nonempty(input)?; input.parse::]>()?; Some(gen_params) } else { @@ -380,13 +381,13 @@ impl VariantProperty { } pub struct InstructionArguments { - pub generic: Option>, + pub generic: Option>, pub fields: Punctuated, } impl InstructionArguments { pub fn parse( - generic: Option>, + generic: Option>, input: syn::parse::ParseStream, ) -> syn::Result { let fields = Punctuated::::parse_terminated_with( @@ -396,13 +397,17 @@ impl InstructionArguments { Ok(Self { generic, fields }) } - fn emit_visit(&self, parent_type: &Option, parent_space: &Option) -> TokenStream { + fn emit_visit( + &self, + parent_type: &Option>, + parent_space: &Option, + ) -> TokenStream { self.emit_visit_impl(parent_type, parent_space, ArgumentField::emit_visit) } fn emit_visit_mut( &self, - parent_type: &Option, + parent_type: &Option>, parent_space: &Option, ) -> TokenStream { self.emit_visit_impl(parent_type, parent_space, ArgumentField::emit_visit_mut) @@ -410,7 +415,7 @@ impl InstructionArguments { fn emit_visit_map( &self, - parent_type: &Option, + parent_type: &Option>, parent_space: &Option, ) -> TokenStream { self.emit_visit_impl(parent_type, parent_space, ArgumentField::emit_visit_map) @@ -418,14 +423,19 @@ impl InstructionArguments { fn emit_visit_impl( &self, - parent_type: &Option, + parent_type: &Option>, parent_space: &Option, - mut fn_: impl FnMut(&ArgumentField, &Option, &Option) -> TokenStream, + mut fn_: impl FnMut(&ArgumentField, &Option>, &Option, bool) -> TokenStream, ) -> TokenStream { + let is_ident = if let Some(ref generic) = self.generic { + generic.len() > 1 + } else { + false + }; let field_calls = self .fields .iter() - .map(|f| fn_(f, parent_type, parent_space)); + .map(|f| fn_(f, parent_type, parent_space, is_ident)); quote! { #(#field_calls)* } @@ -487,25 +497,37 @@ impl ArgumentField { input.parse::() } - fn emit_visit(&self, parent_type: &Option, parent_space: &Option) -> TokenStream { - self.emit_visit_impl(parent_type, parent_space, false) + fn emit_visit( + &self, + parent_type: &Option>, + parent_space: &Option, + is_ident: bool, + ) -> TokenStream { + self.emit_visit_impl(parent_type, parent_space, is_ident, false) } fn emit_visit_mut( &self, - parent_type: &Option, + parent_type: &Option>, parent_space: &Option, + is_ident: bool, ) -> TokenStream { - self.emit_visit_impl(parent_type, parent_space, true) + self.emit_visit_impl(parent_type, parent_space, is_ident, true) } fn emit_visit_impl( &self, - parent_type: &Option, + parent_type: &Option>, parent_space: &Option, + is_ident: bool, is_mut: bool, ) -> TokenStream { - let type_ = self.type_.as_ref().or(parent_type.as_ref()).unwrap(); + let (is_typeless, type_) = match (self.type_.as_ref(), parent_type) { + (Some(type_), _) => (false, Some(type_)), + (None, None) => panic!("No type set"), + (None, Some(None)) => (true, None), + (None, Some(Some(type_))) => (false, Some(type_)), + }; let space = self .space .as_ref() @@ -514,38 +536,72 @@ impl ArgumentField { .unwrap_or_else(|| quote! { StateSpace::Reg }); let is_dst = self.is_dst; let name = &self.name; - let (operand_fn, arguments_name) = if is_mut { - ( - quote! { - VisitOperand::visit_mut - }, - quote! { - &mut arguments.#name - }, - ) + let type_space = if is_typeless { + quote! { + let type_space = None; + } } else { - ( + quote! { + let type_ = #type_; + let space = #space; + let type_space = Some((std::borrow::Borrow::::borrow(&type_), space)); + } + }; + if is_ident { + if is_mut { quote! { - VisitOperand::visit - }, + { + #type_space + visitor.visit_ident(&mut arguments.#name, type_space, #is_dst); + } + } + } else { quote! { - & arguments.#name - }, - ) - }; - quote! {{ - let type_ = #type_; - let space = #space; - #operand_fn(#arguments_name, |x| visitor.visit(x, &type_, space, #is_dst)); - }} + { + #type_space + visitor.visit_ident(& arguments.#name, type_space, #is_dst); + } + } + } + } else { + let (operand_fn, arguments_name) = if is_mut { + ( + quote! { + VisitOperand::visit_mut + }, + quote! { + &mut arguments.#name + }, + ) + } else { + ( + quote! { + VisitOperand::visit + }, + quote! { + & arguments.#name + }, + ) + }; + quote! {{ + #type_space + #operand_fn(#arguments_name, |x| visitor.visit(x, type_space, #is_dst)); + }} + } } fn emit_visit_map( &self, - parent_type: &Option, + parent_type: &Option>, parent_space: &Option, + is_ident: bool, ) -> TokenStream { - let type_ = self.type_.as_ref().or(parent_type.as_ref()).unwrap(); + let (is_typeless, type_) = match (self.type_.as_ref(), parent_type) { + (Some(type_), _) => (false, Some(type_)), + (None, None) => panic!("No type set"), + (None, Some(None)) => (true, None), + (None, Some(Some(type_))) => (false, Some(type_)), + }; let space = self .space .as_ref() @@ -554,11 +610,30 @@ impl ArgumentField { .unwrap_or_else(|| quote! { StateSpace::Reg }); let is_dst = self.is_dst; let name = &self.name; - quote! { - let #name = { + let type_space = if is_typeless { + quote! { + let type_space = None; + } + } else { + quote! { let type_ = #type_; let space = #space; - MapOperand::map(arguments.#name, |x| visitor.visit(x, &type_, space, #is_dst)) + let type_space = Some((std::borrow::Borrow::::borrow(&type_), space)); + } + }; + let map_call = if is_ident { + quote! { + visitor.visit_ident(arguments.#name, type_space, #is_dst) + } + } else { + quote! { + MapOperand::map(arguments.#name, |x| visitor.visit(x, type_space, #is_dst)) + } + }; + quote! { + let #name = { + #type_space + #map_call }; } } @@ -702,27 +777,6 @@ mod tests { assert!(matches!(src.space, None)); } - #[test] - fn visit_variant() { - let input = quote! { - Ld { - type: ScalarType::U32, - data: LdDetails, - arguments

: { - dst: { - repr: P::Operand, - type: ScalarType::U32 - }, - src: P::Operand, - }, - } - }; - let variant = syn::parse2::(input).unwrap(); - let mut output = TokenStream::new(); - variant.emit_visit(&Ident::new("Instruction", Span::call_site()), &mut output); - assert_eq!(output.to_string(), "Instruction :: Ld { arguments , data } => { { let type_ = ScalarType :: U32 ; let space = StateSpace :: Reg ; visitor . visit (& arguments . dst , & type_ , space , true) ; } { let type_ = ScalarType :: U32 ; let space = StateSpace :: Reg ; visitor . visit (& arguments . src , & type_ , space , false) ; } }"); - } - #[test] fn visit_variant_empty() { let input = quote! { diff --git a/gen_impl/src/parser.rs b/gen_impl/src/parser.rs index 519bf129..ea5070d0 100644 --- a/gen_impl/src/parser.rs +++ b/gen_impl/src/parser.rs @@ -86,13 +86,27 @@ impl Parse for OpcodeDecl { } } -pub struct CodeBlock(pub proc_macro2::Group); +pub struct CodeBlock { + pub special: bool, + pub code: proc_macro2::Group, +} impl Parse for CodeBlock { fn parse(input: syn::parse::ParseStream) -> syn::Result { - input.parse::]>()?; - let group = input.parse::()?; - Ok(Self(group)) + let lookahead = input.lookahead1(); + let (special, code) = if lookahead.peek(Token![<]) { + input.parse::()?; + input.parse::()?; + //input.parse::]>()?; + (true, input.parse::()?) + } else if lookahead.peek(Token![=]) { + input.parse::()?; + input.parse::]>()?; + (false, input.parse::()?) + } else { + return Err(lookahead.error()); + }; + Ok(Self{special, code}) } } @@ -761,7 +775,7 @@ mod tests { .ss: StateSpace = { .global, .local, .param{::func}, .shared{::cta, ::cluster} } }; let rule = syn::parse2::(input).unwrap(); - assert_eq!(". ss", rule.modifier.tokens().to_string()); + assert_eq!(". ss", rule.modifier.unwrap().tokens().to_string()); assert_eq!( "StateSpace", rule.type_.unwrap().to_token_stream().to_string() @@ -791,7 +805,7 @@ mod tests { .cop: StCacheOperator = { .wb, .cg, .cs, .wt } }; let rule = syn::parse2::(input).unwrap(); - assert_eq!(". cop", rule.modifier.tokens().to_string()); + assert_eq!(". cop", rule.modifier.unwrap().tokens().to_string()); assert_eq!( "StCacheOperator", rule.type_.unwrap().to_token_stream().to_string() @@ -819,4 +833,12 @@ mod tests { assert!(!a.can_be_negated); assert!(a.unified); } + + #[test] + fn special_block() { + let input = quote! { + bra <= { bra(stream) } + }; + syn::parse2::(input).unwrap(); + } } diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index e456e03a..6f1a9e37 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -31,7 +31,7 @@ gen::generate_instruction_type!( } }, Add { - type: { data.type_().into() }, + type: { Type::from(data.type_()) }, data: ArithDetails, arguments: { dst: T, @@ -51,12 +51,12 @@ gen::generate_instruction_type!( } }, Mul { - type: { data.type_().into() }, + type: { Type::from(data.type_()) }, data: MulDetails, arguments: { dst: { repr: T, - type: { data.dst_type().into() }, + type: { Type::from(data.dst_type()) }, }, src1: T, src2: T, @@ -67,19 +67,19 @@ gen::generate_instruction_type!( arguments: { dst1: { repr: T, - type: ScalarType::Pred.into() + type: Type::from(ScalarType::Pred) }, dst2: { repr: Option, - type: ScalarType::Pred.into() + type: Type::from(ScalarType::Pred) }, src1: { repr: T, - type: data.type_.into(), + type: Type::from(data.type_), }, src2: { repr: T, - type: data.type_.into(), + type: Type::from(data.type_), } } }, @@ -88,26 +88,58 @@ gen::generate_instruction_type!( arguments: { dst1: { repr: T, - type: ScalarType::Pred.into() + type: Type::from(ScalarType::Pred) }, dst2: { repr: Option, - type: ScalarType::Pred.into() + type: Type::from(ScalarType::Pred) }, src1: { repr: T, - type: data.base.type_.into(), + type: Type::from(data.base.type_), }, src2: { repr: T, - type: data.base.type_.into(), + type: Type::from(data.base.type_), }, src3: { repr: T, - type: ScalarType::Pred.into() + type: Type::from(ScalarType::Pred) } } }, + Not { + data: ScalarType, + type: { Type::Scalar(data.clone()) }, + arguments: { + dst: T, + src: T, + } + }, + Or { + data: ScalarType, + type: { Type::Scalar(data.clone()) }, + arguments: { + dst: T, + src1: T, + src2: T, + } + }, + And { + data: ScalarType, + type: { Type::Scalar(data.clone()) }, + arguments: { + dst: T, + src1: T, + src2: T, + } + }, + Bra { + type: !, + arguments: { + src: T + } + }, Ret { data: RetData }, @@ -115,21 +147,26 @@ gen::generate_instruction_type!( } ); -pub trait Visitor { - fn visit(&mut self, args: &T, type_: &Type, space: StateSpace, is_dst: bool); +pub trait Visitor { + fn visit(&mut self, args: &T, type_space: Option<(&Type, StateSpace)>, is_dst: bool); + fn visit_ident(&self, args: &T::Ident, type_space: Option<(&Type, StateSpace)>, is_dst: bool); } -pub trait VisitorMut { - fn visit(&mut self, args: &mut T, type_: &Type, space: StateSpace, is_dst: bool); +pub trait VisitorMut { + fn visit(&mut self, args: &mut T, type_space: Option<(&Type, StateSpace)>, is_dst: bool); + fn visit_ident(&mut self, args: &mut T::Ident, type_space: Option<(&Type, StateSpace)>, is_dst: bool); } -pub trait VisitorMap { - fn visit(&mut self, args: From, type_: &Type, space: StateSpace, is_dst: bool) -> To; +pub trait VisitorMap { + fn visit(&mut self, args: From, type_space: Option<(&Type, StateSpace)>, is_dst: bool) -> To; + fn visit_ident(&mut self, args: From::Ident, type_space: Option<(&Type, StateSpace)>, is_dst: bool) -> To::Ident; } trait VisitOperand { - type Operand; + type Operand: Operand; + #[allow(unused)] // Used by generated code fn visit(&self, fn_: impl FnOnce(&Self::Operand)); + #[allow(unused)] // Used by generated code fn visit_mut(&mut self, fn_: impl FnOnce(&mut Self::Operand)); } @@ -156,6 +193,7 @@ impl VisitOperand for Option { trait MapOperand: Sized { type Input; type Output; + #[allow(unused)] // Used by generated code fn map(self, fn_: impl FnOnce(Self::Input) -> U) -> Self::Output; } @@ -289,12 +327,12 @@ pub enum ParsedOperand { VecPack(Vec), } -impl Operand for ParsedOperand { +impl Operand for ParsedOperand { type Ident = Ident; } pub trait Operand { - type Ident; + type Ident: Copy; } #[derive(Copy, Clone)] @@ -447,6 +485,7 @@ pub enum MulDetails { } impl MulDetails { + #[allow(unused)] // Used by generated code fn type_(&self) -> ScalarType { match self { MulDetails::Integer { type_, .. } => *type_, @@ -454,6 +493,7 @@ impl MulDetails { } } + #[allow(unused)] // Used by generated code fn dst_type(&self) -> ScalarType { match self { MulDetails::Integer { @@ -521,7 +561,7 @@ impl SetpData { pub struct SetpBoolData { pub base: SetpData, pub bool_op: SetpBoolPostOp, - pub negate_src3: bool + pub negate_src3: bool, } #[derive(PartialEq, Eq, Copy, Clone)] diff --git a/ptx_parser/src/main.rs b/ptx_parser/src/main.rs index 785496d1..a6a23818 100644 --- a/ptx_parser/src/main.rs +++ b/ptx_parser/src/main.rs @@ -623,7 +623,7 @@ fn predicated_instruction<'a, 'input>( } fn pred_at<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult> { - (Token::At, opt(Token::Not), ident) + (Token::At, opt(Token::Exclamation), ident) .map(|(_, not, label)| ast::PredAt { not: not.is_some(), label, @@ -888,6 +888,21 @@ impl<'input, I: Stream + StreamIsPartial, E: ParserError> Parse } } +fn bra<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult>> { + preceded( + opt(Token::DotUni), + any.verify_map(|t| match t { + Token::Ident(ident) => Some(ast::Instruction::Bra { + arguments: BraArgs { src: ident }, + }), + _ => None, + }), + ) + .parse_next(stream) +} + // Modifiers are turned into arguments to the blocks, with type: // * If it is an alternative: // * If it is mandatory then its type is Foo (as defined by the relevant rule) @@ -919,9 +934,9 @@ derive_parser!( #[regex(r#""[^"]*""#)] String, #[token("|")] - Or, + Pipe, #[token("!")] - Not, + Exclamation, #[token("(")] LParen, #[token(")")] @@ -1461,6 +1476,36 @@ derive_parser!( .f32, .f64, .f16, .f16x2, .bf16, .bf16x2 }; + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-not + not.type d, a => { + ast::Instruction::Not { + data: type_, + arguments: NotArgs { dst: d, src: a } + } + } + .type: ScalarType = { .pred, .b16, .b32, .b64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-or + or.type d, a, b => { + ast::Instruction::Or { + data: type_, + arguments: OrArgs { dst: d, src1: a, src2: b } + } + } + .type: ScalarType = { .pred, .b16, .b32, .b64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-and + and.type d, a, b => { + ast::Instruction::And { + data: type_, + arguments: AndArgs { dst: d, src1: a, src2: b } + } + } + .type: ScalarType = { .pred, .b16, .b32, .b64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-bra + bra <= { bra(stream) } + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-ret ret{.uni} => { Instruction::Ret { data: RetData { uniform: uni } } From 34b0a67f0aba939524830463c308ed941e76fd43 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Tue, 20 Aug 2024 02:58:58 +0200 Subject: [PATCH 13/47] Add types for call instruction --- gen_impl/src/lib.rs | 129 ++++++++++++++++++++++++++++++++++-------- ptx_parser/src/ast.rs | 109 ++++++++++++++++++++++++++++++++--- 2 files changed, 207 insertions(+), 31 deletions(-) diff --git a/gen_impl/src/lib.rs b/gen_impl/src/lib.rs index 39cc30e7..08911ec1 100644 --- a/gen_impl/src/lib.rs +++ b/gen_impl/src/lib.rs @@ -1,8 +1,8 @@ use proc_macro2::TokenStream; use quote::{format_ident, quote, ToTokens}; use syn::{ - braced, parse::Parse, punctuated::Punctuated, token, Expr, Ident, PathSegment, Token, Type, - TypeParam, Visibility, + braced, parse::Parse, punctuated::Punctuated, token, Expr, Ident, LitBool, PathSegment, Token, + Type, TypeParam, Visibility, }; pub mod parser; @@ -168,7 +168,10 @@ pub struct InstructionVariant { pub type_: Option>, pub space: Option, pub data: Option, - pub arguments: Option, + pub arguments: Option, + pub visit: Option, + pub visit_mut: Option, + pub map: Option, } impl InstructionVariant { @@ -194,17 +197,23 @@ impl InstructionVariant { } Some(args) => { let args_name = self.args_name(); - match &args.generic { - None => { + match &args { + Arguments::Def(InstructionArguments { generic: None, .. }) => { quote! { arguments: #args_name, } } - Some(generics) => { + Arguments::Def(InstructionArguments { + generic: Some(generics), + .. + }) => { quote! { arguments: #args_name <#generics>, } } + Arguments::Decl(type_) => quote! { + arguments: #type_, + }, } } }; @@ -214,15 +223,21 @@ impl InstructionVariant { } fn emit_visit(&self, enum_: &Ident, tokens: &mut TokenStream) { - self.emit_visit_impl(enum_, tokens, InstructionArguments::emit_visit) + self.emit_visit_impl(&self.visit, enum_, tokens, InstructionArguments::emit_visit) } fn emit_visit_mut(&self, enum_: &Ident, tokens: &mut TokenStream) { - self.emit_visit_impl(enum_, tokens, InstructionArguments::emit_visit_mut) + self.emit_visit_impl( + &self.visit_mut, + enum_, + tokens, + InstructionArguments::emit_visit_mut, + ) } fn emit_visit_impl( &self, + visit_fn: &Option, enum_: &Ident, tokens: &mut TokenStream, mut fn_: impl FnMut(&InstructionArguments, &Option>, &Option) -> TokenStream, @@ -236,7 +251,14 @@ impl InstructionVariant { .to_tokens(tokens); return; } - Some(args) => args, + Some(Arguments::Decl(_)) => { + quote! { + #enum_ :: #name { data, arguments } => { #visit_fn } + } + .to_tokens(tokens); + return; + } + Some(Arguments::Def(args)) => args, }; let data = &self.data.as_ref().map(|_| quote! { data,}); let arg_calls = fn_(arguments, &self.type_, &self.space); @@ -250,10 +272,24 @@ impl InstructionVariant { fn emit_visit_map(&self, enum_: &Ident, tokens: &mut TokenStream) { let name = &self.name; - let arguments = &self.arguments.as_ref().map(|_| quote! { arguments,}); let data = &self.data.as_ref().map(|_| quote! { data,}); + let arguments = match self.arguments { + None => None, + Some(Arguments::Decl(_)) => { + let map = self.map.as_ref().unwrap(); + quote! { + #enum_ :: #name { #data arguments } => { + #map + } + } + .to_tokens(tokens); + return; + } + Some(Arguments::Def(ref def)) => Some(def), + }; + let arguments_ident = &self.arguments.as_ref().map(|_| quote! { arguments,}); let mut arg_calls = None; - let arguments_init = self.arguments.as_ref().map(|arguments| { + let arguments_init = arguments.as_ref().map(|arguments| { let arg_type = self.args_name(); arg_calls = Some(arguments.emit_visit_map(&self.type_, &self.space)); let arg_names = arguments.fields.iter().map(|arg| &arg.name); @@ -262,7 +298,7 @@ impl InstructionVariant { } }); quote! { - #enum_ :: #name { #data #arguments } => { + #enum_ :: #name { #data #arguments_ident } => { #arg_calls #enum_ :: #name { #data #arguments_init } } @@ -272,7 +308,8 @@ impl InstructionVariant { fn emit_type(&self, vis: &Option, tokens: &mut TokenStream) { let arguments = match self.arguments { - Some(ref a) => a, + Some(Arguments::Def(ref a)) => a, + Some(Arguments::Decl(_)) => return, None => return, }; let name = self.args_name(); @@ -301,12 +338,18 @@ impl Parse for InstructionVariant { let mut space = None; let mut data = None; let mut arguments = None; + let mut visit = None; + let mut visit_mut = None; + let mut map = None; for property in properties { match property { VariantProperty::Type(t) => type_ = Some(t), VariantProperty::Space(s) => space = Some(s), VariantProperty::Data(d) => data = Some(d), VariantProperty::Arguments(a) => arguments = Some(a), + VariantProperty::Visit(e) => visit = Some(e), + VariantProperty::VisitMut(e) => visit_mut = Some(e), + VariantProperty::Map(e) => map = Some(e), } } Ok(Self { @@ -315,6 +358,9 @@ impl Parse for InstructionVariant { space, data, arguments, + visit, + visit_mut, + map, }) } } @@ -323,7 +369,10 @@ enum VariantProperty { Type(Option), Space(Expr), Data(Type), - Arguments(InstructionArguments), + Arguments(Arguments), + Visit(Expr), + VisitMut(Expr), + Map(Expr), } impl VariantProperty { @@ -360,15 +409,33 @@ impl VariantProperty { None }; input.parse::()?; - let fields; - braced!(fields in input); - VariantProperty::Arguments(InstructionArguments::parse(generics, &fields)?) + if input.peek(token::Brace) { + let fields; + braced!(fields in input); + VariantProperty::Arguments(Arguments::Def(InstructionArguments::parse( + generics, &fields, + )?)) + } else { + VariantProperty::Arguments(Arguments::Decl(input.parse::()?)) + } + } + "visit" => { + input.parse::()?; + VariantProperty::Visit(input.parse::()?) + } + "visit_mut" => { + input.parse::()?; + VariantProperty::VisitMut(input.parse::()?) + } + "map" => { + input.parse::()?; + VariantProperty::Map(input.parse::()?) } x => { return Err(syn::Error::new( key.span(), format!( - "Unexpected key `{}`. Expected `type`, `data` or `arguments`.", + "Unexpected key `{}`. Expected `type`, `data`, `arguments`, `visit, `visit_mut` or `map`.", x ), )) @@ -380,6 +447,11 @@ impl VariantProperty { } } +pub enum Arguments { + Decl(Type), + Def(InstructionArguments), +} + pub struct InstructionArguments { pub generic: Option>, pub fields: Punctuated, @@ -453,7 +525,7 @@ pub struct ArgumentField { impl ArgumentField { fn parse_block( input: syn::parse::ParseStream, - ) -> syn::Result<(Type, Option, Option)> { + ) -> syn::Result<(Type, Option, Option, Option)> { let content; braced!(content in input); let all_fields = @@ -469,6 +541,10 @@ impl ArgumentField { match &*name_ident.to_string() { "repr" => ExprOrPath::Repr(content.parse::()?), "space" => ExprOrPath::Space(content.parse::()?), + "dst" => { + let ident = content.parse::()?; + ExprOrPath::Dst(ident.value) + } name => { return Err(syn::Error::new( name_ident.span(), @@ -483,14 +559,16 @@ impl ArgumentField { let mut repr = None; let mut type_ = None; let mut space = None; + let mut is_dst = None; for exp_or_path in all_fields { match exp_or_path { ExprOrPath::Repr(r) => repr = Some(r), ExprOrPath::Type(t) => type_ = Some(t), ExprOrPath::Space(s) => space = Some(s), + ExprOrPath::Dst(x) => is_dst = Some(x), } } - Ok((repr.unwrap(), type_, space)) + Ok((repr.unwrap(), type_, space, is_dst)) } fn parse_basic(input: &syn::parse::ParseBuffer) -> syn::Result { @@ -666,16 +744,20 @@ impl ArgumentField { impl Parse for ArgumentField { fn parse(input: syn::parse::ParseStream) -> syn::Result { let name = input.parse::()?; - let is_dst = Self::is_dst(&name)?; + input.parse::()?; let lookahead = input.lookahead1(); - let (repr, type_, space) = if lookahead.peek(token::Brace) { + let (repr, type_, space, is_dst) = if lookahead.peek(token::Brace) { Self::parse_block(input)? } else if lookahead.peek(syn::Ident) { - (Self::parse_basic(input)?, None, None) + (Self::parse_basic(input)?, None, None, None) } else { return Err(lookahead.error()); }; + let is_dst = match is_dst { + Some(x) => x, + None => Self::is_dst(&name)?, + }; Ok(Self { name, is_dst, @@ -690,6 +772,7 @@ enum ExprOrPath { Repr(Type), Type(Expr), Space(Expr), + Dst(bool), } #[cfg(test)] diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index 6f1a9e37..ab0fc58c 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -140,6 +140,13 @@ gen::generate_instruction_type!( src: T } }, + Call { + data: CallDetails, + arguments: CallArgs, + visit: arguments.visit(data, visitor), + visit_mut: arguments.visit_mut(data, visitor), + map: Instruction::Call{ arguments: arguments.map(&data, visitor), data } + }, Ret { data: RetData }, @@ -154,42 +161,66 @@ pub trait Visitor { pub trait VisitorMut { fn visit(&mut self, args: &mut T, type_space: Option<(&Type, StateSpace)>, is_dst: bool); - fn visit_ident(&mut self, args: &mut T::Ident, type_space: Option<(&Type, StateSpace)>, is_dst: bool); + fn visit_ident( + &mut self, + args: &mut T::Ident, + type_space: Option<(&Type, StateSpace)>, + is_dst: bool, + ); } pub trait VisitorMap { fn visit(&mut self, args: From, type_space: Option<(&Type, StateSpace)>, is_dst: bool) -> To; - fn visit_ident(&mut self, args: From::Ident, type_space: Option<(&Type, StateSpace)>, is_dst: bool) -> To::Ident; + fn visit_ident( + &mut self, + args: From::Ident, + type_space: Option<(&Type, StateSpace)>, + is_dst: bool, + ) -> To::Ident; } trait VisitOperand { type Operand: Operand; #[allow(unused)] // Used by generated code - fn visit(&self, fn_: impl FnOnce(&Self::Operand)); + fn visit(&self, fn_: impl FnMut(&Self::Operand)); #[allow(unused)] // Used by generated code - fn visit_mut(&mut self, fn_: impl FnOnce(&mut Self::Operand)); + fn visit_mut(&mut self, fn_: impl FnMut(&mut Self::Operand)); } impl VisitOperand for T { type Operand = Self; - fn visit(&self, fn_: impl FnOnce(&Self::Operand)) { + fn visit(&self, mut fn_: impl FnMut(&Self::Operand)) { fn_(self) } - fn visit_mut(&mut self, fn_: impl FnOnce(&mut Self::Operand)) { + fn visit_mut(&mut self, mut fn_: impl FnMut(&mut Self::Operand)) { fn_(self) } } impl VisitOperand for Option { type Operand = T; - fn visit(&self, fn_: impl FnOnce(&Self::Operand)) { + fn visit(&self, fn_: impl FnMut(&Self::Operand)) { self.as_ref().map(fn_); } - fn visit_mut(&mut self, fn_: impl FnOnce(&mut Self::Operand)) { + fn visit_mut(&mut self, fn_: impl FnMut(&mut Self::Operand)) { self.as_mut().map(fn_); } } +impl VisitOperand for Vec { + type Operand = T; + fn visit(&self, mut fn_: impl FnMut(&Self::Operand)) { + for o in self { + fn_(o) + } + } + fn visit_mut(&mut self, mut fn_: impl FnMut(&mut Self::Operand)) { + for o in self { + fn_(o) + } + } +} + trait MapOperand: Sized { type Input; type Output; @@ -649,3 +680,65 @@ impl From for SetpCompareFloat { } } } + +pub struct CallDetails { + uniform: bool, + ret_params: Vec<(Type, StateSpace)>, + param_list: Vec<(Type, StateSpace)>, +} + +pub struct CallArgs { + pub ret_params: Vec, + pub func: T::Ident, + pub param_list: Vec, +} + +impl CallArgs { + #[allow(dead_code)] // Used by generated code + fn visit(&self, details: &CallDetails, visitor: &mut impl Visitor) { + for (param, (type_, space)) in self.ret_params.iter().zip(details.ret_params.iter()) { + visitor.visit_ident(param, Some((type_, *space)), true); + } + visitor.visit_ident(&self.func, None, false); + for (param, (type_, space)) in self.param_list.iter().zip(details.param_list.iter()) { + visitor.visit(param, Some((type_, *space)), true); + } + } + + #[allow(dead_code)] // Used by generated code + fn visit_mut(&mut self, details: &CallDetails, visitor: &mut impl VisitorMut) { + for (param, (type_, space)) in self.ret_params.iter_mut().zip(details.ret_params.iter()) { + visitor.visit_ident(param, Some((type_, *space)), true); + } + visitor.visit_ident(&mut self.func, None, false); + for (param, (type_, space)) in self.param_list.iter_mut().zip(details.param_list.iter()) { + visitor.visit(param, Some((type_, *space)), true); + } + } + + #[allow(dead_code)] // Used by generated code + fn map( + self, + details: &CallDetails, + visitor: &mut impl VisitorMap, + ) -> CallArgs { + let ret_params = self + .ret_params + .into_iter() + .zip(details.ret_params.iter()) + .map(|(param, (type_, space))| visitor.visit_ident(param, Some((type_, *space)), true)) + .collect::>(); + let func = visitor.visit_ident(self.func, None, false); + let param_list = self + .param_list + .into_iter() + .zip(details.param_list.iter()) + .map(|(param, (type_, space))| visitor.visit(param, Some((type_, *space)), true)) + .collect::>(); + CallArgs { + ret_params, + func, + param_list, + } + } +} From c21c55dfc22d90d84f89085ce29412965daff22f Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Tue, 20 Aug 2024 03:53:18 +0200 Subject: [PATCH 14/47] Parse call instruction --- ptx_parser/Cargo.toml | 3 +- ptx_parser/src/ast.rs | 56 +++++++++------ ptx_parser/src/main.rs | 157 ++++++++++++++++++++++++++++++++++++----- 3 files changed, 177 insertions(+), 39 deletions(-) diff --git a/ptx_parser/Cargo.toml b/ptx_parser/Cargo.toml index 4f32860c..35251ee8 100644 --- a/ptx_parser/Cargo.toml +++ b/ptx_parser/Cargo.toml @@ -5,7 +5,8 @@ edition = "2021" [dependencies] logos = "0.14" -winnow = { version = "0.6.18", features = ["debug"] } +winnow = { version = "0.6.18" } gen = { path = "../gen" } thiserror = "1.0" bitflags = "1.2" +rustc-hash = "2.0.0" diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index ab0fc58c..0dabd5d3 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -557,7 +557,7 @@ pub struct SetpData { impl SetpData { pub(crate) fn try_parse( - errors: &mut PtxParserState, + state: &mut PtxParserState, cmp_op: super::RawSetpCompareOp, ftz: bool, type_: ScalarType, @@ -565,7 +565,7 @@ impl SetpData { let flush_to_zero = match (ftz, type_) { (_, ScalarType::F32) => Some(ftz), _ => { - errors.push(PtxError::NonF32Ftz); + state.errors.push(PtxError::NonF32Ftz); None } }; @@ -576,7 +576,7 @@ impl SetpData { match SetpCompareInt::try_from(cmp_op) { Ok(op) => SetpCompareOp::Integer(op), Err(err) => { - errors.push(err); + state.errors.push(err); SetpCompareOp::Integer(SetpCompareInt::Eq) } } @@ -682,36 +682,52 @@ impl From for SetpCompareFloat { } pub struct CallDetails { - uniform: bool, - ret_params: Vec<(Type, StateSpace)>, - param_list: Vec<(Type, StateSpace)>, + pub uniform: bool, + pub return_arguments: Vec<(Type, StateSpace)>, + pub input_arguments: Vec<(Type, StateSpace)>, } pub struct CallArgs { - pub ret_params: Vec, + pub return_arguments: Vec, pub func: T::Ident, - pub param_list: Vec, + pub input_arguments: Vec, } impl CallArgs { #[allow(dead_code)] // Used by generated code fn visit(&self, details: &CallDetails, visitor: &mut impl Visitor) { - for (param, (type_, space)) in self.ret_params.iter().zip(details.ret_params.iter()) { + for (param, (type_, space)) in self + .return_arguments + .iter() + .zip(details.return_arguments.iter()) + { visitor.visit_ident(param, Some((type_, *space)), true); } visitor.visit_ident(&self.func, None, false); - for (param, (type_, space)) in self.param_list.iter().zip(details.param_list.iter()) { + for (param, (type_, space)) in self + .input_arguments + .iter() + .zip(details.input_arguments.iter()) + { visitor.visit(param, Some((type_, *space)), true); } } #[allow(dead_code)] // Used by generated code fn visit_mut(&mut self, details: &CallDetails, visitor: &mut impl VisitorMut) { - for (param, (type_, space)) in self.ret_params.iter_mut().zip(details.ret_params.iter()) { + for (param, (type_, space)) in self + .return_arguments + .iter_mut() + .zip(details.return_arguments.iter()) + { visitor.visit_ident(param, Some((type_, *space)), true); } visitor.visit_ident(&mut self.func, None, false); - for (param, (type_, space)) in self.param_list.iter_mut().zip(details.param_list.iter()) { + for (param, (type_, space)) in self + .input_arguments + .iter_mut() + .zip(details.input_arguments.iter()) + { visitor.visit(param, Some((type_, *space)), true); } } @@ -722,23 +738,23 @@ impl CallArgs { details: &CallDetails, visitor: &mut impl VisitorMap, ) -> CallArgs { - let ret_params = self - .ret_params + let return_arguments = self + .return_arguments .into_iter() - .zip(details.ret_params.iter()) + .zip(details.return_arguments.iter()) .map(|(param, (type_, space))| visitor.visit_ident(param, Some((type_, *space)), true)) .collect::>(); let func = visitor.visit_ident(self.func, None, false); - let param_list = self - .param_list + let input_arguments = self + .input_arguments .into_iter() - .zip(details.param_list.iter()) + .zip(details.input_arguments.iter()) .map(|(param, (type_, space))| visitor.visit(param, Some((type_, *space)), true)) .collect::>(); CallArgs { - ret_params, + return_arguments, func, - param_list, + input_arguments, } } } diff --git a/ptx_parser/src/main.rs b/ptx_parser/src/main.rs index a6a23818..2c602d54 100644 --- a/ptx_parser/src/main.rs +++ b/ptx_parser/src/main.rs @@ -1,5 +1,7 @@ use gen::derive_parser; use logos::Logos; +use rustc_hash::FxHashMap; +use std::fmt::Debug; use std::mem; use std::num::{ParseFloatError, ParseIntError}; use winnow::ascii::dec_uint; @@ -69,8 +71,49 @@ impl From for ast::RoundingMode { } } -type PtxParserState = Vec; -type PtxParser<'a, 'input> = Stateful<&'a [Token<'input>], PtxParserState>; +struct PtxParserState<'input> { + errors: Vec, + function_declarations: + FxHashMap<&'input str, (Vec<(ast::Type, StateSpace)>, Vec<(ast::Type, StateSpace)>)>, +} + +impl<'input> PtxParserState<'input> { + fn new() -> Self { + Self { + errors: Vec::new(), + function_declarations: FxHashMap::default(), + } + } + + fn record_function(&mut self, function_decl: &MethodDeclaration<'input, &'input str>) { + let name = match function_decl.name { + MethodName::Kernel(name) => name, + MethodName::Func(name) => name, + }; + let return_arguments = Self::get_type_space(&*function_decl.return_arguments); + let input_arguments = Self::get_type_space(&*function_decl.input_arguments); + // TODO: check if declarations match + self.function_declarations + .insert(name, (return_arguments, input_arguments)); + } + + fn get_type_space(input_arguments: &[Variable<&str>]) -> Vec<(Type, StateSpace)> { + input_arguments + .iter() + .map(|var| (var.v_type.clone(), var.state_space)) + .collect::>() + } +} + +impl<'input> Debug for PtxParserState<'input> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("PtxParserState") + .field("errors", &self.errors) /* .field("function_decl", &self.function_decl) */ + .finish() + } +} + +type PtxParser<'a, 'input> = Stateful<&'a [Token<'input>], PtxParserState<'input>>; fn ident<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<&'input str> { any.verify_map(|t| { @@ -127,7 +170,7 @@ fn take_error<'a, 'input: 'a, O, E>( Ok(match parser.parse_next(input)? { Ok(x) => x, Err((x, err)) => { - input.state.push(err); + input.state.errors.push(err); x } }) @@ -353,7 +396,7 @@ fn function<'a, 'input>( ast::LinkingDirective, ast::Function<'input, &'input str, ast::Statement>>, )> { - ( + let (linking, function) = ( linking_directives, method_declaration, repeat(0.., tuning_directive), @@ -369,7 +412,9 @@ fn function<'a, 'input>( }, ) }) - .parse_next(stream) + .parse_next(stream)?; + stream.state.record_function(&function.func_directive); + Ok((linking, function)) } fn linking_directives<'a, 'input>( @@ -771,6 +816,10 @@ pub enum PtxError { #[error("")] WrongType, #[error("")] + UnknownFunction, + #[error("")] + MalformedCall, + #[error("")] WrongArrayType, #[error("")] WrongVectorElement, @@ -903,6 +952,74 @@ fn bra<'a, 'input>( .parse_next(stream) } +fn call<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult>> { + let (uni, return_arguments, name, input_arguments) = ( + opt(Token::DotUni), + opt(( + Token::LParen, + separated(1.., ident, Token::Comma).map(|x: Vec<_>| x), + Token::RParen, + Token::Comma, + ) + .map(|(_, arguments, _, _)| arguments)), + ident, + opt(( + Token::Comma.void(), + Token::LParen.void(), + separated(1.., ParsedOperand::<&'input str>::parse, Token::Comma).map(|x: Vec<_>| x), + Token::RParen.void(), + ) + .map(|(_, _, arguments, _)| arguments)), + ) + .parse_next(stream)?; + let uniform = uni.is_some(); + let recorded_fn = match stream.state.function_declarations.get(name) { + Some(decl) => decl, + None => { + stream.state.errors.push(PtxError::UnknownFunction); + return Ok(empty_call(uniform, name)); + } + }; + let return_arguments = return_arguments.unwrap_or(Vec::new()); + let input_arguments = input_arguments.unwrap_or(Vec::new()); + if recorded_fn.0.len() != return_arguments.len() || recorded_fn.1.len() != input_arguments.len() + { + stream.state.errors.push(PtxError::MalformedCall); + return Ok(empty_call(uniform, name)); + } + let data = CallDetails { + uniform, + return_arguments: recorded_fn.0.clone(), + input_arguments: recorded_fn.1.clone(), + }; + let arguments = CallArgs { + return_arguments, + func: name, + input_arguments, + }; + Ok(ast::Instruction::Call { data, arguments }) +} + +fn empty_call<'input>( + uniform: bool, + name: &'input str, +) -> ast::Instruction> { + ast::Instruction::Call { + data: CallDetails { + uniform, + return_arguments: Vec::new(), + input_arguments: Vec::new(), + }, + arguments: CallArgs { + return_arguments: Vec::new(), + func: name, + input_arguments: Vec::new(), + }, + } +} + // Modifiers are turned into arguments to the blocks, with type: // * If it is an alternative: // * If it is mandatory then its type is Foo (as defined by the relevant rule) @@ -1033,7 +1150,7 @@ derive_parser!( // https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-st st{.weak}{.ss}{.cop}{.level::eviction_priority}{.level::cache_hint}{.vec}.type [a], b{, cache_policy} => { if level_eviction_priority.is_some() || level_cache_hint || cache_policy.is_some() { - state.push(PtxError::Todo); + state.errors.push(PtxError::Todo); } Instruction::St { data: StData { @@ -1058,7 +1175,7 @@ derive_parser!( } st.relaxed.scope{.ss}{.level::eviction_priority}{.level::cache_hint}{.vec}.type [a], b{, cache_policy} => { if level_eviction_priority.is_some() || level_cache_hint || cache_policy.is_some() { - state.push(PtxError::Todo); + state.errors.push(PtxError::Todo); } Instruction::St { data: StData { @@ -1072,7 +1189,7 @@ derive_parser!( } st.release.scope{.ss}{.level::eviction_priority}{.level::cache_hint}{.vec}.type [a], b{, cache_policy} => { if level_eviction_priority.is_some() || level_cache_hint || cache_policy.is_some() { - state.push(PtxError::Todo); + state.errors.push(PtxError::Todo); } Instruction::St { data: StData { @@ -1085,7 +1202,7 @@ derive_parser!( } } st.mmio.relaxed.sys{.global}.type [a], b => { - state.push(PtxError::Todo); + state.errors.push(PtxError::Todo); Instruction::St { data: ast::StData { qualifier: ast::LdStQualifier::Relaxed(MemScope::Sys), @@ -1114,7 +1231,7 @@ derive_parser!( ld{.weak}{.ss}{.cop}{.level::eviction_priority}{.level::cache_hint}{.level::prefetch_size}{.vec}.type d, [a]{.unified}{, cache_policy} => { let (a, unified) = a; if level_eviction_priority.is_some() || level_cache_hint || level_prefetch_size.is_some() || unified || cache_policy.is_some() { - state.push(PtxError::Todo); + state.errors.push(PtxError::Todo); } Instruction::Ld { data: LdDetails { @@ -1129,7 +1246,7 @@ derive_parser!( } ld.volatile{.ss}{.level::prefetch_size}{.vec}.type d, [a] => { if level_prefetch_size.is_some() { - state.push(PtxError::Todo); + state.errors.push(PtxError::Todo); } Instruction::Ld { data: LdDetails { @@ -1144,7 +1261,7 @@ derive_parser!( } ld.relaxed.scope{.ss}{.level::eviction_priority}{.level::cache_hint}{.level::prefetch_size}{.vec}.type d, [a]{, cache_policy} => { if level_eviction_priority.is_some() || level_cache_hint || level_prefetch_size.is_some() || cache_policy.is_some() { - state.push(PtxError::Todo); + state.errors.push(PtxError::Todo); } Instruction::Ld { data: LdDetails { @@ -1159,7 +1276,7 @@ derive_parser!( } ld.acquire.scope{.ss}{.level::eviction_priority}{.level::cache_hint}{.level::prefetch_size}{.vec}.type d, [a]{, cache_policy} => { if level_eviction_priority.is_some() || level_cache_hint || level_prefetch_size.is_some() || cache_policy.is_some() { - state.push(PtxError::Todo); + state.errors.push(PtxError::Todo); } Instruction::Ld { data: LdDetails { @@ -1173,7 +1290,7 @@ derive_parser!( } } ld.mmio.relaxed.sys{.global}.type d, [a] => { - state.push(PtxError::Todo); + state.errors.push(PtxError::Todo); Instruction::Ld { data: LdDetails { qualifier: ast::LdStQualifier::Relaxed(MemScope::Sys), @@ -1506,6 +1623,9 @@ derive_parser!( // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-bra bra <= { bra(stream) } + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-call + call <= { call(stream) } + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-ret ret{.uni} => { Instruction::Ret { data: RetData { uniform: uni } } @@ -1558,7 +1678,7 @@ fn main() { println!("{:?}", &tokens); let stream = PtxParser { input: &tokens[..], - state: Vec::new(), + state: PtxParserState::new(), }; let _module = module.parse(stream).unwrap(); println!("{}", mem::size_of::()); @@ -1567,6 +1687,7 @@ fn main() { #[cfg(test)] mod tests { use super::target; + use super::PtxParserState; use super::Token; use logos::Logos; use winnow::prelude::*; @@ -1578,7 +1699,7 @@ mod tests { .unwrap(); let stream = super::PtxParser { input: &tokens[..], - state: Vec::new(), + state: PtxParserState::new(), }; assert_eq!(target.parse(stream).unwrap(), (11, None)); } @@ -1590,7 +1711,7 @@ mod tests { .unwrap(); let stream = super::PtxParser { input: &tokens[..], - state: Vec::new(), + state: PtxParserState::new(), }; assert_eq!(target.parse(stream).unwrap(), (90, Some('a'))); } @@ -1602,7 +1723,7 @@ mod tests { .unwrap(); let stream = super::PtxParser { input: &tokens[..], - state: Vec::new(), + state: PtxParserState::new(), }; assert!(target.parse(stream).is_err()); } From bc1074ed6723da9dd10e5847dca17069bc136847 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Tue, 20 Aug 2024 17:59:39 +0200 Subject: [PATCH 15/47] Add cvt --- gen_impl/src/lib.rs | 6 +- ptx_parser/src/ast.rs | 187 ++++++++++++++++++++++++++++++++++++++++- ptx_parser/src/main.rs | 47 ++++++++--- 3 files changed, 228 insertions(+), 12 deletions(-) diff --git a/gen_impl/src/lib.rs b/gen_impl/src/lib.rs index 08911ec1..4532964b 100644 --- a/gen_impl/src/lib.rs +++ b/gen_impl/src/lib.rs @@ -847,7 +847,11 @@ mod tests { assert_eq!("ScalarType :: U32", to_string(variant.type_.unwrap())); assert_eq!("StateSpace :: Global", to_string(variant.space.unwrap())); assert_eq!("LdDetails", to_string(variant.data.unwrap())); - let arguments = variant.arguments.unwrap(); + let arguments = if let Some(Arguments::Def(a)) = variant.arguments { + a + } else { + panic!() + }; assert_eq!("P", to_string(arguments.generic)); let mut fields = arguments.fields.into_iter(); let dst = fields.next().unwrap(); diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index 0dabd5d3..daee9daa 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -1,4 +1,9 @@ -use super::{MemScope, RawSetpCompareOp, ScalarType, SetpBoolPostOp, StateSpace, VectorPrefix}; +use std::cmp::Ordering; + +use super::{ + MemScope, RawRoundingMode, RawSetpCompareOp, ScalarType, SetpBoolPostOp, StateSpace, + VectorPrefix, +}; use crate::{PtxError, PtxParserState}; use bitflags::bitflags; @@ -147,6 +152,19 @@ gen::generate_instruction_type!( visit_mut: arguments.visit_mut(data, visitor), map: Instruction::Call{ arguments: arguments.map(&data, visitor), data } }, + Cvt { + data: CvtDetails, + arguments: { + dst: { + repr: T, + type: { Type::Scalar(data.to) }, + }, + src: { + repr: T, + type: { Type::Scalar(data.from) }, + }, + } + }, Ret { data: RetData }, @@ -284,6 +302,28 @@ impl Type { } impl ScalarType { + pub fn size_of(self) -> u8 { + match self { + ScalarType::U8 | ScalarType::S8 | ScalarType::B8 => 1, + ScalarType::U16 + | ScalarType::S16 + | ScalarType::B16 + | ScalarType::F16 + | ScalarType::BF16 => 2, + ScalarType::U32 + | ScalarType::S32 + | ScalarType::B32 + | ScalarType::F32 + | ScalarType::U16x2 + | ScalarType::S16x2 + | ScalarType::F16x2 + | ScalarType::BF16x2 => 4, + ScalarType::U64 | ScalarType::S64 | ScalarType::B64 | ScalarType::F64 => 8, + ScalarType::B128 => 16, + ScalarType::Pred => 1, + } + } + pub fn kind(self) -> ScalarKind { match self { ScalarType::U8 => ScalarKind::Unsigned, @@ -758,3 +798,148 @@ impl CallArgs { } } } + +pub struct CvtDetails { + from: ScalarType, + to: ScalarType, + mode: CvtMode, +} + +pub enum CvtMode { + // int from int + ZeroExtend, + SignExtend, + Truncate, + Bitcast, + // float from float + FPExtend { + flush_to_zero: Option, + }, + FPTruncate { + // float rounding + rounding: RoundingMode, + flush_to_zero: Option, + }, + FPRound { + integer_rounding: Option, + flush_to_zero: Option, + }, + // int from float + SignedFromFP { + rounding: RoundingMode, + flush_to_zero: Option, + }, // integer rounding + UnsignedFromFP { + rounding: RoundingMode, + flush_to_zero: Option, + }, // integer rounding + // float from int, ftz is allowed in the grammar, but clearly nonsensical + FPFromSigned(RoundingMode), // float rounding + FPFromUnsigned(RoundingMode), // float rounding +} + +impl CvtDetails { + pub(crate) fn new( + errors: &mut Vec, + rnd: Option, + ftz: bool, + saturate: bool, + dst: ScalarType, + src: ScalarType, + ) -> Self { + if saturate { + errors.push(PtxError::Todo); + } + // Modifier .ftz can only be specified when either .dtype or .atype is .f32 and applies only to single precision (.f32) inputs and results. + let flush_to_zero = match (dst, src) { + (ScalarType::F32, _) | (_, ScalarType::F32) => Some(ftz), + _ => { + if ftz { + errors.push(PtxError::NonF32Ftz); + } + None + } + }; + let rounding = rnd.map(Into::into); + let mut unwrap_rounding = || match rounding { + Some(rnd) => rnd, + None => { + errors.push(PtxError::SyntaxError); + RoundingMode::NearestEven + } + }; + let mode = match (dst.kind(), src.kind()) { + (ScalarKind::Float, ScalarKind::Float) => match dst.size_of().cmp(&src.size_of()) { + Ordering::Less => CvtMode::FPTruncate { + rounding: unwrap_rounding(), + flush_to_zero, + }, + Ordering::Equal => CvtMode::FPRound { + integer_rounding: rounding, + flush_to_zero, + }, + Ordering::Greater => { + if rounding.is_some() { + errors.push(PtxError::SyntaxError); + } + CvtMode::FPExtend { flush_to_zero } + } + }, + (ScalarKind::Unsigned, ScalarKind::Float) => CvtMode::UnsignedFromFP { + rounding: unwrap_rounding(), + flush_to_zero, + }, + (ScalarKind::Signed, ScalarKind::Float) => CvtMode::SignedFromFP { + rounding: unwrap_rounding(), + flush_to_zero, + }, + (ScalarKind::Float, ScalarKind::Unsigned) => CvtMode::FPFromUnsigned(unwrap_rounding()), + (ScalarKind::Float, ScalarKind::Signed) => CvtMode::FPFromSigned(unwrap_rounding()), + ( + ScalarKind::Unsigned | ScalarKind::Signed, + ScalarKind::Unsigned | ScalarKind::Signed, + ) => match dst.size_of().cmp(&src.size_of()) { + Ordering::Less => { + if dst.kind() != src.kind() { + errors.push(PtxError::Todo); + } + CvtMode::Truncate + } + Ordering::Equal => CvtMode::Bitcast, + Ordering::Greater => { + if dst.kind() != src.kind() { + errors.push(PtxError::Todo); + } + if src.kind() == ScalarKind::Signed { + CvtMode::SignExtend + } else { + CvtMode::ZeroExtend + } + } + }, + (_, _) => { + errors.push(PtxError::SyntaxError); + CvtMode::Bitcast + } + }; + CvtDetails { + mode, + to: dst, + from: src, + } + } +} + +pub struct CvtIntToIntDesc { + pub dst: ScalarType, + pub src: ScalarType, + pub saturate: bool, +} + +pub struct CvtDesc { + pub rounding: Option, + pub flush_to_zero: Option, + pub saturate: bool, + pub dst: ScalarType, + pub src: ScalarType, +} diff --git a/ptx_parser/src/main.rs b/ptx_parser/src/main.rs index 2c602d54..68787dbc 100644 --- a/ptx_parser/src/main.rs +++ b/ptx_parser/src/main.rs @@ -60,13 +60,13 @@ impl From for ast::LdStQualifier { } } -impl From for ast::RoundingMode { - fn from(value: RawFloatRounding) -> Self { +impl From for ast::RoundingMode { + fn from(value: RawRoundingMode) -> Self { match value { - RawFloatRounding::Rn => ast::RoundingMode::NearestEven, - RawFloatRounding::Rz => ast::RoundingMode::Zero, - RawFloatRounding::Rm => ast::RoundingMode::NegativeInf, - RawFloatRounding::Rp => ast::RoundingMode::PositiveInf, + RawRoundingMode::Rn | RawRoundingMode::Rni => ast::RoundingMode::NearestEven, + RawRoundingMode::Rz | RawRoundingMode::Rzi => ast::RoundingMode::Zero, + RawRoundingMode::Rm | RawRoundingMode::Rmi => ast::RoundingMode::NegativeInf, + RawRoundingMode::Rp | RawRoundingMode::Rpi => ast::RoundingMode::PositiveInf, } } } @@ -1380,7 +1380,7 @@ derive_parser!( } } } - .rnd: RawFloatRounding = { .rn, .rz, .rm, .rp }; + .rnd: RawRoundingMode = { .rn, .rz, .rm, .rp }; ScalarType = { .f32, .f64 }; // https://docs.nvidia.com/cuda/parallel-thread-execution/#half-precision-floating-point-instructions-add @@ -1444,7 +1444,7 @@ derive_parser!( } } } - .rnd: RawFloatRounding = { .rn }; + .rnd: RawRoundingMode = { .rn }; ScalarType = { .f16, .f16x2, .bf16, .bf16x2 }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-mul @@ -1502,7 +1502,7 @@ derive_parser!( arguments: MulArgs { dst: d, src1: a, src2: b } } } - .rnd: RawFloatRounding = { .rn, .rz, .rm, .rp }; + .rnd: RawRoundingMode = { .rn, .rz, .rm, .rp }; ScalarType = { .f32, .f64 }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-mul @@ -1558,7 +1558,7 @@ derive_parser!( arguments: MulArgs { dst: d, src1: a, src2: b } } } - .rnd: RawFloatRounding = { .rn }; + .rnd: RawRoundingMode = { .rn }; ScalarType = { .f16, .f16x2, .bf16, .bf16x2 }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#comparison-and-selection-instructions-setp @@ -1626,6 +1626,33 @@ derive_parser!( // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-call call <= { call(stream) } + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt + cvt{.ifrnd}{.ftz}{.sat}.dtype.atype d, a => { + let data = ast::CvtDetails::new(&mut state.errors, ifrnd, ftz, sat, dtype, atype); + let arguments = ast::CvtArgs { dst: d, src: a }; + ast::Instruction::Cvt { + data, arguments + } + } + // cvt.frnd2{.relu}{.satfinite}.f16.f32 d, a; + // cvt.frnd2{.relu}{.satfinite}.f16x2.f32 d, a, b; + // cvt.frnd2{.relu}{.satfinite}.bf16.f32 d, a; + // cvt.frnd2{.relu}{.satfinite}.bf16x2.f32 d, a, b; + // cvt.rna{.satfinite}.tf32.f32 d, a; + // cvt.frnd2{.relu}.tf32.f32 d, a; + // cvt.rn.satfinite{.relu}.f8x2type.f32 d, a, b; + // cvt.rn.satfinite{.relu}.f8x2type.f16x2 d, a; + // cvt.rn.{.relu}.f16x2.f8x2type d, a; + + .ifrnd: RawRoundingMode = { .rn, .rz, .rm, .rp, .rni, .rzi, .rmi, .rpi }; + .frnd2: RawRoundingMode = { .rn, .rz }; + .dtype: ScalarType = { .u8, .u16, .u32, .u64, + .s8, .s16, .s32, .s64, + .bf16, .f16, .f32, .f64 }; + .atype: ScalarType = { .u8, .u16, .u32, .u64, + .s8, .s16, .s32, .s64, + .bf16, .f16, .f32, .f64 }; + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-ret ret{.uni} => { Instruction::Ret { data: RetData { uniform: uni } } From 47f8314a5d436b9fbe5f4344cb2dac91b8427f1b Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Tue, 20 Aug 2024 19:33:45 +0200 Subject: [PATCH 16/47] Add shr, shl --- ptx_parser/src/ast.rs | 34 ++++++++++++++++++++++++++++++++++ ptx_parser/src/main.rs | 18 ++++++++++++++++++ 2 files changed, 52 insertions(+) diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index daee9daa..7755c7fc 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -165,6 +165,30 @@ gen::generate_instruction_type!( }, } }, + Shr { + data: ShrData, + type: { Type::Scalar(data.type_.clone()) }, + arguments: { + dst: T, + src1: T, + src2: { + repr: T, + type: { Type::Scalar(ScalarType::U32) }, + }, + } + }, + Shl { + data: ScalarType, + type: { Type::Scalar(data.clone()) }, + arguments: { + dst: T, + src1: T, + src2: { + repr: T, + type: { Type::Scalar(ScalarType::U32) }, + }, + } + }, Ret { data: RetData }, @@ -943,3 +967,13 @@ pub struct CvtDesc { pub dst: ScalarType, pub src: ScalarType, } + +pub struct ShrData { + pub type_: ScalarType, + pub kind: RightShiftKind, +} + +pub enum RightShiftKind { + Arithmetic, + Logical, +} diff --git a/ptx_parser/src/main.rs b/ptx_parser/src/main.rs index 68787dbc..6055c1d1 100644 --- a/ptx_parser/src/main.rs +++ b/ptx_parser/src/main.rs @@ -1652,6 +1652,24 @@ derive_parser!( .atype: ScalarType = { .u8, .u16, .u32, .u64, .s8, .s16, .s32, .s64, .bf16, .f16, .f32, .f64 }; + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-shl + shl.type d, a, b => { + ast::Instruction::Shl { data: type_, arguments: ShlArgs { dst: d, src1: a, src2: b } } + } + .type: ScalarType = { .b16, .b32, .b64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-shr + shr.type d, a, b => { + let kind = if type_.kind() == ast::ScalarKind::Signed { RightShiftKind::Arithmetic} else { RightShiftKind::Logical }; + ast::Instruction::Shr { + data: ast::ShrData { type_, kind }, + arguments: ShrArgs { dst: d, src1: a, src2: b } + } + } + + .type: ScalarType = { .b16, .b32, .b64, + .u16, .u32, .u64, + .s16, .s32, .s64 }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-ret ret{.uni} => { From 588d66b236097c11ca0d11ba1cd6f2e1c1ae3448 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Tue, 20 Aug 2024 19:50:09 +0200 Subject: [PATCH 17/47] Add cvta --- ptx/src/ast.rs | 2 ++ ptx_parser/src/ast.rs | 24 +++++++++++++++++++++--- ptx_parser/src/main.rs | 39 +++++++++++++++++++++++++++++++++++++-- 3 files changed, 60 insertions(+), 5 deletions(-) diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index d308479b..f1323be4 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -16,6 +16,8 @@ pub enum PtxError { source: ParseFloatError, }, #[error("")] + Unsupported32Bit, + #[error("")] SyntaxError, #[error("")] NonF32Ftz, diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index 7755c7fc..98583a85 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -192,6 +192,14 @@ gen::generate_instruction_type!( Ret { data: RetData }, + Cvta { + data: CvtaDetails, + type: { Type::Scalar(ScalarType::B64) }, + arguments: { + dst: T, + src: T, + } + }, Trap { } } ); @@ -824,9 +832,9 @@ impl CallArgs { } pub struct CvtDetails { - from: ScalarType, - to: ScalarType, - mode: CvtMode, + pub from: ScalarType, + pub to: ScalarType, + pub mode: CvtMode, } pub enum CvtMode { @@ -977,3 +985,13 @@ pub enum RightShiftKind { Arithmetic, Logical, } + +pub struct CvtaDetails { + pub state_space: StateSpace, + pub direction: CvtaDirection, +} + +pub enum CvtaDirection { + GenericToExplicit, + ExplicitToGeneric, +} diff --git a/ptx_parser/src/main.rs b/ptx_parser/src/main.rs index 6055c1d1..03360f36 100644 --- a/ptx_parser/src/main.rs +++ b/ptx_parser/src/main.rs @@ -814,6 +814,8 @@ pub enum PtxError { #[error("")] NonF32Ftz, #[error("")] + Unsupported32Bit, + #[error("")] WrongType, #[error("")] UnknownFunction, @@ -1653,7 +1655,7 @@ derive_parser!( .s8, .s16, .s32, .s64, .bf16, .f16, .f32, .f64 }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-shl - shl.type d, a, b => { + shl.type d, a, b => { ast::Instruction::Shl { data: type_, arguments: ShlArgs { dst: d, src1: a, src2: b } } } .type: ScalarType = { .b16, .b32, .b64 }; @@ -1666,11 +1668,44 @@ derive_parser!( arguments: ShrArgs { dst: d, src1: a, src2: b } } } - .type: ScalarType = { .b16, .b32, .b64, .u16, .u32, .u64, .s16, .s32, .s64 }; + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvta + cvta.space.size p, a => { + if size != ScalarType::U64 { + state.errors.push(PtxError::Unsupported32Bit); + } + let data = ast::CvtaDetails { + state_space: space, + direction: ast::CvtaDirection::ExplicitToGeneric + }; + let arguments = ast::CvtaArgs { + dst: p, src: a + }; + ast::Instruction::Cvta { + data, arguments + } + } + cvta.to.space.size p, a => { + if size != ScalarType::U64 { + state.errors.push(PtxError::Unsupported32Bit); + } + let data = ast::CvtaDetails { + state_space: space, + direction: ast::CvtaDirection::GenericToExplicit + }; + let arguments = ast::CvtaArgs { + dst: p, src: a + }; + ast::Instruction::Cvta { + data, arguments + } + } + .space: StateSpace = { .const, .global, .local, .shared{::cta, ::cluster}, .param{::entry} }; + .size: ScalarType = { .u32, .u64 }; + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-ret ret{.uni} => { Instruction::Ret { data: RetData { uniform: uni } } From 6cd18bfdb8926e374a7a060e4acc20bfadfacfd0 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Wed, 21 Aug 2024 02:45:52 +0200 Subject: [PATCH 18/47] Add abs, mad --- gen_impl/src/parser.rs | 4 +- ptx_parser/src/ast.rs | 69 +++++++++++++++- ptx_parser/src/main.rs | 173 ++++++++++++++++++++++++++++++++++++++++- 3 files changed, 237 insertions(+), 9 deletions(-) diff --git a/gen_impl/src/parser.rs b/gen_impl/src/parser.rs index ea5070d0..f1cd7383 100644 --- a/gen_impl/src/parser.rs +++ b/gen_impl/src/parser.rs @@ -73,7 +73,7 @@ pub struct OpcodeDecl(pub Instruction, pub Arguments); impl OpcodeDecl { fn peek(input: syn::parse::ParseStream) -> bool { - Instruction::peek(input) + Instruction::peek(input) && !input.peek2(Token![=]) } } @@ -106,7 +106,7 @@ impl Parse for CodeBlock { } else { return Err(lookahead.error()); }; - Ok(Self{special, code}) + Ok(Self { special, code }) } } diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index 98583a85..248a6f32 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -200,6 +200,27 @@ gen::generate_instruction_type!( src: T, } }, + Abs { + data: AbsDetails, + type: { Type::Scalar(data.type_) }, + arguments: { + dst: T, + src: T, + } + }, + Mad { + type: { Type::from(data.type_()) }, + data: MadDetails, + arguments: { + dst: { + repr: T, + type: { Type::from(data.dst_type()) }, + }, + src1: T, + src2: T, + src3: T, + } + }, Trap { } } ); @@ -588,16 +609,14 @@ pub enum MulDetails { } impl MulDetails { - #[allow(unused)] // Used by generated code - fn type_(&self) -> ScalarType { + pub fn type_(&self) -> ScalarType { match self { MulDetails::Integer { type_, .. } => *type_, MulDetails::Float(arith) => arith.type_, } } - #[allow(unused)] // Used by generated code - fn dst_type(&self) -> ScalarType { + pub fn dst_type(&self) -> ScalarType { match self { MulDetails::Integer { type_, @@ -995,3 +1014,45 @@ pub enum CvtaDirection { GenericToExplicit, ExplicitToGeneric, } + +#[derive(Copy, Clone)] +pub struct AbsDetails { + pub flush_to_zero: Option, + pub type_: ScalarType, +} + +#[derive(Copy, Clone)] +pub enum MadDetails { + Integer { + control: MulIntControl, + saturate: bool, + type_: ScalarType, + }, + Float(ArithFloat), +} + +impl MadDetails { + pub fn dst_type(&self) -> ScalarType { + match self { + MadDetails::Integer { + type_, + control: MulIntControl::Wide, + .. + } => match type_ { + ScalarType::U16 => ScalarType::U32, + ScalarType::S16 => ScalarType::S32, + ScalarType::U32 => ScalarType::U64, + ScalarType::S32 => ScalarType::S64, + _ => unreachable!(), + }, + _ => self.type_(), + } + } + + fn type_(&self) -> ScalarType { + match self { + MadDetails::Integer { type_, .. } => *type_, + MadDetails::Float(arith) => arith.type_, + } + } +} diff --git a/ptx_parser/src/main.rs b/ptx_parser/src/main.rs index 03360f36..ce1f56dd 100644 --- a/ptx_parser/src/main.rs +++ b/ptx_parser/src/main.rs @@ -1450,6 +1450,8 @@ derive_parser!( ScalarType = { .f16, .f16x2, .bf16, .bf16x2 }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-mul + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-mul + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-mul mul.mode.type d, a, b => { ast::Instruction::Mul { data: ast::MulDetails::Integer { @@ -1476,8 +1478,6 @@ derive_parser!( .s16, .s32 }; RawMulIntControl = { .wide }; - - // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-mul mul{.rnd}{.ftz}{.sat}.f32 d, a, b => { ast::Instruction::Mul { data: ast::MulDetails::Float ( @@ -1507,7 +1507,6 @@ derive_parser!( .rnd: RawRoundingMode = { .rn, .rz, .rm, .rp }; ScalarType = { .f32, .f64 }; - // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-mul mul{.rnd}{.ftz}{.sat}.f16 d, a, b => { ast::Instruction::Mul { data: ast::MulDetails::Float ( @@ -1706,6 +1705,174 @@ derive_parser!( .space: StateSpace = { .const, .global, .local, .shared{::cta, ::cluster}, .param{::entry} }; .size: ScalarType = { .u32, .u64 }; + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-abs + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-abs + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-abs + abs.type d, a => { + ast::Instruction::Abs { + data: ast::AbsDetails { + flush_to_zero: None, + type_ + }, + arguments: ast::AbsArgs { + dst: d, src: a + } + } + } + abs{.ftz}.f32 d, a => { + ast::Instruction::Abs { + data: ast::AbsDetails { + flush_to_zero: Some(ftz), + type_: f32 + }, + arguments: ast::AbsArgs { + dst: d, src: a + } + } + } + abs.f64 d, a => { + ast::Instruction::Abs { + data: ast::AbsDetails { + flush_to_zero: None, + type_: f64 + }, + arguments: ast::AbsArgs { + dst: d, src: a + } + } + } + abs{.ftz}.f16 d, a => { + ast::Instruction::Abs { + data: ast::AbsDetails { + flush_to_zero: Some(ftz), + type_: f16 + }, + arguments: ast::AbsArgs { + dst: d, src: a + } + } + } + abs{.ftz}.f16x2 d, a => { + ast::Instruction::Abs { + data: ast::AbsDetails { + flush_to_zero: Some(ftz), + type_: f16x2 + }, + arguments: ast::AbsArgs { + dst: d, src: a + } + } + } + abs.bf16 d, a => { + ast::Instruction::Abs { + data: ast::AbsDetails { + flush_to_zero: None, + type_: bf16 + }, + arguments: ast::AbsArgs { + dst: d, src: a + } + } + } + abs.bf16x2 d, a => { + ast::Instruction::Abs { + data: ast::AbsDetails { + flush_to_zero: None, + type_: bf16x2 + }, + arguments: ast::AbsArgs { + dst: d, src: a + } + } + } + .type: ScalarType = { .s16, .s32, .s64 }; + ScalarType = { .f32, .f64, .f16, .f16x2, .bf16, .bf16x2 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-mad + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-mad + mad.mode.type d, a, b, c => { + ast::Instruction::Mad { + data: ast::MadDetails::Integer { + type_, + control: mode.into(), + saturate: false + }, + arguments: MadArgs { dst: d, src1: a, src2: b, src3: c } + } + } + .type: ScalarType = { .u16, .u32, .u64, + .s16, .s32, .s64 }; + .mode: RawMulIntControl = { .hi, .lo }; + + // The .wide suffix is supported only for 16-bit and 32-bit integer types. + mad.wide.type d, a, b, c => { + ast::Instruction::Mad { + data: ast::MadDetails::Integer { + type_, + control: wide.into(), + saturate: false + }, + arguments: MadArgs { dst: d, src1: a, src2: b, src3: c } + } + } + .type: ScalarType = { .u16, .u32, + .s16, .s32 }; + RawMulIntControl = { .wide }; + + mad.hi.sat.s32 d, a, b, c => { + ast::Instruction::Mad { + data: ast::MadDetails::Integer { + type_: s32, + control: hi.into(), + saturate: true + }, + arguments: MadArgs { dst: d, src1: a, src2: b, src3: c } + } + } + RawMulIntControl = { .hi }; + ScalarType = { .s32 }; + + mad{.ftz}{.sat}.f32 d, a, b, c => { + ast::Instruction::Mad { + data: ast::MadDetails::Float( + ArithFloat { + type_: f32, + rounding: None, + flush_to_zero: Some(ftz), + saturate: sat + } + ), + arguments: MadArgs { dst: d, src1: a, src2: b, src3: c } + } + } + mad.rnd{.ftz}{.sat}.f32 d, a, b, c => { + ast::Instruction::Mad { + data: ast::MadDetails::Float( + ArithFloat { + type_: f32, + rounding: Some(rnd.into()), + flush_to_zero: Some(ftz), + saturate: sat + } + ), + arguments: MadArgs { dst: d, src1: a, src2: b, src3: c } + } + } + mad.rnd.f64 d, a, b, c => { + ast::Instruction::Mad { + data: ast::MadDetails::Float( + ArithFloat { + type_: f64, + rounding: Some(rnd.into()), + flush_to_zero: None, + saturate: false + } + ), + arguments: MadArgs { dst: d, src1: a, src2: b, src3: c } + }} + .rnd: RawRoundingMode = { .rn, .rz, .rm, .rp }; + ScalarType = { .f32, .f64 }; + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-ret ret{.uni} => { Instruction::Ret { data: RetData { uniform: uni } } From 798bbf06e102892113224a20af952f34503a72b8 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Wed, 21 Aug 2024 03:02:41 +0200 Subject: [PATCH 19/47] Add fma and sub --- ptx_parser/src/ast.rs | 19 +++++ ptx_parser/src/main.rs | 181 ++++++++++++++++++++++++++++++++++++++--- 2 files changed, 190 insertions(+), 10 deletions(-) diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index 248a6f32..e1725c88 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -221,6 +221,25 @@ gen::generate_instruction_type!( src3: T, } }, + Fma { + type: { Type::from(data.type_) }, + data: ArithFloat, + arguments: { + dst: T, + src1: T, + src2: T, + src3: T, + } + }, + Sub { + type: { Type::from(data.type_()) }, + data: ArithDetails, + arguments: { + dst: T, + src1: T, + src2: T, + } + }, Trap { } } ); diff --git a/ptx_parser/src/main.rs b/ptx_parser/src/main.rs index ce1f56dd..9531f1c5 100644 --- a/ptx_parser/src/main.rs +++ b/ptx_parser/src/main.rs @@ -1481,7 +1481,7 @@ derive_parser!( mul{.rnd}{.ftz}{.sat}.f32 d, a, b => { ast::Instruction::Mul { data: ast::MulDetails::Float ( - ArithFloat { + ast::ArithFloat { type_: f32, rounding: rnd.map(Into::into), flush_to_zero: Some(ftz), @@ -1494,7 +1494,7 @@ derive_parser!( mul{.rnd}.f64 d, a, b => { ast::Instruction::Mul { data: ast::MulDetails::Float ( - ArithFloat { + ast::ArithFloat { type_: f64, rounding: rnd.map(Into::into), flush_to_zero: None, @@ -1510,7 +1510,7 @@ derive_parser!( mul{.rnd}{.ftz}{.sat}.f16 d, a, b => { ast::Instruction::Mul { data: ast::MulDetails::Float ( - ArithFloat { + ast::ArithFloat { type_: f16, rounding: rnd.map(Into::into), flush_to_zero: Some(ftz), @@ -1523,7 +1523,7 @@ derive_parser!( mul{.rnd}{.ftz}{.sat}.f16x2 d, a, b => { ast::Instruction::Mul { data: ast::MulDetails::Float ( - ArithFloat { + ast::ArithFloat { type_: f16x2, rounding: rnd.map(Into::into), flush_to_zero: Some(ftz), @@ -1536,7 +1536,7 @@ derive_parser!( mul{.rnd}.bf16 d, a, b => { ast::Instruction::Mul { data: ast::MulDetails::Float ( - ArithFloat { + ast::ArithFloat { type_: bf16, rounding: rnd.map(Into::into), flush_to_zero: None, @@ -1549,7 +1549,7 @@ derive_parser!( mul{.rnd}.bf16x2 d, a, b => { ast::Instruction::Mul { data: ast::MulDetails::Float ( - ArithFloat { + ast::ArithFloat { type_: bf16x2, rounding: rnd.map(Into::into), flush_to_zero: None, @@ -1835,7 +1835,7 @@ derive_parser!( mad{.ftz}{.sat}.f32 d, a, b, c => { ast::Instruction::Mad { data: ast::MadDetails::Float( - ArithFloat { + ast::ArithFloat { type_: f32, rounding: None, flush_to_zero: Some(ftz), @@ -1848,7 +1848,7 @@ derive_parser!( mad.rnd{.ftz}{.sat}.f32 d, a, b, c => { ast::Instruction::Mad { data: ast::MadDetails::Float( - ArithFloat { + ast::ArithFloat { type_: f32, rounding: Some(rnd.into()), flush_to_zero: Some(ftz), @@ -1861,7 +1861,7 @@ derive_parser!( mad.rnd.f64 d, a, b, c => { ast::Instruction::Mad { data: ast::MadDetails::Float( - ArithFloat { + ast::ArithFloat { type_: f64, rounding: Some(rnd.into()), flush_to_zero: None, @@ -1869,10 +1869,171 @@ derive_parser!( } ), arguments: MadArgs { dst: d, src1: a, src2: b, src3: c } - }} + } + } .rnd: RawRoundingMode = { .rn, .rz, .rm, .rp }; ScalarType = { .f32, .f64 }; + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-fma + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-fma + fma.rnd{.ftz}{.sat}.f32 d, a, b, c => { + ast::Instruction::Fma { + data: ast::ArithFloat { + type_: f32, + rounding: Some(rnd.into()), + flush_to_zero: Some(ftz), + saturate: sat + }, + arguments: FmaArgs { dst: d, src1: a, src2: b, src3: c } + } + } + fma.rnd.f64 d, a, b, c => { + ast::Instruction::Fma { + data: ast::ArithFloat { + type_: f64, + rounding: Some(rnd.into()), + flush_to_zero: None, + saturate: false + }, + arguments: FmaArgs { dst: d, src1: a, src2: b, src3: c } + } + } + .rnd: RawRoundingMode = { .rn, .rz, .rm, .rp }; + ScalarType = { .f32, .f64 }; + + fma.rnd{.ftz}{.sat}.f16 d, a, b, c => { + ast::Instruction::Fma { + data: ast::ArithFloat { + type_: f16, + rounding: Some(rnd.into()), + flush_to_zero: Some(ftz), + saturate: sat + }, + arguments: FmaArgs { dst: d, src1: a, src2: b, src3: c } + } + } + //fma.rnd{.ftz}{.sat}.f16x2 d, a, b, c; + //fma.rnd{.ftz}.relu.f16 d, a, b, c; + //fma.rnd{.ftz}.relu.f16x2 d, a, b, c; + //fma.rnd{.relu}.bf16 d, a, b, c; + //fma.rnd{.relu}.bf16x2 d, a, b, c; + //fma.rnd.oob.{relu}.type d, a, b, c; + .rnd: RawRoundingMode = { .rn }; + ScalarType = { .f16 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-sub + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-sub + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-sub + sub.type d, a, b => { + ast::Instruction::Sub { + data: ast::ArithDetails::Integer( + ArithInteger { + type_, + saturate: false + } + ), + arguments: SubArgs { dst: d, src1: a, src2: b } + } + } + sub.sat.s32 d, a, b => { + ast::Instruction::Sub { + data: ast::ArithDetails::Integer( + ArithInteger { + type_: s32, + saturate: true + } + ), + arguments: SubArgs { dst: d, src1: a, src2: b } + } + } + .type: ScalarType = { .u16, .u32, .u64, + .s16, .s32, .s64 }; + ScalarType = { .s32 }; + + sub{.rnd}{.ftz}{.sat}.f32 d, a, b => { + ast::Instruction::Sub { + data: ast::ArithDetails::Float( + ast::ArithFloat { + type_: f32, + rounding: rnd.map(Into::into), + flush_to_zero: Some(ftz), + saturate: sat + } + ), + arguments: SubArgs { dst: d, src1: a, src2: b } + } + } + sub{.rnd}.f64 d, a, b => { + ast::Instruction::Sub { + data: ast::ArithDetails::Float( + ast::ArithFloat { + type_: f64, + rounding: rnd.map(Into::into), + flush_to_zero: None, + saturate: false + } + ), + arguments: SubArgs { dst: d, src1: a, src2: b } + } + } + .rnd: RawRoundingMode = { .rn, .rz, .rm, .rp }; + ScalarType = { .f32, .f64 }; + + sub{.rnd}{.ftz}{.sat}.f16 d, a, b => { + ast::Instruction::Sub { + data: ast::ArithDetails::Float( + ast::ArithFloat { + type_: f16, + rounding: rnd.map(Into::into), + flush_to_zero: Some(ftz), + saturate: sat + } + ), + arguments: SubArgs { dst: d, src1: a, src2: b } + } + } + sub{.rnd}{.ftz}{.sat}.f16x2 d, a, b => { + ast::Instruction::Sub { + data: ast::ArithDetails::Float( + ast::ArithFloat { + type_: f16x2, + rounding: rnd.map(Into::into), + flush_to_zero: Some(ftz), + saturate: sat + } + ), + arguments: SubArgs { dst: d, src1: a, src2: b } + } + } + sub{.rnd}.bf16 d, a, b => { + ast::Instruction::Sub { + data: ast::ArithDetails::Float( + ast::ArithFloat { + type_: bf16, + rounding: rnd.map(Into::into), + flush_to_zero: None, + saturate: false + } + ), + arguments: SubArgs { dst: d, src1: a, src2: b } + } + } + sub{.rnd}.bf16x2 d, a, b => { + ast::Instruction::Sub { + data: ast::ArithDetails::Float( + ast::ArithFloat { + type_: bf16x2, + rounding: rnd.map(Into::into), + flush_to_zero: None, + saturate: false + } + ), + arguments: SubArgs { dst: d, src1: a, src2: b } + } + } + .rnd: RawRoundingMode = { .rn }; + ScalarType = { .f16, .f16x2, .bf16, .bf16x2 }; + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-ret ret{.uni} => { Instruction::Ret { data: RetData { uniform: uni } } From fc713f29308f96c9a9f424e65d277f4c41c5415c Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Wed, 21 Aug 2024 03:19:45 +0200 Subject: [PATCH 20/47] Add min, max --- ptx_parser/src/ast.rs | 42 +++++++++ ptx_parser/src/main.rs | 210 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 252 insertions(+) diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index e1725c88..4251d97d 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -240,6 +240,24 @@ gen::generate_instruction_type!( src2: T, } }, + Min { + type: { Type::from(data.type_()) }, + data: MinMaxDetails, + arguments: { + dst: T, + src1: T, + src2: T, + } + }, + Max { + type: { Type::from(data.type_()) }, + data: MinMaxDetails, + arguments: { + dst: T, + src1: T, + src2: T, + } + }, Trap { } } ); @@ -1075,3 +1093,27 @@ impl MadDetails { } } } + +#[derive(Copy, Clone)] +pub enum MinMaxDetails { + Signed(ScalarType), + Unsigned(ScalarType), + Float(MinMaxFloat), +} + +impl MinMaxDetails { + pub fn type_(&self) -> ScalarType { + match self { + MinMaxDetails::Signed(t) => *t, + MinMaxDetails::Unsigned(t) => *t, + MinMaxDetails::Float(float) => float.type_, + } + } +} + +#[derive(Copy, Clone)] +pub struct MinMaxFloat { + pub flush_to_zero: Option, + pub nan: bool, + pub type_: ScalarType, +} diff --git a/ptx_parser/src/main.rs b/ptx_parser/src/main.rs index 9531f1c5..71d8dced 100644 --- a/ptx_parser/src/main.rs +++ b/ptx_parser/src/main.rs @@ -2034,6 +2034,216 @@ derive_parser!( .rnd: RawRoundingMode = { .rn }; ScalarType = { .f16, .f16x2, .bf16, .bf16x2 }; + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-min + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-min + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-min + min.atype d, a, b => { + ast::Instruction::Min { + data: if atype.kind() == ast::ScalarKind::Signed { + ast::MinMaxDetails::Signed(atype) + } else { + ast::MinMaxDetails::Unsigned(atype) + }, + arguments: MinArgs { dst: d, src1: a, src2: b } + } + } + //min{.relu}.btype d, a, b => { todo!() } + min.btype d, a, b => { + ast::Instruction::Min { + data: ast::MinMaxDetails::Signed(btype), + arguments: MinArgs { dst: d, src1: a, src2: b } + } + } + .atype: ScalarType = { .u16, .u32, .u64, + .u16x2, .s16, .s64 }; + .btype: ScalarType = { .s16x2, .s32 }; + + //min{.ftz}{.NaN}{.xorsign.abs}.f32 d, a, b; + min{.ftz}{.NaN}.f32 d, a, b => { + ast::Instruction::Min { + data: ast::MinMaxDetails::Float( + MinMaxFloat { + flush_to_zero: Some(ftz), + nan, + type_: f32 + } + ), + arguments: MinArgs { dst: d, src1: a, src2: b } + } + } + min.f64 d, a, b => { + ast::Instruction::Min { + data: ast::MinMaxDetails::Float( + MinMaxFloat { + flush_to_zero: None, + nan: false, + type_: f64 + } + ), + arguments: MinArgs { dst: d, src1: a, src2: b } + } + } + ScalarType = { .f32, .f64 }; + + //min{.ftz}{.NaN}{.xorsign.abs}.f16 d, a, b; + //min{.ftz}{.NaN}{.xorsign.abs}.f16x2 d, a, b; + //min{.NaN}{.xorsign.abs}.bf16 d, a, b; + //min{.NaN}{.xorsign.abs}.bf16x2 d, a, b; + min{.ftz}{.NaN}.f16 d, a, b => { + ast::Instruction::Min { + data: ast::MinMaxDetails::Float( + MinMaxFloat { + flush_to_zero: Some(ftz), + nan, + type_: f16 + } + ), + arguments: MinArgs { dst: d, src1: a, src2: b } + } + } + min{.ftz}{.NaN}.f16x2 d, a, b => { + ast::Instruction::Min { + data: ast::MinMaxDetails::Float( + MinMaxFloat { + flush_to_zero: Some(ftz), + nan, + type_: f16x2 + } + ), + arguments: MinArgs { dst: d, src1: a, src2: b } + } + } + min{.NaN}.bf16 d, a, b => { + ast::Instruction::Min { + data: ast::MinMaxDetails::Float( + MinMaxFloat { + flush_to_zero: None, + nan, + type_: bf16 + } + ), + arguments: MinArgs { dst: d, src1: a, src2: b } + } + } + min{.NaN}.bf16x2 d, a, b => { + ast::Instruction::Min { + data: ast::MinMaxDetails::Float( + MinMaxFloat { + flush_to_zero: None, + nan, + type_: bf16x2 + } + ), + arguments: MinArgs { dst: d, src1: a, src2: b } + } + } + ScalarType = { .f16, .f16x2, .bf16, .bf16x2 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-max + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-max + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-max + max.atype d, a, b => { + ast::Instruction::Max { + data: if atype.kind() == ast::ScalarKind::Signed { + ast::MinMaxDetails::Signed(atype) + } else { + ast::MinMaxDetails::Unsigned(atype) + }, + arguments: MaxArgs { dst: d, src1: a, src2: b } + } + } + //max{.relu}.btype d, a, b => { todo!() } + max.btype d, a, b => { + ast::Instruction::Max { + data: ast::MinMaxDetails::Signed(btype), + arguments: MaxArgs { dst: d, src1: a, src2: b } + } + } + .atype: ScalarType = { .u16, .u32, .u64, + .u16x2, .s16, .s64 }; + .btype: ScalarType = { .s16x2, .s32 }; + + //max{.ftz}{.NaN}{.xorsign.abs}.f32 d, a, b; + max{.ftz}{.NaN}.f32 d, a, b => { + ast::Instruction::Max { + data: ast::MinMaxDetails::Float( + MinMaxFloat { + flush_to_zero: Some(ftz), + nan, + type_: f32 + } + ), + arguments: MaxArgs { dst: d, src1: a, src2: b } + } + } + max.f64 d, a, b => { + ast::Instruction::Max { + data: ast::MinMaxDetails::Float( + MinMaxFloat { + flush_to_zero: None, + nan: false, + type_: f64 + } + ), + arguments: MaxArgs { dst: d, src1: a, src2: b } + } + } + ScalarType = { .f32, .f64 }; + + //max{.ftz}{.NaN}{.xorsign.abs}.f16 d, a, b; + //max{.ftz}{.NaN}{.xorsign.abs}.f16x2 d, a, b; + //max{.NaN}{.xorsign.abs}.bf16 d, a, b; + //max{.NaN}{.xorsign.abs}.bf16x2 d, a, b; + max{.ftz}{.NaN}.f16 d, a, b => { + ast::Instruction::Max { + data: ast::MinMaxDetails::Float( + MinMaxFloat { + flush_to_zero: Some(ftz), + nan, + type_: f16 + } + ), + arguments: MaxArgs { dst: d, src1: a, src2: b } + } + } + max{.ftz}{.NaN}.f16x2 d, a, b => { + ast::Instruction::Max { + data: ast::MinMaxDetails::Float( + MinMaxFloat { + flush_to_zero: Some(ftz), + nan, + type_: f16x2 + } + ), + arguments: MaxArgs { dst: d, src1: a, src2: b } + } + } + max{.NaN}.bf16 d, a, b => { + ast::Instruction::Max { + data: ast::MinMaxDetails::Float( + MinMaxFloat { + flush_to_zero: None, + nan, + type_: bf16 + } + ), + arguments: MaxArgs { dst: d, src1: a, src2: b } + } + } + max{.NaN}.bf16x2 d, a, b => { + ast::Instruction::Max { + data: ast::MinMaxDetails::Float( + MinMaxFloat { + flush_to_zero: None, + nan, + type_: bf16x2 + } + ), + arguments: MaxArgs { dst: d, src1: a, src2: b } + } + } + ScalarType = { .f16, .f16x2, .bf16, .bf16x2 }; + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-ret ret{.uni} => { Instruction::Ret { data: RetData { uniform: uni } } From c16bae32b5e6963d0efa8d3674b2501b5ec6f266 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Wed, 21 Aug 2024 03:38:43 +0200 Subject: [PATCH 21/47] Add rcp, sqrt, rsqrt --- ptx_parser/src/ast.rs | 50 ++++++++++++++++++++ ptx_parser/src/main.rs | 101 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 151 insertions(+) diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index 4251d97d..944341c1 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -258,6 +258,30 @@ gen::generate_instruction_type!( src2: T, } }, + Rcp { + type: { Type::from(data.type_) }, + data: RcpData, + arguments: { + dst: T, + src: T, + } + }, + Sqrt { + type: { Type::from(data.type_) }, + data: RcpData, + arguments: { + dst: T, + src: T, + } + }, + Rsqrt { + type: { Type::from(data.type_) }, + data: RsqrtData, + arguments: { + dst: T, + src: T, + } + }, Trap { } } ); @@ -1117,3 +1141,29 @@ pub struct MinMaxFloat { pub nan: bool, pub type_: ScalarType, } + +#[derive(Copy, Clone, Eq, PartialEq)] +pub enum DivFloatKind { + Approx, + Full, + Rounding(RoundingMode), +} + +#[derive(Copy, Clone)] +pub struct RcpData { + pub kind: RcpKind, + pub flush_to_zero: Option, + pub type_: ScalarType, +} + +#[derive(Copy, Clone, Eq, PartialEq)] +pub enum RcpKind { + Approx, + Full(RoundingMode), +} + +#[derive(Copy, Clone)] +pub struct RsqrtData { + pub flush_to_zero: Option, + pub type_: ScalarType, +} diff --git a/ptx_parser/src/main.rs b/ptx_parser/src/main.rs index 71d8dced..159b918e 100644 --- a/ptx_parser/src/main.rs +++ b/ptx_parser/src/main.rs @@ -2244,6 +2244,107 @@ derive_parser!( } ScalarType = { .f16, .f16x2, .bf16, .bf16x2 }; + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-rcp + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-rcp-approx-ftz-f64 + rcp.approx{.ftz}.type d, a => { + ast::Instruction::Rcp { + data: ast::RcpData { + kind: ast::RcpKind::Approx, + flush_to_zero: Some(ftz), + type_ + }, + arguments: RcpArgs { dst: d, src: a } + } + } + rcp.rnd{.ftz}.f32 d, a => { + ast::Instruction::Rcp { + data: ast::RcpData { + kind: ast::RcpKind::Full(rnd.into()), + flush_to_zero: Some(ftz), + type_: f32 + }, + arguments: RcpArgs { dst: d, src: a } + } + } + rcp.rnd.f64 d, a => { + ast::Instruction::Rcp { + data: ast::RcpData { + kind: ast::RcpKind::Full(rnd.into()), + flush_to_zero: None, + type_: f64 + }, + arguments: RcpArgs { dst: d, src: a } + } + } + .type: ScalarType = { .f32, .f64 }; + .rnd: RawRoundingMode = { .rn, .rz, .rm, .rp }; + ScalarType = { .f32, .f64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-sqrt + sqrt.approx{.ftz}.f32 d, a => { + ast::Instruction::Sqrt { + data: ast::RcpData { + kind: ast::RcpKind::Approx, + flush_to_zero: Some(ftz), + type_: f32 + }, + arguments: SqrtArgs { dst: d, src: a } + } + } + sqrt.rnd{.ftz}.f32 d, a => { + ast::Instruction::Sqrt { + data: ast::RcpData { + kind: ast::RcpKind::Full(rnd.into()), + flush_to_zero: Some(ftz), + type_: f32 + }, + arguments: SqrtArgs { dst: d, src: a } + } + } + sqrt.rnd.f64 d, a => { + ast::Instruction::Sqrt { + data: ast::RcpData { + kind: ast::RcpKind::Full(rnd.into()), + flush_to_zero: None, + type_: f64 + }, + arguments: SqrtArgs { dst: d, src: a } + } + } + .rnd: RawRoundingMode = { .rn, .rz, .rm, .rp }; + ScalarType = { .f32, .f64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-rsqrt + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-rsqrt-approx-ftz-f64 + rsqrt.approx{.ftz}.f32 d, a => { + ast::Instruction::Rsqrt { + data: ast::RsqrtData { + flush_to_zero: Some(ftz), + type_: f32 + }, + arguments: RsqrtArgs { dst: d, src: a } + } + } + rsqrt.approx.f64 d, a => { + ast::Instruction::Rsqrt { + data: ast::RsqrtData { + flush_to_zero: None, + type_: f64 + }, + arguments: RsqrtArgs { dst: d, src: a } + } + } + rsqrt.approx.ftz.f64 d, a => { + ast::Instruction::Rsqrt { + data: ast::RsqrtData { + flush_to_zero: None, + type_: f64 + }, + arguments: RsqrtArgs { dst: d, src: a } + } + } + ScalarType = { .f32, .f64 }; + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-ret ret{.uni} => { Instruction::Ret { data: RetData { uniform: uni } } From 39faaa7214dbde018875db14d154d4a6a9fc1c98 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Wed, 21 Aug 2024 15:46:06 +0200 Subject: [PATCH 22/47] Add atom and atom.cas --- ptx_parser/src/ast.rs | 111 +++++++++++++++++++++++++++-- ptx_parser/src/main.rs | 154 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 261 insertions(+), 4 deletions(-) diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index 944341c1..232fdfcd 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -1,8 +1,8 @@ use std::cmp::Ordering; use super::{ - MemScope, RawRoundingMode, RawSetpCompareOp, ScalarType, SetpBoolPostOp, StateSpace, - VectorPrefix, + AtomSemantics, MemScope, RawRoundingMode, RawSetpCompareOp, ScalarType, SetpBoolPostOp, + StateSpace, VectorPrefix, }; use crate::{PtxError, PtxParserState}; use bitflags::bitflags; @@ -282,6 +282,52 @@ gen::generate_instruction_type!( src: T, } }, + Selp { + type: { Type::Scalar(data.clone()) }, + data: ScalarType, + arguments: { + dst: T, + src1: T, + src2: T, + src3: { + repr: T, + type: Type::Scalar(ScalarType::Pred) + }, + } + }, + Bar { + type: Type::Scalar(ScalarType::U32), + data: BarData, + arguments: { + src1: T, + src2: Option, + } + }, + Atom { + type: &data.type_, + data: AtomDetails, + arguments: { + dst: T, + src1: { + repr: T, + space: { data.space }, + }, + src2: T, + } + }, + AtomCas { + type: Type::Scalar(data.type_), + data: AtomCasDetails, + arguments: { + dst: T, + src1: { + repr: T, + space: { data.space }, + }, + src2: T, + src3: T, + } + }, Trap { } } ); @@ -408,8 +454,7 @@ pub enum Type { impl Type { pub(crate) fn maybe_vector(vector: Option, scalar: ScalarType) -> Self { match vector { - Some(VectorPrefix::V2) => Type::Vector(scalar, 2), - Some(VectorPrefix::V4) => Type::Vector(scalar, 4), + Some(prefix) => Type::Vector(scalar, prefix.len()), None => Type::Scalar(scalar), } } @@ -1167,3 +1212,61 @@ pub struct RsqrtData { pub flush_to_zero: Option, pub type_: ScalarType, } + +pub struct BarData { + pub aligned: bool, +} + +pub struct AtomDetails { + pub type_: Type, + pub semantics: AtomSemantics, + pub scope: MemScope, + pub space: StateSpace, + pub op: AtomicOp, +} + +#[derive(Copy, Clone)] +pub enum AtomicOp { + And, + Or, + Xor, + Exchange, + Add, + IncrementWrap, + DecrementWrap, + SignedMin, + UnsignedMin, + SignedMax, + UnsignedMax, + FloatAdd, + FloatMin, + FloatMax, +} + +impl AtomicOp { + pub(crate) fn new(op: super::RawAtomicOp, kind: ScalarKind) -> Self { + use super::RawAtomicOp; + match (op, kind) { + (RawAtomicOp::And, _) => Self::And, + (RawAtomicOp::Or, _) => Self::Or, + (RawAtomicOp::Xor, _) => Self::Xor, + (RawAtomicOp::Exch, _) => Self::Exchange, + (RawAtomicOp::Add, _) => Self::Add, + (RawAtomicOp::Inc, _) => Self::IncrementWrap, + (RawAtomicOp::Dec, _) => Self::DecrementWrap, + (RawAtomicOp::Min, ScalarKind::Signed) => Self::SignedMin, + (RawAtomicOp::Min, ScalarKind::Float) => Self::FloatMin, + (RawAtomicOp::Min, _) => Self::UnsignedMin, + (RawAtomicOp::Max, ScalarKind::Signed) => Self::SignedMax, + (RawAtomicOp::Max, ScalarKind::Float) => Self::FloatMax, + (RawAtomicOp::Max, _) => Self::UnsignedMax, + } + } +} + +pub struct AtomCasDetails { + pub type_: ScalarType, + pub semantics: AtomSemantics, + pub scope: MemScope, + pub space: StateSpace, +} diff --git a/ptx_parser/src/main.rs b/ptx_parser/src/main.rs index 159b918e..060c7679 100644 --- a/ptx_parser/src/main.rs +++ b/ptx_parser/src/main.rs @@ -71,6 +71,16 @@ impl From for ast::RoundingMode { } } +impl VectorPrefix { + pub(crate) fn len(self) -> u8 { + match self { + VectorPrefix::V2 => 2, + VectorPrefix::V4 => 4, + VectorPrefix::V8 => 8, + } + } +} + struct PtxParserState<'input> { errors: Vec, function_declarations: @@ -1135,6 +1145,9 @@ derive_parser!( #[derive(Copy, Clone, PartialEq, Eq, Hash)] pub enum SetpBoolPostOp { } + #[derive(Copy, Clone, PartialEq, Eq, Hash)] + pub enum AtomSemantics { } + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov mov{.vec}.type d, a => { Instruction::Mov { @@ -2345,6 +2358,147 @@ derive_parser!( } ScalarType = { .f32, .f64 }; + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#comparison-and-selection-instructions-selp + selp.type d, a, b, c => { + ast::Instruction::Selp { + data: type_, + arguments: SelpArgs { dst: d, src1: a, src2: b, src3: c } + } + } + .type: ScalarType = { .b16, .b32, .b64, + .u16, .u32, .u64, + .s16, .s32, .s64, + .f32, .f64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-bar + barrier{.cta}.sync{.aligned} a{, b} => { + let _ = cta; + ast::Instruction::Bar { + data: ast::BarData { aligned }, + arguments: BarArgs { src1: a, src2: b } + } + } + //barrier{.cta}.arrive{.aligned} a, b; + //barrier{.cta}.red.popc{.aligned}.u32 d, a{, b}, {!}c; + //barrier{.cta}.red.op{.aligned}.pred p, a{, b}, {!}c; + bar{.cta}.sync a{, b} => { + let _ = cta; + ast::Instruction::Bar { + data: ast::BarData { aligned: true }, + arguments: BarArgs { src1: a, src2: b } + } + } + //bar{.cta}.arrive a, b; + //bar{.cta}.red.popc.u32 d, a{, b}, {!}c; + //bar{.cta}.red.op.pred p, a{, b}, {!}c; + //.op = { .and, .or }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-atom + atom{.sem}{.scope}{.space}.op{.level::cache_hint}.type d, [a], b{, cache_policy} => { + if level_cache_hint || cache_policy.is_some() { + state.errors.push(PtxError::Todo); + } + ast::Instruction::Atom { + data: AtomDetails { + semantics: sem.map(Into::into).unwrap_or(AtomSemantics::Relaxed), + scope: scope.unwrap_or(MemScope::Gpu), + space: space.unwrap_or(StateSpace::Generic), + op: ast::AtomicOp::new(op, type_.kind()), + type_: type_.into() + }, + arguments: AtomArgs { dst: d, src1: a, src2: b } + } + } + atom{.sem}{.scope}{.space}.cas.cas_type d, [a], b, c => { + ast::Instruction::AtomCas { + data: AtomCasDetails { + semantics: sem.map(Into::into).unwrap_or(AtomSemantics::Relaxed), + scope: scope.unwrap_or(MemScope::Gpu), + space: space.unwrap_or(StateSpace::Generic), + type_: cas_type + }, + arguments: AtomCasArgs { dst: d, src1: a, src2: b, src3: c } + } + } + atom{.sem}{.scope}{.space}.exch{.level::cache_hint}.b128 d, [a], b{, cache_policy} => { + if level_cache_hint || cache_policy.is_some() { + state.errors.push(PtxError::Todo); + } + ast::Instruction::Atom { + data: AtomDetails { + semantics: sem.map(Into::into).unwrap_or(AtomSemantics::Relaxed), + scope: scope.unwrap_or(MemScope::Gpu), + space: space.unwrap_or(StateSpace::Generic), + op: ast::AtomicOp::new(exch, b128.kind()), + type_: b128.into() + }, + arguments: AtomArgs { dst: d, src1: a, src2: b } + } + } + atom{.sem}{.scope}{.global}.float_op{.level::cache_hint}.vec_32_bit.f32 d, [a], b{, cache_policy} => { + if level_cache_hint || cache_policy.is_some() { + state.errors.push(PtxError::Todo); + } + ast::Instruction::Atom { + data: AtomDetails { + semantics: sem.map(Into::into).unwrap_or(AtomSemantics::Relaxed), + scope: scope.unwrap_or(MemScope::Gpu), + space: global.unwrap_or(StateSpace::Generic), + op: ast::AtomicOp::new(float_op, f32.kind()), + type_: ast::Type::Vector(f32, vec_32_bit.len()) + }, + arguments: AtomArgs { dst: d, src1: a, src2: b } + } + } + atom{.sem}{.scope}{.global}.float_op.noftz{.level::cache_hint}{.vec_16_bit}.half_word_type d, [a], b{, cache_policy} => { + if level_cache_hint || cache_policy.is_some() { + state.errors.push(PtxError::Todo); + } + ast::Instruction::Atom { + data: AtomDetails { + semantics: sem.map(Into::into).unwrap_or(AtomSemantics::Relaxed), + scope: scope.unwrap_or(MemScope::Gpu), + space: global.unwrap_or(StateSpace::Generic), + op: ast::AtomicOp::new(float_op, half_word_type.kind()), + type_: ast::Type::maybe_vector(vec_16_bit, half_word_type) + }, + arguments: AtomArgs { dst: d, src1: a, src2: b } + } + } + atom{.sem}{.scope}{.global}.float_op.noftz{.level::cache_hint}{.vec_32_bit}.packed_type d, [a], b{, cache_policy} => { + if level_cache_hint || cache_policy.is_some() { + state.errors.push(PtxError::Todo); + } + ast::Instruction::Atom { + data: AtomDetails { + semantics: sem.map(Into::into).unwrap_or(AtomSemantics::Relaxed), + scope: scope.unwrap_or(MemScope::Gpu), + space: global.unwrap_or(StateSpace::Generic), + op: ast::AtomicOp::new(float_op, packed_type.kind()), + type_: ast::Type::maybe_vector(vec_32_bit, packed_type) + }, + arguments: AtomArgs { dst: d, src1: a, src2: b } + } + } + .space: StateSpace = { .global, .shared{::cta, ::cluster} }; + .sem: AtomSemantics = { .relaxed, .acquire, .release, .acq_rel }; + .scope: MemScope = { .cta, .cluster, .gpu, .sys }; + .op: RawAtomicOp = { .and, .or, .xor, + .exch, + .add, .inc, .dec, + .min, .max }; + .level::cache_hint = { .L2::cache_hint }; + .type: ScalarType = { .b32, .b64, .u32, .u64, .s32, .s64, .f32, .f64 }; + .cas_type: ScalarType = { .b32, .b64, .u32, .u64, .s32, .s64, .f32, .f64, .b16, .b128 }; + .half_word_type: ScalarType = { .f16, .bf16 }; + .packed_type: ScalarType = { .f16x2, .bf16x2 }; + .vec_16_bit: VectorPrefix = { .v2, .v4, .v8 }; + .vec_32_bit: VectorPrefix = { .v2, .v4 }; + .float_op: RawAtomicOp = { .add, .min, .max }; + ScalarType = { .b16, .b128, .f32 }; + StateSpace = { .global }; + RawAtomicOp = { .exch }; + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-ret ret{.uni} => { Instruction::Ret { data: RetData { uniform: uni } } From 0760c3d58f547cee471e0f9fc7532414c69c9193 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Wed, 21 Aug 2024 16:57:33 +0200 Subject: [PATCH 23/47] Map remaining instructions --- ptx_parser/src/ast.rs | 216 ++++++++++++++++++++++--- ptx_parser/src/main.rs | 347 +++++++++++++++++++++++++++++++++++++++-- 2 files changed, 531 insertions(+), 32 deletions(-) diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index 232fdfcd..1eead3ce 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -201,7 +201,7 @@ gen::generate_instruction_type!( } }, Abs { - data: AbsDetails, + data: TypeFtz, type: { Type::Scalar(data.type_) }, arguments: { dst: T, @@ -276,7 +276,7 @@ gen::generate_instruction_type!( }, Rsqrt { type: { Type::from(data.type_) }, - data: RsqrtData, + data: TypeFtz, arguments: { dst: T, src: T, @@ -328,6 +328,163 @@ gen::generate_instruction_type!( src3: T, } }, + Div { + type: Type::Scalar(data.type_()), + data: DivDetails, + arguments: { + dst: T, + src1: T, + src2: T, + } + }, + Neg { + type: Type::Scalar(data.type_), + data: TypeFtz, + arguments: { + dst: T, + src: T + } + }, + Sin { + type: Type::Scalar(ScalarType::F32), + data: FlushToZero, + arguments: { + dst: T, + src: T + } + }, + Cos { + type: Type::Scalar(ScalarType::F32), + data: FlushToZero, + arguments: { + dst: T, + src: T + } + }, + Lg2 { + type: Type::Scalar(ScalarType::F32), + data: FlushToZero, + arguments: { + dst: T, + src: T + } + }, + Ex2 { + type: Type::Scalar(ScalarType::F32), + data: TypeFtz, + arguments: { + dst: T, + src: T + } + }, + Clz { + type: Type::Scalar(data.clone()), + data: ScalarType, + arguments: { + dst: { + repr: T, + type: Type::Scalar(ScalarType::U32) + }, + src: T + } + }, + Brev { + type: Type::Scalar(data.clone()), + data: ScalarType, + arguments: { + dst: T, + src: T + } + }, + Popc { + type: Type::Scalar(data.clone()), + data: ScalarType, + arguments: { + dst: { + repr: T, + type: Type::Scalar(ScalarType::U32) + }, + src: T + } + }, + Xor { + type: Type::Scalar(data.clone()), + data: ScalarType, + arguments: { + dst: T, + src1: T, + src2: T + } + }, + Rem { + type: Type::Scalar(data.clone()), + data: ScalarType, + arguments: { + dst: T, + src1: T, + src2: T + } + }, + Bfe { + type: Type::Scalar(data.clone()), + data: ScalarType, + arguments: { + dst: T, + src1: T, + src2: { + repr: T, + type: Type::Scalar(ScalarType::U32) + }, + src3: { + repr: T, + type: Type::Scalar(ScalarType::U32) + }, + } + }, + Bfi { + type: Type::Scalar(data.clone()), + data: ScalarType, + arguments: { + dst: T, + src1: T, + src2: T, + src3: { + repr: T, + type: Type::Scalar(ScalarType::U32) + }, + src4: { + repr: T, + type: Type::Scalar(ScalarType::U32) + }, + } + }, + PrmtSlow { + type: Type::Scalar(ScalarType::U32), + arguments: { + dst: T, + src1: T, + src2: T, + src3: T + } + }, + Prmt { + type: Type::Scalar(ScalarType::B32), + data: u16, + arguments: { + dst: T, + src1: T, + src2: T + } + }, + Activemask { + type: Type::Scalar(ScalarType::B32), + arguments: { + dst: T + } + }, + Membar { + data: MemScope + }, Trap { } } ); @@ -1121,8 +1278,8 @@ pub enum CvtaDirection { ExplicitToGeneric, } -#[derive(Copy, Clone)] -pub struct AbsDetails { +#[derive(Copy, Clone, PartialEq, Eq)] +pub struct TypeFtz { pub flush_to_zero: Option, pub type_: ScalarType, } @@ -1187,13 +1344,6 @@ pub struct MinMaxFloat { pub type_: ScalarType, } -#[derive(Copy, Clone, Eq, PartialEq)] -pub enum DivFloatKind { - Approx, - Full, - Rounding(RoundingMode), -} - #[derive(Copy, Clone)] pub struct RcpData { pub kind: RcpKind, @@ -1204,13 +1354,7 @@ pub struct RcpData { #[derive(Copy, Clone, Eq, PartialEq)] pub enum RcpKind { Approx, - Full(RoundingMode), -} - -#[derive(Copy, Clone)] -pub struct RsqrtData { - pub flush_to_zero: Option, - pub type_: ScalarType, + Compliant(RoundingMode), } pub struct BarData { @@ -1270,3 +1414,39 @@ pub struct AtomCasDetails { pub scope: MemScope, pub space: StateSpace, } + +#[derive(Copy, Clone)] +pub enum DivDetails { + Unsigned(ScalarType), + Signed(ScalarType), + Float(DivFloatDetails), +} + +impl DivDetails { + pub fn type_(&self) -> ScalarType { + match self { + DivDetails::Unsigned(t) => *t, + DivDetails::Signed(t) => *t, + DivDetails::Float(float) => float.type_, + } + } +} + +#[derive(Copy, Clone)] +pub struct DivFloatDetails { + pub type_: ScalarType, + pub flush_to_zero: Option, + pub kind: DivFloatKind, +} + +#[derive(Copy, Clone, Eq, PartialEq)] +pub enum DivFloatKind { + Approx, + ApproxFull, + Rounding(RoundingMode), +} + +#[derive(Copy, Clone, Eq, PartialEq)] +pub struct FlushToZero { + pub flush_to_zero: bool +} diff --git a/ptx_parser/src/main.rs b/ptx_parser/src/main.rs index 060c7679..87b5e935 100644 --- a/ptx_parser/src/main.rs +++ b/ptx_parser/src/main.rs @@ -1723,7 +1723,7 @@ derive_parser!( // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-abs abs.type d, a => { ast::Instruction::Abs { - data: ast::AbsDetails { + data: ast::TypeFtz { flush_to_zero: None, type_ }, @@ -1734,7 +1734,7 @@ derive_parser!( } abs{.ftz}.f32 d, a => { ast::Instruction::Abs { - data: ast::AbsDetails { + data: ast::TypeFtz { flush_to_zero: Some(ftz), type_: f32 }, @@ -1745,7 +1745,7 @@ derive_parser!( } abs.f64 d, a => { ast::Instruction::Abs { - data: ast::AbsDetails { + data: ast::TypeFtz { flush_to_zero: None, type_: f64 }, @@ -1756,7 +1756,7 @@ derive_parser!( } abs{.ftz}.f16 d, a => { ast::Instruction::Abs { - data: ast::AbsDetails { + data: ast::TypeFtz { flush_to_zero: Some(ftz), type_: f16 }, @@ -1767,7 +1767,7 @@ derive_parser!( } abs{.ftz}.f16x2 d, a => { ast::Instruction::Abs { - data: ast::AbsDetails { + data: ast::TypeFtz { flush_to_zero: Some(ftz), type_: f16x2 }, @@ -1778,7 +1778,7 @@ derive_parser!( } abs.bf16 d, a => { ast::Instruction::Abs { - data: ast::AbsDetails { + data: ast::TypeFtz { flush_to_zero: None, type_: bf16 }, @@ -1789,7 +1789,7 @@ derive_parser!( } abs.bf16x2 d, a => { ast::Instruction::Abs { - data: ast::AbsDetails { + data: ast::TypeFtz { flush_to_zero: None, type_: bf16x2 }, @@ -2272,7 +2272,7 @@ derive_parser!( rcp.rnd{.ftz}.f32 d, a => { ast::Instruction::Rcp { data: ast::RcpData { - kind: ast::RcpKind::Full(rnd.into()), + kind: ast::RcpKind::Compliant(rnd.into()), flush_to_zero: Some(ftz), type_: f32 }, @@ -2282,7 +2282,7 @@ derive_parser!( rcp.rnd.f64 d, a => { ast::Instruction::Rcp { data: ast::RcpData { - kind: ast::RcpKind::Full(rnd.into()), + kind: ast::RcpKind::Compliant(rnd.into()), flush_to_zero: None, type_: f64 }, @@ -2307,7 +2307,7 @@ derive_parser!( sqrt.rnd{.ftz}.f32 d, a => { ast::Instruction::Sqrt { data: ast::RcpData { - kind: ast::RcpKind::Full(rnd.into()), + kind: ast::RcpKind::Compliant(rnd.into()), flush_to_zero: Some(ftz), type_: f32 }, @@ -2317,7 +2317,7 @@ derive_parser!( sqrt.rnd.f64 d, a => { ast::Instruction::Sqrt { data: ast::RcpData { - kind: ast::RcpKind::Full(rnd.into()), + kind: ast::RcpKind::Compliant(rnd.into()), flush_to_zero: None, type_: f64 }, @@ -2331,7 +2331,7 @@ derive_parser!( // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-rsqrt-approx-ftz-f64 rsqrt.approx{.ftz}.f32 d, a => { ast::Instruction::Rsqrt { - data: ast::RsqrtData { + data: ast::TypeFtz { flush_to_zero: Some(ftz), type_: f32 }, @@ -2340,7 +2340,7 @@ derive_parser!( } rsqrt.approx.f64 d, a => { ast::Instruction::Rsqrt { - data: ast::RsqrtData { + data: ast::TypeFtz { flush_to_zero: None, type_: f64 }, @@ -2349,7 +2349,7 @@ derive_parser!( } rsqrt.approx.ftz.f64 d, a => { ast::Instruction::Rsqrt { - data: ast::RsqrtData { + data: ast::TypeFtz { flush_to_zero: None, type_: f64 }, @@ -2499,6 +2499,325 @@ derive_parser!( StateSpace = { .global }; RawAtomicOp = { .exch }; + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-div + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-div + div.type d, a, b => { + ast::Instruction::Div { + data: if type_.kind() == ast::ScalarKind::Signed { + ast::DivDetails::Signed(type_) + } else { + ast::DivDetails::Unsigned(type_) + }, + arguments: DivArgs { + dst: d, + src1: a, + src2: b, + }, + } + } + .type: ScalarType = { .u16, .u32, .u64, + .s16, .s32, .s64 }; + + div.approx{.ftz}.f32 d, a, b => { + ast::Instruction::Div { + data: ast::DivDetails::Float(ast::DivFloatDetails{ + type_: f32, + flush_to_zero: Some(ftz), + kind: ast::DivFloatKind::Approx + }), + arguments: DivArgs { + dst: d, + src1: a, + src2: b, + }, + } + } + div.full{.ftz}.f32 d, a, b => { + ast::Instruction::Div { + data: ast::DivDetails::Float(ast::DivFloatDetails{ + type_: f32, + flush_to_zero: Some(ftz), + kind: ast::DivFloatKind::ApproxFull + }), + arguments: DivArgs { + dst: d, + src1: a, + src2: b, + }, + } + } + div.rnd{.ftz}.f32 d, a, b => { + ast::Instruction::Div { + data: ast::DivDetails::Float(ast::DivFloatDetails{ + type_: f32, + flush_to_zero: Some(ftz), + kind: ast::DivFloatKind::Rounding(rnd.into()) + }), + arguments: DivArgs { + dst: d, + src1: a, + src2: b, + }, + } + } + div.rnd.f64 d, a, b => { + ast::Instruction::Div { + data: ast::DivDetails::Float(ast::DivFloatDetails{ + type_: f64, + flush_to_zero: None, + kind: ast::DivFloatKind::Rounding(rnd.into()) + }), + arguments: DivArgs { + dst: d, + src1: a, + src2: b, + }, + } + } + .rnd: RawRoundingMode = { .rn, .rz, .rm, .rp }; + ScalarType = { .f32, .f64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-neg + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-neg + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-neg + neg.type d, a => { + ast::Instruction::Neg { + data: TypeFtz { + type_, + flush_to_zero: None + }, + arguments: NegArgs { dst: d, src: a, }, + } + } + .type: ScalarType = { .s16, .s32, .s64 }; + + neg{.ftz}.f32 d, a => { + ast::Instruction::Neg { + data: TypeFtz { + type_: f32, + flush_to_zero: Some(ftz) + }, + arguments: NegArgs { dst: d, src: a, }, + } + } + neg.f64 d, a => { + ast::Instruction::Neg { + data: TypeFtz { + type_: f64, + flush_to_zero: None + }, + arguments: NegArgs { dst: d, src: a, }, + } + } + neg{.ftz}.f16 d, a => { + ast::Instruction::Neg { + data: TypeFtz { + type_: f16, + flush_to_zero: Some(ftz) + }, + arguments: NegArgs { dst: d, src: a, }, + } + } + neg{.ftz}.f16x2 d, a => { + ast::Instruction::Neg { + data: TypeFtz { + type_: f16x2, + flush_to_zero: Some(ftz) + }, + arguments: NegArgs { dst: d, src: a, }, + } + } + neg.bf16 d, a => { + ast::Instruction::Neg { + data: TypeFtz { + type_: bf16, + flush_to_zero: None + }, + arguments: NegArgs { dst: d, src: a, }, + } + } + neg.bf16x2 d, a => { + ast::Instruction::Neg { + data: TypeFtz { + type_: bf16x2, + flush_to_zero: None + }, + arguments: NegArgs { dst: d, src: a, }, + } + } + ScalarType = { .f32, .f64, .f16, .f16x2, .bf16, .bf16x2 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-sin + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-cos + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-lg2 + sin.approx{.ftz}.f32 d, a => { + ast::Instruction::Sin { + data: ast::FlushToZero { + flush_to_zero: ftz + }, + arguments: SinArgs { dst: d, src: a, }, + } + } + cos.approx{.ftz}.f32 d, a => { + ast::Instruction::Cos { + data: ast::FlushToZero { + flush_to_zero: ftz + }, + arguments: CosArgs { dst: d, src: a, }, + } + } + lg2.approx{.ftz}.f32 d, a => { + ast::Instruction::Lg2 { + data: ast::FlushToZero { + flush_to_zero: ftz + }, + arguments: Lg2Args { dst: d, src: a, }, + } + } + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-ex2 + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-ex2 + ex2.approx{.ftz}.f32 d, a => { + ast::Instruction::Ex2 { + data: ast::TypeFtz { + type_: f32, + flush_to_zero: Some(ftz) + }, + arguments: Ex2Args { dst: d, src: a, }, + } + } + ex2.approx.atype d, a => { + ast::Instruction::Ex2 { + data: ast::TypeFtz { + type_: atype, + flush_to_zero: None + }, + arguments: Ex2Args { dst: d, src: a, }, + } + } + ex2.approx.ftz.btype d, a => { + ast::Instruction::Ex2 { + data: ast::TypeFtz { + type_: btype, + flush_to_zero: Some(true) + }, + arguments: Ex2Args { dst: d, src: a, }, + } + } + .atype: ScalarType = { .f16, .f16x2 }; + .btype: ScalarType = { .bf16, .bf16x2 }; + ScalarType = { .f32 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-clz + clz.type d, a => { + ast::Instruction::Clz { + data: type_, + arguments: ClzArgs { dst: d, src: a, }, + } + } + .type: ScalarType = { .b32, .b64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-brev + brev.type d, a => { + ast::Instruction::Brev { + data: type_, + arguments: BrevArgs { dst: d, src: a, }, + } + } + .type: ScalarType = { .b32, .b64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-popc + popc.type d, a => { + ast::Instruction::Popc { + data: type_, + arguments: PopcArgs { dst: d, src: a, }, + } + } + .type: ScalarType = { .b32, .b64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-xor + xor.type d, a, b => { + ast::Instruction::Xor { + data: type_, + arguments: XorArgs { dst: d, src1: a, src2: b, }, + } + } + .type: ScalarType = { .pred, .b16, .b32, .b64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-rem + rem.type d, a, b => { + ast::Instruction::Rem { + data: type_, + arguments: RemArgs { dst: d, src1: a, src2: b, }, + } + } + .type: ScalarType = { .u16, .u32, .u64, .s16, .s32, .s64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-bfe + bfe.type d, a, b, c => { + ast::Instruction::Bfe { + data: type_, + arguments: BfeArgs { dst: d, src1: a, src2: b, src3: c }, + } + } + .type: ScalarType = { .u32, .u64, .s32, .s64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-bfi + bfi.type f, a, b, c, d => { + ast::Instruction::Bfi { + data: type_, + arguments: BfiArgs { dst: f, src1: a, src2: b, src3: c, src4: d }, + } + } + .type: ScalarType = { .b32, .b64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-prmt + // prmt.b32{.mode} d, a, b, c; + // .mode = { .f4e, .b4e, .rc8, .ecl, .ecr, .rc16 }; + prmt.b32 d, a, b, c => { + match c { + ast::ParsedOperand::Imm(ImmediateValue::U64(control)) => ast::Instruction::Prmt { + data: control as u16, + arguments: PrmtArgs { + dst: d, src1: a, src2: b + } + }, + _ => ast::Instruction::PrmtSlow { + arguments: PrmtSlowArgs { + dst: d, src1: a, src2: b, src3: c + } + } + } + } + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-activemask + activemask.b32 d => { + ast::Instruction::Activemask { + arguments: ActivemaskArgs { dst: d } + } + } + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-membar + // fence{.sem}.scope; + // fence.op_restrict.release.cluster; + // fence.proxy.proxykind; + // fence.proxy.to_proxykind::from_proxykind.release.scope; + // fence.proxy.to_proxykind::from_proxykind.acquire.scope [addr], size; + //membar.proxy.proxykind; + //.sem = { .sc, .acq_rel }; + //.scope = { .cta, .cluster, .gpu, .sys }; + //.proxykind = { .alias, .async, async.global, .async.shared::{cta, cluster} }; + //.op_restrict = { .mbarrier_init }; + //.to_proxykind::from_proxykind = {.tensormap::generic}; + + membar.level => { + ast::Instruction::Membar { data: level } + } + membar.gl => { + ast::Instruction::Membar { data: MemScope::Gpu } + } + .level: MemScope = { .cta, .sys }; + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-ret ret{.uni} => { Instruction::Ret { data: RetData { uniform: uni } } From 71e025845ce11c3d68e8c13bd5aa7cd2c5b4b319 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Wed, 21 Aug 2024 20:00:22 +0200 Subject: [PATCH 24/47] Rename new crates --- Cargo.toml | 6 +++--- ptx_parser/Cargo.toml | 5 +++-- ptx_parser/src/ast.rs | 7 +++---- ptx_parser/src/main.rs | 2 +- {gen => ptx_parser_macros}/Cargo.toml | 7 ++++--- {gen => ptx_parser_macros}/src/lib.rs | 6 +++--- {gen_impl => ptx_parser_macros_impl}/Cargo.toml | 5 +++-- {gen_impl => ptx_parser_macros_impl}/src/lib.rs | 0 {gen_impl => ptx_parser_macros_impl}/src/parser.rs | 0 9 files changed, 20 insertions(+), 18 deletions(-) rename {gen => ptx_parser_macros}/Cargo.toml (54%) rename {gen => ptx_parser_macros}/src/lib.rs (96%) rename {gen_impl => ptx_parser_macros_impl}/Cargo.toml (64%) rename {gen_impl => ptx_parser_macros_impl}/src/lib.rs (100%) rename {gen_impl => ptx_parser_macros_impl}/src/parser.rs (100%) diff --git a/Cargo.toml b/Cargo.toml index 7f38976a..350d568f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,9 +17,9 @@ members = [ "zluda_redirect", "zluda_ml", "ptx", - "gen", - "gen_impl", - "ptx_parser" + "ptx_parser", + "ptx_parser_macros", + "ptx_parser_macros_impl", ] default-members = ["zluda_lib", "zluda_ml", "zluda_inject", "zluda_redirect"] diff --git a/ptx_parser/Cargo.toml b/ptx_parser/Cargo.toml index 35251ee8..af3058b2 100644 --- a/ptx_parser/Cargo.toml +++ b/ptx_parser/Cargo.toml @@ -1,12 +1,13 @@ [package] name = "ptx_parser" -version = "0.1.0" +version = "0.0.0" +authors = ["Andrzej Janik "] edition = "2021" [dependencies] logos = "0.14" winnow = { version = "0.6.18" } -gen = { path = "../gen" } +ptx_parser_macros = { path = "../ptx_parser_macros" } thiserror = "1.0" bitflags = "1.2" rustc-hash = "2.0.0" diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index 1eead3ce..6cf12649 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -1,11 +1,10 @@ -use std::cmp::Ordering; - use super::{ AtomSemantics, MemScope, RawRoundingMode, RawSetpCompareOp, ScalarType, SetpBoolPostOp, StateSpace, VectorPrefix, }; use crate::{PtxError, PtxParserState}; use bitflags::bitflags; +use std::cmp::Ordering; pub enum Statement { Label(P::Ident), @@ -14,7 +13,7 @@ pub enum Statement { Block(Vec>), } -gen::generate_instruction_type!( +ptx_parser_macros::generate_instruction_type!( pub enum Instruction { Mov { type: { &data.typ }, @@ -1448,5 +1447,5 @@ pub enum DivFloatKind { #[derive(Copy, Clone, Eq, PartialEq)] pub struct FlushToZero { - pub flush_to_zero: bool + pub flush_to_zero: bool, } diff --git a/ptx_parser/src/main.rs b/ptx_parser/src/main.rs index 87b5e935..5db94f23 100644 --- a/ptx_parser/src/main.rs +++ b/ptx_parser/src/main.rs @@ -1,5 +1,5 @@ -use gen::derive_parser; use logos::Logos; +use ptx_parser_macros::derive_parser; use rustc_hash::FxHashMap; use std::fmt::Debug; use std::mem; diff --git a/gen/Cargo.toml b/ptx_parser_macros/Cargo.toml similarity index 54% rename from gen/Cargo.toml rename to ptx_parser_macros/Cargo.toml index e26383da..62a5081b 100644 --- a/gen/Cargo.toml +++ b/ptx_parser_macros/Cargo.toml @@ -1,13 +1,14 @@ [package] -name = "gen" -version = "0.1.0" +name = "ptx_parser_macros" +version = "0.0.0" +authors = ["Andrzej Janik "] edition = "2021" [lib] proc-macro = true [dependencies] -gen_impl = { path = "../gen_impl" } +ptx_parser_macros_impl = { path = "../ptx_parser_macros_impl" } convert_case = "0.6.0" rustc-hash = "2.0.0" syn = "2.0.67" diff --git a/gen/src/lib.rs b/ptx_parser_macros/src/lib.rs similarity index 96% rename from gen/src/lib.rs rename to ptx_parser_macros/src/lib.rs index a110fdc4..a2f8396f 100644 --- a/gen/src/lib.rs +++ b/ptx_parser_macros/src/lib.rs @@ -1,5 +1,5 @@ use either::Either; -use gen_impl::parser; +use ptx_parser_macros_impl::parser; use proc_macro2::{Span, TokenStream}; use quote::{format_ident, quote, ToTokens}; use rustc_hash::{FxHashMap, FxHashSet}; @@ -359,7 +359,7 @@ fn gather_rules( #[proc_macro] pub fn derive_parser(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream { - let parse_definitions = parse_macro_input!(tokens as gen_impl::parser::ParseDefinitions); + let parse_definitions = parse_macro_input!(tokens as ptx_parser_macros_impl::parser::ParseDefinitions); let mut definitions = FxHashMap::default(); let mut special_definitions = FxHashMap::default(); let types = OpcodeDefinitions::get_enum_types(&parse_definitions.definitions); @@ -1012,7 +1012,7 @@ impl DotModifierRef { #[proc_macro] pub fn generate_instruction_type(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream { - let input = parse_macro_input!(tokens as gen_impl::GenerateInstructionType); + let input = parse_macro_input!(tokens as ptx_parser_macros_impl::GenerateInstructionType); let mut result = proc_macro2::TokenStream::new(); input.emit_arg_types(&mut result); input.emit_instruction_type(&mut result); diff --git a/gen_impl/Cargo.toml b/ptx_parser_macros_impl/Cargo.toml similarity index 64% rename from gen_impl/Cargo.toml rename to ptx_parser_macros_impl/Cargo.toml index ff93f98c..96f3b749 100644 --- a/gen_impl/Cargo.toml +++ b/ptx_parser_macros_impl/Cargo.toml @@ -1,6 +1,7 @@ [package] -name = "gen_impl" -version = "0.1.0" +name = "ptx_parser_macros_impl" +version = "0.0.0" +authors = ["Andrzej Janik "] edition = "2021" [lib] diff --git a/gen_impl/src/lib.rs b/ptx_parser_macros_impl/src/lib.rs similarity index 100% rename from gen_impl/src/lib.rs rename to ptx_parser_macros_impl/src/lib.rs diff --git a/gen_impl/src/parser.rs b/ptx_parser_macros_impl/src/parser.rs similarity index 100% rename from gen_impl/src/parser.rs rename to ptx_parser_macros_impl/src/parser.rs From 1ec1ca0c30257301a90caf8b394e491560c395e7 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Fri, 23 Aug 2024 02:19:36 +0200 Subject: [PATCH 25/47] Attempt #2 --- ptx/Cargo.toml | 8 +- ptx/src/ast.rs | 26 +- ptx/src/lib.rs | 2 + ptx/src/pass/mod.rs | 531 +++++++++++++++++++++++++++++ ptx/src/pass/normalize.rs | 83 +++++ ptx/src/translate2.rs | 60 ++++ ptx_parser/Cargo.toml | 3 + ptx_parser/src/ast.rs | 212 +++++++++--- ptx_parser/src/{main.rs => lib.rs} | 88 ++--- ptx_parser_macros/src/lib.rs | 2 +- ptx_parser_macros_impl/src/lib.rs | 36 +- 11 files changed, 903 insertions(+), 148 deletions(-) create mode 100644 ptx/src/pass/mod.rs create mode 100644 ptx/src/pass/normalize.rs create mode 100644 ptx/src/translate2.rs rename ptx_parser/src/{main.rs => lib.rs} (98%) diff --git a/ptx/Cargo.toml b/ptx/Cargo.toml index 2ac1f689..d4852862 100644 --- a/ptx/Cargo.toml +++ b/ptx/Cargo.toml @@ -7,7 +7,7 @@ edition = "2018" [lib] [dependencies] -lalrpop-util = "0.19" +ptx_parser = { path = "../ptx_parser" } regex = "1" rspirv = "0.7" spirv_headers = "1.5" @@ -17,8 +17,12 @@ bit-vec = "0.6" half ="1.6" bitflags = "1.2" +[dependencies.lalrpop-util] +version = "0.19.12" +features = ["lexer"] + [build-dependencies.lalrpop] -version = "0.19" +version = "0.19.12" features = ["lexer"] [dev-dependencies] diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index f1323be4..358b8cef 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -34,15 +34,9 @@ pub enum PtxError { #[error("")] NonExternPointer, #[error("{start}:{end}")] - UnrecognizedStatement { - start: usize, - end: usize, - }, + UnrecognizedStatement { start: usize, end: usize }, #[error("{start}:{end}")] - UnrecognizedDirective { - start: usize, - end: usize, - }, + UnrecognizedDirective { start: usize, end: usize }, } // For some weird reson this is illegal: @@ -578,11 +572,15 @@ impl CvtDetails { if saturate { if src.kind() == ScalarKind::Signed { if dst.kind() == ScalarKind::Signed && dst.size_of() >= src.size_of() { - err.push(ParseError::from(PtxError::SyntaxError)); + err.push(ParseError::User { + error: PtxError::SyntaxError, + }); } } else { if dst == src || dst.size_of() >= src.size_of() { - err.push(ParseError::from(PtxError::SyntaxError)); + err.push(ParseError::User { + error: PtxError::SyntaxError, + }); } } } @@ -598,7 +596,9 @@ impl CvtDetails { err: &'err mut Vec, PtxError>>, ) -> Self { if flush_to_zero && dst != ScalarType::F32 { - err.push(ParseError::from(PtxError::NonF32Ftz)); + err.push(ParseError::from(lalrpop_util::ParseError::User { + error: PtxError::NonF32Ftz, + })); } CvtDetails::FloatFromInt(CvtDesc { dst, @@ -618,7 +618,9 @@ impl CvtDetails { err: &'err mut Vec, PtxError>>, ) -> Self { if flush_to_zero && src != ScalarType::F32 { - err.push(ParseError::from(PtxError::NonF32Ftz)); + err.push(ParseError::from(lalrpop_util::ParseError::User { + error: PtxError::NonF32Ftz, + })); } CvtDetails::IntFromFloat(CvtDesc { dst, diff --git a/ptx/src/lib.rs b/ptx/src/lib.rs index 1cb96308..b70019ea 100644 --- a/ptx/src/lib.rs +++ b/ptx/src/lib.rs @@ -24,9 +24,11 @@ lalrpop_mod!( ); pub mod ast; +mod pass; #[cfg(test)] mod test; mod translate; +mod translate2; use std::fmt; diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs new file mode 100644 index 00000000..7b794d6d --- /dev/null +++ b/ptx/src/pass/mod.rs @@ -0,0 +1,531 @@ +use ptx_parser as ast; +use std::{ + borrow::Cow, + cell::RefCell, + collections::{hash_map, HashMap}, + rc::Rc, +}; + +mod normalize; + +#[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone)] +enum PtxSpecialRegister { + Tid, + Ntid, + Ctaid, + Nctaid, + Clock, + LanemaskLt, +} + +impl PtxSpecialRegister { + fn try_parse(s: &str) -> Option { + match s { + "%tid" => Some(Self::Tid), + "%ntid" => Some(Self::Ntid), + "%ctaid" => Some(Self::Ctaid), + "%nctaid" => Some(Self::Nctaid), + "%clock" => Some(Self::Clock), + "%lanemask_lt" => Some(Self::LanemaskLt), + _ => None, + } + } + + fn get_type(self) -> ast::Type { + match self { + PtxSpecialRegister::Tid + | PtxSpecialRegister::Ntid + | PtxSpecialRegister::Ctaid + | PtxSpecialRegister::Nctaid => ast::Type::Vector(self.get_function_return_type(), 4), + _ => ast::Type::Scalar(self.get_function_return_type()), + } + } + + fn get_function_return_type(self) -> ast::ScalarType { + match self { + PtxSpecialRegister::Tid => ast::ScalarType::U32, + PtxSpecialRegister::Ntid => ast::ScalarType::U32, + PtxSpecialRegister::Ctaid => ast::ScalarType::U32, + PtxSpecialRegister::Nctaid => ast::ScalarType::U32, + PtxSpecialRegister::Clock => ast::ScalarType::U32, + PtxSpecialRegister::LanemaskLt => ast::ScalarType::U32, + } + } + + fn get_function_input_type(self) -> Option { + match self { + PtxSpecialRegister::Tid + | PtxSpecialRegister::Ntid + | PtxSpecialRegister::Ctaid + | PtxSpecialRegister::Nctaid => Some(ast::ScalarType::U8), + PtxSpecialRegister::Clock | PtxSpecialRegister::LanemaskLt => None, + } + } + + fn get_unprefixed_function_name(self) -> &'static str { + match self { + PtxSpecialRegister::Tid => "sreg_tid", + PtxSpecialRegister::Ntid => "sreg_ntid", + PtxSpecialRegister::Ctaid => "sreg_ctaid", + PtxSpecialRegister::Nctaid => "sreg_nctaid", + PtxSpecialRegister::Clock => "sreg_clock", + PtxSpecialRegister::LanemaskLt => "sreg_lanemask_lt", + } + } +} + +struct SpecialRegistersMap { + reg_to_id: HashMap, + id_to_reg: HashMap, +} + +impl SpecialRegistersMap { + fn new() -> Self { + SpecialRegistersMap { + reg_to_id: HashMap::new(), + id_to_reg: HashMap::new(), + } + } + + fn get(&self, id: SpirvWord) -> Option { + self.id_to_reg.get(&id).copied() + } + + fn get_or_add(&mut self, current_id: &mut SpirvWord, reg: PtxSpecialRegister) -> SpirvWord { + match self.reg_to_id.entry(reg) { + hash_map::Entry::Occupied(e) => *e.get(), + hash_map::Entry::Vacant(e) => { + let numeric_id = SpirvWord(current_id.0); + current_id.0 += 1; + e.insert(numeric_id); + self.id_to_reg.insert(numeric_id, reg); + numeric_id + } + } + } +} + +struct FnStringIdResolver<'input, 'b> { + current_id: &'b mut SpirvWord, + global_variables: &'b HashMap, SpirvWord>, + global_type_check: &'b HashMap>, + special_registers: &'b mut SpecialRegistersMap, + variables: Vec, SpirvWord>>, + type_check: HashMap>, +} + +impl<'a, 'b> FnStringIdResolver<'a, 'b> { + fn finish(self) -> NumericIdResolver<'b> { + NumericIdResolver { + current_id: self.current_id, + global_type_check: self.global_type_check, + type_check: self.type_check, + special_registers: self.special_registers, + } + } + + fn start_block(&mut self) { + self.variables.push(HashMap::new()) + } + + fn end_block(&mut self) { + self.variables.pop(); + } + + fn get_id(&mut self, id: &str) -> Result { + for scope in self.variables.iter().rev() { + match scope.get(id) { + Some(id) => return Ok(*id), + None => continue, + } + } + match self.global_variables.get(id) { + Some(id) => Ok(*id), + None => { + let sreg = PtxSpecialRegister::try_parse(id).ok_or_else(error_unknown_symbol)?; + Ok(self.special_registers.get_or_add(self.current_id, sreg)) + } + } + } + + fn add_def( + &mut self, + id: &'a str, + typ: Option<(ast::Type, ast::StateSpace)>, + is_variable: bool, + ) -> SpirvWord { + let numeric_id = *self.current_id; + self.variables + .last_mut() + .unwrap() + .insert(Cow::Borrowed(id), numeric_id); + self.type_check.insert( + numeric_id.0, + typ.map(|(typ, space)| (typ, space, is_variable)), + ); + self.current_id.0 += 1; + numeric_id + } + + #[must_use] + fn add_defs( + &mut self, + base_id: &'a str, + count: u32, + typ: ast::Type, + state_space: ast::StateSpace, + is_variable: bool, + ) -> impl Iterator { + let numeric_id = *self.current_id; + for i in 0..count { + self.variables.last_mut().unwrap().insert( + Cow::Owned(format!("{}{}", base_id, i)), + SpirvWord(numeric_id.0 + i), + ); + self.type_check.insert( + numeric_id.0 + i, + Some((typ.clone(), state_space, is_variable)), + ); + } + self.current_id.0 += count; + (0..count) + .into_iter() + .map(move |i| SpirvWord(i + numeric_id.0)) + } +} + +struct NumericIdResolver<'b> { + current_id: &'b mut SpirvWord, + global_type_check: &'b HashMap>, + type_check: HashMap>, + special_registers: &'b mut SpecialRegistersMap, +} + +impl<'b> NumericIdResolver<'b> { + fn finish(self) -> MutableNumericIdResolver<'b> { + MutableNumericIdResolver { base: self } + } + + fn get_typed( + &self, + id: SpirvWord, + ) -> Result<(ast::Type, ast::StateSpace, bool), TranslateError> { + match self.type_check.get(&id.0) { + Some(Some(x)) => Ok(x.clone()), + Some(None) => Err(TranslateError::UntypedSymbol), + None => match self.special_registers.get(id) { + Some(x) => Ok((x.get_type(), ast::StateSpace::Sreg, true)), + None => match self.global_type_check.get(&id.0) { + Some(Some(result)) => Ok(result.clone()), + Some(None) | None => Err(TranslateError::UntypedSymbol), + }, + }, + } + } + + // This is for identifiers which will be emitted later as OpVariable + // They are candidates for insertion of LoadVar/StoreVar + fn register_variable(&mut self, typ: ast::Type, state_space: ast::StateSpace) -> SpirvWord { + let new_id = *self.current_id; + self.type_check + .insert(new_id.0, Some((typ, state_space, true))); + self.current_id.0 += 1; + new_id + } + + fn register_intermediate(&mut self, typ: Option<(ast::Type, ast::StateSpace)>) -> SpirvWord { + let new_id = *self.current_id; + self.type_check + .insert(new_id.0, typ.map(|(t, space)| (t, space, false))); + self.current_id.0 += 1; + new_id + } +} + +struct MutableNumericIdResolver<'b> { + base: NumericIdResolver<'b>, +} + +impl<'b> MutableNumericIdResolver<'b> { + fn unmut(self) -> NumericIdResolver<'b> { + self.base + } + + fn get_typed(&self, id: SpirvWord) -> Result<(ast::Type, ast::StateSpace), TranslateError> { + self.base.get_typed(id).map(|(t, space, _)| (t, space)) + } + + fn register_intermediate(&mut self, typ: ast::Type, state_space: ast::StateSpace) -> SpirvWord { + self.base.register_intermediate(Some((typ, state_space))) + } +} + +quick_error! { + #[derive(Debug)] + pub enum TranslateError { + UnknownSymbol {} + UntypedSymbol {} + MismatchedType {} + Spirv(err: rspirv::dr::Error) { + from() + display("{}", err) + cause(err) + } + Unreachable {} + Todo {} + } +} + +#[cfg(debug_assertions)] +fn error_unreachable() -> TranslateError { + unreachable!() +} + +#[cfg(not(debug_assertions))] +fn error_unreachable() -> TranslateError { + TranslateError::Unreachable +} + +fn error_unknown_symbol() -> TranslateError { + TranslateError::UnknownSymbol +} + +pub struct GlobalFnDeclResolver<'input, 'a> { + fns: &'a HashMap>, +} + +impl<'input, 'a> GlobalFnDeclResolver<'input, 'a> { + fn get_fn_sig_resolver(&self, id: SpirvWord) -> Result<&FnSigMapper<'input>, TranslateError> { + self.fns.get(&id).ok_or_else(error_unknown_symbol) + } +} + +struct FnSigMapper<'input> { + // true - stays as return argument + // false - is moved to input argument + return_param_args: Vec, + func_decl: Rc>>, +} + +impl<'input> FnSigMapper<'input> { + fn remap_to_spirv_repr(mut method: ast::MethodDeclaration<'input, SpirvWord>) -> Self { + let return_param_args = method + .return_arguments + .iter() + .map(|a| a.state_space != ast::StateSpace::Param) + .collect::>(); + let mut new_return_arguments = Vec::new(); + for arg in method.return_arguments.into_iter() { + if arg.state_space == ast::StateSpace::Param { + method.input_arguments.push(arg); + } else { + new_return_arguments.push(arg); + } + } + method.return_arguments = new_return_arguments; + FnSigMapper { + return_param_args, + func_decl: Rc::new(RefCell::new(method)), + } + } + + /* + fn resolve_in_spirv_repr( + &self, + call_inst: ast::CallInst, + ) -> Result, TranslateError> { + let func_decl = (*self.func_decl).borrow(); + let mut return_arguments = Vec::new(); + let mut input_arguments = call_inst + .param_list + .into_iter() + .zip(func_decl.input_arguments.iter()) + .map(|(id, var)| (id, var.v_type.clone(), var.state_space)) + .collect::>(); + let mut func_decl_return_iter = func_decl.return_arguments.iter(); + let mut func_decl_input_iter = func_decl.input_arguments[input_arguments.len()..].iter(); + for (idx, id) in call_inst.ret_params.iter().enumerate() { + let stays_as_return = match self.return_param_args.get(idx) { + Some(x) => *x, + None => return Err(TranslateError::MismatchedType), + }; + if stays_as_return { + if let Some(var) = func_decl_return_iter.next() { + return_arguments.push((*id, var.v_type.clone(), var.state_space)); + } else { + return Err(TranslateError::MismatchedType); + } + } else { + if let Some(var) = func_decl_input_iter.next() { + input_arguments.push(( + ast::Operand::Reg(*id), + var.v_type.clone(), + var.state_space, + )); + } else { + return Err(TranslateError::MismatchedType); + } + } + } + if return_arguments.len() != func_decl.return_arguments.len() + || input_arguments.len() != func_decl.input_arguments.len() + { + return Err(TranslateError::MismatchedType); + } + Ok(ResolvedCall { + return_arguments, + input_arguments, + uniform: call_inst.uniform, + name: call_inst.func, + }) + } + */ +} + +enum Statement { + Label(SpirvWord), + Variable(ast::Variable), + Instruction(I), + // SPIR-V compatible replacement for PTX predicates + Conditional(BrachCondition), + LoadVar(LoadVarDetails), + StoreVar(StoreVarDetails), + Conversion(ImplicitConversion), + Constant(ConstantDefinition), + RetValue(ast::RetData, SpirvWord), + PtrAccess(PtrAccess

), + RepackVector(RepackVectorDetails), + FunctionPointer(FunctionPointerDetails), +} + +struct BrachCondition { + predicate: SpirvWord, + if_true: SpirvWord, + if_false: SpirvWord, +} +struct LoadVarDetails { + arg: ast::LdArgs, + typ: ast::Type, + state_space: ast::StateSpace, + // (index, vector_width) + // HACK ALERT + // For some reason IGC explodes when you try to load from builtin vectors + // using OpInBoundsAccessChain, the one true way to do it is to + // OpLoad+OpCompositeExtract + member_index: Option<(u8, Option)>, +} + +struct StoreVarDetails { + arg: ast::StArgs, + typ: ast::Type, + member_index: Option, +} + +#[derive(Clone)] +struct ImplicitConversion { + src: SpirvWord, + dst: SpirvWord, + from_type: ast::Type, + to_type: ast::Type, + from_space: ast::StateSpace, + to_space: ast::StateSpace, + kind: ConversionKind, +} + +#[derive(PartialEq, Clone)] +enum ConversionKind { + Default, + // zero-extend/chop/bitcast depending on types + SignExtend, + BitToPtr, + PtrToPtr, + AddressOf, +} + +struct ConstantDefinition { + pub dst: SpirvWord, + pub typ: ast::ScalarType, + pub value: ast::ImmediateValue, +} + +pub struct PtrAccess { + underlying_type: ast::Type, + state_space: ast::StateSpace, + dst: SpirvWord, + ptr_src: SpirvWord, + offset_src: T, +} + +struct RepackVectorDetails { + is_extract: bool, + typ: ast::ScalarType, + packed: SpirvWord, + unpacked: Vec, + non_default_implicit_conversion: Option< + fn( + (ast::StateSpace, &ast::Type), + (ast::StateSpace, &ast::Type), + ) -> Result, TranslateError>, + >, +} + +struct FunctionPointerDetails { + dst: SpirvWord, + src: SpirvWord, +} + +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +struct SpirvWord(spirv::Word); + +impl From for SpirvWord { + fn from(value: spirv::Word) -> Self { + Self(value) + } +} +impl From for spirv::Word { + fn from(value: SpirvWord) -> Self { + value.0 + } +} + +impl ast::Operand for SpirvWord { + type Ident = Self; +} + +fn pred_map_variable Result>( + this: ast::PredAt, + f: &mut F, +) -> Result, TranslateError> { + let new_label = f(this.label)?; + Ok(ast::PredAt { + not: this.not, + label: new_label, + }) +} + +impl Result, Err> ast::VisitorMap for X { + fn visit( + &mut self, + args: T, + type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>, + is_dst: bool, + ) -> U { + todo!() + } + + fn visit_ident( + &mut self, + args: ::Ident, + type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>, + is_dst: bool, + ) -> ::Ident { + todo!() + } +} + +fn op_map_variable<'a, F: FnMut(&str) -> Result>( + this: ast::Instruction>, + f: &mut F, +) -> Result>, TranslateError> { + ast::visit_map(this , f) +} diff --git a/ptx/src/pass/normalize.rs b/ptx/src/pass/normalize.rs new file mode 100644 index 00000000..38326852 --- /dev/null +++ b/ptx/src/pass/normalize.rs @@ -0,0 +1,83 @@ +use super::*; +use ptx_parser as ast; + +type NormalizedStatement = Statement< + ( + Option>, + ast::Instruction>, + ), + ast::ParsedOperand, +>; + +fn run<'input, 'b>( + id_defs: &mut FnStringIdResolver<'input, 'b>, + fn_defs: &GlobalFnDeclResolver<'input, 'b>, + func: Vec>>, +) -> Result, TranslateError> { + for s in func.iter() { + match s { + ast::Statement::Label(id) => { + id_defs.add_def(*id, None, false); + } + _ => (), + } + } + let mut result = Vec::new(); + for s in func { + expand_map_variables(id_defs, fn_defs, &mut result, s)?; + } + Ok(result) +} + +fn expand_map_variables<'a, 'b>( + id_defs: &mut FnStringIdResolver<'a, 'b>, + fn_defs: &GlobalFnDeclResolver<'a, 'b>, + result: &mut Vec, + s: ast::Statement>, +) -> Result<(), TranslateError> { + match s { + ast::Statement::Block(block) => { + id_defs.start_block(); + for s in block { + expand_map_variables(id_defs, fn_defs, result, s)?; + } + id_defs.end_block(); + } + ast::Statement::Label(name) => result.push(Statement::Label(id_defs.get_id(name)?)), + ast::Statement::Instruction(p, i) => result.push(Statement::Instruction(( + p.map(|p| pred_map_variable(p, &mut |id| id_defs.get_id(id))) + .transpose()?, + op_map_variable(i, &mut |id| id_defs.get_id(id))?, + ))), + ast::Statement::Variable(var) => { + let var_type = var.var.v_type.clone(); + match var.count { + Some(count) => { + for new_id in + id_defs.add_defs(var.var.name, count, var_type, var.var.state_space, true) + { + result.push(Statement::Variable(ast::Variable { + align: var.var.align, + v_type: var.var.v_type.clone(), + state_space: var.var.state_space, + name: new_id, + array_init: var.var.array_init.clone(), + })) + } + } + None => { + let new_id = + id_defs.add_def(var.var.name, Some((var_type, var.var.state_space)), true); + result.push(Statement::Variable(ast::Variable { + align: var.var.align, + v_type: var.var.v_type.clone(), + state_space: var.var.state_space, + name: new_id, + array_init: var.var.array_init, + })); + } + } + } + }; + Ok(()) +} diff --git a/ptx/src/translate2.rs b/ptx/src/translate2.rs new file mode 100644 index 00000000..4ac5dea7 --- /dev/null +++ b/ptx/src/translate2.rs @@ -0,0 +1,60 @@ +use std::collections::HashMap; +use half::f16; +use ptx_parser as ast; + +fn to_ssa<'input, 'b>( + ptx_impl_imports: &'b mut HashMap>, + mut id_defs: FnStringIdResolver<'input, 'b>, + fn_defs: GlobalFnDeclResolver<'input, 'b>, + func_decl: Rc>>, + f_body: Option>>>, + tuning: Vec, + linkage: ast::LinkingDirective, +) -> Result, TranslateError> { + //deparamize_function_decl(&func_decl)?; + let f_body = match f_body { + Some(vec) => vec, + None => { + return Ok(Function { + func_decl: func_decl, + body: None, + globals: Vec::new(), + import_as: None, + tuning, + linkage, + }) + } + }; + let normalized_ids = normalize_identifiers(&mut id_defs, &fn_defs, f_body)?; + /* + let mut numeric_id_defs = id_defs.finish(); + let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs)?; + let typed_statements = + convert_to_typed_statements(unadorned_statements, &fn_defs, &mut numeric_id_defs)?; + let typed_statements = + fix_special_registers2(ptx_impl_imports, typed_statements, &mut numeric_id_defs)?; + let (func_decl, typed_statements) = + convert_to_stateful_memory_access(func_decl, typed_statements, &mut numeric_id_defs)?; + let ssa_statements = insert_mem_ssa_statements( + typed_statements, + &mut numeric_id_defs, + &mut (*func_decl).borrow_mut(), + )?; + let mut numeric_id_defs = numeric_id_defs.finish(); + let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs)?; + let expanded_statements = + insert_implicit_conversions(expanded_statements, &mut numeric_id_defs)?; + let mut numeric_id_defs = numeric_id_defs.unmut(); + let labeled_statements = normalize_labels(expanded_statements, &mut numeric_id_defs); + let (f_body, globals) = + extract_globals(labeled_statements, ptx_impl_imports, &mut numeric_id_defs)?; + Ok(Function { + func_decl: func_decl, + globals: globals, + body: Some(f_body), + import_as: None, + tuning, + linkage, + }) + */ +} \ No newline at end of file diff --git a/ptx_parser/Cargo.toml b/ptx_parser/Cargo.toml index af3058b2..a4df14f1 100644 --- a/ptx_parser/Cargo.toml +++ b/ptx_parser/Cargo.toml @@ -4,6 +4,8 @@ version = "0.0.0" authors = ["Andrzej Janik "] edition = "2021" +[lib] + [dependencies] logos = "0.14" winnow = { version = "0.6.18" } @@ -11,3 +13,4 @@ ptx_parser_macros = { path = "../ptx_parser_macros" } thiserror = "1.0" bitflags = "1.2" rustc-hash = "2.0.0" +derive_more = { version = "1", features = ["display"] } diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index 6cf12649..87a2f6ba 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -147,9 +147,9 @@ ptx_parser_macros::generate_instruction_type!( Call { data: CallDetails, arguments: CallArgs, - visit: arguments.visit(data, visitor), - visit_mut: arguments.visit_mut(data, visitor), - map: Instruction::Call{ arguments: arguments.map(&data, visitor), data } + visit: arguments.visit(data, visitor)?, + visit_mut: arguments.visit_mut(data, visitor)?, + map: Instruction::Call{ arguments: arguments.map(&data, visitor)?, data } }, Cvt { data: CvtDetails, @@ -488,93 +488,185 @@ ptx_parser_macros::generate_instruction_type!( } ); -pub trait Visitor { - fn visit(&mut self, args: &T, type_space: Option<(&Type, StateSpace)>, is_dst: bool); - fn visit_ident(&self, args: &T::Ident, type_space: Option<(&Type, StateSpace)>, is_dst: bool); +pub trait Visitor { + fn visit( + &mut self, + args: &T, + type_space: Option<(&Type, StateSpace)>, + is_dst: bool, + ) -> Result<(), Err>; + fn visit_ident( + &mut self, + args: &T::Ident, + type_space: Option<(&Type, StateSpace)>, + is_dst: bool, + ) -> Result<(), Err>; +} + +impl, bool) -> Result<(), Err>> + Visitor for Fn +{ + fn visit( + &mut self, + args: &T, + type_space: Option<(&Type, StateSpace)>, + is_dst: bool, + ) -> Result<(), Err> { + (self)(args, type_space, is_dst) + } + + fn visit_ident( + &mut self, + args: &T::Ident, + type_space: Option<(&Type, StateSpace)>, + is_dst: bool, + ) -> Result<(), Err> { + (self)(&T::from_ident(*args), type_space, is_dst) + } } -pub trait VisitorMut { - fn visit(&mut self, args: &mut T, type_space: Option<(&Type, StateSpace)>, is_dst: bool); +pub trait VisitorMut { + fn visit( + &mut self, + args: &mut T, + type_space: Option<(&Type, StateSpace)>, + is_dst: bool, + ) -> Result<(), Err>; fn visit_ident( &mut self, args: &mut T::Ident, type_space: Option<(&Type, StateSpace)>, is_dst: bool, - ); + ) -> Result<(), Err>; } -pub trait VisitorMap { - fn visit(&mut self, args: From, type_space: Option<(&Type, StateSpace)>, is_dst: bool) -> To; +pub trait VisitorMap { + fn visit( + &mut self, + args: From, + type_space: Option<(&Type, StateSpace)>, + is_dst: bool, + ) -> Result; fn visit_ident( &mut self, args: From::Ident, type_space: Option<(&Type, StateSpace)>, is_dst: bool, - ) -> To::Ident; + ) -> Result; } -trait VisitOperand { +impl< + T: Operand, + U: Operand, + Err, + Fn: FnMut(T, Option<(&Type, StateSpace)>, bool) -> Result, + > VisitorMap for Fn +{ + fn visit( + &mut self, + args: T, + type_space: Option<(&Type, StateSpace)>, + is_dst: bool, + ) -> Result { + (self)(args, type_space, is_dst) + } + + fn visit_ident( + &mut self, + args: T::Ident, + type_space: Option<(&Type, StateSpace)>, + is_dst: bool, + ) -> Result { + let value: U = (self)(T::from_ident(args), type_space, is_dst)?; + Ok(value) + } +} + +trait VisitOperand { type Operand: Operand; #[allow(unused)] // Used by generated code - fn visit(&self, fn_: impl FnMut(&Self::Operand)); + fn visit(&self, fn_: impl FnMut(&Self::Operand) -> Result<(), Err>) -> Result<(), Err>; #[allow(unused)] // Used by generated code - fn visit_mut(&mut self, fn_: impl FnMut(&mut Self::Operand)); + fn visit_mut( + &mut self, + fn_: impl FnMut(&mut Self::Operand) -> Result<(), Err>, + ) -> Result<(), Err>; } -impl VisitOperand for T { +impl VisitOperand for T { type Operand = Self; - fn visit(&self, mut fn_: impl FnMut(&Self::Operand)) { + fn visit(&self, mut fn_: impl FnMut(&Self::Operand) -> Result<(), Err>) -> Result<(), Err> { fn_(self) } - fn visit_mut(&mut self, mut fn_: impl FnMut(&mut Self::Operand)) { + fn visit_mut( + &mut self, + mut fn_: impl FnMut(&mut Self::Operand) -> Result<(), Err>, + ) -> Result<(), Err> { fn_(self) } } -impl VisitOperand for Option { +impl VisitOperand for Option { type Operand = T; - fn visit(&self, fn_: impl FnMut(&Self::Operand)) { - self.as_ref().map(fn_); + fn visit(&self, mut fn_: impl FnMut(&Self::Operand) -> Result<(), Err>) -> Result<(), Err> { + if let Some(x) = self { + fn_(x)?; + } + Ok(()) } - fn visit_mut(&mut self, fn_: impl FnMut(&mut Self::Operand)) { - self.as_mut().map(fn_); + fn visit_mut( + &mut self, + mut fn_: impl FnMut(&mut Self::Operand) -> Result<(), Err>, + ) -> Result<(), Err> { + if let Some(x) = self { + fn_(x)?; + } + Ok(()) } } -impl VisitOperand for Vec { +impl VisitOperand for Vec { type Operand = T; - fn visit(&self, mut fn_: impl FnMut(&Self::Operand)) { + fn visit(&self, mut fn_: impl FnMut(&Self::Operand) -> Result<(), Err>) -> Result<(), Err> { for o in self { - fn_(o) + fn_(o)?; } + Ok(()) } - fn visit_mut(&mut self, mut fn_: impl FnMut(&mut Self::Operand)) { + fn visit_mut( + &mut self, + mut fn_: impl FnMut(&mut Self::Operand) -> Result<(), Err>, + ) -> Result<(), Err> { for o in self { - fn_(o) + fn_(o)?; } + Ok(()) } } -trait MapOperand: Sized { +trait MapOperand: Sized { type Input; type Output; #[allow(unused)] // Used by generated code - fn map(self, fn_: impl FnOnce(Self::Input) -> U) -> Self::Output; + fn map( + self, + fn_: impl FnOnce(Self::Input) -> Result, + ) -> Result, Err>; } -impl MapOperand for T { +impl MapOperand for T { type Input = Self; type Output = U; - fn map(self, fn_: impl FnOnce(T) -> U) -> U { + fn map(self, fn_: impl FnOnce(T) -> Result) -> Result { fn_(self) } } -impl MapOperand for Option { +impl MapOperand for Option { type Input = T; type Output = Option; - fn map(self, fn_: impl FnOnce(T) -> U) -> Option { - self.map(|x| fn_(x)) + fn map(self, fn_: impl FnOnce(T) -> Result) -> Result, Err> { + self.map(|x| fn_(x)).transpose() } } @@ -715,10 +807,16 @@ pub enum ParsedOperand { impl Operand for ParsedOperand { type Ident = Ident; + + fn from_ident(ident: Self::Ident) -> Self { + ParsedOperand::Reg(ident) + } } -pub trait Operand { +pub trait Operand: Sized { type Ident: Copy; + + fn from_ident(ident: Self::Ident) -> Self; } #[derive(Copy, Clone)] @@ -1048,67 +1146,77 @@ pub struct CallArgs { impl CallArgs { #[allow(dead_code)] // Used by generated code - fn visit(&self, details: &CallDetails, visitor: &mut impl Visitor) { + fn visit( + &self, + details: &CallDetails, + visitor: &mut impl Visitor, + ) -> Result<(), Err> { for (param, (type_, space)) in self .return_arguments .iter() .zip(details.return_arguments.iter()) { - visitor.visit_ident(param, Some((type_, *space)), true); + visitor.visit_ident(param, Some((type_, *space)), true)?; } - visitor.visit_ident(&self.func, None, false); + visitor.visit_ident(&self.func, None, false)?; for (param, (type_, space)) in self .input_arguments .iter() .zip(details.input_arguments.iter()) { - visitor.visit(param, Some((type_, *space)), true); + visitor.visit(param, Some((type_, *space)), true)?; } + Ok(()) } #[allow(dead_code)] // Used by generated code - fn visit_mut(&mut self, details: &CallDetails, visitor: &mut impl VisitorMut) { + fn visit_mut( + &mut self, + details: &CallDetails, + visitor: &mut impl VisitorMut, + ) -> Result<(), Err> { for (param, (type_, space)) in self .return_arguments .iter_mut() .zip(details.return_arguments.iter()) { - visitor.visit_ident(param, Some((type_, *space)), true); + visitor.visit_ident(param, Some((type_, *space)), true)?; } - visitor.visit_ident(&mut self.func, None, false); + visitor.visit_ident(&mut self.func, None, false)?; for (param, (type_, space)) in self .input_arguments .iter_mut() .zip(details.input_arguments.iter()) { - visitor.visit(param, Some((type_, *space)), true); + visitor.visit(param, Some((type_, *space)), true)?; } + Ok(()) } #[allow(dead_code)] // Used by generated code - fn map( + fn map( self, details: &CallDetails, - visitor: &mut impl VisitorMap, - ) -> CallArgs { + visitor: &mut impl VisitorMap, + ) -> Result, Err> { let return_arguments = self .return_arguments .into_iter() .zip(details.return_arguments.iter()) .map(|(param, (type_, space))| visitor.visit_ident(param, Some((type_, *space)), true)) - .collect::>(); - let func = visitor.visit_ident(self.func, None, false); + .collect::, _>>()?; + let func = visitor.visit_ident(self.func, None, false)?; let input_arguments = self .input_arguments .into_iter() .zip(details.input_arguments.iter()) .map(|(param, (type_, space))| visitor.visit(param, Some((type_, *space)), true)) - .collect::>(); - CallArgs { + .collect::, _>>()?; + Ok(CallArgs { return_arguments, func, input_arguments, - } + }) } } diff --git a/ptx_parser/src/main.rs b/ptx_parser/src/lib.rs similarity index 98% rename from ptx_parser/src/main.rs rename to ptx_parser/src/lib.rs index 5db94f23..cfb87939 100644 --- a/ptx_parser/src/main.rs +++ b/ptx_parser/src/lib.rs @@ -1,8 +1,8 @@ +use derive_more::Display; use logos::Logos; use ptx_parser_macros::derive_parser; use rustc_hash::FxHashMap; use std::fmt::Debug; -use std::mem; use std::num::{ParseFloatError, ParseIntError}; use winnow::ascii::dec_uint; use winnow::combinator::*; @@ -81,16 +81,16 @@ impl VectorPrefix { } } -struct PtxParserState<'input> { - errors: Vec, +struct PtxParserState<'a, 'input> { + errors: &'a mut Vec, function_declarations: FxHashMap<&'input str, (Vec<(ast::Type, StateSpace)>, Vec<(ast::Type, StateSpace)>)>, } -impl<'input> PtxParserState<'input> { - fn new() -> Self { +impl<'a, 'input> PtxParserState<'a, 'input> { + fn new(errors: &'a mut Vec) -> Self { Self { - errors: Vec::new(), + errors, function_declarations: FxHashMap::default(), } } @@ -115,7 +115,7 @@ impl<'input> PtxParserState<'input> { } } -impl<'input> Debug for PtxParserState<'input> { +impl<'a, 'input> Debug for PtxParserState<'a, 'input> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("PtxParserState") .field("errors", &self.errors) /* .field("function_decl", &self.function_decl) */ @@ -123,7 +123,7 @@ impl<'input> Debug for PtxParserState<'input> { } } -type PtxParser<'a, 'input> = Stateful<&'a [Token<'input>], PtxParserState<'input>>; +type PtxParser<'a, 'input> = Stateful<&'a [Token<'input>], PtxParserState<'a, 'input>>; fn ident<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<&'input str> { any.verify_map(|t| { @@ -277,6 +277,18 @@ fn immediate_value<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult(text: &'input str) -> Option> { + let lexer = Token::lexer(text); + let input = lexer.collect::, _>>().ok()?; + let mut errors = Vec::new(); + let state = PtxParserState::new(&mut errors); + let parser = PtxParser { + state, + input: &input[..], + }; + module.parse(parser).ok() +} + fn module<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult> { ( version, @@ -818,6 +830,8 @@ pub enum PtxError { source: ParseFloatError, }, #[error("")] + Lexer(#[from] TokenError), + #[error("")] Todo, #[error("")] SyntaxError, @@ -1042,9 +1056,15 @@ fn empty_call<'input>( type ParsedOperandStr<'input> = ast::ParsedOperand<&'input str>; +#[derive(Clone, PartialEq, Default, Debug, Display)] +pub struct TokenError; + +impl std::error::Error for TokenError {} + derive_parser!( #[derive(Logos, PartialEq, Eq, Debug, Clone, Copy)] #[logos(skip r"\s+")] + #[logos(error = TokenError)] enum Token<'input> { #[token(",")] Comma, @@ -1134,6 +1154,7 @@ derive_parser!( pub enum StateSpace { Reg, Generic, + Sreg, } #[derive(Copy, Clone, PartialEq, Eq, Hash)] @@ -2825,57 +2846,6 @@ derive_parser!( ); -fn main() { - use winnow::Parser; - - let lexer = Token::lexer( - " - .version 6.5 - .target sm_30 - .address_size 64 - - .const .align 8 .b32 constparams; - - .visible .entry const( - .param .u64 input, - .param .u64 output - ) - { - .reg .u64 in_addr; - .reg .u64 out_addr; - .reg .b16 temp1; - .reg .b16 temp2; - .reg .b16 temp3; - .reg .b16 temp4; - - ld.param.u64 in_addr, [input]; - ld.param.u64 out_addr, [output]; - - ld.const.b16 temp1, [constparams]; - ld.const.b16 temp2, [constparams+2]; - ld.const.b16 temp3, [constparams+4]; - ld.const.b16 temp4, [constparams+6]; - st.u16 [out_addr], temp1; - st.u16 [out_addr+2], temp2; - st.u16 [out_addr+4], temp3; - st.u16 [out_addr+6], temp4; - ret; - } - - ", - ); - let tokens = lexer.clone().collect::>(); - println!("{:?}", &tokens); - let tokens = lexer.map(|t| t.unwrap()).collect::>(); - println!("{:?}", &tokens); - let stream = PtxParser { - input: &tokens[..], - state: PtxParserState::new(), - }; - let _module = module.parse(stream).unwrap(); - println!("{}", mem::size_of::()); -} - #[cfg(test)] mod tests { use super::target; diff --git a/ptx_parser_macros/src/lib.rs b/ptx_parser_macros/src/lib.rs index a2f8396f..4502c953 100644 --- a/ptx_parser_macros/src/lib.rs +++ b/ptx_parser_macros/src/lib.rs @@ -1017,7 +1017,7 @@ pub fn generate_instruction_type(tokens: proc_macro::TokenStream) -> proc_macro: input.emit_arg_types(&mut result); input.emit_instruction_type(&mut result); input.emit_visit(&mut result); - input.emit_visit_mut(&mut result); + //input.emit_visit_mut(&mut result); input.emit_visit_map(&mut result); result.into() } diff --git a/ptx_parser_macros_impl/src/lib.rs b/ptx_parser_macros_impl/src/lib.rs index 4532964b..3e53607e 100644 --- a/ptx_parser_macros_impl/src/lib.rs +++ b/ptx_parser_macros_impl/src/lib.rs @@ -67,37 +67,29 @@ impl GenerateInstructionType { let visit_ref = kind.reference(); let visitor_type = format_ident!("Visitor{}", kind.type_suffix()); let visit_fn = format_ident!("visit{}", kind.fn_suffix()); - let visit_slice_fn = format_ident!("visit{}_slice", kind.fn_suffix()); let (type_parameters, visitor_parameters, return_type) = if kind == VisitKind::Map { ( - quote! { <#type_parameters, To: Operand> }, - quote! { <#short_parameters, To> }, - quote! { #type_name }, + quote! { <#type_parameters, To: Operand, Err> }, + quote! { <#short_parameters, To, Err> }, + quote! { std::result::Result<#type_name, Err> }, ) } else { ( - quote! { <#type_parameters> }, - quote! { <#short_parameters> }, - quote! { () }, + quote! { <#type_parameters, Err> }, + quote! { <#short_parameters, Err> }, + quote! { std::result::Result<(), Err> }, ) }; quote! { - fn #visit_fn #type_parameters (i: #visit_ref #type_name<#short_parameters>, visitor: &mut impl #visitor_type #visitor_parameters ) -> #return_type { - match i { + pub fn #visit_fn #type_parameters (i: #visit_ref #type_name<#short_parameters>, visitor: &mut impl #visitor_type #visitor_parameters ) -> #return_type { + Ok(match i { #inner_tokens - } + }) } }.to_tokens(tokens); if kind == VisitKind::Map { return; } - quote! { - fn #visit_slice_fn #type_parameters (instructions: #visit_ref [#type_name<#short_parameters>], visitor: &mut impl #visitor_type #visitor_parameters) { - for i in instructions { - #visit_fn(i, visitor) - } - } - }.to_tokens(tokens); } } @@ -630,14 +622,14 @@ impl ArgumentField { quote! { { #type_space - visitor.visit_ident(&mut arguments.#name, type_space, #is_dst); + visitor.visit_ident(&mut arguments.#name, type_space, #is_dst)?; } } } else { quote! { { #type_space - visitor.visit_ident(& arguments.#name, type_space, #is_dst); + visitor.visit_ident(& arguments.#name, type_space, #is_dst)?; } } } @@ -663,7 +655,7 @@ impl ArgumentField { }; quote! {{ #type_space - #operand_fn(#arguments_name, |x| visitor.visit(x, type_space, #is_dst)); + #operand_fn(#arguments_name, |x| visitor.visit(x, type_space, #is_dst))?; }} } } @@ -701,11 +693,11 @@ impl ArgumentField { }; let map_call = if is_ident { quote! { - visitor.visit_ident(arguments.#name, type_space, #is_dst) + visitor.visit_ident(arguments.#name, type_space, #is_dst)? } } else { quote! { - MapOperand::map(arguments.#name, |x| visitor.visit(x, type_space, #is_dst)) + MapOperand::map(arguments.#name, |x| visitor.visit(x, type_space, #is_dst))? } }; quote! { From 12ef8dbc9001451f4892c49903844826e2f06fcf Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Fri, 23 Aug 2024 03:03:57 +0200 Subject: [PATCH 26/47] Port first pass --- ptx/src/lib.rs | 3 +- ptx/src/pass/mod.rs | 396 ++++++++++++++++++++++++++++++++++---- ptx/src/pass/normalize.rs | 8 +- ptx/src/translate2.rs | 60 ------ ptx_parser/src/ast.rs | 62 +++++- 5 files changed, 421 insertions(+), 108 deletions(-) delete mode 100644 ptx/src/translate2.rs diff --git a/ptx/src/lib.rs b/ptx/src/lib.rs index b70019ea..5e95dae2 100644 --- a/ptx/src/lib.rs +++ b/ptx/src/lib.rs @@ -24,11 +24,10 @@ lalrpop_mod!( ); pub mod ast; -mod pass; +pub(crate) mod pass; #[cfg(test)] mod test; mod translate; -mod translate2; use std::fmt; diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index 7b794d6d..934a4726 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -1,12 +1,347 @@ use ptx_parser as ast; +use rspirv::{binary::Assemble, dr}; use std::{ borrow::Cow, cell::RefCell, collections::{hash_map, HashMap}, + ffi::CString, rc::Rc, }; -mod normalize; +pub(crate) mod normalize; + +static ZLUDA_PTX_IMPL_INTEL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.spv"); +static ZLUDA_PTX_IMPL_AMD: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc"); +const ZLUDA_PTX_PREFIX: &'static str = "__zluda_ptx_impl__"; + +pub fn to_spirv_module<'input>(ast: ast::Module<'input>) -> Result { + let mut id_defs = GlobalStringIdResolver::<'input>::new(SpirvWord(1)); + let mut ptx_impl_imports = HashMap::new(); + let directives = ast + .directives + .into_iter() + .filter_map(|directive| { + translate_directive(&mut id_defs, &mut ptx_impl_imports, directive).transpose() + }) + .collect::, _>>()?; + /* + let directives = hoist_function_globals(directives); + let must_link_ptx_impl = ptx_impl_imports.len() > 0; + let mut directives = ptx_impl_imports + .into_iter() + .map(|(_, v)| v) + .chain(directives.into_iter()) + .collect::>(); + let mut builder = dr::Builder::new(); + builder.reserve_ids(id_defs.current_id()); + let call_map = MethodsCallMap::new(&directives); + let mut directives = + convert_dynamic_shared_memory_usage(directives, &call_map, &mut || builder.id()); + normalize_variable_decls(&mut directives); + let denorm_information = compute_denorm_information(&directives); + // https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module + builder.set_version(1, 3); + emit_capabilities(&mut builder); + emit_extensions(&mut builder); + let opencl_id = emit_opencl_import(&mut builder); + emit_memory_model(&mut builder); + let mut map = TypeWordMap::new(&mut builder); + //emit_builtins(&mut builder, &mut map, &id_defs); + let mut kernel_info = HashMap::new(); + let (build_options, should_flush_denorms) = + emit_denorm_build_string(&call_map, &denorm_information); + let (directives, globals_use_map) = get_globals_use_map(directives); + emit_directives( + &mut builder, + &mut map, + &id_defs, + opencl_id, + should_flush_denorms, + &call_map, + globals_use_map, + directives, + &mut kernel_info, + )?; + let spirv = builder.module(); + Ok(Module { + spirv, + kernel_info, + should_link_ptx_impl: if must_link_ptx_impl { + Some((ZLUDA_PTX_IMPL_INTEL, ZLUDA_PTX_IMPL_AMD)) + } else { + None + }, + build_options, + }) + */ + todo!() +} + +fn translate_directive<'input, 'a>( + id_defs: &'a mut GlobalStringIdResolver<'input>, + ptx_impl_imports: &'a mut HashMap>, + d: ast::Directive<'input, ast::ParsedOperand<&'input str>>, +) -> Result>, TranslateError> { + Ok(match d { + ast::Directive::Variable(linking, var) => Some(Directive::Variable( + linking, + ast::Variable { + align: var.align, + v_type: var.v_type.clone(), + state_space: var.state_space, + name: id_defs.get_or_add_def_typed(var.name, var.v_type, var.state_space, true), + array_init: var.array_init, + }, + )), + ast::Directive::Method(linkage, f) => { + translate_function(id_defs, ptx_impl_imports, linkage, f)?.map(Directive::Method) + } + }) +} + +type ParsedFunction<'a> = ast::Function<'a, &'a str, ast::Statement>>; + +fn translate_function<'input, 'a>( + id_defs: &'a mut GlobalStringIdResolver<'input>, + ptx_impl_imports: &'a mut HashMap>, + linkage: ast::LinkingDirective, + f: ParsedFunction<'input>, +) -> Result>, TranslateError> { + let import_as = match &f.func_directive { + ast::MethodDeclaration { + name: ast::MethodName::Func(func_name), + .. + } if *func_name == "__assertfail" || *func_name == "vprintf" => { + Some([ZLUDA_PTX_PREFIX, func_name].concat()) + } + _ => None, + }; + let (str_resolver, fn_resolver, fn_decl) = id_defs.start_fn(&f.func_directive)?; + let mut func = to_ssa( + ptx_impl_imports, + str_resolver, + fn_resolver, + fn_decl, + f.body, + f.tuning, + linkage, + )?; + func.import_as = import_as; + if func.import_as.is_some() { + ptx_impl_imports.insert( + func.import_as.as_ref().unwrap().clone(), + Directive::Method(func), + ); + Ok(None) + } else { + Ok(Some(func)) + } +} + +fn to_ssa<'input, 'b>( + ptx_impl_imports: &'b mut HashMap>, + mut id_defs: FnStringIdResolver<'input, 'b>, + fn_defs: GlobalFnDeclResolver<'input, 'b>, + func_decl: Rc>>, + f_body: Option>>>, + tuning: Vec, + linkage: ast::LinkingDirective, +) -> Result, TranslateError> { + //deparamize_function_decl(&func_decl)?; + let f_body = match f_body { + Some(vec) => vec, + None => { + return Ok(Function { + func_decl: func_decl, + body: None, + globals: Vec::new(), + import_as: None, + tuning, + linkage, + }) + } + }; + let normalized_ids = normalize::run(&mut id_defs, &fn_defs, f_body)?; + todo!() + /* + let mut numeric_id_defs = id_defs.finish(); + let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs)?; + let typed_statements = + convert_to_typed_statements(unadorned_statements, &fn_defs, &mut numeric_id_defs)?; + let typed_statements = + fix_special_registers2(ptx_impl_imports, typed_statements, &mut numeric_id_defs)?; + let (func_decl, typed_statements) = + convert_to_stateful_memory_access(func_decl, typed_statements, &mut numeric_id_defs)?; + let ssa_statements = insert_mem_ssa_statements( + typed_statements, + &mut numeric_id_defs, + &mut (*func_decl).borrow_mut(), + )?; + let mut numeric_id_defs = numeric_id_defs.finish(); + let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs)?; + let expanded_statements = + insert_implicit_conversions(expanded_statements, &mut numeric_id_defs)?; + let mut numeric_id_defs = numeric_id_defs.unmut(); + let labeled_statements = normalize_labels(expanded_statements, &mut numeric_id_defs); + let (f_body, globals) = + extract_globals(labeled_statements, ptx_impl_imports, &mut numeric_id_defs)?; + Ok(Function { + func_decl: func_decl, + globals: globals, + body: Some(f_body), + import_as: None, + tuning, + linkage, + }) + */ +} + +pub struct Module { + pub spirv: dr::Module, + pub kernel_info: HashMap, + pub should_link_ptx_impl: Option<(&'static [u8], &'static [u8])>, + pub build_options: CString, +} + +impl Module { + pub fn assemble(&self) -> Vec { + self.spirv.assemble() + } +} + +struct GlobalStringIdResolver<'input> { + current_id: SpirvWord, + variables: HashMap, SpirvWord>, + reverse_variables: HashMap, + variables_type_check: HashMap>, + special_registers: SpecialRegistersMap, + fns: HashMap>, +} + +impl<'input> GlobalStringIdResolver<'input> { + fn new(start_id: SpirvWord) -> Self { + Self { + current_id: start_id, + variables: HashMap::new(), + reverse_variables: HashMap::new(), + variables_type_check: HashMap::new(), + special_registers: SpecialRegistersMap::new(), + fns: HashMap::new(), + } + } + + fn get_or_add_def(&mut self, id: &'input str) -> SpirvWord { + self.get_or_add_impl(id, None) + } + + fn get_or_add_def_typed( + &mut self, + id: &'input str, + typ: ast::Type, + state_space: ast::StateSpace, + is_variable: bool, + ) -> SpirvWord { + self.get_or_add_impl(id, Some((typ, state_space, is_variable))) + } + + fn get_or_add_impl( + &mut self, + id: &'input str, + typ: Option<(ast::Type, ast::StateSpace, bool)>, + ) -> SpirvWord { + let id = match self.variables.entry(Cow::Borrowed(id)) { + hash_map::Entry::Occupied(e) => *(e.get()), + hash_map::Entry::Vacant(e) => { + let numeric_id = self.current_id; + e.insert(numeric_id); + self.reverse_variables.insert(numeric_id, id); + self.current_id.0 += 1; + numeric_id + } + }; + self.variables_type_check.insert(id, typ); + id + } + + fn get_id(&self, id: &str) -> Result { + self.variables + .get(id) + .copied() + .ok_or_else(error_unknown_symbol) + } + + fn current_id(&self) -> SpirvWord { + self.current_id + } + + fn start_fn<'b>( + &'b mut self, + header: &'b ast::MethodDeclaration<'input, &'input str>, + ) -> Result< + ( + FnStringIdResolver<'input, 'b>, + GlobalFnDeclResolver<'input, 'b>, + Rc>>, + ), + TranslateError, + > { + // In case a function decl was inserted earlier we want to use its id + let name_id = self.get_or_add_def(header.name()); + let mut fn_resolver = FnStringIdResolver { + current_id: &mut self.current_id, + global_variables: &self.variables, + global_type_check: &self.variables_type_check, + special_registers: &mut self.special_registers, + variables: vec![HashMap::new(); 1], + type_check: HashMap::new(), + }; + let return_arguments = rename_fn_params(&mut fn_resolver, &header.return_arguments); + let input_arguments = rename_fn_params(&mut fn_resolver, &header.input_arguments); + let name = match header.name { + ast::MethodName::Kernel(name) => ast::MethodName::Kernel(name), + ast::MethodName::Func(_) => ast::MethodName::Func(name_id), + }; + let fn_decl = ast::MethodDeclaration { + return_arguments, + name, + input_arguments, + shared_mem: None, + }; + let new_fn_decl = if !matches!(fn_decl.name, ast::MethodName::Kernel(_)) { + let resolver = FnSigMapper::remap_to_spirv_repr(fn_decl); + let new_fn_decl = resolver.func_decl.clone(); + self.fns.insert(name_id, resolver); + new_fn_decl + } else { + Rc::new(RefCell::new(fn_decl)) + }; + Ok(( + fn_resolver, + GlobalFnDeclResolver { fns: &self.fns }, + new_fn_decl, + )) + } +} + +fn rename_fn_params<'a, 'b>( + fn_resolver: &mut FnStringIdResolver<'a, 'b>, + args: &'b [ast::Variable<&'a str>], +) -> Vec> { + args.iter() + .map(|a| ast::Variable { + name: fn_resolver.add_def(a.name, Some((a.v_type.clone(), a.state_space)), true), + v_type: a.v_type.clone(), + state_space: a.state_space, + align: a.align, + array_init: a.array_init.clone(), + }) + .collect() +} + +pub struct KernelInfo { + pub arguments_sizes: Vec<(usize, bool)>, + pub uses_shared_mem: bool, +} #[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone)] enum PtxSpecialRegister { @@ -108,10 +443,10 @@ impl SpecialRegistersMap { struct FnStringIdResolver<'input, 'b> { current_id: &'b mut SpirvWord, global_variables: &'b HashMap, SpirvWord>, - global_type_check: &'b HashMap>, + global_type_check: &'b HashMap>, special_registers: &'b mut SpecialRegistersMap, variables: Vec, SpirvWord>>, - type_check: HashMap>, + type_check: HashMap>, } impl<'a, 'b> FnStringIdResolver<'a, 'b> { @@ -160,7 +495,7 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> { .unwrap() .insert(Cow::Borrowed(id), numeric_id); self.type_check.insert( - numeric_id.0, + numeric_id, typ.map(|(typ, space)| (typ, space, is_variable)), ); self.current_id.0 += 1; @@ -183,7 +518,7 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> { SpirvWord(numeric_id.0 + i), ); self.type_check.insert( - numeric_id.0 + i, + SpirvWord(numeric_id.0 + i), Some((typ.clone(), state_space, is_variable)), ); } @@ -196,8 +531,8 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> { struct NumericIdResolver<'b> { current_id: &'b mut SpirvWord, - global_type_check: &'b HashMap>, - type_check: HashMap>, + global_type_check: &'b HashMap>, + type_check: HashMap>, special_registers: &'b mut SpecialRegistersMap, } @@ -210,12 +545,12 @@ impl<'b> NumericIdResolver<'b> { &self, id: SpirvWord, ) -> Result<(ast::Type, ast::StateSpace, bool), TranslateError> { - match self.type_check.get(&id.0) { + match self.type_check.get(&id) { Some(Some(x)) => Ok(x.clone()), Some(None) => Err(TranslateError::UntypedSymbol), None => match self.special_registers.get(id) { Some(x) => Ok((x.get_type(), ast::StateSpace::Sreg, true)), - None => match self.global_type_check.get(&id.0) { + None => match self.global_type_check.get(&id) { Some(Some(result)) => Ok(result.clone()), Some(None) | None => Err(TranslateError::UntypedSymbol), }, @@ -228,7 +563,7 @@ impl<'b> NumericIdResolver<'b> { fn register_variable(&mut self, typ: ast::Type, state_space: ast::StateSpace) -> SpirvWord { let new_id = *self.current_id; self.type_check - .insert(new_id.0, Some((typ, state_space, true))); + .insert(new_id, Some((typ, state_space, true))); self.current_id.0 += 1; new_id } @@ -236,7 +571,7 @@ impl<'b> NumericIdResolver<'b> { fn register_intermediate(&mut self, typ: Option<(ast::Type, ast::StateSpace)>) -> SpirvWord { let new_id = *self.current_id; self.type_check - .insert(new_id.0, typ.map(|(t, space)| (t, space, false))); + .insert(new_id, typ.map(|(t, space)| (t, space, false))); self.current_id.0 += 1; new_id } @@ -490,6 +825,10 @@ impl From for spirv::Word { impl ast::Operand for SpirvWord { type Ident = Self; + + fn from_ident(ident: Self::Ident) -> Self { + ident + } } fn pred_map_variable Result>( @@ -503,29 +842,18 @@ fn pred_map_variable Result>( }) } -impl Result, Err> ast::VisitorMap for X { - fn visit( - &mut self, - args: T, - type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>, - is_dst: bool, - ) -> U { - todo!() - } - - fn visit_ident( - &mut self, - args: ::Ident, - type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>, - is_dst: bool, - ) -> ::Ident { - todo!() - } +pub(crate) enum Directive<'input> { + Variable(ast::LinkingDirective, ast::Variable), + Method(Function<'input>), } -fn op_map_variable<'a, F: FnMut(&str) -> Result>( - this: ast::Instruction>, - f: &mut F, -) -> Result>, TranslateError> { - ast::visit_map(this , f) +pub(crate) struct Function<'input> { + pub func_decl: Rc>>, + pub globals: Vec>, + pub body: Option>, + import_as: Option, + tuning: Vec, + linkage: ast::LinkingDirective, } + +type ExpandedStatement = Statement, SpirvWord>; \ No newline at end of file diff --git a/ptx/src/pass/normalize.rs b/ptx/src/pass/normalize.rs index 38326852..68ac26ea 100644 --- a/ptx/src/pass/normalize.rs +++ b/ptx/src/pass/normalize.rs @@ -9,7 +9,7 @@ type NormalizedStatement = Statement< ast::ParsedOperand, >; -fn run<'input, 'b>( +pub(crate) fn run<'input, 'b>( id_defs: &mut FnStringIdResolver<'input, 'b>, fn_defs: &GlobalFnDeclResolver<'input, 'b>, func: Vec>>, @@ -47,7 +47,11 @@ fn expand_map_variables<'a, 'b>( ast::Statement::Instruction(p, i) => result.push(Statement::Instruction(( p.map(|p| pred_map_variable(p, &mut |id| id_defs.get_id(id))) .transpose()?, - op_map_variable(i, &mut |id| id_defs.get_id(id))?, + ast::visit_map(i, &mut |id, + _: Option<(&ast::Type, ast::StateSpace)>, + _: bool| { + id_defs.get_id(id) + })?, ))), ast::Statement::Variable(var) => { let var_type = var.var.v_type.clone(); diff --git a/ptx/src/translate2.rs b/ptx/src/translate2.rs deleted file mode 100644 index 4ac5dea7..00000000 --- a/ptx/src/translate2.rs +++ /dev/null @@ -1,60 +0,0 @@ -use std::collections::HashMap; -use half::f16; -use ptx_parser as ast; - -fn to_ssa<'input, 'b>( - ptx_impl_imports: &'b mut HashMap>, - mut id_defs: FnStringIdResolver<'input, 'b>, - fn_defs: GlobalFnDeclResolver<'input, 'b>, - func_decl: Rc>>, - f_body: Option>>>, - tuning: Vec, - linkage: ast::LinkingDirective, -) -> Result, TranslateError> { - //deparamize_function_decl(&func_decl)?; - let f_body = match f_body { - Some(vec) => vec, - None => { - return Ok(Function { - func_decl: func_decl, - body: None, - globals: Vec::new(), - import_as: None, - tuning, - linkage, - }) - } - }; - let normalized_ids = normalize_identifiers(&mut id_defs, &fn_defs, f_body)?; - /* - let mut numeric_id_defs = id_defs.finish(); - let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs)?; - let typed_statements = - convert_to_typed_statements(unadorned_statements, &fn_defs, &mut numeric_id_defs)?; - let typed_statements = - fix_special_registers2(ptx_impl_imports, typed_statements, &mut numeric_id_defs)?; - let (func_decl, typed_statements) = - convert_to_stateful_memory_access(func_decl, typed_statements, &mut numeric_id_defs)?; - let ssa_statements = insert_mem_ssa_statements( - typed_statements, - &mut numeric_id_defs, - &mut (*func_decl).borrow_mut(), - )?; - let mut numeric_id_defs = numeric_id_defs.finish(); - let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs)?; - let expanded_statements = - insert_implicit_conversions(expanded_statements, &mut numeric_id_defs)?; - let mut numeric_id_defs = numeric_id_defs.unmut(); - let labeled_statements = normalize_labels(expanded_statements, &mut numeric_id_defs); - let (f_body, globals) = - extract_globals(labeled_statements, ptx_impl_imports, &mut numeric_id_defs)?; - Ok(Function { - func_decl: func_decl, - globals: globals, - body: Some(f_body), - import_as: None, - tuning, - linkage, - }) - */ -} \ No newline at end of file diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index 87a2f6ba..ee9f9681 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -555,12 +555,46 @@ pub trait VisitorMap { ) -> Result; } -impl< - T: Operand, - U: Operand, - Err, - Fn: FnMut(T, Option<(&Type, StateSpace)>, bool) -> Result, - > VisitorMap for Fn +impl VisitorMap, ParsedOperand, Err> for Fn +where + Fn: FnMut(T, Option<(&Type, StateSpace)>, bool) -> Result, +{ + fn visit( + &mut self, + args: ParsedOperand, + type_space: Option<(&Type, StateSpace)>, + is_dst: bool, + ) -> Result, Err> { + Ok(match args { + ParsedOperand::Reg(ident) => ParsedOperand::Reg((self)(ident, type_space, is_dst)?), + ParsedOperand::RegOffset(ident, imm) => { + ParsedOperand::RegOffset((self)(ident, type_space, is_dst)?, imm) + } + ParsedOperand::Imm(imm) => ParsedOperand::Imm(imm), + ParsedOperand::VecMember(ident, index) => { + ParsedOperand::VecMember((self)(ident, type_space, is_dst)?, index) + } + ParsedOperand::VecPack(vec) => ParsedOperand::VecPack( + vec.into_iter() + .map(|ident| (self)(ident, type_space, is_dst)) + .collect::, _>>()?, + ), + }) + } + + fn visit_ident( + &mut self, + args: T, + type_space: Option<(&Type, StateSpace)>, + is_dst: bool, + ) -> Result { + (self)(args, type_space, is_dst) + } +} + +impl, U: Operand, Err, Fn> VisitorMap for Fn +where + Fn: FnMut(T, Option<(&Type, StateSpace)>, bool) -> Result, { fn visit( &mut self, @@ -573,12 +607,11 @@ impl< fn visit_ident( &mut self, - args: T::Ident, + args: T, type_space: Option<(&Type, StateSpace)>, is_dst: bool, - ) -> Result { - let value: U = (self)(T::from_ident(args), type_space, is_dst)?; - Ok(value) + ) -> Result { + (self)(args, type_space, is_dst) } } @@ -925,6 +958,15 @@ pub struct MethodDeclaration<'input, ID> { pub shared_mem: Option, } +impl<'input> MethodDeclaration<'input, &'input str> { + pub fn name(&self) -> &'input str { + match self.name { + MethodName::Kernel(n) => n, + MethodName::Func(n) => n, + } + } +} + #[derive(Hash, PartialEq, Eq, Copy, Clone)] pub enum MethodName<'input, ID> { Kernel(&'input str), From 7ea990edb72c0635c5e59a263fac8385a6208e8b Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Fri, 23 Aug 2024 16:26:28 +0200 Subject: [PATCH 27/47] Work on more passes --- ptx/src/pass/convert_to_typed.rs | 140 ++++++++++++++++++ ptx/src/pass/mod.rs | 45 +++++- ...{normalize.rs => normalize_identifiers.rs} | 8 - ptx/src/pass/normalize_predicates.rs | 44 ++++++ 4 files changed, 222 insertions(+), 15 deletions(-) create mode 100644 ptx/src/pass/convert_to_typed.rs rename ptx/src/pass/{normalize.rs => normalize_identifiers.rs} (91%) create mode 100644 ptx/src/pass/normalize_predicates.rs diff --git a/ptx/src/pass/convert_to_typed.rs b/ptx/src/pass/convert_to_typed.rs new file mode 100644 index 00000000..3dfef55b --- /dev/null +++ b/ptx/src/pass/convert_to_typed.rs @@ -0,0 +1,140 @@ +use super::*; +use ptx_parser as ast; + +pub(crate) fn run( + func: Vec, + fn_defs: &GlobalFnDeclResolver, + id_defs: &mut NumericIdResolver, +) -> Result, TranslateError> { + let mut result = Vec::::with_capacity(func.len()); + for s in func { + match s { + Statement::Instruction(inst) => match inst { + ast::Instruction::Mov { + data, + arguments: + ast::MovArgs { + dst: ast::ParsedOperand::Reg(dst_reg), + src: ast::ParsedOperand::Reg(src_reg), + }, + } if fn_defs.fns.contains_key(&src_reg) => { + if data.typ != ast::Type::Scalar(ast::ScalarType::U64) { + return Err(TranslateError::MismatchedType); + } + result.push(TypedStatement::FunctionPointer(FunctionPointerDetails { + dst: dst_reg, + src: src_reg, + })); + } + ast::Instruction::Call(call) => { + let resolver = fn_defs.get_fn_sig_resolver(call.func)?; + let resolved_call = resolver.resolve_in_spirv_repr(call)?; + let mut visitor = VectorRepackVisitor::new(&mut result, id_defs); + let reresolved_call = resolved_call.visit(&mut visitor)?; + visitor.func.push(reresolved_call); + visitor.func.extend(visitor.post_stmts); + } + inst => { + let mut visitor = VectorRepackVisitor::new(&mut result, id_defs); + let instruction = Statement::Instruction(inst.map(&mut visitor)?); + visitor.func.push(instruction); + visitor.func.extend(visitor.post_stmts); + } + }, + Statement::Label(i) => result.push(Statement::Label(i)), + Statement::Variable(v) => result.push(Statement::Variable(v)), + Statement::Conditional(c) => result.push(Statement::Conditional(c)), + _ => return Err(error_unreachable()), + } + } + Ok(result) +} + +struct VectorRepackVisitor<'a, 'b> { + func: &'b mut Vec, + id_def: &'b mut NumericIdResolver<'a>, + post_stmts: Option, +} + +impl<'a, 'b> VectorRepackVisitor<'a, 'b> { + fn new(func: &'b mut Vec, id_def: &'b mut NumericIdResolver<'a>) -> Self { + VectorRepackVisitor { + func, + id_def, + post_stmts: None, + } + } + + fn convert_vector( + &mut self, + is_dst: bool, + non_default_implicit_conversion: Option< + fn( + (ast::StateSpace, &ast::Type), + (ast::StateSpace, &ast::Type), + ) -> Result, TranslateError>, + >, + typ: &ast::Type, + state_space: ast::StateSpace, + idx: Vec, + ) -> Result { + // mov.u32 foobar, {a,b}; + let scalar_t = match typ { + ast::Type::Vector(scalar_t, _) => *scalar_t, + _ => return Err(TranslateError::MismatchedType), + }; + let temp_vec = self + .id_def + .register_intermediate(Some((typ.clone(), state_space))); + let statement = Statement::RepackVector(RepackVectorDetails { + is_extract: is_dst, + typ: scalar_t, + packed: temp_vec, + unpacked: idx, + non_default_implicit_conversion, + }); + if is_dst { + self.post_stmts = Some(statement); + } else { + self.func.push(statement); + } + Ok(temp_vec) + } +} + +impl<'a, 'b> ast::VisitorMap, TypedOperand, TranslateError> + for VectorRepackVisitor<'a, 'b> +{ + fn visit_ident( + &mut self, + ident: SpirvWord, + _: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>, + _: bool, + ) -> Result { + Ok(ident) + } + + fn visit( + &mut self, + op: ast::ParsedOperand, + type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>, + is_dst: bool, + ) -> Result { + Ok(match op { + ast::ParsedOperand::Reg(reg) => TypedOperand::Reg(reg), + ast::ParsedOperand::RegOffset(reg, offset) => TypedOperand::RegOffset(reg, offset), + ast::ParsedOperand::Imm(x) => TypedOperand::Imm(x), + ast::ParsedOperand::VecMember(vec, idx) => TypedOperand::VecMember(vec, idx), + ast::ParsedOperand::VecPack(vec) => { + let (type_, space) = type_space.ok_or(TranslateError::MismatchedType)?; + TypedOperand::Reg(self.convert_vector( + is_dst, + desc.non_default_implicit_conversion, + type_, + space, + vec, + )?) + } + }) + } +} diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index 934a4726..bedf46ab 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -8,7 +8,9 @@ use std::{ rc::Rc, }; -pub(crate) mod normalize; +mod convert_to_typed; +mod normalize_identifiers; +mod normalize_predicates; static ZLUDA_PTX_IMPL_INTEL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.spv"); static ZLUDA_PTX_IMPL_AMD: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc"); @@ -161,13 +163,13 @@ fn to_ssa<'input, 'b>( }) } }; - let normalized_ids = normalize::run(&mut id_defs, &fn_defs, f_body)?; - todo!() - /* + let normalized_ids = normalize_identifiers::run(&mut id_defs, &fn_defs, f_body)?; let mut numeric_id_defs = id_defs.finish(); - let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs)?; + let unadorned_statements = normalize_predicates::run(normalized_ids, &mut numeric_id_defs)?; let typed_statements = - convert_to_typed_statements(unadorned_statements, &fn_defs, &mut numeric_id_defs)?; + convert_to_typed::run(unadorned_statements, &fn_defs, &mut numeric_id_defs)?; + todo!() + /* let typed_statements = fix_special_registers2(ptx_impl_imports, typed_statements, &mut numeric_id_defs)?; let (func_decl, typed_statements) = @@ -856,4 +858,33 @@ pub(crate) struct Function<'input> { linkage: ast::LinkingDirective, } -type ExpandedStatement = Statement, SpirvWord>; \ No newline at end of file +type ExpandedStatement = Statement, SpirvWord>; + +type NormalizedStatement = Statement< + ( + Option>, + ast::Instruction>, + ), + ast::ParsedOperand, +>; + +type UnconditionalStatement = + Statement>, ast::ParsedOperand>; + +type TypedStatement = Statement, TypedOperand>; + +#[derive(Copy, Clone)] +enum TypedOperand { + Reg(SpirvWord), + RegOffset(SpirvWord, i32), + Imm(ast::ImmediateValue), + VecMember(SpirvWord, u8), +} + +impl ast::Operand for TypedOperand { + type Ident = SpirvWord; + + fn from_ident(ident: Self::Ident) -> Self { + TypedOperand::Reg(ident) + } +} diff --git a/ptx/src/pass/normalize.rs b/ptx/src/pass/normalize_identifiers.rs similarity index 91% rename from ptx/src/pass/normalize.rs rename to ptx/src/pass/normalize_identifiers.rs index 68ac26ea..6588d637 100644 --- a/ptx/src/pass/normalize.rs +++ b/ptx/src/pass/normalize_identifiers.rs @@ -1,14 +1,6 @@ use super::*; use ptx_parser as ast; -type NormalizedStatement = Statement< - ( - Option>, - ast::Instruction>, - ), - ast::ParsedOperand, ->; - pub(crate) fn run<'input, 'b>( id_defs: &mut FnStringIdResolver<'input, 'b>, fn_defs: &GlobalFnDeclResolver<'input, 'b>, diff --git a/ptx/src/pass/normalize_predicates.rs b/ptx/src/pass/normalize_predicates.rs new file mode 100644 index 00000000..c971cfaa --- /dev/null +++ b/ptx/src/pass/normalize_predicates.rs @@ -0,0 +1,44 @@ +use super::*; +use ptx_parser as ast; + +pub(crate) fn run( + func: Vec, + id_def: &mut NumericIdResolver, +) -> Result, TranslateError> { + let mut result = Vec::with_capacity(func.len()); + for s in func { + match s { + Statement::Label(id) => result.push(Statement::Label(id)), + Statement::Instruction((pred, inst)) => { + if let Some(pred) = pred { + let if_true = id_def.register_intermediate(None); + let if_false = id_def.register_intermediate(None); + let folded_bra = match &inst { + ast::Instruction::Bra { arguments, .. } => Some(arguments.src), + _ => None, + }; + let mut branch = BrachCondition { + predicate: pred.label, + if_true: folded_bra.unwrap_or(if_true), + if_false, + }; + if pred.not { + std::mem::swap(&mut branch.if_true, &mut branch.if_false); + } + result.push(Statement::Conditional(branch)); + if folded_bra.is_none() { + result.push(Statement::Label(if_true)); + result.push(Statement::Instruction(inst)); + } + result.push(Statement::Label(if_false)); + } else { + result.push(Statement::Instruction(inst)); + } + } + Statement::Variable(var) => result.push(Statement::Variable(var)), + // Blocks are flattened when resolving ids + _ => return Err(error_unreachable()), + } + } + Ok(result) +} From 69175d27edbff4a102aad504b6cfbaf531b48fa1 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Sat, 24 Aug 2024 02:51:46 +0200 Subject: [PATCH 28/47] Add relaxed type check information to visitors --- ptx_parser/src/ast.rs | 82 +++++++++++++++++++++---------- ptx_parser_macros/src/lib.rs | 2 +- ptx_parser_macros_impl/src/lib.rs | 28 +++++++---- 3 files changed, 76 insertions(+), 36 deletions(-) diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index ee9f9681..5175b2d2 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -27,7 +27,10 @@ ptx_parser_macros::generate_instruction_type!( type: { &data.typ }, data: LdDetails, arguments: { - dst: T, + dst: { + repr: T, + relaxed_type_check: true, + }, src: { repr: T, space: { data.state_space }, @@ -51,7 +54,10 @@ ptx_parser_macros::generate_instruction_type!( repr: T, space: { data.state_space }, }, - src2: T, + src2: { + repr: T, + relaxed_type_check: true, + } } }, Mul { @@ -157,10 +163,13 @@ ptx_parser_macros::generate_instruction_type!( dst: { repr: T, type: { Type::Scalar(data.to) }, + // TODO: double check + relaxed_type_check: true, }, src: { repr: T, type: { Type::Scalar(data.from) }, + relaxed_type_check: true, }, } }, @@ -494,16 +503,18 @@ pub trait Visitor { args: &T, type_space: Option<(&Type, StateSpace)>, is_dst: bool, + relaxed_type_check: bool, ) -> Result<(), Err>; fn visit_ident( &mut self, args: &T::Ident, type_space: Option<(&Type, StateSpace)>, is_dst: bool, + relaxed_type_check: bool, ) -> Result<(), Err>; } -impl, bool) -> Result<(), Err>> +impl, bool, bool) -> Result<(), Err>> Visitor for Fn { fn visit( @@ -511,8 +522,9 @@ impl, bool) -> Result args: &T, type_space: Option<(&Type, StateSpace)>, is_dst: bool, + relaxed_type_check: bool, ) -> Result<(), Err> { - (self)(args, type_space, is_dst) + (self)(args, type_space, is_dst, relaxed_type_check) } fn visit_ident( @@ -520,8 +532,14 @@ impl, bool) -> Result args: &T::Ident, type_space: Option<(&Type, StateSpace)>, is_dst: bool, + relaxed_type_check: bool, ) -> Result<(), Err> { - (self)(&T::from_ident(*args), type_space, is_dst) + (self)( + &T::from_ident(*args), + type_space, + is_dst, + relaxed_type_check, + ) } } @@ -531,12 +549,14 @@ pub trait VisitorMut { args: &mut T, type_space: Option<(&Type, StateSpace)>, is_dst: bool, + relaxed_type_check: bool, ) -> Result<(), Err>; fn visit_ident( &mut self, args: &mut T::Ident, type_space: Option<(&Type, StateSpace)>, is_dst: bool, + relaxed_type_check: bool, ) -> Result<(), Err>; } @@ -546,37 +566,44 @@ pub trait VisitorMap { args: From, type_space: Option<(&Type, StateSpace)>, is_dst: bool, + relaxed_type_check: bool, ) -> Result; fn visit_ident( &mut self, args: From::Ident, type_space: Option<(&Type, StateSpace)>, is_dst: bool, + relaxed_type_check: bool, ) -> Result; } impl VisitorMap, ParsedOperand, Err> for Fn where - Fn: FnMut(T, Option<(&Type, StateSpace)>, bool) -> Result, + Fn: FnMut(T, Option<(&Type, StateSpace)>, bool, bool) -> Result, { fn visit( &mut self, args: ParsedOperand, type_space: Option<(&Type, StateSpace)>, is_dst: bool, + relaxed_type_check: bool, ) -> Result, Err> { Ok(match args { - ParsedOperand::Reg(ident) => ParsedOperand::Reg((self)(ident, type_space, is_dst)?), - ParsedOperand::RegOffset(ident, imm) => { - ParsedOperand::RegOffset((self)(ident, type_space, is_dst)?, imm) + ParsedOperand::Reg(ident) => { + ParsedOperand::Reg((self)(ident, type_space, is_dst, relaxed_type_check)?) } + ParsedOperand::RegOffset(ident, imm) => ParsedOperand::RegOffset( + (self)(ident, type_space, is_dst, relaxed_type_check)?, + imm, + ), ParsedOperand::Imm(imm) => ParsedOperand::Imm(imm), - ParsedOperand::VecMember(ident, index) => { - ParsedOperand::VecMember((self)(ident, type_space, is_dst)?, index) - } + ParsedOperand::VecMember(ident, index) => ParsedOperand::VecMember( + (self)(ident, type_space, is_dst, relaxed_type_check)?, + index, + ), ParsedOperand::VecPack(vec) => ParsedOperand::VecPack( vec.into_iter() - .map(|ident| (self)(ident, type_space, is_dst)) + .map(|ident| (self)(ident, type_space, is_dst, relaxed_type_check)) .collect::, _>>()?, ), }) @@ -587,22 +614,24 @@ where args: T, type_space: Option<(&Type, StateSpace)>, is_dst: bool, + relaxed_type_check: bool, ) -> Result { - (self)(args, type_space, is_dst) + (self)(args, type_space, is_dst, relaxed_type_check) } } impl, U: Operand, Err, Fn> VisitorMap for Fn where - Fn: FnMut(T, Option<(&Type, StateSpace)>, bool) -> Result, + Fn: FnMut(T, Option<(&Type, StateSpace)>, bool, bool) -> Result, { fn visit( &mut self, args: T, type_space: Option<(&Type, StateSpace)>, is_dst: bool, + relaxed_type_check: bool, ) -> Result { - (self)(args, type_space, is_dst) + (self)(args, type_space, is_dst, relaxed_type_check) } fn visit_ident( @@ -610,8 +639,9 @@ where args: T, type_space: Option<(&Type, StateSpace)>, is_dst: bool, + relaxed_type_check: bool, ) -> Result { - (self)(args, type_space, is_dst) + (self)(args, type_space, is_dst, relaxed_type_check) } } @@ -1198,15 +1228,15 @@ impl CallArgs { .iter() .zip(details.return_arguments.iter()) { - visitor.visit_ident(param, Some((type_, *space)), true)?; + visitor.visit_ident(param, Some((type_, *space)), true, false)?; } - visitor.visit_ident(&self.func, None, false)?; + visitor.visit_ident(&self.func, None, false, false)?; for (param, (type_, space)) in self .input_arguments .iter() .zip(details.input_arguments.iter()) { - visitor.visit(param, Some((type_, *space)), true)?; + visitor.visit(param, Some((type_, *space)), true, false)?; } Ok(()) } @@ -1222,15 +1252,15 @@ impl CallArgs { .iter_mut() .zip(details.return_arguments.iter()) { - visitor.visit_ident(param, Some((type_, *space)), true)?; + visitor.visit_ident(param, Some((type_, *space)), true, false)?; } - visitor.visit_ident(&mut self.func, None, false)?; + visitor.visit_ident(&mut self.func, None, false, false)?; for (param, (type_, space)) in self .input_arguments .iter_mut() .zip(details.input_arguments.iter()) { - visitor.visit(param, Some((type_, *space)), true)?; + visitor.visit(param, Some((type_, *space)), true, false)?; } Ok(()) } @@ -1245,14 +1275,14 @@ impl CallArgs { .return_arguments .into_iter() .zip(details.return_arguments.iter()) - .map(|(param, (type_, space))| visitor.visit_ident(param, Some((type_, *space)), true)) + .map(|(param, (type_, space))| visitor.visit_ident(param, Some((type_, *space)), true, false)) .collect::, _>>()?; - let func = visitor.visit_ident(self.func, None, false)?; + let func = visitor.visit_ident(self.func, None, false, false)?; let input_arguments = self .input_arguments .into_iter() .zip(details.input_arguments.iter()) - .map(|(param, (type_, space))| visitor.visit(param, Some((type_, *space)), true)) + .map(|(param, (type_, space))| visitor.visit(param, Some((type_, *space)), true, false)) .collect::, _>>()?; Ok(CallArgs { return_arguments, diff --git a/ptx_parser_macros/src/lib.rs b/ptx_parser_macros/src/lib.rs index 4502c953..a2f8396f 100644 --- a/ptx_parser_macros/src/lib.rs +++ b/ptx_parser_macros/src/lib.rs @@ -1017,7 +1017,7 @@ pub fn generate_instruction_type(tokens: proc_macro::TokenStream) -> proc_macro: input.emit_arg_types(&mut result); input.emit_instruction_type(&mut result); input.emit_visit(&mut result); - //input.emit_visit_mut(&mut result); + input.emit_visit_mut(&mut result); input.emit_visit_map(&mut result); result.into() } diff --git a/ptx_parser_macros_impl/src/lib.rs b/ptx_parser_macros_impl/src/lib.rs index 3e53607e..2f2c87a0 100644 --- a/ptx_parser_macros_impl/src/lib.rs +++ b/ptx_parser_macros_impl/src/lib.rs @@ -512,12 +512,13 @@ pub struct ArgumentField { pub repr: Type, pub space: Option, pub type_: Option, + pub relaxed_type_check: bool, } impl ArgumentField { fn parse_block( input: syn::parse::ParseStream, - ) -> syn::Result<(Type, Option, Option, Option)> { + ) -> syn::Result<(Type, Option, Option, Option, bool)> { let content; braced!(content in input); let all_fields = @@ -531,6 +532,9 @@ impl ArgumentField { let name_ident = content.parse::()?; content.parse::()?; match &*name_ident.to_string() { + "relaxed_type_check" => { + ExprOrPath::RelaxedTypeCheck(content.parse::()?.value) + } "repr" => ExprOrPath::Repr(content.parse::()?), "space" => ExprOrPath::Space(content.parse::()?), "dst" => { @@ -552,15 +556,17 @@ impl ArgumentField { let mut type_ = None; let mut space = None; let mut is_dst = None; + let mut relaxed_type_check = false; for exp_or_path in all_fields { match exp_or_path { ExprOrPath::Repr(r) => repr = Some(r), ExprOrPath::Type(t) => type_ = Some(t), ExprOrPath::Space(s) => space = Some(s), ExprOrPath::Dst(x) => is_dst = Some(x), + ExprOrPath::RelaxedTypeCheck(relaxed) => relaxed_type_check = relaxed, } } - Ok((repr.unwrap(), type_, space, is_dst)) + Ok((repr.unwrap(), type_, space, is_dst, relaxed_type_check)) } fn parse_basic(input: &syn::parse::ParseBuffer) -> syn::Result { @@ -605,6 +611,7 @@ impl ArgumentField { .map(|space| quote! { #space }) .unwrap_or_else(|| quote! { StateSpace::Reg }); let is_dst = self.is_dst; + let relaxed_type_check = self.relaxed_type_check; let name = &self.name; let type_space = if is_typeless { quote! { @@ -622,14 +629,14 @@ impl ArgumentField { quote! { { #type_space - visitor.visit_ident(&mut arguments.#name, type_space, #is_dst)?; + visitor.visit_ident(&mut arguments.#name, type_space, #is_dst, #relaxed_type_check)?; } } } else { quote! { { #type_space - visitor.visit_ident(& arguments.#name, type_space, #is_dst)?; + visitor.visit_ident(& arguments.#name, type_space, #is_dst, #relaxed_type_check)?; } } } @@ -655,7 +662,7 @@ impl ArgumentField { }; quote! {{ #type_space - #operand_fn(#arguments_name, |x| visitor.visit(x, type_space, #is_dst))?; + #operand_fn(#arguments_name, |x| visitor.visit(x, type_space, #is_dst, #relaxed_type_check))?; }} } } @@ -679,6 +686,7 @@ impl ArgumentField { .map(|space| quote! { #space }) .unwrap_or_else(|| quote! { StateSpace::Reg }); let is_dst = self.is_dst; + let relaxed_type_check = self.relaxed_type_check; let name = &self.name; let type_space = if is_typeless { quote! { @@ -693,11 +701,11 @@ impl ArgumentField { }; let map_call = if is_ident { quote! { - visitor.visit_ident(arguments.#name, type_space, #is_dst)? + visitor.visit_ident(arguments.#name, type_space, #is_dst, #relaxed_type_check)? } } else { quote! { - MapOperand::map(arguments.#name, |x| visitor.visit(x, type_space, #is_dst))? + MapOperand::map(arguments.#name, |x| visitor.visit(x, type_space, #is_dst, #relaxed_type_check))? } }; quote! { @@ -739,10 +747,10 @@ impl Parse for ArgumentField { input.parse::()?; let lookahead = input.lookahead1(); - let (repr, type_, space, is_dst) = if lookahead.peek(token::Brace) { + let (repr, type_, space, is_dst, relaxed_type_check) = if lookahead.peek(token::Brace) { Self::parse_block(input)? } else if lookahead.peek(syn::Ident) { - (Self::parse_basic(input)?, None, None, None) + (Self::parse_basic(input)?, None, None, None, false) } else { return Err(lookahead.error()); }; @@ -756,6 +764,7 @@ impl Parse for ArgumentField { repr, type_, space, + relaxed_type_check }) } } @@ -765,6 +774,7 @@ enum ExprOrPath { Type(Expr), Space(Expr), Dst(bool), + RelaxedTypeCheck(bool), } #[cfg(test)] From 4e6dc07a52a4928b8d254785266b17c73f734af6 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Sat, 24 Aug 2024 03:10:41 +0200 Subject: [PATCH 29/47] Implement third pass --- ptx/src/pass/convert_to_typed.rs | 23 ++++++----------------- ptx/src/pass/mod.rs | 7 +------ ptx/src/pass/normalize_identifiers.rs | 1 + 3 files changed, 8 insertions(+), 23 deletions(-) diff --git a/ptx/src/pass/convert_to_typed.rs b/ptx/src/pass/convert_to_typed.rs index 3dfef55b..7ff52909 100644 --- a/ptx/src/pass/convert_to_typed.rs +++ b/ptx/src/pass/convert_to_typed.rs @@ -26,17 +26,9 @@ pub(crate) fn run( src: src_reg, })); } - ast::Instruction::Call(call) => { - let resolver = fn_defs.get_fn_sig_resolver(call.func)?; - let resolved_call = resolver.resolve_in_spirv_repr(call)?; - let mut visitor = VectorRepackVisitor::new(&mut result, id_defs); - let reresolved_call = resolved_call.visit(&mut visitor)?; - visitor.func.push(reresolved_call); - visitor.func.extend(visitor.post_stmts); - } inst => { let mut visitor = VectorRepackVisitor::new(&mut result, id_defs); - let instruction = Statement::Instruction(inst.map(&mut visitor)?); + let instruction = Statement::Instruction(ast::visit_map(inst, &mut visitor)?); visitor.func.push(instruction); visitor.func.extend(visitor.post_stmts); } @@ -68,12 +60,7 @@ impl<'a, 'b> VectorRepackVisitor<'a, 'b> { fn convert_vector( &mut self, is_dst: bool, - non_default_implicit_conversion: Option< - fn( - (ast::StateSpace, &ast::Type), - (ast::StateSpace, &ast::Type), - ) -> Result, TranslateError>, - >, + relaxed_type_check: bool, typ: &ast::Type, state_space: ast::StateSpace, idx: Vec, @@ -91,7 +78,7 @@ impl<'a, 'b> VectorRepackVisitor<'a, 'b> { typ: scalar_t, packed: temp_vec, unpacked: idx, - non_default_implicit_conversion, + relaxed_type_check, }); if is_dst { self.post_stmts = Some(statement); @@ -110,6 +97,7 @@ impl<'a, 'b> ast::VisitorMap, TypedOperand, Transl ident: SpirvWord, _: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>, _: bool, + _: bool, ) -> Result { Ok(ident) } @@ -119,6 +107,7 @@ impl<'a, 'b> ast::VisitorMap, TypedOperand, Transl op: ast::ParsedOperand, type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>, is_dst: bool, + relaxed_type_check: bool, ) -> Result { Ok(match op { ast::ParsedOperand::Reg(reg) => TypedOperand::Reg(reg), @@ -129,7 +118,7 @@ impl<'a, 'b> ast::VisitorMap, TypedOperand, Transl let (type_, space) = type_space.ok_or(TranslateError::MismatchedType)?; TypedOperand::Reg(self.convert_vector( is_dst, - desc.non_default_implicit_conversion, + relaxed_type_check, type_, space, vec, diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index bedf46ab..3968d3d5 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -798,12 +798,7 @@ struct RepackVectorDetails { typ: ast::ScalarType, packed: SpirvWord, unpacked: Vec, - non_default_implicit_conversion: Option< - fn( - (ast::StateSpace, &ast::Type), - (ast::StateSpace, &ast::Type), - ) -> Result, TranslateError>, - >, + relaxed_type_check: bool } struct FunctionPointerDetails { diff --git a/ptx/src/pass/normalize_identifiers.rs b/ptx/src/pass/normalize_identifiers.rs index 6588d637..b5983453 100644 --- a/ptx/src/pass/normalize_identifiers.rs +++ b/ptx/src/pass/normalize_identifiers.rs @@ -41,6 +41,7 @@ fn expand_map_variables<'a, 'b>( .transpose()?, ast::visit_map(i, &mut |id, _: Option<(&ast::Type, ast::StateSpace)>, + _: bool, _: bool| { id_defs.get_id(id) })?, From 107f1eb17f680dbdfccdebd3828b38b6ec0897aa Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Mon, 26 Aug 2024 15:27:14 +0200 Subject: [PATCH 30/47] Port sreg fix pass --- ptx/src/pass/fix_special_registers.rs | 183 +++++++++++++++++++ ptx/src/pass/mod.rs | 247 +++++++++++++++++++++++++- 2 files changed, 428 insertions(+), 2 deletions(-) create mode 100644 ptx/src/pass/fix_special_registers.rs diff --git a/ptx/src/pass/fix_special_registers.rs b/ptx/src/pass/fix_special_registers.rs new file mode 100644 index 00000000..d94786eb --- /dev/null +++ b/ptx/src/pass/fix_special_registers.rs @@ -0,0 +1,183 @@ +use super::*; +use std::collections::HashMap; + +fn run<'a, 'b, 'input>( + ptx_impl_imports: &'a mut HashMap>, + typed_statements: Vec, + numeric_id_defs: &'a mut NumericIdResolver<'b>, +) -> Result, TranslateError> { + let result = Vec::with_capacity(typed_statements.len()); + let mut sreg_sresolver = SpecialRegisterResolver { + ptx_impl_imports, + numeric_id_defs, + result, + }; + for statement in typed_statements { + let statement = statement.visit_map(&mut sreg_sresolver)?; + sreg_sresolver.result.push(statement); + } + Ok(sreg_sresolver.result) +} + +struct SpecialRegisterResolver<'a, 'b, 'input> { + ptx_impl_imports: &'a mut HashMap>, + numeric_id_defs: &'a mut NumericIdResolver<'b>, + result: Vec, +} + +impl<'a, 'b, 'input> ast::VisitorMap + for SpecialRegisterResolver<'a, 'b, 'input> +{ + fn visit( + &mut self, + operand: TypedOperand, + _type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>, + is_dst: bool, + _relaxed_type_check: bool, + ) -> Result { + operand.map(|name, vector_index| self.replace_sreg(name, is_dst, vector_index)) + } + + fn visit_ident( + &mut self, + args: SpirvWord, + _type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>, + is_dst: bool, + _relaxed_type_check: bool, + ) -> Result { + self.replace_sreg(args, is_dst, None) + } +} + +impl<'a, 'b, 'input> SpecialRegisterResolver<'a, 'b, 'input> { + fn replace_sreg( + &mut self, + name: SpirvWord, + is_dst: bool, + vector_index: Option, + ) -> Result { + if let Some(sreg) = self.numeric_id_defs.special_registers.get(name) { + if is_dst { + return Err(TranslateError::MismatchedType); + } + let input_arguments = match (vector_index, sreg.get_function_input_type()) { + (Some(idx), Some(inp_type)) => { + if inp_type != ast::ScalarType::U8 { + return Err(TranslateError::Unreachable); + } + let constant = self.numeric_id_defs.register_intermediate(Some(( + ast::Type::Scalar(inp_type), + ast::StateSpace::Reg, + ))); + self.result.push(Statement::Constant(ConstantDefinition { + dst: constant, + typ: inp_type, + value: ast::ImmediateValue::U64(idx as u64), + })); + vec![( + TypedOperand::Reg(constant), + ast::Type::Scalar(inp_type), + ast::StateSpace::Reg, + )] + } + (None, None) => Vec::new(), + _ => return Err(TranslateError::MismatchedType), + }; + let ocl_fn_name = [ZLUDA_PTX_PREFIX, sreg.get_unprefixed_function_name()].concat(); + let return_type = sreg.get_function_return_type(); + let fn_result = self.numeric_id_defs.register_intermediate(Some(( + ast::Type::Scalar(return_type), + ast::StateSpace::Reg, + ))); + let return_arguments = vec![( + fn_result, + ast::Type::Scalar(return_type), + ast::StateSpace::Reg, + )]; + let fn_call = register_external_fn_call( + self.numeric_id_defs, + self.ptx_impl_imports, + ocl_fn_name.to_string(), + return_arguments.iter().map(|(_, typ, space)| (typ, *space)), + input_arguments.iter().map(|(_, typ, space)| (typ, *space)), + )?; + let data = ast::CallDetails { + uniform: false, + return_arguments: return_arguments + .iter() + .map(|(_, typ, space)| (typ.clone(), *space)) + .collect(), + input_arguments: input_arguments + .iter() + .map(|(_, typ, space)| (typ.clone(), *space)) + .collect(), + }; + let arguments = ast::CallArgs { + return_arguments: return_arguments.iter().map(|(name, _, _)| *name).collect(), + func: fn_call, + input_arguments: input_arguments.iter().map(|(name, _, _)| *name).collect(), + }; + self.result + .push(Statement::Instruction(ast::Instruction::Call { + data, + arguments, + })); + Ok(fn_result) + } else { + Ok(name) + } + } +} + +fn register_external_fn_call<'a>( + id_defs: &mut NumericIdResolver, + ptx_impl_imports: &mut HashMap, + name: String, + return_arguments: impl Iterator, + input_arguments: impl Iterator, +) -> Result { + match ptx_impl_imports.entry(name) { + hash_map::Entry::Vacant(entry) => { + let fn_id = id_defs.register_intermediate(None); + let return_arguments = fn_arguments_to_variables(id_defs, return_arguments); + let input_arguments = fn_arguments_to_variables(id_defs, input_arguments); + let func_decl = ast::MethodDeclaration:: { + return_arguments, + name: ast::MethodName::Func(fn_id), + input_arguments, + shared_mem: None, + }; + let func = Function { + func_decl: Rc::new(RefCell::new(func_decl)), + globals: Vec::new(), + body: None, + import_as: Some(entry.key().clone()), + tuning: Vec::new(), + linkage: ast::LinkingDirective::EXTERN, + }; + entry.insert(Directive::Method(func)); + Ok(fn_id) + } + hash_map::Entry::Occupied(entry) => match entry.get() { + Directive::Method(Function { func_decl, .. }) => match (**func_decl).borrow().name { + ast::MethodName::Func(fn_id) => Ok(fn_id), + ast::MethodName::Kernel(_) => Err(error_unreachable()), + }, + _ => Err(error_unreachable()), + }, + } +} + +fn fn_arguments_to_variables<'a>( + id_defs: &mut NumericIdResolver, + args: impl Iterator, +) -> Vec> { + args.map(|(typ, space)| ast::Variable { + align: None, + v_type: typ.clone(), + state_space: space, + name: id_defs.register_intermediate(None), + array_init: Vec::new(), + }) + .collect::>() +} \ No newline at end of file diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index 3968d3d5..b3bfa722 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -9,6 +9,7 @@ use std::{ }; mod convert_to_typed; +mod fix_special_registers; mod normalize_identifiers; mod normalize_predicates; @@ -735,6 +736,235 @@ enum Statement { FunctionPointer(FunctionPointerDetails), } +impl> Statement, T> { + fn visit_map, Err>( + self, + visitor: &mut impl ast::VisitorMap, + ) -> std::result::Result, T>, Err> { + Ok(match self { + Statement::Instruction(i) => { + return ast::visit_map(i, visitor).map(Statement::Instruction) + } + Statement::Label(label) => { + Statement::Label(visitor.visit_ident(label, None, false, false)?) + } + Statement::Variable(var) => { + let name = visitor.visit_ident( + var.name, + Some((&var.v_type, var.state_space)), + true, + false, + )?; + Statement::Variable(ast::Variable { + align: var.align, + v_type: var.v_type, + state_space: var.state_space, + name, + array_init: var.array_init, + }) + } + Statement::Conditional(conditional) => { + let predicate = visitor.visit_ident(conditional.predicate, None, false, false)?; + let if_true = visitor.visit_ident(conditional.if_true, None, false, false)?; + let if_false = visitor.visit_ident(conditional.if_false, None, false, false)?; + Statement::Conditional(BrachCondition { + predicate, + if_true, + if_false, + }) + } + Statement::LoadVar(LoadVarDetails { + arg, + typ, + member_index, + }) => { + let dst = visitor.visit_ident( + arg.dst, + Some((&typ, ast::StateSpace::Reg)), + true, + false, + )?; + let src = visitor.visit_ident( + arg.src, + Some((&typ, ast::StateSpace::Local)), + false, + false, + )?; + Statement::LoadVar(LoadVarDetails { + arg: ast::LdArgs { dst, src }, + typ, + member_index, + }) + } + Statement::StoreVar(StoreVarDetails { + arg, + typ, + member_index, + }) => { + let src1 = visitor.visit_ident( + arg.src1, + Some((&typ, ast::StateSpace::Local)), + false, + false, + )?; + let src2 = visitor.visit_ident( + arg.src2, + Some((&typ, ast::StateSpace::Reg)), + false, + false, + )?; + Statement::StoreVar(StoreVarDetails { + arg: ast::StArgs { src1, src2 }, + typ, + member_index, + }) + } + Statement::Conversion(ImplicitConversion { + src, + dst, + from_type, + to_type, + from_space, + to_space, + kind, + }) => { + let dst = visitor.visit_ident( + dst, + Some((&to_type, ast::StateSpace::Reg)), + true, + false, + )?; + let src = visitor.visit_ident( + src, + Some((&from_type, ast::StateSpace::Reg)), + false, + false, + )?; + Statement::Conversion(ImplicitConversion { + src, + dst, + from_type, + to_type, + from_space, + to_space, + kind, + }) + } + Statement::Constant(ConstantDefinition { dst, typ, value }) => { + let dst = visitor.visit_ident( + dst, + Some((&typ.into(), ast::StateSpace::Reg)), + true, + false, + )?; + Statement::Constant(ConstantDefinition { dst, typ, value }) + } + Statement::RetValue(data, value) => { + // TODO: + // We should report type here + let value = visitor.visit_ident(value, None, false, false)?; + Statement::RetValue(data, value) + } + Statement::PtrAccess(PtrAccess { + underlying_type, + state_space, + dst, + ptr_src, + offset_src, + }) => { + let dst = + visitor.visit_ident(dst, Some((&underlying_type, state_space)), true, false)?; + let ptr_src = visitor.visit_ident( + ptr_src, + Some((&underlying_type, state_space)), + false, + false, + )?; + Statement::PtrAccess(PtrAccess { + underlying_type, + state_space, + dst, + ptr_src, + offset_src, + }) + } + Statement::RepackVector(RepackVectorDetails { + is_extract, + typ, + packed, + unpacked, + relaxed_type_check, + }) => { + let (packed, unpacked) = if is_extract { + let unpacked = unpacked + .into_iter() + .map(|ident| { + visitor.visit_ident( + ident, + Some((&typ.into(), ast::StateSpace::Reg)), + true, + relaxed_type_check, + ) + }) + .collect::, _>>()?; + let packed = visitor.visit_ident( + packed, + Some(( + &ast::Type::Vector(typ, unpacked.len() as u8), + ast::StateSpace::Reg, + )), + false, + false, + )?; + (packed, unpacked) + } else { + let packed = visitor.visit_ident( + packed, + Some(( + &ast::Type::Vector(typ, unpacked.len() as u8), + ast::StateSpace::Reg, + )), + true, + false, + )?; + let unpacked = unpacked + .into_iter() + .map(|ident| { + visitor.visit_ident( + ident, + Some((&typ.into(), ast::StateSpace::Reg)), + false, + relaxed_type_check, + ) + }) + .collect::, _>>()?; + (packed, unpacked) + }; + Statement::RepackVector(RepackVectorDetails { + is_extract, + typ, + packed, + unpacked, + relaxed_type_check, + }) + } + Statement::FunctionPointer(FunctionPointerDetails { dst, src }) => { + let dst = visitor.visit_ident( + dst, + Some(( + &ast::Type::Scalar(ast::ScalarType::U64), + ast::StateSpace::Reg, + )), + true, + false, + )?; + let src = visitor.visit_ident(src, None, false, false)?; + Statement::FunctionPointer(FunctionPointerDetails { dst, src }) + } + }) + } +} + struct BrachCondition { predicate: SpirvWord, if_true: SpirvWord, @@ -743,7 +973,6 @@ struct BrachCondition { struct LoadVarDetails { arg: ast::LdArgs, typ: ast::Type, - state_space: ast::StateSpace, // (index, vector_width) // HACK ALERT // For some reason IGC explodes when you try to load from builtin vectors @@ -798,7 +1027,7 @@ struct RepackVectorDetails { typ: ast::ScalarType, packed: SpirvWord, unpacked: Vec, - relaxed_type_check: bool + relaxed_type_check: bool, } struct FunctionPointerDetails { @@ -876,6 +1105,20 @@ enum TypedOperand { VecMember(SpirvWord, u8), } +impl TypedOperand { + fn map( + self, + fn_: impl FnOnce(SpirvWord, Option) -> Result, + ) -> Result { + Ok(match self { + TypedOperand::Reg(reg) => TypedOperand::Reg(fn_(reg, None)?), + TypedOperand::RegOffset(reg, off) => TypedOperand::RegOffset(fn_(reg, None)?, off), + TypedOperand::Imm(imm) => TypedOperand::Imm(imm), + TypedOperand::VecMember(reg, idx) => TypedOperand::VecMember(fn_(reg, Some(idx))?, idx), + }) + } +} + impl ast::Operand for TypedOperand { type Ident = SpirvWord; From 3e0a15ac845679b9ecd4f12c8bc84cf16b77081c Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Mon, 26 Aug 2024 18:31:06 +0200 Subject: [PATCH 31/47] Add stateless-to-stateful conversion --- .../pass/convert_to_stateful_memory_access.rs | 535 ++++++++++++++++++ ptx/src/pass/fix_special_registers.rs | 2 +- ptx/src/pass/mod.rs | 90 ++- ptx/src/translate.rs | 20 +- ptx_parser/src/ast.rs | 1 + 5 files changed, 627 insertions(+), 21 deletions(-) create mode 100644 ptx/src/pass/convert_to_stateful_memory_access.rs diff --git a/ptx/src/pass/convert_to_stateful_memory_access.rs b/ptx/src/pass/convert_to_stateful_memory_access.rs new file mode 100644 index 00000000..3060a704 --- /dev/null +++ b/ptx/src/pass/convert_to_stateful_memory_access.rs @@ -0,0 +1,535 @@ +use super::*; +use ptx_parser as ast; +use std::{ + collections::{BTreeSet, HashSet}, + iter, + rc::Rc, +}; + +/* + Our goal here is to transform + .visible .entry foobar(.param .u64 input) { + .reg .b64 in_addr; + .reg .b64 in_addr2; + ld.param.u64 in_addr, [input]; + cvta.to.global.u64 in_addr2, in_addr; + } + into: + .visible .entry foobar(.param .u8 input[]) { + .reg .u8 in_addr[]; + .reg .u8 in_addr2[]; + ld.param.u8[] in_addr, [input]; + mov.u8[] in_addr2, in_addr; + } + or: + .visible .entry foobar(.reg .u8 input[]) { + .reg .u8 in_addr[]; + .reg .u8 in_addr2[]; + mov.u8[] in_addr, input; + mov.u8[] in_addr2, in_addr; + } + or: + .visible .entry foobar(.param ptr input) { + .reg ptr in_addr; + .reg ptr in_addr2; + ld.param.ptr in_addr, [input]; + mov.ptr in_addr2, in_addr; + } +*/ +// TODO: detect more patterns (mov, call via reg, call via param) +// TODO: don't convert to ptr if the register is not ultimately used for ld/st +// TODO: once insert_mem_ssa_statements is moved to later, move this pass after +// argument expansion +// TODO: propagate out of calls and into calls +pub(super) fn run<'a, 'input>( + func_args: Rc>>, + func_body: Vec, + id_defs: &mut NumericIdResolver<'a>, +) -> Result< + ( + Rc>>, + Vec, + ), + TranslateError, +> { + let mut method_decl = func_args.borrow_mut(); + if !matches!(method_decl.name, ast::MethodName::Kernel(..)) { + drop(method_decl); + return Ok((func_args, func_body)); + } + if Rc::strong_count(&func_args) != 1 { + return Err(error_unreachable()); + } + let func_args_64bit = (*method_decl) + .input_arguments + .iter() + .filter_map(|arg| match arg.v_type { + ast::Type::Scalar(ast::ScalarType::U64) + | ast::Type::Scalar(ast::ScalarType::B64) + | ast::Type::Scalar(ast::ScalarType::S64) => Some(arg.name), + _ => None, + }) + .collect::>(); + let mut stateful_markers = Vec::new(); + let mut stateful_init_reg = HashMap::<_, Vec<_>>::new(); + for statement in func_body.iter() { + match statement { + Statement::Instruction(ast::Instruction::Cvta { + data: + ast::CvtaDetails { + state_space: ast::StateSpace::Global, + direction: ast::CvtaDirection::GenericToExplicit, + }, + arguments, + }) => { + if let (TypedOperand::Reg(dst), Some(src)) = + (arguments.dst, arguments.src.underlying_register()) + { + if is_64_bit_integer(id_defs, src) && is_64_bit_integer(id_defs, dst) { + stateful_markers.push((dst, src)); + } + } + } + Statement::Instruction(ast::Instruction::Ld { + data: + ast::LdDetails { + state_space: ast::StateSpace::Param, + typ: ast::Type::Scalar(ast::ScalarType::U64), + .. + }, + arguments, + }) + | Statement::Instruction(ast::Instruction::Ld { + data: + ast::LdDetails { + state_space: ast::StateSpace::Param, + typ: ast::Type::Scalar(ast::ScalarType::S64), + .. + }, + arguments, + }) + | Statement::Instruction(ast::Instruction::Ld { + data: + ast::LdDetails { + state_space: ast::StateSpace::Param, + typ: ast::Type::Scalar(ast::ScalarType::B64), + .. + }, + arguments, + }) => { + if let (TypedOperand::Reg(dst), Some(src)) = + (arguments.dst, arguments.src.underlying_register()) + { + if func_args_64bit.contains(&src) { + multi_hash_map_append(&mut stateful_init_reg, dst, src); + } + } + } + _ => {} + } + } + if stateful_markers.len() == 0 { + drop(method_decl); + return Ok((func_args, func_body)); + } + let mut func_args_ptr = HashSet::new(); + let mut regs_ptr_current = HashSet::new(); + for (dst, src) in stateful_markers { + if let Some(func_args) = stateful_init_reg.get(&src) { + for a in func_args { + func_args_ptr.insert(*a); + regs_ptr_current.insert(src); + regs_ptr_current.insert(dst); + } + } + } + // BTreeSet here to have a stable order of iteration, + // unfortunately our tests rely on it + let mut regs_ptr_seen = BTreeSet::new(); + while regs_ptr_current.len() > 0 { + let mut regs_ptr_new = HashSet::new(); + for statement in func_body.iter() { + match statement { + Statement::Instruction(ast::Instruction::Add { + data: + ast::ArithDetails::Integer(ast::ArithInteger { + type_: ast::ScalarType::U64, + saturate: false, + }), + arguments, + }) + | Statement::Instruction(ast::Instruction::Add { + data: + ast::ArithDetails::Integer(ast::ArithInteger { + type_: ast::ScalarType::S64, + saturate: false, + }), + arguments, + }) => { + // TODO: don't mark result of double pointer sub or double + // pointer add as ptr result + if let (TypedOperand::Reg(dst), Some(src1)) = + (arguments.dst, arguments.src1.underlying_register()) + { + if regs_ptr_current.contains(&src1) && !regs_ptr_seen.contains(&src1) { + regs_ptr_new.insert(dst); + } + } else if let (TypedOperand::Reg(dst), Some(src2)) = + (arguments.dst, arguments.src2.underlying_register()) + { + if regs_ptr_current.contains(&src2) && !regs_ptr_seen.contains(&src2) { + regs_ptr_new.insert(dst); + } + } + } + + Statement::Instruction(ast::Instruction::Sub { + data: + ast::ArithDetails::Integer(ast::ArithInteger { + type_: ast::ScalarType::U64, + saturate: false, + }), + arguments, + }) + | Statement::Instruction(ast::Instruction::Sub { + data: + ast::ArithDetails::Integer(ast::ArithInteger { + type_: ast::ScalarType::S64, + saturate: false, + }), + arguments, + }) => { + // TODO: don't mark result of double pointer sub or double + // pointer add as ptr result + if let (TypedOperand::Reg(dst), Some(src1)) = + (arguments.dst, arguments.src1.underlying_register()) + { + if regs_ptr_current.contains(&src1) && !regs_ptr_seen.contains(&src1) { + regs_ptr_new.insert(dst); + } + } else if let (TypedOperand::Reg(dst), Some(src2)) = + (arguments.dst, arguments.src2.underlying_register()) + { + if regs_ptr_current.contains(&src2) && !regs_ptr_seen.contains(&src2) { + regs_ptr_new.insert(dst); + } + } + } + _ => {} + } + } + for id in regs_ptr_current { + regs_ptr_seen.insert(id); + } + regs_ptr_current = regs_ptr_new; + } + drop(regs_ptr_current); + let mut remapped_ids = HashMap::new(); + let mut result = Vec::with_capacity(regs_ptr_seen.len() + func_body.len()); + for reg in regs_ptr_seen { + let new_id = id_defs.register_variable( + ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global), + ast::StateSpace::Reg, + ); + result.push(Statement::Variable(ast::Variable { + align: None, + name: new_id, + array_init: Vec::new(), + v_type: ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global), + state_space: ast::StateSpace::Reg, + })); + remapped_ids.insert(reg, new_id); + } + for arg in (*method_decl).input_arguments.iter_mut() { + if !func_args_ptr.contains(&arg.name) { + continue; + } + let new_id = id_defs.register_variable( + ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global), + ast::StateSpace::Param, + ); + let old_name = arg.name; + arg.v_type = ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global); + arg.name = new_id; + remapped_ids.insert(old_name, new_id); + } + for statement in func_body { + match statement { + l @ Statement::Label(_) => result.push(l), + c @ Statement::Conditional(_) => result.push(c), + c @ Statement::Constant(..) => result.push(c), + Statement::Variable(var) => { + if !remapped_ids.contains_key(&var.name) { + result.push(Statement::Variable(var)); + } + } + Statement::Instruction(ast::Instruction::Add { + data: + ast::ArithDetails::Integer(ast::ArithInteger { + type_: ast::ScalarType::U64, + saturate: false, + }), + arguments, + }) + | Statement::Instruction(ast::Instruction::Add { + data: + ast::ArithDetails::Integer(ast::ArithInteger { + type_: ast::ScalarType::S64, + saturate: false, + }), + arguments, + }) if is_add_ptr_direct(&remapped_ids, &arguments) => { + let (ptr, offset) = match arguments.src1.underlying_register() { + Some(src1) if remapped_ids.contains_key(&src1) => { + (remapped_ids.get(&src1).unwrap(), arguments.src2) + } + Some(src2) if remapped_ids.contains_key(&src2) => { + (remapped_ids.get(&src2).unwrap(), arguments.src1) + } + _ => return Err(error_unreachable()), + }; + let dst = arguments.dst.unwrap_reg()?; + result.push(Statement::PtrAccess(PtrAccess { + underlying_type: ast::Type::Scalar(ast::ScalarType::U8), + state_space: ast::StateSpace::Global, + dst: *remapped_ids.get(&dst).unwrap(), + ptr_src: *ptr, + offset_src: offset, + })) + } + Statement::Instruction(ast::Instruction::Sub { + data: + ast::ArithDetails::Integer(ast::ArithInteger { + type_: ast::ScalarType::U64, + saturate: false, + }), + arguments, + }) + | Statement::Instruction(ast::Instruction::Sub { + data: + ast::ArithDetails::Integer(ast::ArithInteger { + type_: ast::ScalarType::S64, + saturate: false, + }), + arguments, + }) if is_sub_ptr_direct(&remapped_ids, &arguments) => { + let (ptr, offset) = match arguments.src1.underlying_register() { + Some(ref src1) => (remapped_ids.get(src1).unwrap(), arguments.src2), + _ => return Err(error_unreachable()), + }; + let offset_neg = id_defs.register_intermediate(Some(( + ast::Type::Scalar(ast::ScalarType::S64), + ast::StateSpace::Reg, + ))); + result.push(Statement::Instruction(ast::Instruction::Neg { + data: ast::TypeFtz { + type_: ast::ScalarType::S64, + flush_to_zero: None, + }, + arguments: ast::NegArgs { + src: offset, + dst: TypedOperand::Reg(offset_neg), + }, + })); + let dst = arguments.dst.unwrap_reg()?; + result.push(Statement::PtrAccess(PtrAccess { + underlying_type: ast::Type::Scalar(ast::ScalarType::U8), + state_space: ast::StateSpace::Global, + dst: *remapped_ids.get(&dst).unwrap(), + ptr_src: *ptr, + offset_src: TypedOperand::Reg(offset_neg), + })) + } + inst @ Statement::Instruction(_) => { + let mut post_statements = Vec::new(); + let new_statement = inst.visit_map(&mut FnVisitor::new( + |operand, type_space, is_dst, relaxed_conversion| { + convert_to_stateful_memory_access_postprocess( + id_defs, + &remapped_ids, + &mut result, + &mut post_statements, + operand, + type_space, + is_dst, + relaxed_conversion, + ) + }, + ))?; + result.push(new_statement); + result.extend(post_statements); + } + repack @ Statement::RepackVector(_) => { + let mut post_statements = Vec::new(); + let new_statement = repack.visit_map(&mut FnVisitor::new( + |operand, type_space, is_dst, relaxed_conversion| { + convert_to_stateful_memory_access_postprocess( + id_defs, + &remapped_ids, + &mut result, + &mut post_statements, + operand, + type_space, + is_dst, + relaxed_conversion, + ) + }, + ))?; + result.push(new_statement); + result.extend(post_statements); + } + _ => return Err(error_unreachable()), + } + } + drop(method_decl); + Ok((func_args, result)) +} + +fn is_64_bit_integer(id_defs: &NumericIdResolver, id: SpirvWord) -> bool { + match id_defs.get_typed(id) { + Ok((ast::Type::Scalar(ast::ScalarType::U64), _, _)) + | Ok((ast::Type::Scalar(ast::ScalarType::S64), _, _)) + | Ok((ast::Type::Scalar(ast::ScalarType::B64), _, _)) => true, + _ => false, + } +} + +fn multi_hash_map_append< + K: Eq + std::hash::Hash, + V, + Collection: std::iter::Extend + std::default::Default, +>( + m: &mut HashMap, + key: K, + value: V, +) { + match m.entry(key) { + hash_map::Entry::Occupied(mut entry) => { + entry.get_mut().extend(iter::once(value)); + } + hash_map::Entry::Vacant(entry) => { + entry.insert(Default::default()).extend(iter::once(value)); + } + } +} + +fn is_add_ptr_direct( + remapped_ids: &HashMap, + arg: &ast::AddArgs, +) -> bool { + match arg.dst { + TypedOperand::Imm(..) | TypedOperand::RegOffset(..) | TypedOperand::VecMember(..) => { + return false + } + TypedOperand::Reg(dst) => { + if !remapped_ids.contains_key(&dst) { + return false; + } + if let Some(ref src1_reg) = arg.src1.underlying_register() { + if remapped_ids.contains_key(src1_reg) { + // don't trigger optimization when adding two pointers + if let Some(ref src2_reg) = arg.src2.underlying_register() { + return !remapped_ids.contains_key(src2_reg); + } + } + } + if let Some(ref src2_reg) = arg.src2.underlying_register() { + remapped_ids.contains_key(src2_reg) + } else { + false + } + } + } +} + +fn is_sub_ptr_direct( + remapped_ids: &HashMap, + arg: &ast::SubArgs, +) -> bool { + match arg.dst { + TypedOperand::Imm(..) | TypedOperand::RegOffset(..) | TypedOperand::VecMember(..) => { + return false + } + TypedOperand::Reg(dst) => { + if !remapped_ids.contains_key(&dst) { + return false; + } + match arg.src1.underlying_register() { + Some(ref src1_reg) => { + if remapped_ids.contains_key(src1_reg) { + // don't trigger optimization when subtracting two pointers + arg.src2 + .underlying_register() + .map_or(true, |ref src2_reg| !remapped_ids.contains_key(src2_reg)) + } else { + false + } + } + None => false, + } + } + } +} + +fn convert_to_stateful_memory_access_postprocess( + id_defs: &mut NumericIdResolver, + remapped_ids: &HashMap, + result: &mut Vec, + post_statements: &mut Vec, + operand: TypedOperand, + type_space: Option<(&ast::Type, ast::StateSpace)>, + is_dst: bool, + relaxed_conversion: bool, +) -> Result { + operand.map(|operand, _| { + Ok(match remapped_ids.get(&operand) { + Some(new_id) => { + let (new_operand_type, new_operand_space, _) = id_defs.get_typed(*new_id)?; + // TODO: readd if required + if let Some(..) = type_space { + if relaxed_conversion { + return Ok(*new_id); + } + } + let (old_operand_type, old_operand_space, _) = id_defs.get_typed(operand)?; + let converting_id = id_defs + .register_intermediate(Some((old_operand_type.clone(), old_operand_space))); + let kind = if state_is_compatible(new_operand_space, ast::StateSpace::Reg) { + ConversionKind::Default + } else { + ConversionKind::PtrToPtr + }; + if is_dst { + post_statements.push(Statement::Conversion(ImplicitConversion { + src: converting_id, + dst: *new_id, + from_type: old_operand_type, + from_space: old_operand_space, + to_type: new_operand_type, + to_space: new_operand_space, + kind, + })); + converting_id + } else { + result.push(Statement::Conversion(ImplicitConversion { + src: *new_id, + dst: converting_id, + from_type: new_operand_type, + from_space: new_operand_space, + to_type: old_operand_type, + to_space: old_operand_space, + kind, + })); + converting_id + } + } + None => operand, + }) + }) +} + +fn state_is_compatible(this: ast::StateSpace, other: ast::StateSpace) -> bool { + this == other + || this == ast::StateSpace::Reg && other == ast::StateSpace::Sreg + || this == ast::StateSpace::Sreg && other == ast::StateSpace::Reg +} diff --git a/ptx/src/pass/fix_special_registers.rs b/ptx/src/pass/fix_special_registers.rs index d94786eb..871537d5 100644 --- a/ptx/src/pass/fix_special_registers.rs +++ b/ptx/src/pass/fix_special_registers.rs @@ -1,7 +1,7 @@ use super::*; use std::collections::HashMap; -fn run<'a, 'b, 'input>( +pub(super) fn run<'a, 'b, 'input>( ptx_impl_imports: &'a mut HashMap>, typed_statements: Vec, numeric_id_defs: &'a mut NumericIdResolver<'b>, diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index b3bfa722..439233a5 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -5,9 +5,11 @@ use std::{ cell::RefCell, collections::{hash_map, HashMap}, ffi::CString, + marker::PhantomData, rc::Rc, }; +mod convert_to_stateful_memory_access; mod convert_to_typed; mod fix_special_registers; mod normalize_identifiers; @@ -169,12 +171,12 @@ fn to_ssa<'input, 'b>( let unadorned_statements = normalize_predicates::run(normalized_ids, &mut numeric_id_defs)?; let typed_statements = convert_to_typed::run(unadorned_statements, &fn_defs, &mut numeric_id_defs)?; - todo!() - /* let typed_statements = - fix_special_registers2(ptx_impl_imports, typed_statements, &mut numeric_id_defs)?; + fix_special_registers::run(ptx_impl_imports, typed_statements, &mut numeric_id_defs)?; let (func_decl, typed_statements) = - convert_to_stateful_memory_access(func_decl, typed_statements, &mut numeric_id_defs)?; + convert_to_stateful_memory_access::run(func_decl, typed_statements, &mut numeric_id_defs)?; + todo!() + /* let ssa_statements = insert_mem_ssa_statements( typed_statements, &mut numeric_id_defs, @@ -1035,7 +1037,7 @@ struct FunctionPointerDetails { src: SpirvWord, } -#[derive(Copy, Clone, PartialEq, Eq, Hash)] +#[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] struct SpirvWord(spirv::Word); impl From for SpirvWord { @@ -1117,6 +1119,20 @@ impl TypedOperand { TypedOperand::VecMember(reg, idx) => TypedOperand::VecMember(fn_(reg, Some(idx))?, idx), }) } + + fn underlying_register(&self) -> Option { + match self { + Self::Reg(r) | Self::RegOffset(r, _) | Self::VecMember(r, _) => Some(*r), + Self::Imm(_) => None, + } + } + + fn unwrap_reg(&self) -> Result { + match self { + TypedOperand::Reg(reg) => Ok(*reg), + _ => Err(error_unreachable()), + } + } } impl ast::Operand for TypedOperand { @@ -1126,3 +1142,67 @@ impl ast::Operand for TypedOperand { TypedOperand::Reg(ident) } } + +impl ast::VisitorMap + for FnVisitor +where + Fn: FnMut( + TypedOperand, + Option<(&ast::Type, ast::StateSpace)>, + bool, + bool, + ) -> Result, +{ + fn visit( + &mut self, + args: TypedOperand, + type_space: Option<(&ast::Type, ast::StateSpace)>, + is_dst: bool, + relaxed_type_check: bool, + ) -> Result { + (self.fn_)(args, type_space, is_dst, relaxed_type_check) + } + + fn visit_ident( + &mut self, + args: SpirvWord, + type_space: Option<(&ast::Type, ast::StateSpace)>, + is_dst: bool, + relaxed_type_check: bool, + ) -> Result { + match (self.fn_)( + TypedOperand::Reg(args), + type_space, + is_dst, + relaxed_type_check, + )? { + TypedOperand::Reg(reg) => Ok(reg), + _ => Err(TranslateError::Unreachable), + } + } +} + +struct FnVisitor< + T, + U, + Err, + Fn: FnMut(T, Option<(&ast::Type, ast::StateSpace)>, bool, bool) -> Result, +> { + fn_: Fn, + _marker: PhantomData Result>, +} + +impl< + T, + U, + Err, + Fn: FnMut(T, Option<(&ast::Type, ast::StateSpace)>, bool, bool) -> Result, + > FnVisitor +{ + fn new(fn_: Fn) -> Self { + Self { + fn_, + _marker: PhantomData, + } + } +} diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index db1063b6..9b422fda 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1608,17 +1608,13 @@ fn extract_globals<'input, 'b>( for statement in sorted_statements { match statement { Statement::Variable( - var - @ - ast::Variable { + var @ ast::Variable { state_space: ast::StateSpace::Shared, .. }, ) | Statement::Variable( - var - @ - ast::Variable { + var @ ast::Variable { state_space: ast::StateSpace::Global, .. }, @@ -1660,9 +1656,7 @@ fn extract_globals<'input, 'b>( )?); } Statement::Instruction(ast::Instruction::Atom( - details - @ - ast::AtomDetails { + details @ ast::AtomDetails { inner: ast::AtomInnerDetails::Unsigned { op: ast::AtomUIntOp::Inc, @@ -1691,9 +1685,7 @@ fn extract_globals<'input, 'b>( )?); } Statement::Instruction(ast::Instruction::Atom( - details - @ - ast::AtomDetails { + details @ ast::AtomDetails { inner: ast::AtomInnerDetails::Unsigned { op: ast::AtomUIntOp::Dec, @@ -1722,9 +1714,7 @@ fn extract_globals<'input, 'b>( )?); } Statement::Instruction(ast::Instruction::Atom( - details - @ - ast::AtomDetails { + details @ ast::AtomDetails { inner: ast::AtomInnerDetails::Float { op: ast::AtomFloatOp::Add, diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index 5175b2d2..59815f25 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -760,6 +760,7 @@ pub enum Type { Vector(ScalarType, u8), // .param.b32 foo[4]; Array(ScalarType, Vec), + Pointer(ScalarType, StateSpace) } impl Type { From cccd37f6ee4a14ed644a67a7d6f671a56e9ed8d1 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Mon, 26 Aug 2024 19:07:49 +0200 Subject: [PATCH 32/47] Port ssa conversion --- .../pass/convert_to_stateful_memory_access.rs | 6 - ptx/src/pass/insert_mem_ssa_statements.rs | 276 ++++++++++++++++++ ptx/src/pass/mod.rs | 13 +- 3 files changed, 286 insertions(+), 9 deletions(-) create mode 100644 ptx/src/pass/insert_mem_ssa_statements.rs diff --git a/ptx/src/pass/convert_to_stateful_memory_access.rs b/ptx/src/pass/convert_to_stateful_memory_access.rs index 3060a704..829e1e60 100644 --- a/ptx/src/pass/convert_to_stateful_memory_access.rs +++ b/ptx/src/pass/convert_to_stateful_memory_access.rs @@ -527,9 +527,3 @@ fn convert_to_stateful_memory_access_postprocess( }) }) } - -fn state_is_compatible(this: ast::StateSpace, other: ast::StateSpace) -> bool { - this == other - || this == ast::StateSpace::Reg && other == ast::StateSpace::Sreg - || this == ast::StateSpace::Sreg && other == ast::StateSpace::Reg -} diff --git a/ptx/src/pass/insert_mem_ssa_statements.rs b/ptx/src/pass/insert_mem_ssa_statements.rs new file mode 100644 index 00000000..6ab19bd8 --- /dev/null +++ b/ptx/src/pass/insert_mem_ssa_statements.rs @@ -0,0 +1,276 @@ +use super::*; +use ptx_parser as ast; + +/* + How do we handle arguments: + - input .params in kernels + .param .b64 in_arg + get turned into this SPIR-V: + %1 = OpFunctionParameter %ulong + %2 = OpVariable %_ptr_Function_ulong Function + OpStore %2 %1 + We do this for two reasons. One, common treatment for argument-declared + .param variables and .param variables inside function (we assume that + at SPIR-V level every .param is a pointer in Function storage class) + - input .params in functions + .param .b64 in_arg + get turned into this SPIR-V: + %1 = OpFunctionParameter %_ptr_Function_ulong + - input .regs + .reg .b64 in_arg + get turned into the same SPIR-V as kernel .params: + %1 = OpFunctionParameter %ulong + %2 = OpVariable %_ptr_Function_ulong Function + OpStore %2 %1 + - output .regs + .reg .b64 out_arg + get just a variable declaration: + %2 = OpVariable %%_ptr_Function_ulong Function + - output .params don't exist, they have been moved to input positions + by an earlier pass + Distinguishing betweem kernel .params and function .params is not the + cleanest solution. Alternatively, we could "deparamize" all kernel .param + arguments by turning them into .reg arguments like this: + .param .b64 arg -> .reg ptr<.b64,.param> arg + This has the massive downside that this transformation would have to run + very early and would muddy up already difficult code. It's simpler to just + have an if here +*/ +pub(super) fn run<'a, 'b>( + func: Vec, + id_def: &mut NumericIdResolver, + fn_decl: &'a mut ast::MethodDeclaration<'b, SpirvWord>, +) -> Result, TranslateError> { + let mut result = Vec::with_capacity(func.len()); + for arg in fn_decl.input_arguments.iter_mut() { + insert_mem_ssa_argument( + id_def, + &mut result, + arg, + matches!(fn_decl.name, ast::MethodName::Kernel(_)), + ); + } + for arg in fn_decl.return_arguments.iter() { + insert_mem_ssa_argument_reg_return(&mut result, arg); + } + for s in func { + match s { + Statement::Instruction(inst) => match inst { + ast::Instruction::Ret { data } => { + // TODO: handle multiple output args + match &fn_decl.return_arguments[..] { + [return_reg] => { + let new_id = id_def.register_intermediate(Some(( + return_reg.v_type.clone(), + ast::StateSpace::Reg, + ))); + result.push(Statement::LoadVar(LoadVarDetails { + arg: ast::LdArgs { + dst: new_id, + src: return_reg.name, + }, + typ: return_reg.v_type.clone(), + member_index: None, + })); + result.push(Statement::RetValue(data, new_id)); + } + [] => result.push(Statement::Instruction(ast::Instruction::Ret { data })), + _ => unimplemented!(), + } + } + inst => insert_mem_ssa_statement_default( + id_def, + &mut result, + Statement::Instruction(inst), + )?, + }, + Statement::Conditional(bra) => { + insert_mem_ssa_statement_default(id_def, &mut result, Statement::Conditional(bra))? + } + Statement::Conversion(conv) => { + insert_mem_ssa_statement_default(id_def, &mut result, Statement::Conversion(conv))? + } + Statement::PtrAccess(ptr_access) => insert_mem_ssa_statement_default( + id_def, + &mut result, + Statement::PtrAccess(ptr_access), + )?, + Statement::RepackVector(repack) => insert_mem_ssa_statement_default( + id_def, + &mut result, + Statement::RepackVector(repack), + )?, + Statement::FunctionPointer(func_ptr) => insert_mem_ssa_statement_default( + id_def, + &mut result, + Statement::FunctionPointer(func_ptr), + )?, + s @ Statement::Variable(_) | s @ Statement::Label(_) | s @ Statement::Constant(..) => { + result.push(s) + } + _ => return Err(error_unreachable()), + } + } + Ok(result) +} + +fn insert_mem_ssa_argument( + id_def: &mut NumericIdResolver, + func: &mut Vec, + arg: &mut ast::Variable, + is_kernel: bool, +) { + if !is_kernel && arg.state_space == ast::StateSpace::Param { + return; + } + let new_id = id_def.register_intermediate(Some((arg.v_type.clone(), arg.state_space))); + func.push(Statement::Variable(ast::Variable { + align: arg.align, + v_type: arg.v_type.clone(), + state_space: ast::StateSpace::Reg, + name: arg.name, + array_init: Vec::new(), + })); + func.push(Statement::StoreVar(StoreVarDetails { + arg: ast::StArgs { + src1: arg.name, + src2: new_id, + }, + typ: arg.v_type.clone(), + member_index: None, + })); + arg.name = new_id; +} + +fn insert_mem_ssa_argument_reg_return( + func: &mut Vec, + arg: &ast::Variable, +) { + func.push(Statement::Variable(ast::Variable { + align: arg.align, + v_type: arg.v_type.clone(), + state_space: arg.state_space, + name: arg.name, + array_init: arg.array_init.clone(), + })); +} + +fn insert_mem_ssa_statement_default<'a, 'input>( + id_def: &'a mut NumericIdResolver<'input>, + func: &'a mut Vec, + stmt: TypedStatement, +) -> Result<(), TranslateError> { + let mut visitor = InsertMemSSAVisitor { + id_def, + func, + post_statements: Vec::new(), + }; + let new_stmt = stmt.visit_map(&mut visitor)?; + visitor.func.push(new_stmt); + visitor.func.extend(visitor.post_statements); + Ok(()) +} + +struct InsertMemSSAVisitor<'a, 'input> { + id_def: &'a mut NumericIdResolver<'input>, + func: &'a mut Vec, + post_statements: Vec, +} + +impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { + fn symbol( + &mut self, + symbol: SpirvWord, + member_index: Option, + expected: Option<(&ast::Type, ast::StateSpace)>, + is_dst: bool, + ) -> Result { + if expected.is_none() { + return Ok(symbol); + }; + let (mut var_type, var_space, is_variable) = self.id_def.get_typed(symbol)?; + if !state_is_compatible(var_space, ast::StateSpace::Reg) || !is_variable { + return Ok(symbol); + }; + let member_index = match member_index { + Some(idx) => { + let vector_width = match var_type { + ast::Type::Vector(scalar_t, width) => { + var_type = ast::Type::Scalar(scalar_t); + width + } + _ => return Err(TranslateError::MismatchedType), + }; + Some(( + idx, + if self.id_def.special_registers.get(symbol).is_some() { + Some(vector_width) + } else { + None + }, + )) + } + None => None, + }; + let generated_id = self + .id_def + .register_intermediate(Some((var_type.clone(), ast::StateSpace::Reg))); + if !is_dst { + self.func.push(Statement::LoadVar(LoadVarDetails { + arg: ast::LdArgs { + dst: generated_id, + src: symbol, + }, + typ: var_type, + member_index, + })); + } else { + self.post_statements + .push(Statement::StoreVar(StoreVarDetails { + arg: ast::StArgs { + src1: symbol, + src2: generated_id, + }, + typ: var_type, + member_index: member_index.map(|(idx, _)| idx), + })); + } + Ok(generated_id) + } +} + +impl<'a, 'input> ast::VisitorMap + for InsertMemSSAVisitor<'a, 'input> +{ + fn visit( + &mut self, + operand: TypedOperand, + type_space: Option<(&ast::Type, ast::StateSpace)>, + is_dst: bool, + _relaxed_type_check: bool, + ) -> Result { + Ok(match operand { + TypedOperand::Reg(reg) => { + TypedOperand::Reg(self.symbol(reg, None, type_space, is_dst)?) + } + TypedOperand::RegOffset(reg, offset) => { + TypedOperand::RegOffset(self.symbol(reg, None, type_space, is_dst)?, offset) + } + op @ TypedOperand::Imm(..) => op, + TypedOperand::VecMember(symbol, index) => TypedOperand::VecMember( + self.symbol(symbol, Some(index), type_space, is_dst)?, + index, + ), + }) + } + + fn visit_ident( + &mut self, + args: SpirvWord, + type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>, + is_dst: bool, + relaxed_type_check: bool, + ) -> Result { + self.symbol(args, None, type_space, is_dst) + } +} diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index 439233a5..f6b700b3 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -12,6 +12,7 @@ use std::{ mod convert_to_stateful_memory_access; mod convert_to_typed; mod fix_special_registers; +mod insert_mem_ssa_statements; mod normalize_identifiers; mod normalize_predicates; @@ -175,13 +176,13 @@ fn to_ssa<'input, 'b>( fix_special_registers::run(ptx_impl_imports, typed_statements, &mut numeric_id_defs)?; let (func_decl, typed_statements) = convert_to_stateful_memory_access::run(func_decl, typed_statements, &mut numeric_id_defs)?; - todo!() - /* - let ssa_statements = insert_mem_ssa_statements( + let ssa_statements = insert_mem_ssa_statements::run( typed_statements, &mut numeric_id_defs, &mut (*func_decl).borrow_mut(), )?; + todo!() + /* let mut numeric_id_defs = numeric_id_defs.finish(); let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs)?; let expanded_statements = @@ -1206,3 +1207,9 @@ impl< } } } + +fn state_is_compatible(this: ast::StateSpace, other: ast::StateSpace) -> bool { + this == other + || this == ast::StateSpace::Reg && other == ast::StateSpace::Sreg + || this == ast::StateSpace::Sreg && other == ast::StateSpace::Reg +} From c088cc21716c4b18f9f98f42e5524333e392db33 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Mon, 26 Aug 2024 21:37:04 +0200 Subject: [PATCH 33/47] Port expand_arguments --- ptx/src/pass/expand_arguments.rs | 181 +++++++++++++++++++++++++++++++ ptx/src/pass/mod.rs | 13 ++- 2 files changed, 191 insertions(+), 3 deletions(-) create mode 100644 ptx/src/pass/expand_arguments.rs diff --git a/ptx/src/pass/expand_arguments.rs b/ptx/src/pass/expand_arguments.rs new file mode 100644 index 00000000..eb03866d --- /dev/null +++ b/ptx/src/pass/expand_arguments.rs @@ -0,0 +1,181 @@ +use super::*; +use ptx_parser as ast; + +pub(super) fn run<'a, 'b>( + func: Vec, + id_def: &'b mut MutableNumericIdResolver<'a>, +) -> Result, TranslateError> { + let mut result = Vec::with_capacity(func.len()); + for s in func { + match s { + Statement::Label(id) => result.push(Statement::Label(id)), + Statement::Conditional(bra) => result.push(Statement::Conditional(bra)), + Statement::LoadVar(details) => result.push(Statement::LoadVar(details)), + Statement::StoreVar(details) => result.push(Statement::StoreVar(details)), + Statement::RetValue(d, id) => result.push(Statement::RetValue(d, id)), + Statement::Conversion(conv) => result.push(Statement::Conversion(conv)), + Statement::Constant(c) => result.push(Statement::Constant(c)), + Statement::FunctionPointer(d) => result.push(Statement::FunctionPointer(d)), + s => { + let (new_statement, post_stmts) = { + let mut visitor = FlattenArguments::new(&mut result, id_def); + (s.visit_map(&mut visitor)?, visitor.post_stmts) + }; + result.push(new_statement); + result.extend(post_stmts); + } + } + } + Ok(result) +} + +struct FlattenArguments<'a, 'b> { + func: &'b mut Vec, + id_def: &'b mut MutableNumericIdResolver<'a>, + post_stmts: Vec, +} + +impl<'a, 'b> FlattenArguments<'a, 'b> { + fn new( + func: &'b mut Vec, + id_def: &'b mut MutableNumericIdResolver<'a>, + ) -> Self { + FlattenArguments { + func, + id_def, + post_stmts: Vec::new(), + } + } + + fn reg(&mut self, name: SpirvWord) -> Result { + Ok(name) + } + + fn reg_offset( + &mut self, + reg: SpirvWord, + offset: i32, + type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>, + _is_dst: bool, + ) -> Result { + let (type_, state_space) = if let Some((type_, state_space)) = type_space { + (type_, state_space) + } else { + return Err(TranslateError::UntypedSymbol); + }; + if state_space != ast::StateSpace::Reg && state_space != ast::StateSpace::Sreg { + let (reg_type, reg_space) = self.id_def.get_typed(reg)?; + if !state_is_compatible(reg_space, ast::StateSpace::Reg) { + return Err(TranslateError::MismatchedType); + } + let reg_scalar_type = match reg_type { + ast::Type::Scalar(underlying_type) => underlying_type, + _ => return Err(TranslateError::MismatchedType), + }; + let id_constant_stmt = self + .id_def + .register_intermediate(reg_type.clone(), ast::StateSpace::Reg); + self.func.push(Statement::Constant(ConstantDefinition { + dst: id_constant_stmt, + typ: reg_scalar_type, + value: ast::ImmediateValue::S64(offset as i64), + })); + let arith_details = match reg_scalar_type.kind() { + ast::ScalarKind::Signed => ast::ArithDetails::Integer(ast::ArithInteger { + type_: reg_scalar_type, + saturate: false, + }), + ast::ScalarKind::Unsigned | ast::ScalarKind::Bit => { + ast::ArithDetails::Integer(ast::ArithInteger { + type_: reg_scalar_type, + saturate: false, + }) + } + _ => return Err(error_unreachable()), + }; + let id_add_result = self.id_def.register_intermediate(reg_type, state_space); + self.func + .push(Statement::Instruction(ast::Instruction::Add { + data: arith_details, + arguments: ast::AddArgs { + dst: id_add_result, + src1: reg, + src2: id_constant_stmt, + }, + })); + Ok(id_add_result) + } else { + let id_constant_stmt = self.id_def.register_intermediate( + ast::Type::Scalar(ast::ScalarType::S64), + ast::StateSpace::Reg, + ); + self.func.push(Statement::Constant(ConstantDefinition { + dst: id_constant_stmt, + typ: ast::ScalarType::S64, + value: ast::ImmediateValue::S64(offset as i64), + })); + let dst = self + .id_def + .register_intermediate(type_.clone(), state_space); + self.func.push(Statement::PtrAccess(PtrAccess { + underlying_type: type_.clone(), + state_space: state_space, + dst, + ptr_src: reg, + offset_src: id_constant_stmt, + })); + Ok(dst) + } + } + + fn immediate( + &mut self, + value: ast::ImmediateValue, + type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>, + ) -> Result { + let (scalar_t, state_space) = + if let Some((ast::Type::Scalar(scalar), state_space)) = type_space { + (*scalar, state_space) + } else { + return Err(TranslateError::UntypedSymbol); + }; + let id = self + .id_def + .register_intermediate(ast::Type::Scalar(scalar_t), state_space); + self.func.push(Statement::Constant(ConstantDefinition { + dst: id, + typ: scalar_t, + value, + })); + Ok(id) + } +} + +impl<'a, 'b> ast::VisitorMap for FlattenArguments<'a, 'b> { + fn visit( + &mut self, + args: TypedOperand, + type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>, + is_dst: bool, + _relaxed_type_check: bool, + ) -> Result { + match args { + TypedOperand::Reg(r) => self.reg(r), + TypedOperand::Imm(x) => self.immediate(x, type_space), + TypedOperand::RegOffset(reg, offset) => { + self.reg_offset(reg, offset, type_space, is_dst) + } + TypedOperand::VecMember(..) => Err(error_unreachable()), + } + } + + fn visit_ident( + &mut self, + name: ::Ident, + _type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>, + _is_dst: bool, + _relaxed_type_check: bool, + ) -> Result<::Ident, TranslateError> { + self.reg(name) + } +} diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index f6b700b3..896a34aa 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -11,6 +11,7 @@ use std::{ mod convert_to_stateful_memory_access; mod convert_to_typed; +mod expand_arguments; mod fix_special_registers; mod insert_mem_ssa_statements; mod normalize_identifiers; @@ -181,10 +182,10 @@ fn to_ssa<'input, 'b>( &mut numeric_id_defs, &mut (*func_decl).borrow_mut(), )?; + let mut numeric_id_defs = numeric_id_defs.finish(); + let expanded_statements = expand_arguments::run(ssa_statements, &mut numeric_id_defs)?; todo!() /* - let mut numeric_id_defs = numeric_id_defs.finish(); - let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs)?; let expanded_statements = insert_implicit_conversions(expanded_statements, &mut numeric_id_defs)?; let mut numeric_id_defs = numeric_id_defs.unmut(); @@ -743,7 +744,7 @@ impl> Statement, T> { fn visit_map, Err>( self, visitor: &mut impl ast::VisitorMap, - ) -> std::result::Result, T>, Err> { + ) -> std::result::Result, To>, Err> { Ok(match self { Statement::Instruction(i) => { return ast::visit_map(i, visitor).map(Statement::Instruction) @@ -883,6 +884,12 @@ impl> Statement, T> { false, false, )?; + let offset_src = visitor.visit( + offset_src, + Some((&underlying_type, state_space)), + false, + false, + )?; Statement::PtrAccess(PtrAccess { underlying_type, state_space, From 144f8bd5ed0b049bb17a9a67ad0dc45c58c0df12 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Wed, 28 Aug 2024 01:52:54 +0200 Subject: [PATCH 34/47] Port remaining two passes --- ptx/src/pass/extract_globals.rs | 282 ++++++++++++++ ptx/src/pass/fix_special_registers.rs | 53 --- ptx/src/pass/insert_implicit_conversions.rs | 402 ++++++++++++++++++++ ptx/src/pass/mod.rs | 65 +++- ptx/src/pass/normalize_labels.rs | 48 +++ 5 files changed, 791 insertions(+), 59 deletions(-) create mode 100644 ptx/src/pass/extract_globals.rs create mode 100644 ptx/src/pass/insert_implicit_conversions.rs create mode 100644 ptx/src/pass/normalize_labels.rs diff --git a/ptx/src/pass/extract_globals.rs b/ptx/src/pass/extract_globals.rs new file mode 100644 index 00000000..680a5eee --- /dev/null +++ b/ptx/src/pass/extract_globals.rs @@ -0,0 +1,282 @@ +use super::*; + +pub(super) fn run<'input, 'b>( + sorted_statements: Vec, + ptx_impl_imports: &mut HashMap, + id_def: &mut NumericIdResolver, +) -> Result<(Vec, Vec>), TranslateError> { + let mut local = Vec::with_capacity(sorted_statements.len()); + let mut global = Vec::new(); + for statement in sorted_statements { + match statement { + Statement::Variable( + var @ ast::Variable { + state_space: ast::StateSpace::Shared, + .. + }, + ) + | Statement::Variable( + var @ ast::Variable { + state_space: ast::StateSpace::Global, + .. + }, + ) => global.push(var), + Statement::Instruction(ast::Instruction::Bfe { data, arguments }) => { + let fn_name = [ZLUDA_PTX_PREFIX, "bfe_", scalar_to_ptx_name(data)].concat(); + local.push(instruction_to_fn_call( + id_def, + ptx_impl_imports, + ast::Instruction::Bfe { data, arguments }, + fn_name, + )?); + } + Statement::Instruction(ast::Instruction::Bfi { data, arguments }) => { + let fn_name = [ZLUDA_PTX_PREFIX, "bfi_", scalar_to_ptx_name(data)].concat(); + local.push(instruction_to_fn_call( + id_def, + ptx_impl_imports, + ast::Instruction::Bfi { data, arguments }, + fn_name, + )?); + } + Statement::Instruction(ast::Instruction::Brev { data, arguments }) => { + let fn_name: String = + [ZLUDA_PTX_PREFIX, "brev_", scalar_to_ptx_name(data)].concat(); + local.push(instruction_to_fn_call( + id_def, + ptx_impl_imports, + ast::Instruction::Brev { data, arguments }, + fn_name, + )?); + } + Statement::Instruction(ast::Instruction::Activemask { arguments }) => { + let fn_name = [ZLUDA_PTX_PREFIX, "activemask"].concat(); + local.push(instruction_to_fn_call( + id_def, + ptx_impl_imports, + ast::Instruction::Activemask { arguments }, + fn_name, + )?); + } + Statement::Instruction(ast::Instruction::Atom { + data: + data @ ast::AtomDetails { + op: ast::AtomicOp::IncrementWrap, + semantics, + scope, + space, + .. + }, + arguments, + }) => { + let fn_name = [ + ZLUDA_PTX_PREFIX, + "atom_", + semantics_to_ptx_name(semantics), + "_", + scope_to_ptx_name(scope), + "_", + space_to_ptx_name(space), + "_inc", + ] + .concat(); + local.push(instruction_to_fn_call( + id_def, + ptx_impl_imports, + ast::Instruction::Atom { data, arguments }, + fn_name, + )?); + } + Statement::Instruction(ast::Instruction::Atom { + data: + data @ ast::AtomDetails { + op: ast::AtomicOp::DecrementWrap, + semantics, + scope, + space, + .. + }, + arguments, + }) => { + let fn_name = [ + ZLUDA_PTX_PREFIX, + "atom_", + semantics_to_ptx_name(semantics), + "_", + scope_to_ptx_name(scope), + "_", + space_to_ptx_name(space), + "_dec", + ] + .concat(); + local.push(instruction_to_fn_call( + id_def, + ptx_impl_imports, + ast::Instruction::Atom { data, arguments }, + fn_name, + )?); + } + Statement::Instruction(ast::Instruction::Atom { + data: + data @ ast::AtomDetails { + op: ast::AtomicOp::FloatAdd, + semantics, + scope, + space, + .. + }, + arguments, + }) => { + let scalar_type = match data.type_ { + ptx_parser::Type::Scalar(scalar) => scalar, + _ => return Err(error_unreachable()), + }; + let fn_name = [ + ZLUDA_PTX_PREFIX, + "atom_", + semantics_to_ptx_name(semantics), + "_", + scope_to_ptx_name(scope), + "_", + space_to_ptx_name(space), + "_add_", + scalar_to_ptx_name(scalar_type), + ] + .concat(); + local.push(instruction_to_fn_call( + id_def, + ptx_impl_imports, + ast::Instruction::Atom { data, arguments }, + fn_name, + )?); + } + s => local.push(s), + } + } + Ok((local, global)) +} + +fn instruction_to_fn_call( + id_defs: &mut NumericIdResolver, + ptx_impl_imports: &mut HashMap, + inst: ast::Instruction, + fn_name: String, +) -> Result { + let mut arguments = Vec::new(); + ast::visit_map(inst, &mut |operand, + type_space: Option<( + &ast::Type, + ast::StateSpace, + )>, + is_dst, + _| { + let (typ, space) = match type_space { + Some((typ, space)) => (typ.clone(), space), + None => return Err(error_unreachable()), + }; + arguments.push((operand, is_dst, typ, space)); + Ok(SpirvWord(0)) + })?; + let return_arguments_count = arguments + .iter() + .position(|(desc, is_dst, _, _)| !is_dst) + .unwrap_or(arguments.len()); + let (return_arguments, input_arguments) = arguments.split_at(return_arguments_count); + let fn_id = register_external_fn_call( + id_defs, + ptx_impl_imports, + fn_name, + return_arguments + .iter() + .map(|(_, _, typ, state)| (typ, *state)), + input_arguments + .iter() + .map(|(_, _, typ, state)| (typ, *state)), + )?; + Ok(Statement::Instruction(ast::Instruction::Call { + data: ast::CallDetails { + uniform: false, + return_arguments: return_arguments + .iter() + .map(|(_, _, typ, state)| (typ.clone(), *state)) + .collect::>(), + input_arguments: input_arguments + .iter() + .map(|(_, _, typ, state)| (typ.clone(), *state)) + .collect::>(), + }, + arguments: ast::CallArgs { + return_arguments: return_arguments + .iter() + .map(|(name, _, _, _)| *name) + .collect::>(), + func: fn_id, + input_arguments: input_arguments + .iter() + .map(|(name, _, _, _)| *name) + .collect::>(), + }, + })) +} + +fn scalar_to_ptx_name(this: ast::ScalarType) -> &'static str { + match this { + ast::ScalarType::B8 => "b8", + ast::ScalarType::B16 => "b16", + ast::ScalarType::B32 => "b32", + ast::ScalarType::B64 => "b64", + ast::ScalarType::B128 => "b128", + ast::ScalarType::U8 => "u8", + ast::ScalarType::U16 => "u16", + ast::ScalarType::U16x2 => "u16x2", + ast::ScalarType::U32 => "u32", + ast::ScalarType::U64 => "u64", + ast::ScalarType::S8 => "s8", + ast::ScalarType::S16 => "s16", + ast::ScalarType::S16x2 => "s16x2", + ast::ScalarType::S32 => "s32", + ast::ScalarType::S64 => "s64", + ast::ScalarType::F16 => "f16", + ast::ScalarType::F16x2 => "f16x2", + ast::ScalarType::F32 => "f32", + ast::ScalarType::F64 => "f64", + ast::ScalarType::BF16 => "bf16", + ast::ScalarType::BF16x2 => "bf16x2", + ast::ScalarType::Pred => "pred", + } +} + +fn semantics_to_ptx_name(this: ast::AtomSemantics) -> &'static str { + match this { + ast::AtomSemantics::Relaxed => "relaxed", + ast::AtomSemantics::Acquire => "acquire", + ast::AtomSemantics::Release => "release", + ast::AtomSemantics::AcqRel => "acq_rel", + } +} + +fn scope_to_ptx_name(this: ast::MemScope) -> &'static str { + match this { + ast::MemScope::Cta => "cta", + ast::MemScope::Gpu => "gpu", + ast::MemScope::Sys => "sys", + ast::MemScope::Cluster => "cluster", + } +} + +fn space_to_ptx_name(this: ast::StateSpace) -> &'static str { + match this { + ast::StateSpace::Generic => "generic", + ast::StateSpace::Global => "global", + ast::StateSpace::Shared => "shared", + ast::StateSpace::Reg => "reg", + ast::StateSpace::Const => "const", + ast::StateSpace::Local => "local", + ast::StateSpace::Param => "param", + ast::StateSpace::Sreg => "sreg", + ast::StateSpace::SharedCluster => "shared_cluster", + ast::StateSpace::ParamEntry => "param_entry", + ast::StateSpace::SharedCta => "shared_cta", + ast::StateSpace::ParamFunc => "param_func", + } +} diff --git a/ptx/src/pass/fix_special_registers.rs b/ptx/src/pass/fix_special_registers.rs index 871537d5..304bc611 100644 --- a/ptx/src/pass/fix_special_registers.rs +++ b/ptx/src/pass/fix_special_registers.rs @@ -128,56 +128,3 @@ impl<'a, 'b, 'input> SpecialRegisterResolver<'a, 'b, 'input> { } } } - -fn register_external_fn_call<'a>( - id_defs: &mut NumericIdResolver, - ptx_impl_imports: &mut HashMap, - name: String, - return_arguments: impl Iterator, - input_arguments: impl Iterator, -) -> Result { - match ptx_impl_imports.entry(name) { - hash_map::Entry::Vacant(entry) => { - let fn_id = id_defs.register_intermediate(None); - let return_arguments = fn_arguments_to_variables(id_defs, return_arguments); - let input_arguments = fn_arguments_to_variables(id_defs, input_arguments); - let func_decl = ast::MethodDeclaration:: { - return_arguments, - name: ast::MethodName::Func(fn_id), - input_arguments, - shared_mem: None, - }; - let func = Function { - func_decl: Rc::new(RefCell::new(func_decl)), - globals: Vec::new(), - body: None, - import_as: Some(entry.key().clone()), - tuning: Vec::new(), - linkage: ast::LinkingDirective::EXTERN, - }; - entry.insert(Directive::Method(func)); - Ok(fn_id) - } - hash_map::Entry::Occupied(entry) => match entry.get() { - Directive::Method(Function { func_decl, .. }) => match (**func_decl).borrow().name { - ast::MethodName::Func(fn_id) => Ok(fn_id), - ast::MethodName::Kernel(_) => Err(error_unreachable()), - }, - _ => Err(error_unreachable()), - }, - } -} - -fn fn_arguments_to_variables<'a>( - id_defs: &mut NumericIdResolver, - args: impl Iterator, -) -> Vec> { - args.map(|(typ, space)| ast::Variable { - align: None, - v_type: typ.clone(), - state_space: space, - name: id_defs.register_intermediate(None), - array_init: Vec::new(), - }) - .collect::>() -} \ No newline at end of file diff --git a/ptx/src/pass/insert_implicit_conversions.rs b/ptx/src/pass/insert_implicit_conversions.rs new file mode 100644 index 00000000..4a0dc8e7 --- /dev/null +++ b/ptx/src/pass/insert_implicit_conversions.rs @@ -0,0 +1,402 @@ +use std::mem; + +use super::*; +use ptx_parser as ast; + +/* + There are several kinds of implicit conversions in PTX: + * auto-bitcast: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#type-information-for-instructions-and-operands + * special ld/st/cvt conversion rules: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size + - ld.param: not documented, but for instruction `ld.param. x, [y]`, + semantics are to first zext/chop/bitcast `y` as needed and then do + documented special ld/st/cvt conversion rules for destination operands + - st.param [x] y (used as function return arguments) same rule as above applies + - generic/global ld: for instruction `ld x, [y]`, y must be of type + b64/u64/s64, which is bitcast to a pointer, dereferenced and then + documented special ld/st/cvt conversion rules are applied to dst + - generic/global st: for instruction `st [x], y`, x must be of type + b64/u64/s64, which is bitcast to a pointer +*/ +pub(super) fn run( + func: Vec, + id_def: &mut MutableNumericIdResolver, +) -> Result, TranslateError> { + let mut result = Vec::with_capacity(func.len()); + for s in func.into_iter() { + match s { + Statement::Instruction(inst) => { + insert_implicit_conversions_impl( + &mut result, + id_def, + Statement::Instruction(inst), + )?; + } + Statement::PtrAccess(access) => { + insert_implicit_conversions_impl( + &mut result, + id_def, + Statement::PtrAccess(access), + )?; + } + Statement::RepackVector(repack) => { + insert_implicit_conversions_impl( + &mut result, + id_def, + Statement::RepackVector(repack), + )?; + } + s @ Statement::Conditional(_) + | s @ Statement::Conversion(_) + | s @ Statement::Label(_) + | s @ Statement::Constant(_) + | s @ Statement::Variable(_) + | s @ Statement::LoadVar(..) + | s @ Statement::StoreVar(..) + | s @ Statement::RetValue(..) + | s @ Statement::FunctionPointer(..) => result.push(s), + } + } + Ok(result) +} + +fn insert_implicit_conversions_impl( + func: &mut Vec, + id_def: &mut MutableNumericIdResolver, + stmt: ExpandedStatement, +) -> Result<(), TranslateError> { + let mut post_conv = Vec::new(); + let statement = stmt.visit_map::( + &mut |operand, + type_state: Option<(&ast::Type, ast::StateSpace)>, + is_dst, + relaxed_type_check| { + let (instr_type, instruction_space) = match type_state { + None => return Ok(operand), + Some(t) => t, + }; + let (operand_type, operand_space) = id_def.get_typed(operand)?; + let conversion_fn = if relaxed_type_check { + if is_dst { + should_convert_relaxed_dst_wrapper + } else { + should_convert_relaxed_src_wrapper + } + } else { + default_implicit_conversion + }; + match conversion_fn( + (operand_space, &operand_type), + (instruction_space, instr_type), + )? { + Some(conv_kind) => { + let conv_output = if is_dst { &mut post_conv } else { &mut *func }; + let mut from_type = instr_type.clone(); + let mut from_space = instruction_space; + let mut to_type = operand_type; + let mut to_space = operand_space; + let mut src = + id_def.register_intermediate(instr_type.clone(), instruction_space); + let mut dst = operand; + let result = Ok::<_, TranslateError>(src); + if !is_dst { + mem::swap(&mut src, &mut dst); + mem::swap(&mut from_type, &mut to_type); + mem::swap(&mut from_space, &mut to_space); + } + conv_output.push(Statement::Conversion(ImplicitConversion { + src, + dst, + from_type, + from_space, + to_type, + to_space, + kind: conv_kind, + })); + result + } + None => Ok(operand), + } + }, + )?; + func.push(statement); + func.append(&mut post_conv); + Ok(()) +} + +fn default_implicit_conversion( + (operand_space, operand_type): (ast::StateSpace, &ast::Type), + (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), +) -> Result, TranslateError> { + if !state_is_compatible(instruction_space, operand_space) { + default_implicit_conversion_space( + (operand_space, operand_type), + (instruction_space, instruction_type), + ) + } else if instruction_type != operand_type { + default_implicit_conversion_type(instruction_space, operand_type, instruction_type) + } else { + Ok(None) + } +} + +// Space is different +fn default_implicit_conversion_space( + (operand_space, operand_type): (ast::StateSpace, &ast::Type), + (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), +) -> Result, TranslateError> { + if (instruction_space == ast::StateSpace::Generic && coerces_to_generic(operand_space)) + || (operand_space == ast::StateSpace::Generic && coerces_to_generic(instruction_space)) + { + Ok(Some(ConversionKind::PtrToPtr)) + } else if state_is_compatible(operand_space, ast::StateSpace::Reg) { + match operand_type { + ast::Type::Pointer(operand_ptr_type, operand_ptr_space) + if *operand_ptr_space == instruction_space => + { + if instruction_type != &ast::Type::Scalar(*operand_ptr_type) { + Ok(Some(ConversionKind::PtrToPtr)) + } else { + Ok(None) + } + } + // TODO: 32 bit + ast::Type::Scalar(ast::ScalarType::B64) + | ast::Type::Scalar(ast::ScalarType::U64) + | ast::Type::Scalar(ast::ScalarType::S64) => match instruction_space { + ast::StateSpace::Global + | ast::StateSpace::Generic + | ast::StateSpace::Const + | ast::StateSpace::Local + | ast::StateSpace::Shared => Ok(Some(ConversionKind::BitToPtr)), + _ => Err(TranslateError::MismatchedType), + }, + ast::Type::Scalar(ast::ScalarType::B32) + | ast::Type::Scalar(ast::ScalarType::U32) + | ast::Type::Scalar(ast::ScalarType::S32) => match instruction_space { + ast::StateSpace::Const | ast::StateSpace::Local | ast::StateSpace::Shared => { + Ok(Some(ConversionKind::BitToPtr)) + } + _ => Err(TranslateError::MismatchedType), + }, + _ => Err(TranslateError::MismatchedType), + } + } else if state_is_compatible(instruction_space, ast::StateSpace::Reg) { + match instruction_type { + ast::Type::Pointer(instruction_ptr_type, instruction_ptr_space) + if operand_space == *instruction_ptr_space => + { + if operand_type != &ast::Type::Scalar(*instruction_ptr_type) { + Ok(Some(ConversionKind::PtrToPtr)) + } else { + Ok(None) + } + } + _ => Err(TranslateError::MismatchedType), + } + } else { + Err(TranslateError::MismatchedType) + } +} + +// Space is same, but type is different +fn default_implicit_conversion_type( + space: ast::StateSpace, + operand_type: &ast::Type, + instruction_type: &ast::Type, +) -> Result, TranslateError> { + if state_is_compatible(space, ast::StateSpace::Reg) { + if should_bitcast(instruction_type, operand_type) { + Ok(Some(ConversionKind::Default)) + } else { + Err(TranslateError::MismatchedType) + } + } else { + Ok(Some(ConversionKind::PtrToPtr)) + } +} + +fn coerces_to_generic(this: ast::StateSpace) -> bool { + match this { + ast::StateSpace::Global + | ast::StateSpace::Const + | ast::StateSpace::Local + | ptx_parser::StateSpace::SharedCta + | ast::StateSpace::SharedCluster + | ast::StateSpace::Shared => true, + ast::StateSpace::Reg + | ast::StateSpace::Param + | ast::StateSpace::ParamEntry + | ast::StateSpace::ParamFunc + | ast::StateSpace::Generic + | ast::StateSpace::Sreg => false, + } +} + +fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool { + match (instr, operand) { + (ast::Type::Scalar(inst), ast::Type::Scalar(operand)) => { + if inst.size_of() != operand.size_of() { + return false; + } + match inst.kind() { + ast::ScalarKind::Bit => operand.kind() != ast::ScalarKind::Bit, + ast::ScalarKind::Float => operand.kind() == ast::ScalarKind::Bit, + ast::ScalarKind::Signed => { + operand.kind() == ast::ScalarKind::Bit + || operand.kind() == ast::ScalarKind::Unsigned + } + ast::ScalarKind::Unsigned => { + operand.kind() == ast::ScalarKind::Bit + || operand.kind() == ast::ScalarKind::Signed + } + ast::ScalarKind::Pred => false, + } + } + (ast::Type::Vector(inst, _), ast::Type::Vector(operand, _)) + | (ast::Type::Array(inst, _), ast::Type::Array(operand, _)) => { + should_bitcast(&ast::Type::Scalar(*inst), &ast::Type::Scalar(*operand)) + } + _ => false, + } +} + +fn should_convert_relaxed_dst_wrapper( + (operand_space, operand_type): (ast::StateSpace, &ast::Type), + (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), +) -> Result, TranslateError> { + if !state_is_compatible(operand_space, instruction_space) { + return Err(TranslateError::MismatchedType); + } + if operand_type == instruction_type { + return Ok(None); + } + match should_convert_relaxed_dst(operand_type, instruction_type) { + conv @ Some(_) => Ok(conv), + None => Err(TranslateError::MismatchedType), + } +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-destination-operands +fn should_convert_relaxed_dst( + dst_type: &ast::Type, + instr_type: &ast::Type, +) -> Option { + if dst_type == instr_type { + return None; + } + match (dst_type, instr_type) { + (ast::Type::Scalar(dst_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() { + ast::ScalarKind::Bit => { + if instr_type.size_of() <= dst_type.size_of() { + Some(ConversionKind::Default) + } else { + None + } + } + ast::ScalarKind::Signed => { + if dst_type.kind() != ast::ScalarKind::Float { + if instr_type.size_of() == dst_type.size_of() { + Some(ConversionKind::Default) + } else if instr_type.size_of() < dst_type.size_of() { + Some(ConversionKind::SignExtend) + } else { + None + } + } else { + None + } + } + ast::ScalarKind::Unsigned => { + if instr_type.size_of() <= dst_type.size_of() + && dst_type.kind() != ast::ScalarKind::Float + { + Some(ConversionKind::Default) + } else { + None + } + } + ast::ScalarKind::Float => { + if instr_type.size_of() <= dst_type.size_of() + && dst_type.kind() == ast::ScalarKind::Bit + { + Some(ConversionKind::Default) + } else { + None + } + } + ast::ScalarKind::Pred => None, + }, + (ast::Type::Vector(dst_type, _), ast::Type::Vector(instr_type, _)) + | (ast::Type::Array(dst_type, _), ast::Type::Array(instr_type, _)) => { + should_convert_relaxed_dst( + &ast::Type::Scalar(*dst_type), + &ast::Type::Scalar(*instr_type), + ) + } + _ => None, + } +} + +fn should_convert_relaxed_src_wrapper( + (operand_space, operand_type): (ast::StateSpace, &ast::Type), + (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), +) -> Result, TranslateError> { + if !state_is_compatible(operand_space, instruction_space) { + return Err(TranslateError::MismatchedType); + } + if operand_type == instruction_type { + return Ok(None); + } + match should_convert_relaxed_src(operand_type, instruction_type) { + conv @ Some(_) => Ok(conv), + None => Err(TranslateError::MismatchedType), + } +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-source-operands +fn should_convert_relaxed_src( + src_type: &ast::Type, + instr_type: &ast::Type, +) -> Option { + if src_type == instr_type { + return None; + } + match (src_type, instr_type) { + (ast::Type::Scalar(src_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() { + ast::ScalarKind::Bit => { + if instr_type.size_of() <= src_type.size_of() { + Some(ConversionKind::Default) + } else { + None + } + } + ast::ScalarKind::Signed | ast::ScalarKind::Unsigned => { + if instr_type.size_of() <= src_type.size_of() + && src_type.kind() != ast::ScalarKind::Float + { + Some(ConversionKind::Default) + } else { + None + } + } + ast::ScalarKind::Float => { + if instr_type.size_of() <= src_type.size_of() + && src_type.kind() == ast::ScalarKind::Bit + { + Some(ConversionKind::Default) + } else { + None + } + } + ast::ScalarKind::Pred => None, + }, + (ast::Type::Vector(dst_type, _), ast::Type::Vector(instr_type, _)) + | (ast::Type::Array(dst_type, _), ast::Type::Array(instr_type, _)) => { + should_convert_relaxed_src( + &ast::Type::Scalar(*dst_type), + &ast::Type::Scalar(*instr_type), + ) + } + _ => None, + } +} diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index 896a34aa..1fdf3a62 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -16,6 +16,9 @@ mod fix_special_registers; mod insert_mem_ssa_statements; mod normalize_identifiers; mod normalize_predicates; +mod insert_implicit_conversions; +mod normalize_labels; +mod extract_globals; static ZLUDA_PTX_IMPL_INTEL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.spv"); static ZLUDA_PTX_IMPL_AMD: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc"); @@ -184,14 +187,12 @@ fn to_ssa<'input, 'b>( )?; let mut numeric_id_defs = numeric_id_defs.finish(); let expanded_statements = expand_arguments::run(ssa_statements, &mut numeric_id_defs)?; - todo!() - /* let expanded_statements = - insert_implicit_conversions(expanded_statements, &mut numeric_id_defs)?; + insert_implicit_conversions::run(expanded_statements, &mut numeric_id_defs)?; let mut numeric_id_defs = numeric_id_defs.unmut(); - let labeled_statements = normalize_labels(expanded_statements, &mut numeric_id_defs); + let labeled_statements = normalize_labels::run(expanded_statements, &mut numeric_id_defs); let (f_body, globals) = - extract_globals(labeled_statements, ptx_impl_imports, &mut numeric_id_defs)?; + extract_globals::run(labeled_statements, ptx_impl_imports, &mut numeric_id_defs)?; Ok(Function { func_decl: func_decl, globals: globals, @@ -200,7 +201,6 @@ fn to_ssa<'input, 'b>( tuning, linkage, }) - */ } pub struct Module { @@ -1220,3 +1220,56 @@ fn state_is_compatible(this: ast::StateSpace, other: ast::StateSpace) -> bool { || this == ast::StateSpace::Reg && other == ast::StateSpace::Sreg || this == ast::StateSpace::Sreg && other == ast::StateSpace::Reg } + +fn register_external_fn_call<'a>( + id_defs: &mut NumericIdResolver, + ptx_impl_imports: &mut HashMap, + name: String, + return_arguments: impl Iterator, + input_arguments: impl Iterator, +) -> Result { + match ptx_impl_imports.entry(name) { + hash_map::Entry::Vacant(entry) => { + let fn_id = id_defs.register_intermediate(None); + let return_arguments = fn_arguments_to_variables(id_defs, return_arguments); + let input_arguments = fn_arguments_to_variables(id_defs, input_arguments); + let func_decl = ast::MethodDeclaration:: { + return_arguments, + name: ast::MethodName::Func(fn_id), + input_arguments, + shared_mem: None, + }; + let func = Function { + func_decl: Rc::new(RefCell::new(func_decl)), + globals: Vec::new(), + body: None, + import_as: Some(entry.key().clone()), + tuning: Vec::new(), + linkage: ast::LinkingDirective::EXTERN, + }; + entry.insert(Directive::Method(func)); + Ok(fn_id) + } + hash_map::Entry::Occupied(entry) => match entry.get() { + Directive::Method(Function { func_decl, .. }) => match (**func_decl).borrow().name { + ast::MethodName::Func(fn_id) => Ok(fn_id), + ast::MethodName::Kernel(_) => Err(error_unreachable()), + }, + _ => Err(error_unreachable()), + }, + } +} + +fn fn_arguments_to_variables<'a>( + id_defs: &mut NumericIdResolver, + args: impl Iterator, +) -> Vec> { + args.map(|(typ, space)| ast::Variable { + align: None, + v_type: typ.clone(), + state_space: space, + name: id_defs.register_intermediate(None), + array_init: Vec::new(), + }) + .collect::>() +} diff --git a/ptx/src/pass/normalize_labels.rs b/ptx/src/pass/normalize_labels.rs new file mode 100644 index 00000000..097d87c7 --- /dev/null +++ b/ptx/src/pass/normalize_labels.rs @@ -0,0 +1,48 @@ +use std::{collections::HashSet, iter}; + +use super::*; + +pub(super) fn run( + func: Vec, + id_def: &mut NumericIdResolver, +) -> Vec { + let mut labels_in_use = HashSet::new(); + for s in func.iter() { + match s { + Statement::Instruction(i) => { + if let Some(target) = jump_target(i) { + labels_in_use.insert(target); + } + } + Statement::Conditional(cond) => { + labels_in_use.insert(cond.if_true); + labels_in_use.insert(cond.if_false); + } + Statement::Variable(..) + | Statement::LoadVar(..) + | Statement::StoreVar(..) + | Statement::RetValue(..) + | Statement::Conversion(..) + | Statement::Constant(..) + | Statement::Label(..) + | Statement::PtrAccess { .. } + | Statement::RepackVector(..) + | Statement::FunctionPointer(..) => {} + } + } + iter::once(Statement::Label(id_def.register_intermediate(None))) + .chain(func.into_iter().filter(|s| match s { + Statement::Label(i) => labels_in_use.contains(i), + _ => true, + })) + .collect::>() +} + +fn jump_target>( + this: &ast::Instruction, +) -> Option { + match this { + ast::Instruction::Bra { arguments } => Some(arguments.src), + _ => None, + } +} From 790fe1857927e12ac7ad3da539f21f6bc70e34d7 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Fri, 30 Aug 2024 03:12:33 +0200 Subject: [PATCH 35/47] Emit most of SPIR-V --- .../convert_dynamic_shared_memory_usage.rs | 299 ++ .../pass/convert_to_stateful_memory_access.rs | 19 - ptx/src/pass/emit_spirv.rs | 2767 +++++++++++++++++ ptx/src/pass/mod.rs | 429 ++- ptx_parser/src/ast.rs | 99 +- 5 files changed, 3535 insertions(+), 78 deletions(-) create mode 100644 ptx/src/pass/convert_dynamic_shared_memory_usage.rs create mode 100644 ptx/src/pass/emit_spirv.rs diff --git a/ptx/src/pass/convert_dynamic_shared_memory_usage.rs b/ptx/src/pass/convert_dynamic_shared_memory_usage.rs new file mode 100644 index 00000000..1dac7fd7 --- /dev/null +++ b/ptx/src/pass/convert_dynamic_shared_memory_usage.rs @@ -0,0 +1,299 @@ +use std::collections::{BTreeMap, BTreeSet}; + +use super::*; + +/* + PTX represents dynamically allocated shared local memory as + .extern .shared .b32 shared_mem[]; + In SPIRV/OpenCL world this is expressed as an additional argument to the kernel + And in AMD compilation + This pass looks for all uses of .extern .shared and converts them to + an additional method argument + The question is how this artificial argument should be expressed. There are + several options: + * Straight conversion: + .shared .b32 shared_mem[] + * Introduce .param_shared statespace: + .param_shared .b32 shared_mem + or + .param_shared .b32 shared_mem[] + * Introduce .shared_ptr type: + .param .shared_ptr .b32 shared_mem + * Reuse .ptr hint: + .param .u64 .ptr shared_mem + This is the most tempting, but also the most nonsensical, .ptr is just a + hint, which has no semantical meaning (and the output of our + transformation has a semantical meaning - we emit additional + "OpFunctionParameter ..." with type "OpTypePointer Workgroup ...") +*/ +pub(super) fn run<'input>( + module: Vec>, + kernels_methods_call_map: &MethodsCallMap<'input>, + new_id: &mut impl FnMut() -> SpirvWord, +) -> Result>, TranslateError> { + let mut globals_shared = HashMap::new(); + for dir in module.iter() { + match dir { + Directive::Variable( + _, + ast::Variable { + state_space: ast::StateSpace::Shared, + name, + v_type, + .. + }, + ) => { + globals_shared.insert(*name, v_type.clone()); + } + _ => {} + } + } + if globals_shared.len() == 0 { + return Ok(module); + } + let mut methods_to_directly_used_shared_globals = HashMap::<_, HashSet>::new(); + let module = module + .into_iter() + .map(|directive| match directive { + Directive::Method(Function { + func_decl, + globals, + body: Some(statements), + import_as, + tuning, + linkage, + }) => { + let call_key = (*func_decl).borrow().name; + let statements = statements + .into_iter() + .map(|statement| { + statement.visit_map( + &mut |id, _: Option<(&ast::Type, ast::StateSpace)>, _, _| { + if let Some(_) = globals_shared.get(&id) { + methods_to_directly_used_shared_globals + .entry(call_key) + .or_insert_with(HashSet::new) + .insert(id); + } + Ok::<_, TranslateError>(id) + }, + ) + }) + .collect::, _>>()?; + Ok::<_, TranslateError>(Directive::Method(Function { + func_decl, + globals, + body: Some(statements), + import_as, + tuning, + linkage, + })) + } + directive => Ok(directive), + }) + .collect::, _>>()?; + // If there's a chain `kernel` -> `fn1` -> `fn2`, where only `fn2` uses extern shared, + // make sure it gets propagated to `fn1` and `kernel` + let methods_to_indirectly_used_shared_globals = resolve_indirect_uses_of_globals_shared( + methods_to_directly_used_shared_globals, + kernels_methods_call_map, + ); + // now visit every method declaration and inject those additional arguments + let mut directives = Vec::with_capacity(module.len()); + for directive in module.into_iter() { + match directive { + Directive::Method(Function { + func_decl, + globals, + body: Some(statements), + import_as, + tuning, + linkage, + }) => { + let statements = { + let func_decl_ref = &mut (*func_decl).borrow_mut(); + let method_name = func_decl_ref.name; + insert_arguments_remap_statements( + new_id, + kernels_methods_call_map, + &globals_shared, + &methods_to_indirectly_used_shared_globals, + method_name, + &mut directives, + func_decl_ref, + statements, + )? + }; + directives.push(Directive::Method(Function { + func_decl, + globals, + body: Some(statements), + import_as, + tuning, + linkage, + })); + } + directive => directives.push(directive), + } + } + Ok(directives) +} + +// We need to compute two kinds of information: +// * If it's a kernel -> size of .shared globals in use (direct or indirect) +// * If it's a function -> does it use .shared global (directly or indirectly) +fn resolve_indirect_uses_of_globals_shared<'input>( + methods_use_of_globals_shared: HashMap, HashSet>, + kernels_methods_call_map: &MethodsCallMap<'input>, +) -> HashMap, BTreeSet> { + let mut result = HashMap::new(); + for (method, callees) in kernels_methods_call_map.methods() { + let mut indirect_globals = methods_use_of_globals_shared + .get(&method) + .into_iter() + .flatten() + .copied() + .collect::>(); + for &callee in callees { + indirect_globals.extend( + methods_use_of_globals_shared + .get(&ast::MethodName::Func(callee)) + .into_iter() + .flatten() + .copied(), + ); + } + result.insert(method, indirect_globals); + } + result +} + +fn insert_arguments_remap_statements<'input>( + new_id: &mut impl FnMut() -> SpirvWord, + kernels_methods_call_map: &MethodsCallMap<'input>, + globals_shared: &HashMap, + methods_to_indirectly_used_shared_globals: &HashMap< + ast::MethodName<'input, SpirvWord>, + BTreeSet, + >, + method_name: ast::MethodName, + result: &mut Vec, + func_decl_ref: &mut std::cell::RefMut>, + statements: Vec, SpirvWord>>, +) -> Result, SpirvWord>>, TranslateError> { + let remapped_globals_in_method = + if let Some(method_globals) = methods_to_indirectly_used_shared_globals.get(&method_name) { + match method_name { + ast::MethodName::Func(..) => { + let remapped_globals = method_globals + .iter() + .map(|global| { + ( + *global, + ( + new_id(), + globals_shared + .get(&global) + .unwrap_or_else(|| todo!()) + .clone(), + ), + ) + }) + .collect::>(); + for (_, (new_shared_global_id, shared_global_type)) in remapped_globals.iter() { + func_decl_ref.input_arguments.push(ast::Variable { + align: None, + v_type: shared_global_type.clone(), + state_space: ast::StateSpace::Shared, + name: *new_shared_global_id, + array_init: Vec::new(), + }); + } + remapped_globals + } + ast::MethodName::Kernel(..) => method_globals + .iter() + .map(|global| { + ( + *global, + ( + *global, + globals_shared + .get(&global) + .unwrap_or_else(|| todo!()) + .clone(), + ), + ) + }) + .collect::>(), + } + } else { + return Ok(statements); + }; + replace_uses_of_shared_memory( + new_id, + methods_to_indirectly_used_shared_globals, + statements, + remapped_globals_in_method, + ) +} + +fn replace_uses_of_shared_memory<'input>( + new_id: &mut impl FnMut() -> SpirvWord, + methods_to_indirectly_used_shared_globals: &HashMap< + ast::MethodName<'input, SpirvWord>, + BTreeSet, + >, + statements: Vec, + remapped_globals_in_method: BTreeMap, +) -> Result, TranslateError> { + let mut result = Vec::with_capacity(statements.len()); + for statement in statements { + match statement { + Statement::Instruction(ast::Instruction::Call { + mut data, + mut arguments, + }) => { + // We can safely skip checking call arguments, + // because there's simply no way to pass shared ptr + // without converting it to .b64 first + if let Some(shared_globals_used_by_callee) = + methods_to_indirectly_used_shared_globals + .get(&ast::MethodName::Func(arguments.func)) + { + for &shared_global_used_by_callee in shared_globals_used_by_callee { + let (remapped_shared_id, type_) = remapped_globals_in_method + .get(&shared_global_used_by_callee) + .unwrap_or_else(|| todo!()); + data.input_arguments + .push((type_.clone(), ast::StateSpace::Shared)); + arguments.input_arguments.push(*remapped_shared_id); + } + } + result.push(Statement::Instruction(ast::Instruction::Call { + data, + arguments, + })) + } + statement => { + let new_statement = + statement.visit_map(&mut |id, + _: Option<(&ast::Type, ast::StateSpace)>, + _, + _| { + Ok::<_, TranslateError>( + if let Some((remapped_shared_id, _)) = + remapped_globals_in_method.get(&id) + { + *remapped_shared_id + } else { + id + }, + ) + })?; + result.push(new_statement); + } + } + } + Ok(result) +} diff --git a/ptx/src/pass/convert_to_stateful_memory_access.rs b/ptx/src/pass/convert_to_stateful_memory_access.rs index 829e1e60..61b31ad3 100644 --- a/ptx/src/pass/convert_to_stateful_memory_access.rs +++ b/ptx/src/pass/convert_to_stateful_memory_access.rs @@ -394,25 +394,6 @@ fn is_64_bit_integer(id_defs: &NumericIdResolver, id: SpirvWord) -> bool { } } -fn multi_hash_map_append< - K: Eq + std::hash::Hash, - V, - Collection: std::iter::Extend + std::default::Default, ->( - m: &mut HashMap, - key: K, - value: V, -) { - match m.entry(key) { - hash_map::Entry::Occupied(mut entry) => { - entry.get_mut().extend(iter::once(value)); - } - hash_map::Entry::Vacant(entry) => { - entry.insert(Default::default()).extend(iter::once(value)); - } - } -} - fn is_add_ptr_direct( remapped_ids: &HashMap, arg: &ast::AddArgs, diff --git a/ptx/src/pass/emit_spirv.rs b/ptx/src/pass/emit_spirv.rs new file mode 100644 index 00000000..9dff12ee --- /dev/null +++ b/ptx/src/pass/emit_spirv.rs @@ -0,0 +1,2767 @@ +use super::*; +use half::f16; +use ptx_parser as ast; +use rspirv::{binary::Assemble, dr}; +use std::{ + collections::{HashMap, HashSet}, + ffi::CString, + mem, +}; + +pub(super) fn run<'input>( + mut builder: dr::Builder, + id_defs: &GlobalStringIdResolver<'input>, + call_map: MethodsCallMap<'input>, + denorm_information: HashMap< + ptx_parser::MethodName, + HashMap, + >, + directives: Vec>, +) -> Result<(), TranslateError> { + builder.set_version(1, 3); + emit_capabilities(&mut builder); + emit_extensions(&mut builder); + let opencl_id = emit_opencl_import(&mut builder); + emit_memory_model(&mut builder); + let mut map = TypeWordMap::new(&mut builder); + //emit_builtins(&mut builder, &mut map, &id_defs); + let mut kernel_info = HashMap::new(); + let (build_options, should_flush_denorms) = + emit_denorm_build_string(&call_map, &denorm_information); + let (directives, globals_use_map) = get_globals_use_map(directives); + emit_directives( + &mut builder, + &mut map, + &id_defs, + opencl_id, + should_flush_denorms, + &call_map, + globals_use_map, + directives, + &mut kernel_info, + ) +} + +fn emit_capabilities(builder: &mut dr::Builder) { + builder.capability(spirv::Capability::GenericPointer); + builder.capability(spirv::Capability::Linkage); + builder.capability(spirv::Capability::Addresses); + builder.capability(spirv::Capability::Kernel); + builder.capability(spirv::Capability::Int8); + builder.capability(spirv::Capability::Int16); + builder.capability(spirv::Capability::Int64); + builder.capability(spirv::Capability::Float16); + builder.capability(spirv::Capability::Float64); + builder.capability(spirv::Capability::DenormFlushToZero); + // TODO: re-enable when Intel float control extension works + //builder.capability(spirv::Capability::FunctionFloatControlINTEL); +} + +// http://htmlpreview.github.io/?https://github.com/KhronosGroup/SPIRV-Registry/blob/master/extensions/KHR/SPV_KHR_float_controls.html +fn emit_extensions(builder: &mut dr::Builder) { + // TODO: re-enable when Intel float control extension works + //builder.extension("SPV_INTEL_float_controls2"); + builder.extension("SPV_KHR_float_controls"); + builder.extension("SPV_KHR_no_integer_wrap_decoration"); +} + +fn emit_opencl_import(builder: &mut dr::Builder) -> spirv::Word { + builder.ext_inst_import("OpenCL.std") +} + +fn emit_memory_model(builder: &mut dr::Builder) { + builder.memory_model( + spirv::AddressingModel::Physical64, + spirv::MemoryModel::OpenCL, + ); +} + +struct TypeWordMap { + void: spirv::Word, + complex: HashMap, + constants: HashMap<(SpirvType, u64), SpirvWord>, +} + +impl TypeWordMap { + fn new(b: &mut dr::Builder) -> TypeWordMap { + let void = b.type_void(None); + TypeWordMap { + void: void, + complex: HashMap::::new(), + constants: HashMap::new(), + } + } + + fn void(&self) -> spirv::Word { + self.void + } + + fn get_or_add_scalar(&mut self, b: &mut dr::Builder, t: ast::ScalarType) -> SpirvWord { + let key: SpirvScalarKey = t.into(); + self.get_or_add_spirv_scalar(b, key) + } + + fn get_or_add_spirv_scalar(&mut self, b: &mut dr::Builder, key: SpirvScalarKey) -> SpirvWord { + *self.complex.entry(SpirvType::Base(key)).or_insert_with(|| { + SpirvWord(match key { + SpirvScalarKey::B8 => b.type_int(None, 8, 0), + SpirvScalarKey::B16 => b.type_int(None, 16, 0), + SpirvScalarKey::B32 => b.type_int(None, 32, 0), + SpirvScalarKey::B64 => b.type_int(None, 64, 0), + SpirvScalarKey::F16 => b.type_float(None, 16), + SpirvScalarKey::F32 => b.type_float(None, 32), + SpirvScalarKey::F64 => b.type_float(None, 64), + SpirvScalarKey::Pred => b.type_bool(None), + SpirvScalarKey::F16x2 => todo!(), + }) + }) + } + + fn get_or_add(&mut self, b: &mut dr::Builder, t: SpirvType) -> SpirvWord { + match t { + SpirvType::Base(key) => self.get_or_add_spirv_scalar(b, key), + SpirvType::Pointer(ref typ, storage) => { + let base = self.get_or_add(b, *typ.clone()); + *self + .complex + .entry(t) + .or_insert_with(|| SpirvWord(b.type_pointer(None, storage, base.0))) + } + SpirvType::Vector(typ, len) => { + let base = self.get_or_add_spirv_scalar(b, typ); + *self + .complex + .entry(t) + .or_insert_with(|| SpirvWord(b.type_vector(None, base.0, len as u32))) + } + SpirvType::Array(typ, array_dimensions) => { + let (base_type, length) = match &*array_dimensions { + &[] => { + return self.get_or_add(b, SpirvType::Base(typ)); + } + &[len] => { + let u32_type = self.get_or_add_scalar(b, ast::ScalarType::U32); + let base = self.get_or_add_spirv_scalar(b, typ); + let len_const = b.constant_u32(u32_type.0, None, len); + (base, len_const) + } + array_dimensions => { + let u32_type = self.get_or_add_scalar(b, ast::ScalarType::U32); + let base = self + .get_or_add(b, SpirvType::Array(typ, array_dimensions[1..].to_vec())); + let len_const = b.constant_u32(u32_type.0, None, array_dimensions[0]); + (base, len_const) + } + }; + *self + .complex + .entry(SpirvType::Array(typ, array_dimensions)) + .or_insert_with(|| SpirvWord(b.type_array(None, base_type.0, length))) + } + SpirvType::Func(ref out_params, ref in_params) => { + let out_t = match out_params { + Some(p) => self.get_or_add(b, *p.clone()), + None => SpirvWord(self.void()), + }; + let in_t = in_params + .iter() + .map(|t| self.get_or_add(b, t.clone()).0) + .collect::>(); + *self + .complex + .entry(t) + .or_insert_with(|| SpirvWord(b.type_function(None, out_t.0, in_t))) + } + SpirvType::Struct(ref underlying) => { + let underlying_ids = underlying + .iter() + .map(|t| self.get_or_add_spirv_scalar(b, *t).0) + .collect::>(); + *self + .complex + .entry(t) + .or_insert_with(|| SpirvWord(b.type_struct(None, underlying_ids))) + } + } + } + + fn get_or_add_fn( + &mut self, + b: &mut dr::Builder, + in_params: impl Iterator, + mut out_params: impl ExactSizeIterator, + ) -> (SpirvWord, SpirvWord) { + let (out_args, out_spirv_type) = if out_params.len() == 0 { + (None, SpirvWord(self.void())) + } else if out_params.len() == 1 { + let arg_as_key = out_params.next().unwrap(); + ( + Some(Box::new(arg_as_key.clone())), + self.get_or_add(b, arg_as_key), + ) + } else { + // TODO: support multiple return values + todo!() + }; + ( + out_spirv_type, + self.get_or_add(b, SpirvType::Func(out_args, in_params.collect::>())), + ) + } + + fn get_or_add_constant( + &mut self, + b: &mut dr::Builder, + typ: &ast::Type, + init: &[u8], + ) -> Result { + Ok(match typ { + ast::Type::Scalar(t) => match t { + ast::ScalarType::B8 | ast::ScalarType::U8 | ast::ScalarType::S8 => self + .get_or_add_constant_single::( + b, + *t, + init, + |v| v as u64, + |b, result_type, v| b.constant_u32(result_type, None, v as u32), + ), + ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => self + .get_or_add_constant_single::( + b, + *t, + init, + |v| v as u64, + |b, result_type, v| b.constant_u32(result_type, None, v as u32), + ), + ast::ScalarType::B32 | ast::ScalarType::U32 | ast::ScalarType::S32 => self + .get_or_add_constant_single::( + b, + *t, + init, + |v| v as u64, + |b, result_type, v| b.constant_u32(result_type, None, v), + ), + ast::ScalarType::B64 | ast::ScalarType::U64 | ast::ScalarType::S64 => self + .get_or_add_constant_single::( + b, + *t, + init, + |v| v, + |b, result_type, v| b.constant_u64(result_type, None, v), + ), + ast::ScalarType::F16 => self.get_or_add_constant_single::( + b, + *t, + init, + |v| unsafe { mem::transmute::<_, u16>(v) } as u64, + |b, result_type, v| b.constant_f32(result_type, None, v.to_f32()), + ), + ast::ScalarType::F32 => self.get_or_add_constant_single::( + b, + *t, + init, + |v| unsafe { mem::transmute::<_, u32>(v) } as u64, + |b, result_type, v| b.constant_f32(result_type, None, v), + ), + ast::ScalarType::F64 => self.get_or_add_constant_single::( + b, + *t, + init, + |v| unsafe { mem::transmute::<_, u64>(v) }, + |b, result_type, v| b.constant_f64(result_type, None, v), + ), + ast::ScalarType::F16x2 => return Err(TranslateError::Todo), + ast::ScalarType::Pred => self.get_or_add_constant_single::( + b, + *t, + init, + |v| v as u64, + |b, result_type, v| { + if v == 0 { + b.constant_false(result_type, None) + } else { + b.constant_true(result_type, None) + } + }, + ), + ast::ScalarType::S16x2 + | ast::ScalarType::U16x2 + | ast::ScalarType::BF16 + | ast::ScalarType::BF16x2 + | ast::ScalarType::B128 => todo!(), + }, + ast::Type::Vector(typ, len) => { + let result_type = + self.get_or_add(b, SpirvType::Vector(SpirvScalarKey::from(*typ), *len)); + let size_of_t = typ.size_of(); + let components = (0..*len) + .map(|x| { + Ok::<_, TranslateError>( + self.get_or_add_constant( + b, + &ast::Type::Scalar(*typ), + &init[((size_of_t as usize) * (x as usize))..], + )? + .0, + ) + }) + .collect::, _>>()?; + SpirvWord(b.constant_composite(result_type.0, None, components.into_iter())) + } + ast::Type::Array(typ, dims) => match dims.as_slice() { + [] => return Err(error_unreachable()), + [dim] => { + let result_type = self + .get_or_add(b, SpirvType::Array(SpirvScalarKey::from(*typ), vec![*dim])); + let size_of_t = typ.size_of(); + let components = (0..*dim) + .map(|x| { + Ok::<_, TranslateError>( + self.get_or_add_constant( + b, + &ast::Type::Scalar(*typ), + &init[((size_of_t as usize) * (x as usize))..], + )? + .0, + ) + }) + .collect::, _>>()?; + SpirvWord(b.constant_composite(result_type.0, None, components.into_iter())) + } + [first_dim, rest @ ..] => { + let result_type = self.get_or_add( + b, + SpirvType::Array(SpirvScalarKey::from(*typ), rest.to_vec()), + ); + let size_of_t = rest + .iter() + .fold(typ.size_of() as u32, |x, y| (x as u32) * (*y)); + let components = (0..*first_dim) + .map(|x| { + Ok::<_, TranslateError>( + self.get_or_add_constant( + b, + &ast::Type::Array(*typ, rest.to_vec()), + &init[((size_of_t as usize) * (x as usize))..], + )? + .0, + ) + }) + .collect::, _>>()?; + SpirvWord(b.constant_composite(result_type.0, None, components.into_iter())) + } + }, + ast::Type::Pointer(..) => return Err(error_unreachable()), + }) + } + + fn get_or_add_constant_single< + T: Copy, + CastAsU64: FnOnce(T) -> u64, + InsertConstant: FnOnce(&mut dr::Builder, spirv::Word, T) -> spirv::Word, + >( + &mut self, + b: &mut dr::Builder, + key: ast::ScalarType, + init: &[u8], + cast: CastAsU64, + f: InsertConstant, + ) -> SpirvWord { + let value = unsafe { *(init.as_ptr() as *const T) }; + let value_64 = cast(value); + let ht_key = (SpirvType::Base(SpirvScalarKey::from(key)), value_64); + match self.constants.get(&ht_key) { + Some(value) => *value, + None => { + let spirv_type = self.get_or_add_scalar(b, key); + let result = SpirvWord(f(b, spirv_type.0, value)); + self.constants.insert(ht_key, result); + result + } + } + } +} + +#[derive(PartialEq, Eq, Hash, Clone)] +enum SpirvType { + Base(SpirvScalarKey), + Vector(SpirvScalarKey, u8), + Array(SpirvScalarKey, Vec), + Pointer(Box, spirv::StorageClass), + Func(Option>, Vec), + Struct(Vec), +} + +impl SpirvType { + fn new(t: ast::Type) -> Self { + match t { + ast::Type::Scalar(t) => SpirvType::Base(t.into()), + ast::Type::Vector(typ, len) => SpirvType::Vector(typ.into(), len), + ast::Type::Array(t, len) => SpirvType::Array(t.into(), len), + ast::Type::Pointer(pointer_t, space) => SpirvType::Pointer( + Box::new(SpirvType::Base(pointer_t.into())), + space_to_spirv(space), + ), + } + } + + fn pointer_to(t: ast::Type, outer_space: spirv::StorageClass) -> Self { + let key = Self::new(t); + SpirvType::Pointer(Box::new(key), outer_space) + } +} + +impl From for SpirvType { + fn from(t: ast::ScalarType) -> Self { + SpirvType::Base(t.into()) + } +} +// SPIR-V integer type definitions are signless, more below: +// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_unsignedsigned_a_unsigned_versus_signed_integers +// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_validation_rules_for_kernel_a_href_capability_capabilities_a +#[derive(PartialEq, Eq, Hash, Clone, Copy)] +enum SpirvScalarKey { + B8, + B16, + B32, + B64, + F16, + F32, + F64, + Pred, + F16x2, +} + +impl From for SpirvScalarKey { + fn from(t: ast::ScalarType) -> Self { + match t { + ast::ScalarType::B8 | ast::ScalarType::U8 | ast::ScalarType::S8 => SpirvScalarKey::B8, + ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => { + SpirvScalarKey::B16 + } + ast::ScalarType::B32 | ast::ScalarType::U32 | ast::ScalarType::S32 => { + SpirvScalarKey::B32 + } + ast::ScalarType::B64 | ast::ScalarType::U64 | ast::ScalarType::S64 => { + SpirvScalarKey::B64 + } + ast::ScalarType::F16 => SpirvScalarKey::F16, + ast::ScalarType::F32 => SpirvScalarKey::F32, + ast::ScalarType::F64 => SpirvScalarKey::F64, + ast::ScalarType::F16x2 => SpirvScalarKey::F16x2, + ast::ScalarType::Pred => SpirvScalarKey::Pred, + ast::ScalarType::S16x2 + | ast::ScalarType::U16x2 + | ast::ScalarType::BF16 + | ast::ScalarType::BF16x2 + | ast::ScalarType::B128 => todo!(), + } + } +} + +fn space_to_spirv(this: ast::StateSpace) -> spirv::StorageClass { + match this { + ast::StateSpace::Const => spirv::StorageClass::UniformConstant, + ast::StateSpace::Generic => spirv::StorageClass::Generic, + ast::StateSpace::Global => spirv::StorageClass::CrossWorkgroup, + ast::StateSpace::Local => spirv::StorageClass::Function, + ast::StateSpace::Shared => spirv::StorageClass::Workgroup, + ast::StateSpace::Param => spirv::StorageClass::Function, + ast::StateSpace::Reg => spirv::StorageClass::Function, + ast::StateSpace::Sreg => spirv::StorageClass::Input, + ast::StateSpace::ParamEntry + | ast::StateSpace::ParamFunc + | ast::StateSpace::SharedCluster + | ast::StateSpace::SharedCta => todo!(), + } +} + +// TODO: remove this once we have pef-function support for denorms +fn emit_denorm_build_string<'input>( + call_map: &MethodsCallMap, + denorm_information: &HashMap< + ast::MethodName<'input, SpirvWord>, + HashMap, + >, +) -> (CString, bool) { + let denorm_counts = denorm_information + .iter() + .map(|(method, meth_denorm)| { + let f16_count = meth_denorm + .get(&(mem::size_of::() as u8)) + .unwrap_or(&(spirv::FPDenormMode::FlushToZero, 0)) + .1; + let f32_count = meth_denorm + .get(&(mem::size_of::() as u8)) + .unwrap_or(&(spirv::FPDenormMode::FlushToZero, 0)) + .1; + (method, (f16_count + f32_count)) + }) + .collect::>(); + let mut flush_over_preserve = 0; + for (kernel, children) in call_map.kernels() { + flush_over_preserve += *denorm_counts + .get(&ast::MethodName::Kernel(kernel)) + .unwrap_or(&0); + for child_fn in children { + flush_over_preserve += *denorm_counts + .get(&ast::MethodName::Func(*child_fn)) + .unwrap_or(&0); + } + } + if flush_over_preserve > 0 { + ( + CString::new("-ze-take-global-address -ze-denorms-are-zero").unwrap(), + true, + ) + } else { + (CString::new("-ze-take-global-address").unwrap(), false) + } +} + +fn get_globals_use_map<'input>( + directives: Vec>, +) -> ( + Vec>, + HashMap, HashSet>, +) { + let mut known_globals = HashSet::new(); + for directive in directives.iter() { + match directive { + Directive::Variable(_, ast::Variable { name, .. }) => { + known_globals.insert(*name); + } + Directive::Method(..) => {} + } + } + let mut symbol_uses_map = HashMap::new(); + let directives = directives + .into_iter() + .map(|directive| match directive { + Directive::Variable(..) | Directive::Method(Function { body: None, .. }) => directive, + Directive::Method(Function { + func_decl, + body: Some(mut statements), + globals, + import_as, + tuning, + linkage, + }) => { + let method_name = func_decl.borrow().name; + statements = statements + .into_iter() + .map(|statement| { + statement.visit_map( + &mut |symbol, _: Option<(&ast::Type, ast::StateSpace)>, _, _| { + if known_globals.contains(&symbol) { + multi_hash_map_append( + &mut symbol_uses_map, + method_name, + symbol, + ); + } + Ok::<_, TranslateError>(symbol) + }, + ) + }) + .collect::, _>>() + .unwrap(); + Directive::Method(Function { + func_decl, + body: Some(statements), + globals, + import_as, + tuning, + linkage, + }) + } + }) + .collect::>(); + (directives, symbol_uses_map) +} + +fn emit_directives<'input>( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + id_defs: &GlobalStringIdResolver<'input>, + opencl_id: spirv::Word, + should_flush_denorms: bool, + call_map: &MethodsCallMap<'input>, + globals_use_map: HashMap, HashSet>, + directives: Vec>, + kernel_info: &mut HashMap, +) -> Result<(), TranslateError> { + let empty_body = Vec::new(); + for d in directives.iter() { + match d { + Directive::Variable(linking, var) => { + emit_variable(builder, map, id_defs, *linking, &var)?; + } + Directive::Method(f) => { + let f_body = match &f.body { + Some(f) => f, + None => { + if f.linkage.contains(ast::LinkingDirective::EXTERN) { + &empty_body + } else { + continue; + } + } + }; + for var in f.globals.iter() { + emit_variable(builder, map, id_defs, ast::LinkingDirective::NONE, var)?; + } + let func_decl = (*f.func_decl).borrow(); + let fn_id = emit_function_header( + builder, + map, + &id_defs, + &*func_decl, + call_map, + &globals_use_map, + kernel_info, + )?; + if matches!(func_decl.name, ast::MethodName::Kernel(_)) { + if should_flush_denorms { + builder.execution_mode( + fn_id.0, + spirv_headers::ExecutionMode::DenormFlushToZero, + [16], + ); + builder.execution_mode( + fn_id.0, + spirv_headers::ExecutionMode::DenormFlushToZero, + [32], + ); + builder.execution_mode( + fn_id.0, + spirv_headers::ExecutionMode::DenormFlushToZero, + [64], + ); + } + // FP contraction happens when compiling source -> PTX and is illegal at this stage (unless you force it in cuModuleLoadDataEx) + builder.execution_mode( + fn_id.0, + spirv_headers::ExecutionMode::ContractionOff, + [], + ); + for t in f.tuning.iter() { + match *t { + ast::TuningDirective::MaxNtid(nx, ny, nz) => { + builder.execution_mode( + fn_id.0, + spirv_headers::ExecutionMode::MaxWorkgroupSizeINTEL, + [nx, ny, nz], + ); + } + ast::TuningDirective::ReqNtid(nx, ny, nz) => { + builder.execution_mode( + fn_id.0, + spirv_headers::ExecutionMode::LocalSize, + [nx, ny, nz], + ); + } + // Too architecture specific + ast::TuningDirective::MaxNReg(..) + | ast::TuningDirective::MinNCtaPerSm(..) => {} + } + } + } + emit_function_body_ops(builder, map, id_defs, opencl_id, &f_body)?; + emit_function_linkage(builder, id_defs, f, fn_id)?; + builder.select_block(None)?; + builder.end_function()?; + } + } + } + Ok(()) +} + +fn emit_variable<'input>( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + id_defs: &GlobalStringIdResolver<'input>, + linking: ast::LinkingDirective, + var: &ast::Variable, +) -> Result<(), TranslateError> { + let (must_init, st_class) = match var.state_space { + ast::StateSpace::Reg | ast::StateSpace::Param | ast::StateSpace::Local => { + (false, spirv::StorageClass::Function) + } + ast::StateSpace::Global => (true, spirv::StorageClass::CrossWorkgroup), + ast::StateSpace::Shared => (false, spirv::StorageClass::Workgroup), + ast::StateSpace::Const => (false, spirv::StorageClass::UniformConstant), + ast::StateSpace::Generic => todo!(), + ast::StateSpace::Sreg => todo!(), + ast::StateSpace::ParamEntry + | ast::StateSpace::ParamFunc + | ast::StateSpace::SharedCluster + | ast::StateSpace::SharedCta => todo!(), + }; + let initalizer = if var.array_init.len() > 0 { + Some( + map.get_or_add_constant( + builder, + &ast::Type::from(var.v_type.clone()), + &*var.array_init, + )? + .0, + ) + } else if must_init { + let type_id = map.get_or_add(builder, SpirvType::new(var.v_type.clone())); + Some(builder.constant_null(type_id.0, None)) + } else { + None + }; + let ptr_type_id = map.get_or_add(builder, SpirvType::pointer_to(var.v_type.clone(), st_class)); + builder.variable(ptr_type_id.0, Some(var.name.0), st_class, initalizer); + if let Some(align) = var.align { + builder.decorate( + var.name.0, + spirv::Decoration::Alignment, + [dr::Operand::LiteralInt32(align)].iter().cloned(), + ); + } + if var.state_space != ast::StateSpace::Shared + || !linking.contains(ast::LinkingDirective::EXTERN) + { + emit_linking_decoration(builder, id_defs, None, var.name, linking); + } + Ok(()) +} + +fn emit_function_header<'input>( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + defined_globals: &GlobalStringIdResolver<'input>, + func_decl: &ast::MethodDeclaration<'input, SpirvWord>, + call_map: &MethodsCallMap<'input>, + globals_use_map: &HashMap, HashSet>, + kernel_info: &mut HashMap, +) -> Result { + if let ast::MethodName::Kernel(name) = func_decl.name { + let args_lens = func_decl + .input_arguments + .iter() + .map(|param| { + ( + type_size_of(¶m.v_type), + matches!(param.v_type, ast::Type::Pointer(..)), + ) + }) + .collect(); + kernel_info.insert( + name.to_string(), + KernelInfo { + arguments_sizes: args_lens, + uses_shared_mem: func_decl.shared_mem.is_some(), + }, + ); + } + let (ret_type, func_type) = get_function_type( + builder, + map, + effective_input_arguments(func_decl).map(|(_, typ)| typ), + &func_decl.return_arguments, + ); + let fn_id = match func_decl.name { + ast::MethodName::Kernel(name) => { + let fn_id = defined_globals.get_id(name)?; + let interface = globals_use_map + .get(&ast::MethodName::Kernel(name)) + .into_iter() + .flatten() + .copied() + .chain({ + call_map + .get_kernel_children(name) + .copied() + .flat_map(|subfunction| { + globals_use_map + .get(&ast::MethodName::Func(subfunction)) + .into_iter() + .flatten() + .copied() + }) + .into_iter() + }) + .map(|word| word.0) + .collect::>(); + builder.entry_point(spirv::ExecutionModel::Kernel, fn_id.0, name, interface); + fn_id + } + ast::MethodName::Func(name) => name, + }; + builder.begin_function( + ret_type.0, + Some(fn_id.0), + spirv::FunctionControl::NONE, + func_type.0, + )?; + for (name, typ) in effective_input_arguments(func_decl) { + let result_type = map.get_or_add(builder, typ); + builder.function_parameter(Some(name.0), result_type.0)?; + } + Ok(fn_id) +} + +pub fn type_size_of(this: &ast::Type) -> usize { + match this { + ast::Type::Scalar(typ) => typ.size_of() as usize, + ast::Type::Vector(typ, len) => (typ.size_of() as usize) * (*len as usize), + ast::Type::Array(typ, len) => len + .iter() + .fold(typ.size_of() as usize, |x, y| (x as usize) * (*y as usize)), + ast::Type::Pointer(..) => mem::size_of::(), + } +} +fn emit_function_body_ops<'input>( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + id_defs: &GlobalStringIdResolver<'input>, + opencl: spirv::Word, + func: &[ExpandedStatement], +) -> Result<(), TranslateError> { + for s in func { + match s { + Statement::Label(id) => { + if builder.selected_block().is_some() { + builder.branch(id.0)?; + } + builder.begin_block(Some(id.0))?; + } + _ => { + if builder.selected_block().is_none() && builder.selected_function().is_some() { + builder.begin_block(None)?; + } + } + } + match s { + Statement::Label(_) => (), + Statement::Variable(var) => { + emit_variable(builder, map, id_defs, ast::LinkingDirective::NONE, var)?; + } + Statement::Constant(cnst) => { + let typ_id = map.get_or_add_scalar(builder, cnst.typ); + match (cnst.typ, cnst.value) { + (ast::ScalarType::B8, ast::ImmediateValue::U64(value)) + | (ast::ScalarType::U8, ast::ImmediateValue::U64(value)) => { + builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as u8 as u32); + } + (ast::ScalarType::B16, ast::ImmediateValue::U64(value)) + | (ast::ScalarType::U16, ast::ImmediateValue::U64(value)) => { + builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as u16 as u32); + } + (ast::ScalarType::B32, ast::ImmediateValue::U64(value)) + | (ast::ScalarType::U32, ast::ImmediateValue::U64(value)) => { + builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as u32); + } + (ast::ScalarType::B64, ast::ImmediateValue::U64(value)) + | (ast::ScalarType::U64, ast::ImmediateValue::U64(value)) => { + builder.constant_u64(typ_id.0, Some(cnst.dst.0), value); + } + (ast::ScalarType::S8, ast::ImmediateValue::U64(value)) => { + builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as i8 as u32); + } + (ast::ScalarType::S16, ast::ImmediateValue::U64(value)) => { + builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as i16 as u32); + } + (ast::ScalarType::S32, ast::ImmediateValue::U64(value)) => { + builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as i32 as u32); + } + (ast::ScalarType::S64, ast::ImmediateValue::U64(value)) => { + builder.constant_u64(typ_id.0, Some(cnst.dst.0), value as i64 as u64); + } + (ast::ScalarType::B8, ast::ImmediateValue::S64(value)) + | (ast::ScalarType::U8, ast::ImmediateValue::S64(value)) => { + builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as u8 as u32); + } + (ast::ScalarType::B16, ast::ImmediateValue::S64(value)) + | (ast::ScalarType::U16, ast::ImmediateValue::S64(value)) => { + builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as u16 as u32); + } + (ast::ScalarType::B32, ast::ImmediateValue::S64(value)) + | (ast::ScalarType::U32, ast::ImmediateValue::S64(value)) => { + builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as u32); + } + (ast::ScalarType::B64, ast::ImmediateValue::S64(value)) + | (ast::ScalarType::U64, ast::ImmediateValue::S64(value)) => { + builder.constant_u64(typ_id.0, Some(cnst.dst.0), value as u64); + } + (ast::ScalarType::S8, ast::ImmediateValue::S64(value)) => { + builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as i8 as u32); + } + (ast::ScalarType::S16, ast::ImmediateValue::S64(value)) => { + builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as i16 as u32); + } + (ast::ScalarType::S32, ast::ImmediateValue::S64(value)) => { + builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as i32 as u32); + } + (ast::ScalarType::S64, ast::ImmediateValue::S64(value)) => { + builder.constant_u64(typ_id.0, Some(cnst.dst.0), value as u64); + } + (ast::ScalarType::F16, ast::ImmediateValue::F32(value)) => { + builder.constant_f32( + typ_id.0, + Some(cnst.dst.0), + f16::from_f32(value).to_f32(), + ); + } + (ast::ScalarType::F32, ast::ImmediateValue::F32(value)) => { + builder.constant_f32(typ_id.0, Some(cnst.dst.0), value); + } + (ast::ScalarType::F64, ast::ImmediateValue::F32(value)) => { + builder.constant_f64(typ_id.0, Some(cnst.dst.0), value as f64); + } + (ast::ScalarType::F16, ast::ImmediateValue::F64(value)) => { + builder.constant_f32( + typ_id.0, + Some(cnst.dst.0), + f16::from_f64(value).to_f32(), + ); + } + (ast::ScalarType::F32, ast::ImmediateValue::F64(value)) => { + builder.constant_f32(typ_id.0, Some(cnst.dst.0), value as f32); + } + (ast::ScalarType::F64, ast::ImmediateValue::F64(value)) => { + builder.constant_f64(typ_id.0, Some(cnst.dst.0), value); + } + (ast::ScalarType::Pred, ast::ImmediateValue::U64(value)) => { + let bool_type = map.get_or_add_scalar(builder, ast::ScalarType::Pred).0; + if value == 0 { + builder.constant_false(bool_type, Some(cnst.dst.0)); + } else { + builder.constant_true(bool_type, Some(cnst.dst.0)); + } + } + (ast::ScalarType::Pred, ast::ImmediateValue::S64(value)) => { + let bool_type = map.get_or_add_scalar(builder, ast::ScalarType::Pred).0; + if value == 0 { + builder.constant_false(bool_type, Some(cnst.dst.0)); + } else { + builder.constant_true(bool_type, Some(cnst.dst.0)); + } + } + _ => return Err(TranslateError::MismatchedType), + } + } + Statement::Conversion(cv) => emit_implicit_conversion(builder, map, cv)?, + Statement::Conditional(bra) => { + builder.branch_conditional( + bra.predicate.0, + bra.if_true.0, + bra.if_false.0, + iter::empty(), + )?; + } + Statement::FunctionPointer(FunctionPointerDetails { dst, src }) => { + // TODO: implement properly + let zero = map.get_or_add_constant( + builder, + &ast::Type::Scalar(ast::ScalarType::U64), + &vec_repr(0u64), + )?; + let result_type = map.get_or_add_scalar(builder, ast::ScalarType::U64); + builder.copy_object(result_type.0, Some(dst.0), zero.0)?; + } + Statement::Instruction(inst) => match inst { + ast::Instruction::PrmtSlow { .. } | ast::Instruction::Trap { .. } => todo!(), + ast::Instruction::Call { data, arguments } => { + let (result_type, result_id) = + match (&*data.return_arguments, &*arguments.return_arguments) { + ([(type_, space)], [id]) => { + if *space != ast::StateSpace::Reg { + return Err(error_unreachable()); + } + ( + map.get_or_add(builder, SpirvType::new(type_.clone())).0, + Some(id.0), + ) + } + ([], []) => (map.void(), None), + _ => todo!(), + }; + let arg_list = arguments + .input_arguments + .iter() + .map(|id| id.0) + .collect::>(); + builder.function_call(result_type, result_id, arguments.func.0, arg_list)?; + } + ast::Instruction::Abs { data, arguments } => { + emit_abs(builder, map, opencl, data, arguments)? + } + // SPIR-V does not support marking jumps as guaranteed-converged + ast::Instruction::Bra { arguments, .. } => { + builder.branch(arguments.src.0)?; + } + ast::Instruction::Ld { data, arguments } => { + let mem_access = match data.qualifier { + ast::LdStQualifier::Weak => spirv::MemoryAccess::NONE, + // ld.volatile does not match Volatile OpLoad nor Relaxed OpAtomicLoad + ast::LdStQualifier::Volatile => spirv::MemoryAccess::VOLATILE, + _ => return Err(TranslateError::Todo), + }; + let result_type = + map.get_or_add(builder, SpirvType::new(ast::Type::from(data.typ.clone()))); + builder.load( + result_type.0, + Some(arguments.dst.0), + arguments.src.0, + Some(mem_access | spirv::MemoryAccess::ALIGNED), + [dr::Operand::LiteralInt32( + type_size_of(&ast::Type::from(data.typ.clone())) as u32, + )] + .iter() + .cloned(), + )?; + } + ast::Instruction::St { data, arguments } => { + let mem_access = match data.qualifier { + ast::LdStQualifier::Weak => spirv::MemoryAccess::NONE, + // st.volatile does not match Volatile OpStore nor Relaxed OpAtomicStore + ast::LdStQualifier::Volatile => spirv::MemoryAccess::VOLATILE, + _ => return Err(TranslateError::Todo), + }; + builder.store( + arguments.src1.0, + arguments.src2.0, + Some(mem_access | spirv::MemoryAccess::ALIGNED), + [dr::Operand::LiteralInt32( + type_size_of(&ast::Type::from(data.typ.clone())) as u32, + )] + .iter() + .cloned(), + )?; + } + // SPIR-V does not support ret as guaranteed-converged + ast::Instruction::Ret { .. } => builder.ret()?, + ast::Instruction::Mov { data, arguments } => { + let result_type = + map.get_or_add(builder, SpirvType::new(ast::Type::from(data.typ.clone()))); + builder.copy_object(result_type.0, Some(arguments.dst.0), arguments.src.0)?; + } + ast::Instruction::Mul { data, arguments } => match data { + ast::MulDetails::Integer { type_, control } => { + emit_mul_int(builder, map, opencl, *type_, *control, arguments)? + } + ast::MulDetails::Float(ref ctr) => { + emit_mul_float(builder, map, ctr, arguments)? + } + }, + ast::Instruction::Add { data, arguments } => match data { + ast::ArithDetails::Integer(desc) => { + emit_add_int(builder, map, desc.type_.into(), desc.saturate, arguments)? + } + ast::ArithDetails::Float(desc) => { + emit_add_float(builder, map, desc, arguments)? + } + }, + ast::Instruction::Setp { data, arguments } => { + if arguments.dst2.is_some() { + todo!() + } + emit_setp(builder, map, data, arguments)?; + } + ast::Instruction::Not { data, arguments } => { + let result_type = map.get_or_add(builder, SpirvType::from(*data)); + let result_id = Some(arguments.dst.0); + let operand = arguments.src; + match data { + ast::ScalarType::Pred => { + logical_not(builder, result_type.0, result_id, operand.0) + } + _ => builder.not(result_type.0, result_id, operand.0), + }?; + } + ast::Instruction::Shl { data, arguments } => { + let full_type = ast::Type::Scalar(*data); + let size_of = type_size_of(&full_type); + let result_type = map.get_or_add(builder, SpirvType::new(full_type)); + let offset_src = insert_shift_hack(builder, map, arguments.src2.0, size_of)?; + builder.shift_left_logical( + result_type.0, + Some(arguments.dst.0), + arguments.src1.0, + offset_src, + )?; + } + ast::Instruction::Shr { data, arguments } => { + let full_type = ast::ScalarType::from(data.type_); + let size_of = full_type.size_of(); + let result_type = map.get_or_add_scalar(builder, full_type).0; + let offset_src = + insert_shift_hack(builder, map, arguments.src2.0, size_of as usize)?; + match data.kind { + ptx_parser::RightShiftKind::Arithmetic => { + builder.shift_right_arithmetic( + result_type, + Some(arguments.dst.0), + arguments.src1.0, + offset_src, + )?; + } + ptx_parser::RightShiftKind::Logical => { + builder.shift_right_logical( + result_type, + Some(arguments.dst.0), + arguments.src1.0, + offset_src, + )?; + } + } + } + ast::Instruction::Cvt { data, arguments } => { + emit_cvt(builder, map, opencl, data, arguments)?; + } + ast::Instruction::Cvta { data, arguments } => { + // This would be only meaningful if const/slm/global pointers + // had a different format than generic pointers, but they don't pretty much by ptx definition + // Honestly, I have no idea why this instruction exists and is emitted by the compiler + let result_type = map.get_or_add_scalar(builder, ast::ScalarType::B64); + builder.copy_object(result_type.0, Some(arguments.dst.0), arguments.src.0)?; + } + ast::Instruction::SetpBool { .. } => todo!(), + ast::Instruction::Mad { data, arguments } => match data { + ast::MadDetails::Integer { + type_, + control, + saturate, + } => { + if *saturate { + todo!() + } + if type_.kind() == ast::ScalarKind::Signed { + emit_mad_sint(builder, map, opencl, *type_, *control, arguments)? + } else { + emit_mad_uint(builder, map, opencl, *type_, *control, arguments)? + } + } + ast::MadDetails::Float(desc) => { + emit_mad_float(builder, map, opencl, desc, arguments)? + } + }, + ast::Instruction::Fma { data, arguments } => { + emit_fma_float(builder, map, opencl, data, arguments)? + } + ast::Instruction::Or { data, arguments } => { + let result_type = map.get_or_add_scalar(builder, *data).0; + if *data == ast::ScalarType::Pred { + builder.logical_or( + result_type, + Some(arguments.dst.0), + arguments.src1.0, + arguments.src2.0, + )?; + } else { + builder.bitwise_or( + result_type, + Some(arguments.dst.0), + arguments.src1.0, + arguments.src2.0, + )?; + } + } + ast::Instruction::Sub { data, arguments } => match data { + ast::ArithDetails::Integer(desc) => { + emit_sub_int(builder, map, desc.type_.into(), desc.saturate, arguments)?; + } + ast::ArithDetails::Float(desc) => { + emit_sub_float(builder, map, desc, arguments)?; + } + }, + ast::Instruction::Min { data, arguments } => { + emit_min(builder, map, opencl, data, arguments)?; + } + ast::Instruction::Max { data, arguments } => { + emit_max(builder, map, opencl, data, arguments)?; + } + ast::Instruction::Rcp { data, arguments } => { + emit_rcp(builder, map, opencl, data, arguments)?; + } + ast::Instruction::And { data, arguments } => { + let result_type = map.get_or_add_scalar(builder, *data); + if *data == ast::ScalarType::Pred { + builder.logical_and( + result_type.0, + Some(arguments.dst.0), + arguments.src1.0, + arguments.src2.0, + )?; + } else { + builder.bitwise_and( + result_type.0, + Some(arguments.dst.0), + arguments.src1.0, + arguments.src2.0, + )?; + } + } + ast::Instruction::Selp { data, arguments } => { + let result_type = map.get_or_add_scalar(builder, *data); + builder.select( + result_type.0, + Some(arguments.dst.0), + arguments.src3.0, + arguments.src1.0, + arguments.src2.0, + )?; + } + // TODO: implement named barriers + ast::Instruction::Bar { data, arguments } => { + let workgroup_scope = map.get_or_add_constant( + builder, + &ast::Type::Scalar(ast::ScalarType::U32), + &vec_repr(spirv::Scope::Workgroup as u32), + )?; + let barrier_semantics = map.get_or_add_constant( + builder, + &ast::Type::Scalar(ast::ScalarType::U32), + &vec_repr( + spirv::MemorySemantics::CROSS_WORKGROUP_MEMORY + | spirv::MemorySemantics::WORKGROUP_MEMORY + | spirv::MemorySemantics::SEQUENTIALLY_CONSISTENT, + ), + )?; + builder.control_barrier( + workgroup_scope.0, + workgroup_scope.0, + barrier_semantics.0, + )?; + } + ast::Instruction::Atom { data, arguments } => { + emit_atom(builder, map, data, arguments)?; + } + ast::Instruction::AtomCas { data, arguments } => { + let result_type = map.get_or_add_scalar(builder, data.type_); + let memory_const = map.get_or_add_constant( + builder, + &ast::Type::Scalar(ast::ScalarType::U32), + &vec_repr(scope_to_spirv(data.scope) as u32), + )?; + let semantics_const = map.get_or_add_constant( + builder, + &ast::Type::Scalar(ast::ScalarType::U32), + &vec_repr(semantics_to_spirv(data.semantics).bits()), + )?; + builder.atomic_compare_exchange( + result_type.0, + Some(arguments.dst.0), + arguments.src1.0, + memory_const.0, + semantics_const.0, + semantics_const.0, + arguments.src3.0, + arguments.src2.0, + )?; + } + ast::Instruction::Div { data, arguments } => match data { + ast::DivDetails::Unsigned(t) => { + let result_type = map.get_or_add_scalar(builder, (*t).into()); + builder.u_div( + result_type.0, + Some(arguments.dst.0), + arguments.src1.0, + arguments.src2.0, + )?; + } + ast::DivDetails::Signed(t) => { + let result_type = map.get_or_add_scalar(builder, (*t).into()); + builder.s_div( + result_type.0, + Some(arguments.dst.0), + arguments.src1.0, + arguments.src2.0, + )?; + } + ast::DivDetails::Float(t) => { + let result_type = map.get_or_add_scalar(builder, t.type_.into()); + builder.f_div( + result_type.0, + Some(arguments.dst.0), + arguments.src1.0, + arguments.src2.0, + )?; + emit_float_div_decoration(builder, arguments.dst, t.kind); + } + }, + ast::Instruction::Sqrt { data, arguments } => { + emit_sqrt(builder, map, opencl, data, arguments)?; + } + ast::Instruction::Rsqrt { data, arguments } => { + let result_type = map.get_or_add_scalar(builder, data.type_.into()); + builder.ext_inst( + result_type.0, + Some(arguments.dst.0), + opencl, + spirv::CLOp::rsqrt as spirv::Word, + [dr::Operand::IdRef(arguments.src.0)].iter().cloned(), + )?; + } + ast::Instruction::Neg { data, arguments } => { + let result_type = map.get_or_add_scalar(builder, data.type_); + let negate_func = if data.type_.kind() == ast::ScalarKind::Float { + dr::Builder::f_negate + } else { + dr::Builder::s_negate + }; + negate_func( + builder, + result_type.0, + Some(arguments.dst.0), + arguments.src.0, + )?; + } + ast::Instruction::Sin { arguments, .. } => { + let result_type = map.get_or_add_scalar(builder, ast::ScalarType::F32); + builder.ext_inst( + result_type.0, + Some(arguments.dst.0), + opencl, + spirv::CLOp::sin as u32, + [dr::Operand::IdRef(arguments.src.0)].iter().cloned(), + )?; + } + ast::Instruction::Cos { arguments, .. } => { + let result_type = map.get_or_add_scalar(builder, ast::ScalarType::F32); + builder.ext_inst( + result_type.0, + Some(arguments.dst.0), + opencl, + spirv::CLOp::cos as u32, + [dr::Operand::IdRef(arguments.src.0)].iter().cloned(), + )?; + } + ast::Instruction::Lg2 { arguments, .. } => { + let result_type = map.get_or_add_scalar(builder, ast::ScalarType::F32); + builder.ext_inst( + result_type.0, + Some(arguments.dst.0), + opencl, + spirv::CLOp::log2 as u32, + [dr::Operand::IdRef(arguments.src.0)].iter().cloned(), + )?; + } + ast::Instruction::Ex2 { arguments, .. } => { + let result_type = map.get_or_add_scalar(builder, ast::ScalarType::F32); + builder.ext_inst( + result_type.0, + Some(arguments.dst.0), + opencl, + spirv::CLOp::exp2 as u32, + [dr::Operand::IdRef(arguments.src.0)].iter().cloned(), + )?; + } + ast::Instruction::Clz { data, arguments } => { + let result_type = map.get_or_add_scalar(builder, (*data).into()); + builder.ext_inst( + result_type.0, + Some(arguments.dst.0), + opencl, + spirv::CLOp::clz as u32, + [dr::Operand::IdRef(arguments.src.0)].iter().cloned(), + )?; + } + ast::Instruction::Brev { data, arguments } => { + let result_type = map.get_or_add_scalar(builder, (*data).into()); + builder.bit_reverse(result_type.0, Some(arguments.dst.0), arguments.src.0)?; + } + ast::Instruction::Popc { data, arguments } => { + let result_type = map.get_or_add_scalar(builder, (*data).into()); + builder.bit_count(result_type.0, Some(arguments.dst.0), arguments.src.0)?; + } + ast::Instruction::Xor { data, arguments } => { + let builder_fn: fn( + &mut dr::Builder, + u32, + Option, + u32, + u32, + ) -> Result = match data { + ast::ScalarType::Pred => emit_logical_xor_spirv, + _ => dr::Builder::bitwise_xor, + }; + let result_type = map.get_or_add_scalar(builder, (*data).into()); + builder_fn( + builder, + result_type.0, + Some(arguments.dst.0), + arguments.src1.0, + arguments.src2.0, + )?; + } + ast::Instruction::Bfe { .. } + | ast::Instruction::Bfi { .. } + | ast::Instruction::Activemask { .. } => { + // Should have beeen replaced with a funciton call earlier + return Err(error_unreachable()); + } + + ast::Instruction::Rem { data, arguments } => { + let builder_fn = if data.kind() == ast::ScalarKind::Signed { + dr::Builder::s_mod + } else { + dr::Builder::u_mod + }; + let result_type = map.get_or_add_scalar(builder, (*data).into()); + builder_fn( + builder, + result_type.0, + Some(arguments.dst.0), + arguments.src1.0, + arguments.src2.0, + )?; + } + ast::Instruction::Prmt { data, arguments } => { + let control = *data as u32; + let components = [ + (control >> 0) & 0b1111, + (control >> 4) & 0b1111, + (control >> 8) & 0b1111, + (control >> 12) & 0b1111, + ]; + if components.iter().any(|&c| c > 7) { + return Err(TranslateError::Todo); + } + let vec4_b8_type = + map.get_or_add(builder, SpirvType::Vector(SpirvScalarKey::B8, 4)); + let b32_type = map.get_or_add_scalar(builder, ast::ScalarType::B32); + let src1_vector = builder.bitcast(vec4_b8_type.0, None, arguments.src1.0)?; + let src2_vector = builder.bitcast(vec4_b8_type.0, None, arguments.src2.0)?; + let dst_vector = builder.vector_shuffle( + vec4_b8_type.0, + None, + src1_vector, + src2_vector, + components, + )?; + builder.bitcast(b32_type.0, Some(arguments.dst.0), dst_vector)?; + } + ast::Instruction::Membar { data } => { + let (scope, semantics) = match data { + ast::MemScope::Cta => ( + spirv::Scope::Workgroup, + spirv::MemorySemantics::CROSS_WORKGROUP_MEMORY + | spirv::MemorySemantics::WORKGROUP_MEMORY + | spirv::MemorySemantics::SEQUENTIALLY_CONSISTENT, + ), + ast::MemScope::Gpu => ( + spirv::Scope::Device, + spirv::MemorySemantics::CROSS_WORKGROUP_MEMORY + | spirv::MemorySemantics::WORKGROUP_MEMORY + | spirv::MemorySemantics::SEQUENTIALLY_CONSISTENT, + ), + ast::MemScope::Sys => ( + spirv::Scope::CrossDevice, + spirv::MemorySemantics::CROSS_WORKGROUP_MEMORY + | spirv::MemorySemantics::WORKGROUP_MEMORY + | spirv::MemorySemantics::SEQUENTIALLY_CONSISTENT, + ), + + ast::MemScope::Cluster => todo!(), + }; + let spirv_scope = map.get_or_add_constant( + builder, + &ast::Type::Scalar(ast::ScalarType::U32), + &vec_repr(scope as u32), + )?; + let spirv_semantics = map.get_or_add_constant( + builder, + &ast::Type::Scalar(ast::ScalarType::U32), + &vec_repr(semantics), + )?; + builder.memory_barrier(spirv_scope.0, spirv_semantics.0)?; + } + }, + Statement::LoadVar(details) => { + emit_load_var(builder, map, details)?; + } + Statement::StoreVar(details) => { + let dst_ptr = match details.member_index { + Some(index) => { + let result_ptr_type = map.get_or_add( + builder, + SpirvType::pointer_to( + details.typ.clone(), + spirv::StorageClass::Function, + ), + ); + let index_spirv = map.get_or_add_constant( + builder, + &ast::Type::Scalar(ast::ScalarType::U32), + &vec_repr(index as u32), + )?; + builder.in_bounds_access_chain( + result_ptr_type.0, + None, + details.arg.src1.0, + [index_spirv.0].iter().copied(), + )? + } + None => details.arg.src1.0, + }; + builder.store(dst_ptr, details.arg.src2.0, None, iter::empty())?; + } + Statement::RetValue(_, id) => { + builder.ret_value(id.0)?; + } + Statement::PtrAccess(PtrAccess { + underlying_type, + state_space, + dst, + ptr_src, + offset_src, + }) => { + let u8_pointer = map.get_or_add( + builder, + SpirvType::new(ast::Type::Pointer(ast::ScalarType::U8, *state_space)), + ); + let result_type = map.get_or_add( + builder, + SpirvType::pointer_to(underlying_type.clone(), space_to_spirv(*state_space)), + ); + let ptr_src_u8 = builder.bitcast(u8_pointer.0, None, ptr_src.0)?; + let temp = builder.in_bounds_ptr_access_chain( + u8_pointer.0, + None, + ptr_src_u8, + offset_src.0, + iter::empty(), + )?; + builder.bitcast(result_type.0, Some(dst.0), temp)?; + } + Statement::RepackVector(repack) => { + if repack.is_extract { + let scalar_type = map.get_or_add_scalar(builder, repack.typ); + for (index, dst_id) in repack.unpacked.iter().enumerate() { + builder.composite_extract( + scalar_type.0, + Some(dst_id.0), + repack.packed.0, + [index as u32].iter().copied(), + )?; + } + } else { + let vector_type = map.get_or_add( + builder, + SpirvType::Vector( + SpirvScalarKey::from(repack.typ), + repack.unpacked.len() as u8, + ), + ); + let mut temp_vec = builder.undef(vector_type.0, None); + for (index, src_id) in repack.unpacked.iter().enumerate() { + temp_vec = builder.composite_insert( + vector_type.0, + None, + src_id.0, + temp_vec, + [index as u32].iter().copied(), + )?; + } + builder.copy_object(vector_type.0, Some(repack.packed.0), temp_vec)?; + } + } + } + } + Ok(()) +} + +fn emit_function_linkage<'input>( + builder: &mut dr::Builder, + id_defs: &GlobalStringIdResolver<'input>, + f: &Function, + fn_name: SpirvWord, +) -> Result<(), TranslateError> { + if f.linkage == ast::LinkingDirective::NONE { + return Ok(()); + }; + let linking_name = match f.func_decl.borrow().name { + // According to SPIR-V rules linkage attributes are invalid on kernels + ast::MethodName::Kernel(..) => return Ok(()), + ast::MethodName::Func(fn_id) => f.import_as.as_deref().map_or_else( + || match id_defs.reverse_variables.get(&fn_id) { + Some(fn_name) => Ok(fn_name), + None => Err(error_unknown_symbol()), + }, + Result::Ok, + )?, + }; + emit_linking_decoration(builder, id_defs, Some(linking_name), fn_name, f.linkage); + Ok(()) +} + +fn get_function_type( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + spirv_input: impl Iterator, + spirv_output: &[ast::Variable], +) -> (SpirvWord, SpirvWord) { + map.get_or_add_fn( + builder, + spirv_input, + spirv_output + .iter() + .map(|var| SpirvType::new(var.v_type.clone())), + ) +} + +fn emit_linking_decoration<'input>( + builder: &mut dr::Builder, + id_defs: &GlobalStringIdResolver<'input>, + name_override: Option<&str>, + name: SpirvWord, + linking: ast::LinkingDirective, +) { + if linking == ast::LinkingDirective::NONE { + return; + } + if linking.contains(ast::LinkingDirective::VISIBLE) { + let string_name = + name_override.unwrap_or_else(|| id_defs.reverse_variables.get(&name).unwrap()); + builder.decorate( + name.0, + spirv::Decoration::LinkageAttributes, + [ + dr::Operand::LiteralString(string_name.to_string()), + dr::Operand::LinkageType(spirv::LinkageType::Export), + ] + .iter() + .cloned(), + ); + } else if linking.contains(ast::LinkingDirective::EXTERN) { + let string_name = + name_override.unwrap_or_else(|| id_defs.reverse_variables.get(&name).unwrap()); + builder.decorate( + name.0, + spirv::Decoration::LinkageAttributes, + [ + dr::Operand::LiteralString(string_name.to_string()), + dr::Operand::LinkageType(spirv::LinkageType::Import), + ] + .iter() + .cloned(), + ); + } + // TODO: handle LinkingDirective::WEAK +} + +fn effective_input_arguments<'a>( + this: &'a ast::MethodDeclaration<'a, SpirvWord>, +) -> impl Iterator + 'a { + let is_kernel = matches!(this.name, ast::MethodName::Kernel(_)); + this.input_arguments.iter().map(move |arg| { + if !is_kernel && arg.state_space != ast::StateSpace::Reg { + let spirv_type = + SpirvType::pointer_to(arg.v_type.clone(), space_to_spirv(arg.state_space)); + (arg.name, spirv_type) + } else { + (arg.name, SpirvType::new(arg.v_type.clone())) + } + }) +} + +fn emit_implicit_conversion( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + cv: &ImplicitConversion, +) -> Result<(), TranslateError> { + let from_parts = to_parts(&cv.from_type); + let to_parts = to_parts(&cv.to_type); + match (from_parts.kind, to_parts.kind, &cv.kind) { + (_, _, &ConversionKind::BitToPtr) => { + let dst_type = map.get_or_add( + builder, + SpirvType::pointer_to(cv.to_type.clone(), space_to_spirv(cv.to_space)), + ); + builder.convert_u_to_ptr(dst_type.0, Some(cv.dst.0), cv.src.0)?; + } + (TypeKind::Scalar, TypeKind::Scalar, &ConversionKind::Default) => { + if from_parts.width == to_parts.width { + let dst_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); + if from_parts.scalar_kind != ast::ScalarKind::Float + && to_parts.scalar_kind != ast::ScalarKind::Float + { + // It is noop, but another instruction expects result of this conversion + builder.copy_object(dst_type.0, Some(cv.dst.0), cv.src.0)?; + } else { + builder.bitcast(dst_type.0, Some(cv.dst.0), cv.src.0)?; + } + } else { + // This block is safe because it's illegal to implictly convert between floating point values + let same_width_bit_type = map.get_or_add( + builder, + SpirvType::new(type_from_parts(TypeParts { + scalar_kind: ast::ScalarKind::Bit, + ..from_parts + })), + ); + let same_width_bit_value = + builder.bitcast(same_width_bit_type.0, None, cv.src.0)?; + let wide_bit_type = type_from_parts(TypeParts { + scalar_kind: ast::ScalarKind::Bit, + ..to_parts + }); + let wide_bit_type_spirv = + map.get_or_add(builder, SpirvType::new(wide_bit_type.clone())); + if to_parts.scalar_kind == ast::ScalarKind::Unsigned + || to_parts.scalar_kind == ast::ScalarKind::Bit + { + builder.u_convert( + wide_bit_type_spirv.0, + Some(cv.dst.0), + same_width_bit_value, + )?; + } else { + let conversion_fn = if from_parts.scalar_kind == ast::ScalarKind::Signed + && to_parts.scalar_kind == ast::ScalarKind::Signed + { + dr::Builder::s_convert + } else { + dr::Builder::u_convert + }; + let wide_bit_value = + conversion_fn(builder, wide_bit_type_spirv.0, None, same_width_bit_value)?; + emit_implicit_conversion( + builder, + map, + &ImplicitConversion { + src: SpirvWord(wide_bit_value), + dst: cv.dst, + from_type: wide_bit_type, + from_space: cv.from_space, + to_type: cv.to_type.clone(), + to_space: cv.to_space, + kind: ConversionKind::Default, + }, + )?; + } + } + } + (TypeKind::Scalar, TypeKind::Scalar, &ConversionKind::SignExtend) => { + let result_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); + builder.s_convert(result_type.0, Some(cv.dst.0), cv.src.0)?; + } + (TypeKind::Vector, TypeKind::Scalar, &ConversionKind::Default) + | (TypeKind::Scalar, TypeKind::Array, &ConversionKind::Default) + | (TypeKind::Array, TypeKind::Scalar, &ConversionKind::Default) => { + let into_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); + builder.bitcast(into_type.0, Some(cv.dst.0), cv.src.0)?; + } + (_, _, &ConversionKind::PtrToPtr) => { + let result_type = map.get_or_add( + builder, + SpirvType::Pointer( + Box::new(SpirvType::new(cv.to_type.clone())), + space_to_spirv(cv.to_space), + ), + ); + if cv.to_space == ast::StateSpace::Generic && cv.from_space != ast::StateSpace::Generic + { + let src = if cv.from_type != cv.to_type { + let temp_type = map.get_or_add( + builder, + SpirvType::Pointer( + Box::new(SpirvType::new(cv.to_type.clone())), + space_to_spirv(cv.from_space), + ), + ); + builder.bitcast(temp_type.0, None, cv.src.0)? + } else { + cv.src.0 + }; + builder.ptr_cast_to_generic(result_type.0, Some(cv.dst.0), src)?; + } else if cv.from_space == ast::StateSpace::Generic + && cv.to_space != ast::StateSpace::Generic + { + let src = if cv.from_type != cv.to_type { + let temp_type = map.get_or_add( + builder, + SpirvType::Pointer( + Box::new(SpirvType::new(cv.to_type.clone())), + space_to_spirv(cv.from_space), + ), + ); + builder.bitcast(temp_type.0, None, cv.src.0)? + } else { + cv.src.0 + }; + builder.generic_cast_to_ptr(result_type.0, Some(cv.dst.0), src)?; + } else { + builder.bitcast(result_type.0, Some(cv.dst.0), cv.src.0)?; + } + } + (_, _, &ConversionKind::AddressOf) => { + let dst_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); + builder.convert_ptr_to_u(dst_type.0, Some(cv.dst.0), cv.src.0)?; + } + (TypeKind::Pointer, TypeKind::Scalar, &ConversionKind::Default) => { + let result_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); + builder.convert_ptr_to_u(result_type.0, Some(cv.dst.0), cv.src.0)?; + } + (TypeKind::Scalar, TypeKind::Pointer, &ConversionKind::Default) => { + let result_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); + builder.convert_u_to_ptr(result_type.0, Some(cv.dst.0), cv.src.0)?; + } + _ => unreachable!(), + } + Ok(()) +} + +fn vec_repr(t: T) -> Vec { + let mut result = vec![0; mem::size_of::()]; + unsafe { std::ptr::copy_nonoverlapping(&t, result.as_mut_ptr() as *mut _, 1) }; + result +} + +fn emit_abs( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + opencl: spirv::Word, + d: &ast::TypeFtz, + arg: &ast::AbsArgs, +) -> Result<(), dr::Error> { + let scalar_t = ast::ScalarType::from(d.type_); + let result_type = map.get_or_add(builder, SpirvType::from(scalar_t)); + let cl_abs = if scalar_t.kind() == ast::ScalarKind::Signed { + spirv::CLOp::s_abs + } else { + spirv::CLOp::fabs + }; + builder.ext_inst( + result_type.0, + Some(arg.dst.0), + opencl, + cl_abs as spirv::Word, + [dr::Operand::IdRef(arg.src.0)].iter().cloned(), + )?; + Ok(()) +} + +fn emit_mul_int( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + opencl: spirv::Word, + type_: ast::ScalarType, + control: ast::MulIntControl, + arg: &ast::MulArgs, +) -> Result<(), dr::Error> { + let inst_type = map.get_or_add(builder, SpirvType::from(type_)); + match control { + ast::MulIntControl::Low => { + builder.i_mul(inst_type.0, Some(arg.dst.0), arg.src1.0, arg.src2.0)?; + } + ast::MulIntControl::High => { + builder.ext_inst( + inst_type.0, + Some(arg.dst.0), + opencl, + spirv::CLOp::s_mul_hi as spirv::Word, + [ + dr::Operand::IdRef(arg.src1.0), + dr::Operand::IdRef(arg.src2.0), + ] + .iter() + .cloned(), + )?; + } + ast::MulIntControl::Wide => { + let instr_width = type_.size_of(); + let instr_kind = type_.kind(); + let dst_type = scalar_from_parts(instr_width * 2, instr_kind); + let dst_type_id = map.get_or_add_scalar(builder, dst_type); + let (src1, src2) = if type_.kind() == ast::ScalarKind::Signed { + let src1 = builder.s_convert(dst_type_id.0, None, arg.src1.0)?; + let src2 = builder.s_convert(dst_type_id.0, None, arg.src2.0)?; + (src1, src2) + } else { + let src1 = builder.u_convert(dst_type_id.0, None, arg.src1.0)?; + let src2 = builder.u_convert(dst_type_id.0, None, arg.src2.0)?; + (src1, src2) + }; + builder.i_mul(dst_type_id.0, Some(arg.dst.0), src1, src2)?; + builder.decorate(arg.dst.0, spirv::Decoration::NoSignedWrap, iter::empty()); + } + } + Ok(()) +} + +fn emit_mul_float( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + ctr: &ast::ArithFloat, + arg: &ast::MulArgs, +) -> Result<(), dr::Error> { + if ctr.saturate { + todo!() + } + let result_type = map.get_or_add_scalar(builder, ctr.type_.into()); + builder.f_mul(result_type.0, Some(arg.dst.0), arg.src1.0, arg.src2.0)?; + emit_rounding_decoration(builder, arg.dst, ctr.rounding); + Ok(()) +} + +fn scalar_from_parts(width: u8, kind: ast::ScalarKind) -> ast::ScalarType { + match kind { + ast::ScalarKind::Float => match width { + 2 => ast::ScalarType::F16, + 4 => ast::ScalarType::F32, + 8 => ast::ScalarType::F64, + _ => unreachable!(), + }, + ast::ScalarKind::Bit => match width { + 1 => ast::ScalarType::B8, + 2 => ast::ScalarType::B16, + 4 => ast::ScalarType::B32, + 8 => ast::ScalarType::B64, + _ => unreachable!(), + }, + ast::ScalarKind::Signed => match width { + 1 => ast::ScalarType::S8, + 2 => ast::ScalarType::S16, + 4 => ast::ScalarType::S32, + 8 => ast::ScalarType::S64, + _ => unreachable!(), + }, + ast::ScalarKind::Unsigned => match width { + 1 => ast::ScalarType::U8, + 2 => ast::ScalarType::U16, + 4 => ast::ScalarType::U32, + 8 => ast::ScalarType::U64, + _ => unreachable!(), + }, + ast::ScalarKind::Pred => ast::ScalarType::Pred, + } +} + +fn emit_rounding_decoration( + builder: &mut dr::Builder, + dst: SpirvWord, + rounding: Option, +) { + if let Some(rounding) = rounding { + builder.decorate( + dst.0, + spirv::Decoration::FPRoundingMode, + [rounding_to_spirv(rounding)].iter().cloned(), + ); + } +} + +fn rounding_to_spirv(this: ast::RoundingMode) -> rspirv::dr::Operand { + let mode = match this { + ast::RoundingMode::NearestEven => spirv::FPRoundingMode::RTE, + ast::RoundingMode::Zero => spirv::FPRoundingMode::RTZ, + ast::RoundingMode::PositiveInf => spirv::FPRoundingMode::RTP, + ast::RoundingMode::NegativeInf => spirv::FPRoundingMode::RTN, + }; + rspirv::dr::Operand::FPRoundingMode(mode) +} + +fn emit_add_int( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + typ: ast::ScalarType, + saturate: bool, + arg: &ast::AddArgs, +) -> Result<(), dr::Error> { + if saturate { + todo!() + } + let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(typ))); + builder.i_add(inst_type.0, Some(arg.dst.0), arg.src1.0, arg.src2.0)?; + Ok(()) +} + +fn emit_add_float( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + desc: &ast::ArithFloat, + arg: &ast::AddArgs, +) -> Result<(), dr::Error> { + let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.type_))); + builder.f_add(inst_type.0, Some(arg.dst.0), arg.src1.0, arg.src2.0)?; + emit_rounding_decoration(builder, arg.dst, desc.rounding); + Ok(()) +} + +fn emit_setp( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + setp: &ast::SetpData, + arg: &ast::SetpArgs, +) -> Result<(), dr::Error> { + let result_type = map + .get_or_add(builder, SpirvType::Base(SpirvScalarKey::Pred)) + .0; + let result_id = Some(arg.dst1.0); + let operand_1 = arg.src1.0; + let operand_2 = arg.src2.0; + match setp.cmp_op { + ast::SetpCompareOp::Integer(ast::SetpCompareInt::Eq) => { + builder.i_equal(result_type, result_id, operand_1, operand_2) + } + ast::SetpCompareOp::Float(ast::SetpCompareFloat::Eq) => { + builder.f_ord_equal(result_type, result_id, operand_1, operand_2) + } + ast::SetpCompareOp::Integer(ast::SetpCompareInt::NotEq) => { + builder.i_not_equal(result_type, result_id, operand_1, operand_2) + } + ast::SetpCompareOp::Float(ast::SetpCompareFloat::NotEq) => { + builder.f_ord_not_equal(result_type, result_id, operand_1, operand_2) + } + ast::SetpCompareOp::Integer(ast::SetpCompareInt::UnsignedLess) => { + builder.u_less_than(result_type, result_id, operand_1, operand_2) + } + ast::SetpCompareOp::Integer(ast::SetpCompareInt::SignedLess) => { + builder.s_less_than(result_type, result_id, operand_1, operand_2) + } + ast::SetpCompareOp::Float(ast::SetpCompareFloat::Less) => { + builder.f_ord_less_than(result_type, result_id, operand_1, operand_2) + } + ast::SetpCompareOp::Integer(ast::SetpCompareInt::UnsignedLessOrEq) => { + builder.u_less_than_equal(result_type, result_id, operand_1, operand_2) + } + ast::SetpCompareOp::Integer(ast::SetpCompareInt::SignedLessOrEq) => { + builder.s_less_than_equal(result_type, result_id, operand_1, operand_2) + } + ast::SetpCompareOp::Float(ast::SetpCompareFloat::LessOrEq) => { + builder.f_ord_less_than_equal(result_type, result_id, operand_1, operand_2) + } + ast::SetpCompareOp::Integer(ast::SetpCompareInt::UnsignedGreater) => { + builder.u_greater_than(result_type, result_id, operand_1, operand_2) + } + ast::SetpCompareOp::Integer(ast::SetpCompareInt::SignedGreater) => { + builder.s_greater_than(result_type, result_id, operand_1, operand_2) + } + ast::SetpCompareOp::Float(ast::SetpCompareFloat::Greater) => { + builder.f_ord_greater_than(result_type, result_id, operand_1, operand_2) + } + ast::SetpCompareOp::Integer(ast::SetpCompareInt::UnsignedGreaterOrEq) => { + builder.u_greater_than_equal(result_type, result_id, operand_1, operand_2) + } + ast::SetpCompareOp::Integer(ast::SetpCompareInt::SignedGreaterOrEq) => { + builder.s_greater_than_equal(result_type, result_id, operand_1, operand_2) + } + ast::SetpCompareOp::Float(ast::SetpCompareFloat::GreaterOrEq) => { + builder.f_ord_greater_than_equal(result_type, result_id, operand_1, operand_2) + } + ast::SetpCompareOp::Float(ast::SetpCompareFloat::NanEq) => { + builder.f_unord_equal(result_type, result_id, operand_1, operand_2) + } + ast::SetpCompareOp::Float(ast::SetpCompareFloat::NanNotEq) => { + builder.f_unord_not_equal(result_type, result_id, operand_1, operand_2) + } + ast::SetpCompareOp::Float(ast::SetpCompareFloat::NanLess) => { + builder.f_unord_less_than(result_type, result_id, operand_1, operand_2) + } + ast::SetpCompareOp::Float(ast::SetpCompareFloat::NanLessOrEq) => { + builder.f_unord_less_than_equal(result_type, result_id, operand_1, operand_2) + } + ast::SetpCompareOp::Float(ast::SetpCompareFloat::NanGreater) => { + builder.f_unord_greater_than(result_type, result_id, operand_1, operand_2) + } + ast::SetpCompareOp::Float(ast::SetpCompareFloat::NanGreaterOrEq) => { + builder.f_unord_greater_than_equal(result_type, result_id, operand_1, operand_2) + } + ast::SetpCompareOp::Float(ast::SetpCompareFloat::IsAnyNan) => { + let temp1 = builder.is_nan(result_type, None, operand_1)?; + let temp2 = builder.is_nan(result_type, None, operand_2)?; + builder.logical_or(result_type, result_id, temp1, temp2) + } + ast::SetpCompareOp::Float(ast::SetpCompareFloat::IsNotNan) => { + let temp1 = builder.is_nan(result_type, None, operand_1)?; + let temp2 = builder.is_nan(result_type, None, operand_2)?; + let any_nan = builder.logical_or(result_type, None, temp1, temp2)?; + logical_not(builder, result_type, result_id, any_nan) + } + _ => todo!(), + }?; + Ok(()) +} + +// HACK ALERT +// Temporary workaround until IGC gets its shit together +// Currently IGC carries two copies of SPIRV-LLVM translator +// a new one in /llvm-spirv/ and old one in /IGC/AdaptorOCL/SPIRV/. +// Obviously, old and buggy one is used for compiling L0 SPIRV +// https://github.com/intel/intel-graphics-compiler/issues/148 +fn logical_not( + builder: &mut dr::Builder, + result_type: spirv::Word, + result_id: Option, + operand: spirv::Word, +) -> Result { + let const_true = builder.constant_true(result_type, None); + let const_false = builder.constant_false(result_type, None); + builder.select(result_type, result_id, operand, const_false, const_true) +} + +// HACK ALERT +// For some reason IGC fails linking if the value and shift size are of different type +fn insert_shift_hack( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + offset_var: spirv::Word, + size_of: usize, +) -> Result { + let result_type = match size_of { + 2 => map.get_or_add_scalar(builder, ast::ScalarType::B16), + 8 => map.get_or_add_scalar(builder, ast::ScalarType::B64), + 4 => return Ok(offset_var), + _ => return Err(error_unreachable()), + }; + Ok(builder.u_convert(result_type.0, None, offset_var)?) +} + +fn emit_cvt( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + opencl: spirv::Word, + dets: &ast::CvtDetails, + arg: &ast::CvtArgs, +) -> Result<(), TranslateError> { + match dets.mode { + ptx_parser::CvtMode::SignExtend => { + let cv = ImplicitConversion { + src: arg.src, + dst: arg.dst, + from_type: dets.from.into(), + from_space: ast::StateSpace::Reg, + to_type: dets.to.into(), + to_space: ast::StateSpace::Reg, + kind: ConversionKind::SignExtend, + }; + emit_implicit_conversion(builder, map, &cv)?; + } + ptx_parser::CvtMode::ZeroExtend + | ptx_parser::CvtMode::Truncate + | ptx_parser::CvtMode::Bitcast => { + let cv = ImplicitConversion { + src: arg.src, + dst: arg.dst, + from_type: dets.from.into(), + from_space: ast::StateSpace::Reg, + to_type: dets.to.into(), + to_space: ast::StateSpace::Reg, + kind: ConversionKind::Default, + }; + emit_implicit_conversion(builder, map, &cv)?; + } + ptx_parser::CvtMode::SaturateUnsignedToSigned => { + let result_type = map.get_or_add(builder, SpirvType::from(dets.to)); + builder.sat_convert_u_to_s(result_type.0, Some(arg.dst.0), arg.src.0)?; + } + ptx_parser::CvtMode::SaturateSignedToUnsigned => { + let result_type = map.get_or_add(builder, SpirvType::from(dets.to)); + builder.sat_convert_s_to_u(result_type.0, Some(arg.dst.0), arg.src.0)?; + } + ptx_parser::CvtMode::FPExtend { flush_to_zero } => { + if flush_to_zero == Some(true) { + todo!() + } + let result_type = map.get_or_add(builder, SpirvType::from(dets.to)); + builder.f_convert(result_type.0, Some(arg.dst.0), arg.src.0)?; + } + ptx_parser::CvtMode::FPTruncate { + rounding, + flush_to_zero, + } => { + if flush_to_zero == Some(true) { + todo!() + } + let result_type = map.get_or_add(builder, SpirvType::from(dets.to)); + builder.f_convert(result_type.0, Some(arg.dst.0), arg.src.0)?; + emit_rounding_decoration(builder, arg.dst, Some(rounding)); + } + ptx_parser::CvtMode::FPRound { + integer_rounding, + flush_to_zero, + } => { + if flush_to_zero == Some(true) { + todo!() + } + let result_type = map.get_or_add(builder, SpirvType::from(dets.to)); + match integer_rounding { + Some(ast::RoundingMode::NearestEven) => { + builder.ext_inst( + result_type.0, + Some(arg.dst.0), + opencl, + spirv::CLOp::rint as u32, + [dr::Operand::IdRef(arg.src.0)].iter().cloned(), + )?; + } + Some(ast::RoundingMode::Zero) => { + builder.ext_inst( + result_type.0, + Some(arg.dst.0), + opencl, + spirv::CLOp::trunc as u32, + [dr::Operand::IdRef(arg.src.0)].iter().cloned(), + )?; + } + Some(ast::RoundingMode::NegativeInf) => { + builder.ext_inst( + result_type.0, + Some(arg.dst.0), + opencl, + spirv::CLOp::floor as u32, + [dr::Operand::IdRef(arg.src.0)].iter().cloned(), + )?; + } + Some(ast::RoundingMode::PositiveInf) => { + builder.ext_inst( + result_type.0, + Some(arg.dst.0), + opencl, + spirv::CLOp::ceil as u32, + [dr::Operand::IdRef(arg.src.0)].iter().cloned(), + )?; + } + None => { + builder.copy_object(result_type.0, Some(arg.dst.0), arg.src.0)?; + } + } + } + ptx_parser::CvtMode::SignedFromFP { + rounding, + flush_to_zero, + } => { + if flush_to_zero == Some(true) { + todo!() + } + let dest_t: ast::ScalarType = dets.to.into(); + let result_type = map.get_or_add(builder, SpirvType::from(dest_t)); + builder.convert_f_to_s(result_type.0, Some(arg.dst.0), arg.src.0)?; + emit_rounding_decoration(builder, arg.dst, Some(rounding)); + } + ptx_parser::CvtMode::UnsignedFromFP { + rounding, + flush_to_zero, + } => { + if flush_to_zero == Some(true) { + todo!() + } + let dest_t: ast::ScalarType = dets.to.into(); + let result_type = map.get_or_add(builder, SpirvType::from(dest_t)); + builder.convert_f_to_u(result_type.0, Some(arg.dst.0), arg.src.0)?; + emit_rounding_decoration(builder, arg.dst, Some(rounding)); + } + ptx_parser::CvtMode::FPFromSigned(rounding) => { + let dest_t: ast::ScalarType = dets.to.into(); + let result_type = map.get_or_add(builder, SpirvType::from(dest_t)); + builder.convert_s_to_f(result_type.0, Some(arg.dst.0), arg.src.0)?; + emit_rounding_decoration(builder, arg.dst, Some(rounding)); + } + ptx_parser::CvtMode::FPFromUnsigned(rounding) => { + let dest_t: ast::ScalarType = dets.to.into(); + let result_type = map.get_or_add(builder, SpirvType::from(dest_t)); + builder.convert_u_to_f(result_type.0, Some(arg.dst.0), arg.src.0)?; + emit_rounding_decoration(builder, arg.dst, Some(rounding)); + } + } + Ok(()) +} + +fn emit_mad_uint( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + opencl: spirv::Word, + type_: ast::ScalarType, + control: ast::MulIntControl, + arg: &ast::MadArgs, +) -> Result<(), dr::Error> { + let inst_type = map + .get_or_add(builder, SpirvType::from(ast::ScalarType::from(type_))) + .0; + match control { + ast::MulIntControl::Low => { + let mul_result = builder.i_mul(inst_type, None, arg.src1.0, arg.src2.0)?; + builder.i_add(inst_type, Some(arg.dst.0), arg.src3.0, mul_result)?; + } + ast::MulIntControl::High => { + builder.ext_inst( + inst_type, + Some(arg.dst.0), + opencl, + spirv::CLOp::u_mad_hi as spirv::Word, + [ + dr::Operand::IdRef(arg.src1.0), + dr::Operand::IdRef(arg.src2.0), + dr::Operand::IdRef(arg.src3.0), + ] + .iter() + .cloned(), + )?; + } + ast::MulIntControl::Wide => todo!(), + }; + Ok(()) +} + +fn emit_mad_sint( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + opencl: spirv::Word, + type_: ast::ScalarType, + control: ast::MulIntControl, + arg: &ast::MadArgs, +) -> Result<(), dr::Error> { + let inst_type = map.get_or_add(builder, SpirvType::from(type_)).0; + match control { + ast::MulIntControl::Low => { + let mul_result = builder.i_mul(inst_type, None, arg.src1.0, arg.src2.0)?; + builder.i_add(inst_type, Some(arg.dst.0), arg.src3.0, mul_result)?; + } + ast::MulIntControl::High => { + builder.ext_inst( + inst_type, + Some(arg.dst.0), + opencl, + spirv::CLOp::s_mad_hi as spirv::Word, + [ + dr::Operand::IdRef(arg.src1.0), + dr::Operand::IdRef(arg.src2.0), + dr::Operand::IdRef(arg.src3.0), + ] + .iter() + .cloned(), + )?; + } + ast::MulIntControl::Wide => todo!(), + }; + Ok(()) +} + +fn emit_mad_float( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + opencl: spirv::Word, + desc: &ast::ArithFloat, + arg: &ast::MadArgs, +) -> Result<(), dr::Error> { + let inst_type = map + .get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.type_))) + .0; + builder.ext_inst( + inst_type, + Some(arg.dst.0), + opencl, + spirv::CLOp::mad as spirv::Word, + [ + dr::Operand::IdRef(arg.src1.0), + dr::Operand::IdRef(arg.src2.0), + dr::Operand::IdRef(arg.src3.0), + ] + .iter() + .cloned(), + )?; + Ok(()) +} + +fn emit_fma_float( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + opencl: spirv::Word, + desc: &ast::ArithFloat, + arg: &ast::FmaArgs, +) -> Result<(), dr::Error> { + let inst_type = map + .get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.type_))) + .0; + builder.ext_inst( + inst_type, + Some(arg.dst.0), + opencl, + spirv::CLOp::fma as spirv::Word, + [ + dr::Operand::IdRef(arg.src1.0), + dr::Operand::IdRef(arg.src2.0), + dr::Operand::IdRef(arg.src3.0), + ] + .iter() + .cloned(), + )?; + Ok(()) +} + +fn emit_sub_int( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + typ: ast::ScalarType, + saturate: bool, + arg: &ast::SubArgs, +) -> Result<(), dr::Error> { + if saturate { + todo!() + } + let inst_type = map + .get_or_add(builder, SpirvType::from(ast::ScalarType::from(typ))) + .0; + builder.i_sub(inst_type, Some(arg.dst.0), arg.src1.0, arg.src2.0)?; + Ok(()) +} + +fn emit_sub_float( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + desc: &ast::ArithFloat, + arg: &ast::SubArgs, +) -> Result<(), dr::Error> { + let inst_type = map + .get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.type_))) + .0; + builder.f_sub(inst_type, Some(arg.dst.0), arg.src1.0, arg.src2.0)?; + emit_rounding_decoration(builder, arg.dst, desc.rounding); + Ok(()) +} + +fn emit_min( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + opencl: spirv::Word, + desc: &ast::MinMaxDetails, + arg: &ast::MinArgs, +) -> Result<(), dr::Error> { + let cl_op = match desc { + ast::MinMaxDetails::Signed(_) => spirv::CLOp::s_min, + ast::MinMaxDetails::Unsigned(_) => spirv::CLOp::u_min, + ast::MinMaxDetails::Float(_) => spirv::CLOp::fmin, + }; + let inst_type = map.get_or_add(builder, SpirvType::from(desc.type_())); + builder.ext_inst( + inst_type.0, + Some(arg.dst.0), + opencl, + cl_op as spirv::Word, + [ + dr::Operand::IdRef(arg.src1.0), + dr::Operand::IdRef(arg.src2.0), + ] + .iter() + .cloned(), + )?; + Ok(()) +} + +fn emit_max( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + opencl: spirv::Word, + desc: &ast::MinMaxDetails, + arg: &ast::MaxArgs, +) -> Result<(), dr::Error> { + let cl_op = match desc { + ast::MinMaxDetails::Signed(_) => spirv::CLOp::s_max, + ast::MinMaxDetails::Unsigned(_) => spirv::CLOp::u_max, + ast::MinMaxDetails::Float(_) => spirv::CLOp::fmax, + }; + let inst_type = map.get_or_add(builder, SpirvType::from(desc.type_())); + builder.ext_inst( + inst_type.0, + Some(arg.dst.0), + opencl, + cl_op as spirv::Word, + [ + dr::Operand::IdRef(arg.src1.0), + dr::Operand::IdRef(arg.src2.0), + ] + .iter() + .cloned(), + )?; + Ok(()) +} + +fn emit_rcp( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + opencl: spirv::Word, + desc: &ast::RcpData, + arg: &ast::RcpArgs, +) -> Result<(), TranslateError> { + let is_f64 = desc.type_ == ast::ScalarType::F64; + let (instr_type, constant) = if is_f64 { + (ast::ScalarType::F64, vec_repr(1.0f64)) + } else { + (ast::ScalarType::F32, vec_repr(1.0f32)) + }; + let result_type = map.get_or_add_scalar(builder, instr_type); + let rounding = match desc.kind { + ptx_parser::RcpKind::Approx => { + builder.ext_inst( + result_type.0, + Some(arg.dst.0), + opencl, + spirv::CLOp::native_recip as u32, + [dr::Operand::IdRef(arg.src.0)].iter().cloned(), + )?; + return Ok(()); + } + ptx_parser::RcpKind::Compliant(rounding) => rounding, + }; + let one = map.get_or_add_constant(builder, &ast::Type::Scalar(instr_type), &constant)?; + builder.f_div(result_type.0, Some(arg.dst.0), one.0, arg.src.0)?; + emit_rounding_decoration(builder, arg.dst, Some(rounding)); + builder.decorate( + arg.dst.0, + spirv::Decoration::FPFastMathMode, + [dr::Operand::FPFastMathMode( + spirv::FPFastMathMode::ALLOW_RECIP, + )] + .iter() + .cloned(), + ); + Ok(()) +} + +fn emit_atom( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + details: &ast::AtomDetails, + arg: &ast::AtomArgs, +) -> Result<(), TranslateError> { + let spirv_op = match details.op { + ptx_parser::AtomicOp::And => dr::Builder::atomic_and, + ptx_parser::AtomicOp::Or => dr::Builder::atomic_or, + ptx_parser::AtomicOp::Xor => dr::Builder::atomic_xor, + ptx_parser::AtomicOp::Exchange => dr::Builder::atomic_exchange, + ptx_parser::AtomicOp::Add => dr::Builder::atomic_i_add, + ptx_parser::AtomicOp::IncrementWrap | ptx_parser::AtomicOp::DecrementWrap => { + return Err(error_unreachable()) + } + ptx_parser::AtomicOp::SignedMin => dr::Builder::atomic_s_min, + ptx_parser::AtomicOp::UnsignedMin => dr::Builder::atomic_u_min, + ptx_parser::AtomicOp::SignedMax => dr::Builder::atomic_s_max, + ptx_parser::AtomicOp::UnsignedMax => dr::Builder::atomic_u_max, + ptx_parser::AtomicOp::FloatAdd => dr::Builder::atomic_f_add_ext, + ptx_parser::AtomicOp::FloatMin => todo!(), + ptx_parser::AtomicOp::FloatMax => todo!(), + }; + let result_type = map.get_or_add(builder, SpirvType::new(details.type_.clone())); + let memory_const = map.get_or_add_constant( + builder, + &ast::Type::Scalar(ast::ScalarType::U32), + &vec_repr(scope_to_spirv(details.scope) as u32), + )?; + let semantics_const = map.get_or_add_constant( + builder, + &ast::Type::Scalar(ast::ScalarType::U32), + &vec_repr(semantics_to_spirv(details.semantics).bits()), + )?; + spirv_op( + builder, + result_type.0, + Some(arg.dst.0), + arg.src1.0, + memory_const.0, + semantics_const.0, + arg.src2.0, + )?; + Ok(()) +} + +fn scope_to_spirv(this: ast::MemScope) -> spirv::Scope { + match this { + ast::MemScope::Cta => spirv::Scope::Workgroup, + ast::MemScope::Gpu => spirv::Scope::Device, + ast::MemScope::Sys => spirv::Scope::CrossDevice, + ptx_parser::MemScope::Cluster => todo!(), + } +} + +fn semantics_to_spirv(this: ast::AtomSemantics) -> spirv::MemorySemantics { + match this { + ast::AtomSemantics::Relaxed => spirv::MemorySemantics::RELAXED, + ast::AtomSemantics::Acquire => spirv::MemorySemantics::ACQUIRE, + ast::AtomSemantics::Release => spirv::MemorySemantics::RELEASE, + ast::AtomSemantics::AcqRel => spirv::MemorySemantics::ACQUIRE_RELEASE, + } +} + +fn emit_float_div_decoration(builder: &mut dr::Builder, dst: SpirvWord, kind: ast::DivFloatKind) { + match kind { + ast::DivFloatKind::Approx => { + builder.decorate( + dst.0, + spirv::Decoration::FPFastMathMode, + [dr::Operand::FPFastMathMode( + spirv::FPFastMathMode::ALLOW_RECIP, + )] + .iter() + .cloned(), + ); + } + ast::DivFloatKind::Rounding(rnd) => { + emit_rounding_decoration(builder, dst, Some(rnd)); + } + ast::DivFloatKind::ApproxFull => {} + } +} + +fn emit_sqrt( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + opencl: spirv::Word, + details: &ast::RcpData, + a: &ast::SqrtArgs, +) -> Result<(), TranslateError> { + let result_type = map.get_or_add_scalar(builder, details.type_.into()); + let (ocl_op, rounding) = match details.kind { + ast::RcpKind::Approx => (spirv::CLOp::sqrt, None), + ast::RcpKind::Compliant(rnd) => (spirv::CLOp::sqrt, Some(rnd)), + }; + builder.ext_inst( + result_type.0, + Some(a.dst.0), + opencl, + ocl_op as spirv::Word, + [dr::Operand::IdRef(a.src.0)].iter().cloned(), + )?; + emit_rounding_decoration(builder, a.dst, rounding); + Ok(()) +} + +// TODO: check what kind of assembly do we emit +fn emit_logical_xor_spirv( + builder: &mut dr::Builder, + result_type: spirv::Word, + result_id: Option, + op1: spirv::Word, + op2: spirv::Word, +) -> Result { + let temp_or = builder.logical_or(result_type, None, op1, op2)?; + let temp_and = builder.logical_and(result_type, None, op1, op2)?; + let temp_neg = logical_not(builder, result_type, None, temp_and)?; + builder.logical_and(result_type, result_id, temp_or, temp_neg) +} + +fn emit_load_var( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + details: &LoadVarDetails, +) -> Result<(), TranslateError> { + let result_type = map.get_or_add(builder, SpirvType::new(details.typ.clone())); + match details.member_index { + Some((index, Some(width))) => { + let vector_type = match details.typ { + ast::Type::Scalar(scalar_t) => ast::Type::Vector(scalar_t, width), + _ => return Err(TranslateError::MismatchedType), + }; + let vector_type_spirv = map.get_or_add(builder, SpirvType::new(vector_type)); + let vector_temp = builder.load( + vector_type_spirv.0, + None, + details.arg.src.0, + None, + iter::empty(), + )?; + builder.composite_extract( + result_type.0, + Some(details.arg.dst.0), + vector_temp, + [index as u32].iter().copied(), + )?; + } + Some((index, None)) => { + let result_ptr_type = map.get_or_add( + builder, + SpirvType::pointer_to(details.typ.clone(), spirv::StorageClass::Function), + ); + let index_spirv = map.get_or_add_constant( + builder, + &ast::Type::Scalar(ast::ScalarType::U32), + &vec_repr(index as u32), + )?; + let src = builder.in_bounds_access_chain( + result_ptr_type.0, + None, + details.arg.src.0, + [index_spirv.0].iter().copied(), + )?; + builder.load( + result_type.0, + Some(details.arg.dst.0), + src, + None, + iter::empty(), + )?; + } + None => { + builder.load( + result_type.0, + Some(details.arg.dst.0), + details.arg.src.0, + None, + iter::empty(), + )?; + } + }; + Ok(()) +} + +fn to_parts(this: &ast::Type) -> TypeParts { + match this { + ast::Type::Scalar(scalar) => TypeParts { + kind: TypeKind::Scalar, + state_space: ast::StateSpace::Reg, + scalar_kind: scalar.kind(), + width: scalar.size_of(), + components: Vec::new(), + }, + ast::Type::Vector(scalar, components) => TypeParts { + kind: TypeKind::Vector, + state_space: ast::StateSpace::Reg, + scalar_kind: scalar.kind(), + width: scalar.size_of(), + components: vec![*components as u32], + }, + ast::Type::Array(scalar, components) => TypeParts { + kind: TypeKind::Array, + state_space: ast::StateSpace::Reg, + scalar_kind: scalar.kind(), + width: scalar.size_of(), + components: components.clone(), + }, + ast::Type::Pointer(scalar, space) => TypeParts { + kind: TypeKind::Pointer, + state_space: *space, + scalar_kind: scalar.kind(), + width: scalar.size_of(), + components: Vec::new(), + }, + } +} + +fn type_from_parts(t: TypeParts) -> ast::Type { + match t.kind { + TypeKind::Scalar => ast::Type::Scalar(scalar_from_parts(t.width, t.scalar_kind)), + TypeKind::Vector => ast::Type::Vector( + scalar_from_parts(t.width, t.scalar_kind), + t.components[0] as u8, + ), + TypeKind::Array => { + ast::Type::Array(scalar_from_parts(t.width, t.scalar_kind), t.components) + } + TypeKind::Pointer => { + ast::Type::Pointer(scalar_from_parts(t.width, t.scalar_kind), t.state_space) + } + } +} + +#[derive(Eq, PartialEq, Clone)] +struct TypeParts { + kind: TypeKind, + scalar_kind: ast::ScalarKind, + width: u8, + state_space: ast::StateSpace, + components: Vec, +} + +#[derive(Eq, PartialEq, Copy, Clone)] +enum TypeKind { + Scalar, + Vector, + Array, + Pointer, +} diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index 1fdf3a62..8923718c 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -3,22 +3,27 @@ use rspirv::{binary::Assemble, dr}; use std::{ borrow::Cow, cell::RefCell, - collections::{hash_map, HashMap}, + collections::{hash_map, HashMap, HashSet}, ffi::CString, + iter, marker::PhantomData, + mem, rc::Rc, }; +use std::hash::Hash; +mod convert_dynamic_shared_memory_usage; mod convert_to_stateful_memory_access; mod convert_to_typed; mod expand_arguments; +mod extract_globals; mod fix_special_registers; +mod insert_implicit_conversions; mod insert_mem_ssa_statements; mod normalize_identifiers; -mod normalize_predicates; -mod insert_implicit_conversions; mod normalize_labels; -mod extract_globals; +mod normalize_predicates; +mod emit_spirv; static ZLUDA_PTX_IMPL_INTEL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.spv"); static ZLUDA_PTX_IMPL_AMD: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc"); @@ -34,7 +39,6 @@ pub fn to_spirv_module<'input>(ast: ast::Module<'input>) -> Result, _>>()?; - /* let directives = hoist_function_globals(directives); let must_link_ptx_impl = ptx_impl_imports.len() > 0; let mut directives = ptx_impl_imports @@ -43,21 +47,19 @@ pub fn to_spirv_module<'input>(ast: ast::Module<'input>) -> Result>(); let mut builder = dr::Builder::new(); - builder.reserve_ids(id_defs.current_id()); + builder.reserve_ids(id_defs.current_id().0); let call_map = MethodsCallMap::new(&directives); let mut directives = - convert_dynamic_shared_memory_usage(directives, &call_map, &mut || builder.id()); + convert_dynamic_shared_memory_usage::run(directives, &call_map, &mut || { + SpirvWord(builder.id()) + })?; normalize_variable_decls(&mut directives); let denorm_information = compute_denorm_information(&directives); + emit_spirv::run(builder, &id_defs, call_map, denorm_information, directives); // https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module - builder.set_version(1, 3); - emit_capabilities(&mut builder); - emit_extensions(&mut builder); - let opencl_id = emit_opencl_import(&mut builder); - emit_memory_model(&mut builder); - let mut map = TypeWordMap::new(&mut builder); - //emit_builtins(&mut builder, &mut map, &id_defs); - let mut kernel_info = HashMap::new(); + + todo!() + /* let (build_options, should_flush_denorms) = emit_denorm_build_string(&call_map, &denorm_information); let (directives, globals_use_map) = get_globals_use_map(directives); @@ -84,7 +86,6 @@ pub fn to_spirv_module<'input>(ast: ast::Module<'input>) -> Result( @@ -1273,3 +1274,399 @@ fn fn_arguments_to_variables<'a>( }) .collect::>() } + +fn hoist_function_globals(directives: Vec) -> Vec { + let mut result = Vec::with_capacity(directives.len()); + for directive in directives { + match directive { + Directive::Method(method) => { + for variable in method.globals { + result.push(Directive::Variable(ast::LinkingDirective::NONE, variable)); + } + result.push(Directive::Method(Function { + globals: Vec::new(), + ..method + })) + } + _ => result.push(directive), + } + } + result +} + +struct MethodsCallMap<'input> { + map: HashMap, HashSet>, +} + +impl<'input> MethodsCallMap<'input> { + fn new(module: &[Directive<'input>]) -> Self { + let mut directly_called_by = HashMap::new(); + for directive in module { + match directive { + Directive::Method(Function { + func_decl, + body: Some(statements), + .. + }) => { + let call_key: ast::MethodName<_> = (**func_decl).borrow().name; + if let hash_map::Entry::Vacant(entry) = directly_called_by.entry(call_key) { + entry.insert(Vec::new()); + } + for statement in statements { + match statement { + Statement::Instruction(ast::Instruction::Call { data, arguments }) => { + multi_hash_map_append( + &mut directly_called_by, + call_key, + arguments.func, + ); + } + _ => {} + } + } + } + _ => {} + } + } + let mut result = HashMap::new(); + for (&method_key, children) in directly_called_by.iter() { + let mut visited = HashSet::new(); + for child in children { + Self::add_call_map_single(&directly_called_by, &mut visited, *child); + } + result.insert(method_key, visited); + } + MethodsCallMap { map: result } + } + + fn add_call_map_single( + directly_called_by: &HashMap, Vec>, + visited: &mut HashSet, + current: SpirvWord, + ) { + if !visited.insert(current) { + return; + } + if let Some(children) = directly_called_by.get(&ast::MethodName::Func(current)) { + for child in children { + Self::add_call_map_single(directly_called_by, visited, *child); + } + } + } + + fn get_kernel_children(&self, name: &'input str) -> impl Iterator { + self.map + .get(&ast::MethodName::Kernel(name)) + .into_iter() + .flatten() + } + + fn kernels(&self) -> impl Iterator)> { + self.map + .iter() + .filter_map(|(method, children)| match method { + ast::MethodName::Kernel(kernel) => Some((*kernel, children)), + ast::MethodName::Func(..) => None, + }) + } + + fn methods( + &self, + ) -> impl Iterator, &HashSet)> { + self.map + .iter() + .map(|(method, children)| (*method, children)) + } + + fn visit_callees(&self, method: ast::MethodName<'input, SpirvWord>, f: impl FnMut(SpirvWord)) { + self.map + .get(&method) + .into_iter() + .flatten() + .copied() + .for_each(f); + } +} + +fn multi_hash_map_append< + K: Eq + std::hash::Hash, + V, + Collection: std::iter::Extend + std::default::Default, +>( + m: &mut HashMap, + key: K, + value: V, +) { + match m.entry(key) { + hash_map::Entry::Occupied(mut entry) => { + entry.get_mut().extend(iter::once(value)); + } + hash_map::Entry::Vacant(entry) => { + entry.insert(Default::default()).extend(iter::once(value)); + } + } +} + +fn normalize_variable_decls(directives: &mut Vec) { + for directive in directives { + match directive { + Directive::Method(Function { + body: Some(func), .. + }) => { + func[1..].sort_by_key(|s| match s { + Statement::Variable(_) => 0, + _ => 1, + }); + } + _ => (), + } + } +} + +// HACK ALERT! +// This function is a "good enough" heuristic of whetever to mark f16/f32 operations +// in the kernel as flushing denorms to zero or preserving them +// PTX support per-instruction ftz information. Unfortunately SPIR-V has no +// such capability, so instead we guesstimate which use is more common in the kernel +// and emit suitable execution mode +fn compute_denorm_information<'input>( + module: &[Directive<'input>], +) -> HashMap, HashMap> { + let mut denorm_methods = HashMap::new(); + for directive in module { + match directive { + Directive::Variable(..) | Directive::Method(Function { body: None, .. }) => {} + Directive::Method(Function { + func_decl, + body: Some(statements), + .. + }) => { + let mut flush_counter = DenormCountMap::new(); + let method_key = (**func_decl).borrow().name; + for statement in statements { + match statement { + Statement::Instruction(inst) => { + if let Some((flush, width)) = flush_to_zero(inst) { + denorm_count_map_update(&mut flush_counter, width, flush); + } + } + Statement::LoadVar(..) => {} + Statement::StoreVar(..) => {} + Statement::Conditional(_) => {} + Statement::Conversion(_) => {} + Statement::Constant(_) => {} + Statement::RetValue(_, _) => {} + Statement::Label(_) => {} + Statement::Variable(_) => {} + Statement::PtrAccess { .. } => {} + Statement::RepackVector(_) => {} + Statement::FunctionPointer(_) => {} + } + } + denorm_methods.insert(method_key, flush_counter); + } + } + } + denorm_methods + .into_iter() + .map(|(name, v)| { + let width_to_denorm = v + .into_iter() + .map(|(k, flush_over_preserve)| { + let mode = if flush_over_preserve > 0 { + spirv::FPDenormMode::FlushToZero + } else { + spirv::FPDenormMode::Preserve + }; + (k, (mode, flush_over_preserve)) + }) + .collect(); + (name, width_to_denorm) + }) + .collect() +} + +fn flush_to_zero(this: &ast::Instruction) -> Option<(bool, u8)> { + match this { + ast::Instruction::Ld { .. } => None, + ast::Instruction::St { .. } => None, + ast::Instruction::Mov { .. } => None, + ast::Instruction::Not { .. } => None, + ast::Instruction::Bra { .. } => None, + ast::Instruction::Shl { .. } => None, + ast::Instruction::Shr { .. } => None, + ast::Instruction::Ret { .. } => None, + ast::Instruction::Call { .. } => None, + ast::Instruction::Or { .. } => None, + ast::Instruction::And { .. } => None, + ast::Instruction::Cvta { .. } => None, + ast::Instruction::Selp { .. } => None, + ast::Instruction::Bar { .. } => None, + ast::Instruction::Atom { .. } => None, + ast::Instruction::AtomCas { .. } => None, + ast::Instruction::Sub { + data: ast::ArithDetails::Integer(_), + .. + } => None, + ast::Instruction::Add { + data: ast::ArithDetails::Integer(_), + .. + } => None, + ast::Instruction::Mul { + data: ast::MulDetails::Integer { .. }, + .. + } => None, + ast::Instruction::Mad { + data: ast::MadDetails::Integer { .. }, + .. + } => None, + ast::Instruction::Min { + data: ast::MinMaxDetails::Signed(_), + .. + } => None, + ast::Instruction::Min { + data: ast::MinMaxDetails::Unsigned(_), + .. + } => None, + ast::Instruction::Max { + data: ast::MinMaxDetails::Signed(_), + .. + } => None, + ast::Instruction::Max { + data: ast::MinMaxDetails::Unsigned(_), + .. + } => None, + ast::Instruction::Cvt { + data: + ast::CvtDetails { + mode: + ast::CvtMode::ZeroExtend + | ast::CvtMode::SignExtend + | ast::CvtMode::Truncate + | ast::CvtMode::Bitcast + | ast::CvtMode::SaturateUnsignedToSigned + | ast::CvtMode::SaturateSignedToUnsigned + | ast::CvtMode::FPFromSigned(_) + | ast::CvtMode::FPFromUnsigned(_), + .. + }, + .. + } => None, + ast::Instruction::Div { + data: ast::DivDetails::Unsigned(_), + .. + } => None, + ast::Instruction::Div { + data: ast::DivDetails::Signed(_), + .. + } => None, + ast::Instruction::Clz { .. } => None, + ast::Instruction::Brev { .. } => None, + ast::Instruction::Popc { .. } => None, + ast::Instruction::Xor { .. } => None, + ast::Instruction::Bfe { .. } => None, + ast::Instruction::Bfi { .. } => None, + ast::Instruction::Rem { .. } => None, + ast::Instruction::Prmt { .. } => None, + ast::Instruction::Activemask { .. } => None, + ast::Instruction::Membar { .. } => None, + ast::Instruction::Sub { + data: ast::ArithDetails::Float(float_control), + .. + } + | ast::Instruction::Add { + data: ast::ArithDetails::Float(float_control), + .. + } + | ast::Instruction::Mul { + data: ast::MulDetails::Float(float_control), + .. + } + | ast::Instruction::Mad { + data: ast::MadDetails::Float(float_control), + .. + } => float_control + .flush_to_zero + .map(|ftz| (ftz, float_control.type_.size_of())), + ast::Instruction::Fma { data, .. } => data.flush_to_zero.map(|ftz| (ftz, data.type_.size_of())), + ast::Instruction::Setp { data, .. } => { + data.flush_to_zero.map(|ftz| (ftz, data.type_.size_of())) + } + ast::Instruction::SetpBool { data, .. } => data + .base + .flush_to_zero + .map(|ftz| (ftz, data.base.type_.size_of())), + ast::Instruction::Abs { data, .. } + | ast::Instruction::Rsqrt { data, .. } + | ast::Instruction::Neg { data, .. } + | ast::Instruction::Ex2 { data, .. } => { + data.flush_to_zero.map(|ftz| (ftz, data.type_.size_of())) + } + ast::Instruction::Min { + data: ast::MinMaxDetails::Float(float_control), + .. + } + | ast::Instruction::Max { + data: ast::MinMaxDetails::Float(float_control), + .. + } => float_control + .flush_to_zero + .map(|ftz| (ftz, ast::ScalarType::from(float_control.type_).size_of())), + ast::Instruction::Sqrt { data, .. } | ast::Instruction::Rcp { data, .. } => { + data.flush_to_zero.map(|ftz| (ftz, data.type_.size_of())) + } + // Modifier .ftz can only be specified when either .dtype or .atype + // is .f32 and applies only to single precision (.f32) inputs and results. + ast::Instruction::Cvt { + data: + ast::CvtDetails { + mode: + ast::CvtMode::FPExtend { flush_to_zero } + | ast::CvtMode::FPTruncate { flush_to_zero, .. } + | ast::CvtMode::FPRound { flush_to_zero, .. } + | ast::CvtMode::SignedFromFP { flush_to_zero, .. } + | ast::CvtMode::UnsignedFromFP { flush_to_zero, .. }, + .. + }, + .. + } => flush_to_zero.map(|ftz| (ftz, 4)), + ast::Instruction::Div { + data: + ast::DivDetails::Float(ast::DivFloatDetails { + type_, + flush_to_zero, + .. + }), + .. + } => flush_to_zero.map(|ftz| (ftz, type_.size_of())), + ast::Instruction::Sin { data, .. } + | ast::Instruction::Cos { data, .. } + | ast::Instruction::Lg2 { data, .. } => { + Some((data.flush_to_zero, mem::size_of::() as u8)) + } + ptx_parser::Instruction::PrmtSlow { .. } => None, + ptx_parser::Instruction::Trap {} => None, + } +} + +type DenormCountMap = HashMap; + +fn denorm_count_map_update(map: &mut DenormCountMap, key: T, value: bool) { + let num_value = if value { 1 } else { -1 }; + denorm_count_map_update_impl(map, key, num_value); +} + +fn denorm_count_map_update_impl( + map: &mut DenormCountMap, + key: T, + num_value: isize, +) { + match map.entry(key) { + hash_map::Entry::Occupied(mut counter) => { + *(counter.get_mut()) += num_value; + } + hash_map::Entry::Vacant(entry) => { + entry.insert(num_value); + } + } +} diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index 59815f25..39b464e3 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -514,8 +514,11 @@ pub trait Visitor { ) -> Result<(), Err>; } -impl, bool, bool) -> Result<(), Err>> - Visitor for Fn +impl< + T: Operand, + Err, + Fn: FnMut(&T, Option<(&Type, StateSpace)>, bool, bool) -> Result<(), Err>, + > Visitor for Fn { fn visit( &mut self, @@ -760,7 +763,7 @@ pub enum Type { Vector(ScalarType, u8), // .param.b32 foo[4]; Array(ScalarType, Vec), - Pointer(ScalarType, StateSpace) + Pointer(ScalarType, StateSpace), } impl Type { @@ -1097,7 +1100,7 @@ impl SetpData { let cmp_op = if type_kind == ScalarKind::Float { SetpCompareOp::Float(SetpCompareFloat::from(cmp_op)) } else { - match SetpCompareInt::try_from(cmp_op) { + match SetpCompareInt::try_from((cmp_op, type_kind)) { Ok(op) => SetpCompareOp::Integer(op), Err(err) => { state.errors.push(err); @@ -1129,10 +1132,14 @@ pub enum SetpCompareOp { pub enum SetpCompareInt { Eq, NotEq, - Less, - LessOrEq, - Greater, - GreaterOrEq, + UnsignedLess, + UnsignedLessOrEq, + UnsignedGreater, + UnsignedGreaterOrEq, + SignedLess, + SignedLessOrEq, + SignedGreater, + SignedGreaterOrEq, } #[derive(PartialEq, Eq, Copy, Clone)] @@ -1153,29 +1160,41 @@ pub enum SetpCompareFloat { IsAnyNan, } -impl TryFrom for SetpCompareInt { +impl TryFrom<(RawSetpCompareOp, ScalarKind)> for SetpCompareInt { type Error = PtxError; - fn try_from(value: RawSetpCompareOp) -> Result { - match value { - RawSetpCompareOp::Eq => Ok(SetpCompareInt::Eq), - RawSetpCompareOp::Ne => Ok(SetpCompareInt::NotEq), - RawSetpCompareOp::Lt => Ok(SetpCompareInt::Less), - RawSetpCompareOp::Le => Ok(SetpCompareInt::LessOrEq), - RawSetpCompareOp::Gt => Ok(SetpCompareInt::Greater), - RawSetpCompareOp::Ge => Ok(SetpCompareInt::GreaterOrEq), - RawSetpCompareOp::Lo => Ok(SetpCompareInt::Less), - RawSetpCompareOp::Ls => Ok(SetpCompareInt::LessOrEq), - RawSetpCompareOp::Hi => Ok(SetpCompareInt::Greater), - RawSetpCompareOp::Hs => Ok(SetpCompareInt::GreaterOrEq), - RawSetpCompareOp::Equ => Err(PtxError::WrongType), - RawSetpCompareOp::Neu => Err(PtxError::WrongType), - RawSetpCompareOp::Ltu => Err(PtxError::WrongType), - RawSetpCompareOp::Leu => Err(PtxError::WrongType), - RawSetpCompareOp::Gtu => Err(PtxError::WrongType), - RawSetpCompareOp::Geu => Err(PtxError::WrongType), - RawSetpCompareOp::Num => Err(PtxError::WrongType), - RawSetpCompareOp::Nan => Err(PtxError::WrongType), + fn try_from((value, kind): (RawSetpCompareOp, ScalarKind)) -> Result { + match (value, kind) { + (RawSetpCompareOp::Eq, _) => Ok(SetpCompareInt::Eq), + (RawSetpCompareOp::Ne, _) => Ok(SetpCompareInt::NotEq), + (RawSetpCompareOp::Lt | RawSetpCompareOp::Lo, ScalarKind::Signed) => { + Ok(SetpCompareInt::SignedLess) + } + (RawSetpCompareOp::Lt | RawSetpCompareOp::Lo, _) => Ok(SetpCompareInt::UnsignedLess), + (RawSetpCompareOp::Le | RawSetpCompareOp::Ls, ScalarKind::Signed) => { + Ok(SetpCompareInt::SignedLessOrEq) + } + (RawSetpCompareOp::Le | RawSetpCompareOp::Ls, _) => { + Ok(SetpCompareInt::UnsignedLessOrEq) + } + (RawSetpCompareOp::Gt | RawSetpCompareOp::Hi, ScalarKind::Signed) => { + Ok(SetpCompareInt::SignedGreater) + } + (RawSetpCompareOp::Gt | RawSetpCompareOp::Hi, _) => Ok(SetpCompareInt::UnsignedGreater), + (RawSetpCompareOp::Ge | RawSetpCompareOp::Hs, ScalarKind::Signed) => { + Ok(SetpCompareInt::SignedGreaterOrEq) + } + (RawSetpCompareOp::Ge | RawSetpCompareOp::Hs, _) => { + Ok(SetpCompareInt::UnsignedGreaterOrEq) + } + (RawSetpCompareOp::Equ, _) => Err(PtxError::WrongType), + (RawSetpCompareOp::Neu, _) => Err(PtxError::WrongType), + (RawSetpCompareOp::Ltu, _) => Err(PtxError::WrongType), + (RawSetpCompareOp::Leu, _) => Err(PtxError::WrongType), + (RawSetpCompareOp::Gtu, _) => Err(PtxError::WrongType), + (RawSetpCompareOp::Geu, _) => Err(PtxError::WrongType), + (RawSetpCompareOp::Num, _) => Err(PtxError::WrongType), + (RawSetpCompareOp::Nan, _) => Err(PtxError::WrongType), } } } @@ -1276,7 +1295,9 @@ impl CallArgs { .return_arguments .into_iter() .zip(details.return_arguments.iter()) - .map(|(param, (type_, space))| visitor.visit_ident(param, Some((type_, *space)), true, false)) + .map(|(param, (type_, space))| { + visitor.visit_ident(param, Some((type_, *space)), true, false) + }) .collect::, _>>()?; let func = visitor.visit_ident(self.func, None, false, false)?; let input_arguments = self @@ -1305,6 +1326,8 @@ pub enum CvtMode { SignExtend, Truncate, Bitcast, + SaturateUnsignedToSigned, + SaturateSignedToUnsigned, // float from float FPExtend { flush_to_zero: Option, @@ -1389,21 +1412,11 @@ impl CvtDetails { }, (ScalarKind::Float, ScalarKind::Unsigned) => CvtMode::FPFromUnsigned(unwrap_rounding()), (ScalarKind::Float, ScalarKind::Signed) => CvtMode::FPFromSigned(unwrap_rounding()), - ( - ScalarKind::Unsigned | ScalarKind::Signed, - ScalarKind::Unsigned | ScalarKind::Signed, - ) => match dst.size_of().cmp(&src.size_of()) { - Ordering::Less => { - if dst.kind() != src.kind() { - errors.push(PtxError::Todo); - } - CvtMode::Truncate - } + (ScalarKind::Unsigned, ScalarKind::Unsigned) + | (ScalarKind::Signed, ScalarKind::Signed) => match dst.size_of().cmp(&src.size_of()) { + Ordering::Less => CvtMode::Truncate, Ordering::Equal => CvtMode::Bitcast, Ordering::Greater => { - if dst.kind() != src.kind() { - errors.push(PtxError::Todo); - } if src.kind() == ScalarKind::Signed { CvtMode::SignExtend } else { From 2e5ad8ebdfe505b5f040bad3b47a4b39268cb492 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Fri, 30 Aug 2024 17:01:47 +0200 Subject: [PATCH 36/47] Wire new parser into spvtxt tests --- ptx/src/pass/convert_to_typed.rs | 6 ++-- ptx/src/pass/emit_spirv.rs | 9 +++--- ptx/src/pass/expand_arguments.rs | 4 +-- ptx/src/pass/fix_special_registers.rs | 4 +-- ptx/src/pass/insert_implicit_conversions.rs | 20 ++++++------ ptx/src/pass/insert_mem_ssa_statements.rs | 2 +- ptx/src/pass/mod.rs | 36 +++++++++------------ ptx/src/test/spirv_run/mod.rs | 7 ++-- 8 files changed, 41 insertions(+), 47 deletions(-) diff --git a/ptx/src/pass/convert_to_typed.rs b/ptx/src/pass/convert_to_typed.rs index 7ff52909..2342ad51 100644 --- a/ptx/src/pass/convert_to_typed.rs +++ b/ptx/src/pass/convert_to_typed.rs @@ -19,7 +19,7 @@ pub(crate) fn run( }, } if fn_defs.fns.contains_key(&src_reg) => { if data.typ != ast::Type::Scalar(ast::ScalarType::U64) { - return Err(TranslateError::MismatchedType); + return Err(error_mismatched_type()); } result.push(TypedStatement::FunctionPointer(FunctionPointerDetails { dst: dst_reg, @@ -68,7 +68,7 @@ impl<'a, 'b> VectorRepackVisitor<'a, 'b> { // mov.u32 foobar, {a,b}; let scalar_t = match typ { ast::Type::Vector(scalar_t, _) => *scalar_t, - _ => return Err(TranslateError::MismatchedType), + _ => return Err(error_mismatched_type()), }; let temp_vec = self .id_def @@ -115,7 +115,7 @@ impl<'a, 'b> ast::VisitorMap, TypedOperand, Transl ast::ParsedOperand::Imm(x) => TypedOperand::Imm(x), ast::ParsedOperand::VecMember(vec, idx) => TypedOperand::VecMember(vec, idx), ast::ParsedOperand::VecPack(vec) => { - let (type_, space) = type_space.ok_or(TranslateError::MismatchedType)?; + let (type_, space) = type_space.ok_or(error_mismatched_type())?; TypedOperand::Reg(self.convert_vector( is_dst, relaxed_type_check, diff --git a/ptx/src/pass/emit_spirv.rs b/ptx/src/pass/emit_spirv.rs index 9dff12ee..e2e6a3b4 100644 --- a/ptx/src/pass/emit_spirv.rs +++ b/ptx/src/pass/emit_spirv.rs @@ -17,7 +17,7 @@ pub(super) fn run<'input>( HashMap, >, directives: Vec>, -) -> Result<(), TranslateError> { +) -> Result<(dr::Module, HashMap, CString), TranslateError> { builder.set_version(1, 3); emit_capabilities(&mut builder); emit_extensions(&mut builder); @@ -39,7 +39,8 @@ pub(super) fn run<'input>( globals_use_map, directives, &mut kernel_info, - ) + )?; + Ok((builder.module(), kernel_info, build_options)) } fn emit_capabilities(builder: &mut dr::Builder) { @@ -942,7 +943,7 @@ fn emit_function_body_ops<'input>( builder.constant_true(bool_type, Some(cnst.dst.0)); } } - _ => return Err(TranslateError::MismatchedType), + _ => return Err(error_mismatched_type()), } } Statement::Conversion(cv) => emit_implicit_conversion(builder, map, cv)?, @@ -2646,7 +2647,7 @@ fn emit_load_var( Some((index, Some(width))) => { let vector_type = match details.typ { ast::Type::Scalar(scalar_t) => ast::Type::Vector(scalar_t, width), - _ => return Err(TranslateError::MismatchedType), + _ => return Err(error_mismatched_type()), }; let vector_type_spirv = map.get_or_add(builder, SpirvType::new(vector_type)); let vector_temp = builder.load( diff --git a/ptx/src/pass/expand_arguments.rs b/ptx/src/pass/expand_arguments.rs index eb03866d..36800051 100644 --- a/ptx/src/pass/expand_arguments.rs +++ b/ptx/src/pass/expand_arguments.rs @@ -66,11 +66,11 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { if state_space != ast::StateSpace::Reg && state_space != ast::StateSpace::Sreg { let (reg_type, reg_space) = self.id_def.get_typed(reg)?; if !state_is_compatible(reg_space, ast::StateSpace::Reg) { - return Err(TranslateError::MismatchedType); + return Err(error_mismatched_type()); } let reg_scalar_type = match reg_type { ast::Type::Scalar(underlying_type) => underlying_type, - _ => return Err(TranslateError::MismatchedType), + _ => return Err(error_mismatched_type()), }; let id_constant_stmt = self .id_def diff --git a/ptx/src/pass/fix_special_registers.rs b/ptx/src/pass/fix_special_registers.rs index 304bc611..c0290167 100644 --- a/ptx/src/pass/fix_special_registers.rs +++ b/ptx/src/pass/fix_special_registers.rs @@ -58,7 +58,7 @@ impl<'a, 'b, 'input> SpecialRegisterResolver<'a, 'b, 'input> { ) -> Result { if let Some(sreg) = self.numeric_id_defs.special_registers.get(name) { if is_dst { - return Err(TranslateError::MismatchedType); + return Err(error_mismatched_type()); } let input_arguments = match (vector_index, sreg.get_function_input_type()) { (Some(idx), Some(inp_type)) => { @@ -81,7 +81,7 @@ impl<'a, 'b, 'input> SpecialRegisterResolver<'a, 'b, 'input> { )] } (None, None) => Vec::new(), - _ => return Err(TranslateError::MismatchedType), + _ => return Err(error_mismatched_type()), }; let ocl_fn_name = [ZLUDA_PTX_PREFIX, sreg.get_unprefixed_function_name()].concat(); let return_type = sreg.get_function_return_type(); diff --git a/ptx/src/pass/insert_implicit_conversions.rs b/ptx/src/pass/insert_implicit_conversions.rs index 4a0dc8e7..baf34536 100644 --- a/ptx/src/pass/insert_implicit_conversions.rs +++ b/ptx/src/pass/insert_implicit_conversions.rs @@ -168,7 +168,7 @@ fn default_implicit_conversion_space( | ast::StateSpace::Const | ast::StateSpace::Local | ast::StateSpace::Shared => Ok(Some(ConversionKind::BitToPtr)), - _ => Err(TranslateError::MismatchedType), + _ => Err(error_mismatched_type()), }, ast::Type::Scalar(ast::ScalarType::B32) | ast::Type::Scalar(ast::ScalarType::U32) @@ -176,9 +176,9 @@ fn default_implicit_conversion_space( ast::StateSpace::Const | ast::StateSpace::Local | ast::StateSpace::Shared => { Ok(Some(ConversionKind::BitToPtr)) } - _ => Err(TranslateError::MismatchedType), + _ => Err(error_mismatched_type()), }, - _ => Err(TranslateError::MismatchedType), + _ => Err(error_mismatched_type()), } } else if state_is_compatible(instruction_space, ast::StateSpace::Reg) { match instruction_type { @@ -191,10 +191,10 @@ fn default_implicit_conversion_space( Ok(None) } } - _ => Err(TranslateError::MismatchedType), + _ => Err(error_mismatched_type()), } } else { - Err(TranslateError::MismatchedType) + Err(error_mismatched_type()) } } @@ -208,7 +208,7 @@ fn default_implicit_conversion_type( if should_bitcast(instruction_type, operand_type) { Ok(Some(ConversionKind::Default)) } else { - Err(TranslateError::MismatchedType) + Err(error_mismatched_type()) } } else { Ok(Some(ConversionKind::PtrToPtr)) @@ -265,14 +265,14 @@ fn should_convert_relaxed_dst_wrapper( (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), ) -> Result, TranslateError> { if !state_is_compatible(operand_space, instruction_space) { - return Err(TranslateError::MismatchedType); + return Err(error_mismatched_type()); } if operand_type == instruction_type { return Ok(None); } match should_convert_relaxed_dst(operand_type, instruction_type) { conv @ Some(_) => Ok(conv), - None => Err(TranslateError::MismatchedType), + None => Err(error_mismatched_type()), } } @@ -342,14 +342,14 @@ fn should_convert_relaxed_src_wrapper( (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), ) -> Result, TranslateError> { if !state_is_compatible(operand_space, instruction_space) { - return Err(TranslateError::MismatchedType); + return Err(error_mismatched_type()); } if operand_type == instruction_type { return Ok(None); } match should_convert_relaxed_src(operand_type, instruction_type) { conv @ Some(_) => Ok(conv), - None => Err(TranslateError::MismatchedType), + None => Err(error_mismatched_type()), } } diff --git a/ptx/src/pass/insert_mem_ssa_statements.rs b/ptx/src/pass/insert_mem_ssa_statements.rs index 6ab19bd8..7369cdb9 100644 --- a/ptx/src/pass/insert_mem_ssa_statements.rs +++ b/ptx/src/pass/insert_mem_ssa_statements.rs @@ -199,7 +199,7 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { var_type = ast::Type::Scalar(scalar_t); width } - _ => return Err(TranslateError::MismatchedType), + _ => return Err(error_mismatched_type()), }; Some(( idx, diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index 8923718c..28250176 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -55,26 +55,7 @@ pub fn to_spirv_module<'input>(ast: ast::Module<'input>) -> Result(ast: ast::Module<'input>) -> Result( @@ -629,10 +609,24 @@ fn error_unreachable() -> TranslateError { TranslateError::Unreachable } +fn error_unknown_symbol() -> TranslateError { + panic!() +} + +#[cfg(not(debug_assertions))] fn error_unknown_symbol() -> TranslateError { TranslateError::UnknownSymbol } +fn error_mismatched_type() -> TranslateError { + panic!() +} + +#[cfg(not(debug_assertions))] +fn error_mismatched_type() -> TranslateError { + TranslateError::MismatchedType +} + pub struct GlobalFnDeclResolver<'input, 'a> { fns: &'a HashMap>, } diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index f5dfa640..62dba04f 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -1,3 +1,4 @@ +use crate::pass; use crate::ptx; use crate::translate; use hip_runtime_sys::hipError_t; @@ -385,10 +386,8 @@ fn test_spvtxt_assert<'a>( spirv_txt: &'a [u8], spirv_file_name: &'a str, ) -> Result<(), Box> { - let mut errors = Vec::new(); - let ast = ptx::ModuleParser::new().parse(&mut errors, ptx_txt)?; - assert!(errors.len() == 0); - let spirv_module = translate::to_spirv_module(ast)?; + let ast = ptx_parser::parse_module_unchecked(ptx_txt).unwrap(); + let spirv_module = pass::to_spirv_module(ast)?; let spv_context = unsafe { spirv_tools::spvContextCreate(spv_target_env::SPV_ENV_UNIVERSAL_1_3) }; assert!(spv_context != ptr::null_mut()); From 32b62626ffce1e7466b62cda8a5cca99b22ff5ef Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Fri, 30 Aug 2024 17:47:47 +0200 Subject: [PATCH 37/47] Fix PtrAdd --- ptx/src/pass/expand_arguments.rs | 2 +- ptx/src/pass/mod.rs | 16 +++++++++++----- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/ptx/src/pass/expand_arguments.rs b/ptx/src/pass/expand_arguments.rs index 36800051..bc01ab07 100644 --- a/ptx/src/pass/expand_arguments.rs +++ b/ptx/src/pass/expand_arguments.rs @@ -63,7 +63,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { } else { return Err(TranslateError::UntypedSymbol); }; - if state_space != ast::StateSpace::Reg && state_space != ast::StateSpace::Sreg { + if state_space == ast::StateSpace::Reg || state_space == ast::StateSpace::Sreg { let (reg_type, reg_space) = self.id_def.get_typed(reg)?; if !state_is_compatible(reg_space, ast::StateSpace::Reg) { return Err(error_mismatched_type()); diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index 28250176..d0f4dfb5 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -1,5 +1,6 @@ use ptx_parser as ast; use rspirv::{binary::Assemble, dr}; +use std::hash::Hash; use std::{ borrow::Cow, cell::RefCell, @@ -10,11 +11,11 @@ use std::{ mem, rc::Rc, }; -use std::hash::Hash; mod convert_dynamic_shared_memory_usage; mod convert_to_stateful_memory_access; mod convert_to_typed; +mod emit_spirv; mod expand_arguments; mod extract_globals; mod fix_special_registers; @@ -23,7 +24,6 @@ mod insert_mem_ssa_statements; mod normalize_identifiers; mod normalize_labels; mod normalize_predicates; -mod emit_spirv; static ZLUDA_PTX_IMPL_INTEL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.spv"); static ZLUDA_PTX_IMPL_AMD: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc"); @@ -55,7 +55,8 @@ pub fn to_spirv_module<'input>(ast: ast::Module<'input>) -> Result> Statement, T> { )?; let offset_src = visitor.visit( offset_src, - Some((&underlying_type, state_space)), + Some(( + &ast::Type::Scalar(ast::ScalarType::S64), + ast::StateSpace::Reg, + )), false, false, )?; @@ -1582,7 +1586,9 @@ fn flush_to_zero(this: &ast::Instruction) -> Option<(bool, u8)> { } => float_control .flush_to_zero .map(|ftz| (ftz, float_control.type_.size_of())), - ast::Instruction::Fma { data, .. } => data.flush_to_zero.map(|ftz| (ftz, data.type_.size_of())), + ast::Instruction::Fma { data, .. } => { + data.flush_to_zero.map(|ftz| (ftz, data.type_.size_of())) + } ast::Instruction::Setp { data, .. } => { data.flush_to_zero.map(|ftz| (ftz, data.type_.size_of())) } From 16fafe553f5a7f1a504c89abdd27b39e746f8cad Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Fri, 30 Aug 2024 20:13:43 +0200 Subject: [PATCH 38/47] Parse comments and vector members correctly --- ptx/src/test/spirv_run/mod.rs | 2 +- ptx/src/test/spirv_run/vector.ptx | 2 +- ptx_parser/src/lib.rs | 59 ++++++++++++++++++++++++++----- 3 files changed, 53 insertions(+), 10 deletions(-) diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 62dba04f..a798720b 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -386,7 +386,7 @@ fn test_spvtxt_assert<'a>( spirv_txt: &'a [u8], spirv_file_name: &'a str, ) -> Result<(), Box> { - let ast = ptx_parser::parse_module_unchecked(ptx_txt).unwrap(); + let ast = ptx_parser::parse_module_checked(ptx_txt).unwrap(); let spirv_module = pass::to_spirv_module(ast)?; let spv_context = unsafe { spirv_tools::spvContextCreate(spv_target_env::SPV_ENV_UNIVERSAL_1_3) }; diff --git a/ptx/src/test/spirv_run/vector.ptx b/ptx/src/test/spirv_run/vector.ptx index 90b8ad30..ba07e15e 100644 --- a/ptx/src/test/spirv_run/vector.ptx +++ b/ptx/src/test/spirv_run/vector.ptx @@ -1,4 +1,4 @@ -// Excersise as many features of vector types as possible +// Exercise as many features of vector types as possible .version 6.5 .target sm_60 diff --git a/ptx_parser/src/lib.rs b/ptx_parser/src/lib.rs index cfb87939..3a9ece5b 100644 --- a/ptx_parser/src/lib.rs +++ b/ptx_parser/src/lib.rs @@ -289,6 +289,43 @@ pub fn parse_module_unchecked<'input>(text: &'input str) -> Option( + text: &'input str, +) -> Result, Vec> { + let mut lexer = Token::lexer(text); + let mut errors = Vec::new(); + let mut tokens = Vec::new(); + loop { + let maybe_token = match lexer.next() { + Some(maybe_token) => maybe_token, + None => break, + }; + match maybe_token { + Ok(token) => tokens.push(token), + Err(mut err) => { + err.0 = lexer.span(); + errors.push(PtxError::from(err)) + } + } + } + if !errors.is_empty() { + return Err(errors); + } + let parse_error = { + let state = PtxParserState::new(&mut errors); + let parser = PtxParser { + state, + input: &tokens[..], + }; + match module.parse(parser) { + Ok(ast) => return Ok(ast), + Err(err) => PtxError::Parser(err.into_inner()), + } + }; + errors.push(parse_error); + Err(errors) +} + fn module<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult> { ( version, @@ -773,10 +810,10 @@ impl ast::ParsedOperand { use winnow::token::any; fn vector_index<'input>(inp: &'input str) -> Result { match inp { - "x" | "r" => Ok(0), - "y" | "g" => Ok(1), - "z" | "b" => Ok(2), - "w" | "a" => Ok(3), + ".x" | ".r" => Ok(0), + ".y" | ".g" => Ok(1), + ".z" | ".b" => Ok(2), + ".w" | ".a" => Ok(3), _ => Err(PtxError::WrongVectorElement), } } @@ -787,7 +824,7 @@ impl ast::ParsedOperand { alt(( preceded(Token::Plus, s32) .map(move |offset| ast::ParsedOperand::RegOffset(main_ident, offset)), - take_error(preceded(Token::Dot, ident).map(move |suffix| { + take_error(dot_ident.map(move |suffix| { let vector_index = vector_index(suffix) .map_err(move |e| (ast::ParsedOperand::VecMember(main_ident, 0), e))?; Ok(ast::ParsedOperand::VecMember(main_ident, vector_index)) @@ -829,8 +866,13 @@ pub enum PtxError { #[from] source: ParseFloatError, }, + #[error("{source}")] + Lexer { + #[from] + source: TokenError, + }, #[error("")] - Lexer(#[from] TokenError), + Parser(ContextError), #[error("")] Todo, #[error("")] @@ -1057,13 +1099,14 @@ fn empty_call<'input>( type ParsedOperandStr<'input> = ast::ParsedOperand<&'input str>; #[derive(Clone, PartialEq, Default, Debug, Display)] -pub struct TokenError; +#[display("({}:{})", _0.start, _0.end)] +pub struct TokenError(std::ops::Range); impl std::error::Error for TokenError {} derive_parser!( #[derive(Logos, PartialEq, Eq, Debug, Clone, Copy)] - #[logos(skip r"\s+")] + #[logos(skip r"(?:\s+)|(?://[^\n\r]*[\n\r]*)|(?:/\*[^*]*\*+(?:[^/*][^*]*\*+)*/)")] #[logos(error = TokenError)] enum Token<'input> { #[token(",")] From aebf06a8c544cd8a1af8f4107be30313ab5a385f Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Fri, 30 Aug 2024 21:27:01 +0200 Subject: [PATCH 39/47] Improve implicit conversion and handling of vectors --- .../pass/convert_to_stateful_memory_access.rs | 2 +- ptx/src/pass/expand_arguments.rs | 2 +- ptx/src/pass/insert_implicit_conversions.rs | 42 ++++++++++++++++--- ptx/src/pass/insert_mem_ssa_statements.rs | 9 ++-- ptx/src/pass/mod.rs | 2 +- 5 files changed, 43 insertions(+), 14 deletions(-) diff --git a/ptx/src/pass/convert_to_stateful_memory_access.rs b/ptx/src/pass/convert_to_stateful_memory_access.rs index 61b31ad3..ad4b473d 100644 --- a/ptx/src/pass/convert_to_stateful_memory_access.rs +++ b/ptx/src/pass/convert_to_stateful_memory_access.rs @@ -475,7 +475,7 @@ fn convert_to_stateful_memory_access_postprocess( let (old_operand_type, old_operand_space, _) = id_defs.get_typed(operand)?; let converting_id = id_defs .register_intermediate(Some((old_operand_type.clone(), old_operand_space))); - let kind = if state_is_compatible(new_operand_space, ast::StateSpace::Reg) { + let kind = if space_is_compatible(new_operand_space, ast::StateSpace::Reg) { ConversionKind::Default } else { ConversionKind::PtrToPtr diff --git a/ptx/src/pass/expand_arguments.rs b/ptx/src/pass/expand_arguments.rs index bc01ab07..d0c7c981 100644 --- a/ptx/src/pass/expand_arguments.rs +++ b/ptx/src/pass/expand_arguments.rs @@ -65,7 +65,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { }; if state_space == ast::StateSpace::Reg || state_space == ast::StateSpace::Sreg { let (reg_type, reg_space) = self.id_def.get_typed(reg)?; - if !state_is_compatible(reg_space, ast::StateSpace::Reg) { + if !space_is_compatible(reg_space, ast::StateSpace::Reg) { return Err(error_mismatched_type()); } let reg_scalar_type = match reg_type { diff --git a/ptx/src/pass/insert_implicit_conversions.rs b/ptx/src/pass/insert_implicit_conversions.rs index baf34536..0dce5984 100644 --- a/ptx/src/pass/insert_implicit_conversions.rs +++ b/ptx/src/pass/insert_implicit_conversions.rs @@ -127,7 +127,22 @@ fn default_implicit_conversion( (operand_space, operand_type): (ast::StateSpace, &ast::Type), (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), ) -> Result, TranslateError> { - if !state_is_compatible(instruction_space, operand_space) { + if instruction_space == ast::StateSpace::Reg { + if space_is_compatible(operand_space, ast::StateSpace::Reg) { + if let (ast::Type::Vector(vec_underlying_type, vec_len), ast::Type::Scalar(scalar)) = + (operand_type, instruction_type) + { + if scalar.kind() == ast::ScalarKind::Bit + && scalar.size_of() == (vec_underlying_type.size_of() * vec_len) + { + return Ok(Some(ConversionKind::Default)); + } + } + } else if is_addressable(operand_space) { + return Ok(Some(ConversionKind::AddressOf)); + } + } + if !space_is_compatible(instruction_space, operand_space) { default_implicit_conversion_space( (operand_space, operand_type), (instruction_space, instruction_type), @@ -139,6 +154,21 @@ fn default_implicit_conversion( } } +fn is_addressable(this: ast::StateSpace) -> bool { + match this { + ast::StateSpace::Const + | ast::StateSpace::Generic + | ast::StateSpace::Global + | ast::StateSpace::Local + | ast::StateSpace::Shared => true, + ast::StateSpace::Param | ast::StateSpace::Reg | ast::StateSpace::Sreg => false, + ast::StateSpace::SharedCluster + | ast::StateSpace::SharedCta + | ast::StateSpace::ParamEntry + | ast::StateSpace::ParamFunc => todo!(), + } +} + // Space is different fn default_implicit_conversion_space( (operand_space, operand_type): (ast::StateSpace, &ast::Type), @@ -148,7 +178,7 @@ fn default_implicit_conversion_space( || (operand_space == ast::StateSpace::Generic && coerces_to_generic(instruction_space)) { Ok(Some(ConversionKind::PtrToPtr)) - } else if state_is_compatible(operand_space, ast::StateSpace::Reg) { + } else if space_is_compatible(operand_space, ast::StateSpace::Reg) { match operand_type { ast::Type::Pointer(operand_ptr_type, operand_ptr_space) if *operand_ptr_space == instruction_space => @@ -180,7 +210,7 @@ fn default_implicit_conversion_space( }, _ => Err(error_mismatched_type()), } - } else if state_is_compatible(instruction_space, ast::StateSpace::Reg) { + } else if space_is_compatible(instruction_space, ast::StateSpace::Reg) { match instruction_type { ast::Type::Pointer(instruction_ptr_type, instruction_ptr_space) if operand_space == *instruction_ptr_space => @@ -204,7 +234,7 @@ fn default_implicit_conversion_type( operand_type: &ast::Type, instruction_type: &ast::Type, ) -> Result, TranslateError> { - if state_is_compatible(space, ast::StateSpace::Reg) { + if space_is_compatible(space, ast::StateSpace::Reg) { if should_bitcast(instruction_type, operand_type) { Ok(Some(ConversionKind::Default)) } else { @@ -264,7 +294,7 @@ fn should_convert_relaxed_dst_wrapper( (operand_space, operand_type): (ast::StateSpace, &ast::Type), (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), ) -> Result, TranslateError> { - if !state_is_compatible(operand_space, instruction_space) { + if !space_is_compatible(operand_space, instruction_space) { return Err(error_mismatched_type()); } if operand_type == instruction_type { @@ -341,7 +371,7 @@ fn should_convert_relaxed_src_wrapper( (operand_space, operand_type): (ast::StateSpace, &ast::Type), (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), ) -> Result, TranslateError> { - if !state_is_compatible(operand_space, instruction_space) { + if !space_is_compatible(operand_space, instruction_space) { return Err(error_mismatched_type()); } if operand_type == instruction_type { diff --git a/ptx/src/pass/insert_mem_ssa_statements.rs b/ptx/src/pass/insert_mem_ssa_statements.rs index 7369cdb9..c1e30b0a 100644 --- a/ptx/src/pass/insert_mem_ssa_statements.rs +++ b/ptx/src/pass/insert_mem_ssa_statements.rs @@ -189,7 +189,7 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { return Ok(symbol); }; let (mut var_type, var_space, is_variable) = self.id_def.get_typed(symbol)?; - if !state_is_compatible(var_space, ast::StateSpace::Reg) || !is_variable { + if !space_is_compatible(var_space, ast::StateSpace::Reg) || !is_variable { return Ok(symbol); }; let member_index = match member_index { @@ -257,10 +257,9 @@ impl<'a, 'input> ast::VisitorMap TypedOperand::RegOffset(self.symbol(reg, None, type_space, is_dst)?, offset) } op @ TypedOperand::Imm(..) => op, - TypedOperand::VecMember(symbol, index) => TypedOperand::VecMember( - self.symbol(symbol, Some(index), type_space, is_dst)?, - index, - ), + TypedOperand::VecMember(symbol, index) => { + TypedOperand::Reg(self.symbol(symbol, Some(index), type_space, is_dst)?) + } }) } diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index d0f4dfb5..4ca2f025 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -1214,7 +1214,7 @@ impl< } } -fn state_is_compatible(this: ast::StateSpace, other: ast::StateSpace) -> bool { +fn space_is_compatible(this: ast::StateSpace, other: ast::StateSpace) -> bool { this == other || this == ast::StateSpace::Reg && other == ast::StateSpace::Sreg || this == ast::StateSpace::Sreg && other == ast::StateSpace::Reg From 0c9339325e693a383769533e0d6a9f8795689a70 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Sat, 31 Aug 2024 03:42:27 +0200 Subject: [PATCH 40/47] Correctly report dst in call instructions --- ptx_parser/src/ast.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index 39b464e3..2a6bb537 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -1256,7 +1256,7 @@ impl CallArgs { .iter() .zip(details.input_arguments.iter()) { - visitor.visit(param, Some((type_, *space)), true, false)?; + visitor.visit(param, Some((type_, *space)), false, false)?; } Ok(()) } @@ -1280,7 +1280,7 @@ impl CallArgs { .iter_mut() .zip(details.input_arguments.iter()) { - visitor.visit(param, Some((type_, *space)), true, false)?; + visitor.visit(param, Some((type_, *space)), false, false)?; } Ok(()) } @@ -1304,7 +1304,7 @@ impl CallArgs { .input_arguments .into_iter() .zip(details.input_arguments.iter()) - .map(|(param, (type_, space))| visitor.visit(param, Some((type_, *space)), true, false)) + .map(|(param, (type_, space))| visitor.visit(param, Some((type_, *space)), false, false)) .collect::, _>>()?; Ok(CallArgs { return_arguments, From 8d15499acc331f8e38787f4e44d9d536e6ecb7de Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Tue, 3 Sep 2024 02:19:27 +0200 Subject: [PATCH 41/47] More fixes --- .../pass/convert_to_stateful_memory_access.rs | 18 +- ptx/src/pass/convert_to_typed.rs | 2 +- ptx/src/pass/emit_spirv.rs | 37 +- ptx/src/pass/insert_implicit_conversions.rs | 20 +- ptx/src/pass/insert_mem_ssa_statements.rs | 2 +- ptx/src/pass/mod.rs | 14 +- ptx_parser/src/ast.rs | 30 +- ptx_parser/src/check_args.py | 69 +++ ptx_parser/src/lib.rs | 395 +++++++++++++++--- 9 files changed, 485 insertions(+), 102 deletions(-) create mode 100644 ptx_parser/src/check_args.py diff --git a/ptx/src/pass/convert_to_stateful_memory_access.rs b/ptx/src/pass/convert_to_stateful_memory_access.rs index ad4b473d..455a8c2e 100644 --- a/ptx/src/pass/convert_to_stateful_memory_access.rs +++ b/ptx/src/pass/convert_to_stateful_memory_access.rs @@ -467,8 +467,22 @@ fn convert_to_stateful_memory_access_postprocess( Some(new_id) => { let (new_operand_type, new_operand_space, _) = id_defs.get_typed(*new_id)?; // TODO: readd if required - if let Some(..) = type_space { - if relaxed_conversion { + if let Some((expected_type, expected_space)) = type_space { + let implicit_conversion = if relaxed_conversion { + if is_dst { + super::insert_implicit_conversions::should_convert_relaxed_dst_wrapper + } else { + super::insert_implicit_conversions::should_convert_relaxed_src_wrapper + } + } else { + super::insert_implicit_conversions::default_implicit_conversion + }; + if implicit_conversion( + (new_operand_space, &new_operand_type), + (expected_space, expected_type), + ) + .is_ok() + { return Ok(*new_id); } } diff --git a/ptx/src/pass/convert_to_typed.rs b/ptx/src/pass/convert_to_typed.rs index 2342ad51..c2af204a 100644 --- a/ptx/src/pass/convert_to_typed.rs +++ b/ptx/src/pass/convert_to_typed.rs @@ -67,7 +67,7 @@ impl<'a, 'b> VectorRepackVisitor<'a, 'b> { ) -> Result { // mov.u32 foobar, {a,b}; let scalar_t = match typ { - ast::Type::Vector(scalar_t, _) => *scalar_t, + ast::Type::Vector(_, scalar_t) => *scalar_t, _ => return Err(error_mismatched_type()), }; let temp_vec = self diff --git a/ptx/src/pass/emit_spirv.rs b/ptx/src/pass/emit_spirv.rs index e2e6a3b4..8aa4576d 100644 --- a/ptx/src/pass/emit_spirv.rs +++ b/ptx/src/pass/emit_spirv.rs @@ -291,7 +291,7 @@ impl TypeWordMap { | ast::ScalarType::BF16x2 | ast::ScalarType::B128 => todo!(), }, - ast::Type::Vector(typ, len) => { + ast::Type::Vector(len, typ) => { let result_type = self.get_or_add(b, SpirvType::Vector(SpirvScalarKey::from(*typ), *len)); let size_of_t = typ.size_of(); @@ -309,7 +309,7 @@ impl TypeWordMap { .collect::, _>>()?; SpirvWord(b.constant_composite(result_type.0, None, components.into_iter())) } - ast::Type::Array(typ, dims) => match dims.as_slice() { + ast::Type::Array(_, typ, dims) => match dims.as_slice() { [] => return Err(error_unreachable()), [dim] => { let result_type = self @@ -342,7 +342,7 @@ impl TypeWordMap { Ok::<_, TranslateError>( self.get_or_add_constant( b, - &ast::Type::Array(*typ, rest.to_vec()), + &ast::Type::Array(None, *typ, rest.to_vec()), &init[((size_of_t as usize) * (x as usize))..], )? .0, @@ -397,8 +397,8 @@ impl SpirvType { fn new(t: ast::Type) -> Self { match t { ast::Type::Scalar(t) => SpirvType::Base(t.into()), - ast::Type::Vector(typ, len) => SpirvType::Vector(typ.into(), len), - ast::Type::Array(t, len) => SpirvType::Array(t.into(), len), + ast::Type::Vector(len, typ) => SpirvType::Vector(typ.into(), len), + ast::Type::Array(_, t, len) => SpirvType::Array(t.into(), len), ast::Type::Pointer(pointer_t, space) => SpirvType::Pointer( Box::new(SpirvType::Base(pointer_t.into())), space_to_spirv(space), @@ -809,8 +809,8 @@ fn emit_function_header<'input>( pub fn type_size_of(this: &ast::Type) -> usize { match this { ast::Type::Scalar(typ) => typ.size_of() as usize, - ast::Type::Vector(typ, len) => (typ.size_of() as usize) * (*len as usize), - ast::Type::Array(typ, len) => len + ast::Type::Vector(len, typ) => (typ.size_of() as usize) * (*len as usize), + ast::Type::Array(_, typ, len) => len .iter() .fold(typ.size_of() as usize, |x, y| (x as usize) * (*y as usize)), ast::Type::Pointer(..) => mem::size_of::(), @@ -1853,11 +1853,16 @@ fn emit_mul_int( builder.i_mul(inst_type.0, Some(arg.dst.0), arg.src1.0, arg.src2.0)?; } ast::MulIntControl::High => { + let opencl_inst = if type_.kind() == ast::ScalarKind::Signed { + spirv::CLOp::s_mul_hi + } else { + spirv::CLOp::u_mul_hi + }; builder.ext_inst( inst_type.0, Some(arg.dst.0), opencl, - spirv::CLOp::s_mul_hi as spirv::Word, + opencl_inst as spirv::Word, [ dr::Operand::IdRef(arg.src1.0), dr::Operand::IdRef(arg.src2.0), @@ -2646,7 +2651,7 @@ fn emit_load_var( match details.member_index { Some((index, Some(width))) => { let vector_type = match details.typ { - ast::Type::Scalar(scalar_t) => ast::Type::Vector(scalar_t, width), + ast::Type::Scalar(scalar_t) => ast::Type::Vector(width, scalar_t), _ => return Err(error_mismatched_type()), }; let vector_type_spirv = map.get_or_add(builder, SpirvType::new(vector_type)); @@ -2710,14 +2715,14 @@ fn to_parts(this: &ast::Type) -> TypeParts { width: scalar.size_of(), components: Vec::new(), }, - ast::Type::Vector(scalar, components) => TypeParts { + ast::Type::Vector(components, scalar) => TypeParts { kind: TypeKind::Vector, state_space: ast::StateSpace::Reg, scalar_kind: scalar.kind(), width: scalar.size_of(), components: vec![*components as u32], }, - ast::Type::Array(scalar, components) => TypeParts { + ast::Type::Array(_, scalar, components) => TypeParts { kind: TypeKind::Array, state_space: ast::StateSpace::Reg, scalar_kind: scalar.kind(), @@ -2738,12 +2743,14 @@ fn type_from_parts(t: TypeParts) -> ast::Type { match t.kind { TypeKind::Scalar => ast::Type::Scalar(scalar_from_parts(t.width, t.scalar_kind)), TypeKind::Vector => ast::Type::Vector( - scalar_from_parts(t.width, t.scalar_kind), t.components[0] as u8, + scalar_from_parts(t.width, t.scalar_kind), + ), + TypeKind::Array => ast::Type::Array( + None, + scalar_from_parts(t.width, t.scalar_kind), + t.components, ), - TypeKind::Array => { - ast::Type::Array(scalar_from_parts(t.width, t.scalar_kind), t.components) - } TypeKind::Pointer => { ast::Type::Pointer(scalar_from_parts(t.width, t.scalar_kind), t.state_space) } diff --git a/ptx/src/pass/insert_implicit_conversions.rs b/ptx/src/pass/insert_implicit_conversions.rs index 0dce5984..2857551a 100644 --- a/ptx/src/pass/insert_implicit_conversions.rs +++ b/ptx/src/pass/insert_implicit_conversions.rs @@ -123,13 +123,13 @@ fn insert_implicit_conversions_impl( Ok(()) } -fn default_implicit_conversion( +pub(crate) fn default_implicit_conversion( (operand_space, operand_type): (ast::StateSpace, &ast::Type), (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), ) -> Result, TranslateError> { if instruction_space == ast::StateSpace::Reg { if space_is_compatible(operand_space, ast::StateSpace::Reg) { - if let (ast::Type::Vector(vec_underlying_type, vec_len), ast::Type::Scalar(scalar)) = + if let (ast::Type::Vector(vec_len, vec_underlying_type), ast::Type::Scalar(scalar)) = (operand_type, instruction_type) { if scalar.kind() == ast::ScalarKind::Bit @@ -282,15 +282,15 @@ fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool { ast::ScalarKind::Pred => false, } } - (ast::Type::Vector(inst, _), ast::Type::Vector(operand, _)) - | (ast::Type::Array(inst, _), ast::Type::Array(operand, _)) => { + (ast::Type::Vector(_, inst), ast::Type::Vector(_, operand)) + | (ast::Type::Array(_, inst, _), ast::Type::Array(_, operand, _)) => { should_bitcast(&ast::Type::Scalar(*inst), &ast::Type::Scalar(*operand)) } _ => false, } } -fn should_convert_relaxed_dst_wrapper( +pub(crate) fn should_convert_relaxed_dst_wrapper( (operand_space, operand_type): (ast::StateSpace, &ast::Type), (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), ) -> Result, TranslateError> { @@ -356,8 +356,8 @@ fn should_convert_relaxed_dst( } ast::ScalarKind::Pred => None, }, - (ast::Type::Vector(dst_type, _), ast::Type::Vector(instr_type, _)) - | (ast::Type::Array(dst_type, _), ast::Type::Array(instr_type, _)) => { + (ast::Type::Vector(_, dst_type), ast::Type::Vector(_, instr_type)) + | (ast::Type::Array(_, dst_type, _), ast::Type::Array(_, instr_type, _)) => { should_convert_relaxed_dst( &ast::Type::Scalar(*dst_type), &ast::Type::Scalar(*instr_type), @@ -367,7 +367,7 @@ fn should_convert_relaxed_dst( } } -fn should_convert_relaxed_src_wrapper( +pub(crate) fn should_convert_relaxed_src_wrapper( (operand_space, operand_type): (ast::StateSpace, &ast::Type), (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), ) -> Result, TranslateError> { @@ -420,8 +420,8 @@ fn should_convert_relaxed_src( } ast::ScalarKind::Pred => None, }, - (ast::Type::Vector(dst_type, _), ast::Type::Vector(instr_type, _)) - | (ast::Type::Array(dst_type, _), ast::Type::Array(instr_type, _)) => { + (ast::Type::Vector(_, dst_type), ast::Type::Vector(_, instr_type)) + | (ast::Type::Array(_, dst_type, _), ast::Type::Array(_, instr_type, _)) => { should_convert_relaxed_src( &ast::Type::Scalar(*dst_type), &ast::Type::Scalar(*instr_type), diff --git a/ptx/src/pass/insert_mem_ssa_statements.rs b/ptx/src/pass/insert_mem_ssa_statements.rs index c1e30b0a..e314b05d 100644 --- a/ptx/src/pass/insert_mem_ssa_statements.rs +++ b/ptx/src/pass/insert_mem_ssa_statements.rs @@ -195,7 +195,7 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { let member_index = match member_index { Some(idx) => { let vector_width = match var_type { - ast::Type::Vector(scalar_t, width) => { + ast::Type::Vector(width, scalar_t) => { var_type = ast::Type::Scalar(scalar_t); width } diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index 4ca2f025..92d1bf40 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -1,6 +1,7 @@ use ptx_parser as ast; use rspirv::{binary::Assemble, dr}; use std::hash::Hash; +use std::num::NonZeroU8; use std::{ borrow::Cow, cell::RefCell, @@ -360,7 +361,7 @@ impl PtxSpecialRegister { PtxSpecialRegister::Tid | PtxSpecialRegister::Ntid | PtxSpecialRegister::Ctaid - | PtxSpecialRegister::Nctaid => ast::Type::Vector(self.get_function_return_type(), 4), + | PtxSpecialRegister::Nctaid => ast::Type::Vector(4, self.get_function_return_type()), _ => ast::Type::Scalar(self.get_function_return_type()), } } @@ -764,7 +765,12 @@ impl> Statement, T> { }) } Statement::Conditional(conditional) => { - let predicate = visitor.visit_ident(conditional.predicate, None, false, false)?; + let predicate = visitor.visit_ident( + conditional.predicate, + Some((&ast::ScalarType::Pred.into(), ast::StateSpace::Reg)), + false, + false, + )?; let if_true = visitor.visit_ident(conditional.if_true, None, false, false)?; let if_false = visitor.visit_ident(conditional.if_false, None, false, false)?; Statement::Conditional(BrachCondition { @@ -919,7 +925,7 @@ impl> Statement, T> { let packed = visitor.visit_ident( packed, Some(( - &ast::Type::Vector(typ, unpacked.len() as u8), + &ast::Type::Vector(unpacked.len() as u8, typ), ast::StateSpace::Reg, )), false, @@ -930,7 +936,7 @@ impl> Statement, T> { let packed = visitor.visit_ident( packed, Some(( - &ast::Type::Vector(typ, unpacked.len() as u8), + &ast::Type::Vector(unpacked.len() as u8, typ), ast::StateSpace::Reg, )), true, diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index 2a6bb537..c2669472 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -4,7 +4,7 @@ use super::{ }; use crate::{PtxError, PtxParserState}; use bitflags::bitflags; -use std::cmp::Ordering; +use std::{cmp::Ordering, num::NonZeroU8}; pub enum Statement { Label(P::Ident), @@ -760,19 +760,37 @@ pub enum Type { // .param.b32 foo; Scalar(ScalarType), // .param.v2.b32 foo; - Vector(ScalarType, u8), + Vector(u8, ScalarType), // .param.b32 foo[4]; - Array(ScalarType, Vec), + Array(Option, ScalarType, Vec), Pointer(ScalarType, StateSpace), } impl Type { pub(crate) fn maybe_vector(vector: Option, scalar: ScalarType) -> Self { match vector { - Some(prefix) => Type::Vector(scalar, prefix.len()), + Some(prefix) => Type::Vector(prefix.len().get(), scalar), None => Type::Scalar(scalar), } } + + pub(crate) fn maybe_vector_parsed(prefix: Option, scalar: ScalarType) -> Self { + match prefix { + Some(prefix) => Type::Vector(prefix.get(), scalar), + None => Type::Scalar(scalar), + } + } + + pub(crate) fn maybe_array( + prefix: Option, + scalar: ScalarType, + array: Option>, + ) -> Self { + match array { + Some(dimensions) => Type::Array(prefix, scalar, dimensions), + None => Self::maybe_vector_parsed(prefix, scalar), + } + } } impl ScalarType { @@ -1304,7 +1322,9 @@ impl CallArgs { .input_arguments .into_iter() .zip(details.input_arguments.iter()) - .map(|(param, (type_, space))| visitor.visit(param, Some((type_, *space)), false, false)) + .map(|(param, (type_, space))| { + visitor.visit(param, Some((type_, *space)), false, false) + }) .collect::, _>>()?; Ok(CallArgs { return_arguments, diff --git a/ptx_parser/src/check_args.py b/ptx_parser/src/check_args.py new file mode 100644 index 00000000..04ffdb91 --- /dev/null +++ b/ptx_parser/src/check_args.py @@ -0,0 +1,69 @@ +import os, sys, subprocess + + +SPACE = [".reg", ".sreg", ".param", ".param::entry", ".param::func", ".local", ".global", ".const", ".shared", ".shared::cta", ".shared::cluster"] +TYPE_AND_INIT = ["", " = 1", "[1]", "[1] = {1}"] +MULTIVAR = ["", "<1>" ] +VECTOR = ["", ".v2" ] + +HEADER = """ + .version 8.5 + .target sm_90 + .address_size 64 +""" + + +def directive(space, variable, multivar, vector): + return """{3} + {0} {4} .b32 variable{2} {1}; + """.format(space, variable, multivar, HEADER, vector) + +def entry_arg(space, variable, multivar, vector): + return """{3} + .entry foobar ( {0} {4} .b32 variable{2} {1}) + {{ + ret; + }} + """.format(space, variable, multivar, HEADER, vector) + + +def fn_arg(space, variable, multivar, vector): + return """{3} + .func foobar ( {0} {4} .b32 variable{2} {1}) + {{ + ret; + }} + """.format(space, variable, multivar, HEADER, vector) + + +def fn_body(space, variable, multivar, vector): + return """{3} + .func foobar () + {{ + {0} {4} .b32 variable{2} {1}; + ret; + }} + """.format(space, variable, multivar, HEADER, vector) + + +def generate(generator): + legal = [] + for space in SPACE: + for init in TYPE_AND_INIT: + for multi in MULTIVAR: + for vector in VECTOR: + ptx = generator(space, init, multi, vector) + if 0 == subprocess.call(["C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v12.6\\bin\\ptxas.exe", "-arch", "sm_90", "-ias", ptx], stdout = subprocess.DEVNULL): # + legal.append((space, vector, init, multi)) + print(generator.__name__) + print(legal) + + +def main(): + generate(directive) + generate(entry_arg) + generate(fn_arg) + generate(fn_body) + +if __name__ == "__main__": + main() diff --git a/ptx_parser/src/lib.rs b/ptx_parser/src/lib.rs index 3a9ece5b..dfe78eec 100644 --- a/ptx_parser/src/lib.rs +++ b/ptx_parser/src/lib.rs @@ -3,9 +3,10 @@ use logos::Logos; use ptx_parser_macros::derive_parser; use rustc_hash::FxHashMap; use std::fmt::Debug; -use std::num::{ParseFloatError, ParseIntError}; +use std::num::{NonZeroU8, ParseFloatError, ParseIntError}; use winnow::ascii::dec_uint; use winnow::combinator::*; +use winnow::error::{ErrMode, ErrorKind}; use winnow::stream::Accumulate; use winnow::token::any; use winnow::{ @@ -72,11 +73,13 @@ impl From for ast::RoundingMode { } impl VectorPrefix { - pub(crate) fn len(self) -> u8 { - match self { - VectorPrefix::V2 => 2, - VectorPrefix::V4 => 4, - VectorPrefix::V8 => 8, + pub(crate) fn len(self) -> NonZeroU8 { + unsafe { + match self { + VectorPrefix::V2 => NonZeroU8::new_unchecked(2), + VectorPrefix::V4 => NonZeroU8::new_unchecked(4), + VectorPrefix::V8 => NonZeroU8::new_unchecked(8), + } } } } @@ -386,22 +389,14 @@ fn module_variable<'a, 'input>( ) -> PResult<(ast::LinkingDirective, ast::Variable<&'input str>)> { ( linking_directives, - module_variable_state_space.flat_map(variable_scalar_or_vector), + global_space + .flat_map(multi_variable) + // TODO: support multi var in globals + .map(|multi_var| multi_var.var), ) .parse_next(stream) } -fn module_variable_state_space<'a, 'input>( - stream: &mut PtxParser<'a, 'input>, -) -> PResult { - alt(( - Token::DotConst.value(StateSpace::Const), - Token::DotGlobal.value(StateSpace::Global), - Token::DotShared.value(StateSpace::Shared), - )) - .parse_next(stream) -} - fn file<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> { ( Token::DotFile, @@ -547,17 +542,13 @@ fn kernel_arguments<'a, 'input>( fn kernel_input<'a, 'input>( stream: &mut PtxParser<'a, 'input>, ) -> PResult> { - preceded( - Token::DotParam, - variable_scalar_or_vector(StateSpace::Param), - ) - .parse_next(stream) + preceded(Token::DotParam, method_parameter(StateSpace::Param)).parse_next(stream) } fn fn_input<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult> { dispatch! { any; - Token::DotParam => variable_scalar_or_vector(StateSpace::Param), - Token::DotReg => variable_scalar_or_vector(StateSpace::Reg), + Token::DotParam => method_parameter(StateSpace::Param), + Token::DotReg => method_parameter(StateSpace::Reg), _ => fail } .parse_next(stream) @@ -596,7 +587,7 @@ fn tuple1to3_u32<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<(u32 } } - separated::<_, _, Tuple3AccumulateU32, _, _, _, _>(1..3, u32, Token::Comma) + separated::<_, _, Tuple3AccumulateU32, _, _, _, _>(1..=3, u32, Token::Comma) .map(|acc| acc.value) .parse_next(stream) } @@ -618,7 +609,12 @@ fn statement<'a, 'input>( alt(( label.map(Some), debug_directive.map(|_| None), - multi_variable.map(Some), + terminated( + method_space + .flat_map(multi_variable) + .map(|var| Some(Statement::Variable(var))), + Token::Semicolon, + ), predicated_instruction.map(Some), pragma.map(|_| None), block_statement.map(Some), @@ -632,59 +628,328 @@ fn pragma<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> { .parse_next(stream) } -fn multi_variable<'a, 'input>( +fn method_parameter<'a, 'input: 'a>( + state_space: StateSpace, +) -> impl Parser, Variable<&'input str>, ContextError> { + move |stream: &mut PtxParser<'a, 'input>| { + let (align, vector, type_, name) = variable_declaration.parse_next(stream)?; + let array_dimensions = if state_space != StateSpace::Reg { + opt(array_dimensions).parse_next(stream)? + } else { + None + }; + // TODO: push this check into array_dimensions(...) + if let Some(ref dims) = array_dimensions { + if dims[0] == 0 { + return Err(ErrMode::from_error_kind(stream, ErrorKind::Verify)); + } + } + Ok(Variable { + align, + v_type: Type::maybe_array(vector, type_, array_dimensions), + state_space, + name, + array_init: Vec::new(), + }) + } +} + +// TODO: split to a separate type +fn variable_declaration<'a, 'input>( stream: &mut PtxParser<'a, 'input>, -) -> PResult>> { +) -> PResult<(Option, Option, ScalarType, &'input str)> { ( - variable, - opt(delimited(Token::Lt, u32, Token::Gt)), - Token::Semicolon, + opt(align.verify(|x| x.count_ones() == 1)), + vector_prefix, + scalar_type, + ident, ) - .map(|(var, count, _)| ast::Statement::Variable(ast::MultiVariable { var, count })) .parse_next(stream) } -fn variable<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult> { - dispatch! {any; - Token::DotReg => variable_scalar_or_vector(StateSpace::Reg), - Token::DotLocal => variable_scalar_or_vector(StateSpace::Local), - Token::DotParam => variable_scalar_or_vector(StateSpace::Param), - Token::DotShared => variable_scalar_or_vector(StateSpace::Shared), - _ => fail - } - .parse_next(stream) -} - -fn variable_scalar_or_vector<'a, 'input: 'a>( +fn multi_variable<'a, 'input: 'a>( state_space: StateSpace, -) -> impl Parser, ast::Variable<&'input str>, ContextError> { +) -> impl Parser, MultiVariable<&'input str>, ContextError> { move |stream: &mut PtxParser<'a, 'input>| { - (opt(align), scalar_vector_type, ident) - .map(|(align, v_type, name)| ast::Variable { + let ((align, vector, type_, name), count) = ( + variable_declaration, + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameterized-variable-names + opt(delimited(Token::Lt, u32.verify(|x| *x != 0), Token::Gt)), + ) + .parse_next(stream)?; + if count.is_some() { + return Ok(MultiVariable { + var: Variable { + align, + v_type: Type::maybe_vector_parsed(vector, type_), + state_space, + name, + array_init: Vec::new(), + }, + count, + }); + } + let mut array_dimensions = if state_space != StateSpace::Reg { + opt(array_dimensions).parse_next(stream)? + } else { + None + }; + let initializer = match state_space { + StateSpace::Global | StateSpace::Const => match array_dimensions { + Some(ref mut dimensions) => { + opt(array_initializer(vector, type_, dimensions)).parse_next(stream)? + } + None => opt(value_initializer(vector, type_)).parse_next(stream)?, + }, + _ => None, + }; + if let Some(ref dims) = array_dimensions { + if dims[0] == 0 { + return Err(ErrMode::from_error_kind(stream, ErrorKind::Verify)); + } + } + Ok(MultiVariable { + var: Variable { align, - v_type, + v_type: Type::maybe_array(vector, type_, array_dimensions), state_space, name, - array_init: Vec::new(), - }) - .parse_next(stream) + array_init: initializer.unwrap_or(Vec::new()), + }, + count, + }) + } +} + +fn array_initializer<'a, 'input: 'a>( + vector: Option, + type_: ScalarType, + array_dimensions: &mut Vec, +) -> impl Parser, Vec, ContextError> + '_ { + move |stream: &mut PtxParser<'a, 'input>| { + Token::Eq.parse_next(stream)?; + let mut result = Vec::new(); + // TODO: vector constants and multi dim arrays + if vector.is_some() || array_dimensions[0] == 0 || array_dimensions.len() > 1 { + return Err(ErrMode::from_error_kind(stream, ErrorKind::Verify)); + } + delimited( + Token::LBracket, + separated( + array_dimensions[0] as usize..=array_dimensions[0] as usize, + single_value_append(&mut result, type_), + Token::Comma, + ), + Token::RBracket, + ) + .parse_next(stream)?; + Ok(result) } } +fn value_initializer<'a, 'input: 'a>( + vector: Option, + type_: ScalarType, +) -> impl Parser, Vec, ContextError> { + move |stream: &mut PtxParser<'a, 'input>| { + Token::Eq.parse_next(stream)?; + let mut result = Vec::new(); + // TODO: vector constants + if vector.is_some() { + return Err(ErrMode::from_error_kind(stream, ErrorKind::Verify)); + } + single_value_append(&mut result, type_).parse_next(stream)?; + Ok(result) + } +} + +fn single_value_append<'a, 'input: 'a>( + accumulator: &mut Vec, + type_: ScalarType, +) -> impl Parser, (), ContextError> + '_ { + move |stream: &mut PtxParser<'a, 'input>| { + let value = immediate_value.parse_next(stream)?; + match (type_, value) { + (ScalarType::U8, ImmediateValue::U64(x)) => accumulator.extend_from_slice( + &u8::try_from(x) + .map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))? + .to_le_bytes(), + ), + (ScalarType::U8, ImmediateValue::S64(x)) => accumulator.extend_from_slice( + &u8::try_from(x) + .map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))? + .to_le_bytes(), + ), + (ScalarType::U16, ImmediateValue::U64(x)) => accumulator.extend_from_slice( + &u16::try_from(x) + .map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))? + .to_le_bytes(), + ), + (ScalarType::U16, ImmediateValue::S64(x)) => accumulator.extend_from_slice( + &u16::try_from(x) + .map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))? + .to_le_bytes(), + ), + (ScalarType::U32, ImmediateValue::U64(x)) => accumulator.extend_from_slice( + &u32::try_from(x) + .map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))? + .to_le_bytes(), + ), + (ScalarType::U32, ImmediateValue::S64(x)) => accumulator.extend_from_slice( + &u32::try_from(x) + .map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))? + .to_le_bytes(), + ), + (ScalarType::U64, ImmediateValue::U64(x)) => accumulator.extend_from_slice( + &u64::try_from(x) + .map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))? + .to_le_bytes(), + ), + (ScalarType::U64, ImmediateValue::S64(x)) => accumulator.extend_from_slice( + &u64::try_from(x) + .map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))? + .to_le_bytes(), + ), + (ScalarType::S8, ImmediateValue::U64(x)) => accumulator.extend_from_slice( + &i8::try_from(x) + .map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))? + .to_le_bytes(), + ), + (ScalarType::S8, ImmediateValue::S64(x)) => accumulator.extend_from_slice( + &i8::try_from(x) + .map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))? + .to_le_bytes(), + ), + (ScalarType::S16, ImmediateValue::U64(x)) => accumulator.extend_from_slice( + &i16::try_from(x) + .map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))? + .to_le_bytes(), + ), + (ScalarType::S16, ImmediateValue::S64(x)) => accumulator.extend_from_slice( + &i16::try_from(x) + .map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))? + .to_le_bytes(), + ), + (ScalarType::S32, ImmediateValue::U64(x)) => accumulator.extend_from_slice( + &i32::try_from(x) + .map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))? + .to_le_bytes(), + ), + (ScalarType::S32, ImmediateValue::S64(x)) => accumulator.extend_from_slice( + &i32::try_from(x) + .map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))? + .to_le_bytes(), + ), + (ScalarType::S64, ImmediateValue::U64(x)) => accumulator.extend_from_slice( + &i64::try_from(x) + .map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))? + .to_le_bytes(), + ), + (ScalarType::S64, ImmediateValue::S64(x)) => accumulator.extend_from_slice( + &i64::try_from(x) + .map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))? + .to_le_bytes(), + ), + (ScalarType::F32, ImmediateValue::F32(x)) => { + accumulator.extend_from_slice(&x.to_le_bytes()) + } + (ScalarType::F64, ImmediateValue::F64(x)) => { + accumulator.extend_from_slice(&x.to_le_bytes()) + } + _ => return Err(ErrMode::from_error_kind(stream, ErrorKind::Verify)), + } + Ok(()) + } +} + +fn array_dimensions<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult> { + let dimension = delimited( + Token::LBracket, + opt(u32).verify(|dim| *dim != Some(0)), + Token::RBracket, + ) + .parse_next(stream)?; + let result = vec![dimension.unwrap_or(0)]; + repeat_fold_0_or_more( + delimited( + Token::LBracket, + u32.verify(|dim| *dim != 0), + Token::RBracket, + ), + move || result, + |mut result: Vec, x| { + result.push(x); + result + }, + stream, + ) +} + +// Copied and fixed from Winnow sources (fold_repeat0_) +// Winnow Repeat::fold takes FnMut() -> Result to initalize accumulator, +// this really should be FnOnce() -> Result +fn repeat_fold_0_or_more( + mut f: F, + init: H, + mut g: G, + input: &mut I, +) -> PResult +where + I: Stream, + F: Parser, + G: FnMut(R, O) -> R, + H: FnOnce() -> R, + E: ParserError, +{ + use winnow::error::ErrMode; + let mut res = init(); + loop { + let start = input.checkpoint(); + match f.parse_next(input) { + Ok(o) => { + res = g(res, o); + } + Err(ErrMode::Backtrack(_)) => { + input.reset(&start); + return Ok(res); + } + Err(e) => { + return Err(e); + } + } + } +} + +fn global_space<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult { + alt(( + Token::DotGlobal.value(StateSpace::Global), + Token::DotConst.value(StateSpace::Const), + Token::DotShared.value(StateSpace::Shared), + )) + .parse_next(stream) +} + +fn method_space<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult { + alt(( + Token::DotReg.value(StateSpace::Reg), + Token::DotLocal.value(StateSpace::Local), + Token::DotParam.value(StateSpace::Param), + global_space, + )) + .parse_next(stream) +} + fn align<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult { preceded(Token::DotAlign, u32).parse_next(stream) } -fn scalar_vector_type<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult { - ( - opt(alt(( - Token::DotV2.value(VectorPrefix::V2), - Token::DotV4.value(VectorPrefix::V4), - ))), - scalar_type, - ) - .map(|(prefix, scalar)| ast::Type::maybe_vector(prefix, scalar)) - .parse_next(stream) +fn vector_prefix<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult> { + opt(alt(( + Token::DotV2.value(unsafe { NonZeroU8::new_unchecked(2) }), + Token::DotV4.value(unsafe { NonZeroU8::new_unchecked(4) }), + Token::DotV8.value(unsafe { NonZeroU8::new_unchecked(8) }), + ))) + .parse_next(stream) } fn scalar_type<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult { @@ -1157,6 +1422,8 @@ derive_parser!( Minus, #[token("+")] Plus, + #[token("=")] + Eq, #[token(".version")] DotVersion, #[token(".loc")] @@ -2509,7 +2776,7 @@ derive_parser!( scope: scope.unwrap_or(MemScope::Gpu), space: global.unwrap_or(StateSpace::Generic), op: ast::AtomicOp::new(float_op, f32.kind()), - type_: ast::Type::Vector(f32, vec_32_bit.len()) + type_: ast::Type::Vector(vec_32_bit.len().get(), f32) }, arguments: AtomArgs { dst: d, src1: a, src2: b } } @@ -2840,7 +3107,7 @@ derive_parser!( // .mode = { .f4e, .b4e, .rc8, .ecl, .ecr, .rc16 }; prmt.b32 d, a, b, c => { match c { - ast::ParsedOperand::Imm(ImmediateValue::U64(control)) => ast::Instruction::Prmt { + ast::ParsedOperand::Imm(ImmediateValue::S64(control)) => ast::Instruction::Prmt { data: control as u16, arguments: PrmtArgs { dst: d, src1: a, src2: b From 340ad86d56fbd454418399ce90bdb7d246418a46 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Tue, 3 Sep 2024 05:25:31 +0200 Subject: [PATCH 42/47] Emit correct float add --- ptx_parser/src/ast.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index c2669472..f0d7f9f7 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -1608,6 +1608,7 @@ impl AtomicOp { (RawAtomicOp::Or, _) => Self::Or, (RawAtomicOp::Xor, _) => Self::Xor, (RawAtomicOp::Exch, _) => Self::Exchange, + (RawAtomicOp::Add, ScalarKind::Float) => Self::FloatAdd, (RawAtomicOp::Add, _) => Self::Add, (RawAtomicOp::Inc, _) => Self::IncrementWrap, (RawAtomicOp::Dec, _) => Self::DecrementWrap, From 7a45b44854e310cbb3a17b0051fb052ca877208c Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Tue, 3 Sep 2024 16:24:50 +0200 Subject: [PATCH 43/47] Fix more failing tests --- ptx/src/pass/convert_to_typed.rs | 9 ++++++ ptx/src/pass/mod.rs | 53 ++++++++++++++++---------------- ptx_parser/src/lib.rs | 32 +++++++++++++++++++ 3 files changed, 67 insertions(+), 27 deletions(-) diff --git a/ptx/src/pass/convert_to_typed.rs b/ptx/src/pass/convert_to_typed.rs index c2af204a..ab5b2463 100644 --- a/ptx/src/pass/convert_to_typed.rs +++ b/ptx/src/pass/convert_to_typed.rs @@ -26,6 +26,15 @@ pub(crate) fn run( src: src_reg, })); } + ast::Instruction::Call { data, arguments } => { + let resolver = fn_defs.get_fn_sig_resolver(arguments.func)?; + let resolved_call = resolver.resolve_in_spirv_repr(data, arguments)?; + let mut visitor = VectorRepackVisitor::new(&mut result, id_defs); + let reresolved_call = + Statement::Instruction(ast::visit_map(resolved_call, &mut visitor)?); + visitor.func.push(reresolved_call); + visitor.func.extend(visitor.post_stmts); + } inst => { let mut visitor = VectorRepackVisitor::new(&mut result, id_defs); let instruction = Statement::Instruction(ast::visit_map(inst, &mut visitor)?); diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index 92d1bf40..2be6297a 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -668,57 +668,56 @@ impl<'input> FnSigMapper<'input> { } } - /* fn resolve_in_spirv_repr( &self, - call_inst: ast::CallInst, - ) -> Result, TranslateError> { + data: ast::CallDetails, + arguments: ast::CallArgs>, + ) -> Result>, TranslateError> { let func_decl = (*self.func_decl).borrow(); - let mut return_arguments = Vec::new(); - let mut input_arguments = call_inst - .param_list - .into_iter() - .zip(func_decl.input_arguments.iter()) - .map(|(id, var)| (id, var.v_type.clone(), var.state_space)) - .collect::>(); + let mut data_return = Vec::new(); + let mut arguments_return = Vec::new(); + let mut data_input = data.input_arguments; + let mut arguments_input = arguments.input_arguments; let mut func_decl_return_iter = func_decl.return_arguments.iter(); - let mut func_decl_input_iter = func_decl.input_arguments[input_arguments.len()..].iter(); - for (idx, id) in call_inst.ret_params.iter().enumerate() { + let mut func_decl_input_iter = func_decl.input_arguments[arguments_input.len()..].iter(); + for (idx, id) in arguments.return_arguments.iter().enumerate() { let stays_as_return = match self.return_param_args.get(idx) { Some(x) => *x, None => return Err(TranslateError::MismatchedType), }; if stays_as_return { if let Some(var) = func_decl_return_iter.next() { - return_arguments.push((*id, var.v_type.clone(), var.state_space)); + data_return.push((var.v_type.clone(), var.state_space)); + arguments_return.push(*id); } else { return Err(TranslateError::MismatchedType); } } else { if let Some(var) = func_decl_input_iter.next() { - input_arguments.push(( - ast::Operand::Reg(*id), - var.v_type.clone(), - var.state_space, - )); + data_input.push((var.v_type.clone(), var.state_space)); + arguments_input.push(ast::ParsedOperand::Reg(*id)); } else { return Err(TranslateError::MismatchedType); } } } - if return_arguments.len() != func_decl.return_arguments.len() - || input_arguments.len() != func_decl.input_arguments.len() + if arguments_return.len() != func_decl.return_arguments.len() + || arguments_input.len() != func_decl.input_arguments.len() { return Err(TranslateError::MismatchedType); } - Ok(ResolvedCall { - return_arguments, - input_arguments, - uniform: call_inst.uniform, - name: call_inst.func, - }) + let data = ast::CallDetails { + uniform: data.uniform, + return_arguments: data_return, + input_arguments: data_input, + }; + let arguments = ast::CallArgs { + func: arguments.func, + return_arguments: arguments_return, + input_arguments: arguments_input, + }; + Ok(ast::Instruction::Call { data, arguments }) } - */ } enum Statement { diff --git a/ptx_parser/src/lib.rs b/ptx_parser/src/lib.rs index dfe78eec..3d095113 100644 --- a/ptx_parser/src/lib.rs +++ b/ptx_parser/src/lib.rs @@ -1663,6 +1663,38 @@ derive_parser!( RawLdStQualifier = { .weak, .volatile }; StateSpace = { .global }; + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld-global-nc + ld.global{.cop}.nc{.level::eviction_priority}{.level::cache_hint}{.level::prefetch_size}{.vec}.type d, [a]{, cache_policy} => { + if cop.is_some() && level_eviction_priority.is_some() { + state.errors.push(PtxError::SyntaxError); + } + if level_eviction_priority.is_some() || level_cache_hint || level_prefetch_size.is_some() || cache_policy.is_some() { + state.errors.push(PtxError::Todo); + } + Instruction::Ld { + data: LdDetails { + qualifier: ast::LdStQualifier::Weak, + state_space: global, + caching: cop.unwrap_or(RawLdCacheOperator::Ca).into(), + typ: Type::maybe_vector(vec, type_), + non_coherent: true + }, + arguments: LdArgs { dst:d, src:a } + } + } + .cop: RawLdCacheOperator = { .ca, .cg, .cs }; + .level::eviction_priority: EvictionPriority = + { .L1::evict_normal, .L1::evict_unchanged, + .L1::evict_first, .L1::evict_last, .L1::no_allocate}; + .level::cache_hint = { .L2::cache_hint }; + .level::prefetch_size: PrefetchSize = { .L2::64B, .L2::128B, .L2::256B }; + .vec: VectorPrefix = { .v2, .v4 }; + .type: ScalarType = { .b8, .b16, .b32, .b64, .b128, + .u8, .u16, .u32, .u64, + .s8, .s16, .s32, .s64, + .f32, .f64 }; + StateSpace = { .global }; + // https://docs.nvidia.com/cuda/parallel-thread-execution/#integer-arithmetic-instructions-add add.type d, a, b => { Instruction::Add { From 6a7c871b252ee3b9d9ec49000ac779487d8ee8b8 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Tue, 3 Sep 2024 16:53:06 +0200 Subject: [PATCH 44/47] Fix array initializers --- ptx_parser/src/lib.rs | 134 ++++++++++++++++-------------------------- 1 file changed, 51 insertions(+), 83 deletions(-) diff --git a/ptx_parser/src/lib.rs b/ptx_parser/src/lib.rs index 3d095113..357304b2 100644 --- a/ptx_parser/src/lib.rs +++ b/ptx_parser/src/lib.rs @@ -734,13 +734,13 @@ fn array_initializer<'a, 'input: 'a>( return Err(ErrMode::from_error_kind(stream, ErrorKind::Verify)); } delimited( - Token::LBracket, + Token::LBrace, separated( array_dimensions[0] as usize..=array_dimensions[0] as usize, single_value_append(&mut result, type_), Token::Comma, ), - Token::RBracket, + Token::RBrace, ) .parse_next(stream)?; Ok(result) @@ -770,86 +770,54 @@ fn single_value_append<'a, 'input: 'a>( move |stream: &mut PtxParser<'a, 'input>| { let value = immediate_value.parse_next(stream)?; match (type_, value) { - (ScalarType::U8, ImmediateValue::U64(x)) => accumulator.extend_from_slice( - &u8::try_from(x) - .map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))? - .to_le_bytes(), - ), - (ScalarType::U8, ImmediateValue::S64(x)) => accumulator.extend_from_slice( - &u8::try_from(x) - .map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))? - .to_le_bytes(), - ), - (ScalarType::U16, ImmediateValue::U64(x)) => accumulator.extend_from_slice( - &u16::try_from(x) - .map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))? - .to_le_bytes(), - ), - (ScalarType::U16, ImmediateValue::S64(x)) => accumulator.extend_from_slice( - &u16::try_from(x) - .map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))? - .to_le_bytes(), - ), - (ScalarType::U32, ImmediateValue::U64(x)) => accumulator.extend_from_slice( - &u32::try_from(x) - .map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))? - .to_le_bytes(), - ), - (ScalarType::U32, ImmediateValue::S64(x)) => accumulator.extend_from_slice( - &u32::try_from(x) - .map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))? - .to_le_bytes(), - ), - (ScalarType::U64, ImmediateValue::U64(x)) => accumulator.extend_from_slice( - &u64::try_from(x) - .map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))? - .to_le_bytes(), - ), - (ScalarType::U64, ImmediateValue::S64(x)) => accumulator.extend_from_slice( - &u64::try_from(x) - .map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))? - .to_le_bytes(), - ), - (ScalarType::S8, ImmediateValue::U64(x)) => accumulator.extend_from_slice( - &i8::try_from(x) - .map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))? - .to_le_bytes(), - ), - (ScalarType::S8, ImmediateValue::S64(x)) => accumulator.extend_from_slice( - &i8::try_from(x) - .map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))? - .to_le_bytes(), - ), - (ScalarType::S16, ImmediateValue::U64(x)) => accumulator.extend_from_slice( - &i16::try_from(x) - .map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))? - .to_le_bytes(), - ), - (ScalarType::S16, ImmediateValue::S64(x)) => accumulator.extend_from_slice( - &i16::try_from(x) - .map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))? - .to_le_bytes(), - ), - (ScalarType::S32, ImmediateValue::U64(x)) => accumulator.extend_from_slice( - &i32::try_from(x) - .map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))? - .to_le_bytes(), - ), - (ScalarType::S32, ImmediateValue::S64(x)) => accumulator.extend_from_slice( - &i32::try_from(x) - .map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))? - .to_le_bytes(), - ), - (ScalarType::S64, ImmediateValue::U64(x)) => accumulator.extend_from_slice( - &i64::try_from(x) - .map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))? - .to_le_bytes(), - ), - (ScalarType::S64, ImmediateValue::S64(x)) => accumulator.extend_from_slice( - &i64::try_from(x) - .map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))? - .to_le_bytes(), - ), + (ScalarType::U8 | ScalarType::B8, ImmediateValue::U64(x)) => { + accumulator.extend_from_slice(&(x as u8).to_le_bytes()) + } + (ScalarType::U8 | ScalarType::B8, ImmediateValue::S64(x)) => { + accumulator.extend_from_slice(&(x as u8).to_le_bytes()) + } + (ScalarType::U16 | ScalarType::B16, ImmediateValue::U64(x)) => { + accumulator.extend_from_slice(&(x as u16).to_le_bytes()) + } + (ScalarType::U16 | ScalarType::B16, ImmediateValue::S64(x)) => { + accumulator.extend_from_slice(&(x as u16).to_le_bytes()) + } + (ScalarType::U32 | ScalarType::B32, ImmediateValue::U64(x)) => { + accumulator.extend_from_slice(&(x as u32).to_le_bytes()) + } + (ScalarType::U32 | ScalarType::B32, ImmediateValue::S64(x)) => { + accumulator.extend_from_slice(&(x as u32).to_le_bytes()) + } + (ScalarType::U64 | ScalarType::B64, ImmediateValue::U64(x)) => { + accumulator.extend_from_slice(&(x as u64).to_le_bytes()) + } + (ScalarType::U64 | ScalarType::B64, ImmediateValue::S64(x)) => { + accumulator.extend_from_slice(&(x as u64).to_le_bytes()) + } + (ScalarType::S8, ImmediateValue::U64(x)) => { + accumulator.extend_from_slice(&(x as i8).to_le_bytes()) + } + (ScalarType::S8, ImmediateValue::S64(x)) => { + accumulator.extend_from_slice(&(x as i8).to_le_bytes()) + } + (ScalarType::S16, ImmediateValue::U64(x)) => { + accumulator.extend_from_slice(&(x as i16).to_le_bytes()) + } + (ScalarType::S16, ImmediateValue::S64(x)) => { + accumulator.extend_from_slice(&(x as i16).to_le_bytes()) + } + (ScalarType::S32, ImmediateValue::U64(x)) => { + accumulator.extend_from_slice(&(x as i32).to_le_bytes()) + } + (ScalarType::S32, ImmediateValue::S64(x)) => { + accumulator.extend_from_slice(&(x as i32).to_le_bytes()) + } + (ScalarType::S64, ImmediateValue::U64(x)) => { + accumulator.extend_from_slice(&(x as i64).to_le_bytes()) + } + (ScalarType::S64, ImmediateValue::S64(x)) => { + accumulator.extend_from_slice(&(x as i64).to_le_bytes()) + } (ScalarType::F32, ImmediateValue::F32(x)) => { accumulator.extend_from_slice(&x.to_le_bytes()) } @@ -1683,7 +1651,7 @@ derive_parser!( } } .cop: RawLdCacheOperator = { .ca, .cg, .cs }; - .level::eviction_priority: EvictionPriority = + .level::eviction_priority: EvictionPriority = { .L1::evict_normal, .L1::evict_unchanged, .L1::evict_first, .L1::evict_last, .L1::no_allocate}; .level::cache_hint = { .L2::cache_hint }; From 3f31069e1bcd68bee2c0761dc2e817b9fc65579d Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Tue, 3 Sep 2024 18:11:09 +0200 Subject: [PATCH 45/47] Allow ftz and saturated conversions --- ptx/src/pass/emit_spirv.rs | 12 ------------ ptx_parser/src/ast.rs | 17 +++++++++++++++-- ptx_parser/src/lib.rs | 26 ++++++++++++++++++-------- 3 files changed, 33 insertions(+), 22 deletions(-) diff --git a/ptx/src/pass/emit_spirv.rs b/ptx/src/pass/emit_spirv.rs index 8aa4576d..5147b79f 100644 --- a/ptx/src/pass/emit_spirv.rs +++ b/ptx/src/pass/emit_spirv.rs @@ -2163,9 +2163,6 @@ fn emit_cvt( builder.sat_convert_s_to_u(result_type.0, Some(arg.dst.0), arg.src.0)?; } ptx_parser::CvtMode::FPExtend { flush_to_zero } => { - if flush_to_zero == Some(true) { - todo!() - } let result_type = map.get_or_add(builder, SpirvType::from(dets.to)); builder.f_convert(result_type.0, Some(arg.dst.0), arg.src.0)?; } @@ -2173,9 +2170,6 @@ fn emit_cvt( rounding, flush_to_zero, } => { - if flush_to_zero == Some(true) { - todo!() - } let result_type = map.get_or_add(builder, SpirvType::from(dets.to)); builder.f_convert(result_type.0, Some(arg.dst.0), arg.src.0)?; emit_rounding_decoration(builder, arg.dst, Some(rounding)); @@ -2234,9 +2228,6 @@ fn emit_cvt( rounding, flush_to_zero, } => { - if flush_to_zero == Some(true) { - todo!() - } let dest_t: ast::ScalarType = dets.to.into(); let result_type = map.get_or_add(builder, SpirvType::from(dest_t)); builder.convert_f_to_s(result_type.0, Some(arg.dst.0), arg.src.0)?; @@ -2246,9 +2237,6 @@ fn emit_cvt( rounding, flush_to_zero, } => { - if flush_to_zero == Some(true) { - todo!() - } let dest_t: ast::ScalarType = dets.to.into(); let result_type = map.get_or_add(builder, SpirvType::from(dest_t)); builder.convert_f_to_u(result_type.0, Some(arg.dst.0), arg.src.0)?; diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index f0d7f9f7..f5e65b4b 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -1384,8 +1384,8 @@ impl CvtDetails { dst: ScalarType, src: ScalarType, ) -> Self { - if saturate { - errors.push(PtxError::Todo); + if saturate && dst.kind() == ScalarKind::Float { + errors.push(PtxError::SyntaxError); } // Modifier .ftz can only be specified when either .dtype or .atype is .f32 and applies only to single precision (.f32) inputs and results. let flush_to_zero = match (dst, src) { @@ -1432,6 +1432,18 @@ impl CvtDetails { }, (ScalarKind::Float, ScalarKind::Unsigned) => CvtMode::FPFromUnsigned(unwrap_rounding()), (ScalarKind::Float, ScalarKind::Signed) => CvtMode::FPFromSigned(unwrap_rounding()), + (ScalarKind::Signed, ScalarKind::Unsigned) if saturate => { + CvtMode::SaturateUnsignedToSigned + } + (ScalarKind::Unsigned, ScalarKind::Signed) if saturate => { + CvtMode::SaturateSignedToUnsigned + } + (ScalarKind::Unsigned, ScalarKind::Signed) + | (ScalarKind::Signed, ScalarKind::Unsigned) + if dst.size_of() == src.size_of() => + { + CvtMode::Bitcast + } (ScalarKind::Unsigned, ScalarKind::Unsigned) | (ScalarKind::Signed, ScalarKind::Signed) => match dst.size_of().cmp(&src.size_of()) { Ordering::Less => CvtMode::Truncate, @@ -1444,6 +1456,7 @@ impl CvtDetails { } } }, + (ScalarKind::Unsigned, ScalarKind::Signed) => CvtMode::SaturateSignedToUnsigned, (_, _) => { errors.push(PtxError::SyntaxError); CvtMode::Bitcast diff --git a/ptx_parser/src/lib.rs b/ptx_parser/src/lib.rs index 357304b2..b81d8269 100644 --- a/ptx_parser/src/lib.rs +++ b/ptx_parser/src/lib.rs @@ -289,7 +289,12 @@ pub fn parse_module_unchecked<'input>(text: &'input str) -> Option( @@ -314,19 +319,24 @@ pub fn parse_module_checked<'input>( if !errors.is_empty() { return Err(errors); } - let parse_error = { + let parse_result = { let state = PtxParserState::new(&mut errors); let parser = PtxParser { state, input: &tokens[..], }; - match module.parse(parser) { - Ok(ast) => return Ok(ast), - Err(err) => PtxError::Parser(err.into_inner()), - } + module + .parse(parser) + .map_err(|err| PtxError::Parser(err.into_inner())) }; - errors.push(parse_error); - Err(errors) + match parse_result { + Ok(result) if errors.is_empty() => Ok(result), + Ok(_) => Err(errors), + Err(err) => { + errors.push(err); + Err(errors) + } + } } fn module<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult> { From aa98ab9e03c37094d745429a8114ee071676f7a7 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Tue, 3 Sep 2024 19:11:25 +0200 Subject: [PATCH 46/47] Fix all remaining problems --- ptx/src/pass/convert_to_typed.rs | 2 +- ptx/src/pass/insert_implicit_conversions.rs | 6 ++-- ptx/src/test/spirv_run/clz.spvtxt | 19 +++++++----- ptx/src/test/spirv_run/cvt_s16_s8.spvtxt | 7 +++-- ptx/src/test/spirv_run/cvt_s64_s32.spvtxt | 8 +++-- ptx/src/test/spirv_run/cvt_sat_s_u.spvtxt | 6 +++- ptx/src/test/spirv_run/popc.spvtxt | 19 +++++++----- ptx_parser/Cargo.toml | 1 + ptx_parser/src/ast.rs | 3 +- ptx_parser/src/lib.rs | 34 ++++++++++++--------- 10 files changed, 64 insertions(+), 41 deletions(-) diff --git a/ptx/src/pass/convert_to_typed.rs b/ptx/src/pass/convert_to_typed.rs index ab5b2463..550c662e 100644 --- a/ptx/src/pass/convert_to_typed.rs +++ b/ptx/src/pass/convert_to_typed.rs @@ -124,7 +124,7 @@ impl<'a, 'b> ast::VisitorMap, TypedOperand, Transl ast::ParsedOperand::Imm(x) => TypedOperand::Imm(x), ast::ParsedOperand::VecMember(vec, idx) => TypedOperand::VecMember(vec, idx), ast::ParsedOperand::VecPack(vec) => { - let (type_, space) = type_space.ok_or(error_mismatched_type())?; + let (type_, space) = type_space.ok_or_else(|| error_mismatched_type())?; TypedOperand::Reg(self.convert_vector( is_dst, relaxed_type_check, diff --git a/ptx/src/pass/insert_implicit_conversions.rs b/ptx/src/pass/insert_implicit_conversions.rs index 2857551a..25e80f05 100644 --- a/ptx/src/pass/insert_implicit_conversions.rs +++ b/ptx/src/pass/insert_implicit_conversions.rs @@ -238,7 +238,7 @@ fn default_implicit_conversion_type( if should_bitcast(instruction_type, operand_type) { Ok(Some(ConversionKind::Default)) } else { - Err(error_mismatched_type()) + Err(TranslateError::MismatchedType) } } else { Ok(Some(ConversionKind::PtrToPtr)) @@ -295,14 +295,14 @@ pub(crate) fn should_convert_relaxed_dst_wrapper( (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), ) -> Result, TranslateError> { if !space_is_compatible(operand_space, instruction_space) { - return Err(error_mismatched_type()); + return Err(TranslateError::MismatchedType); } if operand_type == instruction_type { return Ok(None); } match should_convert_relaxed_dst(operand_type, instruction_type) { conv @ Some(_) => Ok(conv), - None => Err(error_mismatched_type()), + None => Err(TranslateError::MismatchedType), } } diff --git a/ptx/src/test/spirv_run/clz.spvtxt b/ptx/src/test/spirv_run/clz.spvtxt index 9a7f2542..1feb5a0a 100644 --- a/ptx/src/test/spirv_run/clz.spvtxt +++ b/ptx/src/test/spirv_run/clz.spvtxt @@ -7,20 +7,24 @@ OpCapability Int64 OpCapability Float16 OpCapability Float64 - %21 = OpExtInstImport "OpenCL.std" + OpCapability DenormFlushToZero + OpExtension "SPV_KHR_float_controls" + OpExtension "SPV_KHR_no_integer_wrap_decoration" + %22 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %1 "clz" + OpExecutionMode %1 ContractionOff %void = OpTypeVoid %ulong = OpTypeInt 64 0 - %24 = OpTypeFunction %void %ulong %ulong + %25 = OpTypeFunction %void %ulong %ulong %_ptr_Function_ulong = OpTypePointer Function %ulong %uint = OpTypeInt 32 0 %_ptr_Function_uint = OpTypePointer Function %uint %_ptr_Generic_uint = OpTypePointer Generic %uint - %1 = OpFunction %void None %24 + %1 = OpFunction %void None %25 %7 = OpFunctionParameter %ulong %8 = OpFunctionParameter %ulong - %19 = OpLabel + %20 = OpLabel %2 = OpVariable %_ptr_Function_ulong Function %3 = OpVariable %_ptr_Function_ulong Function %4 = OpVariable %_ptr_Function_ulong Function @@ -37,11 +41,12 @@ %11 = OpLoad %uint %17 Aligned 4 OpStore %6 %11 %14 = OpLoad %uint %6 - %13 = OpExtInst %uint %21 clz %14 + %18 = OpExtInst %uint %22 clz %14 + %13 = OpCopyObject %uint %18 OpStore %6 %13 %15 = OpLoad %ulong %5 %16 = OpLoad %uint %6 - %18 = OpConvertUToPtr %_ptr_Generic_uint %15 - OpStore %18 %16 Aligned 4 + %19 = OpConvertUToPtr %_ptr_Generic_uint %15 + OpStore %19 %16 Aligned 4 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/cvt_s16_s8.spvtxt b/ptx/src/test/spirv_run/cvt_s16_s8.spvtxt index 5f4b050a..92322ecc 100644 --- a/ptx/src/test/spirv_run/cvt_s16_s8.spvtxt +++ b/ptx/src/test/spirv_run/cvt_s16_s8.spvtxt @@ -7,6 +7,9 @@ OpCapability Int64 OpCapability Float16 OpCapability Float64 + OpCapability DenormFlushToZero + OpExtension "SPV_KHR_float_controls" + OpExtension "SPV_KHR_no_integer_wrap_decoration" %24 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %1 "cvt_s16_s8" @@ -45,9 +48,7 @@ %32 = OpBitcast %uint %15 %34 = OpUConvert %uchar %32 %20 = OpCopyObject %uchar %34 - %35 = OpBitcast %uchar %20 - %37 = OpSConvert %ushort %35 - %19 = OpCopyObject %ushort %37 + %19 = OpSConvert %ushort %20 %14 = OpSConvert %uint %19 OpStore %6 %14 %16 = OpLoad %ulong %5 diff --git a/ptx/src/test/spirv_run/cvt_s64_s32.spvtxt b/ptx/src/test/spirv_run/cvt_s64_s32.spvtxt index 3f461034..11652905 100644 --- a/ptx/src/test/spirv_run/cvt_s64_s32.spvtxt +++ b/ptx/src/test/spirv_run/cvt_s64_s32.spvtxt @@ -7,9 +7,13 @@ OpCapability Int64 OpCapability Float16 OpCapability Float64 + OpCapability DenormFlushToZero + OpExtension "SPV_KHR_float_controls" + OpExtension "SPV_KHR_no_integer_wrap_decoration" %24 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %1 "cvt_s64_s32" + OpExecutionMode %1 ContractionOff %void = OpTypeVoid %ulong = OpTypeInt 64 0 %27 = OpTypeFunction %void %ulong %ulong @@ -40,9 +44,7 @@ %12 = OpCopyObject %uint %18 OpStore %6 %12 %15 = OpLoad %uint %6 - %32 = OpBitcast %uint %15 - %33 = OpSConvert %ulong %32 - %14 = OpCopyObject %ulong %33 + %14 = OpSConvert %ulong %15 OpStore %7 %14 %16 = OpLoad %ulong %5 %17 = OpLoad %ulong %7 diff --git a/ptx/src/test/spirv_run/cvt_sat_s_u.spvtxt b/ptx/src/test/spirv_run/cvt_sat_s_u.spvtxt index b6760499..07b228e8 100644 --- a/ptx/src/test/spirv_run/cvt_sat_s_u.spvtxt +++ b/ptx/src/test/spirv_run/cvt_sat_s_u.spvtxt @@ -7,9 +7,13 @@ OpCapability Int64 OpCapability Float16 OpCapability Float64 + OpCapability DenormFlushToZero + OpExtension "SPV_KHR_float_controls" + OpExtension "SPV_KHR_no_integer_wrap_decoration" %25 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %1 "cvt_sat_s_u" + OpExecutionMode %1 ContractionOff %void = OpTypeVoid %ulong = OpTypeInt 64 0 %28 = OpTypeFunction %void %ulong %ulong @@ -42,7 +46,7 @@ %15 = OpSatConvertSToU %uint %16 OpStore %7 %15 %18 = OpLoad %uint %7 - %17 = OpBitcast %uint %18 + %17 = OpCopyObject %uint %18 OpStore %8 %17 %19 = OpLoad %ulong %5 %20 = OpLoad %uint %8 diff --git a/ptx/src/test/spirv_run/popc.spvtxt b/ptx/src/test/spirv_run/popc.spvtxt index 845add7a..c41e7926 100644 --- a/ptx/src/test/spirv_run/popc.spvtxt +++ b/ptx/src/test/spirv_run/popc.spvtxt @@ -7,20 +7,24 @@ OpCapability Int64 OpCapability Float16 OpCapability Float64 - %21 = OpExtInstImport "OpenCL.std" + OpCapability DenormFlushToZero + OpExtension "SPV_KHR_float_controls" + OpExtension "SPV_KHR_no_integer_wrap_decoration" + %22 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %1 "popc" + OpExecutionMode %1 ContractionOff %void = OpTypeVoid %ulong = OpTypeInt 64 0 - %24 = OpTypeFunction %void %ulong %ulong + %25 = OpTypeFunction %void %ulong %ulong %_ptr_Function_ulong = OpTypePointer Function %ulong %uint = OpTypeInt 32 0 %_ptr_Function_uint = OpTypePointer Function %uint %_ptr_Generic_uint = OpTypePointer Generic %uint - %1 = OpFunction %void None %24 + %1 = OpFunction %void None %25 %7 = OpFunctionParameter %ulong %8 = OpFunctionParameter %ulong - %19 = OpLabel + %20 = OpLabel %2 = OpVariable %_ptr_Function_ulong Function %3 = OpVariable %_ptr_Function_ulong Function %4 = OpVariable %_ptr_Function_ulong Function @@ -37,11 +41,12 @@ %11 = OpLoad %uint %17 Aligned 4 OpStore %6 %11 %14 = OpLoad %uint %6 - %13 = OpBitCount %uint %14 + %18 = OpBitCount %uint %14 + %13 = OpCopyObject %uint %18 OpStore %6 %13 %15 = OpLoad %ulong %5 %16 = OpLoad %uint %6 - %18 = OpConvertUToPtr %_ptr_Generic_uint %15 - OpStore %18 %16 Aligned 4 + %19 = OpConvertUToPtr %_ptr_Generic_uint %15 + OpStore %19 %16 Aligned 4 OpReturn OpFunctionEnd diff --git a/ptx_parser/Cargo.toml b/ptx_parser/Cargo.toml index a4df14f1..9032de5c 100644 --- a/ptx_parser/Cargo.toml +++ b/ptx_parser/Cargo.toml @@ -9,6 +9,7 @@ edition = "2021" [dependencies] logos = "0.14" winnow = { version = "0.6.18" } +#winnow = { version = "0.6.18", features = ["debug"] } ptx_parser_macros = { path = "../ptx_parser_macros" } thiserror = "1.0" bitflags = "1.2" diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index f5e65b4b..ad44ab7a 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -1109,10 +1109,11 @@ impl SetpData { ) -> Self { let flush_to_zero = match (ftz, type_) { (_, ScalarType::F32) => Some(ftz), - _ => { + (true, _) => { state.errors.push(PtxError::NonF32Ftz); None } + _ => None }; let type_kind = type_.kind(); let cmp_op = if type_kind == ScalarKind::Float { diff --git a/ptx_parser/src/lib.rs b/ptx_parser/src/lib.rs index b81d8269..ed2cf2ae 100644 --- a/ptx_parser/src/lib.rs +++ b/ptx_parser/src/lib.rs @@ -3,6 +3,7 @@ use logos::Logos; use ptx_parser_macros::derive_parser; use rustc_hash::FxHashMap; use std::fmt::Debug; +use std::iter; use std::num::{NonZeroU8, ParseFloatError, ParseIntError}; use winnow::ascii::dec_uint; use winnow::combinator::*; @@ -397,14 +398,13 @@ fn directive<'a, 'input>( fn module_variable<'a, 'input>( stream: &mut PtxParser<'a, 'input>, ) -> PResult<(ast::LinkingDirective, ast::Variable<&'input str>)> { - ( - linking_directives, - global_space - .flat_map(multi_variable) - // TODO: support multi var in globals - .map(|multi_var| multi_var.var), - ) - .parse_next(stream) + let linking = linking_directives.parse_next(stream)?; + let var = global_space + .flat_map(|space| multi_variable(linking.contains(LinkingDirective::EXTERN), space)) + // TODO: support multi var in globals + .map(|multi_var| multi_var.var) + .parse_next(stream)?; + Ok((linking, var)) } fn file<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> { @@ -621,7 +621,7 @@ fn statement<'a, 'input>( debug_directive.map(|_| None), terminated( method_space - .flat_map(multi_variable) + .flat_map(|space| multi_variable(false, space)) .map(|var| Some(Statement::Variable(var))), Token::Semicolon, ), @@ -678,6 +678,7 @@ fn variable_declaration<'a, 'input>( } fn multi_variable<'a, 'input: 'a>( + extern_: bool, state_space: StateSpace, ) -> impl Parser, MultiVariable<&'input str>, ContextError> { move |stream: &mut PtxParser<'a, 'input>| { @@ -714,7 +715,7 @@ fn multi_variable<'a, 'input: 'a>( _ => None, }; if let Some(ref dims) = array_dimensions { - if dims[0] == 0 { + if !extern_ && dims[0] == 0 { return Err(ErrMode::from_error_kind(stream, ErrorKind::Verify)); } } @@ -746,13 +747,16 @@ fn array_initializer<'a, 'input: 'a>( delimited( Token::LBrace, separated( - array_dimensions[0] as usize..=array_dimensions[0] as usize, + 0..=array_dimensions[0] as usize, single_value_append(&mut result, type_), Token::Comma, ), Token::RBrace, ) .parse_next(stream)?; + // pad with zeros + let result_size = type_.size_of() as usize * array_dimensions[0] as usize; + result.extend(iter::repeat(0u8).take(result_size - result.len())); Ok(result) } } @@ -1079,11 +1083,11 @@ impl ast::ParsedOperand { fn vector_operand<'a, 'input>( stream: &mut PtxParser<'a, 'input>, ) -> PResult> { - let (_, r1, _, r2) = - (Token::LBracket, ident, Token::Comma, ident).parse_next(stream)?; + let (_, r1, _, r2) = (Token::LBrace, ident, Token::Comma, ident).parse_next(stream)?; + // TODO: parse .v8 literals dispatch! {any; - Token::LBracket => empty.map(|_| vec![r1, r2]), - Token::Comma => (ident, Token::Comma, ident, Token::LBracket).map(|(r3, _, r4, _)| vec![r1, r2, r3, r4]), + Token::RBrace => empty.map(|_| vec![r1, r2]), + Token::Comma => (ident, Token::Comma, ident, Token::RBrace).map(|(r3, _, r4, _)| vec![r1, r2, r3, r4]), _ => fail } .parse_next(stream) From 061312cf8f3762a2ca07b938748a32aa8a1159b1 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Wed, 4 Sep 2024 15:32:12 +0200 Subject: [PATCH 47/47] Document wtf is going on with parsing macros --- ptx_parser/src/ast.rs | 15 ++++++++ ptx_parser/src/lib.rs | 69 +++++++++++++++++++++++++++++++----- ptx_parser_macros/src/lib.rs | 2 +- 3 files changed, 77 insertions(+), 9 deletions(-) diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index ad44ab7a..d0dc303c 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -13,6 +13,21 @@ pub enum Statement { Block(Vec>), } +// We define the instruction enum through the macro instead of normally, because we have some of how +// we use this type in the compilee. Each instruction can be logically split into two parts: +// properties that define instruction semantics (e.g. is memory load volatile?) that don't change +// during compilation and arguments (e.g. memory load source and destination) that evolve during +// compilation. To support compilation passes we need to be able to visit (and change) every +// argument in a generic way. This macro has visibility over all the fields. Consequently, we use it +// to generate visitor functions. There re three functions to support three different semantics: +// visit-by-ref, visit-by-mutable-ref, visit-and-map. In a previous version of the compiler it was +// done by hand and was very limiting (we supported only visit-and-map). +// The visitor must implement appropriate visitor trait defined below this macro. For convenience, +// we implemented visitors for some corresponding FnMut(...) types. +// Properties in this macro are used to encode information about the instruction arguments (what +// Rust type is used for it post-parsing, what PTX type does it expect, what PTX address space does +// it expect, etc.). +// This information is then available to a visitor. ptx_parser_macros::generate_instruction_type!( pub enum Instruction { Mov { diff --git a/ptx_parser/src/lib.rs b/ptx_parser/src/lib.rs index ed2cf2ae..f842ace6 100644 --- a/ptx_parser/src/lib.rs +++ b/ptx_parser/src/lib.rs @@ -1335,14 +1335,6 @@ fn empty_call<'input>( } } -// Modifiers are turned into arguments to the blocks, with type: -// * If it is an alternative: -// * If it is mandatory then its type is Foo (as defined by the relevant rule) -// * If it is optional then its type is Option -// * Otherwise: -// * If it is mandatory then it is skipped -// * If it is optional then its type is `bool` - type ParsedOperandStr<'input> = ast::ParsedOperand<&'input str>; #[derive(Clone, PartialEq, Default, Debug, Display)] @@ -1351,6 +1343,67 @@ pub struct TokenError(std::ops::Range); impl std::error::Error for TokenError {} +// This macro is responsible for generating parser code for instruction parser. +// Instruction parsing is by far the most complex part of parsing PTX code: +// * There are tens of instruction kinds, each with slightly different parsing rules +// * After parsing, each instruction needs to do some early validation and generate a specific, +// strongly-typed object. We want strong-typing because we have a single PTX parser frontend, but +// there can be multiple different code emitter backends +// * Most importantly, instruction modifiers can come in aby order, so e.g. both +// `ld.relaxed.global.u32 a, b` and `ld.global.relaxed.u32 a, b` are equally valid. This makes +// classic parsing generators fail: if we tried to generate parsing rules that cover every possible +// ordering we'd need thousands of rules. This is not a purely theoretical problem. NVCC and Clang +// will always emit modifiers in the correct order, but people who write inline assembly usually +// get it wrong (even first party developers) +// +// This macro exists purely to generate repetitive code for parsing each instruction. It is +// _not_ self-contained and is _not_ general-purpose: it relies on certain types and functions from +// the enclosing module +// +// derive_parser!(...) input is split into three parts: +// * Token type definition +// * Partial enums +// * Parsing definitions +// +// Token type definition: +// This is the enum type that will be usesby the instruction parser. For every instruction and +// modifier, derive_parser!(...) will add appropriate variant into this type. So e.g. if there is a +// rule for for `bar.sync` then those two variants wil be appended to the Token enum: +// #[token("bar")] Bar, +// #[token(".sync")] DotSync, +// +// Partial enums: +// With proper annotations, derive_parser!(...) parsing definitions are able to interpret +// instruction modifiers as variants of a single enum type. So e.g. for definitions `ld.u32` and +// `ld.u64` the macro can generate `enum ScalarType { U32, U64 }`. The problem is that for some +// (but not all) of those generated enum types we want to add some attributes and additional +// variants. In order to do so, you need to define this enum and derive_parser!(...) will append to +// the type instead of creating a new type. This is sort of replacement for partial classes known +// from C# +// +// Parsing definitions: +// Parsing definitions consist of a list of patterns and rules: +// * Pattern consists of: +// * Opcode: `ld` +// * Modifiers, always start with a dot: `.global`, `.relaxed`. Optionals are enclosed in braces +// * Arguments: `a`, `b`. Optionals are enclosed in braces +// * Code block: => { }. Code blocks implictly take all modifiers ansd arguments +// as parameters. All modifiers and arguments are passed to the code block: +// * If it is an alternative (as defined in rules list later): +// * If it is mandatory then its type is Foo (as defined by the relevant rule) +// * If it is optional then its type is Option +// * Otherwise: +// * If it is mandatory then it is skipped +// * If it is optional then its type is `bool` +// * List of rules. They are associated with the preceding patterns (until different opcode or +// different rules). Rules are used to resolve modifiers. There are two types of rules: +// * Normal rule: `.foobar: FoobarEnum => { .a, .b, .c }`. This means that instead of `.foobar` we +// expecte one of `.a`, `.b`, `.c` and will emit value FoobarEnum::DotA, FoobarEnum::DotB, +// FoobarEnum::DotC appropriately +// * Type-only rule: `FoobarEnum => { .a, .b, .c }` this means that all the occurences of `.a` will +// emit FoobarEnum::DotA to the code block. This helps to avoid copy-paste errors +// Additionally, you can opt out from the usual parsing rule generation with a special `<=` pattern. +// See `call` instruction to see it in action derive_parser!( #[derive(Logos, PartialEq, Eq, Debug, Clone, Copy)] #[logos(skip r"(?:\s+)|(?://[^\n\r]*[\n\r]*)|(?:/\*[^*]*\*+(?:[^/*][^*]*\*+)*/)")] diff --git a/ptx_parser_macros/src/lib.rs b/ptx_parser_macros/src/lib.rs index a2f8396f..5f47fac7 100644 --- a/ptx_parser_macros/src/lib.rs +++ b/ptx_parser_macros/src/lib.rs @@ -16,7 +16,7 @@ use syn::{ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#packed-integer-data-types #[rustfmt::skip] static POSTFIX_MODIFIERS: &[&str] = &[ - ".v2", ".v4", + ".v2", ".v4", ".v8", ".s8", ".s16", ".s16x2", ".s32", ".s64", ".u8", ".u16", ".u16x2", ".u32", ".u64", ".f16", ".f16x2", ".f32", ".f64",