Skip to content

Commit

Permalink
Custom transport mode
Browse files Browse the repository at this point in the history
  • Loading branch information
magik6k committed Jun 6, 2024
1 parent 81c1e3f commit 5e380fa
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 0 deletions.
57 changes: 57 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/url"
Expand Down Expand Up @@ -129,6 +130,62 @@ func NewMergeClient(ctx context.Context, addr string, namespace string, outs []i

}

// NewCustomClient is like NewMergeClient in single-request (http) mode, except it allows for a custom doRequest function
func NewCustomClient(namespace string, outs []interface{}, doRequest func(ctx context.Context, body []byte) (io.ReadCloser, error), opts ...Option) (ClientCloser, error) {
config := defaultConfig()
for _, o := range opts {
o(&config)
}

c := client{
namespace: namespace,
paramEncoders: config.paramEncoders,
errors: config.errors,
}

stop := make(chan struct{})
c.exiting = stop

c.doRequest = func(ctx context.Context, cr clientRequest) (clientResponse, error) {
b, err := json.Marshal(&cr.req)
if err != nil {
return clientResponse{}, xerrors.Errorf("marshalling request: %w", err)
}

rawResp, err := doRequest(ctx, b)
if err != nil {
return clientResponse{}, xerrors.Errorf("doRequest failed: %w", err)
}

defer rawResp.Close()

var resp clientResponse
if cr.req.ID != nil { // non-notification
if err := json.NewDecoder(rawResp).Decode(&resp); err != nil {
return clientResponse{}, xerrors.Errorf("unmarshaling response: %w", err)
}

if resp.ID, err = normalizeID(resp.ID); err != nil {
return clientResponse{}, xerrors.Errorf("failed to response ID: %w", err)
}

if resp.ID != cr.req.ID {
return clientResponse{}, xerrors.New("request and response id didn't match")
}
}

return resp, nil
}

if err := c.provide(outs); err != nil {
return nil, err
}

return func() {
close(stop)
}, nil
}

func httpClient(ctx context.Context, addr string, namespace string, outs []interface{}, requestHeader http.Header, config Config) (ClientCloser, error) {
c := client{
namespace: namespace,
Expand Down
4 changes: 4 additions & 0 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ func (s *RPCServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
s.handleReader(ctx, r.Body, w, rpcError)
}

func (s *RPCServer) HandleRequest(ctx context.Context, r io.Reader, w io.Writer) {
s.handleReader(ctx, r, w, rpcError)
}

func rpcError(wf func(func(io.Writer)), req *request, code ErrorCode, err error) {
log.Errorf("RPC Error: %s", err)
wf(func(w io.Writer) {
Expand Down

0 comments on commit 5e380fa

Please sign in to comment.