Skip to content

Commit

Permalink
feat(hugr-cli): allow validation of (hugr+module) package
Browse files Browse the repository at this point in the history
  • Loading branch information
ss2165 committed Aug 1, 2024
1 parent ee1133a commit 6bc6a80
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 16 deletions.
1 change: 1 addition & 0 deletions hugr-cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ clap-stdin.workspace = true
clap-verbosity-flag.workspace = true
hugr-core = { path = "../hugr-core", version = "0.7.0" }
serde_json.workspace = true
serde.workspace = true
thiserror.workspace = true

[lints]
Expand Down
68 changes: 58 additions & 10 deletions hugr-cli/src/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,37 +54,85 @@ pub enum CliError {
ExtReg(#[from] hugr_core::extension::ExtensionRegistryError),
}

#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
/// Package of module HUGRs and extensions.
/// The HUGRs are validated against the extensions.
pub struct Package {
modules: Vec<Hugr>,
extensions: Vec<Extension>,
}

impl Package {
/// Create a new package.
pub fn new(modules: Vec<Hugr>, extensions: Vec<Extension>) -> Self {
Self {
modules,
extensions,
}
}

/// Modules in the package.
pub fn modules(&self) -> &[Hugr] {
&self.modules
}

/// Extensions in the package.
pub fn extensions(&self) -> &[Extension] {
&self.extensions
}
}

/// String to print when validation is successful.
pub const VALID_PRINT: &str = "HUGR valid!";

impl CliArgs {
/// Run the HUGR cli and validate against an extension registry.
pub fn run(&self) -> Result<Hugr, CliError> {
let mut hugr: Hugr = serde_json::from_reader(self.input.clone().into_reader()?)?;
pub fn run(&self) -> Result<Vec<Hugr>, CliError> {
let rdr = self.input.clone().into_reader()?;
let val: serde_json::Value = serde_json::from_reader(rdr)?;
// read either a package or a single hugr
let (mut modules, packed_exts) = if let Ok(Package {
modules,
extensions,
}) = serde_json::from_value::<Package>(val.clone())
{
(modules, extensions)
} else {
let hugr: Hugr = serde_json::from_value(val)?;
(vec![hugr], vec![])
};

let mut reg: ExtensionRegistry = if self.no_std {
hugr_core::extension::PRELUDE_REGISTRY.to_owned()
} else {
hugr_core::std_extensions::std_reg()
};

// register packed extensions
for ext in packed_exts {
reg.register_updated(ext)?;
}

// register external extensions
for ext in &self.extensions {
let f = std::fs::File::open(ext)?;
let ext: Extension = serde_json::from_reader(f)?;
reg.register_updated(ext)?;
}

if self.mermaid {
println!("{}", hugr.mermaid_string());
}
for hugr in modules.iter_mut() {
if self.mermaid {
println!("{}", hugr.mermaid_string());
}

if !self.no_validate {
hugr.update_validate(&reg)?;
if self.verbosity(Level::Info) {
eprintln!("{}", VALID_PRINT);
if !self.no_validate {
hugr.update_validate(&reg)?;
if self.verbosity(Level::Info) {
eprintln!("{}", VALID_PRINT);
}
}
}
Ok(hugr)
Ok(modules)
}

/// Test whether a `level` message should be output.
Expand Down
31 changes: 25 additions & 6 deletions hugr-cli/tests/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

use assert_cmd::Command;
use assert_fs::{fixture::FileWriteStr, NamedTempFile};
use hugr_cli::validate::VALID_PRINT;
use hugr_cli::validate::{Package, VALID_PRINT};
use hugr_core::builder::DFGBuilder;
use hugr_core::types::Type;
use hugr_core::{
Expand Down Expand Up @@ -146,17 +146,36 @@ fn test_no_std_fail(float_hugr_string: String, mut cmd: Command) {
.stderr(contains(" Extension 'arithmetic.float.types' not found"));
}

// path to the fully serialized float extension
const FLOAT_EXT_FILE: &str = concat!(
env!("CARGO_MANIFEST_DIR"),
"/../specification/std_extensions/arithmetic/float/types.json"
);

#[rstest]
fn test_float_extension(float_hugr_string: String, mut cmd: Command) {
cmd.write_stdin(float_hugr_string);
cmd.arg("-");
cmd.arg("--no-std");
cmd.arg("--extensions");
// path to the fully serialized float extension
cmd.arg(concat!(
env!("CARGO_MANIFEST_DIR"),
"/../specification/std_extensions/arithmetic/float/types.json"
));
cmd.arg(FLOAT_EXT_FILE);

cmd.assert().success().stderr(contains(VALID_PRINT));
}
#[fixture]
fn package_string(#[with(FLOAT64_TYPE)] test_hugr: Hugr) -> String {
let rdr = std::fs::File::open(FLOAT_EXT_FILE).unwrap();
let float_ext: hugr_core::Extension = serde_json::from_reader(rdr).unwrap();
let package = Package::new(vec![test_hugr], vec![float_ext]);
serde_json::to_string(&package).unwrap()
}

#[rstest]
fn test_package(package_string: String, mut cmd: Command) {
// package with float extension and hugr that uses floats can validate
cmd.write_stdin(package_string);
cmd.arg("-");
cmd.arg("--no-std");

cmd.assert().success().stderr(contains(VALID_PRINT));
}

0 comments on commit 6bc6a80

Please sign in to comment.