Skip to content

Commit

Permalink
feat: Use const ConnectionContext in VerifyCommand (#1633)
Browse files Browse the repository at this point in the history
* feat: Use const ConnectionContext in VerifyCommand

Signed-off-by: Vladislav Oleshko <[email protected]>
  • Loading branch information
dranikpg authored Aug 6, 2023
1 parent 6faa530 commit 3bc1e26
Show file tree
Hide file tree
Showing 9 changed files with 148 additions and 119 deletions.
22 changes: 22 additions & 0 deletions src/facade/facade_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,12 @@
#include <absl/container/flat_hash_map.h>
#include <absl/types/span.h>

#include <optional>
#include <string>
#include <string_view>
#include <variant>

#include "facade/op_status.h"

namespace facade {

Expand Down Expand Up @@ -56,6 +60,24 @@ struct ConnectionStats {
ConnectionStats& operator+=(const ConnectionStats& o);
};

struct ErrorReply {
explicit ErrorReply(std::string&& msg, std::string_view kind = {})
: message{move(msg)}, kind{kind} {
}
explicit ErrorReply(std::string_view msg, std::string_view kind = {}) : message{msg}, kind{kind} {
}
explicit ErrorReply(const char* msg,
std::string_view kind = {}) // to resolve ambiguity of constructors above
: message{std::string_view{msg}}, kind{kind} {
}
explicit ErrorReply(OpStatus status) : message{}, kind{}, status{status} {
}

std::variant<std::string, std::string_view> message;
std::string_view kind;
std::optional<OpStatus> status{std::nullopt};
};

inline MutableSlice ToMSS(absl::Span<uint8_t> span) {
return MutableSlice{reinterpret_cast<char*>(span.data()), span.size()};
}
Expand Down
8 changes: 8 additions & 0 deletions src/facade/reply_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,14 @@ void RedisReplyBuilder::SendError(string_view str, string_view err_type) {
}
}

void RedisReplyBuilder::SendError(ErrorReply error) {
if (error.status)
return SendError(*error.status);

string_view message_sv = visit([](auto&& str) -> string_view { return str; }, error.message);
SendError(message_sv, error.kind);
}

void RedisReplyBuilder::SendProtocolError(std::string_view str) {
SendError(absl::StrCat("-ERR Protocol error: ", str), "protocol_error");
}
Expand Down
3 changes: 3 additions & 0 deletions src/facade/reply_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <optional>
#include <string_view>

#include "facade/facade_types.h"
#include "facade/op_status.h"
#include "io/io.h"

Expand Down Expand Up @@ -174,6 +175,8 @@ class RedisReplyBuilder : public SinkReplyBuilder {
void SetResp3(bool is_resp3);

void SendError(std::string_view str, std::string_view type = {}) override;
virtual void SendError(ErrorReply error);

void SendMGetResponse(absl::Span<const OptResp>) override;

void SendStored() override;
Expand Down
7 changes: 7 additions & 0 deletions src/facade/reply_capture.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@ void CapturingReplyBuilder::SendError(std::string_view str, std::string_view typ
Capture(Error{str, type});
}

void CapturingReplyBuilder::SendError(ErrorReply error) {
SKIP_LESS(ReplyMode::ONLY_ERR);

string message = visit([](auto&& str) -> string { return string{move(str)}; }, error.message);
Capture(Error{move(message), error.kind});
}

void CapturingReplyBuilder::SendMGetResponse(absl::Span<const OptResp> arr) {
SKIP_LESS(ReplyMode::FULL);
Capture(vector<OptResp>{arr.begin(), arr.end()});
Expand Down
1 change: 1 addition & 0 deletions src/facade/reply_capture.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class CapturingReplyBuilder : public RedisReplyBuilder {

public:
void SendError(std::string_view str, std::string_view type = {}) override;
void SendError(ErrorReply error) override;
void SendMGetResponse(absl::Span<const OptResp>) override;

// SendStored -> SendSimpleString("OK")
Expand Down
15 changes: 15 additions & 0 deletions src/server/command_registry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,21 @@ void CommandId::Invoke(CmdArgList args, ConnectionContext* cntx) const {
ent.second += (after - before) / 1000;
}

optional<facade::ErrorReply> CommandId::Validate(CmdArgList args) const {
if ((arity() > 0 && args.size() != size_t(arity())) ||
(arity() < 0 && args.size() < size_t(-arity()))) {
return facade::ErrorReply{facade::WrongNumArgsError(name()), kSyntaxErrType};
}

if (key_arg_step() == 2 && (args.size() % 2) == 0) {
return facade::ErrorReply{facade::WrongNumArgsError(name()), kSyntaxErrType};
}

if (validator_)
return validator_(args.subspan(1));
return nullopt;
}

CommandRegistry::CommandRegistry() {
vector<string> rename_command = GetFlag(FLAGS_rename_command);

Expand Down
26 changes: 12 additions & 14 deletions src/server/command_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <absl/types/span.h>

#include <functional>
#include <optional>

#include "base/function2.hpp"
#include "facade/command_id.h"
Expand Down Expand Up @@ -66,11 +67,16 @@ class CommandId : public facade::CommandId {
void(CmdArgList, ConnectionContext*) const>;

using ArgValidator = fu2::function_base<true, true, fu2::capacity_default, false, false,
bool(CmdArgList, ConnectionContext*) const>;
std::optional<facade::ErrorReply>(CmdArgList) const>;

bool is_multi_key() const {
return (last_key_ != first_key_) || (opt_mask_ & CO::VARIADIC_KEYS);
}
void Invoke(CmdArgList args, ConnectionContext* cntx) const;

// Returns error if validation failed, otherwise nullopt
std::optional<facade::ErrorReply> Validate(CmdArgList args) const;

bool IsTransactional() const;

static const char* OptName(CO::CommandOpt fl);

CommandId& SetHandler(Handler f) {
handler_ = std::move(f);
Expand All @@ -79,21 +85,13 @@ class CommandId : public facade::CommandId {

CommandId& SetValidator(ArgValidator f) {
validator_ = std::move(f);

return *this;
}

void Invoke(CmdArgList args, ConnectionContext* cntx) const;

// Returns true if validation succeeded.
bool Validate(CmdArgList args, ConnectionContext* cntx) const {
return !validator_ || validator_(std::move(args), cntx);
bool is_multi_key() const {
return (last_key_ != first_key_) || (opt_mask_ & CO::VARIADIC_KEYS);
}

bool IsTransactional() const;

static const char* OptName(CO::CommandOpt fl);

private:
Handler handler_;
ArgValidator validator_;
Expand Down
Loading

0 comments on commit 3bc1e26

Please sign in to comment.