diff --git a/Cargo.lock.msrv b/Cargo.lock.msrv index 4ddbe9dae1..0bc434d30e 100644 --- a/Cargo.lock.msrv +++ b/Cargo.lock.msrv @@ -586,6 +586,17 @@ dependencies = [ "futures-util", ] +[[package]] +name = "futures-batch" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f444c45a1cb86f2a7e301469fd50a82084a60dadc25d94529a8312276ecb71a" +dependencies = [ + "futures", + "futures-timer", + "pin-utils", +] + [[package]] name = "futures-channel" version = "0.3.28" @@ -642,6 +653,12 @@ version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "76d3d132be6c0e6aa1534069c705a74a5997a356c0dc2f86a47765e5617c5b65" +[[package]] +name = "futures-timer" +version = "3.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" + [[package]] name = "futures-util" version = "0.3.28" @@ -1490,6 +1507,7 @@ dependencies = [ "criterion", "dashmap", "futures", + "futures-batch", "hashbrown 0.14.0", "histogram", "itertools 0.11.0", @@ -1514,6 +1532,7 @@ dependencies = [ "time", "tokio", "tokio-openssl", + "tokio-stream", "tracing", "tracing-subscriber", "url", @@ -1861,6 +1880,17 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-stream" +version = "0.1.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "267ac89e0bec6e691e5813911606935d77c476ff49024f98abcea3e7b15e37af" +dependencies = [ + "futures-core", + "pin-project-lite", + "tokio", +] + [[package]] name = "toml_datetime" version = "0.6.3" diff --git a/scylla-cql/src/types/serialize/row.rs b/scylla-cql/src/types/serialize/row.rs index f980064388..6b968060a7 100644 --- a/scylla-cql/src/types/serialize/row.rs +++ b/scylla-cql/src/types/serialize/row.rs @@ -80,6 +80,17 @@ pub trait SerializeRow { /// the bind marker types and names so that the values can be properly /// type checked and serialized. fn is_empty(&self) -> bool; + + /// Specialization that allows the driver to not re-serialize the row if it's already + /// a `SerializedValues` + /// + /// Note that if using this, it's the user's responsibility to ensure that this + /// `SerializedValues` has been generated with the same prepared statement as the query + /// is going to be made with. + #[inline] + fn already_serialized(&self) -> Option<&SerializedValues> { + None + } } macro_rules! fallback_impl_contents { @@ -255,12 +266,36 @@ impl SerializeRow for &T { ctx: &RowSerializationContext<'_>, writer: &mut RowWriter, ) -> Result<(), SerializationError> { - ::serialize(self, ctx, writer) + ::serialize(*self, ctx, writer) } #[inline] fn is_empty(&self) -> bool { - ::is_empty(self) + ::is_empty(*self) + } + + #[inline] + fn already_serialized(&self) -> Option<&SerializedValues> { + ::already_serialized(*self) + } +} + +impl SerializeRow for SerializedValues { + fn serialize( + &self, + _ctx: &RowSerializationContext<'_>, + writer: &mut RowWriter, + ) -> Result<(), SerializationError> { + writer.append_serialize_row(self); + Ok(()) + } + + fn is_empty(&self) -> bool { + self.is_empty() + } + + fn already_serialized(&self) -> Option<&SerializedValues> { + Some(self) } } diff --git a/scylla/Cargo.toml b/scylla/Cargo.toml index 5c4667c95e..fb47a22ba9 100644 --- a/scylla/Cargo.toml +++ b/scylla/Cargo.toml @@ -89,6 +89,8 @@ tracing-subscriber = { version = "0.3.14", features = ["env-filter"] } assert_matches = "1.5.0" rand_chacha = "0.3.1" time = "0.3" +futures-batch = "0.6.1" +tokio-stream = "0.1.14" [[bench]] name = "benchmark" diff --git a/scylla/src/statement/batch.rs b/scylla/src/statement/batch.rs index efe95031e2..d668c32d65 100644 --- a/scylla/src/statement/batch.rs +++ b/scylla/src/statement/batch.rs @@ -2,9 +2,13 @@ use std::borrow::Cow; use std::sync::Arc; use crate::history::HistoryListener; +use crate::load_balancing; use crate::retry_policy::RetryPolicy; +use crate::routing::Shard; use crate::statement::{prepared_statement::PreparedStatement, query::Query}; use crate::transport::execution_profile::ExecutionProfileHandle; +use crate::transport::NodeRef; +use crate::Session; use super::StatementConfig; use super::{Consistency, SerialConsistency}; @@ -142,6 +146,82 @@ impl Batch { pub fn get_execution_profile_handle(&self) -> Option<&ExecutionProfileHandle> { self.config.execution_profile_handle.as_ref() } + + /// Associates the batch with a new execution profile that will have a load + /// balancing policy that will enforce the use of the provided [`NodeRef`] + /// to the extent possible. + /// + /// This should typically be used in conjunction with + /// [`Session::shard_for_statement`], where you would constitute a batch + /// by assigning to the same batch all the statements that would be executed + /// in the same shard. + /// + /// Since it is not guaranteed that subsequent calls to the load balancer + /// would re-assign the statement to the same node, you should use this + /// method to enforce the use of the original node that was envisioned by + /// `shard_for_statement` for the batch: + /// + /// ```rust + /// # use scylla::Session; + /// # use std::error::Error; + /// # async fn check_only_compiles(session: &Session) -> Result<(), Box> { + /// use scylla::{batch::Batch, serialize::row::SerializedValues}; + /// + /// let prepared_statement = session + /// .prepare("INSERT INTO ks.tab(a, b) VALUES(?, ?)") + /// .await?; + /// + /// let serialized_values: SerializedValues = prepared_statement.serialize_values(&(1, 2))?; + /// let shard = session.shard_for_statement(&prepared_statement, &serialized_values)?; + /// + /// // Send that to a task that will handle statements targeted to the same shard + /// + /// // On that task: + /// // Constitute a batch with all the statements that would be executed in the same shard + /// + /// let mut batch: Batch = Default::default(); + /// if let Some((node, shard_idx)) = shard { + /// batch.enforce_target_node(&node, shard_idx, &session); + /// } + /// let mut batch_values = Vec::new(); + /// + /// // As the task handling statements targeted to this shard receives them, + /// // it appends them to the batch + /// batch.append_statement(prepared_statement); + /// batch_values.push(serialized_values); + /// + /// // Run the batch + /// session.batch(&batch, batch_values).await?; + /// # Ok(()) + /// # } + /// ``` + /// + /// + /// If the target node is not available anymore at the time of executing the + /// statement, it will fallback to the original load balancing policy: + /// - Either that currently set on the [`Batch`], if any + /// - Or that of the [`Session`] if there isn't one on the `Batch` + pub fn enforce_target_node( + &mut self, + node: NodeRef<'_>, + shard: Shard, + base_execution_profile_from_session: &Session, + ) { + let execution_profile_handle = self.get_execution_profile_handle().unwrap_or_else(|| { + base_execution_profile_from_session.get_default_execution_profile_handle() + }); + self.set_execution_profile_handle(Some( + execution_profile_handle + .pointee_to_builder() + .load_balancing_policy(Arc::new(load_balancing::EnforceTargetShardPolicy::new( + node, + shard, + execution_profile_handle.load_balancing_policy(), + ))) + .build() + .into_handle(), + )) + } } impl Default for Batch { diff --git a/scylla/src/statement/prepared_statement.rs b/scylla/src/statement/prepared_statement.rs index 6287a1492e..bbbdf1e39b 100644 --- a/scylla/src/statement/prepared_statement.rs +++ b/scylla/src/statement/prepared_statement.rs @@ -459,7 +459,7 @@ impl PreparedStatement { self.config.execution_profile_handle.as_ref() } - pub(crate) fn serialize_values( + pub fn serialize_values( &self, values: &impl SerializeRow, ) -> Result { diff --git a/scylla/src/transport/execution_profile.rs b/scylla/src/transport/execution_profile.rs index 7b7a14faaf..fc659a9378 100644 --- a/scylla/src/transport/execution_profile.rs +++ b/scylla/src/transport/execution_profile.rs @@ -485,4 +485,19 @@ impl ExecutionProfileHandle { pub fn map_to_another_profile(&mut self, profile: ExecutionProfile) { self.0 .0.store(profile.0) } + + /// Get the load balancing policy associated with this execution profile. + /// + /// This may be useful if one wants to construct a new load balancing policy + /// that is based on the one associated with this execution profile. + pub fn load_balancing_policy(&self) -> Arc { + // Exposed as a building block of `Batch::enforce_target_node` in case a user + // wants more control than what that method does. + // Since the fact that the load balancing policy is accessible from the + // ExecutionProfileHandle is already public API through the fact it's documented + // that it would be preserved by pointee_to_builder, having this as pblic API + // doesn't prevent any more non-breaking evolution than would already be + // blocked anyway + self.0 .0.load().load_balancing_policy.clone() + } } diff --git a/scylla/src/transport/load_balancing/default.rs b/scylla/src/transport/load_balancing/default.rs index 4280c855fa..062de20926 100644 --- a/scylla/src/transport/load_balancing/default.rs +++ b/scylla/src/transport/load_balancing/default.rs @@ -703,7 +703,7 @@ impl DefaultPolicy { vec.into_iter() } - fn is_alive(node: NodeRef, _shard: Option) -> bool { + pub(crate) fn is_alive(node: NodeRef, _shard: Option) -> bool { // For now, we leave this as stub, until we have time to improve node events. // node.is_enabled() && !node.is_down() node.is_enabled() diff --git a/scylla/src/transport/load_balancing/enforce_node.rs b/scylla/src/transport/load_balancing/enforce_node.rs new file mode 100644 index 0000000000..e2e7dde11c --- /dev/null +++ b/scylla/src/transport/load_balancing/enforce_node.rs @@ -0,0 +1,62 @@ +use super::{DefaultPolicy, FallbackPlan, LoadBalancingPolicy, NodeRef, RoutingInfo}; +use crate::{ + routing::Shard, + transport::{cluster::ClusterData, Node}, +}; +use std::sync::Arc; +use uuid::Uuid; + +/// This policy will always return the same node, unless it is not available anymore, in which case it will +/// fallback to the provided policy. +/// +/// This is meant to be used for shard-aware batching. +#[derive(Debug)] +pub struct EnforceTargetShardPolicy { + target_node: Uuid, + shard: Shard, + fallback: Arc, +} + +impl EnforceTargetShardPolicy { + pub fn new( + target_node: &Arc, + shard: Shard, + fallback: Arc, + ) -> Self { + Self { + target_node: target_node.host_id, + shard, + fallback, + } + } +} +impl LoadBalancingPolicy for EnforceTargetShardPolicy { + fn pick<'a>( + &'a self, + query: &'a RoutingInfo, + cluster: &'a ClusterData, + ) -> Option<(NodeRef<'a>, Option)> { + cluster + .known_peers + .get(&self.target_node) + .map(|node| (node, Some(self.shard))) + .filter(|&(node, shard)| DefaultPolicy::is_alive(node, shard)) + .or_else(|| self.fallback.pick(query, cluster)) + } + + fn fallback<'a>( + &'a self, + query: &'a RoutingInfo, + cluster: &'a ClusterData, + ) -> FallbackPlan<'a> { + self.fallback.fallback(query, cluster) + } + + fn name(&self) -> String { + format!( + "Enforce target shard Load balancing policy - Node: {} - fallback: {}", + self.target_node, + self.fallback.name() + ) + } +} diff --git a/scylla/src/transport/load_balancing/mod.rs b/scylla/src/transport/load_balancing/mod.rs index 691f8657ee..7c5e6769fd 100644 --- a/scylla/src/transport/load_balancing/mod.rs +++ b/scylla/src/transport/load_balancing/mod.rs @@ -12,9 +12,13 @@ use scylla_cql::{ use std::time::Duration; mod default; +mod enforce_node; mod plan; -pub use default::{DefaultPolicy, DefaultPolicyBuilder, LatencyAwarenessBuilder}; pub use plan::Plan; +pub use { + default::{DefaultPolicy, DefaultPolicyBuilder, LatencyAwarenessBuilder}, + enforce_node::EnforceTargetShardPolicy, +}; /// Represents info about statement that can be used by load balancing policies. #[derive(Default, Clone, Debug)] diff --git a/scylla/src/transport/session.rs b/scylla/src/transport/session.rs index a81a115a4b..2385ea3c3e 100644 --- a/scylla/src/transport/session.rs +++ b/scylla/src/transport/session.rs @@ -7,6 +7,7 @@ use crate::cloud::CloudConfig; use crate::history; use crate::history::HistoryListener; +use crate::routing; use crate::utils::pretty::{CommaSeparatedDisplayer, CqlValueDisplayer}; use arc_swap::ArcSwapOption; use async_trait::async_trait; @@ -18,8 +19,8 @@ pub use scylla_cql::errors::TranslationError; use scylla_cql::frame::response::result::{deser_cql_value, ColumnSpec, Rows}; use scylla_cql::frame::response::NonErrorResponse; use scylla_cql::types::serialize::batch::BatchValues; -use scylla_cql::types::serialize::row::SerializeRow; -use std::borrow::Borrow; +use scylla_cql::types::serialize::row::{SerializeRow, SerializedValues}; +use std::borrow::{Borrow, Cow}; use std::collections::HashMap; use std::fmt::Display; use std::future::Future; @@ -984,96 +985,92 @@ impl Session { values: impl SerializeRow, paging_state: Option, ) -> Result { - let serialized_values = prepared.serialize_values(&values)?; - let values_ref = &serialized_values; - let paging_state_ref = &paging_state; - - let (partition_key, token) = prepared - .extract_partition_key_and_calculate_token(prepared.get_partitioner_name(), values_ref)? - .unzip(); - - let execution_profile = prepared - .get_execution_profile_handle() - .unwrap_or_else(|| self.get_default_execution_profile_handle()) - .access(); - - let table_spec = prepared.get_table_spec(); - - let statement_info = RoutingInfo { - consistency: prepared - .config - .consistency - .unwrap_or(execution_profile.consistency), - serial_consistency: prepared - .config - .serial_consistency - .unwrap_or(execution_profile.serial_consistency), - token, - table: table_spec, - is_confirmed_lwt: prepared.is_confirmed_lwt(), + // Inlining + const propagation make this optimization zero-cost in unrelated cases: + let serialized_values = match values.already_serialized() { + None => Cow::Owned(prepared.serialize_values(&values)?), + Some(serialized) => Cow::Borrowed(serialized), }; - let span = RequestSpan::new_prepared( - partition_key.as_ref().map(|pk| pk.iter()), - token, - serialized_values.buffer_size(), - ); + return non_generic_inner(self, prepared, &serialized_values, &paging_state).await; + /// Avoid monomorphizing this whole function for every SerializeRow type + async fn non_generic_inner( + self_: &Session, + prepared: &PreparedStatement, + values_ref: &SerializedValues, + paging_state_ref: &Option, + ) -> Result { + let (partition_key, token) = prepared + .extract_partition_key_and_calculate_token( + prepared.get_partitioner_name(), + values_ref, + )? + .unzip(); + + let (execution_profile, statement_info) = + self_.execution_profile_and_routing_info_from_prepared_statement(prepared, token); + + let span = RequestSpan::new_prepared( + partition_key.as_ref().map(|pk| pk.iter()), + token, + values_ref.buffer_size(), + ); - if !span.span().is_disabled() { - if let (Some(table_spec), Some(token)) = (statement_info.table, token) { - let cluster_data = self.get_cluster_data(); - let replicas: smallvec::SmallVec<[_; 8]> = cluster_data - .get_token_endpoints_iter(table_spec, token) - .collect(); - span.record_replicas(&replicas) + if !span.span().is_disabled() { + if let (Some(table_spec), Some(token)) = (statement_info.table, token) { + let cluster_data = self_.get_cluster_data(); + let replicas: smallvec::SmallVec<[_; 8]> = cluster_data + .get_token_endpoints_iter(table_spec, token) + .collect(); + span.record_replicas(&replicas) + } } - } - let run_query_result: RunQueryResult = self - .run_query( - statement_info, - &prepared.config, - execution_profile, - |connection: Arc, - consistency: Consistency, - execution_profile: &ExecutionProfileInner| { - let serial_consistency = prepared - .config - .serial_consistency - .unwrap_or(execution_profile.serial_consistency); - async move { - connection - .execute_with_consistency( - prepared, - values_ref, - consistency, - serial_consistency, - paging_state_ref.clone(), - ) - .await - .and_then(QueryResponse::into_non_error_query_response) - } - }, - &span, - ) - .instrument(span.span().clone()) - .await?; + let run_query_result: RunQueryResult = self_ + .run_query( + statement_info, + &prepared.config, + execution_profile, + |connection: Arc, + consistency: Consistency, + execution_profile: &ExecutionProfileInner| { + let serial_consistency = prepared + .config + .serial_consistency + .unwrap_or(execution_profile.serial_consistency); + async move { + connection + .execute_with_consistency( + prepared, + values_ref, + consistency, + serial_consistency, + paging_state_ref.clone(), + ) + .await + .and_then(QueryResponse::into_non_error_query_response) + } + }, + &span, + ) + .instrument(span.span().clone()) + .await?; - let response = match run_query_result { - RunQueryResult::IgnoredWriteError => NonErrorQueryResponse { - response: NonErrorResponse::Result(result::Result::Void), - tracing_id: None, - warnings: Vec::new(), - }, - RunQueryResult::Completed(response) => response, - }; + let response = match run_query_result { + RunQueryResult::IgnoredWriteError => NonErrorQueryResponse { + response: NonErrorResponse::Result(result::Result::Void), + tracing_id: None, + warnings: Vec::new(), + }, + RunQueryResult::Completed(response) => response, + }; - self.handle_set_keyspace_response(&response).await?; - self.handle_auto_await_schema_agreement(&response).await?; + self_.handle_set_keyspace_response(&response).await?; + self_.handle_auto_await_schema_agreement(&response).await?; - let result = response.into_query_result()?; - span.record_result_fields(&result); - Ok(result) + let result = response.into_query_result()?; + span.record_result_fields(&result); + Ok(result) + } } /// Run a prepared query with paging\ @@ -1822,6 +1819,74 @@ impl Session { Ok(in_agreement.then_some(local_version)) } + /// Get a node/shard that the load balancer would potentially target if running this query + /// + /// This may help constituting shard-aware batches (see [`Batch::enforce_target_node`]) + #[allow(clippy::type_complexity)] + pub fn shard_for_statement( + &self, + prepared: &PreparedStatement, + serialized_values: &SerializedValues, + ) -> Result, routing::Shard)>, QueryError> { + let token = match prepared.extract_partition_key_and_calculate_token( + prepared.get_partitioner_name(), + serialized_values, + )? { + Some((_partition_key, token)) => token, + None => return Ok(None), + }; + + let (execution_profile, routing_info) = + self.execution_profile_and_routing_info_from_prepared_statement(prepared, Some(token)); + let cluster_data = self.cluster.get_data(); + + let mut query_plan = load_balancing::Plan::new( + &*execution_profile.load_balancing_policy, + &routing_info, + &cluster_data, + ); + // We can't return the full iterator here because the iterator borrows from local variables. + // In order to achieve that, two designs would be possible: + // - Construct a self-referential struct and implement iterator on it via e.g. Ouroboros + // - Take a closure as a parameter that will take the local iterator and return anything, and + // this function would return directly what the closure returns + // Most likely though, people would use this for some kind of shard-awareness optimization for batching, + // and are consequently not interested in subsequent nodes. + // Until then, let's just expose this, as it is simpler + Ok(query_plan + .next() + .map(|(node, shard)| (Arc::clone(node), shard))) + } + + fn execution_profile_and_routing_info_from_prepared_statement<'p>( + &self, + prepared: &'p PreparedStatement, + token: Option, + ) -> (Arc, RoutingInfo<'p>) { + let execution_profile = prepared + .get_execution_profile_handle() + .unwrap_or_else(|| self.get_default_execution_profile_handle()) + .access(); + + let table_spec = prepared.get_table_spec(); + + let routing_info = RoutingInfo { + consistency: prepared + .config + .consistency + .unwrap_or(execution_profile.consistency), + serial_consistency: prepared + .config + .serial_consistency + .unwrap_or(execution_profile.serial_consistency), + token, + table: table_spec, + is_confirmed_lwt: prepared.is_confirmed_lwt(), + }; + + (execution_profile, routing_info) + } + /// Retrieves the handle to execution profile that is used by this session /// by default, i.e. when an executed statement does not define its own handle. pub fn get_default_execution_profile_handle(&self) -> &ExecutionProfileHandle { diff --git a/scylla/tests/integration/main.rs b/scylla/tests/integration/main.rs index 06a2ab429a..4697bafbe7 100644 --- a/scylla/tests/integration/main.rs +++ b/scylla/tests/integration/main.rs @@ -4,6 +4,7 @@ mod hygiene; mod lwt_optimisation; mod new_session; mod retries; +mod shard_aware_batching; mod shards; mod silent_prepare_query; mod skip_metadata_optimization; diff --git a/scylla/tests/integration/shard_aware_batching.rs b/scylla/tests/integration/shard_aware_batching.rs new file mode 100644 index 0000000000..5bae286bce --- /dev/null +++ b/scylla/tests/integration/shard_aware_batching.rs @@ -0,0 +1,230 @@ +use crate::utils::test_with_3_node_cluster; +use futures::prelude::*; +use futures_batch::ChunksTimeoutStreamExt; +use scylla::retry_policy::FallthroughRetryPolicy; +use scylla::routing::Shard; +use scylla::serialize::row::SerializedValues; +use scylla::test_utils::unique_keyspace_name; +use scylla::transport::session::Session; +use scylla::{ExecutionProfile, SessionBuilder}; +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Duration; +use tokio::sync::mpsc; +use tokio_stream::wrappers::ReceiverStream; + +use scylla_proxy::{ + Condition, ProxyError, Reaction, RequestOpcode, RequestReaction, RequestRule, RunningProxy, + ShardAwareness, WorkerError, +}; + +#[tokio::test] +#[ntest::timeout(20000)] +#[cfg(not(scylla_cloud_tests))] +async fn shard_aware_batching_pattern_routes_to_proper_shard() { + let res = test_with_3_node_cluster(ShardAwareness::QueryNode, run_test).await; + + match res { + Ok(()) => (), + Err(ProxyError::Worker(WorkerError::DriverDisconnected(_))) => (), + Err(err) => panic!("{}", err), + } +} + +async fn run_test( + proxy_uris: [String; 3], + translation_map: HashMap, + mut running_proxy: RunningProxy, +) -> RunningProxy { + // This is just to increase the likelihood that only intended prepared statements (which contain this mark) are captured by the proxy. + const MAGIC_MARK: i32 = 123; + + // We set up proxy, so that it passes us information about which node was queried (via prepared_rx). + + let prepared_rule = |tx| { + RequestRule( + Condition::and( + Condition::RequestOpcode(RequestOpcode::Batch), + Condition::BodyContainsCaseSensitive(Box::new(MAGIC_MARK.to_be_bytes())), + ), + RequestReaction::noop().with_feedback_when_performed(tx), + ) + }; + + let mut prepared_rxs = [0, 1, 2].map(|i| { + let (prepared_tx, prepared_rx) = mpsc::unbounded_channel(); + running_proxy.running_nodes[i].change_request_rules(Some(vec![prepared_rule(prepared_tx)])); + prepared_rx + }); + let shards_for_nodes_test_check: Arc>>> = + Default::default(); + + let handle = ExecutionProfile::builder() + .retry_policy(Box::new(FallthroughRetryPolicy)) + .build() + .into_handle(); + + // DB preparation phase + let session: Arc = Arc::new( + SessionBuilder::new() + .known_node(proxy_uris[0].as_str()) + .default_execution_profile_handle(handle) + .address_translator(Arc::new(translation_map)) + .build() + .await + .unwrap(), + ); + + // Create schema + let ks = unique_keyspace_name(); + session.query(format!("CREATE KEYSPACE IF NOT EXISTS {} WITH REPLICATION = {{'class' : 'NetworkTopologyStrategy', 'replication_factor' : 3}}", ks), &[]).await.unwrap(); + session.use_keyspace(ks, false).await.unwrap(); + + session + .query("CREATE TABLE t (a int primary key, b int)", &[]) + .await + .unwrap(); + + // We will check which nodes where queries, for both LWT and non-LWT prepared statements. + let prepared_statement = session + .prepare("INSERT INTO t (a, b) VALUES (?, ?)") + .await + .unwrap(); + + assert!(prepared_statement.is_token_aware()); + + // Build the shard-aware batching system + + #[derive(Clone, Copy, PartialEq, Eq, Hash)] + struct DestinationShard { + node_id: uuid::Uuid, + shard_id_on_node: u32, + } + let mut channels_for_shards: HashMap< + DestinationShard, + tokio::sync::mpsc::Sender, + > = HashMap::new(); + let mut batching_tasks: Vec> = Vec::new(); // To make sure nothing panicked + for i in 0..150 { + let values = (i, MAGIC_MARK); + + let serialized_values = prepared_statement + .serialize_values(&values) + .expect("Failed to serialize values"); + + let (node, shard_id_on_node) = session + .shard_for_statement(&prepared_statement, &serialized_values) + .expect("Error when getting shard for statement") + .expect("Query is not shard-aware"); + let destination_shard = DestinationShard { + node_id: node.host_id, + shard_id_on_node, + }; + + // Typically if lines may come from different places, the `shards` `HashMap` would be behind + // a mutex, but for this example we keep it simple. + // Create the task that constitutes and sends the batches for this shard if it doesn't already exist + + let sender = channels_for_shards + .entry(destination_shard) + .or_insert_with(|| { + let (sender, receiver) = tokio::sync::mpsc::channel(10000); + let prepared_statement = prepared_statement.clone(); + let session = session.clone(); + + let mut scylla_batch = + scylla::batch::Batch::new(scylla::batch::BatchType::Unlogged); + scylla_batch.enforce_target_node(&node, shard_id_on_node, &session); + + let shards_for_nodes_test_check_clone = Arc::clone(&shards_for_nodes_test_check); + batching_tasks.push(tokio::spawn(async move { + let mut batches = ReceiverStream::new(receiver) + .chunks_timeout(10, Duration::from_millis(100)); + + while let Some(batch) = batches.next().await { + // Obviously if the actual prepared statement depends on each element of the batch + // this requires adjustment + scylla_batch.statements.resize_with(batch.len(), || { + scylla::batch::BatchStatement::PreparedStatement( + prepared_statement.clone(), + ) + }); + + // Take a global lock to make test deterministic + // (and because we need to push stuff in there to test that shard-awareness is respected) + let mut shards_for_nodes_test_check = + shards_for_nodes_test_check_clone.lock().await; + + session + .batch(&scylla_batch, &batch) + .await + .expect("Query to send batch failed"); + + shards_for_nodes_test_check + .entry(destination_shard.node_id) + .or_default() + .push(destination_shard.shard_id_on_node); + } + })); + sender + }); + + sender + .send(serialized_values) + .await + .expect("Failed to send serialized values to dedicated channel"); + } + + // Let's drop the senders, which will ensure that all batches are sent immediately, + // then wait for all the tasks to finish, and ensure that there were no errors + // In a production setting these dynamically instantiated tasks may be monitored more easily + // by using e.g. `tokio_tasks_shutdown` + std::mem::drop(channels_for_shards); + for task in batching_tasks { + task.await.unwrap(); + } + + // finally check that batching was indeed shard-aware. + + let mut expected: Vec> = Arc::try_unwrap(shards_for_nodes_test_check) + .expect("All batching tasks have finished") + .into_inner() + .into_values() + .collect(); + + let mut nodes_shards_calls: Vec> = Vec::new(); + for rx in prepared_rxs.iter_mut() { + let mut shards_calls = Vec::new(); + shards_calls.push( + rx.recv() + .await + .expect("Each node should have received at least one message") + .1 + .unwrap_or({ + // Cassandra case (non-scylla) + 0 + }) + .into(), + ); + while let Ok((_, call_shard)) = rx.try_recv() { + shards_calls.push( + call_shard + .unwrap_or({ + // Cassandra case (non-scylla) + 0 + }) + .into(), + ) + } + nodes_shards_calls.push(shards_calls); + } + + // Don't know which node is which + // but at least once we don't care about which node is which they should agree about what was sent to what shard + dbg!(&expected, &nodes_shards_calls); + expected.sort_unstable(); + nodes_shards_calls.sort_unstable(); + assert_eq!(expected, nodes_shards_calls); + + running_proxy +}