Skip to content

Commit

Permalink
feat: Add AppContextWeak to prevent reference cycles (#529)
Browse files Browse the repository at this point in the history
This is useful for preventing reference cycles between things that are
held in the `AppContext` and also need a reference to the `AppContext`;
for example, `HealthCheck`s.

Closes #526
  • Loading branch information
spencewenski authored Dec 6, 2024
1 parent ef4508b commit 33e8e65
Show file tree
Hide file tree
Showing 8 changed files with 109 additions and 37 deletions.
6 changes: 3 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ license = "MIT OR Apache-2.0"
keywords = ["web", "framework"]
categories = ["web-programming", "web-programming::http-server"]
# Determined using `cargo msrv` -- https://github.com/foresterre/cargo-msrv
rust-version = "1.77.2"
rust-version = "1.81"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

Expand Down Expand Up @@ -153,8 +153,8 @@ schemars = "0.8.16"
mime = "0.3.17"

# DB
sea-orm = { version = "1.1.0" }
sea-orm-migration = { version = "1.1.0" }
sea-orm = { version = "1.1.2" }
sea-orm-migration = { version = "1.1.2" }

# Email
lettre = "0.11.0"
Expand Down
24 changes: 23 additions & 1 deletion src/app/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use anyhow::anyhow;
use axum_core::extract::FromRef;
#[cfg(feature = "db-sql")]
use sea_orm::DatabaseConnection;
use std::sync::{Arc, OnceLock};
use std::sync::{Arc, OnceLock, Weak};

#[cfg(not(test))]
type Inner = AppContextInner;
Expand All @@ -20,6 +20,21 @@ pub struct AppContext {
inner: Arc<Inner>,
}

/// A version of [`AppContext`] that holds a [`Weak`] pointer to the inner context. Useful for
/// preventing reference cycles between things that are held in the [`AppContext`] and also
/// need a reference to the [`AppContext`]; for example, [`HealthCheck`]s.
#[derive(Clone)]
pub struct AppContextWeak {
inner: Weak<Inner>,
}

impl AppContextWeak {
/// Get an [`AppContext`] from [`Self`].
pub fn upgrade(&self) -> Option<AppContext> {
self.inner.upgrade().map(|inner| AppContext { inner })
}
}

impl AppContext {
// This method isn't used when running tests; only the mocked version is used.
#[cfg_attr(test, allow(dead_code))]
Expand Down Expand Up @@ -118,6 +133,13 @@ impl AppContext {
Ok(context)
}

/// Get an [`AppContextWeak`] from [`Self`].
pub fn downgrade(&self) -> AppContextWeak {
AppContextWeak {
inner: Arc::downgrade(&self.inner),
}
}

#[cfg(test)]
pub(crate) fn test(
config: Option<AppConfig>,
Expand Down
18 changes: 13 additions & 5 deletions src/health_check/database.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use crate::api::core::health::db_health;
use crate::app::context::AppContext;
use crate::app::context::{AppContext, AppContextWeak};
use crate::error::RoadsterResult;
use crate::health_check::{CheckResponse, HealthCheck};
use crate::health_check::{missing_context_response, CheckResponse, HealthCheck};
use async_trait::async_trait;
use tracing::instrument;

pub struct DatabaseHealthCheck {
pub(crate) context: AppContext,
pub(crate) context: AppContextWeak,
}

#[async_trait]
Expand All @@ -16,12 +16,20 @@ impl HealthCheck for DatabaseHealthCheck {
}

fn enabled(&self) -> bool {
enabled(&self.context)
self.context
.upgrade()
.map(|context| enabled(&context))
.unwrap_or_default()
}

#[instrument(skip_all)]
async fn check(&self) -> RoadsterResult<CheckResponse> {
Ok(db_health(&self.context, None).await)
let context = self.context.upgrade();
let response = match context {
Some(context) => db_health(&context, None).await,
None => missing_context_response(),
};
Ok(response)
}
}

Expand Down
8 changes: 4 additions & 4 deletions src/health_check/default.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,19 @@ pub fn default_health_checks(
let health_checks: Vec<Arc<dyn HealthCheck>> = vec![
#[cfg(feature = "db-sql")]
Arc::new(DatabaseHealthCheck {
context: context.clone(),
context: context.downgrade(),
}),
#[cfg(feature = "sidekiq")]
Arc::new(SidekiqEnqueueHealthCheck {
context: context.clone(),
context: context.downgrade(),
}),
#[cfg(feature = "sidekiq")]
Arc::new(SidekiqFetchHealthCheck {
context: context.clone(),
context: context.downgrade(),
}),
#[cfg(feature = "email-smtp")]
Arc::new(SmtpHealthCheck {
context: context.clone(),
context: context.downgrade(),
}),
];

Expand Down
18 changes: 13 additions & 5 deletions src/health_check/email/smtp.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use crate::api::core::health::smtp_health;
use crate::app::context::AppContext;
use crate::app::context::{AppContext, AppContextWeak};
use crate::error::RoadsterResult;
use crate::health_check::{CheckResponse, HealthCheck};
use crate::health_check::{missing_context_response, CheckResponse, HealthCheck};
use async_trait::async_trait;
use tracing::instrument;

pub struct SmtpHealthCheck {
pub(crate) context: AppContext,
pub(crate) context: AppContextWeak,
}

#[async_trait]
Expand All @@ -16,12 +16,20 @@ impl HealthCheck for SmtpHealthCheck {
}

fn enabled(&self) -> bool {
enabled(&self.context)
self.context
.upgrade()
.map(|context| enabled(&context))
.unwrap_or_default()
}

#[instrument(skip_all)]
async fn check(&self) -> RoadsterResult<CheckResponse> {
Ok(smtp_health(&self.context, None).await)
let context = self.context.upgrade();
let response = match context {
Some(context) => smtp_health(&context, None).await,
None => missing_context_response(),
};
Ok(response)
}
}

Expand Down
23 changes: 21 additions & 2 deletions src/health_check/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ use schemars::JsonSchema;
use serde_derive::{Deserialize, Serialize};
use serde_json::Value;
use serde_with::{serde_as, skip_serializing_none};
use std::time::Duration;
use tracing::error;
use typed_builder::TypedBuilder;

#[serde_as]
Expand All @@ -27,11 +29,14 @@ use typed_builder::TypedBuilder;
pub struct CheckResponse {
pub status: Status,
/// Total latency of checking the health of the resource in milliseconds.
#[builder(setter(transform = |duration: std::time::Duration| duration.as_millis() ))]
#[builder(setter(transform = |duration: std::time::Duration| duration.as_millis()))]
pub latency: u128,
/// Custom health data, for example, separate latency measurements for acquiring a connection
/// from a resource pool vs making a request with the connection.
#[builder(default, setter(transform = |custom: impl serde::Serialize| serialize_custom(custom) ))]
#[builder(
default,
setter(transform = |custom: impl serde::Serialize| serialize_custom(custom))
)]
pub custom: Option<Value>,
}

