diff --git a/agent/integration/config_allowlisting_integration_test.go b/agent/integration/config_allowlisting_integration_test.go index 03a3337a8a..6d32ccb083 100644 --- a/agent/integration/config_allowlisting_integration_test.go +++ b/agent/integration/config_allowlisting_integration_test.go @@ -126,7 +126,7 @@ func TestConfigAllowlisting(t *testing.T) { } e := createTestAgentEndpoint() - server := e.server(jobID) + server := e.server() defer server.Close() mb := mockBootstrap(t) diff --git a/agent/integration/job_environment_integration_test.go b/agent/integration/job_environment_integration_test.go index a9f277c322..6f4496a4a5 100644 --- a/agent/integration/job_environment_integration_test.go +++ b/agent/integration/job_environment_integration_test.go @@ -36,7 +36,7 @@ func TestWhenCachePathsSetInJobStep_CachePathsEnvVarIsSet(t *testing.T) { // create a mock agent API e := createTestAgentEndpoint() - server := e.server("my-job-id") + server := e.server() defer server.Close() err := runJob(t, ctx, testRunJobConfig{ diff --git a/agent/integration/job_runner_integration_test.go b/agent/integration/job_runner_integration_test.go index 47c26bec0c..577f836c9b 100644 --- a/agent/integration/job_runner_integration_test.go +++ b/agent/integration/job_runner_integration_test.go @@ -94,7 +94,7 @@ func TestPreBootstrapHookScripts(t *testing.T) { // Creates a mock agent API e := createTestAgentEndpoint() - server := e.server(defaultJobID) + server := e.server() t.Cleanup(server.Close) j := &api.Job{ @@ -154,7 +154,7 @@ func TestPreBootstrapHookRefusesJob(t *testing.T) { // create a mock agent API e := createTestAgentEndpoint() - server := e.server("my-job-id") + server := e.server() defer server.Close() mb := mockBootstrap(t) @@ -205,7 +205,7 @@ func TestJobRunner_WhenBootstrapExits_ItSendsTheExitStatusToTheAPI(t *testing.T) mb.Expect().Once().AndExitWith(exit) e := createTestAgentEndpoint() - server := e.server("my-job-id") + server := e.server() defer server.Close() err := runJob(t, ctx, testRunJobConfig{ @@ -254,7 +254,7 @@ func TestJobRunner_WhenJobHasToken_ItOverridesAccessToken(t *testing.T) { // create a mock agent API e := createTestAgentEndpoint() - server := e.server("my-job-id") + server := e.server() defer server.Close() err := runJob(t, ctx, testRunJobConfig{ @@ -294,7 +294,7 @@ func TestJobRunnerPassesAccessTokenToBootstrap(t *testing.T) { // create a mock agent API e := createTestAgentEndpoint() - server := e.server("my-job-id") + server := e.server() defer server.Close() err := runJob(t, ctx, testRunJobConfig{ @@ -333,7 +333,7 @@ func TestJobRunnerIgnoresPipelineChangesToProtectedVars(t *testing.T) { // create a mock agent API e := createTestAgentEndpoint() - server := e.server("my-job-id") + server := e.server() defer server.Close() runJob(t, ctx, testRunJobConfig{ diff --git a/agent/integration/job_verification_integration_test.go b/agent/integration/job_verification_integration_test.go index bdf2d53704..9220dab2ec 100644 --- a/agent/integration/job_verification_integration_test.go +++ b/agent/integration/job_verification_integration_test.go @@ -556,7 +556,7 @@ func TestJobVerification(t *testing.T) { // create a mock agent API e := createTestAgentEndpoint() - server := e.server(tc.job.ID) + server := e.server() defer server.Close() mb := mockBootstrap(t) diff --git a/agent/integration/test_helpers.go b/agent/integration/test_helpers.go index faea48d0a1..b8aa95ff3c 100644 --- a/agent/integration/test_helpers.go +++ b/agent/integration/test_helpers.go @@ -70,6 +70,7 @@ func runJob(t *testing.T, ctx context.Context, cfg testRunJobConfig) error { AgentConfiguration: cfg.agentCfg, MetricsScope: scope, }) + if err != nil { t.Fatalf("agent.NewJobRunner() error = %v", err) } @@ -127,33 +128,109 @@ func (tae *testAgentEndpoint) logsFor(t *testing.T, jobID string) string { return strings.Join(logChunks, "") } -func (t *testAgentEndpoint) server(jobID string) *httptest.Server { - return httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - t.mtx.Lock() - defer t.mtx.Unlock() - - b, _ := io.ReadAll(req.Body) - t.calls[req.URL.Path] = append(t.calls[req.URL.Path], b) - - switch req.URL.Path { - case "/jobs/" + jobID: - rw.WriteHeader(http.StatusOK) - fmt.Fprintf(rw, `{"state":"running"}`) - case "/jobs/" + jobID + "/start": - rw.WriteHeader(http.StatusOK) - case "/jobs/" + jobID + "/chunks": - sequence := req.URL.Query().Get("sequence") - seqNo, _ := strconv.Atoi(sequence) - r, _ := gzip.NewReader(bytes.NewBuffer(b)) - uz, _ := io.ReadAll(r) - t.logChunks[seqNo] = string(uz) - rw.WriteHeader(http.StatusCreated) - case "/jobs/" + jobID + "/finish": - rw.WriteHeader(http.StatusOK) - default: - http.Error(rw, fmt.Sprintf("not found; method = %q, path = %q", req.Method, req.URL.Path), http.StatusNotFound) +type route struct { + Path string + Method string + http.HandlerFunc +} + +func (t *testAgentEndpoint) getJobsHandler() http.HandlerFunc { + return func(rw http.ResponseWriter, req *http.Request) { + rw.WriteHeader(http.StatusOK) + fmt.Fprintf(rw, `{"state":"running"}`) + } +} + +func (t *testAgentEndpoint) chunksHandler() http.HandlerFunc { + return func(rw http.ResponseWriter, req *http.Request) { + b, err := io.ReadAll(req.Body) + if err != nil { + rw.WriteHeader(http.StatusInternalServerError) + return } - })) + + sequence := req.URL.Query().Get("sequence") + seqNo, err := strconv.Atoi(sequence) + if err != nil { + rw.WriteHeader(http.StatusBadRequest) + return + } + + r, err := gzip.NewReader(bytes.NewBuffer(b)) + if err != nil { + rw.WriteHeader(http.StatusInternalServerError) + return + } + + uz, err := io.ReadAll(r) + if err != nil { + rw.WriteHeader(http.StatusInternalServerError) + return + } + + t.logChunks[seqNo] = string(uz) + rw.WriteHeader(http.StatusCreated) + } +} + +func (t *testAgentEndpoint) defaultRoutes() []route { + return []route{ + { + Method: "GET", + Path: "/jobs/", + HandlerFunc: t.getJobsHandler(), + }, + { + Method: "PUT", + Path: "/jobs/{id}/start", + HandlerFunc: func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) }, + }, + { + Method: "POST", + Path: "/jobs/{id}/chunks", + HandlerFunc: t.chunksHandler(), + }, + { + Method: "PUT", + Path: "/jobs/{id}/finish", + HandlerFunc: func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) }, + }, + } +} + +func (t *testAgentEndpoint) server(extraRoutes ...route) *httptest.Server { + mux := http.NewServeMux() + + defaultRoutes := t.defaultRoutes() + routesUniq := make(map[string]http.HandlerFunc, len(defaultRoutes)) + for _, r := range defaultRoutes { + routesUniq[fmt.Sprintf("%s %s", r.Method, r.Path)] = r.HandlerFunc + } + + // extra routes overwrite default routes if they conflict + for _, r := range extraRoutes { + routesUniq[fmt.Sprintf("%s %s", r.Method, r.Path)] = r.HandlerFunc + } + + wrapRecordRequest := func(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + b, _ := io.ReadAll(req.Body) + req.Body.Close() + req.Body = io.NopCloser(bytes.NewBuffer(b)) + + t.mtx.Lock() + t.calls[req.URL.Path] = append(t.calls[req.URL.Path], b) + t.mtx.Unlock() + + next.ServeHTTP(rw, req) + }) + } + + for path, handler := range routesUniq { + mux.Handle(path, wrapRecordRequest(handler)) + } + + return httptest.NewServer(mux) } func mockPreBootstrap(t *testing.T, hooksDir string) *bintest.Mock {