Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite mergeRequestModifiers to operate on pointers rather than copies. #224

Merged
merged 4 commits into from
Jul 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 10 additions & 16 deletions client_requests.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ func sendStructuredRequestParseResponse[ResponseT any](
path string,
body any,
parameters url.Values,
requestModifiersPerRequest requestModifiers,
requestModifiers requestModifiers,
) (*Response[ResponseT], error) {
var buf bytes.Buffer

Expand All @@ -155,7 +155,7 @@ func sendStructuredRequestParseResponse[ResponseT any](
path,
&buf,
parameters,
requestModifiersPerRequest,
requestModifiers,
)
}

Expand All @@ -167,7 +167,7 @@ func sendRequestParseResponse[ResponseT any](
path string,
body io.Reader,
parameters url.Values,
requestModifiersPerRequest requestModifiers,
requestModifiers requestModifiers,
) (*Response[ResponseT], error) {
// apply the client-level request timeout, if set
if client.configuration.RequestTimeout > 0 {
Expand All @@ -177,13 +177,10 @@ func sendRequestParseResponse[ResponseT any](
}

// clone the client-level request modifiers to prevent race conditions
requestModifiersClient := client.cloneClientRequestModifiers()
modifiers := client.cloneClientRequestModifiers()

// merge the client-level & request-level modifiers, preferring the later
modifiers := mergeRequestModifiers(
requestModifiersClient,
requestModifiersPerRequest,
)
// merge in the request-level request modifiers
mergeRequestModifiers(&modifiers, &requestModifiers)

req, err := client.newRequest(ctx, method, path, body, parameters, modifiers.headers)
if err != nil {
Expand Down Expand Up @@ -215,16 +212,13 @@ func sendRequestReturnRawResponse(
path string,
body io.Reader,
parameters url.Values,
requestModifiersPerRequest requestModifiers,
requestModifiers requestModifiers,
) (*http.Response, error) {
// clone the client-level request modifiers to prevent race conditions
requestModifiersClient := client.cloneClientRequestModifiers()
modifiers := client.cloneClientRequestModifiers()

// merge the client-level & request-level modifiers, preferring the later
modifiers := mergeRequestModifiers(
requestModifiersClient,
requestModifiersPerRequest,
)
// merge in the request-level request modifiers
mergeRequestModifiers(&modifiers, &requestModifiers)

req, err := client.newRequest(ctx, method, path, body, parameters, modifiers.headers)
if err != nil {
Expand Down
44 changes: 20 additions & 24 deletions request_modifiers.go
Original file line number Diff line number Diff line change
Expand Up @@ -245,47 +245,43 @@ func (m *requestModifiers) additionalQueryParametersOrDefault() url.Values {
return m.additionalQueryParameters
}

// mergeRequestModifiers merges the two objects, preferring the per-request modifiers
func mergeRequestModifiers(perClient, perRequest requestModifiers) requestModifiers {
merged := perClient

if perRequest.headers.userAgent != "" {
merged.headers.userAgent = perRequest.headers.userAgent
// mergeRequestModifiers merges the values from *rhs into *lhs.
func mergeRequestModifiers(lhs, rhs *requestModifiers) {
if rhs.headers.userAgent != "" {
lhs.headers.userAgent = rhs.headers.userAgent
}

if perRequest.headers.token != "" {
merged.headers.token = perRequest.headers.token
if rhs.headers.token != "" {
lhs.headers.token = rhs.headers.token
}

if perRequest.headers.namespace != "" {
merged.headers.namespace = perRequest.headers.namespace
if rhs.headers.namespace != "" {
lhs.headers.namespace = rhs.headers.namespace
}

if len(perRequest.headers.mfaCredentials) != 0 {
merged.headers.mfaCredentials = perRequest.headers.mfaCredentials
if len(rhs.headers.mfaCredentials) != 0 {
lhs.headers.mfaCredentials = rhs.headers.mfaCredentials
}

if perRequest.headers.responseWrappingTTL != 0 {
merged.headers.responseWrappingTTL = perRequest.headers.responseWrappingTTL
if rhs.headers.responseWrappingTTL != 0 {
lhs.headers.responseWrappingTTL = rhs.headers.responseWrappingTTL
}

if perRequest.headers.replicationForwardingMode != ReplicationForwardNone {
merged.headers.replicationForwardingMode = perRequest.headers.replicationForwardingMode
if rhs.headers.replicationForwardingMode != ReplicationForwardNone {
lhs.headers.replicationForwardingMode = rhs.headers.replicationForwardingMode
}

if len(perRequest.headers.customHeaders) != 0 {
merged.headers.customHeaders = perRequest.headers.customHeaders
if len(rhs.headers.customHeaders) != 0 {
lhs.headers.customHeaders = rhs.headers.customHeaders
}

if len(perRequest.requestCallbacks) != 0 {
merged.requestCallbacks = perRequest.requestCallbacks
if len(rhs.requestCallbacks) != 0 {
lhs.requestCallbacks = rhs.requestCallbacks
}

if len(perRequest.responseCallbacks) != 0 {
merged.responseCallbacks = perRequest.responseCallbacks
if len(rhs.responseCallbacks) != 0 {
lhs.responseCallbacks = rhs.responseCallbacks
}

return merged
}

func validateToken(token string) error {
Expand Down
45 changes: 23 additions & 22 deletions request_modifiers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,43 +88,44 @@ func Test_validateCustomHeaders(t *testing.T) {
}
}

func Test_mergeRequestModifiers(t *testing.T) {
func Test_mergeRequestModifiers_overwrite(t *testing.T) {
cases := map[string]struct {
name string
a requestModifiers
b requestModifiers
lhs requestModifiers
rhs requestModifiers
expected requestModifiers
}{
"empty": {
a: requestModifiers{},
b: requestModifiers{},
lhs: requestModifiers{},
rhs: requestModifiers{},
expected: requestModifiers{},
},
"token-a": {
a: requestModifiers{headers: requestHeaders{token: "token-a"}},
b: requestModifiers{},
expected: requestModifiers{headers: requestHeaders{token: "token-a"}},
"token-in-lhs": {
lhs: requestModifiers{headers: requestHeaders{token: "token-lhs"}},
rhs: requestModifiers{},
expected: requestModifiers{headers: requestHeaders{token: "token-lhs"}},
},
"token-b": {
a: requestModifiers{},
b: requestModifiers{headers: requestHeaders{token: "token-b"}},
expected: requestModifiers{headers: requestHeaders{token: "token-b"}},
"token-in-rhs": {
lhs: requestModifiers{},
rhs: requestModifiers{headers: requestHeaders{token: "token-rhs"}},
expected: requestModifiers{headers: requestHeaders{token: "token-rhs"}},
},
"token-a-b": {
a: requestModifiers{headers: requestHeaders{token: "token-a"}},
b: requestModifiers{headers: requestHeaders{token: "token-b"}},
expected: requestModifiers{headers: requestHeaders{token: "token-b"}},
"token-in-both": {
lhs: requestModifiers{headers: requestHeaders{token: "token-lhs"}},
rhs: requestModifiers{headers: requestHeaders{token: "token-rhs"}},
expected: requestModifiers{headers: requestHeaders{token: "token-rhs"}},
},
"token-namespace": {
a: requestModifiers{headers: requestHeaders{token: "token-a"}},
b: requestModifiers{headers: requestHeaders{namespace: "namespace-b"}},
expected: requestModifiers{headers: requestHeaders{token: "token-a", namespace: "namespace-b"}},
"token-lhs-and-namespace-rhs": {
lhs: requestModifiers{headers: requestHeaders{token: "token-lhs"}},
rhs: requestModifiers{headers: requestHeaders{namespace: "namespace-rhs"}},
expected: requestModifiers{headers: requestHeaders{token: "token-lhs", namespace: "namespace-rhs"}},
},
}

for name, tc := range cases {
t.Run(name, func(t *testing.T) {
require.Equal(t, tc.expected, mergeRequestModifiers(tc.a, tc.b))
mergeRequestModifiers(&tc.lhs, &tc.rhs)
require.Equal(t, tc.expected, tc.lhs)
})
}
}