Skip to content

Commit

Permalink
Allow passing extra routes to job runner integration test server
Browse files Browse the repository at this point in the history
  • Loading branch information
moskyb committed May 2, 2024
1 parent f13425a commit 67c7673
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 35 deletions.
2 changes: 1 addition & 1 deletion agent/integration/config_allowlisting_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ func TestConfigAllowlisting(t *testing.T) {
}

e := createTestAgentEndpoint()
server := e.server(jobID)
server := e.server()
defer server.Close()

mb := mockBootstrap(t)
Expand Down
2 changes: 1 addition & 1 deletion agent/integration/job_environment_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func TestWhenCachePathsSetInJobStep_CachePathsEnvVarIsSet(t *testing.T) {

// create a mock agent API
e := createTestAgentEndpoint()
server := e.server("my-job-id")
server := e.server()
defer server.Close()

err := runJob(t, ctx, testRunJobConfig{
Expand Down
12 changes: 6 additions & 6 deletions agent/integration/job_runner_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func TestPreBootstrapHookScripts(t *testing.T) {

// Creates a mock agent API
e := createTestAgentEndpoint()
server := e.server(defaultJobID)
server := e.server()
t.Cleanup(server.Close)

j := &api.Job{
Expand Down Expand Up @@ -154,7 +154,7 @@ func TestPreBootstrapHookRefusesJob(t *testing.T) {

// create a mock agent API
e := createTestAgentEndpoint()
server := e.server("my-job-id")
server := e.server()
defer server.Close()

mb := mockBootstrap(t)
Expand Down Expand Up @@ -205,7 +205,7 @@ func TestJobRunner_WhenBootstrapExits_ItSendsTheExitStatusToTheAPI(t *testing.T)
mb.Expect().Once().AndExitWith(exit)

e := createTestAgentEndpoint()
server := e.server("my-job-id")
server := e.server()
defer server.Close()

err := runJob(t, ctx, testRunJobConfig{
Expand Down Expand Up @@ -254,7 +254,7 @@ func TestJobRunner_WhenJobHasToken_ItOverridesAccessToken(t *testing.T) {

// create a mock agent API
e := createTestAgentEndpoint()
server := e.server("my-job-id")
server := e.server()
defer server.Close()

err := runJob(t, ctx, testRunJobConfig{
Expand Down Expand Up @@ -294,7 +294,7 @@ func TestJobRunnerPassesAccessTokenToBootstrap(t *testing.T) {

// create a mock agent API
e := createTestAgentEndpoint()
server := e.server("my-job-id")
server := e.server()
defer server.Close()

err := runJob(t, ctx, testRunJobConfig{
Expand Down Expand Up @@ -333,7 +333,7 @@ func TestJobRunnerIgnoresPipelineChangesToProtectedVars(t *testing.T) {

// create a mock agent API
e := createTestAgentEndpoint()
server := e.server("my-job-id")
server := e.server()
defer server.Close()

runJob(t, ctx, testRunJobConfig{
Expand Down
2 changes: 1 addition & 1 deletion agent/integration/job_verification_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,7 @@ func TestJobVerification(t *testing.T) {

// create a mock agent API
e := createTestAgentEndpoint()
server := e.server(tc.job.ID)
server := e.server()
defer server.Close()

mb := mockBootstrap(t)
Expand Down
129 changes: 103 additions & 26 deletions agent/integration/test_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ func runJob(t *testing.T, ctx context.Context, cfg testRunJobConfig) error {
AgentConfiguration: cfg.agentCfg,
MetricsScope: scope,
})

if err != nil {
t.Fatalf("agent.NewJobRunner() error = %v", err)
}
Expand Down Expand Up @@ -127,33 +128,109 @@ func (tae *testAgentEndpoint) logsFor(t *testing.T, jobID string) string {
return strings.Join(logChunks, "")
}

func (t *testAgentEndpoint) server(jobID string) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
t.mtx.Lock()
defer t.mtx.Unlock()

b, _ := io.ReadAll(req.Body)
t.calls[req.URL.Path] = append(t.calls[req.URL.Path], b)

switch req.URL.Path {
case "/jobs/" + jobID:
rw.WriteHeader(http.StatusOK)
fmt.Fprintf(rw, `{"state":"running"}`)
case "/jobs/" + jobID + "/start":
rw.WriteHeader(http.StatusOK)
case "/jobs/" + jobID + "/chunks":
sequence := req.URL.Query().Get("sequence")
seqNo, _ := strconv.Atoi(sequence)
r, _ := gzip.NewReader(bytes.NewBuffer(b))
uz, _ := io.ReadAll(r)
t.logChunks[seqNo] = string(uz)
rw.WriteHeader(http.StatusCreated)
case "/jobs/" + jobID + "/finish":
rw.WriteHeader(http.StatusOK)
default:
http.Error(rw, fmt.Sprintf("not found; method = %q, path = %q", req.Method, req.URL.Path), http.StatusNotFound)
type route struct {
Path string
Method string
http.HandlerFunc
}

func (t *testAgentEndpoint) getJobsHandler() http.HandlerFunc {
return func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusOK)
fmt.Fprintf(rw, `{"state":"running"}`)
}
}

func (t *testAgentEndpoint) chunksHandler() http.HandlerFunc {
return func(rw http.ResponseWriter, req *http.Request) {
b, err := io.ReadAll(req.Body)
if err != nil {
rw.WriteHeader(http.StatusInternalServerError)
return
}
}))

sequence := req.URL.Query().Get("sequence")
seqNo, err := strconv.Atoi(sequence)
if err != nil {
rw.WriteHeader(http.StatusBadRequest)
return
}

r, err := gzip.NewReader(bytes.NewBuffer(b))
if err != nil {
rw.WriteHeader(http.StatusInternalServerError)
return
}

uz, err := io.ReadAll(r)
if err != nil {
rw.WriteHeader(http.StatusInternalServerError)
return
}

t.logChunks[seqNo] = string(uz)
rw.WriteHeader(http.StatusCreated)
}
}

func (t *testAgentEndpoint) defaultRoutes() []route {
return []route{
{
Method: "GET",
Path: "/jobs/",
HandlerFunc: t.getJobsHandler(),
},
{
Method: "PUT",
Path: "/jobs/{id}/start",
HandlerFunc: func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) },
},
{
Method: "POST",
Path: "/jobs/{id}/chunks",
HandlerFunc: t.chunksHandler(),
},
{
Method: "PUT",
Path: "/jobs/{id}/finish",
HandlerFunc: func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) },
},
}
}

func (t *testAgentEndpoint) server(extraRoutes ...route) *httptest.Server {
mux := http.NewServeMux()

defaultRoutes := t.defaultRoutes()
routesUniq := make(map[string]http.HandlerFunc, len(defaultRoutes))
for _, r := range defaultRoutes {
routesUniq[fmt.Sprintf("%s %s", r.Method, r.Path)] = r.HandlerFunc
}

// extra routes overwrite default routes if they conflict
for _, r := range extraRoutes {
routesUniq[fmt.Sprintf("%s %s", r.Method, r.Path)] = r.HandlerFunc
}

wrapRecordRequest := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
b, _ := io.ReadAll(req.Body)
req.Body.Close()
req.Body = io.NopCloser(bytes.NewBuffer(b))

t.mtx.Lock()
t.calls[req.URL.Path] = append(t.calls[req.URL.Path], b)
t.mtx.Unlock()

next.ServeHTTP(rw, req)
})
}

for path, handler := range routesUniq {
mux.Handle(path, wrapRecordRequest(handler))
}

return httptest.NewServer(mux)
}

func mockPreBootstrap(t *testing.T, hooksDir string) *bintest.Mock {
Expand Down

0 comments on commit 67c7673

Please sign in to comment.