Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(sharding): update sharding return value #306

Merged
merged 2 commits into from
Sep 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions pisa-proxy/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@ impl RouteBalance for GenericRuleMatchInner {
mod test {
use endpoint::endpoint::Endpoint;
use loadbalance::balance::*;
use indexmap::IndexMap;

use super::RulesMatchBuilder;
use crate::{config::*, readwritesplitting::ReadWriteEndpoint, RouteBalance, RouteInput};
Expand All @@ -384,13 +385,15 @@ mod test {
regex: vec![String::from("^select")],
target: TargetRole::Read,
algorithm_name: AlgorithmName::Random,
node_group_name: vec![String::from("")],
}),
ReadWriteSplittingRule::Regex(RegexRule {
name: String::from("t2"),
rule_type: String::from("regex"),
regex: vec![String::from("^insert")],
target: TargetRole::Read,
algorithm_name: AlgorithmName::Random,
node_group_name: vec![String::from("")],
}),
];

Expand All @@ -415,7 +418,8 @@ mod test {
}],
};

let mut m = RulesMatchBuilder::build(rules, default_target, rw_endpoint);
let endpoint_group = IndexMap::new();
let mut m = RulesMatchBuilder::build(rules, default_target, endpoint_group, rw_endpoint);
let (b, target) = m.get(&RouteInput::Statement("insert"));
let endpoint = b.next();
assert_eq!(target, TargetRole::Read);
Expand Down
2 changes: 2 additions & 0 deletions pisa-proxy/proxy/strategy/src/readwritesplitting/static_rw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,15 @@ mod test {
regex: vec![String::from("^select")],
target: TargetRole::Read,
algorithm_name: AlgorithmName::Random,
node_group_name: vec![String::from("")],
}),
ReadWriteSplittingRule::Regex(RegexRule {
name: String::from("t2"),
rule_type: String::from("regex"),
regex: vec![String::from("^insert")],
target: TargetRole::ReadWrite,
algorithm_name: AlgorithmName::Random,
node_group_name: vec![String::from("")],
}),
];

Expand Down
74 changes: 28 additions & 46 deletions pisa-proxy/proxy/strategy/src/sharding_rewrite/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use std::vec;

use endpoint::endpoint::Endpoint;
use indexmap::IndexMap;
use crc32fast::Hasher;
use mysql_parser::ast::{SqlStmt, Visitor, TableIdent};

use crate::{config::{Sharding, StrategyType, ShardingAlgorithmName}, rewrite::{ShardingRewriter, ShardingRewriteInput}, route::BoxError};
Expand All @@ -34,7 +35,6 @@ impl CalcShardingIdx<u64> for u64 {
ShardingAlgorithmName::Mod => {
Some(self.wrapping_rem(id))
},

_ => None
}
}
Expand All @@ -46,7 +46,6 @@ impl CalcShardingIdx<i64> for i64 {
ShardingAlgorithmName::Mod => {
Some(self.wrapping_rem(id) as u64)
},

_ => None
}
}
Expand All @@ -57,8 +56,7 @@ impl CalcShardingIdx<f64> for f64 {
match algo {
ShardingAlgorithmName::Mod => {
Some((self % id).round() as u64)
},

}
_ => None
}
}
Expand All @@ -81,7 +79,6 @@ pub struct DatabaseChange {
pub struct ShardingRewriteOutput {
pub changes: Vec<RewriteChange>,
pub target_sql: String,
pub endpoint: Endpoint,
pub data_source: DataSource,
}

Expand Down Expand Up @@ -219,8 +216,7 @@ impl ShardingRewrite {
ShardingRewriteOutput {
changes,
target_sql,
endpoint: ep.take().unwrap().clone(),
data_source: DataSource::None,
data_source: DataSource::Endpoint(ep.take().unwrap().clone()),
}
]
)
Expand Down Expand Up @@ -267,19 +263,10 @@ impl ShardingRewrite {
None => continue
}
},

Err(e ) => return Err(e)
}
}

//for v in try_where {
// let v = v?;
// match v {
// Some((idx, num, _)) => wheres.push((idx, num)),
// None => continue
// }
//}

let expect_sum = wheres[0].1 as usize * wheres.len();
let sum: usize = wheres.iter().map(|x| x.1).sum::<u64>() as usize;

