Skip to content

Commit

Permalink
[naga spv-out] Spill arrays and matrices for runtime indexing.
Browse files Browse the repository at this point in the history
Improve handling of `Access` expressions whose base is an array or
matrix (not a pointer to such), and whose index is not known at
compile time. SPIR-V does not have instructions that can do this
directly, so spill such values to temporary variables, and perform the
accesses using `OpAccessChain` instructions applied to the
temporaries.

When performing chains of accesses like `a[i].x[j]`, do not reify
intermediate values; generate a single `OpAccessIndex` for the entire
thing.

Remove special cases for arrays; the same code now handles arrays and
matrices.

Update validation to permit dynamic indexing of matrices.

For details, see the comments on the new tracking structures in
`naga::back::spv::Function`.

Add snapshot test `index-by-value.wgsl`.

Fixes #6358.
Alternative to #6362.
  • Loading branch information
jimblandy committed Oct 10, 2024
1 parent ebd0ed5 commit e479b15
Show file tree
Hide file tree
Showing 12 changed files with 846 additions and 113 deletions.
6 changes: 6 additions & 0 deletions naga/src/arena/handle_set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,12 @@ impl<T> HandleSet<T> {
}
}

impl<T> Default for HandleSet<T> {
fn default() -> Self {
Self::new()
}
}

pub trait ArenaType<T> {
fn len(&self) -> usize;
}
Expand Down
153 changes: 123 additions & 30 deletions naga/src/back/spv/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,7 @@ use super::{
index::BoundsCheckResult, selection::Selection, Block, BlockContext, Dimension, Error,
Instruction, LocalType, LookupType, NumericType, ResultMember, Writer, WriterFlags,
};
use crate::{
arena::Handle,
proc::{index::GuardedIndex, TypeResolution},
Statement,
};
use crate::{arena::Handle, proc::index::GuardedIndex, Statement};
use spirv::Word;

