Skip to content

Commit

Permalink
feat(sharding): add table strategy (#305)
Browse files Browse the repository at this point in the history
Signed-off-by: wangbo <[email protected]>

Signed-off-by: wangbo <[email protected]>
Co-authored-by: wangbo <[email protected]>
  • Loading branch information
wbtlb and wangbo authored Sep 19, 2022
1 parent bc9c113 commit 37b84ca
Showing 1 changed file with 255 additions and 8 deletions.
263 changes: 255 additions & 8 deletions pisa-proxy/proxy/strategy/src/sharding_rewrite/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ impl ShardingRewrite {
let would_changes:Vec<DatabaseChange> = try_tables.iter().filter_map(|x| {
let w = wheres.iter().find(|w| w.0 == x.0);
if let Some(w) = w {
let target = self.change_table(x.2, &x.1.actual_datanodes[w.1 as usize]);
let target = self.change_table(x.2, &x.1.actual_datanodes[w.1 as usize], 0);
Some(
DatabaseChange {
span: x.2.span,
Expand Down Expand Up @@ -225,6 +225,112 @@ impl ShardingRewrite {
]
)
}

fn table_strategy(&self, meta: RewriteMetaData) -> Result<Vec<ShardingRewriteOutput>, Box<dyn std::error::Error>> {
let tables = meta.get_tables();
let try_tables = self.find_table_rule(tables);
if try_tables.is_empty() {
return Ok(vec![])
}

let wheres = meta.get_wheres();
if wheres.is_empty() {
return Ok(self.table_strategy_iproduct(try_tables.clone()));
}

let try_where = Self::find_where(wheres, |query_id, meta| {
let rule = try_tables.iter().find(|x| x.0 == query_id);
if let Some(rule) = rule {
let strategy = match &rule.1.table_strategy {
Some(StrategyType::TableStrategyConfig(strategy)) => {
strategy
}

_ => unreachable!(),
};

let algo = &strategy.table_sharding_algorithm_name;
let sharding_count = strategy.sharding_count as u64;

return Self::parse_where(meta, algo, sharding_count, query_id, &strategy.table_sharding_column)
}

Ok(None)
});

let mut wheres = Vec::with_capacity(try_where.len());
for v in try_where {
match v {
Ok(data) => {
match data {
Some((idx, num, _)) => wheres.push((idx, num)),
None => continue
}
},

Err(e ) => return Err(e)
}
}

//for v in try_where {
// let v = v?;
// match v {
// Some((idx, num, _)) => wheres.push((idx, num)),
// None => continue
// }
//}

let expect_sum = wheres[0].1 as usize * wheres.len();
let sum: usize = wheres.iter().map(|x| x.1).sum::<u64>() as usize;

if expect_sum != sum {
return Ok(self.table_strategy_iproduct(try_tables));
}

let would_changes:Vec<DatabaseChange> = try_tables.iter().filter_map(|x| {
let w = wheres.iter().find(|w| w.0 == x.0);
if let Some(w) = w {
let target = self.change_table(x.2, "", w.1);
Some(
DatabaseChange {
span: x.2.span,
shard_idx: w.1,
target,
rule: x.1.clone(),
}
)
} else {
None
}
}).collect::<Vec<_>>();

let mut target_sql = self.raw_sql.clone();
let mut offset = 0;
let sharding_rule = &would_changes[0].rule.clone();

let changes = would_changes.into_iter().map(|x| {
Self::change_sql(&mut target_sql, x.span, &x.target, offset);
offset = x.target.len() - x.span.len();

RewriteChange::DatabaseChange(x)
}).collect::<Vec<_>>();

let mut ep = self.endpoints.iter().find(|e| e.name == sharding_rule.actual_datanodes[0]);
if ep.is_none() {
return Err(Box::new(std::io::Error::new(std::io::ErrorKind::NotFound, "endpoint not found")))
}

Ok(
vec![
ShardingRewriteOutput {
changes,
target_sql: target_sql.to_string(),
endpoint: ep.take().unwrap().clone(),
data_source: DataSource::None
}
]
)
}

fn find_table_rule<'a>(&self, tables: &'a IndexMap<u8, Vec<TableIdent>>) -> Vec<(u8, Sharding, &'a TableIdent)> {
Self::find_table(tables, |idx, meta| {
Expand Down Expand Up @@ -318,7 +424,7 @@ impl ShardingRewrite {

for t in tables.iter() {
for (idx, node) in t.1.actual_datanodes.iter().enumerate() {
let target = self.change_table(t.2, node);
let target = self.change_table(t.2, node, 0);

let change = DatabaseChange {
span: t.2.span,
Expand Down Expand Up @@ -350,15 +456,66 @@ impl ShardingRewrite {
output
}

fn change_table(&self, table: &TableIdent, actual_node: &str) -> String {
fn table_strategy_iproduct(&self, tables: Vec<(u8, Sharding, &TableIdent)>) -> Vec<ShardingRewriteOutput> {
let mut output = vec![];
let mut group_changes = IndexMap::<usize, Vec<DatabaseChange>>::new();

for t in tables.iter() {
match t.1.table_strategy.as_ref().unwrap() {
crate::config::StrategyType::TableStrategyConfig(config) => {
for idx in 0..config.sharding_count as u64 {
let target = self.change_table(t.2, "", idx);

let change = DatabaseChange {
span: t.2.span,
target,
shard_idx: idx as u64,
rule: t.1.clone()
};

group_changes.entry(idx as usize).or_insert(vec![]).push(change);
}
}
_ => unreachable!()
}
}

for(_, changes) in group_changes.into_iter() {
let mut offset = 0;
let mut target_sql = self.raw_sql.clone();
for change in changes.iter() {
Self::change_sql(&mut target_sql, change.span, &change.target, offset);
offset = change.target.len() - change.span.len();
}

let endpoint = self.endpoints[0].clone();
output.push(ShardingRewriteOutput {
changes: changes.into_iter().map(|x| RewriteChange::DatabaseChange(x)).collect(),
target_sql: target_sql.to_string(),
endpoint,
data_source: DataSource::None
})
}

output
}

fn change_table(&self, table: &TableIdent, actual_node: &str, table_idx: u64) -> String {
let schema = table.schema.as_ref().unwrap();
let mut target = String::with_capacity(schema.len());
target.push_str(actual_node);
target.push('.');
target.push_str(&table.name);

if actual_node.len() == 0 {
target.push_str(schema);
target.push('.');
target.push_str(&format!("{}000{}", &table.name, table_idx.to_string()));
} else {
target.push_str(actual_node);
target.push_str(".");
target.push_str(&table.name);
}
//target.push(' ');
target
}
}

fn change_sql(target_sql: &mut String, span: mysql_parser::Span, target: &str, offset: usize) {
for _ in 0 .. span.len() {
Expand Down Expand Up @@ -438,6 +595,42 @@ mod test {
)
}

fn get_table_sharding_config() -> (Vec<Sharding>, Vec<Endpoint>) {
(
vec![
Sharding {
table_name: "tshard".to_string(),
actual_datanodes: vec!["ds001".to_string()],
binding_tables: None,
broadcast_tables: None,
table_strategy: Some(
StrategyType::TableStrategyConfig(
crate::config::TableStrategyConfig {
datanode_name: "ds001".to_string(),
table_sharding_algorithm_name: ShardingAlgorithmName::Mod,
table_sharding_column: "idx".to_string(),
sharding_count: 4,
},
),
),
database_strategy: None,
database_table_strategy: None,
}
],

vec![
Endpoint {
weight: 1,
name: String::from("ds001"),
db: String::from("db"),
user: String::from("user"),
password: String::from("password"),
addr: String::from("127.0.0.1:3306"),
},
]
)
}

#[test]
fn test_database_sharding_strategy() {
let config = get_database_sharding_config();
Expand Down Expand Up @@ -473,4 +666,58 @@ mod test {
],
)
}
}

#[test]
fn test_table_sharding_strategy() {
let config = get_table_sharding_config();
let raw_sql = "SELECT idx from db.tshard where idx > 3".to_string();
let parser = Parser::new();
let mut ast = parser.parse(&raw_sql).unwrap();
let mut sr = ShardingRewrite::new(config.0.clone(), config.1.clone(), false);
sr.set_raw_sql(raw_sql);
let meta = sr.get_meta(&mut ast[0]);

let res = sr.table_strategy(meta).unwrap();
assert_eq!(
res.into_iter().map(|x| x.target_sql).collect::<Vec<_>>(),
vec![
"SELECT idx from db.tshard0000 where idx > 3",
"SELECT idx from db.tshard0001 where idx > 3",
"SELECT idx from db.tshard0002 where idx > 3",
"SELECT idx from db.tshard0003 where idx > 3",
],
);

let raw_sql = "SELECT idx from db.tshard where idx = 4".to_string();
let mut ast = parser.parse(&raw_sql).unwrap();
let mut sr = ShardingRewrite::new(config.0.clone(), config.1.clone(), false);
sr.set_raw_sql(raw_sql);
let meta = sr.get_meta(&mut ast[0]);
let res = sr.table_strategy(meta).unwrap();
assert_eq!(res[0].target_sql, "SELECT idx from db.tshard0000 where idx = 4".to_string());

let raw_sql = "SELECT idx from db.tshard where idx = 3 and idx = (SELECT idx from db.tshard where idx = 3)".to_string();
let mut ast = parser.parse(&raw_sql).unwrap();
let mut sr = ShardingRewrite::new(config.0.clone(), config.1.clone(), false);
sr.set_raw_sql(raw_sql);
let meta = sr.get_meta(&mut ast[0]);
let res = sr.table_strategy(meta).unwrap();
assert_eq!(res[0].target_sql, "SELECT idx from db.tshard0003 where idx = 3 and idx = (SELECT idx from db.tshard0003 where idx = 3)".to_string());

let raw_sql = "SELECT idx from db.tshard where idx = 3 and idx = (SELECT idx from db.tshard where idx = 4)".to_string();
let mut ast = parser.parse(&raw_sql).unwrap();
let mut sr = ShardingRewrite::new(config.0.clone(), config.1.clone(), false);
sr.set_raw_sql(raw_sql);
let meta = sr.get_meta(&mut ast[0]);
let res = sr.table_strategy(meta).unwrap();
assert_eq!(
res.into_iter().map(|x| x.target_sql).collect::<Vec<_>>(),
vec![
"SELECT idx from db.tshard0000 where idx = 3 and idx = (SELECT idx from db.tshard0000 where idx = 4)",
"SELECT idx from db.tshard0001 where idx = 3 and idx = (SELECT idx from db.tshard0001 where idx = 4)",
"SELECT idx from db.tshard0002 where idx = 3 and idx = (SELECT idx from db.tshard0002 where idx = 4)",
"SELECT idx from db.tshard0003 where idx = 3 and idx = (SELECT idx from db.tshard0003 where idx = 4)",
],
);
}
}

0 comments on commit 37b84ca

Please sign in to comment.