Skip to content

Commit

Permalink
add test to demonstrate that the middleware streams the response body…
Browse files Browse the repository at this point in the history
… contents
  • Loading branch information
flxbk committed Jan 3, 2024
1 parent 4d8a08d commit 211723d
Showing 1 changed file with 83 additions and 0 deletions.
83 changes: 83 additions & 0 deletions pkg/frontend/querymiddleware/shard_active_series_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,89 @@ func Test_shardActiveSeriesMiddleware_RoundTrip(t *testing.T) {
}
}

func Test_shardActiveSeriesMiddleware_RoundTrip_ResponseBodyStreamed(t *testing.T) {
// This value needs to be set at least as large as the buffer size used by the
// actual code for this test to make sense.
const bufferSize = 512
const shardCount = 2

// Stub upstream with two responses that are larger than the buffer size and
// retain a reference to the response bodies. The responses use a custom body
// type that counts the number of bytes read, so we can assert on that later in
// the test.
var upstreamResponseBodies [shardCount]*bodyReadBytesCounter
var responseSize [shardCount]int
upstream := RoundTripFunc(func(r *http.Request) (*http.Response, error) {
// Extract requested shard index
require.NoError(t, r.ParseForm())
req, err := cardinality.DecodeActiveSeriesRequestFromValues(r.Form)
require.NoError(t, err)
shard, _, err := sharding.ShardFromMatchers(req.Matchers)
require.NoError(t, err)
require.NotNil(t, shard, "this test requires a shard to be requested")

// Make sure the response body is big enough to not be buffered entirely.
response := fmt.Sprintf(fmt.Sprintf(`{"data": [{"__name__": "metric-%%%dd"}]}`, bufferSize), shard.ShardIndex)
body := &bodyReadBytesCounter{body: io.NopCloser(strings.NewReader(response))}
upstreamResponseBodies[shard.ShardIndex] = body
responseSize[shard.ShardIndex] = len(response)

return &http.Response{StatusCode: http.StatusOK, Body: body}, nil
})

s := newShardActiveSeriesMiddleware(
upstream,
mockLimits{maxShardedQueries: shardCount, totalShards: shardCount},
log.NewNopLogger(),
)

r := httptest.NewRequest("POST", "/active_series", strings.NewReader(`selector={__name__=~"metric-.*"}`))
r.Header.Add("Content-Type", "application/x-www-form-urlencoded")

resp, err := s.RoundTrip(r.WithContext(user.InjectOrgID(r.Context(), "test")))
require.NoError(t, err)
defer func(body io.ReadCloser) {
_, _ = io.ReadAll(body)
_ = body.Close()
}(resp.Body)

// Check that upstream responses have been read only up to a max of the buffer size.
for _, body := range upstreamResponseBodies {
bytesRead := int(body.BytesRead())
require.GreaterOrEqual(t, bufferSize, bytesRead)
}

// Read and close the response body.
_, _ = io.ReadAll(resp.Body)
_ = resp.Body.Close()

// Check that upstream responses have been fully read now.
for i, body := range upstreamResponseBodies {
bytesRead := int(body.BytesRead())
assert.Equal(t, responseSize[i], bytesRead)
}
}

// bodyReadBytesCounter is a wrapper around a response body that counts the number of bytes read from it.
type bodyReadBytesCounter struct {
body io.ReadCloser
bytesRead atomic.Uint64
}

func (b *bodyReadBytesCounter) Read(p []byte) (n int, err error) {
read, err := b.body.Read(p)
b.bytesRead.Add(uint64(read))
return read, err
}

func (b *bodyReadBytesCounter) Close() error {
return b.body.Close()
}

func (b *bodyReadBytesCounter) BytesRead() uint64 {
return b.bytesRead.Load()
}

type result struct {
Data []labels.Labels `json:"data"`
Status string `json:"status,omitempty"`
Expand Down

0 comments on commit 211723d

Please sign in to comment.