From 54e1bd52b5f13e45cc5fa45f23054a15bd7addcc Mon Sep 17 00:00:00 2001
From: "fox.cpp" <fox.cpp@disroot.org>
Date: Fri, 12 Jun 2020 00:00:06 +0300
Subject: [PATCH 1/6] server: Implement CHUNKING extension support

---
 chunk_reader.go | 127 +++++++++++++++++++++
 conn.go         | 182 +++++++++++++++++++++++++++---
 server.go       |   2 +-
 server_test.go  | 293 ++++++++++++++++++++++++++++++++++++++++++++++++
 smtp.go         |   1 +
 5 files changed, 590 insertions(+), 15 deletions(-)
 create mode 100644 chunk_reader.go

diff --git a/chunk_reader.go b/chunk_reader.go
new file mode 100644
index 0000000..3c8a627
--- /dev/null
+++ b/chunk_reader.go
@@ -0,0 +1,127 @@
+package smtp
+
+import (
+	"errors"
+	"io"
+	"io/ioutil"
+)
+
+// ErrDataReset is returned by Reader pased to Data function if client does not
+// send another BDAT command and instead closes connection or issues RSET command.
+var ErrDataReset = errors.New("smtp: message transmission aborted")
+
+// chunkReader implements io.Reader by consuming a sequence of size-limited
+// "chunks" from underlying Reader allowing them to be interleaved with other
+// protocol data.
+//
+// It is caller responsibility to not use Reader while chunk is being processed.
+// This can be enforced by blocking on chunkEnd channel that is used to signal the
+// end of another chunk being reached.
+type chunkReader struct {
+	remainingBytes int
+	conn           io.Reader
+	chunks         chan int
+	// Sent to by abort() to unlock running Read.
+	rset         chan struct{}
+	currentChunk *io.LimitedReader
+
+	chunkEnd chan struct{}
+}
+
+func (cr *chunkReader) addChunk(size int) {
+	cr.chunks <- size
+}
+
+func (cr *chunkReader) end() {
+	close(cr.chunks)
+}
+
+func (cr *chunkReader) abort() {
+	close(cr.rset)
+	close(cr.chunkEnd)
+}
+
+func (cr *chunkReader) discardCurrentChunk() error {
+	if cr.currentChunk == nil {
+		return nil
+	}
+	_, err := io.Copy(ioutil.Discard, cr.currentChunk)
+	return err
+}
+
+func (cr *chunkReader) waitNextChunk() error {
+	select {
+	case <-cr.rset:
+		return ErrDataReset
+	case r, ok := <-cr.chunks:
+		if !ok {
+			// Okay, that's the end.
+			return io.EOF
+		}
+		cr.currentChunk = &io.LimitedReader{R: cr.conn, N: int64(r)}
+		return nil
+	}
+}
+
+func (cr *chunkReader) Read(b []byte) (int, error) {
+	/*
+		Possible states:
+
+		1. We are at the start of next chunk.
+		cr.currentChunk == nil, cr.chunks is not closed.
+
+		2. We are in the middle of chunk.
+		cr.currentchunk != nil
+
+		3. Chunk ended, cr.currentChunk returns io.EOF.
+		Generate an 250 response and wait for the next chunk.
+	*/
+
+	if cr.currentChunk == nil {
+		if err := cr.waitNextChunk(); err != nil {
+			return 0, err
+		}
+	}
+
+	n, err := cr.currentChunk.Read(b)
+	if err == io.EOF {
+		cr.chunkEnd <- struct{}{}
+		cr.currentChunk = nil
+		err = nil
+	}
+
+	if cr.remainingBytes != 0 /* no limit */ {
+		cr.remainingBytes -= n
+		if cr.remainingBytes <= 0 {
+			return 0, ErrDataTooLarge
+		}
+	}
+
+	// Strip CR from slice contents.
+	offset := 0
+	for i, chr := range b {
+		if chr == '\r' {
+			offset += 1
+		}
+		if i+offset >= len(b) {
+			break
+		}
+		b[i] = b[i+offset]
+	}
+
+	// We also likely left garbage in remaining bytes but lets hope backend
+	// code does not assume they are intact.
+	return n - offset, err
+
+}
+
+func newChunkReader(conn io.Reader, maxBytes int) *chunkReader {
+	return &chunkReader{
+		remainingBytes: maxBytes,
+		conn:           conn,
+		chunks:         make(chan int, 1),
+		// buffer to make sure abort() will not block if Read is not running.
+		rset:     make(chan struct{}, 1),
+		chunkEnd: make(chan struct{}, 1),
+	}
+}
diff --git a/conn.go b/conn.go
index 50d748b..f50c33e 100644
--- a/conn.go
+++ b/conn.go
@@ -33,6 +33,9 @@ type Conn struct {
 	session   Session
 	locker    sync.Mutex
 
+	chunkReader *chunkReader
+	bdatError   chan error
+
 	fromReceived bool
 	recipients   []string
 }
