From b0a883767ad4d51b2be36eab3649f5b7f310f48f Mon Sep 17 00:00:00 2001 From: Travis Ralston Date: Thu, 14 Mar 2024 19:36:03 -0600 Subject: [PATCH] Initial implementation --- .github/workflows/main.yml | 47 +++++++ .gitignore | 3 + README.md | 13 +- bucket.go | 111 +++++++++++++++++ bucket_test.go | 249 +++++++++++++++++++++++++++++++++++++ examples/basic/main.go | 52 ++++++++ go.mod | 11 ++ go.sum | 10 ++ 8 files changed, 495 insertions(+), 1 deletion(-) create mode 100644 .github/workflows/main.yml create mode 100644 bucket.go create mode 100644 bucket_test.go create mode 100644 examples/basic/main.go create mode 100644 go.mod create mode 100644 go.sum diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml new file mode 100644 index 0000000..1e97451 --- /dev/null +++ b/.github/workflows/main.yml @@ -0,0 +1,47 @@ +name: Go package +on: [push] +jobs: + build: + name: 'Go Build (1.21)' + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 + with: + go-version: '1.21' + - name: Install dependencies + run: go get . + - name: Build + run: go build ./... + static: + name: 'Go Static (1.21)' + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 + with: + go-version: '1.21' + - run: 'go install honnef.co/go/tools/cmd/staticcheck@latest' + - run: 'go vet ./...' + - run: 'staticcheck ./...' + test: + name: 'Go Test (1.21)' + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 + with: + go-version: '1.21' + - name: Install dependencies + run: go get . + - name: Test + run: go test -cover -vet all -coverprofile cover.out . + - name: Coverage Check + run: | + go tool cover -func ./cover.out + val=$(go tool cover -func cover.out | fgrep total | awk '{print $3}') + if [[ "100.0%" != $val ]] + then + echo 'Test coverage is less than 100.0%' + exit 1 + fi diff --git a/.gitignore b/.gitignore index 3b735ec..7b6fd0f 100644 --- a/.gitignore +++ b/.gitignore @@ -19,3 +19,6 @@ # Go workspace file go.work + +# Custom +/.idea \ No newline at end of file diff --git a/README.md b/README.md index 040a89b..d38eb60 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,13 @@ # go-leaky-bucket -Leaky bucket meter implementation in Go +[Leaky bucket meter](https://en.wikipedia.org/wiki/Leaky_bucket#As_a_meter) implementation in Go. + +This implementation atomically drains when its value is being mutated rather than using a timer or continual +drain. Before mutation, the bucket will drain the supplied number of units as many times as necessary to match +the precision of the supplied interval. + +For example, if a bucket is created which drains 5 units every 2 minutes, then after 2.5 minutes only 5 units will +be drained. However, the 30 seconds of "unused" drain time will be accounted for to ensure future drains are kept +accurate. If another 1.5 minutes were to pass, the bucket will drain by another 5 units because the unused time +was recorded. + +See [`./examples`](./examples) for usage and inspiration. \ No newline at end of file diff --git a/bucket.go b/bucket.go new file mode 100644 index 0000000..ae1c73e --- /dev/null +++ b/bucket.go @@ -0,0 +1,111 @@ +package leaky + +import ( + "encoding/gob" + "errors" + "sync" + "time" +) + +var ErrBucketFull = errors.New("leaky: bucket full or would overflow") + +func init() { + gob.Register(&Bucket{}) +} + +type Bucket struct { + DrainBy int64 + DrainInterval time.Duration + Capacity int64 + + value int64 + lastDrain time.Time + lock sync.Mutex +} + +func NewBucket(drainBy int64, drainEvery time.Duration, capacity int64) (*Bucket, error) { + if drainBy <= 0 || drainEvery <= 0 { + return nil, errors.New("leaky: bucket never drains") + } + if capacity <= 0 { + return nil, errors.New("leaky: bucket can never fill") + } + return &Bucket{ + DrainBy: drainBy, + DrainInterval: drainEvery, + Capacity: capacity, + value: 0, + lastDrain: time.Now(), + lock: sync.Mutex{}, + }, nil +} + +func (b *Bucket) drain() { + b.lock.Lock() + defer b.lock.Unlock() + + if b.lastDrain.IsZero() { + b.lastDrain = time.Now() // assume we've never drained + } + + if b.value <= 0 { + b.value = 0 + b.lastDrain = time.Now() + return // nothing to drain, so don't bother + } + + since := time.Since(b.lastDrain) + drainTime := since.Truncate(b.DrainInterval) + leaks := int64(drainTime.Abs() / b.DrainInterval.Abs()) + b.value -= b.DrainBy * leaks + if b.value < 0 { + b.value = 0 + } + b.lastDrain = time.Now().Add((since - drainTime) * -1) +} + +func (b *Bucket) Peek() int64 { + return b.value +} + +func (b *Bucket) Value() int64 { + b.drain() + return b.value +} + +func (b *Bucket) Remaining() int64 { + b.drain() + return b.Capacity - b.value +} + +func (b *Bucket) Add(amount int64) error { + b.drain() + + b.lock.Lock() + defer b.lock.Unlock() + + newValue := b.value + amount + if newValue > b.Capacity { + return ErrBucketFull + } + b.value = newValue + return nil +} + +func (b *Bucket) Set(value int64, resetDrain bool) error { + if value < 0 { + return errors.New("leaky: bucket value cannot be negative") + } + if value > b.Capacity { + return errors.New("leaky: bucket value cannot exceed capacity") + } + + b.lock.Lock() + defer b.lock.Unlock() + + b.value = value + if resetDrain { + b.lastDrain = time.Now() + } + return nil +} diff --git a/bucket_test.go b/bucket_test.go new file mode 100644 index 0000000..39b73f7 --- /dev/null +++ b/bucket_test.go @@ -0,0 +1,249 @@ +package leaky + +import ( + "errors" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +var createCaseFunctions = []func(drainBy int64, drainEvery time.Duration, capacity int64) (*Bucket, error){ + func(drainBy int64, drainEvery time.Duration, capacity int64) (*Bucket, error) { + return &Bucket{ + DrainBy: drainBy, + DrainInterval: drainEvery, + Capacity: capacity, + }, nil + }, + func(drainBy int64, drainEvery time.Duration, capacity int64) (*Bucket, error) { + return NewBucket(drainBy, drainEvery, capacity) + }, +} + +func TestNewBucket(t *testing.T) { + var err error + + // Zero drain + _, err = NewBucket(0, time.Minute, 300) + assert.EqualError(t, err, "leaky: bucket never drains") + _, err = NewBucket(5, 0*time.Minute, 300) + assert.EqualError(t, err, "leaky: bucket never drains") + + // Negative drain + _, err = NewBucket(-10, time.Minute, 300) + assert.EqualError(t, err, "leaky: bucket never drains") + _, err = NewBucket(5, -10*time.Minute, 300) + assert.EqualError(t, err, "leaky: bucket never drains") + + // No capacity + _, err = NewBucket(5, time.Minute, 0) + assert.EqualError(t, err, "leaky: bucket can never fill") + _, err = NewBucket(5, time.Minute, -10) + assert.EqualError(t, err, "leaky: bucket can never fill") + + // Happy path + bucket, err := NewBucket(5, time.Minute, 300) + assert.Nil(t, err) + assert.NotNil(t, bucket) + assert.Equal(t, int64(5), bucket.DrainBy) + assert.Equal(t, time.Minute, bucket.DrainInterval) + assert.Equal(t, int64(300), bucket.Capacity) + assert.Equal(t, int64(0), bucket.value) + assert.Equal(t, false, bucket.lastDrain.IsZero()) // ensure we set a timestamp +} + +func TestBucket_drain(t *testing.T) { + for i, createFn := range createCaseFunctions { + bucket, err := createFn(5, time.Minute, 300) + if err != nil { + t.Errorf("TestBucket_drain(case:%d): unexpected error %v", i, err) + continue + } + + // Shouldn't drain when empty + assert.Equal(t, int64(0), bucket.value) + bucket.drain() + assert.Equal(t, int64(0), bucket.value) + + // Still shouldn't drain when empty + bucket.lastDrain = time.Now().Add(-1 * bucket.DrainInterval) + bucket.drain() + assert.Equal(t, int64(0), bucket.value) + + // Shouldn't become negative when drained + bucket.value = bucket.DrainBy / 2 + bucket.lastDrain = time.Now().Add(-1 * bucket.DrainInterval) + bucket.drain() + assert.Equal(t, int64(0), bucket.value) + + // Should drain exactly 1 interval + bucket.value = bucket.Capacity + bucket.lastDrain = time.Now().Add(-1 * bucket.DrainInterval) + bucket.drain() + assert.Equal(t, bucket.Capacity-bucket.DrainBy, bucket.value) + + // Should drain exactly 2 intervals, but only to zero (not become negative) + bucket.value = bucket.DrainBy + bucket.lastDrain = time.Now().Add(-2 * bucket.DrainInterval) + bucket.drain() + assert.Equal(t, int64(0), bucket.value) + + // Should drain exactly 2 intervals + bucket.value = bucket.DrainBy * 3 + bucket.lastDrain = time.Now().Add(-2 * bucket.DrainInterval) + bucket.drain() + assert.Equal(t, bucket.DrainBy, bucket.value) + + // Should drain by partial intervals + bucket.value = bucket.DrainBy * 3 + bucket.lastDrain = time.Now().Add(-1 * (bucket.DrainInterval / 2)) + bucket.drain() + assert.Equal(t, bucket.DrainBy*3, bucket.value) + bucket.lastDrain = bucket.lastDrain.Add(-3 * (bucket.DrainInterval / 2)) // 1.5x interval + bucket.drain() + assert.Equal(t, bucket.DrainBy*1, bucket.value) + bucket.lastDrain = bucket.lastDrain.Add(-3 * (bucket.DrainInterval / 2)) // 1.5x interval + bucket.drain() + assert.Equal(t, int64(0), bucket.value) + } +} + +func TestBucket_Peek(t *testing.T) { + for i, createFn := range createCaseFunctions { + bucket, err := createFn(5, time.Minute, 300) + if err != nil { + t.Errorf("TestBucket_Peek(case:%d): unexpected error %v", i, err) + continue + } + + // Doesn't drain on call, even if it could + bucket.lastDrain = time.Now().Add(-1 * bucket.DrainInterval) + bucket.value = 100 + assert.Equalf(t, int64(100), bucket.Peek(), "Bucket_Peek(case:%d) should be equal", i) + } +} + +func TestBucket_Value(t *testing.T) { + for i, createFn := range createCaseFunctions { + bucket, err := createFn(5, time.Minute, 300) + if err != nil { + t.Errorf("TestBucket_Value(case:%d): unexpected error %v", i, err) + continue + } + + // Doesn't drain on first call + bucket.value = 100 + assert.Equalf(t, int64(100), bucket.Value(), "TestBucket_Value(case:%d) should be equal", i) + + // Does drain if required + bucket.lastDrain = time.Now().Add(-1 * bucket.DrainInterval) + assert.Equalf(t, int64(95), bucket.Value(), "TestBucket_Value(case:%d) should be equal", i) + } +} + +func TestBucket_Remaining(t *testing.T) { + for i, createFn := range createCaseFunctions { + bucket, err := createFn(5, time.Minute, 300) + if err != nil { + t.Errorf("TestBucket_Remaining(case:%d): unexpected error %v", i, err) + continue + } + + // Doesn't drain on first call + bucket.value = 100 + assert.Equalf(t, int64(200), bucket.Remaining(), "TestBucket_Remaining(case:%d) should be equal", i) + + // Does drain if required + bucket.lastDrain = time.Now().Add(-1 * bucket.DrainInterval) + assert.Equalf(t, int64(205), bucket.Remaining(), "TestBucket_Remaining(case:%d) should be equal", i) + } +} + +func TestBucket_Add(t *testing.T) { + for i, createFn := range createCaseFunctions { + bucket, err := createFn(5, time.Minute, 300) + if err != nil { + t.Errorf("Bucket_Add(case:%d): unexpected error %v", i, err) + continue + } + + if err = bucket.Add(100); err != nil { + t.Errorf("Bucket_Add(case:%d): unexpected Add error %v", i, err) + } + assert.Equalf(t, int64(100), bucket.value, "Bucket_Add(case:%d) should be equal", i) + + // Test overflow + if err = bucket.Add(250); err != nil { + if !errors.Is(err, ErrBucketFull) { + t.Errorf("Bucket_Add(case:%d): expected overflow error, got %v", i, err) + } + } else if err == nil { + t.Errorf("Bucket_Add(case:%d): expected overflow error, got nil", i) + } + assert.Equalf(t, int64(100), bucket.value, "Bucket_Add(case:%d) should be equal", i) + + // Test exact fill + if err = bucket.Add(200); err != nil { + t.Errorf("Bucket_Add(case:%d): unexpected Add error %v", i, err) + } + assert.Equalf(t, int64(300), bucket.value, "Bucket_Add(case:%d) should be equal", i) + + // Drains before add + bucket.value = bucket.Capacity + bucket.lastDrain = time.Now().Add(-1 * bucket.DrainInterval) + if err = bucket.Add(bucket.DrainBy); err != nil { + t.Errorf("Bucket_Add(case:%d): unexpected Add error %v", i, err) + } + assert.Equalf(t, int64(300), bucket.value, "Bucket_Add(case:%d) should be equal", i) + } +} + +func TestBucket_Set(t *testing.T) { + for i, createFn := range createCaseFunctions { + bucket, err := createFn(5, time.Minute, 300) + if err != nil { + t.Errorf("TestBucket_Set(case:%d): unexpected error %v", i, err) + continue + } + + // Must be positive value + if err = bucket.Set(-1, true); err != nil { + assert.EqualErrorf(t, err, "leaky: bucket value cannot be negative", "TestBucket_Set(case:%d)", i) + } else { + t.Errorf("TestBucket_Set(case:%d): expected error, got nil", i) + } + + // Must be less than capacity + if err = bucket.Set(bucket.Capacity+1, true); err != nil { + assert.EqualErrorf(t, err, "leaky: bucket value cannot exceed capacity", "TestBucket_Set(case:%d)", i) + } else { + t.Errorf("TestBucket_Set(case:%d): expected error, got nil", i) + } + + // Can be zero, and resets lastDrain, and doesn't drain + bucket.lastDrain = time.Now().Add(-5 * bucket.DrainInterval) + if err = bucket.Set(0, true); err != nil { + t.Errorf("TestBucket_Set(case:%d): unexpected Set error %v", i, err) + } + assert.Equal(t, int64(0), bucket.value) + assert.InDeltaf(t, 0*time.Millisecond, time.Since(bucket.lastDrain), float64(10*time.Millisecond), "TestBucket_Set(case:%d)", i) + + // Can be positive, and resets lastDrain, and doesn't drain + bucket.lastDrain = time.Now().Add(-5 * bucket.DrainInterval) + if err = bucket.Set(5, true); err != nil { + t.Errorf("TestBucket_Set(case:%d): unexpected Set error %v", i, err) + } + assert.Equal(t, int64(5), bucket.value) + assert.InDeltaf(t, 0*time.Millisecond, time.Since(bucket.lastDrain), float64(10*time.Millisecond), "TestBucket_Set(case:%d)", i) + + // Doesn't reset lastDrain when resetDrain=false + drainTime := time.Now().Add(-5 * bucket.DrainInterval) + bucket.lastDrain = drainTime + if err = bucket.Set(10, false); err != nil { + t.Errorf("TestBucket_Set(case:%d): unexpected Set error %v", i, err) + } + assert.Equalf(t, int64(10), bucket.value, "TestBucket_Set(case:%d) should be equal", i) + assert.Equalf(t, drainTime, bucket.lastDrain, "TestBucket_Set(case:%d) should be equal", i) + } +} diff --git a/examples/basic/main.go b/examples/basic/main.go new file mode 100644 index 0000000..8c05e39 --- /dev/null +++ b/examples/basic/main.go @@ -0,0 +1,52 @@ +package main + +import ( + "errors" + "fmt" + "time" + + "github.com/t2bot/go-leaky-bucket" +) + +func main() { + bucket, err := leaky.NewBucket(5, time.Minute, 300) + if err != nil { + panic(err) // TODO: Handle error + } + + // Try to add to the bucket + if err = bucket.Add(50); errors.Is(err, leaky.ErrBucketFull) { + panic("bucket is full") // or cancel the request, return a 429, etc + } else if err != nil { + panic(err) // TODO: Handle error + } else { + // continue processing normally + } + + // Inspect the bucket + // All of these operations cause a drain to happen before returning a value. + fmt.Println("Remaining capacity:", bucket.Remaining()) + fmt.Println("Size:", bucket.Value()) + + // Inspect the bucket *without* causing a drain + // Caution: it may have been a while since the last drain. You probably want `.Value()` + fmt.Println("Undrained size:", bucket.Peek()) + + // Force the bucket to have a particular size + if err = bucket.Set(42 /*resetDrain=*/, true); err != nil { + panic(err) // TODO: Handle error + } else { + // The bucket is now set to 42, and the drain has been reset. Calling `.Value()` or similar + // right now would not cause the size to decrease. + // + // If for some reason you'd like to leave the drain status unchanged, set resetDrain to false. + } + fmt.Println("Size after Set:", bucket.Value()) // will not drain, even if a minute had passed + + // Expand the bucket in any direction + bucket.Capacity = 700 + bucket.DrainBy = 40 + bucket.DrainInterval = time.Hour + fmt.Println("Remaining capacity after expansion:", bucket.Remaining()) + fmt.Println("Size after expansion:", bucket.Value()) +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..8a14936 --- /dev/null +++ b/go.mod @@ -0,0 +1,11 @@ +module github.com/t2bot/go-leaky-bucket + +go 1.21 + +require github.com/stretchr/testify v1.9.0 + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..60ce688 --- /dev/null +++ b/go.sum @@ -0,0 +1,10 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=