Skip to content

Commit

Permalink
get our own certificate
Browse files Browse the repository at this point in the history
  • Loading branch information
brokad committed Nov 17, 2022
1 parent 7ca9c59 commit 2520411
Show file tree
Hide file tree
Showing 9 changed files with 184 additions and 66 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions gateway/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ axum = { version = "0.5.8", features = [ "headers" ] }
axum-server = { version = "0.4.4", features = [ "tls-rustls" ] }
rustls = { version = "0.20.6" }
rustls-pemfile = { version = "1.0.1" }
pem = "1.1.0"

base64 = "0.13"
bollard = "0.13"
Expand Down
19 changes: 9 additions & 10 deletions gateway/src/acme.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use std::time::Duration;

use axum::body::boxed;
use axum::response::Response;
use fqdn::Fqdn;
use futures::future::BoxFuture;
use hyper::server::conn::AddrStream;
use hyper::{Body, Request};
Expand Down Expand Up @@ -88,14 +87,14 @@ impl AcmeClient {
Ok(credentials)
}

/// Create a certificate and return it with the keys used to sign it
/// Create an ACME-signed certificate and return it and its
/// associated PEM-encoded private key
pub async fn create_certificate(
&self,
fqdn: &Fqdn,
identifier: &str,
credentials: AccountCredentials<'_>,
) -> Result<(String, Certificate), AcmeClientError> {
let fqdn = fqdn.to_string();
trace!(fqdn, "requesting acme certificate");
) -> Result<(String, String), AcmeClientError> {
trace!(identifier, "requesting acme certificate");

let account = Account::from_credentials(credentials).map_err(|error| {
error!(
Expand All @@ -107,7 +106,7 @@ impl AcmeClient {

let (mut order, state) = account
.new_order(&NewOrder {
identifiers: &[Identifier::Dns(fqdn.to_string())],
identifiers: &[Identifier::Dns(identifier.to_string())],
})
.await
.map_err(|error| {
Expand Down Expand Up @@ -155,7 +154,7 @@ impl AcmeClient {
AcmeClientError::OrderFinalizing
})?;

Ok((certificate_chain, certificate))
Ok((certificate_chain, certificate.serialize_private_key_pem()))
}

async fn complete_challenge(
Expand Down Expand Up @@ -282,12 +281,12 @@ pub struct ChallengeResponder<S> {

impl<'r, S> AsResponderTo<&'r AddrStream> for ChallengeResponder<S>
where
S: AsResponderTo<&'r AddrStream>
S: AsResponderTo<&'r AddrStream>,
{
fn as_responder_to(&self, req: &'r AddrStream) -> Self {
Self {
client: self.client.clone(),
inner: self.inner.as_responder_to(req)
inner: self.inner.as_responder_to(req),
}
}
}
Expand Down
17 changes: 10 additions & 7 deletions gateway/src/api/latest.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::io::Cursor;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
Expand All @@ -22,7 +23,7 @@ use tracing::{debug, debug_span, field, Span};
use crate::acme::AcmeClient;
use crate::auth::{Admin, ScopedUser, User};
use crate::task::{self, BoxedTask};
use crate::tls::GatewayCertResolver;
use crate::tls::{GatewayCertResolver};
use crate::worker::WORKER_QUEUE_SIZE;
use crate::{AccountName, Error, GatewayService, ProjectName};

Expand Down Expand Up @@ -215,15 +216,17 @@ async fn request_acme_certificate(
.parse()
.map_err(|_err| Error::from(ErrorKind::InvalidCustomDomain))?;

let (chain, async_keys) = acme_client.create_certificate(&fqdn, credentials).await?;
let private_key = async_keys.serialize_private_key_pem();
let (certs, private_key) = acme_client.create_certificate(&fqdn.to_string(), credentials).await?;

resolver
.serve_pem(&fqdn.to_string(), chain.as_bytes(), private_key.as_bytes())
service
.create_custom_domain(project_name, &fqdn, &certs, &private_key)
.await?;

service
.create_custom_domain(project_name, &fqdn, &chain, &private_key)
let mut buf = Vec::new();
buf.extend(certs.as_bytes());
buf.extend(private_key.as_bytes());
resolver
.serve_pem(&fqdn.to_string(), Cursor::new(buf))
.await?;

Ok("certificate created".to_string())
Expand Down
8 changes: 4 additions & 4 deletions gateway/src/args.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::net::SocketAddr;
use std::{net::SocketAddr, path::PathBuf};

use clap::{Parser, Subcommand, ValueEnum};
use fqdn::FQDN;
Expand All @@ -7,9 +7,9 @@ use crate::auth::Key;

#[derive(Parser, Debug)]
pub struct Args {
/// Uri to the `.sqlite` file used to store state
#[arg(long, default_value = "./gateway.sqlite")]
pub state: String,
/// Where to store gateway state (such as sqlite state, and certs)
#[arg(long, default_value = "./")]
pub state: PathBuf,

#[command(subcommand)]
pub command: Commands,
Expand Down
6 changes: 6 additions & 0 deletions gateway/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,12 @@ impl<T> From<SendError<T>> for Error {
}
}

impl From<io::Error> for Error {
fn from(_: io::Error) -> Self {
Self::from(ErrorKind::Internal)
}
}

impl From<AcmeClientError> for Error {
fn from(error: AcmeClientError) -> Self {
Self::source(ErrorKind::Internal, error)
Expand Down
70 changes: 60 additions & 10 deletions gateway/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use clap::Parser;
use fqdn::FQDN;
use futures::prelude::*;
use instant_acme::AccountCredentials;
use opentelemetry::global;
use shuttle_gateway::acme::AcmeClient;
use shuttle_gateway::api::latest::ApiBuilder;
Expand All @@ -9,12 +11,12 @@ use shuttle_gateway::auth::Key;
use shuttle_gateway::proxy::UserServiceBuilder;
use shuttle_gateway::service::{GatewayService, MIGRATIONS};
use shuttle_gateway::task;
use shuttle_gateway::tls::make_tls_acceptor;
use shuttle_gateway::tls::{make_tls_acceptor, ChainAndPrivateKey};
use shuttle_gateway::worker::Worker;
use sqlx::migrate::MigrateDatabase;
use sqlx::{query, Sqlite, SqlitePool};
use std::io;
use std::path::Path;
use std::io::{self, Cursor};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::Duration;
use tracing::{debug, error, info, trace, warn};
Expand Down Expand Up @@ -45,8 +47,11 @@ async fn main() -> io::Result<()> {
.with(opentelemetry)
.init();

if !Path::new(&args.state).exists() {
Sqlite::create_database(&args.state).await.unwrap();
let db_path = args.state.join("gateway.sqlite");
let db_uri = db_path.to_str().unwrap();

if !db_path.exists() {
Sqlite::create_database(db_uri).await.unwrap();
}

info!(
Expand All @@ -55,17 +60,17 @@ async fn main() -> io::Result<()> {
.unwrap()
.to_string_lossy()
);
let db = SqlitePool::connect(&args.state).await.unwrap();
let db = SqlitePool::connect(db_uri).await.unwrap();

MIGRATIONS.run(&db).await.unwrap();

match args.command {
Commands::Start(start_args) => start(db, start_args).await,
Commands::Start(start_args) => start(db, args.state, start_args).await,
Commands::Init(init_args) => init(db, init_args).await,
}
}

async fn start(db: SqlitePool, args: StartArgs) -> io::Result<()> {
async fn start(db: SqlitePool, fs: PathBuf, args: StartArgs) -> io::Result<()> {
let gateway = Arc::new(GatewayService::init(args.context.clone(), db).await);

let worker = Worker::new();
Expand Down Expand Up @@ -128,7 +133,7 @@ async fn start(db: SqlitePool, args: StartArgs) -> io::Result<()> {

let mut user_builder = UserServiceBuilder::new()
.with_service(Arc::clone(&gateway))
.with_public(args.context.proxy_fqdn)
.with_public(args.context.proxy_fqdn.clone())
.with_user_proxy_binding_to(args.user)
.with_bouncer(args.bouncer);

Expand All @@ -139,7 +144,13 @@ async fn start(db: SqlitePool, args: StartArgs) -> io::Result<()> {
.with_acme(acme_client.clone())
.with_tls(tls_acceptor);

api_builder = api_builder.with_acme(acme_client, resolver);
api_builder = api_builder.with_acme(acme_client.clone(), resolver.clone());

tokio::spawn(async move {
// make sure we have a certificate for ourselves
let certs = init_certs(fs, args.context.proxy_fqdn.clone(), acme_client.clone()).await;
resolver.serve_default_der(certs).await.unwrap();
});
} else {
warn!("TLS is disabled in the proxy service. This is only acceptable in testing, and should *never* be used in deployments.");
};
Expand Down Expand Up @@ -179,3 +190,42 @@ async fn init(db: SqlitePool, args: InitArgs) -> io::Result<()> {
println!("`{}` created as super user with key: {key}", args.name);
Ok(())
}

async fn init_certs<P: AsRef<Path>>(fs: P, public: FQDN, acme: AcmeClient) -> ChainAndPrivateKey {
let tls_path = fs.as_ref().join("ssl.pem").canonicalize().unwrap();

match ChainAndPrivateKey::load_pem(&tls_path) {
Ok(valid) => valid,
Err(_) => {
let creds_path = fs.as_ref().join("acme.json").canonicalize().unwrap();
warn!(
"no valid certificate found at {}, creating one...",
tls_path.display()
);

if !creds_path.exists() {
panic!(
"no ACME credentials found at {}, cannot continue with certificate creation",
creds_path.display()
);
}

let creds = std::fs::File::open(creds_path).unwrap();
let creds: AccountCredentials = serde_json::from_reader(&creds).unwrap();

let identifier = format!("*.{public}");

let (chain, private_key) = acme.create_certificate(&identifier, creds).await.unwrap();

let mut buf = Vec::new();
buf.extend(chain.as_bytes());
buf.extend(private_key.as_bytes());

let certs = ChainAndPrivateKey::parse_pem(Cursor::new(buf)).unwrap();

certs.clone().save_pem(&tls_path).unwrap();

certs
}
}
}
6 changes: 3 additions & 3 deletions gateway/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -454,13 +454,13 @@ impl GatewayService {
&self,
project_name: ProjectName,
fqdn: &Fqdn,
certificate: &str,
private_key: &str,
certs: &str,
private_key: &str
) -> Result<(), Error> {
query("INSERT INTO custom_domains (fqdn, project_name, certificate, private_key) VALUES (?1, ?2, ?3, ?4)")
.bind(fqdn.to_string())
.bind(&project_name)
.bind(certificate)
.bind(certs)
.bind(private_key)
.execute(&self.db)
.await
Expand Down
Loading

0 comments on commit 2520411

Please sign in to comment.