Skip to content

Commit

Permalink
feat: lazy websocket connection
Browse files Browse the repository at this point in the history
only create actual websocket connection when there are pending swaps and
close it if there are no more pending swaps
  • Loading branch information
jackstar12 committed Feb 17, 2025
1 parent 71b97a5 commit 78401ae
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 38 deletions.
87 changes: 62 additions & 25 deletions pkg/boltz/ws.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,25 @@ type SwapUpdate struct {
Id string `json:"id"`
}

type websocketState string

const (
disconnected websocketState = "disconnected"
connected websocketState = "connected"
reconnecting websocketState = "reconnecting"
closed websocketState = "closed"
)

type Websocket struct {
Updates chan SwapUpdate

apiUrl string
subscriptions chan bool
conn *websocket.Conn
closed bool
reconnect bool
dialer *websocket.Dialer
swapIds []string
apiUrl string
subscriptions chan bool
conn *websocket.Conn
dialer *websocket.Dialer
swapIds []string
state websocketState
reconnectInterval time.Duration
}

type wsResponse struct {
Expand All @@ -51,14 +60,19 @@ func (boltz *Api) NewWebsocket() *Websocket {
}

return &Websocket{
apiUrl: boltz.URL,
subscriptions: make(chan bool),
dialer: &dialer,
Updates: make(chan SwapUpdate),
apiUrl: boltz.URL,
subscriptions: make(chan bool),
dialer: &dialer,
Updates: make(chan SwapUpdate),
state: disconnected,
reconnectInterval: reconnectInterval,
}
}

