Skip to content

Commit

Permalink
Add a rate limit on logger
Browse files Browse the repository at this point in the history
Signed-off-by: Yuri Nikolic <[email protected]>
  • Loading branch information
duricanikolic committed Aug 9, 2023
1 parent e772133 commit 3bd7a14
Show file tree
Hide file tree
Showing 6 changed files with 271 additions and 24 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
* [FEATURE] Add support for waiting on the rate limiter using the new `WaitN` method. #279
* [FEATURE] Add `log.BufferedLogger` type. #338
* [FEATURE] Add `flagext.ParseFlagsAndArguments()` and `flagext.ParseFlagsWithoutArguments()` utilities. #341
* [FEATURE] Add a rate limited logger `rateLimitedLogger`, used by `middleware.Log` for limiting the rate of logging errors `502` and `503`. The rate limit is configurable via the newly introduced `-server.log-error-rate` CLI flag. #352
* [ENHANCEMENT] Add configuration to customize backoff for the gRPC clients.
* [ENHANCEMENT] Use `SecretReader` interface to fetch secrets when configuring TLS. #274
* [ENHANCEMENT] Add middleware package. #38
Expand Down
79 changes: 79 additions & 0 deletions log/ratelimit.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package log

import "golang.org/x/time/rate"

type rateLimitedLogger struct {
next Interface
limiter *rate.Limiter
}

// NewRateLimitedLogger returns a logger.Interface that is limited to a number
// of logs per second
func NewRateLimitedLogger(logger Interface, logsPerSecond rate.Limit, burstSize int) Interface {
return &rateLimitedLogger{
next: logger,
limiter: rate.NewLimiter(logsPerSecond, burstSize),
}
}

func (l *rateLimitedLogger) Debugf(format string, args ...interface{}) {
if l.limiter.Allow() {
l.next.Debugf(format, args...)
}
}

func (l *rateLimitedLogger) Debugln(args ...interface{}) {
if l.limiter.Allow() {
l.next.Debugln(args...)
}
}

func (l *rateLimitedLogger) Infof(format string, args ...interface{}) {
if l.limiter.Allow() {
l.next.Infof(format, args...)
}
}

func (l *rateLimitedLogger) Infoln(args ...interface{}) {
if l.limiter.Allow() {
l.next.Infoln(args...)
}
}

func (l *rateLimitedLogger) Errorf(format string, args ...interface{}) {
if l.limiter.Allow() {
l.next.Errorf(format, args...)
}
}

func (l *rateLimitedLogger) Errorln(args ...interface{}) {
if l.limiter.Allow() {
l.next.Errorln(args...)
}
}

func (l *rateLimitedLogger) Warnf(format string, args ...interface{}) {
if l.limiter.Allow() {
l.next.Warnf(format, args...)
}
}

func (l *rateLimitedLogger) Warnln(args ...interface{}) {
if l.limiter.Allow() {
l.next.Warnln(args...)
}
}

func (l *rateLimitedLogger) WithField(key string, value interface{}) Interface {
return &rateLimitedLogger{
next: l.next.WithField(key, value),
limiter: l.limiter,
}
}

func (l *rateLimitedLogger) WithFields(f Fields) Interface {
return &rateLimitedLogger{
next: l.next.WithFields(f),
limiter: l.limiter,
}
}
59 changes: 59 additions & 0 deletions log/ratelimit_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package log

import (
"testing"
"time"

"github.com/stretchr/testify/assert"
)

type counterLogger struct {
count int
}

func (c *counterLogger) Debugf(_ string, _ ...interface{}) { c.count++ }
func (c *counterLogger) Debugln(_ ...interface{}) { c.count++ }
func (c *counterLogger) Infof(_ string, _ ...interface{}) { c.count++ }
func (c *counterLogger) Infoln(_ ...interface{}) { c.count++ }
func (c *counterLogger) Warnf(_ string, _ ...interface{}) { c.count++ }
func (c *counterLogger) Warnln(_ ...interface{}) { c.count++ }
func (c *counterLogger) Errorf(_ string, _ ...interface{}) { c.count++ }
func (c *counterLogger) Errorln(_ ...interface{}) { c.count++ }
func (c *counterLogger) WithField(_ string, _ interface{}) Interface {
return c
}
func (c *counterLogger) WithFields(Fields) Interface {
return c
}