@@ -135,6 +138,8 @@ func (c *Conn) handle(cmd string, arg string) {
 		c.WriteResponse(250, EnhancedCode{2, 0, 0}, "Session reset")
 	case "DATA":
 		c.handleData(arg)
+	case "BDAT":
+		c.handleBdat(arg)
 	case "QUIT":
 		c.WriteResponse(221, EnhancedCode{2, 0, 0}, "Goodnight and good luck")
 		c.Close()
@@ -169,9 +174,17 @@ func (c *Conn) SetSession(session Session) {
 }
 
 func (c *Conn) Close() error {
-	if session := c.Session(); session != nil {
-		session.Logout()
-		c.SetSession(nil)
+	c.locker.Lock()
+	defer c.locker.Unlock()
+
+	if c.session != nil {
+		c.session.Logout()
+		c.session = nil
+	}
+
+	if c.chunkReader != nil {
+		c.chunkReader.abort()
+		c.chunkReader = nil
 	}
 
 	return c.conn.Close()
@@ -584,6 +597,11 @@ func (c *Conn) handleData(arg string) {
 		return
 	}
 
+	if c.chunkReader != nil {
+		c.WriteResponse(502, EnhancedCode{5, 5, 1}, "DATA command cannot be used together with BDAT.")
+		return
+	}
+
 	if !c.fromReceived || len(c.recipients) == 0 {
 		c.WriteResponse(502, EnhancedCode{5, 5, 1}, "Missing RCPT TO command.")
 		return
@@ -607,6 +625,126 @@ func (c *Conn) handleData(arg string) {
 
 }
 
+func (c *Conn) handleBdat(arg string) {
+	args := strings.Split(arg, " ")
+	if len(args) == 0 {
+		c.WriteResponse(501, EnhancedCode{5, 5, 4}, "Missing chunk size argument")
+		return
+	}
+
+	if !c.fromReceived || len(c.recipients) == 0 {
+		c.WriteResponse(502, EnhancedCode{5, 5, 1}, "Missing RCPT TO command.")
+		return
+	}
+
+	last := false
+	if len(args) == 2 {
+		if !strings.EqualFold(args[1], "LAST") {
+			c.WriteResponse(501, EnhancedCode{5, 5, 4}, "Unknown BDAT argument")
+			return
+		}
+		last = true
+	}
+
+	// ParseUint instead of Atoi so we will not accept negative values.
+	size, err := strconv.ParseUint(args[0], 10, 32)
+	if err != nil {
+		c.WriteResponse(501, EnhancedCode{5, 5, 4}, "Malformed size argument")
+		return
+	}
+
+	var status *statusCollector
+	if c.server.LMTP {
+		status = c.createStatusCollector()
+	}
+
+	paniced := make(chan struct{}, 1)
+
+	if c.chunkReader == nil {
+		c.chunkReader = newChunkReader(c.text.R, c.server.MaxMessageBytes)
+		c.bdatError = make(chan error, 1)
+
+		chunkReader := c.chunkReader
+
+		go func() {
+			defer func() {
+				if err := recover(); err != nil {
+					c.handlePanic(err, status)
+					paniced <- struct{}{}
+				}
+			}()
+
+			var err error
+			if !c.server.LMTP {
+				err = c.Session().Data(c.chunkReader)
+			} else {
+				lmtpSession, ok := c.Session().(LMTPSession)
+				if !ok {
+					err = c.Session().Data(c.chunkReader)
+					for _, rcpt := range c.recipients {
+						status.SetStatus(rcpt, err)
+					}
+				} else {
+					err = lmtpSession.LMTPData(c.chunkReader, status)
+				}
+			}
+
+			chunkReader.discardCurrentChunk()
+			c.bdatError <- err
+		}()
+	}
+
+	c.chunkReader.addChunk(int(size))
+
+	select {
+	// Wait for Data to consume chunk.
+	case <-c.chunkReader.chunkEnd:
+	case <-paniced:
+		c.WriteResponse(420, EnhancedCode{4, 0, 0}, "Internal server error")
+		c.Close()
+		return
+	case err := <-c.bdatError:
+		// This code path handles early errors that backend may return before
+		// reading _all_ chunks. chunkEnd is not sent to in this case.
+		//
+		// The RFC says the connection is in indeterminate state in this case
+		// so we don't bother resetting it.
+		if c.server.LMTP {
+			status.fillRemaining(err)
+			for i, rcpt := range c.recipients {
+				code, enchCode, msg := toSMTPStatus(<-status.status[i])
+				c.WriteResponse(code, enchCode, "<"+rcpt+"> "+msg)
+			}
+		} else {
+			code, enhancedCode, msg := toSMTPStatus(err)
+			c.WriteResponse(code, enhancedCode, msg)
+		}
+		return
+	}
+
+	if !last {
+		c.WriteResponse(250, EnhancedCode{2, 0, 0}, "Continue")
+		return
+	}
+
+	// This code path handles errors backend may return after processing all
+	// chunks. That is, client had the chance to send all chunks.
+
+	c.chunkReader.end()
+	err = <-c.bdatError
+	if c.server.LMTP {
+		status.fillRemaining(err)
+		for i, rcpt := range c.recipients {
+			code, enchCode, msg := toSMTPStatus(<-status.status[i])
+			c.WriteResponse(code, enchCode, "<"+rcpt+"> "+msg)
+		}
+	} else {
+		code, enhancedCode, msg := toSMTPStatus(err)
+		c.WriteResponse(code, enhancedCode, msg)
+	}
+	c.reset()
+}
+
 type statusCollector struct {
 	// Contains map from recipient to list of channels that are used for that
 	// recipient.
@@ -649,9 +787,7 @@ func (s *statusCollector) SetStatus(rcptTo string, err error) {
 	}
 }
 
-func (c *Conn) handleDataLMTP() {
-	r := newDataReader(c)
-
+func (c *Conn) createStatusCollector() *statusCollector {
 	rcptCounts := make(map[string]int, len(c.recipients))
 
 	status := &statusCollector{
@@ -670,6 +806,26 @@ func (c *Conn) handleDataLMTP() {
 		status.status = append(status.status, status.statusMap[rcpt])
 	}
 
+	return status
+}
+
+func (c *Conn) handlePanic(err interface{}, status *statusCollector) {
+	if status != nil {
+		status.fillRemaining(&SMTPError{
+			Code:         421,
+			EnhancedCode: EnhancedCode{4, 0, 0},
+			Message:      "Internal server error",
+		})
+	}
+
+	stack := debug.Stack()
+	c.server.ErrorLog.Printf("panic serving %v: %v\n%s", c.State().RemoteAddr, err, stack)
+}
+
+func (c *Conn) handleDataLMTP() {
+	r := newDataReader(c)
+
+	status := c.createStatusCollector()
 	done := make(chan bool, 1)
 
 	lmtpSession, ok := c.Session().(LMTPSession)
@@ -685,14 +841,7 @@ func (c *Conn) handleDataLMTP() {
 		go func() {
 			defer func() {
 				if err := recover(); err != nil {
-					status.fillRemaining(&SMTPError{
-						Code:         421,
-						EnhancedCode: EnhancedCode{4, 0, 0},
-						Message:      "Internal server error",
-					})
-
-					stack := debug.Stack()
-					c.server.ErrorLog.Printf("panic serving %v: %v\n%s", c.State().RemoteAddr, err, stack)
+					c.handlePanic(err, status)
 					done <- false
 				}
 			}()
@@ -779,6 +928,11 @@ func (c *Conn) reset() {
 	c.locker.Lock()
 	defer c.locker.Unlock()
 
+	if c.chunkReader != nil {
+		c.chunkReader.abort()
+		c.chunkReader = nil
+	}
+
 	if c.session != nil {
 		c.session.Reset()
 	}
diff --git a/server.go b/server.go
index 927af99..2632d39 100755
--- a/server.go
+++ b/server.go
@@ -78,7 +78,7 @@ func NewServer(be Backend) *Server {
 		Backend:  be,
 		done:     make(chan struct{}, 1),
 		ErrorLog: log.New(os.Stderr, "smtp/server ", log.LstdFlags),
-		caps:     []string{"PIPELINING", "8BITMIME", "ENHANCEDSTATUSCODES"},
+		caps:     []string{"PIPELINING", "8BITMIME", "ENHANCEDSTATUSCODES", "CHUNKING"},
 		auths: map[string]SaslServerFactory{
 			sasl.Plain: func(conn *Conn) sasl.Server {
 				return sasl.NewPlainServer(func(identity, username, password string) error {
diff --git a/server_test.go b/server_test.go
index 57370b9..9a9c3e2 100644
--- a/server_test.go
+++ b/server_test.go
@@ -31,6 +31,15 @@ type backend struct {
 	}
 	lmtpStatusSync chan struct{}
 
+	// Errors returned by Data method.
+	dataErrors chan error
+
+	// Error that will be returned by Data method.
+	dataErr error
+
+	// Read N bytes of message before returning dataErr.
+	dataErrOffset int64
+
 	panicOnMail bool
 	userErr     error
 }
@@ -98,7 +107,23 @@ func (s *session) Rcpt(to string) error {
 }
 
 func (s *session) Data(r io.Reader) error {
+	if s.backend.dataErr != nil {
+
+		if s.backend.dataErrOffset != 0 {
+			io.CopyN(ioutil.Discard, r, s.backend.dataErrOffset)
+		}
+
+		err := s.backend.dataErr
+		if s.backend.dataErrors != nil {
+			s.backend.dataErrors <- err
+		}
+		return err
+	}
+
 	if b, err := ioutil.ReadAll(r); err != nil {
+		if s.backend.dataErrors != nil {
+			s.backend.dataErrors <- err
+		}
 		return err
 	} else {
 		s.msg.Data = b
@@ -107,6 +132,9 @@ func (s *session) Data(r io.Reader) error {
 		} else {
 			s.backend.messages = append(s.backend.messages, s.msg)
 		}
+		if s.backend.dataErrors != nil {
+			s.backend.dataErrors <- nil
+		}
 	}
 	return nil
 }
@@ -752,3 +780,268 @@ func TestStrictServerBad(t *testing.T) {
 		t.Fatal("Invalid MAIL response:", scanner.Text())
 	}
 }
+
+func TestServer_Chunking(t *testing.T) {
+	be, s, c, scanner := testServerAuthenticated(t)
+	defer s.Close()
+	defer c.Close()
+
+	io.WriteString(c, "MAIL FROM:<root@nsa.gov>\r\n")
+	scanner.Scan()
+	if !strings.HasPrefix(scanner.Text(), "250 ") {
+		t.Fatal("Invalid MAIL response:", scanner.Text())
+	}
+
+	io.WriteString(c, "RCPT TO:<root@gchq.gov.uk>\r\n")
+	scanner.Scan()
+	if !strings.HasPrefix(scanner.Text(), "250 ") {
+		t.Fatal("Invalid RCPT response:", scanner.Text())
+	}
+
+	io.WriteString(c, "BDAT 8\r\n")
+	io.WriteString(c, "Hey <3\r\n")
+	scanner.Scan()
+	if !strings.HasPrefix(scanner.Text(), "250 ") {
+		t.Fatal("Invalid BDAT response:", scanner.Text())
+	}
+
+	io.WriteString(c, "BDAT 8 LAST\r\n")
+	io.WriteString(c, "Hey :3\r\n")
+	scanner.Scan()
+	if !strings.HasPrefix(scanner.Text(), "250 ") {
+		t.Fatal("Invalid BDAT response:", scanner.Text())
+	}
+
+	if len(be.messages) != 1 || len(be.anonmsgs) != 0 {
+		t.Fatal("Invalid number of sent messages:", be.messages, be.anonmsgs)
+	}
+
+	msg := be.messages[0]
+	if msg.From != "root@nsa.gov" {
+		t.Fatal("Invalid mail sender:", msg.From)
+	}
+	if len(msg.To) != 1 || msg.To[0] != "root@gchq.gov.uk" {
+		t.Fatal("Invalid mail recipients:", msg.To)
+	}
+	if string(msg.Data) != "Hey <3\nHey :3\n" {
+		t.Fatal("Invalid mail data:", string(msg.Data), msg.Data)
+	}
+}
+
+func TestServer_Chunking_LMTP(t *testing.T) {
+	be, s, c, scanner := testServerAuthenticated(t)
+	s.LMTP = true
+	defer s.Close()
+	defer c.Close()
+
+	io.WriteString(c, "MAIL FROM:<root@nsa.gov>\r\n")
+	scanner.Scan()
+	if !strings.HasPrefix(scanner.Text(), "250 ") {
+		t.Fatal("Invalid MAIL response:", scanner.Text())
+	}
+
+	io.WriteString(c, "RCPT TO:<root@gchq.gov.uk>\r\n")
+	scanner.Scan()
+	if !strings.HasPrefix(scanner.Text(), "250 ") {
+		t.Fatal("Invalid RCPT response:", scanner.Text())
+	}
+	io.WriteString(c, "RCPT TO:<toor@gchq.gov.uk>\r\n")
+	scanner.Scan()
+	if !strings.HasPrefix(scanner.Text(), "250 ") {
+		t.Fatal("Invalid RCPT response:", scanner.Text())
+	}
+
+	io.WriteString(c, "BDAT 8\r\n")
+	io.WriteString(c, "Hey <3\r\n")
+	scanner.Scan()
+	if !strings.HasPrefix(scanner.Text(), "250 ") {
+		t.Fatal("Invalid BDAT response:", scanner.Text())
+	}
+
+	io.WriteString(c, "BDAT 8 LAST\r\n")
+	io.WriteString(c, "Hey :3\r\n")
+	scanner.Scan()
+	if !strings.HasPrefix(scanner.Text(), "250 ") {
+		t.Fatal("Invalid BDAT response:", scanner.Text())
+	}
+	scanner.Scan()
+	if !strings.HasPrefix(scanner.Text(), "250 ") {
+		t.Fatal("Invalid BDAT response:", scanner.Text())
+	}
+
+	if len(be.messages) != 1 || len(be.anonmsgs) != 0 {
+		t.Fatal("Invalid number of sent messages:", be.messages, be.anonmsgs)
+	}
+
+	msg := be.messages[0]
+	if msg.From != "root@nsa.gov" {
+		t.Fatal("Invalid mail sender:", msg.From)
+	}
+	if string(msg.Data) != "Hey <3\nHey :3\n" {
+		t.Fatal("Invalid mail data:", string(msg.Data), msg.Data)
+	}
+}
+
+func TestServer_Chunking_Reset(t *testing.T) {
+	be, s, c, scanner := testServerAuthenticated(t)
+	defer s.Close()
+	defer c.Close()
+	be.dataErrors = make(chan error, 10)
+
+	io.WriteString(c, "MAIL FROM:<root@nsa.gov>\r\n")
+	scanner.Scan()
+	if !strings.HasPrefix(scanner.Text(), "250 ") {
+		t.Fatal("Invalid MAIL response:", scanner.Text())
+	}
+
+	io.WriteString(c, "RCPT TO:<root@gchq.gov.uk>\r\n")
+	scanner.Scan()
+	if !strings.HasPrefix(scanner.Text(), "250 ") {
+		t.Fatal("Invalid RCPT response:", scanner.Text())
+	}
+
+	io.WriteString(c, "BDAT 8\r\n")
+	io.WriteString(c, "Hey <3\r\n")
+	scanner.Scan()
+	if !strings.HasPrefix(scanner.Text(), "250 ") {
+		t.Fatal("Invalid BDAT response:", scanner.Text())
+	}
+
+	// Client changed its mind... Note, in this case Data method error is discarded and not returned to the cilent.
+	io.WriteString(c, "RSET\r\n")
+	scanner.Scan()
+	if !strings.HasPrefix(scanner.Text(), "250 ") {
+		t.Fatal("Invalid BDAT response:", scanner.Text())
+	}
+
+	if err := <-be.dataErrors; err != smtp.ErrDataReset {
+		t.Fatal("Backend received a different error:", err)
+	}
+}
+
+func TestServer_Chunking_ClosedInTheMiddle(t *testing.T) {
+	be, s, c, scanner := testServerAuthenticated(t)
+	defer s.Close()
+	defer c.Close()
+	be.dataErrors = make(chan error, 10)
+
+	io.WriteString(c, "MAIL FROM:<root@nsa.gov>\r\n")
+	scanner.Scan()
+	if !strings.HasPrefix(scanner.Text(), "250 ") {
+		t.Fatal("Invalid MAIL response:", scanner.Text())
+	}
+
+	io.WriteString(c, "RCPT TO:<root@gchq.gov.uk>\r\n")
+	scanner.Scan()
+	if !strings.HasPrefix(scanner.Text(), "250 ") {
+		t.Fatal("Invalid RCPT response:", scanner.Text())
+	}
+
+	io.WriteString(c, "BDAT 8\r\n")
+	io.WriteString(c, "Hey <")
+
+	// Bye!
+	c.Close()
+
+	if err := <-be.dataErrors; err != smtp.ErrDataReset {
+		t.Fatal("Backend received a different error:", err)
+	}
+}
+
+func TestServer_Chunking_EarlyError(t *testing.T) {
+	be, s, c, scanner := testServerAuthenticated(t)
+	defer s.Close()
+	defer c.Close()
+
+	be.dataErr = &smtp.SMTPError{
+		Code:         555,
+		EnhancedCode: smtp.EnhancedCode{5, 0, 0},
+		Message:      "I failed",
+	}
+
+	io.WriteString(c, "MAIL FROM:<root@nsa.gov>\r\n")
+	scanner.Scan()
+	if !strings.HasPrefix(scanner.Text(), "250 ") {
+		t.Fatal("Invalid MAIL response:", scanner.Text())
+	}
+
+	io.WriteString(c, "RCPT TO:<root@gchq.gov.uk>\r\n")
+	scanner.Scan()
+	if !strings.HasPrefix(scanner.Text(), "250 ") {
+		t.Fatal("Invalid RCPT response:", scanner.Text())
+	}
+
+	io.WriteString(c, "BDAT 8\r\n")
+	io.WriteString(c, "Hey <3\r\n")
+	scanner.Scan()
+	if !strings.HasPrefix(scanner.Text(), "555 5.0.0 I failed") {
+		t.Fatal("Invalid BDAT response:", scanner.Text())
+	}
+}
+
+func TestServer_Chunking_EarlyErrorDuringChunk(t *testing.T) {
+	be, s, c, scanner := testServerAuthenticated(t)
+	defer s.Close()
+	defer c.Close()
+
+	be.dataErr = &smtp.SMTPError{
+		Code:         555,
+		EnhancedCode: smtp.EnhancedCode{5, 0, 0},
+		Message:      "I failed",
+	}
+	be.dataErrOffset = 5
+
+	io.WriteString(c, "MAIL FROM:<root@nsa.gov>\r\n")
+	scanner.Scan()
+	if !strings.HasPrefix(scanner.Text(), "250 ") {
+		t.Fatal("Invalid MAIL response:", scanner.Text())
+	}
+
+	io.WriteString(c, "RCPT TO:<root@gchq.gov.uk>\r\n")
+	scanner.Scan()
+	if !strings.HasPrefix(scanner.Text(), "250 ") {
+		t.Fatal("Invalid RCPT response:", scanner.Text())
+	}
+
+	io.WriteString(c, "BDAT 8\r\n")
+	io.WriteString(c, "Hey <3\r\n")
+	scanner.Scan()
+	if !strings.HasPrefix(scanner.Text(), "555 5.0.0 I failed") {
+		t.Fatal("Invalid BDAT response:", scanner.Text())
+	}
+
+	// See that command stream state is not corrupted e.g. server is still not
+	// waiting for remaining chunk octets.
+	io.WriteString(c, "NOOP\r\n")
+	scanner.Scan()
+	if !strings.HasPrefix(scanner.Text(), "250 ") {
+		t.Fatal("Invalid RCPT response:", scanner.Text())
+	}
+}
+
+func TestServer_Chunking_tooLongMessage(t *testing.T) {
+	be, s, c, scanner := testServerAuthenticated(t)
+	defer s.Close()
+
+	s.MaxMessageBytes = 50
+
+	io.WriteString(c, "MAIL FROM:<root@nsa.gov>\r\n")
+	scanner.Scan()
+	io.WriteString(c, "RCPT TO:<root@gchq.gov.uk>\r\n")
+	scanner.Scan()
+	io.WriteString(c, "BDAT 30\r\n")
+	io.WriteString(c, "This is a very long message.\r\n")
+	scanner.Scan()
+
+	io.WriteString(c, "BDAT 96 LAST\r\n")
+	io.WriteString(c, "Much longer than you can possibly imagine.\r\n")
+	io.WriteString(c, "And much longer than the server's MaxMessageBytes.\r\n")
+	scanner.Scan()
+	if !strings.HasPrefix(scanner.Text(), "552 ") {
+		t.Fatal("Invalid DATA response, expected an error but got:", scanner.Text())
+	}
+
+	if len(be.messages) != 0 || len(be.anonmsgs) != 0 {
+		t.Fatal("Invalid number of sent messages:", be.messages, be.anonmsgs)
+	}
+}
diff --git a/smtp.go b/smtp.go
index e5045a7..b293779 100644
--- a/smtp.go
+++ b/smtp.go
@@ -8,6 +8,7 @@
 //	ENHANCEDSTATUSCODES	RFC 2034
 //  SMTPUTF8		RFC 6531
 //  REQUIRETLS		draft-ietf-uta-smtp-require-tls-09
+//  CHUNKING		RFC 3030
 //
 // LMTP (RFC 2033) is also supported.
 //

From 23052006b7f03df002e5c8c88fff423f24d77d70 Mon Sep 17 00:00:00 2001
From: "fox.cpp" <fox.cpp@disroot.org>
Date: Fri, 12 Jun 2020 00:14:21 +0300
Subject: [PATCH 2/6] server: Add BINARYMIME support

---
 conn.go   | 10 +++++++++-
 server.go |  2 +-
 smtp.go   |  1 +
 3 files changed, 11 insertions(+), 2 deletions(-)

diff --git a/conn.go b/conn.go
index f50c33e..c1d9ac5 100644
--- a/conn.go
+++ b/conn.go
@@ -35,6 +35,7 @@ type Conn struct {
 
 	chunkReader *chunkReader
 	bdatError   chan error
+	binarymime  bool
 
 	fromReceived bool
 	recipients   []string
@@ -348,6 +349,8 @@ func (c *Conn) handleMail(arg string) {
 				opts.RequireTLS = true
 			case "BODY":
 				switch value {
+				case "BINARYMIME":
+					c.binarymime = true
 				case "7BIT", "8BITMIME":
 				default:
 					c.WriteResponse(500, EnhancedCode{5, 5, 4}, "Unknown BODY value")
@@ -598,7 +601,11 @@ func (c *Conn) handleData(arg string) {
 	}
 
 	if c.chunkReader != nil {
-		c.WriteResponse(502, EnhancedCode{5, 5, 1}, "DATA command cannot be used together with BDAT.")
+		c.WriteResponse(502, EnhancedCode{5, 5, 1}, "DATA command cannot be used together with BDAT")
+		return
+	}
+	if c.binarymime {
+		c.WriteResponse(502, EnhancedCode{5, 5, 1}, "DATA command cannot be used with BODY=BINARYMIME")
 		return
 	}
 
@@ -938,4 +945,5 @@ func (c *Conn) reset() {
 	}
 	c.fromReceived = false
 	c.recipients = nil
+	c.binarymime = false
 }
diff --git a/server.go b/server.go
index 2632d39..762974e 100755
--- a/server.go
+++ b/server.go
@@ -78,7 +78,7 @@ func NewServer(be Backend) *Server {
 		Backend:  be,
 		done:     make(chan struct{}, 1),
 		ErrorLog: log.New(os.Stderr, "smtp/server ", log.LstdFlags),
-		caps:     []string{"PIPELINING", "8BITMIME", "ENHANCEDSTATUSCODES", "CHUNKING"},
+		caps:     []string{"PIPELINING", "8BITMIME", "ENHANCEDSTATUSCODES", "CHUNKING", "BINARYMIME"},
 		auths: map[string]SaslServerFactory{
 			sasl.Plain: func(conn *Conn) sasl.Server {
 				return sasl.NewPlainServer(func(identity, username, password string) error {
diff --git a/smtp.go b/smtp.go
index b293779..77ef835 100644
--- a/smtp.go
+++ b/smtp.go
@@ -9,6 +9,7 @@
 //  SMTPUTF8		RFC 6531
 //  REQUIRETLS		draft-ietf-uta-smtp-require-tls-09
 //  CHUNKING		RFC 3030
+//  BINARYMIME		RFC 3030
 //
 // LMTP (RFC 2033) is also supported.
 //

From 066ac01c6ef5ddd7ef4280216a0181141e9c7cd2 Mon Sep 17 00:00:00 2001
From: "fox.cpp" <fox.cpp@disroot.org>
Date: Fri, 12 Jun 2020 00:20:02 +0300
Subject: [PATCH 3/6] server: Disallow RCPT TO/MAIL FROM while body chunks are
 being sent

---
 conn.go | 10 ++++++++++
 1 file changed, 10 insertions(+)

diff --git a/conn.go b/conn.go
index c1d9ac5..7fe68e5 100644
--- a/conn.go
+++ b/conn.go
@@ -276,6 +276,11 @@ func (c *Conn) handleMail(arg string) {
 		return
 	}
 
+	if c.chunkReader != nil {
+		c.WriteResponse(503, EnhancedCode{5, 5, 1}, "RCPT TO not allowed while BDAT is in progress")
+		return
+	}
+
 	if c.Session() == nil {
 		state := c.State()
 		session, err := c.server.Backend.AnonymousLogin(&state)
@@ -458,6 +463,11 @@ func (c *Conn) handleRcpt(arg string) {
 	// TODO: This trim is probably too forgiving
 	recipient := strings.Trim(arg[3:], "<> ")
 
+	if c.chunkReader != nil {
+		c.WriteResponse(503, EnhancedCode{5, 5, 1}, "RCPT TO not allowed while BDAT is in progress")
+		return
+	}
+
 	if c.server.MaxRecipients > 0 && len(c.recipients) >= c.server.MaxRecipients {
 		c.WriteResponse(552, EnhancedCode{5, 5, 3}, fmt.Sprintf("Maximum limit of %v recipients reached", c.server.MaxRecipients))
 		return

From c7cff6ed3168631b1fef77ba6eab5342b0c73d03 Mon Sep 17 00:00:00 2001
From: "fox.cpp" <fox.cpp@disroot.org>
Date: Fri, 12 Jun 2020 22:14:34 +0300
Subject: [PATCH 4/6] Clarify data flow and rename chunkReader.conn

---
 chunk_reader.go |  9 +++++----
 conn.go         | 19 +++++++++++++------
 2 files changed, 18 insertions(+), 10 deletions(-)

diff --git a/chunk_reader.go b/chunk_reader.go
index 3c8a627..88d3774 100644
--- a/chunk_reader.go
+++ b/chunk_reader.go
@@ -19,7 +19,7 @@ var ErrDataReset = errors.New("smtp: message transmission aborted")
 // end of another chunk being reached.
 type chunkReader struct {
 	remainingBytes int
-	conn           io.Reader
+	r              io.Reader
 	chunks         chan int
 	// Sent to by abort() to unlock running Read.
 	rset         chan struct{}
@@ -58,7 +58,7 @@ func (cr *chunkReader) waitNextChunk() error {
 			// Okay, that's the end.
 			return io.EOF
 		}
-		cr.currentChunk = &io.LimitedReader{R: cr.conn, N: int64(r)}
+		cr.currentChunk = &io.LimitedReader{R: cr.r, N: int64(r)}
 		return nil
 	}
 }
@@ -74,7 +74,8 @@ func (cr *chunkReader) Read(b []byte) (int, error) {
 		cr.currentchunk != nil
 
 		3. Chunk ended, cr.currentChunk returns io.EOF.
-		Generate an 250 response and wait for the next chunk.
+		Signal connection handling code to send 250 (using chunkEnd)
+		and wait for the next chunk to arrive.
 	*/
 
 	if cr.currentChunk == nil {
@@ -118,7 +119,7 @@ func (cr *chunkReader) Read(b []byte) (int, error) {
 func newChunkReader(conn io.Reader, maxBytes int) *chunkReader {
 	return &chunkReader{
 		remainingBytes: maxBytes,
-		conn:           conn,
+		r:              conn,
 		chunks:         make(chan int, 1),
 		// buffer to make sure abort() will not block if Read is not running.
 		rset:     make(chan struct{}, 1),
diff --git a/conn.go b/conn.go
index 7fe68e5..485d0d2 100644
--- a/conn.go
+++ b/conn.go
@@ -681,6 +681,11 @@ func (c *Conn) handleBdat(arg string) {
 		c.chunkReader = newChunkReader(c.text.R, c.server.MaxMessageBytes)
 		c.bdatError = make(chan error, 1)
 
+		// If chunkReader.abort() is called from somewhere else (e.g.
+		// connection is being forcibly closed) then it will also reset
+		// chunkReader to nil. However, it is important to keep
+		// the original instance so we can call discardCurrentChunk.
+		// This is also the case for RSET command.
 		chunkReader := c.chunkReader
 
 		go func() {
@@ -714,8 +719,11 @@ func (c *Conn) handleBdat(arg string) {
 	c.chunkReader.addChunk(int(size))
 
 	select {
-	// Wait for Data to consume chunk.
 	case <-c.chunkReader.chunkEnd:
+		if !last {
+			c.WriteResponse(250, EnhancedCode{2, 0, 0}, "Continue")
+			return
+		}
 	case <-paniced:
 		c.WriteResponse(420, EnhancedCode{4, 0, 0}, "Internal server error")
 		c.Close()
@@ -739,15 +747,14 @@ func (c *Conn) handleBdat(arg string) {
 		return
 	}
 
-	if !last {
-		c.WriteResponse(250, EnhancedCode{2, 0, 0}, "Continue")
-		return
-	}
-
 	// This code path handles errors backend may return after processing all
 	// chunks. That is, client had the chance to send all chunks.
 
+	// This unlocks Read that Data method may still run, making it return EOF
+	// and allowing Data to complete if it reads all body.
 	c.chunkReader.end()
+
+	// We then wait for it complete and return error to us.
 	err = <-c.bdatError
 	if c.server.LMTP {
 		status.fillRemaining(err)

From b085f82c53cd98e3ce573713e04381d0986ba8a0 Mon Sep 17 00:00:00 2001
From: "fox.cpp" <fox.cpp@disroot.org>
Date: Sat, 13 Jun 2020 21:40:40 +0300
Subject: [PATCH 5/6] server: Do not apply CRLF->LF conversion if
 BODY=BINARYMIME is used

---
 chunk_reader.go |  8 +++++++-
 conn.go         |  2 +-
 server_test.go  | 47 +++++++++++++++++++++++++++++++++++++++++++++++
 3 files changed, 55 insertions(+), 2 deletions(-)

diff --git a/chunk_reader.go b/chunk_reader.go
index 88d3774..8fec1d8 100644
--- a/chunk_reader.go
+++ b/chunk_reader.go
@@ -19,6 +19,7 @@ var ErrDataReset = errors.New("smtp: message transmission aborted")
 // end of another chunk being reached.
 type chunkReader struct {
 	remainingBytes int
+	stripCR        bool
 	r              io.Reader
 	chunks         chan int
 	// Sent to by abort() to unlock running Read.
@@ -98,6 +99,10 @@ func (cr *chunkReader) Read(b []byte) (int, error) {
 		}
 	}
 
+	if !cr.stripCR {
+		return n, err
+	}
+
 	// Strip CR from slice contents.
 	offset := 0
 	for i, chr := range b {
@@ -116,8 +121,9 @@ func (cr *chunkReader) Read(b []byte) (int, error) {
 
 }
 
-func newChunkReader(conn io.Reader, maxBytes int) *chunkReader {
+func newChunkReader(conn io.Reader, maxBytes int, stripCR bool) *chunkReader {
 	return &chunkReader{
+		stripCR:        stripCR,
 		remainingBytes: maxBytes,
 		r:              conn,
 		chunks:         make(chan int, 1),
diff --git a/conn.go b/conn.go
index 485d0d2..ca521e5 100644
--- a/conn.go
+++ b/conn.go
@@ -678,7 +678,7 @@ func (c *Conn) handleBdat(arg string) {
 	paniced := make(chan struct{}, 1)
 
 	if c.chunkReader == nil {
-		c.chunkReader = newChunkReader(c.text.R, c.server.MaxMessageBytes)
+		c.chunkReader = newChunkReader(c.text.R, c.server.MaxMessageBytes, !c.binarymime)
 		c.bdatError = make(chan error, 1)
 
 		// If chunkReader.abort() is called from somewhere else (e.g.
diff --git a/server_test.go b/server_test.go
index 9a9c3e2..89ae1ea 100644
--- a/server_test.go
+++ b/server_test.go
@@ -828,6 +828,53 @@ func TestServer_Chunking(t *testing.T) {
 	}
 }
 
+func TestServer_Chunking_Binarymime(t *testing.T) {
+	be, s, c, scanner := testServerAuthenticated(t)
+	defer s.Close()
+	defer c.Close()
+
+	io.WriteString(c, "MAIL FROM:<root@nsa.gov> BODY=BINARYMIME\r\n")
+	scanner.Scan()
+	if !strings.HasPrefix(scanner.Text(), "250 ") {
+		t.Fatal("Invalid MAIL response:", scanner.Text())
+	}
+
+	io.WriteString(c, "RCPT TO:<root@gchq.gov.uk>\r\n")
+	scanner.Scan()
+	if !strings.HasPrefix(scanner.Text(), "250 ") {
+		t.Fatal("Invalid RCPT response:", scanner.Text())
+	}
+
+	io.WriteString(c, "BDAT 8\r\n")
+	io.WriteString(c, "Hey <3\r\n")
+	scanner.Scan()
+	if !strings.HasPrefix(scanner.Text(), "250 ") {
+		t.Fatal("Invalid BDAT response:", scanner.Text())
+	}
+
+	io.WriteString(c, "BDAT 8 LAST\r\n")
+	io.WriteString(c, "Hey :3\r\n")
+	scanner.Scan()
+	if !strings.HasPrefix(scanner.Text(), "250 ") {
+		t.Fatal("Invalid BDAT response:", scanner.Text())
+	}
+
+	if len(be.messages) != 1 || len(be.anonmsgs) != 0 {
+		t.Fatal("Invalid number of sent messages:", be.messages, be.anonmsgs)
+	}
+
+	msg := be.messages[0]
+	if msg.From != "root@nsa.gov" {
+		t.Fatal("Invalid mail sender:", msg.From)
+	}
+	if len(msg.To) != 1 || msg.To[0] != "root@gchq.gov.uk" {
+		t.Fatal("Invalid mail recipients:", msg.To)
+	}
+	if string(msg.Data) != "Hey <3\r\nHey :3\r\n" {
+		t.Fatal("Invalid mail data:", string(msg.Data), msg.Data)
+	}
+}
+
 func TestServer_Chunking_LMTP(t *testing.T) {
 	be, s, c, scanner := testServerAuthenticated(t)
 	s.LMTP = true

From e797215553b133513fa3c7bd3a43bd9f64699621 Mon Sep 17 00:00:00 2001
From: "fox.cpp" <fox.cpp@disroot.org>
Date: Sat, 13 Jun 2020 21:41:35 +0300
Subject: [PATCH 6/6] Remove misleading comment

---
 chunk_reader.go | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/chunk_reader.go b/chunk_reader.go
index 8fec1d8..1c6258b 100644
--- a/chunk_reader.go
+++ b/chunk_reader.go
@@ -127,8 +127,7 @@ func newChunkReader(conn io.Reader, maxBytes int, stripCR bool) *chunkReader {
 		remainingBytes: maxBytes,
 		r:              conn,
 		chunks:         make(chan int, 1),
-		// buffer to make sure abort() will not block if Read is not running.
-		rset:     make(chan struct{}, 1),
-		chunkEnd: make(chan struct{}, 1),
+		rset:           make(chan struct{}),
+		chunkEnd:       make(chan struct{}, 1),
 	}
 }