Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Query-frontend: enforce limitsMiddleware for remote read requests #8374

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
270 changes: 186 additions & 84 deletions pkg/frontend/querymiddleware/limits_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package querymiddleware

import (
"context"
"errors"
"fmt"
"net/http"
"strings"
Expand All @@ -16,6 +17,9 @@ import (

"github.com/go-kit/log"
"github.com/grafana/dskit/user"
"github.com/prometheus/prometheus/model/labels"
"github.com/prometheus/prometheus/prompb"
"github.com/prometheus/prometheus/promql/parser"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -132,38 +136,51 @@ func TestLimitsMiddleware_MaxQueryLookback(t *testing.T) {

for testName, testData := range tests {
t.Run(testName, func(t *testing.T) {
req := &PrometheusRangeQueryRequest{
start: util.TimeToMillis(testData.reqStartTime),
end: util.TimeToMillis(testData.reqEndTime),
reqs := map[string]MetricsQueryRequest{
"range query": &PrometheusRangeQueryRequest{
start: util.TimeToMillis(testData.reqStartTime),
end: util.TimeToMillis(testData.reqEndTime),
},
"remote read": &remoteReadQueryRequest{
path: remoteReadPathSuffix,
query: &prompb.Query{
StartTimestampMs: util.TimeToMillis(testData.reqStartTime),
EndTimestampMs: util.TimeToMillis(testData.reqEndTime),
},
},
}

limits := mockLimits{maxQueryLookback: testData.maxQueryLookback, compactorBlocksRetentionPeriod: testData.blocksRetentionPeriod}
middleware := newLimitsMiddleware(limits, log.NewNopLogger())

innerRes := newEmptyPrometheusResponse()
inner := &mockHandler{}
inner.On("Do", mock.Anything, mock.Anything).Return(innerRes, nil)

ctx := user.InjectOrgID(context.Background(), "test")
outer := middleware.Wrap(inner)
res, err := outer.Do(ctx, req)
require.NoError(t, err)

if testData.expectedSkipped {
// We expect an empty response, but not the one returned by the inner handler
// which we expect has been skipped.
assert.NotSame(t, innerRes, res)
assert.Len(t, inner.Calls, 0)
} else {
// We expect the response returned by the inner handler.
assert.Same(t, innerRes, res)

// Assert on the time range of the request passed to the inner handler (5s delta).
delta := float64(5000)
require.Len(t, inner.Calls, 1)

assert.InDelta(t, util.TimeToMillis(testData.expectedStartTime), inner.Calls[0].Arguments.Get(1).(MetricsQueryRequest).GetStart(), delta)
assert.InDelta(t, util.TimeToMillis(testData.expectedEndTime), inner.Calls[0].Arguments.Get(1).(MetricsQueryRequest).GetEnd(), delta)
for reqType, req := range reqs {
t.Run(reqType, func(t *testing.T) {
limits := mockLimits{maxQueryLookback: testData.maxQueryLookback, compactorBlocksRetentionPeriod: testData.blocksRetentionPeriod}
middleware := newLimitsMiddleware(limits, log.NewNopLogger())

innerRes := newEmptyPrometheusResponse()
inner := &mockHandler{}
inner.On("Do", mock.Anything, mock.Anything).Return(innerRes, nil)

ctx := user.InjectOrgID(context.Background(), "test")
outer := middleware.Wrap(inner)
res, err := outer.Do(ctx, req)
require.NoError(t, err)

if testData.expectedSkipped {
// We expect an empty response, but not the one returned by the inner handler
// which we expect has been skipped.
assert.NotSame(t, innerRes, res)
assert.Len(t, inner.Calls, 0)
} else {
// We expect the response returned by the inner handler.
assert.Same(t, innerRes, res)

// Assert on the time range of the request passed to the inner handler (5s delta).
delta := float64(5000)
require.Len(t, inner.Calls, 1)

assert.InDelta(t, util.TimeToMillis(testData.expectedStartTime), inner.Calls[0].Arguments.Get(1).(MetricsQueryRequest).GetStart(), delta)
assert.InDelta(t, util.TimeToMillis(testData.expectedEndTime), inner.Calls[0].Arguments.Get(1).(MetricsQueryRequest).GetEnd(), delta)
}
})
}
})
}
Expand Down Expand Up @@ -206,34 +223,64 @@ func TestLimitsMiddleware_MaxQueryExpressionSizeBytes(t *testing.T) {

for testName, testData := range tests {
t.Run(testName, func(t *testing.T) {
req := &PrometheusRangeQueryRequest{
queryExpr: parseQuery(t, testData.query),
start: util.TimeToMillis(now.Add(-time.Hour * 2)),
end: util.TimeToMillis(now.Add(-time.Hour)),
}

limits := multiTenantMockLimits{
byTenant: map[string]mockLimits{
"test1": {maxQueryExpressionSizeBytes: testData.queryLimits["test1"]},
"test2": {maxQueryExpressionSizeBytes: testData.queryLimits["test2"]},
startMs := util.TimeToMillis(now.Add(-time.Hour * 2))
endMs := util.TimeToMillis(now.Add(-time.Hour))

reqs := map[string]MetricsQueryRequest{
"range query": &PrometheusRangeQueryRequest{
queryExpr: parseQuery(t, testData.query),
start: startMs,
end: endMs,
},
"remote read": &remoteReadQueryRequest{
path: remoteReadPathSuffix,
promQuery: testData.query,
query: &prompb.Query{
StartTimestampMs: startMs,
EndTimestampMs: endMs,
Matchers: func() []*prompb.LabelMatcher {
v := &findVectorSelectorsVisitor{}
require.NoError(t, parser.Walk(v, parseQuery(t, testData.query), nil))

// This test requires the query has only 1 vector selector.
require.Len(t, v.selectors, 1)
require.NotEmpty(t, v.selectors[0].LabelMatchers)

matchers, err := toLabelMatchers(v.selectors[0].LabelMatchers)
require.NoError(t, err)

return matchers
}(),
},
},
}
middleware := newLimitsMiddleware(limits, log.NewNopLogger())

innerRes := newEmptyPrometheusResponse()
inner := &mockHandler{}
inner.On("Do", mock.Anything, mock.Anything).Return(innerRes, nil)

ctx := user.InjectOrgID(context.Background(), "test1|test2")
outer := middleware.Wrap(inner)
res, err := outer.Do(ctx, req)

if testData.expectError {
require.Error(t, err)
require.Contains(t, err.Error(), "err-mimir-max-query-expression-size-bytes")
} else {
require.NoError(t, err)
require.Same(t, innerRes, res)

for reqType, req := range reqs {
t.Run(reqType, func(t *testing.T) {
limits := multiTenantMockLimits{
byTenant: map[string]mockLimits{
"test1": {maxQueryExpressionSizeBytes: testData.queryLimits["test1"]},
"test2": {maxQueryExpressionSizeBytes: testData.queryLimits["test2"]},
},
}
middleware := newLimitsMiddleware(limits, log.NewNopLogger())

innerRes := newEmptyPrometheusResponse()
inner := &mockHandler{}
inner.On("Do", mock.Anything, mock.Anything).Return(innerRes, nil)

ctx := user.InjectOrgID(context.Background(), "test1|test2")
outer := middleware.Wrap(inner)
res, err := outer.Do(ctx, req)

if testData.expectError {
require.Error(t, err)
require.Contains(t, err.Error(), "err-mimir-max-query-expression-size-bytes")
} else {
require.NoError(t, err)
require.Same(t, innerRes, res)
}
})
}
})
}
Expand Down Expand Up @@ -295,36 +342,50 @@ func TestLimitsMiddleware_MaxQueryLength(t *testing.T) {

for testName, testData := range tests {
t.Run(testName, func(t *testing.T) {
req := &PrometheusRangeQueryRequest{
start: util.TimeToMillis(testData.reqStartTime),
end: util.TimeToMillis(testData.reqEndTime),
reqs := map[string]MetricsQueryRequest{
"range query": &PrometheusRangeQueryRequest{
start: util.TimeToMillis(testData.reqStartTime),
end: util.TimeToMillis(testData.reqEndTime),
},
"remote read": &remoteReadQueryRequest{
path: remoteReadPathSuffix,
promQuery: "",
query: &prompb.Query{
StartTimestampMs: util.TimeToMillis(testData.reqStartTime),
EndTimestampMs: util.TimeToMillis(testData.reqEndTime),
},
},
}

limits := mockLimits{maxQueryLength: testData.maxQueryLength, maxTotalQueryLength: testData.maxTotalQueryLength}
middleware := newLimitsMiddleware(limits, log.NewNopLogger())

innerRes := newEmptyPrometheusResponse()
inner := &mockHandler{}
inner.On("Do", mock.Anything, mock.Anything).Return(innerRes, nil)

ctx := user.InjectOrgID(context.Background(), "test")
outer := middleware.Wrap(inner)
res, err := outer.Do(ctx, req)

if testData.expectedErr != "" {
require.Error(t, err)
assert.Contains(t, err.Error(), testData.expectedErr)
assert.Nil(t, res)
assert.Len(t, inner.Calls, 0)
} else {
// We expect the response returned by the inner handler.
require.NoError(t, err)
assert.Same(t, innerRes, res)

// The time range of the request passed to the inner handler should have not been manipulated.
require.Len(t, inner.Calls, 1)
assert.Equal(t, util.TimeToMillis(testData.reqStartTime), inner.Calls[0].Arguments.Get(1).(MetricsQueryRequest).GetStart())
assert.Equal(t, util.TimeToMillis(testData.reqEndTime), inner.Calls[0].Arguments.Get(1).(MetricsQueryRequest).GetEnd())
for reqType, req := range reqs {
t.Run(reqType, func(t *testing.T) {
limits := mockLimits{maxQueryLength: testData.maxQueryLength, maxTotalQueryLength: testData.maxTotalQueryLength}
middleware := newLimitsMiddleware(limits, log.NewNopLogger())

innerRes := newEmptyPrometheusResponse()
inner := &mockHandler{}
inner.On("Do", mock.Anything, mock.Anything).Return(innerRes, nil)

ctx := user.InjectOrgID(context.Background(), "test")
outer := middleware.Wrap(inner)
res, err := outer.Do(ctx, req)

if testData.expectedErr != "" {
require.Error(t, err)
assert.Contains(t, err.Error(), testData.expectedErr)
assert.Nil(t, res)
assert.Len(t, inner.Calls, 0)
} else {
// We expect the response returned by the inner handler.
require.NoError(t, err)
assert.Same(t, innerRes, res)

// The time range of the request passed to the inner handler should have not been manipulated.
require.Len(t, inner.Calls, 1)
assert.Equal(t, util.TimeToMillis(testData.reqStartTime), inner.Calls[0].Arguments.Get(1).(MetricsQueryRequest).GetStart())
assert.Equal(t, util.TimeToMillis(testData.reqEndTime), inner.Calls[0].Arguments.Get(1).(MetricsQueryRequest).GetEnd())
}
})
}
})
}
Expand Down Expand Up @@ -777,3 +838,44 @@ func TestSmallestPositiveNonZeroDuration(t *testing.T) {
assert.Equal(t, time.Duration(1), smallestPositiveNonZeroDuration(0, 1, -1))
assert.Equal(t, time.Duration(1), smallestPositiveNonZeroDuration(0, 2, 1))
}

type findVectorSelectorsVisitor struct {
selectors []*parser.VectorSelector
}

func (v *findVectorSelectorsVisitor) Visit(node parser.Node, _ []parser.Node) (parser.Visitor, error) {
selector, ok := node.(*parser.VectorSelector)
if !ok {
return v, nil
}

v.selectors = append(v.selectors, selector)
return v, nil
}

// This function has been copied from:
// https://github.com/prometheus/prometheus/blob/5efc8dd27b6e68d5102b77bc708e52c9821c5101/storage/remote/codec.go#L569
func toLabelMatchers(matchers []*labels.Matcher) ([]*prompb.LabelMatcher, error) {
pbMatchers := make([]*prompb.LabelMatcher, 0, len(matchers))
for _, m := range matchers {
var mType prompb.LabelMatcher_Type
switch m.Type {
case labels.MatchEqual:
mType = prompb.LabelMatcher_EQ
case labels.MatchNotEqual:
mType = prompb.LabelMatcher_NEQ
case labels.MatchRegexp:
mType = prompb.LabelMatcher_RE
case labels.MatchNotRegexp:
mType = prompb.LabelMatcher_NRE
default:
return nil, errors.New("invalid matcher type")
}
pbMatchers = append(pbMatchers, &prompb.LabelMatcher{
Type: mType,
Name: m.Name,
Value: m.Value,
})
}
return pbMatchers, nil
}
Loading
Loading