Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Add symlink subcommand and support for multiple query types #423

Merged
merged 1 commit into from
Nov 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ pub struct Cli {

impl Cli {
pub fn parse() -> Self {
#[cfg(feature = "resolve-cli")]
if ResolveCommand::is_resolve_cli() {
return ResolveCommand::parse().into();
}

match Self::try_parse() {
Ok(cli) => cli,
Err(e) => {
Expand Down Expand Up @@ -96,10 +101,17 @@ pub enum Commands {
command: ServiceCommands,
},

/// Perform DNS resolution. Can be used in place of the standard OS resolution facilities.
/// Perform DNS resolution.
#[cfg(feature = "resolve-cli")]
Resolve(ResolveCommand),

/// Create a symbolic link to the Smart-DNS binary (drop-in replacement for `dig`, `nslookup`, `resolve` etc.)
#[cfg(feature = "resolve-cli")]
Symlink {
/// The path to the symlink to create.
link: std::path::PathBuf,
},

/// Test configuration and exit
Test {
/// Config file
Expand Down
19 changes: 19 additions & 0 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,25 @@ impl Cli {
drop(_guard);
command.execute();
}
#[cfg(all(feature = "resolve-cli", any(unix, windows)))]
Commands::Symlink { link } => {
let original = std::env::current_exe().expect("failed to get current exe path");
if link.exists() {
println!("link already exists");
return;
}

#[cfg(unix)]
let res = std::os::unix::fs::symlink(original, link);

#[cfg(windows)]
let res = std::os::windows::fs::symlink_file(original, link);

match res {
Ok(()) => println!("symlink created"),
Err(err) => println!("failed to create symlink, {}", err),
}
}
#[allow(unreachable_patterns)]
_ => {
unimplemented!()
Expand Down
117 changes: 90 additions & 27 deletions src/resolver.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::collections::HashMap;
use std::path::Path;
use std::{ops::Deref, str::FromStr, time::Duration};

use clap::Parser;
Expand All @@ -7,10 +8,7 @@ use console::{style, StyledObject};

use crate::libdns::proto::{
op::Message,
rr::{
DNSClass as QueryClass, Name as Domain, Record, RecordData, RecordType as QueryType,
RecordType,
},
rr::{DNSClass as QueryClass, Name as Domain, Record, RecordData, RecordType},
xfer::Protocol as DnsOverProtocol,
};

Expand All @@ -30,7 +28,7 @@ impl ResolveCommand {
}
}
let domain = self.domain().clone();
let query_type = self.q_type();
let query_types = self.q_type();

let palette = Colours::pretty();

Expand All @@ -50,17 +48,19 @@ impl ResolveCommand {
DnsClient::builder().build().await
};

let options = LookupOptions {
record_type: query_type,
..Default::default()
};

match dns_client.lookup(domain, options).await {
Ok(res) => {
print(&res, &palette);
}
Err(err) => {
println!("{}", err);
for query_type in query_types {
let options = LookupOptions {
record_type: *query_type,
..Default::default()
};

match dns_client.lookup(domain.clone(), options).await {
Ok(res) => {
print(&res, &palette);
}
Err(err) => {
println!("{}", err);
}
}
}
});
Expand Down Expand Up @@ -100,7 +100,7 @@ pub struct ResolveCommand {

/// is one of (a,any,mx,ns,soa,hinfo,axfr,txt,...)
#[arg(value_name = "q-type", default_value = "a", value_parser = Self::parse_query_type)]
q_type: QueryType,
q_type: QueryTypes,

/// is one of (in,hs,ch,...)
#[arg(value_name = "q-class", default_value = "in", value_parser = Self::parse_query_class)]
Expand All @@ -112,13 +112,26 @@ pub struct ResolveCommand {
}

impl ResolveCommand {
pub fn parse() -> Self {
match Parser::try_parse() {
Ok(cli) => cli,
Err(e) => {
if let Ok(resolve_command) = ResolveCommand::try_parse() {
return resolve_command;
}
e.exit()
}
}
}

pub fn try_parse() -> Result<Self, String> {
use DnsOverProtocol::*;
let mut proto = None;
let mut q_type = None;
let mut q_types = vec![];
let mut q_class = None;
let mut domain = None;
let mut global_server = None;
let mut prev_parsing_qtype = false;

for arg in std::env::args().skip(1) {
if arg == "resolve" {
Expand Down Expand Up @@ -159,11 +172,18 @@ impl ResolveCommand {
continue;
}

if q_type.is_none() {
if let Ok(t) = Self::parse_query_type(arg.as_str()) {
q_type = Some(t);
if q_types.is_empty() {
if let Ok(t) = Self::parse_query_type(&arg) {
q_types = t.0;
prev_parsing_qtype = true;
continue;
}
} else if prev_parsing_qtype {
if let Ok(t) = Self::parse_query_type(&arg) {
q_types.extend(t.0);
continue;
}
prev_parsing_qtype = false;
}

if q_class.is_none() {
Expand All @@ -179,14 +199,17 @@ impl ResolveCommand {
continue;
}
}

return Err(format!("Invalid argument {arg}"));
}

let Some(domain) = domain else {
return Err("domain is required".to_string());
};

let q_type = q_type.unwrap_or(QueryType::A);
if q_types.is_empty() {
q_types.push(RecordType::A);
}
let q_class = q_class.unwrap_or(QueryClass::IN);

Ok(Self {
Expand All @@ -198,11 +221,27 @@ impl ResolveCommand {
h3: matches!(proto, Some(H3)),
global_server,
domain,
q_type,
q_type: QueryTypes(q_types),
q_class,
})
}

pub fn is_resolve_cli() -> bool {
std::env::args()
.next()
.as_deref()
.map(Path::new)
.and_then(|s| s.file_stem())
.and_then(|s| s.to_str())
.map(|s| match s {
"dig" => true,
"nslookup" => true,
"resolve" => true,
_ => false,
})
.unwrap_or_default()
}

pub fn proto(&self) -> Option<DnsOverProtocol> {
use DnsOverProtocol::*;
if self.udp {
Expand Down Expand Up @@ -230,8 +269,8 @@ impl ResolveCommand {
&self.domain
}

pub fn q_type(&self) -> QueryType {
self.q_type
pub fn q_type(&self) -> &[RecordType] {
&self.q_type.0
}

pub fn q_class(&self) -> QueryClass {
Expand All @@ -245,14 +284,38 @@ impl ResolveCommand {
Err(format!("Invalid global server: {}", s))
}
}
fn parse_query_type(s: &str) -> Result<QueryType, String> {
QueryType::from_str(s.to_uppercase().as_str()).map_err(|e| e.to_string())
fn parse_query_type(s: &str) -> Result<QueryTypes, String> {
if s.contains("+") {
let mut types = Vec::new();
let mut last_err = None;
for t in s.split('+') {
match RecordType::from_str(t.to_uppercase().as_str()) {
Ok(t) => types.push(t),
Err(err) => last_err = Some(err),
}
}

if types.is_empty() {
return Err(last_err
.map(|e| e.to_string())
.unwrap_or("Invalid query type".to_string()));
}

Ok(QueryTypes(types))
} else {
RecordType::from_str(s.to_uppercase().as_str())
.map(|q| QueryTypes(vec![q]))
.map_err(|e| e.to_string())
}
}
fn parse_query_class(s: &str) -> Result<QueryClass, String> {
QueryClass::from_str(s.to_uppercase().as_str()).map_err(|e| e.to_string())
}
}

#[derive(Debug, Clone)]
struct QueryTypes(Vec<RecordType>);

fn print(message: &Message, palette: &Colours) {
for r in message.answers() {
print_record(&r, palette);
Expand Down