diff --git a/.circleci/config.yml b/.circleci/config.yml index 6e50aecaf..95ea79bd6 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -420,6 +420,9 @@ jobs: stripe-secret-key: description: "Stripe secret key used to connect a client to Stripe backend" type: string + stripe-rds-price-id: + description: "Stripe price id of Shuttle AWS RDS product." + type: string jwt-signing-private-key: description: "Auth private key used for JWT signing" type: string @@ -457,6 +460,7 @@ jobs: DEPLOYS_API_KEY=${<< parameters.deploys-api-key >>} \ LOGGER_POSTGRES_URI=${<< parameters.logger-postgres-uri >>} \ STRIPE_SECRET_KEY=${<< parameters.stripe-secret-key >>} \ + STRIPE_RDS_PRICE_ID=${<< parameters.stripe-rds-price-id >>} \ AUTH_JWTSIGNING_PRIVATE_KEY=${<< parameters.jwt-signing-private-key >>} \ CONTROL_DB_POSTGRES_URI=${<< parameters.control-db-postgres-uri >>} \ make deploy @@ -834,9 +838,9 @@ workflows: only: main - approve-build-and-push-images-unstable: type: approval - filters: - branches: - only: main + # filters: + # branches: + # only: main - build-and-push: name: build-and-push-unstable aws-access-key-id: DEV_AWS_ACCESS_KEY_ID @@ -852,6 +856,7 @@ workflows: deploys-api-key: DEV_DEPLOYS_API_KEY logger-postgres-uri: DEV_LOGGER_POSTGRES_URI stripe-secret-key: DEV_STRIPE_SECRET_KEY + stripe-rds-price-id: DEV_STRIPE_RDS_PRICE_ID jwt-signing-private-key: DEV_AUTH_JWTSIGNING_PRIVATE_KEY control-db-postgres-uri: DEV_CONTROL_DB_POSTGRES_URI requires: @@ -933,6 +938,7 @@ workflows: deploys-api-key: PROD_DEPLOYS_API_KEY logger-postgres-uri: PROD_LOGGER_POSTGRES_URI stripe-secret-key: PROD_STRIPE_SECRET_KEY + stripe-rds-price-id: PROD_STRIPE_RDS_PRICE_ID jwt-signing-private-key: PROD_AUTH_JWTSIGNING_PRIVATE_KEY control-db-postgres-uri: PROD_CONTROL_DB_POSTGRES_URI ssh-fingerprint: 6a:c5:33:fe:5b:c9:06:df:99:64:ca:17:0d:32:18:2e diff --git a/Cargo.lock b/Cargo.lock index e9779f68d..01e56fe73 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -189,6 +189,16 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "assert-json-diff" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47e4f2b81832e72834d7518d8487a0396a28cc408186a2e8854c0f98011faf12" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "assert_cmd" version = "2.0.12" @@ -1775,6 +1785,25 @@ version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7e962a19be5cfc3f3bf6dd8f61eb50107f356ad6270fbb3ed41476571db78be5" +[[package]] +name = "deadpool" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "421fe0f90f2ab22016f32a9881be5134fdd71c65298917084b0c7477cbc3856e" +dependencies = [ + "async-trait", + "deadpool-runtime", + "num_cpus", + "retain_mut", + "tokio", +] + +[[package]] +name = "deadpool-runtime" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63dfa964fe2a66f3fde91fc70b267fe193d822c7e603e2a675a49a7f46ad3f49" + [[package]] name = "debugid" version = "0.8.0" @@ -2277,6 +2306,12 @@ version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "efd193069b0ddadc69c46389b740bbccdd97203899b48d09c5f7969591d6bae2" +[[package]] +name = "futures-timer" +version = "3.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e64b03909df88034c26dc1547e8970b91f98bdb65165d6a4e9110d94263dbb2c" + [[package]] name = "futures-util" version = "0.3.29" @@ -5164,6 +5199,12 @@ dependencies = [ "quick-error", ] +[[package]] +name = "retain_mut" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4389f1d5789befaf6029ebd9f7dac4af7f7e3d61b69d4f30e2ac02b57e7712b0" + [[package]] name = "retry-policies" version = "0.2.1" @@ -5799,6 +5840,7 @@ dependencies = [ "tracing", "tracing-opentelemetry", "tracing-subscriber", + "wiremock", ] [[package]] @@ -6089,6 +6131,7 @@ dependencies = [ "portpicker", "prost 0.12.3", "rand 0.8.5", + "reqwest", "serde_json", "shuttle-common", "shuttle-proto", @@ -6098,6 +6141,7 @@ dependencies = [ "tonic 0.10.2", "tracing", "tracing-subscriber", + "wiremock", ] [[package]] @@ -8663,6 +8707,28 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "wiremock" +version = "0.5.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13a3a53eaf34f390dd30d7b1b078287dd05df2aa2e21a589ccb80f5c7253c2e9" +dependencies = [ + "assert-json-diff", + "async-trait", + "base64 0.21.5", + "deadpool", + "futures", + "futures-timer", + "http-types", + "hyper", + "log", + "once_cell", + "regex", + "serde", + "serde_json", + "tokio", +] + [[package]] name = "wit-parser" version = "0.11.3" diff --git a/Makefile b/Makefile index b2ea5b69b..a9b1100bb 100644 --- a/Makefile +++ b/Makefile @@ -47,6 +47,7 @@ MONGO_INITDB_ROOT_USERNAME?=mongodb MONGO_INITDB_ROOT_PASSWORD?=password STRIPE_SECRET_KEY?="" AUTH_JWTSIGNING_PRIVATE_KEY?="" +STRIPE_RDS_PRICE_ID?="" DD_ENV=$(SHUTTLE_ENV) ifeq ($(SHUTTLE_ENV),production) @@ -136,6 +137,7 @@ DOCKER_COMPOSE_ENV=\ MONGO_INITDB_ROOT_USERNAME=$(MONGO_INITDB_ROOT_USERNAME)\ MONGO_INITDB_ROOT_PASSWORD=$(MONGO_INITDB_ROOT_PASSWORD)\ STRIPE_SECRET_KEY=$(STRIPE_SECRET_KEY)\ + STRIPE_RDS_PRICE_ID=$(STRIPE_RDS_PRICE_ID)\ AUTH_JWTSIGNING_PRIVATE_KEY=$(AUTH_JWTSIGNING_PRIVATE_KEY)\ DD_ENV=$(DD_ENV)\ USE_TLS=$(USE_TLS)\ diff --git a/auth/Cargo.toml b/auth/Cargo.toml index 01a47d3d7..507299212 100644 --- a/auth/Cargo.toml +++ b/auth/Cargo.toml @@ -26,6 +26,7 @@ pem = "2" rand = { workspace = true } ring = { workspace = true } serde = { workspace = true, features = ["derive"] } +serde_json = { workspace = true } sqlx = { workspace = true, features = ["postgres", "json", "migrate"] } strum = { workspace = true } thiserror = { workspace = true } @@ -43,3 +44,4 @@ portpicker = { workspace = true } serde_json = { workspace = true } shuttle-common-tests = { workspace = true } tower = { workspace = true, features = ["util"] } +wiremock = "0.5" diff --git a/auth/README b/auth/README deleted file mode 100644 index 3aeaadae7..000000000 --- a/auth/README +++ /dev/null @@ -1,13 +0,0 @@ -# Auth service considerations - -## JWT signing private key - -Starting the service locally requires provisioning of a base64 encoded PEM encoded PKCS#8 v1 unencrypted private key. -The service was tested with keys generated as follows: - -```bash -openssl genpkey -algorithm ED25519 -out auth_jwtsigning_private_key.pem -base64 < auth_jwtsigning_private_key.pem -``` - -Used `OpenSSL 3.1.2 1 Aug 2023 (Library: OpenSSL 3.1.2 1 Aug 2023)` and `FreeBSD base64`, on a `macOS Sonoma 14.1.1`. \ No newline at end of file diff --git a/auth/README.md b/auth/README.md new file mode 100644 index 000000000..1e000d423 --- /dev/null +++ b/auth/README.md @@ -0,0 +1,40 @@ +# Auth service considerations + +## JWT signing private key + +Starting the service locally requires provisioning of a base64 encoded PEM encoded PKCS#8 v1 unencrypted private key. +The service was tested with keys generated as follows: + +```bash +openssl genpkey -algorithm ED25519 -out auth_jwtsigning_private_key.pem +base64 < auth_jwtsigning_private_key.pem +``` + +Used `OpenSSL 3.1.2 1 Aug 2023 (Library: OpenSSL 3.1.2 1 Aug 2023)` and `FreeBSD base64`, on a `macOS Sonoma 14.1.1`. + +## Running the binary on it's own + +**The below commands are ran from the root of the repo** + +- First, start the control db container: + +``` +docker compose -f docker-compose.rendered.yml up control-db +``` + +- Then insert an admin user into the database: + +``` +cargo run --bin shuttle-auth -- --db-connection-uri postgres://postgres:postgres@localhost:5434/postgres init-admin --name admin +``` + +- Then start the service, you can get a stripe-secret-key from the Stripe dashboard. **Always use the test Stripe API for this**. See instructions above for generating a jwt-signing-private-key. + +``` +cargo run --bin shuttle-auth -- \ + --db-connection-uri postgres://postgres:postgres@localhost:5434/postgres \ + start \ + --stripe-secret-key sk_test_ \ + --jwt-signing-private-key \ + --stripe-rds-price-id price_1OIS06FrN7EDaGOjaV0GXD7P +``` diff --git a/auth/src/api/builder.rs b/auth/src/api/builder.rs index 282cd573a..9dba96b73 100644 --- a/auth/src/api/builder.rs +++ b/auth/src/api/builder.rs @@ -2,6 +2,7 @@ use std::{net::SocketAddr, sync::Arc}; use axum::{ extract::FromRef, + handler::Handler, middleware::from_extractor, routing::{get, post, put}, Router, Server, @@ -9,7 +10,11 @@ use axum::{ use axum_sessions::{async_session::MemoryStore, SessionLayer}; use rand::RngCore; use shuttle_common::{ - backends::metrics::{Metrics, TraceLayer}, + backends::{ + auth::{JwtAuthenticationLayer, ScopedLayer}, + metrics::{Metrics, TraceLayer}, + }, + claims::Scope, request_span, }; use sqlx::PgPool; @@ -22,8 +27,8 @@ use crate::{ }; use super::handlers::{ - convert_cookie, convert_key, get_public_key, get_user, health_check, logout, post_user, - put_user_reset_key, refresh_token, update_user_tier, + add_subscription_items, convert_cookie, convert_key, get_public_key, get_user, health_check, + logout, post_user, put_user_reset_key, refresh_token, update_user_tier, }; pub type UserManagerState = Arc>; @@ -33,6 +38,7 @@ pub type KeyManagerState = Arc>; pub struct RouterState { pub user_manager: UserManagerState, pub key_manager: KeyManagerState, + pub rds_price_id: String, } // Allow getting a user management state directly @@ -54,17 +60,16 @@ pub struct ApiBuilder { pool: Option, session_layer: Option>, stripe_client: Option, - jwt_signing_private_key: Option, -} - -impl Default for ApiBuilder { - fn default() -> Self { - Self::new() - } + rds_price_id: Option, + key_manager: EdDsaManager, } impl ApiBuilder { - pub fn new() -> Self { + pub fn new(jwt_signing_private_key: String) -> Self { + let key_manager = EdDsaManager::new(jwt_signing_private_key); + + let public_key = key_manager.public_key().to_vec(); + let router = Router::new() .route("/", get(health_check)) .route("/logout", post(logout)) @@ -73,6 +78,17 @@ impl ApiBuilder { .route("/auth/refresh", post(refresh_token)) .route("/public-key", get(get_public_key)) .route("/users/:account_name", get(get_user)) + .route( + "/users/subscription/items", + post( + add_subscription_items + .layer(ScopedLayer::new(vec![Scope::ResourcesWrite])) + .layer(JwtAuthenticationLayer::new(move || { + let public_key = public_key.clone(); + async move { public_key.clone() } + })), + ), + ) .route( "/users/:account_name/:account_tier", post(post_user).put(update_user_tier), @@ -96,7 +112,8 @@ impl ApiBuilder { pool: None, session_layer: None, stripe_client: None, - jwt_signing_private_key: None, + rds_price_id: None, + key_manager, } } @@ -124,8 +141,8 @@ impl ApiBuilder { self } - pub fn with_jwt_signing_private_key(mut self, private_key: String) -> Self { - self.jwt_signing_private_key = Some(private_key); + pub fn with_rds_price_id(mut self, price_id: String) -> Self { + self.rds_price_id = Some(price_id); self } @@ -133,18 +150,17 @@ impl ApiBuilder { let pool = self.pool.expect("an sqlite pool is required"); let session_layer = self.session_layer.expect("a session layer is required"); let stripe_client = self.stripe_client.expect("a stripe client is required"); - let jwt_signing_private_key = self - .jwt_signing_private_key - .expect("a jwt signing private key"); + let rds_price_id = self.rds_price_id.expect("rds price id is required"); + let user_manager = UserManager { pool, stripe_client, }; - let key_manager = EdDsaManager::new(jwt_signing_private_key); let state = RouterState { user_manager: Arc::new(Box::new(user_manager)), - key_manager: Arc::new(Box::new(key_manager)), + key_manager: Arc::new(Box::new(self.key_manager)), + rds_price_id, }; self.router.layer(session_layer).with_state(state) diff --git a/auth/src/api/handlers.rs b/auth/src/api/handlers.rs index b65b17e01..2bda45b85 100644 --- a/auth/src/api/handlers.rs +++ b/auth/src/api/handlers.rs @@ -1,25 +1,22 @@ use crate::{ error::Error, user::{AccountName, Admin, Key, User}, + NewSubscriptionItemExtractor, }; use axum::{ extract::{Path, State}, - Json, + Extension, Json, }; use axum_sessions::extractors::{ReadableSession, WritableSession}; use http::StatusCode; -use serde::{Deserialize, Serialize}; use shuttle_common::{ claims::{AccountTier, Claim}, models::user, }; use stripe::CheckoutSession; -use tracing::instrument; +use tracing::{error, instrument}; -use super::{ - builder::{KeyManagerState, UserManagerState}, - RouterState, -}; +use super::{builder::KeyManagerState, RouterState, UserManagerState}; #[instrument(skip(user_manager))] pub(crate) async fn get_user( @@ -68,6 +65,32 @@ pub(crate) async fn update_user_tier( Ok(()) } +#[instrument(skip(claim, user_manager), fields(account_name = claim.sub, account_tier = %claim.tier))] +pub(crate) async fn add_subscription_items( + Extension(claim): Extension, + State(user_manager): State, + NewSubscriptionItemExtractor(update_subscription_items): NewSubscriptionItemExtractor, +) -> Result<(), Error> { + // Fetching the user will also sync their subscription. This means we can verify that the + // caller still has the required tier after the sync. + let user = user_manager.get_user(AccountName::from(claim.sub)).await?; + + let Some(ref subscription_id) = user.subscription_id else { + return Err(Error::MissingSubscriptionId); + }; + + if !matches![user.account_tier, AccountTier::Pro | AccountTier::Admin] { + error!("account was downgraded from pro in sync, denying the addition of new items"); + return Err(Error::Unauthorized); + } + + user_manager + .add_subscription_items(subscription_id, update_subscription_items) + .await?; + + Ok(()) +} + pub(crate) async fn put_user_reset_key( session: ReadableSession, State(user_manager): State, @@ -124,6 +147,7 @@ pub(crate) async fn convert_key( State(RouterState { key_manager, user_manager, + .. }): State, key: Key, ) -> Result, StatusCode> { @@ -153,8 +177,3 @@ pub(crate) async fn refresh_token() {} pub(crate) async fn get_public_key(State(key_manager): State) -> Vec { key_manager.public_key().to_vec() } - -#[derive(Deserialize, Serialize)] -pub struct LoginRequest { - account_name: AccountName, -} diff --git a/auth/src/args.rs b/auth/src/args.rs index 9023c3676..49355a7d5 100644 --- a/auth/src/args.rs +++ b/auth/src/args.rs @@ -31,8 +31,12 @@ pub struct StartArgs { /// Auth JWT signing private key, as a base64 encoding of /// a PEM encoded PKCS#8 v1 formatted unencrypted private key. - #[arg(long, default_value = "")] + #[arg(long)] pub jwt_signing_private_key: String, + + /// The price id of the AWS RDS product. + #[arg(long, default_value = "price_1OIS06FrN7EDaGOjaV0GXD7P")] + pub stripe_rds_price_id: String, } #[derive(clap::Args, Debug, Clone)] diff --git a/auth/src/error.rs b/auth/src/error.rs index 8efb237f0..3b5752516 100644 --- a/auth/src/error.rs +++ b/auth/src/error.rs @@ -20,7 +20,7 @@ pub enum Error { #[error("Database error: {0}")] Database(#[from] sqlx::Error), #[error(transparent)] - UnexpectedError(#[from] anyhow::Error), + Internal(#[from] anyhow::Error), #[error("Missing checkout session.")] MissingCheckoutSession, #[error("Incomplete checkout session.")] diff --git a/auth/src/lib.rs b/auth/src/lib.rs index d2661ff80..fdf586b4f 100644 --- a/auth/src/lib.rs +++ b/auth/src/lib.rs @@ -2,6 +2,7 @@ mod api; mod args; mod error; mod secrets; +mod subscription; mod user; use std::{io, time::Duration}; @@ -12,22 +13,27 @@ use sqlx::{migrate::Migrator, query, PgPool}; use tracing::info; use crate::api::serve; -pub use api::ApiBuilder; +pub use api::{ApiBuilder, RouterState}; pub use args::{Args, Commands, InitArgs}; +pub use subscription::NewSubscriptionItemExtractor; pub const COOKIE_EXPIRATION: Duration = Duration::from_secs(60 * 60 * 24); // One day pub static MIGRATIONS: Migrator = sqlx::migrate!("./migrations"); pub async fn start(pool: PgPool, args: StartArgs) -> io::Result<()> { - let router = api::ApiBuilder::new() + let router = api::ApiBuilder::new(args.jwt_signing_private_key) .with_pg_pool(pool) .with_sessions() .with_stripe_client(stripe::Client::new(args.stripe_secret_key)) - .with_jwt_signing_private_key(args.jwt_signing_private_key) + .with_rds_price_id(args.stripe_rds_price_id.clone()) .into_router(); info!(address=%args.address, "Binding to and listening at address"); + info!( + "starting auth service with RDS price id: {}", + args.stripe_rds_price_id + ); serve(router, args.address).await; diff --git a/auth/src/subscription.rs b/auth/src/subscription.rs new file mode 100644 index 000000000..adf2103ad --- /dev/null +++ b/auth/src/subscription.rs @@ -0,0 +1,47 @@ +use async_trait::async_trait; +use axum::extract::{FromRef, FromRequest}; +use axum::response::{IntoResponse, Response}; +use axum::BoxError; +use http::Request; +use shuttle_common::backends::subscription::{NewSubscriptionItem, SubscriptionItem}; + +use crate::RouterState; + +/// A wrapper for [stripe::UpdateSubscriptionItems] so we can implement [FromRequest] for it. +pub struct NewSubscriptionItemExtractor(pub stripe::UpdateSubscriptionItems); + +#[async_trait] +impl FromRequest for NewSubscriptionItemExtractor +where + B: axum::body::HttpBody + Send + 'static, + B::Data: Send, + B::Error: Into, + RouterState: FromRef, + S: Send + Sync, +{ + type Rejection = Response; + + async fn from_request(req: Request, state: &S) -> Result { + // Extract the NewSubscriptionItem, the struct that other services should use when calling + // the endpoint to add subscription items. + let NewSubscriptionItem { quantity, item } = axum::Json::from_request(req, state) + .await + .map_err(IntoResponse::into_response)? + .0; + + // Access the router state to extract price IDs. + let state = RouterState::from_ref(state); + + let price_id = match item { + SubscriptionItem::AwsRds => state.rds_price_id, + }; + + let update_subscription_items = stripe::UpdateSubscriptionItems { + price: Some(price_id), + quantity: Some(quantity), + ..Default::default() + }; + + Ok(Self(update_subscription_items)) + } +} diff --git a/auth/src/user.rs b/auth/src/user.rs index b2fa83758..4bdf08b69 100644 --- a/auth/src/user.rs +++ b/auth/src/user.rs @@ -8,7 +8,7 @@ use axum::{ TypedHeader, }; use serde::{Deserialize, Deserializer, Serialize}; -use shuttle_common::{claims::AccountTier, secrets::Secret, ApiKey}; +use shuttle_common::{claims::AccountTier, ApiKey, Secret}; use sqlx::{postgres::PgRow, query, FromRow, PgPool, Row}; use tracing::{debug, error, trace, Span}; @@ -29,6 +29,11 @@ pub trait UserManagement: Send + Sync { async fn get_user(&self, name: AccountName) -> Result; async fn get_user_by_key(&self, key: ApiKey) -> Result; async fn reset_key(&self, name: AccountName) -> Result<(), Error>; + async fn add_subscription_items( + &self, + subscription_id: &SubscriptionId, + subscription_item: stripe::UpdateSubscriptionItems, + ) -> Result<(), Error>; } #[derive(Clone)] @@ -149,6 +154,22 @@ impl UserManagement for UserManager { Ok(user) } + async fn add_subscription_items( + &self, + subscription_id: &SubscriptionId, + subscription_items: stripe::UpdateSubscriptionItems, + ) -> Result<(), Error> { + let subscription_update = stripe::UpdateSubscription { + items: Some(vec![subscription_items]), + ..Default::default() + }; + + stripe::Subscription::update(&self.stripe_client, subscription_id, subscription_update) + .await?; + + Ok(()) + } + async fn reset_key(&self, name: AccountName) -> Result<(), Error> { let key = ApiKey::generate(); @@ -206,7 +227,7 @@ impl User { Ok(false) } - // Synchronize the tiers with the subscription validity. + /// Synchronize the tiers with the subscription validity. async fn sync_tier(&mut self, user_manager: &UserManager) -> Result { let has_pro_access = self.account_tier == AccountTier::Pro || self.account_tier == AccountTier::CancelledPro diff --git a/auth/tests/api/helpers.rs b/auth/tests/api/helpers.rs index 3994ac492..11bea0c8b 100644 --- a/auth/tests/api/helpers.rs +++ b/auth/tests/api/helpers.rs @@ -1,24 +1,24 @@ -use crate::stripe::MOCKED_SUBSCRIPTIONS; -use axum::{body::Body, extract::Path, extract::State, response::Response, routing::get, Router}; +use axum::{body::Body, response::Response, Router}; use http::{header::CONTENT_TYPE, StatusCode}; -use hyper::{ - http::{header::AUTHORIZATION, Request}, - Server, -}; +use hyper::http::{header::AUTHORIZATION, Request}; use once_cell::sync::Lazy; use serde_json::Value; use shuttle_auth::{pgpool_init, ApiBuilder}; use shuttle_common::claims::{AccountTier, Claim}; use shuttle_common_tests::postgres::DockerInstance; use sqlx::query; -use std::{ - net::SocketAddr, - str::FromStr, - sync::{Arc, Mutex}, -}; use tower::ServiceExt; +use wiremock::{ + matchers::{bearer_token, method, path}, + Mock, MockServer, ResponseTemplate, +}; +/// Admin user API key. pub(crate) const ADMIN_KEY: &str = "ndh9z58jttoes3qv"; +/// Stripe test API key. +pub(crate) const STRIPE_TEST_KEY: &str = "sk_test_123"; +/// Stripe test RDS price id. +pub(crate) const STRIPE_TEST_RDS_PRICE_ID: &str = "price_1OIS06FrN7EDaGOjaV0GXD7P"; static PG: Lazy = Lazy::new(DockerInstance::default); #[ctor::dtor] @@ -28,14 +28,15 @@ fn cleanup() { pub(crate) struct TestApp { pub router: Router, - pub mocked_stripe_server: MockedStripeServer, + pub mock_server: MockServer, } /// Initialize a router with an in-memory sqlite database for each test. pub(crate) async fn app() -> TestApp { let pg_pool = pgpool_init(PG.get_unique_uri().as_str()).await.unwrap(); - let mocked_stripe_server = MockedStripeServer::default(); + let mock_server = MockServer::start().await; + // Insert an admin user for the tests. query("INSERT INTO users (account_name, key, account_tier) VALUES ($1, $2, $3)") .bind("admin") @@ -45,19 +46,19 @@ pub(crate) async fn app() -> TestApp { .await .unwrap(); - let router = ApiBuilder::new() + let router = ApiBuilder::new("LS0tLS1CRUdJTiBQUklWQVRFIEtFWS0tLS0tCk1DNENBUUF3QlFZREsyVndCQ0lFSUR5V0ZFYzhKYm05NnA0ZGNLTEwvQWNvVUVsbUF0MVVKSTU4WTc4d1FpWk4KLS0tLS1FTkQgUFJJVkFURSBLRVktLS0tLQo=".to_string()) .with_pg_pool(pg_pool) .with_sessions() .with_stripe_client(stripe::Client::from_url( - mocked_stripe_server.uri.to_string().as_str(), - "", + mock_server.uri().as_str(), + STRIPE_TEST_KEY, )) - .with_jwt_signing_private_key("LS0tLS1CRUdJTiBQUklWQVRFIEtFWS0tLS0tCk1DNENBUUF3QlFZREsyVndCQ0lFSUR5V0ZFYzhKYm05NnA0ZGNLTEwvQWNvVUVsbUF0MVVKSTU4WTc4d1FpWk4KLS0tLS1FTkQgUFJJVkFURSBLRVktLS0tLQo=".to_string()) + .with_rds_price_id(STRIPE_TEST_RDS_PRICE_ID.to_string()) .into_router(); TestApp { router, - mocked_stripe_server, + mock_server, } } @@ -136,91 +137,27 @@ impl TestApp { Claim::from_token(token, &public_key).unwrap() } -} - -#[derive(Clone)] -pub(crate) struct MockedStripeServer { - uri: http::Uri, - router: Router, -} - -#[derive(Clone)] -pub(crate) struct RouterState { - subscription_cancel_side_effect_toggle: Arc>, -} - -impl MockedStripeServer { - async fn subscription_retrieve_handler( - Path(subscription_id): Path, - State(state): State, - ) -> axum::response::Response { - let is_sub_cancelled = state - .subscription_cancel_side_effect_toggle - .lock() - .unwrap() - .to_owned(); - - if subscription_id == "sub_123" { - if is_sub_cancelled { - return Response::new(MOCKED_SUBSCRIPTIONS[3].to_string()); - } else { - let mut toggle = state.subscription_cancel_side_effect_toggle.lock().unwrap(); - *toggle = true; - return Response::new(MOCKED_SUBSCRIPTIONS[2].to_string()); - } - } - - let sessions = MOCKED_SUBSCRIPTIONS - .iter() - .filter(|sub| sub.contains(format!("\"id\": \"{}\"", subscription_id).as_str())) - .map(|sub| serde_json::from_str(sub).unwrap()) - .collect::>(); - if sessions.len() == 1 { - return Response::new(sessions[0].to_string()); - } - - Response::builder() - .status(http::StatusCode::NOT_FOUND) - .body("subscription id not found".to_string()) - .unwrap() - } - pub(crate) async fn serve(self) { - let address = &SocketAddr::from_str( - format!("{}:{}", self.uri.host().unwrap(), self.uri.port().unwrap()).as_str(), - ) - .unwrap(); - println!("serving on: {}", address); - Server::bind(address) - .serve(self.router.into_make_service()) - .await - .unwrap_or_else(|_| panic!("Failed to bind to address: {}", self.uri)); - } -} - -impl Default for MockedStripeServer { - fn default() -> MockedStripeServer { - let router_state = RouterState { - subscription_cancel_side_effect_toggle: Arc::new(Mutex::new(false)), - }; - - let router = Router::new() - .route( - "/v1/subscriptions/:subscription_id", - get(MockedStripeServer::subscription_retrieve_handler), - ) - .with_state(router_state); - - MockedStripeServer { - uri: http::Uri::from_str( - format!( - "http://127.0.0.1:{}", - portpicker::pick_unused_port().unwrap() - ) - .as_str(), + /// A test util to get a user with a subscription, mocking the response from Stripe. A key part + /// of this util is `mount_as_scoped`, since getting a user with a subscription can be done + /// several times in a test, if they're not scoped the first mock would always apply. + pub async fn get_user_with_mocked_stripe( + &self, + subscription_id: &str, + response_body: &str, + account_name: &str, + ) -> Response { + // This mock will apply until the end of this function scope. + let _mock_guard = Mock::given(method("GET")) + .and(bearer_token(STRIPE_TEST_KEY)) + .and(path(format!("/v1/subscriptions/{subscription_id}"))) + .respond_with( + ResponseTemplate::new(200) + .set_body_json(serde_json::from_str::(response_body).unwrap()), ) - .unwrap(), - router, - } + .mount_as_scoped(&self.mock_server) + .await; + + self.get_user(account_name).await } } diff --git a/auth/tests/api/users.rs b/auth/tests/api/users.rs index d5062aca3..57b5885fc 100644 --- a/auth/tests/api/users.rs +++ b/auth/tests/api/users.rs @@ -1,13 +1,17 @@ mod needs_docker { - use std::time::Duration; - use crate::{ - helpers::{self, app}, + helpers::{self, app, ADMIN_KEY, STRIPE_TEST_KEY, STRIPE_TEST_RDS_PRICE_ID}, stripe::{MOCKED_CHECKOUT_SESSIONS, MOCKED_SUBSCRIPTIONS}, }; use axum::body::Body; + use http::header::CONTENT_TYPE; use hyper::http::{header::AUTHORIZATION, Request, StatusCode}; use serde_json::{self, Value}; + use shuttle_common::backends::subscription::NewSubscriptionItem; + use wiremock::{ + matchers::{bearer_token, body_string_contains, method, path}, + Mock, ResponseTemplate, + }; #[tokio::test] async fn post_user() { @@ -114,10 +118,6 @@ mod needs_docker { async fn successful_upgrade_to_pro() { let app = app().await; - // Wait for the mocked Stripe server to start. - tokio::task::spawn(app.mocked_stripe_server.clone().serve()); - tokio::time::sleep(Duration::from_secs(1)).await; - // POST user first so one exists in the database. let response = app.post_user("test-user", "basic").await; @@ -126,12 +126,24 @@ mod needs_docker { let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); let expected_user: Value = serde_json::from_slice(&body).unwrap(); + // PUT /users/test-user/pro with a completed checkout session to upgrade a user to pro. let response = app .put_user("test-user", "pro", MOCKED_CHECKOUT_SESSIONS[0]) .await; assert_eq!(response.status(), StatusCode::OK); - let response = app.get_user("test-user").await; + // Next we're going to fetch the user, which will trigger a sync of the users tier. It will + // fetch the subscription from stripe using the subscription ID from the previous checkout + // session. This should return an active subscription, meaning the users tier should remain + // pro. + let response = app + .get_user_with_mocked_stripe( + "sub_1Nw8xOD8t1tt0S3DtwAuOVp6", + MOCKED_SUBSCRIPTIONS[0], + "test-user", + ) + .await; + assert_eq!(response.status(), StatusCode::OK); let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); let actual_user: Value = serde_json::from_slice(&body).unwrap(); @@ -174,10 +186,6 @@ mod needs_docker { async fn unsuccessful_upgrade_to_pro() { let app = app().await; - // Wait for the mocked Stripe server to start. - tokio::task::spawn(app.mocked_stripe_server.clone().serve()); - tokio::time::sleep(Duration::from_secs(1)).await; - // POST user first so one exists in the database. let response = app.post_user("test-user", "basic").await; assert_eq!(response.status(), StatusCode::OK); @@ -197,10 +205,6 @@ mod needs_docker { async fn downgrade_in_case_subscription_due_payment() { let app = app().await; - // Wait for the mocked Stripe server to start. - tokio::task::spawn(app.mocked_stripe_server.clone().serve()); - tokio::time::sleep(Duration::from_secs(1)).await; - // POST user first so one exists in the database. let response = app.post_user("test-user", "basic").await; assert_eq!(response.status(), StatusCode::OK); @@ -211,9 +215,18 @@ mod needs_docker { .await; assert_eq!(response.status(), StatusCode::OK); - // This get_user request should check the subscription status and return an accurate tier. - let response = app.get_user("test-user").await; + // The auth service should call stripe to fetch the subscription with the sub id from the + // checkout session, and return a subscription that is pending payment. + let response = app + .get_user_with_mocked_stripe( + "sub_1NwObED8t1tt0S3Dq0IYOEsa", + MOCKED_SUBSCRIPTIONS[1], + "test-user", + ) + .await; + assert_eq!(response.status(), StatusCode::OK); + let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); let actual_user: Value = serde_json::from_slice(&body).unwrap(); @@ -227,6 +240,110 @@ mod needs_docker { ); } + #[tokio::test] + async fn update_subscription_endpoint_requires_jwt() { + let app = app().await; + + let subscription_item = serde_json::to_string(&NewSubscriptionItem::new( + shuttle_common::backends::subscription::SubscriptionItem::AwsRds, + 1, + )) + .unwrap(); + + // POST /users/subscription/items without bearer JWT. + let request = Request::builder() + .uri("/users/subscription/items") + .method("POST") + .header(CONTENT_TYPE, "application/json") + .body(Body::from(subscription_item.clone())) + .unwrap(); + + let response = app.send_request(request).await; + + assert_eq!(response.status(), StatusCode::UNAUTHORIZED); + + // POST /users/subscription/items with invalid bearer JWT. + let request = Request::builder() + .uri("/users/subscription/items") + .method("POST") + .header(CONTENT_TYPE, "application/json") + .header(AUTHORIZATION, "invalid token") + .body(Body::from(subscription_item.clone())) + .unwrap(); + + let response = app.send_request(request).await; + + assert_eq!(response.status(), StatusCode::BAD_REQUEST); + + // GET /auth/key with the api key of the admin user to get their jwt. + let response = app.get_jwt_from_api_key(ADMIN_KEY).await; + + assert_eq!(response.status(), StatusCode::OK); + + // Extract the token. + let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); + let convert: Value = serde_json::from_slice(&body).unwrap(); + let token = convert["token"].as_str().unwrap(); + + // POST /users/:account_name with valid JWT. + let request = || { + Request::builder() + .uri("/users/subscription/items") + .method("POST") + .header(CONTENT_TYPE, "application/json") + .header(AUTHORIZATION, format!("Bearer {}", token)) + .body(Body::from(subscription_item.clone())) + .unwrap() + }; + + let response = app.send_request(request()).await; + + // The test user (claim subject) does not have a subscription ID. + assert_eq!(response.status(), StatusCode::BAD_REQUEST); + + // Upgrade the user to pro so they have subscription ID. + let response = app + .put_user("admin", "pro", MOCKED_CHECKOUT_SESSIONS[0]) + .await; + + assert_eq!(response.status(), StatusCode::OK); + + // We now want to retry the request. + // In the process the auth service will try to sync the tier, fetching the subscription + // from stripe. + Mock::given(method("GET")) + .and(bearer_token(STRIPE_TEST_KEY)) + .and(path("/v1/subscriptions/sub_1Nw8xOD8t1tt0S3DtwAuOVp6")) + .respond_with( + ResponseTemplate::new(200) + .set_body_json(serde_json::from_str::(MOCKED_SUBSCRIPTIONS[0]).unwrap()), + ) + .mount(&app.mock_server) + .await; + + // We just return a mocked active subscription without the RDS items, our logic doesn't check + // the subscription after updating, if it receives a 200 and a correctly formed subscription + // response we know that the update succeeded. + // We also want to ensure it's called with the correct price_id, the one the auth serviec was + // started with, as well as the quantity field. + Mock::given(method("POST")) + .and(bearer_token(STRIPE_TEST_KEY)) + .and(path("/v1/subscriptions/sub_1Nw8xOD8t1tt0S3DtwAuOVp6")) + .and(body_string_contains(STRIPE_TEST_RDS_PRICE_ID)) + .and(body_string_contains("quantity")) + .respond_with( + ResponseTemplate::new(200) + .set_body_json(serde_json::from_str::(MOCKED_SUBSCRIPTIONS[0]).unwrap()), + ) + .mount(&app.mock_server) + .await; + + // POST /users/:account_name with valid JWT and the user upgraded to pro. + let response = app.send_request(request()).await; + + assert_eq!(response.status(), StatusCode::OK); + } + #[tokio::test] async fn test_reset_key() { let app = app().await; @@ -255,10 +372,6 @@ mod needs_docker { async fn downgrade_from_cancelledpro() { let app = app().await; - // Wait for the mocked Stripe server to start. - tokio::task::spawn(app.mocked_stripe_server.clone().serve()); - tokio::time::sleep(Duration::from_secs(1)).await; - // Create user with basic tier let response = app.post_user("test-user", "basic").await; assert_eq!(response.status(), StatusCode::OK); @@ -269,68 +382,39 @@ mod needs_docker { .await; assert_eq!(response.status(), StatusCode::OK); - // Cancel subscription + // Cancel subscription, this will be called by the console. let response = app.put_user("test-user", "cancelledpro", "").await; assert_eq!(response.status(), StatusCode::OK); - // Trigger status change to canceled. This call has a side effect because the user has a - // subscription that is handled in a specific way by the MockedStripeServer, which changes - // the subscription state to cancelled. - let response = app.get_user("test-user").await; - assert_eq!(response.status(), StatusCode::OK); - - let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); - let user: Value = serde_json::from_slice(&body).unwrap(); - assert_eq!( - user.as_object().unwrap().get("account_tier").unwrap(), - "cancelledpro" - ); + // Fetch the user to trigger a sync of the account tier to cancelled. The account should not + // be downgraded to basic right away, since when we cancel subscriptions we pass in the + // "cancel_at_period_end" end flag. + let response = app + .get_user_with_mocked_stripe("sub_123", MOCKED_SUBSCRIPTIONS[2], "test-user") + .await; - // Check if user is downgraded to basic - let response = app.get_user("test-user").await; assert_eq!(response.status(), StatusCode::OK); let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); let user: Value = serde_json::from_slice(&body).unwrap(); - assert_eq!( user.as_object().unwrap().get("account_tier").unwrap(), - "basic" + "cancelledpro" ); - } - - #[tokio::test] - async fn retain_cancelledpro_status() { - let app = app().await; - - // Wait for the mocked Stripe server to start. - tokio::task::spawn(app.mocked_stripe_server.clone().serve()); - tokio::time::sleep(Duration::from_secs(1)).await; - - // Create user with basic tier - let response = app.post_user("test-user", "basic").await; - assert_eq!(response.status(), StatusCode::OK); - // Upgrade user to pro + // When called again at some later time, the subscription returned from stripe should be + // cancelled. let response = app - .put_user("test-user", "pro", MOCKED_CHECKOUT_SESSIONS[3]) + .get_user_with_mocked_stripe("sub_123", MOCKED_SUBSCRIPTIONS[3], "test-user") .await; assert_eq!(response.status(), StatusCode::OK); - // Cancel subscription - let response = app.put_user("test-user", "cancelledpro", "").await; - assert_eq!(response.status(), StatusCode::OK); - - // Check if user has cancelledpro status - let response = app.get_user("test-user").await; - assert_eq!(response.status(), StatusCode::OK); - let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); let user: Value = serde_json::from_slice(&body).unwrap(); assert_eq!( user.as_object().unwrap().get("account_tier").unwrap(), - "cancelledpro" + "basic" ); } } diff --git a/common/src/backends/auth.rs b/common/src/backends/auth.rs index 935ecadfd..3e0b2734a 100644 --- a/common/src/backends/auth.rs +++ b/common/src/backends/auth.rs @@ -432,7 +432,8 @@ pub trait VerifyClaim { fn verify(&self, required_scope: Scope) -> Result<(), Self::Error>; - fn verify_rds_access(&self) -> Result<(), Self::Error>; + /// Verify the claim subject has permission to provision RDS, and if they do, return their claim. + fn verify_rds_access(&self) -> Result; } #[cfg(feature = "tonic")] @@ -459,14 +460,14 @@ impl VerifyClaim for tonic::Request { } } - fn verify_rds_access(&self) -> Result<(), Self::Error> { + fn verify_rds_access(&self) -> Result { let claim = self .extensions() .get::() .ok_or_else(|| tonic::Status::internal("could not get claim"))?; if claim.can_provision_rds() { - Ok(()) + Ok(claim.clone()) } else { Err(tonic::Status::permission_denied( "don't have permission to provision rds instances", diff --git a/common/src/backends/mod.rs b/common/src/backends/mod.rs index 7cc92c0c2..ff9ed2a00 100644 --- a/common/src/backends/mod.rs +++ b/common/src/backends/mod.rs @@ -4,4 +4,5 @@ mod future; pub mod headers; pub mod metrics; mod otlp_tracing_bridge; +pub mod subscription; pub mod tracing; diff --git a/common/src/backends/subscription.rs b/common/src/backends/subscription.rs new file mode 100644 index 000000000..4e296c284 --- /dev/null +++ b/common/src/backends/subscription.rs @@ -0,0 +1,19 @@ +use serde::{Deserialize, Serialize}; + +/// Used when sending requests to the Auth service to add a new item to a user's subscription. +#[derive(Debug, Deserialize, Serialize)] +pub struct NewSubscriptionItem { + pub item: SubscriptionItem, + pub quantity: u64, +} + +impl NewSubscriptionItem { + pub fn new(item: SubscriptionItem, quantity: u64) -> NewSubscriptionItem { + NewSubscriptionItem { item, quantity } + } +} + +#[derive(Deserialize, Debug, Serialize)] +pub enum SubscriptionItem { + AwsRds, +} diff --git a/common/src/claims.rs b/common/src/claims.rs index 8fb247a52..d96c20c5d 100644 --- a/common/src/claims.rs +++ b/common/src/claims.rs @@ -316,6 +316,10 @@ impl Claim { Ok(claim) } + + pub fn token(&self) -> Option<&str> { + self.token.as_deref() + } } // Future for layers that just return the inner response diff --git a/common/src/models/deployment.rs b/common/src/models/deployment.rs index cc0f7945e..2ee22aa45 100644 --- a/common/src/models/deployment.rs +++ b/common/src/models/deployment.rs @@ -36,6 +36,7 @@ pub struct Response { pub git_commit_msg: Option, pub git_branch: Option, pub git_dirty: Option, + pub message: Option, } impl Display for Response { @@ -51,7 +52,7 @@ impl Display for Response { self.state .to_string() // Unwrap is safe because Color::from_str returns the color white if the argument is not a Color. - .with(crossterm::style::Color::from_str(self.state.get_color()).unwrap()) + .with(crossterm::style::Color::from_str(self.state.get_color()).unwrap()), ) } } diff --git a/common/src/models/service.rs b/common/src/models/service.rs index a43013184..58f98cd5a 100644 --- a/common/src/models/service.rs +++ b/common/src/models/service.rs @@ -2,7 +2,7 @@ use crate::ulid_type; use crossterm::style::{Color, Stylize}; use serde::{Deserialize, Serialize}; -use std::fmt::Display; +use std::fmt::{Display, Write}; use std::str::FromStr; #[cfg(feature = "openapi")] use utoipa::ToSchema; @@ -31,7 +31,7 @@ pub struct Summary { impl Display for Summary { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let deployment = if let Some(ref deployment) = self.deployment { - format!( + let mut summary = format!( r#" Service Name: {} Deployment ID: {} @@ -47,7 +47,12 @@ URI: {} ), deployment.last_update.format("%Y-%m-%dT%H:%M:%SZ"), self.uri, - ) + ); + // If any message is associated with the deployment, append it to the summary. + if let Some(ref message) = deployment.message { + write!(summary, "Message: {message}")?; + } + summary } else { format!( "{}\n\n", diff --git a/deployer/migrations/0006_deployment_message.sql b/deployer/migrations/0006_deployment_message.sql new file mode 100644 index 000000000..bfb4a8406 --- /dev/null +++ b/deployer/migrations/0006_deployment_message.sql @@ -0,0 +1,2 @@ +ALTER TABLE deployments +ADD COLUMN message TEXT; diff --git a/deployer/src/deployment/run.rs b/deployer/src/deployment/run.rs index 049635e57..3cacad92e 100644 --- a/deployer/src/deployment/run.rs +++ b/deployer/src/deployment/run.rs @@ -287,9 +287,11 @@ impl Built { load( self.service_name.clone(), self.service_id, + self.id, executable_path.clone(), resource_manager, runtime_client.clone(), + deployment_updater.clone(), self.claim, self.secrets, ) @@ -308,18 +310,21 @@ impl Built { } } +#[allow(clippy::too_many_arguments)] async fn load( service_name: String, service_id: Ulid, + deployment_id: Uuid, executable_path: PathBuf, mut resource_manager: impl ResourceManager, mut runtime_client: RuntimeClient>>, + deployment_updater: impl DeploymentUpdater, claim: Claim, mut secrets: HashMap, ) -> Result<()> { info!("Loading resources"); - let resources = resource_manager + let resources: Vec<_> = resource_manager .get_resources(&service_id, claim.clone()) .await .map_err(|err| Error::Load(err.to_string()))? @@ -352,16 +357,26 @@ async fn load( } } }) - .map(resource::Response::into_bytes) .collect(); + // Check whether or not any rds instances are cached (already provisioned). + let cached_resources_contains_rds = resources.iter().any(|resource| { + matches!( + resource.r#type, + shuttle_common::resource::Type::Database(shuttle_common::database::Type::AwsRds(_)) + ) + }); + let mut load_request = tonic::Request::new(LoadRequest { path: executable_path .into_os_string() .into_string() .unwrap_or_default(), service_name: service_name.clone(), - resources, + resources: resources + .into_iter() + .map(resource::Response::into_bytes) + .collect(), secrets, }); @@ -379,23 +394,54 @@ async fn load( info!("successfully loaded service"); } - let resources = response + let resources: Vec<_> = response .resources .into_iter() .map(|res| { let resource: resource::Response = serde_json::from_slice(&res).unwrap(); - record_request::Resource { - r#type: resource.r#type.to_string(), - config: resource.config.to_string().into_bytes(), - data: resource.data.to_string().into_bytes(), - } + resource }) .collect(); + + // Check whether any rds instances were included in the service resources after loading. + let loaded_resources_contains_rds = resources.iter().any(|resource| { + matches!( + resource.r#type, + shuttle_common::resource::Type::Database( + shuttle_common::database::Type::AwsRds(_) + ) + ) + }); + + let resources: Vec<_> = resources + .into_iter() + .map(|resource| record_request::Resource { + r#type: resource.r#type.to_string(), + config: resource.config.to_string().into_bytes(), + data: resource.data.to_string().into_bytes(), + }) + .collect(); + resource_manager .insert_resources(resources, &service_id, claim.clone()) .await .expect("to add resource to persistence"); + // If any rds instances were not cached, and they were returned from the runtime load, + // we know that these rds instances were provisioned for the first time. + if !cached_resources_contains_rds && loaded_resources_contains_rds { + deployment_updater + .set_message( + &deployment_id, + "This deployment increased the cost of your subscription. To check your total, log in to the Shuttle web console.", + ) + .await + .map_err(|err| { + error!(error = %err, "failed to set deployment message"); + Error::Load("failed to set deployment message".to_string()) + })?; + } + if response.success { Ok(()) } else { diff --git a/deployer/src/deployment/state_change_layer.rs b/deployer/src/deployment/state_change_layer.rs index 42f186e37..42236dfb7 100644 --- a/deployer/src/deployment/state_change_layer.rs +++ b/deployer/src/deployment/state_change_layer.rs @@ -386,6 +386,10 @@ mod tests { async fn set_is_next(&self, _id: &Uuid, _is_next: bool) -> Result<(), Self::Err> { Ok(()) } + + async fn set_message(&self, _id: &Uuid, _message: &str) -> Result<(), Self::Err> { + Ok(()) + } } #[derive(Clone)] diff --git a/deployer/src/handlers/mod.rs b/deployer/src/handlers/mod.rs index fa9faae8c..e22d67ea6 100644 --- a/deployer/src/handlers/mod.rs +++ b/deployer/src/handlers/mod.rs @@ -442,6 +442,7 @@ pub async fn create_service( .git_branch .map(|s| s.chars().take(GIT_STRINGS_MAX_LENGTH).collect()), git_dirty: deployment_req.git_dirty, + message: None, }; persistence.insert_deployment(&deployment).await?; diff --git a/deployer/src/persistence/deployment.rs b/deployer/src/persistence/deployment.rs index ff8fb7907..b10c6c2f9 100644 --- a/deployer/src/persistence/deployment.rs +++ b/deployer/src/persistence/deployment.rs @@ -23,6 +23,7 @@ pub struct Deployment { pub git_commit_msg: Option, pub git_branch: Option, pub git_dirty: Option, + pub message: Option, } impl FromRow<'_, SqliteRow> for Deployment { @@ -51,6 +52,7 @@ impl FromRow<'_, SqliteRow> for Deployment { git_commit_msg: row.try_get("git_commit_msg")?, git_branch: row.try_get("git_branch")?, git_dirty: row.try_get("git_dirty")?, + message: row.try_get("message")?, }) } } @@ -66,6 +68,7 @@ impl From for shuttle_common::models::deployment::Response { git_commit_msg: deployment.git_commit_msg, git_branch: deployment.git_branch, git_dirty: deployment.git_dirty, + message: deployment.message, } } } @@ -80,6 +83,9 @@ pub trait DeploymentUpdater: Clone + Send + Sync + 'static { /// Set if a deployment is build on shuttle-next async fn set_is_next(&self, id: &Uuid, is_next: bool) -> Result<(), Self::Err>; + + /// Associate messages with a deployment. + async fn set_message(&self, id: &Uuid, message: &str) -> Result<(), Self::Err>; } #[derive(Debug, PartialEq, Eq)] diff --git a/deployer/src/persistence/mod.rs b/deployer/src/persistence/mod.rs index 35a182423..7b8aeb466 100644 --- a/deployer/src/persistence/mod.rs +++ b/deployer/src/persistence/mod.rs @@ -222,7 +222,7 @@ impl Persistence { pub async fn insert_deployment(&self, deployment: impl Into<&Deployment>) -> Result<()> { let deployment: &Deployment = deployment.into(); - sqlx::query("INSERT INTO deployments VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)") + sqlx::query("INSERT INTO deployments VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)") .bind(deployment.id) .bind(deployment.service_id.to_string()) .bind(deployment.state) @@ -233,6 +233,7 @@ impl Persistence { .bind(deployment.git_commit_msg.as_ref()) .bind(deployment.git_branch.as_ref()) .bind(deployment.git_dirty) + .bind(deployment.message.as_ref()) .execute(&self.pool) .await .map(|_| ()) @@ -637,6 +638,16 @@ impl DeploymentUpdater for Persistence { .map(|_| ()) .map_err(Error::from) } + + async fn set_message(&self, id: &Uuid, message: &str) -> Result<()> { + sqlx::query("UPDATE deployments SET message = ? WHERE id = ?") + .bind(message) + .bind(id) + .execute(&self.pool) + .await + .map(|_| ()) + .map_err(Error::from) + } } #[async_trait::async_trait] @@ -716,11 +727,13 @@ mod tests { p.set_address(&id, &address).await.unwrap(); p.set_is_next(&id, true).await.unwrap(); + p.set_message(&id, "dummy message").await.unwrap(); let update = p.get_deployment(&id).await.unwrap().unwrap(); assert_eq!(update.state, State::Built); assert_eq!(update.address, Some(address)); assert!(update.is_next); + assert_eq!(update.message, Some("dummy message".to_string())); assert_ne!( update.last_update, Utc.with_ymd_and_hms(2022, 4, 25, 4, 43, 33).unwrap() @@ -744,6 +757,7 @@ mod tests { git_commit_msg: None, git_branch: None, git_dirty: None, + message: None, }) .collect(); diff --git a/deployer/tests/integration_run.rs b/deployer/tests/integration_run.rs index 9a374cb46..f881bdde3 100644 --- a/deployer/tests/integration_run.rs +++ b/deployer/tests/integration_run.rs @@ -153,6 +153,10 @@ impl DeploymentUpdater for StubDeploymentUpdater { async fn set_is_next(&self, _id: &Uuid, _is_next: bool) -> Result<(), Self::Err> { Ok(()) } + + async fn set_message(&self, _id: &Uuid, _message: &str) -> Result<(), Self::Err> { + Ok(()) + } } // This test uses the kill signal to make sure a service does stop when asked to diff --git a/docker-compose.yml b/docker-compose.yml index 6fb5b0a58..4d8874145 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -55,6 +55,7 @@ services: - "--address=0.0.0.0:8000" - "--stripe-secret-key=${STRIPE_SECRET_KEY}" - "--jwt-signing-private-key=${AUTH_JWTSIGNING_PRIVATE_KEY}" + - "--stripe-rds-price-id=${STRIPE_RDS_PRICE_ID}" healthcheck: test: curl --fail http://localhost:8000/ || exit 1 interval: 1m diff --git a/provisioner/Cargo.toml b/provisioner/Cargo.toml index 93810f9f6..95a9a2224 100644 --- a/provisioner/Cargo.toml +++ b/provisioner/Cargo.toml @@ -17,6 +17,8 @@ fqdn = { workspace = true } mongodb = "2.4.0" prost = { workspace = true } rand = { workspace = true } +reqwest = { workspace = true, features = ["json"] } +serde_json = { workspace = true } sqlx = { workspace = true, features = ["postgres"] } thiserror = { workspace = true } tokio = { workspace = true, features = ["macros", "rt-multi-thread"] } @@ -28,4 +30,4 @@ tracing-subscriber = { workspace = true, features = ["default", "fmt"] } ctor = { workspace = true } once_cell = { workspace = true } portpicker = { workspace = true } -serde_json = { workspace = true } +wiremock = "0.5" diff --git a/provisioner/src/error.rs b/provisioner/src/error.rs index 3eb336e2c..a5405ef0d 100644 --- a/provisioner/src/error.rs +++ b/provisioner/src/error.rs @@ -35,7 +35,13 @@ pub enum Error { Plain(String), } -unsafe impl Send for Error {} +#[derive(Error, Debug)] +pub enum AuthClientError { + #[error["token sent to auth service was expired, retry the request"]] + ExpiredJwt, + #[error["failed to request subscription update from auth service: {0}"]] + Internal(String), +} impl From for Status { fn from(err: Error) -> Self { diff --git a/provisioner/src/lib.rs b/provisioner/src/lib.rs index 1f2f31c54..dfbdc44c5 100644 --- a/provisioner/src/lib.rs +++ b/provisioner/src/lib.rs @@ -7,11 +7,13 @@ use aws_sdk_rds::{ error::SdkError, operation::modify_db_instance::ModifyDBInstanceError, types::DbInstance, Client, }; +use error::AuthClientError; pub use error::Error; use mongodb::{bson::doc, options::ClientOptions}; use rand::Rng; use shuttle_common::backends::auth::VerifyClaim; -use shuttle_common::claims::Scope; +use shuttle_common::backends::subscription::{NewSubscriptionItem, SubscriptionItem}; +use shuttle_common::claims::{AccountTier, Scope}; use shuttle_common::models::project::ProjectName; pub use shuttle_proto::provisioner::provisioner_server::ProvisionerServer; use shuttle_proto::provisioner::{ @@ -21,8 +23,9 @@ use shuttle_proto::provisioner::{provisioner_server::Provisioner, DatabaseDeleti use shuttle_proto::provisioner::{Ping, Pong}; use sqlx::{postgres::PgPoolOptions, ConnectOptions, Executor, PgPool}; use tokio::time::sleep; +use tonic::transport::Uri; use tonic::{Request, Response, Status}; -use tracing::{debug, info, warn}; +use tracing::{debug, error, info, warn}; mod args; mod error; @@ -38,6 +41,8 @@ pub struct MyProvisioner { fqdn: String, internal_pg_address: String, internal_mongodb_address: String, + auth_client: reqwest::Client, + auth_uri: Uri, } impl MyProvisioner { @@ -47,6 +52,7 @@ impl MyProvisioner { fqdn: String, internal_pg_address: String, internal_mongodb_address: String, + auth_uri: Uri, ) -> Result { let pool = PgPoolOptions::new() .min_connections(4) @@ -77,6 +83,8 @@ impl MyProvisioner { fqdn, internal_pg_address, internal_mongodb_address, + auth_client: reqwest::Client::new(), + auth_uri, }) } @@ -249,7 +257,7 @@ impl MyProvisioner { async fn request_aws_rds( &self, project_name: &str, - engine: aws_rds::Engine, + engine: &aws_rds::Engine, ) -> Result { let client = &self.rds_client; @@ -338,6 +346,43 @@ impl MyProvisioner { }) } + /// Send a request to the auth service with new subscription items that should be added to + /// the subscription of the [Claim] subject. + pub async fn add_subscription_items( + &self, + jwt: &str, + subscription_item: NewSubscriptionItem, + ) -> Result<(), AuthClientError> { + let response = self + .auth_client + .post(format!("{}users/subscription/items", self.auth_uri)) + .bearer_auth(jwt) + .json(&subscription_item) + .send() + .await + .map_err(|err| { + error!(error = %err, "failed to connect to auth service"); + AuthClientError::Internal("failed to connect to auth service".to_string()) + })?; + + match response.status().as_u16() { + 200 => Ok(()), + 499 => { + error!( + status_code = 499, + "failed to update subscription due to expired jwt" + ); + Err(AuthClientError::ExpiredJwt) + } + status_code => { + error!(status_code = status_code, "failed to update subscription"); + Err(AuthClientError::Internal( + "failed to update subscription".to_string(), + )) + } + } + } + async fn delete_shared_db( &self, project_name: &str, @@ -428,7 +473,7 @@ impl MyProvisioner { async fn delete_aws_rds( &self, project_name: &str, - engine: aws_rds::Engine, + engine: &aws_rds::Engine, ) -> Result { let client = &self.rds_client; let instance_name = format!("{project_name}-{engine}"); @@ -470,9 +515,31 @@ impl Provisioner for MyProvisioner { .await? } DbType::AwsRds(AwsRds { engine }) => { - can_provision_rds?; - self.request_aws_rds(&request.project_name, engine.expect("engine to be set")) - .await? + let claim = can_provision_rds?; + + let engine = engine.expect("engine should be set"); + + let response = self.request_aws_rds(&request.project_name, &engine).await?; + + // Skip updating subscriptions for admin users. + if claim.tier != AccountTier::Admin { + // If the subscription update fails, e.g. due to a JWT expiring or the subject's + // subscription expiring, delete the instance immediately. + if let Err(err) = self + .add_subscription_items( + // The token should be set on the claim in the JWT auth layer. + claim.token().expect("claim should have a token"), + NewSubscriptionItem::new(SubscriptionItem::AwsRds, 1), + ) + .await + { + self.delete_aws_rds(&request.project_name, &engine).await?; + + return Err(Status::internal(err.to_string())); + }; + } + + response } }; @@ -498,7 +565,7 @@ impl Provisioner for MyProvisioner { .await? } DbType::AwsRds(AwsRds { engine }) => { - self.delete_aws_rds(&request.project_name, engine.expect("engine to be set")) + self.delete_aws_rds(&request.project_name, &engine.expect("engine to be set")) .await? } }; @@ -552,7 +619,7 @@ async fn wait_for_instance( } } -fn engine_to_port(engine: aws_rds::Engine) -> String { +fn engine_to_port(engine: &aws_rds::Engine) -> String { match engine { aws_rds::Engine::Postgres(_) => "5432".to_string(), aws_rds::Engine::Mariadb(_) => "3306".to_string(), diff --git a/provisioner/src/main.rs b/provisioner/src/main.rs index 99886134f..844b4ea49 100644 --- a/provisioner/src/main.rs +++ b/provisioner/src/main.rs @@ -33,6 +33,7 @@ async fn main() -> Result<(), Box> { fqdn.to_string(), internal_pg_address, internal_mongodb_address, + auth_uri.clone(), ) .await .unwrap(); diff --git a/provisioner/tests/provisioner.rs b/provisioner/tests/provisioner.rs index 98b968e29..be7f1007b 100644 --- a/provisioner/tests/provisioner.rs +++ b/provisioner/tests/provisioner.rs @@ -2,9 +2,16 @@ mod helpers; use ctor::dtor; use helpers::{exec_mongosh, exec_psql, DbType, DockerInstance}; use once_cell::sync::Lazy; +use reqwest::header::{AUTHORIZATION, CONTENT_TYPE}; use serde_json::Value; +use shuttle_common::backends::subscription::{NewSubscriptionItem, SubscriptionItem}; use shuttle_proto::provisioner::shared; use shuttle_provisioner::MyProvisioner; +use tonic::transport::Uri; +use wiremock::{ + matchers::{body_json, header, header_exists, method, path}, + MockServer, ResponseTemplate, +}; static PG: Lazy = Lazy::new(|| DockerInstance::new(DbType::Postgres)); static MONGODB: Lazy = Lazy::new(|| DockerInstance::new(DbType::MongoDb)); @@ -15,6 +22,41 @@ fn cleanup() { MONGODB.cleanup(); } +#[tokio::test] +async fn correctly_calls_auth_service_to_add_rds_subscription_item() { + let mock_server = MockServer::start().await; + + let provisioner = MyProvisioner::new( + &PG.uri, + &MONGODB.uri, + "fqdn".to_string(), + "pg".to_string(), + "mongodb".to_string(), + // Pass in the mock server's URI as the auth URI. + mock_server.uri().parse::().unwrap(), + ) + .await + .unwrap(); + + let subscription_item = || NewSubscriptionItem::new(SubscriptionItem::AwsRds, 1); + + // Respond with a 200 for a correctly formed request. + wiremock::Mock::given(method("POST")) + .and(path("/users/subscription/items")) + .and(header(CONTENT_TYPE, "application/json")) + .and(header_exists(AUTHORIZATION)) + .and(body_json(subscription_item())) + .respond_with(ResponseTemplate::new(200)) + .mount(&mock_server) + .await; + + let res = provisioner + .add_subscription_items("jwt", subscription_item()) + .await; + + assert!(res.is_ok()); +} + mod needs_docker { use super::*; @@ -26,6 +68,7 @@ mod needs_docker { "fqdn".to_string(), "pg".to_string(), "mongodb".to_string(), + Uri::from_static("http://127.0.0.1:8008"), ) .await .unwrap(); @@ -54,6 +97,7 @@ mod needs_docker { "fqdn".to_string(), "pg".to_string(), "mongodb".to_string(), + Uri::from_static("http://127.0.0.1:8008"), ) .await .unwrap(); @@ -84,6 +128,7 @@ mod needs_docker { "fqdn".to_string(), "pg".to_string(), "mongodb".to_string(), + Uri::from_static("http://127.0.0.1:8008"), ) .await .unwrap(); @@ -105,6 +150,7 @@ mod needs_docker { "fqdn".to_string(), "pg".to_string(), "mongodb".to_string(), + Uri::from_static("http://127.0.0.1:8008"), ) .await .unwrap(); @@ -133,6 +179,7 @@ mod needs_docker { "fqdn".to_string(), "pg".to_string(), "mongodb".to_string(), + Uri::from_static("http://127.0.0.1:8008"), ) .await .unwrap(); @@ -163,6 +210,7 @@ mod needs_docker { "fqdn".to_string(), "pg".to_string(), "mongodb".to_string(), + Uri::from_static("http://127.0.0.1:8008"), ) .await .unwrap(); @@ -187,6 +235,7 @@ mod needs_docker { "fqdn".to_string(), "pg".to_string(), "mongodb".to_string(), + Uri::from_static("http://127.0.0.1:8008"), ) .await .unwrap();