diff --git a/.github/workflows/rust.yml b/.github/workflows/test.yml similarity index 99% rename from .github/workflows/rust.yml rename to .github/workflows/test.yml index 71e4fec26..470861528 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/test.yml @@ -1,4 +1,4 @@ -name: Rust +name: Testing on: push: diff --git a/shotover-proxy/Cargo.toml b/shotover-proxy/Cargo.toml index 4041fc3f5..8b1a24793 100644 --- a/shotover-proxy/Cargo.toml +++ b/shotover-proxy/Cargo.toml @@ -7,7 +7,7 @@ edition = "2018" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -tokio = { version = "1.6.1", features = ["full", "macros"] } +tokio = { version = "1.11.0", features = ["full", "macros"] } tokio-util = { version = "0.6.1", features = ["full"]} tokio-stream = "0.1.2" bytes = "1.0.0" @@ -79,7 +79,7 @@ num_cpus = "1.0" serial_test = "0.5.1" test-helpers = { path = "../test-helpers" } hex-literal = "0.3.3" -reqwest = { version = "0.11.4", features = ["blocking"] } +ntest = "0.7.3" [[bench]] name = "redis_benches" diff --git a/shotover-proxy/examples/cassandra-standalone/topology.yaml b/shotover-proxy/examples/cassandra-standalone/topology.yaml index 8c4b85344..cf18b5e2c 100644 --- a/shotover-proxy/examples/cassandra-standalone/topology.yaml +++ b/shotover-proxy/examples/cassandra-standalone/topology.yaml @@ -14,11 +14,10 @@ sources: - clustering chain_config: main_chain: -# - Printer - CodecDestination: bypass_result_processing: true remote_address: "172.18.0.2:9042" named_topics: testtopic: 5 source_to_chain_mapping: - cassandra_prod: main_chain \ No newline at end of file + cassandra_prod: main_chain diff --git a/shotover-proxy/examples/null-cassandra/config.yaml b/shotover-proxy/examples/null-cassandra/config.yaml new file mode 100644 index 000000000..42ea740ca --- /dev/null +++ b/shotover-proxy/examples/null-cassandra/config.yaml @@ -0,0 +1,3 @@ +--- +main_log_level: "info,shotover_proxy=info" +observability_interface: "0.0.0.0:9001" diff --git a/shotover-proxy/examples/null-cassandra/topology.yaml b/shotover-proxy/examples/null-cassandra/topology.yaml new file mode 100644 index 000000000..9acfbcedc --- /dev/null +++ b/shotover-proxy/examples/null-cassandra/topology.yaml @@ -0,0 +1,21 @@ +--- +sources: + cassandra_prod: + Cassandra: + bypass_query_processing: true + listen_addr: "127.0.0.1:9042" + cassandra_ks: + system.local: + - key + test.simple: + - pk + test.clustering: + - pk + - clustering +chain_config: + main_chain: + - Null +named_topics: + testtopic: 5 +source_to_chain_mapping: + cassandra_prod: main_chain diff --git a/shotover-proxy/examples/null-redis/config.yaml b/shotover-proxy/examples/null-redis/config.yaml new file mode 100644 index 000000000..42ea740ca --- /dev/null +++ b/shotover-proxy/examples/null-redis/config.yaml @@ -0,0 +1,3 @@ +--- +main_log_level: "info,shotover_proxy=info" +observability_interface: "0.0.0.0:9001" diff --git a/shotover-proxy/examples/null-redis/topology.yaml b/shotover-proxy/examples/null-redis/topology.yaml new file mode 100644 index 000000000..30bb48b6b --- /dev/null +++ b/shotover-proxy/examples/null-redis/topology.yaml @@ -0,0 +1,13 @@ +--- +sources: + redis_prod: + Redis: + batch_size_hint: 1 + listen_addr: "127.0.0.1:6379" +chain_config: + redis_chain: + - Null +named_topics: + testtopic: 5 +source_to_chain_mapping: + redis_prod: redis_chain diff --git a/shotover-proxy/src/config/topology.rs b/shotover-proxy/src/config/topology.rs index fe8d79bc4..b7a173d50 100644 --- a/shotover-proxy/src/config/topology.rs +++ b/shotover-proxy/src/config/topology.rs @@ -12,7 +12,7 @@ use serde::{Deserialize, Serialize}; use std::collections::HashMap; use tokio::sync::mpsc::{channel, Receiver, Sender}; use tokio::sync::oneshot::Sender as OneSender; -use tokio::sync::{broadcast, mpsc}; +use tokio::sync::watch; use tracing::info; #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] @@ -107,14 +107,14 @@ impl Topology { #[allow(clippy::type_complexity)] pub async fn run_chains( &self, - trigger_shutdown_rx: broadcast::Sender<()>, + trigger_shutdown_rx: watch::Receiver, ) -> Result<(Vec, Receiver<()>)> { let mut topics = self.build_topics(); info!("Loaded topics {:?}", topics.topics_tx.keys()); let mut sources_list: Vec = Vec::new(); - let (shutdown_complete_tx, shutdown_complete_rx) = mpsc::channel(1); + let (shutdown_complete_tx, shutdown_complete_rx) = channel(1); let chains = self.build_chains(&topics).await?; info!("Loaded chains {:?}", chains.keys()); diff --git a/shotover-proxy/src/runner.rs b/shotover-proxy/src/runner.rs index fb3032826..62fb2af91 100644 --- a/shotover-proxy/src/runner.rs +++ b/shotover-proxy/src/runner.rs @@ -4,9 +4,9 @@ use std::net::SocketAddr; use anyhow::{anyhow, Result}; use clap::{crate_version, Clap}; use metrics_exporter_prometheus::PrometheusBuilder; -use tokio::runtime::{self, Runtime}; +use tokio::runtime::{self, Handle as RuntimeHandle, Runtime}; use tokio::signal; -use tokio::sync::broadcast; +use tokio::sync::watch; use tokio::task::JoinHandle; use tracing::{debug, error, info}; use tracing_appender::non_blocking::{NonBlocking, WorkerGuard}; @@ -51,7 +51,8 @@ impl Default for ConfigOpts { } pub struct Runner { - runtime: Runtime, + runtime: Option, + runtime_handle: RuntimeHandle, topology: Topology, config: Config, tracing: TracingState, @@ -62,18 +63,13 @@ impl Runner { let config = Config::from_file(params.config_file.clone())?; let topology = Topology::from_file(params.topology_file.clone())?; - let runtime = runtime::Builder::new_multi_thread() - .enable_all() - .thread_name("RPProxy-Thread") - .thread_stack_size(params.stack_size) - .worker_threads(params.core_threads) - .build() - .unwrap(); - let tracing = TracingState::new(config.main_log_level.as_str())?; + let (runtime_handle, runtime) = Runner::get_runtime(params.stack_size, params.core_threads); + Ok(Runner { runtime, + runtime_handle, topology, config, tracing, @@ -87,36 +83,60 @@ impl Runner { let socket: SocketAddr = self.config.observability_interface.parse()?; let exporter = LogFilterHttpExporter::new(handle, socket, self.tracing.handle.clone()); - self.runtime.spawn(exporter.async_run()); + + self.runtime_handle.spawn(exporter.async_run()); Ok(self) } pub fn run_spawn(self) -> RunnerSpawned { - let (trigger_shutdown_tx, _) = broadcast::channel(1); - let handle = - self.runtime - .spawn(run(self.topology, self.config, trigger_shutdown_tx.clone())); + let (trigger_shutdown_tx, trigger_shutdown_rx) = watch::channel(false); + + let join_handle = + self.runtime_handle + .spawn(run(self.topology, self.config, trigger_shutdown_rx)); RunnerSpawned { + runtime_handle: self.runtime_handle, runtime: self.runtime, tracing_guard: self.tracing.guard, trigger_shutdown_tx, - handle, + join_handle, } } pub fn run_block(self) -> Result<()> { - let (trigger_shutdown_tx, _) = broadcast::channel(1); + let (trigger_shutdown_tx, trigger_shutdown_rx) = watch::channel(false); - let trigger_shutdown_tx_clone = trigger_shutdown_tx.clone(); - self.runtime.spawn(async move { + self.runtime_handle.spawn(async move { signal::ctrl_c().await.unwrap(); - trigger_shutdown_tx_clone.send(()).unwrap(); + trigger_shutdown_tx.send(true).unwrap(); }); - self.runtime - .block_on(run(self.topology, self.config, trigger_shutdown_tx)) + self.runtime_handle + .block_on(run(self.topology, self.config, trigger_shutdown_rx)) + } + + /// Get handle for an existing runtime or create one + fn get_runtime(stack_size: usize, core_threads: usize) -> (RuntimeHandle, Option) { + if let Ok(handle) = tokio::runtime::Handle::try_current() { + // Using block_in_place to trigger a panic in case the runtime is set up in single-threaded mode. + // Shotover does not function correctly in single threaded mode (currently hangs) + // and block_in_place gives an error message explaining to setup the runtime in multi-threaded mode. + tokio::task::block_in_place(|| {}); + + (handle, None) + } else { + let runtime = runtime::Builder::new_multi_thread() + .enable_all() + .thread_name("Shotover-Proxy-Thread") + .thread_stack_size(stack_size) + .worker_threads(core_threads) + .build() + .unwrap(); + + (runtime.handle().clone(), Some(runtime)) + } } } @@ -172,16 +192,17 @@ impl TracingState { } pub struct RunnerSpawned { - pub runtime: Runtime, - pub handle: JoinHandle>, + pub runtime: Option, + pub runtime_handle: RuntimeHandle, + pub join_handle: JoinHandle>, pub tracing_guard: WorkerGuard, - pub trigger_shutdown_tx: broadcast::Sender<()>, + pub trigger_shutdown_tx: watch::Sender, } pub async fn run( topology: Topology, config: Config, - trigger_shutdown_tx: broadcast::Sender<()>, + trigger_shutdown_tx: watch::Receiver, ) -> Result<()> { info!("Starting Shotover {}", crate_version!()); info!(configuration = ?config); diff --git a/shotover-proxy/src/server.rs b/shotover-proxy/src/server.rs index a275f5577..1244f88eb 100644 --- a/shotover-proxy/src/server.rs +++ b/shotover-proxy/src/server.rs @@ -6,7 +6,7 @@ use futures::StreamExt; use metrics::gauge; use std::sync::Arc; use tokio::net::{TcpListener, TcpStream}; -use tokio::sync::{broadcast, mpsc, Semaphore}; +use tokio::sync::{mpsc, watch, Semaphore}; use tokio::time; use tokio::time::timeout; use tokio::time::Duration; @@ -61,23 +61,13 @@ pub struct TcpCodecListener { /// The initial `shutdown` trigger is provided by the `run` caller. The /// server is responsible for gracefully shutting down active connections. /// When a connection task is spawned, it is passed a broadcast receiver - /// handle. When a graceful shutdown is initiated, a `()` value is sent via - /// the broadcast::Sender. Each active connection receives it, reaches a + /// handle. When a graceful shutdown is initiated, a `true` value is sent via + /// the watch::Sender. Each active connection receives it, reaches a /// safe terminal state, and completes the task. - pub trigger_shutdown_tx: broadcast::Sender<()>, + pub trigger_shutdown_rx: watch::Receiver, /// Used as part of the graceful shutdown process to wait for client /// connections to complete processing. - /// - /// Tokio channels are closed once all `Sender` handles go out of scope. - /// When a channel is closed, the receiver receives `None`. This is - /// leveraged to detect all connection handlers completing. When a - /// connection handler is initialized, it is assigned a clone of - /// `shutdown_complete_tx`. When the listener shuts down, it drops the - /// sender held by this `shutdown_complete_tx` field. Once all handler tasks - /// complete, all clones of the `Sender` are also dropped. This results in - /// `shutdown_complete_rx.recv()` completing with `None`. At this point, it - /// is safe to exit the server process. pub shutdown_complete_tx: mpsc::Sender<()>, } @@ -184,7 +174,7 @@ impl TcpCodecListener { limit_connections: self.limit_connections.clone(), // Receive shutdown notifications. - shutdown: Shutdown::new(self.trigger_shutdown_tx.subscribe()), + shutdown: Shutdown::new(self.trigger_shutdown_rx.clone()), // Notifies the receiver half once all clones are // dropped. @@ -275,7 +265,6 @@ pub struct Handler { /// which point the connection is terminated. shutdown: Shutdown, - /// Not used directly. Instead, when `Handler` is dropped...? _shutdown_complete: mpsc::Sender<()>, } @@ -321,7 +310,7 @@ impl Handler { }); tokio::spawn(async move { - let rx_stream = UnboundedReceiverStream::new(out_rx).map(|x| Ok(x)); + let rx_stream = UnboundedReceiverStream::new(out_rx).map(Ok); let r = rx_stream.forward(writer).await; debug!("Stream ended {:?}", r); }); @@ -418,12 +407,12 @@ pub struct Shutdown { shutdown: bool, /// The receive half of the channel used to listen for shutdown. - notify: broadcast::Receiver<()>, + notify: watch::Receiver, } impl Shutdown { /// Create a new `Shutdown` backed by the given `broadcast::Receiver`. - pub(crate) fn new(notify: broadcast::Receiver<()>) -> Shutdown { + pub(crate) fn new(notify: watch::Receiver) -> Shutdown { Shutdown { shutdown: false, notify, @@ -443,8 +432,11 @@ impl Shutdown { return; } - // Cannot receive a "lag error" as only one value is ever sent. - self.notify.recv().await.unwrap(); + // check we didn't receive a shutdown message before the receiver was created + if !*self.notify.borrow() { + // Await the shutdown messsage + self.notify.changed().await.unwrap(); + } // Remember that the signal has been received. self.shutdown = true; diff --git a/shotover-proxy/src/sources/cassandra_source.rs b/shotover-proxy/src/sources/cassandra_source.rs index 680148c7e..58e6c92d1 100644 --- a/shotover-proxy/src/sources/cassandra_source.rs +++ b/shotover-proxy/src/sources/cassandra_source.rs @@ -5,7 +5,7 @@ use anyhow::Result; use async_trait::async_trait; use serde::{Deserialize, Serialize}; use tokio::runtime::Handle; -use tokio::sync::{broadcast, mpsc, Semaphore}; +use tokio::sync::{mpsc, watch, Semaphore}; use tokio::task::JoinHandle; use tracing::{error, info}; @@ -30,7 +30,7 @@ impl SourcesFromConfig for CassandraConfig { &self, chain: &TransformChain, _topics: &mut TopicHolder, - trigger_shutdown_tx: broadcast::Sender<()>, + trigger_shutdown_rx: watch::Receiver, shutdown_complete_tx: mpsc::Sender<()>, ) -> Result> { Ok(vec![Sources::Cassandra( @@ -38,7 +38,7 @@ impl SourcesFromConfig for CassandraConfig { chain, self.listen_addr.clone(), self.cassandra_ks.clone(), - trigger_shutdown_tx, + trigger_shutdown_rx, shutdown_complete_tx, self.bypass_query_processing.unwrap_or(true), self.connection_limit, @@ -61,7 +61,7 @@ impl CassandraSource { chain: &TransformChain, listen_addr: String, cassandra_ks: HashMap>, - trigger_shutdown_tx: broadcast::Sender<()>, + mut trigger_shutdown_rx: watch::Receiver, shutdown_complete_tx: mpsc::Sender<()>, bypass: bool, connection_limit: Option, @@ -71,8 +71,6 @@ impl CassandraSource { info!("Starting Cassandra source on [{}]", listen_addr); - let mut trigger_shutdown_rx = trigger_shutdown_tx.subscribe(); - let mut listener = TcpCodecListener { chain: chain.clone(), source_name: name.to_string(), @@ -81,37 +79,38 @@ impl CassandraSource { hard_connection_limit: hard_connection_limit.unwrap_or(false), codec: CassandraCodec2::new(cassandra_ks, bypass), limit_connections: Arc::new(Semaphore::new(connection_limit.unwrap_or(512))), - trigger_shutdown_tx, + trigger_shutdown_rx: trigger_shutdown_rx.clone(), shutdown_complete_tx, }; - let jh = Handle::current().spawn(async move { - tokio::select! { - res = listener.run() => { - if let Err(err) = res { - error!(cause = %err, "failed to accept"); + let join_handle = Handle::current().spawn(async move { + // Check we didn't receive a shutdown signal before the receiver was created + if !*trigger_shutdown_rx.borrow() { + tokio::select! { + res = listener.run() => { + if let Err(err) = res { + error!(cause = %err, "failed to accept"); + } + } + _ = trigger_shutdown_rx.changed() => { + info!("cassandra source shutting down") } - } - _ = trigger_shutdown_rx.recv() => { - info!("cassandra source shutting down") } } let TcpCodecListener { - trigger_shutdown_tx, shutdown_complete_tx, .. } = listener; drop(shutdown_complete_tx); - drop(trigger_shutdown_tx); Ok(()) }); CassandraSource { name, - join_handle: jh, + join_handle, listen_addr, } } diff --git a/shotover-proxy/src/sources/mod.rs b/shotover-proxy/src/sources/mod.rs index f0fffe861..610d75835 100644 --- a/shotover-proxy/src/sources/mod.rs +++ b/shotover-proxy/src/sources/mod.rs @@ -5,7 +5,7 @@ use crate::sources::redis_source::{RedisConfig, RedisSource}; use crate::transforms::chain::TransformChain; use async_trait::async_trait; use serde::{Deserialize, Serialize}; -use tokio::sync::{broadcast, mpsc}; +use tokio::sync::{mpsc, watch}; use tokio::task::JoinHandle; use anyhow::Result; @@ -62,20 +62,20 @@ impl SourcesConfig { &self, chain: &TransformChain, topics: &mut TopicHolder, - trigger_shutdown_tx: broadcast::Sender<()>, + trigger_shutdown_rx: watch::Receiver, shutdown_complete_tx: mpsc::Sender<()>, ) -> Result> { match self { SourcesConfig::Cassandra(c) => { - c.get_source(chain, topics, trigger_shutdown_tx, shutdown_complete_tx) + c.get_source(chain, topics, trigger_shutdown_rx, shutdown_complete_tx) .await } SourcesConfig::Mpsc(m) => { - m.get_source(chain, topics, trigger_shutdown_tx, shutdown_complete_tx) + m.get_source(chain, topics, trigger_shutdown_rx, shutdown_complete_tx) .await } SourcesConfig::Redis(r) => { - r.get_source(chain, topics, trigger_shutdown_tx, shutdown_complete_tx) + r.get_source(chain, topics, trigger_shutdown_rx, shutdown_complete_tx) .await } } @@ -88,7 +88,7 @@ pub trait SourcesFromConfig: Send { &self, chain: &TransformChain, topics: &mut TopicHolder, - trigger_shutdown: broadcast::Sender<()>, + trigger_shutdown_rx: watch::Receiver, shutdown_complete_tx: mpsc::Sender<()>, ) -> Result>; } diff --git a/shotover-proxy/src/sources/mpsc_source.rs b/shotover-proxy/src/sources/mpsc_source.rs index 62cb64417..513f93b7d 100644 --- a/shotover-proxy/src/sources/mpsc_source.rs +++ b/shotover-proxy/src/sources/mpsc_source.rs @@ -12,7 +12,7 @@ use async_trait::async_trait; use serde::{Deserialize, Serialize}; use std::time::Instant; use tokio::runtime::Handle; -use tokio::sync::{broadcast, mpsc}; +use tokio::sync::{mpsc, watch}; use tokio::task::JoinHandle; use tracing::info; use tracing::warn; @@ -29,7 +29,7 @@ impl SourcesFromConfig for AsyncMpscConfig { &self, chain: &TransformChain, topics: &mut TopicHolder, - trigger_shutdown_on_drop_rx: broadcast::Sender<()>, + trigger_shutdown_on_drop_rx: watch::Receiver, shutdown_complete_tx: mpsc::Sender<()>, ) -> Result> { if let Some(rx) = topics.get_rx(&self.topic_name) { @@ -41,7 +41,7 @@ impl SourcesFromConfig for AsyncMpscConfig { chain.clone(), rx, &self.topic_name, - Shutdown::new(trigger_shutdown_on_drop_rx.subscribe()), + Shutdown::new(trigger_shutdown_on_drop_rx), shutdown_complete_tx, behavior.clone(), ))]) diff --git a/shotover-proxy/src/sources/redis_source.rs b/shotover-proxy/src/sources/redis_source.rs index 366d57337..2c1a75f78 100644 --- a/shotover-proxy/src/sources/redis_source.rs +++ b/shotover-proxy/src/sources/redis_source.rs @@ -8,7 +8,7 @@ use async_trait::async_trait; use serde::{Deserialize, Serialize}; use std::sync::Arc; use tokio::runtime::Handle; -use tokio::sync::{broadcast, mpsc, Semaphore}; +use tokio::sync::{mpsc, watch, Semaphore}; use tokio::task::JoinHandle; use tracing::{error, info}; @@ -28,7 +28,7 @@ impl SourcesFromConfig for RedisConfig { &self, chain: &TransformChain, _topics: &mut TopicHolder, - trigger_shutdown_tx: broadcast::Sender<()>, + trigger_shutdown_rx: watch::Receiver, shutdown_complete_tx: mpsc::Sender<()>, ) -> Result> { Ok(vec![Sources::Redis( @@ -36,7 +36,7 @@ impl SourcesFromConfig for RedisConfig { chain, self.listen_addr.clone(), self.batch_size_hint, - trigger_shutdown_tx, + trigger_shutdown_rx, shutdown_complete_tx, self.connection_limit, self.hard_connection_limit, @@ -58,7 +58,7 @@ impl RedisSource { chain: &TransformChain, listen_addr: String, batch_hint: u64, - trigger_shutdown_tx: broadcast::Sender<()>, + mut trigger_shutdown_rx: watch::Receiver, shutdown_complete_tx: mpsc::Sender<()>, connection_limit: Option, hard_connection_limit: Option, @@ -66,8 +66,6 @@ impl RedisSource { info!("Starting Redis source on [{}]", listen_addr); let name = "Redis Source"; - let mut trigger_shutdown_rx = trigger_shutdown_tx.subscribe(); - let mut listener = TcpCodecListener { chain: chain.clone(), source_name: name.to_string(), @@ -76,30 +74,32 @@ impl RedisSource { hard_connection_limit: hard_connection_limit.unwrap_or(false), codec: RedisCodec::new(false, batch_hint as usize), limit_connections: Arc::new(Semaphore::new(connection_limit.unwrap_or(512))), - trigger_shutdown_tx, + trigger_shutdown_rx: trigger_shutdown_rx.clone(), shutdown_complete_tx, }; let join_handle = Handle::current().spawn(async move { - tokio::select! { - res = listener.run() => { - if let Err(err) = res { - error!(cause = %err, "failed to accept"); + // Check we didn't receive a shutdown signal before the receiver was created + if !*trigger_shutdown_rx.borrow() { + tokio::select! { + res = listener.run() => { + if let Err(err) = res { + error!(cause = %err, "failed to accept"); + } } - } - _ = trigger_shutdown_rx.recv() => { - info!("redis source shutting down") + _ = trigger_shutdown_rx.changed() => { + info!("redis source shutting down") + } + } } let TcpCodecListener { - trigger_shutdown_tx, shutdown_complete_tx, .. } = listener; drop(shutdown_complete_tx); - drop(trigger_shutdown_tx); Ok(()) }); diff --git a/shotover-proxy/src/transforms/mod.rs b/shotover-proxy/src/transforms/mod.rs index de6bc638e..d975d60dd 100644 --- a/shotover-proxy/src/transforms/mod.rs +++ b/shotover-proxy/src/transforms/mod.rs @@ -196,6 +196,7 @@ pub enum TransformsConfig { RedisCluster(RedisClusterConfig), RedisTimestampTagger, Printer, + Null, ParallelMap(ParallelMapConfig), PoolConnections(ConnectionBalanceAndPoolConfig), Coalesce(CoalesceConfig), @@ -219,6 +220,7 @@ impl TransformsConfig { Ok(Transforms::RedisTimeStampTagger(RedisTimestampTagger::new())) } TransformsConfig::Printer => Ok(Transforms::Printer(Printer::new())), + TransformsConfig::Null => Ok(Transforms::Null(Null::new())), TransformsConfig::RedisCluster(r) => r.get_source(topics).await, TransformsConfig::ParallelMap(s) => s.get_source(topics).await, TransformsConfig::PoolConnections(s) => s.get_source(topics).await, diff --git a/shotover-proxy/tests/helpers/mod.rs b/shotover-proxy/tests/helpers/mod.rs index 06373cd46..87a28cb86 100644 --- a/shotover-proxy/tests/helpers/mod.rs +++ b/shotover-proxy/tests/helpers/mod.rs @@ -4,14 +4,15 @@ use shotover_proxy::runner::{ConfigOpts, Runner}; use std::net::TcpStream; use std::thread; use std::time::Duration; -use tokio::runtime::Runtime; -use tokio::sync::broadcast; +use tokio::runtime::{Handle as RuntimeHandle, Runtime}; +use tokio::sync::watch; use tokio::task::JoinHandle; pub struct ShotoverManager { - pub runtime: Runtime, - pub handle: Option>>, - pub trigger_shutdown_tx: broadcast::Sender<()>, + pub runtime: Option, + pub runtime_handle: RuntimeHandle, + pub join_handle: Option>>, + pub trigger_shutdown_tx: watch::Sender, } impl ShotoverManager { @@ -33,7 +34,8 @@ impl ShotoverManager { ShotoverManager { runtime: spawn.runtime, - handle: Some(spawn.handle), + runtime_handle: spawn.runtime_handle, + join_handle: Some(spawn.join_handle), trigger_shutdown_tx: spawn.trigger_shutdown_tx, } } @@ -68,9 +70,10 @@ impl Drop for ShotoverManager { // We only shutdown shotover to test the shutdown process not because we need to clean up any resources. // So skipping shutdown on panic is fine. } else { - self.trigger_shutdown_tx.send(()).unwrap(); - self.runtime - .block_on(self.handle.take().unwrap()) + self.trigger_shutdown_tx.send(true).unwrap(); + + let _enter_guard = self.runtime_handle.enter(); + futures::executor::block_on(self.join_handle.take().unwrap()) .unwrap() .unwrap(); } diff --git a/shotover-proxy/tests/lib.rs b/shotover-proxy/tests/lib.rs index 564a1c72b..f14806e09 100644 --- a/shotover-proxy/tests/lib.rs +++ b/shotover-proxy/tests/lib.rs @@ -1,4 +1,4 @@ -pub mod admin; pub mod codec; mod helpers; pub mod redis_int_tests; +pub mod runner; diff --git a/shotover-proxy/tests/admin/mod.rs b/shotover-proxy/tests/runner/mod.rs similarity index 56% rename from shotover-proxy/tests/admin/mod.rs rename to shotover-proxy/tests/runner/mod.rs index cdc007157..1bf23588c 100644 --- a/shotover-proxy/tests/admin/mod.rs +++ b/shotover-proxy/tests/runner/mod.rs @@ -1 +1,2 @@ mod observability_int_tests; +mod runner_int_tests; diff --git a/shotover-proxy/tests/admin/observability_int_tests.rs b/shotover-proxy/tests/runner/observability_int_tests.rs similarity index 74% rename from shotover-proxy/tests/admin/observability_int_tests.rs rename to shotover-proxy/tests/runner/observability_int_tests.rs index 424279b1e..54923c656 100644 --- a/shotover-proxy/tests/admin/observability_int_tests.rs +++ b/shotover-proxy/tests/runner/observability_int_tests.rs @@ -2,9 +2,9 @@ use crate::helpers::ShotoverManager; use serial_test::serial; use test_helpers::docker_compose::DockerCompose; -#[test] +#[tokio::test(flavor = "multi_thread")] #[serial] -fn test_metrics() { +async fn test_metrics() { let _compose = DockerCompose::new("examples/redis-passthrough/docker-compose.yml"); let shotover_manager = @@ -22,10 +22,11 @@ fn test_metrics() { .arg(43) .execute(&mut connection); - let body = reqwest::blocking::get("http://localhost:9001/metrics") - .unwrap() - .text() - .unwrap(); + let client = hyper::Client::new(); + let uri = "http://localhost:9001/metrics".parse().unwrap(); + let res = client.get(uri).await.unwrap(); + let body_bytes = hyper::body::to_bytes(res.into_body()).await.unwrap(); + let body = String::from_utf8(body_bytes.to_vec()).unwrap(); // If the body contains these substrings, we can assume metrics are working assert!(body.contains("# TYPE shotover_transform_total counter")); diff --git a/shotover-proxy/tests/runner/runner_int_tests.rs b/shotover-proxy/tests/runner/runner_int_tests.rs new file mode 100644 index 000000000..067e2994f --- /dev/null +++ b/shotover-proxy/tests/runner/runner_int_tests.rs @@ -0,0 +1,41 @@ +use serial_test::serial; +use std::any::Any; + +use crate::helpers::ShotoverManager; + +#[tokio::test(flavor = "multi_thread")] +#[serial] +async fn test_runtime_use_existing() { + let shotover_manager = ShotoverManager::from_topology_file("examples/null-redis/topology.yaml"); + + // Assert that shotover is using the test runtime + let handle = tokio::runtime::Handle::current(); + assert_eq!(handle.type_id(), shotover_manager.runtime_handle.type_id()); + + // Assert that shotover did not create a runtime for itself + assert!(shotover_manager.runtime.is_none()); +} + +#[tokio::test(flavor = "current_thread")] +#[ntest::timeout(100)] +async fn test_shotover_panics_in_single_thread_runtime() { + let result = std::panic::catch_unwind(|| { + ShotoverManager::from_topology_file("examples/null-redis/topology.yaml"); + }); + assert!(result.is_err()); +} + +#[test] +#[serial] +fn test_runtime_create() { + let shotover_manager = ShotoverManager::from_topology_file("examples/null-redis/topology.yaml"); + + // Assert that shotover created a runtime for itself + assert!(shotover_manager.runtime.is_some()); +} + +#[test] +#[serial] +fn test_early_shutdown_cassandra_source() { + ShotoverManager::from_topology_file("examples/null-cassandra/topology.yaml"); +}