diff --git a/source/common/config/BUILD b/source/common/config/BUILD index 156ea3b17fa5..082f1cc30f50 100644 --- a/source/common/config/BUILD +++ b/source/common/config/BUILD @@ -342,6 +342,7 @@ envoy_cc_library( ":api_type_oracle_lib", "//source/common/common:assert_lib", "//source/common/protobuf", + "//source/common/protobuf:visitor_lib", "//source/common/protobuf:well_known_lib", "@envoy_api//envoy/config/core/v3:pkg_cc_proto", ], diff --git a/source/common/config/version_converter.cc b/source/common/config/version_converter.cc index 181313fbbaa9..e1f557c5130b 100644 --- a/source/common/config/version_converter.cc +++ b/source/common/config/version_converter.cc @@ -4,6 +4,7 @@ #include "common/common/assert.h" #include "common/config/api_type_oracle.h" +#include "common/protobuf/visitor.h" #include "common/protobuf/well_known.h" #include "absl/strings/match.h" @@ -31,30 +32,6 @@ class ProtoVisitor { virtual void onMessage(Protobuf::Message&, const void*){}; }; -// TODO(htuch): refactor these message visitor patterns into utility.cc and share with -// MessageUtil::checkForUnexpectedFields. -void traverseMutableMessage(ProtoVisitor& visitor, Protobuf::Message& message, const void* ctxt) { - visitor.onMessage(message, ctxt); - const Protobuf::Descriptor* descriptor = message.GetDescriptor(); - const Protobuf::Reflection* reflection = message.GetReflection(); - for (int i = 0; i < descriptor->field_count(); ++i) { - const Protobuf::FieldDescriptor* field = descriptor->field(i); - const void* field_ctxt = visitor.onField(message, *field, ctxt); - // If this is a message, recurse to scrub deprecated fields in the sub-message. - if (field->cpp_type() == Protobuf::FieldDescriptor::CPPTYPE_MESSAGE) { - if (field->is_repeated()) { - const int size = reflection->FieldSize(message, field); - for (int j = 0; j < size; ++j) { - traverseMutableMessage(visitor, *reflection->MutableRepeatedMessage(&message, field, j), - field_ctxt); - } - } else if (reflection->HasField(message, field)) { - traverseMutableMessage(visitor, *reflection->MutableMessage(&message, field), field_ctxt); - } - } - } -} - // Reinterpret a Protobuf message as another Protobuf message by converting to wire format and back. // This only works for messages that can be effectively duck typed this way, e.g. with a subtype // relationship modulo field name. @@ -88,7 +65,7 @@ DynamicMessagePtr createForDescriptorWithCast(const Protobuf::Message& message, // internally, we later want to recover their original types. void annotateWithOriginalType(const Protobuf::Descriptor& prev_descriptor, Protobuf::Message& next_message) { - class TypeAnnotatingProtoVisitor : public ProtoVisitor { + class TypeAnnotatingProtoVisitor : public ProtobufMessage::ProtoVisitor { public: void onMessage(Protobuf::Message& message, const void* ctxt) override { const Protobuf::Descriptor* descriptor = message.GetDescriptor(); @@ -127,7 +104,7 @@ void annotateWithOriginalType(const Protobuf::Descriptor& prev_descriptor, } }; TypeAnnotatingProtoVisitor proto_visitor; - traverseMutableMessage(proto_visitor, next_message, &prev_descriptor); + ProtobufMessage::traverseMutableMessage(proto_visitor, next_message, &prev_descriptor); } } // namespace @@ -140,7 +117,7 @@ void VersionConverter::upgrade(const Protobuf::Message& prev_message, } void VersionConverter::eraseOriginalTypeInformation(Protobuf::Message& message) { - class TypeErasingProtoVisitor : public ProtoVisitor { + class TypeErasingProtoVisitor : public ProtobufMessage::ProtoVisitor { public: void onMessage(Protobuf::Message& message, const void*) override { const Protobuf::Reflection* reflection = message.GetReflection(); @@ -149,7 +126,7 @@ void VersionConverter::eraseOriginalTypeInformation(Protobuf::Message& message) } }; TypeErasingProtoVisitor proto_visitor; - traverseMutableMessage(proto_visitor, message, nullptr); + ProtobufMessage::traverseMutableMessage(proto_visitor, message, nullptr); } DynamicMessagePtr VersionConverter::recoverOriginal(const Protobuf::Message& upgraded_message) { @@ -226,7 +203,7 @@ void VersionConverter::prepareMessageForGrpcWire(Protobuf::Message& message, } void VersionUtil::scrubHiddenEnvoyDeprecated(Protobuf::Message& message) { - class HiddenFieldScrubbingProtoVisitor : public ProtoVisitor { + class HiddenFieldScrubbingProtoVisitor : public ProtobufMessage::ProtoVisitor { public: const void* onField(Protobuf::Message& message, const Protobuf::FieldDescriptor& field, const void*) override { @@ -238,7 +215,7 @@ void VersionUtil::scrubHiddenEnvoyDeprecated(Protobuf::Message& message) { } }; HiddenFieldScrubbingProtoVisitor proto_visitor; - traverseMutableMessage(proto_visitor, message, nullptr); + ProtobufMessage::traverseMutableMessage(proto_visitor, message, nullptr); } } // namespace Config diff --git a/source/common/protobuf/BUILD b/source/common/protobuf/BUILD index 34bd1f7e18d4..9a9aa1f30624 100644 --- a/source/common/protobuf/BUILD +++ b/source/common/protobuf/BUILD @@ -68,12 +68,20 @@ envoy_cc_library( "//source/common/common:utility_lib", "//source/common/config:api_type_oracle_lib", "//source/common/config:version_converter_lib", + "//source/common/protobuf:visitor_lib", "@com_github_cncf_udpa//udpa/annotations:pkg_cc_proto", "@envoy_api//envoy/annotations:pkg_cc_proto", "@envoy_api//envoy/type/v3:pkg_cc_proto", ], ) +envoy_cc_library( + name = "visitor_lib", + srcs = ["visitor.cc"], + hdrs = ["visitor.h"], + deps = [":protobuf"], +) + envoy_cc_library( name = "well_known_lib", hdrs = ["well_known.h"], diff --git a/source/common/protobuf/utility.cc b/source/common/protobuf/utility.cc index 117ae9555aa4..e3715a13c79d 100644 --- a/source/common/protobuf/utility.cc +++ b/source/common/protobuf/utility.cc @@ -13,6 +13,7 @@ #include "common/config/version_converter.h" #include "common/protobuf/message_validator_impl.h" #include "common/protobuf/protobuf.h" +#include "common/protobuf/visitor.h" #include "common/protobuf/well_known.h" #include "absl/strings/match.h" @@ -403,80 +404,77 @@ void checkForDeprecatedNonRepeatedEnumValue(const Protobuf::Message& message, message); } -void checkForUnexpectedFields(const Protobuf::Message& message, - ProtobufMessage::ValidationVisitor& validation_visitor, - Runtime::Loader* runtime) { - // Reject unknown fields. - const auto& unknown_fields = message.GetReflection()->GetUnknownFields(message); - if (!unknown_fields.empty()) { - std::string error_msg; - for (int n = 0; n < unknown_fields.field_count(); ++n) { - if (unknown_fields.field(n).number() == ProtobufWellKnown::OriginalTypeFieldNumber) { - continue; - } - error_msg += absl::StrCat(n > 0 ? ", " : "", unknown_fields.field(n).number()); - } - // We use the validation visitor but have hard coded behavior below for deprecated fields. - // TODO(htuch): Unify the deprecated and unknown visitor handling behind the validation - // visitor pattern. https://github.com/envoyproxy/envoy/issues/8092. - if (!error_msg.empty()) { - validation_visitor.onUnknownField("type " + message.GetTypeName() + - " with unknown field set {" + error_msg + "}"); - } - } +class UnexpectedFieldProtoVisitor : public ProtobufMessage::ConstProtoVisitor { +public: + UnexpectedFieldProtoVisitor(ProtobufMessage::ValidationVisitor& validation_visitor, + Runtime::Loader* runtime) + : validation_visitor_(validation_visitor), runtime_(runtime) {} - const Protobuf::Descriptor* descriptor = message.GetDescriptor(); - const Protobuf::Reflection* reflection = message.GetReflection(); - for (int i = 0; i < descriptor->field_count(); ++i) { - const Protobuf::FieldDescriptor* field = descriptor->field(i); - absl::string_view filename = filenameFromPath(field->file()->name()); + const void* onField(const Protobuf::Message& message, const Protobuf::FieldDescriptor& field, + const void*) override { + const Protobuf::Reflection* reflection = message.GetReflection(); + absl::string_view filename = filenameFromPath(field.file()->name()); // Before we check to see if the field is in use, see if there's a // deprecated default enum value. - checkForDeprecatedNonRepeatedEnumValue(message, filename, field, reflection, runtime); + checkForDeprecatedNonRepeatedEnumValue(message, filename, &field, reflection, runtime_); // If this field is not in use, continue. - if ((field->is_repeated() && reflection->FieldSize(message, field) == 0) || - (!field->is_repeated() && !reflection->HasField(message, field))) { - continue; + if ((field.is_repeated() && reflection->FieldSize(message, &field) == 0) || + (!field.is_repeated() && !reflection->HasField(message, &field))) { + return nullptr; } // If this field is deprecated, warn or throw an error. - if (field->options().deprecated()) { + if (field.options().deprecated()) { const std::string warning = absl::StrCat( - "Using {}deprecated option '", field->full_name(), "' from file ", filename, + "Using {}deprecated option '", field.full_name(), "' from file ", filename, ". This configuration will be removed from " "Envoy soon. Please see https://www.envoyproxy.io/docs/envoy/latest/intro/deprecated " "for details."); - deprecatedFieldHelper( - runtime, true /*deprecated*/, - field->options().GetExtension(envoy::annotations::disallowed_by_default), - absl::StrCat("envoy.deprecated_features:", field->full_name()), warning, message); + deprecatedFieldHelper(runtime_, true /*deprecated*/, + field.options().GetExtension(envoy::annotations::disallowed_by_default), + absl::StrCat("envoy.deprecated_features:", field.full_name()), warning, + message); } + return nullptr; + } - // If this is a message, recurse to check for deprecated fields in the sub-message. - if (field->cpp_type() == Protobuf::FieldDescriptor::CPPTYPE_MESSAGE) { - if (field->is_repeated()) { - const int size = reflection->FieldSize(message, field); - for (int j = 0; j < size; ++j) { - checkForUnexpectedFields(reflection->GetRepeatedMessage(message, field, j), - validation_visitor, runtime); + void onMessage(const Protobuf::Message& message, const void*) override { + // Reject unknown fields. + const auto& unknown_fields = message.GetReflection()->GetUnknownFields(message); + if (!unknown_fields.empty()) { + std::string error_msg; + for (int n = 0; n < unknown_fields.field_count(); ++n) { + if (unknown_fields.field(n).number() == ProtobufWellKnown::OriginalTypeFieldNumber) { + continue; } - } else { - checkForUnexpectedFields(reflection->GetMessage(message, field), validation_visitor, - runtime); + error_msg += absl::StrCat(n > 0 ? ", " : "", unknown_fields.field(n).number()); + } + // We use the validation visitor but have hard coded behavior below for deprecated fields. + // TODO(htuch): Unify the deprecated and unknown visitor handling behind the validation + // visitor pattern. https://github.com/envoyproxy/envoy/issues/8092. + if (!error_msg.empty()) { + validation_visitor_.onUnknownField("type " + message.GetTypeName() + + " with unknown field set {" + error_msg + "}"); } } } -} + +private: + ProtobufMessage::ValidationVisitor& validation_visitor_; + Runtime::Loader* runtime_; +}; } // namespace void MessageUtil::checkForUnexpectedFields(const Protobuf::Message& message, ProtobufMessage::ValidationVisitor& validation_visitor, Runtime::Loader* runtime) { - ::Envoy::checkForUnexpectedFields(API_RECOVER_ORIGINAL(message), validation_visitor, runtime); + UnexpectedFieldProtoVisitor unexpected_field_visitor(validation_visitor, runtime); + ProtobufMessage::traverseMessage(unexpected_field_visitor, API_RECOVER_ORIGINAL(message), + nullptr); } std::string MessageUtil::getYamlStringFromMessage(const Protobuf::Message& message, diff --git a/source/common/protobuf/visitor.cc b/source/common/protobuf/visitor.cc new file mode 100644 index 000000000000..d978f964ff0f --- /dev/null +++ b/source/common/protobuf/visitor.cc @@ -0,0 +1,50 @@ +#include "common/protobuf/visitor.h" + +namespace Envoy { +namespace ProtobufMessage { + +void traverseMutableMessage(ProtoVisitor& visitor, Protobuf::Message& message, const void* ctxt) { + visitor.onMessage(message, ctxt); + const Protobuf::Descriptor* descriptor = message.GetDescriptor(); + const Protobuf::Reflection* reflection = message.GetReflection(); + for (int i = 0; i < descriptor->field_count(); ++i) { + const Protobuf::FieldDescriptor* field = descriptor->field(i); + const void* field_ctxt = visitor.onField(message, *field, ctxt); + // If this is a message, recurse to visit fields in the sub-message. + if (field->cpp_type() == Protobuf::FieldDescriptor::CPPTYPE_MESSAGE) { + if (field->is_repeated()) { + const int size = reflection->FieldSize(message, field); + for (int j = 0; j < size; ++j) { + traverseMutableMessage(visitor, *reflection->MutableRepeatedMessage(&message, field, j), + field_ctxt); + } + } else if (reflection->HasField(message, field)) { + traverseMutableMessage(visitor, *reflection->MutableMessage(&message, field), field_ctxt); + } + } + } +} +void traverseMessage(ConstProtoVisitor& visitor, const Protobuf::Message& message, + const void* ctxt) { + visitor.onMessage(message, ctxt); + const Protobuf::Descriptor* descriptor = message.GetDescriptor(); + const Protobuf::Reflection* reflection = message.GetReflection(); + for (int i = 0; i < descriptor->field_count(); ++i) { + const Protobuf::FieldDescriptor* field = descriptor->field(i); + const void* field_ctxt = visitor.onField(message, *field, ctxt); + // If this is a message, recurse to scrub deprecated fields in the sub-message. + if (field->cpp_type() == Protobuf::FieldDescriptor::CPPTYPE_MESSAGE) { + if (field->is_repeated()) { + const int size = reflection->FieldSize(message, field); + for (int j = 0; j < size; ++j) { + traverseMessage(visitor, reflection->GetRepeatedMessage(message, field, j), field_ctxt); + } + } else if (reflection->HasField(message, field)) { + traverseMessage(visitor, reflection->GetMessage(message, field), field_ctxt); + } + } + } +} + +} // namespace ProtobufMessage +} // namespace Envoy diff --git a/source/common/protobuf/visitor.h b/source/common/protobuf/visitor.h new file mode 100644 index 000000000000..93ec07af93b7 --- /dev/null +++ b/source/common/protobuf/visitor.h @@ -0,0 +1,43 @@ +#pragma once + +#include "common/protobuf/protobuf.h" + +namespace Envoy { +namespace ProtobufMessage { + +class ProtoVisitor { +public: + virtual ~ProtoVisitor() = default; + + // Invoked when a field is visited, with the message, field descriptor and context. Returns a new + // context for use when traversing the sub-message in a field. + virtual const void* onField(Protobuf::Message&, const Protobuf::FieldDescriptor&, + const void* ctxt) { + return ctxt; + } + + // Invoked when a message is visited, with the message and a context. + virtual void onMessage(Protobuf::Message&, const void*){}; +}; + +class ConstProtoVisitor { +public: + virtual ~ConstProtoVisitor() = default; + + // Invoked when a field is visited, with the message, field descriptor and context. Returns a new + // context for use when traversing the sub-message in a field. + virtual const void* onField(const Protobuf::Message&, const Protobuf::FieldDescriptor&, + const void* ctxt) { + return ctxt; + } + + // Invoked when a message is visited, with the message and a context. + virtual void onMessage(const Protobuf::Message&, const void*){}; +}; + +void traverseMutableMessage(ProtoVisitor& visitor, Protobuf::Message& message, const void* ctxt); +void traverseMessage(ConstProtoVisitor& visitor, const Protobuf::Message& message, + const void* ctxt); + +} // namespace ProtobufMessage +} // namespace Envoy