diff --git a/configure.go b/configure.go new file mode 100644 index 00000000..692ec9b0 --- /dev/null +++ b/configure.go @@ -0,0 +1,231 @@ +package nsq + +import ( + "errors" + "fmt" + "reflect" + "strconv" + "time" +) + +// Configure takes an option as a string and a value as an interface and +// attempts to set the appropriate configuration option on the reader instance. +// +// It attempts to coerce the value into the right format depending on the named +// option and the underlying type of the value passed in. +// +// It returns an error for an invalid option or value. +func (q *Reader) Configure(option string, value interface{}) error { + getDuration := func(v interface{}) (time.Duration, error) { + switch v.(type) { + case string: + return time.ParseDuration(v.(string)) + case int, int16, uint16, int32, uint32, int64, uint64: + // treat like ms + return time.Duration(reflect.ValueOf(v).Int()) * time.Millisecond, nil + case time.Duration: + return v.(time.Duration), nil + } + return 0, errors.New("invalid value type") + } + + getBool := func(v interface{}) (bool, error) { + switch value.(type) { + case bool: + return value.(bool), nil + case string: + return strconv.ParseBool(v.(string)) + case int, int16, uint16, int32, uint32, int64, uint64: + return reflect.ValueOf(value).Int() != 0, nil + } + return false, errors.New("invalid value type") + } + + getFloat64 := func(v interface{}) (float64, error) { + switch value.(type) { + case string: + return strconv.ParseFloat(value.(string), 64) + case int, int16, uint16, int32, uint32, int64, uint64: + return float64(reflect.ValueOf(value).Int()), nil + case float64: + return value.(float64), nil + } + return 0, errors.New("invalid value type") + } + + getInt64 := func(v interface{}) (int64, error) { + switch value.(type) { + case string: + return strconv.ParseInt(v.(string), 10, 64) + case int, int16, uint16, int32, uint32, int64, uint64: + return reflect.ValueOf(value).Int(), nil + } + return 0, errors.New("invalid value type") + } + + switch option { + case "read_timeout": + v, err := getDuration(value) + if err != nil { + return errors.New(fmt.Sprintf("invalid %s - %s", option, err)) + } + if v > 5*time.Minute || v < 100*time.Millisecond { + return errors.New(fmt.Sprintf("invalid %s ! 100ms <= %s <= 5m", option, v)) + } + q.ReadTimeout = v + case "write_timeout": + v, err := getDuration(value) + if err != nil { + return errors.New(fmt.Sprintf("invalid %s - %s", option, err)) + } + if v > 5*time.Minute || v < 100*time.Millisecond { + return errors.New(fmt.Sprintf("invalid %s ! 100ms <= %s <= 5m", option, v)) + } + q.WriteTimeout = v + case "lookupd_poll_interval": + v, err := getDuration(value) + if err != nil { + return errors.New(fmt.Sprintf("invalid %s - %s", option, err)) + } + if v > 5*time.Minute || v < 5*time.Second { + return errors.New(fmt.Sprintf("invalid %s ! 5s <= %s <= 5m", option, v)) + } + q.LookupdPollInterval = v + case "lookupd_poll_jitter": + v, err := getFloat64(value) + if err != nil { + return errors.New(fmt.Sprintf("invalid %s - %s", option, err)) + } + if v < 0 || v > 1 { + return errors.New(fmt.Sprintf("invalid %s ! 0 <= %v <= 1", option, v)) + } + q.LookupdPollJitter = v + case "max_requeue_delay": + v, err := getDuration(value) + if err != nil { + return errors.New(fmt.Sprintf("invalid %s - %s", option, err)) + } + if v > 60*time.Minute || v < 0 { + return errors.New(fmt.Sprintf("invalid %s ! 0 <= %s <= 60m", option, v)) + } + q.MaxRequeueDelay = v + case "default_requeue_delay": + v, err := getDuration(value) + if err != nil { + return errors.New(fmt.Sprintf("invalid %s - %s", option, err)) + } + if v > 60*time.Minute || v < 0 { + return errors.New(fmt.Sprintf("invalid %s ! 0 <= %s <= 60m", option, v)) + } + q.DefaultRequeueDelay = v + case "backoff_multiplier": + v, err := getDuration(value) + if err != nil { + return errors.New(fmt.Sprintf("invalid %s - %s", option, err)) + } + if v > 60*time.Minute || v < time.Second { + return errors.New(fmt.Sprintf("invalid %s ! 1s <= %s <= 60m", option, v)) + } + q.BackoffMultiplier = v + case "max_attempt_count": + v, err := getInt64(value) + if err != nil { + return errors.New(fmt.Sprintf("invalid %s - %s", option, err)) + } + if v < 1 || v > 65535 { + return errors.New(fmt.Sprintf("invalid %s ! 1 <= %d <= 65535", option, v)) + } + q.MaxAttemptCount = uint16(v) + case "low_rdy_idle_timeout": + v, err := getDuration(value) + if err != nil { + return errors.New(fmt.Sprintf("invalid %s - %s", option, err)) + } + if v > 5*time.Minute || v < time.Second { + return errors.New(fmt.Sprintf("invalid %s ! 1s <= %s <= 5m", option, v)) + } + q.LowRdyIdleTimeout = v + case "heartbeat_interval": + v, err := getDuration(value) + if err != nil { + return errors.New(fmt.Sprintf("invalid %s - %s", option, err)) + } + q.HeartbeatInterval = v + case "output_buffer_size": + v, err := getInt64(value) + if err != nil { + return errors.New(fmt.Sprintf("invalid %s - %s", option, err)) + } + q.OutputBufferSize = v + case "output_buffer_timeout": + v, err := getDuration(value) + if err != nil { + return errors.New(fmt.Sprintf("invalid %s - %s", option, err)) + } + q.OutputBufferTimeout = v + case "tls_v1": + v, err := getBool(value) + if err != nil { + return errors.New(fmt.Sprintf("invalid %s - %s", option, err)) + } + q.TLSv1 = v + case "deflate": + v, err := getBool(value) + if err != nil { + return errors.New(fmt.Sprintf("invalid %s - %s", option, err)) + } + q.Deflate = v + case "deflate_level": + v, err := getInt64(value) + if err != nil { + return errors.New(fmt.Sprintf("invalid %s - %s", option, err)) + } + if v < 1 || v > 9 { + return errors.New(fmt.Sprintf("invalid %s ! 1 <= %d <= 9", option, v)) + } + q.DeflateLevel = int(v) + case "sample_rate": + v, err := getInt64(value) + if err != nil { + return errors.New(fmt.Sprintf("invalid %s - %s", option, err)) + } + if v < 0 || v > 99 { + return errors.New(fmt.Sprintf("invalid %s ! 0 <= %d <= 99", option, v)) + } + q.SampleRate = int32(v) + case "user_agent": + q.UserAgent = value.(string) + case "snappy": + v, err := getBool(value) + if err != nil { + return errors.New(fmt.Sprintf("invalid %s - %s", option, err)) + } + q.Snappy = v + case "max_in_flight": + v, err := getInt64(value) + if err != nil { + return errors.New(fmt.Sprintf("invalid %s - %s", option, err)) + } + if v < 1 { + return errors.New(fmt.Sprintf("invalid %s ! 1 <= %d", option, v)) + } + q.SetMaxInFlight(int(v)) + case "max_backoff_duration": + v, err := getDuration(value) + if err != nil { + return errors.New(fmt.Sprintf("invalid %s - %s", option, err)) + } + if v > 60*time.Minute || v < 0 { + return errors.New(fmt.Sprintf("invalid %s ! 0 <= %s <= 60m", option, v)) + } + q.SetMaxBackoffDuration(v) + case "verbose": + v, err := getBool(value) + if err != nil { + return errors.New(fmt.Sprintf("invalid %s - %s", option, err)) + } + q.VerboseLogging = v + } + + return nil +} diff --git a/conn.go b/conn.go index 099c7494..68027031 100644 --- a/conn.go +++ b/conn.go @@ -5,22 +5,37 @@ import ( "bytes" "compress/flate" "crypto/tls" + "encoding/json" "errors" "fmt" "io" + "log" "net" + "os" + "strings" "sync" + "sync/atomic" "time" "github.com/mreiferson/go-snappystream" ) -type nsqConn struct { +// IdentifyResponse represents the metadata +// returned from an IDENTIFY command to nsqd +type IdentifyResponse struct { + MaxRdyCount int64 `json:"max_rdy_count"` + TLSv1 bool `json:"tls_v1"` + Deflate bool `json:"deflate"` + Snappy bool `json:"snappy"` +} + +// Conn represents a connection to nsqd +// +// Conn exposes a set of callbacks for the +// various events that occur on a connection +type Conn struct { // 64bit atomic vars need to be first for proper alignment on 32bit platforms messagesInFlight int64 - messagesReceived uint64 - messagesFinished uint64 - messagesRequeued uint64 maxRdyCount int64 rdyCount int64 lastRdyCount int64 @@ -38,94 +53,228 @@ type nsqConn struct { r io.Reader w io.Writer + // ResponseCB is called when the connection + // receives a FrameTypeResponse from nsqd + ResponseCB func(*Conn, []byte) + + // ErrorCB is called when the connection + // receives a FrameTypeError from nsqd + ErrorCB func(*Conn, []byte) + + // MessageCB is called when the connection + // receives a FrameTypeMessage from nsqd + MessageCB func(*Conn, *Message) + + // MessageProcessedCB is called when the connection + // handles a FIN or REQ command from a message handler + MessageProcessedCB func(*Conn, *FinishedMessage) + + // IOErrorCB is called when the connection experiences + // a low-level TCP transport error + IOErrorCB func(*Conn, error) + + // HeartbeatCB is called when the connection + // receives a heartbeat from nsqd + HeartbeatCB func(*Conn) + + // CloseCB is called when the connection + // closes, after all cleanup + CloseCB func(*Conn) + + cmdBuf bytes.Buffer + flateWriter *flate.Writer - readTimeout time.Duration - writeTimeout time.Duration + ReadTimeout time.Duration + WriteTimeout time.Duration backoffCounter int32 rdyRetryTimer *time.Timer - rdyChan chan *nsqConn finishedMessages chan *FinishedMessage cmdChan chan *Command exitChan chan int drainReady chan int + ShortIdentifier string // an identifier to send to nsqd when connecting (defaults: short hostname) + LongIdentifier string // an identifier to send to nsqd when connecting (defaults: long hostname) + + HeartbeatInterval time.Duration // duration of time between heartbeats + SampleRate int32 // set the sampleRate of the client's messagePump (requires nsqd 0.2.25+) + UserAgent string // a string identifying the agent for this client in the spirit of HTTP (default: "/") + + // transport layer security + TLSv1 bool // negotiate enabling TLS + TLSConfig *tls.Config // client TLS configuration + + // compression + Deflate bool // negotiate enabling Deflate compression + DeflateLevel int // the compression level to negotiate for Deflate + Snappy bool // negotiate enabling Snappy compression + + // output buffering + OutputBufferSize int64 // size of the buffer (in bytes) used by nsqd for buffering writes to this connection + OutputBufferTimeout time.Duration // timeout (in ms) used by nsqd before flushing buffered writes (set to 0 to disable). Warning: configuring clients with an extremely low (< 25ms) output_buffer_timeout has a significant effect on nsqd CPU usage (particularly with > 50 clients connected). + stopFlag int32 stopper sync.Once wg sync.WaitGroup + + readLoopRunning int32 } -func newNSQConn(rdyChan chan *nsqConn, addr string, - topic string, channel string, - readTimeout time.Duration, writeTimeout time.Duration) (*nsqConn, error) { - conn, err := net.DialTimeout("tcp", addr, time.Second) +// NewConn returns a new Conn instance +func NewConn(addr string, topic string, channel string) *Conn { + hostname, err := os.Hostname() if err != nil { - return nil, err + log.Fatalf("ERROR: unable to get hostname %s", err.Error()) } - - nc := &nsqConn{ - Conn: conn, - + return &Conn{ addr: addr, topic: topic, channel: channel, - r: conn, - w: conn, + ReadTimeout: DefaultClientTimeout, + WriteTimeout: time.Second, - readTimeout: readTimeout, - writeTimeout: writeTimeout, maxRdyCount: 2500, lastMsgTimestamp: time.Now().UnixNano(), finishedMessages: make(chan *FinishedMessage), cmdChan: make(chan *Command), - rdyChan: rdyChan, exitChan: make(chan int), drainReady: make(chan int), + + ShortIdentifier: strings.Split(hostname, ".")[0], + LongIdentifier: hostname, + + DeflateLevel: 6, + OutputBufferSize: 16 * 1024, + OutputBufferTimeout: 250 * time.Millisecond, + + HeartbeatInterval: DefaultClientTimeout / 2, + + UserAgent: fmt.Sprintf("go-nsq/%s", VERSION), } +} - _, err = nc.Write(MagicV2) +// Connect dials and bootstraps the nsqd connection +// (including IDENTIFY) and returns the IdentifyResponse +func (c *Conn) Connect() (*IdentifyResponse, error) { + conn, err := net.DialTimeout("tcp", c.addr, time.Second) if err != nil { - nc.Close() - return nil, fmt.Errorf("[%s] failed to write magic - %s", addr, err.Error()) + return nil, err } + c.Conn = conn + c.r = conn + c.w = conn + + _, err = c.Write(MagicV2) + if err != nil { + c.Close() + return nil, fmt.Errorf("[%s] failed to write magic - %s", c.addr, err) + } + + resp, err := c.identify() + if err != nil { + return nil, err + } + + c.wg.Add(2) + atomic.StoreInt32(&c.readLoopRunning, 1) + go c.readLoop() + go c.writeLoop() + return resp, nil +} + +// Close idempotently closes the connection +func (c *Conn) Close() error { + // so that external users dont need + // to do this dance... + // (would only happen if the dial failed) + if c.Conn == nil { + return nil + } + return c.Conn.Close() +} + +// Stop gracefully initiates connection close +// allowing in-flight messages to finish +func (c *Conn) Stop() { + atomic.StoreInt32(&c.stopFlag, 1) +} + +// IsStopping indicates whether or not the +// connection is currently in the processing of +// gracefully closing +func (c *Conn) IsStopping() bool { + return atomic.LoadInt32(&c.stopFlag) == 1 +} + +// RDY returns the current RDY count +func (c *Conn) RDY() int64 { + return atomic.LoadInt64(&c.rdyCount) +} + +// LastRDY returns the previously set RDY count +func (c *Conn) LastRDY() int64 { + return atomic.LoadInt64(&c.lastRdyCount) +} + +// SetRDY stores the specified RDY count +func (c *Conn) SetRDY(rdy int64) { + atomic.StoreInt64(&c.rdyCount, rdy) + atomic.StoreInt64(&c.lastRdyCount, rdy) +} + +// MaxRDY returns the nsqd negotiated maximum +// RDY count that it will accept for this connection +func (c *Conn) MaxRDY() int64 { + return c.maxRdyCount +} - return nc, nil +// LastMessageTime returns a time.Time representing +// the time at which the last message was received +func (c *Conn) LastMessageTime() time.Time { + return time.Unix(0, atomic.LoadInt64(&c.lastMsgTimestamp)) } -func (c *nsqConn) String() string { +// Address returns the configured destination nsqd address +func (c *Conn) Address() string { + return c.addr +} + +// String returns the fully-qualified address/topic/channel +func (c *Conn) String() string { return fmt.Sprintf("%s/%s/%s", c.addr, c.topic, c.channel) } -func (c *nsqConn) Read(p []byte) (int, error) { - c.SetReadDeadline(time.Now().Add(c.readTimeout)) +// Read performs a deadlined read on the underlying TCP connection +func (c *Conn) Read(p []byte) (int, error) { + c.SetReadDeadline(time.Now().Add(c.ReadTimeout)) return c.r.Read(p) } -func (c *nsqConn) Write(p []byte) (int, error) { - c.SetWriteDeadline(time.Now().Add(c.writeTimeout)) +// Write performs a deadlined write on the underlying TCP connection +func (c *Conn) Write(p []byte) (int, error) { + c.SetWriteDeadline(time.Now().Add(c.WriteTimeout)) return c.w.Write(p) } -func (c *nsqConn) enableReadBuffering() { - c.r = bufio.NewReader(c.r) -} - -func (c *nsqConn) sendCommand(buf *bytes.Buffer, cmd *Command) error { +// SendCommand writes the specified Command to the underlying +// TCP connection according to the NSQ TCP protocol spec +func (c *Conn) SendCommand(cmd *Command) error { c.Lock() defer c.Unlock() - buf.Reset() - err := cmd.Write(buf) + c.cmdBuf.Reset() + err := cmd.Write(&c.cmdBuf) if err != nil { return err } - _, err = buf.WriteTo(c) + _, err = c.cmdBuf.WriteTo(c) if err != nil { return err } @@ -137,7 +286,10 @@ func (c *nsqConn) sendCommand(buf *bytes.Buffer, cmd *Command) error { return nil } -func (c *nsqConn) readUnpackedResponse() (int32, []byte, error) { +// ReadUnpackedResponse reads and parses data from the underlying +// TCP connection according to the NSQ TCP protocol spec and +// returns the frameType, data or error +func (c *Conn) ReadUnpackedResponse() (int32, []byte, error) { resp, err := ReadResponse(c) if err != nil { return -1, nil, err @@ -145,7 +297,81 @@ func (c *nsqConn) readUnpackedResponse() (int32, []byte, error) { return UnpackResponse(resp) } -func (c *nsqConn) upgradeTLS(conf *tls.Config) error { +func (c *Conn) identify() (*IdentifyResponse, error) { + ci := make(map[string]interface{}) + ci["short_id"] = c.ShortIdentifier + ci["long_id"] = c.LongIdentifier + ci["tls_v1"] = c.TLSv1 + ci["deflate"] = c.Deflate + ci["deflate_level"] = c.DeflateLevel + ci["snappy"] = c.Snappy + ci["feature_negotiation"] = true + ci["heartbeat_interval"] = int64(c.HeartbeatInterval / time.Millisecond) + ci["sample_rate"] = c.SampleRate + ci["user_agent"] = c.UserAgent + ci["output_buffer_size"] = c.OutputBufferSize + ci["output_buffer_timeout"] = int64(c.OutputBufferTimeout / time.Millisecond) + cmd, err := Identify(ci) + if err != nil { + return nil, ErrIdentify{Reason: err.Error()} + } + + err = c.SendCommand(cmd) + if err != nil { + return nil, ErrIdentify{Reason: err.Error()} + } + + frameType, data, err := c.ReadUnpackedResponse() + if err != nil { + return nil, ErrIdentify{Reason: err.Error()} + } + + if frameType == FrameTypeError { + return nil, ErrIdentify{string(data)} + } + + // check to see if the server was able to respond w/ capabilities + // i.e. it was a JSON response + if data[0] != '{' { + return nil, nil + } + + resp := &IdentifyResponse{} + err = json.Unmarshal(data, resp) + if err != nil { + return nil, ErrIdentify{err.Error()} + } + + c.maxRdyCount = resp.MaxRdyCount + + if resp.TLSv1 { + err := c.upgradeTLS(c.TLSConfig) + if err != nil { + return nil, ErrIdentify{err.Error()} + } + } + + if resp.Deflate { + err := c.upgradeDeflate(c.DeflateLevel) + if err != nil { + return nil, ErrIdentify{err.Error()} + } + } + + if resp.Snappy { + err := c.upgradeSnappy() + if err != nil { + return nil, ErrIdentify{err.Error()} + } + } + + // now that connection is bootstrapped, enable read buffering + c.r = bufio.NewReader(c.r) + + return resp, nil +} + +func (c *Conn) upgradeTLS(conf *tls.Config) error { c.tlsConn = tls.Client(c.Conn, conf) err := c.tlsConn.Handshake() if err != nil { @@ -153,7 +379,7 @@ func (c *nsqConn) upgradeTLS(conf *tls.Config) error { } c.r = c.tlsConn c.w = c.tlsConn - frameType, data, err := c.readUnpackedResponse() + frameType, data, err := c.ReadUnpackedResponse() if err != nil { return err } @@ -163,7 +389,7 @@ func (c *nsqConn) upgradeTLS(conf *tls.Config) error { return nil } -func (c *nsqConn) upgradeDeflate(level int) error { +func (c *Conn) upgradeDeflate(level int) error { conn := c.Conn if c.tlsConn != nil { conn = c.tlsConn @@ -172,7 +398,7 @@ func (c *nsqConn) upgradeDeflate(level int) error { fw, _ := flate.NewWriter(conn, level) c.flateWriter = fw c.w = fw - frameType, data, err := c.readUnpackedResponse() + frameType, data, err := c.ReadUnpackedResponse() if err != nil { return err } @@ -182,14 +408,14 @@ func (c *nsqConn) upgradeDeflate(level int) error { return nil } -func (c *nsqConn) upgradeSnappy() error { +func (c *Conn) upgradeSnappy() error { conn := c.Conn if c.tlsConn != nil { conn = c.tlsConn } c.r = snappystream.NewReader(conn, snappystream.SkipVerifyChecksum) c.w = snappystream.NewWriter(conn) - frameType, data, err := c.readUnpackedResponse() + frameType, data, err := c.ReadUnpackedResponse() if err != nil { return err } @@ -198,3 +424,201 @@ func (c *nsqConn) upgradeSnappy() error { } return nil } + +func (c *Conn) readLoop() { + for { + if atomic.LoadInt32(&c.stopFlag) == 1 { + goto exit + } + + frameType, data, err := c.ReadUnpackedResponse() + if err != nil { + c.IOErrorCB(c, err) + goto exit + } + + if frameType == FrameTypeResponse && bytes.Equal(data, []byte("_heartbeat_")) { + c.HeartbeatCB(c) + err := c.SendCommand(Nop()) + if err != nil { + c.IOErrorCB(c, err) + goto exit + } + continue + } + + switch frameType { + case FrameTypeResponse: + c.ResponseCB(c, data) + case FrameTypeMessage: + msg, err := DecodeMessage(data) + if err != nil { + c.IOErrorCB(c, err) + goto exit + } + msg.cmdChan = c.cmdChan + msg.responseChan = c.finishedMessages + msg.exitChan = c.exitChan + + atomic.AddInt64(&c.rdyCount, -1) + atomic.AddInt64(&c.messagesInFlight, 1) + atomic.StoreInt64(&c.lastMsgTimestamp, time.Now().UnixNano()) + + c.MessageCB(c, msg) + case FrameTypeError: + c.ErrorCB(c, data) + default: + c.IOErrorCB(c, fmt.Errorf("unknown frame type %d", frameType)) + } + } + +exit: + atomic.StoreInt32(&c.readLoopRunning, 0) + // start the connection close + messagesInFlight := atomic.LoadInt64(&c.messagesInFlight) + if messagesInFlight == 0 { + // if we exited readLoop with no messages in flight + // we need to explicitly trigger the close because + // writeLoop won't + c.close() + } else { + log.Printf("[%s] delaying close, %d outstanding messages", + c, messagesInFlight) + } + c.wg.Done() + log.Printf("[%s] readLoop exiting", c) +} + +func (c *Conn) writeLoop() { + for { + select { + case <-c.exitChan: + log.Printf("[%s] breaking out of writeLoop", c) + // Indicate drainReady because we will not pull any more off finishedMessages + close(c.drainReady) + goto exit + case cmd := <-c.cmdChan: + err := c.SendCommand(cmd) + if err != nil { + log.Printf("[%s] error sending command %s - %s", c, cmd, err) + c.close() + continue + } + case finishedMsg := <-c.finishedMessages: + // Decrement this here so it is correct even if we can't respond to nsqd + msgsInFlight := atomic.AddInt64(&c.messagesInFlight, -1) + + if finishedMsg.Success { + err := c.SendCommand(Finish(finishedMsg.Id)) + if err != nil { + log.Printf("[%s] error finishing %s - %s", c, finishedMsg.Id, err.Error()) + c.close() + continue + } + } else { + err := c.SendCommand(Requeue(finishedMsg.Id, finishedMsg.RequeueDelayMs)) + if err != nil { + log.Printf("[%s] error requeueing %s - %s", c, finishedMsg.Id, err.Error()) + c.close() + continue + } + } + + c.MessageProcessedCB(c, finishedMsg) + + if msgsInFlight == 0 && atomic.LoadInt32(&c.stopFlag) == 1 { + c.close() + continue + } + } + } + +exit: + c.wg.Done() + log.Printf("[%s] writeLoop exiting", c) +} + +func (c *Conn) close() { + // a "clean" connection close is orchestrated as follows: + // + // 1. CLOSE cmd sent to nsqd + // 2. CLOSE_WAIT response received from nsqd + // 3. set c.stopFlag + // 4. readLoop() exits + // a. if messages-in-flight > 0 delay close() + // i. writeLoop() continues receiving on c.finishedMessages chan + // x. when messages-in-flight == 0 call close() + // b. else call close() immediately + // 5. c.exitChan close + // a. writeLoop() exits + // i. c.drainReady close + // 6a. launch cleanup() goroutine (we're racing with intraprocess + // routed messages, see comments below) + // a. wait on c.drainReady + // b. loop and receive on c.finishedMessages chan + // until messages-in-flight == 0 + // i. ensure that readLoop has exited + // 6b. launch waitForCleanup() goroutine + // b. wait on waitgroup (covers readLoop() and writeLoop() + // and cleanup goroutine) + // c. underlying TCP connection close + // d. trigger CloseCB() + // + c.stopper.Do(func() { + log.Printf("[%s] beginning close", c) + close(c.exitChan) + + c.wg.Add(1) + go c.cleanup() + + go c.waitForCleanup() + }) +} + +func (c *Conn) cleanup() { + <-c.drainReady + ticker := time.NewTicker(100 * time.Millisecond) + // finishLoop has exited, drain any remaining in flight messages + for { + // we're racing with router which potentially has a message + // for handling... + // + // infinitely loop until the connection's waitgroup is satisfied, + // ensuring that both finishLoop and router have exited, at which + // point we can be guaranteed that messagesInFlight accurately + // represents whatever is left... continue until 0. + var msgsInFlight int64 + select { + case <-c.finishedMessages: + msgsInFlight = atomic.AddInt64(&c.messagesInFlight, -1) + case <-ticker.C: + msgsInFlight = atomic.LoadInt64(&c.messagesInFlight) + } + if msgsInFlight > 0 { + log.Printf("[%s] draining... waiting for %d messages in flight", c, msgsInFlight) + continue + } + // until the readLoop has exited we cannot be sure that there + // still won't be a race + if atomic.LoadInt32(&c.readLoopRunning) == 1 { + log.Printf("[%s] draining... readLoop still running", c) + continue + } + goto exit + } + +exit: + ticker.Stop() + c.wg.Done() + log.Printf("[%s] finished draining, cleanup exiting", c) +} + +func (c *Conn) waitForCleanup() { + // this blocks until readLoop and writeLoop + // (and cleanup goroutine above) have exited + c.wg.Wait() + // actually close the underlying connection + c.Close() + log.Printf("[%s] clean close complete", c) + c.CloseCB(c) +} diff --git a/errors.go b/errors.go new file mode 100644 index 00000000..2a2b8227 --- /dev/null +++ b/errors.go @@ -0,0 +1,23 @@ +package nsq + +import ( + "errors" + "fmt" +) + +// returned from ConnectToNSQ when already connected +var ErrAlreadyConnected = errors.New("already connected") + +// returned from updateRdy if over max-in-flight +var ErrOverMaxInFlight = errors.New("over configure max-inflight") + +// returned from ConnectToLookupd when given lookupd address exists already +var ErrLookupdAddressExists = errors.New("lookupd address already exists") + +type ErrIdentify struct { + Reason string +} + +func (e ErrIdentify) Error() string { + return fmt.Sprintf("failed to IDENTIFY - %s", e.Reason) +} diff --git a/reader.go b/reader.go index 5feb2f92..71fc9bec 100644 --- a/reader.go +++ b/reader.go @@ -3,7 +3,6 @@ package nsq import ( "bytes" "crypto/tls" - "encoding/json" "errors" "fmt" "log" @@ -11,24 +10,12 @@ import ( "math/rand" "net" "net/url" - "os" - "reflect" "strconv" - "strings" "sync" "sync/atomic" "time" ) -// returned from ConnectToNSQ when already connected -var ErrAlreadyConnected = errors.New("already connected") - -// returned from updateRdy if over max-in-flight -var ErrOverMaxInFlight = errors.New("over configure max-inflight") - -// returned from ConnectToLookupd when given lookupd address exists already -var ErrLookupdAddressExists = errors.New("lookupd address already exists") - // Handler is the synchronous interface to Reader. // // Implement this interface for handlers that return whether or not message @@ -84,18 +71,15 @@ type Reader struct { MessagesFinished uint64 // an atomic counter - # of messages FINished MessagesRequeued uint64 // an atomic counter - # of messages REQueued totalRdyCount int64 - messagesInFlight int64 backoffDuration int64 sync.RWMutex // basics - TopicName string // name of topic to subscribe to - ChannelName string // name of channel to subscribe to - ShortIdentifier string // an identifier to send to nsqd when connecting (defaults: short hostname) - LongIdentifier string // an identifier to send to nsqd when connecting (defaults: long hostname) - VerboseLogging bool // enable verbose logging - ExitChan chan int // read from this channel to block your main loop + TopicName string // name of topic to subscribe to + ChannelName string // name of channel to subscribe to + VerboseLogging bool // enable verbose logging + ExitChan chan int // read from this channel to block your main loop // network deadlines ReadTimeout time.Duration // the deadline set for network reads @@ -114,6 +98,13 @@ type Reader struct { MaxAttemptCount uint16 // maximum number of times this reader will attempt to process a message LowRdyIdleTimeout time.Duration // the amount of time in seconds to wait for a message from a producer when in a state where RDY counts are re-distributed (ie. max_in_flight < num_producers) + ShortIdentifier string // an identifier to send to nsqd when connecting (defaults: short hostname) + LongIdentifier string // an identifier to send to nsqd when connecting (defaults: long hostname) + + HeartbeatInterval time.Duration // duration of time between heartbeats + SampleRate int32 // set the sampleRate of the client's messagePump (requires nsqd 0.2.25+) + UserAgent string // a string identifying the agent for this client in the spirit of HTTP (default: "/") + // transport layer security TLSv1 bool // negotiate enabling TLS TLSConfig *tls.Config // client TLS configuration @@ -123,14 +114,15 @@ type Reader struct { DeflateLevel int // the compression level to negotiate for Deflate Snappy bool // negotiate enabling Snappy compression - SampleRate int32 // set the sampleRate of the client's messagePump (requires nsqd 0.2.25+) - UserAgent string // a string identifying the agent for this client in the spirit of HTTP (default: "/") + // output buffering + OutputBufferSize int64 // size of the buffer (in bytes) used by nsqd for buffering writes to this connection + OutputBufferTimeout time.Duration // timeout (in ms) used by nsqd before flushing buffered writes (set to 0 to disable). Warning: configuring clients with an extremely low (< 25ms) output_buffer_timeout has a significant effect on nsqd CPU usage (particularly with > 50 clients connected). // internal variables maxBackoffDuration time.Duration maxBackoffCount int32 backoffChan chan bool - rdyChan chan *nsqConn + rdyChan chan *Conn needRDYRedistributed int32 backoffCounter int32 @@ -139,8 +131,9 @@ type Reader struct { incomingMessages chan *Message + rdyRetryTimers map[string]*time.Timer pendingConnections map[string]bool - nsqConnections map[string]*nsqConn + connections map[string]*Conn lookupdRecheckChan chan int lookupdHTTPAddrs []string @@ -164,10 +157,6 @@ func NewReader(topic string, channel string) (*Reader, error) { return nil, errors.New("invalid channel name") } - hostname, err := os.Hostname() - if err != nil { - log.Fatalf("ERROR: unable to get hostname %s", err.Error()) - } q := &Reader{ TopicName: topic, ChannelName: channel, @@ -183,23 +172,19 @@ func NewReader(topic string, channel string) (*Reader, error) { MaxRequeueDelay: 15 * time.Minute, BackoffMultiplier: time.Second, - ShortIdentifier: strings.Split(hostname, ".")[0], - LongIdentifier: hostname, - ReadTimeout: DefaultClientTimeout, WriteTimeout: time.Second, - DeflateLevel: 6, - incomingMessages: make(chan *Message), + rdyRetryTimers: make(map[string]*time.Timer), pendingConnections: make(map[string]bool), - nsqConnections: make(map[string]*nsqConn), + connections: make(map[string]*Conn), lookupdRecheckChan: make(chan int, 1), // used at connection close to force a possible reconnect maxInFlight: 1, backoffChan: make(chan bool), - rdyChan: make(chan *nsqConn, 1), + rdyChan: make(chan *Conn, 1), ExitChan: make(chan int), } @@ -208,210 +193,6 @@ func NewReader(topic string, channel string) (*Reader, error) { return q, nil } -// Configure takes an option as a string and a value as an interface and -// attempts to set the appropriate configuration option on the reader instance. -// -// It attempts to coerce the value into the right format depending on the named -// option and the underlying type of the value passed in. -// -// It returns an error for an invalid option or value. -func (q *Reader) Configure(option string, value interface{}) error { - getDuration := func(v interface{}) (time.Duration, error) { - switch v.(type) { - case string: - return time.ParseDuration(v.(string)) - case int, int16, uint16, int32, uint32, int64, uint64: - // treat like ms - return time.Duration(reflect.ValueOf(v).Int()) * time.Millisecond, nil - case time.Duration: - return v.(time.Duration), nil - } - return 0, errors.New("invalid value type") - } - - getBool := func(v interface{}) (bool, error) { - switch value.(type) { - case bool: - return value.(bool), nil - case string: - return strconv.ParseBool(v.(string)) - case int, int16, uint16, int32, uint32, int64, uint64: - return reflect.ValueOf(value).Int() == 0, nil - } - return false, errors.New("invalid value type") - } - - getFloat64 := func(v interface{}) (float64, error) { - switch value.(type) { - case string: - return strconv.ParseFloat(value.(string), 64) - case int, int16, uint16, int32, uint32, int64, uint64: - return float64(reflect.ValueOf(value).Int()), nil - case float64: - return value.(float64), nil - } - return 0, errors.New("invalid value type") - } - - getInt64 := func(v interface{}) (int64, error) { - switch value.(type) { - case string: - return strconv.ParseInt(v.(string), 10, 64) - case int, int16, uint16, int32, uint32, int64, uint64: - return reflect.ValueOf(value).Int(), nil - } - return 0, errors.New("invalid value type") - } - - switch option { - case "read_timeout": - v, err := getDuration(value) - if err != nil { - return errors.New(fmt.Sprintf("invalid %s - %s", option, err)) - } - if v > 5*time.Minute || v < 100*time.Millisecond { - return errors.New(fmt.Sprintf("invalid %s ! 100ms <= %s <= 5m", option, v)) - } - q.ReadTimeout = v - case "write_timeout": - v, err := getDuration(value) - if err != nil { - return errors.New(fmt.Sprintf("invalid %s - %s", option, err)) - } - if v > 5*time.Minute || v < 100*time.Millisecond { - return errors.New(fmt.Sprintf("invalid %s ! 100ms <= %s <= 5m", option, v)) - } - q.WriteTimeout = v - case "lookupd_poll_interval": - v, err := getDuration(value) - if err != nil { - return errors.New(fmt.Sprintf("invalid %s - %s", option, err)) - } - if v > 5*time.Minute || v < 5*time.Second { - return errors.New(fmt.Sprintf("invalid %s ! 5s <= %s <= 5m", option, v)) - } - q.LookupdPollInterval = v - case "lookupd_poll_jitter": - v, err := getFloat64(value) - if err != nil { - return errors.New(fmt.Sprintf("invalid %s - %s", option, err)) - } - if v < 0 || v > 1 { - return errors.New(fmt.Sprintf("invalid %s ! 0 <= %v <= 1", option, v)) - } - q.LookupdPollJitter = v - case "max_requeue_delay": - v, err := getDuration(value) - if err != nil { - return errors.New(fmt.Sprintf("invalid %s - %s", option, err)) - } - if v > 60*time.Minute || v < 0 { - return errors.New(fmt.Sprintf("invalid %s ! 0 <= %s <= 60m", option, v)) - } - q.MaxRequeueDelay = v - case "default_requeue_delay": - v, err := getDuration(value) - if err != nil { - return errors.New(fmt.Sprintf("invalid %s - %s", option, err)) - } - if v > 60*time.Minute || v < 0 { - return errors.New(fmt.Sprintf("invalid %s ! 0 <= %s <= 60m", option, v)) - } - q.DefaultRequeueDelay = v - case "backoff_multiplier": - v, err := getDuration(value) - if err != nil { - return errors.New(fmt.Sprintf("invalid %s - %s", option, err)) - } - if v > 60*time.Minute || v < time.Second { - return errors.New(fmt.Sprintf("invalid %s ! 1s <= %s <= 60m", option, v)) - } - q.BackoffMultiplier = v - case "max_attempt_count": - v, err := getInt64(value) - if err != nil { - return errors.New(fmt.Sprintf("invalid %s - %s", option, err)) - } - if v < 1 || v > 65535 { - return errors.New(fmt.Sprintf("invalid %s ! 1 <= %d <= 65535", option, v)) - } - q.MaxAttemptCount = uint16(v) - case "low_rdy_idle_timeout": - v, err := getDuration(value) - if err != nil { - return errors.New(fmt.Sprintf("invalid %s - %s", option, err)) - } - if v > 5*time.Minute || v < time.Second { - return errors.New(fmt.Sprintf("invalid %s ! 1s <= %s <= 5m", option, v)) - } - q.LowRdyIdleTimeout = v - case "tls_v1": - v, err := getBool(value) - if err != nil { - return errors.New(fmt.Sprintf("invalid %s - %s", option, err)) - } - q.TLSv1 = v - case "deflate": - v, err := getBool(value) - if err != nil { - return errors.New(fmt.Sprintf("invalid %s - %s", option, err)) - } - q.Deflate = v - case "deflate_level": - v, err := getInt64(value) - if err != nil { - return errors.New(fmt.Sprintf("invalid %s - %s", option, err)) - } - if v < 1 || v > 9 { - return errors.New(fmt.Sprintf("invalid %s ! 1 <= %d <= 9", option, v)) - } - q.DeflateLevel = int(v) - case "sample_rate": - v, err := getInt64(value) - if err != nil { - return errors.New(fmt.Sprintf("invalid %s - %s", option, err)) - } - if v < 0 || v > 99 { - return errors.New(fmt.Sprintf("invalid %s ! 0 <= %d <= 99", option, v)) - } - q.SampleRate = int32(v) - case "user_agent": - q.UserAgent = value.(string) - case "snappy": - v, err := getBool(value) - if err != nil { - return errors.New(fmt.Sprintf("invalid %s - %s", option, err)) - } - q.Snappy = v - case "max_in_flight": - v, err := getInt64(value) - if err != nil { - return errors.New(fmt.Sprintf("invalid %s - %s", option, err)) - } - if v < 1 { - return errors.New(fmt.Sprintf("invalid %s ! 1 <= %d", option, v)) - } - q.SetMaxInFlight(int(v)) - case "max_backoff_duration": - v, err := getDuration(value) - if err != nil { - return errors.New(fmt.Sprintf("invalid %s - %s", option, err)) - } - if v > 60*time.Minute || v < 0 { - return errors.New(fmt.Sprintf("invalid %s ! 0 <= %s <= 60m", option, v)) - } - q.SetMaxBackoffDuration(v) - case "verbose": - v, err := getBool(value) - if err != nil { - return errors.New(fmt.Sprintf("invalid %s - %s", option, err)) - } - q.VerboseLogging = v - } - - return nil -} - // ConnectionMaxInFlight calculates the per-connection max-in-flight count. // // This may change dynamically based on the number of connections to nsqd the Reader @@ -419,7 +200,7 @@ func (q *Reader) Configure(option string, value interface{}) error { func (q *Reader) ConnectionMaxInFlight() int64 { b := float64(q.MaxInFlight()) q.RLock() - s := b / float64(len(q.nsqConnections)) + s := b / float64(len(q.connections)) q.RUnlock() return int64(math.Min(math.Max(1, s), b)) } @@ -430,7 +211,7 @@ func (q *Reader) IsStarved() bool { q.RLock() defer q.RUnlock() - for _, conn := range q.nsqConnections { + for _, conn := range q.connections { threshold := int64(float64(atomic.LoadInt64(&conn.lastRdyCount)) * 0.85) inFlight := atomic.LoadInt64(&conn.messagesInFlight) if inFlight >= threshold && inFlight > 0 && atomic.LoadInt32(&conn.stopFlag) != 1 { @@ -459,8 +240,8 @@ func (q *Reader) SetMaxInFlight(maxInFlight int) { q.RLock() defer q.RUnlock() - for _, c := range q.nsqConnections { - c.rdyChan <- c + for _, c := range q.connections { + q.rdyChan <- c } } @@ -598,8 +379,6 @@ func (q *Reader) queryLookupd() { // automatically. This method is useful when you want to connect to a single, local, // instance. func (q *Reader) ConnectToNSQ(addr string) error { - var buf bytes.Buffer - if atomic.LoadInt32(&q.stopFlag) == 1 { return errors.New("reader stopped") } @@ -609,7 +388,7 @@ func (q *Reader) ConnectToNSQ(addr string) error { } q.RLock() - _, ok := q.nsqConnections[addr] + _, ok := q.connections[addr] _, pendingOk := q.pendingConnections[addr] if ok || pendingOk { q.RUnlock() @@ -619,361 +398,183 @@ func (q *Reader) ConnectToNSQ(addr string) error { log.Printf("[%s] connecting to nsqd", addr) - connection, err := newNSQConn(q.rdyChan, addr, - q.TopicName, q.ChannelName, q.ReadTimeout, q.WriteTimeout) - if err != nil { - return err + conn := NewConn(addr, q.TopicName, q.ChannelName) + if q.ReadTimeout > 0 { + conn.ReadTimeout = q.ReadTimeout + } + if q.WriteTimeout > 0 { + conn.WriteTimeout = q.WriteTimeout + } + conn.Deflate = q.Deflate + if q.DeflateLevel > 0 { + conn.DeflateLevel = q.DeflateLevel + } + conn.Snappy = q.Snappy + conn.TLSv1 = q.TLSv1 + conn.TLSConfig = q.TLSConfig + if q.ShortIdentifier != "" { + conn.ShortIdentifier = q.ShortIdentifier + } + if q.LongIdentifier != "" { + conn.LongIdentifier = q.LongIdentifier + } + if q.HeartbeatInterval != 0 { + conn.HeartbeatInterval = q.HeartbeatInterval } + if q.SampleRate > 0 { + conn.SampleRate = q.SampleRate + } + if q.UserAgent != "" { + conn.UserAgent = q.UserAgent + } + if q.OutputBufferSize != 0 { + conn.OutputBufferSize = q.OutputBufferSize + } + if q.OutputBufferTimeout != 0 { + conn.OutputBufferTimeout = q.OutputBufferTimeout + } + conn.MessageCB = func(c *Conn, msg *Message) { + q.onConnectionMessage(c, msg) + } + conn.MessageProcessedCB = func(c *Conn, finishedMsg *FinishedMessage) { + q.onConnectionMessageProcessed(c, finishedMsg) + } + conn.ResponseCB = func(c *Conn, data []byte) { + q.onConnectionResponse(c, data) + } + conn.ErrorCB = func(c *Conn, data []byte) { + q.onConnectionError(c, data) + } + conn.HeartbeatCB = func(c *Conn) { + q.onConnectionHeartbeat(c) + } + conn.IOErrorCB = func(c *Conn, err error) { + q.onConnectionIOError(c, err) + } + conn.CloseCB = func(c *Conn) { + q.onConnectionClosed(c) + } + cleanupConnection := func() { q.Lock() delete(q.pendingConnections, addr) q.Unlock() - connection.Close() + conn.Close() } - q.pendingConnections[addr] = true - // set the user_agent string to the default if there is no user input version - userAgent := fmt.Sprintf("go-nsq/%s", VERSION) - if q.UserAgent != "" { - userAgent = q.UserAgent - } - - ci := make(map[string]interface{}) - ci["short_id"] = q.ShortIdentifier - ci["long_id"] = q.LongIdentifier - ci["tls_v1"] = q.TLSv1 - ci["deflate"] = q.Deflate - ci["deflate_level"] = q.DeflateLevel - ci["snappy"] = q.Snappy - ci["feature_negotiation"] = true - ci["sample_rate"] = q.SampleRate - ci["user_agent"] = userAgent - cmd, err := Identify(ci) - if err != nil { - cleanupConnection() - return fmt.Errorf("[%s] failed to create identify command - %s", connection, err.Error()) - } + q.pendingConnections[addr] = true - err = connection.sendCommand(&buf, cmd) + resp, err := conn.Connect() if err != nil { cleanupConnection() - return fmt.Errorf("[%s] failed to identify - %s", connection, err.Error()) + return err } - _, data, err := connection.readUnpackedResponse() - if err != nil { - cleanupConnection() - return fmt.Errorf("[%s] error reading response %s", connection, err.Error()) - } - - // check to see if the server was able to respond w/ capabilities - if data[0] == '{' { - resp := struct { - MaxRdyCount int64 `json:"max_rdy_count"` - TLSv1 bool `json:"tls_v1"` - Deflate bool `json:"deflate"` - Snappy bool `json:"snappy"` - SampleRate int32 `json:"sample_rate"` - }{} - err := json.Unmarshal(data, &resp) - if err != nil { - cleanupConnection() - return fmt.Errorf("[%s] error (%s) unmarshaling IDENTIFY response %s", connection, err.Error(), data) - } - - log.Printf("[%s] IDENTIFY response: %+v", connection, resp) - - connection.maxRdyCount = resp.MaxRdyCount + if resp != nil { + log.Printf("[%s] IDENTIFY response: %+v", conn, resp) if resp.MaxRdyCount < int64(q.MaxInFlight()) { log.Printf("[%s] max RDY count %d < reader max in flight %d, truncation possible", - connection, resp.MaxRdyCount, q.MaxInFlight()) + conn, resp.MaxRdyCount, q.MaxInFlight()) } - if resp.TLSv1 { - log.Printf("[%s] upgrading to TLS", connection) - err := connection.upgradeTLS(q.TLSConfig) - if err != nil { - cleanupConnection() - return fmt.Errorf("[%s] error (%s) upgrading to TLS", connection, err.Error()) - } + log.Printf("[%s] upgrading to TLS", conn) } - if resp.Deflate { - log.Printf("[%s] upgrading to Deflate", connection) - err := connection.upgradeDeflate(q.DeflateLevel) - if err != nil { - connection.Close() - return fmt.Errorf("[%s] error (%s) upgrading to deflate", connection, err.Error()) - } + log.Printf("[%s] upgrading to Deflate", conn) } - if resp.Snappy { - log.Printf("[%s] upgrading to Snappy", connection) - err := connection.upgradeSnappy() - if err != nil { - connection.Close() - return fmt.Errorf("[%s] error (%s) upgrading to snappy", connection, err.Error()) - } + log.Printf("[%s] upgrading to Snappy", conn) } } - cmd = Subscribe(q.TopicName, q.ChannelName) - err = connection.sendCommand(&buf, cmd) + cmd := Subscribe(q.TopicName, q.ChannelName) + err = conn.SendCommand(cmd) if err != nil { cleanupConnection() - return fmt.Errorf("[%s] failed to subscribe to %s:%s - %s", connection, q.TopicName, q.ChannelName, err.Error()) + return fmt.Errorf("[%s] failed to subscribe to %s:%s - %s", + conn, q.TopicName, q.ChannelName, err.Error()) } - connection.enableReadBuffering() - q.Lock() delete(q.pendingConnections, addr) - q.nsqConnections[addr] = connection + q.connections[addr] = conn q.Unlock() // pre-emptive signal to existing connections to lower their RDY count q.RLock() - for _, c := range q.nsqConnections { - c.rdyChan <- c + for _, c := range q.connections { + q.rdyChan <- c } q.RUnlock() - connection.wg.Add(2) - go q.readLoop(connection) - go q.finishLoop(connection) - return nil } -func handleError(q *Reader, c *nsqConn, errMsg string) { - log.Printf(errMsg) - atomic.StoreInt32(&c.stopFlag, 1) - - q.RLock() - numLookupd := len(q.lookupdHTTPAddrs) - q.RUnlock() - if numLookupd == 0 { - go func(addr string) { - for { - log.Printf("[%s] re-connecting in 15 seconds...", addr) - time.Sleep(15 * time.Second) - if atomic.LoadInt32(&q.stopFlag) == 1 { - break - } - err := q.ConnectToNSQ(addr) - if err != nil && err != ErrAlreadyConnected { - log.Printf("ERROR: failed to connect to %s - %s", - addr, err.Error()) - continue - } - break - } - }(c.RemoteAddr().String()) - } +func (q *Reader) onConnectionMessage(c *Conn, msg *Message) { + atomic.AddInt64(&q.totalRdyCount, -1) + atomic.AddUint64(&q.MessagesReceived, 1) + q.incomingMessages <- msg + q.rdyChan <- c } -func (q *Reader) readLoop(c *nsqConn) { - for { - if atomic.LoadInt32(&c.stopFlag) == 1 || atomic.LoadInt32(&q.stopFlag) == 1 { - // start the connection close - if atomic.LoadInt64(&c.messagesInFlight) == 0 { - q.stopFinishLoop(c) - } else { - log.Printf("[%s] delaying close, %d outstanding messages", - c, c.messagesInFlight) - } - goto exit - } - - frameType, data, err := c.readUnpackedResponse() - if err != nil { - handleError(q, c, fmt.Sprintf("[%s] error (%s) reading response %d %s", - c, err.Error(), frameType, data)) - continue +func (q *Reader) onConnectionMessageProcessed(c *Conn, finishedMsg *FinishedMessage) { + if finishedMsg.Success { + if q.VerboseLogging { + log.Printf("[%s] finishing %s", c, finishedMsg.Id) } - - switch frameType { - case FrameTypeMessage: - msg, err := DecodeMessage(data) - msg.cmdChan = c.cmdChan - msg.responseChan = c.finishedMessages - - if err != nil { - handleError(q, c, fmt.Sprintf("[%s] error (%s) decoding message %s", - c, err.Error(), data)) - continue - } - - remain := atomic.AddInt64(&c.rdyCount, -1) - atomic.AddInt64(&q.totalRdyCount, -1) - atomic.AddUint64(&c.messagesReceived, 1) - atomic.AddUint64(&q.MessagesReceived, 1) - atomic.AddInt64(&c.messagesInFlight, 1) - atomic.AddInt64(&q.messagesInFlight, 1) - atomic.StoreInt64(&c.lastMsgTimestamp, time.Now().UnixNano()) - - if q.VerboseLogging { - log.Printf("[%s] (remain %d) FrameTypeMessage: %s - %s", - c, remain, msg.Id, msg.Body) - } - - q.incomingMessages <- msg - c.rdyChan <- c - case FrameTypeResponse: - switch { - case bytes.Equal(data, []byte("CLOSE_WAIT")): - // server is ready for us to close (it ack'd our StartClose) - // we can assume we will not receive any more messages over this channel - // (but we can still write back responses) - log.Printf("[%s] received ACK from nsqd - now in CLOSE_WAIT", c) - atomic.StoreInt32(&c.stopFlag, 1) - case bytes.Equal(data, []byte("_heartbeat_")): - var buf bytes.Buffer - log.Printf("[%s] heartbeat received", c) - err := c.sendCommand(&buf, Nop()) - if err != nil { - handleError(q, c, fmt.Sprintf("[%s] error sending NOP - %s", - c, err.Error())) - goto exit - } - } - case FrameTypeError: - log.Printf("[%s] error from nsqd %s", c, data) - default: - log.Printf("[%s] unknown message type %d", c, frameType) + atomic.AddUint64(&q.MessagesFinished, 1) + } else { + if q.VerboseLogging { + log.Printf("[%s] requeuing %s", c, finishedMsg.Id) } + atomic.AddUint64(&q.MessagesRequeued, 1) } - -exit: - c.wg.Done() - log.Printf("[%s] readLoop exiting", c) + q.backoffChan <- finishedMsg.Success } -func (q *Reader) finishLoop(c *nsqConn) { - var buf bytes.Buffer - - for { - select { - case <-c.exitChan: - log.Printf("[%s] breaking out of finish loop", c) - // Indicate drainReady because we will not pull any more off finishedMessages - close(c.drainReady) - goto exit - case cmd := <-c.cmdChan: - err := c.sendCommand(&buf, cmd) - if err != nil { - log.Printf("[%s] error sending command %s - %s", c, cmd, err) - q.stopFinishLoop(c) - continue - } - case msg := <-c.finishedMessages: - // Decrement this here so it is correct even if we can't respond to nsqd - atomic.AddInt64(&q.messagesInFlight, -1) - atomic.AddInt64(&c.messagesInFlight, -1) - - if msg.Success { - if q.VerboseLogging { - log.Printf("[%s] finishing %s", c, msg.Id) - } - - err := c.sendCommand(&buf, Finish(msg.Id)) - if err != nil { - log.Printf("[%s] error finishing %s - %s", c, msg.Id, err.Error()) - q.stopFinishLoop(c) - continue - } - - atomic.AddUint64(&c.messagesFinished, 1) - atomic.AddUint64(&q.MessagesFinished, 1) - } else { - if q.VerboseLogging { - log.Printf("[%s] requeuing %s", c, msg.Id) - } - - err := c.sendCommand(&buf, Requeue(msg.Id, msg.RequeueDelayMs)) - if err != nil { - log.Printf("[%s] error requeueing %s - %s", c, msg.Id, err.Error()) - q.stopFinishLoop(c) - continue - } - - atomic.AddUint64(&c.messagesRequeued, 1) - atomic.AddUint64(&q.MessagesRequeued, 1) - } - - q.backoffChan <- msg.Success - - if atomic.LoadInt64(&c.messagesInFlight) == 0 && - (atomic.LoadInt32(&c.stopFlag) == 1 || atomic.LoadInt32(&q.stopFlag) == 1) { - q.stopFinishLoop(c) - continue - } - } +func (q *Reader) onConnectionResponse(c *Conn, data []byte) { + switch { + case bytes.Equal(data, []byte("CLOSE_WAIT")): + // server is ready for us to close (it ack'd our StartClose) + // we can assume we will not receive any more messages over this channel + // (but we can still write back responses) + log.Printf("[%s] received ACK from nsqd - now in CLOSE_WAIT", c) + c.Stop() } +} -exit: - c.wg.Done() - log.Printf("[%s] finishLoop exiting", c) +func (q *Reader) onConnectionError(c *Conn, data []byte) { + log.Printf("[%s] error from nsqd %s", c, data) } -func (q *Reader) stopFinishLoop(c *nsqConn) { - c.stopper.Do(func() { - log.Printf("[%s] beginning stopFinishLoop", c) - close(c.exitChan) - c.Close() - go q.cleanupConnection(c) - }) +func (q *Reader) onConnectionHeartbeat(c *Conn) { + log.Printf("[%s] heartbeat received", c) } -func (q *Reader) cleanupConnection(c *nsqConn) { - go func() { - <-c.drainReady - ticker := time.NewTicker(100 * time.Millisecond) - // finishLoop has exited, drain any remaining in flight messages - for { - // we're racing with readLoop which potentially has a message - // for handling... - // - // infinitely loop until the connection's waitgroup is satisfied, - // ensuring that both finishLoop and readLoop have exited, at which - // point we can be guaranteed that messagesInFlight accurately - // represents whatever is left... continue until 0. - var msgsInFlight int64 - select { - case <-c.finishedMessages: - msgsInFlight = atomic.AddInt64(&c.messagesInFlight, -1) - case <-ticker.C: - msgsInFlight = atomic.LoadInt64(&c.messagesInFlight) - } - if msgsInFlight > 0 { - log.Printf("[%s] draining... waiting for %d messages in flight", c, msgsInFlight) - continue - } - log.Printf("[%s] done draining finishedMessages", c) - ticker.Stop() - return - } - }() +func (q *Reader) onConnectionIOError(c *Conn, err error) { + log.Printf("[%s] IO Error - %s", c, err) + c.Stop() +} - // this blocks until finishLoop and readLoop have exited - c.wg.Wait() +func (q *Reader) onConnectionClosed(c *Conn) { + var hasRDYRetryTimer bool // remove this connections RDY count from the reader's total - rdyCount := atomic.LoadInt64(&c.rdyCount) + rdyCount := c.RDY() atomic.AddInt64(&q.totalRdyCount, -rdyCount) c.Lock() - hasRDYRetryTimer := c.rdyRetryTimer != nil - if c.rdyRetryTimer != nil { + if timer, ok := q.rdyRetryTimers[c.String()]; ok { // stop any pending retry of an old RDY update - c.rdyRetryTimer.Stop() - c.rdyRetryTimer = nil + timer.Stop() + delete(q.rdyRetryTimers, c.String()) + hasRDYRetryTimer = true } c.Unlock() q.Lock() - delete(q.nsqConnections, c.addr) - left := len(q.nsqConnections) + delete(q.connections, c.Address()) + left := len(q.connections) q.Unlock() log.Printf("there are %d connections left alive", left) @@ -1003,6 +604,24 @@ func (q *Reader) cleanupConnection(c *nsqConn) { case q.lookupdRecheckChan <- 1: default: } + } else if numLookupd == 0 && atomic.LoadInt32(&q.stopFlag) == 0 { + // there are no lookupd, try to reconnect after a bit + go func(addr string) { + for { + log.Printf("[%s] re-connecting in 15 seconds...", addr) + time.Sleep(15 * time.Second) + if atomic.LoadInt32(&q.stopFlag) == 1 { + break + } + err := q.ConnectToNSQ(addr) + if err != nil && err != ErrAlreadyConnected { + log.Printf("ERROR: failed to connect to %s - %s", + addr, err.Error()) + continue + } + break + } + }(c.RemoteAddr().String()) } } @@ -1032,13 +651,13 @@ func (q *Reader) rdyLoop() { for { select { case <-backoffTimerChan: - var choice *nsqConn + var choice *Conn q.RLock() // pick a random connection to test the waters var i int - idx := rand.Intn(len(q.nsqConnections)) - for _, c := range q.nsqConnections { + idx := rand.Intn(len(q.connections)) + for _, c := range q.connections { if i == idx { choice = c break @@ -1062,8 +681,8 @@ func (q *Reader) rdyLoop() { } // send ready immediately - remain := atomic.LoadInt64(&c.rdyCount) - lastRdyCount := atomic.LoadInt64(&c.lastRdyCount) + remain := c.RDY() + lastRdyCount := c.LastRDY() count := q.ConnectionMaxInFlight() // refill when at 1, or at 25%, or if connections have changed and we have too many RDY if remain <= 1 || remain < (lastRdyCount/4) || (count > 0 && count < remain) { @@ -1108,7 +727,7 @@ func (q *Reader) rdyLoop() { if backoffCounter == 0 && backoffUpdated { count := q.ConnectionMaxInFlight() q.RLock() - for _, c := range q.nsqConnections { + for _, c := range q.connections { if q.VerboseLogging { log.Printf("[%s] exiting backoff. returning to RDY %d", c, count) } @@ -1130,7 +749,7 @@ func (q *Reader) rdyLoop() { // send RDY 0 immediately (to *all* connections) q.RLock() - for _, c := range q.nsqConnections { + for _, c := range q.connections { if q.VerboseLogging { log.Printf("[%s] in backoff. sending RDY 0", c) } @@ -1153,27 +772,27 @@ exit: log.Printf("rdyLoop exiting") } -func (q *Reader) updateRDY(c *nsqConn, count int64) error { - if atomic.LoadInt32(&c.stopFlag) != 0 { +func (q *Reader) updateRDY(c *Conn, count int64) error { + if c.IsStopping() { return nil } // never exceed the nsqd's configured max RDY count - if count > c.maxRdyCount { - count = c.maxRdyCount + if count > c.MaxRDY() { + count = c.MaxRDY() } // stop any pending retry of an old RDY update c.Lock() - if c.rdyRetryTimer != nil { - c.rdyRetryTimer.Stop() - c.rdyRetryTimer = nil + if timer, ok := q.rdyRetryTimers[c.String()]; ok { + timer.Stop() + delete(q.rdyRetryTimers, c.String()) } c.Unlock() // never exceed our global max in flight. truncate if possible. // this could help a new connection get partial max-in-flight - rdyCount := atomic.LoadInt64(&c.rdyCount) + rdyCount := c.RDY() maxPossibleRdy := int64(q.MaxInFlight()) - atomic.LoadInt64(&q.totalRdyCount) + rdyCount if maxPossibleRdy > 0 && maxPossibleRdy < count { count = maxPossibleRdy @@ -1184,9 +803,10 @@ func (q *Reader) updateRDY(c *nsqConn, count int64) error { // in order to prevent eternal starvation we reschedule this attempt // (if any other RDY update succeeds this timer will be stopped) c.Lock() - c.rdyRetryTimer = time.AfterFunc(5*time.Second, func() { - q.updateRDY(c, count) - }) + q.rdyRetryTimers[c.String()] = time.AfterFunc(5*time.Second, + func() { + q.updateRDY(c, count) + }) c.Unlock() } return ErrOverMaxInFlight @@ -1195,21 +815,18 @@ func (q *Reader) updateRDY(c *nsqConn, count int64) error { return q.sendRDY(c, count) } -func (q *Reader) sendRDY(c *nsqConn, count int64) error { - var buf bytes.Buffer - - if count == 0 && atomic.LoadInt64(&c.lastRdyCount) == 0 { +func (q *Reader) sendRDY(c *Conn, count int64) error { + if count == 0 && c.LastRDY() == 0 { // no need to send. It's already that RDY count return nil } - atomic.AddInt64(&q.totalRdyCount, -atomic.LoadInt64(&c.rdyCount)+count) - atomic.StoreInt64(&c.rdyCount, count) - atomic.StoreInt64(&c.lastRdyCount, count) - err := c.sendCommand(&buf, Ready(int(count))) + atomic.AddInt64(&q.totalRdyCount, -c.RDY()+count) + c.SetRDY(count) + err := c.SendCommand(Ready(int(count))) if err != nil { - handleError(q, c, fmt.Sprintf("[%s] error sending RDY %d - %s", - c, count, err.Error())) + log.Printf("[%s] error sending RDY %d - %s", c, count, err) + c.Stop() return err } return nil @@ -1221,7 +838,7 @@ func (q *Reader) redistributeRDY() { } q.RLock() - numConns := len(q.nsqConnections) + numConns := len(q.connections) q.RUnlock() maxInFlight := q.MaxInFlight() if numConns > maxInFlight { @@ -1240,11 +857,10 @@ func (q *Reader) redistributeRDY() { } q.RLock() - possibleConns := make([]*nsqConn, 0, len(q.nsqConnections)) - for _, c := range q.nsqConnections { - lastMsgTimestamp := atomic.LoadInt64(&c.lastMsgTimestamp) - lastMsgDuration := time.Now().Sub(time.Unix(0, lastMsgTimestamp)) - rdyCount := atomic.LoadInt64(&c.rdyCount) + possibleConns := make([]*Conn, 0, len(q.connections)) + for _, c := range q.connections { + lastMsgDuration := time.Now().Sub(c.LastMessageTime()) + rdyCount := c.RDY() if q.VerboseLogging { log.Printf("[%s] rdy: %d (last message received %s)", c, rdyCount, lastMsgDuration) @@ -1275,8 +891,6 @@ func (q *Reader) redistributeRDY() { // Stop will gracefully stop the Reader func (q *Reader) Stop() { - var buf bytes.Buffer - if !atomic.CompareAndSwapInt32(&q.stopFlag, 0, 1) { return } @@ -1284,17 +898,18 @@ func (q *Reader) Stop() { log.Printf("stopping reader") q.RLock() - l := len(q.nsqConnections) + l := len(q.connections) q.RUnlock() if l == 0 { q.stopHandlers() } else { q.RLock() - for _, c := range q.nsqConnections { - err := c.sendCommand(&buf, StartClose()) + for _, c := range q.connections { + err := c.SendCommand(StartClose()) if err != nil { log.Printf("[%s] failed to start close - %s", c, err.Error()) + c.Stop() } } q.RUnlock() diff --git a/writer.go b/writer.go index ef7b116e..91885375 100644 --- a/writer.go +++ b/writer.go @@ -1,13 +1,9 @@ package nsq import ( - "bufio" - "bytes" + "crypto/tls" "errors" "log" - "net" - "os" - "strings" "sync" "sync/atomic" "time" @@ -19,23 +15,45 @@ import ( // and will lazily connect to that instance (and re-connect) // when Publish commands are executed. type Writer struct { - net.Conn + Addr string + conn *Conn - WriteTimeout time.Duration - Addr string - HeartbeatInterval time.Duration - ShortIdentifier string - LongIdentifier string + responseChan chan []byte + errorChan chan []byte + ioErrorChan chan error + heartbeatChan chan int + closeChan chan int + + // network deadlines + ReadTimeout time.Duration // the deadline set for network reads + WriteTimeout time.Duration // the deadline set for network writes + + ShortIdentifier string // an identifier to send to nsqd when connecting (defaults: short hostname) + LongIdentifier string // an identifier to send to nsqd when connecting (defaults: long hostname) + + HeartbeatInterval time.Duration // duration of time between heartbeats + UserAgent string // a string identifying the agent for this client in the spirit of HTTP (default: "/") + + // transport layer security + TLSv1 bool // negotiate enabling TLS + TLSConfig *tls.Config // client TLS configuration + + // compression + Deflate bool // negotiate enabling Deflate compression + DeflateLevel int // the compression level to negotiate for Deflate + Snappy bool // negotiate enabling Snappy compression + + // output buffering + OutputBufferSize int64 // size of the buffer (in bytes) used by nsqd for buffering writes to this connection + OutputBufferTimeout time.Duration // timeout (in ms) used by nsqd before flushing buffered writes (set to 0 to disable). Warning: configuring clients with an extremely low (< 25ms) output_buffer_timeout has a significant effect on nsqd CPU usage (particularly with > 50 clients connected). concurrentWriters int32 transactionChan chan *WriterTransaction - dataChan chan []byte transactions []*WriterTransaction state int32 stopFlag int32 exitChan chan int - closeChan chan int wg sync.WaitGroup } @@ -65,22 +83,16 @@ var ErrStopped = errors.New("stopped") // NewWriter returns an instance of Writer for the specified address func NewWriter(addr string) *Writer { - hostname, err := os.Hostname() - if err != nil { - log.Fatalf("ERROR: unable to get hostname %s", err.Error()) - } return &Writer{ + Addr: addr, + transactionChan: make(chan *WriterTransaction), exitChan: make(chan int), + responseChan: make(chan []byte), + errorChan: make(chan []byte), + ioErrorChan: make(chan error), + heartbeatChan: make(chan int), closeChan: make(chan int), - dataChan: make(chan []byte), - - // can be overriden before connecting - Addr: addr, - WriteTimeout: time.Second, - HeartbeatInterval: DefaultClientTimeout / 2, - ShortIdentifier: strings.Split(hostname, ".")[0], - LongIdentifier: hostname, } } @@ -94,6 +106,7 @@ func (w *Writer) Stop() { if !atomic.CompareAndSwapInt32(&w.stopFlag, 0, 1) { return } + close(w.exitChan) w.close() w.wg.Wait() } @@ -105,7 +118,8 @@ func (w *Writer) Stop() { // the supplied `doneChan` (if specified) // will receive a `WriterTransaction` instance with the supplied variadic arguments // (and the response `FrameType`, `Data`, and `Error`) -func (w *Writer) PublishAsync(topic string, body []byte, doneChan chan *WriterTransaction, args ...interface{}) error { +func (w *Writer) PublishAsync(topic string, body []byte, doneChan chan *WriterTransaction, + args ...interface{}) error { return w.sendCommandAsync(Publish(topic, body), doneChan, args) } @@ -116,7 +130,8 @@ func (w *Writer) PublishAsync(topic string, body []byte, doneChan chan *WriterTr // the supplied `doneChan` (if specified) // will receive a `WriterTransaction` instance with the supplied variadic arguments // (and the response `FrameType`, `Data`, and `Error`) -func (w *Writer) MultiPublishAsync(topic string, body [][]byte, doneChan chan *WriterTransaction, args ...interface{}) error { +func (w *Writer) MultiPublishAsync(topic string, body [][]byte, doneChan chan *WriterTransaction, + args ...interface{}) error { cmd, err := MultiPublish(topic, body) if err != nil { return err @@ -151,7 +166,8 @@ func (w *Writer) sendCommand(cmd *Command) (int32, []byte, error) { return t.FrameType, t.Data, t.Error } -func (w *Writer) sendCommandAsync(cmd *Command, doneChan chan *WriterTransaction, args []interface{}) error { +func (w *Writer) sendCommandAsync(cmd *Command, doneChan chan *WriterTransaction, + args []interface{}) error { // keep track of how many outstanding writers we're dealing with // in order to later ensure that we clean them all up... atomic.AddInt32(&w.concurrentWriters, 1) @@ -190,68 +206,85 @@ func (w *Writer) connect() error { } log.Printf("[%s] connecting...", w) - conn, err := net.DialTimeout("tcp", w.Addr, time.Second*5) - if err != nil { - log.Printf("ERROR: [%s] failed to dial %s - %s", w, w.Addr, err) - atomic.StoreInt32(&w.state, StateInit) - return err + + conn := NewConn(w.Addr, "", "") + if w.ReadTimeout > 0 { + conn.ReadTimeout = w.ReadTimeout + } + if w.WriteTimeout > 0 { + conn.WriteTimeout = w.WriteTimeout + } + conn.Deflate = w.Deflate + if w.DeflateLevel > 0 { + conn.DeflateLevel = w.DeflateLevel + } + conn.Snappy = w.Snappy + conn.TLSv1 = w.TLSv1 + conn.TLSConfig = w.TLSConfig + if w.ShortIdentifier != "" { + conn.ShortIdentifier = w.ShortIdentifier + } + if w.LongIdentifier != "" { + conn.LongIdentifier = w.LongIdentifier + } + if w.HeartbeatInterval != 0 { + conn.HeartbeatInterval = w.HeartbeatInterval + } + if w.OutputBufferSize != 0 { + conn.OutputBufferSize = w.OutputBufferSize + } + if w.OutputBufferTimeout != 0 { + conn.OutputBufferTimeout = w.OutputBufferTimeout + } + if w.UserAgent != "" { + conn.UserAgent = w.UserAgent } - w.closeChan = make(chan int) - w.Conn = conn + conn.ResponseCB = func(c *Conn, data []byte) { + w.responseChan <- data + } - w.SetWriteDeadline(time.Now().Add(w.WriteTimeout)) - _, err = w.Write(MagicV2) - if err != nil { - log.Printf("ERROR: [%s] failed to write magic - %s", w, err) - w.close() - return err + conn.ErrorCB = func(c *Conn, data []byte) { + w.errorChan <- data } - ci := make(map[string]interface{}) - ci["short_id"] = w.ShortIdentifier - ci["long_id"] = w.LongIdentifier - ci["heartbeat_interval"] = int64(w.HeartbeatInterval / time.Millisecond) - ci["feature_negotiation"] = true - cmd, err := Identify(ci) - if err != nil { - log.Printf("ERROR: [%s] failed to create IDENTIFY command - %s", w, err) - w.close() - return err + conn.HeartbeatCB = func(c *Conn) { + w.heartbeatChan <- 1 } - w.SetWriteDeadline(time.Now().Add(w.WriteTimeout)) - err = cmd.Write(w) - if err != nil { - log.Printf("ERROR: [%s] failed to write IDENTIFY - %s", w, err) - w.close() - return err + conn.IOErrorCB = func(c *Conn, err error) { + w.ioErrorChan <- err } - w.SetReadDeadline(time.Now().Add(w.HeartbeatInterval * 2)) - resp, err := ReadResponse(w) - if err != nil { - log.Printf("ERROR: [%s] failed to read IDENTIFY response - %s", w, err) - w.close() - return err + conn.CloseCB = func(c *Conn) { + w.closeChan <- 1 } - frameType, data, err := UnpackResponse(resp) + resp, err := conn.Connect() if err != nil { - log.Printf("ERROR: [%s] failed to unpack IDENTIFY response - %s", w, resp) - w.close() + conn.Close() + log.Printf("ERROR: [%s] failed to IDENTIFY - %s", w, err) + atomic.StoreInt32(&w.state, StateInit) return err } - if frameType == FrameTypeError { - log.Printf("ERROR: [%s] IDENTIFY returned error response - %s", w, data) - w.close() - return errors.New(string(data)) + if resp != nil { + log.Printf("[%s] IDENTIFY response: %+v", w, resp) + if resp.TLSv1 { + log.Printf("[%s] upgrading to TLS", w) + } + if resp.Deflate { + log.Printf("[%s] upgrading to Deflate", w) + } + if resp.Snappy { + log.Printf("[%s] upgrading to Snappy", w) + } } - w.wg.Add(2) - go w.readLoop() - go w.messageRouter() + w.conn = conn + + w.wg.Add(1) + go w.router() return nil } @@ -260,8 +293,7 @@ func (w *Writer) close() { if !atomic.CompareAndSwapInt32(&w.state, StateConnected, StateDisconnected) { return } - close(w.closeChan) - w.Conn.Close() + w.conn.Close() go func() { // we need to handle this in a goroutine so we don't // block the caller from making progress @@ -270,46 +302,29 @@ func (w *Writer) close() { }() } -func (w *Writer) messageRouter() { +func (w *Writer) router() { for { select { case t := <-w.transactionChan: w.transactions = append(w.transactions, t) - w.SetWriteDeadline(time.Now().Add(w.WriteTimeout)) - err := t.cmd.Write(w.Conn) + err := w.conn.SendCommand(t.cmd) if err != nil { log.Printf("ERROR: [%s] failed writing %s", w, err) w.close() - goto exit - } - case buf := <-w.dataChan: - frameType, data, err := UnpackResponse(buf) - if err != nil { - log.Printf("ERROR: [%s] failed (%s) unpacking response %d %s", w, err, frameType, data) - w.close() - goto exit - } - - if frameType == FrameTypeResponse && bytes.Equal(data, []byte("_heartbeat_")) { - log.Printf("[%s] heartbeat received", w) - w.SetWriteDeadline(time.Now().Add(w.WriteTimeout)) - err := Nop().Write(w.Conn) - if err != nil { - log.Printf("ERROR: [%s] failed sending heartbeat - %s", w, err) - w.close() - goto exit - } - continue } - - t := w.transactions[0] - w.transactions = w.transactions[1:] - t.FrameType = frameType - t.Data = data - t.Error = nil - t.finish() + case data := <-w.responseChan: + w.popTransaction(FrameTypeResponse, data) + case data := <-w.errorChan: + w.popTransaction(FrameTypeError, data) + case <-w.heartbeatChan: + log.Printf("[%s] heartbeat received", w) + case err := <-w.ioErrorChan: + log.Printf("ERROR: [%s] %s", w, err) + w.close() case <-w.closeChan: goto exit + case <-w.exitChan: + goto exit } } @@ -319,6 +334,15 @@ exit: log.Printf("[%s] exiting messageRouter()", w) } +func (w *Writer) popTransaction(frameType int32, data []byte) { + t := w.transactions[0] + w.transactions = w.transactions[1:] + t.FrameType = frameType + t.Data = data + t.Error = nil + t.finish() +} + func (w *Writer) transactionCleanup() { // clean up transactions we can easily account for for _, t := range w.transactions { @@ -346,27 +370,3 @@ func (w *Writer) transactionCleanup() { } } } - -func (w *Writer) readLoop() { - rbuf := bufio.NewReader(w.Conn) - for { - w.SetReadDeadline(time.Now().Add(w.HeartbeatInterval * 2)) - resp, err := ReadResponse(rbuf) - if err != nil { - if !strings.Contains(err.Error(), "use of closed network connection") { - log.Printf("ERROR: [%s] reading response %s", w, err) - } - w.close() - goto exit - } - select { - case w.dataChan <- resp: - case <-w.closeChan: - goto exit - } - } - -exit: - w.wg.Done() - log.Printf("[%s] exiting readLoop()", w) -} diff --git a/writer_test.go b/writer_test.go index a3bc79d0..26a3128c 100644 --- a/writer_test.go +++ b/writer_test.go @@ -202,7 +202,8 @@ func TestWriterHeartbeat(t *testing.T) { if err == nil { t.Fatalf("error should not be nil") } - if err.Error() != "E_BAD_BODY IDENTIFY heartbeat interval (100) is invalid" { + if identifyError, ok := err.(ErrIdentify); !ok || + identifyError.Reason != "E_BAD_BODY IDENTIFY heartbeat interval (100) is invalid" { t.Fatalf("wrong error - %s", err) }