From 30a063f93e8677745f0e31bc864c749b55db5594 Mon Sep 17 00:00:00 2001 From: Thanatat Tamtan Date: Sun, 22 Dec 2024 12:43:59 +0700 Subject: [PATCH] add support sse writer --- arpc.go | 37 +++++++++++++++++++++++++++---------- arpc_test.go | 27 +++++++++++++++++++++++++++ sse.go | 48 ++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 102 insertions(+), 10 deletions(-) create mode 100644 sse.go diff --git a/arpc.go b/arpc.go index 40e38ff..716f824 100644 --- a/arpc.go +++ b/arpc.go @@ -143,6 +143,14 @@ func (m *Manager) Decode(r *http.Request, v any) error { return WrapError(v.UnmarshalRequest(r)) } + // GET without body + if r.Method == http.MethodGet { + if v, ok := v.(FormUnmarshaler); ok { + return WrapError(v.UnmarshalForm(r.Form)) + } + return nil + } + return ErrUnsupported } @@ -177,19 +185,21 @@ func (m *Manager) NotFoundHandler() http.Handler { type mapIndex int const ( - _ mapIndex = iota - miContext // context.Context - miRequest // *http.Request - miResponseWriter // http.ResponseWriter - miAny // any - miError // error + _ mapIndex = iota + miContext // context.Context + miRequest // *http.Request + miResponseWriter // http.ResponseWriter + miSSEResponseWriter // SSEResponseWriter + miAny // any + miError // error ) const ( - strContext = "context.Context" - strRequest = "*http.Request" - strResponseWriter = "http.ResponseWriter" - strError = "error" + strContext = "context.Context" + strRequest = "*http.Request" + strResponseWriter = "http.ResponseWriter" + strSSEResponseWriter = "arpc.SSEResponseWriter" + strError = "error" ) func setOrPanic(m map[mapIndex]int, k mapIndex, v int) { @@ -238,6 +248,9 @@ func (m *Manager) Handler(f any) http.Handler { case strResponseWriter: setOrPanic(mapIn, miResponseWriter, i) hasWriter = true + case strSSEResponseWriter: + setOrPanic(mapIn, miSSEResponseWriter, i) + hasWriter = true default: setOrPanic(mapIn, miAny, i) } @@ -314,6 +327,10 @@ func (m *Manager) Handler(f any) http.Handler { if i, ok := mapIn[miResponseWriter]; ok { vIn[i] = reflect.ValueOf(w) } + // inject sse response writer + if i, ok := mapIn[miSSEResponseWriter]; ok { + vIn[i] = reflect.ValueOf(newSSEResponseWriter(w)) + } vOut := fv.Call(vIn) // check error diff --git a/arpc_test.go b/arpc_test.go index 2d04b80..b6eb63e 100644 --- a/arpc_test.go +++ b/arpc_test.go @@ -27,6 +27,12 @@ func f1(r *request) int { func f2() { } +func f3(ctx context.Context, w arpc.SSEResponseWriter) error { + w.Write([]byte("data: 1\n\n")) + <-ctx.Done() + return nil +} + func TestSuccess(t *testing.T) { t.Parallel() @@ -276,3 +282,24 @@ func TestMiddleware(t *testing.T) { assert.JSONEq(t, `{"ok":true,"result":{}}`, w.Body.String()) }) } + +func TestSSE(t *testing.T) { + t.Parallel() + + m := arpc.New() + h := m.Handler(f3) + ctx, cancel := context.WithCancel(context.Background()) + w := httptest.NewRecorder() + r := httptest.NewRequestWithContext(ctx, "GET", "/", nil) + waitExit := make(chan struct{}) + go func() { + h.ServeHTTP(w, r) + waitExit <- struct{}{} + }() + cancel() + <-waitExit + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "text/event-stream", w.Header().Get("Content-Type")) + assert.Equal(t, "data: 1\n\n", w.Body.String()) +} diff --git a/sse.go b/sse.go new file mode 100644 index 0000000..8fd8feb --- /dev/null +++ b/sse.go @@ -0,0 +1,48 @@ +package arpc + +import "net/http" + +type SSEResponseWriter interface { + http.ResponseWriter + http.Flusher +} + +var _ SSEResponseWriter = (*sseResponseWriter)(nil) + +type sseResponseWriter struct { + wrote bool + w http.ResponseWriter +} + +func newSSEResponseWriter(w http.ResponseWriter) SSEResponseWriter { + return &sseResponseWriter{w: w} +} + +func (w *sseResponseWriter) writeHeader() { + if w.wrote { + return + } + w.WriteHeader(http.StatusOK) +} + +func (w *sseResponseWriter) Header() http.Header { + return w.w.Header() +} + +func (w *sseResponseWriter) WriteHeader(statusCode int) { + if w.wrote { + return + } + w.wrote = true + w.w.Header().Set("Content-Type", "text/event-stream") + w.w.WriteHeader(statusCode) +} + +func (w *sseResponseWriter) Write(b []byte) (int, error) { + w.writeHeader() + return w.w.Write(b) +} + +func (w *sseResponseWriter) Flush() { + w.w.(http.Flusher).Flush() +}