From d8da9e6a1ef2ff8f30c223bd053c95117a662872 Mon Sep 17 00:00:00 2001 From: Roman Gershman Date: Sat, 13 Jul 2024 10:54:21 +0300 Subject: [PATCH] chore: . --- examples/gcs_demo.cc | 5 ++ util/cloud/gcp/gcp_utils.cc | 7 -- util/cloud/gcp/gcp_utils.h | 7 ++ util/cloud/gcp/gcs.cc | 6 +- util/cloud/gcp/gcs_file.cc | 125 ++++++++++++++++++++++++++++++++---- 5 files changed, 129 insertions(+), 21 deletions(-) diff --git a/examples/gcs_demo.cc b/examples/gcs_demo.cc index 2d6074b3..5272b547 100644 --- a/examples/gcs_demo.cc +++ b/examples/gcs_demo.cc @@ -43,6 +43,11 @@ void Run(SSL_CTX* ctx) { io::Result dest_res = cloud::OpenWriteGcsFile(bucket, prefix, &provider, conn_pool.get()); CHECK(dest_res); + unique_ptr dest(*dest_res); + error_code ec = dest->Write(*src); + CHECK(!ec); + ec = dest->Close(); + CHECK(!ec); } else { auto cb = [](cloud::GCS::ObjectItem item) { cout << "Object: " << item.key << ", size: " << item.size << endl; diff --git a/util/cloud/gcp/gcp_utils.cc b/util/cloud/gcp/gcp_utils.cc index 7bf01cc4..77c75d6d 100644 --- a/util/cloud/gcp/gcp_utils.cc +++ b/util/cloud/gcp/gcp_utils.cc @@ -10,13 +10,6 @@ #include "base/logging.h" #include "util/cloud/gcp/gcp_creds_provider.h" -#define RETURN_UNEXPECTED(x) \ - do { \ - auto ec = (x); \ - if (ec) \ - return nonstd::make_unexpected(ec); \ - } while (false) - namespace util::cloud { using namespace std; namespace h2 = boost::beast::http; diff --git a/util/cloud/gcp/gcp_utils.h b/util/cloud/gcp/gcp_utils.h index 8952a560..0b7be5e6 100644 --- a/util/cloud/gcp/gcp_utils.h +++ b/util/cloud/gcp/gcp_utils.h @@ -41,4 +41,11 @@ class RobustSender { GCPCredsProvider* provider_; }; +#define RETURN_UNEXPECTED(x) \ + do { \ + auto ec = (x); \ + if (ec) \ + return nonstd::make_unexpected(ec); \ + } while (false) + } // namespace util::cloud \ No newline at end of file diff --git a/util/cloud/gcp/gcs.cc b/util/cloud/gcp/gcs.cc index 05908720..e8a9cea4 100644 --- a/util/cloud/gcp/gcs.cc +++ b/util/cloud/gcp/gcs.cc @@ -44,7 +44,7 @@ auto Unexpected(std::errc code) { return ec; \ } while (false) -io::Result ExpandFile(string_view path) { +io::Result ExpandFilePath(string_view path) { io::Result res = io::StatFiles(path); if (!res) { @@ -59,7 +59,7 @@ io::Result ExpandFile(string_view path) { } std::error_code LoadGCPConfig(string* account_id, string* project_id) { - io::Result path = ExpandFile("~/.config/gcloud/configurations/config_default"); + io::Result path = ExpandFilePath("~/.config/gcloud/configurations/config_default"); if (!path) { return path.error(); } @@ -194,7 +194,7 @@ error_code EnableKeepAlive(int fd) { error_code GCPCredsProvider::Init(unsigned connect_ms, fb2::ProactorBase* pb) { CHECK_GT(connect_ms, 0u); - io::Result root_path = ExpandFile("~/.config/gcloud"); + io::Result root_path = ExpandFilePath("~/.config/gcloud"); if (!root_path) { return root_path.error(); } diff --git a/util/cloud/gcp/gcs_file.cc b/util/cloud/gcp/gcs_file.cc index 8e3f6559..f96e8b4d 100644 --- a/util/cloud/gcp/gcs_file.cc +++ b/util/cloud/gcp/gcs_file.cc @@ -4,21 +4,49 @@ #include "util/cloud/gcp/gcs_file.h" #include +#include #include #include "base/logging.h" +#include "base/flags.h" #include "strings/escaping.h" #include "util/cloud/gcp/gcp_utils.h" +#include "util/http/http_common.h" + +ABSL_FLAG(bool, gcs_dry_upload, false, ""); namespace util { namespace cloud { using namespace std; namespace h2 = boost::beast::http; +using boost::beast::multi_buffer; namespace { +//! [from, to) limited range out of total. If total is < 0 then it's unknown. +string ContentRangeHeader(size_t from, size_t to, ssize_t total) { + DCHECK_LE(from, to); + string tmp{"bytes "}; + + if (from < to) { // common case. + absl::StrAppend(&tmp, from, "-", to - 1, "/"); // content-range is inclusive. + if (total >= 0) { + absl::StrAppend(&tmp, total); + } else { + tmp.push_back('*'); + } + } else { + // We can write empty ranges only when we finalize the file and total is known. + DCHECK_GE(total, 0); + absl::StrAppend(&tmp, "*/", total); + } + + return tmp; +} + + // File handle that writes to GCS. // // This uses multipart uploads, where it will buffer upto the configured part @@ -31,32 +59,105 @@ class GcsWriteFile : public io::WriteFile { // Closes the object and completes the multipart upload. Therefore the object // will not be uploaded unless Close is called. - std::error_code Close() override; + error_code Close() override; GcsWriteFile(const string_view key, string_view upload_id, size_t part_size, - http::ClientPool* pool); + http::ClientPool* pool, GCPCredsProvider* creds_provider); private: - std::string upload_id_; + using UploadRequest = h2::request; + + error_code FillBuf(const uint8* buffer, size_t length); + error_code Upload(); + UploadRequest PrepareRequest(size_t to, ssize_t total); + + string upload_id_; + multi_buffer body_mb_; + size_t uploaded_ = 0; + http::ClientPool* pool_; + GCPCredsProvider* creds_provider_; }; GcsWriteFile::GcsWriteFile(string_view key, string_view upload_id, size_t part_size, - http::ClientPool* pool) - : io::WriteFile(key), upload_id_(upload_id) { + http::ClientPool* pool, GCPCredsProvider* creds_provider) + : io::WriteFile(key), upload_id_(upload_id), body_mb_(part_size), pool_(pool), + creds_provider_(creds_provider) { } io::Result GcsWriteFile::WriteSome(const iovec* v, uint32_t len) { - size_t total_size = 0; + size_t total = 0; for (uint32_t i = 0; i < len; ++i) { - total_size += v[i].iov_len; + RETURN_UNEXPECTED(FillBuf(reinterpret_cast(v->iov_base), v->iov_len)); + total += v->iov_len; } - return total_size; + return total; } error_code GcsWriteFile::Close() { return {}; } +error_code GcsWriteFile::FillBuf(const uint8* buffer, size_t length) { + size_t prepare_size = std::min(length, body_mb_.max_size() - body_mb_.size()); + auto mbs = body_mb_.prepare(prepare_size); + size_t offs = 0; + for (auto mb : mbs) { + memcpy(mb.data(), buffer + offs, mb.size()); + offs += mb.size(); + } + CHECK_EQ(offs, prepare_size); + body_mb_.commit(prepare_size); +} + +error_code GcsWriteFile::Upload() { + size_t body_size = body_mb_.size(); + CHECK_GT(body_size, 0); + CHECK_EQ(0, body_size % (1U << 18)) << body_size; // Must be multiple of 256KB. + + size_t to = uploaded_ + body_size; + + UploadRequest req = PrepareRequest(to, -1); + + error_code res; + if (!absl::GetFlag(FLAGS_gcs_dry_upload)) { + RobustSender sender(3, creds_provider); + res = SendGeneric(3, std::move(req)).status; + VLOG(1) << "Uploaded range " << uploaded_ << "/" << to << " for " << upload_id_; + + Parser* upload_parser = CHECK_NOTNULL(parser()); + const auto& resp_msg = upload_parser->get(); + auto it = resp_msg.find(h2::field::range); + CHECK(it != resp_msg.end()) << resp_msg; + + string_view range = FromBoostSV(it->value()); + CHECK(absl::ConsumePrefix(&range, "bytes=")); + size_t pos = range.find('-'); + CHECK_LT(pos, range.size()); + size_t uploaded_pos = 0; + CHECK(absl::SimpleAtoi(range.substr(pos + 1), &uploaded_pos)); + CHECK_EQ(uploaded_pos + 1, to); + + + if (!res.ok()) + return res; + } + + uploaded_ = to; + return {}; +} + +auto GcsWriteFile::PrepareRequest(size_t to, ssize_t total) -> UploadRequest { + UploadRequest req(h2::verb::put, upload_id_, 11); + req.body() = std::move(body_mb_); + req.set(h2::field::content_range, ContentRangeHeader(uploaded_, to, total)); + req.set(h2::field::content_type, http::kBinMime); + req.prepare_payload(); + + DCHECK_EQ(0, body_mb_.size()); + + return req; +} + } // namespace io::Result OpenWriteGcsFile(const string& bucket, const string& key, @@ -67,6 +168,7 @@ io::Result OpenWriteGcsFile(const string& bucket, const string& strings::AppendUrlEncoded(key, &url); string token = creds_provider->access_token(); EmptyRequest req = PrepareRequest(h2::verb::post, url, token); + req.prepare_payload(); // it's post request so it's required. RobustSender sender(3, creds_provider); auto client_handle = pool->GetHandle(); @@ -74,11 +176,12 @@ io::Result OpenWriteGcsFile(const string& bucket, const string& if (!res) { return nonstd::make_unexpected(res.error()); } - auto parser_ptr = std::move(*res); + + RobustSender::HeaderParserPtr parser_ptr = std::move(*res); const auto& headers = parser_ptr->get(); auto it = headers.find(h2::field::location); - if (it != headers.end()) { - LOG(ERROR) << "Could not find the header"; + if (it == headers.end()) { + LOG(ERROR) << "Could not find location in " << headers; return nonstd::make_unexpected(make_error_code(errc::connection_refused)); }