Skip to content

Commit

Permalink
signalmeow: update websocket auth
Browse files Browse the repository at this point in the history
  • Loading branch information
tulir committed Jan 15, 2025
1 parent 53ad7fa commit 8afde9b
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 26 deletions.
9 changes: 2 additions & 7 deletions pkg/signalmeow/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,7 @@ func (cli *Client) ConnectAuthedWS(ctx context.Context, requestHandler web.Reque
Str("username", username).
Logger()
ctx = log.WithContext(ctx)
username = url.QueryEscape(username)
password = url.QueryEscape(password)
path := web.WebsocketPath +
"?login=" + username +
"&password=" + password
authedWS := web.NewSignalWebsocket(path, &username, &password)
authedWS := web.NewSignalWebsocket(url.UserPassword(username, password))
statusChan := authedWS.Connect(ctx, &requestHandler)
cli.AuthedWS = authedWS
return statusChan, nil
Expand All @@ -104,7 +99,7 @@ func (cli *Client) ConnectUnauthedWS(ctx context.Context) (chan web.SignalWebsoc
Str("websocket_type", "unauthed").
Logger()
ctx = log.WithContext(ctx)
unauthedWS := web.NewSignalWebsocket(web.WebsocketPath, nil, nil)
unauthedWS := web.NewSignalWebsocket(nil)
statusChan := unauthedWS.Connect(ctx, nil)
cli.UnauthedWS = unauthedWS
return statusChan, nil
Expand Down
2 changes: 1 addition & 1 deletion pkg/signalmeow/contactdiscovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ func (cli *Client) doContactDiscovery(ctx context.Context, req *signalpb.CDSClie
Path: path.Join("v1", ProdContactDiscoveryMrenclave, "discovery"),
}).String()
log.Trace().Msg("Connecting to contact discovery websocket")
ws, _, err := web.OpenWebsocketURL(ctx, addr)
ws, _, err := web.OpenWebsocket(ctx, addr)
if err != nil {
var closeErr websocket.CloseError
if errors.As(err, &closeErr) && closeErr.Code == rateLimitCloseCode {
Expand Down
12 changes: 10 additions & 2 deletions pkg/signalmeow/provisioning.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,11 @@ func PerformProvisioning(ctx context.Context, deviceStore store.DeviceStore, dev

timeoutCtx, cancel := context.WithTimeout(ctx, 2*time.Minute)
defer cancel()
ws, resp, err := web.OpenWebsocket(timeoutCtx, web.WebsocketProvisioningPath)
ws, resp, err := web.OpenWebsocket(timeoutCtx, (&url.URL{
Scheme: "wss",
Host: web.APIHostname,
Path: web.WebsocketProvisioningPath,
}).String())
if err != nil {
log.Err(err).Any("resp", resp).Msg("error opening provisioning websocket")
c <- ProvisioningResponse{State: StateProvisioningError, Err: err}
Expand Down Expand Up @@ -388,7 +392,11 @@ func confirmDevice(
return nil, fmt.Errorf("failed to encrypt device name: %w", err)
}

ws, resp, err := web.OpenWebsocket(ctx, web.WebsocketPath)
ws, resp, err := web.OpenWebsocket(ctx, (&url.URL{
Scheme: "wss",
Host: web.APIHostname,
Path: web.WebsocketPath,
}).String())
if err != nil {
log.Err(err).Any("resp", resp).Msg("error opening websocket")
return nil, err
Expand Down
28 changes: 12 additions & 16 deletions pkg/signalmeow/web/signalwebsocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"errors"
"fmt"
"net/http"
"net/url"
"strings"
"sync"
"time"
Expand All @@ -43,20 +44,13 @@ type RequestHandlerFunc func(context.Context, *signalpb.WebSocketRequestMessage)

type SignalWebsocket struct {
ws *websocket.Conn
path string
basicAuth *string
basicAuth *url.Userinfo
sendChannel chan SignalWebsocketSendMessage
statusChannel chan SignalWebsocketConnectionStatus
}

func NewSignalWebsocket(path string, username *string, password *string) *SignalWebsocket {
var basicAuth *string
if username != nil && password != nil {
b := base64.StdEncoding.EncodeToString([]byte(*username + ":" + *password))
basicAuth = &b
}
func NewSignalWebsocket(basicAuth *url.Userinfo) *SignalWebsocket {
return &SignalWebsocket{
path: path,
basicAuth: basicAuth,
sendChannel: make(chan SignalWebsocketSendMessage),
statusChannel: make(chan SignalWebsocketConnectionStatus),
Expand Down Expand Up @@ -187,6 +181,12 @@ func (s *SignalWebsocket) connectLoop(
retrying := false
errorCount := 0
isFirstConnect := true
wsURL := (&url.URL{
Scheme: "wss",
Host: APIHostname,
Path: WebsocketPath,
User: s.basicAuth,
}).String()
for {
if retrying {
if backoff > maxBackoff {
Expand All @@ -204,7 +204,7 @@ func (s *SignalWebsocket) connectLoop(
}
isFirstConnect = false

ws, resp, err := OpenWebsocket(ctx, s.path)
ws, resp, err := OpenWebsocket(ctx, wsURL)
if resp != nil {
if resp.StatusCode != 101 {
// Server didn't want to open websocket
Expand Down Expand Up @@ -555,7 +555,7 @@ func (s *SignalWebsocket) sendRequestInternal(
retryCount int,
) (*signalpb.WebSocketResponseMessage, error) {
if s.basicAuth != nil {
request.Headers = append(request.Headers, "authorization:Basic "+*s.basicAuth)
request.Headers = append(request.Headers, "authorization:Basic "+s.basicAuth.String())
}
responseChannel := make(chan *signalpb.WebSocketResponseMessage, 1)
if s.sendChannel == nil {
Expand Down Expand Up @@ -590,11 +590,7 @@ func (s *SignalWebsocket) sendRequestInternal(
return response, nil
}

func OpenWebsocket(ctx context.Context, path string) (*websocket.Conn, *http.Response, error) {
return OpenWebsocketURL(ctx, "wss://"+APIHostname+path)
}

func OpenWebsocketURL(ctx context.Context, url string) (*websocket.Conn, *http.Response, error) {
func OpenWebsocket(ctx context.Context, url string) (*websocket.Conn, *http.Response, error) {
opt := &websocket.DialOptions{
HTTPClient: SignalHTTPClient,
HTTPHeader: make(http.Header, 2),
Expand Down

0 comments on commit 8afde9b

Please sign in to comment.