func TestRateLimitedLoggerLogs(t *testing.T) {
c := &counterLogger{}
r := NewRateLimitedLogger(c, 1, 1)

r.Errorln("asdf")
assert.Equal(t, 1, c.count)
}

func TestRateLimitedLoggerLimits(t *testing.T) {
c := &counterLogger{}
r := NewRateLimitedLogger(c, 2, 2)

r.Errorln("asdf")
r.Infoln("asdf")
r.Debugln("asdf")
assert.Equal(t, 2, c.count)
time.Sleep(time.Second)
r.Infoln("asdf")
assert.Equal(t, 3, c.count)
}

func TestRateLimitedLoggerWithFields(t *testing.T) {
c := &counterLogger{}
r := NewRateLimitedLogger(c, 1, 1)
r2 := r.WithField("key", "value")

r.Errorf("asdf")
r2.Errorln("asdf")
r2.Warnln("asdf")
assert.Equal(t, 1, c.count)
}
76 changes: 54 additions & 22 deletions middleware/logging.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
// Log middleware logs http requests
type Log struct {
Log log.Interface
HighVolumeErrorLog log.Interface
DisableRequestSuccessLog bool
LogRequestHeaders bool // LogRequestHeaders true -> dump http headers at debug log level
LogRequestAtInfoLevel bool // LogRequestAtInfoLevel true -> log requests at info log level
Expand All @@ -32,7 +33,7 @@ var defaultExcludedHeaders = map[string]bool{
"Authorization": true,
}

func NewLogMiddleware(log log.Interface, logRequestHeaders bool, logRequestAtInfoLevel bool, sourceIPs *SourceIPExtractor, headersList []string) Log {
func NewLogMiddleware(log log.Interface, highVolumeErrorLog log.Interface, logRequestHeaders bool, logRequestAtInfoLevel bool, sourceIPs *SourceIPExtractor, headersList []string) Log {
httpHeadersToExclude := map[string]bool{}
for header := range defaultExcludedHeaders {
httpHeadersToExclude[header] = true
Expand All @@ -43,42 +44,68 @@ func NewLogMiddleware(log log.Interface, logRequestHeaders bool, logRequestAtInf

return Log{
Log: log,
HighVolumeErrorLog: highVolumeErrorLog,
LogRequestHeaders: logRequestHeaders,
LogRequestAtInfoLevel: logRequestAtInfoLevel,
SourceIPs: sourceIPs,
HTTPHeadersToExclude: httpHeadersToExclude,
}
}

// logWithRequest information from the request and context as fields.
func (l Log) logWithRequest(r *http.Request) log.Interface {
localLog := l.Log
// logsWithFields returns this Log's Log and HighVolumeErrorLog instances enriched
// with the details from the request and context as fields.
// If any of the instances is not set, the corresponding retirned value is nil.
func (l Log) logsWithFields(r *http.Request) (log.Interface, log.Interface) {
logWithRequest := l.logWithFields(r, l.Log)
highVolumeErrorLogWithRequest := l.logWithFields(r, l.HighVolumeErrorLog)

return logWithRequest, highVolumeErrorLogWithRequest
}

// logWithFields enriches the given log.Interface instance with the details from
// the request and context as fields. If the former is nil, nil is returned.
func (l Log) logWithFields(r *http.Request, logger log.Interface) log.Interface {
logWithFields := logger
if logWithFields == nil {
return nil
}
traceID, ok := tracing.ExtractTraceID(r.Context())
if ok {
localLog = localLog.WithField("traceID", traceID)
logWithFields = logWithFields.WithField("traceID", traceID)
}

if l.SourceIPs != nil {
ips := l.SourceIPs.Get(r)
if ips != "" {
localLog = localLog.WithField("sourceIPs", ips)
logWithFields = logWithFields.WithField("sourceIPs", ips)
}
}

return user.LogWith(r.Context(), localLog)
return user.LogWith(r.Context(), logWithFields)
}

// logHighVolumeError logs details about the error passed as input.
// If the passed highVolumeErrorLog is set, the error is logged there at Warn level.
// Otherwise, the error is logged by using the passed log, at Debug level.
func (l Log) logHighVolumeError(highVolumeErrorLog, log log.Interface, format string, args ...interface{}) {
if highVolumeErrorLog != nil {
highVolumeErrorLog.Warnf(format, args...)
} else {
log.Debugf(format, args...)
}
}

// Wrap implements Middleware
func (l Log) Wrap(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
begin := time.Now()
uri := r.RequestURI // capture the URI before running next, as it may get rewritten
requestLog := l.logWithRequest(r)
requestLogger, highVolumeErrorLogger := l.logsWithFields(r)
// Log headers before running 'next' in case other interceptors change the data.
headers, err := dumpRequest(r, l.HTTPHeadersToExclude)
if err != nil {
headers = nil
requestLog.Errorf("Could not dump request headers: %v", err)
requestLogger.Errorf("Could not dump request headers: %v", err)
}
var buf bytes.Buffer
wrapped := newBadResponseLoggingWriter(w, &buf)
Expand All @@ -89,12 +116,12 @@ func (l Log) Wrap(next http.Handler) http.Handler {
if writeErr != nil {
if errors.Is(writeErr, context.Canceled) {
if l.LogRequestAtInfoLevel {
requestLog.Infof("%s %s %s, request cancelled: %s ws: %v; %s", r.Method, uri, time.Since(begin), writeErr, IsWSHandshakeRequest(r), headers)
requestLogger.Infof("%s %s %s, request cancelled: %s ws: %v; %s", r.Method, uri, time.Since(begin), writeErr, IsWSHandshakeRequest(r), headers)
} else {
requestLog.Debugf("%s %s %s, request cancelled: %s ws: %v; %s", r.Method, uri, time.Since(begin), writeErr, IsWSHandshakeRequest(r), headers)
requestLogger.Debugf("%s %s %s, request cancelled: %s ws: %v; %s", r.Method, uri, time.Since(begin), writeErr, IsWSHandshakeRequest(r), headers)
}
} else {
requestLog.Warnf("%s %s %s, error: %s ws: %v; %s", r.Method, uri, time.Since(begin), writeErr, IsWSHandshakeRequest(r), headers)
requestLogger.Warnf("%s %s %s, error: %s ws: %v; %s", r.Method, uri, time.Since(begin), writeErr, IsWSHandshakeRequest(r), headers)
}

return
Expand All @@ -105,23 +132,28 @@ func (l Log) Wrap(next http.Handler) http.Handler {
case statusCode >= 200 && statusCode < 300 && l.DisableRequestSuccessLog:
return

case 100 <= statusCode && statusCode < 500 || statusCode == http.StatusBadGateway || statusCode == http.StatusServiceUnavailable:
case 100 <= statusCode && statusCode < 500:
if l.LogRequestAtInfoLevel {
requestLog.Infof("%s %s (%d) %s", r.Method, uri, statusCode, time.Since(begin))

if l.LogRequestHeaders && headers != nil {
requestLog.Infof("ws: %v; %s", IsWSHandshakeRequest(r), string(headers))
requestLogger.Infof("%s %s (%d) %s ws: %v; %s", r.Method, uri, statusCode, time.Since(begin), IsWSHandshakeRequest(r), string(headers))
} else {
requestLogger.Infof("%s %s (%d) %s", r.Method, uri, statusCode, time.Since(begin))
}
} else {
if l.LogRequestHeaders && headers != nil {
requestLogger.Debugf("%s %s (%d) %s ws: %v; %s", r.Method, uri, statusCode, time.Since(begin), IsWSHandshakeRequest(r), string(headers))
} else {
requestLogger.Debugf("%s %s (%d) %s", r.Method, uri, statusCode, time.Since(begin))
}
return
}

requestLog.Debugf("%s %s (%d) %s", r.Method, uri, statusCode, time.Since(begin))
case statusCode == http.StatusBadGateway || statusCode == http.StatusServiceUnavailable:
if l.LogRequestHeaders && headers != nil {
requestLog.Debugf("ws: %v; %s", IsWSHandshakeRequest(r), string(headers))
l.logHighVolumeError(highVolumeErrorLogger, requestLogger, "%s %s (%d) %s ws: %v; %s", r.Method, uri, statusCode, time.Since(begin), IsWSHandshakeRequest(r), string(headers))
} else {
l.logHighVolumeError(highVolumeErrorLogger, requestLogger, "%s %s (%d) %s", r.Method, uri, statusCode, time.Since(begin))
}
default:
requestLog.Warnf("%s %s (%d) %s Response: %q ws: %v; %s",
r.Method, uri, statusCode, time.Since(begin), buf.Bytes(), IsWSHandshakeRequest(r), headers)
requestLogger.Warnf("%s %s (%d) %s Response: %q ws: %v; %s", r.Method, uri, statusCode, time.Since(begin), buf.Bytes(), IsWSHandshakeRequest(r), headers)
}
})
}
Expand Down
70 changes: 69 additions & 1 deletion middleware/logging_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ func TestLoggingRequestWithExcludedHeaders(t *testing.T) {
logrusLogger.Out = buf
logrusLogger.Level = logrus.DebugLevel

loggingMiddleware := NewLogMiddleware(log.Logrus(logrusLogger), true, false, nil, tc.excludeHeaderList)
loggingMiddleware := NewLogMiddleware(log.Logrus(logrusLogger), nil, true, false, nil, tc.excludeHeaderList)

handler := func(w http.ResponseWriter, r *http.Request) {
_, _ = io.WriteString(w, "<html><body>Hello world!</body></html>")
Expand All @@ -206,6 +206,74 @@ func TestLoggingRequestWithExcludedHeaders(t *testing.T) {
}
}

func TestLoggingRequestsWithError(t *testing.T) {
for _, tc := range []struct {
err error
statusCode int
highVolumeErrorLog bool
logContains []string
}{{
err: errors.New("Bad Gateway"),
statusCode: 502,
logContains: []string{"debug", "GET http://example.com/foo (502)"},
}, {
err: errors.New("Bad Gateway"),
statusCode: 502,
highVolumeErrorLog: true,
logContains: []string{"warning", "GET http://example.com/foo (502)"},
}, {
err: errors.New("Service Unavailable"),
statusCode: 503,
logContains: []string{"debug", "GET http://example.com/foo (503)"},
}, {
err: errors.New("Service Unavailable"),
statusCode: 503,
highVolumeErrorLog: true,
logContains: []string{"warning", "GET http://example.com/foo (503)"},
}, {
err: errors.New("Gateway Timeout"),
statusCode: 504,
logContains: []string{"warning", "GET http://example.com/foo (504)"},
}} {
buf := bytes.NewBuffer(nil)
logrusLogger := logrus.New()
logrusLogger.Out = buf
logrusLogger.Level = logrus.DebugLevel

var highVolumeErrorLog log.Interface
if tc.highVolumeErrorLog {
highVolumeErrorLog = log.NewRateLimitedLogger(log.Logrus(logrusLogger), 1, 1)
}

loggingMiddleware := Log{
Log: log.Logrus(logrusLogger),
HighVolumeErrorLog: highVolumeErrorLog,
LogRequestAtInfoLevel: true,
}
handler := func(w http.ResponseWriter, r *http.Request) {
if tc.err == nil {
_, _ = io.WriteString(w, "<html><body>Hello World!</body></html>")
} else {
w.WriteHeader(tc.statusCode)
}
}
loggingHandler := loggingMiddleware.Wrap(http.HandlerFunc(handler))

req := httptest.NewRequest("GET", "http://example.com/foo", nil)
recorder := httptest.NewRecorder()

w := errorWriter{
err: tc.err,
w: recorder,
}
loggingHandler.ServeHTTP(w, req)

for _, content := range tc.logContains {
require.True(t, bytes.Contains(buf.Bytes(), []byte(content)))
}
}
}

type errorWriter struct {
err error

Expand Down
Loading

0 comments on commit 3bd7a14

Please sign in to comment.