Skip to content

Commit

Permalink
chore: refactor gcp code (#294)
Browse files Browse the repository at this point in the history
  • Loading branch information
romange authored Jul 12, 2024
1 parent 52d95e1 commit bd38683
Show file tree
Hide file tree
Showing 8 changed files with 326 additions and 116 deletions.
2 changes: 1 addition & 1 deletion util/cloud/gcp/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
add_library(gcp_lib gcs.cc)
add_library(gcp_lib gcs.cc gcs_file.cc gcp_utils.cc)

cxx_link(gcp_lib http_client_lib strings_lib TRDP::rapidjson)
64 changes: 64 additions & 0 deletions util/cloud/gcp/gcp_creds_provider.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// Copyright 2024, Roman Gershman. All rights reserved.
// See LICENSE for licensing terms.

#pragma once

#include "base/RWSpinLock.h"

namespace util {

namespace fb2 {
class ProactorBase;
} // namespace fb2

namespace cloud {

class GCPCredsProvider {
GCPCredsProvider(const GCPCredsProvider&) = delete;
GCPCredsProvider& operator=(const GCPCredsProvider&) = delete;

public:
GCPCredsProvider() = default;

std::error_code Init(unsigned connect_ms, fb2::ProactorBase* pb);

const std::string& project_id() const {
return project_id_;
}

const std::string& client_id() const {
return client_id_;
}

// Thread-safe method to access the token.
std::string access_token() const {
folly::RWSpinLock::ReadHolder lock(lock_);
return access_token_;
}

time_t expire_time() const {
return expire_time_.load(std::memory_order_acquire);
}

// Thread-safe method issues refresh of the token.
// Right now will do the refresh unconditonally.
// TODO: to use expire_time_ to skip the refresh if expire time is far away.
std::error_code RefreshToken(fb2::ProactorBase* pb);

private:
bool use_instance_metadata_ = false;
unsigned connect_ms_ = 0;

fb2::ProactorBase* pb_ = nullptr;
std::string account_id_;
std::string project_id_;

std::string client_id_, client_secret_, refresh_token_;

mutable folly::RWSpinLock lock_; // protects access_token_
std::string access_token_;
std::atomic<time_t> expire_time_ = 0; // seconds since epoch
};

} // namespace cloud
} // namespace util
114 changes: 114 additions & 0 deletions util/cloud/gcp/gcp_utils.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
// Copyright 2024, Roman Gershman. All rights reserved.
// See LICENSE for licensing terms.

#include "util/cloud/gcp/gcp_utils.h"

#include <absl/strings/str_cat.h>

#include <boost/beast/http/string_body.hpp>

#include "base/logging.h"
#include "util/cloud/gcp/gcp_creds_provider.h"

#define RETURN_UNEXPECTED(x) \
do { \
auto ec = (x); \
if (ec) \
return nonstd::make_unexpected(ec); \
} while (false)

namespace util::cloud {
using namespace std;
namespace h2 = boost::beast::http;

namespace {

bool IsUnauthorized(const h2::header<false, h2::fields>& resp) {
if (resp.result() != h2::status::unauthorized) {
return false;
}
auto it = resp.find("WWW-Authenticate");

return it != resp.end();
}

inline bool DoesServerPushback(h2::status st) {
return st == h2::status::too_many_requests ||
h2::to_status_class(st) == h2::status_class::server_error;
}

} // namespace

const char GCP_API_DOMAIN[] = "www.googleapis.com";

string AuthHeader(string_view access_token) {
return absl::StrCat("Bearer ", access_token);
}

EmptyRequest PrepareRequest(h2::verb req_verb, std::string_view url,
const string_view access_token) {
EmptyRequest req{req_verb, boost::beast::string_view{url.data(), url.size()}, 11};
req.set(h2::field::host, GCP_API_DOMAIN);
req.set(h2::field::authorization, AuthHeader(access_token));
req.keep_alive(true);

return req;
}

RobustSender::RobustSender(unsigned num_iterations, GCPCredsProvider* provider)
: num_iterations_(num_iterations), provider_(provider) {
}

auto RobustSender::Send(http::Client* client, EmptyRequest* req) -> io::Result<HeaderParserPtr> {
error_code ec;
for (unsigned i = 0; i < num_iterations_; ++i) { // Iterate for possible token refresh.
VLOG(1) << "HttpReq" << i << ": " << *req << ", socket " << client->native_handle();

RETURN_UNEXPECTED(client->Send(*req));
HeaderParserPtr parser(new h2::response_parser<h2::empty_body>());
RETURN_UNEXPECTED(client->ReadHeader(parser.get()));
{
const auto& msg = parser->get();
VLOG(1) << "RespHeader" << i << ": " << msg;

if (!parser->keep_alive()) {
LOG(FATAL) << "TBD: Schedule reconnect due to conn-close header";
}

// Partial content can appear because of the previous reconnect.
if (msg.result() == h2::status::ok || msg.result() == h2::status::partial_content) {
return parser;
}
}
// We have some kind of error, possibly with body that needs to be drained.
h2::response_parser<h2::string_body> drainer(std::move(*parser));
RETURN_UNEXPECTED(client->Recv(&drainer));
const auto& msg = drainer.get();

if (DoesServerPushback(msg.result())) {
LOG(INFO) << "Retrying(" << client->native_handle() << ") with " << msg;

ThisFiber::SleepFor(1s);
i = 0; // Can potentially deadlock
continue;
}

if (IsUnauthorized(msg)) {
RETURN_UNEXPECTED(provider_->RefreshToken(client->proactor()));
req->set(h2::field::authorization, AuthHeader(provider_->access_token()));

continue;
}

if (msg.result() == h2::status::forbidden) {
return nonstd::make_unexpected(make_error_code(errc::operation_not_permitted));
}

ec = make_error_code(errc::bad_message);
LOG(DFATAL) << "Unexpected response " << msg << "\n" << msg.body() << "\n";
}

return nonstd::make_unexpected(ec);
}

} // namespace util::cloud
38 changes: 38 additions & 0 deletions util/cloud/gcp/gcp_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Copyright 2024, Roman Gershman. All rights reserved.
// See LICENSE for licensing terms.

#pragma once

#include <boost/beast/http/empty_body.hpp>
#include <memory>

#include "io/io.h"
#include "util/http/http_client.h"

namespace util::cloud {
class GCPCredsProvider;

extern const char GCP_API_DOMAIN[];

using EmptyRequest = boost::beast::http::request<boost::beast::http::empty_body>;

EmptyRequest PrepareRequest(boost::beast::http::verb req_verb, std::string_view url,
const std::string_view access_token);

std::string AuthHeader(std::string_view access_token);

class RobustSender {
public:
using HeaderParserPtr =
std::unique_ptr<boost::beast::http::response_parser<boost::beast::http::empty_body>>;

RobustSender(unsigned num_iterations, GCPCredsProvider* provider);

io::Result<HeaderParserPtr> Send(http::Client* client, EmptyRequest* req);

private:
unsigned num_iterations_;
GCPCredsProvider* provider_;
};

} // namespace util::cloud
72 changes: 10 additions & 62 deletions util/cloud/gcp/gcs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "io/file_util.h"
#include "io/line_reader.h"
#include "strings/escaping.h"
#include "util/cloud/gcp/gcp_utils.h"

using namespace std;
namespace h2 = boost::beast::http;
Expand All @@ -24,9 +25,6 @@ namespace util {
namespace cloud {

namespace {
constexpr char kDomain[] = "www.googleapis.com";

using EmptyRequest = h2::request<h2::empty_body>;

auto Unexpected(std::errc code) {
return nonstd::make_unexpected(make_error_code(code));
Expand All @@ -46,28 +44,6 @@ auto Unexpected(std::errc code) {
return ec; \
} while (false)

string AuthHeader(string_view access_token) {
return absl::StrCat("Bearer ", access_token);
}

EmptyRequest PrepareRequest(h2::verb req_verb, boost::beast::string_view url,
const string_view access_token) {
EmptyRequest req(req_verb, url, 11);
req.set(h2::field::host, kDomain);
req.set(h2::field::authorization, AuthHeader(access_token));
req.keep_alive(true);

return req;
}

bool IsUnauthorized(const h2::header<false, h2::fields>& resp) {
if (resp.result() != h2::status::unauthorized) {
return false;
}
auto it = resp.find("WWW-Authenticate");

return it != resp.end();
}

io::Result<string> ExpandFile(string_view path) {
io::Result<io::StatShortVec> res = io::StatFiles(path);
Expand Down Expand Up @@ -177,36 +153,6 @@ io::Result<TokenTtl> ParseTokenResponse(std::string&& response) {
return result;
}

using EmptyParserPtr = std::unique_ptr<h2::response_parser<h2::empty_body>>;
io::Result<EmptyParserPtr> SendWithToken(GCPCredsProvider* provider, http::Client* client,
EmptyRequest* req) {
error_code ec;
for (unsigned i = 0; i < 2; ++i) { // Iterate for possible token refresh.
VLOG(1) << "HttpReq" << i << ": " << *req << ", socket " << client->native_handle();

RETURN_UNEXPECTED(client->Send(*req));
EmptyParserPtr parser(new h2::response_parser<h2::empty_body>());
RETURN_UNEXPECTED(client->ReadHeader(parser.get()));

VLOG(1) << "RespHeader" << i << ": " << parser->get();

if (parser->get().result() == h2::status::ok) {
return parser;
};

if (IsUnauthorized(parser->get())) {
RETURN_UNEXPECTED(provider->RefreshToken(client->proactor()));
req->set(h2::field::authorization, AuthHeader(provider->access_token()));

continue;
}
ec = make_error_code(errc::bad_message);
LOG(DFATAL) << "Unexpected response " << parser.get();
}

return nonstd::make_unexpected(ec);
}

#define FETCH_ARRAY_MEMBER(val) \
if (!(val).IsArray()) \
return make_error_code(errc::bad_message); \
Expand Down Expand Up @@ -310,7 +256,7 @@ GCS::~GCS() {
std::error_code GCS::Connect(unsigned msec) {
client_->set_connect_timeout_ms(msec);

return client_->Connect(kDomain, "443", ssl_ctx_);
return client_->Connect(GCP_API_DOMAIN, "443", ssl_ctx_);
}

error_code GCS::ListBuckets(ListBucketCb cb) {
Expand All @@ -321,12 +267,13 @@ error_code GCS::ListBuckets(ListBucketCb cb) {

rj::Document doc;

RobustSender sender(2, &creds_provider_);

while (true) {
io::Result<EmptyParserPtr> parse_res =
SendWithToken(&creds_provider_, client_.get(), &http_req);
io::Result<RobustSender::HeaderParserPtr> parse_res = sender.Send(client_.get(), &http_req);
if (!parse_res)
return parse_res.error();
EmptyParserPtr empty_parser = std::move(*parse_res);
RobustSender::HeaderParserPtr empty_parser = std::move(*parse_res);
h2::response_parser<h2::string_body> resp(std::move(*empty_parser));
RETURN_ERROR(client_->Recv(&resp));

Expand Down Expand Up @@ -376,12 +323,13 @@ error_code GCS::List(string_view bucket, string_view prefix, bool recursive,
auto http_req = PrepareRequest(h2::verb::get, url, creds_provider_.access_token());

rj::Document doc;
RobustSender sender(2, &creds_provider_);
while (true) {
io::Result<EmptyParserPtr> parse_res =
SendWithToken(&creds_provider_, client_.get(), &http_req);
io::Result<RobustSender::HeaderParserPtr> parse_res = sender.Send(client_.get(), &http_req);
if (!parse_res)
return parse_res.error();
EmptyParserPtr empty_parser = std::move(*parse_res);
RobustSender::HeaderParserPtr empty_parser = std::move(*parse_res);

h2::response_parser<h2::string_body> resp(std::move(*empty_parser));
RETURN_ERROR(client_->Recv(&resp));

Expand Down
Loading

0 comments on commit bd38683

Please sign in to comment.