diff --git a/pisa-proxy/proxy/strategy/src/sharding_rewrite/mod.rs b/pisa-proxy/proxy/strategy/src/sharding_rewrite/mod.rs index ff8638bc..ff9546b5 100644 --- a/pisa-proxy/proxy/strategy/src/sharding_rewrite/mod.rs +++ b/pisa-proxy/proxy/strategy/src/sharding_rewrite/mod.rs @@ -13,6 +13,7 @@ // limitations under the License. mod meta; +mod genric_meta; use std::vec; @@ -20,13 +21,12 @@ use endpoint::endpoint::Endpoint; use indexmap::IndexMap; use mysql_parser::ast::{SqlStmt, Visitor, TableIdent}; -use self::meta::{ +use self::{meta::{ FieldMeta, InsertValsMeta, RewriteMetaData, WhereMeta, WhereMetaRightDataType, -}; +}, genric_meta::ShardingMeta}; use crate::{ config::{Sharding, ShardingAlgorithmName, StrategyType}, rewrite::{ShardingRewriteInput, ShardingRewriter}, - route::BoxError, }; pub trait CalcShardingIdx { @@ -112,13 +112,19 @@ pub enum ShardingRewriteError { ShardingColumnNotFound(String), #[error("parse str to u64 error {0:?}")] - ParseError(#[from] std::num::ParseIntError), + ParseIntError(#[from] std::num::ParseIntError), + + #[error("parse str to u64 error {0:?}")] + ParseFloatError(#[from] std::num::ParseFloatError), #[error("calc mod error")] CalcModError, #[error("enpoint not found when using actual_datanodes")] - EndpointNotFound + EndpointNotFound, + + #[error("fields is empty")] + FieldsIsEmpty, } struct ChangeInsertMeta { @@ -126,6 +132,11 @@ struct ChangeInsertMeta { row_value_span: mysql_parser::Span, } +enum StrategyTyp { + Database, + Table, +} + impl ShardingRewrite { pub fn new(rules: Vec, endpoints: Vec, has_rw: bool) -> Self { ShardingRewrite { rules, raw_sql: "".to_string(), endpoints, has_rw } @@ -142,55 +153,21 @@ impl ShardingRewrite { fn database_strategy( &self, meta: RewriteMetaData, - ) -> Result, BoxError> { - let tables = meta.get_tables(); - let try_tables = self.find_table_rule(tables); - - if try_tables.is_empty() { - return Ok(vec![]); - } - + try_tables: Vec<(u8, Sharding, &TableIdent)>, + ) -> Result, ShardingRewriteError> { let wheres = meta.get_wheres(); if wheres.is_empty() { return Ok(self.database_strategy_iproduct(try_tables)); } - let try_where = Self::find_where(wheres, |query_id, meta| { - let rule = try_tables.iter().find(|x| x.0 == query_id); - if let Some(rule) = rule { - let strategy = match &rule.1.database_strategy { - Some(StrategyType::DatabaseStrategyConfig(strategy)) => strategy, - - _ => unreachable!(), - }; - - let algo = &strategy.database_sharding_algorithm_name; - let actual_nodes_length = rule.1.actual_datanodes.len() as u64; - - return Self::parse_where( - meta, - algo, - actual_nodes_length, - query_id, - &strategy.database_sharding_column, - ); - } - - Ok(None) - }); - - let mut wheres = Vec::with_capacity(try_where.len()); - for v in try_where { - match v { - Ok(data) => match data { - Some((idx, num, _)) => wheres.push((idx, num)), - None => continue, - }, - - Err(e) => return Err(e), + let wheres = Self::find_try_where(StrategyTyp::Database, &try_tables, wheres)?.into_iter().filter_map(|x| + match x { + Some((idx, num, _)) => Some((idx, num)), + None => None, } - } + + ).collect::>(); let expect_sum = wheres[0].1 as usize * wheres.len(); let sum: usize = wheres.iter().map(|x| x.1).sum::() as usize; @@ -258,81 +235,30 @@ impl ShardingRewrite { fn table_strategy( &self, meta: RewriteMetaData, - ) -> Result, BoxError> { - let tables = meta.get_tables(); - let try_tables = self.find_table_rule(tables); - if try_tables.is_empty() { - return Ok(vec![]); - } - + try_tables: Vec<(u8, Sharding, &TableIdent)>, + ) -> Result, ShardingRewriteError> { let wheres = meta.get_wheres(); let inserts = meta.get_inserts(); let fields = meta.get_fields(); if !inserts.is_empty() { - let outputs = try_tables.into_iter().map(|(query_id, rule, table)| { - let strategy = if let Some(strategy) = &rule.table_strategy { - if let StrategyType::TableStrategyConfig(config) = strategy { - config - } else { - unreachable!() - } - } else { - unreachable!() - }; - - self.change_insert_sql( - &rule, - &table, - &inserts.get(&query_id).unwrap(), - &fields.get(&query_id).unwrap(), - &strategy.table_sharding_column, - &strategy.table_sharding_algorithm_name, - *&strategy.sharding_count as u64, - ) - }).collect::, _>>()?.into_iter().flatten().collect::>(); + if fields.is_empty() { + return Err(ShardingRewriteError::FieldsIsEmpty) + } - return Ok(outputs); + return self.change_insert_sql(try_tables, fields, inserts); } if wheres.is_empty() { return Ok(self.table_strategy_iproduct(try_tables.clone())); } - let try_where = Self::find_where(wheres, |query_id, meta| { - let rule = try_tables.iter().find(|x| x.0 == query_id); - if let Some(rule) = rule { - let strategy = match &rule.1.table_strategy { - Some(StrategyType::TableStrategyConfig(strategy)) => strategy, - - _ => unreachable!(), - }; - - let algo = &strategy.table_sharding_algorithm_name; - let sharding_count = strategy.sharding_count as u64; - - return Self::parse_where( - meta, - algo, - sharding_count, - query_id, - &strategy.table_sharding_column, - ); - } - - Ok(None) - }); - - let mut wheres = Vec::with_capacity(try_where.len()); - for v in try_where { - match v { - Ok(data) => match data { - Some((idx, num, _)) => wheres.push((idx, num)), - None => continue, - }, - Err(e ) => return Err(e) + let wheres = Self::find_try_where(StrategyTyp::Table, &try_tables, wheres)?.into_iter().filter_map(|x| { + match x { + Some((idx, num, _)) => Some((idx, num)), + None => None, } - } + }).collect::>(); let expect_sum = wheres[0].1 as usize * wheres.len(); let sum: usize = wheres.iter().map(|x| x.1).sum::() as usize; @@ -443,12 +369,39 @@ impl ShardingRewrite { .collect::>() } + fn find_try_where<'a>(strategy_typ: StrategyTyp, try_tables: &[(u8, Sharding, &TableIdent)], wheres: &'a IndexMap>) -> Result>, ShardingRewriteError> { + Self::find_where(wheres, |query_id, meta| { + let rule = try_tables.iter().find(|x| x.0 == query_id); + if let Some(rule) = rule { + let (sharding_column, algo, sharding_count) = match strategy_typ { + StrategyTyp::Database => { + (rule.1.get_sharding_column().0.unwrap(), rule.1.get_algo().0.unwrap(), rule.1.get_sharding_count().0.unwrap()) + } + + StrategyTyp::Table => { + (rule.1.get_sharding_column().1.unwrap(), rule.1.get_algo().1.unwrap(), rule.1.get_sharding_count().1.unwrap()) + } + }; + + return Self::parse_where( + meta, + algo, + sharding_count, + query_id, + sharding_column, + ); + } + + Ok(None) + }) + } + fn find_where( wheres: &IndexMap>, calc_fn: F, - ) -> Vec, BoxError>> + ) -> Result>, ShardingRewriteError> where - F: Fn(u8, &WhereMeta) -> Result, BoxError>, + F: Fn(u8, &WhereMeta) -> Result, ShardingRewriteError>, { wheres .iter() @@ -466,7 +419,7 @@ impl ShardingRewrite { ) }) .flatten() - .collect::>() + .collect::, _>>() } fn parse_where<'b>( @@ -475,7 +428,7 @@ impl ShardingRewrite { sharding_count: u64, query_id: u8, sharding_column: &str, - ) -> Result, BoxError> { + ) -> Result, ShardingRewriteError> { match meta { WhereMeta::BinaryExpr { left, right } => { if left != sharding_column { @@ -509,7 +462,33 @@ impl ShardingRewrite { } } - fn change_insert_sql( + fn change_insert_sql(&self, try_tables: Vec<(u8, Sharding, &TableIdent)>, fields: &IndexMap>,inserts: &IndexMap>) -> Result, ShardingRewriteError> { + let outputs = try_tables.into_iter().map(|(query_id, rule, table)| { + let strategy = if let Some(strategy) = &rule.table_strategy { + if let StrategyType::TableStrategyConfig(config) = strategy { + config + } else { + unreachable!() + } + } else { + unreachable!() + }; + + self.change_insert_sql_inner( + &rule, + &table, + &inserts.get(&query_id).unwrap(), + &fields.get(&query_id).unwrap(), + &strategy.table_sharding_column, + &strategy.table_sharding_algorithm_name, + *&strategy.sharding_count as u64, + ) + }).collect::, _>>()?.into_iter().flatten().collect::>(); + + Ok(outputs) + } + + fn change_insert_sql_inner( &self, rule: &Sharding, table: &TableIdent, @@ -609,8 +588,13 @@ impl ShardingRewrite { ) -> Vec { let mut output = vec![]; let mut group_changes = IndexMap::>::new(); + let mut sharding_column = None; for t in tables.iter() { + if let StrategyType::DatabaseStrategyConfig(config) = &t.1.database_strategy.as_ref().unwrap() { + sharding_column = Some(config.database_sharding_column.clone()) + } + for (idx, node) in t.1.actual_datanodes.iter().enumerate() { let target = self.change_table(t.2, node, 0); @@ -638,7 +622,7 @@ impl ShardingRewrite { changes: changes.into_iter().map(|x| RewriteChange::DatabaseChange(x)).collect(), target_sql, data_source: DataSource::Endpoint(ep), - sharding_column: None, + sharding_column: sharding_column.clone(), }) } output @@ -654,7 +638,7 @@ impl ShardingRewrite { for t in tables.iter() { match t.1.table_strategy.as_ref().unwrap() { - crate::config::StrategyType::TableStrategyConfig(config) => { + StrategyType::TableStrategyConfig(config) => { sharding_column = Some(config.table_sharding_column.clone()); for idx in 0..config.sharding_count as u64 { let target = self.change_table(t.2, "", idx); @@ -725,11 +709,28 @@ impl ShardingRewrite { } impl ShardingRewriter for ShardingRewrite { - type Output = Result, BoxError>; + type Output = Result, ShardingRewriteError>; fn rewrite(&mut self, mut input: ShardingRewriteInput) -> Self::Output { self.set_raw_sql(input.raw_sql); let meta = self.get_meta(&mut input.ast); - self.table_strategy(meta) + let tables = meta.get_tables().clone(); + let try_tables = self.find_table_rule(&tables); + if try_tables.is_empty() { + return Ok(vec![]); + } + + // Strategy according to first element of `try_tables`. + let rule = &try_tables[0].1; + + if rule.database_strategy.is_some() { + return self.database_strategy(meta, try_tables) + } + + if rule.table_strategy.is_some() { + return self.table_strategy(meta, try_tables) + } + + return Ok(vec![]) } } @@ -739,7 +740,7 @@ mod test { use mysql_parser::parser::Parser; use super::ShardingRewrite; - use crate::config::{DatabaseStrategyConfig, Sharding, ShardingAlgorithmName, StrategyType}; + use crate::{config::{DatabaseStrategyConfig, Sharding, ShardingAlgorithmName, StrategyType}, rewrite::{ShardingRewriteInput, ShardingRewriter}}; fn get_database_sharding_config() -> (Vec, Vec) { ( @@ -812,28 +813,31 @@ mod test { let config = get_database_sharding_config(); let raw_sql = "SELECT idx from db.tshard where idx = 3"; let parser = Parser::new(); - let mut ast = parser.parse(raw_sql).unwrap(); + let ast = parser.parse(raw_sql).unwrap(); let mut sr = ShardingRewrite::new(config.0.clone(), config.1.clone(), false); - sr.set_raw_sql(raw_sql.to_string()); - let meta = sr.get_meta(&mut ast[0]); - - let res = sr.database_strategy(meta).unwrap(); + let input = ShardingRewriteInput { + raw_sql: raw_sql.to_string(), + ast: ast[0].clone(), + }; + let res = sr.rewrite(input).unwrap(); assert_eq!(res[0].target_sql, "SELECT idx from ds1.tshard where idx = 3"); let raw_sql = "SELECT idx from db.tshard where idx = 3 and idx = (SELECT idx from db.tshard where idx = 3)"; - let mut sr = ShardingRewrite::new(config.0.clone(), config.1.clone(), false); - sr.set_raw_sql(raw_sql.to_string()); - let mut ast = parser.parse(raw_sql).unwrap(); - let meta = sr.get_meta(&mut ast[0]); - let res = sr.database_strategy(meta).unwrap(); + let ast = parser.parse(raw_sql).unwrap(); + let input = ShardingRewriteInput { + raw_sql: raw_sql.to_string(), + ast: ast[0].clone(), + }; + let res = sr.rewrite(input).unwrap(); assert_eq!(res[0].target_sql, "SELECT idx from ds1.tshard where idx = 3 and idx = (SELECT idx from ds1.tshard where idx = 3)"); let raw_sql = "SELECT idx from db.tshard where idx = 3 and idx = (SELECT idx from db.tshard where idx = 4)"; - let mut sr = ShardingRewrite::new(config.0.clone(), config.1, false); - sr.set_raw_sql(raw_sql.to_string()); - let mut ast = parser.parse(raw_sql).unwrap(); - let meta = sr.get_meta(&mut ast[0]); - let res = sr.database_strategy(meta).unwrap(); + let ast = parser.parse(raw_sql).unwrap(); + let input = ShardingRewriteInput { + raw_sql: raw_sql.to_string(), + ast: ast[0].clone(), + }; + let res = sr.rewrite(input).unwrap(); assert_eq!( res.into_iter().map(|x| x.target_sql).collect::>(), vec![ @@ -848,51 +852,55 @@ mod test { let config = get_table_sharding_config(); let raw_sql = "SELECT idx from db.tshard where idx > 3".to_string(); let parser = Parser::new(); - let mut ast = parser.parse(&raw_sql).unwrap(); + let ast = parser.parse(&raw_sql).unwrap(); + let input = ShardingRewriteInput { + raw_sql: raw_sql.to_string(), + ast: ast[0].clone(), + }; let mut sr = ShardingRewrite::new(config.0.clone(), config.1.clone(), false); - sr.set_raw_sql(raw_sql); - let meta = sr.get_meta(&mut ast[0]); - - let res = sr.table_strategy(meta).unwrap(); + let res = sr.rewrite(input).unwrap(); assert_eq!( res.into_iter().map(|x| x.target_sql).collect::>(), vec![ - "SELECT idx from db.tshard00000 where idx > 3", - "SELECT idx from db.tshard00001 where idx > 3", - "SELECT idx from db.tshard00002 where idx > 3", - "SELECT idx from db.tshard00003 where idx > 3", + "SELECT idx from db.tshard_00000 where idx > 3", + "SELECT idx from db.tshard_00001 where idx > 3", + "SELECT idx from db.tshard_00002 where idx > 3", + "SELECT idx from db.tshard_00003 where idx > 3", ], ); let raw_sql = "SELECT idx from db.tshard where idx = 4".to_string(); - let mut ast = parser.parse(&raw_sql).unwrap(); - let mut sr = ShardingRewrite::new(config.0.clone(), config.1.clone(), false); - sr.set_raw_sql(raw_sql); - let meta = sr.get_meta(&mut ast[0]); - let res = sr.table_strategy(meta).unwrap(); - assert_eq!(res[0].target_sql, "SELECT idx from db.tshard00000 where idx = 4".to_string()); + let ast = parser.parse(&raw_sql).unwrap(); + let input = ShardingRewriteInput { + raw_sql: raw_sql.to_string(), + ast: ast[0].clone(), + }; + let res = sr.rewrite(input).unwrap(); + assert_eq!(res[0].target_sql, "SELECT idx from db.tshard_00000 where idx = 4".to_string()); let raw_sql = "SELECT idx from db.tshard where idx = 3 and idx = (SELECT idx from db.tshard where idx = 3)".to_string(); - let mut ast = parser.parse(&raw_sql).unwrap(); - let mut sr = ShardingRewrite::new(config.0.clone(), config.1.clone(), false); - sr.set_raw_sql(raw_sql); - let meta = sr.get_meta(&mut ast[0]); - let res = sr.table_strategy(meta).unwrap(); - assert_eq!(res[0].target_sql, "SELECT idx from db.tshard00003 where idx = 3 and idx = (SELECT idx from db.tshard00003 where idx = 3)".to_string()); + let ast = parser.parse(&raw_sql).unwrap(); + let input = ShardingRewriteInput { + raw_sql: raw_sql.to_string(), + ast: ast[0].clone(), + }; + let res = sr.rewrite(input).unwrap(); + assert_eq!(res[0].target_sql, "SELECT idx from db.tshard_00003 where idx = 3 and idx = (SELECT idx from db.tshard_00003 where idx = 3)".to_string()); let raw_sql = "SELECT idx from db.tshard where idx = 3 and idx = (SELECT idx from db.tshard where idx = 4)".to_string(); - let mut ast = parser.parse(&raw_sql).unwrap(); - let mut sr = ShardingRewrite::new(config.0.clone(), config.1.clone(), false); - sr.set_raw_sql(raw_sql); - let meta = sr.get_meta(&mut ast[0]); - let res = sr.table_strategy(meta).unwrap(); + let ast = parser.parse(&raw_sql).unwrap(); + let input = ShardingRewriteInput { + raw_sql: raw_sql.to_string(), + ast: ast[0].clone(), + }; + let res = sr.rewrite(input).unwrap(); assert_eq!( res.into_iter().map(|x| x.target_sql).collect::>(), vec![ - "SELECT idx from db.tshard00000 where idx = 3 and idx = (SELECT idx from db.tshard00000 where idx = 4)", - "SELECT idx from db.tshard00001 where idx = 3 and idx = (SELECT idx from db.tshard00001 where idx = 4)", - "SELECT idx from db.tshard00002 where idx = 3 and idx = (SELECT idx from db.tshard00002 where idx = 4)", - "SELECT idx from db.tshard00003 where idx = 3 and idx = (SELECT idx from db.tshard00003 where idx = 4)", + "SELECT idx from db.tshard_00000 where idx = 3 and idx = (SELECT idx from db.tshard_00000 where idx = 4)", + "SELECT idx from db.tshard_00001 where idx = 3 and idx = (SELECT idx from db.tshard_00001 where idx = 4)", + "SELECT idx from db.tshard_00002 where idx = 3 and idx = (SELECT idx from db.tshard_00002 where idx = 4)", + "SELECT idx from db.tshard_00003 where idx = 3 and idx = (SELECT idx from db.tshard_00003 where idx = 4)", ], ); } @@ -902,12 +910,13 @@ mod test { let config = get_table_sharding_config(); let raw_sql = "INSERT INTO db.tshard(idx) VALUES (12), (13), (16)".to_string(); let parser = Parser::new(); - let mut ast = parser.parse(&raw_sql).unwrap(); + let ast = parser.parse(&raw_sql).unwrap(); let mut sr = ShardingRewrite::new(config.0.clone(), config.1.clone(), false); - sr.set_raw_sql(raw_sql); - let meta = sr.get_meta(&mut ast[0]); - - let res = sr.table_strategy(meta).unwrap(); + let input = ShardingRewriteInput { + raw_sql: raw_sql.to_string(), + ast: ast[0].clone(), + }; + let res = sr.rewrite(input).unwrap(); assert_eq!( res.into_iter().map(|x| x.target_sql).collect::>(), @@ -916,6 +925,5 @@ mod test { "INSERT INTO db.tshard_00001(idx) VALUES (13)", ], ); - } }