Skip to content
This repository has been archived by the owner on Jan 29, 2025. It is now read-only.

spv-out: implement OpArrayLength on array buffer bindings #2372

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 99 additions & 17 deletions src/back/spv/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ Bounds-checking for SPIR-V output.
*/

use super::{
helpers::global_needs_wrapper, selection::Selection, Block, BlockContext, Error, IdGenerator,
Instruction, Word,
helpers::{global_needs_wrapper, map_storage_class},
selection::Selection,
Block, BlockContext, Error, IdGenerator, Instruction, Word,
};
use crate::{arena::Handle, proc::BoundsCheckPolicy};

Expand Down Expand Up @@ -42,32 +43,113 @@ impl<'w> BlockContext<'w> {
array: Handle<crate::Expression>,
block: &mut Block,
) -> Result<Word, Error> {
// Naga IR permits runtime-sized arrays as global variables or as the
// final member of a struct that is a global variable. SPIR-V permits
// only the latter, so this back end wraps bare runtime-sized arrays
// in a made-up struct; see `helpers::global_needs_wrapper` and its uses.
// This code must handle both cases.
let (structure_id, last_member_index) = match self.ir_function.expressions[array] {
// Naga IR permits runtime-sized arrays as global variables, or as the
// final member of a struct that is a global variable, or one of these
// inside a buffer that is itself an element in a buffer bindings array.
// SPIR-V requires that runtime-sized arrays are wrapped in structs.
// See `helpers::global_needs_wrapper` and its uses.
let (opt_array_index_id, global_handle, opt_last_member_index) = match self
.ir_function
.expressions[array]
{
crate::Expression::AccessIndex { base, index } => {
match self.ir_function.expressions[base] {
crate::Expression::GlobalVariable(handle) => (
self.writer.global_variables[handle.index()].access_id,
index,
),
_ => return Err(Error::Validation("array length expression")),
// The global variable is an array of buffer bindings of structs,
// we are accessing one of them with a static index,
// and the last member of it.
crate::Expression::AccessIndex {
base: base_outer,
index: index_outer,
} => match self.ir_function.expressions[base_outer] {
crate::Expression::GlobalVariable(handle) => {
let index_id = self.get_index_constant(index_outer);
(Some(index_id), handle, Some(index))
}
_ => return Err(Error::Validation("array length expression case-1a")),
},
// The global variable is an array of buffer bindings of structs,
// we are accessing one of them with a dynamic index,
// and the last member of it.
crate::Expression::Access {
base: base_outer,
index: index_outer,
} => match self.ir_function.expressions[base_outer] {
crate::Expression::GlobalVariable(handle) => {
let index_id = self.cached[index_outer];
(Some(index_id), handle, Some(index))
}
_ => return Err(Error::Validation("array length expression case-1b")),
},
// The global variable is a buffer, and we are accessing the last member.
crate::Expression::GlobalVariable(handle) => {
let global = &self.ir_module.global_variables[handle];
match self.ir_module.types[global.ty].inner {
// The global variable is an array of buffer bindings of run-time arrays.
crate::TypeInner::BindingArray { .. } => (Some(index), handle, None),
// The global variable is a struct, and we are accessing the last member
_ => (None, handle, Some(index)),
}
}
_ => return Err(Error::Validation("array length expression case-1c")),
}
}
// The global variable is an array of buffer bindings of arrays.
crate::Expression::Access { base, index } => match self.ir_function.expressions[base] {
crate::Expression::GlobalVariable(handle) => {
let index_id = self.cached[index];
let global = &self.ir_module.global_variables[handle];
match self.ir_module.types[global.ty].inner {
crate::TypeInner::BindingArray { .. } => (Some(index_id), handle, None),
_ => return Err(Error::Validation("array length expression case-2a")),
}
}
_ => return Err(Error::Validation("array length expression case-2b")),
},
// The global variable is a run-time array.
crate::Expression::GlobalVariable(handle) => {
let global = &self.ir_module.global_variables[handle];
if !global_needs_wrapper(self.ir_module, global) {
return Err(Error::Validation("array length expression"));
return Err(Error::Validation("array length expression case-3"));
}

(self.writer.global_variables[handle.index()].var_id, 0)
(None, handle, None)
}
_ => return Err(Error::Validation("array length expression")),
_ => return Err(Error::Validation("array length expression case-4")),
};

let gvar = self.writer.global_variables[global_handle.index()].clone();
let global = &self.ir_module.global_variables[global_handle];
let (last_member_index, gvar_id) = match opt_last_member_index {
Some(index) => (index, gvar.access_id),
None => {
if !global_needs_wrapper(self.ir_module, global) {
return Err(Error::Validation(
"pointer to a global that is not a wrapped array",
));
}
(0, gvar.var_id)
}
};
let structure_id = match opt_array_index_id {
// We are indexing inside a binding array, generate the access op.
Some(index_id) => {
let element_type_id = match self.ir_module.types[global.ty].inner {
crate::TypeInner::BindingArray { base, size: _ } => {
let class = map_storage_class(global.space);
self.get_pointer_id(base, class)?
}
_ => return Err(Error::Validation("array length expression case-5")),
};
let structure_id = self.gen_id();
block.body.push(Instruction::access_chain(
element_type_id,
structure_id,
gvar_id,
&[index_id],
));
structure_id
}
None => gvar_id,
};
let length_id = self.gen_id();
block.body.push(Instruction::array_length(
self.writer.get_uint_type_id(),
Expand Down
9 changes: 9 additions & 0 deletions src/back/spv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,15 @@ impl BlockContext<'_> {
self.writer
.get_constant_scalar(crate::Literal::I32(scope as _))
}

fn get_pointer_id(
&mut self,
handle: Handle<crate::Type>,
class: spirv::StorageClass,
) -> Result<Word, Error> {
self.writer
.get_pointer_id(&self.ir_module.types, handle, class)
}
}

#[derive(Clone, Copy, Default)]
Expand Down
62 changes: 32 additions & 30 deletions src/back/spv/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -599,36 +599,38 @@ impl Writer {
// Handle globals are pre-emitted and should be loaded automatically.
//
// Any that are binding arrays we skip as we cannot load the array, we must load the result after indexing.
let is_binding_array = match ir_module.types[var.ty].inner {
crate::TypeInner::BindingArray { .. } => true,
_ => false,
};

if var.space == crate::AddressSpace::Handle && !is_binding_array {
let var_type_id = self.get_type_id(LookupType::Handle(var.ty));
let id = self.id_gen.next();
prelude
.body
.push(Instruction::load(var_type_id, id, gv.var_id, None));
gv.access_id = gv.var_id;
gv.handle_id = id;
} else if global_needs_wrapper(ir_module, var) {
let class = map_storage_class(var.space);
let pointer_type_id = self.get_pointer_id(&ir_module.types, var.ty, class)?;
let index_id = self.get_index_constant(0);

let id = self.id_gen.next();
prelude.body.push(Instruction::access_chain(
pointer_type_id,
id,
gv.var_id,
&[index_id],
));
gv.access_id = id;
} else {
// by default, the variable ID is accessed as is
gv.access_id = gv.var_id;
};
match ir_module.types[var.ty].inner {
crate::TypeInner::BindingArray { .. } => {
gv.access_id = gv.var_id;
}
_ => {
if var.space == crate::AddressSpace::Handle {
let var_type_id = self.get_type_id(LookupType::Handle(var.ty));
let id = self.id_gen.next();
prelude
.body
.push(Instruction::load(var_type_id, id, gv.var_id, None));
gv.access_id = gv.var_id;
gv.handle_id = id;
} else if global_needs_wrapper(ir_module, var) {
let class = map_storage_class(var.space);
let pointer_type_id =
self.get_pointer_id(&ir_module.types, var.ty, class)?;
let index_id = self.get_index_constant(0);
let id = self.id_gen.next();
prelude.body.push(Instruction::access_chain(
pointer_type_id,
id,
gv.var_id,
&[index_id],
));
gv.access_id = id;
} else {
// by default, the variable ID is accessed as is
gv.access_id = gv.var_id;
};
}
}

// work around borrow checking in the presence of `self.xxx()` calls
self.global_variables[handle.index()] = gv;
Expand Down
2 changes: 1 addition & 1 deletion src/valid/type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,6 @@ impl super::Validator {
ti.uniform_layout = Ok(Alignment::MIN_UNIFORM);

let mut min_offset = 0;

let mut prev_struct_data: Option<(u32, u32)> = None;

for (i, member) in members.iter().enumerate() {
Expand Down Expand Up @@ -585,6 +584,7 @@ impl super::Validator {
// Currently Naga only supports binding arrays of structs for non-handle types.
match gctx.types[base].inner {
crate::TypeInner::Struct { .. } => {}
crate::TypeInner::Array { .. } => {}
_ => return Err(TypeError::BindingArrayBaseTypeNotStruct(base)),
};
}
Expand Down
6 changes: 5 additions & 1 deletion tests/in/binding-buffer-arrays.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ struct UniformIndex {
index: u32
}

struct Foo { x: u32 }
struct Foo { x: u32, far: array<i32> }
@group(0) @binding(0)
var<storage, read> storage_array: binding_array<Foo, 1>;
@group(0) @binding(10)
Expand All @@ -23,5 +23,9 @@ fn main(fragment_in: FragmentIn) -> @location(0) u32 {
u1 += storage_array[uniform_index].x;
u1 += storage_array[non_uniform_index].x;

u1 += arrayLength(&storage_array[0].far);
kvark marked this conversation as resolved.
Show resolved Hide resolved
u1 += arrayLength(&storage_array[uniform_index].far);
u1 += arrayLength(&storage_array[non_uniform_index].far);

return u1;
}
Loading