From f50ee3bd8c968ea7a308d5ac3a59a9d5c0334a71 Mon Sep 17 00:00:00 2001 From: Craig Roy Date: Fri, 5 Jan 2024 18:24:19 +0000 Subject: [PATCH] refactor: Put extension inference behind a feature gate --- .github/workflows/ci.yml | 4 ++-- Cargo.toml | 3 +++ src/extension.rs | 5 ++++- src/extension/infer/test.rs | 14 ++++++++++++-- src/hugr.rs | 23 ++++++++++++++++++++--- src/hugr/validate.rs | 12 ++++++++++-- src/hugr/validate/test.rs | 13 ++++++++++++- 7 files changed, 63 insertions(+), 11 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5514dedee..8ed7a545f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -3,7 +3,7 @@ name: Continuous integration on: push: branches: - - main + - main pull_request: branches: - main @@ -33,7 +33,7 @@ jobs: - name: Check formatting run: cargo fmt -- --check - name: Run clippy - run: cargo clippy --all-targets -- -D warnings + run: cargo clippy --all-targets --all-features -- -D warnings - name: Build docs run: cargo doc --no-deps --all-features env: diff --git a/Cargo.toml b/Cargo.toml index c633a1222..70c5dab08 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,6 +22,9 @@ name = "hugr" bench = false path = "src/lib.rs" +[features] +extension_inference = [] + [dependencies] thiserror = "1.0.28" portgraph = { version = "0.11.0", features = ["serde", "petgraph"] } diff --git a/src/extension.rs b/src/extension.rs index 5519e456b..237dc3f4f 100644 --- a/src/extension.rs +++ b/src/extension.rs @@ -18,8 +18,11 @@ use crate::types::type_param::{check_type_args, TypeArgError}; use crate::types::type_param::{TypeArg, TypeParam}; use crate::types::{check_typevar_decl, CustomType, PolyFuncType, Substitution, TypeBound}; +#[allow(dead_code)] mod infer; -pub use infer::{infer_extensions, ExtensionSolution, InferExtensionError}; +#[cfg(feature = "extension_inference")] +pub use infer::infer_extensions; +pub use infer::{ExtensionSolution, InferExtensionError}; mod op_def; pub use op_def::{ diff --git a/src/extension/infer/test.rs b/src/extension/infer/test.rs index 1ac40455a..e89de2dea 100644 --- a/src/extension/infer/test.rs +++ b/src/extension/infer/test.rs @@ -1,6 +1,7 @@ use std::error::Error; use super::*; +#[cfg(feature = "extension_inference")] use crate::builder::test::closed_dfg_root_hugr; use crate::builder::{ Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, ModuleBuilder, @@ -8,10 +9,14 @@ use crate::builder::{ use crate::extension::prelude::QB_T; use crate::extension::ExtensionId; use crate::extension::{prelude::PRELUDE_REGISTRY, ExtensionSet}; -use crate::hugr::{validate::ValidationError, Hugr, HugrMut, HugrView, NodeType}; +#[cfg(feature = "extension_inference")] +use crate::hugr::validate::ValidationError; +use crate::hugr::{Hugr, HugrMut, HugrView, NodeType}; use crate::macros::const_extension_ids; use crate::ops::custom::{ExternalOp, OpaqueOp}; -use crate::ops::{self, dataflow::IOTrait, handle::NodeHandle}; +#[cfg(feature = "extension_inference")] +use crate::ops::handle::NodeHandle; +use crate::ops::{self, dataflow::IOTrait}; use crate::ops::{LeafOp, OpType}; use crate::type_row; @@ -153,6 +158,7 @@ fn plus() -> Result<(), InferExtensionError> { Ok(()) } +#[cfg(feature = "extension_inference")] #[test] // This generates a solution that causes validation to fail // because of a missing lift node @@ -214,6 +220,7 @@ fn open_variables() -> Result<(), InferExtensionError> { Ok(()) } +#[cfg(feature = "extension_inference")] #[test] // Infer the extensions on a child node with no inputs fn dangling_src() -> Result<(), Box> { @@ -305,6 +312,7 @@ fn create_with_io( Ok([node, input, output]) } +#[cfg(feature = "extension_inference")] #[test] fn test_conditional_inference() -> Result<(), Box> { fn build_case( @@ -967,6 +975,7 @@ fn simple_funcdefn() -> Result<(), Box> { Ok(()) } +#[cfg(feature = "extension_inference")] #[test] fn funcdefn_signature_mismatch() -> Result<(), Box> { let mut builder = ModuleBuilder::new(); @@ -997,6 +1006,7 @@ fn funcdefn_signature_mismatch() -> Result<(), Box> { Ok(()) } +#[cfg(feature = "extension_inference")] #[test] // Test that the difference between a FuncDefn's input and output nodes is being // constrained to be the same as the extension delta in the FuncDefn signature. diff --git a/src/hugr.rs b/src/hugr.rs index 9672f3dbb..a13d1f601 100644 --- a/src/hugr.rs +++ b/src/hugr.rs @@ -8,6 +8,8 @@ pub mod serialize; pub mod validate; pub mod views; +#[cfg(not(feature = "extension_inference"))] +use std::collections::HashMap; use std::collections::VecDeque; use std::iter; @@ -23,9 +25,9 @@ use thiserror::Error; pub use self::views::{HugrView, RootTagged}; use crate::core::NodeIndex; -use crate::extension::{ - infer_extensions, ExtensionRegistry, ExtensionSet, ExtensionSolution, InferExtensionError, -}; +#[cfg(feature = "extension_inference")] +use crate::extension::infer_extensions; +use crate::extension::{ExtensionRegistry, ExtensionSet, ExtensionSolution, InferExtensionError}; use crate::ops::custom::resolve_extension_ops; use crate::ops::{OpTag, OpTrait, OpType, DEFAULT_OPTYPE}; use crate::types::FunctionType; @@ -197,12 +199,19 @@ impl Hugr { /// Infer extension requirements and add new information to `op_types` field /// /// See [`infer_extensions`] for details on the "closure" value + #[cfg(feature = "extension_inference")] pub fn infer_extensions(&mut self) -> Result { let (solution, extension_closure) = infer_extensions(self)?; self.instantiate_extensions(solution); Ok(extension_closure) } + /// Do nothing - this functionality is gated by the feature "extension_inference" + #[cfg(not(feature = "extension_inference"))] + pub fn infer_extensions(&mut self) -> Result { + Ok(HashMap::new()) + } + #[allow(dead_code)] /// Add extension requirement information to the hugr in place. fn instantiate_extensions(&mut self, solution: ExtensionSolution) { // We only care about inferred _input_ extensions, because `NodeType` @@ -345,13 +354,20 @@ pub enum HugrError { #[cfg(test)] mod test { use super::{Hugr, HugrView}; + #[cfg(feature = "extension_inference")] use crate::builder::test::closed_dfg_root_hugr; + #[cfg(feature = "extension_inference")] use crate::extension::ExtensionSet; + #[cfg(feature = "extension_inference")] use crate::hugr::HugrMut; + #[cfg(feature = "extension_inference")] use crate::ops; + #[cfg(feature = "extension_inference")] use crate::type_row; + #[cfg(feature = "extension_inference")] use crate::types::{FunctionType, Type}; + #[cfg(feature = "extension_inference")] use std::error::Error; #[test] @@ -371,6 +387,7 @@ mod test { assert_matches!(hugr.get_io(hugr.root()), Some(_)); } + #[cfg(feature = "extension_inference")] #[test] fn extension_instantiation() -> Result<(), Box> { const BIT: Type = crate::extension::prelude::USIZE_T; diff --git a/src/hugr/validate.rs b/src/hugr/validate.rs index 295ef76ce..df982816c 100644 --- a/src/hugr/validate.rs +++ b/src/hugr/validate.rs @@ -9,10 +9,11 @@ use petgraph::visit::{Topo, Walker}; use portgraph::{LinkView, PortView}; use thiserror::Error; +#[cfg(feature = "extension_inference")] +use crate::extension::validate::ExtensionValidator; use crate::extension::SignatureError; use crate::extension::{ - validate::{ExtensionError, ExtensionValidator}, - ExtensionRegistry, ExtensionSolution, InferExtensionError, + validate::ExtensionError, ExtensionRegistry, ExtensionSolution, InferExtensionError, }; use crate::ops::custom::CustomOpError; @@ -36,6 +37,7 @@ struct ValidationContext<'a, 'b> { /// Dominator tree for each CFG region, using the container node as index. dominators: HashMap>, /// Context for the extension validation. + #[cfg(feature = "extension_inference")] extension_validator: ExtensionValidator, /// Registry of available Extensions extension_registry: &'b ExtensionRegistry, @@ -64,6 +66,9 @@ impl Hugr { impl<'a, 'b> ValidationContext<'a, 'b> { /// Create a new validation context. + // Allow unused "extension_closure" variable for when + // the "extension_inference" feature is disabled. + #[allow(unused_variables)] pub fn new( hugr: &'a Hugr, extension_closure: ExtensionSolution, @@ -72,6 +77,7 @@ impl<'a, 'b> ValidationContext<'a, 'b> { Self { hugr, dominators: HashMap::new(), + #[cfg(feature = "extension_inference")] extension_validator: ExtensionValidator::new(hugr, extension_closure), extension_registry, } @@ -163,6 +169,7 @@ impl<'a, 'b> ValidationContext<'a, 'b> { // FuncDefns have no resources since they're static nodes, but the // functions they define can have any extension delta. + #[cfg(feature = "extension_inference")] if node_type.tag() != OpTag::FuncDefn { // If this is a container with I/O nodes, check that the extension they // define match the extensions of the container. @@ -240,6 +247,7 @@ impl<'a, 'b> ValidationContext<'a, 'b> { let other_node: Node = self.hugr.graph.port_node(link).unwrap().into(); let other_offset = self.hugr.graph.port_offset(link).unwrap().into(); + #[cfg(feature = "extension_inference")] self.extension_validator .check_extensions_compatible(&(node, port), &(other_node, other_offset))?; diff --git a/src/hugr/validate/test.rs b/src/hugr/validate/test.rs index 520359a3a..9803ed985 100644 --- a/src/hugr/validate/test.rs +++ b/src/hugr/validate/test.rs @@ -2,9 +2,10 @@ use cool_asserts::assert_matches; use super::*; use crate::builder::test::closed_dfg_root_hugr; +#[cfg(feature = "extension_inference")] +use crate::builder::ModuleBuilder; use crate::builder::{ BuildError, Container, Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder, - ModuleBuilder, }; use crate::extension::prelude::{BOOL_T, PRELUDE, USIZE_T}; use crate::extension::{ @@ -12,6 +13,7 @@ use crate::extension::{ }; use crate::hugr::hugrmut::sealed::HugrMutInternals; use crate::hugr::{HugrError, HugrMut, NodeType}; +#[cfg(feature = "extension_inference")] use crate::macros::const_extension_ids; use crate::ops::dataflow::IOTrait; use crate::ops::{self, Const, LeafOp, OpType}; @@ -23,6 +25,7 @@ use crate::values::Value; use crate::{type_row, Direction, IncomingPort, Node}; const NAT: Type = crate::extension::prelude::USIZE_T; +#[cfg(feature = "infer_extensions")] const Q: Type = crate::extension::prelude::QB_T; /// Creates a hugr with a single function definition that copies a bit `copies` times. @@ -71,6 +74,7 @@ fn add_df_children(b: &mut Hugr, parent: Node, copies: usize) -> (Node, Node, No /// Intended to be used to populate a BasicBlock node in a CFG. /// /// Returns the node indices of each of the operations. +#[cfg(feature = "infer_extensions")] fn add_block_children( b: &mut Hugr, parent: Node, @@ -257,6 +261,7 @@ fn df_children_restrictions() { ); } +#[cfg(feature = "extension_inference")] #[test] /// Validation errors in a dataflow subgraph. fn cfg_children_restrictions() { @@ -404,6 +409,7 @@ fn test_ext_edge() -> Result<(), HugrError> { Ok(()) } +#[cfg(feature = "extension_inference")] const_extension_ids! { const XA: ExtensionId = "A"; const XB: ExtensionId = "BOOL_EXT"; @@ -441,6 +447,7 @@ fn test_local_const() -> Result<(), HugrError> { Ok(()) } +#[cfg(feature = "extension_inference")] #[test] /// A wire with no extension requirements is wired into a node which has /// [A,BOOL_T] extensions required on its inputs and outputs. This could be fixed @@ -474,6 +481,7 @@ fn missing_lift_node() -> Result<(), BuildError> { Ok(()) } +#[cfg(feature = "extension_inference")] #[test] /// A wire with extension requirement `[A]` is wired into a an output with no /// extension req. In the validation extension typechecking, we don't do any @@ -505,6 +513,7 @@ fn too_many_extension() -> Result<(), BuildError> { Ok(()) } +#[cfg(feature = "extension_inference")] #[test] /// A wire with extension requirements `[A]` and another with requirements /// `[BOOL_T]` are both wired into a node which requires its inputs to have @@ -558,6 +567,7 @@ fn extensions_mismatch() -> Result<(), BuildError> { Ok(()) } +#[cfg(feature = "extension_inference")] #[test] fn parent_signature_mismatch() -> Result<(), BuildError> { let rs = ExtensionSet::singleton(&XA); @@ -740,6 +750,7 @@ fn invalid_types() { ); } +#[cfg(feature = "extension_inference")] #[test] fn parent_io_mismatch() { // The DFG node declares that it has an empty extension delta,