Skip to content

Commit

Permalink
chore: .
Browse files Browse the repository at this point in the history
  • Loading branch information
romange committed Jul 13, 2024
1 parent a8a4f83 commit d8da9e6
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 21 deletions.
5 changes: 5 additions & 0 deletions examples/gcs_demo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ void Run(SSL_CTX* ctx) {
io::Result<io::WriteFile*> dest_res =
cloud::OpenWriteGcsFile(bucket, prefix, &provider, conn_pool.get());
CHECK(dest_res);
unique_ptr<io::WriteFile> dest(*dest_res);
error_code ec = dest->Write(*src);
CHECK(!ec);
ec = dest->Close();
CHECK(!ec);
} else {
auto cb = [](cloud::GCS::ObjectItem item) {
cout << "Object: " << item.key << ", size: " << item.size << endl;
Expand Down
7 changes: 0 additions & 7 deletions util/cloud/gcp/gcp_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,6 @@
#include "base/logging.h"
#include "util/cloud/gcp/gcp_creds_provider.h"

#define RETURN_UNEXPECTED(x) \
do { \
auto ec = (x); \
if (ec) \
return nonstd::make_unexpected(ec); \
} while (false)

namespace util::cloud {
using namespace std;
namespace h2 = boost::beast::http;
Expand Down
7 changes: 7 additions & 0 deletions util/cloud/gcp/gcp_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,11 @@ class RobustSender {
GCPCredsProvider* provider_;
};

#define RETURN_UNEXPECTED(x) \
do { \
auto ec = (x); \
if (ec) \
return nonstd::make_unexpected(ec); \
} while (false)

} // namespace util::cloud
6 changes: 3 additions & 3 deletions util/cloud/gcp/gcs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ auto Unexpected(std::errc code) {
return ec; \
} while (false)

io::Result<string> ExpandFile(string_view path) {
io::Result<string> ExpandFilePath(string_view path) {
io::Result<io::StatShortVec> res = io::StatFiles(path);

if (!res) {
Expand All @@ -59,7 +59,7 @@ io::Result<string> ExpandFile(string_view path) {
}

std::error_code LoadGCPConfig(string* account_id, string* project_id) {
io::Result<string> path = ExpandFile("~/.config/gcloud/configurations/config_default");
io::Result<string> path = ExpandFilePath("~/.config/gcloud/configurations/config_default");
if (!path) {
return path.error();
}
Expand Down Expand Up @@ -194,7 +194,7 @@ error_code EnableKeepAlive(int fd) {
error_code GCPCredsProvider::Init(unsigned connect_ms, fb2::ProactorBase* pb) {
CHECK_GT(connect_ms, 0u);

io::Result<string> root_path = ExpandFile("~/.config/gcloud");
io::Result<string> root_path = ExpandFilePath("~/.config/gcloud");
if (!root_path) {
return root_path.error();
}
Expand Down
125 changes: 114 additions & 11 deletions util/cloud/gcp/gcs_file.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,49 @@
#include "util/cloud/gcp/gcs_file.h"

#include <absl/strings/str_cat.h>
#include <absl/strings/strip.h>

#include <boost/beast/http/empty_body.hpp>

#include "base/logging.h"
#include "base/flags.h"
#include "strings/escaping.h"
#include "util/cloud/gcp/gcp_utils.h"
#include "util/http/http_common.h"

ABSL_FLAG(bool, gcs_dry_upload, false, "");

namespace util {

namespace cloud {
using namespace std;
namespace h2 = boost::beast::http;
using boost::beast::multi_buffer;

namespace {

//! [from, to) limited range out of total. If total is < 0 then it's unknown.
string ContentRangeHeader(size_t from, size_t to, ssize_t total) {
DCHECK_LE(from, to);
string tmp{"bytes "};

if (from < to) { // common case.
absl::StrAppend(&tmp, from, "-", to - 1, "/"); // content-range is inclusive.
if (total >= 0) {
absl::StrAppend(&tmp, total);
} else {
tmp.push_back('*');
}
} else {
// We can write empty ranges only when we finalize the file and total is known.
DCHECK_GE(total, 0);
absl::StrAppend(&tmp, "*/", total);
}

return tmp;
}


// File handle that writes to GCS.
//
// This uses multipart uploads, where it will buffer upto the configured part
Expand All @@ -31,32 +59,105 @@ class GcsWriteFile : public io::WriteFile {

// Closes the object and completes the multipart upload. Therefore the object
// will not be uploaded unless Close is called.
std::error_code Close() override;
error_code Close() override;

GcsWriteFile(const string_view key, string_view upload_id, size_t part_size,
http::ClientPool* pool);
http::ClientPool* pool, GCPCredsProvider* creds_provider);

private:
std::string upload_id_;
using UploadRequest = h2::request<h2::dynamic_body>;

error_code FillBuf(const uint8* buffer, size_t length);
error_code Upload();
UploadRequest PrepareRequest(size_t to, ssize_t total);

string upload_id_;
multi_buffer body_mb_;
size_t uploaded_ = 0;
http::ClientPool* pool_;
GCPCredsProvider* creds_provider_;
};

GcsWriteFile::GcsWriteFile(string_view key, string_view upload_id, size_t part_size,
http::ClientPool* pool)
: io::WriteFile(key), upload_id_(upload_id) {
http::ClientPool* pool, GCPCredsProvider* creds_provider)
: io::WriteFile(key), upload_id_(upload_id), body_mb_(part_size), pool_(pool),
creds_provider_(creds_provider) {
}

io::Result<size_t> GcsWriteFile::WriteSome(const iovec* v, uint32_t len) {
size_t total_size = 0;
size_t total = 0;
for (uint32_t i = 0; i < len; ++i) {
total_size += v[i].iov_len;
RETURN_UNEXPECTED(FillBuf(reinterpret_cast<const uint8_t*>(v->iov_base), v->iov_len));
total += v->iov_len;
}
return total_size;
return total;
}

error_code GcsWriteFile::Close() {
return {};
}

error_code GcsWriteFile::FillBuf(const uint8* buffer, size_t length) {
size_t prepare_size = std::min(length, body_mb_.max_size() - body_mb_.size());
auto mbs = body_mb_.prepare(prepare_size);
size_t offs = 0;
for (auto mb : mbs) {
memcpy(mb.data(), buffer + offs, mb.size());
offs += mb.size();
}
CHECK_EQ(offs, prepare_size);
body_mb_.commit(prepare_size);
}

error_code GcsWriteFile::Upload() {
size_t body_size = body_mb_.size();
CHECK_GT(body_size, 0);
CHECK_EQ(0, body_size % (1U << 18)) << body_size; // Must be multiple of 256KB.

size_t to = uploaded_ + body_size;

UploadRequest req = PrepareRequest(to, -1);

error_code res;
if (!absl::GetFlag(FLAGS_gcs_dry_upload)) {
RobustSender sender(3, creds_provider);
res = SendGeneric(3, std::move(req)).status;
VLOG(1) << "Uploaded range " << uploaded_ << "/" << to << " for " << upload_id_;

Parser* upload_parser = CHECK_NOTNULL(parser());
const auto& resp_msg = upload_parser->get();
auto it = resp_msg.find(h2::field::range);
CHECK(it != resp_msg.end()) << resp_msg;

string_view range = FromBoostSV(it->value());
CHECK(absl::ConsumePrefix(&range, "bytes="));
size_t pos = range.find('-');
CHECK_LT(pos, range.size());
size_t uploaded_pos = 0;
CHECK(absl::SimpleAtoi(range.substr(pos + 1), &uploaded_pos));
CHECK_EQ(uploaded_pos + 1, to);


if (!res.ok())
return res;
}

uploaded_ = to;
return {};
}

auto GcsWriteFile::PrepareRequest(size_t to, ssize_t total) -> UploadRequest {
UploadRequest req(h2::verb::put, upload_id_, 11);
req.body() = std::move(body_mb_);
req.set(h2::field::content_range, ContentRangeHeader(uploaded_, to, total));
req.set(h2::field::content_type, http::kBinMime);
req.prepare_payload();

DCHECK_EQ(0, body_mb_.size());

return req;
}

} // namespace

io::Result<io::WriteFile*> OpenWriteGcsFile(const string& bucket, const string& key,
Expand All @@ -67,18 +168,20 @@ io::Result<io::WriteFile*> OpenWriteGcsFile(const string& bucket, const string&
strings::AppendUrlEncoded(key, &url);
string token = creds_provider->access_token();
EmptyRequest req = PrepareRequest(h2::verb::post, url, token);
req.prepare_payload(); // it's post request so it's required.

RobustSender sender(3, creds_provider);
auto client_handle = pool->GetHandle();
io::Result<RobustSender::HeaderParserPtr> res = sender.Send(client_handle.get(), &req);
if (!res) {
return nonstd::make_unexpected(res.error());
}
auto parser_ptr = std::move(*res);

RobustSender::HeaderParserPtr parser_ptr = std::move(*res);
const auto& headers = parser_ptr->get();
auto it = headers.find(h2::field::location);
if (it != headers.end()) {
LOG(ERROR) << "Could not find the header";
if (it == headers.end()) {
LOG(ERROR) << "Could not find location in " << headers;
return nonstd::make_unexpected(make_error_code(errc::connection_refused));
}

Expand Down

0 comments on commit d8da9e6

Please sign in to comment.