Skip to content

Commit

Permalink
Refactor: Move the maybe-gzip-reader out to util.
Browse files Browse the repository at this point in the history
Signed-off-by: Philip Conrad <[email protected]>
  • Loading branch information
philipaconrad committed Jun 25, 2024
1 parent 4caaa64 commit cb39173
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 57 deletions.
21 changes: 1 addition & 20 deletions server/authorizer/authorizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
package authorizer

import (
"bytes"
"compress/gzip"
"context"
"io"
"net/http"
Expand Down Expand Up @@ -165,7 +163,7 @@ func makeInput(r *http.Request) (*http.Request, interface{}, error) {

if expectBody(r.Method, path) {
var err error
plaintextBody, err := readPlainBody(r)
plaintextBody, err := util.ReadMaybeCompressedBody(r)
if err != nil {
return r, nil, err
}
Expand Down Expand Up @@ -282,20 +280,3 @@ func GetBodyOnContext(ctx context.Context) (interface{}, bool) {
}
return input.parsed, true
}

// Note(philipc): Copied over from server/server.go
func readPlainBody(r *http.Request) (io.ReadCloser, error) {
if strings.Contains(r.Header.Get("Content-Encoding"), "gzip") {
gzReader, err := gzip.NewReader(r.Body)
if err != nil {
return nil, err
}
bytesBody, err := io.ReadAll(gzReader)
if err != nil {
return nil, err
}
defer gzReader.Close()
return io.NopCloser(bytes.NewReader(bytesBody)), err
}
return r.Body, nil
}
23 changes: 3 additions & 20 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ package server

import (
"bytes"
"compress/gzip"
"context"
"crypto/tls"
"crypto/x509"
Expand Down Expand Up @@ -1337,7 +1336,7 @@ func (s *Server) v1CompilePost(w http.ResponseWriter, r *http.Request) {
m.Timer(metrics.RegoQueryParse).Start()

// decompress the input if sent as zip
body, err := readPlainBody(r)
body, err := util.ReadMaybeCompressedBody(r)
if err != nil {
writer.Error(w, http.StatusBadRequest, types.NewErrorV1(types.CodeInvalidParameter, "could not decompress the body"))
return
Expand Down Expand Up @@ -2732,7 +2731,7 @@ func readInputV0(r *http.Request) (ast.Value, *interface{}, error) {
}

// decompress the input if sent as zip
body, err := readPlainBody(r)
body, err := util.ReadMaybeCompressedBody(r)
if err != nil {
return nil, nil, fmt.Errorf("could not decompress the body: %w", err)
}
Expand Down Expand Up @@ -2785,7 +2784,7 @@ func readInputPostV1(r *http.Request) (ast.Value, *interface{}, error) {
var request types.DataRequestV1

// decompress the input if sent as zip
body, err := readPlainBody(r)
body, err := util.ReadMaybeCompressedBody(r)
if err != nil {
return nil, nil, fmt.Errorf("could not decompress the body: %w", err)
}
Expand Down Expand Up @@ -3023,22 +3022,6 @@ func annotateSpan(ctx context.Context, decisionID string) {
SetAttributes(attribute.String(otelDecisionIDAttr, decisionID))
}

func readPlainBody(r *http.Request) (io.ReadCloser, error) {
if strings.Contains(r.Header.Get("Content-Encoding"), "gzip") {
gzReader, err := gzip.NewReader(r.Body)
if err != nil {
return nil, err
}
bytesBody, err := io.ReadAll(gzReader)
if err != nil {
return nil, err
}
defer gzReader.Close()
return io.NopCloser(bytes.NewReader(bytesBody)), err
}
return r.Body, nil
}

func pretty(r *http.Request) bool {
return getBoolParam(r.URL, types.ParamPrettyV1, true)
}
Expand Down
25 changes: 8 additions & 17 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1682,12 +1682,12 @@ func TestDataPutV1IfNoneMatch(t *testing.T) {
}
}

// Ensure JSON payload is compressed with gzip.
func mustGZIPPayload(payload []byte) []byte {
// Compress JSON payload with gzip
var compressedPayload bytes.Buffer
gz := gzip.NewWriter(&compressedPayload)
if _, err := gz.Write(payload); err != nil {
panic(fmt.Errorf("Error closing gzip writer: %w", err))
panic(fmt.Errorf("Error writing to gzip writer: %w", err))
}
if err := gz.Close(); err != nil {
panic(fmt.Errorf("Error closing gzip writer: %w", err))
Expand All @@ -1702,38 +1702,32 @@ func TestDataGetV1CompressedRequestWithAuthorizer(t *testing.T) {
payload []byte
forcePayloadSizeField uint32 // Size to manually set the payload field for the gzip blob.
expRespHTTPStatus int
authzEnabled bool
}{
{
note: "empty message",
payload: mustGZIPPayload([]byte{}),
expRespHTTPStatus: 401,
authzEnabled: true,
},
{
note: "empty object",
payload: mustGZIPPayload([]byte(`{}`)),
expRespHTTPStatus: 401,
authzEnabled: true,
},
{
note: "basic authz - fail",
payload: mustGZIPPayload([]byte(`{"user": "bob"}`)),
expRespHTTPStatus: 401,
authzEnabled: true,
},
{
note: "basic authz - pass",
payload: mustGZIPPayload([]byte(`{"user": "alice"}`)),
expRespHTTPStatus: 200,
authzEnabled: true,
},
{
note: "basic authz - malicious size field",
payload: mustGZIPPayload([]byte(`{"user": "alice"}`)),
expRespHTTPStatus: 200,
forcePayloadSizeField: 134217728, // 128 MB
authzEnabled: true,
},
}

Expand Down Expand Up @@ -1762,28 +1756,25 @@ allow if {
panic(err)
}

opts := [](func(*Server)){func(s *Server) {
s.WithStore(store)
}}
if test.authzEnabled {
opts = append(opts, func(s *Server) {
opts := [](func(*Server)){
func(s *Server) {
s.WithStore(store)
},
func(s *Server) {
s.WithAuthorization(AuthorizationBasic)
})
},
}

f := newFixtureWithConfig(t, fmt.Sprintf(`{"server":{"decision_logs": %t}}`, true), opts...)

// execute the request
req := newReqV1(http.MethodPost, "/data/test", string(test.payload))
// req.Header.Set("Content-Type", "application/json")
req.Header.Set("Content-Encoding", "gzip")
f.reset()
f.server.Handler.ServeHTTP(f.recorder, req)
if f.recorder.Code != test.expRespHTTPStatus {
t.Fatalf("Unexpected HTTP status code, (exp,got): %d, %d", test.expRespHTTPStatus, f.recorder.Code)
}
fmt.Println(f.recorder.Body)
// panic("AAA")
})
}
}
Expand Down
26 changes: 26 additions & 0 deletions util/read_gzip_body.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package util

import (
"bytes"
"compress/gzip"
"io"
"net/http"
"strings"
)

// Note(philipc): Originally taken from server/server.go
func ReadMaybeCompressedBody(r *http.Request) (io.ReadCloser, error) {
if strings.Contains(r.Header.Get("Content-Encoding"), "gzip") {
gzReader, err := gzip.NewReader(r.Body)
if err != nil {
return nil, err
}
defer gzReader.Close()
bytesBody, err := io.ReadAll(gzReader)
if err != nil {
return nil, err
}
return io.NopCloser(bytes.NewReader(bytesBody)), err
}
return r.Body, nil
}

0 comments on commit cb39173

Please sign in to comment.