diff --git a/conn.go b/conn.go index d60d7c0..317fbaf 100644 --- a/conn.go +++ b/conn.go @@ -21,10 +21,11 @@ import ( const errThreshold = 3 type ConnectionState struct { - Hostname string - LocalAddr net.Addr - RemoteAddr net.Addr - TLS tls.ConnectionState + ServerDomain string + Hostname string + LocalAddr net.Addr + RemoteAddr net.Addr + TLS tls.ConnectionState } type Conn struct { @@ -101,12 +102,14 @@ func (c *Conn) handle(cmd string, arg string) { c.Close() stack := debug.Stack() - c.server.ErrorLog.Printf("panic serving %v: %v\n%s", c.State().RemoteAddr, err, stack) + c.server.ErrorLog.Printf(c, "panic serving %v: %w\n%s", c.State().RemoteAddr, err, stack) } }() if cmd == "" { - c.protocolError(500, EnhancedCode{5, 5, 2}, "Error: bad syntax") + msg := "Error: bad syntax" + c.server.ErrorLog.Printf(c, "%s", msg) + c.protocolError(500, EnhancedCode{5, 5, 2}, msg) return } @@ -147,7 +150,9 @@ func (c *Conn) handle(cmd string, arg string) { c.Close() case "AUTH": if c.server.AuthDisabled { - c.protocolError(500, EnhancedCode{5, 5, 2}, "Syntax error, AUTH command unrecognized") + msg := "Syntax error, AUTH command unrecognized" + c.server.ErrorLog.Printf(c, "%s", msg) + c.protocolError(500, EnhancedCode{5, 5, 2}, msg) } else { c.handleAuth(arg) } @@ -155,6 +160,7 @@ func (c *Conn) handle(cmd string, arg string) { c.handleStartTLS() default: msg := fmt.Sprintf("Syntax errors, %v command unrecognized", cmd) + c.server.ErrorLog.Printf(c, "%s", msg) c.protocolError(500, EnhancedCode{5, 5, 2}, msg) } } @@ -210,6 +216,7 @@ func (c *Conn) State() ConnectionState { state.TLS = tlsState } + state.ServerDomain = c.server.Domain state.Hostname = c.helo state.LocalAddr = c.conn.LocalAddr() state.RemoteAddr = c.conn.RemoteAddr() @@ -616,7 +623,7 @@ func (c *Conn) handleStartTLS() { if err == io.EOF { return } - c.server.ErrorLog.Printf("TLS handshake error for %s: %v", c.conn.RemoteAddr(), err) + c.server.ErrorLog.Printf(c, "TLS handshake error for %s: %w", c.State().RemoteAddr, err) c.WriteResponse(550, EnhancedCode{5, 0, 0}, "Handshake error") return } @@ -822,7 +829,7 @@ func (c *Conn) handlePanic(err interface{}, status *statusCollector) { } stack := debug.Stack() - c.server.ErrorLog.Printf("panic serving %v: %v\n%s", c.State().RemoteAddr, err, stack) + c.server.ErrorLog.Printf(c, "panic serving %v: %w\n%s", c.State().RemoteAddr, err, stack) } func (c *Conn) createStatusCollector() *statusCollector { @@ -915,7 +922,7 @@ func (c *Conn) handleDataLMTP() { }) stack := debug.Stack() - c.server.ErrorLog.Printf("panic serving %v: %v\n%s", c.State().RemoteAddr, err, stack) + c.server.ErrorLog.Printf(c, "panic serving %v: %w\n%s", c.State().RemoteAddr, err, stack) done <- false } }() diff --git a/server.go b/server.go index 2ebe019..1378cd3 100644 --- a/server.go +++ b/server.go @@ -3,11 +3,13 @@ package smtp import ( "crypto/tls" "errors" + "fmt" "io" "log" "net" "os" "sync" + "syscall" "time" "github.com/emersion/go-sasl" @@ -20,8 +22,20 @@ type SaslServerFactory func(conn *Conn) sasl.Server // Logger interface is used by Server to report unexpected internal errors. type Logger interface { - Printf(format string, v ...interface{}) - Println(v ...interface{}) + Printf(c *Conn, format string, v ...interface{}) + Println(c *Conn, v ...interface{}) +} + +type DefaultLogger struct { + *log.Logger +} + +func (l *DefaultLogger) Printf(_ *Conn, format string, v ...interface{}) { + l.Logger.Println(fmt.Errorf(format, v...)) +} + +func (l *DefaultLogger) Println(_ *Conn, v ...interface{}) { + l.Logger.Println(v...) } // A SMTP server. @@ -81,7 +95,7 @@ func NewServer(be Backend) *Server { Backend: be, done: make(chan struct{}, 1), - ErrorLog: log.New(os.Stderr, "smtp/server ", log.LstdFlags), + ErrorLog: &DefaultLogger{log.New(os.Stderr, "smtp/server ", log.LstdFlags)}, caps: []string{"PIPELINING", "8BITMIME", "ENHANCEDSTATUSCODES", "CHUNKING"}, auths: map[string]SaslServerFactory{ sasl.Plain: func(conn *Conn) sasl.Server { @@ -131,16 +145,17 @@ func (s *Server) Serve(l net.Listener) error { if max := 1 * time.Second; tempDelay > max { tempDelay = max } - s.ErrorLog.Printf("accept error: %s; retrying in %s", err, tempDelay) + s.ErrorLog.Printf(nil, "accept error: %w; retrying in %s", err, tempDelay) time.Sleep(tempDelay) continue } return err } go func() { - err := s.handleConn(newConn(c, s)) + conn := newConn(c, s) + err := s.handleConn(conn) if err != nil { - s.ErrorLog.Printf("handler error: %s", err) + s.ErrorLog.Printf(conn, "handler error: %w", err) } }() } @@ -174,7 +189,7 @@ func (s *Server) handleConn(c *Conn) error { // preserve remote address from PROXY protocol err.Addr = c.conn.RemoteAddr() } - s.ErrorLog.Printf("TLS handshake error: %w", err) + s.ErrorLog.Printf(c, "TLS handshake error: %w", err) return err } } @@ -186,7 +201,9 @@ func (s *Server) handleConn(c *Conn) error { if err == nil { cmd, arg, err := parseCmd(line) if err != nil { - c.protocolError(501, EnhancedCode{5, 5, 2}, "Bad command") + msg := "Bad command" + s.ErrorLog.Printf(c, "%s: %w", msg, err) + c.protocolError(501, EnhancedCode{5, 5, 2}, msg) continue } @@ -195,17 +212,36 @@ func (s *Server) handleConn(c *Conn) error { if err == io.EOF { return nil } + if err == ErrTooLongLine { - c.WriteResponse(500, EnhancedCode{5, 4, 0}, "Too long line, closing connection") + msg := "Too long line, closing connection" + s.ErrorLog.Printf(c, "%s: %w", msg, err) + c.WriteResponse(500, EnhancedCode{5, 4, 0}, msg) return nil } + if err, ok := err.(*net.OpError); ok { + if err.Err == net.ErrClosed { + return nil + } + if errors.Is(err, syscall.ECONNRESET) && c.Session() == nil { + // healthcheck monitor + return nil + } + // preserve remote address from PROXY protocol + err.Addr = c.conn.RemoteAddr() + } + if neterr, ok := err.(net.Error); ok && neterr.Timeout() { - c.WriteResponse(221, EnhancedCode{2, 4, 2}, "Idle timeout, bye bye") + msg := "Idle timeout, bye bye" + s.ErrorLog.Printf(c, "%s: %w", msg, err) + c.WriteResponse(221, EnhancedCode{2, 4, 2}, msg) return nil } - c.WriteResponse(221, EnhancedCode{2, 4, 0}, "Connection error, sorry") + msg := "Connection error, sorry" + s.ErrorLog.Printf(c, "%s: %w", msg, err) + c.WriteResponse(221, EnhancedCode{2, 4, 0}, msg) return err } } diff --git a/server_test.go b/server_test.go index 4e0b00f..6afd4ef 100644 --- a/server_test.go +++ b/server_test.go @@ -293,7 +293,7 @@ func TestServerAcceptErrorHandling(t *testing.T) { s := smtp.NewServer(be) s.Domain = "localhost" s.AllowInsecureAuth = true - s.ErrorLog = log.New(errorLog, "", 0) + s.ErrorLog = &smtp.DefaultLogger{log.New(errorLog, "", 0)} l := newFailingListener() var serveError error @@ -435,7 +435,7 @@ func TestServerPanicRecover(t *testing.T) { s.Backend.(*backend).panicOnMail = true // Don't log panic in tests to not confuse people who run 'go test'. - s.ErrorLog = log.New(ioutil.Discard, "", 0) + s.ErrorLog = &smtp.DefaultLogger{log.New(ioutil.Discard, "", 0)} io.WriteString(c, "MAIL FROM:\r\n") scanner.Scan()