diff --git a/pkg/frontend/querymiddleware/shard_active_series_test.go b/pkg/frontend/querymiddleware/shard_active_series_test.go index 4ab4d891530..5f09fce37b7 100644 --- a/pkg/frontend/querymiddleware/shard_active_series_test.go +++ b/pkg/frontend/querymiddleware/shard_active_series_test.go @@ -328,6 +328,89 @@ func Test_shardActiveSeriesMiddleware_RoundTrip(t *testing.T) { } } +func Test_shardActiveSeriesMiddleware_RoundTrip_ResponseBodyStreamed(t *testing.T) { + // This value needs to be set at least as large as the buffer size used by the + // actual code for this test to make sense. + const bufferSize = 512 + const shardCount = 2 + + // Stub upstream with two responses that are larger than the buffer size and + // retain a reference to the response bodies. The responses use a custom body + // type that counts the number of bytes read, so we can assert on that later in + // the test. + var upstreamResponseBodies [shardCount]*bodyReadBytesCounter + var responseSize [shardCount]int + upstream := RoundTripFunc(func(r *http.Request) (*http.Response, error) { + // Extract requested shard index + require.NoError(t, r.ParseForm()) + req, err := cardinality.DecodeActiveSeriesRequestFromValues(r.Form) + require.NoError(t, err) + shard, _, err := sharding.ShardFromMatchers(req.Matchers) + require.NoError(t, err) + require.NotNil(t, shard, "this test requires a shard to be requested") + + // Make sure the response body is big enough to not be buffered entirely. + response := fmt.Sprintf(fmt.Sprintf(`{"data": [{"__name__": "metric-%%%dd"}]}`, bufferSize), shard.ShardIndex) + body := &bodyReadBytesCounter{body: io.NopCloser(strings.NewReader(response))} + upstreamResponseBodies[shard.ShardIndex] = body + responseSize[shard.ShardIndex] = len(response) + + return &http.Response{StatusCode: http.StatusOK, Body: body}, nil + }) + + s := newShardActiveSeriesMiddleware( + upstream, + mockLimits{maxShardedQueries: shardCount, totalShards: shardCount}, + log.NewNopLogger(), + ) + + r := httptest.NewRequest("POST", "/active_series", strings.NewReader(`selector={__name__=~"metric-.*"}`)) + r.Header.Add("Content-Type", "application/x-www-form-urlencoded") + + resp, err := s.RoundTrip(r.WithContext(user.InjectOrgID(r.Context(), "test"))) + require.NoError(t, err) + defer func(body io.ReadCloser) { + _, _ = io.ReadAll(body) + _ = body.Close() + }(resp.Body) + + // Check that upstream responses have been read only up to a max of the buffer size. + for _, body := range upstreamResponseBodies { + bytesRead := int(body.BytesRead()) + require.GreaterOrEqual(t, bufferSize, bytesRead) + } + + // Read and close the response body. + _, _ = io.ReadAll(resp.Body) + _ = resp.Body.Close() + + // Check that upstream responses have been fully read now. + for i, body := range upstreamResponseBodies { + bytesRead := int(body.BytesRead()) + assert.Equal(t, responseSize[i], bytesRead) + } +} + +// bodyReadBytesCounter is a wrapper around a response body that counts the number of bytes read from it. +type bodyReadBytesCounter struct { + body io.ReadCloser + bytesRead atomic.Uint64 +} + +func (b *bodyReadBytesCounter) Read(p []byte) (n int, err error) { + read, err := b.body.Read(p) + b.bytesRead.Add(uint64(read)) + return read, err +} + +func (b *bodyReadBytesCounter) Close() error { + return b.body.Close() +} + +func (b *bodyReadBytesCounter) BytesRead() uint64 { + return b.bytesRead.Load() +} + type result struct { Data []labels.Labels `json:"data"` Status string `json:"status,omitempty"`