Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Server-side CHUNKING and BINARYMIME support #104

Closed
wants to merge 6 commits into from
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
server: Implement CHUNKING extension support
foxcpp committed Jun 11, 2020

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
commit 54e1bd52b5f13e45cc5fa45f23054a15bd7addcc
127 changes: 127 additions & 0 deletions chunk_reader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
package smtp

import (
"errors"
"io"
"io/ioutil"
)

// ErrDataReset is returned by Reader pased to Data function if client does not
// send another BDAT command and instead closes connection or issues RSET command.
var ErrDataReset = errors.New("smtp: message transmission aborted")

// chunkReader implements io.Reader by consuming a sequence of size-limited
// "chunks" from underlying Reader allowing them to be interleaved with other
// protocol data.
//
// It is caller responsibility to not use Reader while chunk is being processed.
// This can be enforced by blocking on chunkEnd channel that is used to signal the
// end of another chunk being reached.
type chunkReader struct {
remainingBytes int
conn io.Reader
chunks chan int
// Sent to by abort() to unlock running Read.
rset chan struct{}
currentChunk *io.LimitedReader

chunkEnd chan struct{}
}

func (cr *chunkReader) addChunk(size int) {
cr.chunks <- size
}

func (cr *chunkReader) end() {
close(cr.chunks)
}

func (cr *chunkReader) abort() {
close(cr.rset)
close(cr.chunkEnd)
}

func (cr *chunkReader) discardCurrentChunk() error {
if cr.currentChunk == nil {
return nil
}
_, err := io.Copy(ioutil.Discard, cr.currentChunk)
return err
}

func (cr *chunkReader) waitNextChunk() error {
select {
case <-cr.rset:
return ErrDataReset
case r, ok := <-cr.chunks:
if !ok {
// Okay, that's the end.
return io.EOF
}
cr.currentChunk = &io.LimitedReader{R: cr.conn, N: int64(r)}
return nil
}
}

func (cr *chunkReader) Read(b []byte) (int, error) {
/*
Possible states:

1. We are at the start of next chunk.
cr.currentChunk == nil, cr.chunks is not closed.

2. We are in the middle of chunk.
cr.currentchunk != nil

3. Chunk ended, cr.currentChunk returns io.EOF.
Generate an 250 response and wait for the next chunk.
*/

if cr.currentChunk == nil {
if err := cr.waitNextChunk(); err != nil {
return 0, err
}
}

n, err := cr.currentChunk.Read(b)
if err == io.EOF {
cr.chunkEnd <- struct{}{}
cr.currentChunk = nil
err = nil
}

if cr.remainingBytes != 0 /* no limit */ {
cr.remainingBytes -= n
if cr.remainingBytes <= 0 {
return 0, ErrDataTooLarge
}
}

// Strip CR from slice contents.
offset := 0
for i, chr := range b {
if chr == '\r' {
offset += 1
}
if i+offset >= len(b) {
break
}
b[i] = b[i+offset]
}

// We also likely left garbage in remaining bytes but lets hope backend
// code does not assume they are intact.
return n - offset, err

}

