Skip to content

Commit

Permalink
remove files and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
labkode committed Oct 10, 2023
1 parent 0d32711 commit df01798
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 990 deletions.
32 changes: 21 additions & 11 deletions internal/http/interceptors/trace/trace.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
package trace

import (
"context"
"net/http"

"github.com/cs3org/reva/pkg/trace"
Expand All @@ -36,23 +37,32 @@ func New() func(http.Handler) http.Handler {

func handler(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
// try to get trace from context
traceID := trace.Get(ctx)
if traceID == "" {
// check if traceID is coming from header
traceID = r.Header.Get("X-Trace-ID")
if traceID == "" {
traceID = trace.Generate()
}
ctx = trace.Set(ctx, traceID)
}
traceID, ctx := getTraceID(r)

// in case the http service will call a grpc service,
// we set the outgoing context so the trace information is
// passed through the two protocols.
ctx = metadata.AppendToOutgoingContext(ctx, "revad-grpc-trace-id", traceID)

r = r.WithContext(ctx)
h.ServeHTTP(w, r)
})
}

func getTraceID(r *http.Request) (string, context.Context) {
ctx := r.Context()
// try to get trace from context
traceID := trace.Get(ctx)
if traceID == "" {
// check if traceID is coming from header
traceID = r.Header.Get("X-Trace-ID")
if traceID == "" {
traceID = r.Header.Get("X-Request-ID")
if traceID == "" {
traceID = trace.Generate()
}
}
ctx = trace.Set(ctx, traceID)
}
return traceID, ctx
}
58 changes: 58 additions & 0 deletions internal/http/interceptors/trace/trace_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package trace

import (
"context"
"net/http"
"testing"

"github.com/cs3org/reva/pkg/trace"
)

type testPair struct {
e string
r *http.Request
}

func TestGetTrace(t *testing.T) {
pairs := []*testPair{
&testPair{
r: newRequest(context.Background(), map[string]string{"X-Trace-ID": "def"}),
e: "def",
},
&testPair{
r: newRequest(context.Background(), map[string]string{"X-Request-ID": "abc"}),
e: "abc",
},
&testPair{
r: newRequest(trace.Set(context.Background(), "fgh"), nil),
e: "fgh",
},
}

for _, p := range pairs {
got, _ := getTraceID(p.r)
t.Logf("headers: %+v context: %+v got: %+v\n", p.r, p.r.Context(), got)
if got != p.e {
t.Fatal("expected: "+p.e, "got: "+got)
return
}
}

}

func newRequest(ctx context.Context, headers map[string]string) *http.Request {
r, _ := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost", nil)
for k, v := range headers {
r.Header.Set(k, v)
}
return r
}

func TestGenerateTrace(t *testing.T) {
got, _ := getTraceID(newRequest(context.Background(), nil))
if len(got) != 36 {
t.Fatal("expected random generated UUID 36 chars trace ID but got:" + got)
return
}

}
Loading

0 comments on commit df01798

Please sign in to comment.