Skip to content
This repository has been archived by the owner on Jan 29, 2025. It is now read-only.

[wgsl-in] Handle modf and frexp #2454

Merged
merged 3 commits into from
Sep 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/back/glsl/keywords.rs
Original file line number Diff line number Diff line change
Expand Up @@ -477,4 +477,7 @@ pub const RESERVED_KEYWORDS: &[&str] = &[
// entry point name (should not be shadowed)
//
"main",
// Naga utilities:
super::MODF_FUNCTION,
super::FREXP_FUNCTION,
];
54 changes: 52 additions & 2 deletions src/back/glsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ pub const SUPPORTED_ES_VERSIONS: &[u16] = &[300, 310, 320];
/// of detail for bounds checking in `ImageLoad`
const CLAMPED_LOD_SUFFIX: &str = "_clamped_lod";

pub(crate) const MODF_FUNCTION: &str = "naga_modf";
pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";

/// Mapping between resources and bindings.
pub type BindingMap = std::collections::BTreeMap<crate::ResourceBinding, u8>;

Expand Down Expand Up @@ -631,6 +634,53 @@ impl<'a, W: Write> Writer<'a, W> {
}
}

// Write functions to create special types.
for (type_key, struct_ty) in self.module.special_types.predeclared_types.iter() {
match type_key {
&crate::PredeclaredType::ModfResult { size, width }
| &crate::PredeclaredType::FrexpResult { size, width } => {
let arg_type_name_owner;
let arg_type_name = if let Some(size) = size {
arg_type_name_owner =
format!("{}vec{}", if width == 8 { "d" } else { "" }, size as u8);
&arg_type_name_owner
} else if width == 8 {
"double"
} else {
"float"
};

let other_type_name_owner;
let (defined_func_name, called_func_name, other_type_name) =
if matches!(type_key, &crate::PredeclaredType::ModfResult { .. }) {
(MODF_FUNCTION, "modf", arg_type_name)
} else {
let other_type_name = if let Some(size) = size {
other_type_name_owner = format!("ivec{}", size as u8);
&other_type_name_owner
} else {
"int"
};
(FREXP_FUNCTION, "frexp", other_type_name)
};

let struct_name = &self.names[&NameKey::Type(*struct_ty)];

writeln!(self.out)?;
writeln!(
self.out,
"{} {defined_func_name}({arg_type_name} arg) {{
{other_type_name} other;
{arg_type_name} fract = {called_func_name}(arg, other);
return {}(fract, other);
}}",
struct_name, struct_name
)?;
}
&crate::PredeclaredType::AtomicCompareExchangeWeakResult { .. } => {}
}
}

// Write all named constants
let mut constants = self
.module
Expand Down Expand Up @@ -2997,8 +3047,8 @@ impl<'a, W: Write> Writer<'a, W> {
Mf::Round => "roundEven",
Mf::Fract => "fract",
Mf::Trunc => "trunc",
Mf::Modf => "modf",
Mf::Frexp => "frexp",
Mf::Modf => MODF_FUNCTION,
Mf::Frexp => FREXP_FUNCTION,
Mf::Ldexp => "ldexp",
// exponent
Mf::Exp => "exp",
Expand Down
53 changes: 53 additions & 0 deletions src/back/hlsl/help.rs
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,59 @@ impl<'a, W: Write> super::Writer<'a, W> {
Ok(())
}

/// Write functions to create special types.
pub(super) fn write_special_functions(&mut self, module: &crate::Module) -> BackendResult {
for (type_key, struct_ty) in module.special_types.predeclared_types.iter() {
match type_key {
&crate::PredeclaredType::ModfResult { size, width }
| &crate::PredeclaredType::FrexpResult { size, width } => {
let arg_type_name_owner;
let arg_type_name = if let Some(size) = size {
arg_type_name_owner = format!(
"{}{}",
if width == 8 { "double" } else { "float" },
size as u8
);
&arg_type_name_owner
} else if width == 8 {
"double"
} else {
"float"
};

let (defined_func_name, called_func_name, second_field_name, sign_multiplier) =
if matches!(type_key, &crate::PredeclaredType::ModfResult { .. }) {
(super::writer::MODF_FUNCTION, "modf", "whole", "")
} else {
(
super::writer::FREXP_FUNCTION,
"frexp",
"exp_",
"sign(arg) * ",
teoxoy marked this conversation as resolved.
Show resolved Hide resolved
)
};

let struct_name = &self.names[&NameKey::Type(*struct_ty)];

writeln!(
self.out,
"{struct_name} {defined_func_name}({arg_type_name} arg) {{
{arg_type_name} other;
{struct_name} result;
result.fract = {sign_multiplier}{called_func_name}(arg, other);
result.{second_field_name} = other;
return result;
}}"
)?;
writeln!(self.out)?;
}
&crate::PredeclaredType::AtomicCompareExchangeWeakResult { .. } => {}
}
}

