Skip to content

Commit

Permalink
Handle SMTP connection errors
Browse files Browse the repository at this point in the history
  • Loading branch information
kayrus committed Jul 6, 2022
1 parent 14bc7fa commit 71f1376
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 23 deletions.
27 changes: 17 additions & 10 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@ import (
const errThreshold = 3

type ConnectionState struct {
Hostname string
LocalAddr net.Addr
RemoteAddr net.Addr
TLS tls.ConnectionState
ServerDomain string
Hostname string
LocalAddr net.Addr
RemoteAddr net.Addr
TLS tls.ConnectionState
}

type Conn struct {
Expand Down Expand Up @@ -101,12 +102,14 @@ func (c *Conn) handle(cmd string, arg string) {
c.Close()

stack := debug.Stack()
c.server.ErrorLog.Printf("panic serving %v: %v\n%s", c.State().RemoteAddr, err, stack)
c.server.ErrorLog.Printf(c, "panic serving %v: %w\n%s", c.State().RemoteAddr, err, stack)
}
}()

if cmd == "" {
c.protocolError(500, EnhancedCode{5, 5, 2}, "Error: bad syntax")
msg := "Error: bad syntax"
c.server.ErrorLog.Printf(c, "%s", msg)
c.protocolError(500, EnhancedCode{5, 5, 2}, msg)
return
}

Expand Down Expand Up @@ -147,14 +150,17 @@ func (c *Conn) handle(cmd string, arg string) {
c.Close()
case "AUTH":
if c.server.AuthDisabled {
c.protocolError(500, EnhancedCode{5, 5, 2}, "Syntax error, AUTH command unrecognized")
msg := "Syntax error, AUTH command unrecognized"
c.server.ErrorLog.Printf(c, "%s", msg)
c.protocolError(500, EnhancedCode{5, 5, 2}, msg)
} else {
c.handleAuth(arg)
}
case "STARTTLS":
c.handleStartTLS()
default:
msg := fmt.Sprintf("Syntax errors, %v command unrecognized", cmd)
c.server.ErrorLog.Printf(c, "%s", msg)
c.protocolError(500, EnhancedCode{5, 5, 2}, msg)
}
}
Expand Down Expand Up @@ -210,6 +216,7 @@ func (c *Conn) State() ConnectionState {
state.TLS = tlsState
}

state.ServerDomain = c.server.Domain
state.Hostname = c.helo
state.LocalAddr = c.conn.LocalAddr()
state.RemoteAddr = c.conn.RemoteAddr()
Expand Down Expand Up @@ -616,7 +623,7 @@ func (c *Conn) handleStartTLS() {
if err == io.EOF {
return
}
c.server.ErrorLog.Printf("TLS handshake error for %s: %v", c.conn.RemoteAddr(), err)
c.server.ErrorLog.Printf(c, "TLS handshake error for %s: %w", c.State().RemoteAddr, err)
c.WriteResponse(550, EnhancedCode{5, 0, 0}, "Handshake error")
return
}
Expand Down Expand Up @@ -822,7 +829,7 @@ func (c *Conn) handlePanic(err interface{}, status *statusCollector) {
}

stack := debug.Stack()
c.server.ErrorLog.Printf("panic serving %v: %v\n%s", c.State().RemoteAddr, err, stack)
c.server.ErrorLog.Printf(c, "panic serving %v: %w\n%s", c.State().RemoteAddr, err, stack)
}

func (c *Conn) createStatusCollector() *statusCollector {
Expand Down Expand Up @@ -915,7 +922,7 @@ func (c *Conn) handleDataLMTP() {
})

