From 583cc6ab04796b8ab29a84920ff5bd42d9736e7a Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Mon, 3 Jun 2024 07:59:25 -0700 Subject: [PATCH] [naga] Ensure that `FooResult` expressions are correctly populated. Make Naga module validation require that `CallResult` and `AtomicResult` expressions are indeed visited by exactly one `Call` / `Atomic` statement. --- naga/src/valid/function.rs | 41 ++++++++++- naga/src/valid/mod.rs | 21 ++++++ naga/tests/validation.rs | 138 +++++++++++++++++++++++-------------- 3 files changed, 145 insertions(+), 55 deletions(-) diff --git a/naga/src/valid/function.rs b/naga/src/valid/function.rs index f112861ff6..d92cda87f9 100644 --- a/naga/src/valid/function.rs +++ b/naga/src/valid/function.rs @@ -22,6 +22,8 @@ pub enum CallError { }, #[error("Result expression {0:?} has already been introduced earlier")] ResultAlreadyInScope(Handle), + #[error("Result expression {0:?} is populated by multiple `Call` statements")] + ResultAlreadyPopulated(Handle), #[error("Result value is invalid")] ResultValue(#[source] ExpressionError), #[error("Requires {required} arguments, but {seen} are provided")] @@ -45,6 +47,8 @@ pub enum AtomicError { InvalidOperand(Handle), #[error("Result type for {0:?} doesn't match the statement")] ResultTypeMismatch(Handle), + #[error("Result expression {0:?} is populated by multiple `Atomic` statements")] + ResultAlreadyPopulated(Handle), } #[derive(Clone, Debug, thiserror::Error)] @@ -174,6 +178,8 @@ pub enum FunctionError { InvalidSubgroup(#[from] SubgroupError), #[error("Emit statement should not cover \"result\" expressions like {0:?}")] EmitResult(Handle), + #[error("Expression not visited by the appropriate statement")] + UnvisitedExpression(Handle), } bitflags::bitflags! { @@ -305,7 +311,13 @@ impl super::Validator { } match context.expressions[expr] { crate::Expression::CallResult(callee) - if fun.result.is_some() && callee == function => {} + if fun.result.is_some() && callee == function => + { + if !self.needs_visit.remove(expr.index()) { + return Err(CallError::ResultAlreadyPopulated(expr) + .with_span_handle(expr, context.expressions)); + } + } _ => { return Err(CallError::ExpressionMismatch(result) .with_span_handle(expr, context.expressions)) @@ -397,7 +409,14 @@ impl super::Validator { } _ => false, } - } => {} + } => + { + if !self.needs_visit.remove(result.index()) { + return Err(AtomicError::ResultAlreadyPopulated(result) + .with_span_handle(result, context.expressions) + .into_other()); + } + } _ => { return Err(AtomicError::ResultTypeMismatch(result) .with_span_handle(result, context.expressions) @@ -1290,11 +1309,20 @@ impl super::Validator { self.valid_expression_set.clear(); self.valid_expression_list.clear(); + self.needs_visit.clear(); for (handle, expr) in fun.expressions.iter() { if expr.needs_pre_emit() { self.valid_expression_set.insert(handle.index()); } if self.flags.contains(super::ValidationFlags::EXPRESSIONS) { + // Mark expressions that need to be visited by a particular kind of + // statement. + if let crate::Expression::CallResult(_) | crate::Expression::AtomicResult { .. } = + *expr + { + self.needs_visit.insert(handle.index()); + } + match self.validate_expression( handle, expr, @@ -1321,6 +1349,15 @@ impl super::Validator { )? .stages; info.available_stages &= stages; + + if self.flags.contains(super::ValidationFlags::EXPRESSIONS) { + if let Some(unvisited) = self.needs_visit.iter().next() { + let index = std::num::NonZeroU32::new(unvisited as u32 + 1).unwrap(); + let handle = Handle::new(index); + return Err(FunctionError::UnvisitedExpression(handle) + .with_span_handle(handle, &fun.expressions)); + } + } } Ok(info) } diff --git a/naga/src/valid/mod.rs b/naga/src/valid/mod.rs index 9e566a8754..d86c23c1e9 100644 --- a/naga/src/valid/mod.rs +++ b/naga/src/valid/mod.rs @@ -246,6 +246,26 @@ pub struct Validator { valid_expression_set: BitSet, override_ids: FastHashSet, allow_overrides: bool, + + /// A checklist of expressions that must be visited by a specific kind of + /// statement. + /// + /// For example: + /// + /// - [`CallResult`] expressions must be visited by a [`Call`] statement. + /// - [`AtomicResult`] expressions must be visited by an [`Atomic`] statement. + /// + /// Be sure not to remove any [`Expression`] handle from this set unless + /// you've explicitly checked that it is the right kind of expression for + /// the visiting [`Statement`]. + /// + /// [`CallResult`]: crate::Expression::CallResult + /// [`Call`]: crate::Statement::Call + /// [`AtomicResult`]: crate::Expression::AtomicResult + /// [`Atomic`]: crate::Statement::Atomic + /// [`Expression`]: crate::Expression + /// [`Statement`]: crate::Statement + needs_visit: BitSet, } #[derive(Clone, Debug, thiserror::Error)] @@ -398,6 +418,7 @@ impl Validator { valid_expression_set: BitSet::new(), override_ids: FastHashSet::default(), allow_overrides: true, + needs_visit: BitSet::new(), } } diff --git a/naga/tests/validation.rs b/naga/tests/validation.rs index 2b632daeb6..7491fd262a 100644 --- a/naga/tests/validation.rs +++ b/naga/tests/validation.rs @@ -1,18 +1,30 @@ use naga::{valid, Expression, Function, Scalar}; +/// Validation should fail if `AtomicResult` expressions are not +/// populated by `Atomic` statements. #[test] -fn emit_atomic_result() { +fn populate_atomic_result() { use naga::{Module, Type, TypeInner}; - // We want to ensure that the *only* problem with the code is the - // use of an `Emit` statement instead of an `Atomic` statement. So - // validate two versions of the module varying only in that - // aspect. - // - // Looking at uses of the `atomic` makes it easy to identify the - // differences between the two variants. - fn variant( - atomic: bool, + /// Different variants of the test case that we want to exercise. + enum Variant { + /// An `AtomicResult` expression with an `Atomic` statement + /// that populates it: valid. + Atomic, + + /// An `AtomicResult` expression visited by an `Emit` + /// statement: invalid. + Emit, + + /// An `AtomicResult` expression visited by no statement at + /// all: invalid + None, + } + + // Looking at uses of `variant` should make it easy to identify + // the differences between the test cases. + fn try_variant( + variant: Variant, ) -> Result> { let span = naga::Span::default(); let mut module = Module::default(); @@ -56,21 +68,25 @@ fn emit_atomic_result() { span, ); - if atomic { - fun.body.push( - naga::Statement::Atomic { - pointer: ex_global, - fun: naga::AtomicFunction::Add, - value: ex_42, - result: ex_result, - }, - span, - ); - } else { - fun.body.push( - naga::Statement::Emit(naga::Range::new_from_bounds(ex_result, ex_result)), - span, - ); + match variant { + Variant::Atomic => { + fun.body.push( + naga::Statement::Atomic { + pointer: ex_global, + fun: naga::AtomicFunction::Add, + value: ex_42, + result: ex_result, + }, + span, + ); + } + Variant::Emit => { + fun.body.push( + naga::Statement::Emit(naga::Range::new_from_bounds(ex_result, ex_result)), + span, + ); + } + Variant::None => {} } module.functions.append(fun, span); @@ -82,23 +98,34 @@ fn emit_atomic_result() { .validate(&module) } - variant(true).expect("module should validate"); - assert!(variant(false).is_err()); + try_variant(Variant::Atomic).expect("module should validate"); + assert!(try_variant(Variant::Emit).is_err()); + assert!(try_variant(Variant::None).is_err()); } #[test] -fn emit_call_result() { +fn populate_call_result() { use naga::{Module, Type, TypeInner}; - // We want to ensure that the *only* problem with the code is the - // use of an `Emit` statement instead of a `Call` statement. So - // validate two versions of the module varying only in that - // aspect. - // - // Looking at uses of the `call` makes it easy to identify the - // differences between the two variants. - fn variant( - call: bool, + /// Different variants of the test case that we want to exercise. + enum Variant { + /// A `CallResult` expression with an `Call` statement that + /// populates it: valid. + Call, + + /// A `CallResult` expression visited by an `Emit` statement: + /// invalid. + Emit, + + /// A `CallResult` expression visited by no statement at all: + /// invalid + None, + } + + // Looking at uses of `variant` should make it easy to identify + // the differences between the test cases. + fn try_variant( + variant: Variant, ) -> Result> { let span = naga::Span::default(); let mut module = Module::default(); @@ -130,20 +157,24 @@ fn emit_call_result() { .expressions .append(Expression::CallResult(fun_callee), span); - if call { - fun_caller.body.push( - naga::Statement::Call { - function: fun_callee, - arguments: vec![], - result: Some(ex_result), - }, - span, - ); - } else { - fun_caller.body.push( - naga::Statement::Emit(naga::Range::new_from_bounds(ex_result, ex_result)), - span, - ); + match variant { + Variant::Call => { + fun_caller.body.push( + naga::Statement::Call { + function: fun_callee, + arguments: vec![], + result: Some(ex_result), + }, + span, + ); + } + Variant::Emit => { + fun_caller.body.push( + naga::Statement::Emit(naga::Range::new_from_bounds(ex_result, ex_result)), + span, + ); + } + Variant::None => {} } module.functions.append(fun_caller, span); @@ -155,8 +186,9 @@ fn emit_call_result() { .validate(&module) } - variant(true).expect("should validate"); - assert!(variant(false).is_err()); + try_variant(Variant::Call).expect("should validate"); + assert!(try_variant(Variant::Emit).is_err()); + assert!(try_variant(Variant::None).is_err()); } #[test]