Skip to content

Commit

Permalink
Check for duplicate stdin usage on read instead of arg parsing
Browse files Browse the repository at this point in the history
This change relaxes the check for duplicate usage of `stdin` on arg
declarations (error at runtime if any two args use `MaybeStdin` or
`FileOrStdin`) and instead only check for duplicated `stdin` usage
when these args are accessed.

This allows usages of args that may be mutually exclusive (E.g. under
different subcommands) or use the `global=true` clap option (as reported
in #9)

If a tool happens to accept multiple args that can be Stdin, the CLI user
will only see an error if they actually try to use `stdin` for values
twice.
  • Loading branch information
thepacketgeek committed Jul 6, 2024
1 parent fb6d12f commit b52069e
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 44 deletions.
23 changes: 13 additions & 10 deletions src/file_or_stdin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,21 @@ use super::{Source, StdinError};
/// ```
#[derive(Debug, Clone)]
pub struct FileOrStdin<T = String> {
pub source: Source,
source: Source,
_type: PhantomData<T>,
}

impl<T> FileOrStdin<T> {
/// Was this value read from stdin
pub fn is_stdin(&self) -> bool {
matches!(self.source, Source::Stdin)
}

/// Was this value read from a file (path passed in from argument values)
pub fn is_file(&self) -> bool {
!self.is_stdin()
}

/// Read the entire contents from the input source, returning T::from_str
pub fn contents(self) -> Result<T, StdinError>
where
Expand Down Expand Up @@ -77,15 +87,8 @@ impl<T> FileOrStdin<T> {
/// # Ok(())
/// # }
/// ```
pub fn into_reader(&self) -> Result<impl std::io::Read, StdinError> {
let input: Box<dyn std::io::Read + 'static> = match &self.source {
Source::Stdin => Box::new(std::io::stdin()),
Source::Arg(filepath) => {
let f = std::fs::File::open(filepath)?;
Box::new(f)
}
};
Ok(input)
pub fn into_reader(self) -> Result<impl std::io::Read, StdinError> {
self.source.into_reader()
}

#[cfg(feature = "tokio")]
Expand Down
51 changes: 40 additions & 11 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#![doc = include_str!("../README.md")]

use std::io;
use std::io::{self, Read};
use std::str::FromStr;
use std::sync::atomic::AtomicBool;

Expand All @@ -9,11 +9,11 @@ pub use maybe_stdin::MaybeStdin;
mod file_or_stdin;
pub use file_or_stdin::FileOrStdin;

static STDIN_HAS_BEEN_USED: AtomicBool = AtomicBool::new(false);
static STDIN_HAS_BEEN_READ: AtomicBool = AtomicBool::new(false);

#[derive(Debug, thiserror::Error)]
pub enum StdinError {
#[error("stdin argument used more than once")]
#[error("stdin read from more than once")]
StdInRepeatedUse,
#[error(transparent)]
StdIn(#[from] io::Error),
Expand All @@ -23,23 +23,52 @@ pub enum StdinError {

/// Source of the value contents will be either from `stdin` or a CLI arg provided value
#[derive(Clone)]
pub enum Source {
pub(crate) enum Source {
Stdin,
Arg(String),
}

impl Source {
pub(crate) fn into_reader(self) -> Result<impl std::io::Read, StdinError> {
let input: Box<dyn std::io::Read + 'static> = match self {
Source::Stdin => {
if STDIN_HAS_BEEN_READ.load(std::sync::atomic::Ordering::Acquire) {
return Err(StdinError::StdInRepeatedUse);
}
STDIN_HAS_BEEN_READ.store(true, std::sync::atomic::Ordering::SeqCst);
Box::new(std::io::stdin())
}
Source::Arg(filepath) => {
let f = std::fs::File::open(filepath)?;
Box::new(f)
}
};
Ok(input)
}

pub(crate) fn get_value(self) -> Result<String, StdinError> {
match self {
Source::Stdin => {
if STDIN_HAS_BEEN_READ.load(std::sync::atomic::Ordering::Acquire) {
return Err(StdinError::StdInRepeatedUse);
}
STDIN_HAS_BEEN_READ.store(true, std::sync::atomic::Ordering::SeqCst);
let stdin = io::stdin();
let mut input = String::new();
stdin.lock().read_to_string(&mut input)?;
Ok(input)
}
Source::Arg(value) => Ok(value),
}
}
}

impl FromStr for Source {
type Err = StdinError;

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"-" => {
if STDIN_HAS_BEEN_USED.load(std::sync::atomic::Ordering::Acquire) {
return Err(StdinError::StdInRepeatedUse);
}
STDIN_HAS_BEEN_USED.store(true, std::sync::atomic::Ordering::SeqCst);
Ok(Self::Stdin)
}
"-" => Ok(Self::Stdin),
arg => Ok(Self::Arg(arg.to_owned())),
}
}
Expand Down
19 changes: 3 additions & 16 deletions src/maybe_stdin.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use std::io::{self, Read};
use std::str::FromStr;

use super::{Source, StdinError};
Expand Down Expand Up @@ -27,8 +26,6 @@ use super::{Source, StdinError};
/// ```
#[derive(Clone)]
pub struct MaybeStdin<T> {
/// Source of the contents
pub source: Source,
inner: T,
}

Expand All @@ -41,19 +38,9 @@ where

fn from_str(s: &str) -> Result<Self, Self::Err> {
let source = Source::from_str(s)?;
match &source {
Source::Stdin => {
let stdin = io::stdin();
let mut input = String::new();
stdin.lock().read_to_string(&mut input)?;
Ok(T::from_str(input.trim_end())
.map_err(|e| StdinError::FromStr(format!("{e}")))
.map(|val| Self { source, inner: val })?)
}
Source::Arg(value) => Ok(T::from_str(value)
.map_err(|e| StdinError::FromStr(format!("{e}")))
.map(|val| Self { source, inner: val })?),
}
T::from_str(source.get_value()?.trim())
.map_err(|e| StdinError::FromStr(format!("{e}")))
.map(|val| Self { inner: val })
}
}

Expand Down
7 changes: 4 additions & 3 deletions tests/fixtures/file_or_stdin_positional_arg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@ struct Args {
}

#[cfg(feature = "test_bin")]
fn main() {
fn main() -> Result<(), String> {
let args = Args::parse();
println!(
"FIRST: {}; SECOND: {:?}",
args.first.contents().unwrap(),
args.first.contents().map_err(|e| format!("{e}"))?,
args.second
);
Ok(())
}

#[cfg(feature = "test_bin_tokio")]
Expand All @@ -26,7 +27,7 @@ async fn main() -> anyhow::Result<()> {
let args = Args::parse();
println!(
"FIRST: {}; SECOND: {:?}",
args.first.contents_async().await.unwrap(),
args.first.contents_async().await?,
args.second
);
}
10 changes: 6 additions & 4 deletions tests/fixtures/file_or_stdin_twice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@ struct Args {
}

#[cfg(feature = "test_bin")]
fn main() {
fn main() -> Result<(), String> {
let args = Args::parse();
println!(
"FIRST: {}; SECOND: {}",
args.first.contents().unwrap(),
"FIRST: {}; SECOND: {:?}",
args.first.contents().map_err(|e| format!("{e}"))?,
args.second
);

Ok(())
}

#[cfg(feature = "test_bin_tokio")]
Expand All @@ -24,7 +26,7 @@ async fn main() -> anyhow::Result<()> {
let args = Args::parse();
println!(
"FIRST: {}; SECOND: {}",
args.first.contents_async().unwrap(),
args.first.contents_async()?,
args.second
);
}

0 comments on commit b52069e

Please sign in to comment.