diff --git a/pisa-proxy/proxy/strategy/src/sharding_rewrite/mod.rs b/pisa-proxy/proxy/strategy/src/sharding_rewrite/mod.rs index 7ee427ad..46cf5ff4 100644 --- a/pisa-proxy/proxy/strategy/src/sharding_rewrite/mod.rs +++ b/pisa-proxy/proxy/strategy/src/sharding_rewrite/mod.rs @@ -182,7 +182,7 @@ impl ShardingRewrite { let would_changes:Vec = try_tables.iter().filter_map(|x| { let w = wheres.iter().find(|w| w.0 == x.0); if let Some(w) = w { - let target = self.change_table(x.2, &x.1.actual_datanodes[w.1 as usize]); + let target = self.change_table(x.2, &x.1.actual_datanodes[w.1 as usize], 0); Some( DatabaseChange { span: x.2.span, @@ -225,6 +225,112 @@ impl ShardingRewrite { ] ) } + + fn table_strategy(&self, meta: RewriteMetaData) -> Result, Box> { + let tables = meta.get_tables(); + let try_tables = self.find_table_rule(tables); + if try_tables.is_empty() { + return Ok(vec![]) + } + + let wheres = meta.get_wheres(); + if wheres.is_empty() { + return Ok(self.table_strategy_iproduct(try_tables.clone())); + } + + let try_where = Self::find_where(wheres, |query_id, meta| { + let rule = try_tables.iter().find(|x| x.0 == query_id); + if let Some(rule) = rule { + let strategy = match &rule.1.table_strategy { + Some(StrategyType::TableStrategyConfig(strategy)) => { + strategy + } + + _ => unreachable!(), + }; + + let algo = &strategy.table_sharding_algorithm_name; + let sharding_count = strategy.sharding_count as u64; + + return Self::parse_where(meta, algo, sharding_count, query_id, &strategy.table_sharding_column) + } + + Ok(None) + }); + + let mut wheres = Vec::with_capacity(try_where.len()); + for v in try_where { + match v { + Ok(data) => { + match data { + Some((idx, num, _)) => wheres.push((idx, num)), + None => continue + } + }, + + Err(e ) => return Err(e) + } + } + + //for v in try_where { + // let v = v?; + // match v { + // Some((idx, num, _)) => wheres.push((idx, num)), + // None => continue + // } + //} + + let expect_sum = wheres[0].1 as usize * wheres.len(); + let sum: usize = wheres.iter().map(|x| x.1).sum::() as usize; + + if expect_sum != sum { + return Ok(self.table_strategy_iproduct(try_tables)); + } + + let would_changes:Vec = try_tables.iter().filter_map(|x| { + let w = wheres.iter().find(|w| w.0 == x.0); + if let Some(w) = w { + let target = self.change_table(x.2, "", w.1); + Some( + DatabaseChange { + span: x.2.span, + shard_idx: w.1, + target, + rule: x.1.clone(), + } + ) + } else { + None + } + }).collect::>(); + + let mut target_sql = self.raw_sql.clone(); + let mut offset = 0; + let sharding_rule = &would_changes[0].rule.clone(); + + let changes = would_changes.into_iter().map(|x| { + Self::change_sql(&mut target_sql, x.span, &x.target, offset); + offset = x.target.len() - x.span.len(); + + RewriteChange::DatabaseChange(x) + }).collect::>(); + + let mut ep = self.endpoints.iter().find(|e| e.name == sharding_rule.actual_datanodes[0]); + if ep.is_none() { + return Err(Box::new(std::io::Error::new(std::io::ErrorKind::NotFound, "endpoint not found"))) + } + + Ok( + vec![ + ShardingRewriteOutput { + changes, + target_sql: target_sql.to_string(), + endpoint: ep.take().unwrap().clone(), + data_source: DataSource::None + } + ] + ) + } fn find_table_rule<'a>(&self, tables: &'a IndexMap>) -> Vec<(u8, Sharding, &'a TableIdent)> { Self::find_table(tables, |idx, meta| { @@ -318,7 +424,7 @@ impl ShardingRewrite { for t in tables.iter() { for (idx, node) in t.1.actual_datanodes.iter().enumerate() { - let target = self.change_table(t.2, node); + let target = self.change_table(t.2, node, 0); let change = DatabaseChange { span: t.2.span, @@ -350,15 +456,66 @@ impl ShardingRewrite { output } - fn change_table(&self, table: &TableIdent, actual_node: &str) -> String { + fn table_strategy_iproduct(&self, tables: Vec<(u8, Sharding, &TableIdent)>) -> Vec { + let mut output = vec![]; + let mut group_changes = IndexMap::>::new(); + + for t in tables.iter() { + match t.1.table_strategy.as_ref().unwrap() { + crate::config::StrategyType::TableStrategyConfig(config) => { + for idx in 0..config.sharding_count as u64 { + let target = self.change_table(t.2, "", idx); + + let change = DatabaseChange { + span: t.2.span, + target, + shard_idx: idx as u64, + rule: t.1.clone() + }; + + group_changes.entry(idx as usize).or_insert(vec![]).push(change); + } + } + _ => unreachable!() + } + } + + for(_, changes) in group_changes.into_iter() { + let mut offset = 0; + let mut target_sql = self.raw_sql.clone(); + for change in changes.iter() { + Self::change_sql(&mut target_sql, change.span, &change.target, offset); + offset = change.target.len() - change.span.len(); + } + + let endpoint = self.endpoints[0].clone(); + output.push(ShardingRewriteOutput { + changes: changes.into_iter().map(|x| RewriteChange::DatabaseChange(x)).collect(), + target_sql: target_sql.to_string(), + endpoint, + data_source: DataSource::None + }) + } + + output + } + + fn change_table(&self, table: &TableIdent, actual_node: &str, table_idx: u64) -> String { let schema = table.schema.as_ref().unwrap(); let mut target = String::with_capacity(schema.len()); - target.push_str(actual_node); - target.push('.'); - target.push_str(&table.name); + + if actual_node.len() == 0 { + target.push_str(schema); + target.push('.'); + target.push_str(&format!("{}000{}", &table.name, table_idx.to_string())); + } else { + target.push_str(actual_node); + target.push_str("."); + target.push_str(&table.name); + } //target.push(' '); target - } + } fn change_sql(target_sql: &mut String, span: mysql_parser::Span, target: &str, offset: usize) { for _ in 0 .. span.len() { @@ -438,6 +595,42 @@ mod test { ) } + fn get_table_sharding_config() -> (Vec, Vec) { + ( + vec![ + Sharding { + table_name: "tshard".to_string(), + actual_datanodes: vec!["ds001".to_string()], + binding_tables: None, + broadcast_tables: None, + table_strategy: Some( + StrategyType::TableStrategyConfig( + crate::config::TableStrategyConfig { + datanode_name: "ds001".to_string(), + table_sharding_algorithm_name: ShardingAlgorithmName::Mod, + table_sharding_column: "idx".to_string(), + sharding_count: 4, + }, + ), + ), + database_strategy: None, + database_table_strategy: None, + } + ], + + vec![ + Endpoint { + weight: 1, + name: String::from("ds001"), + db: String::from("db"), + user: String::from("user"), + password: String::from("password"), + addr: String::from("127.0.0.1:3306"), + }, + ] + ) + } + #[test] fn test_database_sharding_strategy() { let config = get_database_sharding_config(); @@ -473,4 +666,58 @@ mod test { ], ) } -} \ No newline at end of file + + #[test] + fn test_table_sharding_strategy() { + let config = get_table_sharding_config(); + let raw_sql = "SELECT idx from db.tshard where idx > 3".to_string(); + let parser = Parser::new(); + let mut ast = parser.parse(&raw_sql).unwrap(); + let mut sr = ShardingRewrite::new(config.0.clone(), config.1.clone(), false); + sr.set_raw_sql(raw_sql); + let meta = sr.get_meta(&mut ast[0]); + + let res = sr.table_strategy(meta).unwrap(); + assert_eq!( + res.into_iter().map(|x| x.target_sql).collect::>(), + vec![ + "SELECT idx from db.tshard0000 where idx > 3", + "SELECT idx from db.tshard0001 where idx > 3", + "SELECT idx from db.tshard0002 where idx > 3", + "SELECT idx from db.tshard0003 where idx > 3", + ], + ); + + let raw_sql = "SELECT idx from db.tshard where idx = 4".to_string(); + let mut ast = parser.parse(&raw_sql).unwrap(); + let mut sr = ShardingRewrite::new(config.0.clone(), config.1.clone(), false); + sr.set_raw_sql(raw_sql); + let meta = sr.get_meta(&mut ast[0]); + let res = sr.table_strategy(meta).unwrap(); + assert_eq!(res[0].target_sql, "SELECT idx from db.tshard0000 where idx = 4".to_string()); + + let raw_sql = "SELECT idx from db.tshard where idx = 3 and idx = (SELECT idx from db.tshard where idx = 3)".to_string(); + let mut ast = parser.parse(&raw_sql).unwrap(); + let mut sr = ShardingRewrite::new(config.0.clone(), config.1.clone(), false); + sr.set_raw_sql(raw_sql); + let meta = sr.get_meta(&mut ast[0]); + let res = sr.table_strategy(meta).unwrap(); + assert_eq!(res[0].target_sql, "SELECT idx from db.tshard0003 where idx = 3 and idx = (SELECT idx from db.tshard0003 where idx = 3)".to_string()); + + let raw_sql = "SELECT idx from db.tshard where idx = 3 and idx = (SELECT idx from db.tshard where idx = 4)".to_string(); + let mut ast = parser.parse(&raw_sql).unwrap(); + let mut sr = ShardingRewrite::new(config.0.clone(), config.1.clone(), false); + sr.set_raw_sql(raw_sql); + let meta = sr.get_meta(&mut ast[0]); + let res = sr.table_strategy(meta).unwrap(); + assert_eq!( + res.into_iter().map(|x| x.target_sql).collect::>(), + vec![ + "SELECT idx from db.tshard0000 where idx = 3 and idx = (SELECT idx from db.tshard0000 where idx = 4)", + "SELECT idx from db.tshard0001 where idx = 3 and idx = (SELECT idx from db.tshard0001 where idx = 4)", + "SELECT idx from db.tshard0002 where idx = 3 and idx = (SELECT idx from db.tshard0002 where idx = 4)", + "SELECT idx from db.tshard0003 where idx = 3 and idx = (SELECT idx from db.tshard0003 where idx = 4)", + ], + ); + } +}