diff --git a/cmd/collector/app/zipkin/http_handler_test.go b/cmd/collector/app/zipkin/http_handler_test.go index 3a29d56a13e..62a41d251f9 100644 --- a/cmd/collector/app/zipkin/http_handler_test.go +++ b/cmd/collector/app/zipkin/http_handler_test.go @@ -295,9 +295,9 @@ func TestSaveProtoSpansV2(t *testing.T) { resBody string }{ {Span: zipkinProto.Span{Id: validID, TraceId: validTraceID, LocalEndpoint: &zipkinProto.Endpoint{Ipv4: randBytesOfLen(4)}, Kind: zipkinProto.Span_CLIENT}, StatusCode: http.StatusAccepted}, - {Span: zipkinProto.Span{Id: randBytesOfLen(4)}, StatusCode: http.StatusBadRequest, resBody: "Unable to process request body: invalid length for Span ID\n"}, - {Span: zipkinProto.Span{Id: validID, TraceId: randBytesOfLen(32)}, StatusCode: http.StatusBadRequest, resBody: "Unable to process request body: invalid length for traceId\n"}, - {Span: zipkinProto.Span{Id: validID, TraceId: validTraceID, ParentId: randBytesOfLen(16)}, StatusCode: http.StatusBadRequest, resBody: "Unable to process request body: invalid length for parentId\n"}, + {Span: zipkinProto.Span{Id: randBytesOfLen(4)}, StatusCode: http.StatusBadRequest, resBody: "Unable to process request body: invalid length for SpanID\n"}, + {Span: zipkinProto.Span{Id: validID, TraceId: randBytesOfLen(32)}, StatusCode: http.StatusBadRequest, resBody: "Unable to process request body: invalid length for TraceID\n"}, + {Span: zipkinProto.Span{Id: validID, TraceId: validTraceID, ParentId: randBytesOfLen(16)}, StatusCode: http.StatusBadRequest, resBody: "Unable to process request body: invalid length for SpanID\n"}, {Span: zipkinProto.Span{Id: validID, TraceId: validTraceID, LocalEndpoint: &zipkinProto.Endpoint{Ipv4: randBytesOfLen(2)}}, StatusCode: http.StatusBadRequest, resBody: "Unable to process request body: wrong Ipv4\n"}, } for _, test := range tests { diff --git a/cmd/collector/app/zipkin/protov2.go b/cmd/collector/app/zipkin/protov2.go index ca4b8b0c681..bce0298a752 100644 --- a/cmd/collector/app/zipkin/protov2.go +++ b/cmd/collector/app/zipkin/protov2.go @@ -38,14 +38,17 @@ func protoSpansV2ToThrift(listOfSpans *zipkinProto.ListOfSpans) ([]*zipkincore.S } func protoSpanV2ToThrift(s *zipkinProto.Span) (*zipkincore.Span, error) { - if len(s.Id) != model.TraceIDShortBytesLen { - return nil, fmt.Errorf("invalid length for Span ID") + var id model.SpanID + var err error + if id, err = model.SpanIDFromBytes(s.Id); err != nil { + return nil, err } - id := binary.BigEndian.Uint64(s.Id) - traceID, err := traceIDFromBytes(s.TraceId) - if err != nil { + + var traceID model.TraceID + if traceID, err = model.TraceIDFromBytes(s.TraceId); err != nil { return nil, err } + ts, d := int64(s.Timestamp), int64(s.Duration) tSpan := &zipkincore.Span{ ID: int64(id), @@ -61,18 +64,17 @@ func protoSpanV2ToThrift(s *zipkinProto.Span) (*zipkincore.Span, error) { } if len(s.ParentId) > 0 { - if len(s.ParentId) != model.TraceIDShortBytesLen { - return nil, fmt.Errorf("invalid length for parentId") + var parentID model.SpanID + if parentID, err = model.SpanIDFromBytes(s.ParentId); err != nil { + return nil, err } - parentID := binary.BigEndian.Uint64(s.ParentId) signed := int64(parentID) tSpan.ParentID = &signed } var localE *zipkincore.Endpoint if s.LocalEndpoint != nil { - localE, err = protoEndpointV2ToThrift(s.LocalEndpoint) - if err != nil { + if localE, err = protoEndpointV2ToThrift(s.LocalEndpoint); err != nil { return nil, err } } @@ -105,21 +107,6 @@ func protoSpanV2ToThrift(s *zipkinProto.Span) (*zipkincore.Span, error) { return tSpan, nil } -func traceIDFromBytes(tid []byte) (model.TraceID, error) { - var hi, lo uint64 - switch { - case len(tid) > model.TraceIDLongBytesLen: - return model.TraceID{}, fmt.Errorf("invalid length for traceId") - case len(tid) > model.TraceIDShortBytesLen: - hiLen := len(tid) - model.TraceIDShortBytesLen - hi = binary.BigEndian.Uint64(tid[:hiLen]) - lo = binary.BigEndian.Uint64(tid[hiLen:]) - default: - lo = binary.BigEndian.Uint64(tid) - } - return model.TraceID{High: hi, Low: lo}, nil -} - func protoRemoteEndpToAddrAnno(e *zipkinProto.Endpoint, kind zipkinProto.Span_Kind) (*zipkincore.BinaryAnnotation, error) { rEndp, err := protoEndpointV2ToThrift(e) if err != nil { diff --git a/cmd/collector/app/zipkin/protov2_test.go b/cmd/collector/app/zipkin/protov2_test.go index 2ba5637d576..362186f2e45 100644 --- a/cmd/collector/app/zipkin/protov2_test.go +++ b/cmd/collector/app/zipkin/protov2_test.go @@ -82,9 +82,9 @@ func TestIdErrs(t *testing.T) { span zipkinProto.Span errMsg string }{ - {span: zipkinProto.Span{Id: randBytesOfLen(16)}, errMsg: "invalid length for Span ID"}, - {span: zipkinProto.Span{Id: validID, TraceId: invalidTraceID}, errMsg: "invalid length for traceId"}, - {span: zipkinProto.Span{Id: validID, TraceId: validTraceID, ParentId: invalidParentID}, errMsg: "invalid length for parentId"}, + {span: zipkinProto.Span{Id: randBytesOfLen(16)}, errMsg: "invalid length for SpanID"}, + {span: zipkinProto.Span{Id: validID, TraceId: invalidTraceID}, errMsg: "invalid length for TraceID"}, + {span: zipkinProto.Span{Id: validID, TraceId: validTraceID, ParentId: invalidParentID}, errMsg: "invalid length for SpanID"}, } for _, test := range tests { _, err := protoSpanV2ToThrift(&test.span) diff --git a/model/ids.go b/model/ids.go index a814d039789..8b34f7c9a33 100644 --- a/model/ids.go +++ b/model/ids.go @@ -25,10 +25,10 @@ import ( ) const ( - // TraceIDShortBytesLen indicates length of 64bit traceID when represented as list of bytes - TraceIDShortBytesLen = 8 - // TraceIDLongBytesLen indicates length of 128bit traceID when represented as list of bytes - TraceIDLongBytesLen = 16 + // traceIDShortBytesLen indicates length of 64bit traceID when represented as list of bytes + traceIDShortBytesLen = 8 + // traceIDLongBytesLen indicates length of 128bit traceID when represented as list of bytes + traceIDLongBytesLen = 16 ) // TraceID is a random 128bit identifier for a trace @@ -77,6 +77,21 @@ func TraceIDFromString(s string) (TraceID, error) { return TraceID{High: hi, Low: lo}, nil } +// TraceIDFromBytes creates a TraceID from list of bytes +func TraceIDFromBytes(data []byte) (TraceID, error) { + var t TraceID + switch { + case len(data) == traceIDLongBytesLen: + t.High = binary.BigEndian.Uint64(data[:traceIDShortBytesLen]) + t.Low = binary.BigEndian.Uint64(data[traceIDShortBytesLen:]) + case len(data) == traceIDShortBytesLen: + t.Low = binary.BigEndian.Uint64(data) + default: + return TraceID{}, fmt.Errorf("invalid length for TraceID") + } + return t, nil +} + // MarshalText is called by encoding/json, which we do not want people to use. func (t TraceID) MarshalText() ([]byte, error) { return nil, fmt.Errorf("unsupported method TraceID.MarshalText; please use github.com/gogo/protobuf/jsonpb for marshalling") @@ -102,12 +117,9 @@ func (t *TraceID) MarshalTo(data []byte) (n int, err error) { // Unmarshal inflates this trace ID from binary representation. Called by protobuf serialization. func (t *TraceID) Unmarshal(data []byte) error { - if len(data) < 16 { - return fmt.Errorf("buffer is too short") - } - t.High = binary.BigEndian.Uint64(data[:8]) - t.Low = binary.BigEndian.Uint64(data[8:]) - return nil + var err error + *t, err = TraceIDFromBytes(data) + return err } func marshalBytes(dst []byte, src []byte) (n int, err error) { @@ -166,6 +178,14 @@ func SpanIDFromString(s string) (SpanID, error) { return SpanID(id), nil } +// SpanIDFromBytes creates a SpandID from list of bytes +func SpanIDFromBytes(data []byte) (SpanID, error) { + if len(data) != traceIDShortBytesLen { + return SpanID(0), fmt.Errorf("invalid length for SpanID") + } + return NewSpanID(binary.BigEndian.Uint64(data)), nil +} + // MarshalText is called by encoding/json, which we do not want people to use. func (s SpanID) MarshalText() ([]byte, error) { return nil, fmt.Errorf("unsupported method SpanID.MarshalText; please use github.com/gogo/protobuf/jsonpb for marshalling") @@ -190,11 +210,9 @@ func (s *SpanID) MarshalTo(data []byte) (n int, err error) { // Unmarshal inflates span ID from a binary representation. Called by protobuf serialization. func (s *SpanID) Unmarshal(data []byte) error { - if len(data) < 8 { - return fmt.Errorf("buffer is too short") - } - *s = NewSpanID(binary.BigEndian.Uint64(data)) - return nil + var err error + *s, err = SpanIDFromBytes(data) + return err } // MarshalJSON converts span id into a base64 string enclosed in quotes. diff --git a/model/ids_test.go b/model/ids_test.go index ffdd4125709..aa9af0babfd 100644 --- a/model/ids_test.go +++ b/model/ids_test.go @@ -86,3 +86,45 @@ func TestTraceSpanIDMarshalProto(t *testing.T) { }) } } + +func TestSpanIDFromBytes(t *testing.T) { + errTests := [][]byte{ + {0, 0, 0, 0}, + {0, 0, 0, 0, 0, 0, 0, 13, 0}, + } + for _, data := range errTests { + _, err := model.SpanIDFromBytes(data) + require.Error(t, err) + assert.Equal(t, err.Error(), "invalid length for SpanID") + } + + spanID, err := model.SpanIDFromBytes([]byte{0, 0, 0, 0, 0, 0, 0, 13}) + require.NoError(t, err) + assert.Equal(t, spanID, model.NewSpanID(13)) +} + +func TestTraceIDFromBytes(t *testing.T) { + errTests := [][]byte{ + {0, 0, 0, 0, 0, 0, 0, 13, 0, 0, 0, 0, 0, 0, 0, 0, 13}, + {0, 0, 0, 0, 0, 0, 0, 13, 0, 0, 0, 0, 0, 0, 13}, + {0, 0, 0, 0, 0, 0, 13}, + } + for _, data := range errTests { + _, err := model.TraceIDFromBytes(data) + require.Error(t, err) + assert.Equal(t, err.Error(), "invalid length for TraceID") + } + + tests := []struct { + data []byte + expected model.TraceID + }{ + {data: []byte{0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 3}, expected: model.NewTraceID(2, 3)}, + {data: []byte{0, 0, 0, 0, 0, 0, 0, 2}, expected: model.NewTraceID(0, 2)}, + } + for _, test := range tests { + traceID, err := model.TraceIDFromBytes(test.data) + require.NoError(t, err) + assert.Equal(t, traceID, test.expected) + } +} diff --git a/model/span_test.go b/model/span_test.go index 71acb4cfb5b..dfd4d20a101 100644 --- a/model/span_test.go +++ b/model/span_test.go @@ -175,7 +175,7 @@ func TestSpanIDUnmarshalJSONErrors(t *testing.T) { err = id.UnmarshalJSONPB(nil, []byte("")) require.Error(t, err) - assert.Contains(t, err.Error(), "buffer is too short") + assert.Contains(t, err.Error(), "invalid length for SpanID") err = id.UnmarshalJSONPB(nil, []byte("123")) require.Error(t, err) assert.Contains(t, err.Error(), "illegal base64 data")