Skip to content
This repository has been archived by the owner on Jan 29, 2025. It is now read-only.

Commit

Permalink
[hlsl-out] Implement switch statement (#1265)
Browse files Browse the repository at this point in the history
* [hlsl-out] Implement switch statement

* [hlsl-out] Implement switch statement

* Add switch tests to control-flow snapshot
  • Loading branch information
Gordon-F authored and kvark committed Aug 24, 2021
1 parent 7d88637 commit 02c74b5
Show file tree
Hide file tree
Showing 8 changed files with 198 additions and 12 deletions.
5 changes: 0 additions & 5 deletions src/back/glsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1441,11 +1441,6 @@ impl<'a, W: Write> Writer<'a, W> {
for sta in case.body.iter() {
self.write_stmt(sta, ctx, indent + 2)?;
}

// Write `break;` if the block isn't fallthrough
if !case.fall_through {
writeln!(self.out, "{}break;", INDENT.repeat(indent + 2))?;
}
}

// Only write the default block if the block isn't empty
Expand Down
57 changes: 55 additions & 2 deletions src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1204,8 +1204,61 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
self.temp_access_chain = chain;
self.named_expressions.insert(result, res_name);
}
Statement::Switch { .. } => {
return Err(Error::Unimplemented(format!("write_stmt {:?}", stmt)))
Statement::Switch {
selector,
ref cases,
ref default,
} => {
// Start the switch
write!(self.out, "{}", INDENT.repeat(indent))?;
write!(self.out, "switch(")?;
self.write_expr(module, selector, func_ctx)?;
writeln!(self.out, ") {{")?;

// Write all cases
let indent_str_1 = INDENT.repeat(indent + 1);
let indent_str_2 = INDENT.repeat(indent + 2);

for case in cases {
writeln!(self.out, "{}case {}: {{", &indent_str_1, case.value)?;

if case.fall_through {
// Generate each fallthrough case statement in a new block. This is done to
// prevent symbol collision of variables declared in these cases statements.
writeln!(self.out, "{}/* fallthrough */", &indent_str_2)?;
writeln!(self.out, "{}{{", &indent_str_2)?;
}
for sta in case.body.iter() {
self.write_stmt(
module,
sta,
func_ctx,
indent + 2 + usize::from(case.fall_through),
)?;
}

if case.fall_through {
writeln!(self.out, "{}}}", &indent_str_2)?;
} else {
writeln!(self.out, "{}break;", &indent_str_2)?;
}

writeln!(self.out, "{}}}", &indent_str_1)?;
}

// Only write the default block if the block isn't empty
// Writing default without a block is valid but it's more readable this way
if !default.is_empty() {
writeln!(self.out, "{}default: {{", &indent_str_1)?;

for sta in default {
self.write_stmt(module, sta, func_ctx, indent + 2)?;
}

writeln!(self.out, "{}}}", &indent_str_1)?;
}

writeln!(self.out, "{}}}", INDENT.repeat(indent))?
}
}

Expand Down
26 changes: 26 additions & 0 deletions tests/in/control-flow.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,30 @@ fn main([[builtin(global_invocation_id)]] global_id: vec3<u32>) {
//TODO: execution-only barrier?
storageBarrier();
workgroupBarrier();

var pos: i32;
// switch without cases
switch (1) {
default: {
pos = 1;
}
}

switch (pos) {
case 1: {
pos = 0;
break;
}
case 2: {
pos = 1;
}
case 3: {
pos = 2;
fallthrough;
}
case 4: {}
default: {
pos = 3;
}
}
}
22 changes: 21 additions & 1 deletion tests/out/glsl/control-flow.main.Compute.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,28 @@ layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;

void main() {
uvec3 global_id = gl_GlobalInvocationID;
int pos = 0;
groupMemoryBarrier();
groupMemoryBarrier();
return;
switch(1) {
default:
pos = 1;
}
int _e4 = pos;
switch(_e4) {
case 1:
pos = 0;
break;
case 2:
pos = 1;
return;
case 3:
pos = 2;
case 4:
return;
default:
pos = 3;
return;
}
}

35 changes: 34 additions & 1 deletion tests/out/hlsl/control-flow.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,40 @@ struct ComputeInput_main {
[numthreads(1, 1, 1)]
void main(ComputeInput_main computeinput_main)
{
int pos = (int)0;

DeviceMemoryBarrierWithGroupSync();
GroupMemoryBarrierWithGroupSync();
return;
switch(1) {
default: {
pos = 1;
}
}
int _expr4 = pos;
switch(_expr4) {
case 1: {
pos = 0;
break;
break;
}
case 2: {
pos = 1;
return;
break;
}
case 3: {
/* fallthrough */
{
pos = 2;
}
}
case 4: {
return;
break;
}
default: {
pos = 3;
return;
}
}
}
31 changes: 30 additions & 1 deletion tests/out/msl/control-flow.msl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,36 @@ struct main1Input {
kernel void main1(
metal::uint3 global_id [[thread_position_in_grid]]
) {
int pos;
metal::threadgroup_barrier(metal::mem_flags::mem_device);
metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
return;
switch(1) {
default: {
pos = 1;
}
}
int _e4 = pos;
switch(_e4) {
case 1: {
pos = 0;
break;
break;
}
case 2: {
pos = 1;
return;
break;
}
case 3: {
pos = 2;
}
case 4: {
return;
break;
}
default: {
pos = 3;
return;
}
}
}
30 changes: 29 additions & 1 deletion tests/out/wgsl/control-flow.wgsl
Original file line number Diff line number Diff line change
@@ -1,6 +1,34 @@
[[stage(compute), workgroup_size(1, 1, 1)]]
fn main([[builtin(global_invocation_id)]] global_id: vec3<u32>) {
var pos: i32;

storageBarrier();
workgroupBarrier();
return;
switch(1) {
default: {
pos = 1;
}
}
let _e4: i32 = pos;
switch(_e4) {
case 1: {
pos = 0;
break;
}
case 2: {
pos = 1;
return;
}
case 3: {
pos = 2;
fallthrough;
}
case 4: {
return;
}
default: {
pos = 3;
return;
}
}
}
4 changes: 3 additions & 1 deletion tests/snapshots.rs
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,9 @@ fn convert_wgsl() {
),
(
"control-flow",
Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL,
// TODO: SPIRV https://github.com/gfx-rs/naga/issues/1017
//Targets::SPIRV |
Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL,
),
(
"standard",
Expand Down

0 comments on commit 02c74b5

Please sign in to comment.