Skip to content

Commit

Permalink
add context traversal from WASM (envoyproxy#142)
Browse files Browse the repository at this point in the history
Signed-off-by: Kuat Yessenov <[email protected]>
  • Loading branch information
kyessenov authored and PiotrSikora committed Aug 23, 2019
1 parent b7827d9 commit 83a3be9
Show file tree
Hide file tree
Showing 9 changed files with 275 additions and 1 deletion.
27 changes: 27 additions & 0 deletions api/wasm/cpp/proxy_wasm_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class WasmData {
WasmData(const char* data, size_t size) : data_(data), size_(size) {}
~WasmData() { ::free(const_cast<char*>(data_)); }
const char* data() { return data_; }
size_t size() { return size_; }
StringView view() { return {data_, size_}; }
std::string toString() { return std::string(view()); }
std::vector<std::pair<StringView, StringView>> pairs();
Expand Down Expand Up @@ -528,6 +529,32 @@ inline WasmResult getPluginDirection(PluginDirection *direction_ptr) {
PROXY_EXPRESSION_GET("plugin.direction", reinterpret_cast<uint32_t*>(direction_ptr));
}

// Generic selector
inline absl::optional<WasmDataPtr> getSelectorExpression(std::initializer_list<absl::string_view> parts) {
size_t size = 0;
for (auto part: parts) {
size += part.size() + 1; // null terminated string value
}

char* buffer = static_cast<char*>(::malloc(size));
char* b = buffer;

for (auto part : parts) {
memcpy(b, part.data(), part.size());
b += part.size();
*b++ = 0;
}

const char* value_ptr = nullptr;
size_t value_size = 0;
auto result = proxy_getSelectorExpression(buffer, size, &value_ptr, &value_size);
::free(buffer);
if (result != WasmResult::Ok) {
return {};
}
return std::make_unique<WasmData>(value_ptr, value_size);
}

// Metadata
inline WasmResult getMetadata(MetadataType type, StringView key, WasmDataPtr *wasm_data) {
const char* value_ptr = nullptr;
Expand Down
4 changes: 4 additions & 0 deletions api/wasm/cpp/proxy_wasm_externs.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ extern "C" WasmResult proxy_getMetadataStruct(MetadataType type, const char* nam
extern "C" WasmResult proxy_setMetadataStruct(MetadataType type, const char* name_ptr, size_t name_size,
const char* value_ptr, size_t value_size);

// Generic selector
extern "C" WasmResult proxy_getSelectorExpression(const char* path_ptr, size_t path_size,
const char** value_ptr_ptr, size_t* value_size_ptr);

// Continue/Reply/Route
extern "C" WasmResult proxy_continueRequest();
extern "C" WasmResult proxy_continueResponse();
Expand Down
5 changes: 5 additions & 0 deletions source/extensions/common/wasm/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ envoy_cc_library(
":wasm_hdr",
":wasm_vm_lib",
"//api/wasm/cpp:shared_lib",
"//external:abseil_base",
"//external:abseil_node_hash_map",
"//source/common/buffer:buffer_lib",
"//source/common/common:enum_to_int",
Expand All @@ -74,5 +75,9 @@ envoy_cc_library(
"//source/common/http:utility_lib",
"//source/common/tracing:http_tracer_lib",
"//source/extensions/common/wasm/null:null_lib",
"//source/extensions/filters/common/expr:context_lib",
"@com_google_cel_cpp//eval/eval:field_access",
"@com_google_cel_cpp//eval/eval:field_backed_list_impl",
"@com_google_cel_cpp//eval/eval:field_backed_map_impl",
],
)
1 change: 1 addition & 0 deletions source/extensions/common/wasm/null/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ envoy_cc_library(
":null_vm_plugin_interface",
"//api/wasm/cpp:api_lib",
"//api/wasm/cpp:shared_lib",
"//external:abseil_base",
"//external:abseil_node_hash_map",
"//source/common/common:assert_lib",
"//source/common/protobuf",
Expand Down
20 changes: 19 additions & 1 deletion source/extensions/common/wasm/null/sample_plugin/plugin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#else

#include "extensions/common/wasm/null/null_plugin.h"
#include "absl/base/casts.h"

namespace Envoy {
namespace Extensions {
Expand Down Expand Up @@ -45,7 +46,24 @@ FilterDataStatus PluginContext::onRequestBody(size_t body_buffer_length, bool /*

void PluginContext::onLog() {
auto path = getRequestHeader(":path");
logWarn("onLog " + std::to_string(id()) + " " + std::string(path->view()));
if (path->view() == "/test_context") {
logWarn("request.path: " + getSelectorExpression({"request", "path"}).value()->toString());
logWarn("node.metadata: " +
getSelectorExpression({"node", "metadata", "istio.io/metadata"}).value()->toString());
logWarn("metadata: " + getSelectorExpression({"metadata", "filter_metadata", "envoy.wasm",
"wasm_request_get_key"})
.value()
->toString());
auto responseCode = getSelectorExpression({"response", "code"}).value();
if (responseCode->size() == sizeof(int64_t)) {
char buf[sizeof(int64_t)];
responseCode->view().copy(buf, sizeof(int64_t), 0);
int64_t code = absl::bit_cast<int64_t>(buf);
logWarn("response.code: " + absl::StrCat(code));
}
} else {
logWarn("onLog " + std::to_string(id()) + " " + std::string(path->view()));
}
}

void PluginContext::onDone() { logWarn("onDone " + std::to_string(id())); }
Expand Down
7 changes: 7 additions & 0 deletions source/extensions/common/wasm/null/wasm_api_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@ inline WasmResult proxy_setMetadataStruct(MetadataType type, const char* name_pt
WS(name_size), WR(value_ptr), WS(value_size)));
}

// Generic selector
inline WasmResult proxy_getSelectorExpression(const char* path_ptr, size_t path_size,
const char** value_ptr_ptr, size_t* value_size_ptr) {
return wordToWasmResult(getSelectorExpressionHandler(
current_context_, WR(path_ptr), WS(path_size), WR(value_ptr_ptr), WR(value_size_ptr)));
}

// Continue
inline WasmResult proxy_continueRequest() {
return wordToWasmResult(continueRequestHandler(current_context_));
Expand Down
171 changes: 171 additions & 0 deletions source/extensions/common/wasm/wasm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,15 @@
#include "common/tracing/http_tracer_impl.h"

#include "extensions/common/wasm/well_known_names.h"
#include "extensions/filters/common/expr/context.h"

#include "absl/base/casts.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/node_hash_map.h"
#include "absl/synchronization/mutex.h"
#include "eval/eval/field_access.h"
#include "eval/eval/field_backed_list_impl.h"
#include "eval/eval/field_backed_map_impl.h"

namespace Envoy {
namespace Extensions {
Expand Down Expand Up @@ -462,6 +467,25 @@ Word setMetadataStructHandler(void* raw_context, Word type, Word name_ptr, Word
name.value(), value.value())));
}

// Generic selector
Word getSelectorExpressionHandler(void* raw_context, Word path_ptr, Word path_size,
Word value_ptr_ptr, Word value_size_ptr) {
auto context = WASM_CONTEXT(raw_context);
auto path = context->wasmVm()->getMemory(path_ptr, path_size);
if (!path.has_value()) {
return wasmResultToWord(WasmResult::InvalidMemoryAccess);
}
std::string value;
auto result = context->getSelectorExpression(path.value(), &value);
if (result != WasmResult::Ok) {
return wasmResultToWord(result);
}
if (!context->wasm()->copyToPointerSize(value, value_ptr_ptr, value_size_ptr)) {
return wasmResultToWord(WasmResult::InvalidMemoryAccess);
}
return wasmResultToWord(WasmResult::Ok);
}

// Continue/Reply/Route
Word continueRequestHandler(void* raw_context) {
auto context = WASM_CONTEXT(raw_context);
Expand Down Expand Up @@ -1016,6 +1040,153 @@ uint64_t Context::getCurrentTimeNanoseconds() {
.count();
}

WasmResult serializeValue(Filters::Common::Expr::CelValue value, std::string* result) {
using Filters::Common::Expr::CelValue;
switch (value.type()) {
case CelValue::Type::kMessage:
if (value.MessageOrDie() != nullptr) {
value.MessageOrDie()->SerializeToString(result);
return WasmResult::Ok;
}
case CelValue::Type::kString:
result->assign(value.StringOrDie().value().data(), value.StringOrDie().value().size());
return WasmResult::Ok;
case CelValue::Type::kBytes:
result->assign(value.BytesOrDie().value().data(), value.BytesOrDie().value().size());
return WasmResult::Ok;
case CelValue::Type::kInt64: {
auto out = value.Int64OrDie();
result->assign(reinterpret_cast<const char*>(&out), sizeof(int64_t));
return WasmResult::Ok;
}
case CelValue::Type::kUint64: {
auto out = value.Uint64OrDie();
result->assign(reinterpret_cast<const char*>(&out), sizeof(uint64_t));
return WasmResult::Ok;
}
case CelValue::Type::kDouble: {
auto out = value.DoubleOrDie();
result->assign(reinterpret_cast<const char*>(&out), sizeof(double));
return WasmResult::Ok;
}
case CelValue::Type::kBool: {
auto out = value.BoolOrDie();
result->assign(reinterpret_cast<const char*>(&out), sizeof(bool));
return WasmResult::Ok;
}
case CelValue::Type::kDuration: {
auto out = value.DurationOrDie();
result->assign(reinterpret_cast<const char*>(&out), sizeof(absl::Duration));
return WasmResult::Ok;
}
case CelValue::Type::kTimestamp: {
auto out = value.TimestampOrDie();
result->assign(reinterpret_cast<const char*>(&out), sizeof(absl::Time));
return WasmResult::Ok;
}
default:
// TODO: lists and maps
break;
}

return WasmResult::SerializationFailure;
}

WasmResult Context::getSelectorExpression(absl::string_view path, std::string* result) {
using Filters::Common::Expr::CelValue;
using google::api::expr::runtime::FieldBackedListImpl;
using google::api::expr::runtime::FieldBackedMapImpl;

bool first = true;
CelValue value;
Protobuf::Arena arena;
const StreamInfo::StreamInfo& info = decoder_callbacks_->streamInfo();
const auto request_headers = request_headers_ ? request_headers_ : access_log_request_headers_;
const auto response_headers =
response_headers_ ? response_headers_ : access_log_response_headers_;
const auto response_trailers =
response_trailers_ ? response_trailers_ : access_log_response_trailers_;

size_t start = 0;
while (true) {
if (start >= path.size()) {
break;
}

size_t end = path.find('\0', start);
if (end == absl::string_view::npos) {
// this should not happen unless the input string is not null-terminated in the view
return WasmResult::ParseFailure;
}
auto part = path.substr(start, end - start);
start = end + 1;

// top-level ident
if (first) {
first = false;
if (part == "metadata") {
value = CelValue::CreateMessage(&info.dynamicMetadata(), &arena);
} else if (part == "request") {
value = CelValue::CreateMap(Protobuf::Arena::Create<Filters::Common::Expr::RequestWrapper>(
&arena, request_headers, info));
} else if (part == "response") {
value = CelValue::CreateMap(Protobuf::Arena::Create<Filters::Common::Expr::ResponseWrapper>(
&arena, response_headers, response_trailers, info));
} else if (part == "connection") {
value = CelValue::CreateMap(
Protobuf::Arena::Create<Filters::Common::Expr::ConnectionWrapper>(&arena, info));
} else if (part == "node") {
value = CelValue::CreateMessage(&wasm_->local_info_.node(), &arena);
} else if (part == "source") {
value = CelValue::CreateMap(
Protobuf::Arena::Create<Filters::Common::Expr::PeerWrapper>(&arena, info, false));
} else if (part == "destination") {
value = CelValue::CreateMap(
Protobuf::Arena::Create<Filters::Common::Expr::PeerWrapper>(&arena, info, true));
} else {
return WasmResult::NotFound;
}
continue;
}

if (value.IsMap()) {
auto& map = *value.MapOrDie();
auto field = map[CelValue::CreateString(part)];
if (field.has_value()) {
value = field.value();
} else {
return {};
}
} else if (value.IsMessage()) {
auto msg = value.MessageOrDie();
if (msg == nullptr) {
return {};
}
const Protobuf::Descriptor* desc = msg->GetDescriptor();
const Protobuf::FieldDescriptor* field_desc = desc->FindFieldByName(std::string(part));
if (field_desc == nullptr) {
return {};
} else if (field_desc->is_map()) {
value = CelValue::CreateMap(
Protobuf::Arena::Create<FieldBackedMapImpl>(&arena, msg, field_desc, &arena));
} else if (field_desc->is_repeated()) {
value = CelValue::CreateList(
Protobuf::Arena::Create<FieldBackedListImpl>(&arena, msg, field_desc, &arena));
} else {
auto status =
google::api::expr::runtime::CreateValueFromSingleField(msg, field_desc, &arena, &value);
if (!status.ok()) {
return {};
}
}
} else {
return {};
}
}

return serializeValue(value, result);
}

// Shared Data
WasmResult Context::getSharedData(absl::string_view key, std::pair<std::string, uint32_t>* data) {
return global_shared_data.get(wasm_->id(), key, data);
Expand Down
5 changes: 5 additions & 0 deletions source/extensions/common/wasm/wasm.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ Word getMetadataStructHandler(void* raw_context, Word type, Word name_ptr, Word
Word value_ptr_ptr, Word value_size_ptr);
Word setMetadataStructHandler(void* raw_context, Word type, Word name_ptr, Word name_size,
Word value_ptr, Word value_size);
Word getSelectorExpressionHandler(void* raw_context, Word path_ptr, Word path_size,
Word value_ptr_ptr, Word value_size_ptr);
Word continueRequestHandler(void* raw_context);
Word continueResponseHandler(void* raw_context);
Word sendLocalResponseHandler(void* raw_context, Word response_code, Word response_code_details_ptr,
Expand Down Expand Up @@ -254,6 +256,9 @@ class Context : public Http::StreamFilter,
virtual std::string getTlsVersion(StreamType type);
virtual absl::optional<bool> peerCertificatePresented(StreamType type);

// Generic resolver producing a serialized value
virtual WasmResult getSelectorExpression(absl::string_view path, std::string* result);

// Metadata
// When used with MetadataType::Request/Response refers to metadata with name "envoy.wasm": the
// values are serialized ProtobufWkt::Struct Value
Expand Down
36 changes: 36 additions & 0 deletions test/extensions/filters/http/wasm/wasm_filter_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,42 @@ TEST_F(WasmHttpFilterTest, NullPluginRequestHeadersOnly) {
filter_->onDestroy();
}

TEST_F(WasmHttpFilterTest, NullVmResolver) {
setupNullConfig("null_vm_plugin");
setupFilter();
envoy::api::v2::core::Node node_data;
ProtobufWkt::Value node_val;
node_val.set_string_value("sample_data");
(*node_data.mutable_metadata()->mutable_fields())["istio.io/metadata"] = node_val;
EXPECT_CALL(local_info_, node()).WillRepeatedly(ReturnRef(node_data));

request_stream_info_.metadata_.mutable_filter_metadata()->insert(
Protobuf::MapPair<std::string, ProtobufWkt::Struct>(
HttpFilters::HttpFilterNames::get().Wasm,
MessageUtil::keyValueStruct("wasm_request_get_key", "wasm_request_get_value")));
EXPECT_CALL(request_stream_info_, responseCode()).WillRepeatedly(Return(403));
EXPECT_CALL(decoder_callbacks_, streamInfo()).WillRepeatedly(ReturnRef(request_stream_info_));
EXPECT_CALL(*filter_,
scriptLog_(spdlog::level::debug, Eq(absl::string_view("onRequestHeaders 2"))));
EXPECT_CALL(*filter_,
scriptLog_(spdlog::level::info, Eq(absl::string_view("header path /test_context"))));

// test outputs should match inputs
EXPECT_CALL(*filter_, scriptLog_(spdlog::level::warn,
Eq(absl::string_view("request.path: /test_context"))));
EXPECT_CALL(*filter_,
scriptLog_(spdlog::level::warn, Eq(absl::string_view("node.metadata: sample_data"))));
EXPECT_CALL(*filter_, scriptLog_(spdlog::level::warn,
Eq(absl::string_view("metadata: wasm_request_get_value"))));
EXPECT_CALL(*filter_,
scriptLog_(spdlog::level::warn, Eq(absl::string_view("response.code: 403"))));

Http::TestHeaderMapImpl request_headers{{":path", "/test_context"}};
EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->decodeHeaders(request_headers, true));
StreamInfo::MockStreamInfo log_stream_info;
filter_->log(&request_headers, nullptr, nullptr, log_stream_info);
}

TEST_P(WasmHttpFilterTest, SharedData) {
setupConfig(TestEnvironment::readFileToStringForTest(TestEnvironment::substitute(
"{{ test_rundir }}/test/extensions/filters/http/wasm/test_data/shared_cpp.wasm")));
Expand Down

0 comments on commit 83a3be9

Please sign in to comment.