Skip to content

Commit

Permalink
Merge pull request #25 from stampzilla/bugfix/lostMessages
Browse files Browse the repository at this point in the history
Fixed bug with lost messages
  • Loading branch information
jonaz authored Mar 26, 2017
2 parents 7b58369 + 62b6c90 commit c389b24
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 30 deletions.
58 changes: 42 additions & 16 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,28 @@ type sendPackage struct {
timeout time.Duration

returnChan chan *serialapi.Message
sync.RWMutex
}

func (sp *sendPackage) Close() {
sp.RLock()
if sp.returnChan == nil {
sp.RUnlock()
return
}
sp.RUnlock()

sp.Lock()
close(sp.returnChan)
sp.returnChan = nil
sp.Unlock()
}

func (sp *sendPackage) IsOpen() bool {
sp.RLock()
defer sp.RUnlock()

return sp.returnChan != nil
}

func NewConnection() *Connection {
Expand Down Expand Up @@ -136,16 +158,23 @@ func (conn *Connection) Writer() {
go conn.timeoutWorker(pkg)
}

abort := make(chan struct{})
go func() {
select {
case result := <-conn.lastResult:
pkg.result = result
case <-time.After(time.Second):
logrus.Warn("Send timeout")
// SEND TIMEOUT
}
close(abort)
}()

logrus.Debugf("Write: %x", pkg.message)
conn.readWriteCloser.Write(pkg.message)
conn.Unlock()

select {
case result := <-conn.lastResult:
pkg.result = result
case <-time.After(time.Second):
// SEND TIMEOUT
}
<-abort
conn.lastCommand = ""
}
}
Expand Down Expand Up @@ -193,8 +222,7 @@ func (conn *Connection) Reader() error {
if msg.IsNAK() {
logrus.Warnf("Command failed: %s - %#v", c.uuid, c)
delete(conn.inFlight, index)
close(c.returnChan)
c.returnChan = nil
c.Close()
}

if msg.IsCAN() {
Expand All @@ -206,6 +234,7 @@ func (conn *Connection) Reader() error {
}
conn.RUnlock()
}

}

// The message is not compleatly read yet, wait for some more data
Expand All @@ -228,14 +257,14 @@ func (conn *Connection) Reader() error {
continue
}

if c.returnChan != nil {
if c.IsOpen() {
select {
case c.returnChan <- msg: // Try to deliver message
default:
case <-time.After(time.Second):
logrus.Warnf("Timeout writing response to requester: %#v", msg)
}

close(c.returnChan)
c.returnChan = nil
c.Close()
}
delete(conn.inFlight, index)
}
Expand Down Expand Up @@ -318,10 +347,7 @@ func (conn *Connection) timeoutWorker(sp *sendPackage) {
if index == sp.uuid {
logrus.Warnf("TIMEOUT: %s", sp.uuid)
delete(conn.inFlight, sp.uuid)
if c.returnChan != nil {
close(c.returnChan)
c.returnChan = nil
}
c.Close()
}
}
conn.Unlock()
Expand Down
15 changes: 8 additions & 7 deletions gozwave_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package gozwave
import (
"fmt"
"io"
"log"
"os"
"testing"
"text/tabwriter"
Expand All @@ -23,15 +22,16 @@ func TestConnect(t *testing.T) {
controller, err := ConnectWithCustomPortOpener("/test", "", mockPO)

assert.NoError(t, err)
log.Println("controller: ", controller)
//log.Println("controller: ", controller)

reply(t, mockPO.mockSerial.getFromWrite, mockPO.mockSerial.sendToRead) // Start up a conversation loop

// TODO: make something better than a sleep here
time.Sleep(10 * time.Millisecond)
time.Sleep(100 * time.Millisecond)

n := controller.Nodes.Get(8)
if assert.NotNil(t, n.ProtocolInfo) {
assert.NotNil(t, n)
if assert.NotNil(t, n.ProtocolInfo()) {
assert.Equal(t,
serialapi.FuncGetNodeProtocolInfo{
Listening: false,
Expand All @@ -45,7 +45,7 @@ func TestConnect(t *testing.T) {
Generic: 0x40,
Specific: 0x0,
},
*n.ProtocolInfo,
*n.ProtocolInfo(),
)
}
}
Expand All @@ -54,10 +54,11 @@ func reply(t *testing.T, c chan []byte, w chan string) {
replies := map[string][]string{
"06": []string{},
"01030002fe": []string{ // Request discovery nodes
"06", // Ack
"0125010205001d8000000000000000000000000000000000000000000000000000000000050044", // Answer with node 8 active
},
"0104004108b2": []string{ // Request node information node 8
"06", // ack
"06", // Ack
"01080141539c0004403c", // Answer with node information
},
}
Expand All @@ -82,7 +83,7 @@ func reply(t *testing.T, c chan []byte, w chan string) {
}

if len(reads) == len(replies) { // We got all messages expected. Exit the loop
<-time.After(time.Second)
//<-time.After(time.Second)
return
}
break
Expand Down
15 changes: 8 additions & 7 deletions nodes/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ type Node struct {
Id int `json:"id"`
IsAwake bool `json:"is_awake"`

ProtocolInfo *serialapi.FuncGetNodeProtocolInfo
protocolInfo *serialapi.FuncGetNodeProtocolInfo
ManufacurerSpecific *reports.ManufacturerSpecific

//Device *database.Device
Expand Down Expand Up @@ -149,7 +149,7 @@ func (n *Node) Identify() {
defer logrus.Infof("Ended identification on node %d", n.Id)

for {
if n.ProtocolInfo == nil {
if n.ProtocolInfo() == nil {
resp, err := n.RequestProtocolInfo()
if err != nil {
logrus.Errorf("Node ident: Failed RequestProtocolInfo: %s", err.Error())
Expand All @@ -158,18 +158,19 @@ func (n *Node) Identify() {
}

n.Lock()
n.ProtocolInfo = resp
n.protocolInfo = resp
n.IsAwake = resp.Listening
n.pushEvent(events.NodeUpdated{
Address: n.Id,
})
n.Unlock()

}

// set basic commandClasses
classes := database.GetMandatoryCommandClasses(n.ProtocolInfo().Generic, n.ProtocolInfo().Specific)

n.Lock()
n.CommandClasses = database.GetMandatoryCommandClasses(n.ProtocolInfo.Generic, n.ProtocolInfo.Specific)
n.CommandClasses = classes
n.Unlock()

//<-self.Connection.SendRaw([]byte{serialapi.GetNodeProtocolInfo, byte(index + 1)}) // Request node information
Expand Down Expand Up @@ -313,9 +314,9 @@ func (n *Node) HasCommand(c commands.ZWaveCommand) bool {
}

func (n *Node) IsDeviceClass(generic, specific byte) bool {
if n.ProtocolInfo == nil {
if n.ProtocolInfo() == nil {
return false
}

return n.ProtocolInfo.Generic == generic && n.ProtocolInfo.Specific == specific
return n.ProtocolInfo().Generic == generic && n.ProtocolInfo().Specific == specific
}
7 changes: 7 additions & 0 deletions nodes/protocol_info.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@ import (
"github.com/stampzilla/gozwave/serialapi"
)

func (n *Node) ProtocolInfo() *serialapi.FuncGetNodeProtocolInfo {
n.RLock()
defer n.RUnlock()

return n.protocolInfo
}

func (n *Node) RequestProtocolInfo() (*serialapi.FuncGetNodeProtocolInfo, error) {
cmd := serialapi.NewRaw(
[]byte{
Expand Down

0 comments on commit c389b24

Please sign in to comment.