Skip to content

Commit

Permalink
Enforce query-frontend limits middleware for remote reads too
Browse files Browse the repository at this point in the history
Signed-off-by: Marco Pracucci <[email protected]>
  • Loading branch information
pracucci committed Jun 17, 2024
1 parent f4867bf commit edeb04e
Show file tree
Hide file tree
Showing 7 changed files with 506 additions and 129 deletions.
5 changes: 4 additions & 1 deletion pkg/frontend/querymiddleware/limits.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,14 @@ func (l limitsMiddleware) Do(ctx context.Context, r MetricsQueryRequest) (Respon
// Clamp the time range based on the max query lookback and block retention period.
blocksRetentionPeriod := validation.SmallestPositiveNonZeroDurationPerTenant(tenantIDs, l.CompactorBlocksRetentionPeriod)
maxQueryLookback := validation.SmallestPositiveNonZeroDurationPerTenant(tenantIDs, l.MaxQueryLookback)
maxLookback := util_math.Min(blocksRetentionPeriod, maxQueryLookback)
maxLookback := util_math.Min(blocksRetentionPeriod, maxQueryLookback) // TODO bug! we should not compare zero values

if maxLookback > 0 {
minStartTime := util.TimeToMillis(time.Now().Add(-maxLookback))

if r.GetEnd() < minStartTime {
// TODO this is something that will not work with remote read requests. Maybe we can manipulate the end too?

// The request is fully outside the allowed range, so we can return an
// empty response.
level.Debug(log).Log(
Expand Down
269 changes: 185 additions & 84 deletions pkg/frontend/querymiddleware/limits_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package querymiddleware

import (
"context"
"errors"
"fmt"
"net/http"
"strings"
Expand All @@ -16,6 +17,9 @@ import (

"github.com/go-kit/log"
"github.com/grafana/dskit/user"
"github.com/prometheus/prometheus/model/labels"
"github.com/prometheus/prometheus/prompb"
"github.com/prometheus/prometheus/promql/parser"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -92,38 +96,51 @@ func TestLimitsMiddleware_MaxQueryLookback(t *testing.T) {

for testName, testData := range tests {
t.Run(testName, func(t *testing.T) {
req := &PrometheusRangeQueryRequest{
start: util.TimeToMillis(testData.reqStartTime),
end: util.TimeToMillis(testData.reqEndTime),
reqs := map[string]MetricsQueryRequest{
"range query": &PrometheusRangeQueryRequest{
start: util.TimeToMillis(testData.reqStartTime),
end: util.TimeToMillis(testData.reqEndTime),
},
"remote read": &remoteReadQueryRequest{
path: remoteReadPathSuffix,
query: &prompb.Query{
StartTimestampMs: util.TimeToMillis(testData.reqStartTime),
EndTimestampMs: util.TimeToMillis(testData.reqEndTime),
},
},
}

limits := mockLimits{maxQueryLookback: testData.maxQueryLookback, compactorBlocksRetentionPeriod: testData.blocksRetentionPeriod}
middleware := newLimitsMiddleware(limits, log.NewNopLogger())

innerRes := newEmptyPrometheusResponse()
inner := &mockHandler{}
inner.On("Do", mock.Anything, mock.Anything).Return(innerRes, nil)

ctx := user.InjectOrgID(context.Background(), "test")
outer := middleware.Wrap(inner)
res, err := outer.Do(ctx, req)
require.NoError(t, err)

if testData.expectedSkipped {
// We expect an empty response, but not the one returned by the inner handler
// which we expect has been skipped.
assert.NotSame(t, innerRes, res)
assert.Len(t, inner.Calls, 0)
} else {
// We expect the response returned by the inner handler.
assert.Same(t, innerRes, res)

// Assert on the time range of the request passed to the inner handler (5s delta).
delta := float64(5000)
require.Len(t, inner.Calls, 1)

assert.InDelta(t, util.TimeToMillis(testData.expectedStartTime), inner.Calls[0].Arguments.Get(1).(MetricsQueryRequest).GetStart(), delta)
assert.InDelta(t, util.TimeToMillis(testData.expectedEndTime), inner.Calls[0].Arguments.Get(1).(MetricsQueryRequest).GetEnd(), delta)
for reqType, req := range reqs {
t.Run(reqType, func(t *testing.T) {
limits := mockLimits{maxQueryLookback: testData.maxQueryLookback, compactorBlocksRetentionPeriod: testData.blocksRetentionPeriod}
middleware := newLimitsMiddleware(limits, log.NewNopLogger())

innerRes := newEmptyPrometheusResponse()
inner := &mockHandler{}
inner.On("Do", mock.Anything, mock.Anything).Return(innerRes, nil)

ctx := user.InjectOrgID(context.Background(), "test")
outer := middleware.Wrap(inner)
res, err := outer.Do(ctx, req)
require.NoError(t, err)

if testData.expectedSkipped {
// We expect an empty response, but not the one returned by the inner handler
// which we expect has been skipped.
assert.NotSame(t, innerRes, res)
assert.Len(t, inner.Calls, 0)
} else {
// We expect the response returned by the inner handler.
assert.Same(t, innerRes, res)

// Assert on the time range of the request passed to the inner handler (5s delta).
delta := float64(5000)
require.Len(t, inner.Calls, 1)

assert.InDelta(t, util.TimeToMillis(testData.expectedStartTime), inner.Calls[0].Arguments.Get(1).(MetricsQueryRequest).GetStart(), delta)
assert.InDelta(t, util.TimeToMillis(testData.expectedEndTime), inner.Calls[0].Arguments.Get(1).(MetricsQueryRequest).GetEnd(), delta)
}
})
}
})
}
Expand Down Expand Up @@ -166,34 +183,64 @@ func TestLimitsMiddleware_MaxQueryExpressionSizeBytes(t *testing.T) {

for testName, testData := range tests {
t.Run(testName, func(t *testing.T) {
req := &PrometheusRangeQueryRequest{
queryExpr: parseQuery(t, testData.query),
start: util.TimeToMillis(now.Add(-time.Hour * 2)),
end: util.TimeToMillis(now.Add(-time.Hour)),
}

limits := multiTenantMockLimits{
byTenant: map[string]mockLimits{
"test1": {maxQueryExpressionSizeBytes: testData.queryLimits["test1"]},
"test2": {maxQueryExpressionSizeBytes: testData.queryLimits["test2"]},
startMs := util.TimeToMillis(now.Add(-time.Hour * 2))
endMs := util.TimeToMillis(now.Add(-time.Hour))

reqs := map[string]MetricsQueryRequest{
"range query": &PrometheusRangeQueryRequest{
queryExpr: parseQuery(t, testData.query),
start: startMs,
end: endMs,
},
"remote read": &remoteReadQueryRequest{
path: remoteReadPathSuffix,
promQuery: testData.query,
query: &prompb.Query{
StartTimestampMs: startMs,
EndTimestampMs: endMs,
Matchers: func() []*prompb.LabelMatcher {
v := &findVectorSelectorsVisitor{}
require.NoError(t, parser.Walk(v, parseQuery(t, testData.query), nil))

// This test requires the query has only 1 vector selector.
require.Len(t, v.selectors, 1)
require.NotEmpty(t, v.selectors[0].LabelMatchers)

matchers, err := toLabelMatchers(v.selectors[0].LabelMatchers)
require.NoError(t, err)

return matchers
}(),
},
},
}
middleware := newLimitsMiddleware(limits, log.NewNopLogger())

innerRes := newEmptyPrometheusResponse()
inner := &mockHandler{}
inner.On("Do", mock.Anything, mock.Anything).Return(innerRes, nil)

ctx := user.InjectOrgID(context.Background(), "test1|test2")
outer := middleware.Wrap(inner)
res, err := outer.Do(ctx, req)

if testData.expectError {
require.Error(t, err)
require.Contains(t, err.Error(), "err-mimir-max-query-expression-size-bytes")
} else {
require.NoError(t, err)
require.Same(t, innerRes, res)

for reqType, req := range reqs {
t.Run(reqType, func(t *testing.T) {
limits := multiTenantMockLimits{
byTenant: map[string]mockLimits{
"test1": {maxQueryExpressionSizeBytes: testData.queryLimits["test1"]},
"test2": {maxQueryExpressionSizeBytes: testData.queryLimits["test2"]},
},
}
middleware := newLimitsMiddleware(limits, log.NewNopLogger())

innerRes := newEmptyPrometheusResponse()
inner := &mockHandler{}
inner.On("Do", mock.Anything, mock.Anything).Return(innerRes, nil)

ctx := user.InjectOrgID(context.Background(), "test1|test2")
outer := middleware.Wrap(inner)
res, err := outer.Do(ctx, req)

if testData.expectError {
require.Error(t, err)
require.Contains(t, err.Error(), "err-mimir-max-query-expression-size-bytes")
} else {
require.NoError(t, err)
require.Same(t, innerRes, res)
}
})
}
})
}
Expand Down Expand Up @@ -255,36 +302,50 @@ func TestLimitsMiddleware_MaxQueryLength(t *testing.T) {

for testName, testData := range tests {
t.Run(testName, func(t *testing.T) {
req := &PrometheusRangeQueryRequest{
start: util.TimeToMillis(testData.reqStartTime),
end: util.TimeToMillis(testData.reqEndTime),
reqs := map[string]MetricsQueryRequest{
"range query": &PrometheusRangeQueryRequest{
start: util.TimeToMillis(testData.reqStartTime),
end: util.TimeToMillis(testData.reqEndTime),
},
"remote read": &remoteReadQueryRequest{
path: remoteReadPathSuffix,
promQuery: "",
query: &prompb.Query{
StartTimestampMs: util.TimeToMillis(testData.reqStartTime),
EndTimestampMs: util.TimeToMillis(testData.reqEndTime),
},
},
}

limits := mockLimits{maxQueryLength: testData.maxQueryLength, maxTotalQueryLength: testData.maxTotalQueryLength}
middleware := newLimitsMiddleware(limits, log.NewNopLogger())

innerRes := newEmptyPrometheusResponse()
inner := &mockHandler{}
inner.On("Do", mock.Anything, mock.Anything).Return(innerRes, nil)

ctx := user.InjectOrgID(context.Background(), "test")
outer := middleware.Wrap(inner)
res, err := outer.Do(ctx, req)

if testData.expectedErr != "" {
require.Error(t, err)
assert.Contains(t, err.Error(), testData.expectedErr)
assert.Nil(t, res)
assert.Len(t, inner.Calls, 0)
} else {
// We expect the response returned by the inner handler.
require.NoError(t, err)
assert.Same(t, innerRes, res)

// The time range of the request passed to the inner handler should have not been manipulated.
require.Len(t, inner.Calls, 1)
assert.Equal(t, util.TimeToMillis(testData.reqStartTime), inner.Calls[0].Arguments.Get(1).(MetricsQueryRequest).GetStart())
assert.Equal(t, util.TimeToMillis(testData.reqEndTime), inner.Calls[0].Arguments.Get(1).(MetricsQueryRequest).GetEnd())
for reqType, req := range reqs {
t.Run(reqType, func(t *testing.T) {
limits := mockLimits{maxQueryLength: testData.maxQueryLength, maxTotalQueryLength: testData.maxTotalQueryLength}
middleware := newLimitsMiddleware(limits, log.NewNopLogger())

innerRes := newEmptyPrometheusResponse()
inner := &mockHandler{}
inner.On("Do", mock.Anything, mock.Anything).Return(innerRes, nil)

ctx := user.InjectOrgID(context.Background(), "test")
outer := middleware.Wrap(inner)
res, err := outer.Do(ctx, req)

if testData.expectedErr != "" {
require.Error(t, err)
assert.Contains(t, err.Error(), testData.expectedErr)
assert.Nil(t, res)
assert.Len(t, inner.Calls, 0)
} else {
// We expect the response returned by the inner handler.
require.NoError(t, err)
assert.Same(t, innerRes, res)

// The time range of the request passed to the inner handler should have not been manipulated.
require.Len(t, inner.Calls, 1)
assert.Equal(t, util.TimeToMillis(testData.reqStartTime), inner.Calls[0].Arguments.Get(1).(MetricsQueryRequest).GetStart())
assert.Equal(t, util.TimeToMillis(testData.reqEndTime), inner.Calls[0].Arguments.Get(1).(MetricsQueryRequest).GetEnd())
}
})
}
})
}
Expand Down Expand Up @@ -728,3 +789,43 @@ func BenchmarkLimitedParallelismRoundTripper(b *testing.B) {
}
}
}

type findVectorSelectorsVisitor struct {
selectors []*parser.VectorSelector
}

func (v *findVectorSelectorsVisitor) Visit(node parser.Node, _ []parser.Node) (parser.Visitor, error) {
selector, ok := node.(*parser.VectorSelector)
if !ok {
return v, nil
}

v.selectors = append(v.selectors, selector)
return v, nil
}

// TODO copied from prometheus/storage/remote: expose it
func toLabelMatchers(matchers []*labels.Matcher) ([]*prompb.LabelMatcher, error) {
pbMatchers := make([]*prompb.LabelMatcher, 0, len(matchers))
for _, m := range matchers {
var mType prompb.LabelMatcher_Type
switch m.Type {
case labels.MatchEqual:
mType = prompb.LabelMatcher_EQ
case labels.MatchNotEqual:
mType = prompb.LabelMatcher_NEQ
case labels.MatchRegexp:
mType = prompb.LabelMatcher_RE
case labels.MatchNotRegexp:
mType = prompb.LabelMatcher_NRE
default:
return nil, errors.New("invalid matcher type")
}
pbMatchers = append(pbMatchers, &prompb.LabelMatcher{
Type: mType,
Name: m.Name,
Value: m.Value,
})
}
return pbMatchers, nil
}
Loading

0 comments on commit edeb04e

Please sign in to comment.