Skip to content

Commit

Permalink
Move websocket headers to opt function 'WithWebsocketHeaders'
Browse files Browse the repository at this point in the history
  • Loading branch information
HaraldNordgren committed Dec 1, 2024
1 parent 5913cd6 commit 221ad76
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 11 deletions.
21 changes: 13 additions & 8 deletions graphql/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,16 +133,10 @@ type WebSocketOption func(*webSocketClient)
//
// The client does not support queries nor mutations, and will return an error
// if passed a request that attempts one.
func NewClientUsingWebSocket(endpoint string, wsDialer Dialer, headers http.Header, opts ...WebSocketOption) WebSocketClient {
if headers == nil {
headers = http.Header{}
}
if headers.Get("Sec-WebSocket-Protocol") == "" {
headers.Add("Sec-WebSocket-Protocol", "graphql-transport-ws")
}
func NewClientUsingWebSocket(endpoint string, wsDialer Dialer, opts ...WebSocketOption) WebSocketClient {
client := &webSocketClient{
Dialer: wsDialer,
Header: headers,
header: http.Header{},
errChan: make(chan error),
endpoint: endpoint,
subscriptions: subscriptionMap{map_: make(map[string]subscription)},
Expand All @@ -152,6 +146,10 @@ func NewClientUsingWebSocket(endpoint string, wsDialer Dialer, headers http.Head
opt(client)
}

if client.header.Get("Sec-WebSocket-Protocol") == "" {
client.header.Add("Sec-WebSocket-Protocol", "graphql-transport-ws")
}

return client
}

Expand All @@ -163,6 +161,13 @@ func WithConnectionParams(connParams map[string]interface{}) WebSocketOption {
}
}

// WithWebsocketHeader sets a header to be sent to the server.
func WithWebsocketHeader(header http.Header) WebSocketOption {
return func(ws *webSocketClient) {
ws.header = header
}
}

func newClient(endpoint string, httpClient Doer, method string) Client {
if httpClient == nil || httpClient == (*http.Client)(nil) {
httpClient = http.DefaultClient
Expand Down
4 changes: 2 additions & 2 deletions graphql/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ const (

type webSocketClient struct {
Dialer Dialer
Header http.Header
header http.Header
endpoint string
conn WSConn
connParams map[string]interface{}
Expand Down Expand Up @@ -169,7 +169,7 @@ func checkConnectionAckReceived(message []byte) (bool, error) {
}

func (w *webSocketClient) Start(ctx context.Context) (errChan chan error, err error) {
w.conn, err = w.Dialer.DialContext(ctx, w.endpoint, w.Header)
w.conn, err = w.Dialer.DialContext(ctx, w.endpoint, w.header)
if err != nil {
return nil, err
}
Expand Down
1 change: 0 additions & 1 deletion internal/integration/roundtrip.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,6 @@ func newRoundtripWebSocketClient(t *testing.T, endpoint string, opts ...graphql.
wsWrapped: graphql.NewClientUsingWebSocket(
endpoint,
&MyDialer{Dialer: dialer},
nil,
opts...,
),
t: t,
Expand Down
1 change: 1 addition & 0 deletions internal/integration/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ func RunServer() *httptest.Server {
graphql.RegisterExtension(ctx, "foobar", "test")
return next(ctx)
})

return httptest.NewServer(gqlgenServer)
}

Expand Down

0 comments on commit 221ad76

Please sign in to comment.