diff --git a/src/grpcbox_client_stream.erl b/src/grpcbox_client_stream.erl index f1d7bf8..faf5b75 100644 --- a/src/grpcbox_client_stream.erl +++ b/src/grpcbox_client_stream.erl @@ -13,15 +13,16 @@ -include("grpcbox.hrl"). --define(headers(Scheme, Host, Path, Encoding, MessageType, MD), [{<<":method">>, <<"POST">>}, - {<<":path">>, Path}, - {<<":scheme">>, Scheme}, - {<<":authority">>, Host}, - {<<"grpc-encoding">>, Encoding}, - {<<"grpc-message-type">>, MessageType}, - {<<"content-type">>, <<"application/grpc+proto">>}, - {<<"user-agent">>, <<"grpc-erlang/0.9.2">>}, - {<<"te">>, <<"trailers">>} | MD]). +-define(protected_headers, [<<"content-type">>, <<"te">>]). +-define(pseudoheaders(Path, Scheme, Authority), [{<<":method">>, <<"POST">>}, + {<<":path">>, Path}, + {<<":scheme">>, Scheme}, + {<<":authority">>, Authority}]). +-define(headers(Encoding, MessageType, MD), (MD ++ [{<<"grpc-encoding">>, Encoding}, + {<<"grpc-message-type">>, MessageType}, + {<<"content-type">>, <<"application/grpc+proto">>}, + {<<"user-agent">>, <<"grpc-erlang/0.9.2">>}, + {<<"te">>, <<"trailers">>}])). new_stream(Ctx, Channel, Path, Def=#grpcbox_def{service=Service, message_type=MessageType, @@ -33,8 +34,9 @@ new_stream(Ctx, Channel, Path, Def=#grpcbox_def{service=Service, encoding := DefaultEncoding, stats_handler := StatsHandler}} -> Encoding = maps:get(encoding, Options, DefaultEncoding), - RequestHeaders = ?headers(Scheme, Authority, Path, encoding_to_binary(Encoding), - MessageType, metadata_headers(Ctx)), + UserHeaders = merge_headers(?headers(encoding_to_binary(Encoding), MessageType, metadata_headers(Ctx))), + RequestHeaders = ?pseudoheaders(Path, Scheme, Authority) ++ UserHeaders, + case h2_connection:new_stream(Conn, ?MODULE, [#{service => Service, marshal_fun => MarshalFun, unmarshal_fun => UnMarshalFun, @@ -70,7 +72,8 @@ send_request(Ctx, Channel, Path, Input, #grpcbox_def{service=Service, stats_handler := StatsHandler}} -> Encoding = maps:get(encoding, Options, DefaultEncoding), Body = grpcbox_frame:encode(Encoding, MarshalFun(Input)), - Headers = ?headers(Scheme, Authority, Path, encoding_to_binary(Encoding), MessageType, metadata_headers(Ctx)), + UserHeaders = merge_headers(?headers(encoding_to_binary(Encoding), MessageType, metadata_headers(Ctx))), + RequestHeaders = ?pseudoheaders(Path, Scheme, Authority) ++ UserHeaders, %% headers are sent in the same request as creating a new stream to ensure %% concurrent calls can't end up interleaving the sending of headers in such @@ -83,7 +86,7 @@ send_request(Ctx, Channel, Path, Input, #grpcbox_def{service=Service, buffer => <<>>, stats_handler => StatsHandler, stats => #{}, - client_pid => self()}], Headers, [], self()) of + client_pid => self()}], RequestHeaders, [], self()) of {error, _Code} = Err -> Err; {StreamId, Pid} -> @@ -223,3 +226,62 @@ encoding_to_binary(deflate) -> <<"deflate">>; encoding_to_binary(snappy) -> <<"snappy">>; encoding_to_binary(Custom) -> atom_to_binary(Custom, latin1). +merge_headers(Headers) -> + lists:foldl(fun merge_header_field/2, [], Headers). + +merge_header_field({K, V}, HeadersAcc) -> + case {is_protected_header(K), proplists:is_defined(K, HeadersAcc)} of + {true, true} -> + % is protected and already exists, skip + HeadersAcc; + {false, true} -> + % isn't protected and already exists, join + join_header_values({K, V}, HeadersAcc); + {_, false} -> + % doesn't exist, add + [{K, V} | HeadersAcc] + end. + +join_header_values({Name, Val}, HeadersAcc) -> + OrigVal = proplists:get_value(Name, HeadersAcc), + NewValue = <>/binary, Val/binary>>, + NewList = lists:keyreplace(Name, 1, HeadersAcc, {Name, NewValue}), + NewList. + +is_protected_header(Name) -> + lists:member(Name, ?protected_headers). + +-ifdef(TEST). +-include_lib("eunit/include/eunit.hrl"). + +merge_headers_test() -> + {Encoding, MsgType} = {<<"identity">>, <<"grpc.TestRequest">>}, + Ctx = ctx:new(), + Ctx1 = grpcbox_metadata:append_to_outgoing_ctx(Ctx, #{<<"content-type">> => <<"application/grpc">>, + <<"user-agent">> => <<"custom-grpc-client">>}), + Headers0 = ?headers(Encoding, MsgType, metadata_headers(Ctx1)), + Headers1 = merge_headers(Headers0), + + ?assertEqual([{<<"te">>, <<"trailers">>}, + {<<"grpc-message-type">>, <<"grpc.TestRequest">>}, + {<<"grpc-encoding">>, <<"identity">>}, + {<<"user-agent">>, <<"custom-grpc-client, grpc-erlang/0.9.2">>}, + {<<"content-type">>, <<"application/grpc">>} + ], Headers1), + ok. + +merge_headers_empty_ctx_test() -> + {Encoding, MsgType} = {<<"identity">>, <<"grpc.TestRequest">>}, + Ctx = ctx:new(), + Headers0 = ?headers(Encoding, MsgType, metadata_headers(Ctx)), + Headers1 = merge_headers(Headers0), + + ?assertEqual([{<<"te">>, <<"trailers">>}, + {<<"user-agent">>, <<"grpc-erlang/0.9.2">>}, + {<<"content-type">>, <<"application/grpc+proto">>}, + {<<"grpc-message-type">>, <<"grpc.TestRequest">>}, + {<<"grpc-encoding">>, <<"identity">>} + ], Headers1), + ok. + +-endif.