Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(auth): add subscriptions table to auth, add rds quota to claim limits #1529

Merged
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion auth/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ async-trait = { workspace = true }
axum = { workspace = true, features = ["headers"] }
axum-sessions = { workspace = true }
base64 = { workspace = true }
chrono = { workspace = true }
clap = { workspace = true }
http = { workspace = true }
jsonwebtoken = { workspace = true }
Expand All @@ -26,7 +27,7 @@ pem = "2"
rand = { workspace = true }
ring = { workspace = true }
serde = { workspace = true, features = ["derive"] }
sqlx = { workspace = true, features = ["postgres", "json", "migrate"] }
sqlx = { workspace = true, features = ["postgres", "json", "migrate", "chrono"] }
strum = { workspace = true }
thiserror = { workspace = true }
tokio = { workspace = true, features = ["full"] }
Expand All @@ -40,6 +41,7 @@ ctor = { workspace = true }
hyper = { workspace = true }
once_cell = { workspace = true }
portpicker = { workspace = true }
pretty_assertions = { workspace = true }
serde_json = { workspace = true }
shuttle-common-tests = { workspace = true }
tower = { workspace = true, features = ["util"] }
Expand Down
10 changes: 10 additions & 0 deletions auth/migrations/0001_sync_updated_at.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
-- Create a function (that can be registered on triggers) to automatically set updated_at to current_timestamp
CREATE OR REPLACE FUNCTION sync_updated_at()
RETURNS TRIGGER
LANGUAGE PLPGSQL
AS $$
BEGIN
NEW.updated_at = current_timestamp;
RETURN NEW;
END;
$$
27 changes: 27 additions & 0 deletions auth/migrations/0002_subscriptions.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
CREATE TABLE IF NOT EXISTS subscriptions (
subscription_id TEXT PRIMARY KEY,
account_name TEXT NOT NULL,
type TEXT NOT NULL,
quantity INT DEFAULT 1,
created_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP,
Kazy marked this conversation as resolved.
Show resolved Hide resolved
UNIQUE (account_name, type),
FOREIGN KEY (account_name) REFERENCES users(account_name)
);

-- Create a trigger to automatically update the updated_at column
CREATE TRIGGER sync_users_updated_at
BEFORE UPDATE
ON subscriptions
FOR EACH ROW
EXECUTE PROCEDURE sync_updated_at();

-- Insert existing subscriptions into the new subscriptions table, all of which are of the pro type
INSERT INTO subscriptions (subscription_id, account_name, type)
SELECT subscription_id, account_name, 'pro'
FROM users
WHERE subscription_id IS NOT NULL;

-- Drop the subscription_id column from the users table
ALTER TABLE users
DROP COLUMN subscription_id;
chesedo marked this conversation as resolved.
Show resolved Hide resolved
5 changes: 4 additions & 1 deletion auth/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,10 @@ pub async fn pgpool_init(db_uri: &str) -> io::Result<PgPool> {
let pool = PgPool::connect_with(opts)
.await
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
MIGRATIONS.run(&pool).await.unwrap();
MIGRATIONS
.run(&pool)
.await
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;

Ok(pool)
}
144 changes: 115 additions & 29 deletions auth/src/user.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@ use axum::{
http::request::Parts,
TypedHeader,
};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Deserializer, Serialize};
use shuttle_common::{backends::headers::XShuttleAdminSecret, claims::AccountTier, ApiKey, Secret};
use sqlx::{postgres::PgRow, query, FromRow, PgPool, Row};
use strum::EnumString;
use tracing::{debug, error, trace, Span};

use crate::{api::UserManagerState, error::Error};
Expand Down Expand Up @@ -49,7 +51,7 @@ impl UserManagement for UserManager {
.execute(&self.pool)
.await?;

Ok(User::new(name, key, tier, None))
Ok(User::new(name, key, tier, vec![]))
}

// Update user tier to pro and update the subscription id.
Expand All @@ -75,16 +77,34 @@ impl UserManagement for UserManager {
})
.ok_or(Error::MissingSubscriptionId)?;

