Skip to content

Commit

Permalink
feat(acl): add acl keys to acl setuser command (#2258)
Browse files Browse the repository at this point in the history
* add parsing of ACL keys
* add ACL keys to acl setuser command
  • Loading branch information
kostasrim authored Dec 8, 2023
1 parent 636507c commit b642fb6
Show file tree
Hide file tree
Showing 9 changed files with 110 additions and 40 deletions.
3 changes: 3 additions & 0 deletions src/facade/conn_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,14 @@ class ConnectionContext {
// How many async subscription sources are active: monitor and/or pubsub - at most 2.
uint8_t subscriptions;

// TODO fix inherit actual values from default
std::string authed_username{"default"};
uint32_t acl_categories{dfly::acl::ALL};
std::vector<uint64_t> acl_commands;
// Skip ACL validation, used by internal commands and commands run on admin port
bool skip_acl_validation = false;
// keys
dfly::acl::AclKeys keys{{}, true};

private:
Connection* owner_;
Expand Down
19 changes: 11 additions & 8 deletions src/facade/dragonfly_connection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <absl/strings/match.h>
#include <mimalloc.h>

#include <numeric>
#include <variant>

#include "absl/strings/str_cat.h"
Expand Down Expand Up @@ -205,9 +206,12 @@ size_t Connection::MessageHandle::UsedMemory() const {
return msg.capacity();
}
size_t operator()(const AclUpdateMessagePtr& msg) {
return sizeof(AclUpdateMessage) + msg->username.capacity() * sizeof(string) +
msg->commands.capacity() * sizeof(vector<int>) +
msg->categories.capacity() * sizeof(uint32_t);
size_t key_cap = std::accumulate(
msg->keys.key_globs.begin(), msg->keys.key_globs.end(), 0, [](auto acc, auto& str) {
return acc + (str.first.capacity() * sizeof(char)) + sizeof(str.second);
});
return sizeof(AclUpdateMessage) + msg->username.capacity() * sizeof(char) +
msg->commands.capacity() * sizeof(uint64_t) + key_cap;
}
size_t operator()(const MigrationRequestMessage& msg) {
return 0;
Expand Down Expand Up @@ -240,11 +244,10 @@ void Connection::DispatchOperations::operator()(const MonitorMessage& msg) {

void Connection::DispatchOperations::operator()(const AclUpdateMessage& msg) {
if (self->cntx()) {
for (size_t id = 0; id < msg.username.size(); ++id) {
if (msg.username[id] == self->cntx()->authed_username) {
self->cntx()->acl_categories = msg.categories[id];
self->cntx()->acl_commands = msg.commands[id];
}
if (msg.username == self->cntx()->authed_username) {
self->cntx()->acl_categories = msg.categories;
self->cntx()->acl_commands = msg.commands;
self->cntx()->keys = msg.keys;
}
}
}
Expand Down
8 changes: 5 additions & 3 deletions src/facade/dragonfly_connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "base/io_buf.h"
#include "core/fibers.h"
#include "facade/acl_commands_def.h"
#include "facade/facade_types.h"
#include "facade/resp_expr.h"
#include "util/connection.h"
Expand Down Expand Up @@ -101,9 +102,10 @@ class Connection : public util::Connection {

// ACL Update message, contains ACL updates to be applied to the connection.
struct AclUpdateMessage {
std::vector<std::string> username;
std::vector<uint32_t> categories;
std::vector<std::vector<uint64_t>> commands;
std::string username;
uint32_t categories;
std::vector<uint64_t> commands;
dfly::acl::AclKeys keys;
};

// Migration request message, the dispatch fiber stops to give way for thread migration.
Expand Down
39 changes: 20 additions & 19 deletions src/server/acl/acl_family.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,16 +76,14 @@ void AclFamily::List(CmdArgList args, ConnectionContext* cntx) {
}
}

void AclFamily::StreamUpdatesToAllProactorConnections(const std::vector<std::string>& user,
const std::vector<uint32_t>& update_cat,
const NestedVector& update_commands) {
auto update_cb = [&user, &update_cat, &update_commands]([[maybe_unused]] size_t id,
util::Connection* conn) {
void AclFamily::StreamUpdatesToAllProactorConnections(const std::string& user, uint32_t update_cat,
const Commands& update_commands,
const AclKeys& update_keys) {
auto update_cb = [&]([[maybe_unused]] size_t id, util::Connection* conn) {
DCHECK(conn);
auto connection = static_cast<facade::Connection*>(conn);
DCHECK(user.size() == update_cat.size());
connection->SendAclUpdateAsync(
facade::Connection::AclUpdateMessage{user, update_cat, update_commands});
facade::Connection::AclUpdateMessage{user, update_cat, update_commands, update_keys});
};

if (main_listener_) {
Expand All @@ -97,14 +95,20 @@ using facade::ErrorReply;

void AclFamily::SetUser(CmdArgList args, ConnectionContext* cntx) {
std::string_view username = facade::ToSV(args[0]);
auto req = ParseAclSetUser(args.subspan(1), *cmd_registry_);
auto error_case = [cntx](ErrorReply&& error) { cntx->SendError(std::move(error)); };
auto update_case = [username, cntx, this](User::UpdateRequest&& req) {
auto user_with_lock = registry_->MaybeAddAndUpdateWithLock(username, std::move(req));
if (user_with_lock.exists) {
StreamUpdatesToAllProactorConnections({std::string(username)},
{user_with_lock.user.AclCategory()},
{user_with_lock.user.AclCommands()});
auto reg = registry_->GetRegistryWithWriteLock();
const bool exists = reg.registry.contains(username);
const bool has_all_keys = exists ? reg.registry.find(username)->second.Keys().all_keys : false;

auto req = ParseAclSetUser(args.subspan(1), *cmd_registry_, false, has_all_keys);

auto error_case = [cntx](ErrorReply&& error) { cntx->SendError(error); };

auto update_case = [username, &reg, cntx, this, exists](User::UpdateRequest&& req) {
auto& user = reg.registry[username];
user.Update(std::move(req));
if (exists) {
StreamUpdatesToAllProactorConnections(std::string(username), user.AclCategory(),
user.AclCommands(), user.Keys());
}
cntx->SendOk();
};
Expand Down Expand Up @@ -273,13 +277,10 @@ std::optional<facade::ErrorReply> AclFamily::LoadToRegistryFromFile(std::string_
EvictOpenConnectionsOnAllProactorsWithRegistry(registry);
registry.clear();
}
std::vector<uint32_t> categories;
NestedVector commands;

for (size_t i = 0; i < usernames.size(); ++i) {
auto& user = registry[usernames[i]];
user.Update(std::move(requests[i]));
categories.push_back(user.AclCategory());
commands.push_back(user.AclCommands());
}

if (!registry.contains("default")) {
Expand Down
8 changes: 4 additions & 4 deletions src/server/acl/acl_family.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ class AclFamily final {

// Helper function that updates all open connections and their
// respective ACL fields on all the available proactor threads
using NestedVector = std::vector<std::vector<uint64_t>>;
void StreamUpdatesToAllProactorConnections(const std::vector<std::string>& user,
const std::vector<uint32_t>& update_cat,
const NestedVector& update_commands);
using Commands = std::vector<uint64_t>;
void StreamUpdatesToAllProactorConnections(const std::string& user, uint32_t update_cat,
const Commands& update_commands,
const AclKeys& update_keys);

// Helper function that closes all open connection from the deleted user
void EvictOpenConnectionsOnAllProactors(std::string_view user);
Expand Down
59 changes: 56 additions & 3 deletions src/server/acl/helpers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,38 @@ std::string PrettyPrintSha(std::string_view pass, bool all) {
return absl::BytesToHexString(pass.substr(0, 15)).substr(0, 15);
};

std::optional<ParseKeyResult> MaybeParseAclKey(std::string_view command) {
if (absl::EqualsIgnoreCase(command, "ALLKEYS") || command == "~*") {
return ParseKeyResult{"", {}, true};
}

if (absl::EqualsIgnoreCase(command, "RESETKEYS")) {
return ParseKeyResult{"", {}, false, true};
}

auto op = KeyOp::READ_WRITE;

if (absl::StartsWith(command, "%RW")) {
command = command.substr(3);
} else if (absl::StartsWith(command, "%R")) {
op = KeyOp::READ;
command = command.substr(2);
} else if (absl::StartsWith(command, "%W")) {
op = KeyOp::WRITE;
command = command.substr(2);
}

if (!absl::StartsWith(command, "~")) {
return {};
}

auto key = command.substr(1);
if (key.empty()) {
return {};
}
return ParseKeyResult{std::string(key), op};
}

std::optional<std::string> MaybeParsePassword(std::string_view command, bool hashed) {
if (command == "nopass") {
return std::string(command);
Expand Down Expand Up @@ -190,7 +222,7 @@ using facade::ErrorReply;
template <typename T>
std::variant<User::UpdateRequest, ErrorReply> ParseAclSetUser(T args,
const CommandRegistry& registry,
bool hashed) {
bool hashed, bool has_all_keys) {
User::UpdateRequest req;

for (auto& arg : args) {
Expand All @@ -202,6 +234,26 @@ std::variant<User::UpdateRequest, ErrorReply> ParseAclSetUser(T args,
req.is_hashed = hashed;
continue;
}

if (auto res = MaybeParseAclKey(facade::ToSV(arg)); res) {
auto& [glob, op, all_keys, reset_keys] = *res;
if ((has_all_keys && !all_keys && !reset_keys) ||
(req.allow_all_keys && !all_keys && !reset_keys)) {
return ErrorReply(
"Error in ACL SETUSER modifier '~tmp': Adding a pattern after the * pattern (or the "
"'allkeys' flag) is not valid and does not have any effect. Try 'resetkeys' to start "
"with an empty list of patterns");
}

req.allow_all_keys = all_keys;
req.reset_all_keys = reset_keys;
if (reset_keys) {
has_all_keys = false;
}
req.keys.push_back({std::move(glob), op, all_keys, reset_keys});
continue;
}

std::string buffer;
std::string_view command;
if constexpr (std::is_same_v<T, facade::CmdArgList>) {
Expand Down Expand Up @@ -252,8 +304,9 @@ using facade::CmdArgList;

template std::variant<User::UpdateRequest, ErrorReply>
ParseAclSetUser<std::vector<std::string_view>&>(std::vector<std::string_view>&,
const CommandRegistry& registry, bool hashed);
const CommandRegistry& registry, bool hashed,
bool has_all_keys);

template std::variant<User::UpdateRequest, ErrorReply> ParseAclSetUser<CmdArgList>(
CmdArgList args, const CommandRegistry& registry, bool hashed);
CmdArgList args, const CommandRegistry& registry, bool hashed, bool has_all_keys);
} // namespace dfly::acl
10 changes: 9 additions & 1 deletion src/server/acl/helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,19 @@ std::pair<OptCommand, bool> MaybeParseAclCommand(std::string_view command,

template <typename T>
std::variant<User::UpdateRequest, facade::ErrorReply> ParseAclSetUser(
T args, const CommandRegistry& registry, bool hashed = false);
T args, const CommandRegistry& registry, bool hashed = false, bool has_all_keys = false);

using MaterializedContents = std::optional<std::vector<std::vector<std::string_view>>>;

MaterializedContents MaterializeFileContents(std::vector<std::string>* usernames,
std::string_view file_contents);

struct ParseKeyResult {
std::string glob;
KeyOp op;
bool all_keys{false};
bool reset_keys{false};
};

std::optional<ParseKeyResult> MaybeParseAclKey(std::string_view command);
} // namespace dfly::acl
2 changes: 1 addition & 1 deletion src/server/acl/user.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ const AclKeys& User::Keys() const {
return keys_;
}

void User::SetKeyGlobs(std::vector<UpdateKey>&& keys) {
void User::SetKeyGlobs(std::vector<UpdateKey> keys) {
for (auto& key : keys) {
if (key.all_keys) {
keys_.key_globs.clear();
Expand Down
2 changes: 1 addition & 1 deletion src/server/acl/user.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ class User final {
void SetPasswordHash(std::string_view password, bool is_hashed);

// For ACL key globs
void SetKeyGlobs(std::vector<UpdateKey>&& keys);
void SetKeyGlobs(std::vector<UpdateKey> keys);

// when optional is empty, the special `nopass` password is implied
// password hashed with xx64
Expand Down

0 comments on commit b642fb6

Please sign in to comment.