Skip to content

Commit

Permalink
feat: accept GET method at /quitquitquit (#726)
Browse files Browse the repository at this point in the history
* fix: Accept GET method at /quitquitquit

* dedicate POST/HEAD quitquitquit tests
  • Loading branch information
kchygoe authored Dec 2, 2024
1 parent 281c9e3 commit 19b1ec8
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 4 deletions.
2 changes: 1 addition & 1 deletion cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -1194,7 +1194,7 @@ func runSignalWrapper(cmd *Command) (err error) {

func quitquitquit(quitOnce *sync.Once, shutdownCh chan<- error) http.HandlerFunc {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodPost {
if req.Method != http.MethodPost && req.Method != http.MethodGet {
rw.WriteHeader(400)
return
}
Expand Down
48 changes: 45 additions & 3 deletions cmd/root_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1192,7 +1192,7 @@ func TestPProfServer(t *testing.T) {
}
}

func TestQuitQuitQuit(t *testing.T) {
func TestQuitQuitQuitHTTPPost(t *testing.T) {
c := NewCommand(WithDialer(&spyDialer{}))
c.SilenceUsage = true
c.SilenceErrors = true
Expand All @@ -1206,14 +1206,56 @@ func TestQuitQuitQuit(t *testing.T) {
err := c.ExecuteContext(ctx)
errCh <- err
}()
resp, err := tryDial("GET", "http://localhost:9192/quitquitquit")
resp, err := tryDial("HEAD", "http://localhost:9192/quitquitquit")
if err != nil {
t.Fatalf("failed to dial endpoint: %v", err)
}
if resp.StatusCode != http.StatusBadRequest {
t.Fatalf("expected a 400 status, got = %v", resp.StatusCode)
}
resp, err = http.Post("http://localhost:9192/quitquitquit", "", nil)
resp, err = tryDial("POST", "http://localhost:9192/quitquitquit")
if err != nil {
t.Fatalf("failed to dial endpoint: %v", err)
}
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected a 200 status, got = %v", resp.StatusCode)
}

var gotErr error
select {
case err := <-errCh:
gotErr = err
case <-time.After(30 * time.Second):
t.Fatal("timeout waiting for error")
}

if !errors.Is(gotErr, errQuitQuitQuit) {
t.Fatalf("want = %v, got = %v", errQuitQuitQuit, gotErr)
}
}

func TestQuitQuitQuitGet(t *testing.T) {
c := NewCommand(WithDialer(&spyDialer{}))
c.SilenceUsage = true
c.SilenceErrors = true
c.SetArgs([]string{"--quitquitquit", "--admin-port", "9192",
"projects/proj/locations/region/clusters/clust/instances/inst"})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

errCh := make(chan error)
go func() {
err := c.ExecuteContext(ctx)
errCh <- err
}()
resp, err := tryDial("HEAD", "http://localhost:9192/quitquitquit")
if err != nil {
t.Fatalf("failed to dial endpoint: %v", err)
}
if resp.StatusCode != http.StatusBadRequest {
t.Fatalf("expected a 400 status, got = %v", resp.StatusCode)
}
resp, err = tryDial("GET", "http://localhost:9192/quitquitquit")
if err != nil {
t.Fatalf("failed to dial endpoint: %v", err)
}
Expand Down

0 comments on commit 19b1ec8

Please sign in to comment.