fn get_dimension(type_inner: &crate::TypeInner) -> Dimension {
Expand Down Expand Up @@ -345,34 +341,48 @@ impl<'w> BlockContext<'w> {
// that actually dereferences the pointer.
0
}
_ if self.function.spilled_accesses.contains(base) => {
// As far as Naga IR is concerned, this expression does not yield
// a pointer, but we spilled it to a temporary variable.

// The base expression is something we spilled to a temporary
// variable, so mark this access as spilled as well.
self.function.spilled_accesses.insert(expr_handle);
self.maybe_access_spilled_composite(expr_handle, block, result_type_id)?
}
crate::TypeInner::Vector { .. } => {
self.write_vector_access(expr_handle, base, index, block)?
}
crate::TypeInner::Array {
base: ty_element, ..
} => {
let index_id = self.cached[index];
let base_id = self.cached[base];
let base_ty = match self.fun_info[base].ty {
TypeResolution::Handle(handle) => handle,
TypeResolution::Value(_) => {
return Err(Error::Validation(
"Array types should always be in the arena",
))
crate::TypeInner::Array { .. } | crate::TypeInner::Matrix { .. } => {
// See if `index` is known at compile time.
match GuardedIndex::from_expression(index, self.ir_function, self.ir_module)
{
GuardedIndex::Known(value) => {
// If `index` is known, we can just use `OpCompositeExtract`.
//
// We never need bounds checks for these cases: everything
// size is statically known and checked in validation.
let id = self.gen_id();
let base_id = self.cached[base];
block.body.push(Instruction::composite_extract(
result_type_id,
id,
base_id,
&[value],
));
id
}
};
let (id, variable) = self.writer.promote_access_expression_to_variable(
result_type_id,
base_id,
base_ty,
index_id,
ty_element,
block,
)?;
self.function.internal_variables.push(variable);
id
GuardedIndex::Expression(_) => {
self.spill_to_internal_variable(base, block);
self.function.spilled_accesses.insert(expr_handle);
self.maybe_access_spilled_composite(
expr_handle,
block,
result_type_id,
)?
}
}
}
// wgpu#4337: Support `crate::TypeInner::Matrix`
crate::TypeInner::BindingArray {
base: binding_type, ..
} => {
Expand Down Expand Up @@ -435,6 +445,15 @@ impl<'w> BlockContext<'w> {
// that actually dereferences the pointer.
0
}
_ if self.function.spilled_accesses.contains(base) => {
// As far as Naga IR is concerned, this expression does not yield
// a pointer, but we spilled it to a temporary variable.

// The base expression is something we spilled to a temporary
// variable, so mark this access as spilled as well.
self.function.spilled_accesses.insert(expr_handle);
self.maybe_access_spilled_composite(expr_handle, block, result_type_id)?
}
crate::TypeInner::Vector { .. }
| crate::TypeInner::Matrix { .. }
| crate::TypeInner::Array { .. }
Expand Down Expand Up @@ -1390,7 +1409,7 @@ impl<'w> BlockContext<'w> {
}
crate::Expression::LocalVariable(variable) => self.function.variables[&variable].id,
crate::Expression::Load { pointer } => {
self.write_checked_load(pointer, block, result_type_id)?
self.write_checked_load(pointer, block, AccessTypeAdjustment::None, result_type_id)?
}
crate::Expression::FunctionArgument(index) => self.function.parameter_id(index),
crate::Expression::CallResult(_)
Expand Down Expand Up @@ -1737,6 +1756,14 @@ impl<'w> BlockContext<'w> {

self.temp_list.clear();
let root_id = loop {
// If `expr_handle` was spilled, then the temporary variable has exactly
// the value we want to start from.
if let Some(spilled) = self.function.spilled_composites.get(&expr_handle) {
// The root id of the `OpAccessChain` instruction is the temporary
// variable we spilled the composite to.
break spilled.id;
}

expr_handle = match self.ir_function.expressions[expr_handle] {
crate::Expression::Access { base, index } => {
is_non_uniform_binding_array |=
Expand Down Expand Up @@ -1931,9 +1958,10 @@ impl<'w> BlockContext<'w> {
&mut self,
pointer: Handle<crate::Expression>,
block: &mut Block,
access_type_adjustment: AccessTypeAdjustment,
result_type_id: Word,
) -> Result<Word, Error> {
match self.write_expression_pointer(pointer, block, AccessTypeAdjustment::None)? {
match self.write_expression_pointer(pointer, block, access_type_adjustment)? {
ExpressionPointer::Ready { pointer_id } => {
let id = self.gen_id();
let atomic_space =
Expand Down Expand Up @@ -1988,6 +2016,71 @@ impl<'w> BlockContext<'w> {
}
}

fn spill_to_internal_variable(&mut self, base: Handle<crate::Expression>, block: &mut Block) {
// Generate an internal variable of the appropriate type for `base`.
let variable_id = self.writer.id_gen.next();
let pointer_type_id = self
.writer
.get_resolution_pointer_id(&self.fun_info[base].ty, spirv::StorageClass::Function);
let variable = super::LocalVariable {
id: variable_id,
instruction: Instruction::variable(
pointer_type_id,
variable_id,
spirv::StorageClass::Function,
None,
),
};

let base_id = self.cached[base];
block
.body
.push(Instruction::store(variable.id, base_id, None));
self.function.spilled_composites.insert(base, variable);
}

/// Generate an access to a spilled temporary, if necessary.
///
/// Given `access`, an [`Access`] or [`AccessIndex`] expression that refers
/// to a component of a composite value that has been spilled to a temporary
/// variable, determine whether other expressions are going to use
/// `access`'s value:
///
/// - If so, perform the access and cache that as the value of `access`.
///
/// - Otherwise, generate no code and cache no value for `access`.
///
/// Return `Ok(0)` if no value was fetched, or `Ok(id)` if we loaded it into
/// the instruction given by `id`.
///
/// [`Access`]: crate::Expression::Access
/// [`AccessIndex`]: crate::Expression::AccessIndex
fn maybe_access_spilled_composite(
&mut self,
access: Handle<crate::Expression>,
block: &mut Block,
result_type_id: Word,
) -> Result<Word, Error> {
let access_uses = self.function.access_uses.get(&access).map_or(0, |r| *r);
if access_uses == self.fun_info[access].ref_count {
// This expression is only used by other `Access` and
// `AccessIndex` expressions, so we don't need to cache a
// value for it yet.
Ok(0)
} else {
// There are other expressions that are going to expect this
// expression's value to be cached, not just other `Access` or
// `AccessIndex` expressions. We must actually perform the
// access on the spill variable now.
self.write_checked_load(
access,
block,
AccessTypeAdjustment::IntroducePointer(spirv::StorageClass::Function),
result_type_id,
)
}
}

/// Build the instructions for matrix - matrix column operations
#[allow(clippy::too_many_arguments)]
fn write_matrix_matrix_column_op(
Expand Down
36 changes: 35 additions & 1 deletion naga/src/back/spv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,41 @@ struct Function {
signature: Option<Instruction>,
parameters: Vec<FunctionArgument>,
variables: crate::FastHashMap<Handle<crate::LocalVariable>, LocalVariable>,
internal_variables: Vec<LocalVariable>,

/// A map taking an expression that yields a composite value (array, matrix)
/// to the temporary variables we have spilled it to, if any. Spilling
/// allows us to render an arbitrary chain of [`Access`] and [`AccessIndex`]
/// expressions as an `OpAccessChain` and an `OpLoad` (plus bounds checks).
/// This supports dynamic indexing of by-value arrays and matrices, which
/// SPIR-V does not.
///
/// [`Access`]: crate::Expression::Access
/// [`AccessIndex`]: crate::Expression::AccessIndex
spilled_composites: crate::FastIndexMap<Handle<crate::Expression>, LocalVariable>,

/// A set of expressions that are either in [`spilled_composites`] or refer
/// to some component/element of such.
///
/// [`spilled_composites`]: Function::spilled_composites
spilled_accesses: crate::arena::HandleSet<crate::Expression>,

/// A map taking each expression to the number of [`Access`] and
/// [`AccessIndex`] expressions that uses it as a base value. If an
/// expression has no entry, its count is zero: it is never used as a
/// [`Access`] or [`AccessIndex`] base.
///
/// We use this, together with [`ExpressionInfo::ref_count`], to recognize
/// the tips of chains of [`Access`] and [`AccessIndex`] expressions that
/// access spilled values --- expressions in [`spilled_composites`]. We
/// defer generating code for the chain until we reach its tip, so we can
/// handle it with a single instruction.
///
/// [`Access`]: crate::Expression::Access
/// [`AccessIndex`]: crate::Expression::AccessIndex
/// [`ExpressionInfo::ref_count`]: crate::valid::ExpressionInfo
/// [`spilled_composites`]: Function::spilled_composites
access_uses: crate::FastHashMap<Handle<crate::Expression>, usize>,

blocks: Vec<TerminatedBlock>,
entry_point_context: Option<EntryPointContext>,
}
Expand Down
66 changes: 14 additions & 52 deletions naga/src/back/spv/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ impl Function {
for local_var in self.variables.values() {
local_var.instruction.to_words(sink);
}
for internal_var in self.internal_variables.iter() {
for internal_var in self.spilled_composites.values() {
internal_var.instruction.to_words(sink);
}
}
Expand Down Expand Up @@ -138,54 +138,6 @@ impl Writer {
self.capabilities_used.insert(spirv::Capability::Shader);
}

#[allow(clippy::too_many_arguments)]
pub(super) fn promote_access_expression_to_variable(
&mut self,
result_type_id: Word,
container_id: Word,
container_ty: Handle<crate::Type>,
index_id: Word,
element_ty: Handle<crate::Type>,
block: &mut Block,
) -> Result<(Word, LocalVariable), Error> {
let pointer_type_id = self.get_pointer_id(container_ty, spirv::StorageClass::Function);

let variable = {
let id = self.id_gen.next();
LocalVariable {
id,
instruction: Instruction::variable(
pointer_type_id,
id,
spirv::StorageClass::Function,
None,
),
}
};
block
.body
.push(Instruction::store(variable.id, container_id, None));

let element_pointer_id = self.id_gen.next();
let element_pointer_type_id =
self.get_pointer_id(element_ty, spirv::StorageClass::Function);
block.body.push(Instruction::access_chain(
element_pointer_type_id,
element_pointer_id,
variable.id,
&[index_id],
));
let id = self.id_gen.next();
block.body.push(Instruction::load(
result_type_id,
id,
element_pointer_id,
None,
));

Ok((id, variable))
}

/// Indicate that the code requires any one of the listed capabilities.
///
/// If nothing in `capabilities` appears in the available capabilities
Expand Down Expand Up @@ -683,10 +635,20 @@ impl Writer {
.insert(handle, LocalVariable { id, instruction });
}

// cache local variable expressions
for (handle, expr) in ir_function.expressions.iter() {
if matches!(*expr, crate::Expression::LocalVariable(_)) {
context.cache_expression_value(handle, &mut prelude)?;
match *expr {
crate::Expression::LocalVariable(_) => {
// Cache the `OpVariable` instruction we generated above as
// the value of this expression.
context.cache_expression_value(handle, &mut prelude)?;
}
crate::Expression::Access { base, .. }
| crate::Expression::AccessIndex { base, .. } => {
// Count references to `base` by `Access` and `AccessIndex`
// instructions. See `access_uses` for details.
*context.function.access_uses.entry(base).or_insert(0) += 1;
}
_ => {}
}
}

Expand Down
12 changes: 3 additions & 9 deletions naga/src/valid/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ pub enum ExpressionError {
NegativeIndex(Handle<crate::Expression>),
#[error("Accessing index {1} is out of {0:?} bounds")]
IndexOutOfBounds(Handle<crate::Expression>, u32),
#[error("The expression {0:?} may only be indexed by a constant")]
IndexMustBeConstant(Handle<crate::Expression>),
#[error("Function argument {0:?} doesn't exist")]
FunctionArgumentDoesntExist(u32),
#[error("Loading of {0:?} can't be done")]
Expand Down Expand Up @@ -238,10 +236,9 @@ impl super::Validator {
let stages = match *expression {
E::Access { base, index } => {
let base_type = &resolver[base];
// See the documentation for `Expression::Access`.
let dynamic_indexing_restricted = match *base_type {
Ti::Matrix { .. } => true,
Ti::Vector { .. }
match *base_type {
Ti::Matrix { .. }
| Ti::Vector { .. }
| Ti::Array { .. }
| Ti::Pointer { .. }
| Ti::ValuePointer { size: Some(_), .. }
Expand All @@ -262,9 +259,6 @@ impl super::Validator {
return Err(ExpressionError::InvalidIndexType(index));
}
}
if dynamic_indexing_restricted && function.expressions[index].is_dynamic_index() {
return Err(ExpressionError::IndexMustBeConstant(base));
}

// If we know both the length and the index, we can do the
// bounds check now.
Expand Down
13 changes: 13 additions & 0 deletions naga/tests/in/index-by-value.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
fn index_arg_array(a: array<i32, 5>, i: i32) -> i32 {
return a[i];
}

fn index_let_array(i: i32, j: i32) -> i32 {
let a = array<array<i32, 2>, 2>(array(1, 2), array(3, 4));
return a[i][j];
}

fn index_let_matrix(i: i32, j: i32) -> f32 {
let a = mat2x2<f32>(1, 2, 3, 4);
return a[i][j];
}
Loading

0 comments on commit e479b15

Please sign in to comment.