// Update the user account tier and subscription_id.
let rows_affected = query(
"UPDATE users SET account_tier = $1, subscription_id = $2 WHERE account_name = $3",
// Update the user account tier and insert or update their pro subscription.
let mut transaction = self.pool.begin().await?;

let rows_affected = query("UPDATE users SET account_tier = $1 WHERE account_name = $2")
.bind(AccountTier::Pro.to_string())
.bind(name)
.execute(&mut *transaction)
.await?
.rows_affected();

// Insert a new pro subscription. If a pro subscription already exists, update the
// subscription id.
// NOTE: we do not increase the quantity if a pro subscription exists, because it
// should never be increased.
query(
r#"INSERT INTO subscriptions (subscription_id, account_name, type)
VALUES ($1, $2, $3)
ON CONFLICT (account_name, type)
DO UPDATE SET subscription_id = EXCLUDED.subscription_id
"#,
)
.bind(AccountTier::Pro.to_string())
.bind(subscription_id)
.bind(&subscription_id)
.bind(name)
.execute(&self.pool)
.await?
.rows_affected();
.bind(ShuttleSubscriptionType::Pro.to_string())
.execute(&mut *transaction)
.await?;

transaction.commit().await?;

// In case no rows were updated, this means the account doesn't exist.
if rows_affected > 0 {
Expand Down Expand Up @@ -114,12 +134,24 @@ impl UserManagement for UserManager {
}

async fn get_user(&self, name: AccountName) -> Result<User, Error> {
let mut user: User =
sqlx::query_as("SELECT account_name, key, account_tier, subscription_id FROM users WHERE account_name = $1")
.bind(&name)
.fetch_optional(&self.pool)
.await?
.ok_or(Error::UserNotFound)?;
let mut user: User = sqlx::query_as(
"SELECT account_name, key, account_tier FROM users WHERE account_name = $1",
)
.bind(&name)
.fetch_optional(&self.pool)
.await?
.ok_or(Error::UserNotFound)?;

let subscriptions: Vec<Subscription> = sqlx::query_as(
"SELECT subscription_id, type, quantity, created_at, updated_at FROM subscriptions WHERE account_name = $1",
)
.bind(&user.name.to_string())
.fetch_all(&self.pool)
.await?;

if !subscriptions.is_empty() {
user.subscriptions = subscriptions;
}

// Sync the user tier based on the subscription validity, if any.
if let Err(err) = user.sync_tier(self).await {
Expand All @@ -133,13 +165,23 @@ impl UserManagement for UserManager {
}

async fn get_user_by_key(&self, key: ApiKey) -> Result<User, Error> {
let mut user: User = sqlx::query_as(
"SELECT account_name, key, account_tier, subscription_id FROM users WHERE key = $1",
let mut user: User =
sqlx::query_as("SELECT account_name, key, account_tier FROM users WHERE key = $1")
.bind(&key)
.fetch_optional(&self.pool)
.await?
.ok_or(Error::UserNotFound)?;

let subscriptions: Vec<Subscription> = sqlx::query_as(
"SELECT subscription_id, type, quantity, created_at, updated_at FROM subscriptions WHERE account_name = $1",
)
.bind(&key)
.fetch_optional(&self.pool)
.await?
.ok_or(Error::UserNotFound)?;
.bind(&user.name.to_string())
.fetch_all(&self.pool)
.await?;

if !subscriptions.is_empty() {
user.subscriptions = subscriptions;
}

// Sync the user tier based on the subscription validity, if any.
if user.sync_tier(self).await? {
Expand Down Expand Up @@ -167,36 +209,59 @@ impl UserManagement for UserManager {
}
}

#[derive(Clone, Deserialize, PartialEq, Eq, Serialize, Debug)]
#[derive(Clone, Debug)]
pub struct User {
pub name: AccountName,
pub key: Secret<ApiKey>,
pub account_tier: AccountTier,
pub subscription_id: Option<SubscriptionId>,
pub subscriptions: Vec<Subscription>,
}

#[derive(Clone, Debug)]
pub struct Subscription {
pub id: stripe::SubscriptionId,
pub r#type: ShuttleSubscriptionType,
pub quantity: i32,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}

#[derive(Clone, Debug, EnumString, strum::Display)]
#[strum(serialize_all = "lowercase")]
pub enum ShuttleSubscriptionType {
Pro,
Rds,
}

impl User {
pub fn is_admin(&self) -> bool {
self.account_tier == AccountTier::Admin
}

pub fn pro_subscription_id(&self) -> Option<&stripe::SubscriptionId> {
self.subscriptions
.iter()
.find(|sub| matches!(sub.r#type, ShuttleSubscriptionType::Pro))
.map(|sub| &sub.id)
}

pub fn new(
name: AccountName,
key: ApiKey,
account_tier: AccountTier,
subscription_id: Option<SubscriptionId>,
subscriptions: Vec<Subscription>,
) -> Self {
Self {
name,
key: Secret::new(key),
account_tier,
subscription_id,
subscriptions,
}
}

/// In case of an existing subscription, check if valid.
async fn subscription_is_valid(&self, client: &stripe::Client) -> Result<bool, Error> {
if let Some(subscription_id) = self.subscription_id.as_ref() {
if let Some(subscription_id) = self.pro_subscription_id() {
let subscription = stripe::Subscription::retrieve(client, subscription_id, &[]).await?;
debug!("subscription: {:#?}", subscription);
return Ok(subscription.status == SubscriptionStatus::Active
Expand Down Expand Up @@ -259,10 +324,28 @@ impl FromRow<'_, PgRow> for User {
source: Box::new(std::io::Error::new(ErrorKind::Other, err.to_string())),
},
)?,
subscription_id: row
subscriptions: vec![],
})
}
}

impl FromRow<'_, PgRow> for Subscription {
fn from_row(row: &PgRow) -> Result<Self, sqlx::Error> {
Ok(Subscription {
id: row
.try_get("subscription_id")
.ok()
.and_then(|inner| SubscriptionId::from_str(inner).ok()),
.and_then(|inner| SubscriptionId::from_str(inner).ok())
.unwrap(),
r#type: ShuttleSubscriptionType::from_str(row.try_get("type").unwrap()).map_err(
|err| sqlx::Error::ColumnDecode {
index: "type".to_string(),
source: Box::new(std::io::Error::new(ErrorKind::Other, err.to_string())),
},
)?,
quantity: row.try_get("quantity").unwrap(),
created_at: row.try_get("created_at").unwrap(),
updated_at: row.try_get("updated_at").unwrap(),
})
}
}
Expand Down Expand Up @@ -299,7 +382,10 @@ impl From<User> for shuttle_common::models::user::Response {
name: user.name.to_string(),
key: user.key.expose().as_ref().to_owned(),
account_tier: user.account_tier.to_string(),
subscription_id: user.subscription_id.map(|inner| inner.to_string()),
// TODO: this just returns what was always returned, the id of the pro subscription.
// We will need to update this when we want to also return the rds subscription. We
// can return a vec of IDs, but it will be a breaking change for the console (I don't believe it is used anywhere else).
subscription_id: user.pro_subscription_id().map(|id| id.to_string()),
}
}
}
Expand Down
1 change: 1 addition & 0 deletions auth/tests/api/users.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ mod needs_docker {
};
use axum::body::Body;
use hyper::http::{header::AUTHORIZATION, Request, StatusCode};
use pretty_assertions::assert_eq;
use serde_json::{self, Value};

#[tokio::test]
Expand Down
24 changes: 18 additions & 6 deletions common/src/limits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,32 +10,44 @@ pub struct Limits {
/// The amount of projects this user can create.
project_limit: u32,
/// Whether this user has permission to provision RDS instances.
#[deprecated(
since = "0.38.0",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Should it be 0.37? The current version is 0.36.

Suggested change
since = "0.38.0",
since = "0.37.0",

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assumed this PR will only make it into the next release (if we release today) 🤷‍♂️

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can update this if it is released earlier 😄

note = "This was replaced with rds_quota, but old runtimes might still try to deserialize a claim expecting this field"
)]
#[serde(skip_deserializing)]
rds_access: bool,
/// The quantity of RDS instances this user can provision.
rds_quota: u32,
}

impl Default for Limits {
fn default() -> Self {
#[allow(deprecated)]
Self {
project_limit: MAX_PROJECTS_DEFAULT,
rds_access: false,
rds_quota: 0,
chesedo marked this conversation as resolved.
Show resolved Hide resolved
}
}
}

impl Limits {
pub fn new(project_limit: u32, rds_limit: bool) -> Self {
pub fn new(project_limit: u32, rds_quota: u32) -> Self {
#[allow(deprecated)]
Self {
project_limit,
rds_access: rds_limit,
rds_access: false,
rds_quota,
}
}

pub fn project_limit(&self) -> u32 {
self.project_limit
}

pub fn rds_access(&self) -> bool {
self.rds_access
/// Use the subscription quantity to set the RDS quota for this claim.
pub fn rds_quota(&mut self, quantity: u32) {
self.rds_quota = quantity;
}
}

Expand All @@ -47,7 +59,7 @@ impl From<AccountTier> for Limits {
| AccountTier::PendingPaymentPro
| AccountTier::Deployer => Self::default(),
AccountTier::Pro | AccountTier::CancelledPro | AccountTier::Team => {
Self::new(MAX_PROJECTS_EXTRA, true)
Self::new(MAX_PROJECTS_EXTRA, 1)
}
}
}
Expand All @@ -72,6 +84,6 @@ impl ClaimExt for Claim {
}

fn can_provision_rds(&self) -> bool {
self.is_admin() || self.limits.rds_access
self.is_admin() || self.limits.rds_quota > 0
}
}