diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 8c75bdf..3016608 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -30,7 +30,22 @@ jobs: pulsar-version: [ 2.10.6, 2.11.4, 3.0.8, 3.2.4, 3.3.3, 4.0.1 ] steps: - name: Start Pulsar Standalone Container - run: docker run --name pulsar -p 6650:6650 -p 8080:8080 -d -e GITHUB_ACTIONS=true -e CI=true apachepulsar/pulsar:${{ matrix.pulsar-version }} bin/pulsar standalone + run: | + docker run --name pulsar \ + -p 6650:6650 \ + -p 8080:8080 \ + -d \ + -e GITHUB_ACTIONS=true \ + -e CI=true \ + -e PULSAR_PREFIX_transactionCoordinatorEnabled="true" \ + -e PULSAR_PREFIX_systemTopicEnabled="true" \ + -e PULSAR_PREFIX_metadataStoreUrl="zk:127.0.0.1:2181" \ + apachepulsar/pulsar:${{ matrix.pulsar-version }} sh -c \ + "bin/apply-config-from-env.py \ + conf/standalone.conf && \ + bin/pulsar-daemon start zookeeper && \ + bin/pulsar initialize-transaction-coordinator-metadata -cs 127.0.0.1:2181 -c standalone && \ + bin/pulsar standalone" - uses: actions/checkout@v3 - uses: Swatinem/rust-cache@v2 - name: Run tests diff --git a/examples/transaction.rs b/examples/transaction.rs new file mode 100644 index 0000000..957b7ba --- /dev/null +++ b/examples/transaction.rs @@ -0,0 +1,191 @@ +#[macro_use] +extern crate serde; +use std::{env, time::Duration}; + +use futures::TryStreamExt; +use pulsar::{ + authentication::oauth2::OAuth2Authentication, message::proto, producer, Authentication, + Consumer, DeserializeMessage, Error as PulsarError, Payload, Pulsar, SerializeMessage, + TokioExecutor, +}; + +#[derive(Serialize, Deserialize, Debug)] +struct TestData { + data: String, +} + +impl SerializeMessage for TestData { + fn serialize_message(input: Self) -> Result { + let payload = serde_json::to_vec(&input).map_err(|e| PulsarError::Custom(e.to_string()))?; + Ok(producer::Message { + payload, + ..Default::default() + }) + } +} + +impl DeserializeMessage for TestData { + type Output = Result; + + fn deserialize_message(payload: &Payload) -> Self::Output { + serde_json::from_slice(&payload.data) + } +} + +#[tokio::main] +async fn main() -> Result<(), pulsar::Error> { + env_logger::init(); + + let addr = env::var("PULSAR_ADDRESS") + .ok() + .unwrap_or_else(|| "pulsar://127.0.0.1:6650".to_string()); + + let mut builder = Pulsar::builder(addr, TokioExecutor).with_transactions(); + + if let Ok(token) = env::var("PULSAR_TOKEN") { + let authentication = Authentication { + name: "token".to_string(), + data: token.into_bytes(), + }; + + builder = builder.with_auth(authentication); + } else if let Ok(oauth2_cfg) = env::var("PULSAR_OAUTH2") { + builder = builder.with_auth_provider(OAuth2Authentication::client_credentials( + serde_json::from_str(oauth2_cfg.as_str()) + .unwrap_or_else(|_| panic!("invalid oauth2 config [{}]", oauth2_cfg.as_str())), + )); + } + + let pulsar: Pulsar<_> = builder.build().await?; + + let input_topic = "persistent://public/default/test-input-topic"; + let output_topic_one = "persistent://public/default/test-output-topic-1"; + let output_topic_two = "persistent://public/default/test-output-topic-2"; + + let producer_builder = pulsar.producer().with_options(producer::ProducerOptions { + schema: Some(proto::Schema { + r#type: proto::schema::Type::String as i32, + ..Default::default() + }), + ..Default::default() + }); + + let mut input_producer = producer_builder + .clone() + .with_topic(input_topic) + .build() + .await?; + + let mut output_producer_one = producer_builder + .clone() + .with_topic(output_topic_one) + .build() + .await?; + + let mut output_producer_two = producer_builder + .clone() + .with_topic(output_topic_two) + .build() + .await?; + + let mut input_consumer: Consumer = pulsar + .consumer() + .with_topic(input_topic) + .with_subscription("test-input-subscription") + .build() + .await?; + + let mut output_consumer_one: Consumer = pulsar + .consumer() + .with_topic(output_topic_one) + .with_subscription("test-output-subscription-1") + .build() + .await?; + + let mut output_consumer_two: Consumer = pulsar + .consumer() + .with_topic(output_topic_two) + .with_subscription("test-output-subscription-2") + .build() + .await?; + + let count = 2; + + for i in 0..count { + input_producer + .send_non_blocking(TestData { + data: format!("Hello Pulsar! count : {}", i), + }) + .await? + .await?; + } + + for i in 0..count { + let msg = input_consumer + .try_next() + .await? + .expect("No message received"); + + let txn = pulsar + .new_txn()? + .with_timeout(Duration::from_secs(10)) + .build() + .await?; + + output_producer_one + .create_message() + .with_content(TestData { + data: format!("Hello Pulsar! output_topic_one count : {}", i), + }) + .with_txn(&txn) + .send_non_blocking() + .await?; + + output_producer_two + .create_message() + .with_content(TestData { + data: format!("Hello Pulsar! output_topic_two count : {}", i), + }) + .with_txn(&txn) + .send_non_blocking() + .await?; + + input_consumer.txn_ack(&msg, &txn).await?; + + if let Err(e) = txn.commit().await { + match e { + pulsar::Error::Transaction(pulsar::error::TransactionError::Conflict) => { + // If TransactionConflictException is not thrown, + // you need to redeliver or negativeAcknowledge this message, + // or else this message will not be received again. + input_consumer.nack(&msg).await?; + } + _ => (), + } + + txn.abort().await?; + + return Err(e); + } + } + + for _ in 0..count { + let msg = output_consumer_one + .try_next() + .await? + .expect("No message received"); + + println!("Received transaction message: {:?}", msg.deserialize()); + } + + for _ in 0..count { + let msg = output_consumer_two + .try_next() + .await? + .expect("No message received"); + + println!("Received transaction message: {:?}", msg.deserialize()); + } + + Ok(()) +} diff --git a/src/client.rs b/src/client.rs index 873b571..3d89bf9 100644 --- a/src/client.rs +++ b/src/client.rs @@ -21,6 +21,7 @@ use crate::{ }, producer::{self, ProducerBuilder, SendFuture}, service_discovery::ServiceDiscovery, + transaction::{coord::TransactionCoordinatorClient, TransactionBuilder}, }; /// Helper trait for consumer deserialization @@ -156,6 +157,8 @@ impl SerializeMessage for &str { pub struct Pulsar { pub(crate) manager: Arc>, service_discovery: Arc>, + /// The transaction coordinator client, if transactions are enabled. + tc_client: Option>>, // this field is an Option to avoid a cyclic dependency between Pulsar // and run_producer: the run_producer loop needs a client to create // a multitopic producer, this producer stores internally a copy @@ -179,6 +182,7 @@ impl Pulsar { connection_retry_parameters: Option, operation_retry_parameters: Option, tls_options: Option, + transactions_enabled: bool, outbound_channel_size: Option, executor: Exe, ) -> Result { @@ -221,17 +225,24 @@ impl Pulsar { let (producer, producer_rx) = mpsc::unbounded(); let mut client = Pulsar { - manager, + manager: Arc::clone(&manager), service_discovery, + tc_client: None, producer: None, operation_retry_options, executor, }; + if transactions_enabled { + let tc_client = TransactionCoordinatorClient::new(client.clone(), manager).await?; + client.tc_client = Some(Arc::new(tc_client)); + } + let _ = client .executor .spawn(Box::pin(run_producer(client.clone(), producer_rx))); client.producer = Some(producer); + Ok(client) } @@ -256,10 +267,33 @@ impl Pulsar { operation_retry_options: None, tls_options: None, outbound_channel_size: None, + transactions_enabled: false, executor, } } + /// Creates a new transaction builder. If transactions were not enabled when creating the + /// Pulsar client, this will return an error. + /// + /// ```rust,no_run + /// use std::time::Duration; + /// use pulsar::Transaction; + /// + /// # async fn run(pulsar: pulsar::Pulsar) -> Result<(), pulsar::Error> { + /// let txn = pulsar.new_txn()?.with_timeout(Duration::from_millis(1000)).build().await?; + /// # Ok(()) + /// # } + pub fn new_txn(&self) -> Result, Error> { + if let Some(tc_client) = self.tc_client.as_ref() { + Ok(TransactionBuilder::new( + Arc::clone(&self.executor), + Arc::clone(tc_client), + )) + } else { + Err(Error::Custom("Transactions are not enabled".into())) + } + } + /// creates a consumer builder /// /// ```rust,no_run @@ -457,6 +491,7 @@ pub struct PulsarBuilder { operation_retry_options: Option, tls_options: Option, outbound_channel_size: Option, + transactions_enabled: bool, executor: Exe, } @@ -560,6 +595,13 @@ impl PulsarBuilder { self } + /// Enable transactions on this Pulsar client + #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))] + pub fn with_transactions(mut self) -> Self { + self.transactions_enabled = true; + self + } + /// creates the Pulsar client and connects it #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))] pub async fn build(self) -> Result, Error> { @@ -570,6 +612,7 @@ impl PulsarBuilder { operation_retry_options, tls_options, outbound_channel_size, + transactions_enabled, executor, } = self; @@ -579,6 +622,7 @@ impl PulsarBuilder { connection_retry_options, operation_retry_options, tls_options, + transactions_enabled, outbound_channel_size, executor, ) diff --git a/src/connection.rs b/src/connection.rs index 1b23e26..a851fb0 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -34,6 +34,8 @@ use crate::{ BaseCommand, Codec, Message, }, producer::{self, ProducerOptions}, + proto::ProtocolVersion, + transaction::TransactionId, Certificate, }; @@ -286,6 +288,7 @@ impl SerialId { pub struct ConnectionSender { connection_id: Uuid, tx: async_channel::Sender, + protocol_version: ProtocolVersion, registrations: mpsc::UnboundedSender, receiver_shutdown: Option>, request_id: SerialId, @@ -299,6 +302,7 @@ impl ConnectionSender { pub(crate) fn new( connection_id: Uuid, tx: async_channel::Sender, + protocol_version: ProtocolVersion, registrations: mpsc::UnboundedSender, receiver_shutdown: oneshot::Sender<()>, request_id: SerialId, @@ -309,6 +313,7 @@ impl ConnectionSender { ConnectionSender { connection_id, tx, + protocol_version, registrations, receiver_shutdown: Some(receiver_shutdown), request_id, @@ -431,6 +436,83 @@ impl ConnectionSender { .await } + #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))] + pub async fn new_txn( + &self, + tc_id: u64, + timeout: Option, + ) -> Result { + let request_id = self.request_id.get(); + let msg = messages::new_txn(tc_id, timeout, request_id); + self.send_message(msg, RequestKey::RequestId(request_id), |resp| { + resp.command.new_txn_response + }) + .await + } + + #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))] + pub async fn end_txn( + &self, + txn_id: TransactionId, + action: proto::TxnAction, + ) -> Result { + let request_id = self.request_id.get(); + let msg = messages::end_txn(txn_id, action, request_id); + self.send_message(msg, RequestKey::RequestId(request_id), |resp| { + resp.command.end_txn_response + }) + .await + } + + #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))] + pub async fn tc_client_connect_request( + &self, + tc_id: u64, + ) -> Result { + let request_id = self.request_id.get(); + let msg = messages::tc_client_connect_request(tc_id, request_id); + self.send_message(msg, RequestKey::RequestId(request_id), |resp| { + resp.command.tc_client_connect_response + }) + .await + } + + #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))] + pub async fn add_partition_to_txn( + &self, + txn_id: TransactionId, + partitions: Vec, + ) -> Result { + let request_id = self.request_id.get(); + let msg = messages::add_partition_to_txn(txn_id, partitions, request_id); + self.send_message(msg, RequestKey::RequestId(request_id), |resp| { + resp.command.add_partition_to_txn_response + }) + .await + } + + #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))] + pub async fn add_subscription_to_txn( + &self, + txn_id: TransactionId, + topic: String, + subscription: String, + ) -> Result { + let request_id = self.request_id.get(); + let msg = messages::add_subscription_to_txn( + txn_id, + vec![proto::Subscription { + topic, + subscription, + }], + request_id, + ); + self.send_message(msg, RequestKey::RequestId(request_id), |resp| { + resp.command.add_subscription_to_txn_response + }) + .await + } + #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))] pub async fn create_producer( &self, @@ -539,10 +621,11 @@ impl ConnectionSender { &self, consumer_id: u64, message_ids: Vec, + txn_id: Option, cumulative: bool, ) -> Result<(), ConnectionError> { self.tx - .send(messages::ack(consumer_id, message_ids, cumulative)) + .send(messages::ack(consumer_id, message_ids, txn_id, cumulative)) .await?; Ok(()) } @@ -1172,7 +1255,7 @@ impl Connection { .await?; let msg = stream.next().await; - match msg { + let protocol_version = match msg { Some(Ok(Message { command: proto::BaseCommand { @@ -1186,9 +1269,20 @@ impl Connection { Some(Ok(msg)) => { let cmd = msg.command.clone(); trace!("received connection response: {:?}", msg); - msg.command.connected.ok_or_else(|| { - ConnectionError::Unexpected(format!("Unexpected message from pulsar: {cmd:?}")) - }) + + match msg.command.connected { + Some(proto::CommandConnected { + protocol_version, .. + }) => protocol_version + .map::, _>(|v| v.try_into()) + .transpose() + .map_err(|_| { + ConnectionError::UnexpectedResponse("Invalid protocol version".into()) + }), + _ => Err(ConnectionError::Unexpected(format!( + "Unexpected message from pulsar: {cmd:?}" + ))), + } } Some(Err(e)) => Err(e), None => Err(ConnectionError::Disconnected), @@ -1263,6 +1357,7 @@ impl Connection { let sender = ConnectionSender::new( connection_id, tx, + protocol_version.unwrap_or(ProtocolVersion::V0), registrations_tx, receiver_shutdown_tx, SerialId::new(), @@ -1294,6 +1389,11 @@ impl Connection { &self.url } + #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))] + pub fn protocol_version(&self) -> ProtocolVersion { + self.sender().protocol_version + } + /// Chain to send a message, e.g. conn.sender().send_ping() #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))] pub fn sender(&self) -> &ConnectionSender { @@ -1344,6 +1444,8 @@ pub(crate) mod messages { Message, Payload, }, producer::{self, ProducerOptions}, + proto::TxnAction, + transaction::TransactionId, }; #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))] @@ -1361,7 +1463,7 @@ pub(crate) mod messages { auth_data, proxy_to_broker_url, client_version: proto::client_version(), - protocol_version: Some(12), + protocol_version: Some(19), ..Default::default() }), ..Default::default() @@ -1486,6 +1588,8 @@ pub(crate) mod messages { producer_id, sequence_id, num_messages: message.num_messages_in_batch, + txnid_least_bits: message.txn_id.map(|id| id.least_bits()), + txnid_most_bits: message.txn_id.map(|id| id.most_bits()), ..Default::default() }), ..Default::default() @@ -1654,6 +1758,7 @@ pub(crate) mod messages { pub fn ack( consumer_id: u64, message_id: Vec, + txn_id: Option, cumulative: bool, ) -> Message { Message { @@ -1667,6 +1772,8 @@ pub(crate) mod messages { proto::command_ack::AckType::Individual as i32 }, message_id, + txnid_least_bits: txn_id.map(|id| id.least_bits()), + txnid_most_bits: txn_id.map(|id| id.most_bits()), validation_error: None, properties: Vec::new(), ..Default::default() @@ -1767,6 +1874,96 @@ pub(crate) mod messages { payload: None, } } + + #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))] + pub fn new_txn(tc_id: u64, txn_ttl_seconds: Option, request_id: u64) -> Message { + Message { + command: proto::BaseCommand { + r#type: CommandType::NewTxn as i32, + new_txn: Some(proto::CommandNewTxn { + request_id, + tc_id: Some(tc_id), + txn_ttl_seconds, + }), + ..Default::default() + }, + payload: None, + } + } + + #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))] + pub fn end_txn(txn_id: TransactionId, txn_action: TxnAction, request_id: u64) -> Message { + Message { + command: proto::BaseCommand { + r#type: CommandType::EndTxn as i32, + end_txn: Some(proto::CommandEndTxn { + request_id, + txnid_least_bits: Some(txn_id.least_bits()), + txnid_most_bits: Some(txn_id.most_bits()), + txn_action: Some(txn_action as i32), + }), + ..Default::default() + }, + payload: None, + } + } + + #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))] + pub fn tc_client_connect_request(tc_id: u64, request_id: u64) -> Message { + Message { + command: proto::BaseCommand { + r#type: CommandType::TcClientConnectRequest as i32, + tc_client_connect_request: Some(proto::CommandTcClientConnectRequest { + request_id, + tc_id, + }), + ..Default::default() + }, + payload: None, + } + } + + #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))] + pub fn add_partition_to_txn( + txn_id: TransactionId, + partitions: Vec, + request_id: u64, + ) -> Message { + Message { + command: proto::BaseCommand { + r#type: CommandType::AddPartitionToTxn as i32, + add_partition_to_txn: Some(proto::CommandAddPartitionToTxn { + request_id, + txnid_least_bits: Some(txn_id.least_bits()), + txnid_most_bits: Some(txn_id.most_bits()), + partitions, + }), + ..Default::default() + }, + payload: None, + } + } + + #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))] + pub fn add_subscription_to_txn( + txn_id: TransactionId, + subscription: Vec, + request_id: u64, + ) -> Message { + Message { + command: proto::BaseCommand { + r#type: CommandType::AddSubscriptionToTxn as i32, + add_subscription_to_txn: Some(proto::CommandAddSubscriptionToTxn { + request_id, + txnid_least_bits: Some(txn_id.least_bits()), + txnid_most_bits: Some(txn_id.most_bits()), + subscription, + }), + ..Default::default() + }, + payload: None, + } + } } #[cfg(test)] diff --git a/src/consumer/data.rs b/src/consumer/data.rs index 38200c0..1552751 100644 --- a/src/consumer/data.rs +++ b/src/consumer/data.rs @@ -5,6 +5,7 @@ use futures::channel::{mpsc, oneshot}; use crate::{ connection::Connection, message::{proto::MessageIdData, Message as RawMessage}, + transaction::TransactionId, Error, Executor, Payload, }; @@ -16,7 +17,7 @@ pub enum EngineEvent { } pub enum EngineMessage { - Ack(MessageIdData, bool), + Ack(MessageIdData, Option, bool), Nack(MessageIdData), UnackedRedelivery, GetConnection(oneshot::Sender>>), diff --git a/src/consumer/engine.rs b/src/consumer/engine.rs index 0d7b6f6..a4da600 100644 --- a/src/consumer/engine.rs +++ b/src/consumer/engine.rs @@ -23,9 +23,9 @@ use crate::{ proto::{command_subscribe::SubType, MessageIdData}, Message as RawMessage, }, - proto, - proto::{BaseCommand, CommandCloseConsumer, CommandMessage}, + proto::{self, BaseCommand, CommandCloseConsumer, CommandMessage}, retry_op::retry_subscribe_consumer, + transaction::TransactionId, Error, Executor, Payload, Pulsar, }; @@ -279,8 +279,8 @@ impl ConsumerEngine { trace!("ack channel was closed"); false } - Some(EngineMessage::Ack(message_id, cumulative)) => { - self.ack(message_id, cumulative).await; + Some(EngineMessage::Ack(message_id, maybe_txn, cumulative)) => { + self.ack(message_id, maybe_txn, cumulative).await; true } Some(EngineMessage::Nack(message_id)) => { @@ -338,13 +338,18 @@ impl ConsumerEngine { } #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))] - async fn ack(&mut self, message_id: MessageIdData, cumulative: bool) { + async fn ack( + &mut self, + message_id: MessageIdData, + txn_id: Option, + cumulative: bool, + ) { // FIXME: this does not handle cumulative acks self.unacked_messages.remove(&message_id); let res = self .connection .sender() - .send_ack(self.id, vec![message_id], cumulative) + .send_ack(self.id, vec![message_id], txn_id, cumulative) .await; if res.is_err() { error!("ack error: {:?}", res); @@ -561,7 +566,7 @@ impl ConsumerEngine { Error::Custom("DLQ send error".to_string()) })?; - self.ack(message_id, false).await; + self.ack(message_id, None, false).await; } _ => self.send_to_consumer(message_id, payload).await?, } diff --git a/src/consumer/mod.rs b/src/consumer/mod.rs index 2a48298..0a19aa2 100644 --- a/src/consumer/mod.rs +++ b/src/consumer/mod.rs @@ -37,7 +37,7 @@ use crate::{ executor::Executor, message::proto::{command_subscribe::SubType, MessageIdData, Schema}, proto::CommandConsumerStatsResponse, - DeserializeMessage, Pulsar, + DeserializeMessage, Pulsar, Transaction, }; enum InnerConsumer { @@ -108,6 +108,29 @@ impl Consumer { } } + /// acknowledges a single message as part of a transaction + #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))] + pub async fn txn_ack(&mut self, msg: &Message, txn: &Transaction) -> Result<(), Error> { + match &mut self.inner { + InnerConsumer::Single(c) => c.txn_ack(msg, txn).await, + InnerConsumer::Multi(c) => c.txn_ack(msg, txn).await, + } + } + + /// acknowledges a single message as part of a transaction, with a given ID + #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))] + pub async fn txn_ack_with_id( + &mut self, + msg: &Message, + msg_id: MessageIdData, + txn: &Transaction, + ) -> Result<(), Error> { + match &mut self.inner { + InnerConsumer::Single(c) => c.txn_ack_with_id(msg, msg_id, txn).await, + InnerConsumer::Multi(c) => c.txn_ack_with_id(msg, msg_id, txn).await, + } + } + /// acknowledges a single message #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))] pub async fn ack(&mut self, msg: &Message) -> Result<(), ConsumerError> { diff --git a/src/consumer/multi.rs b/src/consumer/multi.rs index d0fae2e..e4af538 100644 --- a/src/consumer/multi.rs +++ b/src/consumer/multi.rs @@ -15,9 +15,8 @@ use crate::{ consumer::{config::ConsumerConfig, message::Message, topic::TopicConsumer}, error::{ConnectionError, ConsumerError}, message::proto::{MessageIdData, Schema}, - proto, - proto::CommandConsumerStatsResponse, - DeserializeMessage, Error, Executor, Pulsar, + proto::{self, CommandConsumerStatsResponse}, + DeserializeMessage, Error, Executor, Pulsar, Transaction, }; /// A consumer that can subscribe on multiple topics, from a regex matching @@ -199,6 +198,29 @@ impl MultiTopicConsumer { })); } + #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))] + pub async fn txn_ack(&mut self, msg: &Message, txn: &Transaction) -> Result<(), Error> { + if let Some(c) = self.consumers.get_mut(&msg.topic) { + c.txn_ack(msg, txn).await + } else { + Err(ConnectionError::Unexpected(format!("no consumer for topic {}", msg.topic)).into()) + } + } + + #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))] + pub async fn txn_ack_with_id( + &mut self, + msg: &Message, + msg_id: MessageIdData, + txn: &Transaction, + ) -> Result<(), Error> { + if let Some(c) = self.consumers.get_mut(&msg.topic) { + c.txn_ack_with_id(msg, msg_id, txn).await + } else { + Err(ConnectionError::Unexpected(format!("no consumer for topic {}", msg.topic)).into()) + } + } + #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))] pub async fn ack(&mut self, msg: &Message) -> Result<(), ConsumerError> { if let Some(c) = self.consumers.get_mut(&msg.topic) { diff --git a/src/consumer/topic.rs b/src/consumer/topic.rs index 6837512..23620f4 100644 --- a/src/consumer/topic.rs +++ b/src/consumer/topic.rs @@ -27,7 +27,7 @@ use crate::{ message::proto::{MessageIdData, Schema}, proto::CommandConsumerStatsResponse, retry_op::retry_subscribe_consumer, - BrokerAddress, DeserializeMessage, Error, Executor, Payload, Pulsar, + BrokerAddress, DeserializeMessage, Error, Executor, Payload, Pulsar, Transaction, }; // this is entirely public for use in reader.rs @@ -179,10 +179,47 @@ impl TopicConsumer { Ok(()) } + #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))] + async fn txn_ack_inner( + &mut self, + msg: &Message, + msg_id: Option, + txn: &Transaction, + ) -> Result<(), Error> { + txn.register_acked_topic(self.topic(), self.config().subscription.clone()) + .await?; + + self.engine_tx + .send(EngineMessage::Ack( + msg_id.unwrap_or(msg.message_id().clone()), + Some(txn.id()), + false, + )) + .await + .map_err::(Into::into)?; + + Ok(()) + } + + #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))] + pub async fn txn_ack(&mut self, msg: &Message, txn: &Transaction) -> Result<(), Error> { + self.txn_ack_inner(msg, None, txn).await + } + + #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))] + pub async fn txn_ack_with_id( + &mut self, + msg: &Message, + msg_id: MessageIdData, + txn: &Transaction, + ) -> Result<(), Error> { + self.txn_ack_inner(msg, Some(msg_id), txn).await + } + #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))] pub async fn ack(&mut self, msg: &Message) -> Result<(), ConsumerError> { self.engine_tx - .send(EngineMessage::Ack(msg.message_id().clone(), false)) + .send(EngineMessage::Ack(msg.message_id().clone(), None, false)) .await?; Ok(()) } @@ -190,7 +227,7 @@ impl TopicConsumer { #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))] pub async fn ack_with_id(&mut self, msg_id: MessageIdData) -> Result<(), ConsumerError> { self.engine_tx - .send(EngineMessage::Ack(msg_id, false)) + .send(EngineMessage::Ack(msg_id, None, false)) .await?; Ok(()) } @@ -203,7 +240,7 @@ impl TopicConsumer { #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))] pub async fn cumulative_ack(&mut self, msg: &Message) -> Result<(), ConsumerError> { self.engine_tx - .send(EngineMessage::Ack(msg.message_id().clone(), true)) + .send(EngineMessage::Ack(msg.message_id().clone(), None, true)) .await?; Ok(()) } @@ -214,7 +251,7 @@ impl TopicConsumer { msg_id: MessageIdData, ) -> Result<(), ConsumerError> { self.engine_tx - .send(EngineMessage::Ack(msg_id, true)) + .send(EngineMessage::Ack(msg_id, None, true)) .await?; Ok(()) } diff --git a/src/error.rs b/src/error.rs index a5b2535..18e41ec 100644 --- a/src/error.rs +++ b/src/error.rs @@ -7,7 +7,11 @@ use std::{ }, }; -use crate::{message::proto::ServerError, producer::SendFuture}; +use crate::{ + message::proto::ServerError, + producer::SendFuture, + transaction::{State as TransactionState, TransactionId}, +}; #[derive(Debug)] pub enum Error { @@ -18,6 +22,7 @@ pub enum Error { Authentication(AuthenticationError), Custom(String), Executor, + Transaction(TransactionError), } impl From for Error { @@ -59,6 +64,7 @@ impl fmt::Display for Error { Error::Authentication(e) => write!(f, "authentication error: {e}"), Error::Custom(e) => write!(f, "error: {e}"), Error::Executor => write!(f, "could not spawn task"), + Error::Transaction(e) => write!(f, "transaction error: {e}"), } } } @@ -74,6 +80,7 @@ impl std::error::Error for Error { Error::Authentication(e) => e.source(), Error::Custom(_) => None, Error::Executor => None, + Error::Transaction(e) => e.source(), } } } @@ -89,6 +96,7 @@ pub enum ConnectionError { Encoding(String), SocketAddr(String), UnexpectedResponse(String), + UnsupportedProtocolVersion(u32), #[cfg(any(feature = "tokio-runtime", feature = "async-std-runtime"))] Tls(native_tls::Error), #[cfg(all( @@ -197,6 +205,9 @@ impl fmt::Display for ConnectionError { ConnectionError::NotFound => write!(f, "error looking up URL"), ConnectionError::Canceled => write!(f, "canceled request"), ConnectionError::Shutdown => write!(f, "The connection was shut down"), + ConnectionError::UnsupportedProtocolVersion(v) => { + write!(f, "Unsupported protocol version: {v}") + } } } } @@ -211,6 +222,59 @@ impl std::error::Error for ConnectionError { } } +#[derive(Debug)] +pub enum TransactionError { + InvalidState(TransactionState), + /// Invalid timeout value + InvalidTimeout, + /// Transaction has timed out + TimedOut, + /// Transaction not found + NotFound, + /// Transaction Conflict + Conflict, + /// Transaction Meta Handler not found + MetaHandlerNotFound(TransactionId), + /// Transaction Coordinator not found + CoordinatorNotFound, +} + +impl std::error::Error for TransactionError { + #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))] + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + TransactionError::InvalidState(_) => None, + TransactionError::InvalidTimeout => None, + TransactionError::TimedOut => None, + TransactionError::NotFound => None, + TransactionError::Conflict => None, + TransactionError::MetaHandlerNotFound(_) => None, + TransactionError::CoordinatorNotFound => None, + } + } +} + +impl fmt::Display for TransactionError { + #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))] + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + TransactionError::InvalidState(state) => { + write!(f, "Transaction is in invalid state: {state}") + } + TransactionError::InvalidTimeout => write!(f, "Invalid timeout value"), + TransactionError::TimedOut => write!(f, "Transaction has timed out"), + TransactionError::NotFound => write!(f, "Transaction not found"), + TransactionError::Conflict => write!(f, "Transaction conflict"), + TransactionError::MetaHandlerNotFound(id) => { + write!(f, "Transaction Meta Handler not found for txn id: {id}") + } + TransactionError::CoordinatorNotFound => { + write!(f, "Transaction Coordinator not found") + } + } + } +} + #[derive(Debug)] pub enum ConsumerError { Connection(ConnectionError), @@ -219,6 +283,7 @@ pub enum ConsumerError { ChannelFull, Closed, BuildError, + Transaction(TransactionError), } impl From for ConsumerError { @@ -262,6 +327,7 @@ impl fmt::Display for ConsumerError { "cannot send message to the consumer engine: the channel is closed" ), ConsumerError::BuildError => write!(f, "Error while building the consumer."), + ConsumerError::Transaction(e) => write!(f, "Transaction error: {e}"), } } } diff --git a/src/lib.rs b/src/lib.rs index 166c424..b8118d5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -182,6 +182,7 @@ pub use message::{ Payload, }; pub use producer::{MultiTopicProducer, Producer, ProducerOptions}; +pub use transaction::{Transaction, TransactionBuilder}; pub mod authentication; mod client; @@ -196,6 +197,7 @@ pub mod producer; pub mod reader; mod retry_op; mod service_discovery; +pub mod transaction; #[cfg(all( any(feature = "tokio-rustls-runtime", feature = "async-std-rustls-runtime"), @@ -518,6 +520,109 @@ mod tests { assert!(redelivery < Duration::from_secs(1)); } + #[tokio::test] + #[cfg(any(feature = "tokio-runtime", feature = "tokio-rustls-runtime"))] + async fn transactions() { + let _result = log::set_logger(&TEST_LOGGER); + log::set_max_level(LevelFilter::Debug); + + let addr = "pulsar://127.0.0.1:6650"; + let output_topic = format!("test_transactions_output_{}", rand::random::()); + + let pulsar: Pulsar<_> = Pulsar::builder(addr, TokioExecutor) + .with_transactions() + .build() + .await + .unwrap(); + + let mut output_producer = pulsar + .producer() + .with_topic(&output_topic) + .build() + .await + .unwrap(); + + let mut output_consumer: Consumer = pulsar + .consumer() + .with_topic(&output_topic) + .build() + .await + .unwrap(); + + let txn = pulsar + .new_txn() + .unwrap() + .with_timeout(Duration::from_secs(10)) + .build() + .await + .unwrap(); + + output_producer + .create_message() + .with_content("test 1".to_string()) + .with_txn(&txn) + .send_non_blocking() + .await + .unwrap(); + + // Ensure that consumers cannot see messages from an open transaction + assert!(timeout(Duration::from_secs(1), output_consumer.next()) + .await + .is_err()); + + txn.commit().await.unwrap(); + + let txn = pulsar + .new_txn() + .unwrap() + .with_timeout(Duration::from_secs(10)) + .build() + .await + .unwrap(); + + // Ensure that consumers see messages from a committed transaction + let msg = output_consumer.next().await.unwrap().unwrap(); + assert_eq!("test 1", std::str::from_utf8(&msg.payload.data).unwrap()); + output_consumer.txn_ack(&msg, &txn).await.unwrap(); + + output_producer + .create_message() + .with_content("test 2".to_string()) + .with_txn(&txn) + .send_non_blocking() + .await + .unwrap(); + + txn.commit().await.unwrap(); + + // We shouldn't see "test 2" now + let msg = output_consumer.next().await.unwrap().unwrap(); + assert_eq!("test 2", std::str::from_utf8(&msg.payload.data).unwrap()); + + let txn = pulsar + .new_txn() + .unwrap() + .with_timeout(Duration::from_secs(10)) + .build() + .await + .unwrap(); + + output_producer + .create_message() + .with_content("test 3".to_string()) + .with_txn(&txn) + .send_non_blocking() + .await + .unwrap(); + + txn.abort().await.unwrap(); + + // We shouldn't see any more messages + assert!(timeout(Duration::from_secs(1), output_consumer.next()) + .await + .is_err()); + } + #[tokio::test] #[cfg(any(feature = "tokio-runtime", feature = "tokio-rustls-runtime"))] async fn batching() { diff --git a/src/message.rs b/src/message.rs index 2c9fd8e..339d409 100644 --- a/src/message.rs +++ b/src/message.rs @@ -115,6 +115,28 @@ impl Message { | BaseCommand { get_schema_response: Some(CommandGetSchemaResponse { request_id, .. }), .. + } + | BaseCommand { + new_txn_response: Some(CommandNewTxnResponse { request_id, .. }), + .. + } + | BaseCommand { + tc_client_connect_response: Some(CommandTcClientConnectResponse { request_id, .. }), + .. + } + | BaseCommand { + add_partition_to_txn_response: + Some(CommandAddPartitionToTxnResponse { request_id, .. }), + .. + } + | BaseCommand { + add_subscription_to_txn_response: + Some(CommandAddSubscriptionToTxnResponse { request_id, .. }), + .. + } + | BaseCommand { + end_txn_response: Some(CommandEndTxnResponse { request_id, .. }), + .. } => Some(RequestKey::RequestId(*request_id)), BaseCommand { send: diff --git a/src/producer.rs b/src/producer.rs index 2ba383d..8b6d7cf 100644 --- a/src/producer.rs +++ b/src/producer.rs @@ -29,7 +29,8 @@ use crate::{ BatchedMessage, }, retry_op::retry_create_producer, - BrokerAddress, Error, Pulsar, + transaction::TransactionId, + BrokerAddress, Error, Pulsar, Transaction, }; type ProducerId = u64; @@ -81,6 +82,8 @@ pub struct Message { pub event_time: ::std::option::Option, /// current version of the schema pub schema_version: ::std::option::Option>, + /// Transaction ID + pub txn_id: ::std::option::Option, } /// internal message type carrying options that must be defined @@ -113,6 +116,8 @@ pub(crate) struct ProducerMessage { /// UTC Unix timestamp in milliseconds, time at which the message should be /// delivered to consumers pub deliver_at_time: ::std::option::Option, + /// Transaction ID + pub txn_id: ::std::option::Option, } impl From for ProducerMessage { @@ -126,6 +131,7 @@ impl From for ProducerMessage { replicate_to: m.replicate_to, event_time: m.event_time, schema_version: m.schema_version, + txn_id: m.txn_id, ..Default::default() } } @@ -754,7 +760,7 @@ impl TopicProducer { message: ProducerMessage, ) -> Result>, Error> { loop { - let msg = message.clone(); + let msg: ProducerMessage = message.clone(); match self.connection.sender().send( self.id, self.name.clone(), @@ -1076,6 +1082,7 @@ pub struct MessageBuilder<'a, T, Exe: Executor> { partition_key: Option, ordering_key: Option>, deliver_at_time: Option, + txn: Option<&'a Transaction>, event_time: Option, content: T, } @@ -1090,6 +1097,7 @@ impl<'a, Exe: Executor> MessageBuilder<'a, (), Exe> { partition_key: None, ordering_key: None, deliver_at_time: None, + txn: None, event_time: None, content: (), } @@ -1107,6 +1115,7 @@ impl<'a, T, Exe: Executor> MessageBuilder<'a, T, Exe> { ordering_key: self.ordering_key, deliver_at_time: self.deliver_at_time, event_time: self.event_time, + txn: self.txn, content, } } @@ -1142,6 +1151,13 @@ impl<'a, T, Exe: Executor> MessageBuilder<'a, T, Exe> { self } + /// adds the message to the specified transaction + #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))] + pub fn with_txn(mut self, txn: &'a Transaction) -> Self { + self.txn = Some(txn); + self + } + /// delivers the message at this date /// Note: The delayed and scheduled message attributes are only applied to shared subscription. /// With other subscription types, the messages will still be delivered immediately. @@ -1191,7 +1207,18 @@ impl<'a, T: SerializeMessage + Sized, Exe: Executor> MessageBuilder<'a, T, Exe> /// sends the message through the producer that created it #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))] - pub async fn send_non_blocking(self) -> Result { + pub async fn send_non_blocking(mut self) -> Result { + if let Some(txn) = self.txn.as_mut() { + // Register the partitions that will be a part of this transaction, if they + // are not already registered + let partitions = self + .producer + .partitions() + .unwrap_or(vec![self.producer.topic().to_owned()]); + + txn.register_produced_partitions(partitions).await?; + } + let MessageBuilder { producer, properties, @@ -1200,6 +1227,7 @@ impl<'a, T: SerializeMessage + Sized, Exe: Executor> MessageBuilder<'a, T, Exe> content, deliver_at_time, event_time, + txn, } = self; let mut message = T::serialize_message(content)?; @@ -1207,6 +1235,7 @@ impl<'a, T: SerializeMessage + Sized, Exe: Executor> MessageBuilder<'a, T, Exe> message.partition_key = partition_key; message.ordering_key = ordering_key; message.event_time = event_time; + message.txn_id = txn.map(|txn| txn.id()); let mut producer_message: ProducerMessage = message.into(); producer_message.deliver_at_time = deliver_at_time; diff --git a/src/reader.rs b/src/reader.rs index 3f71b6a..243f7b9 100644 --- a/src/reader.rs +++ b/src/reader.rs @@ -55,9 +55,11 @@ impl Stream for Reader { let message_id = msg.message_id().clone(); this.state = Some(State::PollingAck( msg, - Box::pin( - async move { acker.send(EngineMessage::Ack(message_id, false)).await }, - ), + Box::pin(async move { + acker + .send(EngineMessage::Ack(message_id, None, false)) + .await + }), )); Pin::new(this).poll_next(cx) } diff --git a/src/transaction/coord.rs b/src/transaction/coord.rs new file mode 100644 index 0000000..8740a4d --- /dev/null +++ b/src/transaction/coord.rs @@ -0,0 +1,128 @@ +use std::{ + collections::BTreeMap, + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, RwLock, + }, +}; + +use super::{meta_store_handler::TransactionMetaStoreHandler, TransactionId}; +use crate::{ + connection_manager::ConnectionManager, error::TransactionError, proto::TxnAction, Error, + Executor, Pulsar, +}; + +pub struct TransactionCoordinatorClient { + handlers: BTreeMap>, + epoch: AtomicU64, +} + +impl TransactionCoordinatorClient { + #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))] + pub async fn new( + client: Pulsar, + connection_mgr: Arc>, + ) -> Result { + debug!("Starting transaction coordinator client"); + + let partitions = client + .lookup_partitioned_topic_number(super::TC_ASSIGN_TOPIC) + .await?; + + if partitions == 0 { + error!("The transaction coordinator is not enabled or has not been initialized"); + return Err(Error::Transaction(TransactionError::CoordinatorNotFound)); + } + + let mut handlers = BTreeMap::new(); + + for partition in 0..partitions { + let connection = super::get_tc_connection(&client, &connection_mgr, partition).await?; + + let handler = TransactionMetaStoreHandler::new( + client.clone(), + Arc::clone(&connection_mgr), + RwLock::new(connection), + Arc::clone(&connection_mgr.executor), + partition, + ) + .await?; + + handlers.insert(partition.into(), handler); + } + + Ok(Self { + handlers, + epoch: AtomicU64::new(0), + }) + } + + fn get_handler_for_txn( + &self, + txn_id: TransactionId, + ) -> Result<&TransactionMetaStoreHandler, Error> { + self.handlers + .get(&txn_id.most_bits()) + .ok_or(Error::Transaction(TransactionError::MetaHandlerNotFound( + txn_id, + ))) + } + + #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))] + pub async fn commit_txn(&self, txn_id: TransactionId) -> Result<(), Error> { + let handler = self.get_handler_for_txn(txn_id)?; + + handler.end_txn(txn_id, TxnAction::Commit).await + } + + #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))] + pub async fn abort_txn(&self, txn_id: TransactionId) -> Result<(), Error> { + let handler = self.get_handler_for_txn(txn_id)?; + + handler.end_txn(txn_id, TxnAction::Abort).await + } + + #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))] + pub async fn new_txn(&self, timeout: Option) -> Result { + self.next_handler().new_txn(timeout).await + } + + #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))] + pub async fn add_publish_partitions_to_txn( + &self, + txn_id: TransactionId, + partitions: Vec, + ) -> Result<(), Error> { + let handler = self.get_handler_for_txn(txn_id)?; + + handler + .add_publish_partitions_to_txn(txn_id, partitions) + .await + } + + #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))] + pub async fn add_subscription_to_txn( + &self, + txn_id: TransactionId, + topic: String, + subscription: String, + ) -> Result<(), Error> { + let handler = self.get_handler_for_txn(txn_id)?; + + handler + .add_subscription_to_txn(txn_id, topic, subscription) + .await + } + + /// Get the next transaction meta store handler to use. This effectively + /// load balances across handlers in a round-robin fashion. + fn next_handler(&self) -> &TransactionMetaStoreHandler { + let index = self + .epoch + .fetch_add(1, Ordering::Relaxed) + .checked_rem(self.handlers.len() as u64) + .expect("epoch overflow"); + + self.handlers.get(&index).expect("handler not found") + } +} diff --git a/src/transaction/meta_store_handler.rs b/src/transaction/meta_store_handler.rs new file mode 100644 index 0000000..fecaca5 --- /dev/null +++ b/src/transaction/meta_store_handler.rs @@ -0,0 +1,200 @@ +use std::{ + future::Future, + sync::{Arc, RwLock}, + time::Duration, +}; + +use rand::Rng; + +use super::TransactionId; +use crate::{ + connection::Connection, + connection_manager::ConnectionManager, + error::{server_error, ConnectionError}, + proto::{ServerError, TxnAction}, + Error, Executor, Pulsar, +}; + +pub struct TransactionMetaStoreHandler { + client: Pulsar, + connection_mgr: Arc>, + connection: RwLock>>, + executor: Arc, + transaction_coordinator_id: u32, +} + +impl TransactionMetaStoreHandler { + const OP_MAX_RETRIES: u32 = 3; + const OP_MIN_BACKOFF: Duration = Duration::from_millis(10); + const OP_MAX_BACKOFF: Duration = Duration::from_millis(1000); + + #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))] + pub async fn new( + client: Pulsar, + connection_mgr: Arc>, + connection: RwLock>>, + executor: Arc, + transaction_coordinator_id: u32, + ) -> Result { + let handler = Self { + client, + connection_mgr, + connection, + executor, + transaction_coordinator_id, + }; + + handler.tc_client_connect().await?; + + Ok(handler) + } + + pub async fn reconnect(&self) -> Result<(), Error> { + // `get_tc_connection` will reconnect and swap the connection associated with the address + // if the connection is not healthy. If the topic lookup points to the same broker + // and the connection is healthy, this is effectively a no-op. In the case we discover + // the TC is on a different broker, we'll swap the connection to the new broker's address. + *self.connection.write().expect("poisoned lock") = super::get_tc_connection( + &self.client, + &self.connection_mgr, + self.transaction_coordinator_id, + ) + .await?; + + Ok(()) + } + + #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))] + async fn extract_with_retry(&self, op: F, extract: E) -> Result + where + R: Future>, + F: Fn() -> R, + E: Fn(T) -> ((Option, Option), S), + { + let mut retries = 0u32; + + loop { + let res = op().await?; + let ((code, message), extracted) = extract(res); + + if let Some(err) = code { + let server_error = server_error(err); + + match server_error { + // TC Not Found errors shouldn't bubble up to the user before + // we've tried reconnecting a few times + error @ Some(ServerError::TransactionCoordinatorNotFound) => { + if retries >= Self::OP_MAX_RETRIES { + return Err(ConnectionError::PulsarError(error, message).into()); + } + + let jitter = Duration::from_millis(rand::thread_rng().gen_range(0..500)); + let backoff = std::cmp::min( + Self::OP_MIN_BACKOFF * 2u32.saturating_pow(retries) + jitter, + Self::OP_MAX_BACKOFF, + ); + retries += 1; + + error!( + "Received error: {:?}. Retrying in {} ms", + error, + backoff.as_millis() + ); + + self.executor.delay(backoff).await; + self.reconnect().await?; + continue; + } + error => { + return Err(ConnectionError::PulsarError(error, message).into()); + } + } + } + + return Ok(extracted); + } + } + + #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))] + pub async fn tc_client_connect(&self) -> Result<(), Error> { + let connection = self.connection.read().expect("poisoned lock").clone(); + let transaction_coordinator_id = self.transaction_coordinator_id; + let op = || { + connection + .sender() + .tc_client_connect_request(transaction_coordinator_id.into()) + }; + + self.extract_with_retry(op, |res| ((res.error, res.message), ())) + .await + } + + #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))] + pub async fn add_publish_partitions_to_txn( + &self, + txn_id: TransactionId, + partitions: Vec, + ) -> Result<(), Error> { + let connection = self.connection.read().expect("poisoned lock").clone(); + let op = || { + connection + .sender() + .add_partition_to_txn(txn_id, partitions.clone()) + }; + + self.extract_with_retry(op, |res| ((res.error, res.message), ())) + .await + } + + #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))] + pub async fn add_subscription_to_txn( + &self, + txn_id: TransactionId, + topic: String, + subscription: String, + ) -> Result<(), Error> { + let connection = self.connection.read().expect("poisoned lock").clone(); + let op = || { + connection + .sender() + .add_subscription_to_txn(txn_id, topic.clone(), subscription.clone()) + }; + + self.extract_with_retry(op, |res| ((res.error, res.message), ())) + .await + } + + #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))] + pub async fn new_txn(&self, timeout: Option) -> Result { + let connection = self.connection.read().expect("poisoned lock").clone(); + let transaction_coordinator_id = self.transaction_coordinator_id; + let op = || { + connection + .sender() + .new_txn(transaction_coordinator_id.into(), timeout) + }; + + let (txnid_most_bits, txnid_least_bits) = self + .extract_with_retry(op, |res| { + ( + (res.error, res.message), + (res.txnid_most_bits, res.txnid_least_bits), + ) + }) + .await?; + + Ok(TransactionId::new( + txnid_most_bits.expect("should be present"), + txnid_least_bits.expect("should be present"), + )) + } + + #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))] + pub async fn end_txn(&self, txn_id: TransactionId, txn_action: TxnAction) -> Result<(), Error> { + let connection = self.connection.read().expect("poisoned lock").clone(); + let op = || connection.sender().end_txn(txn_id, txn_action); + + self.extract_with_retry(op, |res| ((res.error, res.message), ())) + .await + } +} diff --git a/src/transaction/mod.rs b/src/transaction/mod.rs new file mode 100644 index 0000000..d116ca9 --- /dev/null +++ b/src/transaction/mod.rs @@ -0,0 +1,391 @@ +pub(crate) mod coord; +mod meta_store_handler; +// pub mod transaction; + +use std::{ + collections::BTreeSet, + sync::{Arc, Mutex}, + time::Duration, +}; + +use coord::TransactionCoordinatorClient; +use futures::lock::Mutex as AsyncMutex; + +pub use crate::Executor; +use crate::{ + connection::Connection, + connection_manager::ConnectionManager, + error::{ConnectionError, TransactionError}, + proto::{ProtocolVersion, ServerError}, + Error, Pulsar, +}; + +const TC_ASSIGN_TOPIC: &str = "persistent://pulsar/system/transaction_coordinator_assign"; + +fn get_tc_assign_topic_name(partition: u32) -> String { + format!("{}-partition-{}", TC_ASSIGN_TOPIC, partition) +} + +async fn get_tc_connection( + client: &Pulsar, + connection_mgr: &Arc>, + transaction_coordinator_id: u32, +) -> Result>, Error> { + let address = client + .lookup_topic(get_tc_assign_topic_name(transaction_coordinator_id)) + .await?; + + let connection = connection_mgr.get_connection(&address).await?; + let remote_endpoint_protocol_version = connection.protocol_version(); + + // If the broker isn't speaking at least protocol version 19, we should fail with an + // error. + if remote_endpoint_protocol_version < ProtocolVersion::V19 { + error!( + "The remote endpoint associated with the connection is speaking protocol version {} which is not supported", + remote_endpoint_protocol_version.as_str_name() + ); + + return Err(ConnectionError::UnsupportedProtocolVersion( + remote_endpoint_protocol_version as u32, + ) + .into()); + } + + Ok(connection) +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +/// The state of a transaction. +pub enum State { + /// When a transaction is in the `OPEN` state, messages can be produced and acked with this + /// transaction. + /// + /// When a transaction is in the `OPEN` state, it can commit or abort. + Open, + /// When a client invokes a commit, the transaction state is changed from `OPEN` to + /// `COMMITTING`. + Committing, + /// When a client invokes an abort, the transaction state is changed from `OPEN` to `ABORTING`. + Aborting, + /// When a client receives a response to a commit, the transaction state is changed from + /// `COMMITTING` to `COMMITTED`. + Committed, + /// When a client receives a response to an abort, the transaction state is changed from + /// `ABORTING` to `ABORTED`. + Aborted, + /// When a client invokes a commit or an abort, but a transaction does not exist in a + /// coordinator, then the state is changed to `ERROR`. + /// + /// When a client invokes a commit, but the transaction state in a coordinator is `ABORTED` or + /// `ABORTING`, then the state is changed to `ERROR`. + /// + /// When a client invokes an abort, but the transaction state in a coordinator is `COMMITTED` + /// or `COMMITTING`, then the state is changed to `ERROR`. + Error, + /// When a transaction is timed out and the transaction state is `OPEN`, + /// then the transaction state is changed from `OPEN` to `TIME_OUT`. + TimeOut, +} + +impl std::fmt::Display for State { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let state = match self { + State::Open => "Open", + State::Committing => "Committing", + State::Aborting => "Aborting", + State::Committed => "Committed", + State::Aborted => "Aborted", + State::Error => "Error", + State::TimeOut => "TimeOut", + }; + + write!(f, "{}", state) + } +} + +#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] +pub struct TransactionId { + /// The most significant 64 bits of this transaction id. + most_sig_bits: u64, + /// The least significant 64 bits of this transaction id. + least_sig_bits: u64, +} + +impl TransactionId { + pub fn new(most_sig_bits: u64, least_sig_bits: u64) -> Self { + Self { + most_sig_bits, + least_sig_bits, + } + } + + pub fn least_bits(&self) -> u64 { + self.least_sig_bits + } + + pub fn most_bits(&self) -> u64 { + self.most_sig_bits + } +} + +impl std::fmt::Display for TransactionId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "({},{})", self.most_sig_bits, self.least_sig_bits) + } +} + +pub struct TransactionBuilder { + exe: Arc, + tc_client: Arc>, + timeout: Option, +} + +impl TransactionBuilder { + const TRANSACTION_TIMEOUT_DEFAULT: u64 = 60_000; // 1 minute + + pub(crate) fn new(exe: Arc, tc_client: Arc>) -> Self { + Self { + exe, + tc_client, + timeout: None, + } + } + + /// Set the transaction timeout in milliseconds. + pub fn with_timeout(mut self, timeout: Duration) -> Self { + self.timeout = Some(timeout.as_millis()); + self + } + + #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))] + pub async fn build(self) -> Result, Error> { + let timeout: u64 = self + .timeout + .unwrap_or(Self::TRANSACTION_TIMEOUT_DEFAULT.into()) + .try_into() + .map_err(|_| Error::Transaction(TransactionError::InvalidTimeout))?; + + let txn_id = self.tc_client.new_txn(Some(timeout)).await?; + + Ok(Transaction::new(txn_id, self.tc_client, &self.exe, timeout)) + } +} + +#[derive(Clone)] +pub struct Transaction { + state: Arc>, + id: TransactionId, + tc_client: Arc>, + published_partitions: Arc>>, + registered_subscriptions: Arc>>, +} + +impl Transaction { + pub(self) fn new( + id: TransactionId, + tc_client: Arc>, + exe: &Arc, + timeout_ms: u64, + ) -> Self { + let state = Arc::new(Mutex::new(State::Open)); + // Spawn a best-effort task to timeout the transaction. + // Checking for timeouts from the client side is not 100% reliable, but it should + // give more information to the caller in cases where the transaction has definitely + // timed out. + let _ = exe.spawn({ + // We really only need a weak reference to the state here since we don't want + // this task keeping the state alive if the transaction has already been dropped. + let weak_state = Arc::downgrade(&state); + let timeout_fut = exe.delay(Duration::from_millis(timeout_ms)); + + Box::pin(async move { + timeout_fut.await; + + if let Some(state) = weak_state.upgrade() { + *state.lock().expect("poisoned lock") = State::TimeOut; + } + }) + }); + + Self { + state, + id, + tc_client, + published_partitions: Arc::new(AsyncMutex::new(BTreeSet::new())), + registered_subscriptions: Arc::new(AsyncMutex::new(BTreeSet::new())), + } + } + + #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))] + pub(crate) async fn register_produced_partitions( + &self, + partitions: Vec, + ) -> Result<(), Error> { + self.ensure_and_swap_state(State::Open, None)?; + + let mut published_partitions = self.published_partitions.lock().await; + + if partitions.iter().all(|p| published_partitions.contains(p)) { + return Ok(()); + } + + self.tc_client + .add_publish_partitions_to_txn(self.id, partitions.clone()) + .await?; + + debug!( + "Added publish partitions to txn {}: {:?}", + self.id, partitions + ); + + published_partitions.extend(partitions); + + Ok(()) + } + + #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))] + pub(crate) async fn register_acked_topic( + &self, + topic: String, + subscription: String, + ) -> Result<(), Error> { + self.ensure_and_swap_state(State::Open, None)?; + + let mut registered_subscriptions = self.registered_subscriptions.lock().await; + + if registered_subscriptions.contains(&(topic.clone(), subscription.clone())) { + return Ok(()); + } + + self.tc_client + .add_subscription_to_txn(self.id, topic.clone(), subscription.clone()) + .await?; + + registered_subscriptions.insert((topic, subscription)); + + Ok(()) + } + + /// Commit the transaction. + #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))] + pub async fn commit(&self) -> Result<(), Error> { + self.ensure_and_swap_state(State::Open, Some(State::Committing))?; + + if let Err(err) = self.tc_client.commit_txn(self.id).await { + error!("Unable to commit transaction: {}", err); + + match err { + Error::Connection(ConnectionError::PulsarError( + Some(err @ ServerError::TransactionNotFound), + _, + )) + | Error::Connection(ConnectionError::PulsarError( + Some(err @ ServerError::InvalidTxnStatus), + _, + )) => { + self.set_state(State::Error); + + if err == ServerError::TransactionNotFound { + error!("Transaction {} not found", self.id); + return Err(Error::Transaction(TransactionError::NotFound)); + } + } + Error::Connection(ConnectionError::PulsarError( + Some(ServerError::TransactionConflict), + _, + )) => { + warn!("Transaction conflict observed for txn {}", self.id); + return Err(Error::Transaction(TransactionError::Conflict)); + } + _ => (), + } + + return Err(err); + } + + self.set_state(State::Committed); + + Ok(()) + } + + /// Abort the transaction. + #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))] + pub async fn abort(&self) -> Result<(), Error> { + self.ensure_and_swap_state(State::Open, Some(State::Aborting))?; + + if let Err(err) = self.tc_client.abort_txn(self.id).await { + error!("Unable to abort transaction: {}", err); + + match err { + Error::Connection(ConnectionError::PulsarError( + Some(err @ ServerError::TransactionNotFound), + _, + )) + | Error::Connection(ConnectionError::PulsarError( + Some(err @ ServerError::InvalidTxnStatus), + _, + )) => { + self.set_state(State::Error); + + if err == ServerError::TransactionNotFound { + error!("Transaction {} not found", self.id); + return Err(Error::Transaction(TransactionError::NotFound)); + } + } + Error::Connection(ConnectionError::PulsarError( + Some(ServerError::TransactionConflict), + _, + )) => { + warn!("Transaction conflict observed for txn {}", self.id); + return Err(Error::Transaction(TransactionError::Conflict)); + } + _ => (), + } + + return Err(err); + } + + self.set_state(State::Aborted); + + Ok(()) + } + + fn ensure_and_swap_state( + &self, + expected_state: State, + new_state: Option, + ) -> Result<(), Error> { + let mut actual_state = self.state.lock().expect("poisoned lock"); + + if *actual_state != expected_state { + error!( + "Transaction {} is in state {}, expected state {}", + self.id, *actual_state, expected_state + ); + return Err(Error::Transaction(TransactionError::InvalidState( + *actual_state, + ))); + } + + if let Some(new_state) = new_state { + *actual_state = new_state; + } + + Ok(()) + } + + fn set_state(&self, state: State) { + *self.state.lock().expect("poisoned lock") = state; + } + + /// Get the transaction id. + pub fn id(&self) -> TransactionId { + self.id + } + + /// Get the state of the transaction. + pub fn state(&self) -> State { + *self.state.lock().expect("poisoned lock") + } +}