Expand Down Expand Up @@ -325,8 +312,7 @@ impl ShardingRewrite {
ShardingRewriteOutput {
changes,
target_sql: target_sql.to_string(),
endpoint: ep.take().unwrap().clone(),
data_source: DataSource::None
data_source: DataSource::Endpoint(ep.take().unwrap().clone())
}
]
)
Expand Down Expand Up @@ -384,7 +370,8 @@ impl ShardingRewrite {
).flatten().collect::<Vec<_>>()
}

fn parse_where<'b>(meta: &'b WhereMeta, algo: &ShardingAlgorithmName, actual_nodes_length: u64, query_id: u8, sharding_column: &str) -> Result<Option<(u8, u64, &'b WhereMeta)>, BoxError> {
fn parse_where<'b>(meta: &'b WhereMeta, algo: &ShardingAlgorithmName, sharding_count: u64, query_id: u8, sharding_column: &str) -> Result<Option<(u8, u64, &'b WhereMeta)>, BoxError> {

match meta {
WhereMeta::BinaryExpr { left, right } => {
if left != sharding_column {
Expand All @@ -394,17 +381,17 @@ impl ShardingRewrite {
let num = match right {
WhereMetaRightDataType::Num(val) => {
let val = val.parse::<u64>()?;
val.calc(algo, actual_nodes_length)
val.calc(algo, sharding_count)
},

WhereMetaRightDataType::SignedNum(val) => {
let val = val.parse::<i64>()?;
val.calc(algo, actual_nodes_length as i64)
val.calc(algo, sharding_count as i64)
},

WhereMetaRightDataType::FloatNum(val) => {
let val = val.parse::<f64>()?;
val.calc(algo, actual_nodes_length as f64)
val.calc(algo, sharding_count as f64)
}
_ => return Ok(None)
};
Expand Down Expand Up @@ -445,12 +432,11 @@ impl ShardingRewrite {
offset = change.target.len() - change.span.len();
}

let endpoint = self.endpoints[group].clone();
let ep = self.endpoints[group].clone();
output.push(ShardingRewriteOutput {
changes: changes.into_iter().map(|x| RewriteChange::DatabaseChange(x)).collect(),
target_sql,
endpoint,
data_source: DataSource::None,
data_source: DataSource::Endpoint(ep),
})
}
output
Expand Down Expand Up @@ -488,12 +474,11 @@ impl ShardingRewrite {
offset = change.target.len() - change.span.len();
}

let endpoint = self.endpoints[0].clone();
let ep = self.endpoints[0].clone();
output.push(ShardingRewriteOutput {
changes: changes.into_iter().map(|x| RewriteChange::DatabaseChange(x)).collect(),
target_sql: target_sql.to_string(),
endpoint,
data_source: DataSource::None
data_source: DataSource::Endpoint(ep)
})
}

Expand All @@ -507,13 +492,12 @@ impl ShardingRewrite {
if actual_node.len() == 0 {
target.push_str(schema);
target.push('.');
target.push_str(&format!("{}000{}", &table.name, table_idx.to_string()));
target.push_str(&format!("{}{:05}", &table.name, table_idx));
} else {
target.push_str(actual_node);
target.push_str(".");
target.push_str(&table.name);
}
//target.push(' ');
}
target
}

Expand Down Expand Up @@ -541,8 +525,6 @@ impl ShardingRewriter<ShardingRewriteInput> for ShardingRewrite {
}
}



#[cfg(test)]
mod test {
use endpoint::endpoint::Endpoint;
Expand Down Expand Up @@ -637,23 +619,23 @@ mod test {
let raw_sql = "SELECT idx from db.tshard where idx = 3";
let parser = Parser::new();
let mut ast = parser.parse(raw_sql).unwrap();
let sr = ShardingRewrite::new(config.0.clone(), config.1.clone(), false);
let mut sr = ShardingRewrite::new(config.0.clone(), config.1.clone(), false);
sr.set_raw_sql(raw_sql.to_string());
let meta = sr.get_meta(&mut ast[0]);

let res = sr.database_strategy(meta).unwrap();
assert_eq!(res[0].target_sql, "SELECT idx from ds1.tshard where idx = 3");

let raw_sql = "SELECT idx from db.tshard where idx = 3 and idx = (SELECT idx from db.tshard where idx = 3)";
let sr = ShardingRewrite::new(config.0.clone(), config.1.clone(), false);
let mut sr = ShardingRewrite::new(config.0.clone(), config.1.clone(), false);
sr.set_raw_sql(raw_sql.to_string());
let mut ast = parser.parse(raw_sql).unwrap();
let meta = sr.get_meta(&mut ast[0]);
let res = sr.database_strategy( meta).unwrap();
assert_eq!(res[0].target_sql, "SELECT idx from ds1.tshard where idx = 3 and idx = (SELECT idx from ds1.tshard where idx = 3)");

let raw_sql = "SELECT idx from db.tshard where idx = 3 and idx = (SELECT idx from db.tshard where idx = 4)";
let sr = ShardingRewrite::new(config.0.clone(), config.1, false);
let mut sr = ShardingRewrite::new(config.0.clone(), config.1, false);
sr.set_raw_sql(raw_sql.to_string());
let mut ast = parser.parse(raw_sql).unwrap();
let meta = sr.get_meta(&mut ast[0]);
Expand Down Expand Up @@ -681,10 +663,10 @@ mod test {
assert_eq!(
res.into_iter().map(|x| x.target_sql).collect::<Vec<_>>(),
vec![
"SELECT idx from db.tshard0000 where idx > 3",
"SELECT idx from db.tshard0001 where idx > 3",
"SELECT idx from db.tshard0002 where idx > 3",
"SELECT idx from db.tshard0003 where idx > 3",
"SELECT idx from db.tshard00000 where idx > 3",
"SELECT idx from db.tshard00001 where idx > 3",
"SELECT idx from db.tshard00002 where idx > 3",
"SELECT idx from db.tshard00003 where idx > 3",
],
);

Expand All @@ -694,15 +676,15 @@ mod test {
sr.set_raw_sql(raw_sql);
let meta = sr.get_meta(&mut ast[0]);
let res = sr.table_strategy(meta).unwrap();
assert_eq!(res[0].target_sql, "SELECT idx from db.tshard0000 where idx = 4".to_string());
assert_eq!(res[0].target_sql, "SELECT idx from db.tshard00000 where idx = 4".to_string());

let raw_sql = "SELECT idx from db.tshard where idx = 3 and idx = (SELECT idx from db.tshard where idx = 3)".to_string();
let mut ast = parser.parse(&raw_sql).unwrap();
let mut sr = ShardingRewrite::new(config.0.clone(), config.1.clone(), false);
sr.set_raw_sql(raw_sql);
let meta = sr.get_meta(&mut ast[0]);
let res = sr.table_strategy(meta).unwrap();
assert_eq!(res[0].target_sql, "SELECT idx from db.tshard0003 where idx = 3 and idx = (SELECT idx from db.tshard0003 where idx = 3)".to_string());
assert_eq!(res[0].target_sql, "SELECT idx from db.tshard00003 where idx = 3 and idx = (SELECT idx from db.tshard00003 where idx = 3)".to_string());

let raw_sql = "SELECT idx from db.tshard where idx = 3 and idx = (SELECT idx from db.tshard where idx = 4)".to_string();
let mut ast = parser.parse(&raw_sql).unwrap();
Expand All @@ -713,10 +695,10 @@ mod test {
assert_eq!(
res.into_iter().map(|x| x.target_sql).collect::<Vec<_>>(),
vec![
"SELECT idx from db.tshard0000 where idx = 3 and idx = (SELECT idx from db.tshard0000 where idx = 4)",
"SELECT idx from db.tshard0001 where idx = 3 and idx = (SELECT idx from db.tshard0001 where idx = 4)",
"SELECT idx from db.tshard0002 where idx = 3 and idx = (SELECT idx from db.tshard0002 where idx = 4)",
"SELECT idx from db.tshard0003 where idx = 3 and idx = (SELECT idx from db.tshard0003 where idx = 4)",
"SELECT idx from db.tshard00000 where idx = 3 and idx = (SELECT idx from db.tshard00000 where idx = 4)",
"SELECT idx from db.tshard00001 where idx = 3 and idx = (SELECT idx from db.tshard00001 where idx = 4)",
"SELECT idx from db.tshard00002 where idx = 3 and idx = (SELECT idx from db.tshard00002 where idx = 4)",
"SELECT idx from db.tshard00003 where idx = 3 and idx = (SELECT idx from db.tshard00003 where idx = 4)",
],
);
}
Expand Down
1 change: 0 additions & 1 deletion pisa-proxy/runtime/mysql/src/transaction_fsm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ pub fn query_rewrite(
.map(|x| ShardingRewriteOutput {
changes: vec![],
target_sql: raw_sql.clone(),
endpoint: x.clone(),
data_source: strategy::sharding_rewrite::DataSource::Endpoint(x.clone()),
})
.collect::<Vec<_>>()
Expand Down