From 8afde9b2478712d91b4c7ce20cd60025c80fd6f7 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 15 Jan 2025 23:40:50 +0200 Subject: [PATCH] signalmeow: update websocket auth --- pkg/signalmeow/client.go | 9 ++------- pkg/signalmeow/contactdiscovery.go | 2 +- pkg/signalmeow/provisioning.go | 12 ++++++++++-- pkg/signalmeow/web/signalwebsocket.go | 28 ++++++++++++--------------- 4 files changed, 25 insertions(+), 26 deletions(-) diff --git a/pkg/signalmeow/client.go b/pkg/signalmeow/client.go index fcb463a5..8f44db05 100644 --- a/pkg/signalmeow/client.go +++ b/pkg/signalmeow/client.go @@ -84,12 +84,7 @@ func (cli *Client) ConnectAuthedWS(ctx context.Context, requestHandler web.Reque Str("username", username). Logger() ctx = log.WithContext(ctx) - username = url.QueryEscape(username) - password = url.QueryEscape(password) - path := web.WebsocketPath + - "?login=" + username + - "&password=" + password - authedWS := web.NewSignalWebsocket(path, &username, &password) + authedWS := web.NewSignalWebsocket(url.UserPassword(username, password)) statusChan := authedWS.Connect(ctx, &requestHandler) cli.AuthedWS = authedWS return statusChan, nil @@ -104,7 +99,7 @@ func (cli *Client) ConnectUnauthedWS(ctx context.Context) (chan web.SignalWebsoc Str("websocket_type", "unauthed"). Logger() ctx = log.WithContext(ctx) - unauthedWS := web.NewSignalWebsocket(web.WebsocketPath, nil, nil) + unauthedWS := web.NewSignalWebsocket(nil) statusChan := unauthedWS.Connect(ctx, nil) cli.UnauthedWS = unauthedWS return statusChan, nil diff --git a/pkg/signalmeow/contactdiscovery.go b/pkg/signalmeow/contactdiscovery.go index b732339c..33b0ad3d 100644 --- a/pkg/signalmeow/contactdiscovery.go +++ b/pkg/signalmeow/contactdiscovery.go @@ -111,7 +111,7 @@ func (cli *Client) doContactDiscovery(ctx context.Context, req *signalpb.CDSClie Path: path.Join("v1", ProdContactDiscoveryMrenclave, "discovery"), }).String() log.Trace().Msg("Connecting to contact discovery websocket") - ws, _, err := web.OpenWebsocketURL(ctx, addr) + ws, _, err := web.OpenWebsocket(ctx, addr) if err != nil { var closeErr websocket.CloseError if errors.As(err, &closeErr) && closeErr.Code == rateLimitCloseCode { diff --git a/pkg/signalmeow/provisioning.go b/pkg/signalmeow/provisioning.go index 507f5d97..4aedc449 100644 --- a/pkg/signalmeow/provisioning.go +++ b/pkg/signalmeow/provisioning.go @@ -87,7 +87,11 @@ func PerformProvisioning(ctx context.Context, deviceStore store.DeviceStore, dev timeoutCtx, cancel := context.WithTimeout(ctx, 2*time.Minute) defer cancel() - ws, resp, err := web.OpenWebsocket(timeoutCtx, web.WebsocketProvisioningPath) + ws, resp, err := web.OpenWebsocket(timeoutCtx, (&url.URL{ + Scheme: "wss", + Host: web.APIHostname, + Path: web.WebsocketProvisioningPath, + }).String()) if err != nil { log.Err(err).Any("resp", resp).Msg("error opening provisioning websocket") c <- ProvisioningResponse{State: StateProvisioningError, Err: err} @@ -388,7 +392,11 @@ func confirmDevice( return nil, fmt.Errorf("failed to encrypt device name: %w", err) } - ws, resp, err := web.OpenWebsocket(ctx, web.WebsocketPath) + ws, resp, err := web.OpenWebsocket(ctx, (&url.URL{ + Scheme: "wss", + Host: web.APIHostname, + Path: web.WebsocketPath, + }).String()) if err != nil { log.Err(err).Any("resp", resp).Msg("error opening websocket") return nil, err diff --git a/pkg/signalmeow/web/signalwebsocket.go b/pkg/signalmeow/web/signalwebsocket.go index 4f817c86..c273c314 100644 --- a/pkg/signalmeow/web/signalwebsocket.go +++ b/pkg/signalmeow/web/signalwebsocket.go @@ -22,6 +22,7 @@ import ( "errors" "fmt" "net/http" + "net/url" "strings" "sync" "time" @@ -43,20 +44,13 @@ type RequestHandlerFunc func(context.Context, *signalpb.WebSocketRequestMessage) type SignalWebsocket struct { ws *websocket.Conn - path string - basicAuth *string + basicAuth *url.Userinfo sendChannel chan SignalWebsocketSendMessage statusChannel chan SignalWebsocketConnectionStatus } -func NewSignalWebsocket(path string, username *string, password *string) *SignalWebsocket { - var basicAuth *string - if username != nil && password != nil { - b := base64.StdEncoding.EncodeToString([]byte(*username + ":" + *password)) - basicAuth = &b - } +func NewSignalWebsocket(basicAuth *url.Userinfo) *SignalWebsocket { return &SignalWebsocket{ - path: path, basicAuth: basicAuth, sendChannel: make(chan SignalWebsocketSendMessage), statusChannel: make(chan SignalWebsocketConnectionStatus), @@ -187,6 +181,12 @@ func (s *SignalWebsocket) connectLoop( retrying := false errorCount := 0 isFirstConnect := true + wsURL := (&url.URL{ + Scheme: "wss", + Host: APIHostname, + Path: WebsocketPath, + User: s.basicAuth, + }).String() for { if retrying { if backoff > maxBackoff { @@ -204,7 +204,7 @@ func (s *SignalWebsocket) connectLoop( } isFirstConnect = false - ws, resp, err := OpenWebsocket(ctx, s.path) + ws, resp, err := OpenWebsocket(ctx, wsURL) if resp != nil { if resp.StatusCode != 101 { // Server didn't want to open websocket @@ -555,7 +555,7 @@ func (s *SignalWebsocket) sendRequestInternal( retryCount int, ) (*signalpb.WebSocketResponseMessage, error) { if s.basicAuth != nil { - request.Headers = append(request.Headers, "authorization:Basic "+*s.basicAuth) + request.Headers = append(request.Headers, "authorization:Basic "+s.basicAuth.String()) } responseChannel := make(chan *signalpb.WebSocketResponseMessage, 1) if s.sendChannel == nil { @@ -590,11 +590,7 @@ func (s *SignalWebsocket) sendRequestInternal( return response, nil } -func OpenWebsocket(ctx context.Context, path string) (*websocket.Conn, *http.Response, error) { - return OpenWebsocketURL(ctx, "wss://"+APIHostname+path) -} - -func OpenWebsocketURL(ctx context.Context, url string) (*websocket.Conn, *http.Response, error) { +func OpenWebsocket(ctx context.Context, url string) (*websocket.Conn, *http.Response, error) { opt := &websocket.DialOptions{ HTTPClient: SignalHTTPClient, HTTPHeader: make(http.Header, 2),