Skip to content

Commit

Permalink
Query-frontend: enforce limitsMiddleware for remote read requests (#8374
Browse files Browse the repository at this point in the history
)

* Enforce query-frontend limits middleware for remote reads too

Signed-off-by: Marco Pracucci <[email protected]>

* Fixed linter

Signed-off-by: Marco Pracucci <[email protected]>

* Remove zero value initialization

Signed-off-by: Marco Pracucci <[email protected]>

* Improve remoteReadQueryRequest.WithStartEnd() logic

Signed-off-by: Marco Pracucci <[email protected]>

* Fixed TestRemoteReadRoundTripper_ShouldAllowMiddlewaresToManipulateRequest

Signed-off-by: Marco Pracucci <[email protected]>

---------

Signed-off-by: Marco Pracucci <[email protected]>
  • Loading branch information
pracucci authored Jun 17, 2024
1 parent 5898e30 commit 5893053
Show file tree
Hide file tree
Showing 6 changed files with 604 additions and 137 deletions.
269 changes: 185 additions & 84 deletions pkg/frontend/querymiddleware/limits_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package querymiddleware

import (
"context"
"errors"
"fmt"
"net/http"
"strings"
Expand All @@ -16,6 +17,9 @@ import (

"github.com/go-kit/log"
"github.com/grafana/dskit/user"
"github.com/prometheus/prometheus/model/labels"
"github.com/prometheus/prometheus/prompb"
"github.com/prometheus/prometheus/promql/parser"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -132,38 +136,51 @@ func TestLimitsMiddleware_MaxQueryLookback(t *testing.T) {

for testName, testData := range tests {
t.Run(testName, func(t *testing.T) {
req := &PrometheusRangeQueryRequest{
start: util.TimeToMillis(testData.reqStartTime),
end: util.TimeToMillis(testData.reqEndTime),
reqs := map[string]MetricsQueryRequest{
"range query": &PrometheusRangeQueryRequest{
start: util.TimeToMillis(testData.reqStartTime),
end: util.TimeToMillis(testData.reqEndTime),
},
"remote read": &remoteReadQueryRequest{
path: remoteReadPathSuffix,
query: &prompb.Query{
StartTimestampMs: util.TimeToMillis(testData.reqStartTime),
EndTimestampMs: util.TimeToMillis(testData.reqEndTime),
},
},
}

limits := mockLimits{maxQueryLookback: testData.maxQueryLookback, compactorBlocksRetentionPeriod: testData.blocksRetentionPeriod}
middleware := newLimitsMiddleware(limits, log.NewNopLogger())

innerRes := newEmptyPrometheusResponse()
inner := &mockHandler{}
inner.On("Do", mock.Anything, mock.Anything).Return(innerRes, nil)

ctx := user.InjectOrgID(context.Background(), "test")
outer := middleware.Wrap(inner)
res, err := outer.Do(ctx, req)
require.NoError(t, err)

if testData.expectedSkipped {
// We expect an empty response, but not the one returned by the inner handler
// which we expect has been skipped.
assert.NotSame(t, innerRes, res)
assert.Len(t, inner.Calls, 0)
} else {
// We expect the response returned by the inner handler.
assert.Same(t, innerRes, res)

// Assert on the time range of the request passed to the inner handler (5s delta).
delta := float64(5000)
require.Len(t, inner.Calls, 1)

assert.InDelta(t, util.TimeToMillis(testData.expectedStartTime), inner.Calls[0].Arguments.Get(1).(MetricsQueryRequest).GetStart(), delta)
assert.InDelta(t, util.TimeToMillis(testData.expectedEndTime), inner.Calls[0].Arguments.Get(1).(MetricsQueryRequest).GetEnd(), delta)
for reqType, req := range reqs {
t.Run(reqType, func(t *testing.T) {
limits := mockLimits{maxQueryLookback: testData.maxQueryLookback, compactorBlocksRetentionPeriod: testData.blocksRetentionPeriod}
middleware := newLimitsMiddleware(limits, log.NewNopLogger())

innerRes := newEmptyPrometheusResponse()
inner := &mockHandler{}
inner.On("Do", mock.Anything, mock.Anything).Return(innerRes, nil)

ctx := user.InjectOrgID(context.Background(), "test")
outer := middleware.Wrap(inner)
res, err := outer.Do(ctx, req)
require.NoError(t, err)

if testData.expectedSkipped {
// We expect an empty response, but not the one returned by the inner handler
// which we expect has been skipped.
assert.NotSame(t, innerRes, res)
assert.Len(t, inner.Calls, 0)
} else {
// We expect the response returned by the inner handler.
assert.Same(t, innerRes, res)

// Assert on the time range of the request passed to the inner handler (5s delta).
delta := float64(5000)
require.Len(t, inner.Calls, 1)

assert.InDelta(t, util.TimeToMillis(testData.expectedStartTime), inner.Calls[0].Arguments.Get(1).(MetricsQueryRequest).GetStart(), delta)
assert.InDelta(t, util.TimeToMillis(testData.expectedEndTime), inner.Calls[0].Arguments.Get(1).(MetricsQueryRequest).GetEnd(), delta)
}
})
}
})
}
Expand Down Expand Up @@ -206,34 +223,64 @@ func TestLimitsMiddleware_MaxQueryExpressionSizeBytes(t *testing.T) {

for testName, testData := range tests {
t.Run(testName, func(t *testing.T) {
req := &PrometheusRangeQueryRequest{
queryExpr: parseQuery(t, testData.query),
start: util.TimeToMillis(now.Add(-time.Hour * 2)),
end: util.TimeToMillis(now.Add(-time.Hour)),
}

limits := multiTenantMockLimits{
byTenant: map[string]mockLimits{
"test1": {maxQueryExpressionSizeBytes: testData.queryLimits["test1"]},
"test2": {maxQueryExpressionSizeBytes: testData.queryLimits["test2"]},
startMs := util.TimeToMillis(now.Add(-time.Hour * 2))
endMs := util.TimeToMillis(now.Add(-time.Hour))

reqs := map[string]MetricsQueryRequest{
"range query": &PrometheusRangeQueryRequest{
queryExpr: parseQuery(t, testData.query),
start: startMs,
end: endMs,
},
"remote read": &remoteReadQueryRequest{
path: remoteReadPathSuffix,
promQuery: testData.query,
query: &prompb.Query{
StartTimestampMs: startMs,
EndTimestampMs: endMs,
Matchers: func() []*prompb.LabelMatcher {
v := &findVectorSelectorsVisitor{}
require.NoError(t, parser.Walk(v, parseQuery(t, testData.query), nil))

// This test requires the query has only 1 vector selector.
require.Len(t, v.selectors, 1)
require.NotEmpty(t, v.selectors[0].LabelMatchers)

matchers, err := toLabelMatchers(v.selectors[0].LabelMatchers)
require.NoError(t, err)

return matchers
}(),
},
},
}
middleware := newLimitsMiddleware(limits, log.NewNopLogger())

innerRes := newEmptyPrometheusResponse()
inner := &mockHandler{}
inner.On("Do", mock.Anything, mock.Anything).Return(innerRes, nil)

ctx := user.InjectOrgID(context.Background(), "test1|test2")
outer := middleware.Wrap(inner)
res, err := outer.Do(ctx, req)

if testData.expectError {
require.Error(t, err)
require.Contains(t, err.Error(), "err-mimir-max-query-expression-size-bytes")
} else {
require.NoError(t, err)
require.Same(t, innerRes, res)

for reqType, req := range reqs {
t.Run(reqType, func(t *testing.T) {
limits := multiTenantMockLimits{
byTenant: map[string]mockLimits{
"test1": {maxQueryExpressionSizeBytes: testData.queryLimits["test1"]},
"test2": {maxQueryExpressionSizeBytes: testData.queryLimits["test2"]},
},
}
middleware := newLimitsMiddleware(limits, log.NewNopLogger())

innerRes := newEmptyPrometheusResponse()
inner := &mockHandler{}
inner.On("Do", mock.Anything, mock.Anything).Return(innerRes, nil)

ctx := user.InjectOrgID(context.Background(), "test1|test2")
outer := middleware.Wrap(inner)
res, err := outer.Do(ctx, req)

if testData.expectError {
require.Error(t, err)
require.Contains(t, err.Error(), "err-mimir-max-query-expression-size-bytes")
} else {
require.NoError(t, err)
require.Same(t, innerRes, res)
}
})
}
})
}
Expand Down Expand Up @@ -295,36 +342,49 @@ func TestLimitsMiddleware_MaxQueryLength(t *testing.T) {

for testName, testData := range tests {
t.Run(testName, func(t *testing.T) {
req := &PrometheusRangeQueryRequest{
start: util.TimeToMillis(testData.reqStartTime),
end: util.TimeToMillis(testData.reqEndTime),
reqs := map[string]MetricsQueryRequest{
"range query": &PrometheusRangeQueryRequest{
start: util.TimeToMillis(testData.reqStartTime),
end: util.TimeToMillis(testData.reqEndTime),
},
"remote read": &remoteReadQueryRequest{
path: remoteReadPathSuffix,
query: &prompb.Query{
StartTimestampMs: util.TimeToMillis(testData.reqStartTime),
EndTimestampMs: util.TimeToMillis(testData.reqEndTime),
},
},
}

limits := mockLimits{maxQueryLength: testData.maxQueryLength, maxTotalQueryLength: testData.maxTotalQueryLength}
middleware := newLimitsMiddleware(limits, log.NewNopLogger())

innerRes := newEmptyPrometheusResponse()
inner := &mockHandler{}
inner.On("Do", mock.Anything, mock.Anything).Return(innerRes, nil)

ctx := user.InjectOrgID(context.Background(), "test")
outer := middleware.Wrap(inner)
res, err := outer.Do(ctx, req)

if testData.expectedErr != "" {
require.Error(t, err)
assert.Contains(t, err.Error(), testData.expectedErr)
assert.Nil(t, res)
assert.Len(t, inner.Calls, 0)
} else {
// We expect the response returned by the inner handler.
require.NoError(t, err)
assert.Same(t, innerRes, res)

// The time range of the request passed to the inner handler should have not been manipulated.
require.Len(t, inner.Calls, 1)
assert.Equal(t, util.TimeToMillis(testData.reqStartTime), inner.Calls[0].Arguments.Get(1).(MetricsQueryRequest).GetStart())
assert.Equal(t, util.TimeToMillis(testData.reqEndTime), inner.Calls[0].Arguments.Get(1).(MetricsQueryRequest).GetEnd())
for reqType, req := range reqs {
t.Run(reqType, func(t *testing.T) {
limits := mockLimits{maxQueryLength: testData.maxQueryLength, maxTotalQueryLength: testData.maxTotalQueryLength}
middleware := newLimitsMiddleware(limits, log.NewNopLogger())

innerRes := newEmptyPrometheusResponse()
inner := &mockHandler{}
inner.On("Do", mock.Anything, mock.Anything).Return(innerRes, nil)

ctx := user.InjectOrgID(context.Background(), "test")
outer := middleware.Wrap(inner)
res, err := outer.Do(ctx, req)

if testData.expectedErr != "" {
require.Error(t, err)
assert.Contains(t, err.Error(), testData.expectedErr)
assert.Nil(t, res)
assert.Len(t, inner.Calls, 0)
} else {
// We expect the response returned by the inner handler.
require.NoError(t, err)
assert.Same(t, innerRes, res)

// The time range of the request passed to the inner handler should have not been manipulated.
require.Len(t, inner.Calls, 1)
assert.Equal(t, util.TimeToMillis(testData.reqStartTime), inner.Calls[0].Arguments.Get(1).(MetricsQueryRequest).GetStart())
assert.Equal(t, util.TimeToMillis(testData.reqEndTime), inner.Calls[0].Arguments.Get(1).(MetricsQueryRequest).GetEnd())
}
})
}
})
}
Expand Down Expand Up @@ -777,3 +837,44 @@ func TestSmallestPositiveNonZeroDuration(t *testing.T) {
assert.Equal(t, time.Duration(1), smallestPositiveNonZeroDuration(0, 1, -1))
assert.Equal(t, time.Duration(1), smallestPositiveNonZeroDuration(0, 2, 1))
}

type findVectorSelectorsVisitor struct {
selectors []*parser.VectorSelector
}

func (v *findVectorSelectorsVisitor) Visit(node parser.Node, _ []parser.Node) (parser.Visitor, error) {
selector, ok := node.(*parser.VectorSelector)
if !ok {
return v, nil
}

v.selectors = append(v.selectors, selector)
return v, nil
}

// This function has been copied from:
// https://github.com/prometheus/prometheus/blob/5efc8dd27b6e68d5102b77bc708e52c9821c5101/storage/remote/codec.go#L569
func toLabelMatchers(matchers []*labels.Matcher) ([]*prompb.LabelMatcher, error) {
pbMatchers := make([]*prompb.LabelMatcher, 0, len(matchers))
for _, m := range matchers {
var mType prompb.LabelMatcher_Type
switch m.Type {
case labels.MatchEqual:
mType = prompb.LabelMatcher_EQ
case labels.MatchNotEqual:
mType = prompb.LabelMatcher_NEQ
case labels.MatchRegexp:
mType = prompb.LabelMatcher_RE
case labels.MatchNotRegexp:
mType = prompb.LabelMatcher_NRE
default:
return nil, errors.New("invalid matcher type")
}
pbMatchers = append(pbMatchers, &prompb.LabelMatcher{
Type: mType,
Name: m.Name,
Value: m.Value,
})
}
return pbMatchers, nil
}
Loading

0 comments on commit 5893053

Please sign in to comment.