From 07923dacff1ab46051157e6d8657d51c1782f43d Mon Sep 17 00:00:00 2001 From: Umputun Date: Fri, 24 Jan 2025 12:53:54 -0600 Subject: [PATCH] Add tests for mount point conflicts and deep nested mounts --- group_test.go | 125 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 125 insertions(+) diff --git a/group_test.go b/group_test.go index c51a5cf..b77c5a2 100644 --- a/group_test.go +++ b/group_test.go @@ -2244,6 +2244,131 @@ func TestMiddlewareOrder(t *testing.T) { } } +func TestMountPointMethodConflicts(t *testing.T) { + group := routegroup.New(http.NewServeMux()) + + // register handler for /api directly + group.HandleFunc("GET /api", func(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write([]byte("api root")) + }) + + // mount a group at /api + api := group.Mount("/api") + api.HandleFunc("/users", func(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write([]byte("users")) + }) + + srv := httptest.NewServer(group) + defer srv.Close() + + t.Run("get /api root", func(t *testing.T) { + resp, err := http.Get(srv.URL + "/api") + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected status 200, got %d", resp.StatusCode) + } + if string(body) != "api root" { + t.Errorf("expected 'api root', got %q", string(body)) + } + }) + + t.Run("get /api/users", func(t *testing.T) { + resp, err := http.Get(srv.URL + "/api/users") + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected status 200, got %d", resp.StatusCode) + } + if string(body) != "users" { + t.Errorf("expected 'users', got %q", string(body)) + } + }) +} + +func TestDeepNestedMounts(t *testing.T) { + var callOrder []string + mkMiddleware := func(name string) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callOrder = append(callOrder, "before "+name) + next.ServeHTTP(w, r) + callOrder = append(callOrder, "after "+name) + }) + } + } + + group := routegroup.New(http.NewServeMux()) + group.Use(mkMiddleware("root")) + + v1 := group.Mount("/v1") + v1.Use(mkMiddleware("v1")) + + api := v1.Mount("/api") + api.Use(mkMiddleware("api")) + + users := api.Mount("/users") + users.Use(mkMiddleware("users")) + + users.HandleFunc("/list", func(w http.ResponseWriter, _ *http.Request) { + callOrder = append(callOrder, "handler") + _, _ = w.Write([]byte("users list")) + }) + + srv := httptest.NewServer(group) + defer srv.Close() + + resp, err := http.Get(srv.URL + "/v1/api/users/list") + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected status 200, got %d", resp.StatusCode) + } + if string(body) != "users list" { + t.Errorf("expected 'users list', got %q", string(body)) + } + + expected := []string{ + "before root", + "before v1", + "before api", + "before users", + "handler", + "after users", + "after api", + "after v1", + "after root", + } + + if !reflect.DeepEqual(callOrder, expected) { + t.Errorf("middleware execution order mismatch\nwant: %v\ngot: %v", expected, callOrder) + } +} + func ExampleNew() { group := routegroup.New(http.NewServeMux())