func newChunkReader(conn io.Reader, maxBytes int) *chunkReader {
return &chunkReader{
remainingBytes: maxBytes,
conn: conn,
chunks: make(chan int, 1),
// buffer to make sure abort() will not block if Read is not running.
rset: make(chan struct{}, 1),
chunkEnd: make(chan struct{}, 1),
}
}
182 changes: 168 additions & 14 deletions conn.go
Original file line number Diff line number Diff line change
@@ -33,6 +33,9 @@ type Conn struct {
session Session
locker sync.Mutex

chunkReader *chunkReader
bdatError chan error

fromReceived bool
recipients []string
}
@@ -135,6 +138,8 @@ func (c *Conn) handle(cmd string, arg string) {
c.WriteResponse(250, EnhancedCode{2, 0, 0}, "Session reset")
case "DATA":
c.handleData(arg)
case "BDAT":
c.handleBdat(arg)
case "QUIT":
c.WriteResponse(221, EnhancedCode{2, 0, 0}, "Goodnight and good luck")
c.Close()
@@ -169,9 +174,17 @@ func (c *Conn) SetSession(session Session) {
}

func (c *Conn) Close() error {
if session := c.Session(); session != nil {
session.Logout()
c.SetSession(nil)
c.locker.Lock()
defer c.locker.Unlock()

if c.session != nil {
c.session.Logout()
c.session = nil
}

if c.chunkReader != nil {
c.chunkReader.abort()
c.chunkReader = nil
}

return c.conn.Close()
@@ -584,6 +597,11 @@ func (c *Conn) handleData(arg string) {
return
}

if c.chunkReader != nil {
c.WriteResponse(502, EnhancedCode{5, 5, 1}, "DATA command cannot be used together with BDAT.")
return
}

if !c.fromReceived || len(c.recipients) == 0 {
c.WriteResponse(502, EnhancedCode{5, 5, 1}, "Missing RCPT TO command.")
return
@@ -607,6 +625,126 @@ func (c *Conn) handleData(arg string) {

}

func (c *Conn) handleBdat(arg string) {
args := strings.Split(arg, " ")
if len(args) == 0 {
c.WriteResponse(501, EnhancedCode{5, 5, 4}, "Missing chunk size argument")
return
}

if !c.fromReceived || len(c.recipients) == 0 {
c.WriteResponse(502, EnhancedCode{5, 5, 1}, "Missing RCPT TO command.")
return
}

last := false
if len(args) == 2 {
if !strings.EqualFold(args[1], "LAST") {
c.WriteResponse(501, EnhancedCode{5, 5, 4}, "Unknown BDAT argument")
return
}
last = true
}

// ParseUint instead of Atoi so we will not accept negative values.
size, err := strconv.ParseUint(args[0], 10, 32)
if err != nil {
c.WriteResponse(501, EnhancedCode{5, 5, 4}, "Malformed size argument")
return
}

var status *statusCollector
if c.server.LMTP {
status = c.createStatusCollector()
}

paniced := make(chan struct{}, 1)

if c.chunkReader == nil {
c.chunkReader = newChunkReader(c.text.R, c.server.MaxMessageBytes)
c.bdatError = make(chan error, 1)

chunkReader := c.chunkReader

go func() {
defer func() {
if err := recover(); err != nil {
c.handlePanic(err, status)
paniced <- struct{}{}
}
}()

var err error
if !c.server.LMTP {
err = c.Session().Data(c.chunkReader)
} else {
lmtpSession, ok := c.Session().(LMTPSession)
if !ok {
err = c.Session().Data(c.chunkReader)
for _, rcpt := range c.recipients {
status.SetStatus(rcpt, err)
}
} else {
err = lmtpSession.LMTPData(c.chunkReader, status)
}
}

chunkReader.discardCurrentChunk()
c.bdatError <- err
}()
}

c.chunkReader.addChunk(int(size))

select {
// Wait for Data to consume chunk.
case <-c.chunkReader.chunkEnd:
case <-paniced:
c.WriteResponse(420, EnhancedCode{4, 0, 0}, "Internal server error")
c.Close()
return
case err := <-c.bdatError:
// This code path handles early errors that backend may return before
// reading _all_ chunks. chunkEnd is not sent to in this case.
//
// The RFC says the connection is in indeterminate state in this case
// so we don't bother resetting it.
if c.server.LMTP {
status.fillRemaining(err)
for i, rcpt := range c.recipients {
code, enchCode, msg := toSMTPStatus(<-status.status[i])
c.WriteResponse(code, enchCode, "<"+rcpt+"> "+msg)
}
} else {
code, enhancedCode, msg := toSMTPStatus(err)
c.WriteResponse(code, enhancedCode, msg)
}
return
}

if !last {
c.WriteResponse(250, EnhancedCode{2, 0, 0}, "Continue")
return
}

// This code path handles errors backend may return after processing all
// chunks. That is, client had the chance to send all chunks.

c.chunkReader.end()
err = <-c.bdatError
if c.server.LMTP {
status.fillRemaining(err)
for i, rcpt := range c.recipients {
code, enchCode, msg := toSMTPStatus(<-status.status[i])
c.WriteResponse(code, enchCode, "<"+rcpt+"> "+msg)
}
} else {
code, enhancedCode, msg := toSMTPStatus(err)
c.WriteResponse(code, enhancedCode, msg)
}
c.reset()
}

type statusCollector struct {
// Contains map from recipient to list of channels that are used for that
// recipient.
@@ -649,9 +787,7 @@ func (s *statusCollector) SetStatus(rcptTo string, err error) {
}
}

func (c *Conn) handleDataLMTP() {
r := newDataReader(c)

func (c *Conn) createStatusCollector() *statusCollector {
rcptCounts := make(map[string]int, len(c.recipients))

status := &statusCollector{
@@ -670,6 +806,26 @@ func (c *Conn) handleDataLMTP() {
status.status = append(status.status, status.statusMap[rcpt])
}

return status
}

func (c *Conn) handlePanic(err interface{}, status *statusCollector) {
if status != nil {
status.fillRemaining(&SMTPError{
Code: 421,
EnhancedCode: EnhancedCode{4, 0, 0},
Message: "Internal server error",
})
}

stack := debug.Stack()
c.server.ErrorLog.Printf("panic serving %v: %v\n%s", c.State().RemoteAddr, err, stack)
}

func (c *Conn) handleDataLMTP() {
r := newDataReader(c)

status := c.createStatusCollector()
done := make(chan bool, 1)

lmtpSession, ok := c.Session().(LMTPSession)
@@ -685,14 +841,7 @@ func (c *Conn) handleDataLMTP() {
go func() {
defer func() {
if err := recover(); err != nil {
status.fillRemaining(&SMTPError{
Code: 421,
EnhancedCode: EnhancedCode{4, 0, 0},
Message: "Internal server error",
})

stack := debug.Stack()
c.server.ErrorLog.Printf("panic serving %v: %v\n%s", c.State().RemoteAddr, err, stack)
c.handlePanic(err, status)
done <- false
}
}()
@@ -779,6 +928,11 @@ func (c *Conn) reset() {
c.locker.Lock()
defer c.locker.Unlock()

if c.chunkReader != nil {
c.chunkReader.abort()
c.chunkReader = nil
}

if c.session != nil {
c.session.Reset()
}
2 changes: 1 addition & 1 deletion server.go
Original file line number Diff line number Diff line change
@@ -78,7 +78,7 @@ func NewServer(be Backend) *Server {
Backend: be,
done: make(chan struct{}, 1),
ErrorLog: log.New(os.Stderr, "smtp/server ", log.LstdFlags),
caps: []string{"PIPELINING", "8BITMIME", "ENHANCEDSTATUSCODES"},
caps: []string{"PIPELINING", "8BITMIME", "ENHANCEDSTATUSCODES", "CHUNKING"},
auths: map[string]SaslServerFactory{
sasl.Plain: func(conn *Conn) sasl.Server {
return sasl.NewPlainServer(func(identity, username, password string) error {
Loading