From db615033638608e1bfac8c7f0df91244617c2616 Mon Sep 17 00:00:00 2001 From: xuanyuan300 Date: Mon, 19 Sep 2022 23:47:16 +0800 Subject: [PATCH 1/2] feat(pisa-proxy, sharding): WIP: fix merge sort Signed-off-by: xuanyuan300 --- pisa-proxy/protocol/mysql/src/lib.rs | 2 +- pisa-proxy/runtime/mysql/Cargo.toml | 3 +- .../runtime/mysql/src/server/executor.rs | 109 ++++++++++++++---- pisa-proxy/runtime/mysql/src/server/server.rs | 4 + .../runtime/mysql/src/server/stmt_cache.rs | 18 ++- .../runtime/mysql/src/transaction_fsm.rs | 22 ++-- 6 files changed, 119 insertions(+), 39 deletions(-) diff --git a/pisa-proxy/protocol/mysql/src/lib.rs b/pisa-proxy/protocol/mysql/src/lib.rs index 88c68f59..a0783242 100644 --- a/pisa-proxy/protocol/mysql/src/lib.rs +++ b/pisa-proxy/protocol/mysql/src/lib.rs @@ -14,7 +14,7 @@ pub mod charset; pub mod client; -mod column; +pub mod column; pub mod err; mod macros; pub mod mysql_const; diff --git a/pisa-proxy/runtime/mysql/Cargo.toml b/pisa-proxy/runtime/mysql/Cargo.toml index c433c3af..3128f64d 100644 --- a/pisa-proxy/runtime/mysql/Cargo.toml +++ b/pisa-proxy/runtime/mysql/Cargo.toml @@ -34,4 +34,5 @@ tracing-subscriber = "0.3.9" tower = { version = "0.4.13" } #mysql-macro = { path = "../macros" } indexmap = "1.9.1" -lazy_static = "1.4.0" \ No newline at end of file +lazy_static = "1.4.0" +rayon = "1.5" \ No newline at end of file diff --git a/pisa-proxy/runtime/mysql/src/server/executor.rs b/pisa-proxy/runtime/mysql/src/server/executor.rs index 2e9b04e8..e9db5ba2 100644 --- a/pisa-proxy/runtime/mysql/src/server/executor.rs +++ b/pisa-proxy/runtime/mysql/src/server/executor.rs @@ -25,7 +25,7 @@ use mysql_protocol::{ err::ProtocolError, mysql_const::*, server::codec::{make_eof_packet, CommonPacket, PacketSend}, - util::{length_encode_int, is_eof}, + util::{length_encode_int, is_eof}, row::{RowDataText, RowDataBinary, RowDataTyp, RowData}, }; use pisa_error::error::{Error, ErrorKind}; use strategy::sharding_rewrite::{DataSource, ShardingRewriteOutput}; @@ -33,6 +33,10 @@ use tokio::io::{AsyncRead, AsyncWrite}; use tokio_util::codec::{Decoder, Encoder}; use crate::{mysql::{ReqContext, STMT_ID}, transaction_fsm::check_get_conn}; +use rayon::prelude::*; +use std::sync::Arc; +use mysql_protocol::column::ColumnInfo; +use mysql_protocol::column::Column; #[derive(Debug, thiserror::Error)] pub enum ExecuteError { @@ -79,13 +83,16 @@ where let mut conns = Self::shard_send_query(conns, &rewrite_outputs).await?; let shards_length = conns.len(); let mut shard_streams = Vec::with_capacity(shards_length); + for conn in conns.iter_mut() { - shard_streams.push(ResultsetStream::new(conn.framed.as_mut()).fuse()); + let s = ResultsetStream::new(conn.framed.as_mut()); + shard_streams.push(s.fuse()); } let mut merge_stream = MergeStream::new(shard_streams, shards_length); - Self::handle_shard_resultset(req, &mut merge_stream).await?; + let sharding_column = rewrite_outputs[0].sharding_column.clone(); + Self::handle_shard_resultset(req, &mut merge_stream, sharding_column, false).await?; if let Some(id) = curr_server_stmt_id { let stmt_conns = curr_cached_stmt_id.into_iter().zip(conns.into_iter()).collect(); @@ -97,7 +104,7 @@ where Ok(()) } - async fn handle_shard_resultset<'a>(req: &mut ReqContext, merge_stream: &mut MergeStream>) -> Result<(), Error> { + async fn handle_shard_resultset<'a>(req: &mut ReqContext, merge_stream: &mut MergeStream>, sharding_column: Option, is_binary: bool) -> Result<(), Error> { let header = merge_stream.next().await; let header = if let Some(header) = Self::get_shard_one_data(header)? { header.1 @@ -124,7 +131,7 @@ where .codec_mut() .encode(PacketSend::EncodeOffset(header[4..].into(), 0), &mut buf); - Self::get_columns(req, merge_stream, cols, &mut buf).await?; + let col_info = Self::get_columns(req, merge_stream, cols, &mut buf).await?; // read eof let _ = merge_stream.next().await; @@ -135,7 +142,7 @@ where .encode(PacketSend::EncodeOffset(make_eof_packet()[4..].into(), buf.len()), &mut buf); // get rows - Self::get_rows(req, merge_stream, &mut buf).await?; + Self::get_rows(req, merge_stream, &mut buf, sharding_column, col_info, is_binary).await?; let _ = req .framed @@ -150,23 +157,60 @@ where req: &mut ReqContext, stream: &mut MergeStream>, buf: &mut BytesMut, + sharding_column: Option, + col_info: Arc<[ColumnInfo]>, + is_binary: bool, ) -> Result<(), Error> { - while let Some(mut chunk) = stream.next().await { - let _ = Self::check_single_chunk(&mut chunk)?; - for row in chunk.into_iter() { - if let Some(row) = row { - // We have ensured `c` is Ok(_) by above step - let row = row.unwrap(); - if is_eof(&row) { - continue; - } + let row_data = match is_binary { + false => { + let row_data_text = RowDataText::new(col_info, &[][..]); + RowDataTyp::Text(row_data_text) + } + true => { + let row_data_binary = RowDataBinary::new(col_info, &[][..]); + RowDataTyp::Binary(row_data_binary) + }, + }; - let _ = req - .framed - .codec_mut() - .encode(PacketSend::EncodeOffset(row[4..].into(), buf.len()), buf); + while let Some(chunk) = stream.next().await { + let mut chunk = chunk.into_par_iter().filter_map(|x| { + if let Some(x) = x { + if let Ok(data) = &x { + if is_eof(data) { + return None + } + } + Some(x) + } else { + None + } + }).collect::, _>>().map_err(ErrorKind::from)?; + + if is_binary { + if let Some(name) = &sharding_column { + chunk.par_sort_by_cached_key(|x| { + let mut row_data = row_data.clone(); + row_data.with_buf(&x[4..]); + row_data.decode_with_name::(&name).unwrap() + }) + } + } else { + if let Some(name) = &sharding_column { + chunk.par_sort_by_cached_key(|x| { + let mut row_data = row_data.clone(); + row_data.with_buf(&x[4..]); + let value = row_data.decode_with_name::(&name).unwrap().unwrap(); + value.parse::().unwrap() + }) } } + + for row in chunk.iter() { + let _ = req + .framed + .codec_mut() + .encode(PacketSend::EncodeOffset(row[4..].into(), buf.len()), buf); + } } Ok(()) @@ -176,8 +220,10 @@ where stream: &mut MergeStream>, column_length: u64, buf: &mut BytesMut, - ) -> Result<(), Error> { + ) -> Result, Error> { + let mut col_buf = Vec::with_capacity(100); let mut idx: Option = None; + for _ in 0..column_length { let data = stream.next().await; let data = if let Some(idx) = idx { @@ -186,10 +232,10 @@ where if let Some(data) = data { data.map_err(ErrorKind::Protocol)? } else { - return Ok(()); + unreachable!() } } else { - return Ok(()); + unreachable!() } } else { // find index from chunk @@ -198,17 +244,21 @@ where idx = Some(data.0); data.1 } else { - return Ok(()); + unreachable!() } }; + col_buf.extend_from_slice(&data[..]); let _ = req .framed .codec_mut() .encode(PacketSend::EncodeOffset(data[4..].into(), buf.len()), buf); } - Ok(()) + let col_info = col_buf.as_slice().decode_columns(); + let arc_col_info: Arc<[ColumnInfo]> = col_info.into_boxed_slice().into(); + + Ok(arc_col_info) } fn get_shard_one_data( @@ -358,8 +408,9 @@ where shard_streams.push(ResultsetStream::new(conn.1.framed.as_mut()).fuse()); } + let sharding_column = req.stmt_cache.lock().get_sharding_column(stmt_id); let mut merge_stream = MergeStream::new(shard_streams, shard_length); - Self::handle_shard_resultset(req, &mut merge_stream).await?; + Self::handle_shard_resultset(req, &mut merge_stream, sharding_column, true).await?; req.stmt_cache.lock().put_all(stmt_id, conns); @@ -392,3 +443,11 @@ where Ok(sended_conns) } } + +#[cfg(test)] +mod test { + #[test] + fn test() { + assert_eq!(1, 1); + } +} diff --git a/pisa-proxy/runtime/mysql/src/server/server.rs b/pisa-proxy/runtime/mysql/src/server/server.rs index e9853624..7a109aab 100644 --- a/pisa-proxy/runtime/mysql/src/server/server.rs +++ b/pisa-proxy/runtime/mysql/src/server/server.rs @@ -143,6 +143,8 @@ where } route_sharding(input_typ, raw_sql, req.route_strategy.clone(), &mut rewrite_outputs); + let sharding_column = rewrite_outputs[0].sharding_column.clone(); + let (mut stmts, shard_conns) = Executor::shard_prepare_executor(req, rewrite_outputs, attrs, is_get_conn).await?; for i in stmts.iter().zip(shard_conns.into_iter()) { req.stmt_cache.lock().put(stmt_id, i.0.stmt_id, i.1) @@ -150,6 +152,8 @@ where let mut stmt = stmts.remove(0); stmt.stmt_id = stmt_id; + + req.stmt_cache.lock().put_sharding_column(stmt_id, sharding_column); Self::prepare_stmt(req, stmt).await?; Ok(()) diff --git a/pisa-proxy/runtime/mysql/src/server/stmt_cache.rs b/pisa-proxy/runtime/mysql/src/server/stmt_cache.rs index 769e1f99..debba239 100644 --- a/pisa-proxy/runtime/mysql/src/server/stmt_cache.rs +++ b/pisa-proxy/runtime/mysql/src/server/stmt_cache.rs @@ -24,12 +24,15 @@ struct Entry { pub struct StmtCache { // key is generated id by pisa, value is returnd stmt id from client cache: IndexMap>, + sharding_column_cache: IndexMap> } impl StmtCache { pub fn new() -> Self { Self { - cache: IndexMap::new() } + cache: IndexMap::new(), + sharding_column_cache: IndexMap::new(), + } } pub fn put(&mut self, server_stmt_id: u32, stmt_id: u32, conn: PoolConn) { @@ -94,4 +97,17 @@ impl StmtCache { pub fn remove(&mut self, server_stmt_id: u32) { self.cache.remove(&server_stmt_id); } + + pub fn put_sharding_column(&mut self, server_stmt_id: u32, sharding_column: Option) { + let _ = self.sharding_column_cache.insert(server_stmt_id, sharding_column); + } + + pub fn get_sharding_column(&mut self, server_stmt_id: u32) -> Option { + let name = self.sharding_column_cache.get(&server_stmt_id); + if let Some(name) = name { + name.clone() + } else { + None + } + } } diff --git a/pisa-proxy/runtime/mysql/src/transaction_fsm.rs b/pisa-proxy/runtime/mysql/src/transaction_fsm.rs index 70e9280d..b887d27b 100644 --- a/pisa-proxy/runtime/mysql/src/transaction_fsm.rs +++ b/pisa-proxy/runtime/mysql/src/transaction_fsm.rs @@ -99,6 +99,7 @@ pub fn query_rewrite( changes: vec![], target_sql: raw_sql.clone(), data_source: strategy::sharding_rewrite::DataSource::Endpoint(x.clone()), + sharding_column: None, }) .collect::>() }; @@ -527,30 +528,29 @@ pub fn build_conn_attrs(sess: &ServerHandshakeCodec) -> Vec { mod test { use super::*; - #[tokio::test] - async fn test_trigger() { - let lb = Arc::new(tokio::sync::Mutex::new(RouteStrategy::None)); - let mut tsm = TransFsm::new_trans_fsm(lb, Pool::new(1)); + #[test] + fn test_trigger() { + let mut tsm = TransFsm::new(Pool::new(1)); tsm.current_state = TransState::TransUseState; - let _ = tsm.trigger(TransEventName::QueryEvent, RouteInput::None).await; + let _ = tsm.trigger(TransEventName::QueryEvent); assert_eq!(tsm.current_state, TransState::TransUseState); assert_eq!(tsm.current_event, TransEventName::QueryEvent); - let _ = tsm.trigger(TransEventName::SetSessionEvent, RouteInput::None).await; + let _ = tsm.trigger(TransEventName::SetSessionEvent); assert_eq!(tsm.current_state, TransState::TransSetSessionState); assert_eq!(tsm.current_event, TransEventName::SetSessionEvent); - let _ = tsm.trigger(TransEventName::StartEvent, RouteInput::None).await; + let _ = tsm.trigger(TransEventName::StartEvent); assert_eq!(tsm.current_state, TransState::TransStartState); assert_eq!(tsm.current_event, TransEventName::StartEvent); - let _ = tsm.trigger(TransEventName::PrepareEvent, RouteInput::None).await; + let _ = tsm.trigger(TransEventName::PrepareEvent); assert_eq!(tsm.current_state, TransState::TransPrepareState); assert_eq!(tsm.current_event, TransEventName::PrepareEvent); - let _ = tsm.trigger(TransEventName::SendLongDataEvent, RouteInput::None).await; + let _ = tsm.trigger(TransEventName::SendLongDataEvent); assert_eq!(tsm.current_state, TransState::TransPrepareState); assert_eq!(tsm.current_event, TransEventName::SendLongDataEvent); - let _ = tsm.trigger(TransEventName::ExecuteEvent, RouteInput::None).await; + let _ = tsm.trigger(TransEventName::ExecuteEvent); assert_eq!(tsm.current_state, TransState::TransPrepareState); assert_eq!(tsm.current_event, TransEventName::ExecuteEvent); - let _ = tsm.trigger(TransEventName::CommitRollBackEvent, RouteInput::None).await; + let _ = tsm.trigger(TransEventName::CommitRollBackEvent); assert_eq!(tsm.current_state, TransState::TransDummyState); assert_eq!(tsm.current_event, TransEventName::CommitRollBackEvent); } From 8341eec90a3d92d11125dc7466ccbf7ee75245f0 Mon Sep 17 00:00:00 2001 From: xuanyuan300 Date: Tue, 20 Sep 2022 09:15:33 +0800 Subject: [PATCH 2/2] fix(pisa-proxy, sharding): fix table_iproduct sharding_column Signed-off-by: xuanyuan300 --- pisa-proxy/proxy/strategy/src/sharding_rewrite/mod.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pisa-proxy/proxy/strategy/src/sharding_rewrite/mod.rs b/pisa-proxy/proxy/strategy/src/sharding_rewrite/mod.rs index 1cb15431..ff8638bc 100644 --- a/pisa-proxy/proxy/strategy/src/sharding_rewrite/mod.rs +++ b/pisa-proxy/proxy/strategy/src/sharding_rewrite/mod.rs @@ -650,10 +650,12 @@ impl ShardingRewrite { ) -> Vec { let mut output = vec![]; let mut group_changes = IndexMap::>::new(); + let mut sharding_column = None; for t in tables.iter() { match t.1.table_strategy.as_ref().unwrap() { crate::config::StrategyType::TableStrategyConfig(config) => { + sharding_column = Some(config.table_sharding_column.clone()); for idx in 0..config.sharding_count as u64 { let target = self.change_table(t.2, "", idx); @@ -684,7 +686,7 @@ impl ShardingRewrite { changes: changes.into_iter().map(|x| RewriteChange::DatabaseChange(x)).collect(), target_sql: target_sql.to_string(), data_source: DataSource::Endpoint(ep), - sharding_column: None + sharding_column: sharding_column.clone(), }) } @@ -727,7 +729,7 @@ impl ShardingRewriter for ShardingRewrite { fn rewrite(&mut self, mut input: ShardingRewriteInput) -> Self::Output { self.set_raw_sql(input.raw_sql); let meta = self.get_meta(&mut input.ast); - self.database_strategy(meta) + self.table_strategy(meta) } }