Skip to content

Commit

Permalink
Allow multiple values per websocket handshake extra header key
Browse files Browse the repository at this point in the history
  • Loading branch information
sergiu128 committed Aug 15, 2023
1 parent 949c152 commit f539e4a
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 17 deletions.
14 changes: 7 additions & 7 deletions codec/websocket/definitions.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,16 +109,16 @@ type AsyncFrameHandler = func(err error, f *Frame)
type ControlCallback = func(mt MessageType, payload []byte)

type Header struct {
Key string
Value string
Canonical bool
Key string
Values []string
CanonicalKey bool
}

func ExtraHeader(key string, value string, canonical bool) Header {
func ExtraHeader(canonicalKey bool, key string, values ...string) Header {
return Header{
Key: key,
Value: value,
Canonical: canonical,
Key: key,
Values: values,
CanonicalKey: canonicalKey,
}
}

Expand Down
8 changes: 5 additions & 3 deletions codec/websocket/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -774,12 +774,14 @@ func (s *WebsocketStream) upgrade(
req.Header.Set("Sec-Websocket-Version", "13")

for _, header := range headers {
if header.Canonical {
req.Header.Set(header.Key, header.Value)
if header.CanonicalKey {
for _, value := range header.Values {
req.Header.Add(header.Key, value)
}
} else {
req.Header[header.Key] = append(
req.Header[header.Key],
header.Value,
header.Values...,
)
}
}
Expand Down
34 changes: 27 additions & 7 deletions codec/websocket/stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,17 @@ func TestClientSuccessfulHandshakeWithExtraHeaders(t *testing.T) {

assertState(t, ws, StateHandshake)

// Keys are automatically canonicalized by Go's protocol implementation -
// hence we don't care about their casing here.
expected := map[string][]string{
"k1": {"v1"},
"k2": {"v21", "v22"},
"k3": {"v31", "v32"},
"k4": {"v4"},
"k5": {"v51", "v52"},
"k6": {"v61", "v62"},
}

ws.AsyncHandshake(
"ws://localhost:8080",
func(err error) {
Expand All @@ -224,19 +235,28 @@ func TestClientSuccessfulHandshakeWithExtraHeaders(t *testing.T) {
assertState(t, ws, StateActive)
}
},
ExtraHeader("key1-asdf", "value1", true),
ExtraHeader("key2-asdf", "value2", false),
ExtraHeader(true, "k1", "v1"),
ExtraHeader(true, "k2", "v21", "v22"),
ExtraHeader(true, "k3", "v31"), ExtraHeader(true, "k3", "v32"),
ExtraHeader(false, "k4", "v4"),
ExtraHeader(false, "k5", "v51", "v52"),
ExtraHeader(false, "k6", "v61"), ExtraHeader(false, "k6", "v62"),
)

for !srv.IsClosed() {
ioc.RunOne()
}

if srv.Upgrade.Header.Get("key1-asdf") != "value1" {
t.Fatal("invalid extra header")
}
if srv.Upgrade.Header.Get("key2-asdf") != "value2" {
t.Fatal("invalid extra header")
for key := range expected {
given := srv.Upgrade.Header.Values(key)
if len(given) != len(expected[key]) {
t.Fatal("wrong extra header")
}
for i := 0; i < len(given); i++ {
if given[i] != expected[key][i] {
t.Fatal("wrong extra header")
}
}
}
}

Expand Down

0 comments on commit f539e4a

Please sign in to comment.