Skip to content

Commit

Permalink
feat(mqtt): Use Token.WaitTimeout
Browse files Browse the repository at this point in the history
When connecting to a broker and publishing a message, use the
`Token.WaitTimeout` method instead of `Token.Wait`. `Token.Wait` waits
indefinitely, which can lead to situations when the cleint never
succeeds in connecting or publishing.

The timeout for each operation can be configured independently by
setting `mqtt-connect-timeout` and `mqtt-publish-timeout`. Both values
default to 30 seconds. The flags are hidden, as they should not commonly
be required to be changed by users.

Signed-off-by: Link Dupont <[email protected]>
  • Loading branch information
subpop committed Dec 6, 2023
1 parent 463df06 commit 0ad12f9
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 13 deletions.
53 changes: 44 additions & 9 deletions cmd/yggd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,18 @@ func main() {
Value: 0 * time.Second,
Hidden: true,
}),
altsrc.NewDurationFlag(&cli.DurationFlag{
Name: "mqtt-connect-timeout",
Usage: "Sets the time to wait before giving up to `DURATION` when connecting to an MQTT broker",
Value: 30 * time.Second,
Hidden: true,
}),
altsrc.NewDurationFlag(&cli.DurationFlag{
Name: "mqtt-publish-timeout",
Usage: "Sets the time to wait before giving up to `DURATION` when publishing a message to an MQTT broker",
Value: 30 * time.Second,
Hidden: true,
}),
}

// This BeforeFunc will load flag values from a config file only if the
Expand Down Expand Up @@ -277,7 +289,11 @@ func main() {
})
log.Tracef("subscribed to topic: %v", topic)

go publishConnectionStatus(client, d.makeDispatchersMap())
go publishConnectionStatus(
client,
d.makeDispatchersMap(),
c.Duration("mqtt-publish-timeout"),
)
})
mqttClientOpts.SetDefaultPublishHandler(func(c mqtt.Client, m mqtt.Message) {
log.Errorf("unhandled message: %v", string(m.Payload()))
Expand Down Expand Up @@ -324,7 +340,18 @@ func main() {
)

mqttClient := mqtt.NewClient(mqttClientOpts)
if token := mqttClient.Connect(); token.Wait() && token.Error() != nil {
log.Infof("connecting to broker: %v", c.StringSlice("broker"))
token := mqttClient.Connect()
if !token.WaitTimeout(c.Duration("mqtt-connect-timeout")) {
return cli.Exit(
fmt.Errorf(
"cannot connect to broker: connection timeout: %v elapsed",
c.Duration("mqtt-connect-timeout"),
),
1,
)
}
if token.Error() != nil {
return cli.Exit(fmt.Errorf("cannot connect to broker: %w", token.Error()), 1)
}

Expand All @@ -350,7 +377,11 @@ func main() {
}
}
prevDispatchersHash.Store(sum)
go publishConnectionStatus(mqttClient, dispatchers)
go publishConnectionStatus(
mqttClient,
dispatchers,
c.Duration("mqtt-publish-timeout"),
)
}
}()

Expand All @@ -360,7 +391,7 @@ func main() {

// Start a goroutine that receives yggdrasil.Data values on a 'recv'
// channel and publish them to MQTT.
go publishReceivedData(mqttClient, d.recvQ)
go publishReceivedData(mqttClient, d.recvQ, c.Duration("mqtt-publish-timeout"))

// Locate and start worker child processes.
workerPath := filepath.Join(yggdrasil.LibexecDir, yggdrasil.LongName)
Expand Down Expand Up @@ -402,21 +433,25 @@ func main() {
// Start a goroutine that watches the tags file for write events and
// publishes connection status messages when the file changes.
go func() {
c := make(chan notify.EventInfo, 1)
events := make(chan notify.EventInfo, 1)

fp := filepath.Join(yggdrasil.SysconfDir, yggdrasil.LongName, "tags.toml")

if err := notify.Watch(fp, c, notify.InCloseWrite, notify.InDelete); err != nil {
if err := notify.Watch(fp, events, notify.InCloseWrite, notify.InDelete); err != nil {
log.Infof("cannot start watching '%v': %v", fp, err)
return
}
defer notify.Stop(c)
defer notify.Stop(events)

for e := range c {
for e := range events {
log.Debugf("received inotify event %v", e.Event())
switch e.Event() {
case notify.InCloseWrite, notify.InDelete:
go publishConnectionStatus(mqttClient, d.makeDispatchersMap())
go publishConnectionStatus(
mqttClient,
d.makeDispatchersMap(),
c.Duration("mqtt-publish-timeout"),
)
}
}
}()
Expand Down
20 changes: 16 additions & 4 deletions cmd/yggd/mqtt.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,11 @@ func handleControlMessage(client mqtt.Client, msg mqtt.Message) {
}
}

func publishConnectionStatus(c mqtt.Client, dispatchers map[string]map[string]string) {
func publishConnectionStatus(
c mqtt.Client,
dispatchers map[string]map[string]string,
timeout time.Duration,
) {
facts, err := yggdrasil.GetCanonicalFacts()
if err != nil {
log.Errorf("cannot get canonical facts: %v", err)
Expand Down Expand Up @@ -125,13 +129,17 @@ func publishConnectionStatus(c mqtt.Client, dispatchers map[string]map[string]st

topic := fmt.Sprintf("%v/%v/control/out", yggdrasil.TopicPrefix, ClientID)

if token := c.Publish(topic, 1, false, data); token.Wait() && token.Error() != nil {
token := c.Publish(topic, 1, false, data)
if !token.WaitTimeout(timeout) {
log.Errorf("cannot publish message: connection timeout: %v elapsed", timeout)
}
if token.Error() != nil {
log.Errorf("failed to publish message: %v", token.Error())
}
log.Debugf("published message %v to topic %v", msg.MessageID, topic)
}

func publishReceivedData(client mqtt.Client, c <-chan yggdrasil.Data) {
func publishReceivedData(client mqtt.Client, c <-chan yggdrasil.Data, timeout time.Duration) {
for d := range c {
topic := fmt.Sprintf("%v/%v/data/out", yggdrasil.TopicPrefix, ClientID)

Expand All @@ -141,7 +149,11 @@ func publishReceivedData(client mqtt.Client, c <-chan yggdrasil.Data) {
continue
}

if token := client.Publish(topic, 1, false, data); token.Wait() && token.Error() != nil {
token := client.Publish(topic, 1, false, data)
if !token.WaitTimeout(timeout) {
log.Errorf("cannot publish message: connection timeout: %v elapsed", timeout)
}
if token.Error() != nil {
log.Errorf("failed to publish message: %v", token.Error())
}
log.Debugf("published message %v to topic %v", d.MessageID, topic)
Expand Down

0 comments on commit 0ad12f9

Please sign in to comment.