From 409d461bd8f3b472ad52debb70288b5d4e026d1b Mon Sep 17 00:00:00 2001 From: Mat Wood Date: Sat, 6 Jul 2024 07:38:42 -0700 Subject: [PATCH] Check for duplicate stdin usage on read instead of arg parsing (#10) This change relaxes the check for duplicate usage of `stdin` on arg declarations (error at runtime if any two args use `MaybeStdin` or `FileOrStdin`) and instead only check for duplicated `stdin` usage when these args are accessed. This allows usages of args that may be mutually exclusive (E.g. under different subcommands) or use the `global=true` clap option (as reported in #9) If a tool happens to accept multiple args that can be Stdin, the CLI user will only see an error if they actually try to use `stdin` for values twice. --- src/file_or_stdin.rs | 23 +++++---- src/lib.rs | 51 +++++++++++++++---- src/maybe_stdin.rs | 19 ++----- .../fixtures/file_or_stdin_positional_arg.rs | 7 +-- tests/fixtures/file_or_stdin_twice.rs | 10 ++-- 5 files changed, 66 insertions(+), 44 deletions(-) diff --git a/src/file_or_stdin.rs b/src/file_or_stdin.rs index 8e16a40..de0e9bd 100644 --- a/src/file_or_stdin.rs +++ b/src/file_or_stdin.rs @@ -37,11 +37,21 @@ use super::{Source, StdinError}; /// ``` #[derive(Debug, Clone)] pub struct FileOrStdin { - pub source: Source, + source: Source, _type: PhantomData, } impl FileOrStdin { + /// Was this value read from stdin + pub fn is_stdin(&self) -> bool { + matches!(self.source, Source::Stdin) + } + + /// Was this value read from a file (path passed in from argument values) + pub fn is_file(&self) -> bool { + !self.is_stdin() + } + /// Read the entire contents from the input source, returning T::from_str pub fn contents(self) -> Result where @@ -77,15 +87,8 @@ impl FileOrStdin { /// # Ok(()) /// # } /// ``` - pub fn into_reader(&self) -> Result { - let input: Box = match &self.source { - Source::Stdin => Box::new(std::io::stdin()), - Source::Arg(filepath) => { - let f = std::fs::File::open(filepath)?; - Box::new(f) - } - }; - Ok(input) + pub fn into_reader(self) -> Result { + self.source.into_reader() } #[cfg(feature = "tokio")] diff --git a/src/lib.rs b/src/lib.rs index 035dbab..ecd3862 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,6 @@ #![doc = include_str!("../README.md")] -use std::io; +use std::io::{self, Read}; use std::str::FromStr; use std::sync::atomic::AtomicBool; @@ -9,11 +9,11 @@ pub use maybe_stdin::MaybeStdin; mod file_or_stdin; pub use file_or_stdin::FileOrStdin; -static STDIN_HAS_BEEN_USED: AtomicBool = AtomicBool::new(false); +static STDIN_HAS_BEEN_READ: AtomicBool = AtomicBool::new(false); #[derive(Debug, thiserror::Error)] pub enum StdinError { - #[error("stdin argument used more than once")] + #[error("stdin read from more than once")] StdInRepeatedUse, #[error(transparent)] StdIn(#[from] io::Error), @@ -23,23 +23,52 @@ pub enum StdinError { /// Source of the value contents will be either from `stdin` or a CLI arg provided value #[derive(Clone)] -pub enum Source { +pub(crate) enum Source { Stdin, Arg(String), } +impl Source { + pub(crate) fn into_reader(self) -> Result { + let input: Box = match self { + Source::Stdin => { + if STDIN_HAS_BEEN_READ.load(std::sync::atomic::Ordering::Acquire) { + return Err(StdinError::StdInRepeatedUse); + } + STDIN_HAS_BEEN_READ.store(true, std::sync::atomic::Ordering::SeqCst); + Box::new(std::io::stdin()) + } + Source::Arg(filepath) => { + let f = std::fs::File::open(filepath)?; + Box::new(f) + } + }; + Ok(input) + } + + pub(crate) fn get_value(self) -> Result { + match self { + Source::Stdin => { + if STDIN_HAS_BEEN_READ.load(std::sync::atomic::Ordering::Acquire) { + return Err(StdinError::StdInRepeatedUse); + } + STDIN_HAS_BEEN_READ.store(true, std::sync::atomic::Ordering::SeqCst); + let stdin = io::stdin(); + let mut input = String::new(); + stdin.lock().read_to_string(&mut input)?; + Ok(input) + } + Source::Arg(value) => Ok(value), + } + } +} + impl FromStr for Source { type Err = StdinError; fn from_str(s: &str) -> Result { match s { - "-" => { - if STDIN_HAS_BEEN_USED.load(std::sync::atomic::Ordering::Acquire) { - return Err(StdinError::StdInRepeatedUse); - } - STDIN_HAS_BEEN_USED.store(true, std::sync::atomic::Ordering::SeqCst); - Ok(Self::Stdin) - } + "-" => Ok(Self::Stdin), arg => Ok(Self::Arg(arg.to_owned())), } } diff --git a/src/maybe_stdin.rs b/src/maybe_stdin.rs index 712b247..76cbb81 100644 --- a/src/maybe_stdin.rs +++ b/src/maybe_stdin.rs @@ -1,4 +1,3 @@ -use std::io::{self, Read}; use std::str::FromStr; use super::{Source, StdinError}; @@ -27,8 +26,6 @@ use super::{Source, StdinError}; /// ``` #[derive(Clone)] pub struct MaybeStdin { - /// Source of the contents - pub source: Source, inner: T, } @@ -41,19 +38,9 @@ where fn from_str(s: &str) -> Result { let source = Source::from_str(s)?; - match &source { - Source::Stdin => { - let stdin = io::stdin(); - let mut input = String::new(); - stdin.lock().read_to_string(&mut input)?; - Ok(T::from_str(input.trim_end()) - .map_err(|e| StdinError::FromStr(format!("{e}"))) - .map(|val| Self { source, inner: val })?) - } - Source::Arg(value) => Ok(T::from_str(value) - .map_err(|e| StdinError::FromStr(format!("{e}"))) - .map(|val| Self { source, inner: val })?), - } + T::from_str(source.get_value()?.trim()) + .map_err(|e| StdinError::FromStr(format!("{e}"))) + .map(|val| Self { inner: val }) } } diff --git a/tests/fixtures/file_or_stdin_positional_arg.rs b/tests/fixtures/file_or_stdin_positional_arg.rs index 8a414f7..05af24e 100644 --- a/tests/fixtures/file_or_stdin_positional_arg.rs +++ b/tests/fixtures/file_or_stdin_positional_arg.rs @@ -11,13 +11,14 @@ struct Args { } #[cfg(feature = "test_bin")] -fn main() { +fn main() -> Result<(), String> { let args = Args::parse(); println!( "FIRST: {}; SECOND: {:?}", - args.first.contents().unwrap(), + args.first.contents().map_err(|e| format!("{e}"))?, args.second ); + Ok(()) } #[cfg(feature = "test_bin_tokio")] @@ -26,7 +27,7 @@ async fn main() -> anyhow::Result<()> { let args = Args::parse(); println!( "FIRST: {}; SECOND: {:?}", - args.first.contents_async().await.unwrap(), + args.first.contents_async().await?, args.second ); } diff --git a/tests/fixtures/file_or_stdin_twice.rs b/tests/fixtures/file_or_stdin_twice.rs index 1c0b036..02e73e5 100644 --- a/tests/fixtures/file_or_stdin_twice.rs +++ b/tests/fixtures/file_or_stdin_twice.rs @@ -9,13 +9,15 @@ struct Args { } #[cfg(feature = "test_bin")] -fn main() { +fn main() -> Result<(), String> { let args = Args::parse(); println!( - "FIRST: {}; SECOND: {}", - args.first.contents().unwrap(), + "FIRST: {}; SECOND: {:?}", + args.first.contents().map_err(|e| format!("{e}"))?, args.second ); + + Ok(()) } #[cfg(feature = "test_bin_tokio")] @@ -24,7 +26,7 @@ async fn main() -> anyhow::Result<()> { let args = Args::parse(); println!( "FIRST: {}; SECOND: {}", - args.first.contents_async().unwrap(), + args.first.contents_async()?, args.second ); }