func (boltz *Websocket) Connect() error {
if boltz.state == closed {
return errors.New("websocket is closed")
}
wsUrl, err := url.Parse(boltz.apiUrl)
if err != nil {
return err
Expand All @@ -75,9 +89,10 @@ func (boltz *Websocket) Connect() error {
if err != nil {
return fmt.Errorf("could not connect to boltz ws at %s: %w", wsUrl, err)
}
boltz.conn = conn

logger.Infof("Connected to Boltz ws at %s", wsUrl)
boltz.conn = conn
boltz.state = connected

setDeadline := func() error {
return conn.SetReadDeadline(time.Now().Add(pingInterval + pongWait))
Expand All @@ -95,7 +110,7 @@ func (boltz *Websocket) Connect() error {
// Will not wait longer with writing than for the response
err := conn.WriteControl(websocket.PingMessage, nil, time.Now().Add(pongWait))
if err != nil {
if boltz.closed {
if boltz.state == closed {
return
}
logger.Errorf("could not send ping: %s", err)
Expand All @@ -108,11 +123,14 @@ func (boltz *Websocket) Connect() error {
for {
msgType, message, err := conn.ReadMessage()
if err != nil {
if boltz.closed {
close(boltz.Updates)
if boltz.state == closed || boltz.state == disconnected {
boltz.conn = nil
if boltz.state == closed {
close(boltz.Updates)
}
return
}
if !boltz.reconnect {
if boltz.state == connected {
logger.Error("could not receive message: " + err.Error())
}
break
Expand Down Expand Up @@ -156,12 +174,12 @@ func (boltz *Websocket) Connect() error {
}
for {
pingTicker.Stop()
if boltz.reconnect {
boltz.reconnect = false
if boltz.state == reconnecting {
return
} else {
logger.Errorf("lost connection to boltz ws, reconnecting in %s", reconnectInterval)
time.Sleep(reconnectInterval)
boltz.state = reconnecting
logger.Errorf("lost connection to boltz ws, reconnecting in %s", boltz.reconnectInterval)
time.Sleep(boltz.reconnectInterval)
}
err := boltz.Connect()
if err == nil {
Expand All @@ -178,7 +196,7 @@ func (boltz *Websocket) Connect() error {
}

func (boltz *Websocket) subscribe(swapIds []string) error {
if boltz.closed {
if boltz.state == closed {
return errors.New("websocket is closed")
}
logger.Infof("Subscribing to Swaps: %v", swapIds)
Expand All @@ -201,6 +219,11 @@ func (boltz *Websocket) subscribe(swapIds []string) error {
}

func (boltz *Websocket) Subscribe(swapIds []string) error {
if boltz.state == disconnected {
if err := boltz.Connect(); err != nil {
return fmt.Errorf("could not connect boltz ws: %w", err)
}
}
if err := boltz.subscribe(swapIds); err != nil {
// the connection might be dead, so forcefully reconnect
if err := boltz.Reconnect(); err != nil {
Expand All @@ -219,19 +242,33 @@ func (boltz *Websocket) Unsubscribe(swapId string) {
return id == swapId
})
logger.Debugf("Unsubscribed from swap %s", swapId)
if len(boltz.swapIds) == 0 {
logger.Debugf("No more pending swaps, disconnecting websocket")
boltz.state = disconnected
if err := boltz.close(); err != nil {
logger.Warnf("could not close boltz ws: %v", err)
}
}
}

func (boltz *Websocket) close() error {
if conn := boltz.conn; conn != nil {
return conn.Close()
}
return nil
}

func (boltz *Websocket) Close() error {
boltz.closed = true
return boltz.conn.Close()
boltz.state = closed
return boltz.close()
}

func (boltz *Websocket) Reconnect() error {
if boltz.closed {
if boltz.state == closed {
return errors.New("websocket is closed")
}
logger.Infof("Force reconnecting to Boltz ws")
boltz.reconnect = true
boltz.state = reconnecting
if err := boltz.conn.Close(); err != nil {
logger.Warnf("could not close boltz ws: %v", err)
}
Expand Down
95 changes: 82 additions & 13 deletions pkg/boltz/ws_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,96 @@ import (
"github.com/BoltzExchange/boltz-client/v2/internal/logger"
"github.com/stretchr/testify/require"
"testing"
"time"
)

func TestWebsocketReconnect(t *testing.T) {
func setupWs(t *testing.T) *Websocket {
logger.Init(logger.Options{Level: "debug"})
api := Api{URL: "http://localhost:9001"}

ws := api.NewWebsocket()
err := ws.Connect()
require.NoError(t, err)
require.Nil(t, ws.conn)
require.Equal(t, disconnected, ws.state)
return ws
}

func TestWebsocketLazy(t *testing.T) {
ws := setupWs(t)

firstId := "swapId"
err = ws.Subscribe([]string{firstId})
secondId := "anotherSwapId"
err := ws.Subscribe([]string{firstId, secondId})
require.NoError(t, err)

firstConn := ws.conn
err = firstConn.Close()
require.NoError(t, err)
require.Equal(t, []string{firstId, secondId}, ws.swapIds)
require.Equal(t, connected, ws.state)
require.NotNil(t, ws.conn)

anotherId := "anotherSwapId"
err = ws.Subscribe([]string{anotherId})
require.NoError(t, err)
require.NotEqual(t, firstConn, ws.conn, "subscribe should reconnect forcefully")
require.Equal(t, []string{firstId, anotherId}, ws.swapIds)
ws.Unsubscribe(firstId)
require.Equal(t, connected, ws.state)
require.Equal(t, []string{secondId}, ws.swapIds)
ws.Unsubscribe(secondId)

require.Equal(t, disconnected, ws.state)
require.Nil(t, ws.conn)
}

func TestWebsocketReconnect(t *testing.T) {
setup := func(t *testing.T) *Websocket {
ws := setupWs(t)
require.NoError(t, ws.Connect())
require.NotNil(t, ws.conn)
require.Equal(t, connected, ws.state)
return ws
}

t.Run("Automatic", func(t *testing.T) {
ws := setup(t)
ws.reconnectInterval = 50 * time.Millisecond
firstConn := ws.conn
require.NoError(t, ws.conn.Close())

waitFor := time.Second
require.Eventually(t, func() bool {
return ws.state == reconnecting
}, waitFor, ws.reconnectInterval/2)

require.Eventually(t, func() bool {
return ws.state == connected
}, waitFor, ws.reconnectInterval)

newConn := ws.conn
require.NotNil(t, newConn)
require.NotEqual(t, firstConn, newConn)
})

t.Run("Force", func(t *testing.T) {
ws := setup(t)
firstConn := ws.conn

err := ws.Subscribe([]string{"swapId"})
require.NoError(t, err)

require.NoError(t, ws.conn.Close())

err = ws.Subscribe([]string{"anotherSwapId"})
require.NoError(t, err)
require.NotEqual(t, firstConn, ws.conn, "subscribe should reconnect forcefully")
require.Equal(t, connected, ws.state)
})

}

func TestWebsocketShutdown(t *testing.T) {
ws := setupWs(t)
require.NoError(t, ws.Connect())
require.NotNil(t, ws.conn)
require.Equal(t, connected, ws.state)

require.NoError(t, ws.Close())
require.Eventually(t, func() bool {
return ws.state == closed
}, time.Second, 10*time.Millisecond)
require.Nil(t, ws.conn)

require.Error(t, ws.Subscribe([]string{"swapId"}))
}

0 comments on commit 78401ae

Please sign in to comment.