Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Runner changes: will use existing tokio::runtime or create one for itself and race condition fix #162

Merged
merged 20 commits into from
Sep 6, 2021
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/rust.yml → .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Rust
name: Testing

on:
push:
Expand Down
3 changes: 1 addition & 2 deletions shotover-proxy/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ edition = "2018"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
tokio = { version = "1.6.1", features = ["full", "macros"] }
tokio = { version = "1.11.0", features = ["full", "macros"] }
tokio-util = { version = "0.6.1", features = ["full"]}
tokio-stream = "0.1.2"
bytes = "1.0.0"
Expand Down Expand Up @@ -79,7 +79,6 @@ num_cpus = "1.0"
serial_test = "0.5.1"
test-helpers = { path = "../test-helpers" }
hex-literal = "0.3.3"
reqwest = { version = "0.11.4", features = ["blocking"] }

[[bench]]
name = "redis_benches"
Expand Down
3 changes: 1 addition & 2 deletions shotover-proxy/examples/cassandra-standalone/topology.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,10 @@ sources:
- clustering
chain_config:
main_chain:
# - Printer
- CodecDestination:
bypass_result_processing: true
remote_address: "172.18.0.2:9042"
named_topics:
testtopic: 5
source_to_chain_mapping:
cassandra_prod: main_chain
cassandra_prod: main_chain
3 changes: 3 additions & 0 deletions shotover-proxy/examples/null-cassandra/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
---
main_log_level: "info,shotover_proxy=info"
observability_interface: "0.0.0.0:9001"
21 changes: 21 additions & 0 deletions shotover-proxy/examples/null-cassandra/topology.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
---
sources:
cassandra_prod:
Cassandra:
bypass_query_processing: true
listen_addr: "127.0.0.1:9042"
cassandra_ks:
system.local:
- key
test.simple:
- pk
test.clustering:
- pk
- clustering
chain_config:
main_chain:
- Null
named_topics:
testtopic: 5
source_to_chain_mapping:
cassandra_prod: main_chain
3 changes: 3 additions & 0 deletions shotover-proxy/examples/null-redis/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
---
main_log_level: "info,shotover_proxy=info"
observability_interface: "0.0.0.0:9001"
13 changes: 13 additions & 0 deletions shotover-proxy/examples/null-redis/topology.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
---
sources:
redis_prod:
Redis:
batch_size_hint: 1
listen_addr: "127.0.0.1:6379"
chain_config:
redis_chain:
- Null
named_topics:
testtopic: 5
source_to_chain_mapping:
redis_prod: redis_chain
9 changes: 5 additions & 4 deletions shotover-proxy/src/config/topology.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@ use crate::transforms::{build_chain_from_config, TransformsConfig};
use anyhow::{anyhow, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::mpsc::{channel, Receiver, Sender};
use tokio::sync::oneshot::Sender as OneSender;
use tokio::sync::{broadcast, mpsc};
use tokio::sync::watch;
use tracing::info;

#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
Expand Down Expand Up @@ -107,14 +108,14 @@ impl Topology {
#[allow(clippy::type_complexity)]
pub async fn run_chains(
&self,
trigger_shutdown_rx: broadcast::Sender<()>,
trigger_shutdown_tx: Arc<watch::Sender<bool>>,
) -> Result<(Vec<Sources>, Receiver<()>)> {
let mut topics = self.build_topics();
info!("Loaded topics {:?}", topics.topics_tx.keys());

let mut sources_list: Vec<Sources> = Vec::new();

let (shutdown_complete_tx, shutdown_complete_rx) = mpsc::channel(1);
let (shutdown_complete_tx, shutdown_complete_rx) = channel(1);

let chains = self.build_chains(&topics).await?;
info!("Loaded chains {:?}", chains.keys());
Expand All @@ -127,7 +128,7 @@ impl Topology {
.get_source(
chain,
&mut topics,
trigger_shutdown_rx.clone(),
trigger_shutdown_tx.clone(),
shutdown_complete_tx.clone(),
)
.await?,
Expand Down
80 changes: 52 additions & 28 deletions shotover-proxy/src/runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@ use std::net::SocketAddr;
use anyhow::{anyhow, Result};
use clap::{crate_version, Clap};
use metrics_exporter_prometheus::PrometheusBuilder;
use tokio::runtime::{self, Runtime};
use std::sync::Arc;
use tokio::runtime::{self, Handle as RuntimeHandle, Runtime};
use tokio::signal;
use tokio::sync::broadcast;
use tokio::sync::watch;
use tokio::task::JoinHandle;
use tracing::{debug, error, info};
use tracing_appender::non_blocking::{NonBlocking, WorkerGuard};
Expand Down Expand Up @@ -51,7 +52,8 @@ impl Default for ConfigOpts {
}

pub struct Runner {
runtime: Runtime,
runtime: Option<Runtime>,
runtime_handle: RuntimeHandle,
topology: Topology,
config: Config,
tracing: TracingState,
Expand All @@ -62,18 +64,13 @@ impl Runner {
let config = Config::from_file(params.config_file.clone())?;
let topology = Topology::from_file(params.topology_file.clone())?;

let runtime = runtime::Builder::new_multi_thread()
.enable_all()
.thread_name("RPProxy-Thread")
.thread_stack_size(params.stack_size)
.worker_threads(params.core_threads)
.build()
.unwrap();

let tracing = TracingState::new(config.main_log_level.as_str())?;

let (runtime_handle, runtime) = Runner::get_runtime(params.stack_size, params.core_threads);

Ok(Runner {
runtime,
runtime_handle,
topology,
config,
tracing,
Expand All @@ -87,36 +84,61 @@ impl Runner {

let socket: SocketAddr = self.config.observability_interface.parse()?;
let exporter = LogFilterHttpExporter::new(handle, socket, self.tracing.handle.clone());
self.runtime.spawn(exporter.async_run());

self.runtime_handle.spawn(exporter.async_run());

Ok(self)
}

pub fn run_spawn(self) -> RunnerSpawned {
let (trigger_shutdown_tx, _) = broadcast::channel(1);
let handle =
self.runtime
.spawn(run(self.topology, self.config, trigger_shutdown_tx.clone()));
let (trigger_shutdown_tx, trigger_shutdown_rx) = watch::channel(false);
let trigger_shutdown_tx_arc = Arc::new(trigger_shutdown_tx);
rukai marked this conversation as resolved.
Show resolved Hide resolved

let join_handle = self.runtime_handle.spawn(run(
self.topology,
self.config,
trigger_shutdown_tx_arc.clone(),
));

RunnerSpawned {
runtime_handle: self.runtime_handle,
runtime: self.runtime,
tracing_guard: self.tracing.guard,
trigger_shutdown_tx,
handle,
trigger_shutdown_tx: trigger_shutdown_tx_arc,
trigger_shutdown_rx,
conorbros marked this conversation as resolved.
Show resolved Hide resolved
join_handle,
}
}

pub fn run_block(self) -> Result<()> {
let (trigger_shutdown_tx, _) = broadcast::channel(1);
let (trigger_shutdown_tx, _trigger_shutdown_rx) = watch::channel(false);

let trigger_shutdown_tx_clone = trigger_shutdown_tx.clone();
self.runtime.spawn(async move {
let trigger_shutdown_tx_arc = Arc::new(trigger_shutdown_tx);
let trigger_shutdown_tx_arc_c = trigger_shutdown_tx_arc.clone();
self.runtime_handle.spawn(async move {
signal::ctrl_c().await.unwrap();
trigger_shutdown_tx_clone.send(()).unwrap();
trigger_shutdown_tx_arc_c.send(true).unwrap();
});

self.runtime
.block_on(run(self.topology, self.config, trigger_shutdown_tx))
self.runtime_handle
.block_on(run(self.topology, self.config, trigger_shutdown_tx_arc))
}

/// Get handle for an existing runtime or create one
fn get_runtime(stack_size: usize, core_threads: usize) -> (RuntimeHandle, Option<Runtime>) {
if let Ok(handle) = tokio::runtime::Handle::try_current() {
(handle, None)
} else {
let runtime = runtime::Builder::new_multi_thread()
.enable_all()
.thread_name("Shotover-Proxy-Thread")
.thread_stack_size(stack_size)
.worker_threads(core_threads)
.build()
.unwrap();

(runtime.handle().clone(), Some(runtime))
}
}
}

Expand Down Expand Up @@ -172,16 +194,18 @@ impl TracingState {
}

pub struct RunnerSpawned {
pub runtime: Runtime,
pub handle: JoinHandle<Result<()>>,
pub runtime: Option<Runtime>,
pub runtime_handle: RuntimeHandle,
pub join_handle: JoinHandle<Result<()>>,
pub tracing_guard: WorkerGuard,
pub trigger_shutdown_tx: broadcast::Sender<()>,
pub trigger_shutdown_tx: Arc<watch::Sender<bool>>,
pub trigger_shutdown_rx: watch::Receiver<bool>,
}

pub async fn run(
topology: Topology,
config: Config,
trigger_shutdown_tx: broadcast::Sender<()>,
trigger_shutdown_tx: Arc<watch::Sender<bool>>,
) -> Result<()> {
info!("Starting Shotover {}", crate_version!());
info!(configuration = ?config);
Expand Down
32 changes: 12 additions & 20 deletions shotover-proxy/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use futures::StreamExt;
use metrics::gauge;
use std::sync::Arc;
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::{broadcast, mpsc, Semaphore};
use tokio::sync::{mpsc, watch, Semaphore};
use tokio::time;
use tokio::time::timeout;
use tokio::time::Duration;
Expand Down Expand Up @@ -61,23 +61,13 @@ pub struct TcpCodecListener<C: Codec> {
/// The initial `shutdown` trigger is provided by the `run` caller. The
/// server is responsible for gracefully shutting down active connections.
/// When a connection task is spawned, it is passed a broadcast receiver
/// handle. When a graceful shutdown is initiated, a `()` value is sent via
/// the broadcast::Sender. Each active connection receives it, reaches a
/// handle. When a graceful shutdown is initiated, a `true` value is sent via
/// the watch::Sender. Each active connection receives it, reaches a
/// safe terminal state, and completes the task.
pub trigger_shutdown_tx: broadcast::Sender<()>,
pub trigger_shutdown_tx: Arc<watch::Sender<bool>>,

/// Used as part of the graceful shutdown process to wait for client
/// connections to complete processing.
///
/// Tokio channels are closed once all `Sender` handles go out of scope.
/// When a channel is closed, the receiver receives `None`. This is
/// leveraged to detect all connection handlers completing. When a
/// connection handler is initialized, it is assigned a clone of
/// `shutdown_complete_tx`. When the listener shuts down, it drops the
/// sender held by this `shutdown_complete_tx` field. Once all handler tasks
/// complete, all clones of the `Sender` are also dropped. This results in
/// `shutdown_complete_rx.recv()` completing with `None`. At this point, it
/// is safe to exit the server process.
pub shutdown_complete_tx: mpsc::Sender<()>,
}

Expand Down Expand Up @@ -275,7 +265,6 @@ pub struct Handler<C: Codec> {
/// which point the connection is terminated.
shutdown: Shutdown,

/// Not used directly. Instead, when `Handler` is dropped...?
_shutdown_complete: mpsc::Sender<()>,
}

Expand Down Expand Up @@ -321,7 +310,7 @@ impl<C: Codec + 'static> Handler<C> {
});

tokio::spawn(async move {
let rx_stream = UnboundedReceiverStream::new(out_rx).map(|x| Ok(x));
let rx_stream = UnboundedReceiverStream::new(out_rx).map(Ok);
let r = rx_stream.forward(writer).await;
debug!("Stream ended {:?}", r);
});
Expand Down Expand Up @@ -418,12 +407,12 @@ pub struct Shutdown {
shutdown: bool,

/// The receive half of the channel used to listen for shutdown.
notify: broadcast::Receiver<()>,
notify: watch::Receiver<bool>,
}

impl Shutdown {
/// Create a new `Shutdown` backed by the given `broadcast::Receiver`.
pub(crate) fn new(notify: broadcast::Receiver<()>) -> Shutdown {
pub(crate) fn new(notify: watch::Receiver<bool>) -> Shutdown {
Shutdown {
shutdown: false,
notify,
Expand All @@ -443,8 +432,11 @@ impl Shutdown {
return;
}

// Cannot receive a "lag error" as only one value is ever sent.
self.notify.recv().await.unwrap();
// check we didn't receive a shutdown message before the receiver was created
if !*self.notify.borrow() {
Copy link
Member

@rukai rukai Sep 3, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks like a race condition, but also this whole shutdown abstraction looks kind of weird.

Copy link
Member Author

@conorbros conorbros Sep 3, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code is adapted from the tokio mini-redis project.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wondering why there was so many comments... hehe.
Ill take a look

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This snippet demonstrates why we need the check:

#[tokio::main]
async fn main() {
    let (tx, mut rx) = tokio::sync::watch::channel(false);

    tx.send(true).unwrap();

    rx.changed().await.unwrap();
    println!("borrow rx: {}", *rx.borrow());

    let mut rx2 = rx.clone();

    rx2.changed().await.unwrap(); //hangs here!
    println!("borrow rx2: {}", *rx2.borrow());
}

Alright, fair enough.

// Await the shutdown messsage
self.notify.changed().await.unwrap();
}

// Remember that the signal has been received.
self.shutdown = true;
Expand Down
Loading