Skip to content

Commit

Permalink
http: use std::function for headermap iteration (envoyproxy#12103)
Browse files Browse the repository at this point in the history
The existing cast-to-void*-then-back method for passing in a context is
type-unsafe and can make for hard-to-locate errors. By using a
std::function instead of a function pointer, the caller can get
compilation error instead of runtime errors, and doesn't have to do any
sort of bundling dance with std::pair or a custom struct to pass
multiple context items in.

Signed-off-by: Alex Konradi <[email protected]>
Signed-off-by: scheler <[email protected]>
  • Loading branch information
akonradi authored and scheler committed Aug 4, 2020
1 parent be9e031 commit 6a4c9af
Show file tree
Hide file tree
Showing 41 changed files with 499 additions and 708 deletions.
11 changes: 4 additions & 7 deletions include/envoy/http/header_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -519,24 +519,21 @@ class HeaderMap {
/**
* Callback when calling iterate() over a const header map.
* @param header supplies the header entry.
* @param context supplies the context passed to iterate().
* @return Iterate::Continue to continue iteration.
* @return Iterate::Continue to continue iteration, or Iterate::Break to stop;
*/
using ConstIterateCb = Iterate (*)(const HeaderEntry&, void*);
using ConstIterateCb = std::function<Iterate(const HeaderEntry&)>;

/**
* Iterate over a constant header map.
* @param cb supplies the iteration callback.
* @param context supplies the context that will be passed to the callback.
*/
virtual void iterate(ConstIterateCb cb, void* context) const PURE;
virtual void iterate(ConstIterateCb cb) const PURE;

/**
* Iterate over a constant header map in reverse order.
* @param cb supplies the iteration callback.
* @param context supplies the context that will be passed to the callback.
*/
virtual void iterateReverse(ConstIterateCb cb, void* context) const PURE;
virtual void iterateReverse(ConstIterateCb cb) const PURE;

/**
* Clears the headers in the map.
Expand Down
13 changes: 5 additions & 8 deletions source/common/grpc/google_async_client_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,14 +173,11 @@ void GoogleAsyncStreamImpl::initialize(bool /*buffer_body_for_retry*/) {
// copy headers here.
auto initial_metadata = Http::RequestHeaderMapImpl::create();
callbacks_.onCreateInitialMetadata(*initial_metadata);
initial_metadata->iterate(
[](const Http::HeaderEntry& header, void* ctxt) {
auto* client_context = static_cast<grpc::ClientContext*>(ctxt);
client_context->AddMetadata(std::string(header.key().getStringView()),
std::string(header.value().getStringView()));
return Http::HeaderMap::Iterate::Continue;
},
&ctxt_);
initial_metadata->iterate([this](const Http::HeaderEntry& header) {
ctxt_.AddMetadata(std::string(header.key().getStringView()),
std::string(header.value().getStringView()));
return Http::HeaderMap::Iterate::Continue;
});
// Invoke stub call.
rw_ = parent_.stub_->PrepareCall(&ctxt_, "/" + service_full_name_ + "/" + method_name_,
&parent_.tls_.completionQueue());
Expand Down
14 changes: 5 additions & 9 deletions source/common/http/header_list_view.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,11 @@ namespace Envoy {
namespace Http {

HeaderListView::HeaderListView(const HeaderMap& header_map) {
header_map.iterate(
[](const Http::HeaderEntry& header, void* context) -> HeaderMap::Iterate {
auto* context_ptr = static_cast<HeaderListView*>(context);
context_ptr->keys_.emplace_back(std::reference_wrapper<const HeaderString>(header.key()));
context_ptr->values_.emplace_back(
std::reference_wrapper<const HeaderString>(header.value()));
return HeaderMap::Iterate::Continue;
},
this);
header_map.iterate([this](const Http::HeaderEntry& header) -> HeaderMap::Iterate {
keys_.emplace_back(std::reference_wrapper<const HeaderString>(header.key()));
values_.emplace_back(std::reference_wrapper<const HeaderString>(header.value()));
return HeaderMap::Iterate::Continue;
});
}

} // namespace Http
Expand Down
60 changes: 27 additions & 33 deletions source/common/http/header_map_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -264,28 +264,27 @@ void HeaderMapImpl::subtractSize(uint64_t size) {
}

void HeaderMapImpl::copyFrom(HeaderMap& lhs, const HeaderMap& header_map) {
header_map.iterate(
[](const HeaderEntry& header, void* context) -> HeaderMap::Iterate {
// TODO(mattklein123) PERF: Avoid copying here if not necessary.
HeaderString key_string;
key_string.setCopy(header.key().getStringView());
HeaderString value_string;
value_string.setCopy(header.value().getStringView());

static_cast<HeaderMap*>(context)->addViaMove(std::move(key_string),
std::move(value_string));
return HeaderMap::Iterate::Continue;
},
&lhs);
header_map.iterate([&lhs](const HeaderEntry& header) -> HeaderMap::Iterate {
// TODO(mattklein123) PERF: Avoid copying here if not necessary.
HeaderString key_string;
key_string.setCopy(header.key().getStringView());
HeaderString value_string;
value_string.setCopy(header.value().getStringView());

lhs.addViaMove(std::move(key_string), std::move(value_string));
return HeaderMap::Iterate::Continue;
});
}

namespace {

// This is currently only used in tests and is not optimized for performance.
HeaderMap::Iterate collectAllHeaders(const HeaderEntry& header, void* headers) {
static_cast<std::vector<std::pair<absl::string_view, absl::string_view>>*>(headers)->push_back(
std::make_pair(header.key().getStringView(), header.value().getStringView()));
return HeaderMap::Iterate::Continue;
HeaderMap::ConstIterateCb
collectAllHeaders(std::vector<std::pair<absl::string_view, absl::string_view>>* dest) {
return [dest](const HeaderEntry& header) -> HeaderMap::Iterate {
dest->push_back(std::make_pair(header.key().getStringView(), header.value().getStringView()));
return HeaderMap::Iterate::Continue;
};
};

} // namespace
Expand All @@ -298,7 +297,7 @@ bool HeaderMapImpl::operator==(const HeaderMap& rhs) const {

std::vector<std::pair<absl::string_view, absl::string_view>> rhs_headers;
rhs_headers.reserve(rhs.size());
rhs.iterate(collectAllHeaders, &rhs_headers);
rhs.iterate(collectAllHeaders(&rhs_headers));

auto i = headers_.begin();
auto j = rhs_headers.begin();
Expand Down Expand Up @@ -462,17 +461,17 @@ HeaderEntry* HeaderMapImpl::getExisting(const LowerCaseString& key) {
return nullptr;
}

void HeaderMapImpl::iterate(HeaderMap::ConstIterateCb cb, void* context) const {
void HeaderMapImpl::iterate(HeaderMap::ConstIterateCb cb) const {
for (const HeaderEntryImpl& header : headers_) {
if (cb(header, context) == HeaderMap::Iterate::Break) {
if (cb(header) == HeaderMap::Iterate::Break) {
break;
}
}
}

void HeaderMapImpl::iterateReverse(HeaderMap::ConstIterateCb cb, void* context) const {
void HeaderMapImpl::iterateReverse(HeaderMap::ConstIterateCb cb) const {
for (auto it = headers_.rbegin(); it != headers_.rend(); it++) {
if (cb(*it, context) == HeaderMap::Iterate::Break) {
if (cb(*it) == HeaderMap::Iterate::Break) {
break;
}
}
Expand Down Expand Up @@ -527,17 +526,12 @@ size_t HeaderMapImpl::removePrefix(const LowerCaseString& prefix) {
}

void HeaderMapImpl::dumpState(std::ostream& os, int indent_level) const {
using IterateData = std::pair<std::ostream*, const char*>;
const char* spaces = spacesForLevel(indent_level);
IterateData iterate_data = std::make_pair(&os, spaces);
iterate(
[](const HeaderEntry& header, void* context) -> HeaderMap::Iterate {
auto* data = static_cast<IterateData*>(context);
*data->first << data->second << "'" << header.key().getStringView() << "', '"
<< header.value().getStringView() << "'\n";
return HeaderMap::Iterate::Continue;
},
&iterate_data);
iterate([&os,
spaces = spacesForLevel(indent_level)](const HeaderEntry& header) -> HeaderMap::Iterate {
os << spaces << "'" << header.key().getStringView() << "', '" << header.value().getStringView()
<< "'\n";
return HeaderMap::Iterate::Continue;
});
}

HeaderMapImpl::HeaderEntryImpl& HeaderMapImpl::maybeCreateInline(HeaderEntryImpl** entry,
Expand Down
12 changes: 5 additions & 7 deletions source/common/http/header_map_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ class HeaderMapImpl : NonCopyable {
void setCopy(const LowerCaseString& key, absl::string_view value);
uint64_t byteSize() const;
const HeaderEntry* get(const LowerCaseString& key) const;
void iterate(HeaderMap::ConstIterateCb cb, void* context) const;
void iterateReverse(HeaderMap::ConstIterateCb cb, void* context) const;
void iterate(HeaderMap::ConstIterateCb cb) const;
void iterateReverse(HeaderMap::ConstIterateCb cb) const;
void clear();
size_t remove(const LowerCaseString& key);
size_t removePrefix(const LowerCaseString& key);
Expand Down Expand Up @@ -298,11 +298,9 @@ template <class Interface> class TypedHeaderMapImpl : public HeaderMapImpl, publ
const HeaderEntry* get(const LowerCaseString& key) const override {
return HeaderMapImpl::get(key);
}
void iterate(HeaderMap::ConstIterateCb cb, void* context) const override {
HeaderMapImpl::iterate(cb, context);
}
void iterateReverse(HeaderMap::ConstIterateCb cb, void* context) const override {
HeaderMapImpl::iterateReverse(cb, context);
void iterate(HeaderMap::ConstIterateCb cb) const override { HeaderMapImpl::iterate(cb); }
void iterateReverse(HeaderMap::ConstIterateCb cb) const override {
HeaderMapImpl::iterateReverse(cb);
}
void clear() override { HeaderMapImpl::clear(); }
size_t remove(const LowerCaseString& key) override { return HeaderMapImpl::remove(key); }
Expand Down
37 changes: 15 additions & 22 deletions source/common/http/header_utility.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,18 +78,13 @@ HeaderUtility::HeaderData::HeaderData(const envoy::config::route::v3::HeaderMatc

void HeaderUtility::getAllOfHeader(const HeaderMap& headers, absl::string_view key,
std::vector<absl::string_view>& out) {
auto args = std::make_pair(LowerCaseString(std::string(key)), &out);

headers.iterate(
[](const HeaderEntry& header, void* context) -> HeaderMap::Iterate {
auto key_ret =
static_cast<std::pair<LowerCaseString, std::vector<absl::string_view>*>*>(context);
if (header.key() == key_ret->first.get().c_str()) {
key_ret->second->emplace_back(header.value().getStringView());
}
return HeaderMap::Iterate::Continue;
},
&args);
headers.iterate([key = LowerCaseString(std::string(key)),
&out](const HeaderEntry& header) -> HeaderMap::Iterate {
if (header.key() == key.get().c_str()) {
out.emplace_back(header.value().getStringView());
}
return HeaderMap::Iterate::Continue;
});
}

bool HeaderUtility::matchHeaders(const HeaderMap& request_headers,
Expand Down Expand Up @@ -170,16 +165,14 @@ bool HeaderUtility::isConnectResponse(const RequestHeaderMapPtr& request_headers
}

void HeaderUtility::addHeaders(HeaderMap& headers, const HeaderMap& headers_to_add) {
headers_to_add.iterate(
[](const HeaderEntry& header, void* context) -> HeaderMap::Iterate {
HeaderString k;
k.setCopy(header.key().getStringView());
HeaderString v;
v.setCopy(header.value().getStringView());
static_cast<HeaderMap*>(context)->addViaMove(std::move(k), std::move(v));
return HeaderMap::Iterate::Continue;
},
&headers);
headers_to_add.iterate([&headers](const HeaderEntry& header) -> HeaderMap::Iterate {
HeaderString k;
k.setCopy(header.key().getStringView());
HeaderString v;
v.setCopy(header.value().getStringView());
headers.addViaMove(std::move(k), std::move(v));
return HeaderMap::Iterate::Continue;
});
}

bool HeaderUtility::isEnvoyInternalRequest(const RequestHeaderMap& headers) {
Expand Down
44 changes: 19 additions & 25 deletions source/common/http/http1/codec_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,27 +110,24 @@ void ResponseEncoderImpl::encode100ContinueHeaders(const ResponseHeaderMap& head
void StreamEncoderImpl::encodeHeadersBase(const RequestOrResponseHeaderMap& headers,
absl::optional<uint64_t> status, bool end_stream) {
bool saw_content_length = false;
headers.iterate(
[](const HeaderEntry& header, void* context) -> HeaderMap::Iterate {
absl::string_view key_to_use = header.key().getStringView();
uint32_t key_size_to_use = header.key().size();
// Translate :authority -> host so that upper layers do not need to deal with this.
if (key_size_to_use > 1 && key_to_use[0] == ':' && key_to_use[1] == 'a') {
key_to_use = absl::string_view(Headers::get().HostLegacy.get());
key_size_to_use = Headers::get().HostLegacy.get().size();
}
headers.iterate([this](const HeaderEntry& header) -> HeaderMap::Iterate {
absl::string_view key_to_use = header.key().getStringView();
uint32_t key_size_to_use = header.key().size();
// Translate :authority -> host so that upper layers do not need to deal with this.
if (key_size_to_use > 1 && key_to_use[0] == ':' && key_to_use[1] == 'a') {
key_to_use = absl::string_view(Headers::get().HostLegacy.get());
key_size_to_use = Headers::get().HostLegacy.get().size();
}

// Skip all headers starting with ':' that make it here.
if (key_to_use[0] == ':') {
return HeaderMap::Iterate::Continue;
}
// Skip all headers starting with ':' that make it here.
if (key_to_use[0] == ':') {
return HeaderMap::Iterate::Continue;
}

static_cast<StreamEncoderImpl*>(context)->encodeFormattedHeader(
key_to_use, header.value().getStringView());
encodeFormattedHeader(key_to_use, header.value().getStringView());

return HeaderMap::Iterate::Continue;
},
this);
return HeaderMap::Iterate::Continue;
});

if (headers.ContentLength()) {
saw_content_length = true;
Expand Down Expand Up @@ -234,13 +231,10 @@ void StreamEncoderImpl::encodeTrailersBase(const HeaderMap& trailers) {
// Finalize the body
connection_.buffer().add(LAST_CHUNK);

trailers.iterate(
[](const HeaderEntry& header, void* context) -> HeaderMap::Iterate {
static_cast<StreamEncoderImpl*>(context)->encodeFormattedHeader(
header.key().getStringView(), header.value().getStringView());
return HeaderMap::Iterate::Continue;
},
this);
trailers.iterate([this](const HeaderEntry& header) -> HeaderMap::Iterate {
encodeFormattedHeader(header.key().getStringView(), header.value().getStringView());
return HeaderMap::Iterate::Continue;
});

connection_.flushOutput();
connection_.buffer().add(CRLF);
Expand Down
11 changes: 4 additions & 7 deletions source/common/http/http2/codec_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,10 @@ static void insertHeader(std::vector<nghttp2_nv>& headers, const HeaderEntry& he
void ConnectionImpl::StreamImpl::buildHeaders(std::vector<nghttp2_nv>& final_headers,
const HeaderMap& headers) {
final_headers.reserve(headers.size());
headers.iterate(
[](const HeaderEntry& header, void* context) -> HeaderMap::Iterate {
std::vector<nghttp2_nv>* final_headers = static_cast<std::vector<nghttp2_nv>*>(context);
insertHeader(*final_headers, header);
return HeaderMap::Iterate::Continue;
},
&final_headers);
headers.iterate([&final_headers](const HeaderEntry& header) -> HeaderMap::Iterate {
insertHeader(final_headers, header);
return HeaderMap::Iterate::Continue;
});
}

void ConnectionImpl::ServerStreamImpl::encode100ContinueHeaders(const ResponseHeaderMap& headers) {
Expand Down
Loading

0 comments on commit 6a4c9af

Please sign in to comment.