From f86223908dee7e59c8ba05e1a332af12bd18cab3 Mon Sep 17 00:00:00 2001 From: Vecvec Date: Sun, 15 Dec 2024 17:50:52 +1300 Subject: [PATCH 01/11] add test --- tests/tests/ray_tracing/as_build.rs | 85 +-------------- tests/tests/ray_tracing/as_use_after_free.rs | 2 +- tests/tests/ray_tracing/mod.rs | 82 ++++++++++++++ tests/tests/ray_tracing/shader.rs | 107 +++++++++++++++++++ tests/tests/ray_tracing/shader.wgsl | 42 +++++++- 5 files changed, 235 insertions(+), 83 deletions(-) create mode 100644 tests/tests/ray_tracing/shader.rs diff --git a/tests/tests/ray_tracing/as_build.rs b/tests/tests/ray_tracing/as_build.rs index 1b52678128..087b667929 100644 --- a/tests/tests/ray_tracing/as_build.rs +++ b/tests/tests/ray_tracing/as_build.rs @@ -1,85 +1,8 @@ -use std::{iter, mem}; +use std::iter; -use wgpu::{ - util::{BufferInitDescriptor, DeviceExt}, - *, -}; +use wgpu::*; use wgpu_test::{fail, gpu_test, GpuTestConfiguration, TestParameters, TestingContext}; - -struct AsBuildContext { - vertices: Buffer, - blas_size: BlasTriangleGeometrySizeDescriptor, - blas: Blas, - // Putting this last, forces the BLAS to die before the TLAS. - tlas_package: TlasPackage, -} - -impl AsBuildContext { - fn new(ctx: &TestingContext) -> Self { - let vertices = ctx.device.create_buffer_init(&BufferInitDescriptor { - label: None, - contents: &[0; mem::size_of::<[[f32; 3]; 3]>()], - usage: BufferUsages::BLAS_INPUT, - }); - - let blas_size = BlasTriangleGeometrySizeDescriptor { - vertex_format: VertexFormat::Float32x3, - vertex_count: 3, - index_format: None, - index_count: None, - flags: AccelerationStructureGeometryFlags::empty(), - }; - - let blas = ctx.device.create_blas( - &CreateBlasDescriptor { - label: Some("BLAS"), - flags: AccelerationStructureFlags::PREFER_FAST_TRACE, - update_mode: AccelerationStructureUpdateMode::Build, - }, - BlasGeometrySizeDescriptors::Triangles { - descriptors: vec![blas_size.clone()], - }, - ); - - let tlas = ctx.device.create_tlas(&CreateTlasDescriptor { - label: Some("TLAS"), - max_instances: 1, - flags: AccelerationStructureFlags::PREFER_FAST_TRACE, - update_mode: AccelerationStructureUpdateMode::Build, - }); - - let mut tlas_package = TlasPackage::new(tlas); - tlas_package[0] = Some(TlasInstance::new( - &blas, - [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], - 0, - 0xFF, - )); - - Self { - vertices, - blas_size, - blas, - tlas_package, - } - } - - fn blas_build_entry(&self) -> BlasBuildEntry { - BlasBuildEntry { - blas: &self.blas, - geometry: BlasGeometries::TriangleGeometries(vec![BlasTriangleGeometry { - size: &self.blas_size, - vertex_buffer: &self.vertices, - first_vertex: 0, - vertex_stride: mem::size_of::<[f32; 3]>() as BufferAddress, - index_buffer: None, - index_buffer_offset: None, - transform_buffer: None, - transform_buffer_offset: None, - }]), - } - } -} +use crate::ray_tracing::AsBuildContext; #[gpu_test] static UNBUILT_BLAS: GpuTestConfiguration = GpuTestConfiguration::new() @@ -244,7 +167,7 @@ fn out_of_order_as_build_use(ctx: TestingContext) { label: None, layout: None, module: &shader, - entry_point: Some("comp_main"), + entry_point: Some("basic_usage"), compilation_options: Default::default(), cache: None, }); diff --git a/tests/tests/ray_tracing/as_use_after_free.rs b/tests/tests/ray_tracing/as_use_after_free.rs index c0df9d385e..c3a1545f7e 100644 --- a/tests/tests/ray_tracing/as_use_after_free.rs +++ b/tests/tests/ray_tracing/as_use_after_free.rs @@ -108,7 +108,7 @@ fn acceleration_structure_use_after_free(ctx: TestingContext) { label: None, layout: None, module: &shader, - entry_point: Some("comp_main"), + entry_point: Some("basic_usage"), compilation_options: Default::default(), cache: None, }); diff --git a/tests/tests/ray_tracing/mod.rs b/tests/tests/ray_tracing/mod.rs index e204392d2e..c58ef47567 100644 --- a/tests/tests/ray_tracing/mod.rs +++ b/tests/tests/ray_tracing/mod.rs @@ -1,4 +1,86 @@ +use std::mem; +use wgpu::{Blas, BlasBuildEntry, BlasGeometries, BlasGeometrySizeDescriptors, BlasTriangleGeometry, BlasTriangleGeometrySizeDescriptor, Buffer, CreateBlasDescriptor, CreateTlasDescriptor, TlasInstance, TlasPackage, util::DeviceExt}; +use wgpu::util::BufferInitDescriptor; +use wgpu_test::TestingContext; +use wgt::{AccelerationStructureFlags, AccelerationStructureGeometryFlags, AccelerationStructureUpdateMode, BufferAddress, BufferUsages, VertexFormat}; + mod as_build; mod as_create; mod as_use_after_free; mod scene; +mod shader; + +pub struct AsBuildContext { + vertices: Buffer, + blas_size: BlasTriangleGeometrySizeDescriptor, + blas: Blas, + // Putting this last, forces the BLAS to die before the TLAS. + tlas_package: TlasPackage, +} + +impl AsBuildContext { + pub fn new(ctx: &TestingContext) -> Self { + let vertices = ctx.device.create_buffer_init(&BufferInitDescriptor { + label: None, + contents: &[0; mem::size_of::<[[f32; 3]; 3]>()], + usage: BufferUsages::BLAS_INPUT, + }); + + let blas_size = BlasTriangleGeometrySizeDescriptor { + vertex_format: VertexFormat::Float32x3, + vertex_count: 3, + index_format: None, + index_count: None, + flags: AccelerationStructureGeometryFlags::empty(), + }; + + let blas = ctx.device.create_blas( + &CreateBlasDescriptor { + label: Some("BLAS"), + flags: AccelerationStructureFlags::PREFER_FAST_TRACE, + update_mode: AccelerationStructureUpdateMode::Build, + }, + BlasGeometrySizeDescriptors::Triangles { + descriptors: vec![blas_size.clone()], + }, + ); + + let tlas = ctx.device.create_tlas(&CreateTlasDescriptor { + label: Some("TLAS"), + max_instances: 1, + flags: AccelerationStructureFlags::PREFER_FAST_TRACE, + update_mode: AccelerationStructureUpdateMode::Build, + }); + + let mut tlas_package = TlasPackage::new(tlas); + tlas_package[0] = Some(TlasInstance::new( + &blas, + [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], + 0, + 0xFF, + )); + + Self { + vertices, + blas_size, + blas, + tlas_package, + } + } + + pub fn blas_build_entry(&self) -> BlasBuildEntry { + BlasBuildEntry { + blas: &self.blas, + geometry: BlasGeometries::TriangleGeometries(vec![BlasTriangleGeometry { + size: &self.blas_size, + vertex_buffer: &self.vertices, + first_vertex: 0, + vertex_stride: mem::size_of::<[f32; 3]>() as BufferAddress, + index_buffer: None, + index_buffer_offset: None, + transform_buffer: None, + transform_buffer_offset: None, + }]), + } + } +} \ No newline at end of file diff --git a/tests/tests/ray_tracing/shader.rs b/tests/tests/ray_tracing/shader.rs new file mode 100644 index 0000000000..60e5535872 --- /dev/null +++ b/tests/tests/ray_tracing/shader.rs @@ -0,0 +1,107 @@ +use wgpu::{BindGroupDescriptor, BindGroupEntry, BindingResource, BufferDescriptor, CommandEncoderDescriptor, ComputePassDescriptor, ComputePipelineDescriptor, include_wgsl}; +use wgpu_macros::gpu_test; +use wgpu_test::{GpuTestConfiguration, TestingContext, TestParameters}; +use wgt::BufferUsages; +use crate::ray_tracing::AsBuildContext; + +const STRUCT_SIZE: wgt::BufferAddress = 176; + +#[gpu_test] +static ACCESS_ALL_STRUCT_MEMBERS: GpuTestConfiguration = GpuTestConfiguration::new() + .parameters( + TestParameters::default() + .test_features_limits() + .features(wgpu::Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE | wgpu::Features::EXPERIMENTAL_RAY_QUERY), + ) + .run_sync(access_all_struct_members); + +fn access_all_struct_members(ctx: TestingContext) { + let buf = ctx.device.create_buffer(&BufferDescriptor { + label: None, + size: STRUCT_SIZE, + usage: BufferUsages::STORAGE, + mapped_at_creation: false, + }); + // + // Create a clean `AsBuildContext` + // + + let as_ctx = AsBuildContext::new(&ctx); + + // + // Build in the right order, then rebuild the BLAS so the TLAS is invalid, then use the TLAS. + // + + let mut encoder_blas = ctx + .device + .create_command_encoder(&CommandEncoderDescriptor { + label: Some("BLAS 1"), + }); + + encoder_blas.build_acceleration_structures([&as_ctx.blas_build_entry()], []); + + let mut encoder_tlas = ctx + .device + .create_command_encoder(&CommandEncoderDescriptor { + label: Some("TLAS 1"), + }); + + encoder_tlas.build_acceleration_structures([], [&as_ctx.tlas_package]); + + ctx.queue.submit([ + encoder_blas.finish(), + encoder_tlas.finish(), + ]); + + // + // Create shader to use tlas with + // + + let shader = ctx + .device + .create_shader_module(include_wgsl!("shader.wgsl")); + let compute_pipeline = ctx + .device + .create_compute_pipeline(&ComputePipelineDescriptor { + label: None, + layout: None, + module: &shader, + entry_point: Some("all_of_struct"), + compilation_options: Default::default(), + cache: None, + }); + + let bind_group = ctx.device.create_bind_group(&BindGroupDescriptor { + label: None, + layout: &compute_pipeline.get_bind_group_layout(0), + entries: &[ + BindGroupEntry { + binding: 0, + resource: BindingResource::AccelerationStructure(as_ctx.tlas_package.tlas()), + }, + BindGroupEntry { + binding: 1, + resource: BindingResource::Buffer(buf.as_entire_buffer_binding()), + } + ], + }); + + // + // Submit once to check for no issues + // + + let mut encoder_compute = ctx + .device + .create_command_encoder(&CommandEncoderDescriptor::default()); + { + let mut pass = encoder_compute.begin_compute_pass(&ComputePassDescriptor { + label: None, + timestamp_writes: None, + }); + pass.set_pipeline(&compute_pipeline); + pass.set_bind_group(0, Some(&bind_group), &[]); + pass.dispatch_workgroups(1, 2, 1) + } + + ctx.queue.submit([encoder_compute.finish()]); +} \ No newline at end of file diff --git a/tests/tests/ray_tracing/shader.wgsl b/tests/tests/ray_tracing/shader.wgsl index 370d69e1c3..ddb3505e75 100644 --- a/tests/tests/ray_tracing/shader.wgsl +++ b/tests/tests/ray_tracing/shader.wgsl @@ -1,11 +1,51 @@ @group(0) @binding(0) var acc_struct: acceleration_structure; +struct Intersection { + kind: u32, + t: f32, + instance_custom_index: u32, + instance_id: u32, + sbt_record_offset: u32, + geometry_index: u32, + primitive_index: u32, + barycentrics: vec2, + front_face: u32, + object_to_world: mat4x3, + world_to_object: mat4x3, +} + +@group(0) @binding(1) +var out: Intersection; + @workgroup_size(1) @compute -fn comp_main() { +fn basic_usage() { var rq: ray_query; rayQueryInitialize(&rq, acc_struct, RayDesc(0u, 0xFFu, 0.001, 100000.0, vec3f(0.0, 0.0, 0.0), vec3f(0.0, 0.0, 1.0))); rayQueryProceed(&rq); let intersection = rayQueryGetCommittedIntersection(&rq); +} + +@workgroup_size(1) +@compute +fn all_of_struct() { + var rq: ray_query; + rayQueryInitialize(&rq, acc_struct, RayDesc(0u, 0xFFu, 0.0, 0.0, vec3f(0.0, 0.0, 1.0), vec3f(0.0, 0.0, 1.0))); + rayQueryProceed(&rq); + let intersection = rayQueryGetCommittedIntersection(&rq); + // this prevents optimisation as we use the fields + out = Intersection( + intersection.kind, + intersection.t, + intersection.instance_custom_index, + intersection.instance_id, + intersection.sbt_record_offset, + intersection.geometry_index, + intersection.primitive_index, + intersection.barycentrics, + u32(intersection.front_face), + intersection.world_to_object, + intersection.object_to_world, + ); } \ No newline at end of file From 9c4b6cddd9dc4b668d9af9aa8f80c847a1528a4d Mon Sep 17 00:00:00 2001 From: Vecvec Date: Sun, 15 Dec 2024 18:06:15 +1300 Subject: [PATCH 02/11] format --- tests/tests/ray_tracing/as_build.rs | 2 +- tests/tests/ray_tracing/mod.rs | 13 ++++++++--- tests/tests/ray_tracing/shader.rs | 34 ++++++++++++++--------------- 3 files changed, 28 insertions(+), 21 deletions(-) diff --git a/tests/tests/ray_tracing/as_build.rs b/tests/tests/ray_tracing/as_build.rs index 087b667929..4b92d18c71 100644 --- a/tests/tests/ray_tracing/as_build.rs +++ b/tests/tests/ray_tracing/as_build.rs @@ -1,8 +1,8 @@ use std::iter; +use crate::ray_tracing::AsBuildContext; use wgpu::*; use wgpu_test::{fail, gpu_test, GpuTestConfiguration, TestParameters, TestingContext}; -use crate::ray_tracing::AsBuildContext; #[gpu_test] static UNBUILT_BLAS: GpuTestConfiguration = GpuTestConfiguration::new() diff --git a/tests/tests/ray_tracing/mod.rs b/tests/tests/ray_tracing/mod.rs index c58ef47567..0c54ef4749 100644 --- a/tests/tests/ray_tracing/mod.rs +++ b/tests/tests/ray_tracing/mod.rs @@ -1,8 +1,15 @@ use std::mem; -use wgpu::{Blas, BlasBuildEntry, BlasGeometries, BlasGeometrySizeDescriptors, BlasTriangleGeometry, BlasTriangleGeometrySizeDescriptor, Buffer, CreateBlasDescriptor, CreateTlasDescriptor, TlasInstance, TlasPackage, util::DeviceExt}; use wgpu::util::BufferInitDescriptor; +use wgpu::{ + util::DeviceExt, Blas, BlasBuildEntry, BlasGeometries, BlasGeometrySizeDescriptors, + BlasTriangleGeometry, BlasTriangleGeometrySizeDescriptor, Buffer, CreateBlasDescriptor, + CreateTlasDescriptor, TlasInstance, TlasPackage, +}; use wgpu_test::TestingContext; -use wgt::{AccelerationStructureFlags, AccelerationStructureGeometryFlags, AccelerationStructureUpdateMode, BufferAddress, BufferUsages, VertexFormat}; +use wgt::{ + AccelerationStructureFlags, AccelerationStructureGeometryFlags, + AccelerationStructureUpdateMode, BufferAddress, BufferUsages, VertexFormat, +}; mod as_build; mod as_create; @@ -83,4 +90,4 @@ impl AsBuildContext { }]), } } -} \ No newline at end of file +} diff --git a/tests/tests/ray_tracing/shader.rs b/tests/tests/ray_tracing/shader.rs index 60e5535872..9f07365853 100644 --- a/tests/tests/ray_tracing/shader.rs +++ b/tests/tests/ray_tracing/shader.rs @@ -1,18 +1,20 @@ -use wgpu::{BindGroupDescriptor, BindGroupEntry, BindingResource, BufferDescriptor, CommandEncoderDescriptor, ComputePassDescriptor, ComputePipelineDescriptor, include_wgsl}; +use crate::ray_tracing::AsBuildContext; +use wgpu::{ + include_wgsl, BindGroupDescriptor, BindGroupEntry, BindingResource, BufferDescriptor, + CommandEncoderDescriptor, ComputePassDescriptor, ComputePipelineDescriptor, +}; use wgpu_macros::gpu_test; -use wgpu_test::{GpuTestConfiguration, TestingContext, TestParameters}; +use wgpu_test::{GpuTestConfiguration, TestParameters, TestingContext}; use wgt::BufferUsages; -use crate::ray_tracing::AsBuildContext; const STRUCT_SIZE: wgt::BufferAddress = 176; #[gpu_test] static ACCESS_ALL_STRUCT_MEMBERS: GpuTestConfiguration = GpuTestConfiguration::new() - .parameters( - TestParameters::default() - .test_features_limits() - .features(wgpu::Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE | wgpu::Features::EXPERIMENTAL_RAY_QUERY), - ) + .parameters(TestParameters::default().test_features_limits().features( + wgpu::Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE + | wgpu::Features::EXPERIMENTAL_RAY_QUERY, + )) .run_sync(access_all_struct_members); fn access_all_struct_members(ctx: TestingContext) { @@ -48,10 +50,8 @@ fn access_all_struct_members(ctx: TestingContext) { encoder_tlas.build_acceleration_structures([], [&as_ctx.tlas_package]); - ctx.queue.submit([ - encoder_blas.finish(), - encoder_tlas.finish(), - ]); + ctx.queue + .submit([encoder_blas.finish(), encoder_tlas.finish()]); // // Create shader to use tlas with @@ -76,13 +76,13 @@ fn access_all_struct_members(ctx: TestingContext) { layout: &compute_pipeline.get_bind_group_layout(0), entries: &[ BindGroupEntry { - binding: 0, - resource: BindingResource::AccelerationStructure(as_ctx.tlas_package.tlas()), - }, + binding: 0, + resource: BindingResource::AccelerationStructure(as_ctx.tlas_package.tlas()), + }, BindGroupEntry { binding: 1, resource: BindingResource::Buffer(buf.as_entire_buffer_binding()), - } + }, ], }); @@ -104,4 +104,4 @@ fn access_all_struct_members(ctx: TestingContext) { } ctx.queue.submit([encoder_compute.finish()]); -} \ No newline at end of file +} From d726398b87206d919728b8225265c301a382709e Mon Sep 17 00:00:00 2001 From: Vecvec Date: Sun, 15 Dec 2024 18:07:15 +1300 Subject: [PATCH 03/11] move ray query get intersection to its own function --- naga/src/back/spv/block.rs | 15 +- naga/src/back/spv/mod.rs | 2 + naga/src/back/spv/ray.rs | 512 +++++++++++++++++++++++++++++++++++- naga/src/back/spv/writer.rs | 6 +- 4 files changed, 531 insertions(+), 4 deletions(-) diff --git a/naga/src/back/spv/block.rs b/naga/src/back/spv/block.rs index 0fbba5c737..4c721b3e6d 100644 --- a/naga/src/back/spv/block.rs +++ b/naga/src/back/spv/block.rs @@ -1736,7 +1736,20 @@ impl BlockContext<'_> { } crate::Expression::ArrayLength(expr) => self.write_runtime_array_length(expr, block)?, crate::Expression::RayQueryGetIntersection { query, committed } => { - self.write_ray_query_get_intersection(query, block, committed) + let query_id = self.cached[query]; + let func_id = self + .writer + .write_ray_query_get_intersection_function(committed, self.ir_module); + let ray_intersection = self.ir_module.special_types.ray_intersection.unwrap(); + let intersection_type_id = self.get_type_id(LookupType::Handle(ray_intersection)); + let id = self.gen_id(); + block.body.push(Instruction::function_call( + intersection_type_id, + id, + func_id, + &[query_id], + )); + id } }; diff --git a/naga/src/back/spv/mod.rs b/naga/src/back/spv/mod.rs index 6385311c73..e6164098d3 100644 --- a/naga/src/back/spv/mod.rs +++ b/naga/src/back/spv/mod.rs @@ -766,6 +766,8 @@ pub struct Writer { // Just a temporary list of SPIR-V ids temp_list: Vec, + + ray_get_intersection_function: Option, } bitflags::bitflags! { diff --git a/naga/src/back/spv/ray.rs b/naga/src/back/spv/ray.rs index d0076649a3..b2ad77eb3f 100644 --- a/naga/src/back/spv/ray.rs +++ b/naga/src/back/spv/ray.rs @@ -2,8 +2,518 @@ Generating SPIR-V for ray query operations. */ -use super::{Block, BlockContext, Instruction, LocalType, LookupType, NumericType}; +use super::{ + Block, BlockContext, Function, FunctionArgument, Instruction, LocalType, LookupFunctionType, + LookupType, NumericType, Writer, +}; use crate::arena::Handle; +use crate::{Type, TypeInner}; + +impl Writer { + pub(super) fn write_ray_query_get_intersection_function( + &mut self, + is_committed: bool, + ir_module: &crate::Module, + ) -> spirv::Word { + if let Some(func_id) = self.ray_get_intersection_function { + return func_id; + } + let ray_intersection = ir_module.special_types.ray_intersection.unwrap(); + let intersection_type_id = self.get_type_id(LookupType::Handle(ray_intersection)); + let intersection_pointer_type_id = + self.get_type_id(LookupType::Local(LocalType::Pointer { + base: ray_intersection, + class: spirv::StorageClass::Function, + })); + + let flag_type_id = self.get_type_id(LookupType::Local(LocalType::Numeric( + NumericType::Scalar(crate::Scalar::U32), + ))); + let flag_type = ir_module + .types + .get(&Type { + name: None, + inner: TypeInner::Scalar(crate::Scalar::U32), + }) + .unwrap(); + let flag_pointer_type_id = self.get_type_id(LookupType::Local(LocalType::Pointer { + base: flag_type, + class: spirv::StorageClass::Function, + })); + + let transform_type_id = + self.get_type_id(LookupType::Local(LocalType::Numeric(NumericType::Matrix { + columns: crate::VectorSize::Quad, + rows: crate::VectorSize::Tri, + scalar: crate::Scalar::F32, + }))); + let transform_type = ir_module + .types + .get(&Type { + name: None, + inner: TypeInner::Matrix { + columns: crate::VectorSize::Quad, + rows: crate::VectorSize::Tri, + scalar: crate::Scalar::F32, + }, + }) + .unwrap(); + let transform_pointer_type_id = self.get_type_id(LookupType::Local(LocalType::Pointer { + base: transform_type, + class: spirv::StorageClass::Function, + })); + + let barycentrics_type_id = + self.get_type_id(LookupType::Local(LocalType::Numeric(NumericType::Vector { + size: crate::VectorSize::Bi, + scalar: crate::Scalar::F32, + }))); + let barycentrics_type = ir_module + .types + .get(&Type { + name: None, + inner: TypeInner::Vector { + size: crate::VectorSize::Bi, + scalar: crate::Scalar::F32, + }, + }) + .unwrap(); + let barycentrics_pointer_type_id = + self.get_type_id(LookupType::Local(LocalType::Pointer { + base: barycentrics_type, + class: spirv::StorageClass::Function, + })); + + let bool_type_id = self.get_type_id(LookupType::Local(LocalType::Numeric( + NumericType::Scalar(crate::Scalar::BOOL), + ))); + let bool_type = ir_module + .types + .get(&Type { + name: None, + inner: TypeInner::Scalar(crate::Scalar::BOOL), + }) + .unwrap(); + let bool_pointer_type_id = self.get_type_id(LookupType::Local(LocalType::Pointer { + base: bool_type, + class: spirv::StorageClass::Function, + })); + + let scalar_type_id = self.get_type_id(LookupType::Local(LocalType::Numeric( + NumericType::Scalar(crate::Scalar::F32), + ))); + let float_type = ir_module + .types + .get(&Type { + name: None, + inner: TypeInner::Scalar(crate::Scalar::F32), + }) + .unwrap(); + let float_pointer_type_id = self.get_type_id(LookupType::Local(LocalType::Pointer { + base: float_type, + class: spirv::StorageClass::Function, + })); + + let rq_ty = ir_module + .types + .get(&Type { + name: None, + inner: TypeInner::RayQuery, + }) + .expect("ray_query type should have been populated by the variable passed into this!"); + let argument_type_id = self.get_type_id(LookupType::Local(LocalType::Pointer { + base: rq_ty, + class: spirv::StorageClass::Function, + })); + let func_ty = self.get_function_type(LookupFunctionType { + parameter_type_ids: vec![argument_type_id], + return_type_id: intersection_type_id, + }); + + let mut function = Function::default(); + let func_id = self.id_gen.next(); + function.signature = Some(Instruction::function( + intersection_type_id, + func_id, + spirv::FunctionControl::empty(), + func_ty, + )); + let blank_intersection = self.get_constant_null(intersection_type_id); + let query_id = self.id_gen.next(); + let instruction = Instruction::function_parameter(argument_type_id, query_id); + function.parameters.push(FunctionArgument { + instruction, + handle_id: 0, + }); + + let label_id = self.id_gen.next(); + let mut block = Block::new(label_id); + + let blank_intersection_id = self.id_gen.next(); + block.body.push(Instruction::variable( + intersection_pointer_type_id, + blank_intersection_id, + spirv::StorageClass::Function, + Some(blank_intersection), + )); + + let intersection_id = self.get_constant_scalar(crate::Literal::U32(if is_committed { + spirv::RayQueryIntersection::RayQueryCommittedIntersectionKHR + } else { + spirv::RayQueryIntersection::RayQueryCandidateIntersectionKHR + } as _)); + let raw_kind_id = self.id_gen.next(); + block.body.push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionTypeKHR, + flag_type_id, + raw_kind_id, + query_id, + intersection_id, + )); + let kind_id = if is_committed { + // Nothing to do: the IR value matches `spirv::RayQueryCommittedIntersectionType` + raw_kind_id + } else { + // Remap from the candidate kind to IR + let condition_id = self.id_gen.next(); + let committed_triangle_kind_id = self.get_constant_scalar(crate::Literal::U32( + spirv::RayQueryCandidateIntersectionType::RayQueryCandidateIntersectionTriangleKHR + as _, + )); + block.body.push(Instruction::binary( + spirv::Op::IEqual, + self.get_bool_type_id(), + condition_id, + raw_kind_id, + committed_triangle_kind_id, + )); + let kind_id = self.id_gen.next(); + block.body.push(Instruction::select( + flag_type_id, + kind_id, + condition_id, + self.get_constant_scalar(crate::Literal::U32( + crate::RayQueryIntersection::Triangle as _, + )), + self.get_constant_scalar(crate::Literal::U32( + crate::RayQueryIntersection::Aabb as _, + )), + )); + kind_id + }; + let idx_id = self.get_index_constant(0); + let access_idx = self.id_gen.next(); + block.body.push(Instruction::access_chain( + flag_pointer_type_id, + access_idx, + blank_intersection_id, + &[idx_id], + )); + block + .body + .push(Instruction::store(access_idx, kind_id, None)); + + let not_none_comp_id = self.id_gen.next(); + let none_id = + self.get_constant_scalar(crate::Literal::U32(crate::RayQueryIntersection::None as _)); + block.body.push(Instruction::binary( + spirv::Op::INotEqual, + self.get_bool_type_id(), + not_none_comp_id, + kind_id, + none_id, + )); + + let not_none_label_id = self.id_gen.next(); + let mut not_none_block = Block::new(not_none_label_id); + + let final_label_id = self.id_gen.next(); + let mut final_block = Block::new(final_label_id); + + block.body.push(Instruction::selection_merge( + final_label_id, + spirv::SelectionControl::NONE, + )); + function.consume( + block, + Instruction::branch_conditional(not_none_comp_id, not_none_label_id, final_label_id), + ); + + let instance_custom_index_id = self.id_gen.next(); + not_none_block + .body + .push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionInstanceCustomIndexKHR, + flag_type_id, + instance_custom_index_id, + query_id, + intersection_id, + )); + let instance_id = self.id_gen.next(); + not_none_block + .body + .push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionInstanceIdKHR, + flag_type_id, + instance_id, + query_id, + intersection_id, + )); + let sbt_record_offset_id = self.id_gen.next(); + not_none_block + .body + .push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR, + flag_type_id, + sbt_record_offset_id, + query_id, + intersection_id, + )); + let geometry_index_id = self.id_gen.next(); + not_none_block + .body + .push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionGeometryIndexKHR, + flag_type_id, + geometry_index_id, + query_id, + intersection_id, + )); + let primitive_index_id = self.id_gen.next(); + not_none_block + .body + .push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionPrimitiveIndexKHR, + flag_type_id, + primitive_index_id, + query_id, + intersection_id, + )); + + let t_id = self.id_gen.next(); + not_none_block + .body + .push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionTKHR, + scalar_type_id, + t_id, + query_id, + intersection_id, + )); + + //Note: there is also `OpRayQueryGetIntersectionCandidateAABBOpaqueKHR`, + // but it's not a property of an intersection. + + let object_to_world_id = self.id_gen.next(); + not_none_block + .body + .push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionObjectToWorldKHR, + transform_type_id, + object_to_world_id, + query_id, + intersection_id, + )); + let world_to_object_id = self.id_gen.next(); + not_none_block + .body + .push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionWorldToObjectKHR, + transform_type_id, + world_to_object_id, + query_id, + intersection_id, + )); + + // t + let idx_id = self.get_index_constant(1); + let access_idx = self.id_gen.next(); + not_none_block.body.push(Instruction::access_chain( + float_pointer_type_id, + access_idx, + blank_intersection_id, + &[idx_id], + )); + not_none_block + .body + .push(Instruction::store(access_idx, t_id, None)); + + // instance custom index + let idx_id = self.get_index_constant(2); + let access_idx = self.id_gen.next(); + not_none_block.body.push(Instruction::access_chain( + flag_pointer_type_id, + access_idx, + blank_intersection_id, + &[idx_id], + )); + not_none_block.body.push(Instruction::store( + access_idx, + instance_custom_index_id, + None, + )); + + // instance + let idx_id = self.get_index_constant(3); + let access_idx = self.id_gen.next(); + not_none_block.body.push(Instruction::access_chain( + flag_pointer_type_id, + access_idx, + blank_intersection_id, + &[idx_id], + )); + not_none_block + .body + .push(Instruction::store(access_idx, instance_id, None)); + + let idx_id = self.get_index_constant(4); + let access_idx = self.id_gen.next(); + not_none_block.body.push(Instruction::access_chain( + flag_pointer_type_id, + access_idx, + blank_intersection_id, + &[idx_id], + )); + not_none_block + .body + .push(Instruction::store(access_idx, sbt_record_offset_id, None)); + + let idx_id = self.get_index_constant(5); + let access_idx = self.id_gen.next(); + not_none_block.body.push(Instruction::access_chain( + flag_pointer_type_id, + access_idx, + blank_intersection_id, + &[idx_id], + )); + not_none_block + .body + .push(Instruction::store(access_idx, geometry_index_id, None)); + + let idx_id = self.get_index_constant(6); + let access_idx = self.id_gen.next(); + not_none_block.body.push(Instruction::access_chain( + flag_pointer_type_id, + access_idx, + blank_intersection_id, + &[idx_id], + )); + not_none_block + .body + .push(Instruction::store(access_idx, primitive_index_id, None)); + + let idx_id = self.get_index_constant(9); + let access_idx = self.id_gen.next(); + not_none_block.body.push(Instruction::access_chain( + transform_pointer_type_id, + access_idx, + blank_intersection_id, + &[idx_id], + )); + not_none_block + .body + .push(Instruction::store(access_idx, object_to_world_id, None)); + + let idx_id = self.get_index_constant(10); + let access_idx = self.id_gen.next(); + not_none_block.body.push(Instruction::access_chain( + transform_pointer_type_id, + access_idx, + blank_intersection_id, + &[idx_id], + )); + not_none_block + .body + .push(Instruction::store(access_idx, world_to_object_id, None)); + + let tri_comp_id = self.id_gen.next(); + let tri_id = self.get_constant_scalar(crate::Literal::U32( + crate::RayQueryIntersection::Triangle as _, + )); + not_none_block.body.push(Instruction::binary( + spirv::Op::IEqual, + self.get_bool_type_id(), + tri_comp_id, + kind_id, + tri_id, + )); + + let tri_label_id = self.id_gen.next(); + let mut tri_block = Block::new(tri_label_id); + + let merge_label_id = self.id_gen.next(); + let merge_block = Block::new(merge_label_id); + not_none_block.body.push(Instruction::selection_merge( + merge_label_id, + spirv::SelectionControl::NONE, + )); + function.consume( + not_none_block, + Instruction::branch_conditional(not_none_comp_id, tri_label_id, merge_label_id), + ); + + let barycentrics_id = self.id_gen.next(); + tri_block.body.push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionBarycentricsKHR, + barycentrics_type_id, + barycentrics_id, + query_id, + intersection_id, + )); + + let front_face_id = self.id_gen.next(); + tri_block.body.push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionFrontFaceKHR, + bool_type_id, + front_face_id, + query_id, + intersection_id, + )); + + let idx_id = self.get_index_constant(7); + let access_idx = self.id_gen.next(); + tri_block.body.push(Instruction::access_chain( + barycentrics_pointer_type_id, + access_idx, + blank_intersection_id, + &[idx_id], + )); + tri_block + .body + .push(Instruction::store(access_idx, barycentrics_id, None)); + + let idx_id = self.get_index_constant(8); + let access_idx = self.id_gen.next(); + tri_block.body.push(Instruction::access_chain( + bool_pointer_type_id, + access_idx, + blank_intersection_id, + &[idx_id], + )); + tri_block + .body + .push(Instruction::store(access_idx, front_face_id, None)); + function.consume(tri_block, Instruction::branch(merge_label_id)); + function.consume(merge_block, Instruction::branch(final_label_id)); + + let loaded_blank_intersection_id = self.id_gen.next(); + final_block.body.push(Instruction::load( + intersection_type_id, + loaded_blank_intersection_id, + blank_intersection_id, + None, + )); + function.consume( + final_block, + Instruction::return_value(loaded_blank_intersection_id), + ); + + function.to_words(&mut self.logical_layout.function_definitions); + Instruction::function_end().to_words(&mut self.logical_layout.function_definitions); + self.ray_get_intersection_function = Some(func_id); + func_id + } +} impl BlockContext<'_> { pub(super) fn write_ray_query_function( diff --git a/naga/src/back/spv/writer.rs b/naga/src/back/spv/writer.rs index 47f3ec513b..c41b3ab274 100644 --- a/naga/src/back/spv/writer.rs +++ b/naga/src/back/spv/writer.rs @@ -21,7 +21,7 @@ struct FunctionInterface<'a> { } impl Function { - fn to_words(&self, sink: &mut impl Extend) { + pub(super) fn to_words(&self, sink: &mut impl Extend) { self.signature.as_ref().unwrap().to_words(sink); for argument in self.parameters.iter() { argument.instruction.to_words(sink); @@ -81,6 +81,7 @@ impl Writer { saved_cached: CachedExpressions::default(), gl450_ext_inst_id, temp_list: Vec::new(), + ray_get_intersection_function: None, }) } @@ -131,6 +132,7 @@ impl Writer { global_variables: take(&mut self.global_variables).recycle(), saved_cached: take(&mut self.saved_cached).recycle(), temp_list: take(&mut self.temp_list).recycle(), + ray_get_intersection_function: None, }; *self = fresh; @@ -1833,7 +1835,7 @@ impl Writer { Ok(()) } - fn get_function_type(&mut self, lookup_function_type: LookupFunctionType) -> Word { + pub(super) fn get_function_type(&mut self, lookup_function_type: LookupFunctionType) -> Word { match self .lookup_function_type .entry(lookup_function_type.clone()) From c1db49b4a6f36b0e90eae45c01a777a315f7b023 Mon Sep 17 00:00:00 2001 From: Vecvec Date: Mon, 16 Dec 2024 16:27:16 +1300 Subject: [PATCH 04/11] regen snapshots --- naga/tests/out/spv/ray-query.spvasm | 243 ++++++++++++++++------------ 1 file changed, 143 insertions(+), 100 deletions(-) diff --git a/naga/tests/out/spv/ray-query.spvasm b/naga/tests/out/spv/ray-query.spvasm index 5279bfc2e1..b3d01aae76 100644 --- a/naga/tests/out/spv/ray-query.spvasm +++ b/naga/tests/out/spv/ray-query.spvasm @@ -1,16 +1,16 @@ ; SPIR-V ; Version: 1.4 ; Generator: rspirv -; Bound: 136 +; Bound: 160 OpCapability Shader OpCapability RayQueryKHR OpExtension "SPV_KHR_ray_query" %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 -OpEntryPoint GLCompute %84 "main" %15 %17 -OpEntryPoint GLCompute %105 "main_candidate" %15 %17 -OpExecutionMode %84 LocalSize 1 1 1 -OpExecutionMode %105 LocalSize 1 1 1 +OpEntryPoint GLCompute %123 "main" %15 %17 +OpEntryPoint GLCompute %143 "main_candidate" %15 %17 +OpExecutionMode %123 LocalSize 1 1 1 +OpExecutionMode %143 LocalSize 1 1 1 OpMemberDecorate %10 0 Offset 0 OpMemberDecorate %10 1 Offset 4 OpMemberDecorate %10 2 Offset 8 @@ -64,20 +64,87 @@ OpMemberDecorate %18 0 Offset 0 %29 = OpConstant %3 0.1 %30 = OpConstant %3 100.0 %32 = OpTypePointer Function %11 -%50 = OpConstant %6 1 -%67 = OpTypeFunction %4 %4 %10 -%68 = OpConstant %3 1.0 -%69 = OpConstant %3 2.4 -%70 = OpConstant %3 0.0 -%85 = OpTypeFunction %2 -%87 = OpTypePointer StorageBuffer %13 -%88 = OpConstant %6 0 -%90 = OpConstantComposite %4 %70 %70 %70 -%91 = OpConstantComposite %4 %70 %68 %70 -%94 = OpTypePointer StorageBuffer %6 -%99 = OpTypePointer StorageBuffer %4 -%108 = OpConstantComposite %12 %27 %28 %29 %30 %90 %91 -%109 = OpConstant %6 3 +%50 = OpTypePointer Function %10 +%51 = OpTypePointer Function %6 +%52 = OpTypePointer Function %9 +%53 = OpTypePointer Function %7 +%54 = OpTypePointer Function %8 +%55 = OpTypePointer Function %3 +%56 = OpTypeFunction %10 %32 +%58 = OpConstantNull %10 +%62 = OpConstant %6 1 +%64 = OpConstant %6 0 +%78 = OpConstant %6 2 +%80 = OpConstant %6 3 +%83 = OpConstant %6 5 +%85 = OpConstant %6 6 +%87 = OpConstant %6 9 +%89 = OpConstant %6 10 +%96 = OpConstant %6 7 +%98 = OpConstant %6 8 +%106 = OpTypeFunction %4 %4 %10 +%107 = OpConstant %3 1.0 +%108 = OpConstant %3 2.4 +%109 = OpConstant %3 0.0 +%124 = OpTypeFunction %2 +%126 = OpTypePointer StorageBuffer %13 +%128 = OpConstantComposite %4 %109 %109 %109 +%129 = OpConstantComposite %4 %109 %107 %109 +%132 = OpTypePointer StorageBuffer %6 +%137 = OpTypePointer StorageBuffer %4 +%146 = OpConstantComposite %12 %27 %28 %29 %30 %128 %129 +%57 = OpFunction %10 None %56 +%59 = OpFunctionParameter %32 +%60 = OpLabel +%61 = OpVariable %50 Function %58 +%63 = OpRayQueryGetIntersectionTypeKHR %6 %59 %62 +%65 = OpAccessChain %51 %61 %64 +OpStore %65 %63 +%66 = OpINotEqual %8 %63 %64 +OpSelectionMerge %68 None +OpBranchConditional %66 %67 %68 +%67 = OpLabel +%69 = OpRayQueryGetIntersectionInstanceCustomIndexKHR %6 %59 %62 +%70 = OpRayQueryGetIntersectionInstanceIdKHR %6 %59 %62 +%71 = OpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR %6 %59 %62 +%72 = OpRayQueryGetIntersectionGeometryIndexKHR %6 %59 %62 +%73 = OpRayQueryGetIntersectionPrimitiveIndexKHR %6 %59 %62 +%74 = OpRayQueryGetIntersectionTKHR %3 %59 %62 +%75 = OpRayQueryGetIntersectionObjectToWorldKHR %9 %59 %62 +%76 = OpRayQueryGetIntersectionWorldToObjectKHR %9 %59 %62 +%77 = OpAccessChain %55 %61 %62 +OpStore %77 %74 +%79 = OpAccessChain %51 %61 %78 +OpStore %79 %69 +%81 = OpAccessChain %51 %61 %80 +OpStore %81 %70 +%82 = OpAccessChain %51 %61 %27 +OpStore %82 %71 +%84 = OpAccessChain %51 %61 %83 +OpStore %84 %72 +%86 = OpAccessChain %51 %61 %85 +OpStore %86 %73 +%88 = OpAccessChain %52 %61 %87 +OpStore %88 %75 +%90 = OpAccessChain %52 %61 %89 +OpStore %90 %76 +%91 = OpIEqual %8 %63 %62 +OpSelectionMerge %93 None +OpBranchConditional %66 %92 %93 +%92 = OpLabel +%94 = OpRayQueryGetIntersectionBarycentricsKHR %7 %59 %62 +%95 = OpRayQueryGetIntersectionFrontFaceKHR %8 %59 %62 +%97 = OpAccessChain %53 %61 %96 +OpStore %97 %94 +%99 = OpAccessChain %54 %61 %98 +OpStore %99 %95 +OpBranch %93 +%93 = OpLabel +OpBranch %68 +%68 = OpLabel +%100 = OpLoad %10 %61 +OpReturnValue %100 +OpFunctionEnd %25 = OpFunction %10 None %26 %21 = OpFunctionParameter %4 %22 = OpFunctionParameter %4 @@ -114,90 +181,66 @@ OpBranch %44 %44 = OpLabel OpBranch %41 %42 = OpLabel -%51 = OpRayQueryGetIntersectionTypeKHR %6 %31 %50 -%52 = OpRayQueryGetIntersectionInstanceCustomIndexKHR %6 %31 %50 -%53 = OpRayQueryGetIntersectionInstanceIdKHR %6 %31 %50 -%54 = OpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR %6 %31 %50 -%55 = OpRayQueryGetIntersectionGeometryIndexKHR %6 %31 %50 -%56 = OpRayQueryGetIntersectionPrimitiveIndexKHR %6 %31 %50 -%57 = OpRayQueryGetIntersectionTKHR %3 %31 %50 -%58 = OpRayQueryGetIntersectionBarycentricsKHR %7 %31 %50 -%59 = OpRayQueryGetIntersectionFrontFaceKHR %8 %31 %50 -%60 = OpRayQueryGetIntersectionObjectToWorldKHR %9 %31 %50 -%61 = OpRayQueryGetIntersectionWorldToObjectKHR %9 %31 %50 -%62 = OpCompositeConstruct %10 %51 %57 %52 %53 %54 %55 %56 %58 %59 %60 %61 -OpReturnValue %62 +%101 = OpFunctionCall %10 %57 %31 +OpReturnValue %101 OpFunctionEnd -%66 = OpFunction %4 None %67 -%64 = OpFunctionParameter %4 -%65 = OpFunctionParameter %10 -%63 = OpLabel -OpBranch %71 -%71 = OpLabel -%72 = OpCompositeExtract %9 %65 10 -%73 = OpCompositeConstruct %14 %64 %68 -%74 = OpMatrixTimesVector %4 %72 %73 -%75 = OpVectorShuffle %7 %74 %74 0 1 -%76 = OpExtInst %7 %1 Normalize %75 -%77 = OpVectorTimesScalar %7 %76 %69 -%78 = OpCompositeExtract %9 %65 9 -%79 = OpCompositeConstruct %14 %77 %70 %68 -%80 = OpMatrixTimesVector %4 %78 %79 -%81 = OpFSub %4 %64 %80 -%82 = OpExtInst %4 %1 Normalize %81 -OpReturnValue %82 +%105 = OpFunction %4 None %106 +%103 = OpFunctionParameter %4 +%104 = OpFunctionParameter %10 +%102 = OpLabel +OpBranch %110 +%110 = OpLabel +%111 = OpCompositeExtract %9 %104 10 +%112 = OpCompositeConstruct %14 %103 %107 +%113 = OpMatrixTimesVector %4 %111 %112 +%114 = OpVectorShuffle %7 %113 %113 0 1 +%115 = OpExtInst %7 %1 Normalize %114 +%116 = OpVectorTimesScalar %7 %115 %108 +%117 = OpCompositeExtract %9 %104 9 +%118 = OpCompositeConstruct %14 %116 %109 %107 +%119 = OpMatrixTimesVector %4 %117 %118 +%120 = OpFSub %4 %103 %119 +%121 = OpExtInst %4 %1 Normalize %120 +OpReturnValue %121 OpFunctionEnd -%84 = OpFunction %2 None %85 -%83 = OpLabel -%86 = OpLoad %5 %15 -%89 = OpAccessChain %87 %17 %88 -OpBranch %92 -%92 = OpLabel -%93 = OpFunctionCall %10 %25 %90 %91 %15 -%95 = OpCompositeExtract %6 %93 0 -%96 = OpIEqual %8 %95 %88 -%97 = OpSelect %6 %96 %50 %88 -%98 = OpAccessChain %94 %89 %88 -OpStore %98 %97 -%100 = OpCompositeExtract %3 %93 1 -%101 = OpVectorTimesScalar %4 %91 %100 -%102 = OpFunctionCall %4 %66 %101 %93 -%103 = OpAccessChain %99 %89 %50 -OpStore %103 %102 +%123 = OpFunction %2 None %124 +%122 = OpLabel +%125 = OpLoad %5 %15 +%127 = OpAccessChain %126 %17 %64 +OpBranch %130 +%130 = OpLabel +%131 = OpFunctionCall %10 %25 %128 %129 %15 +%133 = OpCompositeExtract %6 %131 0 +%134 = OpIEqual %8 %133 %64 +%135 = OpSelect %6 %134 %62 %64 +%136 = OpAccessChain %132 %127 %64 +OpStore %136 %135 +%138 = OpCompositeExtract %3 %131 1 +%139 = OpVectorTimesScalar %4 %129 %138 +%140 = OpFunctionCall %4 %105 %139 %131 +%141 = OpAccessChain %137 %127 %62 +OpStore %141 %140 OpReturn OpFunctionEnd -%105 = OpFunction %2 None %85 -%104 = OpLabel -%110 = OpVariable %32 Function -%106 = OpLoad %5 %15 -%107 = OpAccessChain %87 %17 %88 -OpBranch %111 -%111 = OpLabel -%112 = OpCompositeExtract %6 %108 0 -%113 = OpCompositeExtract %6 %108 1 -%114 = OpCompositeExtract %3 %108 2 -%115 = OpCompositeExtract %3 %108 3 -%116 = OpCompositeExtract %4 %108 4 -%117 = OpCompositeExtract %4 %108 5 -OpRayQueryInitializeKHR %110 %106 %112 %113 %116 %114 %117 %115 -%118 = OpRayQueryGetIntersectionTypeKHR %6 %110 %88 -%119 = OpIEqual %8 %118 %88 -%120 = OpSelect %6 %119 %50 %109 -%121 = OpRayQueryGetIntersectionInstanceCustomIndexKHR %6 %110 %88 -%122 = OpRayQueryGetIntersectionInstanceIdKHR %6 %110 %88 -%123 = OpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR %6 %110 %88 -%124 = OpRayQueryGetIntersectionGeometryIndexKHR %6 %110 %88 -%125 = OpRayQueryGetIntersectionPrimitiveIndexKHR %6 %110 %88 -%126 = OpRayQueryGetIntersectionTKHR %3 %110 %88 -%127 = OpRayQueryGetIntersectionBarycentricsKHR %7 %110 %88 -%128 = OpRayQueryGetIntersectionFrontFaceKHR %8 %110 %88 -%129 = OpRayQueryGetIntersectionObjectToWorldKHR %9 %110 %88 -%130 = OpRayQueryGetIntersectionWorldToObjectKHR %9 %110 %88 -%131 = OpCompositeConstruct %10 %120 %126 %121 %122 %123 %124 %125 %127 %128 %129 %130 -%132 = OpCompositeExtract %6 %131 0 -%133 = OpIEqual %8 %132 %109 -%134 = OpSelect %6 %133 %50 %88 -%135 = OpAccessChain %94 %107 %88 -OpStore %135 %134 +%143 = OpFunction %2 None %124 +%142 = OpLabel +%147 = OpVariable %32 Function +%144 = OpLoad %5 %15 +%145 = OpAccessChain %126 %17 %64 +OpBranch %148 +%148 = OpLabel +%149 = OpCompositeExtract %6 %146 0 +%150 = OpCompositeExtract %6 %146 1 +%151 = OpCompositeExtract %3 %146 2 +%152 = OpCompositeExtract %3 %146 3 +%153 = OpCompositeExtract %4 %146 4 +%154 = OpCompositeExtract %4 %146 5 +OpRayQueryInitializeKHR %147 %144 %149 %150 %153 %151 %154 %152 +%155 = OpFunctionCall %10 %57 %147 +%156 = OpCompositeExtract %6 %155 0 +%157 = OpIEqual %8 %156 %80 +%158 = OpSelect %6 %157 %62 %64 +%159 = OpAccessChain %132 %145 %64 +OpStore %159 %158 OpReturn OpFunctionEnd \ No newline at end of file From a1f821b87b1b87f817bbb24903b91c5aafa92d76 Mon Sep 17 00:00:00 2001 From: Vecvec Date: Mon, 16 Dec 2024 17:56:10 +1300 Subject: [PATCH 05/11] remove old (and now unused) function --- naga/src/back/spv/ray.rs | 187 --------------------------------------- 1 file changed, 187 deletions(-) diff --git a/naga/src/back/spv/ray.rs b/naga/src/back/spv/ray.rs index b2ad77eb3f..d76bb9ecd8 100644 --- a/naga/src/back/spv/ray.rs +++ b/naga/src/back/spv/ray.rs @@ -611,191 +611,4 @@ impl BlockContext<'_> { crate::RayQueryFunction::Terminate => {} } } - - pub(super) fn write_ray_query_get_intersection( - &mut self, - query: Handle, - block: &mut Block, - is_committed: bool, - ) -> spirv::Word { - let query_id = self.cached[query]; - let intersection_id = - self.writer - .get_constant_scalar(crate::Literal::U32(if is_committed { - spirv::RayQueryIntersection::RayQueryCommittedIntersectionKHR - } else { - spirv::RayQueryIntersection::RayQueryCandidateIntersectionKHR - } as _)); - - let flag_type_id = self.get_type_id(LookupType::Local(LocalType::Numeric( - NumericType::Scalar(crate::Scalar::U32), - ))); - let raw_kind_id = self.gen_id(); - block.body.push(Instruction::ray_query_get_intersection( - spirv::Op::RayQueryGetIntersectionTypeKHR, - flag_type_id, - raw_kind_id, - query_id, - intersection_id, - )); - let kind_id = if is_committed { - // Nothing to do: the IR value matches `spirv::RayQueryCommittedIntersectionType` - raw_kind_id - } else { - // Remap from the candidate kind to IR - let condition_id = self.gen_id(); - let committed_triangle_kind_id = self.writer.get_constant_scalar(crate::Literal::U32( - spirv::RayQueryCandidateIntersectionType::RayQueryCandidateIntersectionTriangleKHR - as _, - )); - block.body.push(Instruction::binary( - spirv::Op::IEqual, - self.writer.get_bool_type_id(), - condition_id, - raw_kind_id, - committed_triangle_kind_id, - )); - let kind_id = self.gen_id(); - block.body.push(Instruction::select( - flag_type_id, - kind_id, - condition_id, - self.writer.get_constant_scalar(crate::Literal::U32( - crate::RayQueryIntersection::Triangle as _, - )), - self.writer.get_constant_scalar(crate::Literal::U32( - crate::RayQueryIntersection::Aabb as _, - )), - )); - kind_id - }; - - let instance_custom_index_id = self.gen_id(); - block.body.push(Instruction::ray_query_get_intersection( - spirv::Op::RayQueryGetIntersectionInstanceCustomIndexKHR, - flag_type_id, - instance_custom_index_id, - query_id, - intersection_id, - )); - let instance_id = self.gen_id(); - block.body.push(Instruction::ray_query_get_intersection( - spirv::Op::RayQueryGetIntersectionInstanceIdKHR, - flag_type_id, - instance_id, - query_id, - intersection_id, - )); - let sbt_record_offset_id = self.gen_id(); - block.body.push(Instruction::ray_query_get_intersection( - spirv::Op::RayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR, - flag_type_id, - sbt_record_offset_id, - query_id, - intersection_id, - )); - let geometry_index_id = self.gen_id(); - block.body.push(Instruction::ray_query_get_intersection( - spirv::Op::RayQueryGetIntersectionGeometryIndexKHR, - flag_type_id, - geometry_index_id, - query_id, - intersection_id, - )); - let primitive_index_id = self.gen_id(); - block.body.push(Instruction::ray_query_get_intersection( - spirv::Op::RayQueryGetIntersectionPrimitiveIndexKHR, - flag_type_id, - primitive_index_id, - query_id, - intersection_id, - )); - - let scalar_type_id = self.get_type_id(LookupType::Local(LocalType::Numeric( - NumericType::Scalar(crate::Scalar::F32), - ))); - let t_id = self.gen_id(); - block.body.push(Instruction::ray_query_get_intersection( - spirv::Op::RayQueryGetIntersectionTKHR, - scalar_type_id, - t_id, - query_id, - intersection_id, - )); - - let barycentrics_type_id = - self.get_type_id(LookupType::Local(LocalType::Numeric(NumericType::Vector { - size: crate::VectorSize::Bi, - scalar: crate::Scalar::F32, - }))); - let barycentrics_id = self.gen_id(); - block.body.push(Instruction::ray_query_get_intersection( - spirv::Op::RayQueryGetIntersectionBarycentricsKHR, - barycentrics_type_id, - barycentrics_id, - query_id, - intersection_id, - )); - - let bool_type_id = self.get_type_id(LookupType::Local(LocalType::Numeric( - NumericType::Scalar(crate::Scalar::BOOL), - ))); - let front_face_id = self.gen_id(); - block.body.push(Instruction::ray_query_get_intersection( - spirv::Op::RayQueryGetIntersectionFrontFaceKHR, - bool_type_id, - front_face_id, - query_id, - intersection_id, - )); - //Note: there is also `OpRayQueryGetIntersectionCandidateAABBOpaqueKHR`, - // but it's not a property of an intersection. - - let transform_type_id = - self.get_type_id(LookupType::Local(LocalType::Numeric(NumericType::Matrix { - columns: crate::VectorSize::Quad, - rows: crate::VectorSize::Tri, - scalar: crate::Scalar::F32, - }))); - let object_to_world_id = self.gen_id(); - block.body.push(Instruction::ray_query_get_intersection( - spirv::Op::RayQueryGetIntersectionObjectToWorldKHR, - transform_type_id, - object_to_world_id, - query_id, - intersection_id, - )); - let world_to_object_id = self.gen_id(); - block.body.push(Instruction::ray_query_get_intersection( - spirv::Op::RayQueryGetIntersectionWorldToObjectKHR, - transform_type_id, - world_to_object_id, - query_id, - intersection_id, - )); - - let id = self.gen_id(); - let intersection_type_id = self.get_type_id(LookupType::Handle( - self.ir_module.special_types.ray_intersection.unwrap(), - )); - //Note: the arguments must match `generate_ray_intersection_type` layout - block.body.push(Instruction::composite_construct( - intersection_type_id, - id, - &[ - kind_id, - t_id, - instance_custom_index_id, - instance_id, - sbt_record_offset_id, - geometry_index_id, - primitive_index_id, - barycentrics_id, - front_face_id, - object_to_world_id, - world_to_object_id, - ], - )); - id - } } From ffb73fe1d558901bdd371763ff0306ab4959e7d1 Mon Sep 17 00:00:00 2001 From: Vecvec Date: Mon, 16 Dec 2024 18:14:44 +1300 Subject: [PATCH 06/11] changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index beae54b7a7..c1c2b7aea6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -171,6 +171,7 @@ By @ErichDonGubler in [#6456](https://github.com/gfx-rs/wgpu/pull/6456), [#6148] - Fix crash when a texture argument is missing. By @aedm in [#6486](https://github.com/gfx-rs/wgpu/pull/6486) - Emit an error in constant evaluation, rather than crash, in certain cases where `vecN` constructors have less than N arguments. By @ErichDonGubler in [#6508](https://github.com/gfx-rs/wgpu/pull/6508). +- Stop naga causing undefined behavior when a ray query misses. By @Vecvec in [#6752](https://github.com/gfx-rs/wgpu/pull/6752). ### Testing From 37260ca2aadd5b7f594635cda016c839d9beac23 Mon Sep 17 00:00:00 2001 From: Vecvec Date: Tue, 17 Dec 2024 14:21:45 +1300 Subject: [PATCH 07/11] format --- tests/tests/ray_tracing/as_build.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/tests/ray_tracing/as_build.rs b/tests/tests/ray_tracing/as_build.rs index dce342e982..427f463795 100644 --- a/tests/tests/ray_tracing/as_build.rs +++ b/tests/tests/ray_tracing/as_build.rs @@ -1,8 +1,10 @@ use std::iter; use crate::ray_tracing::AsBuildContext; -use wgpu::{*}; -use wgpu_test::{fail, gpu_test, FailureCase, GpuTestConfiguration, TestParameters, TestingContext}; +use wgpu::*; +use wgpu_test::{ + fail, gpu_test, FailureCase, GpuTestConfiguration, TestParameters, TestingContext, +}; #[gpu_test] static UNBUILT_BLAS: GpuTestConfiguration = GpuTestConfiguration::new() From 7251a36434c5ffa16f461787058d11587099da70 Mon Sep 17 00:00:00 2001 From: Vecvec Date: Wed, 18 Dec 2024 15:31:34 +1300 Subject: [PATCH 08/11] Make what block ray t is in dependent on whether it is committed or candidate. This is because the requirements of spirv undefined behaviour require it like this. --- naga/src/back/spv/ray.rs | 53 ++++++++++++++++++++++------------------ 1 file changed, 29 insertions(+), 24 deletions(-) diff --git a/naga/src/back/spv/ray.rs b/naga/src/back/spv/ray.rs index d76bb9ecd8..e1793f5225 100644 --- a/naga/src/back/spv/ray.rs +++ b/naga/src/back/spv/ray.rs @@ -290,17 +290,6 @@ impl Writer { intersection_id, )); - let t_id = self.id_gen.next(); - not_none_block - .body - .push(Instruction::ray_query_get_intersection( - spirv::Op::RayQueryGetIntersectionTKHR, - scalar_type_id, - t_id, - query_id, - intersection_id, - )); - //Note: there is also `OpRayQueryGetIntersectionCandidateAABBOpaqueKHR`, // but it's not a property of an intersection. @@ -325,19 +314,6 @@ impl Writer { intersection_id, )); - // t - let idx_id = self.get_index_constant(1); - let access_idx = self.id_gen.next(); - not_none_block.body.push(Instruction::access_chain( - float_pointer_type_id, - access_idx, - blank_intersection_id, - &[idx_id], - )); - not_none_block - .body - .push(Instruction::store(access_idx, t_id, None)); - // instance custom index let idx_id = self.get_index_constant(2); let access_idx = self.id_gen.next(); @@ -443,6 +419,35 @@ impl Writer { let merge_label_id = self.id_gen.next(); let merge_block = Block::new(merge_label_id); + // t + { + let block = if is_committed { + &mut not_none_block + } else { + &mut tri_block + }; + let t_id = self.id_gen.next(); + block + .body + .push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionTKHR, + scalar_type_id, + t_id, + query_id, + intersection_id, + )); + let idx_id = self.get_index_constant(1); + let access_idx = self.id_gen.next(); + block.body.push(Instruction::access_chain( + float_pointer_type_id, + access_idx, + blank_intersection_id, + &[idx_id], + )); + block + .body + .push(Instruction::store(access_idx, t_id, None)); + } not_none_block.body.push(Instruction::selection_merge( merge_label_id, spirv::SelectionControl::NONE, From c4fd478483bf451a82b26bf7eead6542ec42678b Mon Sep 17 00:00:00 2001 From: Vecvec Date: Wed, 18 Dec 2024 15:34:17 +1300 Subject: [PATCH 09/11] regen snapshots --- naga/tests/out/spv/ray-query.spvasm | 56 ++++++++++++++--------------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/naga/tests/out/spv/ray-query.spvasm b/naga/tests/out/spv/ray-query.spvasm index b3d01aae76..d9a9edc984 100644 --- a/naga/tests/out/spv/ray-query.spvasm +++ b/naga/tests/out/spv/ray-query.spvasm @@ -74,12 +74,12 @@ OpMemberDecorate %18 0 Offset 0 %58 = OpConstantNull %10 %62 = OpConstant %6 1 %64 = OpConstant %6 0 -%78 = OpConstant %6 2 -%80 = OpConstant %6 3 -%83 = OpConstant %6 5 -%85 = OpConstant %6 6 -%87 = OpConstant %6 9 -%89 = OpConstant %6 10 +%76 = OpConstant %6 2 +%78 = OpConstant %6 3 +%81 = OpConstant %6 5 +%83 = OpConstant %6 6 +%85 = OpConstant %6 9 +%87 = OpConstant %6 10 %96 = OpConstant %6 7 %98 = OpConstant %6 8 %106 = OpTypeFunction %4 %4 %10 @@ -109,37 +109,37 @@ OpBranchConditional %66 %67 %68 %71 = OpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR %6 %59 %62 %72 = OpRayQueryGetIntersectionGeometryIndexKHR %6 %59 %62 %73 = OpRayQueryGetIntersectionPrimitiveIndexKHR %6 %59 %62 -%74 = OpRayQueryGetIntersectionTKHR %3 %59 %62 -%75 = OpRayQueryGetIntersectionObjectToWorldKHR %9 %59 %62 -%76 = OpRayQueryGetIntersectionWorldToObjectKHR %9 %59 %62 -%77 = OpAccessChain %55 %61 %62 -OpStore %77 %74 +%74 = OpRayQueryGetIntersectionObjectToWorldKHR %9 %59 %62 +%75 = OpRayQueryGetIntersectionWorldToObjectKHR %9 %59 %62 +%77 = OpAccessChain %51 %61 %76 +OpStore %77 %69 %79 = OpAccessChain %51 %61 %78 -OpStore %79 %69 -%81 = OpAccessChain %51 %61 %80 -OpStore %81 %70 -%82 = OpAccessChain %51 %61 %27 -OpStore %82 %71 +OpStore %79 %70 +%80 = OpAccessChain %51 %61 %27 +OpStore %80 %71 +%82 = OpAccessChain %51 %61 %81 +OpStore %82 %72 %84 = OpAccessChain %51 %61 %83 -OpStore %84 %72 -%86 = OpAccessChain %51 %61 %85 -OpStore %86 %73 +OpStore %84 %73 +%86 = OpAccessChain %52 %61 %85 +OpStore %86 %74 %88 = OpAccessChain %52 %61 %87 OpStore %88 %75 -%90 = OpAccessChain %52 %61 %89 -OpStore %90 %76 -%91 = OpIEqual %8 %63 %62 -OpSelectionMerge %93 None -OpBranchConditional %66 %92 %93 -%92 = OpLabel +%89 = OpIEqual %8 %63 %62 +%92 = OpRayQueryGetIntersectionTKHR %3 %59 %62 +%93 = OpAccessChain %55 %61 %62 +OpStore %93 %92 +OpSelectionMerge %91 None +OpBranchConditional %66 %90 %91 +%90 = OpLabel %94 = OpRayQueryGetIntersectionBarycentricsKHR %7 %59 %62 %95 = OpRayQueryGetIntersectionFrontFaceKHR %8 %59 %62 %97 = OpAccessChain %53 %61 %96 OpStore %97 %94 %99 = OpAccessChain %54 %61 %98 OpStore %99 %95 -OpBranch %93 -%93 = OpLabel +OpBranch %91 +%91 = OpLabel OpBranch %68 %68 = OpLabel %100 = OpLoad %10 %61 @@ -238,7 +238,7 @@ OpBranch %148 OpRayQueryInitializeKHR %147 %144 %149 %150 %153 %151 %154 %152 %155 = OpFunctionCall %10 %57 %147 %156 = OpCompositeExtract %6 %155 0 -%157 = OpIEqual %8 %156 %80 +%157 = OpIEqual %8 %156 %78 %158 = OpSelect %6 %157 %62 %64 %159 = OpAccessChain %132 %145 %64 OpStore %159 %158 From 7ff13b9794965933d8e122e30d92b4ca178298ab Mon Sep 17 00:00:00 2001 From: Vecvec Date: Thu, 19 Dec 2024 07:31:10 +1300 Subject: [PATCH 10/11] rename --- etc/specs/ray_tracing.md | 6 +++--- examples/src/ray_cube_compute/shader.wgsl | 4 ++-- examples/src/ray_scene/shader.wgsl | 8 ++++---- naga/src/back/spv/ray.rs | 12 ++++++------ naga/src/front/type_gen.rs | 4 ++-- naga/tests/in/ray-query.wgsl | 4 ++-- player/src/lib.rs | 2 +- tests/tests/ray_tracing/shader.wgsl | 8 ++++---- wgpu-core/src/command/ray_tracing.rs | 8 ++++---- wgpu-core/src/ray_tracing.rs | 4 ++-- wgpu-hal/examples/ray-traced-triangle/main.rs | 2 +- wgpu-hal/src/lib.rs | 2 +- wgpu-hal/src/vulkan/device.rs | 2 +- wgpu-hal/src/vulkan/mod.rs | 2 +- wgpu/src/api/blas.rs | 8 ++++---- wgpu/src/backend/wgpu_core.rs | 2 +- 16 files changed, 39 insertions(+), 39 deletions(-) diff --git a/etc/specs/ray_tracing.md b/etc/specs/ray_tracing.md index b0b50fce56..e5969babc6 100644 --- a/etc/specs/ray_tracing.md +++ b/etc/specs/ray_tracing.md @@ -109,11 +109,11 @@ struct RayIntersection { kind: u32, // Distance from starting point, measured in units of `RayDesc::dir`. t: f32, - // Corresponds to `instance.custom_index` where `instance` is the `TlasInstance` + // Corresponds to `instance.custom_data` where `instance` is the `TlasInstance` // that the intersected object was contained in. - instance_custom_index: u32, + instance_custom_data: u32, // The index into the `TlasPackage` to get the `TlasInstance` that the hit object is in - instance_id: u32, + instance_index: u32, // The offset into the shader binding table. Currently, this value is always 0. sbt_record_offset: u32, // The index into the `Blas`'s build descriptor (e.g. if `BlasBuildEntry::geometry` is diff --git a/examples/src/ray_cube_compute/shader.wgsl b/examples/src/ray_cube_compute/shader.wgsl index 79ee7ad7e5..cba6e1f848 100644 --- a/examples/src/ray_cube_compute/shader.wgsl +++ b/examples/src/ray_cube_compute/shader.wgsl @@ -29,8 +29,8 @@ struct RayDesc { struct RayIntersection { kind: u32, t: f32, - instance_custom_index: u32, - instance_id: u32, + instance_custom_data: u32, + instance_index: u32, sbt_record_offset: u32, geometry_index: u32, primitive_index: u32, diff --git a/examples/src/ray_scene/shader.wgsl b/examples/src/ray_scene/shader.wgsl index 4e16bd9453..496125ea5c 100644 --- a/examples/src/ray_scene/shader.wgsl +++ b/examples/src/ray_scene/shader.wgsl @@ -52,8 +52,8 @@ struct RayDesc { struct RayIntersection { kind: u32, t: f32, - instance_custom_index: u32, - instance_id: u32, + instance_custom_data: u32, + instance_index: u32, sbt_record_offset: u32, geometry_index: u32, primitive_index: u32, @@ -131,7 +131,7 @@ fn fs_main(vertex: VertexOutput) -> @location(0) vec4 { let intersection = rayQueryGetCommittedIntersection(&rq); if (intersection.kind != RAY_QUERY_INTERSECTION_NONE) { - let instance = instances[intersection.instance_custom_index]; + let instance = instances[intersection.instance_custom_data]; let geometry = geometries[intersection.geometry_index + instance.first_geometry]; let index_offset = geometry.first_index; @@ -155,7 +155,7 @@ fn fs_main(vertex: VertexOutput) -> @location(0) vec4 { color = vec4(material.albedo, 1.0); - if(intersection.instance_custom_index == 1u){ + if(intersection.instance_custom_data == 1u){ color = vec4(normal, 1.0); } } diff --git a/naga/src/back/spv/ray.rs b/naga/src/back/spv/ray.rs index e1793f5225..907f4f9c2e 100644 --- a/naga/src/back/spv/ray.rs +++ b/naga/src/back/spv/ray.rs @@ -239,23 +239,23 @@ impl Writer { Instruction::branch_conditional(not_none_comp_id, not_none_label_id, final_label_id), ); - let instance_custom_index_id = self.id_gen.next(); + let instance_custom_data_id = self.id_gen.next(); not_none_block .body .push(Instruction::ray_query_get_intersection( spirv::Op::RayQueryGetIntersectionInstanceCustomIndexKHR, flag_type_id, - instance_custom_index_id, + instance_custom_data_id, query_id, intersection_id, )); - let instance_id = self.id_gen.next(); + let instance_index_id = self.id_gen.next(); not_none_block .body .push(Instruction::ray_query_get_intersection( spirv::Op::RayQueryGetIntersectionInstanceIdKHR, flag_type_id, - instance_id, + instance_index_id, query_id, intersection_id, )); @@ -325,7 +325,7 @@ impl Writer { )); not_none_block.body.push(Instruction::store( access_idx, - instance_custom_index_id, + instance_custom_data_id, None, )); @@ -340,7 +340,7 @@ impl Writer { )); not_none_block .body - .push(Instruction::store(access_idx, instance_id, None)); + .push(Instruction::store(access_idx, instance_index_id, None)); let idx_id = self.get_index_constant(4); let access_idx = self.id_gen.next(); diff --git a/naga/src/front/type_gen.rs b/naga/src/front/type_gen.rs index 1cd9f7f378..26c1099514 100644 --- a/naga/src/front/type_gen.rs +++ b/naga/src/front/type_gen.rs @@ -180,13 +180,13 @@ impl crate::Module { offset: 4, }, crate::StructMember { - name: Some("instance_custom_index".to_string()), + name: Some("instance_custom_data".to_string()), ty: ty_flag, binding: None, offset: 8, }, crate::StructMember { - name: Some("instance_id".to_string()), + name: Some("instance_index".to_string()), ty: ty_flag, binding: None, offset: 12, diff --git a/naga/tests/in/ray-query.wgsl b/naga/tests/in/ray-query.wgsl index 9f94356b83..0ed1606c05 100644 --- a/naga/tests/in/ray-query.wgsl +++ b/naga/tests/in/ray-query.wgsl @@ -28,8 +28,8 @@ struct RayDesc { struct RayIntersection { kind: u32, t: f32, - instance_custom_index: u32, - instance_id: u32, + instance_custom_data: u32, + instance_index: u32, sbt_record_offset: u32, geometry_index: u32, primitive_index: u32, diff --git a/player/src/lib.rs b/player/src/lib.rs index af82168ae4..e4cf5cca94 100644 --- a/player/src/lib.rs +++ b/player/src/lib.rs @@ -191,7 +191,7 @@ impl GlobalPlay for wgc::global::Global { .map(|instance| wgc::ray_tracing::TlasInstance { blas_id: instance.blas_id, transform: &instance.transform, - custom_index: instance.custom_index, + custom_data: instance.custom_data, mask: instance.mask, }) }); diff --git a/tests/tests/ray_tracing/shader.wgsl b/tests/tests/ray_tracing/shader.wgsl index ddb3505e75..2130b8d9ae 100644 --- a/tests/tests/ray_tracing/shader.wgsl +++ b/tests/tests/ray_tracing/shader.wgsl @@ -4,8 +4,8 @@ var acc_struct: acceleration_structure; struct Intersection { kind: u32, t: f32, - instance_custom_index: u32, - instance_id: u32, + instance_custom_data: u32, + instance_index: u32, sbt_record_offset: u32, geometry_index: u32, primitive_index: u32, @@ -38,8 +38,8 @@ fn all_of_struct() { out = Intersection( intersection.kind, intersection.t, - intersection.instance_custom_index, - intersection.instance_id, + intersection.instance_custom_data, + intersection.instance_index, intersection.sbt_record_offset, intersection.geometry_index, intersection.primitive_index, diff --git a/wgpu-core/src/command/ray_tracing.rs b/wgpu-core/src/command/ray_tracing.rs index 01fe891575..093e1dcc2b 100644 --- a/wgpu-core/src/command/ray_tracing.rs +++ b/wgpu-core/src/command/ray_tracing.rs @@ -430,7 +430,7 @@ impl Global { instance.map(|instance| TraceTlasInstance { blas_id: instance.blas_id, transform: *instance.transform, - custom_index: instance.custom_index, + custom_data: instance.custom_data, mask: instance.mask, }) }) @@ -478,7 +478,7 @@ impl Global { instance.as_ref().map(|instance| TlasInstance { blas_id: instance.blas_id, transform: &instance.transform, - custom_index: instance.custom_index, + custom_data: instance.custom_data, mask: instance.mask, }) }); @@ -553,7 +553,7 @@ impl Global { let mut instance_count = 0; for instance in package.instances.flatten() { - if instance.custom_index >= (1u32 << 24u32) { + if instance.custom_data >= (1u32 << 24u32) { return Err(BuildAccelerationStructureError::TlasInvalidCustomIndex( tlas.error_ident(), )); @@ -570,7 +570,7 @@ impl Global { instance_buffer_staging_source.extend(device.raw().tlas_instance_to_bytes( hal::TlasInstance { transform: *instance.transform, - custom_index: instance.custom_index, + custom_data: instance.custom_data, mask: instance.mask, blas_address: blas.handle, }, diff --git a/wgpu-core/src/ray_tracing.rs b/wgpu-core/src/ray_tracing.rs index 9f4a11946d..4b15dc91dd 100644 --- a/wgpu-core/src/ray_tracing.rs +++ b/wgpu-core/src/ray_tracing.rs @@ -197,7 +197,7 @@ pub struct TlasBuildEntry { pub struct TlasInstance<'a> { pub blas_id: BlasId, pub transform: &'a [f32; 12], - pub custom_index: u32, + pub custom_data: u32, pub mask: u8, } @@ -265,7 +265,7 @@ pub struct TraceBlasBuildEntry { pub struct TraceTlasInstance { pub blas_id: BlasId, pub transform: [f32; 12], - pub custom_index: u32, + pub custom_data: u32, pub mask: u8, } diff --git a/wgpu-hal/examples/ray-traced-triangle/main.rs b/wgpu-hal/examples/ray-traced-triangle/main.rs index b81ef86525..b736c7d06b 100644 --- a/wgpu-hal/examples/ray-traced-triangle/main.rs +++ b/wgpu-hal/examples/ray-traced-triangle/main.rs @@ -32,7 +32,7 @@ impl std::fmt::Debug for AccelerationStructureInstance { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Instance") .field("transform", &self.transform) - .field("custom_index()", &self.custom_index()) + .field("custom_data()", &self.custom_index()) .field("mask()", &self.mask()) .field( "shader_binding_table_record_offset()", diff --git a/wgpu-hal/src/lib.rs b/wgpu-hal/src/lib.rs index 12234d6364..17c5db5acc 100644 --- a/wgpu-hal/src/lib.rs +++ b/wgpu-hal/src/lib.rs @@ -2528,7 +2528,7 @@ pub struct AccelerationStructureBarrier { #[derive(Debug, Copy, Clone)] pub struct TlasInstance { pub transform: [f32; 12], - pub custom_index: u32, + pub custom_data: u32, pub mask: u8, pub blas_address: u64, } diff --git a/wgpu-hal/src/vulkan/device.rs b/wgpu-hal/src/vulkan/device.rs index 4a342fcfa1..ed303a84bb 100644 --- a/wgpu-hal/src/vulkan/device.rs +++ b/wgpu-hal/src/vulkan/device.rs @@ -2553,7 +2553,7 @@ impl crate::Device for super::Device { const MAX_U24: u32 = (1u32 << 24u32) - 1u32; let temp = RawTlasInstance { transform: instance.transform, - custom_index_and_mask: (instance.custom_index & MAX_U24) + custom_data_and_mask: (instance.custom_data & MAX_U24) | (u32::from(instance.mask) << 24), shader_binding_table_record_offset_and_flags: 0, acceleration_structure_reference: instance.blas_address, diff --git a/wgpu-hal/src/vulkan/mod.rs b/wgpu-hal/src/vulkan/mod.rs index 83a6b7e903..0b9da3ea9a 100644 --- a/wgpu-hal/src/vulkan/mod.rs +++ b/wgpu-hal/src/vulkan/mod.rs @@ -1452,7 +1452,7 @@ fn get_lost_err() -> crate::DeviceError { #[repr(C)] struct RawTlasInstance { transform: [f32; 12], - custom_index_and_mask: u32, + custom_data_and_mask: u32, shader_binding_table_record_offset_and_flags: u32, acceleration_structure_reference: u64, } diff --git a/wgpu/src/api/blas.rs b/wgpu/src/api/blas.rs index b64c01ba8f..409e90eb77 100644 --- a/wgpu/src/api/blas.rs +++ b/wgpu/src/api/blas.rs @@ -51,7 +51,7 @@ pub struct TlasInstance { /// /// This must only use the lower 24 bits, if any bits are outside that range (byte 4 does not equal 0) the TlasInstance becomes /// invalid and generates a validation error when built - pub custom_index: u32, + pub custom_data: u32, /// Mask for the instance used inside the shader to filter instances. /// Reports hit only if `(shader_cull_mask & tlas_instance.mask) != 0u`. pub mask: u8, @@ -61,7 +61,7 @@ impl TlasInstance { /// Construct TlasInstance. /// - blas: Reference to the bottom level acceleration structure /// - transform: Transform buffer offset in bytes (optional, required if transform buffer is present) - /// - custom_index: Custom index for the instance used inside the shader (max 24 bits) + /// - custom_data: Custom index for the instance used inside the shader (max 24 bits) /// - mask: Mask for the instance used inside the shader to filter instances /// /// Note: while one of these contains a reference to a BLAS that BLAS will not be dropped, @@ -69,11 +69,11 @@ impl TlasInstance { /// TlasInstance(s) will immediately make them invalid. If one or more of those invalid /// TlasInstances is inside a TlasPackage that is attempted to be built, the build will /// generate a validation error. - pub fn new(blas: &Blas, transform: [f32; 12], custom_index: u32, mask: u8) -> Self { + pub fn new(blas: &Blas, transform: [f32; 12], custom_data: u32, mask: u8) -> Self { Self { blas: blas.shared.clone(), transform, - custom_index, + custom_data, mask, } } diff --git a/wgpu/src/backend/wgpu_core.rs b/wgpu/src/backend/wgpu_core.rs index 8a7c8e1c31..a190c5f3b5 100644 --- a/wgpu/src/backend/wgpu_core.rs +++ b/wgpu/src/backend/wgpu_core.rs @@ -2530,7 +2530,7 @@ impl dispatch::CommandEncoderInterface for CoreCommandEncoder { .map(|instance| wgc::ray_tracing::TlasInstance { blas_id: instance.blas.inner.as_core().id, transform: &instance.transform, - custom_index: instance.custom_index, + custom_data: instance.custom_data, mask: instance.mask, }) }); From b3a555d779c394ee2d2fab680adcacacfd778b86 Mon Sep 17 00:00:00 2001 From: Vecvec Date: Thu, 19 Dec 2024 08:32:23 +1300 Subject: [PATCH 11/11] regen snapshots & fmt --- naga/src/back/spv/ray.rs | 20 ++++++++------------ naga/tests/out/msl/ray-query.msl | 4 ++-- 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/naga/src/back/spv/ray.rs b/naga/src/back/spv/ray.rs index 907f4f9c2e..f0b7bacd1e 100644 --- a/naga/src/back/spv/ray.rs +++ b/naga/src/back/spv/ray.rs @@ -427,15 +427,13 @@ impl Writer { &mut tri_block }; let t_id = self.id_gen.next(); - block - .body - .push(Instruction::ray_query_get_intersection( - spirv::Op::RayQueryGetIntersectionTKHR, - scalar_type_id, - t_id, - query_id, - intersection_id, - )); + block.body.push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionTKHR, + scalar_type_id, + t_id, + query_id, + intersection_id, + )); let idx_id = self.get_index_constant(1); let access_idx = self.id_gen.next(); block.body.push(Instruction::access_chain( @@ -444,9 +442,7 @@ impl Writer { blank_intersection_id, &[idx_id], )); - block - .body - .push(Instruction::store(access_idx, t_id, None)); + block.body.push(Instruction::store(access_idx, t_id, None)); } not_none_block.body.push(Instruction::selection_merge( merge_label_id, diff --git a/naga/tests/out/msl/ray-query.msl b/naga/tests/out/msl/ray-query.msl index b8230fb2e8..2bb6fab28a 100644 --- a/naga/tests/out/msl/ray-query.msl +++ b/naga/tests/out/msl/ray-query.msl @@ -16,8 +16,8 @@ constexpr metal::uint _map_intersection_type(const metal::raytracing::intersecti struct RayIntersection { uint kind; float t; - uint instance_custom_index; - uint instance_id; + uint instance_custom_data; + uint instance_index; uint sbt_record_offset; uint geometry_index; uint primitive_index;