Skip to content

Commit

Permalink
Nest authored test functions for safety (#1951)
Browse files Browse the repository at this point in the history
Like `tokio::test` et. al., nest the authored test functions and call them, which helps maintain code safety by separating contexts as recommended for all macros.
  • Loading branch information
heaths authored Dec 10, 2024
1 parent 4f1384a commit bae38b3
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 23 deletions.
19 changes: 17 additions & 2 deletions sdk/core/azure_core_test/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,17 @@ pub use azure_core::test::TestMode;
#[derive(Clone, Debug)]
pub struct TestContext {
test_mode: TestMode,
crate_dir: &'static str,
test_name: &'static str,
}

impl TestContext {
/// Not intended for use outside the `azure_core` crates.
#[doc(hidden)]
pub fn new(test_mode: TestMode, test_name: &'static str) -> Self {
pub fn new(test_mode: TestMode, crate_dir: &'static str, test_name: &'static str) -> Self {
Self {
test_mode,
crate_dir,
test_name,
}
}
Expand All @@ -35,6 +37,11 @@ impl TestContext {
self.test_mode
}

/// Gets the root directory of the crate under test.
pub fn crate_dir(&self) -> &'static str {
self.crate_dir
}

/// Gets the current test function name.
pub fn test_name(&self) -> &'static str {
self.test_name
Expand All @@ -47,8 +54,16 @@ mod tests {

#[test]
fn test_content_new() {
let ctx = TestContext::new(TestMode::default(), "test_content_new");
let ctx = TestContext::new(
TestMode::default(),
env!("CARGO_MANIFEST_DIR"),
"test_content_new",
);
assert_eq!(ctx.test_mode(), TestMode::Playback);
assert!(ctx
.crate_dir()
.replace("\\", "/")
.ends_with("sdk/core/azure_core_test"));
assert_eq!(ctx.test_name(), "test_content_new");
}
}
50 changes: 30 additions & 20 deletions sdk/core/azure_core_test_macros/src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,27 @@ use syn::{parse::Parse, spanned::Spanned, FnArg, ItemFn, Meta, PatType, Result,

const INVALID_RECORDED_ATTRIBUTE_MESSAGE: &str =
"expected `#[recorded::test]` or `#[recorded::test(live)]`";
const INVALID_RECORDED_FUNCTION_MESSAGE: &str = "expected `fn(TestContext)` function signature";
const INVALID_RECORDED_FUNCTION_MESSAGE: &str =
"expected `async fn(TestContext)` function signature with optional `Result<T, E>` return";

// cspell:ignore asyncness
pub fn parse_test(attr: TokenStream, item: TokenStream) -> Result<TokenStream> {
let recorded_attrs: Attributes = syn::parse2(attr)?;
let ItemFn {
attrs,
vis,
mut sig,
sig: original_sig,
block,
} = syn::parse2(item)?;

// Use #[tokio::test] for async functions; otherwise, #[test].
let mut test_attr: TokenStream = if sig.asyncness.is_some() {
quote! { #[::tokio::test] }
} else {
quote! { #[::core::prelude::v1::test] }
let mut test_attr: TokenStream = match original_sig.asyncness {
Some(_) => quote! { #[::tokio::test] },
None => {
return Err(syn::Error::new(
original_sig.span(),
INVALID_RECORDED_FUNCTION_MESSAGE,
))
}
};

// Ignore live-only tests if not running live tests.
Expand All @@ -36,21 +40,23 @@ pub fn parse_test(attr: TokenStream, item: TokenStream) -> Result<TokenStream> {
});
}

let mut inputs = sig.inputs.iter();
let preamble = match inputs.next() {
None if recorded_attrs.live => TokenStream::new(),
Some(FnArg::Typed(PatType { pat, ty, .. })) if is_test_context(ty.as_ref()) => {
let fn_name = &original_sig.ident;
let mut inputs = original_sig.inputs.iter();
let setup = match inputs.next() {
None if recorded_attrs.live => quote! {
#fn_name().await
},
Some(FnArg::Typed(PatType { ty, .. })) if is_test_context(ty.as_ref()) => {
let test_mode = test_mode_to_tokens(test_mode);
let fn_name = &sig.ident;

quote! {
#[allow(dead_code)]
let #pat = #ty::new(#test_mode, stringify!(#fn_name));
let ctx = #ty::new(#test_mode, env!("CARGO_MANIFEST_DIR"), stringify!(#fn_name));
#fn_name(ctx).await
}
}
_ => {
return Err(syn::Error::new(
sig.ident.span(),
original_sig.ident.span(),
INVALID_RECORDED_FUNCTION_MESSAGE,
))
}
Expand All @@ -63,14 +69,18 @@ pub fn parse_test(attr: TokenStream, item: TokenStream) -> Result<TokenStream> {
));
}

// Empty the parameters and return our rewritten test function.
sig.inputs.clear();
// Clear the actual test method parameters.
let mut outer_sig = original_sig.clone();
outer_sig.inputs.clear();

Ok(quote! {
#test_attr
#(#attrs)*
#vis #sig {
#preamble
#block
#vis #outer_sig {
#original_sig {
#block
}
#setup
}
})
}
Expand Down
2 changes: 1 addition & 1 deletion sdk/eventhubs/azure_messaging_eventhubs/tests/producer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ async fn test_get_partition_properties() {
}

#[recorded::test(live)]
fn test_create_eventdata() {
async fn test_create_eventdata() {
common::setup();
let data = b"hello world";
let ed1 = azure_messaging_eventhubs::models::EventData::builder()
Expand Down

0 comments on commit bae38b3

Please sign in to comment.