Skip to content

Commit

Permalink
feat(protocol): add handling for set statetment (database-mesh#101)
Browse files Browse the repository at this point in the history
* feat(protocol): add  handling for character set
feat(runtime): add handling for set statement

Signed-off-by: xuanyuan300 <[email protected]>

* chore(style): rustfmt

Signed-off-by: xuanyuan300 <[email protected]>

* chore(protocol): remove unused comments

Signed-off-by: xuanyuan300 <[email protected]>

* chore(runtime): remove unused dependencies

Signed-off-by: xuanyuan300 <[email protected]>

* chore(protocol): remove unused comments

Signed-off-by: xuanyuan300 <[email protected]>

* refactor(protocol): refactor send_query_discard_result method

Signed-off-by: xuanyuan300 <[email protected]>

* fix(runtime): remmove async for handle_set_stmt

Signed-off-by: xuanyuan300 <[email protected]>

* fix(runtime): call init_session_attr when get_conn is None

Signed-off-by: xuanyuan300 <[email protected]>

* chore(protocol): add comments

Signed-off-by: xuanyuan300 <[email protected]>
  • Loading branch information
xuanyuan300 authored and mlycore committed Jun 27, 2022
1 parent d3e762d commit 6832f30
Show file tree
Hide file tree
Showing 10 changed files with 369 additions and 584 deletions.
2 changes: 1 addition & 1 deletion pisa-proxy/parser/mysql/src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ mod test {
//"START TRANSACTION",
//"COMMIT",
//"ROLLBACK",
//"set names utf8mb4",
"set names utf8mb4",
//"SET character_set_connection = gbk;",
//"SET character_set_results = gbk;",
//"SET character_set_client = \"gbk\";",
Expand Down
1 change: 1 addition & 0 deletions pisa-proxy/protocol/mysql/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,6 @@ thiserror = "1.0"
num-traits = "0.2"
num-derive = "0.3"
async-trait = "0.1"
regex = "1"
conn_pool = { path = "../../proxy/pool" }
protocol_codegen = { path = "../codegen" }
710 changes: 173 additions & 537 deletions pisa-proxy/protocol/mysql/src/charset.rs

Large diffs are not rendered by default.

68 changes: 58 additions & 10 deletions pisa-proxy/protocol/mysql/src/client/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,20 @@
use std::{convert::From, str};

use byteorder::{ByteOrder, LittleEndian};
use bytes::{BufMut, BytesMut};
use bytes::{Buf, BufMut, BytesMut};
use futures::{SinkExt, StreamExt};
use rand::rngs::OsRng;
use rsa::{pkcs8::DecodePublicKey, PaddingScheme, PublicKey, RsaPublicKey};
use sha1::Sha1;
use tokio_util::codec::{Decoder, Encoder, Framed};
use regex::Regex;

use super::{codec::ClientCodec, stream::LocalStream};
use crate::{charset::DEFAULT_COLLATION_ID, err::ProtocolError, mysql_const::*, util::*};
use crate::{charset::*, err::ProtocolError, mysql_const::*, util::*};

lazy_static! {
static ref RE: Regex = Regex::new(r"^(?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)").unwrap();
}

/// Handshake state
#[derive(Debug, Clone)]
Expand All @@ -45,21 +50,43 @@ impl Default for HandshakeState {
}
}

#[derive(Debug, Default, Clone)]
pub struct ServerVersion {
pub major: u8,
pub minor: u8,
pub patch: u8,
}

impl From<(&str, &str, &str)> for ServerVersion {
fn from(version: (&str, &str, &str)) -> ServerVersion {
let major = version.0.parse::<u8>().unwrap();
let minor = version.1.parse::<u8>().unwrap();
let patch = version.2.parse::<u8>().unwrap();

ServerVersion {
major,
minor,
patch,
}
}
}

#[derive(Debug, Default, Clone)]
pub struct ClientAuth {
pub next_state: HandshakeState,
pub connection_id: u32,
pub salt: Vec<u8>,
pub capability: u32,
pub client_capability: u32,
pub charset: u8,
pub charset: String,
pub status: u16,
pub auth_plugin_name: String,
pub tls_config: Option<()>,
pub user: String,
pub password: String,
pub db: String,
pub seq: u8,
pub server_version: ServerVersion,
}

impl ClientAuth {
Expand All @@ -73,11 +100,12 @@ impl ClientAuth {
status: 0,
auth_plugin_name: "".to_string(),
tls_config: None,
charset: 0,
charset: "".to_string(),
user: "".to_string(),
password: "".to_string(),
db: "".to_string(),
seq: 0,
server_version: ServerVersion::default(),
}
}

Expand All @@ -97,7 +125,19 @@ impl ClientAuth {

// skip server version, end with 0x00
let pos = data.iter().position(|&x| x == 0x00).unwrap();
let _ = data.split_to(pos + 1);
let version_bytes = data.split_to(pos + 1);
let version = str::from_utf8(&version_bytes).unwrap();
if let Some(caps) = RE.captures(version) {
let ver = ServerVersion::from(
(
caps.name("major").unwrap().as_str(),
caps.name("minor").unwrap().as_str(),
caps.name("patch").unwrap().as_str(),
)
);

self.server_version = ver;
}

// connection id length is 4
self.connection_id = LittleEndian::read_u32(&data.split_to(4));
Expand All @@ -122,9 +162,14 @@ impl ClientAuth {
return Ok(self.clone());
}

// skip server charset
// server charset
// self.charset = data[pos]
self.charset = data.split_to(1)[0] as u8;
let charset_id = data.get_u8();
match self.server_version.major {
5 => self.charset = CHARSET_ID_NAME_MYSQL5[&charset_id].to_string(),
_ => self.charset = CHARSET_ID_NAME_MYSQL8[&charset_id].to_string(),
}


self.status = LittleEndian::read_u16(&data.split_to(2));

Expand Down Expand Up @@ -251,8 +296,11 @@ impl ClientAuth {
//data[11] = 0x00;

//charset [1 byte]
// data[12] = DEFAULT_COLLATION_ID as u8;
data.put_u8(DEFAULT_COLLATION_ID);
self.charset = DEFAULT_CHARSET_NAME.to_string();
match self.server_version.major {
5 => data.put_u8(CHARSET_NAME_ID_MYSQL5[DEFAULT_CHARSET_NAME]),
_ => data.put_u8(CHARSET_NAME_ID_MYSQL8[DEFAULT_CHARSET_NAME]),
}

data.put_slice(&[0; 23]);

Expand Down Expand Up @@ -622,7 +670,7 @@ mod test {
assert_eq!(c.salt[0], 0x29);
assert_eq!(c.salt[c.salt.len() - 1], 0x59);
assert_eq!(c.auth_plugin_name, "caching_sha2_password".to_string());
assert_eq!(c.charset, 0xff);
assert_eq!(c.charset, "utf8mb4");
}

// test auth success with mysql_native_password plugin
Expand Down
26 changes: 26 additions & 0 deletions pisa-proxy/protocol/mysql/src/client/codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

use std::{
ops::{Deref, DerefMut},
pin::Pin,
task::{Context, Poll},
};
Expand Down Expand Up @@ -49,6 +50,31 @@ pub enum ClientCodec {
Common(Framed<LocalStream, CommonCodec>),
}

// Access `AuthInfo` struct by dereferencing the `ClientCodec` struct.
impl Deref for ClientCodec {
type Target = ClientAuth;
fn deref(&self) -> &Self::Target {
match self {
Self::ClientAuth(framed) => framed.codec(),
Self::Resultset(framed) => framed.codec().auth_info.as_ref().unwrap(),
Self::Stmt(framed) => framed.codec().auth_info.as_ref().unwrap(),
Self::Common(framed) => framed.codec().auth_info.as_ref().unwrap(),
}
}
}

// Modify `AuthInfo` struct by dereferencing the `ClientCodec` struct.
impl DerefMut for ClientCodec {
fn deref_mut(&mut self) -> &mut Self::Target {
match self {
Self::ClientAuth(framed) => framed.codec_mut(),
Self::Resultset(framed) => framed.codec_mut().auth_info.as_mut().unwrap(),
Self::Stmt(framed) => framed.codec_mut().auth_info.as_mut().unwrap(),
Self::Common(framed) => framed.codec_mut().auth_info.as_mut().unwrap(),
}
}
}

#[derive(Debug)]
#[pin_project]
pub struct ResultsetStream<'a> {
Expand Down
41 changes: 32 additions & 9 deletions pisa-proxy/protocol/mysql/src/client/conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ use crate::{err::ProtocolError, mysql_const::*};
#[derive(Debug, Default)]
pub struct ClientConn {
pub framed: Option<Box<ClientCodec>>,
pub auth_info: Option<ClientAuth>,
user: String,
password: String,
endpoint: String,
Expand Down Expand Up @@ -74,21 +73,18 @@ impl ClientConn {
))));

let res = handshake(*(framed.take().unwrap())).await?;
let auth_info = Some(res.0.codec().clone());
let framed = Some(Box::new(ClientCodec::ClientAuth(res.0)));

Ok(ClientConn {
user: self.user.clone(),
password: self.password.clone(),
endpoint: self.endpoint.clone(),
framed,
auth_info,
})
}

