diff --git a/.factory/automation.yml b/.factory/automation.yml index 507c3e93..898de4cf 100644 --- a/.factory/automation.yml +++ b/.factory/automation.yml @@ -23,7 +23,9 @@ config: version-candidate: VERSION dependencies: dependencies: [build] + typedb-common: [build] typedb-protocol: [build, release] + typeql: [build, release] build: quality: @@ -53,7 +55,7 @@ build: bazel run @vaticle_dependencies//distribution/artifact:create-netrc bazel build //... tools/start-core-server.sh - bazel test //tests:queries_core --test_arg=-- --test_arg=--test-threads=1 --test_output=streamed && export TEST_SUCCESS=0 || export TEST_SUCCESS=1 + bazel test //tests:queries --test_arg=-- --test_arg=core --test_arg=--test-threads=1 --test_output=streamed && export TEST_SUCCESS=0 || export TEST_SUCCESS=1 tools/stop-core-server.sh exit $TEST_SUCCESS test-integration-cluster: @@ -65,9 +67,20 @@ build: bazel run @vaticle_dependencies//distribution/artifact:create-netrc bazel build //... source tools/start-cluster-servers.sh # use source to receive export vars - bazel test //tests:queries_cluster --test_env=ROOT_CA=$ROOT_CA --test_arg=-- --test_arg=--test-threads=1 --test_output=streamed && export TEST_SUCCESS=0 || export TEST_SUCCESS=1 + bazel test //tests:queries --test_env=ROOT_CA=$ROOT_CA --test_arg=-- --test_arg=cluster --test_arg=--test-threads=1 --test_output=streamed && export TEST_SUCCESS=0 || export TEST_SUCCESS=1 tools/stop-cluster-servers.sh exit $TEST_SUCCESS + test-integration-runtimes: + image: vaticle-ubuntu-22.04 + command: | + export ARTIFACT_USERNAME=$REPO_VATICLE_USERNAME + export ARTIFACT_PASSWORD=$REPO_VATICLE_PASSWORD + bazel run @vaticle_dependencies//distribution/artifact:create-netrc + bazel build //... + tools/start-core-server.sh + bazel test //tests:runtimes --test_arg=-- --test_arg=--test-threads=1 --test_output=streamed && export TEST_SUCCESS=0 || export TEST_SUCCESS=1 + tools/stop-core-server.sh + exit $TEST_SUCCESS deploy-crate-snapshot: filter: owner: vaticle diff --git a/BUILD b/BUILD index 5de007c8..c192ac19 100644 --- a/BUILD +++ b/BUILD @@ -31,45 +31,47 @@ load("//:deployment.bzl", deployment_github = "deployment") rust_library( name = "typedb_client", srcs = glob(["src/**/*.rs"]), + tags = ["crate-name=typedb-client"], deps = [ + "@crates//:chrono", + "@crates//:crossbeam", + "@crates//:futures", + "@crates//:http", + "@crates//:itertools", + "@crates//:log", + "@crates//:prost", + "@crates//:tokio", + "@crates//:tokio-stream", + "@crates//:tonic", + "@crates//:uuid", "@vaticle_typedb_protocol//grpc/rust:typedb_protocol", "@vaticle_typeql//rust:typeql_lang", - - "@vaticle_dependencies//library/crates:chrono", - "@vaticle_dependencies//library/crates:crossbeam", - "@vaticle_dependencies//library/crates:futures", - "@vaticle_dependencies//library/crates:log", - "@vaticle_dependencies//library/crates:prost", - "@vaticle_dependencies//library/crates:tokio", - "@vaticle_dependencies//library/crates:tonic", - "@vaticle_dependencies//library/crates:uuid", ], - tags = ["crate-name=typedb-client"], ) assemble_crate( name = "assemble_crate", - target = "typedb_client", description = "TypeDB Client API for Rust", homepage = "https://github.com/vaticle/typedb-client-rust", license = "Apache-2.0", repository = "https://github.com/vaticle/typedb-client-rust", + target = "typedb_client", ) deploy_crate( name = "deploy_crate", - target = ":assemble_crate", + release = deployment["crate.release"], snapshot = deployment["crate.snapshot"], - release = deployment["crate.release"] + target = ":assemble_crate", ) deploy_github( name = "deploy_github", draft = True, - title = "TypeDB Client Rust", - release_description = "//:RELEASE_TEMPLATE.md", organisation = deployment_github["github.organisation"], + release_description = "//:RELEASE_TEMPLATE.md", repository = deployment_github["github.repository"], + title = "TypeDB Client Rust", title_append_version = True, ) @@ -105,13 +107,13 @@ filegroup( rustfmt_test( name = "client_rustfmt_test", - targets = ["typedb_client"] + targets = ["typedb_client"], ) # CI targets that are not declared in any BUILD file, but are called externally filegroup( name = "ci", data = [ - "@vaticle_dependencies//ide/rust:sync" + "@vaticle_dependencies//tool/cargo:sync", ], ) diff --git a/WORKSPACE b/WORKSPACE index 86111ef2..dbb04ef5 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -49,8 +49,10 @@ load("@rules_rust//rust:repositories.bzl", "rules_rust_dependencies", "rust_regi rules_rust_dependencies() rust_register_toolchains(edition = "2021", include_rustc_srcs = True) -load("@vaticle_dependencies//library/crates:crates.bzl", "raze_fetch_remote_crates") -raze_fetch_remote_crates() +load("@vaticle_dependencies//library/crates:crates.bzl", "fetch_crates") +fetch_crates() +load("@crates//:defs.bzl", "crate_repositories") +crate_repositories() # Load //builder/python load("@vaticle_dependencies//builder/python:deps.bzl", python_deps = "deps") diff --git a/dependencies/ide/BUILD b/dependencies/ide/rust/BUILD similarity index 100% rename from dependencies/ide/BUILD rename to dependencies/ide/rust/BUILD diff --git a/dependencies/ide/sync.sh b/dependencies/ide/rust/sync.sh similarity index 94% rename from dependencies/ide/sync.sh rename to dependencies/ide/rust/sync.sh index f167a274..0603f4a8 100755 --- a/dependencies/ide/sync.sh +++ b/dependencies/ide/rust/sync.sh @@ -20,4 +20,4 @@ # under the License. # -bazel run @vaticle_dependencies//ide/rust:sync +bazel run @vaticle_dependencies//tool/cargo:sync diff --git a/dependencies/vaticle/repositories.bzl b/dependencies/vaticle/repositories.bzl index 3f1d763e..72a49da2 100644 --- a/dependencies/vaticle/repositories.bzl +++ b/dependencies/vaticle/repositories.bzl @@ -25,26 +25,26 @@ def vaticle_dependencies(): git_repository( name = "vaticle_dependencies", remote = "https://github.com/vaticle/dependencies", - commit = "d76a7b935cd6452615f78772539fbc2e1228f503", # sync-marker: do not remove this comment, this is used for sync-dependencies by @vaticle_dependencies + commit = "76636b1672b04e9880439395b8913231724ae459", # sync-marker: do not remove this comment, this is used for sync-dependencies by @vaticle_dependencies ) def vaticle_typedb_common(): git_repository( name = "vaticle_typedb_common", remote = "https://github.com/vaticle/typedb-common", - tag = "2.12.0" # sync-marker: do not remove this comment, this is used for sync-dependencies by @vaticle_typedb_common + commit = "aa03cb5f6a57ec2a51291b7a0510734ca1f41479" # sync-marker: do not remove this comment, this is used for sync-dependencies by @vaticle_typedb_common ) def vaticle_typedb_protocol(): git_repository( name = "vaticle_typedb_protocol", remote = "https://github.com/vaticle/typedb-protocol", - commit = "16d1fb6749c0fee85843ca67f470015dda9fc497", # sync-marker: do not remove this comment, this is used for sync-dependencies by @vaticle_dependencies + commit = "b1c19e02054c1a1d354b42875e6ccd67602a546f", # sync-marker: do not remove this comment, this is used for sync-dependencies by @vaticle_dependencies ) def vaticle_typeql(): git_repository( name = "vaticle_typeql", remote = "https://github.com/vaticle/typeql", - commit = "776643fb6f0c754730e55230733fd2326f32cd39", # sync-marker: do not remove this comment, this is used for sync-dependencies by @vaticle_dependencies + commit = "7a63699b3879296ae3039577ba3f5220bbf6d33d", # sync-marker: do not remove this comment, this is used for sync-dependencies by @vaticle_dependencies ) diff --git a/rustfmt.toml b/rustfmt.toml index 1fe0cdaa..2bfc3860 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -23,3 +23,4 @@ imports_granularity = "Crate" group_imports = "StdExternalCrate" use_small_heuristics = "Max" +max_width = 120 diff --git a/src/answer/concept_map.rs b/src/answer/concept_map.rs index bebc9376..8d665ba7 100644 --- a/src/answer/concept_map.rs +++ b/src/answer/concept_map.rs @@ -24,7 +24,7 @@ use std::{ ops::Index, }; -use crate::{common::Result, concept::Concept}; +use crate::concept::Concept; #[derive(Debug)] pub struct ConceptMap { @@ -32,14 +32,6 @@ pub struct ConceptMap { } impl ConceptMap { - pub(crate) fn from_proto(proto: typedb_protocol::ConceptMap) -> Result { - let mut map = HashMap::with_capacity(proto.map.len()); - for (k, v) in proto.map { - map.insert(k, Concept::from_proto(v)?); - } - Ok(Self { map }) - } - pub fn get(&self, var_name: &str) -> Option<&Concept> { self.map.get(var_name) } diff --git a/src/answer/numeric.rs b/src/answer/numeric.rs index 53d7b54e..c1771b1d 100644 --- a/src/answer/numeric.rs +++ b/src/answer/numeric.rs @@ -19,10 +19,6 @@ * under the License. */ -use typedb_protocol::numeric::Value; - -use crate::common::{Error, Result}; - #[derive(Clone, Debug)] pub enum Numeric { Long(i64), @@ -48,18 +44,6 @@ impl Numeric { } } -impl TryFrom for Numeric { - type Error = Error; - - fn try_from(value: typedb_protocol::Numeric) -> Result { - match value.value.unwrap() { - Value::LongValue(long) => Ok(Numeric::Long(long)), - Value::DoubleValue(double) => Ok(Numeric::Double(double)), - Value::Nan(_) => Ok(Numeric::NaN), - } - } -} - impl From for i64 { fn from(n: Numeric) -> Self { n.into_i64() diff --git a/src/common/address.rs b/src/common/address.rs index 8266d519..64201efc 100644 --- a/src/common/address.rs +++ b/src/common/address.rs @@ -21,12 +21,12 @@ use std::{fmt, str::FromStr}; -use tonic::transport::Uri; +use http::Uri; use crate::common::{Error, Result}; #[derive(Clone, Debug, Hash, PartialEq, Eq)] -pub struct Address { +pub(crate) struct Address { uri: Uri, } @@ -43,7 +43,7 @@ impl FromStr for Address { let uri = if address.contains("://") { address.parse::()? } else { - format!("http://{}", address).parse::()? + format!("http://{address}").parse::()? }; Ok(Self { uri }) } diff --git a/src/common/credential.rs b/src/common/credential.rs index 14d90d8a..1f90bdbb 100644 --- a/src/common/credential.rs +++ b/src/common/credential.rs @@ -19,35 +19,34 @@ * under the License. */ -use std::{ - fs, - path::{Path, PathBuf}, - sync::RwLock, -}; +use std::{fmt, fs, path::Path}; -use tonic::{ - transport::{Certificate, ClientTlsConfig}, - Request, -}; +use tonic::transport::{Certificate, ClientTlsConfig}; use crate::Result; -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct Credential { username: String, password: String, is_tls_enabled: bool, - tls_root_ca: Option, + tls_config: Option, } impl Credential { - pub fn with_tls(username: &str, password: &str, tls_root_ca: Option<&Path>) -> Self { - Credential { + pub fn with_tls(username: &str, password: &str, tls_root_ca: Option<&Path>) -> Result { + let tls_config = Some(if let Some(tls_root_ca) = tls_root_ca { + ClientTlsConfig::new().ca_certificate(Certificate::from_pem(fs::read_to_string(tls_root_ca)?)) + } else { + ClientTlsConfig::new() + }); + + Ok(Credential { username: username.to_owned(), password: password.to_owned(), is_tls_enabled: true, - tls_root_ca: tls_root_ca.map(Path::to_owned), - } + tls_config, + }) } pub fn without_tls(username: &str, password: &str) -> Self { @@ -55,7 +54,7 @@ impl Credential { username: username.to_owned(), password: password.to_owned(), is_tls_enabled: false, - tls_root_ca: None, + tls_config: None, } } @@ -71,51 +70,17 @@ impl Credential { self.is_tls_enabled } - pub fn tls_config(&self) -> Result { - if let Some(ref tls_root_ca) = self.tls_root_ca { - Ok(ClientTlsConfig::new() - .ca_certificate(Certificate::from_pem(fs::read_to_string(tls_root_ca)?))) - } else { - Ok(ClientTlsConfig::new()) - } + pub fn tls_config(&self) -> &Option { + &self.tls_config } } -#[derive(Debug)] -pub(crate) struct CallCredentials { - credential: Credential, - token: RwLock>, -} - -impl CallCredentials { - pub(super) fn new(credential: Credential) -> Self { - Self { credential, token: RwLock::new(None) } - } - - pub(super) fn username(&self) -> &str { - self.credential.username() - } - - pub(super) fn password(&self) -> &str { - self.credential.password() - } - - pub(super) fn set_token(&self, token: String) { - *self.token.write().unwrap() = Some(token); - } - - pub(super) fn reset_token(&self) { - *self.token.write().unwrap() = None; - } - - pub(super) fn inject(&self, mut request: Request<()>) -> Request<()> { - request.metadata_mut().insert("username", self.credential.username().try_into().unwrap()); - match &*self.token.read().unwrap() { - Some(token) => request.metadata_mut().insert("token", token.try_into().unwrap()), - None => request - .metadata_mut() - .insert("password", self.credential.password().try_into().unwrap()), - }; - request +impl fmt::Debug for Credential { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Credential") + .field("username", &self.username) + .field("is_tls_enabled", &self.is_tls_enabled) + .field("tls_config", &self.tls_config) + .finish() } } diff --git a/src/common/error.rs b/src/common/error.rs index bfadbbb1..d345243e 100644 --- a/src/common/error.rs +++ b/src/common/error.rs @@ -24,8 +24,12 @@ use std::{error::Error as StdError, fmt}; use tonic::{Code, Status}; use typeql_lang::error_messages; -error_messages! { ClientError - code: "CLI", type: "Client Error", +use crate::common::RequestID; + +error_messages! { ConnectionError + code: "CXN", type: "Connection Error", + ConnectionIsClosed() = + 1: "The connection has been closed and no further operation is allowed.", SessionIsClosed() = 2: "The session is closed and no further operation is allowed.", TransactionIsClosed() = @@ -38,7 +42,7 @@ error_messages! { ClientError 8: "The database '{}' does not exist.", MissingResponseField(&'static str) = 9: "Missing field in message received from server: '{}'.", - UnknownRequestId(String) = + UnknownRequestId(RequestID) = 10: "Received a response with unknown request id '{}'", ClusterUnableToConnect(String) = 12: "Unable to connect to TypeDB Cluster. Attempted connecting to the cluster members, but none are available: '{}'.", @@ -52,23 +56,35 @@ error_messages! { ClientError 17: "Failed to close session. It may still be open on the server: or it may already have been closed previously.", } -#[derive(Debug, PartialEq, Eq)] -pub enum Error { - Client(ClientError), - Other(String), +error_messages! { InternalError + code: "INT", type: "Internal Error", + RecvError() = + 1: "Channel is closed.", + SendError() = + 2: "Channel is closed.", + UnexpectedRequestType(String) = + 3: "Unexpected request type for remote procedure call: {}.", + UnexpectedResponseType(String) = + 4: "Unexpected response type for remote procedure call: {}.", + UnknownConnectionAddress(String) = + 5: "Received unrecognized address from the server: {}.", + EnumOutOfBounds(i32, &'static str) = + 6: "Value '{}' is out of bounds for enum '{}'.", } -impl Error { - pub(crate) fn new(msg: String) -> Self { - Error::Other(msg) - } +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum Error { + Connection(ConnectionError), + Internal(InternalError), + Other(String), } impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - Error::Client(error) => write!(f, "{}", error), - Error::Other(message) => write!(f, "{}", message), + Error::Connection(error) => write!(f, "{error}"), + Error::Internal(error) => write!(f, "{error}"), + Error::Other(message) => write!(f, "{message}"), } } } @@ -76,15 +92,36 @@ impl fmt::Display for Error { impl StdError for Error { fn source(&self) -> Option<&(dyn StdError + 'static)> { match self { - Error::Client(error) => Some(error), + Error::Connection(error) => Some(error), + Error::Internal(error) => Some(error), Error::Other(_) => None, } } } -impl From for Error { - fn from(error: ClientError) -> Self { - Error::Client(error) +impl From for Error { + fn from(error: ConnectionError) -> Self { + Error::Connection(error) + } +} + +impl From for Error { + fn from(error: InternalError) -> Self { + Error::Internal(error) + } +} + +impl From for Error { + fn from(status: Status) -> Self { + if is_rst_stream(&status) { + Self::Connection(ConnectionError::UnableToConnect()) + } else if is_replica_not_primary(&status) { + Self::Connection(ConnectionError::ClusterReplicaNotPrimary()) + } else if is_token_credential_invalid(&status) { + Self::Connection(ConnectionError::ClusterTokenCredentialInvalid()) + } else { + Self::Other(status.message().to_string()) + } } } @@ -103,22 +140,14 @@ fn is_token_credential_invalid(status: &Status) -> bool { status.code() == Code::Unauthenticated && status.message().contains("[CLS08]") } -impl From for Error { - fn from(status: Status) -> Self { - if is_rst_stream(&status) { - Self::Client(ClientError::UnableToConnect()) - } else if is_replica_not_primary(&status) { - Self::Client(ClientError::ClusterReplicaNotPrimary()) - } else if is_token_credential_invalid(&status) { - Self::Client(ClientError::ClusterTokenCredentialInvalid()) - } else { - Self::Other(status.message().to_string()) - } +impl From for Error { + fn from(err: http::uri::InvalidUri) -> Self { + Error::Other(err.to_string()) } } -impl From for Error { - fn from(err: futures::channel::mpsc::SendError) -> Self { +impl From for Error { + fn from(err: tonic::transport::Error) -> Self { Error::Other(err.to_string()) } } @@ -129,15 +158,27 @@ impl From> for Error { } } -impl From for Error { - fn from(err: tonic::codegen::http::uri::InvalidUri) -> Self { - Error::Other(err.to_string()) +impl From for Error { + fn from(_err: tokio::sync::oneshot::error::RecvError) -> Self { + Error::Internal(InternalError::RecvError()) } } -impl From for Error { - fn from(err: tonic::transport::Error) -> Self { - Error::Other(err.to_string()) +impl From for Error { + fn from(_err: crossbeam::channel::RecvError) -> Self { + Error::Internal(InternalError::RecvError()) + } +} + +impl From> for Error { + fn from(_err: crossbeam::channel::SendError) -> Self { + Error::Internal(InternalError::SendError()) + } +} + +impl From for Error { + fn from(err: String) -> Self { + Error::Other(err) } } diff --git a/src/common/macros.rs b/src/common/id.rs similarity index 56% rename from src/common/macros.rs rename to src/common/id.rs index 52fff1c3..0c5abf05 100644 --- a/src/common/macros.rs +++ b/src/common/id.rs @@ -19,18 +19,39 @@ * under the License. */ -#[macro_export] -macro_rules! async_enum_dispatch { - { - $variants:tt - $($vis:vis async fn $name:ident(&mut self, $arg:ident : $arg_type:ty $(,)?) -> $res:ty);+ $(;)? - } => { $(async_enum_dispatch!(@impl $variants, $vis, $name, $arg, $arg_type, $res);)+ }; - - (@impl {$($variant:ident),+}, $vis:vis, $name:ident, $arg:ident, $arg_type:ty, $res:ty) => { - $vis async fn $name(&mut self, $arg: $arg_type) -> $res { - match self { - $(Self::$variant(inner) => inner.$name($arg).await,)+ - } - } +use std::fmt; + +use uuid::Uuid; + +#[derive(Clone, Eq, Hash, PartialEq)] +pub struct ID(Vec); + +impl ID { + pub(crate) fn generate() -> Self { + Uuid::new_v4().as_bytes().to_vec().into() + } +} + +impl From for Vec { + fn from(id: ID) -> Self { + id.0 + } +} + +impl From> for ID { + fn from(vec: Vec) -> Self { + Self(vec) + } +} + +impl fmt::Debug for ID { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "ID[{self}]") + } +} + +impl fmt::Display for ID { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.iter().try_for_each(|byte| write!(f, "{byte:02x}")) } } diff --git a/src/common/info.rs b/src/common/info.rs new file mode 100644 index 00000000..b373b30c --- /dev/null +++ b/src/common/info.rs @@ -0,0 +1,45 @@ +/* + * Copyright (C) 2022 Vaticle + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::time::Duration; + +use super::{address::Address, SessionID}; + +#[derive(Clone, Debug)] +pub(crate) struct SessionInfo { + pub(crate) address: Address, + pub(crate) session_id: SessionID, + pub(crate) network_latency: Duration, +} + +#[derive(Debug)] +pub(crate) struct DatabaseInfo { + pub(crate) name: String, + pub(crate) replicas: Vec, +} + +#[derive(Debug)] +pub(crate) struct ReplicaInfo { + pub(crate) address: Address, + pub(crate) is_primary: bool, + pub(crate) is_preferred: bool, + pub(crate) term: i64, +} diff --git a/src/common/mod.rs b/src/common/mod.rs index 6fa466c8..35a175a9 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -19,51 +19,29 @@ * under the License. */ -mod address; -pub mod credential; +pub(crate) mod address; +mod credential; pub mod error; -mod macros; -pub(crate) mod rpc; +mod id; +pub(crate) mod info; +mod options; -use tonic::{Response, Status}; -use typedb_protocol::{session as session_proto, transaction as transaction_proto}; - -pub(crate) use self::rpc::{ClusterRPC, ClusterServerRPC, CoreRPC, ServerRPC, TransactionRPC}; -pub use self::{address::Address, credential::Credential, error::Error}; +pub use self::{credential::Credential, error::Error, options::Options}; pub(crate) type StdResult = std::result::Result; pub type Result = StdResult; -pub(crate) type TonicResult = StdResult, Status>; -pub(crate) type TonicChannel = tonic::transport::Channel; -pub(crate) type Executor = futures::executor::ThreadPool; +pub(crate) type RequestID = id::ID; +pub(crate) type SessionID = id::ID; -#[derive(Copy, Clone, Debug)] +#[derive(Copy, Clone, Debug, Eq, PartialEq)] pub enum SessionType { Data = 0, Schema = 1, } -impl SessionType { - pub(crate) fn to_proto(self) -> session_proto::Type { - match self { - SessionType::Data => session_proto::Type::Data, - SessionType::Schema => session_proto::Type::Schema, - } - } -} - -#[derive(Copy, Clone, Debug)] +#[derive(Copy, Clone, Debug, Eq, PartialEq)] pub enum TransactionType { Read = 0, Write = 1, } - -impl TransactionType { - pub(crate) fn to_proto(self) -> transaction_proto::Type { - match self { - TransactionType::Read => transaction_proto::Type::Read, - TransactionType::Write => transaction_proto::Type::Write, - } - } -} diff --git a/src/common/options.rs b/src/common/options.rs new file mode 100644 index 00000000..176f1e36 --- /dev/null +++ b/src/common/options.rs @@ -0,0 +1,82 @@ +/* + * Copyright (C) 2022 Vaticle + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::time::Duration; + +#[derive(Clone, Debug, Default)] +pub struct Options { + pub infer: Option, + pub trace_inference: Option, + pub explain: Option, + pub parallel: Option, + pub prefetch: Option, + pub prefetch_size: Option, + pub session_idle_timeout: Option, + pub transaction_timeout: Option, + pub schema_lock_acquire_timeout: Option, + pub read_any_replica: Option, +} + +impl Options { + pub fn new() -> Self { + Default::default() + } + + pub fn infer(self, infer: bool) -> Self { + Self { infer: Some(infer), ..self } + } + + pub fn trace_inference(self, trace_inference: bool) -> Self { + Self { trace_inference: Some(trace_inference), ..self } + } + + pub fn explain(self, explain: bool) -> Self { + Self { explain: Some(explain), ..self } + } + + pub fn parallel(self, parallel: bool) -> Self { + Self { parallel: Some(parallel), ..self } + } + + pub fn prefetch(self, prefetch: bool) -> Self { + Self { prefetch: Some(prefetch), ..self } + } + + pub fn prefetch_size(self, prefetch_size: i32) -> Self { + Self { prefetch_size: Some(prefetch_size), ..self } + } + + pub fn session_idle_timeout(self, timeout: Duration) -> Self { + Self { session_idle_timeout: Some(timeout), ..self } + } + + pub fn transaction_timeout(self, timeout: Duration) -> Self { + Self { transaction_timeout: Some(timeout), ..self } + } + + pub fn schema_lock_acquire_timeout(self, timeout: Duration) -> Self { + Self { schema_lock_acquire_timeout: Some(timeout), ..self } + } + + pub fn read_any_replica(self, read_any_replica: bool) -> Self { + Self { read_any_replica: Some(read_any_replica), ..self } + } +} diff --git a/src/common/rpc/builder.rs b/src/common/rpc/builder.rs deleted file mode 100644 index 574bb611..00000000 --- a/src/common/rpc/builder.rs +++ /dev/null @@ -1,312 +0,0 @@ -/* - * Copyright (C) 2022 Vaticle - * - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -pub(crate) mod core { - pub(crate) mod database_manager { - use typedb_protocol::core_database_manager::{all, contains, create}; - - pub(crate) fn contains_req(name: &str) -> contains::Req { - contains::Req { name: name.into() } - } - - pub(crate) fn create_req(name: &str) -> create::Req { - create::Req { name: name.into() } - } - - pub(crate) fn all_req() -> all::Req { - all::Req {} - } - } - - pub(crate) mod database { - use typedb_protocol::core_database::{delete, rule_schema, schema, type_schema}; - - pub(crate) fn delete_req(name: &str) -> delete::Req { - delete::Req { name: name.into() } - } - - pub(crate) fn rule_schema_req(name: &str) -> rule_schema::Req { - rule_schema::Req { name: name.into() } - } - - pub(crate) fn schema_req(name: &str) -> schema::Req { - schema::Req { name: name.into() } - } - - pub(crate) fn type_schema_req(name: &str) -> type_schema::Req { - type_schema::Req { name: name.into() } - } - } -} - -pub(crate) mod cluster { - pub(crate) mod server_manager { - use typedb_protocol::server_manager::all; - - pub(crate) fn all_req() -> all::Req { - all::Req {} - } - } - - pub(crate) mod user_manager { - use typedb_protocol::cluster_user_manager::{all, contains, create}; - - pub(crate) fn contains_req(username: &str) -> contains::Req { - contains::Req { username: username.into() } - } - - pub(crate) fn create_req(username: &str, password: &str) -> create::Req { - create::Req { username: username.into(), password: password.into() } - } - - pub(crate) fn all_req() -> all::Req { - all::Req {} - } - } - - pub(crate) mod user { - use typedb_protocol::cluster_user::{delete, password, token}; - - pub(crate) fn password_req(username: &str, password: &str) -> password::Req { - password::Req { username: username.into(), password: password.into() } - } - - pub(crate) fn token_req(username: &str) -> token::Req { - token::Req { username: username.into() } - } - - pub(crate) fn delete_req(username: &str) -> delete::Req { - delete::Req { username: username.into() } - } - } - - pub(crate) mod database_manager { - use typedb_protocol::cluster_database_manager::{all, get}; - - pub(crate) fn get_req(name: &str) -> get::Req { - get::Req { name: name.into() } - } - - pub(crate) fn all_req() -> all::Req { - all::Req {} - } - } -} - -pub(crate) mod session { - use typedb_protocol::{ - session, - session::{close, open}, - Options, - }; - - pub(crate) fn close_req(session_id: Vec) -> close::Req { - close::Req { session_id } - } - - pub(crate) fn open_req( - database: &str, - session_type: session::Type, - options: Options, - ) -> open::Req { - open::Req { - database: database.into(), - r#type: session_type.into(), - options: options.into(), - } - } -} - -pub(crate) mod transaction { - use typedb_protocol::{ - transaction, - transaction::{commit, open, rollback, stream}, - Options, - }; - use uuid::Uuid; - - pub(crate) fn client_msg(reqs: Vec) -> transaction::Client { - transaction::Client { reqs } - } - - pub(crate) fn stream_req(req_id: Vec) -> transaction::Req { - req_with_id(transaction::req::Req::StreamReq(stream::Req {}), req_id) - } - - pub(crate) fn open_req( - session_id: Vec, - transaction_type: transaction::Type, - options: Options, - network_latency_millis: i32, - ) -> transaction::Req { - req(transaction::req::Req::OpenReq(open::Req { - session_id, - r#type: transaction_type.into(), - options: options.into(), - network_latency_millis, - })) - } - - pub(crate) fn commit_req() -> transaction::Req { - req(transaction::req::Req::CommitReq(commit::Req {})) - } - - pub(crate) fn rollback_req() -> transaction::Req { - req(transaction::req::Req::RollbackReq(rollback::Req {})) - } - - pub(super) fn req(req: transaction::req::Req) -> transaction::Req { - transaction::Req { req_id: new_req_id(), metadata: Default::default(), req: req.into() } - } - - pub(super) fn req_with_id(req: transaction::req::Req, req_id: Vec) -> transaction::Req { - transaction::Req { req_id, metadata: Default::default(), req: req.into() } - } - - fn new_req_id() -> Vec { - Uuid::new_v4().as_bytes().to_vec() - } -} - -#[allow(dead_code)] -pub(crate) mod query_manager { - use typedb_protocol::{ - query_manager, - query_manager::{ - define, delete, explain, insert, match_aggregate, match_group, match_group_aggregate, - r#match, undefine, update, - }, - transaction, - transaction::req::Req::QueryManagerReq, - Options, - }; - - fn query_manager_req(req: query_manager::Req) -> transaction::Req { - super::transaction::req(QueryManagerReq(req)) - } - - pub(crate) fn define_req(query: &str, options: Option) -> transaction::Req { - query_manager_req(query_manager::Req { - options, - req: query_manager::req::Req::DefineReq(define::Req { query: query.to_string() }) - .into(), - }) - } - - pub(crate) fn undefine_req(query: &str, options: Option) -> transaction::Req { - query_manager_req(query_manager::Req { - options, - req: query_manager::req::Req::UndefineReq(undefine::Req { query: query.to_string() }) - .into(), - }) - } - - pub(crate) fn match_req(query: &str, options: Option) -> transaction::Req { - query_manager_req(query_manager::Req { - options, - req: query_manager::req::Req::MatchReq(r#match::Req { query: query.to_string() }) - .into(), - }) - } - - pub(crate) fn match_aggregate_req(query: &str, options: Option) -> transaction::Req { - query_manager_req(query_manager::Req { - options, - req: query_manager::req::Req::MatchAggregateReq(match_aggregate::Req { - query: query.to_string(), - }) - .into(), - }) - } - - pub(crate) fn match_group_req(query: &str, options: Option) -> transaction::Req { - query_manager_req(query_manager::Req { - options, - req: query_manager::req::Req::MatchGroupReq(match_group::Req { - query: query.to_string(), - }) - .into(), - }) - } - - pub(crate) fn match_group_aggregate_req( - query: &str, - options: Option, - ) -> transaction::Req { - query_manager_req(query_manager::Req { - options, - req: query_manager::req::Req::MatchGroupAggregateReq(match_group_aggregate::Req { - query: query.to_string(), - }) - .into(), - }) - } - - pub(crate) fn insert_req(query: &str, options: Option) -> transaction::Req { - query_manager_req(query_manager::Req { - options, - req: query_manager::req::Req::InsertReq(insert::Req { query: query.to_string() }) - .into(), - }) - } - - pub(crate) fn delete_req(query: &str, options: Option) -> transaction::Req { - query_manager_req(query_manager::Req { - options, - req: query_manager::req::Req::DeleteReq(delete::Req { query: query.to_string() }) - .into(), - }) - } - - pub(crate) fn update_req(query: &str, options: Option) -> transaction::Req { - query_manager_req(query_manager::Req { - options, - req: query_manager::req::Req::UpdateReq(update::Req { query: query.to_string() }) - .into(), - }) - } - - pub(crate) fn explain_req(id: i64) -> transaction::Req { - query_manager_req(query_manager::Req { - options: None, - req: query_manager::req::Req::ExplainReq(explain::Req { explainable_id: id }).into(), - }) - } -} - -#[allow(dead_code)] -pub(crate) mod thing { - use typedb_protocol::{ - attribute, thing, thing::req::Req::AttributeGetOwnersReq, transaction, - transaction::req::Req::ThingReq, - }; - - fn thing_req(req: thing::Req) -> transaction::Req { - super::transaction::req(ThingReq(req)) - } - - pub(crate) fn attribute_get_owners_req(iid: &[u8]) -> transaction::Req { - thing_req(thing::Req { - iid: iid.to_vec(), - req: AttributeGetOwnersReq(attribute::get_owners::Req { filter: None }).into(), - }) - } -} diff --git a/src/common/rpc/channel.rs b/src/common/rpc/channel.rs deleted file mode 100644 index a320892f..00000000 --- a/src/common/rpc/channel.rs +++ /dev/null @@ -1,98 +0,0 @@ -/* - * Copyright (C) 2022 Vaticle - * - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use std::sync::Arc; - -use tonic::{codegen::InterceptedService, service::Interceptor, Request, Status}; - -use crate::{ - common::{credential::CallCredentials, Address, Credential, TonicChannel}, - Result, -}; - -pub(crate) type CallCredChannel = InterceptedService; - -#[derive(Clone, Debug)] -pub(crate) enum Channel { - Plaintext(TonicChannel), - Encrypted(CallCredChannel), -} - -impl Channel { - pub(crate) fn open_plaintext(address: Address) -> Result { - Ok(Self::Plaintext(TonicChannel::builder(address.into_uri()).connect_lazy())) - } - - pub(crate) fn open_encrypted( - address: Address, - credential: Credential, - ) -> Result<(Self, Arc)> { - let mut builder = TonicChannel::builder(address.into_uri()); - if credential.is_tls_enabled() { - builder = builder.tls_config(credential.tls_config()?)?; - } - - let channel = builder.connect_lazy(); - let call_credentials = Arc::new(CallCredentials::new(credential)); - Ok(( - Self::Encrypted(InterceptedService::new( - channel, - CredentialInjector::new(call_credentials.clone()), - )), - call_credentials, - )) - } -} - -impl From for TonicChannel { - fn from(channel: Channel) -> Self { - match channel { - Channel::Plaintext(channel) => channel, - _ => panic!(), - } - } -} - -impl From for CallCredChannel { - fn from(channel: Channel) -> Self { - match channel { - Channel::Encrypted(channel) => channel, - _ => panic!(), - } - } -} - -#[derive(Clone, Debug)] -pub(crate) struct CredentialInjector { - call_credentials: Arc, -} - -impl CredentialInjector { - fn new(call_credentials: Arc) -> Self { - Self { call_credentials } - } -} - -impl Interceptor for CredentialInjector { - fn call(&mut self, request: Request<()>) -> std::result::Result, Status> { - Ok(self.call_credentials.inject(request)) - } -} diff --git a/src/common/rpc/cluster.rs b/src/common/rpc/cluster.rs deleted file mode 100644 index 0ec45a7e..00000000 --- a/src/common/rpc/cluster.rs +++ /dev/null @@ -1,280 +0,0 @@ -/* - * Copyright (C) 2022 Vaticle - * - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use std::{ - collections::{HashMap, HashSet}, - sync::Arc, -}; - -use futures::{channel::mpsc, future::BoxFuture, FutureExt}; -use tonic::Streaming; -use typedb_protocol::{ - cluster_database_manager, cluster_user, core_database, core_database_manager, session, - transaction, type_db_cluster_client::TypeDbClusterClient as ClusterGRPC, -}; - -use crate::common::{ - credential::CallCredentials, - error::ClientError, - rpc::{ - builder::{cluster, cluster::user::token_req}, - channel::CallCredChannel, - Channel, CoreRPC, - }, - Address, Credential, Error, Executor, Result, -}; - -#[derive(Debug, Clone)] -pub(crate) struct ClusterRPC { - server_rpcs: HashMap, -} - -impl ClusterRPC { - pub(crate) fn new(addresses: HashSet
, credential: Credential) -> Result> { - let cluster_clients = addresses - .into_iter() - .map(|address| { - Ok((address.clone(), ClusterServerRPC::new(address, credential.clone())?)) - }) - .collect::>()?; - Ok(Arc::new(Self { server_rpcs: cluster_clients })) - } - - pub(crate) async fn fetch_current_addresses>( - addresses: &[T], - credential: &Credential, - ) -> Result> { - for address in addresses { - match ClusterServerRPC::new(address.as_ref().parse()?, credential.clone())? - .validated() - .await - { - Ok(mut client) => { - let servers = client.servers_all().await?.servers; - return servers.into_iter().map(|server| server.address.parse()).collect(); - } - Err(Error::Client(ClientError::UnableToConnect())) => (), - Err(err) => Err(err)?, - } - } - Err(ClientError::UnableToConnect())? - } - - pub(crate) fn server_rpc_count(&self) -> usize { - self.server_rpcs.len() - } - - pub(crate) fn addresses(&self) -> impl Iterator { - self.server_rpcs.keys() - } - - pub(crate) fn get_server_rpc(&self, address: &Address) -> ClusterServerRPC { - self.server_rpcs.get(address).unwrap().clone() - } - - pub(crate) fn get_any_server_rpc(&self) -> ClusterServerRPC { - // TODO round robin? - self.server_rpcs.values().next().unwrap().clone() - } - - pub(crate) fn iter_server_rpcs_cloned(&self) -> impl Iterator + '_ { - self.server_rpcs.values().cloned() - } - - pub(crate) fn unable_to_connect(&self) -> Error { - Error::Client(ClientError::ClusterUnableToConnect( - self.addresses().map(Address::to_string).collect::>().join(","), - )) - } -} - -#[derive(Clone, Debug)] -pub(crate) struct ClusterServerRPC { - address: Address, - core_rpc: CoreRPC, - cluster_grpc: ClusterGRPC, - call_credentials: Arc, - pub(crate) executor: Arc, -} - -impl ClusterServerRPC { - pub(crate) fn new(address: Address, credential: Credential) -> Result { - let (channel, call_credentials) = Channel::open_encrypted(address.clone(), credential)?; - Ok(Self { - address, - core_rpc: CoreRPC::new(channel.clone())?, - cluster_grpc: ClusterGRPC::new(channel.into()), - executor: Arc::new(Executor::new().expect("Failed to create Executor")), - call_credentials, - }) - } - - async fn validated(mut self) -> Result { - self.cluster_grpc.databases_all(cluster::database_manager::all_req()).await?; - Ok(self) - } - - pub(crate) fn address(&self) -> &Address { - &self.address - } - - async fn call_with_auto_renew_token(&mut self, call: F) -> Result - where - for<'a> F: Fn(&'a mut Self) -> BoxFuture<'a, Result>, - { - match call(self).await { - Err(Error::Client(ClientError::ClusterTokenCredentialInvalid())) => { - self.renew_token().await?; - call(self).await - } - res => res, - } - } - - async fn renew_token(&mut self) -> Result { - self.call_credentials.reset_token(); - let req = token_req(self.call_credentials.username()); - let token = self.user_token(req).await?.token; - self.call_credentials.set_token(token); - Ok(()) - } - - async fn user_token( - &mut self, - username: cluster_user::token::Req, - ) -> Result { - Ok(self.cluster_grpc.user_token(username).await?.into_inner()) - } - - pub(crate) async fn servers_all( - &mut self, - ) -> Result { - self.call_with_auto_renew_token(|this| { - Box::pin( - this.cluster_grpc - .servers_all(cluster::server_manager::all_req()) - .map(|res| Ok(res?.into_inner())), - ) - }) - .await - } - - pub(crate) async fn databases_get( - &mut self, - req: cluster_database_manager::get::Req, - ) -> Result { - self.call_with_auto_renew_token(|this| { - Box::pin(this.cluster_grpc.databases_get(req.clone()).map(|res| Ok(res?.into_inner()))) - }) - .await - } - - pub(crate) async fn databases_all( - &mut self, - req: cluster_database_manager::all::Req, - ) -> Result { - self.call_with_auto_renew_token(|this| { - Box::pin(this.cluster_grpc.databases_all(req.clone()).map(|res| Ok(res?.into_inner()))) - }) - .await - } - - // server client pass-through - pub(crate) async fn databases_contains( - &mut self, - req: core_database_manager::contains::Req, - ) -> Result { - self.call_with_auto_renew_token(|this| { - Box::pin(this.core_rpc.databases_contains(req.clone())) - }) - .await - } - - pub(crate) async fn databases_create( - &mut self, - req: core_database_manager::create::Req, - ) -> Result { - self.call_with_auto_renew_token(|this| { - Box::pin(this.core_rpc.databases_create(req.clone())) - }) - .await - } - - pub(crate) async fn database_delete( - &mut self, - req: core_database::delete::Req, - ) -> Result { - self.call_with_auto_renew_token(|this| Box::pin(this.core_rpc.database_delete(req.clone()))) - .await - } - - pub(crate) async fn database_schema( - &mut self, - req: core_database::schema::Req, - ) -> Result { - self.call_with_auto_renew_token(|this| Box::pin(this.core_rpc.database_schema(req.clone()))) - .await - } - - pub(crate) async fn database_rule_schema( - &mut self, - req: core_database::rule_schema::Req, - ) -> Result { - self.call_with_auto_renew_token(|this| { - Box::pin(this.core_rpc.database_rule_schema(req.clone())) - }) - .await - } - - pub(crate) async fn database_type_schema( - &mut self, - req: core_database::type_schema::Req, - ) -> Result { - self.call_with_auto_renew_token(|this| { - Box::pin(this.core_rpc.database_type_schema(req.clone())) - }) - .await - } - - pub(crate) async fn session_open( - &mut self, - req: session::open::Req, - ) -> Result { - self.call_with_auto_renew_token(|this| Box::pin(this.core_rpc.session_open(req.clone()))) - .await - } - - pub(crate) async fn session_close( - &mut self, - req: session::close::Req, - ) -> Result { - self.call_with_auto_renew_token(|this| Box::pin(this.core_rpc.session_close(req.clone()))) - .await - } - - pub(crate) async fn transaction( - &mut self, - req: transaction::Req, - ) -> Result<(mpsc::Sender, Streaming)> { - self.call_with_auto_renew_token(|this| Box::pin(this.core_rpc.transaction(req.clone()))) - .await - } -} diff --git a/src/common/rpc/core.rs b/src/common/rpc/core.rs deleted file mode 100644 index 59715ed1..00000000 --- a/src/common/rpc/core.rs +++ /dev/null @@ -1,224 +0,0 @@ -/* - * Copyright (C) 2022 Vaticle - * - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use std::{future::Future, sync::Arc}; - -use futures::{channel::mpsc, SinkExt}; -use tonic::{Response, Status, Streaming}; -use typedb_protocol::{ - core_database, core_database_manager, session, transaction, - type_db_client::TypeDbClient as RawCoreGRPC, -}; - -use crate::{ - async_enum_dispatch, - common::{ - rpc::{ - builder::{core, transaction::client_msg}, - channel::CallCredChannel, - Channel, - }, - Address, Executor, Result, StdResult, TonicChannel, - }, -}; - -#[derive(Clone, Debug)] -enum CoreGRPC { - Plaintext(RawCoreGRPC), - Encrypted(RawCoreGRPC), -} - -impl CoreGRPC { - pub fn new(channel: Channel) -> Self { - match channel { - Channel::Plaintext(channel) => Self::Plaintext(RawCoreGRPC::new(channel)), - Channel::Encrypted(channel) => Self::Encrypted(RawCoreGRPC::new(channel)), - } - } - - async_enum_dispatch! { { Plaintext, Encrypted } - pub async fn databases_contains( - &mut self, - request: core_database_manager::contains::Req, - ) -> StdResult, Status>; - - pub async fn databases_create( - &mut self, - request: core_database_manager::create::Req, - ) -> StdResult, Status>; - - pub async fn databases_all( - &mut self, - request: core_database_manager::all::Req, - ) -> StdResult, Status>; - - pub async fn database_schema( - &mut self, - request: core_database::schema::Req, - ) -> StdResult, Status>; - - pub async fn database_type_schema( - &mut self, - request: core_database::type_schema::Req, - ) -> StdResult, Status>; - - pub async fn database_rule_schema( - &mut self, - request: core_database::rule_schema::Req, - ) -> StdResult, Status>; - - pub async fn database_delete( - &mut self, - request: core_database::delete::Req, - ) -> StdResult, Status>; - - pub async fn session_open( - &mut self, - request: session::open::Req, - ) -> StdResult, Status>; - - pub async fn session_close( - &mut self, - request: session::close::Req, - ) -> StdResult, Status>; - - pub async fn session_pulse( - &mut self, - request: session::pulse::Req, - ) -> StdResult, Status>; - - pub async fn transaction( - &mut self, - request: impl tonic::IntoStreamingRequest, - ) -> StdResult>, Status>; - } -} - -#[derive(Clone, Debug)] -pub(crate) struct CoreRPC { - core_grpc: CoreGRPC, - pub(crate) executor: Arc, -} - -impl CoreRPC { - pub(crate) fn new(channel: Channel) -> Result { - Ok(Self { - core_grpc: CoreGRPC::new(channel), - executor: Arc::new(Executor::new().expect("Failed to create Executor")), - }) - } - - pub(crate) async fn connect(address: Address) -> Result { - Self::new(Channel::open_plaintext(address)?)?.validated().await - } - - async fn validated(mut self) -> Result { - // TODO: temporary hack to validate connection until we have client pulse - self.core_grpc.databases_all(core::database_manager::all_req()).await?; - Ok(self) - } - - pub(crate) async fn databases_contains( - &mut self, - req: core_database_manager::contains::Req, - ) -> Result { - single(self.core_grpc.databases_contains(req)).await - } - - pub(crate) async fn databases_create( - &mut self, - req: core_database_manager::create::Req, - ) -> Result { - single(self.core_grpc.databases_create(req)).await - } - - pub(crate) async fn databases_all( - &mut self, - req: core_database_manager::all::Req, - ) -> Result { - single(self.core_grpc.databases_all(req)).await - } - - pub(crate) async fn database_delete( - &mut self, - req: core_database::delete::Req, - ) -> Result { - single(self.core_grpc.database_delete(req)).await - } - - pub(crate) async fn database_schema( - &mut self, - req: core_database::schema::Req, - ) -> Result { - single(self.core_grpc.database_schema(req)).await - } - - pub(crate) async fn database_type_schema( - &mut self, - req: core_database::type_schema::Req, - ) -> Result { - single(self.core_grpc.database_type_schema(req)).await - } - - pub(crate) async fn database_rule_schema( - &mut self, - req: core_database::rule_schema::Req, - ) -> Result { - single(self.core_grpc.database_rule_schema(req)).await - } - - pub(crate) async fn session_open( - &mut self, - req: session::open::Req, - ) -> Result { - single(self.core_grpc.session_open(req)).await - } - - pub(crate) async fn session_close( - &mut self, - req: session::close::Req, - ) -> Result { - single(self.core_grpc.session_close(req)).await - } - - pub(crate) async fn transaction( - &mut self, - open_req: transaction::Req, - ) -> Result<(mpsc::Sender, Streaming)> { - // TODO: refactor to crossbeam channel - let (mut sender, receiver) = mpsc::channel::(256); - sender.send(client_msg(vec![open_req])).await.unwrap(); - bidi_stream(sender, self.core_grpc.transaction(receiver)).await - } -} - -pub(crate) async fn single( - res: impl Future, Status>>, -) -> Result { - Ok(res.await?.into_inner()) -} - -pub(crate) async fn bidi_stream( - req_sink: mpsc::Sender, - res: impl Future>, Status>>, -) -> Result<(mpsc::Sender, Streaming)> { - Ok((req_sink, res.await?.into_inner())) -} diff --git a/src/common/rpc/server.rs b/src/common/rpc/server.rs deleted file mode 100644 index c7a281fd..00000000 --- a/src/common/rpc/server.rs +++ /dev/null @@ -1,108 +0,0 @@ -/* - * Copyright (C) 2022 Vaticle - * - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use std::sync::Arc; - -use futures::channel::mpsc; -use tonic::Streaming; -use typedb_protocol::{core_database, core_database_manager, session, transaction}; - -use crate::{ - async_enum_dispatch, - common::{ - rpc::{core::CoreRPC, ClusterServerRPC}, - Executor, Result, - }, -}; - -#[derive(Clone, Debug)] -pub(crate) enum ServerRPC { - Core(CoreRPC), - Cluster(ClusterServerRPC), -} - -impl From for ServerRPC { - fn from(server_client: CoreRPC) -> Self { - ServerRPC::Core(server_client) - } -} - -impl From for ServerRPC { - fn from(cluster_client: ClusterServerRPC) -> Self { - ServerRPC::Cluster(cluster_client) - } -} - -impl ServerRPC { - pub(crate) fn executor(&self) -> &Arc { - match self { - Self::Core(client) => &client.executor, - Self::Cluster(client) => &client.executor, - } - } - - async_enum_dispatch! { { Core, Cluster } - pub(crate) async fn databases_contains( - &mut self, - req: core_database_manager::contains::Req, - ) -> Result; - - pub(crate) async fn databases_create( - &mut self, - req: core_database_manager::create::Req, - ) -> Result; - - pub(crate) async fn database_delete( - &mut self, - req: core_database::delete::Req, - ) -> Result; - - pub(crate) async fn database_schema( - &mut self, - req: core_database::schema::Req, - ) -> Result; - - pub(crate) async fn database_type_schema( - &mut self, - req: core_database::type_schema::Req, - ) -> Result; - - pub(crate) async fn database_rule_schema( - &mut self, - req: core_database::rule_schema::Req, - ) -> Result; - - pub(crate) async fn session_open( - &mut self, - req: session::open::Req, - ) -> Result; - - pub(crate) async fn session_close( - &mut self, - req: session::close::Req, - ) -> Result; - - pub(crate) async fn transaction( - &mut self, - req: transaction::Req, - ) -> Result<(mpsc::Sender, Streaming)>; - } -} diff --git a/src/common/rpc/transaction.rs b/src/common/rpc/transaction.rs deleted file mode 100644 index bfd46c25..00000000 --- a/src/common/rpc/transaction.rs +++ /dev/null @@ -1,430 +0,0 @@ -/* - * Copyright (C) 2022 Vaticle - * - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use std::{ - collections::HashMap, - mem, - pin::Pin, - sync::{Arc, Mutex}, - task::{Context, Poll}, - thread::sleep, - time::Duration, -}; - -use crossbeam::atomic::AtomicCell; -use futures::{ - channel::{mpsc, oneshot}, - SinkExt, Stream, StreamExt, -}; -use tonic::Streaming; -use typedb_protocol::{ - transaction, - transaction::{res::Res, res_part, server::Server, stream::State}, -}; - -use crate::common::{ - error::{ClientError, Error}, - rpc::{ - builder::transaction::{client_msg, stream_req}, - ServerRPC, - }, - Executor, Result, -}; - -// TODO: This structure has become pretty messy - review -#[derive(Clone, Debug)] -pub(crate) struct TransactionRPC { - rpc_client: ServerRPC, - sender: Sender, - receiver: Receiver, -} - -impl TransactionRPC { - pub(crate) async fn new(rpc_client: &ServerRPC, open_req: transaction::Req) -> Result { - let mut rpc_client_clone = rpc_client.clone(); - let (req_sink, streaming_res): ( - mpsc::Sender, - Streaming, - ) = rpc_client_clone.transaction(open_req).await?; - let (close_signal_sink, close_signal_receiver) = oneshot::channel::>(); - Ok(TransactionRPC { - rpc_client: rpc_client_clone.clone(), - sender: Sender::new( - req_sink, - rpc_client_clone.executor().clone(), - close_signal_receiver, - ), - receiver: Receiver::new(streaming_res, rpc_client_clone.executor(), close_signal_sink) - .await, - }) - } - - pub(crate) async fn single(&mut self, req: transaction::Req) -> Result { - if !self.is_open() { - todo!() - } - let (res_sink, res_receiver) = oneshot::channel::>(); - self.receiver.add_single(&req.req_id, res_sink); - self.sender.submit_message(req); - match res_receiver.await { - Ok(result) => result, - Err(err) => Err(Error::new(err.to_string())), - } - } - - pub(crate) fn stream(&mut self, req: transaction::Req) -> ResPartStream { - const BUFFER_SIZE: usize = 1024; - let (res_part_sink, res_part_receiver) = - mpsc::channel::>(BUFFER_SIZE); - let (stream_req_sink, stream_req_receiver) = std::sync::mpsc::channel::(); - self.receiver.add_stream(&req.req_id, res_part_sink); - let res_part_stream = - ResPartStream::new(res_part_receiver, stream_req_sink, req.req_id.clone()); - self.sender.add_message_provider(stream_req_receiver); - self.sender.submit_message(req); - res_part_stream - } - - pub(crate) fn is_open(&self) -> bool { - self.sender.is_open() - } - - pub(crate) async fn close(&self) { - self.sender.close(None).await; - } -} - -#[derive(Clone, Debug)] -struct Sender { - state: Arc, - executor: Arc, -} - -#[derive(Debug)] -struct SenderState { - req_sink: mpsc::Sender, - // TODO: refactor to crossbeam_queue::ArrayQueue? - queued_messages: Mutex>, - // TODO: refactor to message passing for these atomics - ongoing_task_count: AtomicCell, - is_open: AtomicCell, -} - -type ReqId = Vec; - -impl SenderState { - fn new(req_sink: mpsc::Sender) -> Self { - SenderState { - req_sink, - queued_messages: Mutex::new(Vec::new()), - ongoing_task_count: AtomicCell::new(0), - is_open: AtomicCell::new(true), - } - } - - fn submit_message(&self, req: transaction::Req) { - self.queued_messages.lock().unwrap().push(req); - } - - async fn dispatch_loop(&self) { - while self.is_open.load() { - const DISPATCH_INTERVAL: Duration = Duration::from_millis(3); - sleep(DISPATCH_INTERVAL); - self.dispatch_messages().await; - } - } - - async fn dispatch_messages(&self) { - self.ongoing_task_count.fetch_add(1); - let messages = mem::take(&mut *self.queued_messages.lock().unwrap()); - if !messages.is_empty() { - self.req_sink.clone().send(client_msg(messages)).await.unwrap(); - } - self.ongoing_task_count.fetch_sub(1); - } - - async fn await_close_signal(&self, close_signal_receiver: CloseSignalReceiver) { - match close_signal_receiver.await { - Ok(close_signal) => { - self.close(close_signal).await; - } - Err(err) => { - self.close(Some(Error::new(err.to_string()))).await; - } - } - } - - async fn close(&self, error: Option) { - if let Ok(true) = self.is_open.compare_exchange(true, false) { - if error.is_none() { - self.dispatch_messages().await; - } - // TODO: refactor to non-busy wait? - // TODO: this loop should have a timeout - loop { - if self.ongoing_task_count.load() == 0 { - self.req_sink.clone().close().await.unwrap(); - break; - } - } - } - } -} - -impl Sender { - pub(crate) fn new( - req_sink: mpsc::Sender, - executor: Arc, - close_signal_receiver: CloseSignalReceiver, - ) -> Self { - let state = Arc::new(SenderState::new(req_sink)); - // // TODO: clarify lifetimes of these threads - executor.spawn_ok({ - let state = state.clone(); - async move { - state.await_close_signal(close_signal_receiver).await; - } - }); - - executor.spawn_ok({ - let state = state.clone(); - async move { - state.dispatch_loop().await; - } - }); - - Sender { state, executor } - } - - fn submit_message(&self, req: transaction::Req) { - self.state.submit_message(req); - } - - fn add_message_provider(&self, provider: std::sync::mpsc::Receiver) { - let cloned_state = self.state.clone(); - self.executor.spawn_ok(async move { - for req in provider.iter() { - cloned_state.submit_message(req); - } - }); - } - - fn is_open(&self) -> bool { - self.state.is_open.load() - } - - async fn close(&self, error: Option) { - self.state.close(error).await - } -} - -#[derive(Clone, Debug)] -struct Receiver { - state: Arc, -} - -#[derive(Debug)] -struct ReceiverState { - res_collectors: Mutex>, - res_part_collectors: Mutex>, - is_open: AtomicCell, -} - -impl ReceiverState { - fn new() -> Self { - ReceiverState { - res_collectors: Mutex::new(HashMap::new()), - res_part_collectors: Mutex::new(HashMap::new()), - is_open: AtomicCell::new(true), - } - } - - async fn listen( - self: &Arc, - mut grpc_stream: Streaming, - close_signal_sink: CloseSignalSink, - ) { - loop { - match grpc_stream.next().await { - Some(Ok(message)) => { - self.clone().on_receive(message).await; - } - Some(Err(err)) => { - self.close(Some(err.into()), close_signal_sink).await; - break; - } - None => { - self.close(None, close_signal_sink).await; - break; - } - } - } - } - - async fn on_receive(&self, message: transaction::Server) { - // TODO: If an error occurs here (or in some other background process), resources are not - // properly cleaned up, and the application may hang. - match message.server { - Some(Server::Res(res)) => self.collect_res(res), - Some(Server::ResPart(res_part)) => { - self.collect_res_part(res_part).await; - } - None => println!("{}", ClientError::MissingResponseField("server")), - } - } - - fn collect_res(&self, res: transaction::Res) { - match self.res_collectors.lock().unwrap().remove(&res.req_id) { - Some(collector) => collector.send(Ok(res)).unwrap(), - None => { - if let Res::OpenRes(_) = res.res.unwrap() { - // ignore open_res - } else { - println!("{}", ClientError::UnknownRequestId(format!("{:?}", &res.req_id))) - // println!("{}", MESSAGES.client.unknown_request_id.to_err( - // vec![std::str::from_utf8(&res.req_id).unwrap()]) - // ) - } - } - } - } - - async fn collect_res_part(&self, res_part: transaction::ResPart) { - let value = self.res_part_collectors.lock().unwrap().remove(&res_part.req_id); - match value { - Some(mut collector) => { - let req_id = res_part.req_id.clone(); - if collector.send(Ok(res_part)).await.is_ok() { - self.res_part_collectors.lock().unwrap().insert(req_id, collector); - } - } - None => { - let req_id_str = hex_string(&res_part.req_id); - println!("{}", ClientError::UnknownRequestId(req_id_str)); - } - } - } - - async fn close(&self, error: Option, close_signal_sink: CloseSignalSink) { - if let Ok(true) = self.is_open.compare_exchange(true, false) { - let error_str = error.map(|err| err.to_string()); - for (_, collector) in self.res_collectors.lock().unwrap().drain() { - collector.send(Err(close_reason(&error_str))).ok(); - } - let mut res_part_collectors = Vec::new(); - for (_, res_part_collector) in self.res_part_collectors.lock().unwrap().drain() { - res_part_collectors.push(res_part_collector) - } - for mut collector in res_part_collectors { - collector.send(Err(close_reason(&error_str))).await.ok(); - } - close_signal_sink.send(Some(close_reason(&error_str))).unwrap(); - } - } -} - -fn hex_string(v: &[u8]) -> String { - v.iter().map(|b| format!("{:02X}", b)).collect::() -} - -fn close_reason(error_str: &Option) -> Error { - match error_str { - None => ClientError::TransactionIsClosed(), - Some(value) => ClientError::TransactionIsClosedWithErrors(value.clone()), - } - .into() -} - -impl Receiver { - async fn new( - grpc_stream: Streaming, - executor: &Executor, - close_signal_sink: CloseSignalSink, - ) -> Self { - let state = Arc::new(ReceiverState::new()); - executor.spawn_ok({ - let state = state.clone(); - async move { - state.listen(grpc_stream, close_signal_sink).await; - } - }); - Receiver { state } - } - - fn add_single(&mut self, req_id: &ReqId, res_collector: ResCollector) { - self.state.res_collectors.lock().unwrap().insert(req_id.clone(), res_collector); - } - - fn add_stream(&mut self, req_id: &ReqId, res_part_sink: ResPartCollector) { - self.state.res_part_collectors.lock().unwrap().insert(req_id.clone(), res_part_sink); - } -} - -type ResCollector = oneshot::Sender>; -type ResPartCollector = mpsc::Sender>; -type CloseSignalSink = oneshot::Sender>; -type CloseSignalReceiver = oneshot::Receiver>; - -#[derive(Debug)] -pub(crate) struct ResPartStream { - source: mpsc::Receiver>, - stream_req_sink: std::sync::mpsc::Sender, - req_id: ReqId, -} - -impl ResPartStream { - fn new( - source: mpsc::Receiver>, - stream_req_sink: std::sync::mpsc::Sender, - req_id: ReqId, - ) -> Self { - ResPartStream { source, stream_req_sink, req_id } - } -} - -impl Stream for ResPartStream { - type Item = Result; - - fn poll_next(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll> { - let poll = Pin::new(&mut self.source).poll_next(ctx); - match poll { - Poll::Ready(Some(Ok(ref res_part))) => { - match &res_part.res { - Some(res_part::Res::StreamResPart(stream_res_part)) => { - // TODO: unwrap -> expect("enum out of range") - match State::from_i32(stream_res_part.state).unwrap() { - State::Done => Poll::Ready(None), - State::Continue => { - let req_id = self.req_id.clone(); - self.stream_req_sink.send(stream_req(req_id)).unwrap(); - ctx.waker().wake_by_ref(); - Poll::Pending - } - } - } - Some(_other) => poll, - None => panic!("{}", ClientError::MissingResponseField("res_part.res")), - } - } - poll => poll, - } - } -} diff --git a/src/concept/mod.rs b/src/concept/mod.rs index e3e1d450..98dd046d 100644 --- a/src/concept/mod.rs +++ b/src/concept/mod.rs @@ -30,12 +30,8 @@ use std::{ use chrono::NaiveDateTime; use futures::{FutureExt, Stream, StreamExt}; -use typedb_protocol::{ - attribute as attribute_proto, attribute_type as attribute_type_proto, - attribute_type::ValueType, concept as concept_proto, r#type as type_proto, r#type::Encoding, -}; -use crate::common::{error::ClientError, Result}; +use crate::common::{error::ConnectionError, Result}; #[derive(Clone, Debug)] pub enum Concept { @@ -43,41 +39,12 @@ pub enum Concept { Thing(Thing), } -impl Concept { - pub(crate) fn from_proto(mut proto: typedb_protocol::Concept) -> Result { - let concept = proto.concept.ok_or(ClientError::MissingResponseField("concept"))?; - match concept { - concept_proto::Concept::Thing(thing) => Ok(Self::Thing(Thing::from_proto(thing)?)), - concept_proto::Concept::Type(type_) => Ok(Self::Type(Type::from_proto(type_)?)), - } - } -} - #[derive(Clone, Debug)] pub enum Type { Thing(ThingType), Role(RoleType), } -impl Type { - pub(crate) fn from_proto(proto: typedb_protocol::Type) -> Result { - // TODO: replace unwrap() with ok_or(custom_error) throughout the module - match type_proto::Encoding::from_i32(proto.encoding).unwrap() { - Encoding::ThingType => Ok(Self::Thing(ThingType::Root(RootThingType::default()))), - Encoding::EntityType => { - Ok(Self::Thing(ThingType::Entity(EntityType::from_proto(proto)))) - } - Encoding::RelationType => { - Ok(Self::Thing(ThingType::Relation(RelationType::from_proto(proto)))) - } - Encoding::AttributeType => { - Ok(Self::Thing(ThingType::Attribute(AttributeType::from_proto(proto)?))) - } - Encoding::RoleType => Ok(Self::Role(RoleType::from_proto(proto))), - } - } -} - #[derive(Clone, Debug)] pub enum ThingType { Root(RootThingType), @@ -120,10 +87,6 @@ impl EntityType { pub fn new(label: String) -> Self { Self { label } } - - fn from_proto(proto: typedb_protocol::Type) -> Self { - Self::new(proto.label) - } } #[derive(Clone, Debug)] @@ -135,10 +98,6 @@ impl RelationType { pub fn new(label: String) -> Self { Self { label } } - - fn from_proto(proto: typedb_protocol::Type) -> Self { - Self::new(proto.label) - } } #[derive(Clone, Debug)] @@ -151,19 +110,6 @@ pub enum AttributeType { DateTime(DateTimeAttributeType), } -impl AttributeType { - pub(crate) fn from_proto(mut proto: typedb_protocol::Type) -> Result { - match attribute_type_proto::ValueType::from_i32(proto.value_type).unwrap() { - ValueType::Object => Ok(Self::Root(RootAttributeType::default())), - ValueType::Boolean => Ok(Self::Boolean(BooleanAttributeType { label: proto.label })), - ValueType::Long => Ok(Self::Long(LongAttributeType { label: proto.label })), - ValueType::Double => Ok(Self::Double(DoubleAttributeType { label: proto.label })), - ValueType::String => Ok(Self::String(StringAttributeType { label: proto.label })), - ValueType::Datetime => Ok(Self::DateTime(DateTimeAttributeType { label: proto.label })), - } - } -} - #[derive(Clone, Debug)] pub struct RootAttributeType { pub label: String, @@ -244,10 +190,6 @@ pub struct RoleType { } impl RoleType { - fn from_proto(proto: typedb_protocol::Type) -> Self { - Self::new(ScopedLabel::new(proto.scope, proto.label)) - } - pub fn new(label: ScopedLabel) -> Self { Self { label } } @@ -261,23 +203,6 @@ pub enum Thing { Attribute(Attribute), } -impl Thing { - pub(crate) fn from_proto(mut proto: typedb_protocol::Thing) -> Result { - match typedb_protocol::r#type::Encoding::from_i32(proto.r#type.clone().unwrap().encoding) - .unwrap() - { - type_proto::Encoding::EntityType => Ok(Self::Entity(Entity::from_proto(proto)?)), - type_proto::Encoding::RelationType => Ok(Self::Relation(Relation::from_proto(proto)?)), - type_proto::Encoding::AttributeType => { - Ok(Self::Attribute(Attribute::from_proto(proto)?)) - } - _ => { - todo!() - } - } - } -} - // impl ConceptApi for Thing {} // impl ThingApi for Thing { @@ -298,12 +223,6 @@ pub struct Entity { pub type_: EntityType, } -impl Entity { - pub(crate) fn from_proto(mut proto: typedb_protocol::Thing) -> Result { - Ok(Self { type_: EntityType::from_proto(proto.r#type.unwrap()), iid: proto.iid }) - } -} - // impl ThingApi for Entity { // // TODO: use enum_dispatch macro to avoid manually writing the duplicates of this method // fn get_iid(&self) -> &Vec { @@ -321,12 +240,6 @@ pub struct Relation { pub type_: RelationType, } -impl Relation { - pub(crate) fn from_proto(mut proto: typedb_protocol::Thing) -> Result { - Ok(Self { type_: RelationType::from_proto(proto.r#type.unwrap()), iid: proto.iid }) - } -} - // macro_rules! default_impl { // { impl $trait:ident $body:tt for $($t:ident),* $(,)? } => { // $(impl $trait for $t $body)* @@ -354,66 +267,6 @@ pub enum Attribute { DateTime(DateTimeAttribute), } -impl Attribute { - pub(crate) fn from_proto(mut proto: typedb_protocol::Thing) -> Result { - match attribute_type_proto::ValueType::from_i32(proto.r#type.unwrap().value_type).unwrap() { - ValueType::Object => { - todo!() - } - ValueType::Boolean => Ok(Self::Boolean(BooleanAttribute { - value: if let attribute_proto::value::Value::Boolean(value) = - proto.value.unwrap().value.unwrap() - { - value - } else { - todo!() - }, - iid: proto.iid, - })), - ValueType::Long => Ok(Self::Long(LongAttribute { - value: if let attribute_proto::value::Value::Long(value) = - proto.value.unwrap().value.unwrap() - { - value - } else { - todo!() - }, - iid: proto.iid, - })), - ValueType::Double => Ok(Self::Double(DoubleAttribute { - value: if let attribute_proto::value::Value::Double(value) = - proto.value.unwrap().value.unwrap() - { - value - } else { - todo!() - }, - iid: proto.iid, - })), - ValueType::String => Ok(Self::String(StringAttribute { - value: if let attribute_proto::value::Value::String(value) = - proto.value.unwrap().value.unwrap() - { - value - } else { - todo!() - }, - iid: proto.iid, - })), - ValueType::Datetime => Ok(Self::DateTime(DateTimeAttribute { - value: if let attribute_proto::value::Value::DateTime(value) = - proto.value.unwrap().value.unwrap() - { - NaiveDateTime::from_timestamp_opt(value / 1000, (value % 1000) as u32).unwrap() - } else { - todo!() - }, - iid: proto.iid, - })), - } - } -} - #[derive(Clone, Debug)] pub struct BooleanAttribute { pub iid: Vec, diff --git a/src/connection/cluster/client.rs b/src/connection/cluster/client.rs deleted file mode 100644 index af403b95..00000000 --- a/src/connection/cluster/client.rs +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Copyright (C) 2022 Vaticle - * - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use std::sync::Arc; - -use super::{DatabaseManager, Session}; -use crate::common::{ClusterRPC, Credential, Result, SessionType}; - -pub struct Client { - databases: DatabaseManager, - cluster_rpc: Arc, -} - -impl Client { - pub async fn new>(init_addresses: &[T], credential: Credential) -> Result { - let addresses = ClusterRPC::fetch_current_addresses(init_addresses, &credential).await?; - let cluster_rpc = ClusterRPC::new(addresses, credential)?; - let databases = DatabaseManager::new(cluster_rpc.clone()); - Ok(Self { cluster_rpc, databases }) - } - - pub fn databases(&mut self) -> &mut DatabaseManager { - &mut self.databases - } - - pub async fn session( - &mut self, - database_name: &str, - session_type: SessionType, - ) -> Result { - Session::new( - self.databases.get(database_name).await?, - session_type, - self.cluster_rpc.clone(), - ) - .await - } -} diff --git a/src/connection/cluster/database.rs b/src/connection/cluster/database.rs deleted file mode 100644 index 357bdb79..00000000 --- a/src/connection/cluster/database.rs +++ /dev/null @@ -1,247 +0,0 @@ -/* - * Copyright (C) 2022 Vaticle - * - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use std::{fmt, fmt::Debug, future::Future, sync::Arc, time::Duration}; - -use log::debug; -use tokio::time::sleep; - -use crate::{ - common::{ - error::ClientError, rpc::builder::cluster::database_manager::get_req, Address, ClusterRPC, - ClusterServerRPC, Error, Result, - }, - connection::server, -}; - -#[derive(Clone)] -pub struct Database { - pub name: String, - replicas: Vec, - cluster_rpc: Arc, -} - -impl Debug for Database { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("cluster::Database") - .field("name", &self.name) - .field("replicas", &self.replicas) - .finish() - } -} - -impl Database { - const PRIMARY_REPLICA_TASK_MAX_RETRIES: usize = 10; - const FETCH_REPLICAS_MAX_RETRIES: usize = 10; - const WAIT_FOR_PRIMARY_REPLICA_SELECTION: Duration = Duration::from_secs(2); - - pub(super) fn new( - proto: typedb_protocol::ClusterDatabase, - cluster_rpc: Arc, - ) -> Result { - let name = proto.name.clone(); - let replicas = Replica::from_proto(proto, &cluster_rpc); - Ok(Self { name, replicas, cluster_rpc }) - } - - pub(super) async fn get(name: &str, cluster_rpc: Arc) -> Result { - Ok(Self { - name: name.to_string(), - replicas: Replica::fetch_all(name, cluster_rpc.clone()).await?, - cluster_rpc, - }) - } - - pub async fn delete(mut self) -> Result { - self.run_on_primary_replica(|database, _, _| database.delete()).await - } - - pub async fn schema(&mut self) -> Result { - self.run_failsafe(|mut database, _, _| async move { database.schema().await }).await - } - - pub async fn type_schema(&mut self) -> Result { - self.run_failsafe(|mut database, _, _| async move { database.type_schema().await }).await - } - - pub async fn rule_schema(&mut self) -> Result { - self.run_failsafe(|mut database, _, _| async move { database.rule_schema().await }).await - } - - pub(crate) async fn run_failsafe(&mut self, task: F) -> Result - where - F: Fn(server::Database, ClusterServerRPC, bool) -> P, - P: Future>, - { - match self.run_on_any_replica(&task).await { - Err(Error::Client(ClientError::ClusterReplicaNotPrimary())) => { - debug!("Attempted to run on a non-primary replica, retrying on primary..."); - self.run_on_primary_replica(&task).await - } - res => res, - } - } - - async fn run_on_any_replica(&mut self, task: F) -> Result - where - F: Fn(server::Database, ClusterServerRPC, bool) -> P, - P: Future>, - { - let mut is_first_run = true; - for replica in self.replicas.iter() { - match task( - replica.database.clone(), - self.cluster_rpc.get_server_rpc(&replica.address), - is_first_run, - ) - .await - { - Err(Error::Client(ClientError::UnableToConnect())) => { - println!("Unable to connect to {}. Attempting next server.", replica.address); - } - res => return res, - } - is_first_run = false; - } - Err(self.cluster_rpc.unable_to_connect()) - } - - async fn run_on_primary_replica(&mut self, task: F) -> Result - where - F: Fn(server::Database, ClusterServerRPC, bool) -> P, - P: Future>, - { - let mut primary_replica = if let Some(replica) = self.primary_replica() { - replica - } else { - self.seek_primary_replica().await? - }; - - for retry in 0..Self::PRIMARY_REPLICA_TASK_MAX_RETRIES { - match task( - primary_replica.database.clone(), - self.cluster_rpc.get_server_rpc(&primary_replica.address), - retry == 0, - ) - .await - { - Err(Error::Client( - ClientError::ClusterReplicaNotPrimary() | ClientError::UnableToConnect(), - )) => { - debug!("Primary replica error, waiting..."); - Self::wait_for_primary_replica_selection().await; - primary_replica = self.seek_primary_replica().await?; - } - res => return res, - } - } - Err(self.cluster_rpc.unable_to_connect()) - } - - async fn seek_primary_replica(&mut self) -> Result { - for _ in 0..Self::FETCH_REPLICAS_MAX_RETRIES { - self.replicas = Replica::fetch_all(&self.name, self.cluster_rpc.clone()).await?; - if let Some(replica) = self.primary_replica() { - return Ok(replica); - } - Self::wait_for_primary_replica_selection().await; - } - Err(self.cluster_rpc.unable_to_connect()) - } - - fn primary_replica(&mut self) -> Option { - self.replicas.iter().filter(|r| r.is_primary).max_by_key(|r| r.term).cloned() - } - - async fn wait_for_primary_replica_selection() { - sleep(Self::WAIT_FOR_PRIMARY_REPLICA_SELECTION).await; - } -} - -#[derive(Clone)] -pub struct Replica { - address: Address, - database_name: String, - is_primary: bool, - term: i64, - is_preferred: bool, - database: server::Database, -} - -impl Debug for Replica { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Replica") - .field("address", &self.address) - .field("database_name", &self.database_name) - .field("is_primary", &self.is_primary) - .field("term", &self.term) - .field("is_preferred", &self.is_preferred) - .finish() - } -} - -impl Replica { - fn new( - name: &str, - metadata: typedb_protocol::cluster_database::Replica, - server_rpc: ClusterServerRPC, - ) -> Self { - Self { - address: metadata.address.parse().expect("Invalid URI received from the server"), - database_name: name.to_owned(), - is_primary: metadata.primary, - term: metadata.term, - is_preferred: metadata.preferred, - database: server::Database::new(name, server_rpc.into()), - } - } - - fn from_proto(proto: typedb_protocol::ClusterDatabase, cluster_rpc: &ClusterRPC) -> Vec { - proto - .replicas - .into_iter() - .map(|replica| { - let server_rpc = cluster_rpc.get_server_rpc(&replica.address.parse().unwrap()); - Replica::new(&proto.name, replica, server_rpc) - }) - .collect() - } - - async fn fetch_all(name: &str, cluster_rpc: Arc) -> Result> { - for mut rpc in cluster_rpc.iter_server_rpcs_cloned() { - let res = rpc.databases_get(get_req(name)).await; - match res { - Ok(res) => { - return Ok(Replica::from_proto(res.database.unwrap(), &cluster_rpc)); - } - Err(Error::Client(ClientError::UnableToConnect())) => { - println!( - "Failed to fetch replica info for database '{}' from {}. Attempting next server.", - name, - rpc.address() - ); - } - Err(err) => return Err(err), - } - } - Err(cluster_rpc.unable_to_connect()) - } -} diff --git a/src/connection/cluster/database_manager.rs b/src/connection/cluster/database_manager.rs deleted file mode 100644 index 48fd1190..00000000 --- a/src/connection/cluster/database_manager.rs +++ /dev/null @@ -1,94 +0,0 @@ -/* - * Copyright (C) 2022 Vaticle - * - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use std::{fmt::Debug, future::Future, sync::Arc}; - -use super::Database; -use crate::{ - common::{ - error::ClientError, - rpc::builder::{ - cluster::database_manager::all_req, - core::database_manager::{contains_req, create_req}, - }, - ClusterRPC, ClusterServerRPC, Result, - }, - connection::server, -}; - -#[derive(Clone, Debug)] -pub struct DatabaseManager { - cluster_rpc: Arc, -} - -impl DatabaseManager { - pub(crate) fn new(cluster_rpc: Arc) -> Self { - Self { cluster_rpc } - } - - pub async fn get(&mut self, name: &str) -> Result { - Database::get(name, self.cluster_rpc.clone()).await - } - - pub async fn contains(&mut self, name: &str) -> Result { - Ok(self - .run_failsafe(name, move |database, mut server_rpc, _| { - let req = contains_req(&database.name); - async move { server_rpc.databases_contains(req).await } - }) - .await? - .contains) - } - - pub async fn create(&mut self, name: &str) -> Result { - self.run_failsafe(name, |database, mut server_rpc, _| { - let req = create_req(&database.name); - async move { server_rpc.databases_create(req).await } - }) - .await?; - Ok(()) - } - - pub async fn all(&mut self) -> Result> { - let mut error_buffer = Vec::with_capacity(self.cluster_rpc.server_rpc_count()); - for mut server_rpc in self.cluster_rpc.iter_server_rpcs_cloned() { - match server_rpc.databases_all(all_req()).await { - Ok(list) => { - return list - .databases - .into_iter() - .map(|proto_db| Database::new(proto_db, self.cluster_rpc.clone())) - .collect() - } - Err(err) => error_buffer.push(format!("- {}: {}", server_rpc.address(), err)), - } - } - Err(ClientError::ClusterAllNodesFailed(error_buffer.join("\n")))? - } - - async fn run_failsafe(&mut self, name: &str, task: F) -> Result - where - F: Fn(server::Database, ClusterServerRPC, bool) -> P, - P: Future>, - { - Database::get(name, self.cluster_rpc.clone()).await?.run_failsafe(&task).await - } -} diff --git a/src/connection/cluster/session.rs b/src/connection/cluster/session.rs deleted file mode 100644 index 12bc7467..00000000 --- a/src/connection/cluster/session.rs +++ /dev/null @@ -1,93 +0,0 @@ -/* - * Copyright (C) 2022 Vaticle - * - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use std::sync::Arc; - -use super::Database; -use crate::{ - common::{ClusterRPC, Result, SessionType, TransactionType}, - connection::{core, server, server::Transaction}, -}; - -pub struct Session { - pub database: Database, - pub session_type: SessionType, - - server_session: server::Session, - cluster_rpc: Arc, -} - -impl Session { - // TODO options - pub(crate) async fn new( - mut database: Database, - session_type: SessionType, - cluster_rpc: Arc, - ) -> Result { - let server_session = database - .run_failsafe(|database, server_rpc, _| async { - let database_name = database.name; - server::Session::new( - database_name.as_str(), - session_type, - core::Options::default(), - server_rpc.into(), - ) - .await - }) - .await?; - - Ok(Self { database, session_type, server_session, cluster_rpc }) - } - - //TODO options - pub async fn transaction(&mut self, transaction_type: TransactionType) -> Result { - let (session, transaction) = self - .database - .run_failsafe(|database, server_rpc, is_first_run| { - let session_type = self.session_type; - let session = &self.server_session; - async move { - if is_first_run { - let transaction = session.transaction(transaction_type).await?; - Ok((None, transaction)) - } else { - let server_session = server::Session::new( - database.name.as_str(), - session_type, - core::Options::default(), - server_rpc.into(), - ) - .await?; - let transaction = server_session.transaction(transaction_type).await?; - Ok((Some(server_session), transaction)) - } - } - }) - .await?; - - if let Some(session) = session { - self.server_session = session; - } - - Ok(transaction) - } -} diff --git a/src/connection/connection.rs b/src/connection/connection.rs new file mode 100644 index 00000000..1d4d176d --- /dev/null +++ b/src/connection/connection.rs @@ -0,0 +1,332 @@ +/* + * Copyright (C) 2022 Vaticle + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::{ + collections::{HashMap, HashSet}, + fmt, + sync::{Arc, Mutex}, + time::Duration, +}; + +use itertools::Itertools; +use tokio::{ + select, + sync::mpsc::{unbounded_channel as unbounded_async, UnboundedReceiver, UnboundedSender}, + time::{sleep_until, Instant}, +}; + +use super::{ + network::transmitter::{RPCTransmitter, TransactionTransmitter}, + runtime::BackgroundRuntime, + TransactionStream, +}; +use crate::{ + common::{ + address::Address, + error::{ConnectionError, Error}, + info::{DatabaseInfo, SessionInfo}, + Result, SessionID, SessionType, TransactionType, + }, + connection::message::{Request, Response, TransactionRequest}, + error::InternalError, + Credential, Options, +}; + +#[derive(Clone)] +pub struct Connection { + server_connections: HashMap, + background_runtime: Arc, +} + +impl Connection { + pub fn new_plaintext(address: impl AsRef) -> Result { + let address: Address = address.as_ref().parse()?; + let background_runtime = Arc::new(BackgroundRuntime::new()?); + let server_connection = ServerConnection::new_plaintext(background_runtime.clone(), address.clone())?; + Ok(Self { server_connections: [(address, server_connection)].into(), background_runtime }) + } + + pub fn new_encrypted + Sync>(init_addresses: &[T], credential: Credential) -> Result { + let background_runtime = Arc::new(BackgroundRuntime::new()?); + + let init_addresses = init_addresses.iter().map(|addr| addr.as_ref().parse()).try_collect()?; + let addresses = Self::fetch_current_addresses(background_runtime.clone(), init_addresses, credential.clone())?; + + let mut server_connections = HashMap::with_capacity(addresses.len()); + for address in addresses { + let server_connection = + ServerConnection::new_encrypted(background_runtime.clone(), address.clone(), credential.clone())?; + server_connections.insert(address, server_connection); + } + + Ok(Self { server_connections, background_runtime }) + } + + fn fetch_current_addresses( + background_runtime: Arc, + addresses: Vec
, + credential: Credential, + ) -> Result> { + for address in addresses { + let server_connection = + ServerConnection::new_encrypted(background_runtime.clone(), address.clone(), credential.clone())?; + match server_connection.servers_all() { + Ok(servers) => return Ok(servers.into_iter().collect()), + Err(Error::Connection(ConnectionError::UnableToConnect())) => (), + Err(err) => Err(err)?, + } + } + Err(ConnectionError::UnableToConnect())? + } + + pub fn force_close(self) -> Result { + self.server_connections.values().map(ServerConnection::force_close).try_collect()?; + self.background_runtime.force_close() + } + + pub(crate) fn server_count(&self) -> usize { + self.server_connections.len() + } + + pub(crate) fn addresses(&self) -> impl Iterator { + self.server_connections.keys() + } + + pub(crate) fn connection(&self, address: &Address) -> Result<&ServerConnection> { + self.server_connections + .get(address) + .ok_or_else(|| InternalError::UnknownConnectionAddress(address.to_string()).into()) + } + + pub(crate) fn connections(&self) -> impl Iterator + '_ { + self.server_connections.values() + } + + pub(crate) fn unable_to_connect_error(&self) -> Error { + Error::Connection(ConnectionError::ClusterUnableToConnect( + self.addresses().map(Address::to_string).collect::>().join(","), + )) + } +} + +impl fmt::Debug for Connection { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Connection").field("server_connections", &self.server_connections).finish() + } +} + +#[derive(Clone)] +pub(crate) struct ServerConnection { + address: Address, + background_runtime: Arc, + open_sessions: Arc>>>, + request_transmitter: Arc, +} + +impl ServerConnection { + fn new_plaintext(background_runtime: Arc, address: Address) -> Result { + let request_transmitter = Arc::new(RPCTransmitter::start_plaintext(address.clone(), &background_runtime)?); + Ok(Self { address, background_runtime, open_sessions: Default::default(), request_transmitter }) + } + + fn new_encrypted( + background_runtime: Arc, + address: Address, + credential: Credential, + ) -> Result { + let request_transmitter = + Arc::new(RPCTransmitter::start_encrypted(address.clone(), credential, &background_runtime)?); + Ok(Self { address, background_runtime, open_sessions: Default::default(), request_transmitter }) + } + + pub(crate) fn address(&self) -> &Address { + &self.address + } + + async fn request_async(&self, request: Request) -> Result { + if !self.background_runtime.is_open() { + return Err(ConnectionError::ConnectionIsClosed().into()); + } + self.request_transmitter.request_async(request).await + } + + fn request_blocking(&self, request: Request) -> Result { + if !self.background_runtime.is_open() { + return Err(ConnectionError::ConnectionIsClosed().into()); + } + self.request_transmitter.request_blocking(request) + } + + pub(crate) fn force_close(&self) -> Result { + let session_ids: Vec = self.open_sessions.lock().unwrap().keys().cloned().collect(); + for session_id in session_ids.into_iter() { + self.close_session(session_id).ok(); + } + self.request_transmitter.force_close() + } + + pub(crate) fn servers_all(&self) -> Result> { + match self.request_blocking(Request::ServersAll)? { + Response::ServersAll { servers } => Ok(servers), + other => Err(InternalError::UnexpectedResponseType(format!("{other:?}")).into()), + } + } + + pub(crate) async fn database_exists(&self, database_name: String) -> Result { + match self.request_async(Request::DatabasesContains { database_name }).await? { + Response::DatabasesContains { contains } => Ok(contains), + other => Err(InternalError::UnexpectedResponseType(format!("{other:?}")).into()), + } + } + + pub(crate) async fn create_database(&self, database_name: String) -> Result { + self.request_async(Request::DatabaseCreate { database_name }).await?; + Ok(()) + } + + pub(crate) async fn get_database_replicas(&self, database_name: String) -> Result { + match self.request_async(Request::DatabaseGet { database_name }).await? { + Response::DatabaseGet { database } => Ok(database), + other => Err(InternalError::UnexpectedResponseType(format!("{other:?}")).into()), + } + } + + pub(crate) async fn all_databases(&self) -> Result> { + match self.request_async(Request::DatabasesAll).await? { + Response::DatabasesAll { databases } => Ok(databases), + other => Err(InternalError::UnexpectedResponseType(format!("{other:?}")).into()), + } + } + + pub(crate) async fn database_schema(&self, database_name: String) -> Result { + match self.request_async(Request::DatabaseSchema { database_name }).await? { + Response::DatabaseSchema { schema } => Ok(schema), + other => Err(InternalError::UnexpectedResponseType(format!("{other:?}")).into()), + } + } + + pub(crate) async fn database_type_schema(&self, database_name: String) -> Result { + match self.request_async(Request::DatabaseTypeSchema { database_name }).await? { + Response::DatabaseTypeSchema { schema } => Ok(schema), + other => Err(InternalError::UnexpectedResponseType(format!("{other:?}")).into()), + } + } + + pub(crate) async fn database_rule_schema(&self, database_name: String) -> Result { + match self.request_async(Request::DatabaseRuleSchema { database_name }).await? { + Response::DatabaseRuleSchema { schema } => Ok(schema), + other => Err(InternalError::UnexpectedResponseType(format!("{other:?}")).into()), + } + } + + pub(crate) async fn delete_database(&self, database_name: String) -> Result { + self.request_async(Request::DatabaseDelete { database_name }).await?; + Ok(()) + } + + pub(crate) async fn open_session( + &self, + database_name: String, + session_type: SessionType, + options: Options, + ) -> Result { + let start = Instant::now(); + match self.request_async(Request::SessionOpen { database_name, session_type, options }).await? { + Response::SessionOpen { session_id, server_duration } => { + let (pulse_shutdown_sink, pulse_shutdown_source) = unbounded_async(); + self.open_sessions.lock().unwrap().insert(session_id.clone(), pulse_shutdown_sink); + self.background_runtime.spawn(session_pulse( + session_id.clone(), + self.request_transmitter.clone(), + pulse_shutdown_source, + )); + Ok(SessionInfo { + address: self.address.clone(), + session_id, + network_latency: start.elapsed() - server_duration, + }) + } + other => Err(InternalError::UnexpectedResponseType(format!("{other:?}")).into()), + } + } + + pub(crate) fn close_session(&self, session_id: SessionID) -> Result { + if let Some(sink) = self.open_sessions.lock().unwrap().remove(&session_id) { + sink.send(()).ok(); + } + self.request_blocking(Request::SessionClose { session_id })?; + Ok(()) + } + + pub(crate) async fn open_transaction( + &self, + session_id: SessionID, + transaction_type: TransactionType, + options: Options, + network_latency: Duration, + ) -> Result { + match self + .request_async(Request::Transaction(TransactionRequest::Open { + session_id, + transaction_type, + options: options.clone(), + network_latency, + })) + .await? + { + Response::TransactionOpen { request_sink, response_source } => { + let transmitter = TransactionTransmitter::new(&self.background_runtime, request_sink, response_source); + Ok(TransactionStream::new(transaction_type, options, transmitter)) + } + other => Err(InternalError::UnexpectedResponseType(format!("{other:?}")).into()), + } + } +} + +impl fmt::Debug for ServerConnection { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ServerConnection") + .field("address", &self.address) + .field("open_sessions", &self.open_sessions) + .finish() + } +} + +async fn session_pulse( + session_id: SessionID, + request_transmitter: Arc, + mut shutdown_source: UnboundedReceiver<()>, +) { + const PULSE_INTERVAL: Duration = Duration::from_secs(5); + let mut next_pulse = Instant::now(); + loop { + select! { + _ = sleep_until(next_pulse) => { + request_transmitter + .request_async(Request::SessionPulse { session_id: session_id.clone() }) + .await + .ok(); + next_pulse += PULSE_INTERVAL; + } + _ = shutdown_source.recv() => break, + } + } +} diff --git a/src/connection/core/client.rs b/src/connection/core/client.rs deleted file mode 100644 index aa613a93..00000000 --- a/src/connection/core/client.rs +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Copyright (C) 2022 Vaticle - * - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use crate::{ - common::{CoreRPC, Result, SessionType}, - connection::{core, server}, -}; - -pub struct Client { - databases: core::DatabaseManager, - core_rpc: CoreRPC, -} - -impl Client { - pub async fn new(address: &str) -> Result { - let core_rpc = CoreRPC::connect(address.parse()?).await?; - Ok(Self { databases: core::DatabaseManager::new(core_rpc.clone()), core_rpc }) - } - - pub async fn with_default_address() -> Result { - Self::new("http://localhost:1729").await - } - - pub fn databases(&mut self) -> &mut core::DatabaseManager { - &mut self.databases - } - - pub async fn session( - &mut self, - database_name: &str, - session_type: SessionType, - ) -> Result { - self.session_with_options(database_name, session_type, core::Options::default()).await - } - - pub async fn session_with_options( - &mut self, - database_name: &str, - session_type: SessionType, - options: core::Options, - ) -> Result { - server::Session::new(database_name, session_type, options, self.core_rpc.clone().into()) - .await - } -} diff --git a/src/connection/core/database_manager.rs b/src/connection/core/database_manager.rs deleted file mode 100644 index 1d28d901..00000000 --- a/src/connection/core/database_manager.rs +++ /dev/null @@ -1,76 +0,0 @@ -/* - * Copyright (C) 2022 Vaticle - * - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use crate::{ - common::{ - error::ClientError, - rpc::builder::core::database_manager::{all_req, contains_req, create_req}, - CoreRPC, Result, - }, - connection::server, -}; - -/// An interface for performing database-level operations against the connected server. -/// These operations include: -/// -/// - Listing [all databases][DatabaseManager::all] -/// - Creating a [new database][DatabaseManager::create] -/// - Checking if a database [exists][DatabaseManager::contains] -/// - Retrieving a [specific database][DatabaseManager::get] in order to perform further operations on it -/// -/// These operations all connect to the server to retrieve results. In the event of a connection -/// failure or other problem executing the operation, they will return an [`Err`][Err] result. -#[derive(Clone, Debug)] -pub struct DatabaseManager { - pub(crate) core_rpc: CoreRPC, -} - -impl DatabaseManager { - pub(crate) fn new(core_rpc: CoreRPC) -> Self { - DatabaseManager { core_rpc } - } - - /// Retrieves a single [`Database`][Database] by name. Returns an [`Err`][Err] if there does not - /// exist a database with the provided name. - pub async fn get(&mut self, name: &str) -> Result { - match self.contains(name).await? { - true => Ok(server::Database::new(name, self.core_rpc.clone().into())), - false => Err(ClientError::DatabaseDoesNotExist(name.to_string()))?, - } - } - - pub async fn contains(&mut self, name: &str) -> Result { - self.core_rpc.databases_contains(contains_req(name)).await.map(|res| res.contains) - } - - pub async fn create(&mut self, name: &str) -> Result { - self.core_rpc.databases_create(create_req(name)).await.map(|_| ()) - } - - pub async fn all(&mut self) -> Result> { - self.core_rpc.databases_all(all_req()).await.map(|res| { - res.names - .iter() - .map(|name| server::Database::new(name, self.core_rpc.clone().into())) - .collect() - }) - } -} diff --git a/src/connection/core/options.rs b/src/connection/core/options.rs deleted file mode 100644 index 778d69c4..00000000 --- a/src/connection/core/options.rs +++ /dev/null @@ -1,134 +0,0 @@ -/* - * Copyright (C) 2022 Vaticle - * - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use std::time::Duration; - -use typedb_protocol::{ - options::{ - ExplainOpt::Explain, InferOpt::Infer, ParallelOpt::Parallel, PrefetchOpt::Prefetch, - PrefetchSizeOpt::PrefetchSize, ReadAnyReplicaOpt::ReadAnyReplica, - SchemaLockAcquireTimeoutOpt::SchemaLockAcquireTimeoutMillis, - SessionIdleTimeoutOpt::SessionIdleTimeoutMillis, TraceInferenceOpt::TraceInference, - TransactionTimeoutOpt::TransactionTimeoutMillis, - }, - Options as OptionsProto, -}; - -macro_rules! options { - {pub struct $name:ident { $(pub $field_name:ident : Option<$field_type:ty>),* $(,)? }} => { - #[derive(Clone, Debug, Default)] - pub struct $name { - $(pub $field_name: Option<$field_type>,)* - } - - impl $name { - $( - pub fn $field_name(mut self, value: $field_type) -> Self { - self.$field_name = value.into(); - self - } - )* - } - }; -} - -options! { - pub struct Options { - pub infer: Option, - pub trace_inference: Option, - pub explain: Option, - pub parallel: Option, - pub prefetch: Option, - pub prefetch_size: Option, - pub session_idle_timeout: Option, - pub transaction_timeout: Option, - pub schema_lock_acquire_timeout: Option, - } -} - -options! { - pub struct ClusterOptions { - pub infer: Option, - pub trace_inference: Option, - pub explain: Option, - pub parallel: Option, - pub prefetch: Option, - pub prefetch_size: Option, - pub session_idle_timeout: Option, - pub transaction_timeout: Option, - pub schema_lock_acquire_timeout: Option, - pub read_any_replica: Option, - } -} - -impl Options { - pub fn new_core() -> Options { - Options::default() - } - - pub fn new_cluster() -> ClusterOptions { - ClusterOptions::default() - } - - pub(crate) fn to_proto(&self) -> OptionsProto { - OptionsProto { - infer_opt: self.infer.map(Infer), - trace_inference_opt: self.trace_inference.map(TraceInference), - explain_opt: self.explain.map(Explain), - parallel_opt: self.parallel.map(Parallel), - prefetch_size_opt: self.prefetch_size.map(PrefetchSize), - prefetch_opt: self.prefetch.map(Prefetch), - session_idle_timeout_opt: self - .session_idle_timeout - .map(|val| SessionIdleTimeoutMillis(val.as_millis() as i32)), - transaction_timeout_opt: self - .transaction_timeout - .map(|val| TransactionTimeoutMillis(val.as_millis() as i32)), - schema_lock_acquire_timeout_opt: self - .schema_lock_acquire_timeout - .map(|val| SchemaLockAcquireTimeoutMillis(val.as_millis() as i32)), - read_any_replica_opt: None, - } - } -} - -impl ClusterOptions { - pub(crate) fn to_proto(&self) -> OptionsProto { - OptionsProto { - infer_opt: self.infer.map(Infer), - trace_inference_opt: self.trace_inference.map(TraceInference), - explain_opt: self.explain.map(Explain), - parallel_opt: self.parallel.map(Parallel), - prefetch_size_opt: self.prefetch_size.map(PrefetchSize), - prefetch_opt: self.prefetch.map(Prefetch), - session_idle_timeout_opt: self - .session_idle_timeout - .map(|val| SessionIdleTimeoutMillis(val.as_millis() as i32)), - transaction_timeout_opt: self - .transaction_timeout - .map(|val| TransactionTimeoutMillis(val.as_millis() as i32)), - schema_lock_acquire_timeout_opt: self - .schema_lock_acquire_timeout - .map(|val| SchemaLockAcquireTimeoutMillis(val.as_millis() as i32)), - read_any_replica_opt: self.read_any_replica.map(ReadAnyReplica), - } - } -} diff --git a/src/connection/message.rs b/src/connection/message.rs new file mode 100644 index 00000000..9647fb13 --- /dev/null +++ b/src/connection/message.rs @@ -0,0 +1,147 @@ +/* + * Copyright (C) 2022 Vaticle + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::time::Duration; + +use tokio::sync::mpsc::UnboundedSender; +use tonic::Streaming; +use typedb_protocol::transaction; + +use crate::{ + answer::{ConceptMap, Numeric}, + common::{address::Address, info::DatabaseInfo, RequestID, SessionID}, + Options, SessionType, TransactionType, +}; + +#[derive(Debug)] +pub(super) enum Request { + ServersAll, + + DatabasesContains { database_name: String }, + DatabaseCreate { database_name: String }, + DatabaseGet { database_name: String }, + DatabasesAll, + + DatabaseSchema { database_name: String }, + DatabaseTypeSchema { database_name: String }, + DatabaseRuleSchema { database_name: String }, + DatabaseDelete { database_name: String }, + + SessionOpen { database_name: String, session_type: SessionType, options: Options }, + SessionClose { session_id: SessionID }, + SessionPulse { session_id: SessionID }, + + Transaction(TransactionRequest), +} + +#[derive(Debug)] +pub(super) enum Response { + ServersAll { + servers: Vec
, + }, + + DatabasesContains { + contains: bool, + }, + DatabaseCreate, + DatabaseGet { + database: DatabaseInfo, + }, + DatabasesAll { + databases: Vec, + }, + + DatabaseDelete, + DatabaseSchema { + schema: String, + }, + DatabaseTypeSchema { + schema: String, + }, + DatabaseRuleSchema { + schema: String, + }, + + SessionOpen { + session_id: SessionID, + server_duration: Duration, + }, + SessionPulse, + SessionClose, + + TransactionOpen { + request_sink: UnboundedSender, + response_source: Streaming, + }, +} + +#[derive(Debug)] +pub(super) enum TransactionRequest { + Open { session_id: SessionID, transaction_type: TransactionType, options: Options, network_latency: Duration }, + Commit, + Rollback, + Query(QueryRequest), + Stream { request_id: RequestID }, +} + +#[derive(Debug)] +pub(super) enum TransactionResponse { + Open, + Commit, + Rollback, + Query(QueryResponse), +} + +#[derive(Debug)] +pub(super) enum QueryRequest { + Define { query: String, options: Options }, + Undefine { query: String, options: Options }, + Delete { query: String, options: Options }, + + Match { query: String, options: Options }, + Insert { query: String, options: Options }, + Update { query: String, options: Options }, + + MatchAggregate { query: String, options: Options }, + + Explain { explainable_id: i64, options: Options }, // TODO: ID type + + MatchGroup { query: String, options: Options }, + MatchGroupAggregate { query: String, options: Options }, +} + +#[derive(Debug)] +pub(super) enum QueryResponse { + Define, + Undefine, + Delete, + + Match { answers: Vec }, + Insert { answers: Vec }, + Update { answers: Vec }, + + MatchAggregate { answer: Numeric }, + + Explain {}, // TODO: explanations + + MatchGroup {}, // TODO: ConceptMapGroup + MatchGroupAggregate {}, // TODO: NumericGroup +} diff --git a/src/connection/mod.rs b/src/connection/mod.rs index 671b1c06..6b96ace1 100644 --- a/src/connection/mod.rs +++ b/src/connection/mod.rs @@ -19,6 +19,11 @@ * under the License. */ -pub mod cluster; -pub mod core; -pub mod server; +mod connection; +mod message; +mod network; +mod runtime; +mod transaction_stream; + +pub use self::connection::Connection; +pub(crate) use self::{connection::ServerConnection, transaction_stream::TransactionStream}; diff --git a/src/connection/network/channel.rs b/src/connection/network/channel.rs new file mode 100644 index 00000000..e75171b7 --- /dev/null +++ b/src/connection/network/channel.rs @@ -0,0 +1,137 @@ +/* + * Copyright (C) 2022 Vaticle + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::sync::{Arc, RwLock}; + +use tonic::{ + body::BoxBody, + client::GrpcService, + service::{ + interceptor::{self, InterceptedService}, + Interceptor, + }, + transport::{channel, Channel, Error as TonicError}, + Request, Status, +}; + +use crate::{ + common::{address::Address, Result, StdResult}, + Credential, +}; + +type ResponseFuture = interceptor::ResponseFuture; + +pub(super) type PlainTextChannel = InterceptedService; +pub(super) type CallCredChannel = InterceptedService; + +pub(super) trait GRPCChannel: + GrpcService + Clone + Send + 'static +{ + fn is_plaintext(&self) -> bool; +} + +impl GRPCChannel for PlainTextChannel { + fn is_plaintext(&self) -> bool { + true + } +} + +impl GRPCChannel for CallCredChannel { + fn is_plaintext(&self) -> bool { + false + } +} + +pub(super) fn open_plaintext_channel(address: Address) -> PlainTextChannel { + PlainTextChannel::new(Channel::builder(address.into_uri()).connect_lazy(), PlainTextFacade) +} + +#[derive(Clone, Debug)] +pub(super) struct PlainTextFacade; + +impl Interceptor for PlainTextFacade { + fn call(&mut self, request: Request<()>) -> StdResult, Status> { + Ok(request) + } +} + +pub(super) fn open_encrypted_channel( + address: Address, + credential: Credential, +) -> Result<(CallCredChannel, Arc)> { + let mut builder = Channel::builder(address.into_uri()); + if credential.is_tls_enabled() { + builder = builder.tls_config(credential.tls_config().clone().unwrap())?; + } + let channel = builder.connect_lazy(); + let call_credentials = Arc::new(CallCredentials::new(credential)); + Ok((CallCredChannel::new(channel, CredentialInjector::new(call_credentials.clone())), call_credentials)) +} + +#[derive(Debug)] +pub(super) struct CallCredentials { + credential: Credential, + token: RwLock>, +} + +impl CallCredentials { + pub(super) fn new(credential: Credential) -> Self { + Self { credential, token: RwLock::new(None) } + } + + pub(super) fn username(&self) -> &str { + self.credential.username() + } + + pub(super) fn set_token(&self, token: String) { + *self.token.write().unwrap() = Some(token); + } + + pub(super) fn reset_token(&self) { + *self.token.write().unwrap() = None; + } + + pub(super) fn inject(&self, mut request: Request<()>) -> Request<()> { + request.metadata_mut().insert("username", self.credential.username().try_into().unwrap()); + match &*self.token.read().unwrap() { + Some(token) => request.metadata_mut().insert("token", token.try_into().unwrap()), + None => request.metadata_mut().insert("password", self.credential.password().try_into().unwrap()), + }; + request + } +} + +#[derive(Clone, Debug)] +pub(super) struct CredentialInjector { + call_credentials: Arc, +} + +impl CredentialInjector { + pub(super) fn new(call_credentials: Arc) -> Self { + Self { call_credentials } + } +} + +impl Interceptor for CredentialInjector { + fn call(&mut self, request: Request<()>) -> StdResult, Status> { + Ok(self.call_credentials.inject(request)) + } +} diff --git a/src/connection/core/mod.rs b/src/connection/network/mod.rs similarity index 86% rename from src/connection/core/mod.rs rename to src/connection/network/mod.rs index 4bbd6efe..6bd85c52 100644 --- a/src/connection/core/mod.rs +++ b/src/connection/network/mod.rs @@ -19,8 +19,7 @@ * under the License. */ -mod client; -mod database_manager; -mod options; - -pub use self::{client::Client, database_manager::DatabaseManager, options::Options}; +mod channel; +mod proto; +mod stub; +pub(super) mod transmitter; diff --git a/src/connection/network/proto/common.rs b/src/connection/network/proto/common.rs new file mode 100644 index 00000000..99c69a24 --- /dev/null +++ b/src/connection/network/proto/common.rs @@ -0,0 +1,74 @@ +/* + * Copyright (C) 2022 Vaticle + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use typedb_protocol::{ + options::{ + ExplainOpt::Explain, InferOpt::Infer, ParallelOpt::Parallel, PrefetchOpt::Prefetch, + PrefetchSizeOpt::PrefetchSize, ReadAnyReplicaOpt::ReadAnyReplica, + SchemaLockAcquireTimeoutOpt::SchemaLockAcquireTimeoutMillis, SessionIdleTimeoutOpt::SessionIdleTimeoutMillis, + TraceInferenceOpt::TraceInference, TransactionTimeoutOpt::TransactionTimeoutMillis, + }, + session, transaction, Options as OptionsProto, +}; + +use super::IntoProto; +use crate::{Options, SessionType, TransactionType}; + +impl IntoProto for SessionType { + fn into_proto(self) -> session::Type { + match self { + SessionType::Data => session::Type::Data, + SessionType::Schema => session::Type::Schema, + } + } +} + +impl IntoProto for TransactionType { + fn into_proto(self) -> transaction::Type { + match self { + TransactionType::Read => transaction::Type::Read, + TransactionType::Write => transaction::Type::Write, + } + } +} + +impl IntoProto for Options { + fn into_proto(self) -> OptionsProto { + OptionsProto { + infer_opt: self.infer.map(Infer), + trace_inference_opt: self.trace_inference.map(TraceInference), + explain_opt: self.explain.map(Explain), + parallel_opt: self.parallel.map(Parallel), + prefetch_size_opt: self.prefetch_size.map(PrefetchSize), + prefetch_opt: self.prefetch.map(Prefetch), + session_idle_timeout_opt: self + .session_idle_timeout + .map(|val| SessionIdleTimeoutMillis(val.as_millis() as i32)), + transaction_timeout_opt: self + .transaction_timeout + .map(|val| TransactionTimeoutMillis(val.as_millis() as i32)), + schema_lock_acquire_timeout_opt: self + .schema_lock_acquire_timeout + .map(|val| SchemaLockAcquireTimeoutMillis(val.as_millis() as i32)), + read_any_replica_opt: self.read_any_replica.map(ReadAnyReplica), + } + } +} diff --git a/src/connection/network/proto/concept.rs b/src/connection/network/proto/concept.rs new file mode 100644 index 00000000..495f2529 --- /dev/null +++ b/src/connection/network/proto/concept.rs @@ -0,0 +1,196 @@ +/* + * Copyright (C) 2022 Vaticle + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::collections::HashMap; + +use chrono::NaiveDateTime; +use typedb_protocol::{ + attribute::value::Value as ValueProto, attribute_type::ValueType, concept as concept_proto, numeric::Value, + r#type::Encoding, Concept as ConceptProto, ConceptMap as ConceptMapProto, Numeric as NumericProto, + Thing as ThingProto, Type as TypeProto, +}; + +use super::TryFromProto; +use crate::{ + answer::{ConceptMap, Numeric}, + concept::{ + Attribute, AttributeType, BooleanAttribute, BooleanAttributeType, Concept, DateTimeAttribute, + DateTimeAttributeType, DoubleAttribute, DoubleAttributeType, Entity, EntityType, LongAttribute, + LongAttributeType, Relation, RelationType, RoleType, RootAttributeType, RootThingType, ScopedLabel, + StringAttribute, StringAttributeType, Thing, ThingType, Type, + }, + connection::network::proto::FromProto, + error::{ConnectionError, InternalError}, + Result, +}; + +impl TryFromProto for Numeric { + fn try_from_proto(proto: NumericProto) -> Result { + match proto.value { + Some(Value::LongValue(long)) => Ok(Numeric::Long(long)), + Some(Value::DoubleValue(double)) => Ok(Numeric::Double(double)), + Some(Value::Nan(_)) => Ok(Numeric::NaN), + None => Err(ConnectionError::MissingResponseField("value").into()), + } + } +} + +impl TryFromProto for ConceptMap { + fn try_from_proto(proto: ConceptMapProto) -> Result { + let mut map = HashMap::with_capacity(proto.map.len()); + for (k, v) in proto.map { + map.insert(k, Concept::try_from_proto(v)?); + } + Ok(Self { map }) + } +} + +impl TryFromProto for Concept { + fn try_from_proto(proto: ConceptProto) -> Result { + let concept = proto.concept.ok_or(ConnectionError::MissingResponseField("concept"))?; + match concept { + concept_proto::Concept::Thing(thing) => Ok(Self::Thing(Thing::try_from_proto(thing)?)), + concept_proto::Concept::Type(type_) => Ok(Self::Type(Type::try_from_proto(type_)?)), + } + } +} + +impl TryFromProto for Encoding { + fn try_from_proto(proto: i32) -> Result { + Self::from_i32(proto).ok_or(InternalError::EnumOutOfBounds(proto, "Encoding").into()) + } +} + +impl TryFromProto for Type { + fn try_from_proto(proto: TypeProto) -> Result { + match Encoding::try_from_proto(proto.encoding)? { + Encoding::ThingType => Ok(Self::Thing(ThingType::Root(RootThingType::default()))), + Encoding::EntityType => Ok(Self::Thing(ThingType::Entity(EntityType::from_proto(proto)))), + Encoding::RelationType => Ok(Self::Thing(ThingType::Relation(RelationType::from_proto(proto)))), + Encoding::AttributeType => Ok(Self::Thing(ThingType::Attribute(AttributeType::try_from_proto(proto)?))), + Encoding::RoleType => Ok(Self::Role(RoleType::from_proto(proto))), + } + } +} + +impl FromProto for EntityType { + fn from_proto(proto: TypeProto) -> Self { + Self::new(proto.label) + } +} + +impl FromProto for RelationType { + fn from_proto(proto: TypeProto) -> Self { + Self::new(proto.label) + } +} + +impl TryFromProto for ValueType { + fn try_from_proto(proto: i32) -> Result { + Self::from_i32(proto).ok_or(InternalError::EnumOutOfBounds(proto, "ValueType").into()) + } +} + +impl TryFromProto for AttributeType { + fn try_from_proto(proto: TypeProto) -> Result { + match ValueType::try_from_proto(proto.value_type)? { + ValueType::Object => Ok(Self::Root(RootAttributeType::default())), + ValueType::Boolean => Ok(Self::Boolean(BooleanAttributeType { label: proto.label })), + ValueType::Long => Ok(Self::Long(LongAttributeType { label: proto.label })), + ValueType::Double => Ok(Self::Double(DoubleAttributeType { label: proto.label })), + ValueType::String => Ok(Self::String(StringAttributeType { label: proto.label })), + ValueType::Datetime => Ok(Self::DateTime(DateTimeAttributeType { label: proto.label })), + } + } +} + +impl FromProto for RoleType { + fn from_proto(proto: TypeProto) -> Self { + Self::new(ScopedLabel::new(proto.scope, proto.label)) + } +} + +impl TryFromProto for Thing { + fn try_from_proto(proto: ThingProto) -> Result { + let encoding = proto.r#type.clone().ok_or(ConnectionError::MissingResponseField("type"))?.encoding; + match Encoding::try_from_proto(encoding)? { + Encoding::EntityType => Ok(Self::Entity(Entity::try_from_proto(proto)?)), + Encoding::RelationType => Ok(Self::Relation(Relation::try_from_proto(proto)?)), + Encoding::AttributeType => Ok(Self::Attribute(Attribute::try_from_proto(proto)?)), + _ => todo!(), + } + } +} + +impl TryFromProto for Entity { + fn try_from_proto(proto: ThingProto) -> Result { + Ok(Self { + type_: EntityType::from_proto(proto.r#type.ok_or(ConnectionError::MissingResponseField("type"))?), + iid: proto.iid, + }) + } +} + +impl TryFromProto for Relation { + fn try_from_proto(proto: ThingProto) -> Result { + Ok(Self { + type_: RelationType::from_proto(proto.r#type.ok_or(ConnectionError::MissingResponseField("type"))?), + iid: proto.iid, + }) + } +} + +impl TryFromProto for Attribute { + fn try_from_proto(proto: ThingProto) -> Result { + let value = proto.value.and_then(|v| v.value).ok_or(ConnectionError::MissingResponseField("value"))?; + + let value_type = proto.r#type.ok_or(ConnectionError::MissingResponseField("type"))?.value_type; + let iid = proto.iid; + + match ValueType::try_from_proto(value_type)? { + ValueType::Object => todo!(), + ValueType::Boolean => Ok(Self::Boolean(BooleanAttribute { + value: if let ValueProto::Boolean(value) = value { value } else { unreachable!() }, + iid, + })), + ValueType::Long => Ok(Self::Long(LongAttribute { + value: if let ValueProto::Long(value) = value { value } else { unreachable!() }, + iid, + })), + ValueType::Double => Ok(Self::Double(DoubleAttribute { + value: if let ValueProto::Double(value) = value { value } else { unreachable!() }, + iid, + })), + ValueType::String => Ok(Self::String(StringAttribute { + value: if let ValueProto::String(value) = value { value } else { unreachable!() }, + iid, + })), + ValueType::Datetime => Ok(Self::DateTime(DateTimeAttribute { + value: if let ValueProto::DateTime(value) = value { + NaiveDateTime::from_timestamp_opt(value / 1000, (value % 1000) as u32 * 1_000_000).unwrap() + } else { + unreachable!() + }, + iid, + })), + } + } +} diff --git a/src/connection/network/proto/database.rs b/src/connection/network/proto/database.rs new file mode 100644 index 00000000..bc73473d --- /dev/null +++ b/src/connection/network/proto/database.rs @@ -0,0 +1,49 @@ +/* + * Copyright (C) 2022 Vaticle + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use itertools::Itertools; +use typedb_protocol::{cluster_database::Replica as ReplicaProto, ClusterDatabase as DatabaseProto}; + +use super::TryFromProto; +use crate::{ + common::info::{DatabaseInfo, ReplicaInfo}, + Result, +}; + +impl TryFromProto for DatabaseInfo { + fn try_from_proto(proto: DatabaseProto) -> Result { + Ok(Self { + name: proto.name, + replicas: proto.replicas.into_iter().map(ReplicaInfo::try_from_proto).try_collect()?, + }) + } +} + +impl TryFromProto for ReplicaInfo { + fn try_from_proto(proto: ReplicaProto) -> Result { + Ok(Self { + address: proto.address.as_str().parse()?, + is_primary: proto.primary, + is_preferred: proto.preferred, + term: proto.term, + }) + } +} diff --git a/src/connection/network/proto/message.rs b/src/connection/network/proto/message.rs new file mode 100644 index 00000000..1a0c9f81 --- /dev/null +++ b/src/connection/network/proto/message.rs @@ -0,0 +1,370 @@ +/* + * Copyright (C) 2022 Vaticle + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::time::Duration; + +use itertools::Itertools; +use typedb_protocol::{ + cluster_database_manager, core_database, core_database_manager, query_manager, server_manager, session, transaction, +}; + +use super::{FromProto, IntoProto, TryFromProto}; +use crate::{ + answer::{ConceptMap, Numeric}, + common::{info::DatabaseInfo, RequestID, Result}, + connection::{ + message::{QueryRequest, QueryResponse, Request, Response, TransactionRequest, TransactionResponse}, + network::proto::TryIntoProto, + }, + error::{ConnectionError, InternalError}, +}; + +impl TryIntoProto for Request { + fn try_into_proto(self) -> Result { + match self { + Request::ServersAll => Ok(server_manager::all::Req {}), + other => Err(InternalError::UnexpectedRequestType(format!("{other:?}")).into()), + } + } +} + +impl TryIntoProto for Request { + fn try_into_proto(self) -> Result { + match self { + Request::DatabasesContains { database_name } => { + Ok(core_database_manager::contains::Req { name: database_name }) + } + other => Err(InternalError::UnexpectedRequestType(format!("{other:?}")).into()), + } + } +} + +impl TryIntoProto for Request { + fn try_into_proto(self) -> Result { + match self { + Request::DatabaseCreate { database_name } => Ok(core_database_manager::create::Req { name: database_name }), + other => Err(InternalError::UnexpectedRequestType(format!("{other:?}")).into()), + } + } +} + +impl TryIntoProto for Request { + fn try_into_proto(self) -> Result { + match self { + Request::DatabaseGet { database_name } => Ok(cluster_database_manager::get::Req { name: database_name }), + other => Err(InternalError::UnexpectedRequestType(format!("{other:?}")).into()), + } + } +} + +impl TryIntoProto for Request { + fn try_into_proto(self) -> Result { + match self { + Request::DatabasesAll => Ok(cluster_database_manager::all::Req {}), + other => Err(InternalError::UnexpectedRequestType(format!("{other:?}")).into()), + } + } +} + +impl TryIntoProto for Request { + fn try_into_proto(self) -> Result { + match self { + Request::DatabaseDelete { database_name } => Ok(core_database::delete::Req { name: database_name }), + other => Err(InternalError::UnexpectedRequestType(format!("{other:?}")).into()), + } + } +} + +impl TryIntoProto for Request { + fn try_into_proto(self) -> Result { + match self { + Request::DatabaseSchema { database_name } => Ok(core_database::schema::Req { name: database_name }), + other => Err(InternalError::UnexpectedRequestType(format!("{other:?}")).into()), + } + } +} + +impl TryIntoProto for Request { + fn try_into_proto(self) -> Result { + match self { + Request::DatabaseTypeSchema { database_name } => { + Ok(core_database::type_schema::Req { name: database_name }) + } + other => Err(InternalError::UnexpectedRequestType(format!("{other:?}")).into()), + } + } +} + +impl TryIntoProto for Request { + fn try_into_proto(self) -> Result { + match self { + Request::DatabaseRuleSchema { database_name } => { + Ok(core_database::rule_schema::Req { name: database_name }) + } + other => Err(InternalError::UnexpectedRequestType(format!("{other:?}")).into()), + } + } +} + +impl TryIntoProto for Request { + fn try_into_proto(self) -> Result { + match self { + Request::SessionOpen { database_name, session_type, options } => Ok(session::open::Req { + database: database_name, + r#type: session_type.into_proto().into(), + options: Some(options.into_proto()), + }), + other => Err(InternalError::UnexpectedRequestType(format!("{other:?}")).into()), + } + } +} + +impl TryIntoProto for Request { + fn try_into_proto(self) -> Result { + match self { + Request::SessionPulse { session_id } => Ok(session::pulse::Req { session_id: session_id.into() }), + other => Err(InternalError::UnexpectedRequestType(format!("{other:?}")).into()), + } + } +} + +impl TryIntoProto for Request { + fn try_into_proto(self) -> Result { + match self { + Request::SessionClose { session_id } => Ok(session::close::Req { session_id: session_id.into() }), + other => Err(InternalError::UnexpectedRequestType(format!("{other:?}")).into()), + } + } +} + +impl TryIntoProto for Request { + fn try_into_proto(self) -> Result { + match self { + Request::Transaction(transaction_req) => { + Ok(transaction::Client { reqs: vec![transaction_req.into_proto()] }) + } + other => Err(InternalError::UnexpectedRequestType(format!("{other:?}")).into()), + } + } +} + +impl TryFromProto for Response { + fn try_from_proto(proto: server_manager::all::Res) -> Result { + let servers = proto.servers.into_iter().map(|server| server.address.parse()).try_collect()?; + Ok(Response::ServersAll { servers }) + } +} + +impl FromProto for Response { + fn from_proto(proto: core_database_manager::contains::Res) -> Self { + Self::DatabasesContains { contains: proto.contains } + } +} + +impl FromProto for Response { + fn from_proto(_proto: core_database_manager::create::Res) -> Self { + Self::DatabaseCreate + } +} + +impl TryFromProto for Response { + fn try_from_proto(proto: cluster_database_manager::get::Res) -> Result { + Ok(Response::DatabaseGet { + database: DatabaseInfo::try_from_proto( + proto.database.ok_or(ConnectionError::MissingResponseField("database"))?, + )?, + }) + } +} + +impl TryFromProto for Response { + fn try_from_proto(proto: cluster_database_manager::all::Res) -> Result { + Ok(Response::DatabasesAll { + databases: proto.databases.into_iter().map(DatabaseInfo::try_from_proto).try_collect()?, + }) + } +} + +impl FromProto for Response { + fn from_proto(_proto: core_database::delete::Res) -> Self { + Self::DatabaseDelete + } +} + +impl FromProto for Response { + fn from_proto(proto: core_database::schema::Res) -> Self { + Self::DatabaseSchema { schema: proto.schema } + } +} + +impl FromProto for Response { + fn from_proto(proto: core_database::type_schema::Res) -> Self { + Self::DatabaseTypeSchema { schema: proto.schema } + } +} + +impl FromProto for Response { + fn from_proto(proto: core_database::rule_schema::Res) -> Self { + Self::DatabaseRuleSchema { schema: proto.schema } + } +} + +impl FromProto for Response { + fn from_proto(proto: session::open::Res) -> Self { + Self::SessionOpen { + session_id: proto.session_id.into(), + server_duration: Duration::from_millis(proto.server_duration_millis as u64), + } + } +} + +impl FromProto for Response { + fn from_proto(_proto: session::pulse::Res) -> Self { + Self::SessionPulse + } +} + +impl FromProto for Response { + fn from_proto(_proto: session::close::Res) -> Self { + Self::SessionClose + } +} + +impl IntoProto for TransactionRequest { + fn into_proto(self) -> transaction::Req { + let mut request_id = None; + + let req = match self { + TransactionRequest::Open { session_id, transaction_type, options, network_latency } => { + transaction::req::Req::OpenReq(transaction::open::Req { + session_id: session_id.into(), + r#type: transaction_type.into_proto().into(), + options: Some(options.into_proto()), + network_latency_millis: network_latency.as_millis() as i32, + }) + } + TransactionRequest::Commit => transaction::req::Req::CommitReq(transaction::commit::Req {}), + TransactionRequest::Rollback => transaction::req::Req::RollbackReq(transaction::rollback::Req {}), + TransactionRequest::Query(query_request) => { + transaction::req::Req::QueryManagerReq(query_request.into_proto()) + } + TransactionRequest::Stream { request_id: req_id } => { + request_id = Some(req_id); + transaction::req::Req::StreamReq(transaction::stream::Req {}) + } + }; + + transaction::Req { + req_id: request_id.unwrap_or_else(RequestID::generate).into(), + metadata: Default::default(), + req: Some(req), + } + } +} + +impl TryFromProto for TransactionResponse { + fn try_from_proto(proto: transaction::Res) -> Result { + match proto.res { + Some(transaction::res::Res::OpenRes(_)) => Ok(TransactionResponse::Open), + Some(transaction::res::Res::CommitRes(_)) => Ok(TransactionResponse::Commit), + Some(transaction::res::Res::RollbackRes(_)) => Ok(TransactionResponse::Rollback), + Some(transaction::res::Res::QueryManagerRes(res)) => { + Ok(TransactionResponse::Query(QueryResponse::try_from_proto(res)?)) + } + Some(_) => todo!(), + None => Err(ConnectionError::MissingResponseField("res").into()), + } + } +} + +impl TryFromProto for TransactionResponse { + fn try_from_proto(proto: transaction::ResPart) -> Result { + match proto.res { + Some(transaction::res_part::Res::QueryManagerResPart(res_part)) => { + Ok(TransactionResponse::Query(QueryResponse::try_from_proto(res_part)?)) + } + Some(_) => todo!(), + None => Err(ConnectionError::MissingResponseField("res").into()), + } + } +} + +impl IntoProto for QueryRequest { + fn into_proto(self) -> query_manager::Req { + let (req, options) = match self { + QueryRequest::Define { query, options } => { + (query_manager::req::Req::DefineReq(query_manager::define::Req { query }), options) + } + QueryRequest::Undefine { query, options } => { + (query_manager::req::Req::UndefineReq(query_manager::undefine::Req { query }), options) + } + QueryRequest::Delete { query, options } => { + (query_manager::req::Req::DeleteReq(query_manager::delete::Req { query }), options) + } + + QueryRequest::Match { query, options } => { + (query_manager::req::Req::MatchReq(query_manager::r#match::Req { query }), options) + } + QueryRequest::Insert { query, options } => { + (query_manager::req::Req::InsertReq(query_manager::insert::Req { query }), options) + } + QueryRequest::Update { query, options } => { + (query_manager::req::Req::UpdateReq(query_manager::update::Req { query }), options) + } + + QueryRequest::MatchAggregate { query, options } => { + (query_manager::req::Req::MatchAggregateReq(query_manager::match_aggregate::Req { query }), options) + } + + _ => todo!(), + }; + query_manager::Req { req: Some(req), options: Some(options.into_proto()) } + } +} + +impl TryFromProto for QueryResponse { + fn try_from_proto(proto: query_manager::Res) -> Result { + match proto.res { + Some(query_manager::res::Res::DefineRes(_)) => Ok(QueryResponse::Define), + Some(query_manager::res::Res::UndefineRes(_)) => Ok(QueryResponse::Undefine), + Some(query_manager::res::Res::DeleteRes(_)) => Ok(QueryResponse::Delete), + Some(query_manager::res::Res::MatchAggregateRes(res)) => Ok(QueryResponse::MatchAggregate { + answer: Numeric::try_from_proto(res.answer.ok_or(ConnectionError::MissingResponseField("answer"))?)?, + }), + None => Err(ConnectionError::MissingResponseField("res").into()), + } + } +} + +impl TryFromProto for QueryResponse { + fn try_from_proto(proto: query_manager::ResPart) -> Result { + match proto.res { + Some(query_manager::res_part::Res::MatchResPart(res)) => Ok(QueryResponse::Match { + answers: res.answers.into_iter().map(ConceptMap::try_from_proto).try_collect()?, + }), + Some(query_manager::res_part::Res::InsertResPart(res)) => Ok(QueryResponse::Insert { + answers: res.answers.into_iter().map(ConceptMap::try_from_proto).try_collect()?, + }), + Some(_) => todo!(), + None => Err(ConnectionError::MissingResponseField("res").into()), + } + } +} diff --git a/src/common/rpc/mod.rs b/src/connection/network/proto/mod.rs similarity index 67% rename from src/common/rpc/mod.rs rename to src/connection/network/proto/mod.rs index e38fa19d..fc805bb2 100644 --- a/src/common/rpc/mod.rs +++ b/src/connection/network/proto/mod.rs @@ -19,17 +19,25 @@ * under the License. */ -pub(crate) mod builder; -mod channel; -mod cluster; -mod core; -mod server; -mod transaction; +mod common; +mod concept; +mod database; +mod message; -pub(crate) use self::{ - channel::Channel, - cluster::{ClusterRPC, ClusterServerRPC}, - core::CoreRPC, - server::ServerRPC, - transaction::TransactionRPC, -}; +use crate::Result; + +pub(super) trait IntoProto { + fn into_proto(self) -> Proto; +} + +pub(super) trait TryIntoProto { + fn try_into_proto(self) -> Result; +} + +pub(super) trait FromProto { + fn from_proto(proto: Proto) -> Self; +} + +pub(super) trait TryFromProto: Sized { + fn try_from_proto(proto: Proto) -> Result; +} diff --git a/src/connection/network/stub.rs b/src/connection/network/stub.rs new file mode 100644 index 00000000..fccae604 --- /dev/null +++ b/src/connection/network/stub.rs @@ -0,0 +1,264 @@ +/* + * Copyright (C) 2022 Vaticle + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::sync::Arc; + +use futures::{future::BoxFuture, FutureExt, TryFutureExt}; +use log::{debug, trace}; +use tokio::sync::mpsc::{unbounded_channel as unbounded_async, UnboundedSender}; +use tokio_stream::wrappers::UnboundedReceiverStream; +use tonic::{Response, Status, Streaming}; +use typedb_protocol::{ + cluster_database::Replica, cluster_database_manager, cluster_user, core_database, core_database_manager, + server_manager, session, transaction, type_db_client::TypeDbClient as CoreGRPC, + type_db_cluster_client::TypeDbClusterClient as ClusterGRPC, ClusterDatabase, +}; + +use super::channel::{CallCredentials, GRPCChannel}; +use crate::common::{address::Address, error::ConnectionError, Error, Result, StdResult}; + +type TonicResult = StdResult, Status>; + +#[derive(Clone, Debug)] +pub(super) struct RPCStub { + address: Address, + channel: Channel, + core_grpc: CoreGRPC, + cluster_grpc: ClusterGRPC, + call_credentials: Option>, +} + +impl RPCStub { + pub(super) async fn new( + address: Address, + channel: Channel, + call_credentials: Option>, + ) -> Result { + let this = Self { + address, + core_grpc: CoreGRPC::new(channel.clone()), + cluster_grpc: ClusterGRPC::new(channel.clone()), + channel, + call_credentials, + }; + let mut this = this.validated().await?; + this.renew_token().await?; + Ok(this) + } + + pub(super) async fn validated(mut self) -> Result { + self.databases_all(cluster_database_manager::all::Req {}).await?; + Ok(self) + } + + fn address(&self) -> &Address { + &self.address + } + + async fn call_with_auto_renew_token(&mut self, call: F) -> Result + where + for<'a> F: Fn(&'a mut Self) -> BoxFuture<'a, Result>, + { + match call(self).await { + Err(Error::Connection(ConnectionError::ClusterTokenCredentialInvalid())) => { + self.renew_token().await?; + call(self).await + } + res => res, + } + } + + async fn renew_token(&mut self) -> Result { + if let Some(call_credentials) = &self.call_credentials { + trace!("renewing token..."); + call_credentials.reset_token(); + let req = cluster_user::token::Req { username: call_credentials.username().to_owned() }; + trace!("sending token request..."); + let token = self.cluster_grpc.user_token(req).await?.into_inner().token; + call_credentials.set_token(token); + trace!("renewed token"); + } + Ok(()) + } + + pub(super) async fn servers_all(&mut self, req: server_manager::all::Req) -> Result { + self.single(|this| Box::pin(this.cluster_grpc.servers_all(req.clone()))).await + } + + pub(super) async fn databases_contains( + &mut self, + req: core_database_manager::contains::Req, + ) -> Result { + self.single(|this| Box::pin(this.core_grpc.databases_contains(req.clone()))).await + } + + pub(super) async fn databases_create( + &mut self, + req: core_database_manager::create::Req, + ) -> Result { + self.single(|this| Box::pin(this.core_grpc.databases_create(req.clone()))).await + } + + // FIXME: merge after protocol merge + pub(super) async fn databases_get( + &mut self, + req: cluster_database_manager::get::Req, + ) -> Result { + if self.channel.is_plaintext() { + self.databases_get_core(req).await + } else { + self.databases_get_cluster(req).await + } + } + + pub(super) async fn databases_all( + &mut self, + req: cluster_database_manager::all::Req, + ) -> Result { + if self.channel.is_plaintext() { + self.databases_all_core(req).await + } else { + self.databases_all_cluster(req).await + } + } + + async fn databases_get_core( + &mut self, + req: cluster_database_manager::get::Req, + ) -> Result { + Ok(cluster_database_manager::get::Res { + database: Some(ClusterDatabase { + name: req.name, + replicas: vec![Replica { + address: self.address().to_string(), + primary: true, + preferred: true, + term: 0, + }], + }), + }) + } + + async fn databases_all_core( + &mut self, + _req: cluster_database_manager::all::Req, + ) -> Result { + let database_names = + self.single(|this| Box::pin(this.core_grpc.databases_all(core_database_manager::all::Req {}))).await?.names; + Ok(cluster_database_manager::all::Res { + databases: database_names + .into_iter() + .map(|db_name| ClusterDatabase { + name: db_name, + replicas: vec![Replica { + address: self.address().to_string(), + primary: true, + preferred: true, + term: 0, + }], + }) + .collect(), + }) + } + + async fn databases_get_cluster( + &mut self, + req: cluster_database_manager::get::Req, + ) -> Result { + self.single(|this| Box::pin(this.cluster_grpc.databases_get(req.clone()))).await + } + + async fn databases_all_cluster( + &mut self, + req: cluster_database_manager::all::Req, + ) -> Result { + self.single(|this| Box::pin(this.cluster_grpc.databases_all(req.clone()))).await + } + // FIXME: end FIXME + + pub(super) async fn database_delete( + &mut self, + req: core_database::delete::Req, + ) -> Result { + self.single(|this| Box::pin(this.core_grpc.database_delete(req.clone()))).await + } + + pub(super) async fn database_schema( + &mut self, + req: core_database::schema::Req, + ) -> Result { + self.single(|this| Box::pin(this.core_grpc.database_schema(req.clone()))).await + } + + pub(super) async fn database_type_schema( + &mut self, + req: core_database::type_schema::Req, + ) -> Result { + self.single(|this| Box::pin(this.core_grpc.database_type_schema(req.clone()))).await + } + + pub(super) async fn database_rule_schema( + &mut self, + req: core_database::rule_schema::Req, + ) -> Result { + self.single(|this| Box::pin(this.core_grpc.database_rule_schema(req.clone()))).await + } + + pub(super) async fn session_open(&mut self, req: session::open::Req) -> Result { + self.single(|this| Box::pin(this.core_grpc.session_open(req.clone()))).await + } + + pub(super) async fn session_close(&mut self, req: session::close::Req) -> Result { + debug!("closing session"); + self.single(|this| Box::pin(this.core_grpc.session_close(req.clone()))).await + } + + pub(super) async fn session_pulse(&mut self, req: session::pulse::Req) -> Result { + self.single(|this| Box::pin(this.core_grpc.session_pulse(req.clone()))).await + } + + pub(super) async fn transaction( + &mut self, + open_req: transaction::Req, + ) -> Result<(UnboundedSender, Streaming)> { + self.call_with_auto_renew_token(|this| { + let transaction_req = transaction::Client { reqs: vec![open_req.clone()] }; + Box::pin(async { + let (sender, receiver) = unbounded_async(); + sender.send(transaction_req)?; + this.core_grpc + .transaction(UnboundedReceiverStream::new(receiver)) + .map_ok(|stream| Response::new((sender, stream.into_inner()))) + .map(|r| Ok(r?.into_inner())) + .await + }) + }) + .await + } + + async fn single(&mut self, call: F) -> Result + where + for<'a> F: Fn(&'a mut Self) -> BoxFuture<'a, TonicResult> + Send + Sync, + R: 'static, + { + self.call_with_auto_renew_token(|this| Box::pin(call(this).map(|r| Ok(r?.into_inner())))).await + } +} diff --git a/src/connection/server/mod.rs b/src/connection/network/transmitter/mod.rs similarity index 87% rename from src/connection/server/mod.rs rename to src/connection/network/transmitter/mod.rs index 8b865283..971012f7 100644 --- a/src/connection/server/mod.rs +++ b/src/connection/network/transmitter/mod.rs @@ -19,8 +19,8 @@ * under the License. */ -mod database; -mod session; +mod response_sink; +mod rpc; mod transaction; -pub use self::{database::Database, session::Session, transaction::Transaction}; +pub(in crate::connection) use self::{rpc::RPCTransmitter, transaction::TransactionTransmitter}; diff --git a/src/connection/network/transmitter/response_sink.rs b/src/connection/network/transmitter/response_sink.rs new file mode 100644 index 00000000..526a3110 --- /dev/null +++ b/src/connection/network/transmitter/response_sink.rs @@ -0,0 +1,68 @@ +/* + * Copyright (C) 2022 Vaticle + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use crossbeam::channel::Sender as SyncSender; +use log::error; +use tokio::sync::{mpsc::UnboundedSender, oneshot::Sender as AsyncOneshotSender}; + +use crate::{ + common::Result, + error::{ConnectionError, InternalError}, + Error, +}; + +#[derive(Debug)] +pub(super) enum ResponseSink { + AsyncOneShot(AsyncOneshotSender>), + BlockingOneShot(SyncSender>), + Streamed(UnboundedSender>), +} + +impl ResponseSink { + pub(super) fn finish(self, response: Result) { + let result = match self { + Self::AsyncOneShot(sink) => sink.send(response).map_err(|_| InternalError::SendError().into()), + Self::BlockingOneShot(sink) => sink.send(response).map_err(Error::from), + Self::Streamed(sink) => sink.send(response).map_err(Error::from), + }; + if let Err(err) = result { + error!("{}", err); + } + } + + pub(super) fn send(&self, response: Result) { + let result = match self { + Self::Streamed(sink) => sink.send(response).map_err(Error::from), + _ => unreachable!("attempted to stream over a one-shot callback"), + }; + if let Err(err) = result { + error!("{}", err); + } + } + + pub(super) async fn error(self, error: ConnectionError) { + match self { + Self::AsyncOneShot(sink) => sink.send(Err(error.into())).ok(), + Self::BlockingOneShot(sink) => sink.send(Err(error.into())).ok(), + Self::Streamed(sink) => sink.send(Err(error.into())).ok(), + }; + } +} diff --git a/src/connection/network/transmitter/rpc.rs b/src/connection/network/transmitter/rpc.rs new file mode 100644 index 00000000..def4f64e --- /dev/null +++ b/src/connection/network/transmitter/rpc.rs @@ -0,0 +1,161 @@ +/* + * Copyright (C) 2022 Vaticle + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use crossbeam::channel::{bounded as bounded_blocking, Receiver as SyncReceiver, Sender as SyncSender}; +use tokio::{ + select, + sync::{ + mpsc::{unbounded_channel as unbounded_async, UnboundedReceiver, UnboundedSender}, + oneshot::channel as oneshot_async, + }, +}; + +use super::response_sink::ResponseSink; +use crate::{ + common::{address::Address, Result}, + connection::{ + message::{Request, Response}, + network::{ + channel::{open_encrypted_channel, open_plaintext_channel, GRPCChannel}, + proto::{FromProto, IntoProto, TryFromProto, TryIntoProto}, + stub::RPCStub, + }, + runtime::BackgroundRuntime, + }, + Credential, Error, +}; + +fn oneshot_blocking() -> (SyncSender, SyncReceiver) { + bounded_blocking::(0) +} + +pub(in crate::connection) struct RPCTransmitter { + request_sink: UnboundedSender<(Request, ResponseSink)>, + shutdown_sink: UnboundedSender<()>, +} + +impl RPCTransmitter { + pub(in crate::connection) fn start_plaintext(address: Address, runtime: &BackgroundRuntime) -> Result { + let (request_sink, request_source) = unbounded_async(); + let (shutdown_sink, shutdown_source) = unbounded_async(); + runtime.run_blocking(async move { + let channel = open_plaintext_channel(address.clone()); + let rpc = RPCStub::new(address.clone(), channel, None).await?; + tokio::spawn(Self::dispatcher_loop(rpc, request_source, shutdown_source)); + Ok::<(), Error>(()) + })?; + Ok(Self { request_sink, shutdown_sink }) + } + + pub(in crate::connection) fn start_encrypted( + address: Address, + credential: Credential, + runtime: &BackgroundRuntime, + ) -> Result { + let (request_sink, request_source) = unbounded_async(); + let (shutdown_sink, shutdown_source) = unbounded_async(); + runtime.run_blocking(async move { + let (channel, call_credentials) = open_encrypted_channel(address.clone(), credential)?; + let rpc = RPCStub::new(address.clone(), channel, Some(call_credentials)).await?; + tokio::spawn(Self::dispatcher_loop(rpc, request_source, shutdown_source)); + Ok::<(), Error>(()) + })?; + Ok(Self { request_sink, shutdown_sink }) + } + + pub(in crate::connection) async fn request_async(&self, request: Request) -> Result { + let (response_sink, response) = oneshot_async(); + self.request_sink.send((request, ResponseSink::AsyncOneShot(response_sink)))?; + response.await? + } + + pub(in crate::connection) fn request_blocking(&self, request: Request) -> Result { + let (response_sink, response) = oneshot_blocking(); + self.request_sink.send((request, ResponseSink::BlockingOneShot(response_sink)))?; + response.recv()? + } + + pub(in crate::connection) fn force_close(&self) -> Result { + self.shutdown_sink.send(()).map_err(Into::into) + } + + async fn dispatcher_loop( + rpc: RPCStub, + mut request_source: UnboundedReceiver<(Request, ResponseSink)>, + mut shutdown_signal: UnboundedReceiver<()>, + ) { + while let Some((request, response_sink)) = select! { + request = request_source.recv() => request, + _ = shutdown_signal.recv() => None, + } { + let rpc = rpc.clone(); + tokio::spawn(async move { + let response = Self::send_request(rpc, request).await; + response_sink.finish(response); + }); + } + } + + async fn send_request(mut rpc: RPCStub, request: Request) -> Result { + match request { + Request::ServersAll => rpc.servers_all(request.try_into_proto()?).await.and_then(Response::try_from_proto), + + Request::DatabasesContains { .. } => { + rpc.databases_contains(request.try_into_proto()?).await.map(Response::from_proto) + } + Request::DatabaseCreate { .. } => { + rpc.databases_create(request.try_into_proto()?).await.map(Response::from_proto) + } + Request::DatabaseGet { .. } => { + rpc.databases_get(request.try_into_proto()?).await.and_then(Response::try_from_proto) + } + Request::DatabasesAll => { + rpc.databases_all(request.try_into_proto()?).await.and_then(Response::try_from_proto) + } + + Request::DatabaseDelete { .. } => { + rpc.database_delete(request.try_into_proto()?).await.map(Response::from_proto) + } + Request::DatabaseSchema { .. } => { + rpc.database_schema(request.try_into_proto()?).await.map(Response::from_proto) + } + Request::DatabaseTypeSchema { .. } => { + rpc.database_type_schema(request.try_into_proto()?).await.map(Response::from_proto) + } + Request::DatabaseRuleSchema { .. } => { + rpc.database_rule_schema(request.try_into_proto()?).await.map(Response::from_proto) + } + + Request::SessionOpen { .. } => rpc.session_open(request.try_into_proto()?).await.map(Response::from_proto), + Request::SessionPulse { .. } => { + rpc.session_pulse(request.try_into_proto()?).await.map(Response::from_proto) + } + Request::SessionClose { .. } => { + rpc.session_close(request.try_into_proto()?).await.map(Response::from_proto) + } + + Request::Transaction(transaction_request) => { + let (request_sink, response_source) = rpc.transaction(transaction_request.into_proto()).await?; + Ok(Response::TransactionOpen { request_sink, response_source }) + } + } + } +} diff --git a/src/connection/network/transmitter/transaction.rs b/src/connection/network/transmitter/transaction.rs new file mode 100644 index 00000000..41f8427d --- /dev/null +++ b/src/connection/network/transmitter/transaction.rs @@ -0,0 +1,271 @@ +/* + * Copyright (C) 2022 Vaticle + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::{ + collections::HashMap, + ops::DerefMut, + sync::{Arc, RwLock}, + time::Duration, +}; + +use crossbeam::atomic::AtomicCell; +use futures::{Stream, StreamExt, TryStreamExt}; +use log::error; +use prost::Message; +use tokio::{ + select, + sync::{ + mpsc::{error::SendError, unbounded_channel as unbounded_async, UnboundedReceiver, UnboundedSender}, + oneshot::channel as oneshot_async, + }, + time::{sleep_until, Instant}, +}; +use tokio_stream::wrappers::UnboundedReceiverStream; +use tonic::Streaming; +use typedb_protocol::transaction::{self, server::Server, stream::State}; + +use super::response_sink::ResponseSink; +use crate::{ + common::{error::ConnectionError, RequestID, Result}, + connection::{ + message::{TransactionRequest, TransactionResponse}, + network::proto::{IntoProto, TryFromProto}, + runtime::BackgroundRuntime, + }, +}; + +pub(in crate::connection) struct TransactionTransmitter { + request_sink: UnboundedSender<(TransactionRequest, Option>)>, + is_open: Arc>, + shutdown_sink: UnboundedSender<()>, +} + +impl Drop for TransactionTransmitter { + fn drop(&mut self) { + self.is_open.store(false); + self.shutdown_sink.send(()).ok(); + } +} + +impl TransactionTransmitter { + pub(in crate::connection) fn new( + background_runtime: &BackgroundRuntime, + request_sink: UnboundedSender, + response_source: Streaming, + ) -> Self { + let (buffer_sink, buffer_source) = unbounded_async(); + let (shutdown_sink, shutdown_source) = unbounded_async(); + let is_open = Arc::new(AtomicCell::new(true)); + background_runtime.spawn(Self::start_workers( + buffer_sink.clone(), + buffer_source, + request_sink, + response_source, + is_open.clone(), + shutdown_source, + )); + Self { request_sink: buffer_sink, is_open, shutdown_sink } + } + + pub(in crate::connection) async fn single(&self, req: TransactionRequest) -> Result { + if !self.is_open.load() { + return Err(ConnectionError::SessionIsClosed().into()); + } + let (res_sink, recv) = oneshot_async(); + self.request_sink.send((req, Some(ResponseSink::AsyncOneShot(res_sink))))?; + recv.await?.map(Into::into) + } + + pub(in crate::connection) fn stream( + &self, + req: TransactionRequest, + ) -> Result>> { + if !self.is_open.load() { + return Err(ConnectionError::SessionIsClosed().into()); + } + let (res_part_sink, recv) = unbounded_async(); + self.request_sink.send((req, Some(ResponseSink::Streamed(res_part_sink))))?; + Ok(UnboundedReceiverStream::new(recv).map_ok(Into::into)) + } + + async fn start_workers( + queue_sink: UnboundedSender<(TransactionRequest, Option>)>, + queue_source: UnboundedReceiver<(TransactionRequest, Option>)>, + request_sink: UnboundedSender, + response_source: Streaming, + is_open: Arc>, + shutdown_signal: UnboundedReceiver<()>, + ) { + let collector = ResponseCollector { request_sink: queue_sink, callbacks: Default::default(), is_open }; + tokio::spawn(Self::dispatch_loop(queue_source, request_sink, collector.clone(), shutdown_signal)); + tokio::spawn(Self::listen_loop(response_source, collector)); + } + + async fn dispatch_loop( + mut request_source: UnboundedReceiver<(TransactionRequest, Option>)>, + request_sink: UnboundedSender, + mut collector: ResponseCollector, + mut shutdown_signal: UnboundedReceiver<()>, + ) { + const MAX_GRPC_MESSAGE_LEN: usize = 1_000_000; + const DISPATCH_INTERVAL: Duration = Duration::from_millis(3); + + let mut request_buffer = TransactionRequestBuffer::default(); + let mut next_dispatch = Instant::now() + DISPATCH_INTERVAL; + loop { + select! { biased; + _ = shutdown_signal.recv() => { + if !request_buffer.is_empty() { + request_sink.send(request_buffer.take()).unwrap(); + } + break; + } + _ = sleep_until(next_dispatch) => { + if !request_buffer.is_empty() { + request_sink.send(request_buffer.take()).unwrap(); + } + next_dispatch = Instant::now() + DISPATCH_INTERVAL; + } + recv = request_source.recv() => { + if let Some((request, callback)) = recv { + let request = request.into_proto(); + if let Some(callback) = callback { + collector.register(request.req_id.clone().into(), callback); + } + if request_buffer.len() + request.encoded_len() > MAX_GRPC_MESSAGE_LEN { + request_sink.send(request_buffer.take()).unwrap(); + } + request_buffer.push(request); + } else { + break; + } + } + } + } + } + + async fn listen_loop(mut grpc_source: Streaming, collector: ResponseCollector) { + loop { + match grpc_source.next().await { + Some(Ok(message)) => collector.collect(message).await, + Some(Err(err)) => { + break collector.close(ConnectionError::TransactionIsClosedWithErrors(err.to_string())).await + } + None => break collector.close(ConnectionError::TransactionIsClosed()).await, + } + } + } +} + +#[derive(Default)] +struct TransactionRequestBuffer { + reqs: Vec, + len: usize, +} + +impl TransactionRequestBuffer { + fn is_empty(&self) -> bool { + self.reqs.is_empty() + } + + fn len(&self) -> usize { + self.len + } + + fn push(&mut self, request: transaction::Req) { + self.len += request.encoded_len(); + self.reqs.push(request); + } + + fn take(&mut self) -> transaction::Client { + self.len = 0; + transaction::Client { reqs: std::mem::take(&mut self.reqs) } + } +} + +#[derive(Clone)] +struct ResponseCollector { + request_sink: UnboundedSender<(TransactionRequest, Option>)>, + callbacks: Arc>>>, + is_open: Arc>, +} + +impl ResponseCollector { + fn register(&mut self, request_id: RequestID, callback: ResponseSink) { + self.callbacks.write().unwrap().insert(request_id, callback); + } + + async fn collect(&self, message: transaction::Server) { + match message.server { + Some(Server::Res(res)) => self.collect_res(res), + Some(Server::ResPart(res_part)) => self.collect_res_part(res_part).await, + None => error!("{}", ConnectionError::MissingResponseField("server")), + } + } + + fn collect_res(&self, res: transaction::Res) { + if matches!(res.res, Some(transaction::res::Res::OpenRes(_))) { + // Transaction::Open responses don't need to be collected. + return; + } + let req_id = res.req_id.clone().into(); + match self.callbacks.write().unwrap().remove(&req_id) { + Some(sink) => sink.finish(TransactionResponse::try_from_proto(res)), + _ => error!("{}", ConnectionError::UnknownRequestId(req_id)), + } + } + + async fn collect_res_part(&self, res_part: transaction::ResPart) { + let request_id = res_part.req_id.clone().into(); + + match res_part.res { + Some(transaction::res_part::Res::StreamResPart(stream_res_part)) => { + match State::from_i32(stream_res_part.state).expect("enum out of range") { + State::Done => { + self.callbacks.write().unwrap().remove(&request_id); + } + State::Continue => { + match self.request_sink.send((TransactionRequest::Stream { request_id }, None)) { + Err(SendError((TransactionRequest::Stream { request_id }, None))) => { + let callback = self.callbacks.write().unwrap().remove(&request_id).unwrap(); + callback.error(ConnectionError::TransactionIsClosed()).await; + } + _ => (), + } + } + } + } + Some(_) => match self.callbacks.read().unwrap().get(&request_id) { + Some(sink) => sink.send(TransactionResponse::try_from_proto(res_part)), + _ => error!("{}", ConnectionError::UnknownRequestId(request_id)), + }, + None => error!("{}", ConnectionError::MissingResponseField("res_part.res")), + } + } + + async fn close(self, error: ConnectionError) { + self.is_open.store(false); + let mut listeners = std::mem::take(self.callbacks.write().unwrap().deref_mut()); + for (_, listener) in listeners.drain() { + listener.error(error.clone()).await; + } + } +} diff --git a/src/connection/runtime.rs b/src/connection/runtime.rs new file mode 100644 index 00000000..faa45fd9 --- /dev/null +++ b/src/connection/runtime.rs @@ -0,0 +1,88 @@ +/* + * Copyright (C) 2022 Vaticle + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::{future::Future, thread}; + +use crossbeam::{atomic::AtomicCell, channel::bounded as bounded_blocking}; +use tokio::{ + runtime, + sync::mpsc::{unbounded_channel as unbounded_async, UnboundedSender}, +}; + +use crate::common::Result; + +pub(super) struct BackgroundRuntime { + async_runtime_handle: runtime::Handle, + is_open: AtomicCell, + shutdown_sink: UnboundedSender<()>, +} + +impl BackgroundRuntime { + pub(super) fn new() -> Result { + let is_open = AtomicCell::new(true); + let (shutdown_sink, mut shutdown_source) = unbounded_async(); + let async_runtime = runtime::Builder::new_current_thread().enable_time().enable_io().build()?; + let async_runtime_handle = async_runtime.handle().clone(); + thread::Builder::new().name("gRPC worker".to_string()).spawn(move || { + async_runtime.block_on(async move { + shutdown_source.recv().await; + }) + })?; + Ok(Self { async_runtime_handle, is_open, shutdown_sink }) + } + + pub(super) fn is_open(&self) -> bool { + self.is_open.load() + } + + pub(super) fn force_close(&self) -> Result { + self.is_open.store(false); + self.shutdown_sink.send(())?; + Ok(()) + } + + pub(super) fn spawn(&self, future: F) + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + self.async_runtime_handle.spawn(future); + } + + pub(super) fn run_blocking(&self, future: F) -> F::Output + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + let (response_sink, response) = bounded_blocking(0); + self.async_runtime_handle.spawn(async move { + response_sink.send(future.await).ok(); + }); + response.recv().unwrap() + } +} + +impl Drop for BackgroundRuntime { + fn drop(&mut self) { + self.is_open.store(false); + self.shutdown_sink.send(()).ok(); + } +} diff --git a/src/connection/server/database.rs b/src/connection/server/database.rs deleted file mode 100644 index 0f038213..00000000 --- a/src/connection/server/database.rs +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Copyright (C) 2022 Vaticle - * - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use std::fmt::{Display, Formatter}; - -use crate::common::{ - rpc::builder::core::database::{delete_req, rule_schema_req, schema_req, type_schema_req}, - Result, ServerRPC, -}; - -#[derive(Clone, Debug)] -pub struct Database { - pub name: String, - server_rpc: ServerRPC, -} - -impl Database { - pub(crate) fn new(name: &str, server_rpc: ServerRPC) -> Self { - Database { name: name.into(), server_rpc } - } - - pub async fn delete(mut self) -> Result { - self.server_rpc.database_delete(delete_req(self.name.as_str())).await?; - Ok(()) - } - - pub async fn schema(&mut self) -> Result { - self.server_rpc.database_schema(schema_req(self.name.as_str())).await.map(|res| res.schema) - } - - pub async fn type_schema(&mut self) -> Result { - self.server_rpc - .database_type_schema(type_schema_req(self.name.as_str())) - .await - .map(|res| res.schema) - } - - pub async fn rule_schema(&mut self) -> Result { - self.server_rpc - .database_rule_schema(rule_schema_req(self.name.as_str())) - .await - .map(|res| res.schema) - } -} - -impl Display for Database { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.name) - } -} diff --git a/src/connection/server/session.rs b/src/connection/server/session.rs deleted file mode 100644 index ddf50c1b..00000000 --- a/src/connection/server/session.rs +++ /dev/null @@ -1,122 +0,0 @@ -/* - * Copyright (C) 2022 Vaticle - * - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use std::time::{Duration, Instant}; - -use crossbeam::atomic::AtomicCell; -use futures::executor; -use log::warn; - -use crate::{ - common::{ - error::ClientError, - rpc::builder::session::{close_req, open_req}, - Result, ServerRPC, SessionType, TransactionType, - }, - connection::{core, server::Transaction}, -}; - -pub(crate) type SessionId = Vec; - -#[derive(Debug)] -pub struct Session { - pub db_name: String, - pub session_type: SessionType, - pub(crate) id: SessionId, - pub(crate) server_rpc: ServerRPC, - is_open_atomic: AtomicCell, - network_latency: Duration, -} - -impl Session { - pub(crate) async fn new( - db_name: &str, - session_type: SessionType, - options: core::Options, - mut server_rpc: ServerRPC, - ) -> Result { - let start_time = Instant::now(); - let open_req = open_req(db_name, session_type.to_proto(), options.to_proto()); - let res = server_rpc.session_open(open_req).await?; - // TODO: pulse task - Ok(Session { - db_name: String::from(db_name), - session_type, - network_latency: Self::compute_network_latency(start_time, res.server_duration_millis), - id: res.session_id, - server_rpc, - is_open_atomic: AtomicCell::new(true), - }) - } - - pub async fn transaction(&self, transaction_type: TransactionType) -> Result { - self.transaction_with_options(transaction_type, core::Options::default()).await - } - - pub async fn transaction_with_options( - &self, - transaction_type: TransactionType, - options: core::Options, - ) -> Result { - match self.is_open() { - true => { - Transaction::new( - &self.id, - transaction_type, - options, - self.network_latency, - &self.server_rpc, - ) - .await - } - false => Err(ClientError::SessionIsClosed())?, - } - } - - pub fn is_open(&self) -> bool { - self.is_open_atomic.load() - } - - pub async fn close(&mut self) { - if let Ok(true) = self.is_open_atomic.compare_exchange(true, false) { - // let res = self.session_close_sink.send(self.id.clone()); - let res = self.server_rpc.session_close(close_req(self.id.clone())).await; - // TODO: the request errors harmlessly if the session is already closed. Protocol should - // expose the cause of the error and we can use that to decide whether to warn here. - if res.is_err() { - warn!("{}", ClientError::SessionCloseFailed()) - } - } - } - - fn compute_network_latency(start_time: Instant, server_duration_millis: i32) -> Duration { - Duration::from_millis( - (Instant::now() - start_time).as_millis() as u64 - server_duration_millis as u64, - ) - } -} - -impl Drop for Session { - fn drop(&mut self) { - // TODO: this will stall in a single-threaded environment - executor::block_on(self.close()); - } -} diff --git a/src/connection/server/transaction.rs b/src/connection/server/transaction.rs deleted file mode 100644 index 0894cfe3..00000000 --- a/src/connection/server/transaction.rs +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Copyright (C) 2022 Vaticle - * - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use std::{fmt::Debug, time::Duration}; - -use futures::Stream; -use typedb_protocol::transaction as transaction_proto; - -use crate::{ - common::{ - rpc::builder::transaction::{commit_req, open_req, rollback_req}, - Result, ServerRPC, TransactionRPC, TransactionType, - }, - connection::core, - query::QueryManager, -}; - -#[derive(Clone, Debug)] -pub struct Transaction { - pub type_: TransactionType, - pub options: core::Options, - pub query: QueryManager, - rpc: TransactionRPC, -} - -impl Transaction { - pub(crate) async fn new( - session_id: &[u8], - transaction_type: TransactionType, - options: core::Options, - network_latency: Duration, - server_rpc: &ServerRPC, - ) -> Result { - let open_req = open_req( - session_id.to_vec(), - transaction_type.to_proto(), - options.to_proto(), - network_latency.as_millis() as i32, - ); - let rpc = TransactionRPC::new(server_rpc, open_req).await?; - Ok(Transaction { type_: transaction_type, options, query: QueryManager::new(&rpc), rpc }) - } - - pub async fn commit(&mut self) -> Result { - self.single_rpc(commit_req()).await.map(|_| ()) - } - - pub async fn rollback(&mut self) -> Result { - self.single_rpc(rollback_req()).await.map(|_| ()) - } - - pub(crate) async fn single_rpc( - &mut self, - req: transaction_proto::Req, - ) -> Result { - self.rpc.single(req).await - } - - pub(crate) fn streaming_rpc( - &mut self, - req: transaction_proto::Req, - ) -> impl Stream> { - self.rpc.stream(req) - } - - // TODO: refactor to delegate work to a background process - pub async fn close(&self) { - self.rpc.close().await; - } -} diff --git a/src/connection/transaction_stream.rs b/src/connection/transaction_stream.rs new file mode 100644 index 00000000..7adf560e --- /dev/null +++ b/src/connection/transaction_stream.rs @@ -0,0 +1,153 @@ +/* + * Copyright (C) 2022 Vaticle + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::{fmt, iter}; + +use futures::{stream, Stream, StreamExt}; + +use super::network::transmitter::TransactionTransmitter; +use crate::{ + answer::{ConceptMap, Numeric}, + common::Result, + connection::message::{QueryRequest, QueryResponse, TransactionRequest, TransactionResponse}, + error::InternalError, + Options, TransactionType, +}; + +pub(crate) struct TransactionStream { + type_: TransactionType, + options: Options, + transaction_transmitter: TransactionTransmitter, +} + +impl TransactionStream { + pub(super) fn new( + type_: TransactionType, + options: Options, + transaction_transmitter: TransactionTransmitter, + ) -> Self { + Self { type_, options, transaction_transmitter } + } + + pub(crate) fn type_(&self) -> TransactionType { + self.type_ + } + + pub(crate) fn options(&self) -> &Options { + &self.options + } + + pub(crate) async fn commit(&self) -> Result { + self.single(TransactionRequest::Commit).await?; + Ok(()) + } + + pub(crate) async fn rollback(&self) -> Result { + self.single(TransactionRequest::Rollback).await?; + Ok(()) + } + + pub(crate) async fn define(&self, query: String, options: Options) -> Result { + self.single(TransactionRequest::Query(QueryRequest::Define { query, options })).await?; + Ok(()) + } + + pub(crate) async fn undefine(&self, query: String, options: Options) -> Result { + self.single(TransactionRequest::Query(QueryRequest::Undefine { query, options })).await?; + Ok(()) + } + + pub(crate) async fn delete(&self, query: String, options: Options) -> Result { + self.single(TransactionRequest::Query(QueryRequest::Delete { query, options })).await?; + Ok(()) + } + + pub(crate) fn match_(&self, query: String, options: Options) -> Result>> { + let stream = self.query_stream(QueryRequest::Match { query, options })?; + Ok(stream.flat_map(|result| match result { + Ok(QueryResponse::Match { answers }) => stream_iter(answers.into_iter().map(Ok)), + Ok(other) => stream_once(Err(InternalError::UnexpectedResponseType(format!("{other:?}")).into())), + Err(err) => stream_once(Err(err)), + })) + } + + pub(crate) fn insert(&self, query: String, options: Options) -> Result>> { + let stream = self.query_stream(QueryRequest::Insert { query, options })?; + Ok(stream.flat_map(|result| match result { + Ok(QueryResponse::Insert { answers }) => stream_iter(answers.into_iter().map(Ok)), + Ok(other) => stream_once(Err(InternalError::UnexpectedResponseType(format!("{other:?}")).into())), + Err(err) => stream_once(Err(err)), + })) + } + + pub(crate) fn update(&self, query: String, options: Options) -> Result>> { + let stream = self.query_stream(QueryRequest::Update { query, options })?; + Ok(stream.flat_map(|result| match result { + Ok(QueryResponse::Update { answers }) => stream_iter(answers.into_iter().map(Ok)), + Ok(other) => stream_once(Err(InternalError::UnexpectedResponseType(format!("{other:?}")).into())), + Err(err) => stream_once(Err(err)), + })) + } + + pub(crate) async fn match_aggregate(&self, query: String, options: Options) -> Result { + match self.query_single(QueryRequest::MatchAggregate { query, options }).await? { + QueryResponse::MatchAggregate { answer } => Ok(answer), + other => Err(InternalError::UnexpectedResponseType(format!("{other:?}")).into()), + } + } + + async fn single(&self, req: TransactionRequest) -> Result { + self.transaction_transmitter.single(req).await + } + + async fn query_single(&self, req: QueryRequest) -> Result { + match self.single(TransactionRequest::Query(req)).await? { + TransactionResponse::Query(query) => Ok(query), + other => Err(InternalError::UnexpectedResponseType(format!("{other:?}")).into()), + } + } + + fn stream(&self, req: TransactionRequest) -> Result>> { + self.transaction_transmitter.stream(req) + } + + fn query_stream(&self, req: QueryRequest) -> Result>> { + Ok(self.stream(TransactionRequest::Query(req))?.map(|response| match response { + Ok(TransactionResponse::Query(query)) => Ok(query), + Ok(other) => Err(InternalError::UnexpectedResponseType(format!("{other:?}")).into()), + Err(err) => Err(err), + })) + } +} + +impl fmt::Debug for TransactionStream { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("TransactionStream").field("type_", &self.type_).field("options", &self.options).finish() + } +} + +fn stream_once<'a, T: Send + 'a>(value: T) -> stream::BoxStream<'a, T> { + stream_iter(iter::once(value)) +} + +fn stream_iter<'a, T: Send + 'a>(iter: impl Iterator + Send + 'a) -> stream::BoxStream<'a, T> { + Box::pin(stream::iter(iter)) +} diff --git a/src/database/database.rs b/src/database/database.rs new file mode 100644 index 00000000..74c823aa --- /dev/null +++ b/src/database/database.rs @@ -0,0 +1,285 @@ +/* + * Copyright (C) 2022 Vaticle + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::{fmt, future::Future, sync::RwLock, thread::sleep, time::Duration}; + +use itertools::Itertools; +use log::{debug, error}; + +use crate::{ + common::{ + address::Address, + error::ConnectionError, + info::{DatabaseInfo, ReplicaInfo}, + Error, Result, + }, + connection::ServerConnection, + Connection, +}; + +pub struct Database { + name: String, + replicas: RwLock>, + connection: Connection, +} + +impl Database { + const PRIMARY_REPLICA_TASK_MAX_RETRIES: usize = 10; + const FETCH_REPLICAS_MAX_RETRIES: usize = 10; + const WAIT_FOR_PRIMARY_REPLICA_SELECTION: Duration = Duration::from_secs(2); + + pub(super) fn new(database_info: DatabaseInfo, connection: Connection) -> Result { + let name = database_info.name.clone(); + let replicas = RwLock::new(Replica::try_from_info(database_info, &connection)?); + Ok(Self { name, replicas, connection }) + } + + pub(super) async fn get(name: String, connection: Connection) -> Result { + Ok(Self { + name: name.to_string(), + replicas: RwLock::new(Replica::fetch_all(name, connection.clone()).await?), + connection, + }) + } + + pub fn name(&self) -> &str { + self.name.as_str() + } + + pub(super) fn connection(&self) -> &Connection { + &self.connection + } + + pub async fn delete(self) -> Result { + self.run_on_primary_replica(|database, _, _| database.delete()).await + } + + pub async fn schema(&self) -> Result { + self.run_failsafe(|database, _, _| async move { database.schema().await }).await + } + + pub async fn type_schema(&self) -> Result { + self.run_failsafe(|database, _, _| async move { database.type_schema().await }).await + } + + pub async fn rule_schema(&self) -> Result { + self.run_failsafe(|database, _, _| async move { database.rule_schema().await }).await + } + + pub(super) async fn run_failsafe(&self, task: F) -> Result + where + F: Fn(ServerDatabase, ServerConnection, bool) -> P, + P: Future>, + { + match self.run_on_any_replica(&task).await { + Err(Error::Connection(ConnectionError::ClusterReplicaNotPrimary())) => { + debug!("Attempted to run on a non-primary replica, retrying on primary..."); + self.run_on_primary_replica(&task).await + } + res => res, + } + } + + async fn run_on_any_replica(&self, task: F) -> Result + where + F: Fn(ServerDatabase, ServerConnection, bool) -> P, + P: Future>, + { + let mut is_first_run = true; + let replicas = self.replicas.read().unwrap().clone(); + for replica in replicas.iter() { + match task(replica.database.clone(), self.connection.connection(&replica.address)?.clone(), is_first_run) + .await + { + Err(Error::Connection(ConnectionError::UnableToConnect())) => { + debug!("Unable to connect to {}. Attempting next server.", replica.address); + } + res => return res, + } + is_first_run = false; + } + Err(self.connection.unable_to_connect_error()) + } + + async fn run_on_primary_replica(&self, task: F) -> Result + where + F: Fn(ServerDatabase, ServerConnection, bool) -> P, + P: Future>, + { + let mut primary_replica = + if let Some(replica) = self.primary_replica() { replica } else { self.seek_primary_replica().await? }; + + for retry in 0..Self::PRIMARY_REPLICA_TASK_MAX_RETRIES { + match task( + primary_replica.database.clone(), + self.connection.connection(&primary_replica.address)?.clone(), + retry == 0, + ) + .await + { + Err(Error::Connection( + ConnectionError::ClusterReplicaNotPrimary() | ConnectionError::UnableToConnect(), + )) => { + debug!("Primary replica error, waiting..."); + Self::wait_for_primary_replica_selection().await; + primary_replica = self.seek_primary_replica().await?; + } + res => return res, + } + } + Err(self.connection.unable_to_connect_error()) + } + + async fn seek_primary_replica(&self) -> Result { + for _ in 0..Self::FETCH_REPLICAS_MAX_RETRIES { + let replicas = Replica::fetch_all(self.name.clone(), self.connection.clone()).await?; + *self.replicas.write().unwrap() = replicas; + if let Some(replica) = self.primary_replica() { + return Ok(replica); + } + Self::wait_for_primary_replica_selection().await; + } + Err(self.connection.unable_to_connect_error()) + } + + fn primary_replica(&self) -> Option { + self.replicas.read().unwrap().iter().filter(|r| r.is_primary).max_by_key(|r| r.term).cloned() + } + + async fn wait_for_primary_replica_selection() { + // FIXME: blocking sleep! Can't do agnostic async sleep. + sleep(Self::WAIT_FOR_PRIMARY_REPLICA_SELECTION); + } +} + +impl fmt::Debug for Database { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Database").field("name", &self.name).field("replicas", &self.replicas).finish() + } +} + +#[derive(Clone)] +pub(super) struct Replica { + address: Address, + database_name: String, + is_primary: bool, + term: i64, + is_preferred: bool, + database: ServerDatabase, +} + +impl Replica { + fn new(name: String, metadata: ReplicaInfo, server_connection: ServerConnection) -> Self { + Self { + address: metadata.address, + database_name: name.clone(), + is_primary: metadata.is_primary, + term: metadata.term, + is_preferred: metadata.is_preferred, + database: ServerDatabase::new(name, server_connection), + } + } + + fn try_from_info(database_info: DatabaseInfo, connection: &Connection) -> Result> { + database_info + .replicas + .into_iter() + .map(|replica| { + let server_connection = connection.connection(&replica.address)?.clone(); + Ok(Replica::new(database_info.name.clone(), replica, server_connection)) + }) + .try_collect() + } + + async fn fetch_all(name: String, connection: Connection) -> Result> { + for server_connection in connection.connections() { + let res = server_connection.get_database_replicas(name.clone()).await; + match res { + Ok(res) => { + return Replica::try_from_info(res, &connection); + } + Err(Error::Connection(ConnectionError::UnableToConnect())) => { + error!( + "Failed to fetch replica info for database '{}' from {}. Attempting next server.", + name, + server_connection.address() + ); + } + Err(err) => return Err(err), + } + } + Err(connection.unable_to_connect_error()) + } +} + +impl fmt::Debug for Replica { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Replica") + .field("address", &self.address) + .field("database_name", &self.database_name) + .field("is_primary", &self.is_primary) + .field("term", &self.term) + .field("is_preferred", &self.is_preferred) + .finish() + } +} + +#[derive(Clone, Debug)] +pub(super) struct ServerDatabase { + name: String, + connection: ServerConnection, +} + +impl ServerDatabase { + fn new(name: String, connection: ServerConnection) -> Self { + ServerDatabase { name, connection } + } + + pub(super) fn name(&self) -> &str { + self.name.as_str() + } + + pub(super) fn connection(&self) -> &ServerConnection { + &self.connection + } + + async fn delete(self) -> Result { + self.connection.delete_database(self.name).await + } + + async fn schema(&self) -> Result { + self.connection.database_schema(self.name.clone()).await + } + + async fn type_schema(&self) -> Result { + self.connection.database_type_schema(self.name.clone()).await + } + + async fn rule_schema(&self) -> Result { + self.connection.database_rule_schema(self.name.clone()).await + } +} + +impl fmt::Display for ServerDatabase { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.name) + } +} diff --git a/src/database/database_manager.rs b/src/database/database_manager.rs new file mode 100644 index 00000000..f29ee9d8 --- /dev/null +++ b/src/database/database_manager.rs @@ -0,0 +1,79 @@ +/* + * Copyright (C) 2022 Vaticle + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::future::Future; + +use super::{database::ServerDatabase, Database}; +use crate::{ + common::{error::ConnectionError, Result}, + connection::ServerConnection, + Connection, +}; + +#[derive(Clone, Debug)] +pub struct DatabaseManager { + connection: Connection, +} + +impl DatabaseManager { + pub fn new(connection: Connection) -> Self { + Self { connection } + } + + pub async fn get(&self, name: impl Into) -> Result { + Database::get(name.into(), self.connection.clone()).await + } + + pub async fn contains(&self, name: impl Into) -> Result { + self.run_failsafe(name.into(), move |database, server_connection, _| async move { + server_connection.database_exists(database.name().to_owned()).await + }) + .await + } + + pub async fn create(&self, name: impl Into) -> Result { + self.run_failsafe(name.into(), |database, server_connection, _| async move { + server_connection.create_database(database.name().to_owned()).await + }) + .await + } + + pub async fn all(&self) -> Result> { + let mut error_buffer = Vec::with_capacity(self.connection.server_count()); + for server_connection in self.connection.connections() { + match server_connection.all_databases().await { + Ok(list) => { + return list.into_iter().map(|db_info| Database::new(db_info, self.connection.clone())).collect() + } + Err(err) => error_buffer.push(format!("- {}: {}", server_connection.address(), err)), + } + } + Err(ConnectionError::ClusterAllNodesFailed(error_buffer.join("\n")))? + } + + async fn run_failsafe(&self, name: String, task: F) -> Result + where + F: Fn(ServerDatabase, ServerConnection, bool) -> P, + P: Future>, + { + Database::get(name, self.connection.clone()).await?.run_failsafe(&task).await + } +} diff --git a/src/connection/cluster/mod.rs b/src/database/mod.rs similarity index 86% rename from src/connection/cluster/mod.rs rename to src/database/mod.rs index 4bff4fa7..6c3c5606 100644 --- a/src/connection/cluster/mod.rs +++ b/src/database/mod.rs @@ -19,11 +19,10 @@ * under the License. */ -mod client; mod database; mod database_manager; +mod query; mod session; +mod transaction; -pub use self::{ - client::Client, database::Database, database_manager::DatabaseManager, session::Session, -}; +pub use self::{database::Database, database_manager::DatabaseManager, session::Session, transaction::Transaction}; diff --git a/src/database/query.rs b/src/database/query.rs new file mode 100644 index 00000000..b7bd5a52 --- /dev/null +++ b/src/database/query.rs @@ -0,0 +1,98 @@ +/* + * Copyright (C) 2022 Vaticle + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::sync::Arc; + +use futures::Stream; + +use crate::{ + answer::{ConceptMap, Numeric}, + common::Result, + connection::TransactionStream, + Options, +}; + +#[derive(Debug)] +pub struct QueryManager { + transaction_stream: Arc, +} + +impl QueryManager { + pub(super) fn new(transaction_stream: Arc) -> QueryManager { + QueryManager { transaction_stream } + } + + pub async fn define(&self, query: &str) -> Result { + self.define_with_options(query, Options::new()).await + } + + pub async fn define_with_options(&self, query: &str, options: Options) -> Result { + self.transaction_stream.define(query.to_string(), options).await + } + + pub async fn undefine(&self, query: &str) -> Result { + self.undefine_with_options(query, Options::new()).await + } + + pub async fn undefine_with_options(&self, query: &str, options: Options) -> Result { + self.transaction_stream.undefine(query.to_string(), options).await + } + + pub async fn delete(&self, query: &str) -> Result { + self.delete_with_options(query, Options::new()).await + } + + pub async fn delete_with_options(&self, query: &str, options: Options) -> Result { + self.transaction_stream.delete(query.to_string(), options).await + } + + pub fn match_(&self, query: &str) -> Result>> { + self.match_with_options(query, Options::new()) + } + + pub fn match_with_options(&self, query: &str, options: Options) -> Result>> { + self.transaction_stream.match_(query.to_string(), options) + } + + pub fn insert(&self, query: &str) -> Result>> { + self.insert_with_options(query, Options::new()) + } + + pub fn insert_with_options(&self, query: &str, options: Options) -> Result>> { + self.transaction_stream.insert(query.to_string(), options) + } + + pub fn update(&self, query: &str) -> Result>> { + self.update_with_options(query, Options::new()) + } + + pub fn update_with_options(&self, query: &str, options: Options) -> Result>> { + self.transaction_stream.update(query.to_string(), options) + } + + pub async fn match_aggregate(&self, query: &str) -> Result { + self.match_aggregate_with_options(query, Options::new()).await + } + + pub async fn match_aggregate_with_options(&self, query: &str, options: Options) -> Result { + self.transaction_stream.match_aggregate(query.to_string(), options).await + } +} diff --git a/src/database/session.rs b/src/database/session.rs new file mode 100644 index 00000000..b8262732 --- /dev/null +++ b/src/database/session.rs @@ -0,0 +1,129 @@ +/* + * Copyright (C) 2022 Vaticle + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::sync::RwLock; + +use crossbeam::atomic::AtomicCell; +use log::warn; + +use crate::{ + common::{error::ConnectionError, info::SessionInfo, Result, SessionType, TransactionType}, + Database, Options, Transaction, +}; + +#[derive(Debug)] +pub struct Session { + database: Database, + server_session_info: RwLock, + session_type: SessionType, + is_open: AtomicCell, +} + +impl Drop for Session { + fn drop(&mut self) { + if let Err(err) = self.force_close() { + warn!("Error encountered while closing session: {}", err); + } + } +} + +impl Session { + pub async fn new(database: Database, session_type: SessionType) -> Result { + let server_session_info = RwLock::new( + database + .run_failsafe(|database, _, _| async move { + database + .connection() + .open_session(database.name().to_owned(), session_type, Options::default()) + .await + }) + .await?, + ); + + Ok(Self { database, session_type, server_session_info, is_open: AtomicCell::new(true) }) + } + + pub fn database_name(&self) -> &str { + self.database.name() + } + + pub fn type_(&self) -> SessionType { + self.session_type + } + + pub fn is_open(&self) -> bool { + self.is_open.load() + } + + pub fn force_close(&self) -> Result { + if self.is_open.compare_exchange(true, false).is_ok() { + let session_info = self.server_session_info.write().unwrap(); + let connection = self.database.connection().connection(&session_info.address).unwrap(); + connection.close_session(session_info.session_id.clone())?; + } + Ok(()) + } + + pub async fn transaction(&self, transaction_type: TransactionType) -> Result { + self.transaction_with_options(transaction_type, Options::new()).await + } + + pub async fn transaction_with_options( + &self, + transaction_type: TransactionType, + options: Options, + ) -> Result { + if !self.is_open() { + return Err(ConnectionError::SessionIsClosed().into()); + } + + let (session_info, transaction_stream) = self + .database + .run_failsafe(|database, _, is_first_run| { + let session_info = self.server_session_info.read().unwrap().clone(); + let session_type = self.session_type; + let options = options.clone(); + async move { + let connection = database.connection(); + let session_info = if is_first_run { + session_info + } else { + connection.open_session(database.name().to_owned(), session_type, options.clone()).await? + }; + Ok(( + session_info.clone(), + connection + .open_transaction( + session_info.session_id, + transaction_type, + options, + session_info.network_latency, + ) + .await?, + )) + } + }) + .await?; + + *self.server_session_info.write().unwrap() = session_info; + Transaction::new(transaction_stream) + } +} diff --git a/src/database/transaction.rs b/src/database/transaction.rs new file mode 100644 index 00000000..671ae9e4 --- /dev/null +++ b/src/database/transaction.rs @@ -0,0 +1,70 @@ +/* + * Copyright (C) 2022 Vaticle + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::{fmt, marker::PhantomData, sync::Arc}; + +use super::query::QueryManager; +use crate::{ + common::{Result, TransactionType}, + connection::TransactionStream, + Options, +}; + +pub struct Transaction<'a> { + type_: TransactionType, + options: Options, + + query: QueryManager, + transaction_stream: Arc, + + _lifetime_guard: PhantomData<&'a ()>, +} + +impl Transaction<'_> { + pub(super) fn new(transaction_stream: TransactionStream) -> Result { + let transaction_stream = Arc::new(transaction_stream); + Ok(Transaction { + type_: transaction_stream.type_(), + options: transaction_stream.options().clone(), + query: QueryManager::new(transaction_stream.clone()), + transaction_stream, + _lifetime_guard: PhantomData::default(), + }) + } + + pub fn query(&self) -> &QueryManager { + &self.query + } + + pub async fn commit(self) -> Result { + self.transaction_stream.commit().await + } + + pub async fn rollback(&self) -> Result { + self.transaction_stream.rollback().await + } +} + +impl fmt::Debug for Transaction<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Transaction").field("type_", &self.type_).field("options", &self.options).finish() + } +} diff --git a/src/lib.rs b/src/lib.rs index aa246581..721b9fe1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -19,15 +19,14 @@ * under the License. */ -#![allow(dead_code)] - -pub mod answer; -pub mod common; +mod answer; +mod common; pub mod concept; -pub(crate) mod connection; -pub mod query; +mod connection; +mod database; pub use self::{ - common::{Credential, Error, Result, SessionType, TransactionType}, - connection::{cluster, core, server}, + common::{error, Credential, Error, Options, Result, SessionType, TransactionType}, + connection::Connection, + database::{Database, DatabaseManager, Session, Transaction}, }; diff --git a/src/query/mod.rs b/src/query/mod.rs deleted file mode 100644 index 111caccb..00000000 --- a/src/query/mod.rs +++ /dev/null @@ -1,188 +0,0 @@ -/* - * Copyright (C) 2022 Vaticle - * - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use std::iter::once; - -use futures::{stream, Stream, StreamExt}; -use query_manager::res::Res::MatchAggregateRes; -use typedb_protocol::{ - query_manager, - query_manager::res_part::Res::{InsertResPart, MatchResPart, UpdateResPart}, - transaction, -}; - -use crate::{ - answer::{ConceptMap, Numeric}, - common::{ - error::ClientError, - rpc::builder::query_manager::{ - define_req, delete_req, insert_req, match_aggregate_req, match_req, undefine_req, - update_req, - }, - Result, TransactionRPC, - }, - connection::core, -}; - -macro_rules! stream_concept_maps { - ($self:ident, $req:ident, $res_part_kind:ident, $query_type_str:tt) => { - $self.stream_answers($req).flat_map(|result: Result| { - match result { - Ok(res_part) => match res_part { - $res_part_kind(x) => { - stream::iter(x.answers.into_iter().map(|cm| ConceptMap::from_proto(cm))) - .left_stream() - } - _ => stream::iter(once(Err(ClientError::MissingResponseField(concat!( - "query_manager_res_part.", - $query_type_str, - "_res_part" - )) - .into()))) - .right_stream(), - }, - Err(err) => stream::iter(once(Err(err))).right_stream(), - } - }) - }; -} - -#[derive(Clone, Debug)] -pub struct QueryManager { - tx: TransactionRPC, -} - -impl QueryManager { - pub(crate) fn new(tx: &TransactionRPC) -> QueryManager { - QueryManager { tx: tx.clone() } - } - - pub async fn define(&mut self, query: &str) -> Result { - self.single_call(define_req(query, None)).await.map(|_| ()) - } - - pub async fn define_with_options(&mut self, query: &str, options: &core::Options) -> Result { - self.single_call(define_req(query, Some(options.to_proto()))).await.map(|_| ()) - } - - pub async fn delete(&mut self, query: &str) -> Result { - self.single_call(delete_req(query, None)).await.map(|_| ()) - } - - pub async fn delete_with_options(&mut self, query: &str, options: &core::Options) -> Result { - self.single_call(delete_req(query, Some(options.to_proto()))).await.map(|_| ()) - } - - pub fn insert(&mut self, query: &str) -> impl Stream> { - let req = insert_req(query, None); - stream_concept_maps!(self, req, InsertResPart, "insert") - } - - pub fn insert_with_options( - &mut self, - query: &str, - options: &core::Options, - ) -> impl Stream> { - let req = insert_req(query, Some(options.to_proto())); - stream_concept_maps!(self, req, InsertResPart, "insert") - } - - // TODO: investigate performance impact of using BoxStream - pub fn match_(&mut self, query: &str) -> impl Stream> { - let req = match_req(query, None); - stream_concept_maps!(self, req, MatchResPart, "match") - } - - pub fn match_with_options( - &mut self, - query: &str, - options: &core::Options, - ) -> impl Stream> { - let req = match_req(query, Some(options.to_proto())); - stream_concept_maps!(self, req, MatchResPart, "match") - } - - pub async fn match_aggregate(&mut self, query: &str) -> Result { - match self.single_call(match_aggregate_req(query, None)).await? { - MatchAggregateRes(res) => res.answer.unwrap().try_into(), - _ => Err(ClientError::MissingResponseField("match_aggregate_res"))?, - } - } - - pub async fn match_aggregate_with_options( - &mut self, - query: &str, - options: core::Options, - ) -> Result { - match self.single_call(match_aggregate_req(query, Some(options.to_proto()))).await? { - MatchAggregateRes(res) => res.answer.unwrap().try_into(), - _ => Err(ClientError::MissingResponseField("match_aggregate_res"))?, - } - } - - pub async fn undefine(&mut self, query: &str) -> Result { - self.single_call(undefine_req(query, None)).await.map(|_| ()) - } - - pub async fn undefine_with_options(&mut self, query: &str, options: &core::Options) -> Result { - self.single_call(undefine_req(query, Some(options.to_proto()))).await.map(|_| ()) - } - - pub fn update(&mut self, query: &str) -> impl Stream> { - let req = update_req(query, None); - stream_concept_maps!(self, req, UpdateResPart, "update") - } - - pub fn update_with_options( - &mut self, - query: &str, - options: &core::Options, - ) -> impl Stream> { - let req = update_req(query, Some(options.to_proto())); - stream_concept_maps!(self, req, UpdateResPart, "update") - } - - async fn single_call(&mut self, req: transaction::Req) -> Result { - match self.tx.single(req).await?.res { - Some(transaction::res::Res::QueryManagerRes(res)) => { - res.res.ok_or(ClientError::MissingResponseField("res.query_manager_res").into()) - } - _ => Err(ClientError::MissingResponseField("res.query_manager_res"))?, - } - } - - fn stream_answers( - &mut self, - req: transaction::Req, - ) -> impl Stream> { - self.tx.stream(req).map(|result: Result| match result { - Ok(tx_res_part) => match tx_res_part.res { - Some(transaction::res_part::Res::QueryManagerResPart(res_part)) => { - res_part.res.ok_or( - ClientError::MissingResponseField("res_part.query_manager_res_part").into(), - ) - } - _ => Err(ClientError::MissingResponseField("res_part.query_manager_res_part"))?, - }, - Err(err) => Err(err), - }) - } -} diff --git a/tests/BUILD b/tests/BUILD index 5d209075..ea2bd313 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -24,31 +24,36 @@ package(default_visibility = ["//visibility:public"]) load("@rules_rust//rust:defs.bzl", "rust_test", "rustfmt_test") load("@vaticle_bazel_distribution//artifact:rules.bzl", "artifact_extractor") load("@vaticle_dependencies//tool/checkstyle:rules.bzl", "checkstyle_test") -load("@vaticle_typedb_common//test:rules.bzl", "native_typedb_artifact") +load("@vaticle_typedb_common//runner:rules.bzl", "native_typedb_artifact") rust_test( - name = "queries_core", - srcs = ["queries_core.rs"], + name = "queries", + srcs = [ + "common.rs", + "queries.rs", + ], deps = [ "//:typedb_client", - - "@vaticle_dependencies//library/crates:chrono", - "@vaticle_dependencies//library/crates:futures", - "@vaticle_dependencies//library/crates:serial_test", - "@vaticle_dependencies//library/crates:tokio", + "@crates//:chrono", + "@crates//:futures", + "@crates//:serial_test", + "@crates//:tokio", ], ) rust_test( - name = "queries_cluster", - srcs = ["queries_cluster.rs"], + name = "runtimes", + srcs = [ + "common.rs", + "runtimes.rs", + ], deps = [ "//:typedb_client", - - "@vaticle_dependencies//library/crates:futures", - "@vaticle_dependencies//library/crates:serial_test", - "@vaticle_dependencies//library/crates:tokio", + "@crates//:async-std", + "@crates//:futures", + "@crates//:serial_test", + "@crates//:smol", ], ) @@ -80,10 +85,7 @@ artifact_extractor( rustfmt_test( name = "queries_rustfmt_test", - targets = [ - "queries_core", - "queries_cluster", - ] + targets = ["queries", "runtimes"] ) checkstyle_test( diff --git a/tests/common.rs b/tests/common.rs new file mode 100644 index 00000000..f8bc7ac5 --- /dev/null +++ b/tests/common.rs @@ -0,0 +1,62 @@ +/* + * Copyright (C) 2022 Vaticle + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::path::PathBuf; + +use futures::TryFutureExt; +use typedb_client::{ + Connection, Credential, Database, DatabaseManager, Session, SessionType::Schema, TransactionType::Write, +}; + +pub const TEST_DATABASE: &str = "test"; + +pub fn new_core_connection() -> typedb_client::Result { + Connection::new_plaintext("127.0.0.1:1729") +} + +pub fn new_cluster_connection() -> typedb_client::Result { + Connection::new_encrypted( + &["localhost:11729", "localhost:21729", "localhost:31729"], + Credential::with_tls( + "admin", + "password", + Some(&PathBuf::from( + std::env::var("ROOT_CA") + .expect("ROOT_CA environment variable needs to be set for cluster tests to run"), + )), + )?, + ) +} + +pub async fn create_test_database_with_schema(connection: Connection, schema: &str) -> typedb_client::Result { + let databases = DatabaseManager::new(connection); + if databases.contains(TEST_DATABASE).await? { + databases.get(TEST_DATABASE).and_then(Database::delete).await?; + } + databases.create(TEST_DATABASE).await?; + + let database = databases.get(TEST_DATABASE).await?; + let session = Session::new(database, Schema).await?; + let transaction = session.transaction(Write).await?; + transaction.query().define(schema).await?; + transaction.commit().await?; + Ok(()) +} diff --git a/tests/queries.rs b/tests/queries.rs new file mode 100644 index 00000000..55bdc102 --- /dev/null +++ b/tests/queries.rs @@ -0,0 +1,332 @@ +/* + * Copyright (C) 2022 Vaticle + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +mod common; + +use std::{sync::Arc, time::Instant}; + +use chrono::{NaiveDate, NaiveDateTime}; +use futures::StreamExt; +use serial_test::serial; +use tokio::sync::mpsc; +use typedb_client::{ + concept::{Attribute, Concept, DateTimeAttribute, StringAttribute, Thing}, + error::ConnectionError, + Connection, DatabaseManager, Error, Options, Session, + SessionType::Data, + TransactionType::{Read, Write}, +}; + +macro_rules! test_for_each_arg { + { + $perm_args:tt + $( $( #[ $extra_anno:meta ] )* $async:ident fn $test:ident $args:tt -> $ret:ty $test_impl:block )+ + } => { + test_for_each_arg!{ @impl $( $async fn $test $args $ret $test_impl )+ } + test_for_each_arg!{ @impl_per $perm_args { $( $( #[ $extra_anno ] )* $async fn $test )+ } } + }; + + { @impl $( $async:ident fn $test:ident $args:tt $ret:ty $test_impl:block )+ } => { + mod _impl { + use super::*; + $( pub $async fn $test $args -> $ret $test_impl )+ + } + }; + + { @impl_per { $($mod:ident => $arg:expr),+ $(,)? } $fns:tt } => { + $(test_for_each_arg!{ @impl_mod { $mod => $arg } $fns })+ + }; + + { @impl_mod { $mod:ident => $arg:expr } { $( $( #[ $extra_anno:meta ] )* async fn $test:ident )+ } } => { + mod $mod { + use super::*; + $( + #[tokio::test] + #[serial($mod)] + $( #[ $extra_anno ] )* + pub async fn $test() { + _impl::$test($arg).await.unwrap(); + } + )+ + } + }; +} + +test_for_each_arg! { + { + core => common::new_core_connection().unwrap(), + cluster => common::new_cluster_connection().unwrap(), + } + + async fn basic(connection: Connection) -> typedb_client::Result { + common::create_test_database_with_schema(connection.clone(), "define person sub entity;").await?; + let databases = DatabaseManager::new(connection); + assert!(databases.contains(common::TEST_DATABASE).await?); + + let session = Session::new(databases.get(common::TEST_DATABASE).await?, Data).await?; + let transaction = session.transaction(Write).await?; + let answer_stream = transaction.query().match_("match $x sub thing;")?; + let results: Vec<_> = answer_stream.collect().await; + transaction.commit().await?; + assert_eq!(results.len(), 5); + assert!(results.into_iter().all(|res| res.is_ok())); + + Ok(()) + } + + async fn query_error(connection: Connection) -> typedb_client::Result { + common::create_test_database_with_schema(connection.clone(), "define person sub entity;").await?; + let databases = DatabaseManager::new(connection); + + let session = Session::new(databases.get(common::TEST_DATABASE).await?, Data).await?; + let transaction = session.transaction(Write).await?; + let answer_stream = transaction.query().match_("match $x sub nonexistent-type;")?; + let results: Vec<_> = answer_stream.collect().await; + assert_eq!(results.len(), 1); + assert!(results.into_iter().all(|res| res.unwrap_err().to_string().contains("[TYR03]"))); + + Ok(()) + } + + async fn concurrent_transactions(connection: Connection) -> typedb_client::Result { + common::create_test_database_with_schema(connection.clone(), "define person sub entity;").await?; + let databases = DatabaseManager::new(connection); + + let session = Arc::new(Session::new(databases.get(common::TEST_DATABASE).await?, Data).await?); + + let (sender, mut receiver) = mpsc::channel(5 * 5 * 8); + + for _ in 0..8 { + let sender = sender.clone(); + let session = session.clone(); + tokio::spawn(async move { + for _ in 0..5 { + let transaction = session.transaction(Read).await.unwrap(); + let mut answer_stream = transaction.query().match_("match $x sub thing;").unwrap(); + while let Some(result) = answer_stream.next().await { + sender.send(result).await.unwrap(); + } + } + }); + } + drop(sender); // receiver expects data while any sender is live + + let mut results = Vec::with_capacity(5 * 5 * 8); + while let Some(result) = receiver.recv().await { + results.push(result); + } + assert_eq!(results.len(), 5 * 5 * 8); + assert!(results.into_iter().all(|res| res.is_ok())); + + Ok(()) + } + + async fn query_options(connection: Connection) -> typedb_client::Result { + let schema = r#"define + person sub entity, + owns name, + owns age; + name sub attribute, value string; + age sub attribute, value long; + rule age-rule: when { $x isa person; } then { $x has age 25; };"#; + common::create_test_database_with_schema(connection.clone(), schema).await?; + let databases = DatabaseManager::new(connection); + + let session = Session::new(databases.get(common::TEST_DATABASE).await?, Data).await?; + let transaction = session.transaction(Write).await?; + let data = "insert $x isa person, has name 'Alice'; $y isa person, has name 'Bob';"; + let _ = transaction.query().insert(data); + transaction.commit().await?; + + let transaction = session.transaction(Read).await?; + let age_count = transaction.query().match_aggregate("match $x isa age; count;").await?; + assert_eq!(age_count.into_i64(), 0); + + let with_inference = Options::new().infer(true); + let transaction = session.transaction_with_options(Read, with_inference).await?; + let age_count = transaction.query().match_aggregate("match $x isa age; count;").await?; + assert_eq!(age_count.into_i64(), 1); + + Ok(()) + } + + async fn many_concept_types(connection: Connection) -> typedb_client::Result { + let schema = r#"define + person sub entity, + owns name, + owns date-of-birth, + plays friendship:friend; + name sub attribute, value string; + date-of-birth sub attribute, value datetime; + friendship sub relation, + relates friend;"#; + common::create_test_database_with_schema(connection.clone(), schema).await?; + let databases = DatabaseManager::new(connection); + + let session = Session::new(databases.get(common::TEST_DATABASE).await?, Data).await?; + let transaction = session.transaction(Write).await?; + let data = r#"insert + $x isa person, has name "Alice", has date-of-birth 1994-10-03; + $y isa person, has name "Bob", has date-of-birth 1993-04-17; + (friend: $x, friend: $y) isa friendship;"#; + let _ = transaction.query().insert(data); + transaction.commit().await?; + + let transaction = session.transaction(Read).await?; + let mut answer_stream = transaction.query().match_( + r#"match + $p isa person, has name $name, has date-of-birth $date-of-birth; + $f($role: $p) isa friendship;"#, + )?; + + while let Some(result) = answer_stream.next().await { + assert!(result.is_ok()); + let mut result = result?.map; + let name = unwrap_string(result.remove("name").unwrap()); + let date_of_birth = unwrap_date_time(result.remove("date-of-birth").unwrap()).date(); + match name.as_str() { + "Alice" => assert_eq!(date_of_birth, NaiveDate::from_ymd_opt(1994, 10, 3).unwrap()), + "Bob" => assert_eq!(date_of_birth, NaiveDate::from_ymd_opt(1993, 4, 17).unwrap()), + _ => unreachable!(), + } + } + + Ok(()) + } + + async fn force_close_connection(connection: Connection) -> typedb_client::Result { + common::create_test_database_with_schema(connection.clone(), "define person sub entity;").await?; + let databases = DatabaseManager::new(connection.clone()); + + let database = databases.get(common::TEST_DATABASE).await?; + assert!(database.schema().await.is_ok()); + + let session = Session::new(databases.get(common::TEST_DATABASE).await?, Data).await?; + connection.clone().force_close()?; + + let schema = database.schema().await; + assert!(schema.is_err()); + assert!(matches!(schema, Err(Error::Connection(ConnectionError::ConnectionIsClosed())))); + + let database2 = databases.get(common::TEST_DATABASE).await; + assert!(database2.is_err()); + assert!(matches!(database2, Err(Error::Connection(ConnectionError::ConnectionIsClosed())))); + + let transaction = session.transaction(Write).await; + assert!(transaction.is_err()); + assert!(matches!(transaction, Err(Error::Connection(ConnectionError::ConnectionIsClosed())))); + + let session = Session::new(database, Data).await; + assert!(session.is_err()); + assert!(matches!(session, Err(Error::Connection(ConnectionError::ConnectionIsClosed())))); + + Ok(()) + } + + async fn force_close_session(connection: Connection) -> typedb_client::Result { + common::create_test_database_with_schema(connection.clone(), "define person sub entity;").await?; + let databases = DatabaseManager::new(connection.clone()); + + let session = Arc::new(Session::new(databases.get(common::TEST_DATABASE).await?, Data).await?); + let transaction = session.transaction(Write).await?; + + let session2 = session.clone(); + session2.force_close()?; + + let answer_stream = transaction.query().match_("match $x sub thing;"); + assert!(answer_stream.is_err()); + assert!(transaction.query().match_("match $x sub thing;").is_err()); + + let transaction = session.transaction(Write).await; + assert!(transaction.is_err()); + assert!(matches!(transaction, Err(Error::Connection(ConnectionError::SessionIsClosed())))); + + assert!(Session::new(databases.get(common::TEST_DATABASE).await?, Data).await.is_ok()); + + Ok(()) + } + + #[ignore] + async fn streaming_perf(connection: Connection) -> typedb_client::Result { + for i in 0..5 { + let schema = r#"define + person sub entity, owns name, owns age; + name sub attribute, value string; + age sub attribute, value long;"#; + common::create_test_database_with_schema(connection.clone(), schema).await?; + let databases = DatabaseManager::new(connection.clone()); + + let start_time = Instant::now(); + let session = Session::new(databases.get(common::TEST_DATABASE).await?, Data).await?; + let transaction = session.transaction(Write).await?; + for j in 0..100_000 { + drop(transaction.query().insert(format!("insert $x {j} isa age;").as_str())?); + } + transaction.commit().await?; + println!("iteration {i}: inserted and committed 100k attrs in {}ms", start_time.elapsed().as_millis()); + + let mut start_time = Instant::now(); + let session = Session::new(databases.get(common::TEST_DATABASE).await?, Data).await?; + let transaction = session.transaction(Read).await?; + let mut answer_stream = transaction.query().match_("match $x isa attribute;")?; + let mut sum: i64 = 0; + let mut idx = 0; + while let Some(result) = answer_stream.next().await { + match result { + Ok(concept_map) => { + for (_, concept) in concept_map { + if let Concept::Thing(Thing::Attribute(Attribute::Long(long_attr))) = concept { + sum += long_attr.value + } + } + } + Err(err) => { + panic!("An error occurred fetching answers of a Match query: {}", err) + } + } + idx = idx + 1; + if idx == 100_000 { + println!("iteration {i}: retrieved and summed 100k attrs in {}ms", start_time.elapsed().as_millis()); + start_time = Instant::now(); + } + } + println!("sum is {}", sum); + } + + Ok(()) + } +} + +// Concept helpers +// FIXME: should be removed after concept API is implemented +fn unwrap_date_time(concept: Concept) -> NaiveDateTime { + match concept { + Concept::Thing(Thing::Attribute(Attribute::DateTime(DateTimeAttribute { value, .. }))) => value, + _ => unreachable!(), + } +} + +fn unwrap_string(concept: Concept) -> String { + match concept { + Concept::Thing(Thing::Attribute(Attribute::String(StringAttribute { value, .. }))) => value, + _ => unreachable!(), + } +} diff --git a/tests/queries_cluster.rs b/tests/queries_cluster.rs deleted file mode 100644 index 058cafa4..00000000 --- a/tests/queries_cluster.rs +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Copyright (C) 2022 Vaticle - * - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use std::path::PathBuf; - -use futures::{StreamExt, TryFutureExt}; -use serial_test::serial; -use typedb_client::{ - cluster, - common::{Credential, SessionType::Data, TransactionType::Write}, -}; - -const TEST_DATABASE: &str = "test"; - -#[tokio::test(flavor = "multi_thread")] -#[serial] -async fn basic() { - let mut client = cluster::Client::new( - &["localhost:11729", "localhost:21729", "localhost:31729"], - Credential::with_tls( - "admin", - "password", - Some(&PathBuf::from(std::env::var("ROOT_CA").unwrap())), - ), - ) - .await - .unwrap(); - - if client.databases().contains(TEST_DATABASE).await.unwrap() { - client.databases().get(TEST_DATABASE).and_then(|db| db.delete()).await.unwrap(); - } - client.databases().create(TEST_DATABASE).await.unwrap(); - - assert!(client.databases().contains(TEST_DATABASE).await.unwrap()); - - let mut session = client.session(TEST_DATABASE, Data).await.unwrap(); - let mut transaction = session.transaction(Write).await.unwrap(); - let mut answer_stream = transaction.query.match_("match $x sub thing;"); - while let Some(result) = answer_stream.next().await { - assert!(result.is_ok()) - } - transaction.commit().await.unwrap(); -} diff --git a/tests/queries_core.rs b/tests/queries_core.rs deleted file mode 100644 index 3858a0b1..00000000 --- a/tests/queries_core.rs +++ /dev/null @@ -1,249 +0,0 @@ -/* - * Copyright (C) 2022 Vaticle - * - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use std::{sync::mpsc, time::Instant}; - -use chrono::{NaiveDate, NaiveDateTime}; -use futures::{StreamExt, TryFutureExt}; -use serial_test::serial; -use typedb_client::{ - common::{ - SessionType::{Data, Schema}, - TransactionType::{Read, Write}, - }, - concept::{Attribute, Concept, DateTimeAttribute, StringAttribute, Thing}, - core, server, -}; - -const TEST_DATABASE: &str = "test"; - -#[tokio::test(flavor = "multi_thread")] -#[serial] -async fn basic() { - let mut client = core::Client::with_default_address().await.unwrap(); - create_test_database_with_schema(&mut client, "define person sub entity;").await.unwrap(); - assert!(client.databases().contains(TEST_DATABASE).await.unwrap()); - - let session = client.session(TEST_DATABASE, Data).await.unwrap(); - let mut transaction = session.transaction(Write).await.unwrap(); - let mut answer_stream = transaction.query.match_("match $x sub thing;"); - while let Some(result) = answer_stream.next().await { - assert!(result.is_ok()) - } - transaction.commit().await.unwrap(); -} - -#[tokio::test(flavor = "multi_thread")] -#[serial] -async fn concurrent_queries() { - let mut client = core::Client::with_default_address().await.unwrap(); - create_test_database_with_schema(&mut client, "define person sub entity;").await.unwrap(); - - let session = client.session(TEST_DATABASE, Data).await.unwrap(); - let transaction = session.transaction(Write).await.unwrap(); - - let (sender, receiver) = mpsc::channel(); - - for _ in 0..5 { - let sender = sender.clone(); - let mut transaction = transaction.clone(); - tokio::spawn(async move { - for _ in 0..5 { - let mut answer_stream = transaction.query.match_("match $x sub thing;"); - while let Some(result) = answer_stream.next().await { - sender.send(result).unwrap(); - } - } - }); - } - drop(sender); // receiver expects data while any sender is live - - for received in receiver { - assert!(received.is_ok()); - } -} - -#[tokio::test(flavor = "multi_thread")] -#[serial] -async fn query_options() { - let mut client = core::Client::with_default_address().await.unwrap(); - let schema = r#"define - person sub entity, - owns name, - owns age; - name sub attribute, value string; - age sub attribute, value long; - rule age-rule: when { $x isa person; } then { $x has age 25; };"#; - create_test_database_with_schema(&mut client, schema).await.unwrap(); - - let session = client.session(TEST_DATABASE, Data).await.unwrap(); - let mut transaction = session.transaction(Write).await.unwrap(); - let data = "insert $x isa person, has name 'Alice'; $y isa person, has name 'Bob';"; - let _ = transaction.query.insert(data); - transaction.commit().await.unwrap(); - - let mut transaction = session.transaction(Read).await.unwrap(); - let age_count = transaction.query.match_aggregate("match $x isa age; count;").await.unwrap(); - assert_eq!(age_count.into_i64(), 0); - - let with_inference = core::Options::new_core().infer(true); - let mut transaction = session.transaction_with_options(Read, with_inference).await.unwrap(); - let age_count = transaction.query.match_aggregate("match $x isa age; count;").await.unwrap(); - assert_eq!(age_count.into_i64(), 1); -} - -#[tokio::test(flavor = "multi_thread")] -#[serial] -async fn many_concept_types() { - let mut client = core::Client::with_default_address().await.unwrap(); - let schema = r#"define - person sub entity, - owns name, - owns date-of-birth, - plays friendship:friend; - name sub attribute, value string; - date-of-birth sub attribute, value datetime; - friendship sub relation, - relates friend;"#; - create_test_database_with_schema(&mut client, schema).await.unwrap(); - - let session = client.session(TEST_DATABASE, Data).await.unwrap(); - let mut transaction = session.transaction(Write).await.unwrap(); - let data = r#"insert - $x isa person, has name "Alice", has date-of-birth 1994-10-03; - $y isa person, has name "Bob", has date-of-birth 1993-04-17; - (friend: $x, friend: $y) isa friendship;"#; - let _ = transaction.query.insert(data); - transaction.commit().await.unwrap(); - - let mut transaction = session.transaction(Read).await.unwrap(); - let mut answer_stream = transaction.query.match_( - r#"match - $p isa person, has name $name, has date-of-birth $date-of-birth; - $f($role: $p) isa friendship;"#, - ); - - while let Some(result) = answer_stream.next().await { - assert!(result.is_ok()); - let mut result = result.unwrap().map; - let name = unwrap_string(result.remove("name").unwrap()); - let date_of_birth = unwrap_date_time(result.remove("date-of-birth").unwrap()).date(); - match name.as_str() { - "Alice" => assert_eq!(date_of_birth, NaiveDate::from_ymd_opt(1994, 10, 3).unwrap()), - "Bob" => assert_eq!(date_of_birth, NaiveDate::from_ymd_opt(1993, 4, 17).unwrap()), - _ => unreachable!(), - } - } -} - -#[tokio::test(flavor = "multi_thread")] -#[serial] -#[ignore] -async fn streaming_perf() { - let mut client = core::Client::with_default_address().await.unwrap(); - for i in 0..5 { - let schema = r#"define - person sub entity, owns name, owns age; - name sub attribute, value string; - age sub attribute, value long;"#; - create_test_database_with_schema(&mut client, schema).await.unwrap(); - - let start_time = Instant::now(); - let session = client.session(TEST_DATABASE, Data).await.unwrap(); - let mut transaction = session.transaction(Write).await.unwrap(); - for j in 0..100_000 { - let _ = transaction.query.insert(format!("insert $x {j} isa age;").as_str()); - } - transaction.commit().await.unwrap(); - println!( - "iteration {i}: inserted and committed 100k attrs in {}ms", - (Instant::now() - start_time).as_millis() - ); - - let mut start_time = Instant::now(); - let session = client.session(TEST_DATABASE, Data).await.unwrap(); - let mut transaction = session.transaction(Read).await.unwrap(); - let mut answer_stream = transaction.query.match_("match $x isa attribute;"); - let mut sum: i64 = 0; - let mut idx = 0; - while let Some(result) = answer_stream.next().await { - match result { - Ok(concept_map) => { - for (_, concept) in concept_map { - if let Concept::Thing(Thing::Attribute(Attribute::Long(long_attr))) = - concept - { - sum += long_attr.value - } - } - } - Err(err) => { - panic!("An error occurred fetching answers of a Match query: {}", err) - } - } - idx = idx + 1; - if idx == 100_000 { - println!( - "iteration {i}: retrieved and summed 100k attrs in {}ms", - (Instant::now() - start_time).as_millis() - ); - start_time = Instant::now(); - } - } - println!("sum is {}", sum); - } -} - -async fn create_test_database_with_schema( - client: &mut core::Client, - schema: &str, -) -> typedb_client::Result { - if client.databases().contains(TEST_DATABASE).await.unwrap() { - client.databases().get(TEST_DATABASE).and_then(server::Database::delete).await.unwrap(); - } - client.databases().create(TEST_DATABASE).await.unwrap(); - - let mut session = client.session(TEST_DATABASE, Schema).await.unwrap(); - let mut transaction = session.transaction(Write).await.unwrap(); - transaction.query.define(schema).await.unwrap(); - transaction.commit().await.unwrap(); - session.close().await; - - Ok(()) -} - -// Concept helpers -// FIXME should be removed after concept API is implemented -fn unwrap_date_time(concept: Concept) -> NaiveDateTime { - match concept { - Concept::Thing(Thing::Attribute(Attribute::DateTime(DateTimeAttribute { - value, .. - }))) => value, - _ => unreachable!(), - } -} - -fn unwrap_string(concept: Concept) -> String { - match concept { - Concept::Thing(Thing::Attribute(Attribute::String(StringAttribute { value, .. }))) => value, - _ => unreachable!(), - } -} diff --git a/tests/runtimes.rs b/tests/runtimes.rs new file mode 100644 index 00000000..00943ed0 --- /dev/null +++ b/tests/runtimes.rs @@ -0,0 +1,89 @@ +/* + * Copyright (C) 2022 Vaticle + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +mod common; + +use futures::StreamExt; +use serial_test::serial; +use typedb_client::{DatabaseManager, Session, SessionType::Data, TransactionType::Write}; + +#[test] +#[serial] +fn basic_async_std() { + async_std::task::block_on(async { + let connection = common::new_cluster_connection()?; + common::create_test_database_with_schema(connection.clone(), "define person sub entity;").await?; + let databases = DatabaseManager::new(connection); + assert!(databases.contains(common::TEST_DATABASE).await?); + + let session = Session::new(databases.get(common::TEST_DATABASE).await?, Data).await?; + let transaction = session.transaction(Write).await?; + let answer_stream = transaction.query().match_("match $x sub thing;")?; + let results: Vec<_> = answer_stream.collect().await; + transaction.commit().await?; + assert_eq!(results.len(), 5); + assert!(results.into_iter().all(|res| res.is_ok())); + Ok::<(), typedb_client::Error>(()) + }) + .unwrap(); +} + +#[test] +#[serial] +fn basic_smol() { + smol::block_on(async { + let connection = common::new_cluster_connection()?; + common::create_test_database_with_schema(connection.clone(), "define person sub entity;").await?; + let databases = DatabaseManager::new(connection); + assert!(databases.contains(common::TEST_DATABASE).await?); + + let session = Session::new(databases.get(common::TEST_DATABASE).await?, Data).await?; + let transaction = session.transaction(Write).await?; + let answer_stream = transaction.query().match_("match $x sub thing;")?; + let results: Vec<_> = answer_stream.collect().await; + transaction.commit().await?; + assert_eq!(results.len(), 5); + assert!(results.into_iter().all(|res| res.is_ok())); + Ok::<(), typedb_client::Error>(()) + }) + .unwrap(); +} + +#[test] +#[serial] +fn basic_futures() { + futures::executor::block_on(async { + let connection = common::new_cluster_connection()?; + common::create_test_database_with_schema(connection.clone(), "define person sub entity;").await?; + let databases = DatabaseManager::new(connection); + assert!(databases.contains(common::TEST_DATABASE).await?); + + let session = Session::new(databases.get(common::TEST_DATABASE).await?, Data).await?; + let transaction = session.transaction(Write).await?; + let answer_stream = transaction.query().match_("match $x sub thing;")?; + let results: Vec<_> = answer_stream.collect().await; + transaction.commit().await?; + assert_eq!(results.len(), 5); + assert!(results.into_iter().all(|res| res.is_ok())); + Ok::<(), typedb_client::Error>(()) + }) + .unwrap(); +}