Skip to content

Commit

Permalink
Merge pull request #61 from basvanbeek/master
Browse files Browse the repository at this point in the history
middleware: Improved http.Handler logic and added RequestSampler option
  • Loading branch information
basvanbeek authored Jul 27, 2018
2 parents 8a54c36 + 8e1f1f4 commit d455a56
Show file tree
Hide file tree
Showing 2 changed files with 306 additions and 1 deletion.
242 changes: 241 additions & 1 deletion middleware/http/server.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package http

import (
"io"
"net/http"
"strconv"
"sync/atomic"
Expand All @@ -16,6 +17,7 @@ type handler struct {
next http.Handler
tagResponseSize bool
defaultTags map[string]string
requestSampler func(r *http.Request) bool
}

// ServerOption allows Middleware to be optionally configured.
Expand Down Expand Up @@ -46,6 +48,14 @@ func SpanName(name string) ServerOption {
}
}

// RequestSampler allows one to set the sampling decision based on the details
// found in the http.Request.
func RequestSampler(sampleFunc func(r *http.Request) bool) ServerOption {
return func(h *handler) {
h.requestSampler = sampleFunc
}
}

// NewServerMiddleware returns a http.Handler middleware with Zipkin tracing.
func NewServerMiddleware(t *zipkin.Tracer, options ...ServerOption) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
Expand All @@ -67,6 +77,11 @@ func (h handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// try to extract B3 Headers from upstream
sc := h.tracer.Extract(b3.ExtractHTTP(r))

if h.requestSampler != nil && sc.Sampled == nil {
sample := h.requestSampler(r)
sc.Sampled = &sample
}

remoteEndpoint, _ := zipkin.NewEndpoint("", r.RemoteAddr)

if len(h.name) == 0 {
Expand Down Expand Up @@ -114,7 +129,7 @@ func (h handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}()

// call next http Handler func using our updated context.
h.next.ServeHTTP(ri, r.WithContext(ctx))
h.next.ServeHTTP(ri.wrap(), r.WithContext(ctx))
}

// rwInterceptor intercepts the ResponseWriter so it can track response size
Expand Down Expand Up @@ -147,3 +162,228 @@ func (r *rwInterceptor) getStatusCode() int {
func (r *rwInterceptor) getResponseSize() string {
return strconv.FormatUint(atomic.LoadUint64(&r.size), 10)
}

func (r *rwInterceptor) wrap() http.ResponseWriter {
var (
hj, i0 = r.w.(http.Hijacker)
cn, i1 = r.w.(http.CloseNotifier)
pu, i2 = r.w.(http.Pusher)
fl, i3 = r.w.(http.Flusher)
rf, i4 = r.w.(io.ReaderFrom)
)

switch {
case !i0 && !i1 && !i2 && !i3 && !i4:
return struct {
http.ResponseWriter
}{r}
case !i0 && !i1 && !i2 && !i3 && i4:
return struct {
http.ResponseWriter
io.ReaderFrom
}{r, rf}
case !i0 && !i1 && !i2 && i3 && !i4:
return struct {
http.ResponseWriter
http.Flusher
}{r, fl}
case !i0 && !i1 && !i2 && i3 && i4:
return struct {
http.ResponseWriter
http.Flusher
io.ReaderFrom
}{r, fl, rf}
case !i0 && !i1 && i2 && !i3 && !i4:
return struct {
http.ResponseWriter
http.Pusher
}{r, pu}
case !i0 && !i1 && i2 && !i3 && i4:
return struct {
http.ResponseWriter
http.Pusher
io.ReaderFrom
}{r, pu, rf}
case !i0 && !i1 && i2 && i3 && !i4:
return struct {
http.ResponseWriter
http.Pusher
http.Flusher
}{r, pu, fl}
case !i0 && !i1 && i2 && i3 && i4:
return struct {
http.ResponseWriter
http.Pusher
http.Flusher
io.ReaderFrom
}{r, pu, fl, rf}
case !i0 && i1 && !i2 && !i3 && !i4:
return struct {
http.ResponseWriter
http.CloseNotifier
}{r, cn}
case !i0 && i1 && !i2 && !i3 && i4:
return struct {
http.ResponseWriter
http.CloseNotifier
io.ReaderFrom
}{r, cn, rf}
case !i0 && i1 && !i2 && i3 && !i4:
return struct {
http.ResponseWriter
http.CloseNotifier
http.Flusher
}{r, cn, fl}
case !i0 && i1 && !i2 && i3 && i4:
return struct {
http.ResponseWriter
http.CloseNotifier
http.Flusher
io.ReaderFrom
}{r, cn, fl, rf}
case !i0 && i1 && i2 && !i3 && !i4:
return struct {
http.ResponseWriter
http.CloseNotifier
http.Pusher
}{r, cn, pu}
case !i0 && i1 && i2 && !i3 && i4:
return struct {
http.ResponseWriter
http.CloseNotifier
http.Pusher
io.ReaderFrom
}{r, cn, pu, rf}
case !i0 && i1 && i2 && i3 && !i4:
return struct {
http.ResponseWriter
http.CloseNotifier
http.Pusher
http.Flusher
}{r, cn, pu, fl}
case !i0 && i1 && i2 && i3 && i4:
return struct {
http.ResponseWriter
http.CloseNotifier
http.Pusher
http.Flusher
io.ReaderFrom
}{r, cn, pu, fl, rf}
case i0 && !i1 && !i2 && !i3 && !i4:
return struct {
http.ResponseWriter
http.Hijacker
}{r, hj}
case i0 && !i1 && !i2 && !i3 && i4:
return struct {
http.ResponseWriter
http.Hijacker
io.ReaderFrom
}{r, hj, rf}
case i0 && !i1 && !i2 && i3 && !i4:
return struct {
http.ResponseWriter
http.Hijacker
http.Flusher
}{r, hj, fl}
case i0 && !i1 && !i2 && i3 && i4:
return struct {
http.ResponseWriter
http.Hijacker
http.Flusher
io.ReaderFrom
}{r, hj, fl, rf}
case i0 && !i1 && i2 && !i3 && !i4:
return struct {
http.ResponseWriter
http.Hijacker
http.Pusher
}{r, hj, pu}
case i0 && !i1 && i2 && !i3 && i4:
return struct {
http.ResponseWriter
http.Hijacker
http.Pusher
io.ReaderFrom
}{r, hj, pu, rf}
case i0 && !i1 && i2 && i3 && !i4:
return struct {
http.ResponseWriter
http.Hijacker
http.Pusher
http.Flusher
}{r, hj, pu, fl}
case i0 && !i1 && i2 && i3 && i4:
return struct {
http.ResponseWriter
http.Hijacker
http.Pusher
http.Flusher
io.ReaderFrom
}{r, hj, pu, fl, rf}
case i0 && i1 && !i2 && !i3 && !i4:
return struct {
http.ResponseWriter
http.Hijacker
http.CloseNotifier
}{r, hj, cn}
case i0 && i1 && !i2 && !i3 && i4:
return struct {
http.ResponseWriter
http.Hijacker
http.CloseNotifier
io.ReaderFrom
}{r, hj, cn, rf}
case i0 && i1 && !i2 && i3 && !i4:
return struct {
http.ResponseWriter
http.Hijacker
http.CloseNotifier
http.Flusher
}{r, hj, cn, fl}
case i0 && i1 && !i2 && i3 && i4:
return struct {
http.ResponseWriter
http.Hijacker
http.CloseNotifier
http.Flusher
io.ReaderFrom
}{r, hj, cn, fl, rf}
case i0 && i1 && i2 && !i3 && !i4:
return struct {
http.ResponseWriter
http.Hijacker
http.CloseNotifier
http.Pusher
}{r, hj, cn, pu}
case i0 && i1 && i2 && !i3 && i4:
return struct {
http.ResponseWriter
http.Hijacker
http.CloseNotifier
http.Pusher
io.ReaderFrom
}{r, hj, cn, pu, rf}
case i0 && i1 && i2 && i3 && !i4:
return struct {
http.ResponseWriter
http.Hijacker
http.CloseNotifier
http.Pusher
http.Flusher
}{r, hj, cn, pu, fl}
case i0 && i1 && i2 && i3 && i4:
return struct {
http.ResponseWriter
http.Hijacker
http.CloseNotifier
http.Pusher
http.Flusher
io.ReaderFrom
}{r, hj, cn, pu, fl, rf}
default:
return struct {
http.ResponseWriter
}{r}
}
}
65 changes: 65 additions & 0 deletions middleware/http/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,68 @@ func TestHTTPDefaultSpanName(t *testing.T) {
t.Errorf("Expected span name %s, got %s", want, have)
}
}

func TestHTTPRequestSampler(t *testing.T) {
var (
spanRecorder = &recorder.ReporterRecorder{}
httpRecorder = httptest.NewRecorder()
requestBuf = bytes.NewBufferString("incoming data")
methodType = "POST"
httpHandlerFunc = http.HandlerFunc(httpHandler(200, nil, bytes.NewBufferString("")))
)

samplers := [](func(r *http.Request) bool){
nil,
func(r *http.Request) bool { return true },
func(r *http.Request) bool { return false },
}

for _, sampler := range samplers {
tr, _ := zipkin.NewTracer(spanRecorder, zipkin.WithLocalEndpoint(lep), zipkin.WithSampler(zipkin.AlwaysSample))

request, err := http.NewRequest(methodType, "/test", requestBuf)
if err != nil {
t.Fatalf("unable to create request")
}

handler := mw.NewServerMiddleware(tr, mw.RequestSampler(sampler))(httpHandlerFunc)

handler.ServeHTTP(httpRecorder, request)

spans := spanRecorder.Flush()

sampledSpans := 0
if sampler == nil || sampler(request) {
sampledSpans = 1
}

if want, have := sampledSpans, len(spans); want != have {
t.Errorf("Expected %d spans, got %d", want, have)
}
}

for _, sampler := range samplers {
tr, _ := zipkin.NewTracer(spanRecorder, zipkin.WithLocalEndpoint(lep), zipkin.WithSampler(zipkin.NeverSample))

request, err := http.NewRequest(methodType, "/test", requestBuf)
if err != nil {
t.Fatalf("unable to create request")
}

handler := mw.NewServerMiddleware(tr, mw.RequestSampler(sampler))(httpHandlerFunc)

handler.ServeHTTP(httpRecorder, request)

spans := spanRecorder.Flush()

sampledSpans := 0
if sampler != nil && sampler(request) {
sampledSpans = 1
}

if want, have := sampledSpans, len(spans); want != have {
t.Errorf("Expected %d spans, got %d", want, have)
}
}

}

0 comments on commit d455a56

Please sign in to comment.