pub async fn handshake(&mut self) -> Result<(bool, Vec<u8>), ProtocolError> {
let res = handshake(*(self.framed.take().unwrap())).await?;
self.auth_info = Some(res.0.codec().clone());
self.framed = Some(Box::new(ClientCodec::ClientAuth(res.0)));

Ok((res.1, res.2))
Expand All @@ -102,7 +98,7 @@ impl ClientConn {

let mut resultset_codec = framed.into_resultset();

resultset_codec.send(ResultSendCommand::Binary((0x03, val))).await?;
resultset_codec.send(ResultSendCommand::Binary((COM_QUERY, val))).await?;

self.framed = Some(Box::new(ClientCodec::Resultset(resultset_codec)));

Expand Down Expand Up @@ -206,9 +202,22 @@ impl ClientConn {
Ok(CommonStream::new(self.framed.as_mut()))
}

// Send query, but discard result
pub async fn send_query_discard_result(&mut self, val: &str) -> Result<(), ProtocolError> {
let mut stream = self.send_common_command(COM_QUERY, val.as_bytes()).await?;

while stream.next().await.is_some() {}

Ok(())
}

pub fn get_endpoint(&self) -> Option<String> {
Some(self.endpoint.clone())
}

pub fn set_charset(&mut self, name: &str) {
self.framed.as_mut().unwrap().charset = name.to_string()
}
}

impl Clone for ClientConn {
Expand All @@ -218,7 +227,6 @@ impl Clone for ClientConn {
password: self.password.clone(),
endpoint: self.endpoint.clone(),
framed: None,
auth_info: None,
}
}
}
Expand Down Expand Up @@ -254,11 +262,26 @@ impl ConnAttr for ClientConn {
}

