Skip to content

Commit

Permalink
Add tests for mount point conflicts and deep nested mounts
Browse files Browse the repository at this point in the history
  • Loading branch information
umputun committed Jan 24, 2025
1 parent 5b7962f commit 07923da
Showing 1 changed file with 125 additions and 0 deletions.
125 changes: 125 additions & 0 deletions group_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2244,6 +2244,131 @@ func TestMiddlewareOrder(t *testing.T) {
}
}

func TestMountPointMethodConflicts(t *testing.T) {
group := routegroup.New(http.NewServeMux())

// register handler for /api directly
group.HandleFunc("GET /api", func(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write([]byte("api root"))
})

// mount a group at /api
api := group.Mount("/api")
api.HandleFunc("/users", func(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write([]byte("users"))
})

srv := httptest.NewServer(group)
defer srv.Close()

t.Run("get /api root", func(t *testing.T) {
resp, err := http.Get(srv.URL + "/api")
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}

if resp.StatusCode != http.StatusOK {
t.Errorf("expected status 200, got %d", resp.StatusCode)
}
if string(body) != "api root" {
t.Errorf("expected 'api root', got %q", string(body))
}
})

t.Run("get /api/users", func(t *testing.T) {
resp, err := http.Get(srv.URL + "/api/users")
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}

if resp.StatusCode != http.StatusOK {
t.Errorf("expected status 200, got %d", resp.StatusCode)
}
if string(body) != "users" {
t.Errorf("expected 'users', got %q", string(body))
}
})
}

func TestDeepNestedMounts(t *testing.T) {
var callOrder []string
mkMiddleware := func(name string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
callOrder = append(callOrder, "before "+name)
next.ServeHTTP(w, r)
callOrder = append(callOrder, "after "+name)
})
}
}

group := routegroup.New(http.NewServeMux())
group.Use(mkMiddleware("root"))

v1 := group.Mount("/v1")
v1.Use(mkMiddleware("v1"))

api := v1.Mount("/api")
api.Use(mkMiddleware("api"))

users := api.Mount("/users")
users.Use(mkMiddleware("users"))

users.HandleFunc("/list", func(w http.ResponseWriter, _ *http.Request) {
callOrder = append(callOrder, "handler")
_, _ = w.Write([]byte("users list"))
})

srv := httptest.NewServer(group)
defer srv.Close()

resp, err := http.Get(srv.URL + "/v1/api/users/list")
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}

if resp.StatusCode != http.StatusOK {
t.Errorf("expected status 200, got %d", resp.StatusCode)
}
if string(body) != "users list" {
t.Errorf("expected 'users list', got %q", string(body))
}

expected := []string{
"before root",
"before v1",
"before api",
"before users",
"handler",
"after users",
"after api",
"after v1",
"after root",
}

if !reflect.DeepEqual(callOrder, expected) {
t.Errorf("middleware execution order mismatch\nwant: %v\ngot: %v", expected, callOrder)
}
}

func ExampleNew() {
group := routegroup.New(http.NewServeMux())

Expand Down

0 comments on commit 07923da

Please sign in to comment.