From 77795b90f029d5c00a99037f4306724176e1dfbd Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Fri, 2 Aug 2024 13:52:36 +0100 Subject: [PATCH] feat(hugr-cli)!: move mermaid to own sub-command (#1390) drive-by replace `clap_stdin` with `clio` as that's that `clap` recommends, and it comes with stdout too. BREAKING CHANGE: Cli validate command no longer has a mermaid option, use `mermaid` sub-command instead. --- Cargo.toml | 2 +- hugr-cli/Cargo.toml | 2 +- hugr-cli/src/lib.rs | 79 ++++++++++++++++++++-- hugr-cli/src/main.rs | 3 +- hugr-cli/src/mermaid.rs | 43 ++++++++++++ hugr-cli/src/validate.rs | 115 +++++++++----------------------- hugr-cli/tests/validate.rs | 131 ++++++++++++++++++++++--------------- hugr-py/tests/conftest.py | 27 +++++--- 8 files changed, 248 insertions(+), 154 deletions(-) create mode 100644 hugr-cli/src/mermaid.rs diff --git a/Cargo.toml b/Cargo.toml index d3cf7c95b..a84030929 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -57,7 +57,7 @@ typetag = "0.2.7" urlencoding = "2.1.2" webbrowser = "1.0.0" clap = { version = "4.5.4" } -clap-stdin = "0.5.0" +clio = "0.3.5" clap-verbosity-flag = "2.2.0" assert_cmd = "2.0.14" assert_fs = "1.1.1" diff --git a/hugr-cli/Cargo.toml b/hugr-cli/Cargo.toml index 3843db9e9..5c72d50d9 100644 --- a/hugr-cli/Cargo.toml +++ b/hugr-cli/Cargo.toml @@ -14,12 +14,12 @@ categories = ["compilers"] [dependencies] clap = { workspace = true, features = ["derive"] } -clap-stdin.workspace = true clap-verbosity-flag.workspace = true hugr-core = { path = "../hugr-core", version = "0.7.0" } serde_json.workspace = true serde.workspace = true thiserror.workspace = true +clio = { workspace = true, features = ["clap-parse"] } [lints] workspace = true diff --git a/hugr-cli/src/lib.rs b/hugr-cli/src/lib.rs index 31be79b30..8a421aa70 100644 --- a/hugr-cli/src/lib.rs +++ b/hugr-cli/src/lib.rs @@ -1,11 +1,14 @@ //! Standard command line tools, used by the hugr binary. -use std::ffi::OsString; +use clap::Parser; +use clap_verbosity_flag::{InfoLevel, Verbosity}; +use clio::Input; +use hugr_core::{Extension, Hugr}; +use std::{ffi::OsString, path::PathBuf}; use thiserror::Error; -/// We reexport some clap types that are used in the public API. -pub use {clap::Parser, clap_verbosity_flag::Level}; pub mod extensions; +pub mod mermaid; pub mod validate; /// CLI arguments. @@ -16,9 +19,11 @@ pub mod validate; #[non_exhaustive] pub enum CliArgs { /// Validate and visualize a HUGR file. - Validate(validate::CliArgs), + Validate(validate::ValArgs), /// Write standard extensions out in serialized form. GenExtensions(extensions::ExtArgs), + /// Write HUGR as mermaid diagrams. + Mermaid(mermaid::MermaidArgs), /// External commands #[command(external_subcommand)] External(Vec), @@ -29,6 +34,70 @@ pub enum CliArgs { #[error(transparent)] #[non_exhaustive] pub enum CliError { + /// Error reading input. + #[error("Error reading from path: {0}")] + InputFile(#[from] std::io::Error), + /// Error parsing input. + #[error("Error parsing input: {0}")] + Parse(#[from] serde_json::Error), /// Errors produced by the `validate` subcommand. - Validate(#[from] validate::CliError), + Validate(#[from] validate::ValError), +} + +/// Validate and visualise a HUGR file. +#[derive(Parser, Debug)] +pub struct HugrArgs { + /// Input HUGR file, use '-' for stdin + #[clap(value_parser, default_value = "-")] + pub input: Input, + /// Verbosity. + #[command(flatten)] + pub verbose: Verbosity, + /// No standard extensions. + #[arg( + long, + help = "Don't use standard extensions when validating. Prelude is still used." + )] + pub no_std: bool, + /// Extensions paths. + #[arg( + short, + long, + help = "Paths to serialised extensions to validate against." + )] + pub extensions: Vec, +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +/// Package of module HUGRs and extensions. +/// The HUGRs are validated against the extensions. +pub struct Package { + /// Module HUGRs included in the package. + pub modules: Vec, + /// Extensions to validate against. + pub extensions: Vec, +} + +impl Package { + /// Create a new package. + pub fn new(modules: Vec, extensions: Vec) -> Self { + Self { + modules, + extensions, + } + } +} + +impl HugrArgs { + /// Read either a package or a single hugr from the input. + pub fn get_package(&mut self) -> Result { + let val: serde_json::Value = serde_json::from_reader(&mut self.input)?; + // read either a package or a single hugr + if let Ok(p) = serde_json::from_value::(val.clone()) { + Ok(p) + } else { + let hugr: Hugr = serde_json::from_value(val)?; + Ok(Package::new(vec![hugr], vec![])) + } + } } diff --git a/hugr-cli/src/main.rs b/hugr-cli/src/main.rs index cf5704376..c8bb0b56e 100644 --- a/hugr-cli/src/main.rs +++ b/hugr-cli/src/main.rs @@ -10,6 +10,7 @@ fn main() { match CliArgs::parse() { CliArgs::Validate(args) => run_validate(args), CliArgs::GenExtensions(args) => args.run_dump(), + CliArgs::Mermaid(mut args) => args.run_print().unwrap(), CliArgs::External(_) => { // TODO: Implement support for external commands. // Running `hugr COMMAND` would look for `hugr-COMMAND` in the path @@ -25,7 +26,7 @@ fn main() { } /// Run the `validate` subcommand. -fn run_validate(args: validate::CliArgs) { +fn run_validate(mut args: validate::ValArgs) { let result = args.run(); if let Err(e) = result { diff --git a/hugr-cli/src/mermaid.rs b/hugr-cli/src/mermaid.rs new file mode 100644 index 000000000..0a59d875d --- /dev/null +++ b/hugr-cli/src/mermaid.rs @@ -0,0 +1,43 @@ +//! Render mermaid diagrams. +use std::io::Write; + +use clap::Parser; +use clio::Output; +use hugr_core::HugrView; + +/// Dump the standard extensions. +#[derive(Parser, Debug)] +#[clap(version = "1.0", long_about = None)] +#[clap(about = "Render mermaid diagrams..")] +#[group(id = "hugr")] +#[non_exhaustive] +pub struct MermaidArgs { + /// Common arguments + #[command(flatten)] + pub hugr_args: crate::HugrArgs, + /// Validate package. + #[arg( + long, + help = "Validate before rendering, includes extension inference." + )] + pub validate: bool, + /// Output file '-' for stdout + #[clap(long, short, value_parser, default_value = "-")] + output: Output, +} + +impl MermaidArgs { + /// Write the mermaid diagram to the output. + pub fn run_print(&mut self) -> Result<(), crate::CliError> { + let hugrs = if self.validate { + self.hugr_args.validate()? + } else { + self.hugr_args.get_package()?.modules + }; + + for hugr in hugrs { + write!(self.output, "{}", hugr.mermaid_string())?; + } + Ok(()) + } +} diff --git a/hugr-cli/src/validate.rs b/hugr-cli/src/validate.rs index 4488b3716..b603365fe 100644 --- a/hugr-cli/src/validate.rs +++ b/hugr-cli/src/validate.rs @@ -1,58 +1,28 @@ //! The `validate` subcommand. -use std::path::PathBuf; use clap::Parser; -use clap_stdin::FileOrStdin; -use clap_verbosity_flag::{InfoLevel, Level, Verbosity}; -use hugr_core::{extension::ExtensionRegistry, Extension, Hugr, HugrView as _}; +use clap_verbosity_flag::Level; +use hugr_core::{extension::ExtensionRegistry, Extension, Hugr}; use thiserror::Error; +use crate::{CliError, HugrArgs, Package}; + /// Validate and visualise a HUGR file. #[derive(Parser, Debug)] #[clap(version = "1.0", long_about = None)] #[clap(about = "Validate a HUGR.")] #[group(id = "hugr")] #[non_exhaustive] -pub struct CliArgs { - /// The input hugr to parse. - pub input: FileOrStdin, - /// Visualise with mermaid. - #[arg(short, long, value_name = "MERMAID", help = "Visualise with mermaid.")] - pub mermaid: bool, - /// Skip validation. - #[arg(short, long, help = "Skip validation.")] - pub no_validate: bool, - /// Verbosity. +pub struct ValArgs { #[command(flatten)] - pub verbose: Verbosity, - /// No standard extensions. - #[arg( - long, - help = "Don't use standard extensions when validating. Prelude is still used." - )] - pub no_std: bool, - /// Extensions paths. - #[arg( - short, - long, - help = "Paths to serialised extensions to validate against." - )] - pub extensions: Vec, + /// common arguments + pub hugr_args: HugrArgs, } /// Error type for the CLI. #[derive(Error, Debug)] #[non_exhaustive] -pub enum CliError { - /// Error reading input. - #[error("Error reading input: {0}")] - Input(#[from] clap_stdin::StdinError), - /// Error reading input. - #[error("Error reading from path: {0}")] - InputFile(#[from] std::io::Error), - /// Error parsing input. - #[error("Error parsing input: {0}")] - Parse(#[from] serde_json::Error), +pub enum ValError { /// Error validating HUGR. #[error("Error validating HUGR: {0}")] Validate(#[from] hugr_core::hugr::ValidationError), @@ -61,45 +31,28 @@ pub enum CliError { ExtReg(#[from] hugr_core::extension::ExtensionRegistryError), } -#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] -/// Package of module HUGRs and extensions. -/// The HUGRs are validated against the extensions. -pub struct Package { - /// Module HUGRs included in the package. - pub modules: Vec, - /// Extensions to validate against. - pub extensions: Vec, -} - -impl Package { - /// Create a new package. - pub fn new(modules: Vec, extensions: Vec) -> Self { - Self { - modules, - extensions, - } - } -} - /// String to print when validation is successful. pub const VALID_PRINT: &str = "HUGR valid!"; -impl CliArgs { +impl ValArgs { /// Run the HUGR cli and validate against an extension registry. - pub fn run(&self) -> Result, CliError> { - let rdr = self.input.clone().into_reader()?; - let val: serde_json::Value = serde_json::from_reader(rdr)?; - // read either a package or a single hugr - let (mut modules, packed_exts) = if let Ok(Package { - modules, - extensions, - }) = serde_json::from_value::(val.clone()) - { - (modules, extensions) - } else { - let hugr: Hugr = serde_json::from_value(val)?; - (vec![hugr], vec![]) - }; + pub fn run(&mut self) -> Result, CliError> { + self.hugr_args.validate() + } + + /// Test whether a `level` message should be output. + pub fn verbosity(&self, level: Level) -> bool { + self.hugr_args.verbosity(level) + } +} + +impl HugrArgs { + /// Load the package and validate against an extension registry. + pub fn validate(&mut self) -> Result, CliError> { + let Package { + mut modules, + extensions: packed_exts, + } = self.get_package()?; let mut reg: ExtensionRegistry = if self.no_std { hugr_core::extension::PRELUDE_REGISTRY.to_owned() @@ -109,26 +62,20 @@ impl CliArgs { // register packed extensions for ext in packed_exts { - reg.register_updated(ext)?; + reg.register_updated(ext).map_err(ValError::ExtReg)?; } // register external extensions for ext in &self.extensions { let f = std::fs::File::open(ext)?; let ext: Extension = serde_json::from_reader(f)?; - reg.register_updated(ext)?; + reg.register_updated(ext).map_err(ValError::ExtReg)?; } for hugr in modules.iter_mut() { - if self.mermaid { - println!("{}", hugr.mermaid_string()); - } - - if !self.no_validate { - hugr.update_validate(®)?; - if self.verbosity(Level::Info) { - eprintln!("{}", VALID_PRINT); - } + hugr.update_validate(®).map_err(ValError::Validate)?; + if self.verbosity(Level::Info) { + eprintln!("{}", VALID_PRINT); } } Ok(modules) diff --git a/hugr-cli/tests/validate.rs b/hugr-cli/tests/validate.rs index 22a8fb059..d885ab6cd 100644 --- a/hugr-cli/tests/validate.rs +++ b/hugr-cli/tests/validate.rs @@ -6,7 +6,7 @@ use assert_cmd::Command; use assert_fs::{fixture::FileWriteStr, NamedTempFile}; -use hugr_cli::validate::{Package, VALID_PRINT}; +use hugr_cli::{validate::VALID_PRINT, Package}; use hugr_core::builder::DFGBuilder; use hugr_core::types::Type; use hugr_core::{ @@ -22,7 +22,11 @@ use rstest::{fixture, rstest}; #[fixture] fn cmd() -> Command { - let mut cmd = Command::cargo_bin("hugr").unwrap(); + Command::cargo_bin("hugr").unwrap() +} + +#[fixture] +fn val_cmd(mut cmd: Command) -> Command { cmd.arg("validate"); cmd } @@ -48,86 +52,104 @@ fn test_hugr_file(test_hugr_string: String) -> NamedTempFile { } #[rstest] -fn test_doesnt_exist(mut cmd: Command) { - cmd.arg("foobar"); - cmd.assert() +fn test_doesnt_exist(mut val_cmd: Command) { + val_cmd.arg("foobar"); + val_cmd + .assert() .failure() - .stderr(contains("No such file or directory").and(contains("Error reading input"))); + .stderr(contains("No such file or directory")); } #[rstest] -fn test_validate(test_hugr_file: NamedTempFile, mut cmd: Command) { - cmd.arg(test_hugr_file.path()); - cmd.assert().success().stderr(contains(VALID_PRINT)); +fn test_validate(test_hugr_file: NamedTempFile, mut val_cmd: Command) { + val_cmd.arg(test_hugr_file.path()); + val_cmd.assert().success().stderr(contains(VALID_PRINT)); } #[rstest] -fn test_stdin(test_hugr_string: String, mut cmd: Command) { - cmd.write_stdin(test_hugr_string); - cmd.arg("-"); +fn test_stdin(test_hugr_string: String, mut val_cmd: Command) { + val_cmd.write_stdin(test_hugr_string); + val_cmd.arg("-"); - cmd.assert().success().stderr(contains(VALID_PRINT)); + val_cmd.assert().success().stderr(contains(VALID_PRINT)); } #[rstest] -fn test_stdin_silent(test_hugr_string: String, mut cmd: Command) { - cmd.args(["-", "-q"]); - cmd.write_stdin(test_hugr_string); +fn test_stdin_silent(test_hugr_string: String, mut val_cmd: Command) { + val_cmd.args(["-", "-q"]); + val_cmd.write_stdin(test_hugr_string); - cmd.assert().success().stderr(contains(VALID_PRINT).not()); + val_cmd + .assert() + .success() + .stderr(contains(VALID_PRINT).not()); } #[rstest] fn test_mermaid(test_hugr_file: NamedTempFile, mut cmd: Command) { const MERMAID: &str = "graph LR\n subgraph 0 [\"(0) DFG\"]"; + cmd.arg("mermaid"); cmd.arg(test_hugr_file.path()); - cmd.arg("--mermaid"); - cmd.arg("--no-validate"); cmd.assert().success().stdout(contains(MERMAID)); } -#[rstest] -fn test_bad_hugr(mut cmd: Command) { +#[fixture] +fn bad_hugr_string() -> String { let df = DFGBuilder::new(Signature::new_endo(type_row![QB_T])).unwrap(); let bad_hugr = df.hugr().clone(); - let bad_hugr_string = serde_json::to_string(&bad_hugr).unwrap(); + serde_json::to_string(&bad_hugr).unwrap() +} + +#[rstest] +fn test_mermaid_invalid(bad_hugr_string: String, mut cmd: Command) { + cmd.arg("mermaid"); + cmd.arg("--validate"); cmd.write_stdin(bad_hugr_string); - cmd.arg("-"); + cmd.assert().failure().stderr(contains("UnconnectedPort")); +} + +#[rstest] +fn test_bad_hugr(bad_hugr_string: String, mut val_cmd: Command) { + val_cmd.write_stdin(bad_hugr_string); + val_cmd.arg("-"); - cmd.assert() + val_cmd + .assert() .failure() .stderr(contains("Error validating HUGR").and(contains("unconnected port"))); } #[rstest] -fn test_bad_json(mut cmd: Command) { - cmd.write_stdin(r#"{"foo": "bar"}"#); - cmd.arg("-"); +fn test_bad_json(mut val_cmd: Command) { + val_cmd.write_stdin(r#"{"foo": "bar"}"#); + val_cmd.arg("-"); - cmd.assert() + val_cmd + .assert() .failure() .stderr(contains("Error parsing input")); } #[rstest] -fn test_bad_json_silent(mut cmd: Command) { - cmd.write_stdin(r#"{"foo": "bar"}"#); - cmd.args(["-", "-qqq"]); +fn test_bad_json_silent(mut val_cmd: Command) { + val_cmd.write_stdin(r#"{"foo": "bar"}"#); + val_cmd.args(["-", "-qqq"]); - cmd.assert() + val_cmd + .assert() .failure() .stderr(contains("Error parsing input").not()); } #[rstest] -fn test_no_std(test_hugr_string: String, mut cmd: Command) { - cmd.write_stdin(test_hugr_string); - cmd.arg("-"); - cmd.arg("--no-std"); +fn test_no_std(test_hugr_string: String, mut val_cmd: Command) { + val_cmd.write_stdin(test_hugr_string); + val_cmd.arg("-"); + val_cmd.arg("--no-std"); // test hugr doesn't have any standard extensions, so this should succceed - cmd.assert().success().stderr(contains(VALID_PRINT)); + val_cmd.assert().success().stderr(contains(VALID_PRINT)); } #[fixture] @@ -136,12 +158,13 @@ fn float_hugr_string(#[with(FLOAT64_TYPE)] test_hugr: Hugr) -> String { } #[rstest] -fn test_no_std_fail(float_hugr_string: String, mut cmd: Command) { - cmd.write_stdin(float_hugr_string); - cmd.arg("-"); - cmd.arg("--no-std"); +fn test_no_std_fail(float_hugr_string: String, mut val_cmd: Command) { + val_cmd.write_stdin(float_hugr_string); + val_cmd.arg("-"); + val_cmd.arg("--no-std"); - cmd.assert() + val_cmd + .assert() .failure() .stderr(contains(" Extension 'arithmetic.float.types' not found")); } @@ -153,14 +176,14 @@ const FLOAT_EXT_FILE: &str = concat!( ); #[rstest] -fn test_float_extension(float_hugr_string: String, mut cmd: Command) { - cmd.write_stdin(float_hugr_string); - cmd.arg("-"); - cmd.arg("--no-std"); - cmd.arg("--extensions"); - cmd.arg(FLOAT_EXT_FILE); +fn test_float_extension(float_hugr_string: String, mut val_cmd: Command) { + val_cmd.write_stdin(float_hugr_string); + val_cmd.arg("-"); + val_cmd.arg("--no-std"); + val_cmd.arg("--extensions"); + val_cmd.arg(FLOAT_EXT_FILE); - cmd.assert().success().stderr(contains(VALID_PRINT)); + val_cmd.assert().success().stderr(contains(VALID_PRINT)); } #[fixture] fn package_string(#[with(FLOAT64_TYPE)] test_hugr: Hugr) -> String { @@ -171,11 +194,11 @@ fn package_string(#[with(FLOAT64_TYPE)] test_hugr: Hugr) -> String { } #[rstest] -fn test_package(package_string: String, mut cmd: Command) { +fn test_package(package_string: String, mut val_cmd: Command) { // package with float extension and hugr that uses floats can validate - cmd.write_stdin(package_string); - cmd.arg("-"); - cmd.arg("--no-std"); + val_cmd.write_stdin(package_string); + val_cmd.arg("-"); + val_cmd.arg("--no-std"); - cmd.assert().success().stderr(contains(VALID_PRINT)); + val_cmd.assert().success().stderr(contains(VALID_PRINT)); } diff --git a/hugr-py/tests/conftest.py b/hugr-py/tests/conftest.py index 075b70efb..94698f9b9 100644 --- a/hugr-py/tests/conftest.py +++ b/hugr-py/tests/conftest.py @@ -114,21 +114,32 @@ def __call__(self, q: ComWire, fl_wire: ComWire) -> Command: Rz = RzDef() -def validate(h: Hugr, mermaid: bool = False, roundtrip: bool = True): +def _base_command() -> list[str]: workspace_dir = pathlib.Path(__file__).parent.parent.parent # use the HUGR_BIN environment variable if set, otherwise use the debug build bin_loc = os.environ.get("HUGR_BIN", str(workspace_dir / "target/debug/hugr")) - cmd = [bin_loc, "validate", "-"] + return [bin_loc] - if mermaid: - cmd.append("--mermaid") + +def mermaid(h: Hugr): + """Render the Hugr as a mermaid diagram for debugging.""" + cmd = [*_base_command(), "mermaid", "-"] + _run_hugr_cmd(h.to_serial().to_json(), cmd) + + +def validate(h: Hugr, roundtrip: bool = True): + cmd = [*_base_command(), "validate", "-"] serial = h.to_serial().to_json() + _run_hugr_cmd(serial, cmd) + + if roundtrip: + h2 = Hugr.from_serial(SerialHugr.load_json(json.loads(serial))) + assert serial == h2.to_serial().to_json() + + +def _run_hugr_cmd(serial: str, cmd: list[str]): try: subprocess.run(cmd, check=True, input=serial.encode(), capture_output=True) # noqa: S603 except subprocess.CalledProcessError as e: error = e.stderr.decode() raise RuntimeError(error) from e - - if roundtrip: - h2 = Hugr.from_serial(SerialHugr.load_json(json.loads(serial))) - assert serial == h2.to_serial().to_json()