fn get_db(&self) -> Option<String> {
if let Some(auth_info) = &self.auth_info {
if auth_info.db.is_empty() {
let codec = self.framed.as_ref();

if let Some(codec) = codec {
if codec.db.is_empty() {
None
} else {
Some(codec.db.clone())
}
} else {
None
}
}

fn get_charset(&self) -> Option<String> {
let codec = self.framed.as_ref();
if let Some(codec) = codec {
if codec.charset.is_empty() {
None
} else {
Some(auth_info.db.clone())
Some(codec.charset.clone())
}
} else {
None
Expand Down
11 changes: 3 additions & 8 deletions pisa-proxy/protocol/mysql/src/server/conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,13 @@ const DEFAULT_CAPABILITY: u32 = CLIENT_LONG_PASSWORD
pub struct Connection {
salt: Vec<u8>,
status: u16,
collation: CollationId,
capability: u32,
connection_id: u32,
_charset: String,
user: String,
password: String,
auth_data: BytesMut,
pub auth_plugin_name: String,
pub charset: String,
pub db: String,
pub affected_rows: i64,
pub pkt: Packet,
Expand All @@ -72,10 +71,9 @@ impl Connection {
Connection {
salt: crate::util::random_buf(20),
status: SERVER_STATUS_AUTOCOMMIT,
collation: DEFAULT_COLLATION_ID,
capability: 0,
connection_id: CONNECTION_ID.load(Ordering::Relaxed),
_charset: DEFAULT_CHARSET.to_string(),
charset: DEFAULT_CHARSET_NAME.to_string(),
auth_plugin_name: "".to_string(),
user,
password,
Expand Down Expand Up @@ -141,10 +139,7 @@ impl Connection {
data.put_u8((DEFAULT_CAPABILITY >> 8) as u8);

//charset, utf-8 default
if self.collation == 0 {
self.collation = DEFAULT_COLLATION_ID;
}
data.put_u8(self.collation);
data.put_u8(CHARSET_NAME_ID_MYSQL5[&*self.charset]);

//status
data.put_u8(self.status as u8);
Expand Down
2 changes: 2 additions & 0 deletions pisa-proxy/proxy/pool/src/conn_pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ pub trait ConnAttr {
fn get_endpoint(&self) -> String;
// Get current db on conn
fn get_db(&self) -> Option<String>;
// Get current charset
fn get_charset(&self) -> Option<String>;
}

#[derive(Debug)]
Expand Down
Loading

0 comments on commit 6832f30

Please sign in to comment.