From 6bc6a806a0bf88d28a01b7652ed6ce684c7ccab8 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Thu, 1 Aug 2024 17:34:14 +0100 Subject: [PATCH] feat(hugr-cli): allow validation of (hugr+module) package --- hugr-cli/Cargo.toml | 1 + hugr-cli/src/validate.rs | 68 ++++++++++++++++++++++++++++++++------ hugr-cli/tests/validate.rs | 31 +++++++++++++---- 3 files changed, 84 insertions(+), 16 deletions(-) diff --git a/hugr-cli/Cargo.toml b/hugr-cli/Cargo.toml index d38af25798..3843db9e91 100644 --- a/hugr-cli/Cargo.toml +++ b/hugr-cli/Cargo.toml @@ -18,6 +18,7 @@ clap-stdin.workspace = true clap-verbosity-flag.workspace = true hugr-core = { path = "../hugr-core", version = "0.7.0" } serde_json.workspace = true +serde.workspace = true thiserror.workspace = true [lints] diff --git a/hugr-cli/src/validate.rs b/hugr-cli/src/validate.rs index 9e490d289a..4d33a227e9 100644 --- a/hugr-cli/src/validate.rs +++ b/hugr-cli/src/validate.rs @@ -54,13 +54,53 @@ pub enum CliError { ExtReg(#[from] hugr_core::extension::ExtensionRegistryError), } +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +/// Package of module HUGRs and extensions. +/// The HUGRs are validated against the extensions. +pub struct Package { + modules: Vec, + extensions: Vec, +} + +impl Package { + /// Create a new package. + pub fn new(modules: Vec, extensions: Vec) -> Self { + Self { + modules, + extensions, + } + } + + /// Modules in the package. + pub fn modules(&self) -> &[Hugr] { + &self.modules + } + + /// Extensions in the package. + pub fn extensions(&self) -> &[Extension] { + &self.extensions + } +} + /// String to print when validation is successful. pub const VALID_PRINT: &str = "HUGR valid!"; impl CliArgs { /// Run the HUGR cli and validate against an extension registry. - pub fn run(&self) -> Result { - let mut hugr: Hugr = serde_json::from_reader(self.input.clone().into_reader()?)?; + pub fn run(&self) -> Result, CliError> { + let rdr = self.input.clone().into_reader()?; + let val: serde_json::Value = serde_json::from_reader(rdr)?; + // read either a package or a single hugr + let (mut modules, packed_exts) = if let Ok(Package { + modules, + extensions, + }) = serde_json::from_value::(val.clone()) + { + (modules, extensions) + } else { + let hugr: Hugr = serde_json::from_value(val)?; + (vec![hugr], vec![]) + }; let mut reg: ExtensionRegistry = if self.no_std { hugr_core::extension::PRELUDE_REGISTRY.to_owned() @@ -68,23 +108,31 @@ impl CliArgs { hugr_core::std_extensions::std_reg() }; + // register packed extensions + for ext in packed_exts { + reg.register_updated(ext)?; + } + + // register external extensions for ext in &self.extensions { let f = std::fs::File::open(ext)?; let ext: Extension = serde_json::from_reader(f)?; reg.register_updated(ext)?; } - if self.mermaid { - println!("{}", hugr.mermaid_string()); - } + for hugr in modules.iter_mut() { + if self.mermaid { + println!("{}", hugr.mermaid_string()); + } - if !self.no_validate { - hugr.update_validate(®)?; - if self.verbosity(Level::Info) { - eprintln!("{}", VALID_PRINT); + if !self.no_validate { + hugr.update_validate(®)?; + if self.verbosity(Level::Info) { + eprintln!("{}", VALID_PRINT); + } } } - Ok(hugr) + Ok(modules) } /// Test whether a `level` message should be output. diff --git a/hugr-cli/tests/validate.rs b/hugr-cli/tests/validate.rs index c5f09e45aa..22a8fb0591 100644 --- a/hugr-cli/tests/validate.rs +++ b/hugr-cli/tests/validate.rs @@ -6,7 +6,7 @@ use assert_cmd::Command; use assert_fs::{fixture::FileWriteStr, NamedTempFile}; -use hugr_cli::validate::VALID_PRINT; +use hugr_cli::validate::{Package, VALID_PRINT}; use hugr_core::builder::DFGBuilder; use hugr_core::types::Type; use hugr_core::{ @@ -146,17 +146,36 @@ fn test_no_std_fail(float_hugr_string: String, mut cmd: Command) { .stderr(contains(" Extension 'arithmetic.float.types' not found")); } +// path to the fully serialized float extension +const FLOAT_EXT_FILE: &str = concat!( + env!("CARGO_MANIFEST_DIR"), + "/../specification/std_extensions/arithmetic/float/types.json" +); + #[rstest] fn test_float_extension(float_hugr_string: String, mut cmd: Command) { cmd.write_stdin(float_hugr_string); cmd.arg("-"); cmd.arg("--no-std"); cmd.arg("--extensions"); - // path to the fully serialized float extension - cmd.arg(concat!( - env!("CARGO_MANIFEST_DIR"), - "/../specification/std_extensions/arithmetic/float/types.json" - )); + cmd.arg(FLOAT_EXT_FILE); + + cmd.assert().success().stderr(contains(VALID_PRINT)); +} +#[fixture] +fn package_string(#[with(FLOAT64_TYPE)] test_hugr: Hugr) -> String { + let rdr = std::fs::File::open(FLOAT_EXT_FILE).unwrap(); + let float_ext: hugr_core::Extension = serde_json::from_reader(rdr).unwrap(); + let package = Package::new(vec![test_hugr], vec![float_ext]); + serde_json::to_string(&package).unwrap() +} + +#[rstest] +fn test_package(package_string: String, mut cmd: Command) { + // package with float extension and hugr that uses floats can validate + cmd.write_stdin(package_string); + cmd.arg("-"); + cmd.arg("--no-std"); cmd.assert().success().stderr(contains(VALID_PRINT)); }