Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Unicode \u and \U parsing to the cli #4492

Merged
merged 4 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 64 additions & 4 deletions tools/shell/embedded_shell.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,57 @@ EmbeddedShell::EmbeddedShell(std::shared_ptr<Database> database, std::shared_ptr
}
}

std::string decodeEscapeSequences(const std::string& input) {
std::regex unicodeRegex(R"(\\u([0-9A-Fa-f]{4})|\\U([0-9A-Fa-f]{8}))");
std::string result = input;
std::smatch match;
while (std::regex_search(result, match, unicodeRegex)) {
std::string codepointStr;
if (match[1].matched) {
codepointStr = match[1].str();
} else if (match[2].matched) {
codepointStr = match[2].str();
}
uint32_t codepoint = static_cast<uint32_t>(std::stoull(codepointStr, nullptr, 16));
if (codepoint > 0x10FFFF) {
throw std::runtime_error("Invalid Unicode codepoint");
}
if (codepoint == 0) {
throw std::runtime_error("Null character not allowed");
}
// Check for surrogate pairs
if (0xD800 <= codepoint && codepoint <= 0xDBFF) {
// High surrogate, look for the next low surrogate
std::smatch nextMatch;
std::string remainingString = result.substr(match.position() + match.length());
if (std::regex_search(remainingString, nextMatch, unicodeRegex)) {
std::string nextCodepointStr = nextMatch[1].str();
int nextCodepoint = std::stoi(nextCodepointStr, nullptr, 16);
if (0xDC00 <= nextCodepoint && nextCodepoint <= 0xDFFF) {
// Valid surrogate pair
codepoint = 0x10000 + ((codepoint - 0xD800) << 10) + (nextCodepoint - 0xDC00);
result.replace(match.position() + match.length(), nextMatch.length(), "");
} else {
throw std::runtime_error("Invalid surrogate pair");
}
} else {
throw std::runtime_error("Unmatched high surrogate");
}
}

// Convert codepoint to UTF-8
char utf8Char[5] = {0}; // UTF-8 characters can be up to 4 bytes + null terminator
int size = 0;
if (!Utf8Proc::codepointToUtf8(codepoint, size, utf8Char)) {
throw std::runtime_error("Failed to convert codepoint to UTF-8");
}

// Replace the escape sequence with the actual UTF-8 character
result.replace(match.position(), match.length(), std::string(utf8Char, size));
}
return result;
}

std::vector<std::unique_ptr<QueryResult>> EmbeddedShell::processInput(std::string input) {
std::string query;
std::stringstream ss;
Expand All @@ -413,13 +464,22 @@ std::vector<std::unique_ptr<QueryResult>> EmbeddedShell::processInput(std::strin
continueLine = false;
}
input = input.erase(input.find_last_not_of(" \t\n\r\f\v") + 1);
// Decode escape sequences
std::string unicodeInput;
try {
unicodeInput = decodeEscapeSequences(input);
} catch (std::exception& e) {
printf("Error: %s\n", e.what());
historyLine = input;
return queryResults;
}
// process shell commands
if (!continueLine && input[0] == ':') {
processShellCommands(input);
if (!continueLine && unicodeInput[0] == ':') {
processShellCommands(unicodeInput);
// process queries
} else if (!input.empty() && cypherComplete((char*)input.c_str())) {
} else if (!unicodeInput.empty() && cypherComplete((char*)unicodeInput.c_str())) {
ss.clear();
ss.str(input);
ss.str(unicodeInput);
while (getline(ss, query, ';')) {
queryResults.push_back(conn->query(query));
}
Expand Down
28 changes: 28 additions & 0 deletions tools/shell/test/test_shell_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,3 +235,31 @@ def test_shell_auto_completion(temp_db) -> None:
test.send_statement("() ret\t")
test.send_finished_statement(" *;\n")
assert test.shell_process.expect_exact(["\u2502 coolTable \u2502", pexpect.EOF]) == 0


def test_shell_unicode_input(temp_db) -> None:
test = (
ShellTest()
.add_argument(temp_db)
.statement('CREATE NODE TABLE IF NOT EXISTS `B\\u00fccher` (title STRING, price INT64, PRIMARY KEY (title));\n')
.statement("CREATE (n:`B\\u00fccher` {title: 'Der Thron der Sieben K\\00f6nigreiche'}) SET n.price = 20;\n")
.statement('MATCH (n:B\\u00fccher) RETURN label(n);\n')
.statement('return "\\uD83D\\uDE01";\n') # surrogate pair for grinning face emoji
.statement('return "\\U0001F601";\n') # grinning face emoji
.statement('return "\\uD83D";\n') # unmatched surrogate pair
.statement('return "\\uDE01";\n') # unmatched surrogate pair
.statement('return "\\uD83D\\uDBFF";\n') # bad lower surrogate
.statement('return "\\u000";\n') # bad unicode codepoint
.statement('return "\\u0000";\n') # Null character
.statement('return "\\U00110000";\n') # Invalid codepoint
)
result = test.run()
result.check_stdout("\u2502 B\u00fccher")
result.check_stdout("\u2502 \U0001F601") # grinning face emoji
result.check_stdout("\u2502 \U0001F601") # grinning face emoji
result.check_stdout("Error: Unmatched high surrogate")
result.check_stdout("Error: Failed to convert codepoint to UTF-8")
result.check_stdout("Error: Invalid surrogate pair")
result.check_stdout("Error: Parser exception: Invalid input <return \">: expected rule oC_RegularQuery (line: 1, offset: 7)")
result.check_stdout("Error: Null character not allowed")
result.check_stdout("Error: Invalid Unicode codepoint")