Expand Down Expand Up @@ -86,3 +91,17 @@ pub trait HealthCheck: Send + Sync {
/// Run the [`HealthCheck`].
async fn check(&self) -> RoadsterResult<CheckResponse>;
}

// This method is not used in all feature configurations.
#[allow(dead_code)]
fn missing_context_response() -> CheckResponse {
error!("AppContext missing");
CheckResponse::builder()
.status(Status::Err(
ErrorData::builder()
.msg("Unknown error".to_string())
.build(),
))
.latency(Duration::from_secs(0))
.build()
}
18 changes: 13 additions & 5 deletions src/health_check/sidekiq_enqueue.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use crate::api::core::health::redis_health;
use crate::app::context::AppContext;
use crate::app::context::{AppContext, AppContextWeak};
use crate::error::RoadsterResult;
use crate::health_check::{CheckResponse, HealthCheck};
use crate::health_check::{missing_context_response, CheckResponse, HealthCheck};
use async_trait::async_trait;
use tracing::instrument;

pub struct SidekiqEnqueueHealthCheck {
pub(crate) context: AppContext,
pub(crate) context: AppContextWeak,
}

#[async_trait]
Expand All @@ -16,12 +16,20 @@ impl HealthCheck for SidekiqEnqueueHealthCheck {
}

fn enabled(&self) -> bool {
enabled(&self.context)
self.context
.upgrade()
.map(|context| enabled(&context))
.unwrap_or_default()
}

#[instrument(skip_all)]
async fn check(&self) -> RoadsterResult<CheckResponse> {
Ok(redis_health(self.context.redis_enqueue(), None).await)
let context = self.context.upgrade();
let response = match context {
Some(context) => redis_health(context.redis_enqueue(), None).await,
None => missing_context_response(),
};
Ok(response)
}
}

Expand Down
31 changes: 19 additions & 12 deletions src/health_check/sidekiq_fetch.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
use crate::api::core::health::redis_health;
use crate::app::context::AppContext;
use crate::app::context::{AppContext, AppContextWeak};
use crate::error::RoadsterResult;
use crate::health_check::{CheckResponse, HealthCheck};
use crate::health_check::{missing_context_response, CheckResponse, HealthCheck};
use anyhow::anyhow;
use async_trait::async_trait;
use tracing::instrument;

pub struct SidekiqFetchHealthCheck {
pub(crate) context: AppContext,
pub(crate) context: AppContextWeak,
}

#[async_trait]
Expand All @@ -17,19 +17,26 @@ impl HealthCheck for SidekiqFetchHealthCheck {
}

fn enabled(&self) -> bool {
enabled(&self.context)
self.context
.upgrade()
.map(|context| enabled(&context))
.unwrap_or_default()
}

#[instrument(skip_all)]
async fn check(&self) -> RoadsterResult<CheckResponse> {
Ok(redis_health(
self.context
.redis_fetch()
.as_ref()
.ok_or_else(|| anyhow!("Redis fetch connection pool is not present"))?,
None,
)
.await)
let context = self.context.upgrade();
let response = match context {
Some(context) => {
let redis = context
.redis_fetch()
.as_ref()
.ok_or_else(|| anyhow!("Redis fetch connection pool is not present"))?;
redis_health(redis, None).await
}
None => missing_context_response(),
};
Ok(response)
}
}

Expand Down

0 comments on commit 33e8e65

Please sign in to comment.