Skip to content

Commit

Permalink
rpc_util: add bytes pool for the parser
Browse files Browse the repository at this point in the history
  • Loading branch information
hueypark committed Jan 25, 2023
1 parent a6376c9 commit 493be5d
Showing 1 changed file with 37 additions and 12 deletions.
49 changes: 37 additions & 12 deletions rpc_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -574,9 +574,7 @@ func (p *parser) recvMsg(maxReceiveMessageSize int) (pf payloadFormat, msg []byt
if int(length) > maxReceiveMessageSize {
return 0, nil, status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max (%d vs. %d)", length, maxReceiveMessageSize)
}
// TODO(bradfitz,zhaoq): garbage. reuse buffer after proto decoding instead
// of making it for each message:
msg = make([]byte, int(length))
msg = pool.Get(int(length))
if _, err := p.r.Read(msg); err != nil {
if err == io.EOF {
err = io.ErrUnexpectedEOF
Expand Down Expand Up @@ -689,12 +687,12 @@ type payloadInfo struct {
}

func recvAndDecompress(p *parser, s *transport.Stream, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor) ([]byte, error) {
pf, d, err := p.recvMsg(maxReceiveMessageSize)
pf, buf, err := p.recvMsg(maxReceiveMessageSize)
if err != nil {
return nil, err
}
if payInfo != nil {
payInfo.wireLength = len(d)
payInfo.wireLength = len(buf)
}

if st := checkRecvPayload(pf, s.RecvCompress(), compressor != nil || dc != nil); st != nil {
Expand All @@ -706,10 +704,10 @@ func recvAndDecompress(p *parser, s *transport.Stream, dc Decompressor, maxRecei
// To match legacy behavior, if the decompressor is set by WithDecompressor or RPCDecompressor,
// use this decompressor as the default.
if dc != nil {
d, err = dc.Do(bytes.NewReader(d))
size = len(d)
buf, err = dc.Do(bytes.NewReader(buf))
size = len(buf)
} else {
d, size, err = decompress(compressor, d, maxReceiveMessageSize)
buf, size, err = decompress(compressor, buf, maxReceiveMessageSize)
}
if err != nil {
return nil, status.Errorf(codes.Internal, "grpc: failed to decompress the received message: %v", err)
Expand All @@ -720,7 +718,7 @@ func recvAndDecompress(p *parser, s *transport.Stream, dc Decompressor, maxRecei
return nil, status.Errorf(codes.ResourceExhausted, "grpc: received message after decompression larger than max (%d vs. %d)", size, maxReceiveMessageSize)
}
}
return d, nil
return buf, nil
}

// Using compressor, decompress d, returning data and size.
Expand Down Expand Up @@ -755,16 +753,18 @@ func decompress(compressor encoding.Compressor, d []byte, maxReceiveMessageSize
// dc takes precedence over compressor.
// TODO(dfawley): wrap the old compressor/decompressor using the new API?
func recv(p *parser, c baseCodec, s *transport.Stream, dc Decompressor, m interface{}, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor) error {
d, err := recvAndDecompress(p, s, dc, maxReceiveMessageSize, payInfo, compressor)
buf, err := recvAndDecompress(p, s, dc, maxReceiveMessageSize, payInfo, compressor)
if err != nil {
return err
}
if err := c.Unmarshal(d, m); err != nil {
if err := c.Unmarshal(buf, m); err != nil {
return status.Errorf(codes.Internal, "grpc: failed to unmarshal the received message: %v", err)
}
if payInfo != nil {
payInfo.uncompressedBytes = d
payInfo.uncompressedBytes = make([]byte, len(buf))
copy(payInfo.uncompressedBytes, buf)
}
pool.Put(&buf)
return nil
}

Expand Down Expand Up @@ -914,3 +914,28 @@ const (
)

const grpcUA = "grpc-go/" + Version

type bytesPool struct {
sync.Pool
}

func (p *bytesPool) Get(size int) []byte {
bs := p.Pool.Get().(*[]byte)
if cap(*bs) < size {
*bs = make([]byte, size)
}

return (*bs)[:size]
}

func (p *bytesPool) Put(bs *[]byte) {
p.Pool.Put(bs)
}

var pool = bytesPool{
sync.Pool{
New: func() interface{} {
return new([]byte)
},
},
}

0 comments on commit 493be5d

Please sign in to comment.