From 19b1ec8be37e12cdfdd9b4d16153724b2d02e3d7 Mon Sep 17 00:00:00 2001 From: Koichi Yoshigoe Date: Tue, 3 Dec 2024 01:38:39 +0900 Subject: [PATCH] feat: accept GET method at /quitquitquit (#726) * fix: Accept GET method at /quitquitquit * dedicate POST/HEAD quitquitquit tests --- cmd/root.go | 2 +- cmd/root_test.go | 48 +++++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 46 insertions(+), 4 deletions(-) diff --git a/cmd/root.go b/cmd/root.go index 2ffe801c..2ae5718a 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -1194,7 +1194,7 @@ func runSignalWrapper(cmd *Command) (err error) { func quitquitquit(quitOnce *sync.Once, shutdownCh chan<- error) http.HandlerFunc { return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - if req.Method != http.MethodPost { + if req.Method != http.MethodPost && req.Method != http.MethodGet { rw.WriteHeader(400) return } diff --git a/cmd/root_test.go b/cmd/root_test.go index 00cf87d6..6cc1e60e 100644 --- a/cmd/root_test.go +++ b/cmd/root_test.go @@ -1192,7 +1192,7 @@ func TestPProfServer(t *testing.T) { } } -func TestQuitQuitQuit(t *testing.T) { +func TestQuitQuitQuitHTTPPost(t *testing.T) { c := NewCommand(WithDialer(&spyDialer{})) c.SilenceUsage = true c.SilenceErrors = true @@ -1206,14 +1206,56 @@ func TestQuitQuitQuit(t *testing.T) { err := c.ExecuteContext(ctx) errCh <- err }() - resp, err := tryDial("GET", "http://localhost:9192/quitquitquit") + resp, err := tryDial("HEAD", "http://localhost:9192/quitquitquit") if err != nil { t.Fatalf("failed to dial endpoint: %v", err) } if resp.StatusCode != http.StatusBadRequest { t.Fatalf("expected a 400 status, got = %v", resp.StatusCode) } - resp, err = http.Post("http://localhost:9192/quitquitquit", "", nil) + resp, err = tryDial("POST", "http://localhost:9192/quitquitquit") + if err != nil { + t.Fatalf("failed to dial endpoint: %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected a 200 status, got = %v", resp.StatusCode) + } + + var gotErr error + select { + case err := <-errCh: + gotErr = err + case <-time.After(30 * time.Second): + t.Fatal("timeout waiting for error") + } + + if !errors.Is(gotErr, errQuitQuitQuit) { + t.Fatalf("want = %v, got = %v", errQuitQuitQuit, gotErr) + } +} + +func TestQuitQuitQuitGet(t *testing.T) { + c := NewCommand(WithDialer(&spyDialer{})) + c.SilenceUsage = true + c.SilenceErrors = true + c.SetArgs([]string{"--quitquitquit", "--admin-port", "9192", + "projects/proj/locations/region/clusters/clust/instances/inst"}) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + errCh := make(chan error) + go func() { + err := c.ExecuteContext(ctx) + errCh <- err + }() + resp, err := tryDial("HEAD", "http://localhost:9192/quitquitquit") + if err != nil { + t.Fatalf("failed to dial endpoint: %v", err) + } + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("expected a 400 status, got = %v", resp.StatusCode) + } + resp, err = tryDial("GET", "http://localhost:9192/quitquitquit") if err != nil { t.Fatalf("failed to dial endpoint: %v", err) }