Skip to content

Commit

Permalink
add support sse writer
Browse files Browse the repository at this point in the history
  • Loading branch information
acoshift committed Dec 22, 2024
1 parent 5751161 commit 30a063f
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 10 deletions.
37 changes: 27 additions & 10 deletions arpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,14 @@ func (m *Manager) Decode(r *http.Request, v any) error {
return WrapError(v.UnmarshalRequest(r))
}

// GET without body
if r.Method == http.MethodGet {
if v, ok := v.(FormUnmarshaler); ok {
return WrapError(v.UnmarshalForm(r.Form))
}
return nil
}

return ErrUnsupported
}

Expand Down Expand Up @@ -177,19 +185,21 @@ func (m *Manager) NotFoundHandler() http.Handler {
type mapIndex int

const (
_ mapIndex = iota
miContext // context.Context
miRequest // *http.Request
miResponseWriter // http.ResponseWriter
miAny // any
miError // error
_ mapIndex = iota
miContext // context.Context
miRequest // *http.Request
miResponseWriter // http.ResponseWriter
miSSEResponseWriter // SSEResponseWriter
miAny // any
miError // error
)

const (
strContext = "context.Context"
strRequest = "*http.Request"
strResponseWriter = "http.ResponseWriter"
strError = "error"
strContext = "context.Context"
strRequest = "*http.Request"
strResponseWriter = "http.ResponseWriter"
strSSEResponseWriter = "arpc.SSEResponseWriter"
strError = "error"
)

func setOrPanic(m map[mapIndex]int, k mapIndex, v int) {
Expand Down Expand Up @@ -238,6 +248,9 @@ func (m *Manager) Handler(f any) http.Handler {
case strResponseWriter:
setOrPanic(mapIn, miResponseWriter, i)
hasWriter = true
case strSSEResponseWriter:
setOrPanic(mapIn, miSSEResponseWriter, i)
hasWriter = true
default:
setOrPanic(mapIn, miAny, i)
}
Expand Down Expand Up @@ -314,6 +327,10 @@ func (m *Manager) Handler(f any) http.Handler {
if i, ok := mapIn[miResponseWriter]; ok {
vIn[i] = reflect.ValueOf(w)
}
// inject sse response writer
if i, ok := mapIn[miSSEResponseWriter]; ok {
vIn[i] = reflect.ValueOf(newSSEResponseWriter(w))
}

vOut := fv.Call(vIn)
// check error
Expand Down
27 changes: 27 additions & 0 deletions arpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ func f1(r *request) int {
func f2() {
}

func f3(ctx context.Context, w arpc.SSEResponseWriter) error {
w.Write([]byte("data: 1\n\n"))
<-ctx.Done()
return nil
}

func TestSuccess(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -276,3 +282,24 @@ func TestMiddleware(t *testing.T) {
assert.JSONEq(t, `{"ok":true,"result":{}}`, w.Body.String())
})
}

func TestSSE(t *testing.T) {
t.Parallel()

m := arpc.New()
h := m.Handler(f3)
ctx, cancel := context.WithCancel(context.Background())
w := httptest.NewRecorder()
r := httptest.NewRequestWithContext(ctx, "GET", "/", nil)

Check failure on line 293 in arpc_test.go

View workflow job for this annotation

GitHub Actions / Go 1.22

undefined: httptest.NewRequestWithContext

Check failure on line 293 in arpc_test.go

View workflow job for this annotation

GitHub Actions / Go 1.22

undefined: httptest.NewRequestWithContext
waitExit := make(chan struct{})
go func() {
h.ServeHTTP(w, r)
waitExit <- struct{}{}
}()
cancel()
<-waitExit

assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "text/event-stream", w.Header().Get("Content-Type"))
assert.Equal(t, "data: 1\n\n", w.Body.String())
}
48 changes: 48 additions & 0 deletions sse.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package arpc

import "net/http"

type SSEResponseWriter interface {
http.ResponseWriter
http.Flusher
}

var _ SSEResponseWriter = (*sseResponseWriter)(nil)

type sseResponseWriter struct {
wrote bool
w http.ResponseWriter
}

func newSSEResponseWriter(w http.ResponseWriter) SSEResponseWriter {
return &sseResponseWriter{w: w}
}

func (w *sseResponseWriter) writeHeader() {
if w.wrote {
return
}
w.WriteHeader(http.StatusOK)
}

func (w *sseResponseWriter) Header() http.Header {
return w.w.Header()
}

func (w *sseResponseWriter) WriteHeader(statusCode int) {
if w.wrote {
return
}
w.wrote = true
w.w.Header().Set("Content-Type", "text/event-stream")
w.w.WriteHeader(statusCode)
}

func (w *sseResponseWriter) Write(b []byte) (int, error) {
w.writeHeader()
return w.w.Write(b)
}

func (w *sseResponseWriter) Flush() {
w.w.(http.Flusher).Flush()
}

0 comments on commit 30a063f

Please sign in to comment.