Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(pisa-proxy, sharding): Add merge sort by sharding column #310

Merged
merged 2 commits into from
Sep 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pisa-proxy/protocol/mysql/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

pub mod charset;
pub mod client;
mod column;
pub mod column;
pub mod err;
mod macros;
pub mod mysql_const;
Expand Down
6 changes: 4 additions & 2 deletions pisa-proxy/proxy/strategy/src/sharding_rewrite/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -650,10 +650,12 @@ impl ShardingRewrite {
) -> Vec<ShardingRewriteOutput> {
let mut output = vec![];
let mut group_changes = IndexMap::<usize, Vec<DatabaseChange>>::new();
let mut sharding_column = None;

for t in tables.iter() {
match t.1.table_strategy.as_ref().unwrap() {
crate::config::StrategyType::TableStrategyConfig(config) => {
sharding_column = Some(config.table_sharding_column.clone());
for idx in 0..config.sharding_count as u64 {
let target = self.change_table(t.2, "", idx);

Expand Down Expand Up @@ -684,7 +686,7 @@ impl ShardingRewrite {
changes: changes.into_iter().map(|x| RewriteChange::DatabaseChange(x)).collect(),
target_sql: target_sql.to_string(),
data_source: DataSource::Endpoint(ep),
sharding_column: None
sharding_column: sharding_column.clone(),
})
}

Expand Down Expand Up @@ -727,7 +729,7 @@ impl ShardingRewriter<ShardingRewriteInput> for ShardingRewrite {
fn rewrite(&mut self, mut input: ShardingRewriteInput) -> Self::Output {
self.set_raw_sql(input.raw_sql);
let meta = self.get_meta(&mut input.ast);
self.database_strategy(meta)
self.table_strategy(meta)
}
}

Expand Down
3 changes: 2 additions & 1 deletion pisa-proxy/runtime/mysql/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,5 @@ tracing-subscriber = "0.3.9"
tower = { version = "0.4.13" }
#mysql-macro = { path = "../macros" }
indexmap = "1.9.1"
lazy_static = "1.4.0"
lazy_static = "1.4.0"
rayon = "1.5"
109 changes: 84 additions & 25 deletions pisa-proxy/runtime/mysql/src/server/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,18 @@ use mysql_protocol::{
err::ProtocolError,
mysql_const::*,
server::codec::{make_eof_packet, CommonPacket, PacketSend},
util::{length_encode_int, is_eof},
util::{length_encode_int, is_eof}, row::{RowDataText, RowDataBinary, RowDataTyp, RowData},
};
use pisa_error::error::{Error, ErrorKind};
use strategy::sharding_rewrite::{DataSource, ShardingRewriteOutput};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::codec::{Decoder, Encoder};

use crate::{mysql::{ReqContext, STMT_ID}, transaction_fsm::check_get_conn};
use rayon::prelude::*;
use std::sync::Arc;
use mysql_protocol::column::ColumnInfo;
use mysql_protocol::column::Column;

#[derive(Debug, thiserror::Error)]
pub enum ExecuteError {
Expand Down Expand Up @@ -79,13 +83,16 @@ where
let mut conns = Self::shard_send_query(conns, &rewrite_outputs).await?;
let shards_length = conns.len();
let mut shard_streams = Vec::with_capacity(shards_length);

for conn in conns.iter_mut() {
shard_streams.push(ResultsetStream::new(conn.framed.as_mut()).fuse());
let s = ResultsetStream::new(conn.framed.as_mut());
shard_streams.push(s.fuse());
}

let mut merge_stream = MergeStream::new(shard_streams, shards_length);

Self::handle_shard_resultset(req, &mut merge_stream).await?;
let sharding_column = rewrite_outputs[0].sharding_column.clone();
Self::handle_shard_resultset(req, &mut merge_stream, sharding_column, false).await?;

if let Some(id) = curr_server_stmt_id {
let stmt_conns = curr_cached_stmt_id.into_iter().zip(conns.into_iter()).collect();
Expand All @@ -97,7 +104,7 @@ where
Ok(())
}

async fn handle_shard_resultset<'a>(req: &mut ReqContext<T, C>, merge_stream: &mut MergeStream<ResultsetStream<'a>>) -> Result<(), Error> {
async fn handle_shard_resultset<'a>(req: &mut ReqContext<T, C>, merge_stream: &mut MergeStream<ResultsetStream<'a>>, sharding_column: Option<String>, is_binary: bool) -> Result<(), Error> {
let header = merge_stream.next().await;
let header = if let Some(header) = Self::get_shard_one_data(header)? {
header.1
Expand All @@ -124,7 +131,7 @@ where
.codec_mut()
.encode(PacketSend::EncodeOffset(header[4..].into(), 0), &mut buf);

Self::get_columns(req, merge_stream, cols, &mut buf).await?;
let col_info = Self::get_columns(req, merge_stream, cols, &mut buf).await?;

// read eof
let _ = merge_stream.next().await;
Expand All @@ -135,7 +142,7 @@ where
.encode(PacketSend::EncodeOffset(make_eof_packet()[4..].into(), buf.len()), &mut buf);

// get rows
Self::get_rows(req, merge_stream, &mut buf).await?;
Self::get_rows(req, merge_stream, &mut buf, sharding_column, col_info, is_binary).await?;

let _ = req
.framed
Expand All @@ -150,23 +157,60 @@ where
req: &mut ReqContext<T, C>,
stream: &mut MergeStream<ResultsetStream<'a>>,
buf: &mut BytesMut,
sharding_column: Option<String>,
col_info: Arc<[ColumnInfo]>,
is_binary: bool,
) -> Result<(), Error> {
while let Some(mut chunk) = stream.next().await {
let _ = Self::check_single_chunk(&mut chunk)?;
for row in chunk.into_iter() {
if let Some(row) = row {
// We have ensured `c` is Ok(_) by above step
let row = row.unwrap();
if is_eof(&row) {
continue;
}
let row_data = match is_binary {
false => {
let row_data_text = RowDataText::new(col_info, &[][..]);
RowDataTyp::Text(row_data_text)
}
true => {
let row_data_binary = RowDataBinary::new(col_info, &[][..]);
RowDataTyp::Binary(row_data_binary)
},
};

let _ = req
.framed
.codec_mut()
.encode(PacketSend::EncodeOffset(row[4..].into(), buf.len()), buf);
while let Some(chunk) = stream.next().await {
let mut chunk = chunk.into_par_iter().filter_map(|x| {
if let Some(x) = x {
if let Ok(data) = &x {
if is_eof(data) {
return None
}
}
Some(x)
} else {
None
}
}).collect::<Result<Vec<_>, _>>().map_err(ErrorKind::from)?;

if is_binary {
if let Some(name) = &sharding_column {
chunk.par_sort_by_cached_key(|x| {
let mut row_data = row_data.clone();
row_data.with_buf(&x[4..]);
row_data.decode_with_name::<u64>(&name).unwrap()
})
}
} else {
if let Some(name) = &sharding_column {
chunk.par_sort_by_cached_key(|x| {
let mut row_data = row_data.clone();
row_data.with_buf(&x[4..]);
let value = row_data.decode_with_name::<String>(&name).unwrap().unwrap();
value.parse::<u64>().unwrap()
})
}
}

for row in chunk.iter() {
let _ = req
.framed
.codec_mut()
.encode(PacketSend::EncodeOffset(row[4..].into(), buf.len()), buf);
}
}

Ok(())
Expand All @@ -176,8 +220,10 @@ where
stream: &mut MergeStream<ResultsetStream<'a>>,
column_length: u64,
buf: &mut BytesMut,
) -> Result<(), Error> {
) -> Result<Arc<[ColumnInfo]>, Error> {
let mut col_buf = Vec::with_capacity(100);
let mut idx: Option<usize> = None;

for _ in 0..column_length {
let data = stream.next().await;
let data = if let Some(idx) = idx {
Expand All @@ -186,10 +232,10 @@ where
if let Some(data) = data {
data.map_err(ErrorKind::Protocol)?
} else {
return Ok(());
unreachable!()
}
} else {
return Ok(());
unreachable!()
}
} else {
// find index from chunk
Expand All @@ -198,17 +244,21 @@ where
idx = Some(data.0);
data.1
} else {
return Ok(());
unreachable!()
}
};

col_buf.extend_from_slice(&data[..]);
let _ = req
.framed
.codec_mut()
.encode(PacketSend::EncodeOffset(data[4..].into(), buf.len()), buf);
}

Ok(())
let col_info = col_buf.as_slice().decode_columns();
let arc_col_info: Arc<[ColumnInfo]> = col_info.into_boxed_slice().into();

Ok(arc_col_info)
}

fn get_shard_one_data(
Expand Down Expand Up @@ -358,8 +408,9 @@ where
shard_streams.push(ResultsetStream::new(conn.1.framed.as_mut()).fuse());
}

let sharding_column = req.stmt_cache.lock().get_sharding_column(stmt_id);
let mut merge_stream = MergeStream::new(shard_streams, shard_length);
Self::handle_shard_resultset(req, &mut merge_stream).await?;
Self::handle_shard_resultset(req, &mut merge_stream, sharding_column, true).await?;

req.stmt_cache.lock().put_all(stmt_id, conns);

Expand Down Expand Up @@ -392,3 +443,11 @@ where
Ok(sended_conns)
}
}

#[cfg(test)]
mod test {
#[test]
fn test() {
assert_eq!(1, 1);
}
}
4 changes: 4 additions & 0 deletions pisa-proxy/runtime/mysql/src/server/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,13 +143,17 @@ where
}

route_sharding(input_typ, raw_sql, req.route_strategy.clone(), &mut rewrite_outputs);
let sharding_column = rewrite_outputs[0].sharding_column.clone();

let (mut stmts, shard_conns) = Executor::shard_prepare_executor(req, rewrite_outputs, attrs, is_get_conn).await?;
for i in stmts.iter().zip(shard_conns.into_iter()) {
req.stmt_cache.lock().put(stmt_id, i.0.stmt_id, i.1)
}
let mut stmt = stmts.remove(0);
stmt.stmt_id = stmt_id;


req.stmt_cache.lock().put_sharding_column(stmt_id, sharding_column);
Self::prepare_stmt(req, stmt).await?;

Ok(())
Expand Down
18 changes: 17 additions & 1 deletion pisa-proxy/runtime/mysql/src/server/stmt_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,15 @@ struct Entry {
pub struct StmtCache {
// key is generated id by pisa, value is returnd stmt id from client
cache: IndexMap<u32, Vec<Entry>>,
sharding_column_cache: IndexMap<u32, Option<String>>
}

impl StmtCache {
pub fn new() -> Self {
Self {
cache: IndexMap::new() }
cache: IndexMap::new(),
sharding_column_cache: IndexMap::new(),
}
}

pub fn put(&mut self, server_stmt_id: u32, stmt_id: u32, conn: PoolConn<ClientConn>) {
Expand Down Expand Up @@ -94,4 +97,17 @@ impl StmtCache {
pub fn remove(&mut self, server_stmt_id: u32) {
self.cache.remove(&server_stmt_id);
}

pub fn put_sharding_column(&mut self, server_stmt_id: u32, sharding_column: Option<String>) {
let _ = self.sharding_column_cache.insert(server_stmt_id, sharding_column);
}

pub fn get_sharding_column(&mut self, server_stmt_id: u32) -> Option<String> {
let name = self.sharding_column_cache.get(&server_stmt_id);
if let Some(name) = name {
name.clone()
} else {
None
}
}
}
22 changes: 11 additions & 11 deletions pisa-proxy/runtime/mysql/src/transaction_fsm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ pub fn query_rewrite(
changes: vec![],
target_sql: raw_sql.clone(),
data_source: strategy::sharding_rewrite::DataSource::Endpoint(x.clone()),
sharding_column: None,
})
.collect::<Vec<_>>()
};
Expand Down Expand Up @@ -527,30 +528,29 @@ pub fn build_conn_attrs(sess: &ServerHandshakeCodec) -> Vec<SessionAttr> {
mod test {
use super::*;

#[tokio::test]
async fn test_trigger() {
let lb = Arc::new(tokio::sync::Mutex::new(RouteStrategy::None));
let mut tsm = TransFsm::new_trans_fsm(lb, Pool::new(1));
#[test]
fn test_trigger() {
let mut tsm = TransFsm::new(Pool::new(1));
tsm.current_state = TransState::TransUseState;
let _ = tsm.trigger(TransEventName::QueryEvent, RouteInput::None).await;
let _ = tsm.trigger(TransEventName::QueryEvent);
assert_eq!(tsm.current_state, TransState::TransUseState);
assert_eq!(tsm.current_event, TransEventName::QueryEvent);
let _ = tsm.trigger(TransEventName::SetSessionEvent, RouteInput::None).await;
let _ = tsm.trigger(TransEventName::SetSessionEvent);
assert_eq!(tsm.current_state, TransState::TransSetSessionState);
assert_eq!(tsm.current_event, TransEventName::SetSessionEvent);
let _ = tsm.trigger(TransEventName::StartEvent, RouteInput::None).await;
let _ = tsm.trigger(TransEventName::StartEvent);
assert_eq!(tsm.current_state, TransState::TransStartState);
assert_eq!(tsm.current_event, TransEventName::StartEvent);
let _ = tsm.trigger(TransEventName::PrepareEvent, RouteInput::None).await;
let _ = tsm.trigger(TransEventName::PrepareEvent);
assert_eq!(tsm.current_state, TransState::TransPrepareState);
assert_eq!(tsm.current_event, TransEventName::PrepareEvent);
let _ = tsm.trigger(TransEventName::SendLongDataEvent, RouteInput::None).await;
let _ = tsm.trigger(TransEventName::SendLongDataEvent);
assert_eq!(tsm.current_state, TransState::TransPrepareState);
assert_eq!(tsm.current_event, TransEventName::SendLongDataEvent);
let _ = tsm.trigger(TransEventName::ExecuteEvent, RouteInput::None).await;
let _ = tsm.trigger(TransEventName::ExecuteEvent);
assert_eq!(tsm.current_state, TransState::TransPrepareState);
assert_eq!(tsm.current_event, TransEventName::ExecuteEvent);
let _ = tsm.trigger(TransEventName::CommitRollBackEvent, RouteInput::None).await;
let _ = tsm.trigger(TransEventName::CommitRollBackEvent);
assert_eq!(tsm.current_state, TransState::TransDummyState);
assert_eq!(tsm.current_event, TransEventName::CommitRollBackEvent);
}
Expand Down