Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sec: Adjust flag validation for TLS. #1582

Merged
merged 6 commits into from
Jul 30, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 31 additions & 12 deletions src/server/protocol_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,26 @@ static ProtocolClient::SSL_CTX* CreateSslClientCntx() {
const auto& tls_key_file = GetFlag(FLAGS_tls_key_file);
unsigned mask = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT;

CHECK_EQ(1, SSL_CTX_use_PrivateKey_file(ctx, tls_key_file.c_str(), SSL_FILETYPE_PEM));
const auto& tls_cert_file = GetFlag(FLAGS_tls_cert_file);
// Load client certificate if given.
if (!tls_key_file.empty()) {
royjacobson marked this conversation as resolved.
Show resolved Hide resolved
CHECK_EQ(1, SSL_CTX_use_PrivateKey_file(ctx, tls_key_file.c_str(), SSL_FILETYPE_PEM));
// We checked that the flag is non empty in ValidateClientTlsFlags.
const auto& tls_cert_file = GetFlag(FLAGS_tls_cert_file);
royjacobson marked this conversation as resolved.
Show resolved Hide resolved

CHECK_EQ(1, SSL_CTX_use_certificate_chain_file(ctx, tls_cert_file.c_str()));
CHECK_EQ(1, SSL_CTX_use_certificate_chain_file(ctx, tls_cert_file.c_str()));
}

// Load custom certificate validation if given.
const auto& tls_ca_cert_file = GetFlag(FLAGS_tls_ca_cert_file);
const auto& tls_ca_cert_dir = GetFlag(FLAGS_tls_ca_cert_dir);

const auto* file = tls_ca_cert_file.empty() ? nullptr : tls_ca_cert_file.data();
const auto* dir = tls_ca_cert_dir.empty() ? nullptr : tls_ca_cert_dir.data();
CHECK_EQ(1, SSL_CTX_load_verify_locations(ctx, file, dir));
if (file || dir) {
royjacobson marked this conversation as resolved.
Show resolved Hide resolved
CHECK_EQ(1, SSL_CTX_load_verify_locations(ctx, file, dir));
} else {
CHECK_EQ(1, SSL_CTX_set_default_verify_paths(ctx));
}

CHECK_EQ(1, SSL_CTX_set_cipher_list(ctx, "DEFAULT"));
SSL_CTX_set_min_proto_version(ctx, TLS1_2_VERSION);
Expand Down Expand Up @@ -142,22 +151,32 @@ std::string ProtocolClient::ServerContext::Description() const {
return absl::StrCat(host, ":", port);
}

void ProtocolClient::ValidateTlsFlags() const {
if (absl::GetFlag(FLAGS_tls_cert_file).empty()) {
LOG(ERROR) << "tls_cert_file flag should be set";
exit(1);
void ValidateClientTlsFlags() {
royjacobson marked this conversation as resolved.
Show resolved Hide resolved
if (!absl::GetFlag(FLAGS_tls_replication)) {
return;
}

bool has_auth = false;
kostasrim marked this conversation as resolved.
Show resolved Hide resolved

if (!absl::GetFlag(FLAGS_tls_key_file).empty()) {
if (absl::GetFlag(FLAGS_tls_cert_file).empty()) {
LOG(ERROR) << "tls_cert_file flag should be set";
exit(1);
}
has_auth = true;
kostasrim marked this conversation as resolved.
Show resolved Hide resolved
}

if (absl::GetFlag(FLAGS_tls_ca_cert_file).empty() &&
absl::GetFlag(FLAGS_tls_ca_cert_dir).empty()) {
LOG(ERROR) << "Either or both tls_ca_cert_file or tls_ca_cert_dir flags must be set";
if (!absl::GetFlag(FLAGS_masterauth).empty())
has_auth = true;
kostasrim marked this conversation as resolved.
Show resolved Hide resolved

if (!has_auth) {
kostasrim marked this conversation as resolved.
Show resolved Hide resolved
LOG(ERROR) << "No authentication method configured!";
exit(1);
}
}

void ProtocolClient::MaybeInitSslCtx() {
if (absl::GetFlag(FLAGS_tls_replication)) {
ValidateTlsFlags();
ssl_ctx_ = CreateSslClientCntx();
}
}
Expand Down
6 changes: 5 additions & 1 deletion src/server/protocol_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ class ConnectionContext;
class JournalExecutor;
struct JournalReader;

void ValidateClientTlsFlags();
royjacobson marked this conversation as resolved.
Show resolved Hide resolved

// A helper class for implementing a Redis client that talks to a redis server.
// This class should be inherited from.
class ProtocolClient {
Expand Down Expand Up @@ -130,12 +132,14 @@ class ProtocolClient {
std::string last_resp_;

uint64_t last_io_time_ = 0; // in ns, monotonic clock.

#ifdef DFLY_USE_SSL
void ValidateTlsFlags() const;

void MaybeInitSslCtx();

SSL_CTX* ssl_ctx_{nullptr};
#else
void* ssl_ctx_{nullptr};
royjacobson marked this conversation as resolved.
Show resolved Hide resolved
#endif
};

Expand Down
33 changes: 33 additions & 0 deletions src/server/server_family.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ extern "C" {
#include "server/journal/journal.h"
#include "server/main_service.h"
#include "server/memory_cmd.h"
#include "server/protocol_client.h"
#include "server/rdb_load.h"
#include "server/rdb_save.h"
#include "server/script_mgr.h"
Expand Down Expand Up @@ -68,6 +69,9 @@ ABSL_FLAG(int, epoll_file_threads, 0,
ABSL_DECLARE_FLAG(uint32_t, port);
ABSL_DECLARE_FLAG(bool, cache_mode);
ABSL_DECLARE_FLAG(uint32_t, hz);
ABSL_DECLARE_FLAG(bool, tls);
ABSL_DECLARE_FLAG(string, tls_ca_cert_file);
ABSL_DECLARE_FLAG(string, tls_ca_cert_dir);

namespace dfly {

Expand Down Expand Up @@ -419,6 +423,32 @@ void SlowLog(CmdArgList args, ConnectionContext* cntx) {
(*cntx)->SendError(UnknownSubCmd(sub_cmd, "SLOWLOG"), kSyntaxErrType);
}

// Check that if TLS is used at least one form of client authentication is
// enabled. That means either using a password or giving a root
// certificate for authenticating client certificates which will
// be required.
void ValidateServerTlsFlags() {
royjacobson marked this conversation as resolved.
Show resolved Hide resolved
if (!absl::GetFlag(FLAGS_tls)) {
return;
}

bool has_auth = false;
kostasrim marked this conversation as resolved.
Show resolved Hide resolved

if (!dfly::GetPassword().empty()) {
has_auth = true;
}

if (!(absl::GetFlag(FLAGS_tls_ca_cert_file).empty() &&
absl::GetFlag(FLAGS_tls_ca_cert_dir).empty())) {
has_auth = true;
}

if (!has_auth) {
LOG(ERROR) << "TLS configured but no authentication method is used!";
exit(1);
}
}

} // namespace

std::optional<SnapshotSpec> ParseSaveSchedule(string_view time) {
Expand Down Expand Up @@ -488,6 +518,9 @@ ServerFamily::ServerFamily(Service* service) : service_(*service) {
LOG(ERROR) << ec.Format();
exit(1);
}

ValidateServerTlsFlags();
ValidateClientTlsFlags();
}

ServerFamily::~ServerFamily() {
Expand Down
6 changes: 5 additions & 1 deletion tests/dragonfly/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ class DflyParams:
env: any


class DflyStartException(Exception):
pass


class DflyInstance:
"""
Represents a runnable and stoppable Dragonfly instance
Expand Down Expand Up @@ -81,7 +85,7 @@ def _check_status(self):
if not self.params.existing_port:
return_code = self.proc.poll()
if return_code is not None:
raise Exception(f"Failed to start instance, return code {return_code}")
raise DflyStartException(f"Failed to start instance, return code {return_code}")

def __getitem__(self, k):
return self.args.get(k)
Expand Down
50 changes: 50 additions & 0 deletions tests/dragonfly/tls_conf_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import pytest
import redis
from .utility import *
from . import DflyStartException

kostasrim marked this conversation as resolved.
Show resolved Hide resolved

royjacobson marked this conversation as resolved.
Show resolved Hide resolved
def test_tls_no_auth(df_factory, with_tls_server_args):
# Needs some authentication
server = df_factory.create(port=1111, **with_tls_server_args)
with pytest.raises(DflyStartException):
server.start()


def test_tls_no_key(df_factory):
# Needs a private key and certificate.
server = df_factory.create(port=1112, tls=None, requirepass="XXX")
with pytest.raises(DflyStartException):
server.start()


def test_tls_password(df_factory, with_tls_server_args):
server = df_factory.create(port=1113, requirepass="XXX", **with_tls_server_args)
server.start()
server.stop()


def test_tls_client_certs(df_factory, with_ca_tls_server_args):
server = df_factory.create(port=1114, **with_ca_tls_server_args)
server.start()
server.stop()


def test_client_tls_no_auth(df_factory):
server = df_factory.create(port=1115, tls_replication=None)
with pytest.raises(DflyStartException):
server.start()


def test_client_tls_password(df_factory):
server = df_factory.create(port=1116, tls_replication=None, masterauth="XXX")
server.start()
server.stop()


def test_client_tls_cert(df_factory, with_tls_server_args):
key_args = with_tls_server_args.copy()
key_args.pop("tls")
server = df_factory.create(port=1117, tls_replication=None, **key_args)
server.start()
server.stop()