From a6e28349631bcd0237d8675d3a85dfbd1aa06a74 Mon Sep 17 00:00:00 2001 From: James Newton Date: Wed, 4 Dec 2024 11:53:42 -0800 Subject: [PATCH] Add support for namespace to the remote connection builder --- libsql-server/src/http/user/db_factory.rs | 7 ++++ libsql-server/tests/embedded_replica/mod.rs | 40 +++++++++++++++++++++ libsql/src/database.rs | 7 +++- libsql/src/database/builder.rs | 22 +++++++++--- libsql/src/hrana/hyper.rs | 27 +++++++++++--- libsql/src/local/database.rs | 3 +- 6 files changed, 95 insertions(+), 11 deletions(-) diff --git a/libsql-server/src/http/user/db_factory.rs b/libsql-server/src/http/user/db_factory.rs index 2a36024d5c..2a7c4c5752 100644 --- a/libsql-server/src/http/user/db_factory.rs +++ b/libsql-server/src/http/user/db_factory.rs @@ -50,6 +50,8 @@ pub fn namespace_from_headers( if let Some(from_metadata) = headers.get(NAMESPACE_METADATA_KEY) { try_namespace_from_metadata(from_metadata) + } else if let Some(from_ns_header) = headers.get("x-namespace") { + try_namespace_from_header(from_ns_header) } else if let Some(from_host) = headers.get("host") { try_namespace_from_host(from_host, disable_default_namespace) } else if !disable_default_namespace { @@ -59,6 +61,11 @@ pub fn namespace_from_headers( } } +fn try_namespace_from_header(header: &axum::http::HeaderValue) -> Result { + NamespaceName::from_bytes(header.as_bytes().to_vec().into()) + .map_err(|_| Error::InvalidNamespace) +} + fn try_namespace_from_host( from_host: &axum::http::HeaderValue, disable_default_namespace: bool, diff --git a/libsql-server/tests/embedded_replica/mod.rs b/libsql-server/tests/embedded_replica/mod.rs index e7b4b9f7f0..0c4a1c00a0 100644 --- a/libsql-server/tests/embedded_replica/mod.rs +++ b/libsql-server/tests/embedded_replica/mod.rs @@ -1696,3 +1696,43 @@ fn schema_db() { sim.run().unwrap(); } + +#[test] +fn remote_namespace_header_support() { + let tmp_host = tempdir().unwrap(); + let tmp_host_path = tmp_host.path().to_owned(); + + let mut sim = Builder::new() + .simulation_duration(Duration::from_secs(1000)) + .build(); + + make_primary(&mut sim, tmp_host_path.clone()); + + sim.client("client", async move { + let client = Client::new(); + + client + .post("http://primary:9090/v1/namespaces/foo/create", json!({})) + .await?; + + let db_url = "http://primary:8080"; + + let remote = libsql::Builder::new_remote(db_url.to_string(), String::new()) + .namespace("foo") + .connector(TurmoilConnector) + .build() + .await + .unwrap(); + + let conn = remote.connect().unwrap(); + + conn.execute("CREATE TABLE user (id INTEGER NOT NULL PRIMARY KEY)", ()) + .await?; + + conn.execute("INSERT into user(id) values (1);", ()).await?; + + Ok(()) + }); + + sim.run().unwrap(); +} diff --git a/libsql/src/database.rs b/libsql/src/database.rs index 5e0fa328eb..5bf7067442 100644 --- a/libsql/src/database.rs +++ b/libsql/src/database.rs @@ -95,6 +95,7 @@ enum DbType { auth_token: String, connector: crate::util::ConnectorService, version: Option, + namespace: Option, }, } @@ -238,6 +239,7 @@ cfg_replication! { OpenFlags::default(), encryption_config.clone(), None, + None, ).await?; Ok(Database { @@ -522,6 +524,7 @@ cfg_remote! { auth_token: auth_token.into(), connector: crate::util::ConnectorService::new(svc), version, + namespace: None, }, max_write_replication_index: Default::default(), }) @@ -672,7 +675,7 @@ impl Database { remote: HttpConnection::new( url.clone(), auth_token.clone(), - HttpSender::new(connector.clone(), None), + HttpSender::new(connector.clone(), None, None), ), read_your_writes: *read_your_writes, context: db.sync_ctx.clone().unwrap(), @@ -693,6 +696,7 @@ impl Database { auth_token, connector, version, + namespace, } => { let conn = std::sync::Arc::new( crate::hrana::connection::HttpConnection::new_with_connector( @@ -700,6 +704,7 @@ impl Database { auth_token, connector.clone(), version.as_ref().map(|s| s.as_str()), + namespace.as_ref().map(|s| s.as_str()), ), ); diff --git a/libsql/src/database/builder.rs b/libsql/src/database/builder.rs index 14df79ad6d..06c8e523e1 100644 --- a/libsql/src/database/builder.rs +++ b/libsql/src/database/builder.rs @@ -60,12 +60,12 @@ impl Builder<()> { auth_token, connector: None, version: None, + namespace: None, }, encryption_config: None, read_your_writes: true, sync_interval: None, http_request_callback: None, - namespace: None, skip_safety_assert: false, }, } @@ -101,6 +101,7 @@ impl Builder<()> { auth_token, connector: None, version: None, + namespace: None, }, connector: None, read_your_writes: true, @@ -119,6 +120,7 @@ impl Builder<()> { auth_token, connector: None, version: None, + namespace: None, }, } } @@ -132,6 +134,7 @@ cfg_replication_or_remote_or_sync! { auth_token: String, connector: Option, version: Option, + namespace: Option, } } @@ -220,7 +223,6 @@ cfg_replication! { read_your_writes: bool, sync_interval: Option, http_request_callback: Option, - namespace: Option, skip_safety_assert: bool, } @@ -286,7 +288,7 @@ cfg_replication! { /// Set the namespace that will be communicated to remote replica in the http header. pub fn namespace(mut self, namespace: impl Into) -> Builder { - self.inner.namespace = Some(namespace.into()); + self.inner.remote.namespace = Some(namespace.into()); self } @@ -320,12 +322,12 @@ cfg_replication! { auth_token, connector, version, + namespace, }, encryption_config, read_your_writes, sync_interval, http_request_callback, - namespace, skip_safety_assert } = self.inner; @@ -420,6 +422,7 @@ cfg_replication! { auth_token, connector, version, + namespace, }) = remote { let connector = if let Some(connector) = connector { @@ -444,6 +447,7 @@ cfg_replication! { flags, encryption_config.clone(), http_request_callback, + namespace, ) .await? } else { @@ -509,6 +513,7 @@ cfg_sync! { auth_token, connector: _, version: _, + namespace: _, }, connector, remote_writes, @@ -574,6 +579,13 @@ cfg_remote! { self } + /// Set the namespace that will be communicated to the remote in the http header. + pub fn namespace(mut self, namespace: impl Into) -> Builder + { + self.inner.namespace = Some(namespace.into()); + self + } + /// Build the remote database client. pub async fn build(self) -> Result { let Remote { @@ -581,6 +593,7 @@ cfg_remote! { auth_token, connector, version, + namespace, } = self.inner; let connector = if let Some(connector) = connector { @@ -602,6 +615,7 @@ cfg_remote! { auth_token, connector, version, + namespace, }, max_write_replication_index: Default::default(), }) diff --git a/libsql/src/hrana/hyper.rs b/libsql/src/hrana/hyper.rs index 536a371765..a433383212 100644 --- a/libsql/src/hrana/hyper.rs +++ b/libsql/src/hrana/hyper.rs @@ -25,17 +25,27 @@ pub type ByteStream = Box> + Send + Syn pub struct HttpSender { inner: hyper::Client, version: HeaderValue, + namespace: Option, } impl HttpSender { - pub fn new(connector: ConnectorService, version: Option<&str>) -> Self { + pub fn new( + connector: ConnectorService, + version: Option<&str>, + namespace: Option<&str>, + ) -> Self { let ver = version.unwrap_or(env!("CARGO_PKG_VERSION")); let version = HeaderValue::try_from(format!("libsql-remote-{ver}")).unwrap(); + let namespace = namespace.map(|v| HeaderValue::try_from(v).unwrap()); let inner = hyper::Client::builder().build(connector); - Self { inner, version } + Self { + inner, + version, + namespace, + } } async fn send( @@ -44,9 +54,15 @@ impl HttpSender { auth: Arc, body: String, ) -> Result> { - let req = hyper::Request::post(url.as_ref()) + let mut req_builder = hyper::Request::post(url.as_ref()) .header(AUTHORIZATION, auth.as_ref()) - .header("x-libsql-client-version", self.version.clone()) + .header("x-libsql-client-version", self.version.clone()); + + if let Some(namespace) = self.namespace { + req_builder = req_builder.header("x-namespace", namespace); + } + + let req = req_builder .body(hyper::Body::from(body)) .map_err(|err| HranaError::Http(format!("{:?}", err)))?; @@ -108,8 +124,9 @@ impl HttpConnection { token: impl Into, connector: ConnectorService, version: Option<&str>, + namespace: Option<&str>, ) -> Self { - let inner = HttpSender::new(connector, version); + let inner = HttpSender::new(connector, version, namespace); Self::new(url.into(), token.into(), inner) } } diff --git a/libsql/src/local/database.rs b/libsql/src/local/database.rs index 3b157e715d..48d1f62514 100644 --- a/libsql/src/local/database.rs +++ b/libsql/src/local/database.rs @@ -266,6 +266,7 @@ impl Database { flags: OpenFlags, encryption_config: Option, http_request_callback: Option, + namespace: Option, ) -> Result { use std::path::PathBuf; @@ -284,7 +285,7 @@ impl Database { auth_token, version.as_deref(), http_request_callback, - None, + namespace, ) .map_err(|e| crate::Error::Replication(e.into()))?;