diff --git a/tools/shell/embedded_shell.cpp b/tools/shell/embedded_shell.cpp index a5251cb6a90..438c15277cc 100644 --- a/tools/shell/embedded_shell.cpp +++ b/tools/shell/embedded_shell.cpp @@ -401,6 +401,57 @@ EmbeddedShell::EmbeddedShell(std::shared_ptr database, std::shared_ptr } } +std::string decodeEscapeSequences(const std::string& input) { + std::regex unicodeRegex(R"(\\u([0-9A-Fa-f]{4})|\\U([0-9A-Fa-f]{8}))"); + std::string result = input; + std::smatch match; + while (std::regex_search(result, match, unicodeRegex)) { + std::string codepointStr; + if (match[1].matched) { + codepointStr = match[1].str(); + } else if (match[2].matched) { + codepointStr = match[2].str(); + } + uint32_t codepoint = static_cast(std::stoull(codepointStr, nullptr, 16)); + if (codepoint > 0x10FFFF) { + throw std::runtime_error("Invalid Unicode codepoint"); + } + if (codepoint == 0) { + throw std::runtime_error("Null character not allowed"); + } + // Check for surrogate pairs + if (0xD800 <= codepoint && codepoint <= 0xDBFF) { + // High surrogate, look for the next low surrogate + std::smatch nextMatch; + std::string remainingString = result.substr(match.position() + match.length()); + if (std::regex_search(remainingString, nextMatch, unicodeRegex)) { + std::string nextCodepointStr = nextMatch[1].str(); + int nextCodepoint = std::stoi(nextCodepointStr, nullptr, 16); + if (0xDC00 <= nextCodepoint && nextCodepoint <= 0xDFFF) { + // Valid surrogate pair + codepoint = 0x10000 + ((codepoint - 0xD800) << 10) + (nextCodepoint - 0xDC00); + result.replace(match.position() + match.length(), nextMatch.length(), ""); + } else { + throw std::runtime_error("Invalid surrogate pair"); + } + } else { + throw std::runtime_error("Unmatched high surrogate"); + } + } + + // Convert codepoint to UTF-8 + char utf8Char[5] = {0}; // UTF-8 characters can be up to 4 bytes + null terminator + int size = 0; + if (!Utf8Proc::codepointToUtf8(codepoint, size, utf8Char)) { + throw std::runtime_error("Failed to convert codepoint to UTF-8"); + } + + // Replace the escape sequence with the actual UTF-8 character + result.replace(match.position(), match.length(), std::string(utf8Char, size)); + } + return result; +} + std::vector> EmbeddedShell::processInput(std::string input) { std::string query; std::stringstream ss; @@ -413,13 +464,22 @@ std::vector> EmbeddedShell::processInput(std::strin continueLine = false; } input = input.erase(input.find_last_not_of(" \t\n\r\f\v") + 1); + // Decode escape sequences + std::string unicodeInput; + try { + unicodeInput = decodeEscapeSequences(input); + } catch (std::exception& e) { + printf("Error: %s\n", e.what()); + historyLine = input; + return queryResults; + } // process shell commands - if (!continueLine && input[0] == ':') { - processShellCommands(input); + if (!continueLine && unicodeInput[0] == ':') { + processShellCommands(unicodeInput); // process queries - } else if (!input.empty() && cypherComplete((char*)input.c_str())) { + } else if (!unicodeInput.empty() && cypherComplete((char*)unicodeInput.c_str())) { ss.clear(); - ss.str(input); + ss.str(unicodeInput); while (getline(ss, query, ';')) { queryResults.push_back(conn->query(query)); } diff --git a/tools/shell/test/test_shell_basics.py b/tools/shell/test/test_shell_basics.py index 92e2af9ea58..72138e867f8 100644 --- a/tools/shell/test/test_shell_basics.py +++ b/tools/shell/test/test_shell_basics.py @@ -235,3 +235,31 @@ def test_shell_auto_completion(temp_db) -> None: test.send_statement("() ret\t") test.send_finished_statement(" *;\n") assert test.shell_process.expect_exact(["\u2502 coolTable \u2502", pexpect.EOF]) == 0 + + +def test_shell_unicode_input(temp_db) -> None: + test = ( + ShellTest() + .add_argument(temp_db) + .statement('CREATE NODE TABLE IF NOT EXISTS `B\\u00fccher` (title STRING, price INT64, PRIMARY KEY (title));\n') + .statement("CREATE (n:`B\\u00fccher` {title: 'Der Thron der Sieben K\\00f6nigreiche'}) SET n.price = 20;\n") + .statement('MATCH (n:B\\u00fccher) RETURN label(n);\n') + .statement('return "\\uD83D\\uDE01";\n') # surrogate pair for grinning face emoji + .statement('return "\\U0001F601";\n') # grinning face emoji + .statement('return "\\uD83D";\n') # unmatched surrogate pair + .statement('return "\\uDE01";\n') # unmatched surrogate pair + .statement('return "\\uD83D\\uDBFF";\n') # bad lower surrogate + .statement('return "\\u000";\n') # bad unicode codepoint + .statement('return "\\u0000";\n') # Null character + .statement('return "\\U00110000";\n') # Invalid codepoint + ) + result = test.run() + result.check_stdout("\u2502 B\u00fccher") + result.check_stdout("\u2502 \U0001F601") # grinning face emoji + result.check_stdout("\u2502 \U0001F601") # grinning face emoji + result.check_stdout("Error: Unmatched high surrogate") + result.check_stdout("Error: Failed to convert codepoint to UTF-8") + result.check_stdout("Error: Invalid surrogate pair") + result.check_stdout("Error: Parser exception: Invalid input : expected rule oC_RegularQuery (line: 1, offset: 7)") + result.check_stdout("Error: Null character not allowed") + result.check_stdout("Error: Invalid Unicode codepoint")