diff --git a/lib/attack.go b/lib/attack.go index b44fa94f..8436858f 100644 --- a/lib/attack.go +++ b/lib/attack.go @@ -13,6 +13,8 @@ import ( "sync" "time" + "strconv" + "golang.org/x/net/http2" ) @@ -331,7 +333,7 @@ func (a *Attacker) attack(tr Targeter, name string, workers *sync.WaitGroup, tic func (a *Attacker) hit(tr Targeter, name string) *Result { var ( res = Result{Attack: name} - tgt Target + tgt = Target{Header: make(http.Header)} err error ) @@ -348,6 +350,8 @@ func (a *Attacker) hit(tr Targeter, name string) *Result { } }() + tgt.Header.Set("X-Vegeta-Seq", strconv.FormatUint(res.Seq, 10)) + if err = tr(&tgt); err != nil { a.Stop() return &res diff --git a/lib/attack_test.go b/lib/attack_test.go index e8b97fcf..8dc9397c 100644 --- a/lib/attack_test.go +++ b/lib/attack_test.go @@ -13,6 +13,7 @@ import ( "os" "path/filepath" "reflect" + "strconv" "strings" "testing" "time" @@ -358,3 +359,21 @@ func TestClient(t *testing.T) { t.Errorf("Expected timeout error") } } + +func TestSeqHeader(t *testing.T) { + t.Parallel() + + var got http.Header + + atk := NewAttacker() + for i := 0; i < 5; i++ { + res := atk.hit(func(trt *Target) error { + got = trt.Header + return nil + }, "") + + if got.Get("X-Vegeta-Seq") != strconv.FormatUint(res.Seq, 10) { + t.Fatalf("got: %v, want: %v", got.Get("X-Vegeta-Seq"), res.Seq) + } + } +}