Skip to content

Commit

Permalink
Added Unicode \u and \U parsing to the cli (#4492)
Browse files Browse the repository at this point in the history
* Added unicode \u and \U parsing to the cli

* Added tests

* Minor test fixes

* Minor fixes
  • Loading branch information
MSebanc authored and ray6080 committed Dec 17, 2024
1 parent 72d2ace commit c36c9ba
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 4 deletions.
68 changes: 64 additions & 4 deletions tools/shell/embedded_shell.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,57 @@ EmbeddedShell::EmbeddedShell(std::shared_ptr<Database> database, std::shared_ptr
}
}

std::string decodeEscapeSequences(const std::string& input) {
std::regex unicodeRegex(R"(\\u([0-9A-Fa-f]{4})|\\U([0-9A-Fa-f]{8}))");
std::string result = input;
std::smatch match;
while (std::regex_search(result, match, unicodeRegex)) {
std::string codepointStr;
if (match[1].matched) {
codepointStr = match[1].str();
} else if (match[2].matched) {
codepointStr = match[2].str();
}
uint32_t codepoint = static_cast<uint32_t>(std::stoull(codepointStr, nullptr, 16));
if (codepoint > 0x10FFFF) {
throw std::runtime_error("Invalid Unicode codepoint");
}
if (codepoint == 0) {
throw std::runtime_error("Null character not allowed");
}
// Check for surrogate pairs
if (0xD800 <= codepoint && codepoint <= 0xDBFF) {
// High surrogate, look for the next low surrogate
std::smatch nextMatch;
std::string remainingString = result.substr(match.position() + match.length());
if (std::regex_search(remainingString, nextMatch, unicodeRegex)) {
std::string nextCodepointStr = nextMatch[1].str();
int nextCodepoint = std::stoi(nextCodepointStr, nullptr, 16);
if (0xDC00 <= nextCodepoint && nextCodepoint <= 0xDFFF) {
// Valid surrogate pair
codepoint = 0x10000 + ((codepoint - 0xD800) << 10) + (nextCodepoint - 0xDC00);
result.replace(match.position() + match.length(), nextMatch.length(), "");
} else {
throw std::runtime_error("Invalid surrogate pair");
}
} else {
throw std::runtime_error("Unmatched high surrogate");
}
}

// Convert codepoint to UTF-8
char utf8Char[5] = {0}; // UTF-8 characters can be up to 4 bytes + null terminator
int size = 0;
if (!Utf8Proc::codepointToUtf8(codepoint, size, utf8Char)) {
throw std::runtime_error("Failed to convert codepoint to UTF-8");
}

// Replace the escape sequence with the actual UTF-8 character
result.replace(match.position(), match.length(), std::string(utf8Char, size));
}
return result;
}

std::vector<std::unique_ptr<QueryResult>> EmbeddedShell::processInput(std::string input) {
std::string query;
std::stringstream ss;
Expand All @@ -413,13 +464,22 @@ std::vector<std::unique_ptr<QueryResult>> EmbeddedShell::processInput(std::strin
continueLine = false;
}
input = input.erase(input.find_last_not_of(" \t\n\r\f\v") + 1);
// Decode escape sequences
std::string unicodeInput;
try {
unicodeInput = decodeEscapeSequences(input);
} catch (std::exception& e) {
printf("Error: %s\n", e.what());
historyLine = input;
return queryResults;
}
// process shell commands
if (!continueLine && input[0] == ':') {
processShellCommands(input);
if (!continueLine && unicodeInput[0] == ':') {
processShellCommands(unicodeInput);
// process queries
} else if (!input.empty() && cypherComplete((char*)input.c_str())) {
} else if (!unicodeInput.empty() && cypherComplete((char*)unicodeInput.c_str())) {
ss.clear();
ss.str(input);
ss.str(unicodeInput);
while (getline(ss, query, ';')) {
queryResults.push_back(conn->query(query));
}
Expand Down
28 changes: 28 additions & 0 deletions tools/shell/test/test_shell_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,3 +235,31 @@ def test_shell_auto_completion(temp_db) -> None:
test.send_statement("() ret\t")
test.send_finished_statement(" *;\n")
assert test.shell_process.expect_exact(["\u2502 coolTable \u2502", pexpect.EOF]) == 0


def test_shell_unicode_input(temp_db) -> None:
test = (
ShellTest()
.add_argument(temp_db)
.statement('CREATE NODE TABLE IF NOT EXISTS `B\\u00fccher` (title STRING, price INT64, PRIMARY KEY (title));\n')
.statement("CREATE (n:`B\\u00fccher` {title: 'Der Thron der Sieben K\\00f6nigreiche'}) SET n.price = 20;\n")
.statement('MATCH (n:B\\u00fccher) RETURN label(n);\n')
.statement('return "\\uD83D\\uDE01";\n') # surrogate pair for grinning face emoji
.statement('return "\\U0001F601";\n') # grinning face emoji
.statement('return "\\uD83D";\n') # unmatched surrogate pair
.statement('return "\\uDE01";\n') # unmatched surrogate pair
.statement('return "\\uD83D\\uDBFF";\n') # bad lower surrogate
.statement('return "\\u000";\n') # bad unicode codepoint
.statement('return "\\u0000";\n') # Null character
.statement('return "\\U00110000";\n') # Invalid codepoint
)
result = test.run()
result.check_stdout("\u2502 B\u00fccher")
result.check_stdout("\u2502 \U0001F601") # grinning face emoji
result.check_stdout("\u2502 \U0001F601") # grinning face emoji
result.check_stdout("Error: Unmatched high surrogate")
result.check_stdout("Error: Failed to convert codepoint to UTF-8")
result.check_stdout("Error: Invalid surrogate pair")
result.check_stdout("Error: Parser exception: Invalid input <return \">: expected rule oC_RegularQuery (line: 1, offset: 7)")
result.check_stdout("Error: Null character not allowed")
result.check_stdout("Error: Invalid Unicode codepoint")

0 comments on commit c36c9ba

Please sign in to comment.