Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pool flate readers #195

Merged
merged 1 commit into from
Dec 27, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 41 additions & 6 deletions compression.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,22 @@ import (

var (
flateWriterPool = sync.Pool{}
flateReaderPool = sync.Pool{}
)

func decompressNoContextTakeover(r io.Reader) io.Reader {
func decompressNoContextTakeover(r io.Reader) io.ReadCloser {
const tail =
// Add four bytes as specified in RFC
"\x00\x00\xff\xff" +
// Add final block to squelch unexpected EOF error from flate reader.
"\x01\x00\x00\xff\xff"
return flate.NewReader(io.MultiReader(r, strings.NewReader(tail)))

i := flateReaderPool.Get()
if i == nil {
i = flate.NewReader(nil)
}
i.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil)
return &flateReadWrapper{i.(io.ReadCloser)}
}

func compressNoContextTakeover(w io.WriteCloser) (io.WriteCloser, error) {
Expand All @@ -36,7 +43,7 @@ func compressNoContextTakeover(w io.WriteCloser) (io.WriteCloser, error) {
fw = i.(*flate.Writer)
fw.Reset(tw)
}
return &flateWrapper{fw: fw, tw: tw}, err
return &flateWriteWrapper{fw: fw, tw: tw}, err
}

// truncWriter is an io.Writer that writes all but the last four bytes of the
Expand Down Expand Up @@ -75,19 +82,19 @@ func (w *truncWriter) Write(p []byte) (int, error) {
return n + nn, err
}

type flateWrapper struct {
type flateWriteWrapper struct {
fw *flate.Writer
tw *truncWriter
}

func (w *flateWrapper) Write(p []byte) (int, error) {
func (w *flateWriteWrapper) Write(p []byte) (int, error) {
if w.fw == nil {
return 0, errWriteClosed
}
return w.fw.Write(p)
}

func (w *flateWrapper) Close() error {
func (w *flateWriteWrapper) Close() error {
if w.fw == nil {
return errWriteClosed
}
Expand All @@ -103,3 +110,31 @@ func (w *flateWrapper) Close() error {
}
return err2
}

type flateReadWrapper struct {
fr io.ReadCloser
}

func (r *flateReadWrapper) Read(p []byte) (int, error) {
if r.fr == nil {
return 0, io.ErrClosedPipe
}
n, err := r.fr.Read(p)
if err == io.EOF {
// Preemptively place the reader back in the pool. This helps with
// scenarios where the application does not call NextReader() soon after
// this final read.
r.Close()
}
return n, err
}

func (r *flateReadWrapper) Close() error {
if r.fr == nil {
return io.ErrClosedPipe
}
err := r.fr.Close()
flateReaderPool.Put(r.fr)
r.fr = nil
return err
}
18 changes: 14 additions & 4 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ type Conn struct {
newCompressionWriter func(io.WriteCloser) (io.WriteCloser, error)

// Read fields
reader io.ReadCloser // the current reader returned to the application
readErr error
br *bufio.Reader
readRemaining int64 // bytes remaining in current frame.
Expand All @@ -253,7 +254,7 @@ type Conn struct {
messageReader *messageReader // the current low-level reader

readDecompress bool // whether last read frame had RSV1 set
newDecompressionReader func(io.Reader) io.Reader
newDecompressionReader func(io.Reader) io.ReadCloser
}

func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn {
Expand Down Expand Up @@ -855,6 +856,11 @@ func (c *Conn) handleProtocolError(message string) error {
// permanent. Once this method returns a non-nil error, all subsequent calls to
// this method return the same error.
func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
// Close previous reader, only relevant for decompression.
if c.reader != nil {
c.reader.Close()
c.reader = nil
}

c.messageReader = nil
c.readLength = 0
Expand All @@ -867,11 +873,11 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
}
if frameType == TextMessage || frameType == BinaryMessage {
c.messageReader = &messageReader{c}
var r io.Reader = c.messageReader
c.reader = c.messageReader
if c.readDecompress {
r = c.newDecompressionReader(r)
c.reader = c.newDecompressionReader(c.reader)
}
return frameType, r, nil
return frameType, c.reader, nil
}
}

Expand Down Expand Up @@ -933,6 +939,10 @@ func (r *messageReader) Read(b []byte) (int, error) {
return 0, err
}

func (r *messageReader) Close() error {
return nil
}

// ReadMessage is a helper method for getting a reader using NextReader and
// reading from that reader to a buffer.
func (c *Conn) ReadMessage() (messageType int, p []byte, err error) {
Expand Down