Ok(())
}

/// Helper function that writes compose wrapped functions
pub(super) fn write_wrapped_compose_functions(
&mut self,
Expand Down
3 changes: 3 additions & 0 deletions src/back/hlsl/keywords.rs
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,9 @@ pub const RESERVED: &[&str] = &[
"TextureBuffer",
"ConstantBuffer",
"RayQuery",
// Naga utilities
super::writer::MODF_FUNCTION,
super::writer::FREXP_FUNCTION,
];

// DXC scalar types, from https://github.com/microsoft/DirectXShaderCompiler/blob/18c9e114f9c314f93e68fbc72ce207d4ed2e65ae/tools/clang/lib/AST/ASTContextHLSL.cpp#L48-L254
Expand Down
9 changes: 7 additions & 2 deletions src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ const SPECIAL_BASE_VERTEX: &str = "base_vertex";
const SPECIAL_BASE_INSTANCE: &str = "base_instance";
const SPECIAL_OTHER: &str = "other";

pub(crate) const MODF_FUNCTION: &str = "naga_modf";
pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";

struct EpStructMember {
name: String,
ty: Handle<crate::Type>,
Expand Down Expand Up @@ -244,6 +247,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
}
}

self.write_special_functions(module)?;

self.write_wrapped_compose_functions(module, &module.const_expressions)?;

// Write all named constants
Expand Down Expand Up @@ -2675,8 +2680,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
Mf::Round => Function::Regular("round"),
Mf::Fract => Function::Regular("frac"),
Mf::Trunc => Function::Regular("trunc"),
Mf::Modf => Function::Regular("modf"),
Mf::Frexp => Function::Regular("frexp"),
Mf::Modf => Function::Regular(MODF_FUNCTION),
Mf::Frexp => Function::Regular(FREXP_FUNCTION),
Mf::Ldexp => Function::Regular("ldexp"),
// exponent
Mf::Exp => Function::Regular("exp"),
Expand Down
2 changes: 2 additions & 0 deletions src/back/msl/keywords.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,4 +214,6 @@ pub const RESERVED: &[&str] = &[
// Naga utilities
"DefaultConstructible",
"clamped_lod_e",
super::writer::FREXP_FUNCTION,
super::writer::MODF_FUNCTION,
];
61 changes: 59 additions & 2 deletions src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ const RAY_QUERY_FIELD_INTERSECTION: &str = "intersection";
const RAY_QUERY_FIELD_READY: &str = "ready";
const RAY_QUERY_FUN_MAP_INTERSECTION: &str = "_map_intersection_type";

pub(crate) const MODF_FUNCTION: &str = "naga_modf";
pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";

