From 933ca42566accd7aee3d51c3342f983ca988dfc1 Mon Sep 17 00:00:00 2001 From: MSebanc Date: Fri, 8 Nov 2024 18:22:20 -0800 Subject: [PATCH 1/4] Added unicode \u and \U parsing to the cli --- tools/shell/embedded_shell.cpp | 63 +++++++++++++++++++++++++++++++--- 1 file changed, 59 insertions(+), 4 deletions(-) diff --git a/tools/shell/embedded_shell.cpp b/tools/shell/embedded_shell.cpp index a5251cb6a90..12e159a4ab2 100644 --- a/tools/shell/embedded_shell.cpp +++ b/tools/shell/embedded_shell.cpp @@ -401,6 +401,52 @@ EmbeddedShell::EmbeddedShell(std::shared_ptr database, std::shared_ptr } } +std::string decodeEscapeSequences(const std::string& input) { + std::regex unicodeRegex(R"(\\u([0-9A-Fa-f]{4})|\\U([0-9A-Fa-f]{8}))"); + std::string result = input; + std::smatch match; + while (std::regex_search(result, match, unicodeRegex)) { + std::string codepointStr; + if (match[1].matched) { + codepointStr = match[1].str(); + } else if (match[2].matched) { + codepointStr = match[2].str(); + } + int codepoint = std::stoi(codepointStr, nullptr, 16); + + // Check for surrogate pairs + if (0xD800 <= codepoint && codepoint <= 0xDBFF) { + // High surrogate, look for the next low surrogate + std::smatch nextMatch; + std::string remainingString = result.substr(match.position() + match.length()); + if (std::regex_search(remainingString, nextMatch, unicodeRegex)) { + std::string nextCodepointStr = nextMatch[1].str(); + int nextCodepoint = std::stoi(nextCodepointStr, nullptr, 16); + if (0xDC00 <= nextCodepoint && nextCodepoint <= 0xDFFF) { + // Valid surrogate pair + codepoint = 0x10000 + ((codepoint - 0xD800) << 10) + (nextCodepoint - 0xDC00); + result.replace(match.position() + match.length(), nextMatch.length(), ""); + } else { + throw std::runtime_error("Error: Invalid surrogate pair"); + } + } else { + throw std::runtime_error("Error: Unmatched high surrogate"); + } + } + + // Convert codepoint to UTF-8 + char utf8Char[5] = {0}; // UTF-8 characters can be up to 4 bytes + null terminator + int size = 0; + if (!Utf8Proc::codepointToUtf8(codepoint, size, utf8Char)) { + throw std::runtime_error("Error: Failed to convert codepoint to UTF-8"); + } + + // Replace the escape sequence with the actual UTF-8 character + result.replace(match.position(), match.length(), std::string(utf8Char, size)); + } + return result; +} + std::vector> EmbeddedShell::processInput(std::string input) { std::string query; std::stringstream ss; @@ -413,13 +459,22 @@ std::vector> EmbeddedShell::processInput(std::strin continueLine = false; } input = input.erase(input.find_last_not_of(" \t\n\r\f\v") + 1); + // Decode escape sequences + std::string unicodeInput; + try { + unicodeInput = decodeEscapeSequences(input); + } catch (std::exception& e) { + printf("Error: %s\n", e.what()); + historyLine = input; + return queryResults; + } // process shell commands - if (!continueLine && input[0] == ':') { - processShellCommands(input); + if (!continueLine && unicodeInput[0] == ':') { + processShellCommands(unicodeInput); // process queries - } else if (!input.empty() && cypherComplete((char*)input.c_str())) { + } else if (!unicodeInput.empty() && cypherComplete((char*)unicodeInput.c_str())) { ss.clear(); - ss.str(input); + ss.str(unicodeInput); while (getline(ss, query, ';')) { queryResults.push_back(conn->query(query)); } From 37a386837dc2a43d053b6150b3d3ada9485f05ba Mon Sep 17 00:00:00 2001 From: MSebanc Date: Tue, 12 Nov 2024 14:17:56 -0800 Subject: [PATCH 2/4] Added tests --- tools/shell/embedded_shell.cpp | 6 +++--- tools/shell/test/test_shell_basics.py | 24 ++++++++++++++++++++++++ 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/tools/shell/embedded_shell.cpp b/tools/shell/embedded_shell.cpp index 12e159a4ab2..9bc210d7e3a 100644 --- a/tools/shell/embedded_shell.cpp +++ b/tools/shell/embedded_shell.cpp @@ -427,10 +427,10 @@ std::string decodeEscapeSequences(const std::string& input) { codepoint = 0x10000 + ((codepoint - 0xD800) << 10) + (nextCodepoint - 0xDC00); result.replace(match.position() + match.length(), nextMatch.length(), ""); } else { - throw std::runtime_error("Error: Invalid surrogate pair"); + throw std::runtime_error("Invalid surrogate pair"); } } else { - throw std::runtime_error("Error: Unmatched high surrogate"); + throw std::runtime_error("Unmatched high surrogate"); } } @@ -438,7 +438,7 @@ std::string decodeEscapeSequences(const std::string& input) { char utf8Char[5] = {0}; // UTF-8 characters can be up to 4 bytes + null terminator int size = 0; if (!Utf8Proc::codepointToUtf8(codepoint, size, utf8Char)) { - throw std::runtime_error("Error: Failed to convert codepoint to UTF-8"); + throw std::runtime_error("Failed to convert codepoint to UTF-8"); } // Replace the escape sequence with the actual UTF-8 character diff --git a/tools/shell/test/test_shell_basics.py b/tools/shell/test/test_shell_basics.py index 92e2af9ea58..cb627243336 100644 --- a/tools/shell/test/test_shell_basics.py +++ b/tools/shell/test/test_shell_basics.py @@ -235,3 +235,27 @@ def test_shell_auto_completion(temp_db) -> None: test.send_statement("() ret\t") test.send_finished_statement(" *;\n") assert test.shell_process.expect_exact(["\u2502 coolTable \u2502", pexpect.EOF]) == 0 + + +def test_shell_unicode_input(temp_db) -> None: + test = ( + ShellTest() + .add_argument(temp_db) + .statement('CREATE NODE TABLE IF NOT EXISTS `B\\u00fccher` (title STRING, price INT64, PRIMARY KEY (title));\n') # B�cher + .statement("CREATE (n:`B\\u00fccher` {title: 'Der Thron der Sieben K�nigreiche'}) SET n.price = 20;\n") # B�cher + .statement('MATCH (n:B\\u00fccher) RETURN label(n);\n') + .statement('return "\\uD83D\\uDE01";\n') # surrogate pair for grinning face emoji + .statement('return "\\U0001F601";\n') # grinning face emoji + .statement('return "\\uD83D";\n') # unmatched surrogate pair + .statement('return "\\uDE01";\n') # unmatched surrogate pair + .statement('return "\\uD83D\\uDBFF";\n') # bad lower surrogate + .statement('return "\\u000";\n') # bad unicode codepoint + ) + result = test.run() + result.check_stdout("\u2502 B\u00fccher") + result.check_stdout("\u2502 \U0001F601") # grinning face emoji + result.check_stdout("\u2502 \U0001F601") # grinning face emoji + result.check_stdout("Error: Unmatched high surrogate") + result.check_stdout("Error: Failed to convert codepoint to UTF-8") + result.check_stdout("Error: Invalid surrogate pair") + result.check_stdout("Error: Parser exception: Invalid input : expected rule oC_RegularQuery (line: 1, offset: 7)") From 42d4b64945ea123946af21251f6edcf22fbc3d90 Mon Sep 17 00:00:00 2001 From: MSebanc Date: Tue, 12 Nov 2024 14:41:34 -0800 Subject: [PATCH 3/4] Minor test fixes --- tools/shell/test/test_shell_basics.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tools/shell/test/test_shell_basics.py b/tools/shell/test/test_shell_basics.py index cb627243336..eb931de2b6b 100644 --- a/tools/shell/test/test_shell_basics.py +++ b/tools/shell/test/test_shell_basics.py @@ -241,8 +241,8 @@ def test_shell_unicode_input(temp_db) -> None: test = ( ShellTest() .add_argument(temp_db) - .statement('CREATE NODE TABLE IF NOT EXISTS `B\\u00fccher` (title STRING, price INT64, PRIMARY KEY (title));\n') # B�cher - .statement("CREATE (n:`B\\u00fccher` {title: 'Der Thron der Sieben K�nigreiche'}) SET n.price = 20;\n") # B�cher + .statement('CREATE NODE TABLE IF NOT EXISTS `B\\u00fccher` (title STRING, price INT64, PRIMARY KEY (title));\n') + .statement("CREATE (n:`B\\u00fccher` {title: 'Der Thron der Sieben K\\00f6nigreiche'}) SET n.price = 20;\n") .statement('MATCH (n:B\\u00fccher) RETURN label(n);\n') .statement('return "\\uD83D\\uDE01";\n') # surrogate pair for grinning face emoji .statement('return "\\U0001F601";\n') # grinning face emoji From 36a721fdbaca4e92fc17d346eddd1d598f9ec0d4 Mon Sep 17 00:00:00 2001 From: MSebanc Date: Tue, 12 Nov 2024 15:26:37 -0800 Subject: [PATCH 4/4] Minor fixes --- tools/shell/embedded_shell.cpp | 9 +++++++-- tools/shell/test/test_shell_basics.py | 4 ++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/tools/shell/embedded_shell.cpp b/tools/shell/embedded_shell.cpp index 9bc210d7e3a..438c15277cc 100644 --- a/tools/shell/embedded_shell.cpp +++ b/tools/shell/embedded_shell.cpp @@ -412,8 +412,13 @@ std::string decodeEscapeSequences(const std::string& input) { } else if (match[2].matched) { codepointStr = match[2].str(); } - int codepoint = std::stoi(codepointStr, nullptr, 16); - + uint32_t codepoint = static_cast(std::stoull(codepointStr, nullptr, 16)); + if (codepoint > 0x10FFFF) { + throw std::runtime_error("Invalid Unicode codepoint"); + } + if (codepoint == 0) { + throw std::runtime_error("Null character not allowed"); + } // Check for surrogate pairs if (0xD800 <= codepoint && codepoint <= 0xDBFF) { // High surrogate, look for the next low surrogate diff --git a/tools/shell/test/test_shell_basics.py b/tools/shell/test/test_shell_basics.py index eb931de2b6b..72138e867f8 100644 --- a/tools/shell/test/test_shell_basics.py +++ b/tools/shell/test/test_shell_basics.py @@ -250,6 +250,8 @@ def test_shell_unicode_input(temp_db) -> None: .statement('return "\\uDE01";\n') # unmatched surrogate pair .statement('return "\\uD83D\\uDBFF";\n') # bad lower surrogate .statement('return "\\u000";\n') # bad unicode codepoint + .statement('return "\\u0000";\n') # Null character + .statement('return "\\U00110000";\n') # Invalid codepoint ) result = test.run() result.check_stdout("\u2502 B\u00fccher") @@ -259,3 +261,5 @@ def test_shell_unicode_input(temp_db) -> None: result.check_stdout("Error: Failed to convert codepoint to UTF-8") result.check_stdout("Error: Invalid surrogate pair") result.check_stdout("Error: Parser exception: Invalid input : expected rule oC_RegularQuery (line: 1, offset: 7)") + result.check_stdout("Error: Null character not allowed") + result.check_stdout("Error: Invalid Unicode codepoint")