stack := debug.Stack()
c.server.ErrorLog.Printf("panic serving %v: %v\n%s", c.State().RemoteAddr, err, stack)
c.server.ErrorLog.Printf(c, "panic serving %v: %w\n%s", c.State().RemoteAddr, err, stack)
done <- false
}
}()
Expand Down
58 changes: 47 additions & 11 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@ package smtp
import (
"crypto/tls"
"errors"
"fmt"
"io"
"log"
"net"
"os"
"sync"
"syscall"
"time"

"github.com/emersion/go-sasl"
Expand All @@ -20,8 +22,20 @@ type SaslServerFactory func(conn *Conn) sasl.Server

// Logger interface is used by Server to report unexpected internal errors.
type Logger interface {
Printf(format string, v ...interface{})
Println(v ...interface{})
Printf(c *Conn, format string, v ...interface{})
Println(c *Conn, v ...interface{})
}

type DefaultLogger struct {
*log.Logger
}

func (l *DefaultLogger) Printf(_ *Conn, format string, v ...interface{}) {
l.Logger.Println(fmt.Errorf(format, v...))
}

func (l *DefaultLogger) Println(_ *Conn, v ...interface{}) {
l.Logger.Println(v...)
}

// A SMTP server.
Expand Down Expand Up @@ -81,7 +95,7 @@ func NewServer(be Backend) *Server {

Backend: be,
done: make(chan struct{}, 1),
ErrorLog: log.New(os.Stderr, "smtp/server ", log.LstdFlags),
ErrorLog: &DefaultLogger{log.New(os.Stderr, "smtp/server ", log.LstdFlags)},
caps: []string{"PIPELINING", "8BITMIME", "ENHANCEDSTATUSCODES", "CHUNKING"},
auths: map[string]SaslServerFactory{
sasl.Plain: func(conn *Conn) sasl.Server {
Expand Down Expand Up @@ -131,16 +145,17 @@ func (s *Server) Serve(l net.Listener) error {
if max := 1 * time.Second; tempDelay > max {
tempDelay = max
}
s.ErrorLog.Printf("accept error: %s; retrying in %s", err, tempDelay)
s.ErrorLog.Printf(nil, "accept error: %w; retrying in %s", err, tempDelay)
time.Sleep(tempDelay)
continue
}
return err
}
go func() {
err := s.handleConn(newConn(c, s))
conn := newConn(c, s)
err := s.handleConn(conn)
if err != nil {
s.ErrorLog.Printf("handler error: %s", err)
s.ErrorLog.Printf(conn, "handler error: %w", err)
}
}()
}
Expand Down Expand Up @@ -174,7 +189,7 @@ func (s *Server) handleConn(c *Conn) error {
// preserve remote address from PROXY protocol
err.Addr = c.conn.RemoteAddr()
}
s.ErrorLog.Printf("TLS handshake error: %w", err)
s.ErrorLog.Printf(c, "TLS handshake error: %w", err)
return err
}
}
Expand All @@ -186,7 +201,9 @@ func (s *Server) handleConn(c *Conn) error {
if err == nil {
cmd, arg, err := parseCmd(line)
if err != nil {
c.protocolError(501, EnhancedCode{5, 5, 2}, "Bad command")
msg := "Bad command"
s.ErrorLog.Printf(c, "%s: %w", msg, err)
c.protocolError(501, EnhancedCode{5, 5, 2}, msg)
continue
}

Expand All @@ -195,17 +212,36 @@ func (s *Server) handleConn(c *Conn) error {
if err == io.EOF {
return nil
}

if err == ErrTooLongLine {
c.WriteResponse(500, EnhancedCode{5, 4, 0}, "Too long line, closing connection")
msg := "Too long line, closing connection"
s.ErrorLog.Printf(c, "%s: %w", msg, err)
c.WriteResponse(500, EnhancedCode{5, 4, 0}, msg)
return nil
}

if err, ok := err.(*net.OpError); ok {
if err.Err == net.ErrClosed {
return nil
}
if errors.Is(err, syscall.ECONNRESET) && c.Session() == nil {
// healthcheck monitor
return nil
}
// preserve remote address from PROXY protocol
err.Addr = c.conn.RemoteAddr()
}

if neterr, ok := err.(net.Error); ok && neterr.Timeout() {
c.WriteResponse(221, EnhancedCode{2, 4, 2}, "Idle timeout, bye bye")
msg := "Idle timeout, bye bye"
s.ErrorLog.Printf(c, "%s: %w", msg, err)
c.WriteResponse(221, EnhancedCode{2, 4, 2}, msg)
return nil
}

c.WriteResponse(221, EnhancedCode{2, 4, 0}, "Connection error, sorry")
msg := "Connection error, sorry"
s.ErrorLog.Printf(c, "%s: %w", msg, err)
c.WriteResponse(221, EnhancedCode{2, 4, 0}, msg)
return err
}
}
Expand Down
4 changes: 2 additions & 2 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ func TestServerAcceptErrorHandling(t *testing.T) {
s := smtp.NewServer(be)
s.Domain = "localhost"
s.AllowInsecureAuth = true
s.ErrorLog = log.New(errorLog, "", 0)
s.ErrorLog = &smtp.DefaultLogger{log.New(errorLog, "", 0)}

l := newFailingListener()
var serveError error
Expand Down Expand Up @@ -435,7 +435,7 @@ func TestServerPanicRecover(t *testing.T) {

s.Backend.(*backend).panicOnMail = true
// Don't log panic in tests to not confuse people who run 'go test'.
s.ErrorLog = log.New(ioutil.Discard, "", 0)
s.ErrorLog = &smtp.DefaultLogger{log.New(ioutil.Discard, "", 0)}

io.WriteString(c, "MAIL FROM:<[email protected]>\r\n")
scanner.Scan()
Expand Down

0 comments on commit 71f1376

Please sign in to comment.