Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: [spv-front] Support for OpAtomicCompareExchange, fixes #6296 #6590

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ By @ErichDonGubler in [#6456](https://github.com/gfx-rs/wgpu/pull/6456), [#6148]
- Add support for GLSL `usampler*` and `isampler*`. By @DavidPeicho in [#6513](https://github.com/gfx-rs/wgpu/pull/6513).
- Expose Ray Query flags as constants in WGSL. Implement candidate intersections. By @kvark in [#5429](https://github.com/gfx-rs/wgpu/pull/5429)
- Allow for override-expressions in `workgroup_size`. By @KentSlaney in [#6635](https://github.com/gfx-rs/wgpu/pull/6635).
- Add support for OpAtomicCompareExchange in SPIR-V frontend. By @schell in [#6590](https://github.com/gfx-rs/wgpu/pull/6590).

#### General

Expand Down
2 changes: 2 additions & 0 deletions naga/src/front/atomic_upgrade.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ pub enum Error {
GlobalInitUnsupported,
#[error("expected to find a global variable")]
GlobalVariableMissing,
#[error("atomic compare exchange requires a scalar base type")]
CompareExchangeNonScalarBaseType,
}

#[derive(Clone, Default)]
Expand Down
119 changes: 109 additions & 10 deletions naga/src/front/spv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4216,6 +4216,102 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
self.upgrade_atomics
.insert(ctx.get_contained_global_variable(p_exp_h)?);
}
Op::AtomicCompareExchange => {
inst.expect(9)?;

let start = self.data_offset;
let span = self.span_from_with_op(start);
let result_type_id = self.next()?;
let result_id = self.next()?;
let pointer_id = self.next()?;
let _memory_scope_id = self.next()?;
let _equal_memory_semantics_id = self.next()?;
let _unequal_memory_semantics_id = self.next()?;
let value_id = self.next()?;
let comparator_id = self.next()?;

let (p_exp_h, p_base_ty_h) = self.get_exp_and_base_ty_handles(
pointer_id,
ctx,
&mut emitter,
&mut block,
body_idx,
)?;

log::trace!("\t\t\tlooking up value expr {:?}", value_id);
let v_lexp_handle =
get_expr_handle!(value_id, self.lookup_expression.lookup(value_id)?);

log::trace!("\t\t\tlooking up comparator expr {:?}", value_id);
let c_lexp_handle = get_expr_handle!(
comparator_id,
self.lookup_expression.lookup(comparator_id)?
);

// We know from the SPIR-V spec that the result type must be an integer
// scalar, and we'll need the type itself to get a handle to the atomic
// result struct.
let crate::TypeInner::Scalar(scalar) = ctx.module.types[p_base_ty_h].inner
else {
return Err(
crate::front::atomic_upgrade::Error::CompareExchangeNonScalarBaseType
.into(),
);
};

// Get a handle to the atomic result struct type.
let atomic_result_struct_ty_h = ctx.module.generate_predeclared_type(
crate::PredeclaredType::AtomicCompareExchangeWeakResult(scalar),
);

block.extend(emitter.finish(ctx.expressions));

// Create an expression for our atomic result
let atomic_lexp_handle = {
let expr = crate::Expression::AtomicResult {
ty: atomic_result_struct_ty_h,
comparison: true,
};
ctx.expressions.append(expr, span)
};

// Create an dot accessor to extract the value from the
// result struct __atomic_compare_exchange_result<T> and use that
// as the expression for the result_id
{
let expr = crate::Expression::AccessIndex {
base: atomic_lexp_handle,
index: 0,
};
let handle = ctx.expressions.append(expr, span);
// Use this dot accessor as the result id's expression
let _ = self.lookup_expression.insert(
result_id,
LookupExpression {
handle,
type_id: result_type_id,
block_id,
},
);
}

emitter.start(ctx.expressions);

// Create a statement for the op itself
let stmt = crate::Statement::Atomic {
pointer: p_exp_h,
fun: crate::AtomicFunction::Exchange {
compare: Some(c_lexp_handle),
},
value: v_lexp_handle,
result: Some(atomic_lexp_handle),
};
block.push(stmt, span);

// Store any associated global variables so we can upgrade their types later
self.upgrade_atomics
.insert(ctx.get_contained_global_variable(p_exp_h)?);
}
Op::AtomicExchange
| Op::AtomicIAdd
| Op::AtomicISub
Expand Down Expand Up @@ -5912,17 +6008,18 @@ mod test_atomic {
let m = crate::front::spv::parse_u8_slice(bytes, &Default::default()).unwrap();

let mut wgsl = String::new();
let mut should_panic = false;

for vflags in [
crate::valid::ValidationFlags::all(),
crate::valid::ValidationFlags::empty(),
for (vflags, name) in [
(crate::valid::ValidationFlags::empty(), "empty"),
(crate::valid::ValidationFlags::all(), "all"),
] {
log::info!("validating with flags - {name}");
let mut validator = crate::valid::Validator::new(vflags, Default::default());
match validator.validate(&m) {
Err(e) => {
log::error!("SPIR-V validation {}", e.emit_to_string(""));
should_panic = true;
log::info!("types: {:#?}", m.types);
panic!("validation error");
}
Ok(i) => {
wgsl = crate::back::wgsl::write_string(
Expand All @@ -5932,15 +6029,10 @@ mod test_atomic {
)
.unwrap();
log::info!("wgsl-out:\n{wgsl}");
break;
}
};
}

if should_panic {
panic!("validation error");
}

let m = match crate::front::wgsl::parse_str(&wgsl) {
Ok(m) => m,
Err(e) => {
Expand Down Expand Up @@ -5975,6 +6067,13 @@ mod test_atomic {
atomic_test(include_bytes!("../../../tests/in/spv/atomic_exchange.spv"));
}

#[test]
fn atomic_compare_exchange() {
atomic_test(include_bytes!(
"../../../tests/in/spv/atomic_compare_exchange.spv"
));
}

#[test]
fn atomic_i_decrement() {
atomic_test(include_bytes!(
Expand Down
Binary file added naga/tests/in/spv/atomic_compare_exchange.spv
Binary file not shown.
89 changes: 89 additions & 0 deletions naga/tests/in/spv/atomic_compare_exchange.spvasm
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
; SPIR-V
; Version: 1.5
; Generator: Google rspirv; 0
; Bound: 65
; Schema: 0
OpCapability Shader
OpCapability VulkanMemoryModel
OpMemoryModel Logical Vulkan
OpEntryPoint GLCompute %1 "stage::test_atomic_compare_exchange" %2 %3
OpExecutionMode %1 LocalSize 32 1 1
OpMemberDecorate %_struct_9 0 Offset 0
OpMemberDecorate %_struct_9 1 Offset 4
OpDecorate %_struct_10 Block
OpMemberDecorate %_struct_10 0 Offset 0
OpDecorate %2 Binding 0
OpDecorate %2 DescriptorSet 0
OpDecorate %3 NonWritable
OpDecorate %3 Binding 1
OpDecorate %3 DescriptorSet 0
%uint = OpTypeInt 32 0
%void = OpTypeVoid
%13 = OpTypeFunction %void
%bool = OpTypeBool
%uint_0 = OpConstant %uint 0
%uint_2 = OpConstant %uint 2
%false = OpConstantFalse %bool
%_ptr_StorageBuffer_uint = OpTypePointer StorageBuffer %uint
%uint_1 = OpConstant %uint 1
%_struct_9 = OpTypeStruct %uint %uint
%20 = OpUndef %_struct_9
%uint_3 = OpConstant %uint 3
%int = OpTypeInt 32 1
%23 = OpUndef %bool
%true = OpConstantTrue %bool
%_struct_10 = OpTypeStruct %uint
%_ptr_StorageBuffer__struct_10 = OpTypePointer StorageBuffer %_struct_10
%2 = OpVariable %_ptr_StorageBuffer__struct_10 StorageBuffer
%3 = OpVariable %_ptr_StorageBuffer__struct_10 StorageBuffer
%uint_256 = OpConstant %uint 256
%1 = OpFunction %void None %13
%27 = OpLabel
%28 = OpInBoundsAccessChain %_ptr_StorageBuffer_uint %2 %uint_0
%29 = OpInBoundsAccessChain %_ptr_StorageBuffer_uint %3 %uint_0
%30 = OpLoad %uint %29
%31 = OpCompositeConstruct %_struct_9 %uint_0 %30
OpBranch %32
%32 = OpLabel
%33 = OpPhi %_struct_9 %31 %27 %34 %35
OpLoopMerge %36 %35 None
OpBranch %37
%37 = OpLabel
%38 = OpCompositeExtract %uint %33 0
%39 = OpCompositeExtract %uint %33 1
%40 = OpULessThan %bool %38 %39
OpSelectionMerge %41 None
OpBranchConditional %40 %42 %43
%42 = OpLabel
%45 = OpIAdd %uint %38 %uint_1
%46 = OpCompositeInsert %_struct_9 %45 %33 0
%47 = OpCompositeConstruct %_struct_9 %uint_1 %38
OpBranch %41
%43 = OpLabel
%48 = OpCompositeInsert %_struct_9 %uint_0 %20 0
OpBranch %41
%41 = OpLabel
%34 = OpPhi %_struct_9 %46 %42 %33 %43
%49 = OpPhi %_struct_9 %47 %42 %48 %43
%50 = OpCompositeExtract %uint %49 0
%51 = OpCompositeExtract %uint %49 1
%52 = OpBitcast %int %50
OpSelectionMerge %53 None
OpSwitch %52 %54 0 %55 1 %56
%54 = OpLabel
OpBranch %53
%55 = OpLabel
OpBranch %53
%56 = OpLabel
%57 = OpAtomicCompareExchange %uint %28 %uint_2 %uint_256 %uint_256 %51 %uint_3
%58 = OpIEqual %bool %57 %uint_3
%64 = OpSelect %bool %58 %false %true
OpBranch %53
%53 = OpLabel
%63 = OpPhi %bool %23 %54 %false %55 %64 %56
OpBranch %35
%35 = OpLabel
OpBranchConditional %63 %32 %36
%36 = OpLabel
OpReturn
OpFunctionEnd
Loading