/// Write the Metal name for a Naga numeric type: scalar, vector, or matrix.
///
/// The `sizes` slice determines whether this function writes a
Expand Down Expand Up @@ -1678,8 +1681,8 @@ impl<W: Write> Writer<W> {
Mf::Round => "rint",
Mf::Fract => "fract",
Mf::Trunc => "trunc",
Mf::Modf => "modf",
Mf::Frexp => "frexp",
Mf::Modf => MODF_FUNCTION,
Mf::Frexp => FREXP_FUNCTION,
Mf::Ldexp => "ldexp",
// exponent
Mf::Exp => "exp",
Expand Down Expand Up @@ -1813,6 +1816,9 @@ impl<W: Write> Writer<W> {
write!(self.out, "((")?;
self.put_expression(arg, context, false)?;
write!(self.out, ") * 57.295779513082322865)")?;
} else if fun == Mf::Modf || fun == Mf::Frexp {
write!(self.out, "{fun_name}")?;
self.put_call_parameters(iter::once(arg), context)?;
} else {
write!(self.out, "{NAMESPACE}::{fun_name}")?;
self.put_call_parameters(
Expand Down Expand Up @@ -3236,6 +3242,57 @@ impl<W: Write> Writer<W> {
}
}
}

// Write functions to create special types.
for (type_key, struct_ty) in module.special_types.predeclared_types.iter() {
match type_key {
&crate::PredeclaredType::ModfResult { size, width }
| &crate::PredeclaredType::FrexpResult { size, width } => {
let arg_type_name_owner;
let arg_type_name = if let Some(size) = size {
arg_type_name_owner = format!(
"{NAMESPACE}::{}{}",
if width == 8 { "double" } else { "float" },
size as u8
);
&arg_type_name_owner
} else if width == 8 {
"double"
} else {
"float"
};

let other_type_name_owner;
let (defined_func_name, called_func_name, other_type_name) =
if matches!(type_key, &crate::PredeclaredType::ModfResult { .. }) {
(MODF_FUNCTION, "modf", arg_type_name)
} else {
let other_type_name = if let Some(size) = size {
other_type_name_owner = format!("int{}", size as u8);
&other_type_name_owner
} else {
"int"
};
(FREXP_FUNCTION, "frexp", other_type_name)
};

let struct_name = &self.names[&NameKey::Type(*struct_ty)];

writeln!(self.out)?;
writeln!(
self.out,
"{} {defined_func_name}({arg_type_name} arg) {{
{other_type_name} other;
{arg_type_name} fract = {NAMESPACE}::{called_func_name}(arg, other);
return {}{{ fract, other }};
}}",
struct_name, struct_name
)?;
}
&crate::PredeclaredType::AtomicCompareExchangeWeakResult { .. } => {}
}
}

Ok(())
}

Expand Down
4 changes: 2 additions & 2 deletions src/back/spv/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -787,8 +787,8 @@ impl<'w> BlockContext<'w> {
Mf::Floor => MathOp::Ext(spirv::GLOp::Floor),
Mf::Fract => MathOp::Ext(spirv::GLOp::Fract),
Mf::Trunc => MathOp::Ext(spirv::GLOp::Trunc),
Mf::Modf => MathOp::Ext(spirv::GLOp::Modf),
Mf::Frexp => MathOp::Ext(spirv::GLOp::Frexp),
Mf::Modf => MathOp::Ext(spirv::GLOp::ModfStruct),
Mf::Frexp => MathOp::Ext(spirv::GLOp::FrexpStruct),
Mf::Ldexp => MathOp::Ext(spirv::GLOp::Ldexp),
// geometry
Mf::Dot => match *self.fun_info[arg].ty.inner_with(&self.ir_module.types) {
Expand Down
22 changes: 15 additions & 7 deletions src/back/wgsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,14 @@ impl<W: Write> Writer<W> {
self.ep_results.clear();
}

fn is_builtin_wgsl_struct(&self, module: &Module, handle: Handle<crate::Type>) -> bool {
module
.special_types
.predeclared_types
.values()
.any(|t| *t == handle)
}

pub fn write(&mut self, module: &Module, info: &valid::ModuleInfo) -> BackendResult {
self.reset(module);

Expand All @@ -109,13 +117,13 @@ impl<W: Write> Writer<W> {

// Write all structs
for (handle, ty) in module.types.iter() {
if let TypeInner::Struct {
ref members,
span: _,
} = ty.inner
{
self.write_struct(module, handle, members)?;
writeln!(self.out)?;
if let TypeInner::Struct { ref members, .. } = ty.inner {
{
if !self.is_builtin_wgsl_struct(module, handle) {
self.write_struct(module, handle, members)?;
writeln!(self.out)?;
}
}
}
}

Expand Down
Loading