diff --git a/Cargo.lock b/Cargo.lock index 882e30c199b1..7bdcf76db200 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4164,6 +4164,9 @@ dependencies = [ "file-per-thread-logger", "humantime", "rayon", + "serde", + "serde_derive", + "toml", "tracing-subscriber", "wasmtime", ] diff --git a/crates/cli-flags/Cargo.toml b/crates/cli-flags/Cargo.toml index c1c84ce39dd7..578067d12acf 100644 --- a/crates/cli-flags/Cargo.toml +++ b/crates/cli-flags/Cargo.toml @@ -20,6 +20,9 @@ tracing-subscriber = { workspace = true, optional = true } rayon = { version = "1.5.0", optional = true } wasmtime = { workspace = true, features = ["gc"] } humantime = { workspace = true } +serde = { workspace = true } +serde_derive = { workspace = true } +toml = { workspace = true } [features] async = ["wasmtime/async"] diff --git a/crates/cli-flags/src/lib.rs b/crates/cli-flags/src/lib.rs index f819f8d544ca..efdbe50eb49b 100644 --- a/crates/cli-flags/src/lib.rs +++ b/crates/cli-flags/src/lib.rs @@ -1,8 +1,13 @@ //! Contains the common Wasmtime command line interface (CLI) flags. -use anyhow::Result; +use anyhow::{Context, Result}; use clap::Parser; -use std::time::Duration; +use serde::Deserialize; +use std::{ + fs, + path::{Path, PathBuf}, + time::Duration, +}; use wasmtime::Config; pub mod opt; @@ -37,12 +42,17 @@ fn init_file_per_thread_logger(prefix: &'static str) { } wasmtime_option_group! { - #[derive(PartialEq, Clone)] + #[derive(PartialEq, Clone, Deserialize)] + #[serde(deny_unknown_fields)] pub struct OptimizeOptions { /// Optimization level of generated code (0-2, s; default: 2) + #[serde(default)] + #[serde(deserialize_with = "crate::opt::cli_parse_wrapper")] pub opt_level: Option, /// Register allocator algorithm choice. + #[serde(default)] + #[serde(deserialize_with = "crate::opt::cli_parse_wrapper")] pub regalloc_algorithm: Option, /// Do not allow Wasm linear memories to move in the host process's @@ -189,12 +199,15 @@ wasmtime_option_group! { } wasmtime_option_group! { - #[derive(PartialEq, Clone)] + #[derive(PartialEq, Clone, Deserialize)] + #[serde(deny_unknown_fields)] pub struct CodegenOptions { /// Either `cranelift` or `winch`. /// /// Currently only `cranelift` and `winch` are supported, but not all /// builds of Wasmtime have both built in. + #[serde(default)] + #[serde(deserialize_with = "crate::opt::cli_parse_wrapper")] pub compiler: Option, /// Which garbage collector to use: `drc` or `null`. /// @@ -205,6 +218,8 @@ wasmtime_option_group! { /// /// Note that not all builds of Wasmtime will have support for garbage /// collection included. + #[serde(default)] + #[serde(deserialize_with = "crate::opt::cli_parse_wrapper")] pub collector: Option, /// Enable Cranelift's internal debug verifier (expensive) pub cranelift_debug_verifier: Option, @@ -221,6 +236,7 @@ wasmtime_option_group! { pub native_unwind_info: Option, #[prefixed = "cranelift"] + #[serde(default)] /// Set a cranelift-specific option. Use `wasmtime settings` to see /// all. pub cranelift: Vec<(String, Option)>, @@ -232,7 +248,8 @@ wasmtime_option_group! { } wasmtime_option_group! { - #[derive(PartialEq, Clone)] + #[derive(PartialEq, Clone, Deserialize)] + #[serde(deny_unknown_fields)] pub struct DebugOptions { /// Enable generation of DWARF debug information in compiled code. pub debug_info: Option, @@ -252,7 +269,8 @@ wasmtime_option_group! { } wasmtime_option_group! { - #[derive(PartialEq, Clone)] + #[derive(PartialEq, Clone, Deserialize)] + #[serde(deny_unknown_fields)] pub struct WasmOptions { /// Enable canonicalization of all NaN values. pub nan_canonicalization: Option, @@ -361,7 +379,8 @@ wasmtime_option_group! { } wasmtime_option_group! { - #[derive(PartialEq, Clone)] + #[derive(PartialEq, Clone, Deserialize)] + #[serde(deny_unknown_fields)] pub struct WasiOptions { /// Enable support for WASI CLI APIs, including filesystems, sockets, clocks, and random. pub cli: Option, @@ -390,6 +409,7 @@ wasmtime_option_group! { /// systemd listen fd specification (UNIX only) pub listenfd: Option, /// Grant access to the given TCP listen socket + #[serde(default)] pub tcplisten: Vec, /// Implement WASI Preview1 using new Preview2 implementation (true, default) or legacy /// implementation (false) @@ -402,6 +422,7 @@ wasmtime_option_group! { /// an OpenVINO model named `bar`. Note that which model encodings are /// available is dependent on the backends implemented in the /// `wasmtime_wasi_nn` crate. + #[serde(skip)] pub nn_graph: Vec, /// Flag for WASI preview2 to inherit the host's network within the /// guest so it has full access to all addresses/ports/etc. @@ -421,8 +442,10 @@ wasmtime_option_group! { /// This option can be further overwritten with `--env` flags. pub inherit_env: Option, /// Pass a wasi config variable to the program. + #[serde(skip)] pub config_var: Vec, /// Preset data for the In-Memory provider of WASI key-value API. + #[serde(skip)] pub keyvalue_in_memory_data: Vec, } @@ -444,7 +467,8 @@ pub struct KeyValuePair { } /// Common options for commands that translate WebAssembly modules -#[derive(Parser, Clone)] +#[derive(Parser, Clone, Debug, Deserialize)] +#[serde(deny_unknown_fields)] pub struct CommonOptions { // These options groups are used to parse `-O` and such options but aren't // the raw form consumed by the CLI. Instead they're pushed into the `pub` @@ -456,43 +480,70 @@ pub struct CommonOptions { /// Optimization and tuning related options for wasm performance, `-O help` to /// see all. #[arg(short = 'O', long = "optimize", value_name = "KEY[=VAL[,..]]")] + #[serde(skip)] opts_raw: Vec>, /// Codegen-related configuration options, `-C help` to see all. #[arg(short = 'C', long = "codegen", value_name = "KEY[=VAL[,..]]")] + #[serde(skip)] codegen_raw: Vec>, /// Debug-related configuration options, `-D help` to see all. #[arg(short = 'D', long = "debug", value_name = "KEY[=VAL[,..]]")] + #[serde(skip)] debug_raw: Vec>, /// Options for configuring semantic execution of WebAssembly, `-W help` to see /// all. #[arg(short = 'W', long = "wasm", value_name = "KEY[=VAL[,..]]")] + #[serde(skip)] wasm_raw: Vec>, /// Options for configuring WASI and its proposals, `-S help` to see all. #[arg(short = 'S', long = "wasi", value_name = "KEY[=VAL[,..]]")] + #[serde(skip)] wasi_raw: Vec>, // These fields are filled in by the `configure` method below via the // options parsed from the CLI above. This is what the CLI should use. #[arg(skip)] + #[serde(skip)] configured: bool, + #[arg(skip)] + #[serde(rename = "optimize", default)] pub opts: OptimizeOptions, + #[arg(skip)] + #[serde(rename = "codegen", default)] pub codegen: CodegenOptions, + #[arg(skip)] + #[serde(rename = "debug", default)] pub debug: DebugOptions, + #[arg(skip)] + #[serde(rename = "wasm", default)] pub wasm: WasmOptions, + #[arg(skip)] + #[serde(rename = "wasi", default)] pub wasi: WasiOptions, /// The target triple; default is the host triple #[arg(long, value_name = "TARGET")] + #[serde(skip)] pub target: Option, + + /// Use the specified TOML configuration file. + /// This TOML configuration file can provide same configuration options as the + /// `--optimize`, `--codgen`, `--debug`, `--wasm`, `--wasi` CLI options, with a couple exceptions. + /// + /// Additional options specified on the command line will take precedent over options loaded from + /// this TOML file. + #[arg(long = "config", value_name = "FILE")] + #[serde(skip)] + pub config: Option, } macro_rules! match_feature { @@ -517,20 +568,29 @@ macro_rules! match_feature { } impl CommonOptions { - fn configure(&mut self) { + fn configure(&mut self) -> Result<()> { if self.configured { - return; + return Ok(()); } self.configured = true; + if let Some(toml_config_path) = &self.config { + let toml_options = CommonOptions::from_file(toml_config_path)?; + self.opts = toml_options.opts; + self.codegen = toml_options.codegen; + self.debug = toml_options.debug; + self.wasm = toml_options.wasm; + self.wasi = toml_options.wasi; + } self.opts.configure_with(&self.opts_raw); self.codegen.configure_with(&self.codegen_raw); self.debug.configure_with(&self.debug_raw); self.wasm.configure_with(&self.wasm_raw); self.wasi.configure_with(&self.wasi_raw); + Ok(()) } pub fn init_logging(&mut self) -> Result<()> { - self.configure(); + self.configure()?; if self.debug.logging == Some(false) { return Ok(()); } @@ -555,7 +615,7 @@ impl CommonOptions { } pub fn config(&mut self, pooling_allocator_default: Option) -> Result { - self.configure(); + self.configure()?; let mut config = Config::new(); match_feature! { @@ -923,4 +983,135 @@ impl CommonOptions { } Ok(()) } + + pub fn from_file>(path: P) -> Result { + let path_ref = path.as_ref(); + let file_contents = fs::read_to_string(path_ref) + .with_context(|| format!("failed to read config file: {path_ref:?}"))?; + toml::from_str::(&file_contents) + .with_context(|| format!("failed to parse TOML config file {path_ref:?}")) + } +} + +#[cfg(test)] +mod tests { + use wasmtime::{OptLevel, RegallocAlgorithm}; + + use super::*; + + #[test] + fn from_toml() { + // empty toml + let empty_toml = ""; + let mut common_options: CommonOptions = toml::from_str(empty_toml).unwrap(); + common_options.config(None).unwrap(); + + // basic toml + let basic_toml = r#" + [optimize] + [codegen] + [debug] + [wasm] + [wasi] + "#; + let mut common_options: CommonOptions = toml::from_str(basic_toml).unwrap(); + common_options.config(None).unwrap(); + + // toml with custom deserialization to match CLI flag parsing + for (opt_value, expected) in [ + ("0", Some(OptLevel::None)), + ("1", Some(OptLevel::Speed)), + ("2", Some(OptLevel::Speed)), + ("\"s\"", Some(OptLevel::SpeedAndSize)), + ("\"hello\"", None), // should fail + ("3", None), // should fail + ] { + let toml = format!( + r#" + [optimize] + opt_level = {opt_value} + "#, + ); + let parsed_opt_level = toml::from_str::(&toml) + .ok() + .and_then(|common_options| common_options.opts.opt_level); + + assert_eq!( + parsed_opt_level, expected, + "Mismatch for input '{opt_value}'. Parsed: {parsed_opt_level:?}, Expected: {expected:?}" + ); + } + + // Regalloc algorithm + for (regalloc_value, expected) in [ + ("\"backtracking\"", Some(RegallocAlgorithm::Backtracking)), + ("\"single-pass\"", Some(RegallocAlgorithm::SinglePass)), + ("\"hello\"", None), // should fail + ("3", None), // should fail + ("true", None), // should fail + ] { + let toml = format!( + r#" + [optimize] + regalloc_algorithm = {regalloc_value} + "#, + ); + let parsed_regalloc_algorithm = toml::from_str::(&toml) + .ok() + .and_then(|common_options| common_options.opts.regalloc_algorithm); + assert_eq!( + parsed_regalloc_algorithm, expected, + "Mismatch for input '{regalloc_value}'. Parsed: {parsed_regalloc_algorithm:?}, Expected: {expected:?}" + ); + } + + // Strategy + for (strategy_value, expected) in [ + ("\"cranelift\"", Some(wasmtime::Strategy::Cranelift)), + ("\"winch\"", Some(wasmtime::Strategy::Winch)), + ("\"hello\"", None), // should fail + ("5", None), // should fail + ("true", None), // should fail + ] { + let toml = format!( + r#" + [codegen] + compiler = {strategy_value} + "#, + ); + let parsed_strategy = toml::from_str::(&toml) + .ok() + .and_then(|common_options| common_options.codegen.compiler); + assert_eq!( + parsed_strategy, expected, + "Mismatch for input '{strategy_value}'. Parsed: {parsed_strategy:?}, Expected: {expected:?}", + ); + } + + // Collector + for (collector_value, expected) in [ + ( + "\"drc\"", + Some(wasmtime::Collector::DeferredReferenceCounting), + ), + ("\"null\"", Some(wasmtime::Collector::Null)), + ("\"hello\"", None), // should fail + ("5", None), // should fail + ("true", None), // should fail + ] { + let toml = format!( + r#" + [codegen] + collector = {collector_value} + "#, + ); + let parsed_collector = toml::from_str::(&toml) + .ok() + .and_then(|common_options| common_options.codegen.collector); + assert_eq!( + parsed_collector, expected, + "Mismatch for input '{collector_value}'. Parsed: {parsed_collector:?}, Expected: {expected:?}", + ); + } + } } diff --git a/crates/cli-flags/src/opt.rs b/crates/cli-flags/src/opt.rs index 9814ef3169fb..70e2ab8e66e7 100644 --- a/crates/cli-flags/src/opt.rs +++ b/crates/cli-flags/src/opt.rs @@ -9,8 +9,9 @@ use crate::{KeyValuePair, WasiNnGraph}; use anyhow::{bail, Result}; use clap::builder::{StringValueParser, TypedValueParser, ValueParserFactory}; use clap::error::{Error, ErrorKind}; -use std::marker; +use serde::de::{self, Visitor}; use std::time::Duration; +use std::{fmt, marker}; /// Characters which can be safely ignored while parsing numeric options to wasmtime const IGNORED_NUMBER_CHARS: [char; 1] = ['_']; @@ -22,11 +23,13 @@ macro_rules! wasmtime_option_group { pub struct $opts:ident { $( $(#[doc = $doc:tt])* + $(#[serde($serde_attr:meta)])* pub $opt:ident: $container:ident<$payload:ty>, )+ $( #[prefixed = $prefix:tt] + $(#[serde($serde_attr2:meta)])* $(#[doc = $prefixed_doc:tt])* pub $prefixed:ident: Vec<(String, Option)>, )? @@ -39,9 +42,11 @@ macro_rules! wasmtime_option_group { $(#[$attr])* pub struct $opts { $( + $(#[serde($serde_attr)])* pub $opt: $container<$payload>, )+ $( + $(#[serde($serde_attr2)])* pub $prefixed: Vec<(String, Option)>, )? } @@ -440,6 +445,55 @@ impl WasmtimeOptionValue for KeyValuePair { } } +// Used to parse toml values into string so that we can reuse the `WasmtimeOptionValue::parse` +// for parsing toml values the same way we parse command line values. +// +// Used for wasmtime::Strategy, wasmtime::Collector, wasmtime::OptLevel, wasmtime::RegallocAlgorithm +struct ToStringVisitor {} + +impl<'de> Visitor<'de> for ToStringVisitor { + type Value = String; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!(formatter, "&str, u64, or i64") + } + + fn visit_str(self, s: &str) -> Result + where + E: de::Error, + { + Ok(s.to_owned()) + } + + fn visit_u64(self, v: u64) -> Result + where + E: de::Error, + { + Ok(v.to_string()) + } + + fn visit_i64(self, v: i64) -> Result + where + E: de::Error, + { + Ok(v.to_string()) + } +} + +// Deserializer that uses the `WasmtimeOptionValue::parse` to parse toml values +pub(crate) fn cli_parse_wrapper<'de, D, T>(deserializer: D) -> Result, D::Error> +where + T: WasmtimeOptionValue, + D: serde::Deserializer<'de>, +{ + let to_string_visitor = ToStringVisitor {}; + let str = deserializer.deserialize_any(to_string_visitor)?; + + T::parse(Some(&str)) + .map(Some) + .map_err(serde::de::Error::custom) +} + #[cfg(test)] mod tests { use super::WasmtimeOptionValue; diff --git a/crates/wasmtime/src/config.rs b/crates/wasmtime/src/config.rs index 951a49058b2f..f07584a9c6b1 100644 --- a/crates/wasmtime/src/config.rs +++ b/crates/wasmtime/src/config.rs @@ -2547,7 +2547,7 @@ impl fmt::Debug for Config { /// /// This is used as an argument to the [`Config::strategy`] method. #[non_exhaustive] -#[derive(PartialEq, Eq, Clone, Debug, Copy)] +#[derive(PartialEq, Eq, Clone, Debug, Copy, Deserialize)] pub enum Strategy { /// An indicator that the compilation strategy should be automatically /// selected. @@ -2617,7 +2617,7 @@ impl Strategy { /// additional objects. Reference counts are larger than mark bits and /// free lists are larger than bump pointers, for example. #[non_exhaustive] -#[derive(PartialEq, Eq, Clone, Debug, Copy)] +#[derive(PartialEq, Eq, Clone, Debug, Copy, Deserialize)] pub enum Collector { /// An indicator that the garbage collector should be automatically /// selected. diff --git a/tests/all/cli_tests.rs b/tests/all/cli_tests.rs index 564542b16fcc..a39bb4e23827 100644 --- a/tests/all/cli_tests.rs +++ b/tests/all/cli_tests.rs @@ -2078,3 +2078,95 @@ fn unreachable_without_wasi() -> Result<()> { assert_trap_code(&output.status); Ok(()) } + +#[test] +fn config_cli_flag() -> Result<()> { + let wasm = build_wasm("tests/all/cli_tests/simple.wat")?; + + // Test some valid TOML values + let (mut cfg, cfg_path) = tempfile::NamedTempFile::new()?.into_parts(); + cfg.write_all( + br#" + [optimize] + opt_level = 2 + regalloc_algorithm = "single-pass" + signals_based_traps = false + + [codegen] + collector = "null" + + [debug] + debug_info = true + + [wasm] + max_wasm_stack = 65536 + + [wasi] + cli = true + "#, + )?; + let output = run_wasmtime(&[ + "run", + "--config", + cfg_path.to_str().unwrap(), + "--invoke", + "get_f64", + wasm.path().to_str().unwrap(), + ])?; + assert_eq!(output, "100\n"); + + // Make sure CLI flags overrides TOML values + let output = run_wasmtime(&[ + "run", + "--config", + cfg_path.to_str().unwrap(), + "--invoke", + "get_f64", + "-W", + "max-wasm-stack=1", // should override TOML value 65536 specified above and execution should fail + wasm.path().to_str().unwrap(), + ]); + assert!(output + .unwrap_err() + .to_string() + .contains("call stack exhausted")); + + // Test invalid TOML key + let (mut cfg, cfg_path) = tempfile::NamedTempFile::new()?.into_parts(); + cfg.write_all( + br#" + [optimize] + this_key_does_not_exist = true + "#, + )?; + let output = run_wasmtime(&[ + "run", + "--config", + cfg_path.to_str().unwrap(), + wasm.path().to_str().unwrap(), + ]); + assert!(output + .unwrap_err() + .to_string() + .contains("unknown field `this_key_does_not_exist`")); + + // Test invalid TOML table + let (mut cfg, cfg_path) = tempfile::NamedTempFile::new()?.into_parts(); + cfg.write_all( + br#" + [invalid_table] + "#, + )?; + let output = run_wasmtime(&[ + "run", + "--config", + cfg_path.to_str().unwrap(), + wasm.path().to_str().unwrap(), + ]); + + assert!(output.unwrap_err().to_string().contains( + "unknown field `invalid_table`, expected one of `optimize`, `codegen`, `debug`, `wasm`, `wasi`" + )); + + Ok(()) +}