Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Leave overflow logic to go-msgio #13

Merged
merged 1 commit into from
Jan 30, 2016
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 61 additions & 37 deletions p2p/crypto/secio/rw.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,6 @@ import (
context "gx/ipfs/QmZy2y8t9zQH2a1b8q2ZSLKp17ATuJoCNxxyMFG5qFExpt/go-net/context"
)

const MaxMsgSize = 8 * 1024 * 1024

var ErrMaxMessageSize = errors.New("attempted to read message larger than max size")

// ErrMACInvalid signals that a MAC verification failed
var ErrMACInvalid = errors.New("MAC verification failed")

Expand Down Expand Up @@ -85,9 +81,13 @@ type etmReader struct {
msgio.Reader
io.Closer

// buffer
// internal buffer returned from the msgio
buf []byte

// low and high watermark for the buffered data
lowat int
hiwat int

// params
msg msgio.ReadCloser // msgio for knowing where boundaries lie
str cipher.Stream // the stream cipher to encrypt with
Expand All @@ -105,67 +105,91 @@ func (r *etmReader) NextMsgLen() (int, error) {
return r.msg.NextMsgLen()
}

func (r *etmReader) drainBuf(buf []byte) int {
if r.buf == nil {
func (r *etmReader) drain(buf []byte) int {
// Return zero if there is no data remaining in the internal buffer.
if r.lowat == r.hiwat {
return 0
}

n := copy(buf, r.buf)
r.buf = r.buf[n:]
// Copy data to the output buffer.
n := copy(buf, r.buf[r.lowat:r.hiwat])

// Update the low watermark.
r.lowat += n

// Release the buffer and reset the watermarks if it has been fully read.
if r.lowat == r.hiwat {
r.msg.ReleaseMsg(r.buf)
r.buf = nil
r.lowat = 0
r.hiwat = 0
}

return n
}

func (r *etmReader) fill() error {
// Read a message from the underlying msgio.
msg, err := r.msg.ReadMsg()
if err != nil {
return err
}

// Check the MAC.
n, err := r.macCheckThenDecrypt(msg)
if err != nil {
r.msg.ReleaseMsg(msg)
return err
}

// Retain the buffer so it can be drained from and later released.
r.buf = msg
r.lowat = 0
r.hiwat = n

return nil
}

func (r *etmReader) Read(buf []byte) (int, error) {
r.Lock()
defer r.Unlock()

// first, check if we have anything in the buffer
copied := r.drainBuf(buf)
buf = buf[copied:]
// Return buffered data without reading more, if possible.
copied := r.drain(buf)
if copied > 0 {
return copied, nil
// return here to avoid complicating the rest...
// user can call io.ReadFull.
}

// check the buffer has enough space for the next msg
// Check the length of the next message.
fullLen, err := r.msg.NextMsgLen()
if err != nil {
return 0, err
}

if fullLen > MaxMsgSize {
return 0, ErrMaxMessageSize
}

buf2 := buf
changed := false
// if not enough space, allocate a new buffer.
// If the destination buffer is too short, fill an internal buffer and then
// drain as much of that into the output buffer as will fit.
if cap(buf) < fullLen {
buf2 = make([]byte, fullLen)
changed = true
err := r.fill()
if err != nil {
return 0, err
}

copied := r.drain(buf)
return copied, nil
}
buf2 = buf2[:fullLen]

n, err := io.ReadFull(r.msg, buf2)
// Otherwise, read directly into the destination buffer.
n, err := io.ReadFull(r.msg, buf[:fullLen])
if err != nil {
return n, err
return 0, err
}

m, err := r.macCheckThenDecrypt(buf2)
m, err := r.macCheckThenDecrypt(buf[:n])
if err != nil {
return 0, err
}
buf2 = buf2[:m]
if !changed {
return m, nil
}

n = copy(buf, buf2)
if len(buf2) > len(buf) {
r.buf = buf2[len(buf):] // had some left over? save it.
}
return n, nil
return m, nil
}

func (r *etmReader) ReadMsg() ([]byte, error) {
Expand Down