Skip to content

Commit

Permalink
feat(hugr-cli): validate with extra extensions and packages (#1389)
Browse files Browse the repository at this point in the history
Closes #1359
  • Loading branch information
ss2165 authored Aug 2, 2024
1 parent fd609e0 commit f78c0cc
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 20 deletions.
1 change: 1 addition & 0 deletions hugr-cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ clap-stdin.workspace = true
clap-verbosity-flag.workspace = true
hugr-core = { path = "../hugr-core", version = "0.7.0" }
serde_json.workspace = true
serde.workspace = true
thiserror.workspace = true

[lints]
Expand Down
5 changes: 1 addition & 4 deletions hugr-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,7 @@ fn main() {

/// Run the `validate` subcommand.
fn run_validate(args: validate::CliArgs) {
// validate with all std extensions
let reg = hugr_core::std_extensions::std_reg();

let result = args.run(&reg);
let result = args.run();

if let Err(e) = result {
if args.verbosity(Level::Error) {
Expand Down
95 changes: 84 additions & 11 deletions hugr-cli/src/validate.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
//! The `validate` subcommand.
use std::path::PathBuf;

use clap::Parser;
use clap_stdin::FileOrStdin;
use clap_verbosity_flag::{InfoLevel, Level, Verbosity};
use hugr_core::{extension::ExtensionRegistry, Hugr, HugrView as _};
use hugr_core::{extension::ExtensionRegistry, Extension, Hugr, HugrView as _};
use thiserror::Error;

/// Validate and visualise a HUGR file.
Expand All @@ -23,7 +25,19 @@ pub struct CliArgs {
/// Verbosity.
#[command(flatten)]
pub verbose: Verbosity<InfoLevel>,
// TODO YAML extensions
/// No standard extensions.
#[arg(
long,
help = "Don't use standard extensions when validating. Prelude is still used."
)]
pub no_std: bool,
/// Extensions paths.
#[arg(
short,
long,
help = "Paths to serialised extensions to validate against."
)]
pub extensions: Vec<PathBuf>,
}

/// Error type for the CLI.
Expand All @@ -33,32 +47,91 @@ pub enum CliError {
/// Error reading input.
#[error("Error reading input: {0}")]
Input(#[from] clap_stdin::StdinError),
/// Error reading input.
#[error("Error reading from path: {0}")]
InputFile(#[from] std::io::Error),
/// Error parsing input.
#[error("Error parsing input: {0}")]
Parse(#[from] serde_json::Error),
/// Error validating HUGR.
#[error("Error validating HUGR: {0}")]
Validate(#[from] hugr_core::hugr::ValidationError),
/// Error registering extension.
#[error("Error registering extension: {0}")]
ExtReg(#[from] hugr_core::extension::ExtensionRegistryError),
}

#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
/// Package of module HUGRs and extensions.
/// The HUGRs are validated against the extensions.
pub struct Package {
/// Module HUGRs included in the package.
pub modules: Vec<Hugr>,
/// Extensions to validate against.
pub extensions: Vec<Extension>,
}

impl Package {
/// Create a new package.
pub fn new(modules: Vec<Hugr>, extensions: Vec<Extension>) -> Self {
Self {
modules,
extensions,
}
}
}

/// String to print when validation is successful.
pub const VALID_PRINT: &str = "HUGR valid!";

impl CliArgs {
/// Run the HUGR cli and validate against an extension registry.
pub fn run(&self, registry: &ExtensionRegistry) -> Result<Hugr, CliError> {
let mut hugr: Hugr = serde_json::from_reader(self.input.clone().into_reader()?)?;
if self.mermaid {
println!("{}", hugr.mermaid_string());
pub fn run(&self) -> Result<Vec<Hugr>, CliError> {
let rdr = self.input.clone().into_reader()?;
let val: serde_json::Value = serde_json::from_reader(rdr)?;
// read either a package or a single hugr
let (mut modules, packed_exts) = if let Ok(Package {
modules,
extensions,
}) = serde_json::from_value::<Package>(val.clone())
{
(modules, extensions)
} else {
let hugr: Hugr = serde_json::from_value(val)?;
(vec![hugr], vec![])
};

let mut reg: ExtensionRegistry = if self.no_std {
hugr_core::extension::PRELUDE_REGISTRY.to_owned()
} else {
hugr_core::std_extensions::std_reg()
};

// register packed extensions
for ext in packed_exts {
reg.register_updated(ext)?;
}

if !self.no_validate {
hugr.update_validate(registry)?;
if self.verbosity(Level::Info) {
eprintln!("{}", VALID_PRINT);
// register external extensions
for ext in &self.extensions {
let f = std::fs::File::open(ext)?;
let ext: Extension = serde_json::from_reader(f)?;
reg.register_updated(ext)?;
}

for hugr in modules.iter_mut() {
if self.mermaid {
println!("{}", hugr.mermaid_string());
}

if !self.no_validate {
hugr.update_validate(&reg)?;
if self.verbosity(Level::Info) {
eprintln!("{}", VALID_PRINT);
}
}
}
Ok(hugr)
Ok(modules)
}

/// Test whether a `level` message should be output.
Expand Down
73 changes: 68 additions & 5 deletions hugr-cli/tests/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@

use assert_cmd::Command;
use assert_fs::{fixture::FileWriteStr, NamedTempFile};
use hugr_cli::validate::VALID_PRINT;
use hugr_cli::validate::{Package, VALID_PRINT};
use hugr_core::builder::DFGBuilder;
use hugr_core::types::Type;
use hugr_core::{
builder::{Container, Dataflow, DataflowHugr},
builder::{Container, Dataflow},
extension::prelude::{BOOL_T, QB_T},
std_extensions::arithmetic::float_types::FLOAT64_TYPE,
type_row,
types::Signature,
Hugr,
Expand All @@ -26,10 +28,11 @@ fn cmd() -> Command {
}

#[fixture]
fn test_hugr() -> Hugr {
let df = DFGBuilder::new(Signature::new_endo(type_row![BOOL_T])).unwrap();
fn test_hugr(#[default(BOOL_T)] id_type: Type) -> Hugr {
let mut df = DFGBuilder::new(Signature::new_endo(id_type)).unwrap();
let [i] = df.input_wires_arr();
df.finish_prelude_hugr_with_outputs([i]).unwrap()
df.set_outputs([i]).unwrap();
df.hugr().clone() // unvalidated
}

#[fixture]
Expand Down Expand Up @@ -116,3 +119,63 @@ fn test_bad_json_silent(mut cmd: Command) {
.failure()
.stderr(contains("Error parsing input").not());
}

#[rstest]
fn test_no_std(test_hugr_string: String, mut cmd: Command) {
cmd.write_stdin(test_hugr_string);
cmd.arg("-");
cmd.arg("--no-std");
// test hugr doesn't have any standard extensions, so this should succceed

cmd.assert().success().stderr(contains(VALID_PRINT));
}

#[fixture]
fn float_hugr_string(#[with(FLOAT64_TYPE)] test_hugr: Hugr) -> String {
serde_json::to_string(&test_hugr).unwrap()
}

#[rstest]
fn test_no_std_fail(float_hugr_string: String, mut cmd: Command) {
cmd.write_stdin(float_hugr_string);
cmd.arg("-");
cmd.arg("--no-std");

cmd.assert()
.failure()
.stderr(contains(" Extension 'arithmetic.float.types' not found"));
}

// path to the fully serialized float extension
const FLOAT_EXT_FILE: &str = concat!(
env!("CARGO_MANIFEST_DIR"),
"/../specification/std_extensions/arithmetic/float/types.json"
);

#[rstest]
fn test_float_extension(float_hugr_string: String, mut cmd: Command) {
cmd.write_stdin(float_hugr_string);
cmd.arg("-");
cmd.arg("--no-std");
cmd.arg("--extensions");
cmd.arg(FLOAT_EXT_FILE);

cmd.assert().success().stderr(contains(VALID_PRINT));
}
#[fixture]
fn package_string(#[with(FLOAT64_TYPE)] test_hugr: Hugr) -> String {
let rdr = std::fs::File::open(FLOAT_EXT_FILE).unwrap();
let float_ext: hugr_core::Extension = serde_json::from_reader(rdr).unwrap();
let package = Package::new(vec![test_hugr], vec![float_ext]);
serde_json::to_string(&package).unwrap()
}

#[rstest]
fn test_package(package_string: String, mut cmd: Command) {
// package with float extension and hugr that uses floats can validate
cmd.write_stdin(package_string);
cmd.arg("-");
cmd.arg("--no-std");

cmd.assert().success().stderr(contains(VALID_PRINT));
}

0 comments on commit f78c0cc

Please sign in to comment.