From 6832f30f2a3b04297753594aa0b1dfd5ea64ef70 Mon Sep 17 00:00:00 2001 From: pierce Date: Fri, 17 Jun 2022 18:30:33 +0800 Subject: [PATCH] feat(protocol): add handling for `set` statetment (#101) * feat(protocol): add handling for character set feat(runtime): add handling for set statement Signed-off-by: xuanyuan300 * chore(style): rustfmt Signed-off-by: xuanyuan300 * chore(protocol): remove unused comments Signed-off-by: xuanyuan300 * chore(runtime): remove unused dependencies Signed-off-by: xuanyuan300 * chore(protocol): remove unused comments Signed-off-by: xuanyuan300 * refactor(protocol): refactor send_query_discard_result method Signed-off-by: xuanyuan300 * fix(runtime): remmove async for handle_set_stmt Signed-off-by: xuanyuan300 * fix(runtime): call init_session_attr when get_conn is None Signed-off-by: xuanyuan300 * chore(protocol): add comments Signed-off-by: xuanyuan300 --- pisa-proxy/parser/mysql/src/parser.rs | 2 +- pisa-proxy/protocol/mysql/Cargo.toml | 1 + pisa-proxy/protocol/mysql/src/charset.rs | 710 +++++------------- pisa-proxy/protocol/mysql/src/client/auth.rs | 68 +- pisa-proxy/protocol/mysql/src/client/codec.rs | 26 + pisa-proxy/protocol/mysql/src/client/conn.rs | 41 +- pisa-proxy/protocol/mysql/src/server/conn.rs | 11 +- pisa-proxy/proxy/pool/src/conn_pool.rs | 2 + pisa-proxy/runtime/mysql/src/server/server.rs | 46 +- .../runtime/mysql/src/transaction_fsm.rs | 46 +- 10 files changed, 369 insertions(+), 584 deletions(-) diff --git a/pisa-proxy/parser/mysql/src/parser.rs b/pisa-proxy/parser/mysql/src/parser.rs index 61132a26..7e1b1b9a 100644 --- a/pisa-proxy/parser/mysql/src/parser.rs +++ b/pisa-proxy/parser/mysql/src/parser.rs @@ -119,7 +119,7 @@ mod test { //"START TRANSACTION", //"COMMIT", //"ROLLBACK", - //"set names utf8mb4", + "set names utf8mb4", //"SET character_set_connection = gbk;", //"SET character_set_results = gbk;", //"SET character_set_client = \"gbk\";", diff --git a/pisa-proxy/protocol/mysql/Cargo.toml b/pisa-proxy/protocol/mysql/Cargo.toml index a9c41fcc..b00903ab 100644 --- a/pisa-proxy/protocol/mysql/Cargo.toml +++ b/pisa-proxy/protocol/mysql/Cargo.toml @@ -32,5 +32,6 @@ thiserror = "1.0" num-traits = "0.2" num-derive = "0.3" async-trait = "0.1" +regex = "1" conn_pool = { path = "../../proxy/pool" } protocol_codegen = { path = "../codegen" } diff --git a/pisa-proxy/protocol/mysql/src/charset.rs b/pisa-proxy/protocol/mysql/src/charset.rs index 01109464..370ee148 100644 --- a/pisa-proxy/protocol/mysql/src/charset.rs +++ b/pisa-proxy/protocol/mysql/src/charset.rs @@ -14,556 +14,192 @@ use std::collections::HashMap; -pub type CollationId = u8; +pub const DEFAULT_CHARSET_NAME: &str = "utf8mb4"; +pub const DEFAULT_COLLATION_NAME: &str = "utf8mb4_general_ci"; +// Data from `select ID,CHARACTER_SET_NAME from COLLATIONS where IS_DEFAULT='yes';`. //charset key is charset name and value is default collation id lazy_static! { - static ref CHARSET_IDS: HashMap = HashMap::from([ - (String::from("big5"), 1), - (String::from("dec8"), 3), - (String::from("cp850"), 4), - (String::from("hp8"), 6), - (String::from("koi8r"), 7), - (String::from("latin1"), 8), - (String::from("latin2"), 9), - (String::from("swe7"), 10), - (String::from("ascii"), 11), - (String::from("ujis"), 12), - (String::from("sjis"), 13), - (String::from("hebrew"), 16), - (String::from("tis620"), 18), - (String::from("euckr"), 19), - (String::from("koi8u"), 22), - (String::from("gb2312"), 24), - (String::from("greek"), 25), - (String::from("cp1250"), 26), - (String::from("gbk"), 28), - (String::from("latin5"), 30), - (String::from("armscii8"), 32), - (String::from("utf8"), 33), - (String::from("ucs2"), 35), - (String::from("cp866"), 36), - (String::from("keybcs2"), 37), - (String::from("macce"), 38), - (String::from("macroman"), 39), - (String::from("cp852"), 40), - (String::from("latin7"), 41), - (String::from("utf8mb4"), 45), - (String::from("cp1251"), 51), - (String::from("utf16"), 54), - (String::from("utf16le"), 56), - (String::from("cp1256"), 57), - (String::from("cp1257"), 59), - (String::from("utf32"), 60), - (String::from("binary"), 63), - (String::from("geostd8"), 92), - (String::from("cp932"), 95), - (String::from("eucjpms"), 97), + pub static ref CHARSET_NAME_ID_MYSQL5: HashMap<&'static str, u8> = HashMap::from([ + ("big5", 1), + ("dec8", 3), + ("cp850", 4), + ("hp8", 6), + ("koi8r", 7), + ("latin1", 8), + ("latin2", 9), + ("swe7", 10), + ("ascii", 11), + ("ujis", 12), + ("sjis", 13), + ("hebrew", 16), + ("tis620", 18), + ("euckr", 19), + ("koi8u", 22), + ("gb2312", 24), + ("greek", 25), + ("cp1250", 26), + ("gbk", 28), + ("latin5", 30), + ("armscii8", 32), + ("utf8", 33), + ("ucs2", 35), + ("cp866", 36), + ("keybcs2", 37), + ("macce", 38), + ("macroman", 39), + ("cp852", 40), + ("latin7", 41), + ("utf8mb4", 45), + ("cp1251", 51), + ("utf16", 54), + ("utf16le", 56), + ("cp1256", 57), + ("cp1257", 59), + ("utf32", 60), + ("binary", 63), + ("geostd8", 92), + ("cp932", 95), + ("eucjpms", 97), + ("gb18030", 248), ]); } -//charset key is charset name and value is default collation name +//charset key is charset id and value is charset name lazy_static! { - static ref CHARSETS: HashMap = HashMap::from([ - (String::from("big5"), String::from("big5_chinese_ci")), - (String::from("dec8"), String::from("dec8_swedish_ci")), - (String::from("cp850"), String::from("cp850_general_ci")), - (String::from("hp8"), String::from("hp8_english_ci")), - (String::from("koi8r"), String::from("koi8r_general_ci)")), - (String::from("latin1"), String::from("latin1_swedish_ci")), - (String::from("latin2"), String::from("latin2_general_ci)")), - (String::from("swe7"), String::from("swe7_swedish_ci)")), - (String::from("ascii"), String::from("ascii_general_ci)")), - (String::from("ujis"), String::from("ujis_japanese_ci")), - (String::from("sjis"), String::from("sjis_japanese_ci")), - (String::from("hebrew"), String::from("hebrew_general_ci")), - (String::from("tis620"), String::from("tis620_thai_ci")), - (String::from("euckr"), String::from("euckr_korean_ci")), - (String::from("koi8u"), String::from("koi8u_general_ci")), - (String::from("gb2312"), String::from("gb2312_chinese_ci")), - (String::from("greek"), String::from("greek_general_ci")), - (String::from("cp1250"), String::from("cp1250_general_ci")), - (String::from("gbk"), String::from("gbk_chinese_ci")), - (String::from("latin5"), String::from("latin5_turkish_ci")), - (String::from("armscii8"), String::from("armscii8_general_ci")), - (String::from("utf8"), String::from("utf8_general_ci")), - (String::from("ucs2"), String::from("ucs2_general_ci")), - (String::from("cp866"), String::from("cp866_general_ci")), - (String::from("keybcs2"), String::from("keybcs2_general_ci")), - (String::from("macce"), String::from("macce_general_ci")), - (String::from("macroman"), String::from("macroman_general_ci")), - (String::from("cp852"), String::from("cp852_general_ci")), - (String::from("latin7"), String::from("latin7_general_ci")), - (String::from("utf8mb4"), String::from("utf8mb4_general_ci")), - (String::from("cp1251"), String::from("cp1251_general_ci")), - (String::from("utf16"), String::from("utf16_general_ci")), - (String::from("utf16le"), String::from("utf16le_general_ci")), - (String::from("cp1256"), String::from("cp1256_general_ci")), - (String::from("cp1257"), String::from("cp1257_general_ci")), - (String::from("utf32"), String::from("utf32_general_ci")), - (String::from("binary"), String::from("binary")), - (String::from("geostd8"), String::from("geostd8_general_ci")), - (String::from("cp932"), String::from("cp932_japanese_ci")), - (String::from("eucjpms"), String::from("eucjpms_japanese_ci")), + pub static ref CHARSET_ID_NAME_MYSQL5: HashMap = HashMap::from([ + (1, "big5"), + (3, "dec8"), + (4, "cp850"), + (6, "hp8"), + (7, "koi8r"), + (8, "latin1"), + (9, "latin2"), + (10, "swe7"), + (11, "ascii"), + (12, "ujis"), + (13, "sjis"), + (16, "hebrew"), + (18, "tis620"), + (19, "euckr"), + (22, "koi8u"), + (24, "gb2312"), + (25, "greek"), + (26, "cp1250"), + (28, "gbk"), + (30, "latin5"), + (32, "armscii8"), + (33, "utf8"), + (35, "ucs2"), + (36, "cp866"), + (37, "keybcs2"), + (38, "macce"), + (39, "macroman"), + (40, "cp852"), + (41, "latin7"), + (45, "utf8mb4"), + (51, "cp1251"), + (54, "utf16"), + (56, "utf16le"), + (57, "cp1256"), + (59, "cp1257"), + (60, "utf32"), + (63, "binary"), + (92, "geostd8"), + (95, "cp932"), + (97, "eucjpms"), + (248, "gb18030"), ]); } lazy_static! { - static ref COLLATIONS: HashMap = HashMap::from([ - (1, String::from("big5_chinese_ci")), - (2, String::from("latin2_czech_cs")), - (3, String::from("dec8_swedish_ci")), - (4, String::from("cp850_general_ci")), - (5, String::from("latin1_german1_ci")), - (6, String::from("hp8_english_ci")), - (7, String::from("koi8r_general_ci")), - (8, String::from("latin1_swedish_ci")), - (9, String::from("latin2_general_ci")), - (10, String::from("swe7_swedish_ci")), - (11, String::from("ascii_general_ci")), - (12, String::from("ujis_japanese_ci")), - (13, String::from("sjis_japanese_ci")), - (14, String::from("cp1251_bulgarian_ci")), - (15, String::from("latin1_danish_ci")), - (16, String::from("hebrew_general_ci")), - (18, String::from("tis620_thai_ci")), - (19, String::from("euckr_korean_ci")), - (20, String::from("latin7_estonian_cs")), - (21, String::from("latin2_hungarian_ci")), - (22, String::from("koi8u_general_ci")), - (23, String::from("cp1251_ukrainian_ci")), - (24, String::from("gb2312_chinese_ci")), - (25, String::from("greek_general_ci")), - (26, String::from("cp1250_general_ci")), - (27, String::from("latin2_croatian_ci")), - (28, String::from("gbk_chinese_ci")), - (29, String::from("cp1257_lithuanian_ci")), - (30, String::from("latin5_turkish_ci")), - (31, String::from("latin1_german2_ci")), - (32, String::from("armscii8_general_ci")), - (33, String::from("utf8_general_ci")), - (34, String::from("cp1250_czech_cs")), - (35, String::from("ucs2_general_ci")), - (36, String::from("cp866_general_ci")), - (37, String::from("keybcs2_general_ci")), - (38, String::from("macce_general_ci")), - (39, String::from("macroman_general_ci")), - (40, String::from("cp852_general_ci")), - (41, String::from("latin7_general_ci")), - (42, String::from("latin7_general_cs")), - (43, String::from("macce_bin")), - (44, String::from("cp1250_croatian_ci")), - (45, String::from("utf8mb4_general_ci")), - (46, String::from("utf8mb4_bin")), - (47, String::from("latin1_bin")), - (48, String::from("latin1_general_ci")), - (49, String::from("latin1_general_cs")), - (50, String::from("cp1251_bin")), - (51, String::from("cp1251_general_ci")), - (52, String::from("cp1251_general_cs")), - (53, String::from("macroman_bin")), - (54, String::from("utf16_general_ci")), - (55, String::from("utf16_bin")), - (56, String::from("utf16le_general_ci")), - (57, String::from("cp1256_general_ci")), - (58, String::from("cp1257_bin")), - (59, String::from("cp1257_general_ci")), - (60, String::from("utf32_general_ci")), - (61, String::from("utf32_bin")), - (62, String::from("utf16le_bin")), - (63, String::from("binary")), - (64, String::from("armscii8_bin")), - (65, String::from("ascii_bin")), - (66, String::from("cp1250_bin")), - (67, String::from("cp1256_bin")), - (68, String::from("cp866_bin")), - (69, String::from("dec8_bin")), - (70, String::from("greek_bin")), - (71, String::from("hebrew_bin")), - (72, String::from("hp8_bin")), - (73, String::from("keybcs2_bin")), - (74, String::from("koi8r_bin")), - (75, String::from("koi8u_bin")), - (77, String::from("latin2_bin")), - (78, String::from("latin5_bin")), - (79, String::from("latin7_bin")), - (80, String::from("cp850_bin")), - (81, String::from("cp852_bin")), - (82, String::from("swe7_bin")), - (83, String::from("utf8_bin")), - (84, String::from("big5_bin")), - (85, String::from("euckr_bin")), - (86, String::from("gb2312_bin")), - (87, String::from("gbk_bin")), - (88, String::from("sjis_bin")), - (89, String::from("tis620_bin")), - (90, String::from("ucs2_bin")), - (91, String::from("ujis_bin")), - (92, String::from("geostd8_general_ci")), - (93, String::from("geostd8_bin")), - (94, String::from("latin1_spanish_ci")), - (95, String::from("cp932_japanese_ci")), - (96, String::from("cp932_bin")), - (97, String::from("eucjpms_japanese_ci")), - (98, String::from("eucjpms_bin")), - (99, String::from("cp1250_polish_ci")), - (101, String::from("utf16_unicode_ci")), - (102, String::from("utf16_icelandic_ci")), - (103, String::from("utf16_latvian_ci")), - (104, String::from("utf16_romanian_ci")), - (105, String::from("utf16_slovenian_ci")), - (106, String::from("utf16_polish_ci")), - (107, String::from("utf16_estonian_ci")), - (108, String::from("utf16_spanish_ci")), - (109, String::from("utf16_swedish_ci")), - (110, String::from("utf16_turkish_ci")), - (111, String::from("utf16_czech_ci")), - (112, String::from("utf16_danish_ci")), - (113, String::from("utf16_lithuanian_ci")), - (114, String::from("utf16_slovak_ci")), - (115, String::from("utf16_spanish2_ci")), - (116, String::from("utf16_roman_ci")), - (117, String::from("utf16_persian_ci")), - (118, String::from("utf16_esperanto_ci")), - (119, String::from("utf16_hungarian_ci")), - (120, String::from("utf16_sinhala_ci")), - (121, String::from("utf16_german2_ci")), - (122, String::from("utf16_croatian_ci")), - (123, String::from("utf16_unicode_520_ci")), - (124, String::from("utf16_vietnamese_ci")), - (128, String::from("ucs2_unicode_ci")), - (129, String::from("ucs2_icelandic_ci")), - (130, String::from("ucs2_latvian_ci")), - (131, String::from("ucs2_romanian_ci")), - (132, String::from("ucs2_slovenian_ci")), - (133, String::from("ucs2_polish_ci")), - (134, String::from("ucs2_estonian_ci")), - (135, String::from("ucs2_spanish_ci")), - (136, String::from("ucs2_swedish_ci")), - (137, String::from("ucs2_turkish_ci")), - (138, String::from("ucs2_czech_ci")), - (139, String::from("ucs2_danish_ci")), - (140, String::from("ucs2_lithuanian_ci")), - (141, String::from("ucs2_slovak_ci")), - (142, String::from("ucs2_spanish2_ci")), - (143, String::from("ucs2_roman_ci")), - (144, String::from("ucs2_persian_ci")), - (145, String::from("ucs2_esperanto_ci")), - (146, String::from("ucs2_hungarian_ci")), - (147, String::from("ucs2_sinhala_ci")), - (148, String::from("ucs2_german2_ci")), - (149, String::from("ucs2_croatian_ci")), - (150, String::from("ucs2_unicode_520_ci")), - (151, String::from("ucs2_vietnamese_ci")), - (159, String::from("ucs2_general_mysql500_ci")), - (160, String::from("utf32_unicode_ci")), - (161, String::from("utf32_icelandic_ci")), - (162, String::from("utf32_latvian_ci")), - (163, String::from("utf32_romanian_ci")), - (164, String::from("utf32_slovenian_ci")), - (165, String::from("utf32_polish_ci")), - (166, String::from("utf32_estonian_ci")), - (167, String::from("utf32_spanish_ci")), - (168, String::from("utf32_swedish_ci")), - (169, String::from("utf32_turkish_ci")), - (170, String::from("utf32_czech_ci")), - (171, String::from("utf32_danish_ci")), - (172, String::from("utf32_lithuanian_ci")), - (173, String::from("utf32_slovak_ci")), - (174, String::from("utf32_spanish2_ci")), - (175, String::from("utf32_roman_ci")), - (176, String::from("utf32_persian_ci")), - (177, String::from("utf32_esperanto_ci")), - (178, String::from("utf32_hungarian_ci")), - (179, String::from("utf32_sinhala_ci")), - (180, String::from("utf32_german2_ci")), - (181, String::from("utf32_croatian_ci")), - (182, String::from("utf32_unicode_520_ci")), - (183, String::from("utf32_vietnamese_ci")), - (192, String::from("utf8_unicode_ci")), - (193, String::from("utf8_icelandic_ci")), - (194, String::from("utf8_latvian_ci")), - (195, String::from("utf8_romanian_ci")), - (196, String::from("utf8_slovenian_ci")), - (197, String::from("utf8_polish_ci")), - (198, String::from("utf8_estonian_ci")), - (199, String::from("utf8_spanish_ci")), - (200, String::from("utf8_swedish_ci")), - (201, String::from("utf8_turkish_ci")), - (202, String::from("utf8_czech_ci")), - (203, String::from("utf8_danish_ci")), - (204, String::from("utf8_lithuanian_ci")), - (205, String::from("utf8_slovak_ci")), - (206, String::from("utf8_spanish2_ci")), - (207, String::from("utf8_roman_ci")), - (208, String::from("utf8_persian_ci")), - (209, String::from("utf8_esperanto_ci")), - (210, String::from("utf8_hungarian_ci")), - (211, String::from("utf8_sinhala_ci")), - (212, String::from("utf8_german2_ci")), - (213, String::from("utf8_croatian_ci")), - (214, String::from("utf8_unicode_520_ci")), - (215, String::from("utf8_vietnamese_ci")), - (223, String::from("utf8_general_mysql500_ci")), - (224, String::from("utf8mb4_unicode_ci")), - (225, String::from("utf8mb4_icelandic_ci")), - (226, String::from("utf8mb4_latvian_ci")), - (227, String::from("utf8mb4_romanian_ci")), - (228, String::from("utf8mb4_slovenian_ci")), - (229, String::from("utf8mb4_polish_ci")), - (230, String::from("utf8mb4_estonian_ci")), - (231, String::from("utf8mb4_spanish_ci")), - (232, String::from("utf8mb4_swedish_ci")), - (233, String::from("utf8mb4_turkish_ci")), - (234, String::from("utf8mb4_czech_ci")), - (235, String::from("utf8mb4_danish_ci")), - (236, String::from("utf8mb4_lithuanian_ci")), - (237, String::from("utf8mb4_slovak_ci")), - (238, String::from("utf8mb4_spanish2_ci")), - (239, String::from("utf8mb4_roman_ci")), - (240, String::from("utf8mb4_persian_ci")), - (241, String::from("utf8mb4_esperanto_ci")), - (242, String::from("utf8mb4_hungarian_ci")), - (243, String::from("utf8mb4_sinhala_ci")), - (244, String::from("utf8mb4_german2_ci")), - (245, String::from("utf8mb4_croatian_ci")), - (246, String::from("utf8mb4_unicode_520_ci")), - (247, String::from("utf8mb4_vietnamese_ci")), - ]); -} -lazy_static! { - static ref COLLATION_NAMES: HashMap = HashMap::from([ //TODO: HashMap<&str, CollationId> - (String::from("big5_chinese_ci"), 1), - (String::from("latin2_czech_cs"), 2), - (String::from("dec8_swedish_ci"), 3), - (String::from("cp850_general_ci"), 4), - (String::from("latin1_german1_ci"), 5), - (String::from("hp8_english_ci"), 6), - (String::from("koi8r_general_ci"), 7), - (String::from("latin1_swedish_ci"), 8), - (String::from("latin2_general_ci"), 9), - (String::from("swe7_swedish_ci"), 10), - (String::from("ascii_general_ci"), 11), - (String::from("ujis_japanese_ci"), 12), - (String::from("sjis_japanese_ci"), 13), - (String::from("cp1251_bulgarian_ci"), 14), - (String::from("latin1_danish_ci"), 15), - (String::from("hebrew_general_ci"), 16), - (String::from("tis620_thai_ci"), 18), - (String::from("euckr_korean_ci"), 19), - (String::from("latin7_estonian_cs"), 20), - (String::from("latin2_hungarian_ci"), 21), - (String::from("koi8u_general_ci"), 22), - (String::from("cp1251_ukrainian_ci"), 23), - (String::from("gb2312_chinese_ci"), 24), - (String::from("greek_general_ci"), 25), - (String::from("cp1250_general_ci"), 26), - (String::from("latin2_croatian_ci"), 27), - (String::from("gbk_chinese_ci"), 28), - (String::from("cp1257_lithuanian_ci"), 29), - (String::from("latin5_turkish_ci"), 30), - (String::from("latin1_german2_ci"), 31), - (String::from("armscii8_general_ci"), 32), - (String::from("utf8_general_ci"), 33), - (String::from("cp1250_czech_cs"), 34), - (String::from("ucs2_general_ci"), 35), - (String::from("cp866_general_ci"), 36), - (String::from("keybcs2_general_ci"), 37), - (String::from("macce_general_ci"), 38), - (String::from("macroman_general_ci"), 39), - (String::from("cp852_general_ci"), 40), - (String::from("latin7_general_ci"), 41), - (String::from("latin7_general_cs"), 42), - (String::from("macce_bin"), 43), - (String::from("cp1250_croatian_ci"), 44), - (String::from("utf8mb4_general_ci"), 45), - (String::from("utf8mb4_bin"), 46), - (String::from("latin1_bin"), 47), - (String::from("latin1_general_ci"), 48), - (String::from("latin1_general_cs"), 49), - (String::from("cp1251_bin"), 50), - (String::from("cp1251_general_ci"), 51), - (String::from("cp1251_general_cs"), 52), - (String::from("macroman_bin"), 53), - (String::from("utf16_general_ci"), 54), - (String::from("utf16_bin"), 55), - (String::from("utf16le_general_ci"), 56), - (String::from("cp1256_general_ci"), 57), - (String::from("cp1257_bin"), 58), - (String::from("cp1257_general_ci"), 59), - (String::from("utf32_general_ci"), 60), - (String::from("utf32_bin"), 61), - (String::from("utf16le_bin"), 62), - (String::from("binary"), 63), - (String::from("armscii8_bin"), 64), - (String::from("ascii_bin"), 65), - (String::from("cp1250_bin"), 66), - (String::from("cp1256_bin"), 67), - (String::from("cp866_bin"), 68), - (String::from("dec8_bin"), 69), - (String::from("greek_bin"), 70), - (String::from("hebrew_bin"), 71), - (String::from("hp8_bin"), 72), - (String::from("keybcs2_bin"), 73), - (String::from("koi8r_bin"), 74), - (String::from("koi8u_bin"), 75), - (String::from("latin2_bin"), 77), - (String::from("latin5_bin"), 78), - (String::from("latin7_bin"), 79), - (String::from("cp850_bin"), 80), - (String::from("cp852_bin"), 81), - (String::from("swe7_bin"), 82), - (String::from("utf8_bin"), 83), - (String::from("big5_bin"), 84), - (String::from("euckr_bin"), 85), - (String::from("gb2312_bin"), 86), - (String::from("gbk_bin"), 87), - (String::from("sjis_bin"), 88), - (String::from("tis620_bin"), 89), - (String::from("ucs2_bin"), 90), - (String::from("ujis_bin"), 91), - (String::from("geostd8_general_ci"), 92), - (String::from("geostd8_bin"), 93), - (String::from("latin1_spanish_ci"), 94), - (String::from("cp932_japanese_ci"), 95), - (String::from("cp932_bin"), 96), - (String::from("eucjpms_japanese_ci"), 97), - (String::from("eucjpms_bin"), 98), - (String::from("cp1250_polish_ci"), 99), - (String::from("utf16_unicode_ci"), 101), - (String::from("utf16_icelandic_ci"), 102), - (String::from("utf16_latvian_ci"), 103), - (String::from("utf16_romanian_ci"), 104), - (String::from("utf16_slovenian_ci"), 105), - (String::from("utf16_polish_ci"), 106), - (String::from("utf16_estonian_ci"), 107), - (String::from("utf16_spanish_ci"), 108), - (String::from("utf16_swedish_ci"), 109), - (String::from("utf16_turkish_ci"), 110), - (String::from("utf16_czech_ci"), 111), - (String::from("utf16_danish_ci"), 112), - (String::from("utf16_lithuanian_ci"), 113), - (String::from("utf16_slovak_ci"), 114), - (String::from("utf16_spanish2_ci"), 115), - (String::from("utf16_roman_ci"), 116), - (String::from("utf16_persian_ci"), 117), - (String::from("utf16_esperanto_ci"), 118), - (String::from("utf16_hungarian_ci"), 119), - (String::from("utf16_sinhala_ci"), 120), - (String::from("utf16_german2_ci"), 121), - (String::from("utf16_croatian_ci"), 122), - (String::from("utf16_unicode_520_ci"), 123), - (String::from("utf16_vietnamese_ci"), 124), - (String::from("ucs2_unicode_ci"), 128), - (String::from("ucs2_icelandic_ci"), 129), - (String::from("ucs2_latvian_ci"), 130), - (String::from("ucs2_romanian_ci"), 131), - (String::from("ucs2_slovenian_ci"), 132), - (String::from("ucs2_polish_ci"), 133), - (String::from("ucs2_estonian_ci"), 134), - (String::from("ucs2_spanish_ci"), 135), - (String::from("ucs2_swedish_ci"), 136), - (String::from("ucs2_turkish_ci"), 137), - (String::from("ucs2_czech_ci"), 138), - (String::from("ucs2_danish_ci"), 139), - (String::from("ucs2_lithuanian_ci"), 140), - (String::from("ucs2_slovak_ci"), 141), - (String::from("ucs2_spanish2_ci"), 142), - (String::from("ucs2_roman_ci"), 143), - (String::from("ucs2_persian_ci"), 144), - (String::from("ucs2_esperanto_ci"), 145), - (String::from("ucs2_hungarian_ci"), 146), - (String::from("ucs2_sinhala_ci"), 147), - (String::from("ucs2_german2_ci"), 148), - (String::from("ucs2_croatian_ci"), 149), - (String::from("ucs2_unicode_520_ci"), 150), - (String::from("ucs2_vietnamese_ci"), 151), - (String::from("ucs2_general_mysql500_ci"), 159), - (String::from("utf32_unicode_ci"), 160), - (String::from("utf32_icelandic_ci"), 161), - (String::from("utf32_latvian_ci"), 162), - (String::from("utf32_romanian_ci"), 163), - (String::from("utf32_slovenian_ci"), 164), - (String::from("utf32_polish_ci"), 165), - (String::from("utf32_estonian_ci"), 166), - (String::from("utf32_spanish_ci"), 167), - (String::from("utf32_swedish_ci"), 168), - (String::from("utf32_turkish_ci"), 169), - (String::from("utf32_czech_ci"), 170), - (String::from("utf32_danish_ci"), 171), - (String::from("utf32_lithuanian_ci"), 172), - (String::from("utf32_slovak_ci"), 173), - (String::from("utf32_spanish2_ci"), 174), - (String::from("utf32_roman_ci"), 175), - (String::from("utf32_persian_ci"), 176), - (String::from("utf32_esperanto_ci"), 177), - (String::from("utf32_hungarian_ci"), 178), - (String::from("utf32_sinhala_ci"), 179), - (String::from("utf32_german2_ci"), 180), - (String::from("utf32_croatian_ci"), 181), - (String::from("utf32_unicode_520_ci"), 182), - (String::from("utf32_vietnamese_ci"), 183), - (String::from("utf8_unicode_ci"), 192), - (String::from("utf8_icelandic_ci"), 193), - (String::from("utf8_latvian_ci"), 194), - (String::from("utf8_romanian_ci"), 195), - (String::from("utf8_slovenian_ci"), 196), - (String::from("utf8_polish_ci"), 197), - (String::from("utf8_estonian_ci"), 198), - (String::from("utf8_spanish_ci"), 199), - (String::from("utf8_swedish_ci"), 200), - (String::from("utf8_turkish_ci"), 201), - (String::from("utf8_czech_ci"), 202), - (String::from("utf8_danish_ci"), 203), - (String::from("utf8_lithuanian_ci"), 204), - (String::from("utf8_slovak_ci"), 205), - (String::from("utf8_spanish2_ci"), 206), - (String::from("utf8_roman_ci"), 207), - (String::from("utf8_persian_ci"), 208), - (String::from("utf8_esperanto_ci"), 209), - (String::from("utf8_hungarian_ci"), 210), - (String::from("utf8_sinhala_ci"), 211), - (String::from("utf8_german2_ci"), 212), - (String::from("utf8_croatian_ci"), 213), - (String::from("utf8_unicode_520_ci"), 214), - (String::from("utf8_vietnamese_ci"), 215), - (String::from("utf8_general_mysql500_ci"), 223), - (String::from("utf8mb4_unicode_ci"), 224), - (String::from("utf8mb4_icelandic_ci"), 225), - (String::from("utf8mb4_latvian_ci"), 226), - (String::from("utf8mb4_romanian_ci"), 227), - (String::from("utf8mb4_slovenian_ci"), 228), - (String::from("utf8mb4_polish_ci"), 229), - (String::from("utf8mb4_estonian_ci"), 230), - (String::from("utf8mb4_spanish_ci"), 231), - (String::from("utf8mb4_swedish_ci"), 232), - (String::from("utf8mb4_turkish_ci"), 233), - (String::from("utf8mb4_czech_ci"), 234), - (String::from("utf8mb4_danish_ci"), 235), - (String::from("utf8mb4_lithuanian_ci"), 236), - (String::from("utf8mb4_slovak_ci"), 237), - (String::from("utf8mb4_spanish2_ci"), 238), - (String::from("utf8mb4_roman_ci"), 239), - (String::from("utf8mb4_persian_ci"), 240), - (String::from("utf8mb4_esperanto_ci"), 241), - (String::from("utf8mb4_hungarian_ci"), 242), - (String::from("utf8mb4_sinhala_ci"), 243), - (String::from("utf8mb4_german2_ci"), 244), - (String::from("utf8mb4_croatian_ci"), 245), - (String::from("utf8mb4_unicode_520_ci"), 246), - (String::from("utf8mb4_vietnamese_ci"), 247), + pub static ref CHARSET_NAME_ID_MYSQL8: HashMap<&'static str, u8> = HashMap::from([ + ("big5", 1), + ("dec8", 3), + ("cp850", 4), + ("hp8", 6), + ("koi8r", 7), + ("latin1", 8), + ("latin2", 9), + ("swe7", 10), + ("ascii", 11), + ("ujis", 12), + ("sjis", 13), + ("hebrew", 16), + ("tis620", 18), + ("euckr", 19), + ("koi8u", 22), + ("gb2312", 24), + ("greek", 25), + ("cp1250", 26), + ("gbk", 28), + ("latin5", 30), + ("armscii8", 32), + ("utf8", 33), + ("ucs2", 35), + ("cp866", 36), + ("keybcs2", 37), + ("macce", 38), + ("macroman", 39), + ("cp852", 40), + ("latin7", 41), + ("cp1251", 51), + ("utf16", 54), + ("utf16le", 56), + ("cp1256", 57), + ("cp1257", 59), + ("utf32", 60), + ("binary", 63), + ("geostd8", 92), + ("cp932", 95), + ("eucjpms", 97), + ("gb18030", 248), + ("utf8mb4", 255), ]); } lazy_static! { - pub static ref DEFAULT_CHARSET: String = "utf8".to_string(); - pub static ref DEFAULT_COLLATION_NAME: String = "utf8_general_ci".to_string(); -} - -pub static DEFAULT_COLLATION_ID: CollationId = 255; - -#[test] -fn test_charset() { - assert_eq!(COLLATION_NAMES["utf8mb4_vietnamese_ci"], 247); - //assert_eq!(DEFAULT_CHARSET, "utf8"); + pub static ref CHARSET_ID_NAME_MYSQL8: HashMap = HashMap::from([ + (1, "big5"), + (3, "dec8"), + (4, "cp850"), + (6, "hp8"), + (7, "koi8r"), + (8, "latin1"), + (9, "latin2"), + (10, "swe7"), + (11, "ascii"), + (12, "ujis"), + (13, "sjis"), + (16, "hebrew"), + (18, "tis620"), + (19, "euckr"), + (22, "koi8u"), + (24, "gb2312"), + (25, "greek"), + (26, "cp1250"), + (28, "gbk"), + (30, "latin5"), + (32, "armscii8"), + (33, "utf8"), + (35, "ucs2"), + (36, "cp866"), + (37, "keybcs2"), + (38, "macce"), + (39, "macroman"), + (40, "cp852"), + (41, "latin7"), + (51, "cp1251"), + (54, "utf16"), + (56, "utf16le"), + (57, "cp1256"), + (59, "cp1257"), + (60, "utf32"), + (63, "binary"), + (92, "geostd8"), + (95, "cp932"), + (97, "eucjpms"), + (248, "gb18030"), + (255, "utf8mb4"), + ]); } diff --git a/pisa-proxy/protocol/mysql/src/client/auth.rs b/pisa-proxy/protocol/mysql/src/client/auth.rs index 4d8c1549..c1242adb 100644 --- a/pisa-proxy/protocol/mysql/src/client/auth.rs +++ b/pisa-proxy/protocol/mysql/src/client/auth.rs @@ -15,15 +15,20 @@ use std::{convert::From, str}; use byteorder::{ByteOrder, LittleEndian}; -use bytes::{BufMut, BytesMut}; +use bytes::{Buf, BufMut, BytesMut}; use futures::{SinkExt, StreamExt}; use rand::rngs::OsRng; use rsa::{pkcs8::DecodePublicKey, PaddingScheme, PublicKey, RsaPublicKey}; use sha1::Sha1; use tokio_util::codec::{Decoder, Encoder, Framed}; +use regex::Regex; use super::{codec::ClientCodec, stream::LocalStream}; -use crate::{charset::DEFAULT_COLLATION_ID, err::ProtocolError, mysql_const::*, util::*}; +use crate::{charset::*, err::ProtocolError, mysql_const::*, util::*}; + +lazy_static! { + static ref RE: Regex = Regex::new(r"^(?P\d+)\.(?P\d+)\.(?P\d+)").unwrap(); +} /// Handshake state #[derive(Debug, Clone)] @@ -45,6 +50,27 @@ impl Default for HandshakeState { } } +#[derive(Debug, Default, Clone)] +pub struct ServerVersion { + pub major: u8, + pub minor: u8, + pub patch: u8, +} + +impl From<(&str, &str, &str)> for ServerVersion { + fn from(version: (&str, &str, &str)) -> ServerVersion { + let major = version.0.parse::().unwrap(); + let minor = version.1.parse::().unwrap(); + let patch = version.2.parse::().unwrap(); + + ServerVersion { + major, + minor, + patch, + } + } +} + #[derive(Debug, Default, Clone)] pub struct ClientAuth { pub next_state: HandshakeState, @@ -52,7 +78,7 @@ pub struct ClientAuth { pub salt: Vec, pub capability: u32, pub client_capability: u32, - pub charset: u8, + pub charset: String, pub status: u16, pub auth_plugin_name: String, pub tls_config: Option<()>, @@ -60,6 +86,7 @@ pub struct ClientAuth { pub password: String, pub db: String, pub seq: u8, + pub server_version: ServerVersion, } impl ClientAuth { @@ -73,11 +100,12 @@ impl ClientAuth { status: 0, auth_plugin_name: "".to_string(), tls_config: None, - charset: 0, + charset: "".to_string(), user: "".to_string(), password: "".to_string(), db: "".to_string(), seq: 0, + server_version: ServerVersion::default(), } } @@ -97,7 +125,19 @@ impl ClientAuth { // skip server version, end with 0x00 let pos = data.iter().position(|&x| x == 0x00).unwrap(); - let _ = data.split_to(pos + 1); + let version_bytes = data.split_to(pos + 1); + let version = str::from_utf8(&version_bytes).unwrap(); + if let Some(caps) = RE.captures(version) { + let ver = ServerVersion::from( + ( + caps.name("major").unwrap().as_str(), + caps.name("minor").unwrap().as_str(), + caps.name("patch").unwrap().as_str(), + ) + ); + + self.server_version = ver; + } // connection id length is 4 self.connection_id = LittleEndian::read_u32(&data.split_to(4)); @@ -122,9 +162,14 @@ impl ClientAuth { return Ok(self.clone()); } - // skip server charset + // server charset // self.charset = data[pos] - self.charset = data.split_to(1)[0] as u8; + let charset_id = data.get_u8(); + match self.server_version.major { + 5 => self.charset = CHARSET_ID_NAME_MYSQL5[&charset_id].to_string(), + _ => self.charset = CHARSET_ID_NAME_MYSQL8[&charset_id].to_string(), + } + self.status = LittleEndian::read_u16(&data.split_to(2)); @@ -251,8 +296,11 @@ impl ClientAuth { //data[11] = 0x00; //charset [1 byte] - // data[12] = DEFAULT_COLLATION_ID as u8; - data.put_u8(DEFAULT_COLLATION_ID); + self.charset = DEFAULT_CHARSET_NAME.to_string(); + match self.server_version.major { + 5 => data.put_u8(CHARSET_NAME_ID_MYSQL5[DEFAULT_CHARSET_NAME]), + _ => data.put_u8(CHARSET_NAME_ID_MYSQL8[DEFAULT_CHARSET_NAME]), + } data.put_slice(&[0; 23]); @@ -622,7 +670,7 @@ mod test { assert_eq!(c.salt[0], 0x29); assert_eq!(c.salt[c.salt.len() - 1], 0x59); assert_eq!(c.auth_plugin_name, "caching_sha2_password".to_string()); - assert_eq!(c.charset, 0xff); + assert_eq!(c.charset, "utf8mb4"); } // test auth success with mysql_native_password plugin diff --git a/pisa-proxy/protocol/mysql/src/client/codec.rs b/pisa-proxy/protocol/mysql/src/client/codec.rs index 5dfe2628..a736a71b 100644 --- a/pisa-proxy/protocol/mysql/src/client/codec.rs +++ b/pisa-proxy/protocol/mysql/src/client/codec.rs @@ -13,6 +13,7 @@ // limitations under the License. use std::{ + ops::{Deref, DerefMut}, pin::Pin, task::{Context, Poll}, }; @@ -49,6 +50,31 @@ pub enum ClientCodec { Common(Framed), } +// Access `AuthInfo` struct by dereferencing the `ClientCodec` struct. +impl Deref for ClientCodec { + type Target = ClientAuth; + fn deref(&self) -> &Self::Target { + match self { + Self::ClientAuth(framed) => framed.codec(), + Self::Resultset(framed) => framed.codec().auth_info.as_ref().unwrap(), + Self::Stmt(framed) => framed.codec().auth_info.as_ref().unwrap(), + Self::Common(framed) => framed.codec().auth_info.as_ref().unwrap(), + } + } +} + +// Modify `AuthInfo` struct by dereferencing the `ClientCodec` struct. +impl DerefMut for ClientCodec { + fn deref_mut(&mut self) -> &mut Self::Target { + match self { + Self::ClientAuth(framed) => framed.codec_mut(), + Self::Resultset(framed) => framed.codec_mut().auth_info.as_mut().unwrap(), + Self::Stmt(framed) => framed.codec_mut().auth_info.as_mut().unwrap(), + Self::Common(framed) => framed.codec_mut().auth_info.as_mut().unwrap(), + } + } +} + #[derive(Debug)] #[pin_project] pub struct ResultsetStream<'a> { diff --git a/pisa-proxy/protocol/mysql/src/client/conn.rs b/pisa-proxy/protocol/mysql/src/client/conn.rs index 50061fe8..5d45bb12 100644 --- a/pisa-proxy/protocol/mysql/src/client/conn.rs +++ b/pisa-proxy/protocol/mysql/src/client/conn.rs @@ -34,7 +34,6 @@ use crate::{err::ProtocolError, mysql_const::*}; #[derive(Debug, Default)] pub struct ClientConn { pub framed: Option>, - pub auth_info: Option, user: String, password: String, endpoint: String, @@ -74,7 +73,6 @@ impl ClientConn { )))); let res = handshake(*(framed.take().unwrap())).await?; - let auth_info = Some(res.0.codec().clone()); let framed = Some(Box::new(ClientCodec::ClientAuth(res.0))); Ok(ClientConn { @@ -82,13 +80,11 @@ impl ClientConn { password: self.password.clone(), endpoint: self.endpoint.clone(), framed, - auth_info, }) } pub async fn handshake(&mut self) -> Result<(bool, Vec), ProtocolError> { let res = handshake(*(self.framed.take().unwrap())).await?; - self.auth_info = Some(res.0.codec().clone()); self.framed = Some(Box::new(ClientCodec::ClientAuth(res.0))); Ok((res.1, res.2)) @@ -102,7 +98,7 @@ impl ClientConn { let mut resultset_codec = framed.into_resultset(); - resultset_codec.send(ResultSendCommand::Binary((0x03, val))).await?; + resultset_codec.send(ResultSendCommand::Binary((COM_QUERY, val))).await?; self.framed = Some(Box::new(ClientCodec::Resultset(resultset_codec))); @@ -206,9 +202,22 @@ impl ClientConn { Ok(CommonStream::new(self.framed.as_mut())) } + // Send query, but discard result + pub async fn send_query_discard_result(&mut self, val: &str) -> Result<(), ProtocolError> { + let mut stream = self.send_common_command(COM_QUERY, val.as_bytes()).await?; + + while stream.next().await.is_some() {} + + Ok(()) + } + pub fn get_endpoint(&self) -> Option { Some(self.endpoint.clone()) } + + pub fn set_charset(&mut self, name: &str) { + self.framed.as_mut().unwrap().charset = name.to_string() + } } impl Clone for ClientConn { @@ -218,7 +227,6 @@ impl Clone for ClientConn { password: self.password.clone(), endpoint: self.endpoint.clone(), framed: None, - auth_info: None, } } } @@ -254,11 +262,26 @@ impl ConnAttr for ClientConn { } fn get_db(&self) -> Option { - if let Some(auth_info) = &self.auth_info { - if auth_info.db.is_empty() { + let codec = self.framed.as_ref(); + + if let Some(codec) = codec { + if codec.db.is_empty() { + None + } else { + Some(codec.db.clone()) + } + } else { + None + } + } + + fn get_charset(&self) -> Option { + let codec = self.framed.as_ref(); + if let Some(codec) = codec { + if codec.charset.is_empty() { None } else { - Some(auth_info.db.clone()) + Some(codec.charset.clone()) } } else { None diff --git a/pisa-proxy/protocol/mysql/src/server/conn.rs b/pisa-proxy/protocol/mysql/src/server/conn.rs index 80cab551..2bbb8e54 100644 --- a/pisa-proxy/protocol/mysql/src/server/conn.rs +++ b/pisa-proxy/protocol/mysql/src/server/conn.rs @@ -52,14 +52,13 @@ const DEFAULT_CAPABILITY: u32 = CLIENT_LONG_PASSWORD pub struct Connection { salt: Vec, status: u16, - collation: CollationId, capability: u32, connection_id: u32, - _charset: String, user: String, password: String, auth_data: BytesMut, pub auth_plugin_name: String, + pub charset: String, pub db: String, pub affected_rows: i64, pub pkt: Packet, @@ -72,10 +71,9 @@ impl Connection { Connection { salt: crate::util::random_buf(20), status: SERVER_STATUS_AUTOCOMMIT, - collation: DEFAULT_COLLATION_ID, capability: 0, connection_id: CONNECTION_ID.load(Ordering::Relaxed), - _charset: DEFAULT_CHARSET.to_string(), + charset: DEFAULT_CHARSET_NAME.to_string(), auth_plugin_name: "".to_string(), user, password, @@ -141,10 +139,7 @@ impl Connection { data.put_u8((DEFAULT_CAPABILITY >> 8) as u8); //charset, utf-8 default - if self.collation == 0 { - self.collation = DEFAULT_COLLATION_ID; - } - data.put_u8(self.collation); + data.put_u8(CHARSET_NAME_ID_MYSQL5[&*self.charset]); //status data.put_u8(self.status as u8); diff --git a/pisa-proxy/proxy/pool/src/conn_pool.rs b/pisa-proxy/proxy/pool/src/conn_pool.rs index c952e6f1..27b1a10e 100644 --- a/pisa-proxy/proxy/pool/src/conn_pool.rs +++ b/pisa-proxy/proxy/pool/src/conn_pool.rs @@ -39,6 +39,8 @@ pub trait ConnAttr { fn get_endpoint(&self) -> String; // Get current db on conn fn get_db(&self) -> Option; + // Get current charset + fn get_charset(&self) -> Option; } #[derive(Debug)] diff --git a/pisa-proxy/runtime/mysql/src/server/server.rs b/pisa-proxy/runtime/mysql/src/server/server.rs index 8ccc6b60..922b2c4b 100644 --- a/pisa-proxy/runtime/mysql/src/server/server.rs +++ b/pisa-proxy/runtime/mysql/src/server/server.rs @@ -21,7 +21,7 @@ use conn_pool::Pool; use futures::StreamExt; use loadbalance::balance::BalanceType; use mysql_parser::{ - ast::{BeginStmt, SqlStmt}, + ast::{BeginStmt, SqlStmt, SetOptValues, SetOpts}, parser::{ParseError, Parser}, }; use mysql_protocol::{ @@ -326,26 +326,50 @@ impl MySqlServer { "COM_QUERY", client_conn.get_endpoint().unwrap().as_str() ); - let sql = str::from_utf8(payload).unwrap(); - let stream = client_conn.send_query(payload).await?; - let res = match self.get_ast(payload) { + let stream = match self.get_ast(payload) { Err(err) => { error!("err: {:?}", err); - self.handle_query_resultset(stream).await + client_conn.send_query(payload).await? } - Ok(stmt) => match &stmt[0].clone() { + Ok(stmt) => match &stmt[0] { + SqlStmt::Set(stmt) => { + self.handle_set_stmt(stmt); + client_conn.send_query(payload).await? + }, //TODO: split sql stmt for sql audit - SqlStmt::BeginStmt(stmt) => self.handle_begin_stmt(stream, stmt, sql).await, - _ => self.handle_query_resultset(stream).await, + SqlStmt::BeginStmt(_stmt) => client_conn.send_query(payload).await?, + _ => client_conn.send_query(payload).await?, }, }; + + self.handle_query_resultset(stream).await?; + let ep = client_conn.get_endpoint().unwrap(); self.trans_fsm.put_conn(client_conn); collect_sql_under_processing_dec!(self, "COM_QUERY", ep.as_str()); collect_sql_processed_duration!(self, "COM_QUERY", ep.as_str(), earlier); - res + Ok(()) + } + + // Set charset name + fn handle_set_stmt(&mut self, stmt: &SetOptValues) { + match stmt { + SetOptValues::OptValues(vals) => { + match &vals.opt { + SetOpts::SetNames(name) => { + if let Some(name) = &name.charset_name { + self.client.charset = name.clone(); + self.trans_fsm.set_charset(name.clone()) + } + }, + _ => {} + } + }, + + _ => {} + } } pub async fn handle_query_resultset<'b>( @@ -459,11 +483,11 @@ impl MySqlServer { self.client.pkt.conn.shutdown().await.map_err(ProtocolError::Io) } - async fn handle_begin_stmt<'b>( + // TODO, add handle for begin stmt + async fn _handle_begin_stmt<'b>( &mut self, stream: ResultsetStream<'b>, _stmt: &BeginStmt, - _sql: &str, ) -> Result<(), ProtocolError> { if let Err(err) = self.trans_fsm.trigger(TransEventName::StartEvent).await { error!("err: {:?}", err); diff --git a/pisa-proxy/runtime/mysql/src/transaction_fsm.rs b/pisa-proxy/runtime/mysql/src/transaction_fsm.rs index 324ee314..be5a1c56 100644 --- a/pisa-proxy/runtime/mysql/src/transaction_fsm.rs +++ b/pisa-proxy/runtime/mysql/src/transaction_fsm.rs @@ -231,6 +231,7 @@ pub struct TransFsm { pub client_conn: Option>, pub endpoint: Option, pub db: Option, + pub charset: String, } impl TransFsm { @@ -244,6 +245,7 @@ impl TransFsm { client_conn: None, endpoint: None, db: None, + charset: String::from("utf8mb4"), } } @@ -252,17 +254,13 @@ impl TransFsm { if event.name == state_name && event.src_state == self.current_state { match event.src_state { TransState::TransDummyState => { - let (mut client_conn, endpoint) = event + let (client_conn, endpoint) = event .driver .as_ref() .unwrap() .get_driver_conn(self.lb.clone(), &mut self.pool) .await?; - if let None = client_conn.get_db() { - if let Some(db) = &self.db { - client_conn.send_use_db(db).await.map_err(ErrorKind::Protocol)?; - } - } + self.client_conn = Some(client_conn); self.endpoint = endpoint; } @@ -281,13 +279,24 @@ impl TransFsm { self.db = Some(db) } + // Set current charset + pub fn set_charset(&mut self, name: String) { + self.charset = name; + } + pub async fn get_conn(&mut self) -> Result, Error> { let conn = self.client_conn.take(); let addr = self.endpoint.as_ref().unwrap().addr.as_ref(); match conn { - Some(client_conn) => Ok(client_conn), + Some(mut client_conn) => { + self.init_session_attr(&mut client_conn).await?; + Ok(client_conn) + } None => match self.pool.get_conn_with_opts(addr).await { - Ok(client_conn) => Ok(client_conn), + Ok(mut client_conn) => { + self.init_session_attr(&mut client_conn).await?; + Ok(client_conn) + }, Err(err) => Err(Error::new(ErrorKind::Protocol(err))), }, } @@ -296,4 +305,25 @@ impl TransFsm { pub fn put_conn(&mut self, conn: PoolConn) { self.client_conn = Some(conn) } + + // init session attrs, db and charset attr + #[inline] + async fn init_session_attr(&mut self, conn: &mut PoolConn) -> Result<(), Error> { + // set db + if self.db != conn.get_db() { + if let Some(db) = &self.db { + conn.send_use_db(db).await.map_err(ErrorKind::Protocol)?; + } + } + + //set charset + if Some(&self.charset) != conn.get_charset().as_ref() { + conn.set_charset(&self.charset); + conn.send_query_discard_result(&format!("SET NAMES {}", &self.charset)) + .await + .map_err(ErrorKind::Protocol)?; + } + + Ok(()) + } }