Skip to content
This repository has been archived by the owner on Jan 29, 2025. It is now read-only.

Commit

Permalink
fix zero initialization of workgroup memory (#2259)
Browse files Browse the repository at this point in the history
Use the local (not global) invocation id to decide which invocation should do the initialization, so that every workgroup gets initialized, not just the first.
  • Loading branch information
teoxoy authored Feb 21, 2023
1 parent cde457c commit 9742f16
Show file tree
Hide file tree
Showing 19 changed files with 44 additions and 50 deletions.
5 changes: 1 addition & 4 deletions src/back/glsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1583,10 +1583,7 @@ impl<'a, W: Write> Writer<'a, W> {
if vars.peek().is_some() {
let level = back::Level(1);

writeln!(
self.out,
"{level}if (gl_GlobalInvocationID == uvec3(0u)) {{"
)?;
writeln!(self.out, "{level}if (gl_LocalInvocationID == uvec3(0u)) {{")?;

for (handle, var) in vars {
let name = &self.names[&NameKey::GlobalVariable(handle)];
Expand Down
7 changes: 2 additions & 5 deletions src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1176,10 +1176,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
if !func.arguments.is_empty() {
write!(self.out, ", ")?;
}
write!(
self.out,
"uint3 __global_invocation_id : SV_DispatchThreadID"
)?;
write!(self.out, "uint3 __local_invocation_id : SV_GroupThreadID")?;
}
}
}
Expand Down Expand Up @@ -1281,7 +1278,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {

writeln!(
self.out,
"{level}if (all(__global_invocation_id == uint3(0u, 0u, 0u))) {{"
"{level}if (all(__local_invocation_id == uint3(0u, 0u, 0u))) {{"
)?;

let vars = module.global_variables.iter().filter(|&(handle, var)| {
Expand Down
18 changes: 9 additions & 9 deletions src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3633,7 +3633,7 @@ impl<W: Write> Writer<W> {
is_first_argument = false;
}

let mut global_invocation_id = None;
let mut local_invocation_id = None;

// Then pass the remaining arguments not included in the varyings
// struct.
Expand All @@ -3660,8 +3660,8 @@ impl<W: Write> Writer<W> {
&self.names[name_key]
};

if binding == &crate::Binding::BuiltIn(crate::BuiltIn::GlobalInvocationId) {
global_invocation_id = Some(name_key);
if binding == &crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationId) {
local_invocation_id = Some(name_key);
}

let ty_name = TypeContext {
Expand All @@ -3687,7 +3687,7 @@ impl<W: Write> Writer<W> {
let need_workgroup_variables_initialization =
self.need_workgroup_variables_initialization(options, ep, module, fun_info);

if need_workgroup_variables_initialization && global_invocation_id.is_none() {
if need_workgroup_variables_initialization && local_invocation_id.is_none() {
let separator = if is_first_argument {
is_first_argument = false;
' '
Expand All @@ -3696,7 +3696,7 @@ impl<W: Write> Writer<W> {
};
writeln!(
self.out,
"{separator} {NAMESPACE}::uint3 __global_invocation_id [[thread_position_in_grid]]"
"{separator} {NAMESPACE}::uint3 __local_invocation_id [[thread_position_in_threadgroup]]"
)?;
}

Expand Down Expand Up @@ -3786,7 +3786,7 @@ impl<W: Write> Writer<W> {
module,
mod_info,
fun_info,
global_invocation_id,
local_invocation_id,
)?;
}

Expand Down Expand Up @@ -4075,7 +4075,7 @@ mod workgroup_mem_init {
module: &crate::Module,
module_info: &valid::ModuleInfo,
fun_info: &valid::FunctionInfo,
global_invocation_id: Option<&NameKey>,
local_invocation_id: Option<&NameKey>,
) -> BackendResult {
let level = back::Level(1);

Expand All @@ -4084,9 +4084,9 @@ mod workgroup_mem_init {
"{}if ({}::all({} == {}::uint3(0u))) {{",
level,
NAMESPACE,
global_invocation_id
local_invocation_id
.map(|name_key| self.names[name_key].as_str())
.unwrap_or("__global_invocation_id"),
.unwrap_or("__local_invocation_id"),
NAMESPACE,
)?;

Expand Down
20 changes: 10 additions & 10 deletions src/back/spv/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ impl Writer {
results: Vec::new(),
};

let mut global_invocation_id = None;
let mut local_invocation_id = None;

let mut parameter_type_ids = Vec::with_capacity(ir_function.arguments.len());
for argument in ir_function.arguments.iter() {
Expand Down Expand Up @@ -400,8 +400,8 @@ impl Writer {
.body
.push(Instruction::load(argument_type_id, id, varying_id, None));

if binding == &crate::Binding::BuiltIn(crate::BuiltIn::GlobalInvocationId) {
global_invocation_id = Some(id);
if binding == &crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationId) {
local_invocation_id = Some(id);
}

id
Expand Down Expand Up @@ -430,7 +430,7 @@ impl Writer {
constituent_ids.push(id);

if binding == &crate::Binding::BuiltIn(crate::BuiltIn::GlobalInvocationId) {
global_invocation_id = Some(id);
local_invocation_id = Some(id);
}
}
prelude.body.push(Instruction::composite_construct(
Expand Down Expand Up @@ -673,7 +673,7 @@ impl Writer {
next_id,
ir_module,
info,
global_invocation_id,
local_invocation_id,
interface,
context.function,
),
Expand Down Expand Up @@ -1253,7 +1253,7 @@ impl Writer {
entry_id: Word,
ir_module: &crate::Module,
info: &FunctionInfo,
global_invocation_id: Option<Word>,
local_invocation_id: Option<Word>,
interface: &mut FunctionInterface,
function: &mut Function,
) -> Option<Word> {
Expand Down Expand Up @@ -1282,8 +1282,8 @@ impl Writer {

let mut pre_if_block = Block::new(entry_id);

let global_invocation_id = if let Some(global_invocation_id) = global_invocation_id {
global_invocation_id
let local_invocation_id = if let Some(local_invocation_id) = local_invocation_id {
local_invocation_id
} else {
let varying_id = self.id_gen.next();
let class = spirv::StorageClass::Input;
Expand All @@ -1295,7 +1295,7 @@ impl Writer {
self.decorate(
varying_id,
spirv::Decoration::BuiltIn,
&[spirv::BuiltIn::GlobalInvocationId as u32],
&[spirv::BuiltIn::LocalInvocationId as u32],
);

interface.varying_ids.push(varying_id);
Expand All @@ -1315,7 +1315,7 @@ impl Writer {
spirv::Op::IEqual,
bool3_type_id,
eq_id,
global_invocation_id,
local_invocation_id,
zero_id,
));

Expand Down
2 changes: 1 addition & 1 deletion tests/out/glsl/access.assign_through_ptr.Compute.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ void assign_array_through_ptr_fn(inout vec4 foo_2[2]) {
}

void main() {
if (gl_GlobalInvocationID == uvec3(0u)) {
if (gl_LocalInvocationID == uvec3(0u)) {
val = 0u;
}
memoryBarrierShared();
Expand Down
2 changes: 1 addition & 1 deletion tests/out/glsl/globals.main.Compute.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ void test_msl_packed_vec3_() {
}

void main() {
if (gl_GlobalInvocationID == uvec3(0u)) {
if (gl_LocalInvocationID == uvec3(0u)) {
wg = float[10](0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0);
at_1 = 0u;
}
Expand Down
2 changes: 1 addition & 1 deletion tests/out/glsl/workgroup-var-init.main.Compute.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ layout(std430) buffer type_1_block_0Compute { uint _group_0_binding_0_cs[512]; }


void main() {
if (gl_GlobalInvocationID == uvec3(0u)) {
if (gl_LocalInvocationID == uvec3(0u)) {
w_mem = WStruct(uint[512](0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u), 0, int[8][8](int[8](0, 0, 0, 0, 0, 0, 0, 0), int[8](0, 0, 0, 0, 0, 0, 0, 0), int[8](0, 0, 0, 0, 0, 0, 0, 0), int[8](0, 0, 0, 0, 0, 0, 0, 0), int[8](0, 0, 0, 0, 0, 0, 0, 0), int[8](0, 0, 0, 0, 0, 0, 0, 0), int[8](0, 0, 0, 0, 0, 0, 0, 0), int[8](0, 0, 0, 0, 0, 0, 0, 0)));
}
memoryBarrierShared();
Expand Down
4 changes: 2 additions & 2 deletions tests/out/hlsl/access.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -338,9 +338,9 @@ void atomics()
}

[numthreads(1, 1, 1)]
void assign_through_ptr(uint3 __global_invocation_id : SV_DispatchThreadID)
void assign_through_ptr(uint3 __local_invocation_id : SV_GroupThreadID)
{
if (all(__global_invocation_id == uint3(0u, 0u, 0u))) {
if (all(__local_invocation_id == uint3(0u, 0u, 0u))) {
val = (uint)0;
}
GroupMemoryBarrierWithGroupSync();
Expand Down
4 changes: 2 additions & 2 deletions tests/out/hlsl/globals.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,9 @@ uint NagaBufferLength(ByteAddressBuffer buffer)
}

[numthreads(1, 1, 1)]
void main(uint3 __global_invocation_id : SV_DispatchThreadID)
void main(uint3 __local_invocation_id : SV_GroupThreadID)
{
if (all(__global_invocation_id == uint3(0u, 0u, 0u))) {
if (all(__local_invocation_id == uint3(0u, 0u, 0u))) {
wg = (float[10])0;
at_1 = (uint)0;
}
Expand Down
4 changes: 2 additions & 2 deletions tests/out/hlsl/interface.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,9 @@ FragmentOutput fragment(FragmentInput_fragment fragmentinput_fragment)
}

[numthreads(1, 1, 1)]
void compute(uint3 global_id : SV_DispatchThreadID, uint3 local_id : SV_GroupThreadID, uint local_index : SV_GroupIndex, uint3 wg_id : SV_GroupID, uint3 num_wgs : SV_GroupID, uint3 __global_invocation_id : SV_DispatchThreadID)
void compute(uint3 global_id : SV_DispatchThreadID, uint3 local_id : SV_GroupThreadID, uint local_index : SV_GroupIndex, uint3 wg_id : SV_GroupID, uint3 num_wgs : SV_GroupID, uint3 __local_invocation_id : SV_GroupThreadID)
{
if (all(__global_invocation_id == uint3(0u, 0u, 0u))) {
if (all(__local_invocation_id == uint3(0u, 0u, 0u))) {
output = (uint[1])0;
}
GroupMemoryBarrierWithGroupSync();
Expand Down
4 changes: 2 additions & 2 deletions tests/out/hlsl/workgroup-var-init.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ groupshared WStruct w_mem;
RWByteAddressBuffer output : register(u0);

[numthreads(1, 1, 1)]
void main(uint3 __global_invocation_id : SV_DispatchThreadID)
void main(uint3 __local_invocation_id : SV_GroupThreadID)
{
if (all(__global_invocation_id == uint3(0u, 0u, 0u))) {
if (all(__local_invocation_id == uint3(0u, 0u, 0u))) {
w_mem = (WStruct)0;
}
GroupMemoryBarrierWithGroupSync();
Expand Down
4 changes: 2 additions & 2 deletions tests/out/msl/access.msl
Original file line number Diff line number Diff line change
Expand Up @@ -250,10 +250,10 @@ kernel void atomics(


kernel void assign_through_ptr(
metal::uint3 __global_invocation_id [[thread_position_in_grid]]
metal::uint3 __local_invocation_id [[thread_position_in_threadgroup]]
, threadgroup uint& val
) {
if (metal::all(__global_invocation_id == metal::uint3(0u))) {
if (metal::all(__local_invocation_id == metal::uint3(0u))) {
val = {};
}
metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
Expand Down
4 changes: 2 additions & 2 deletions tests/out/msl/globals.msl
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ void test_msl_packed_vec3_(
}

kernel void main_(
metal::uint3 __global_invocation_id [[thread_position_in_grid]]
metal::uint3 __local_invocation_id [[thread_position_in_threadgroup]]
, threadgroup type_2& wg
, threadgroup metal::atomic_uint& at_1
, device FooStruct& alignment [[user(fake0)]]
Expand All @@ -74,7 +74,7 @@ kernel void main_(
, constant type_15& global_nested_arrays_of_matrices_4x2_ [[user(fake0)]]
, constant _mslBufferSizes& _buffer_sizes [[user(fake0)]]
) {
if (metal::all(__global_invocation_id == metal::uint3(0u))) {
if (metal::all(__local_invocation_id == metal::uint3(0u))) {
wg = {};
metal::atomic_store_explicit(&at_1, 0, metal::memory_order_relaxed);
}
Expand Down
2 changes: 1 addition & 1 deletion tests/out/msl/interface.msl
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ kernel void compute_(
, metal::uint3 num_wgs [[threadgroups_per_grid]]
, threadgroup type_4& output
) {
if (metal::all(global_id == metal::uint3(0u))) {
if (metal::all(local_id == metal::uint3(0u))) {
output = {};
}
metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
Expand Down
4 changes: 2 additions & 2 deletions tests/out/msl/workgroup-var-init.msl
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ struct WStruct {
};

kernel void main_(
metal::uint3 __global_invocation_id [[thread_position_in_grid]]
metal::uint3 __local_invocation_id [[thread_position_in_threadgroup]]
, threadgroup WStruct& w_mem
, device type_1& output [[buffer(0)]]
) {
if (metal::all(__global_invocation_id == metal::uint3(0u))) {
if (metal::all(__local_invocation_id == metal::uint3(0u))) {
w_mem.arr = {};
metal::atomic_store_explicit(&w_mem.atom, 0, metal::memory_order_relaxed);
for (int __i0 = 0; __i0 < 8; __i0++) {
Expand Down
2 changes: 1 addition & 1 deletion tests/out/spv/access.spvasm
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ OpMemberDecorate %83 0 Offset 0
OpDecorate %242 BuiltIn VertexIndex
OpDecorate %245 BuiltIn Position
OpDecorate %288 Location 0
OpDecorate %337 BuiltIn GlobalInvocationId
OpDecorate %337 BuiltIn LocalInvocationId
%2 = OpTypeVoid
%4 = OpTypeInt 32 0
%3 = OpConstant %4 0
Expand Down
2 changes: 1 addition & 1 deletion tests/out/spv/globals.spvasm
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ OpDecorate %64 DescriptorSet 0
OpDecorate %64 Binding 7
OpDecorate %65 Block
OpMemberDecorate %65 0 Offset 0
OpDecorate %129 BuiltIn GlobalInvocationId
OpDecorate %129 BuiltIn LocalInvocationId
%2 = OpTypeVoid
%4 = OpTypeBool
%3 = OpConstantTrue %4
Expand Down
2 changes: 1 addition & 1 deletion tests/out/spv/interface.compute.spvasm
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ OpDecorate %33 BuiltIn NumWorkgroups
%34 = OpLoad %17 %33
OpBranch %37
%37 = OpLabel
%41 = OpIEqual %40 %25 %39
%41 = OpIEqual %40 %27 %39
%42 = OpAll %15 %41
OpSelectionMerge %43 None
OpBranchConditional %42 %44 %43
Expand Down
2 changes: 1 addition & 1 deletion tests/out/spv/workgroup-var-init.spvasm
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ OpDecorate %13 DescriptorSet 0
OpDecorate %13 Binding 0
OpDecorate %14 Block
OpMemberDecorate %14 0 Offset 0
OpDecorate %25 BuiltIn GlobalInvocationId
OpDecorate %25 BuiltIn LocalInvocationId
%2 = OpTypeVoid
%4 = OpTypeInt 32 1
%3 = OpConstant %4 512
Expand Down

0 comments on commit 9742f16

Please sign in to comment.