From b7d996bdc26aa9a2d446b33111f853be8b961593 Mon Sep 17 00:00:00 2001 From: Roman Gershman Date: Sun, 24 Mar 2024 10:36:46 +0200 Subject: [PATCH] chore: preparation for basic http api The goal is to provide very basic support for simple commands, fancy stuff like pipelining, blocking commands won't work. 1. Added optional registration for /api handler. 2. Implemented parsing of post body. 3. Added basic formatting routine for the response. It does not cover all the commands but should suffice for basic usage. The API is a POST method and the body of the request should contain command arguments formatted as json array. For example, `'["set", "foo", "bar", "ex", "100"]'`. The response is a json object with either `result` field holding the response of the command or `error` field containing the error message sent by the server. See `test_http` test in tests/dragonfly/connection_test.py for more details. Signed-off-by: Roman Gershman --- helio | 2 +- src/facade/dragonfly_connection.cc | 9 +- src/facade/reply_capture.h | 10 +- src/server/CMakeLists.txt | 2 +- src/server/http_api.cc | 224 +++++++++++++++++++++++++++++ src/server/http_api.h | 16 +++ src/server/main_service.cc | 13 +- tests/dragonfly/connection_test.py | 36 +++-- 8 files changed, 280 insertions(+), 32 deletions(-) create mode 100644 src/server/http_api.cc create mode 100644 src/server/http_api.h diff --git a/helio b/helio index 8985263c3acc..23eb57246dfd 160000 --- a/helio +++ b/helio @@ -1 +1 @@ -Subproject commit 8985263c3acca038752e8f9fdd8e9f61d2ec2b6f +Subproject commit 23eb57246dfd6660f5cabd202e129e4d7e9834f6 diff --git a/src/facade/dragonfly_connection.cc b/src/facade/dragonfly_connection.cc index db89b05f5c25..d82749c72685 100644 --- a/src/facade/dragonfly_connection.cc +++ b/src/facade/dragonfly_connection.cc @@ -89,7 +89,8 @@ void SendProtocolError(RedisParser::Result pres, SinkReplyBuilder* builder) { // https://www.w3.org/Protocols/rfc2616/rfc2616-sec5.html // One place to find a good implementation would be https://github.com/h2o/picohttpparser bool MatchHttp11Line(string_view line) { - return absl::StartsWith(line, "GET ") && absl::EndsWith(line, "HTTP/1.1"); + return (absl::StartsWith(line, "GET ") || absl::StartsWith(line, "POST ")) && + absl::EndsWith(line, "HTTP/1.1"); } void UpdateIoBufCapacity(const base::IoBuf& io_buf, ConnectionStats* stats, @@ -651,11 +652,13 @@ void Connection::HandleRequests() { http_res = CheckForHttpProto(peer); if (http_res) { + cc_.reset(service_->CreateContext(peer, this)); if (*http_res) { VLOG(1) << "HTTP1.1 identified"; is_http_ = true; HttpConnection http_conn{http_listener_}; http_conn.SetSocket(peer); + http_conn.set_user_data(cc_.get()); auto ec = http_conn.ParseFromBuffer(io_buf_.InputBuffer()); io_buf_.ConsumeInput(io_buf_.InputLen()); if (!ec) { @@ -666,7 +669,6 @@ void Connection::HandleRequests() { // this connection. http_conn.ReleaseSocket(); } else { - cc_.reset(service_->CreateContext(peer, this)); if (breaker_cb_) { socket_->RegisterOnErrorCb([this](int32_t mask) { this->OnBreakCb(mask); }); } @@ -674,9 +676,8 @@ void Connection::HandleRequests() { ConnectionFlow(peer); socket_->CancelOnErrorCb(); // noop if nothing is registered. - - cc_.reset(); } + cc_.reset(); } VLOG(1) << "Closed connection for peer " << remote_ep; diff --git a/src/facade/reply_capture.h b/src/facade/reply_capture.h index 7004faf5c9a1..9ff9157ebd90 100644 --- a/src/facade/reply_capture.h +++ b/src/facade/reply_capture.h @@ -47,11 +47,9 @@ class CapturingReplyBuilder : public RedisReplyBuilder { void StartCollection(unsigned len, CollectionType type) override; - private: + public: using Error = std::pair; // SendError (msg, type) using Null = std::nullptr_t; // SendNull or SendNullArray - struct SimpleString : public std::string {}; // SendSimpleString - struct BulkString : public std::string {}; // SendBulkString struct StrArrPayload { bool simple; @@ -66,7 +64,9 @@ class CapturingReplyBuilder : public RedisReplyBuilder { bool with_scores; }; - public: + struct SimpleString : public std::string {}; // SendSimpleString + struct BulkString : public std::string {}; // SendBulkString + CapturingReplyBuilder(ReplyMode mode = ReplyMode::FULL) : RedisReplyBuilder{nullptr}, reply_mode_{mode}, stack_{}, current_{} { } @@ -89,7 +89,6 @@ class CapturingReplyBuilder : public RedisReplyBuilder { // If an error is stored inside payload, get a reference to it. static std::optional GetError(const Payload& pl); - private: struct CollectionPayload { CollectionPayload(unsigned len, CollectionType type); @@ -98,6 +97,7 @@ class CapturingReplyBuilder : public RedisReplyBuilder { std::vector arr; }; + private: private: // Send payload directly, bypassing external interface. For efficient passing between two // captures. diff --git a/src/server/CMakeLists.txt b/src/server/CMakeLists.txt index df40cf6cee3b..bc261a7a267e 100644 --- a/src/server/CMakeLists.txt +++ b/src/server/CMakeLists.txt @@ -36,7 +36,7 @@ SET(SEARCH_FILES search/search_family.cc search/doc_index.cc search/doc_accessor add_library(dragonfly_lib engine_shard_set.cc channel_store.cc config_registry.cc conn_context.cc debugcmd.cc dflycmd.cc - generic_family.cc hset_family.cc json_family.cc + generic_family.cc hset_family.cc http_api.cc json_family.cc ${SEARCH_FILES} list_family.cc main_service.cc memory_cmd.cc rdb_load.cc rdb_save.cc replica.cc protocol_client.cc diff --git a/src/server/http_api.cc b/src/server/http_api.cc new file mode 100644 index 000000000000..4a490b05c705 --- /dev/null +++ b/src/server/http_api.cc @@ -0,0 +1,224 @@ +// Copyright 2024, DragonflyDB authors. All rights reserved. +// See LICENSE for licensing terms. +// + +#include "server/http_api.h" + +#include "base/logging.h" +#include "core/flatbuffers.h" +#include "facade/conn_context.h" +#include "facade/reply_builder.h" +#include "server/main_service.h" +#include "util/http/http_common.h" + +namespace dfly { +using namespace util; +using namespace std; +namespace h2 = boost::beast::http; +using facade::CapturingReplyBuilder; + +namespace { + +bool IsValidReq(flexbuffers::Reference req) { + if (!req.IsVector()) { + return false; + } + + auto vec = req.AsVector(); + if (vec.size() == 0) { + return false; + } + + for (size_t i = 0; i < vec.size(); ++i) { + if (!vec[i].IsString()) { + return false; + } + } + return true; +} + +// Escape a string so that it is legal to print it in JSON text. +std::string JsonEscape(string_view input) { + auto hex_digit = [](int c) -> char { return c < 10 ? c + '0' : c - 10 + 'a'; }; + + string out; + out.reserve(input.size() + 2); + out.push_back('\"'); + + auto* p = reinterpret_cast(input.begin()); + auto* e = reinterpret_cast(input.end()); + + while (p < e) { + if (*p == '\\' || *p == '\"') { + out.push_back('\\'); + out.push_back(*p++); + } else if (*p <= 0x1f) { + switch (*p) { + case '\b': + out.append("\\b"); + p++; + break; + case '\f': + out.append("\\f"); + p++; + break; + case '\n': + out.append("\\n"); + p++; + break; + case '\r': + out.append("\\r"); + p++; + break; + case '\t': + out.append("\\t"); + p++; + break; + default: + // this condition captures non readable chars with value < 32, + // so size = 1 byte (e.g control chars). + out.append("\\u00"); + out.push_back(hex_digit((*p & 0xf0) >> 4)); + out.push_back(hex_digit(*p & 0xf)); + p++; + } + } else { + out.push_back(*p++); + } + } + + out.push_back('\"'); + return out; +} + +struct CaptureVisitor { + CaptureVisitor() { + str = R"({"result":)"; + } + + void operator()(monostate) { + } + + void operator()(long v) { + absl::StrAppend(&str, v); + } + + void operator()(double v) { + absl::StrAppend(&str, v); + } + + void operator()(const CapturingReplyBuilder::SimpleString& ss) { + absl::StrAppend(&str, "\"", ss, "\""); + } + + void operator()(const CapturingReplyBuilder::BulkString& bs) { + absl::StrAppend(&str, JsonEscape(bs)); + } + + void operator()(CapturingReplyBuilder::Null) { + absl::StrAppend(&str, "null"); + } + + void operator()(CapturingReplyBuilder::Error err) { + str = absl::StrCat(R"({"error": ")", err.first); + } + + void operator()(facade::OpStatus status) { + absl::StrAppend(&str, "\"", facade::StatusToMsg(status), "\""); + } + + void operator()(const CapturingReplyBuilder::StrArrPayload& sa) { + absl::StrAppend(&str, "not_implemented"); + } + + void operator()(const unique_ptr& cp) { + if (!cp) { + absl::StrAppend(&str, "null"); + return; + } + if (cp->len == 0 && cp->type == facade::RedisReplyBuilder::ARRAY) { + absl::StrAppend(&str, "[]"); + return; + } + + absl::StrAppend(&str, "["); + for (auto& pl : cp->arr) { + visit(*this, std::move(pl)); + } + } + + void operator()(facade::SinkReplyBuilder::MGetResponse resp) { + absl::StrAppend(&str, "not_implemented"); + } + + void operator()(const CapturingReplyBuilder::ScoredArray& sarr) { + absl::StrAppend(&str, "["); + for (const auto& [key, score] : sarr.arr) { + absl::StrAppend(&str, "{", JsonEscape(key), ":", score, "},"); + } + if (sarr.arr.size() > 0) { + str.pop_back(); + } + absl::StrAppend(&str, "]"); + } + + string str; +}; + +} // namespace + +void HttpAPI(const http::QueryArgs& args, HttpRequest&& req, Service* service, + HttpContext* http_cntx) { + auto& body = req.body(); + + flexbuffers::Builder fbb; + flatbuffers::Parser parser; + flexbuffers::Reference doc; + bool success = parser.ParseFlexBuffer(body.c_str(), nullptr, &fbb); + if (success) { + fbb.Finish(); + doc = flexbuffers::GetRoot(fbb.GetBuffer()); + if (!IsValidReq(doc)) { + success = false; + } + } + + if (!success) { + auto response = http::MakeStringResponse(h2::status::bad_request); + http::SetMime(http::kTextMime, &response); + response.body() = "Failed to parse json\r\n"; + http_cntx->Invoke(std::move(response)); + return; + } + + vector cmd_args; + flexbuffers::Vector vec = doc.AsVector(); + for (size_t i = 0; i < vec.size(); ++i) { + cmd_args.push_back(vec[i].AsString().c_str()); + } + vector cmd_slices(cmd_args.size()); + for (size_t i = 0; i < cmd_args.size(); ++i) { + cmd_slices[i] = absl::MakeSpan(cmd_args[i]); + } + + facade::ConnectionContext* context = (facade::ConnectionContext*)http_cntx->user_data(); + DCHECK(context); + + facade::CapturingReplyBuilder reply_builder; + auto* prev = context->Inject(&reply_builder); + // TODO: to finish this. + service->DispatchCommand(absl::MakeSpan(cmd_slices), context); + facade::CapturingReplyBuilder::Payload payload = reply_builder.Take(); + + context->Inject(prev); + auto response = http::MakeStringResponse(); + http::SetMime(http::kJsonMime, &response); + + CaptureVisitor visitor; + std::visit(visitor, std::move(payload)); + visitor.str.append("}\r\n"); + response.body() = visitor.str; + http_cntx->Invoke(std::move(response)); +} + +} // namespace dfly diff --git a/src/server/http_api.h b/src/server/http_api.h new file mode 100644 index 000000000000..2f76ba26c168 --- /dev/null +++ b/src/server/http_api.h @@ -0,0 +1,16 @@ +// Copyright 2024, DragonflyDB authors. All rights reserved. +// See LICENSE for licensing terms. +// + +#pragma once + +#include "util/http/http_handler.h" + +namespace dfly { +class Service; +using HttpRequest = util::HttpListenerBase::RequestType; + +void HttpAPI(const util::http::QueryArgs& args, HttpRequest&& req, Service* service, + util::HttpContext* send); + +} // namespace dfly diff --git a/src/server/main_service.cc b/src/server/main_service.cc index cfc5e6259db2..67b8c4d6d831 100644 --- a/src/server/main_service.cc +++ b/src/server/main_service.cc @@ -1,4 +1,4 @@ -// Copyright 2022, DragonflyDB authors. All rights reserved. +// Copyright 2024, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // @@ -40,6 +40,7 @@ extern "C" { #include "server/generic_family.h" #include "server/hll_family.h" #include "server/hset_family.h" +#include "server/http_api.h" #include "server/json_family.h" #include "server/list_family.h" #include "server/multi_command_squasher.h" @@ -83,6 +84,9 @@ ABSL_FLAG(bool, admin_nopass, false, "If set, would enable open admin access to console on the assigned port, without " "authorization needed."); +ABSL_FLAG(bool, expose_http_api, false, + "If set, will expose a POST /api handler for sending redis commands as json array."); + ABSL_FLAG(dfly::MemoryBytesFlag, maxmemory, dfly::MemoryBytesFlag{}, "Limit on maximum-memory that is used by the database. " "0 - means the program will automatically determine its maximum memory usage. " @@ -2441,6 +2445,13 @@ void Service::ConfigureHttpHandlers(util::HttpListenerBase* base, bool is_privil base->RegisterCb("/clusterz", [this](const http::QueryArgs& args, HttpContext* send) { return ClusterHtmlPage(args, send, &cluster_family_); }); + + if (absl::GetFlag(FLAGS_expose_http_api)) { + base->RegisterCb("/api", + [this](const http::QueryArgs& args, HttpRequest&& req, HttpContext* send) { + HttpAPI(args, std::move(req), this, send); + }); + } } void Service::OnClose(facade::ConnectionContext* cntx) { diff --git a/tests/dragonfly/connection_test.py b/tests/dragonfly/connection_test.py index cfa1aab1505d..a211f61ca55e 100755 --- a/tests/dragonfly/connection_test.py +++ b/tests/dragonfly/connection_test.py @@ -7,6 +7,7 @@ from redis.exceptions import ConnectionError as redis_conn_error, ResponseError import async_timeout from dataclasses import dataclass +from aiohttp import ClientSession from . import dfly_args from .instance import DflyInstance, DflyInstanceFactory @@ -67,7 +68,6 @@ def should_exclude(cmd: str): """ -@pytest.mark.asyncio @dfly_args({"proactor_threads": 4}) async def test_monitor_command(async_pool): monitor = CollectingMonitor(aioredis.Redis(connection_pool=async_pool)) @@ -90,7 +90,6 @@ async def test_monitor_command(async_pool): """ -@pytest.mark.asyncio @dfly_args({"proactor_threads": 4, "multi_exec_squash": "true"}) async def test_monitor_command_multi(async_pool): monitor = CollectingMonitor(aioredis.Redis(connection_pool=async_pool)) @@ -127,7 +126,6 @@ async def test_monitor_command_multi(async_pool): """ -@pytest.mark.asyncio @dfly_args({"proactor_threads": 4, "lua_auto_async": "false"}) async def test_monitor_command_lua(async_pool): monitor = CollectingMonitor(aioredis.Redis(connection_pool=async_pool)) @@ -151,7 +149,6 @@ async def test_monitor_command_lua(async_pool): """ -@pytest.mark.asyncio async def test_pipeline_support(async_client): def generate(max): for i in range(max): @@ -200,7 +197,6 @@ async def run_pipeline_mode(async_client: aioredis.Redis, messages): """ -@pytest.mark.asyncio async def test_pubsub_command(async_client): def generate(max): for i in range(max): @@ -276,7 +272,6 @@ async def run_multi_pubsub(async_client, messages, channel_name): """ -@pytest.mark.asyncio async def test_multi_pubsub(async_client): def generate(max): for i in range(max): @@ -293,7 +288,6 @@ def generate(max): """ -@pytest.mark.asyncio async def test_pubsub_subcommand_for_numsub(async_client): subs1 = [async_client.pubsub() for i in range(5)] for s in subs1: @@ -343,7 +337,6 @@ async def test_pubsub_subcommand_for_numsub(async_client): """ -@pytest.mark.asyncio @pytest.mark.slow @dfly_args({"proactor_threads": "1", "subscriber_thread_limit": "100"}) async def test_publish_stuck(df_server: DflyInstance, async_client: aioredis.Redis): @@ -381,7 +374,6 @@ async def pub_task(): await pub -@pytest.mark.asyncio async def test_subscribers_with_active_publisher(df_server: DflyInstance, max_connections=100): # TODO: I am not how to customize the max connections for the pool. async_pool = aioredis.ConnectionPool( @@ -562,7 +554,6 @@ async def test_large_cmd(async_client: aioredis.Redis): assert len(res) == MAX_ARR_SIZE -@pytest.mark.asyncio async def test_reject_non_tls_connections_on_tls(with_tls_server_args, df_local_factory): server: DflyInstance = df_local_factory.create( no_tls_on_admin_port="true", @@ -583,7 +574,6 @@ async def test_reject_non_tls_connections_on_tls(with_tls_server_args, df_local_ await client.close() -@pytest.mark.asyncio async def test_tls_insecure(with_ca_tls_server_args, with_tls_client_args, df_local_factory): server = df_local_factory.create(port=BASE_PORT, **with_ca_tls_server_args) server.start() @@ -593,7 +583,6 @@ async def test_tls_insecure(with_ca_tls_server_args, with_tls_client_args, df_lo await client.close() -@pytest.mark.asyncio async def test_tls_full_auth(with_ca_tls_server_args, with_ca_tls_client_args, df_local_factory): server = df_local_factory.create(port=BASE_PORT, **with_ca_tls_server_args) server.start() @@ -603,7 +592,6 @@ async def test_tls_full_auth(with_ca_tls_server_args, with_ca_tls_client_args, d await client.close() -@pytest.mark.asyncio async def test_tls_reject( with_ca_tls_server_args, with_tls_client_args, df_local_factory: DflyInstanceFactory ): @@ -620,7 +608,6 @@ async def test_tls_reject( await client.close() -@pytest.mark.asyncio @dfly_args({"proactor_threads": "4", "pipeline_squash": 10}) async def test_squashed_pipeline(async_client: aioredis.Redis): p = async_client.pipeline(transaction=False) @@ -638,7 +625,6 @@ async def test_squashed_pipeline(async_client: aioredis.Redis): res = res[11:] -@pytest.mark.asyncio @dfly_args({"proactor_threads": "4", "pipeline_squash": 10}) async def test_squashed_pipeline_seeder(df_server, df_seeder_factory): seeder = df_seeder_factory.create(port=df_server.port, keys=10_000) @@ -650,7 +636,6 @@ async def test_squashed_pipeline_seeder(df_server, df_seeder_factory): """ -@pytest.mark.asyncio @dfly_args({"proactor_threads": "4", "pipeline_squash": 1}) async def test_squashed_pipeline_multi(async_client: aioredis.Redis): p = async_client.pipeline(transaction=False) @@ -670,7 +655,6 @@ async def test_squashed_pipeline_multi(async_client: aioredis.Redis): await p.execute() -@pytest.mark.asyncio async def test_unix_domain_socket(df_local_factory, tmp_dir): server = df_local_factory.create(proactor_threads=1, port=BASE_PORT, unixsocket="./df.sock") server.start() @@ -688,7 +672,6 @@ async def test_unix_domain_socket(df_local_factory, tmp_dir): @pytest.mark.slow -@pytest.mark.asyncio async def test_nested_client_pause(async_client: aioredis.Redis): async def do_pause(): await async_client.execute_command("CLIENT", "PAUSE", "1000", "WRITE") @@ -715,7 +698,6 @@ async def do_write(): await p3 -@pytest.mark.asyncio async def test_blocking_command_client_pause(async_client: aioredis.Redis): """ 1. Check client pause success when blocking transaction is running @@ -743,7 +725,6 @@ async def lpush_command(): await blocking -@pytest.mark.asyncio async def test_multiple_blocking_commands_client_pause(async_client: aioredis.Redis): """ Check running client pause command simultaneously with running multiple blocking command @@ -765,3 +746,18 @@ async def client_pause(): assert not all.done() await all + + +@dfly_args({"proactor_threads": "1", "expose_http_api": "true"}) +async def test_http(df_server: DflyInstance): + client = df_server.client() + async with ClientSession() as session: + async with session.get(f"http://localhost:{df_server.port}") as resp: + assert resp.status == 200 + + body = '["set", "foo", "bar", "ex", "100"]' + async with session.post(f"http://localhost:{df_server.port}/api", data=body) as resp: + assert resp.status == 200 + text = await resp.text() + assert text.strip() == '{"result":"OK"}' + assert await client.ttl("foo") > 0