-
Notifications
You must be signed in to change notification settings - Fork 810
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve decompression within ParseProtoReader. #3682
Changes from 9 commits
f92608a
25f3e3b
22e7809
d8d7185
7f900c6
54213e7
bdb8f12
ebfdda8
7b51ccb
beadfad
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,14 +10,15 @@ import ( | |
"net/http" | ||
"strings" | ||
|
||
"github.com/blang/semver" | ||
"github.com/gogo/protobuf/proto" | ||
"github.com/golang/snappy" | ||
"github.com/opentracing/opentracing-go" | ||
otlog "github.com/opentracing/opentracing-go/log" | ||
"gopkg.in/yaml.v2" | ||
) | ||
|
||
const messageSizeLargerErrFmt = "received message larger than max (%d vs %d)" | ||
|
||
// WriteJSONResponse writes some JSON as a HTTP response. | ||
func WriteJSONResponse(w http.ResponseWriter, v interface{}) { | ||
w.Header().Set("Content-Type", "application/json") | ||
|
@@ -81,71 +82,22 @@ type CompressionType int | |
// Values for CompressionType | ||
const ( | ||
NoCompression CompressionType = iota | ||
FramedSnappy | ||
RawSnappy | ||
) | ||
|
||
var rawSnappyFromVersion = semver.MustParse("0.1.0") | ||
|
||
// CompressionTypeFor a given version of the Prometheus remote storage protocol. | ||
// See https://github.com/prometheus/prometheus/issues/2692. | ||
func CompressionTypeFor(version string) CompressionType { | ||
ver, err := semver.Make(version) | ||
if err != nil { | ||
return FramedSnappy | ||
} | ||
|
||
if ver.GTE(rawSnappyFromVersion) { | ||
return RawSnappy | ||
} | ||
return FramedSnappy | ||
} | ||
|
||
// ParseProtoReader parses a compressed proto from an io.Reader. | ||
func ParseProtoReader(ctx context.Context, reader io.Reader, expectedSize, maxSize int, req proto.Message, compression CompressionType) error { | ||
var body []byte | ||
var err error | ||
sp := opentracing.SpanFromContext(ctx) | ||
if sp != nil { | ||
sp.LogFields(otlog.String("event", "util.ParseProtoRequest[start reading]")) | ||
} | ||
var buf bytes.Buffer | ||
if expectedSize > 0 { | ||
if expectedSize > maxSize { | ||
return fmt.Errorf("message expected size larger than max (%d vs %d)", expectedSize, maxSize) | ||
} | ||
buf.Grow(expectedSize + bytes.MinRead) // extra space guarantees no reallocation | ||
} | ||
switch compression { | ||
case NoCompression: | ||
// Read from LimitReader with limit max+1. So if the underlying | ||
// reader is over limit, the result will be bigger than max. | ||
_, err = buf.ReadFrom(io.LimitReader(reader, int64(maxSize)+1)) | ||
body = buf.Bytes() | ||
case FramedSnappy: | ||
_, err = buf.ReadFrom(io.LimitReader(snappy.NewReader(reader), int64(maxSize)+1)) | ||
body = buf.Bytes() | ||
case RawSnappy: | ||
_, err = buf.ReadFrom(reader) | ||
body = buf.Bytes() | ||
if sp != nil { | ||
sp.LogFields(otlog.String("event", "util.ParseProtoRequest[decompress]"), | ||
otlog.Int("size", len(body))) | ||
} | ||
if err == nil && len(body) <= maxSize { | ||
body, err = snappy.Decode(nil, body) | ||
} | ||
} | ||
body, err := decompressRequest(reader, expectedSize, maxSize, compression, sp) | ||
if err != nil { | ||
return err | ||
} | ||
if len(body) > maxSize { | ||
return fmt.Errorf("received message larger than max (%d vs %d)", len(body), maxSize) | ||
} | ||
|
||
if sp != nil { | ||
sp.LogFields(otlog.String("event", "util.ParseProtoRequest[unmarshal]"), | ||
otlog.Int("size", len(body))) | ||
sp.LogFields(otlog.String("event", "util.ParseProtoRequest[unmarshal]"), otlog.Int("size", len(body))) | ||
} | ||
|
||
// We re-implement proto.Unmarshal here as it calls XXX_Unmarshal first, | ||
|
@@ -163,6 +115,89 @@ func ParseProtoReader(ctx context.Context, reader io.Reader, expectedSize, maxSi | |
return nil | ||
} | ||
|
||
func decompressRequest(reader io.Reader, expectedSize, maxSize int, compression CompressionType, sp opentracing.Span) (body []byte, err error) { | ||
defer func() { | ||
if err != nil && len(body) > maxSize { | ||
err = fmt.Errorf(messageSizeLargerErrFmt, len(body), maxSize) | ||
} | ||
}() | ||
if expectedSize > maxSize { | ||
return nil, fmt.Errorf(messageSizeLargerErrFmt, expectedSize, maxSize) | ||
} | ||
buffer, ok := tryBufferFromReader(reader) | ||
if ok { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it would be better to check for non-nil buffer. It would be slightly more robust, as reader implementing |
||
body, err = decompressFromBuffer(buffer, maxSize, compression, sp) | ||
return | ||
} | ||
body, err = decompressFromReader(reader, expectedSize, maxSize, compression, sp) | ||
return | ||
} | ||
|
||
func decompressFromReader(reader io.Reader, expectedSize, maxSize int, compression CompressionType, sp opentracing.Span) ([]byte, error) { | ||
var ( | ||
buf bytes.Buffer | ||
body []byte | ||
err error | ||
) | ||
if expectedSize > 0 { | ||
buf.Grow(expectedSize + bytes.MinRead) // extra space guarantees no reallocation | ||
} | ||
reader = io.LimitReader(reader, int64(maxSize)+1) | ||
switch compression { | ||
case NoCompression: | ||
// Read from LimitReader with limit max+1. So if the underlying | ||
// reader is over limit, the result will be bigger than max. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Move this comment above, where |
||
_, err = buf.ReadFrom(reader) | ||
body = buf.Bytes() | ||
case RawSnappy: | ||
_, err = buf.ReadFrom(reader) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this a possible DoS attack? (I see it is what the old code did) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could fix it up later; I don't mind merging this as-is. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah I see the potential issue, not sure if there's a better option then using a limitReader in both case ? Decoding the length still leave us open for hijacked/fake requests. I made the change let me know what you think ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think using LimitReader in all paths is correct, if the point is to stop someone blowing up the process. Is 'result bigger than max' actually detected in the NoCompression case now? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes because we read a bit more the defer in |
||
if err != nil { | ||
return nil, err | ||
} | ||
body, err = decompressFromBuffer(&buf, maxSize, RawSnappy, sp) | ||
} | ||
return body, err | ||
} | ||
|
||
func decompressFromBuffer(buffer *bytes.Buffer, maxSize int, compression CompressionType, sp opentracing.Span) ([]byte, error) { | ||
if len(buffer.Bytes()) > maxSize { | ||
return nil, fmt.Errorf(messageSizeLargerErrFmt, len(buffer.Bytes()), maxSize) | ||
} | ||
switch compression { | ||
case NoCompression: | ||
return buffer.Bytes(), nil | ||
case RawSnappy: | ||
if sp != nil { | ||
sp.LogFields(otlog.String("event", "util.ParseProtoRequest[decompress]"), | ||
otlog.Int("size", len(buffer.Bytes()))) | ||
} | ||
size, err := snappy.DecodedLen(buffer.Bytes()) | ||
if err != nil { | ||
return nil, err | ||
} | ||
if size > maxSize { | ||
return nil, fmt.Errorf(messageSizeLargerErrFmt, size, maxSize) | ||
} | ||
body, err := snappy.Decode(nil, buffer.Bytes()) | ||
if err != nil { | ||
return nil, err | ||
} | ||
return body, nil | ||
} | ||
return nil, nil | ||
} | ||
|
||
// tryBufferFromReader attempts to cast the reader to a `*bytes.Buffer` this is possible when using httpgrpc. | ||
// If it fails it will return nil and false. | ||
func tryBufferFromReader(reader io.Reader) (*bytes.Buffer, bool) { | ||
if bufReader, ok := reader.(interface { | ||
BytesBuffer() *bytes.Buffer | ||
}); ok && bufReader != nil { | ||
return bufReader.BytesBuffer(), true | ||
} | ||
return nil, false | ||
} | ||
|
||
// SerializeProtoResponse serializes a protobuf response into an HTTP response. | ||
func SerializeProtoResponse(w http.ResponseWriter, resp proto.Message, compression CompressionType) error { | ||
data, err := proto.Marshal(resp) | ||
|
@@ -173,14 +208,6 @@ func SerializeProtoResponse(w http.ResponseWriter, resp proto.Message, compressi | |
|
||
switch compression { | ||
case NoCompression: | ||
case FramedSnappy: | ||
buf := bytes.Buffer{} | ||
writer := snappy.NewBufferedWriter(&buf) | ||
if _, err := writer.Write(data); err != nil { | ||
return err | ||
} | ||
writer.Close() | ||
data = buf.Bytes() | ||
case RawSnappy: | ||
data = snappy.Encode(nil, data) | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: There is no need to use defer in this function, it just makes code more tricky to follow.