From bd3868305ff2d8d6a6acaffd9390f9ab18c3835f Mon Sep 17 00:00:00 2001 From: Roman Gershman Date: Fri, 12 Jul 2024 07:46:34 +0300 Subject: [PATCH] chore: refactor gcp code (#294) --- util/cloud/gcp/CMakeLists.txt | 2 +- util/cloud/gcp/gcp_creds_provider.h | 64 ++++++++++++++++ util/cloud/gcp/gcp_utils.cc | 114 ++++++++++++++++++++++++++++ util/cloud/gcp/gcp_utils.h | 38 ++++++++++ util/cloud/gcp/gcs.cc | 72 +++--------------- util/cloud/gcp/gcs.h | 54 +------------ util/cloud/gcp/gcs_file.cc | 56 ++++++++++++++ util/cloud/gcp/gcs_file.h | 42 ++++++++++ 8 files changed, 326 insertions(+), 116 deletions(-) create mode 100644 util/cloud/gcp/gcp_creds_provider.h create mode 100644 util/cloud/gcp/gcp_utils.cc create mode 100644 util/cloud/gcp/gcp_utils.h create mode 100644 util/cloud/gcp/gcs_file.cc create mode 100644 util/cloud/gcp/gcs_file.h diff --git a/util/cloud/gcp/CMakeLists.txt b/util/cloud/gcp/CMakeLists.txt index d19f6d62..837774c8 100644 --- a/util/cloud/gcp/CMakeLists.txt +++ b/util/cloud/gcp/CMakeLists.txt @@ -1,3 +1,3 @@ -add_library(gcp_lib gcs.cc) +add_library(gcp_lib gcs.cc gcs_file.cc gcp_utils.cc) cxx_link(gcp_lib http_client_lib strings_lib TRDP::rapidjson) diff --git a/util/cloud/gcp/gcp_creds_provider.h b/util/cloud/gcp/gcp_creds_provider.h new file mode 100644 index 00000000..0eea471a --- /dev/null +++ b/util/cloud/gcp/gcp_creds_provider.h @@ -0,0 +1,64 @@ +// Copyright 2024, Roman Gershman. All rights reserved. +// See LICENSE for licensing terms. + +#pragma once + +#include "base/RWSpinLock.h" + +namespace util { + +namespace fb2 { +class ProactorBase; +} // namespace fb2 + +namespace cloud { + +class GCPCredsProvider { + GCPCredsProvider(const GCPCredsProvider&) = delete; + GCPCredsProvider& operator=(const GCPCredsProvider&) = delete; + + public: + GCPCredsProvider() = default; + + std::error_code Init(unsigned connect_ms, fb2::ProactorBase* pb); + + const std::string& project_id() const { + return project_id_; + } + + const std::string& client_id() const { + return client_id_; + } + + // Thread-safe method to access the token. + std::string access_token() const { + folly::RWSpinLock::ReadHolder lock(lock_); + return access_token_; + } + + time_t expire_time() const { + return expire_time_.load(std::memory_order_acquire); + } + + // Thread-safe method issues refresh of the token. + // Right now will do the refresh unconditonally. + // TODO: to use expire_time_ to skip the refresh if expire time is far away. + std::error_code RefreshToken(fb2::ProactorBase* pb); + + private: + bool use_instance_metadata_ = false; + unsigned connect_ms_ = 0; + + fb2::ProactorBase* pb_ = nullptr; + std::string account_id_; + std::string project_id_; + + std::string client_id_, client_secret_, refresh_token_; + + mutable folly::RWSpinLock lock_; // protects access_token_ + std::string access_token_; + std::atomic expire_time_ = 0; // seconds since epoch +}; + +} // namespace cloud +} // namespace util \ No newline at end of file diff --git a/util/cloud/gcp/gcp_utils.cc b/util/cloud/gcp/gcp_utils.cc new file mode 100644 index 00000000..7358048b --- /dev/null +++ b/util/cloud/gcp/gcp_utils.cc @@ -0,0 +1,114 @@ +// Copyright 2024, Roman Gershman. All rights reserved. +// See LICENSE for licensing terms. + +#include "util/cloud/gcp/gcp_utils.h" + +#include + +#include + +#include "base/logging.h" +#include "util/cloud/gcp/gcp_creds_provider.h" + +#define RETURN_UNEXPECTED(x) \ + do { \ + auto ec = (x); \ + if (ec) \ + return nonstd::make_unexpected(ec); \ + } while (false) + +namespace util::cloud { +using namespace std; +namespace h2 = boost::beast::http; + +namespace { + +bool IsUnauthorized(const h2::header& resp) { + if (resp.result() != h2::status::unauthorized) { + return false; + } + auto it = resp.find("WWW-Authenticate"); + + return it != resp.end(); +} + +inline bool DoesServerPushback(h2::status st) { + return st == h2::status::too_many_requests || + h2::to_status_class(st) == h2::status_class::server_error; +} + +} // namespace + +const char GCP_API_DOMAIN[] = "www.googleapis.com"; + +string AuthHeader(string_view access_token) { + return absl::StrCat("Bearer ", access_token); +} + +EmptyRequest PrepareRequest(h2::verb req_verb, std::string_view url, + const string_view access_token) { + EmptyRequest req{req_verb, boost::beast::string_view{url.data(), url.size()}, 11}; + req.set(h2::field::host, GCP_API_DOMAIN); + req.set(h2::field::authorization, AuthHeader(access_token)); + req.keep_alive(true); + + return req; +} + +RobustSender::RobustSender(unsigned num_iterations, GCPCredsProvider* provider) + : num_iterations_(num_iterations), provider_(provider) { +} + +auto RobustSender::Send(http::Client* client, EmptyRequest* req) -> io::Result { + error_code ec; + for (unsigned i = 0; i < num_iterations_; ++i) { // Iterate for possible token refresh. + VLOG(1) << "HttpReq" << i << ": " << *req << ", socket " << client->native_handle(); + + RETURN_UNEXPECTED(client->Send(*req)); + HeaderParserPtr parser(new h2::response_parser()); + RETURN_UNEXPECTED(client->ReadHeader(parser.get())); + { + const auto& msg = parser->get(); + VLOG(1) << "RespHeader" << i << ": " << msg; + + if (!parser->keep_alive()) { + LOG(FATAL) << "TBD: Schedule reconnect due to conn-close header"; + } + + // Partial content can appear because of the previous reconnect. + if (msg.result() == h2::status::ok || msg.result() == h2::status::partial_content) { + return parser; + } + } + // We have some kind of error, possibly with body that needs to be drained. + h2::response_parser drainer(std::move(*parser)); + RETURN_UNEXPECTED(client->Recv(&drainer)); + const auto& msg = drainer.get(); + + if (DoesServerPushback(msg.result())) { + LOG(INFO) << "Retrying(" << client->native_handle() << ") with " << msg; + + ThisFiber::SleepFor(1s); + i = 0; // Can potentially deadlock + continue; + } + + if (IsUnauthorized(msg)) { + RETURN_UNEXPECTED(provider_->RefreshToken(client->proactor())); + req->set(h2::field::authorization, AuthHeader(provider_->access_token())); + + continue; + } + + if (msg.result() == h2::status::forbidden) { + return nonstd::make_unexpected(make_error_code(errc::operation_not_permitted)); + } + + ec = make_error_code(errc::bad_message); + LOG(DFATAL) << "Unexpected response " << msg << "\n" << msg.body() << "\n"; + } + + return nonstd::make_unexpected(ec); +} + +} // namespace util::cloud \ No newline at end of file diff --git a/util/cloud/gcp/gcp_utils.h b/util/cloud/gcp/gcp_utils.h new file mode 100644 index 00000000..0c5bd8a3 --- /dev/null +++ b/util/cloud/gcp/gcp_utils.h @@ -0,0 +1,38 @@ +// Copyright 2024, Roman Gershman. All rights reserved. +// See LICENSE for licensing terms. + +#pragma once + +#include +#include + +#include "io/io.h" +#include "util/http/http_client.h" + +namespace util::cloud { +class GCPCredsProvider; + +extern const char GCP_API_DOMAIN[]; + +using EmptyRequest = boost::beast::http::request; + +EmptyRequest PrepareRequest(boost::beast::http::verb req_verb, std::string_view url, + const std::string_view access_token); + +std::string AuthHeader(std::string_view access_token); + +class RobustSender { + public: + using HeaderParserPtr = + std::unique_ptr>; + + RobustSender(unsigned num_iterations, GCPCredsProvider* provider); + + io::Result Send(http::Client* client, EmptyRequest* req); + + private: + unsigned num_iterations_; + GCPCredsProvider* provider_; +}; + +} // namespace util::cloud \ No newline at end of file diff --git a/util/cloud/gcp/gcs.cc b/util/cloud/gcp/gcs.cc index c5202291..946c5422 100644 --- a/util/cloud/gcp/gcs.cc +++ b/util/cloud/gcp/gcs.cc @@ -15,6 +15,7 @@ #include "io/file_util.h" #include "io/line_reader.h" #include "strings/escaping.h" +#include "util/cloud/gcp/gcp_utils.h" using namespace std; namespace h2 = boost::beast::http; @@ -24,9 +25,6 @@ namespace util { namespace cloud { namespace { -constexpr char kDomain[] = "www.googleapis.com"; - -using EmptyRequest = h2::request; auto Unexpected(std::errc code) { return nonstd::make_unexpected(make_error_code(code)); @@ -46,28 +44,6 @@ auto Unexpected(std::errc code) { return ec; \ } while (false) -string AuthHeader(string_view access_token) { - return absl::StrCat("Bearer ", access_token); -} - -EmptyRequest PrepareRequest(h2::verb req_verb, boost::beast::string_view url, - const string_view access_token) { - EmptyRequest req(req_verb, url, 11); - req.set(h2::field::host, kDomain); - req.set(h2::field::authorization, AuthHeader(access_token)); - req.keep_alive(true); - - return req; -} - -bool IsUnauthorized(const h2::header& resp) { - if (resp.result() != h2::status::unauthorized) { - return false; - } - auto it = resp.find("WWW-Authenticate"); - - return it != resp.end(); -} io::Result ExpandFile(string_view path) { io::Result res = io::StatFiles(path); @@ -177,36 +153,6 @@ io::Result ParseTokenResponse(std::string&& response) { return result; } -using EmptyParserPtr = std::unique_ptr>; -io::Result SendWithToken(GCPCredsProvider* provider, http::Client* client, - EmptyRequest* req) { - error_code ec; - for (unsigned i = 0; i < 2; ++i) { // Iterate for possible token refresh. - VLOG(1) << "HttpReq" << i << ": " << *req << ", socket " << client->native_handle(); - - RETURN_UNEXPECTED(client->Send(*req)); - EmptyParserPtr parser(new h2::response_parser()); - RETURN_UNEXPECTED(client->ReadHeader(parser.get())); - - VLOG(1) << "RespHeader" << i << ": " << parser->get(); - - if (parser->get().result() == h2::status::ok) { - return parser; - }; - - if (IsUnauthorized(parser->get())) { - RETURN_UNEXPECTED(provider->RefreshToken(client->proactor())); - req->set(h2::field::authorization, AuthHeader(provider->access_token())); - - continue; - } - ec = make_error_code(errc::bad_message); - LOG(DFATAL) << "Unexpected response " << parser.get(); - } - - return nonstd::make_unexpected(ec); -} - #define FETCH_ARRAY_MEMBER(val) \ if (!(val).IsArray()) \ return make_error_code(errc::bad_message); \ @@ -310,7 +256,7 @@ GCS::~GCS() { std::error_code GCS::Connect(unsigned msec) { client_->set_connect_timeout_ms(msec); - return client_->Connect(kDomain, "443", ssl_ctx_); + return client_->Connect(GCP_API_DOMAIN, "443", ssl_ctx_); } error_code GCS::ListBuckets(ListBucketCb cb) { @@ -321,12 +267,13 @@ error_code GCS::ListBuckets(ListBucketCb cb) { rj::Document doc; + RobustSender sender(2, &creds_provider_); + while (true) { - io::Result parse_res = - SendWithToken(&creds_provider_, client_.get(), &http_req); + io::Result parse_res = sender.Send(client_.get(), &http_req); if (!parse_res) return parse_res.error(); - EmptyParserPtr empty_parser = std::move(*parse_res); + RobustSender::HeaderParserPtr empty_parser = std::move(*parse_res); h2::response_parser resp(std::move(*empty_parser)); RETURN_ERROR(client_->Recv(&resp)); @@ -376,12 +323,13 @@ error_code GCS::List(string_view bucket, string_view prefix, bool recursive, auto http_req = PrepareRequest(h2::verb::get, url, creds_provider_.access_token()); rj::Document doc; + RobustSender sender(2, &creds_provider_); while (true) { - io::Result parse_res = - SendWithToken(&creds_provider_, client_.get(), &http_req); + io::Result parse_res = sender.Send(client_.get(), &http_req); if (!parse_res) return parse_res.error(); - EmptyParserPtr empty_parser = std::move(*parse_res); + RobustSender::HeaderParserPtr empty_parser = std::move(*parse_res); + h2::response_parser resp(std::move(*empty_parser)); RETURN_ERROR(client_->Recv(&resp)); diff --git a/util/cloud/gcp/gcs.h b/util/cloud/gcp/gcs.h index a2c5cbb7..ee1ca918 100644 --- a/util/cloud/gcp/gcs.h +++ b/util/cloud/gcp/gcs.h @@ -7,66 +7,14 @@ #include +#include "util/cloud/gcp/gcp_creds_provider.h" #include "util/http/http_client.h" -#include "base/RWSpinLock.h" typedef struct ssl_ctx_st SSL_CTX; namespace util { - -namespace fb2 { -class ProactorBase; -} // namespace fb2 - namespace cloud { -class GCPCredsProvider { - GCPCredsProvider(const GCPCredsProvider&) = delete; - GCPCredsProvider& operator=(const GCPCredsProvider&) = delete; - - public: - GCPCredsProvider() = default; - - std::error_code Init(unsigned connect_ms, fb2::ProactorBase* pb); - - const std::string& project_id() const { - return project_id_; - } - - const std::string& client_id() const { - return client_id_; - } - - // Thread-safe method to access the token. - std::string access_token() const { - folly::RWSpinLock::ReadHolder lock(lock_); - return access_token_; - } - - time_t expire_time() const { - return expire_time_.load(std::memory_order_acquire); - } - - // Thread-safe method issues refresh of the token. - // Right now will do the refresh unconditonally. - // TODO: to use expire_time_ to skip the refresh if expire time is far away. - std::error_code RefreshToken(fb2::ProactorBase* pb); - - private: - bool use_instance_metadata_ = false; - unsigned connect_ms_ = 0; - - fb2::ProactorBase* pb_ = nullptr; - std::string account_id_; - std::string project_id_; - - std::string client_id_, client_secret_, refresh_token_; - - mutable folly::RWSpinLock lock_; // protects access_token_ - std::string access_token_; - std::atomic expire_time_ = 0; // seconds since epoch -}; - class GCS { public: using BucketItem = std::string_view; diff --git a/util/cloud/gcp/gcs_file.cc b/util/cloud/gcp/gcs_file.cc new file mode 100644 index 00000000..fb2959d8 --- /dev/null +++ b/util/cloud/gcp/gcs_file.cc @@ -0,0 +1,56 @@ +// Copyright 2024, Roman Gershman. All rights reserved. +// See LICENSE for licensing terms. + +#include "util/cloud/gcp/gcs_file.h" + +#include + +#include + +#include "strings/escaping.h" +#include "util/cloud/gcp/gcp_utils.h" + +namespace util { + +namespace cloud { +using namespace std; +namespace h2 = boost::beast::http; + +namespace { + + +} // namespace + +io::Result GcsWriteFile::Open(const string& bucket, const string& key, + GCPCredsProvider* creds_provider, + http::ClientPool* pool, size_t part_size) { + string url = "/upload/storage/v1/b/"; + absl::StrAppend(&url, bucket, "/o?uploadType=resumable&name="); + strings::AppendUrlEncoded(key, &url); + string token = creds_provider->access_token(); + auto req = PrepareRequest(h2::verb::post, url, token); + string upload_id; +#if 0 + ApiSenderDynamicBody sender("start_write", gce, pool); + auto res = sender.SendGeneric(3, std::move(req)); + if (!res.ok()) + return res.status; + + const auto& resp = sender.parser()->get(); + + // HttpsClientPool::ClientHandle handle = std::move(res.obj); + + auto it = resp.find(h2::field::location); + if (it == resp.end()) { + return Status(StatusCode::PARSE_ERROR, "Can not find location header"); + } + string upload_id = string(it->value()); + + +#endif + + return new GcsWriteFile(key, upload_id, part_size, pool); +} + +} // namespace cloud +} // namespace util diff --git a/util/cloud/gcp/gcs_file.h b/util/cloud/gcp/gcs_file.h new file mode 100644 index 00000000..8b0ab968 --- /dev/null +++ b/util/cloud/gcp/gcs_file.h @@ -0,0 +1,42 @@ +// Copyright 2024, Roman Gershman. All rights reserved. +// See LICENSE for licensing terms. + +#pragma once + +#include "io/file.h" +#include "util/http/https_client_pool.h" +#include "util/cloud/gcp/gcp_creds_provider.h" + +namespace util { + +namespace cloud { + +// File handle that writes to GCS. +// +// This uses multipart uploads, where it will buffer upto the configured part +// size before uploading. +class GcsWriteFile : public io::WriteFile { + public: + static constexpr size_t kDefaultPartSize = 1ULL << 23; // 8MB. + + // Writes bytes to the GCS object. This will either buffer internally or + // write a part to GCS. + io::Result WriteSome(const iovec* v, uint32_t len) override; + + // Closes the object and completes the multipart upload. Therefore the object + // will not be uploaded unless Close is called. + std::error_code Close() override; + + static io::Result Open(const std::string& bucket, const std::string& key, + GCPCredsProvider* creds_provider, + http::ClientPool* pool, size_t part_size = kDefaultPartSize); + + private: + GcsWriteFile(const std::string& key, const std::string& upload_id, + size_t part_size, http::ClientPool* pool); + + std::string upload_id_; +}; + +} // namespace cloud +} // namespace util \ No newline at end of file