diff --git a/internal/metadata/metadata.go b/internal/metadata/metadata.go index b8733dbf340d..b2980f8ac44a 100644 --- a/internal/metadata/metadata.go +++ b/internal/metadata/metadata.go @@ -22,6 +22,9 @@ package metadata import ( + "fmt" + "strings" + "google.golang.org/grpc/metadata" "google.golang.org/grpc/resolver" ) @@ -72,3 +75,46 @@ func Set(addr resolver.Address, md metadata.MD) resolver.Address { addr.Attributes = addr.Attributes.WithValue(mdKey, mdValue(md)) return addr } + +// Validate returns an error if the input md contains invalid keys or values. +// +// If the header is not a pseudo-header, the following items are checked: +// - header names must contain one or more characters from this set [0-9 a-z _ - .]. +// - if the header-name ends with a "-bin" suffix, no validation of the header value is performed. +// - otherwise, the header value must contain one or more characters from the set [%x20-%x7E]. +func Validate(md metadata.MD) error { + for k, vals := range md { + // pseudo-header will be ignored + if k[0] == ':' { + continue + } + // check key, for i that saving a conversion if not using for range + for i := 0; i < len(k); i++ { + r := k[i] + if !(r >= 'a' && r <= 'z') && !(r >= '0' && r <= '9') && r != '.' && r != '-' && r != '_' { + return fmt.Errorf("header key %q contains illegal characters not in [0-9a-z-_.]", k) + } + } + if strings.HasSuffix(k, "-bin") { + continue + } + // check value + for _, val := range vals { + if hasNotPrintable(val) { + return fmt.Errorf("header key %q contains value with non-printable ASCII characters", k) + } + } + } + return nil +} + +// hasNotPrintable return true if msg contains any characters which are not in %x20-%x7E +func hasNotPrintable(msg string) bool { + // for i that saving a conversion if not using for range + for i := 0; i < len(msg); i++ { + if msg[i] < 0x20 || msg[i] > 0x7E { + return true + } + } + return false +} diff --git a/internal/metadata/metadata_test.go b/internal/metadata/metadata_test.go index 1aa0f9798e8c..80f1a44bb6ac 100644 --- a/internal/metadata/metadata_test.go +++ b/internal/metadata/metadata_test.go @@ -19,6 +19,8 @@ package metadata import ( + "errors" + "reflect" "testing" "github.com/google/go-cmp/cmp" @@ -84,3 +86,28 @@ func TestSet(t *testing.T) { }) } } + +func TestValidate(t *testing.T) { + for _, test := range []struct { + md metadata.MD + want error + }{ + { + md: map[string][]string{string(rune(0x19)): {"testVal"}}, + want: errors.New("header key \"\\x19\" contains illegal characters not in [0-9a-z-_.]"), + }, + { + md: map[string][]string{"test": {string(rune(0x19))}}, + want: errors.New("header key \"test\" contains value with non-printable ASCII characters"), + }, + { + md: map[string][]string{"test-bin": {string(rune(0x19))}}, + want: nil, + }, + } { + err := Validate(test.md) + if !reflect.DeepEqual(err, test.want) { + t.Errorf("validating metadata which is %v got err :%v, want err :%v", test.md, err, test.want) + } + } +} diff --git a/stream.go b/stream.go index 625d47b34e59..5b8b34a12f33 100644 --- a/stream.go +++ b/stream.go @@ -36,6 +36,7 @@ import ( "google.golang.org/grpc/internal/channelz" "google.golang.org/grpc/internal/grpcrand" "google.golang.org/grpc/internal/grpcutil" + imetadata "google.golang.org/grpc/internal/metadata" iresolver "google.golang.org/grpc/internal/resolver" "google.golang.org/grpc/internal/serviceconfig" "google.golang.org/grpc/internal/transport" @@ -164,6 +165,11 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth } func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (_ ClientStream, err error) { + if md, _, ok := metadata.FromOutgoingContextRaw(ctx); ok { + if err := imetadata.Validate(md); err != nil { + return nil, status.Error(codes.Internal, err.Error()) + } + } if channelz.IsOn() { cc.incrCallsStarted() defer func() { @@ -1446,11 +1452,20 @@ func (ss *serverStream) SetHeader(md metadata.MD) error { if md.Len() == 0 { return nil } + err := imetadata.Validate(md) + if err != nil { + return status.Error(codes.Internal, err.Error()) + } return ss.s.SetHeader(md) } func (ss *serverStream) SendHeader(md metadata.MD) error { - err := ss.t.WriteHeader(ss.s, md) + err := imetadata.Validate(md) + if err != nil { + return status.Error(codes.Internal, err.Error()) + } + + err = ss.t.WriteHeader(ss.s, md) if ss.binlog != nil && !ss.serverHeaderBinlogged { h, _ := ss.s.Header() ss.binlog.Log(&binarylog.ServerHeader{ @@ -1465,6 +1480,9 @@ func (ss *serverStream) SetTrailer(md metadata.MD) { if md.Len() == 0 { return } + if err := imetadata.Validate(md); err != nil { + logger.Errorf("stream: failed to validate md when setting trailer, err: %v", err) + } ss.s.SetTrailer(md) } diff --git a/test/metadata_test.go b/test/metadata_test.go new file mode 100644 index 000000000000..e3da918fc722 --- /dev/null +++ b/test/metadata_test.go @@ -0,0 +1,120 @@ +/* + * + * Copyright 2022 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package test + +import ( + "context" + "fmt" + "io" + "reflect" + "testing" + "time" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/internal/grpctest" + "google.golang.org/grpc/internal/stubserver" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" + testpb "google.golang.org/grpc/test/grpc_testing" +) + +func (s) TestInvalidMetadata(t *testing.T) { + grpctest.TLogger.ExpectErrorN("stream: failed to validate md when setting trailer", 2) + + tests := []struct { + md metadata.MD + want error + recv error + }{ + { + md: map[string][]string{string(rune(0x19)): {"testVal"}}, + want: status.Error(codes.Internal, "header key \"\\x19\" contains illegal characters not in [0-9a-z-_.]"), + recv: status.Error(codes.Internal, "invalid header field name \"\\x19\""), + }, + { + md: map[string][]string{"test": {string(rune(0x19))}}, + want: status.Error(codes.Internal, "header key \"test\" contains value with non-printable ASCII characters"), + recv: status.Error(codes.Internal, "invalid header field value \"\\x19\""), + }, + { + md: map[string][]string{"test-bin": {string(rune(0x19))}}, + want: nil, + recv: io.EOF, + }, + { + md: map[string][]string{"test": {"value"}}, + want: nil, + recv: io.EOF, + }, + } + + testNum := 0 + ss := &stubserver.StubServer{ + EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { + return &testpb.Empty{}, nil + }, + FullDuplexCallF: func(stream testpb.TestService_FullDuplexCallServer) error { + _, err := stream.Recv() + if err != nil { + return err + } + test := tests[testNum] + testNum++ + if err := stream.SetHeader(test.md); !reflect.DeepEqual(test.want, err) { + return fmt.Errorf("call stream.SendHeader(md) validate metadata which is %v got err :%v, want err :%v", test.md, err, test.want) + } + if err := stream.SendHeader(test.md); !reflect.DeepEqual(test.want, err) { + return fmt.Errorf("call stream.SendHeader(md) validate metadata which is %v got err :%v, want err :%v", test.md, err, test.want) + } + stream.SetTrailer(test.md) + return nil + }, + } + if err := ss.Start(nil); err != nil { + t.Fatalf("Error starting ss endpoint server: %v", err) + } + defer ss.Stop() + + for _, test := range tests { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + ctx = metadata.NewOutgoingContext(ctx, test.md) + if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); !reflect.DeepEqual(test.want, err) { + t.Errorf("call ss.Client.EmptyCall() validate metadata which is %v got err :%v, want err :%v", test.md, err, test.want) + } + } + + // call the stream server's api to drive the server-side unit testing + for _, test := range tests { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + stream, err := ss.Client.FullDuplexCall(ctx) + defer cancel() + if err != nil { + t.Errorf("call ss.Client.FullDuplexCall(context.Background()) will success but got err :%v", err) + continue + } + if err := stream.Send(&testpb.StreamingOutputCallRequest{}); err != nil { + t.Errorf("call ss.Client stream Send(nil) will success but got err :%v", err) + } + if _, err := stream.Recv(); !reflect.DeepEqual(test.recv, err) { + t.Errorf("stream.Recv() = _, get err :%v, want err :%v", err, test.recv) + } + } +}