diff --git a/emathroughput.go b/emathroughput.go new file mode 100644 index 0000000..da40bc3 --- /dev/null +++ b/emathroughput.go @@ -0,0 +1,359 @@ +package dynsampler + +import ( + "encoding/json" + "errors" + "fmt" + "math" + "sync" + "time" +) + +// EMAThroughput implements Sampler and attempts to achieve a given throughput +// rate, weighting rare traffic and frequent traffic differently so as to end up +// with the correct value. +// +// Based on the EMASampleRate implementation, EMAThroughput differs in that +// instead of trying to achieve a given sample rate, it tries to reach a given +// throughput of events. During bursts of traffic, it will reduce sample +// rates so as to keep the number of events per second roughly constant. +// +// Like the EMA sampler, it maintains an Exponential Moving Average of counts +// seen per key, and adjusts this average at regular intervals. The weight +// applied to more recent intervals is defined by `weight`, a number between (0, +// 1) - larger values weight the average more toward recent observations. In +// other words, a larger weight will cause sample rates more quickly adapt to +// traffic patterns, while a smaller weight will result in sample rates that are +// less sensitive to bursts or drops in traffic and thus more consistent over +// time. +// +// New keys that are not found in the EMA will always have a sample +// rate of 1. Keys that occur more frequently will be sampled on a logarithmic +// curve. In other words, every key will be represented at least once in any +// given window and more frequent keys will have their sample rate +// increased proportionally to wind up with the goal throughput. +type EMAThroughput struct { + // AdjustmentInterval defines how often we adjust the moving average from + // recent observations. Default 15s. + AdjustmentInterval time.Duration + + // Weight is a value between (0, 1) indicating the weighting factor used to adjust + // the EMA. With larger values, newer data will influence the average more, and older + // values will be factored out more quickly. In mathematical literature concerning EMA, + // this is referred to as the `alpha` constant. + // Default is 0.5 + Weight float64 + + // InitialSampleRate is the sample rate to use during startup, before we + // have accumulated enough data to calculate a reasonable desired sample + // rate. This is mainly useful in situations where unsampled throughput is + // high enough to cause problems. + // Default 10. + InitialSampleRate int + + // GoalThroughputPerSec is the target number of events to send per second. + // Sample rates are generated to squash the total throughput down to match the + // goal throughput. Actual throughput may exceed goal throughput. default 100 + GoalThroughputPerSec int + + // MaxKeys, if greater than 0, limits the number of distinct keys tracked in EMA. + // Once MaxKeys is reached, new keys will not be included in the sample rate map, but + // existing keys will continue to be be counted. + MaxKeys int + + // AgeOutValue indicates the threshold for removing keys from the EMA. The EMA of any key will approach 0 + // if it is not repeatedly observed, but will never truly reach it, so we have to decide what constitutes "zero". + // Keys with averages below this threshold will be removed from the EMA. Default is the same as Weight, as this prevents + // a key with the smallest integer value (1) from being aged out immediately. This value should generally be <= Weight, + // unless you have very specific reasons to set it higher. + AgeOutValue float64 + + // BurstMultiple, if set, is multiplied by the sum of the running average of counts to define + // the burst detection threshold. If total counts observed for a given interval exceed the threshold + // EMA is updated immediately, rather than waiting on the AdjustmentInterval. + // Defaults to 2; negative value disables. With a default of 2, if your traffic suddenly doubles, + // burst detection will kick in. + BurstMultiple float64 + + // BurstDetectionDelay indicates the number of intervals to run after Start is called before burst detection kicks in. + // Defaults to 3 + BurstDetectionDelay uint + + savedSampleRates map[string]int + currentCounts map[string]float64 + movingAverage map[string]float64 + burstThreshold float64 + currentBurstSum float64 + intervalCount uint + burstSignal chan struct{} + + // haveData indicates that we have gotten a sample of traffic. Before we've + // gotten any samples of traffic, we should use the default goal sample rate + // for all events instead of sampling everything at 1 + haveData bool + updating bool + done chan struct{} + + lock sync.Mutex + + // used only in tests + testSignalMapsDone chan struct{} +} + +// Ensure we implement the sampler interface +var _ Sampler = (*EMAThroughput)(nil) + +func (e *EMAThroughput) Start() error { + // apply defaults + if e.AdjustmentInterval == 0 { + e.AdjustmentInterval = 15 * time.Second + } + if e.AdjustmentInterval < 1*time.Millisecond { + return fmt.Errorf("the AdjustmentInterval %v is unreasonably short for a throughput sampler", e.AdjustmentInterval) + } + if e.InitialSampleRate == 0 { + e.InitialSampleRate = 10 + } + if e.GoalThroughputPerSec == 0 { + e.GoalThroughputPerSec = 100 + } + if e.Weight == 0 { + e.Weight = 0.5 + } + if e.AgeOutValue == 0 { + e.AgeOutValue = e.Weight + } + if e.BurstMultiple == 0 { + e.BurstMultiple = 2 + } + if e.BurstDetectionDelay == 0 { + e.BurstDetectionDelay = 3 + } + + // Don't override these maps at startup in case they were loaded from a previous state + e.currentCounts = make(map[string]float64) + if e.savedSampleRates == nil { + e.savedSampleRates = make(map[string]int) + } + if e.movingAverage == nil { + e.movingAverage = make(map[string]float64) + } + e.burstSignal = make(chan struct{}) + e.done = make(chan struct{}) + + go func() { + ticker := time.NewTicker(e.AdjustmentInterval) + defer ticker.Stop() + for { + select { + case <-e.burstSignal: + // reset ticker when we get a burst + ticker.Stop() + ticker = time.NewTicker(e.AdjustmentInterval) + e.updateMaps() + case <-ticker.C: + e.updateMaps() + e.intervalCount++ + case <-e.done: + return + } + } + }() + return nil +} + +func (e *EMAThroughput) Stop() error { + close(e.done) + return nil +} + +// updateMaps calculates a new saved rate map based on the contents of the +// counter map +func (e *EMAThroughput) updateMaps() { + e.lock.Lock() + if e.testSignalMapsDone != nil { + defer func() { + e.testSignalMapsDone <- struct{}{} + }() + } + // short circuit if no traffic + if len(e.currentCounts) == 0 { + // No traffic the last interval, don't update anything. This is deliberate to avoid + // the average decaying when there's no traffic (comes in bursts, or there's some kind of outage). + e.lock.Unlock() + return + } + // If there is another updateMaps going, bail + if e.updating { + e.lock.Unlock() + return + } + e.updating = true + // make a local copy of the sample counters for calculation + tmpCounts := e.currentCounts + e.currentCounts = make(map[string]float64) + e.currentBurstSum = 0 + e.lock.Unlock() + + e.updateEMA(tmpCounts) + + // Goal events to send this interval is the total count of events in the EMA + // divided by the desired average sample rate + var sumEvents float64 + for _, count := range e.movingAverage { + sumEvents += math.Max(1, count) + } + + // Store this for burst detection. This is checked in GetSampleRate + // so we need to grab the lock when we update it. + e.lock.Lock() + e.burstThreshold = sumEvents * e.BurstMultiple + e.lock.Unlock() + + // Calculate the desired average sample rate per second based on the volume we've received. + // InitialSampleRate := float64(sumEvents) / e.AdjustmentInterval.Seconds() / float64(e.GoalThroughputPerSec) + // goalCount := float64(sumEvents) / InitialSampleRate + + // Calculate the number of events we'd like to let through per adjustment interval + goalCount := float64(e.GoalThroughputPerSec) / e.AdjustmentInterval.Seconds() + + // goalRatio is the goalCount divided by the sum of all the log values - it + // determines what percentage of the total event space belongs to each key + var logSum float64 + for _, count := range e.movingAverage { + // We take the max of (1, count) because count * weight is < 1 for + // very small counts, which throws off the logSum and can cause + // incorrect samples rates to be computed when throughput is low + logSum += math.Log10(math.Max(1, count)) + } + goalRatio := goalCount / logSum + + newSavedSampleRates := calculateSampleRates(goalRatio, e.movingAverage) + e.lock.Lock() + defer e.lock.Unlock() + e.savedSampleRates = newSavedSampleRates + e.haveData = true + e.updating = false +} + +// GetSampleRate takes a key and returns the appropriate sample rate for that +// key. +func (e *EMAThroughput) GetSampleRate(key string) int { + return e.GetSampleRateMulti(key, 1) +} + +// GetSampleRateMulti takes a key representing count spans and returns the +// appropriate sample rate for that key. +func (e *EMAThroughput) GetSampleRateMulti(key string, count int) int { + e.lock.Lock() + defer e.lock.Unlock() + + // Enforce MaxKeys limit on the size of the map + if e.MaxKeys > 0 { + // If a key already exists, increment it. If not, but we're under the limit, store a new key + if _, found := e.currentCounts[key]; found || len(e.currentCounts) < e.MaxKeys { + e.currentCounts[key] += float64(count) + e.currentBurstSum += float64(count) + } + } else { + e.currentCounts[key] += float64(count) + e.currentBurstSum += float64(count) + } + + // Enforce the burst threshold + if e.burstThreshold > 0 && e.currentBurstSum >= e.burstThreshold && e.intervalCount >= e.BurstDetectionDelay { + // reset the burst sum to prevent additional burst updates from occurring while updateMaps is running + e.currentBurstSum = 0 + // send but don't block - consuming is blocked on updateMaps, which takes the same lock we're holding + select { + case e.burstSignal <- struct{}{}: + default: + } + } + + if !e.haveData { + return e.InitialSampleRate + } + if rate, found := e.savedSampleRates[key]; found { + return rate + } + return 1 +} + +func (e *EMAThroughput) updateEMA(newCounts map[string]float64) { + keysToUpdate := make([]string, 0, len(e.movingAverage)) + for key := range e.movingAverage { + keysToUpdate = append(keysToUpdate, key) + } + + // Update any existing keys with new values + for _, key := range keysToUpdate { + var newAvg float64 + // Was this key seen in the last interval? Adjust by that amount + if val, found := newCounts[key]; found { + newAvg = adjustAverage(e.movingAverage[key], val, e.Weight) + } else { + // Otherwise adjust by zero + newAvg = adjustAverage(e.movingAverage[key], 0, e.Weight) + } + + // Age out this value if it's too small to care about for calculating sample rates + // This is also necessary to keep our map from going forever. + if newAvg < e.AgeOutValue { + delete(e.movingAverage, key) + } else { + e.movingAverage[key] = newAvg + } + // We've processed this key - don't process it again when we look at new counts + delete(newCounts, key) + } + + for key := range newCounts { + newAvg := adjustAverage(0, newCounts[key], e.Weight) + if newAvg >= e.AgeOutValue { + e.movingAverage[key] = newAvg + } + } +} + +type emaThroughputState struct { + // These fields are exported for use by `JSON.Marshal` and `JSON.Unmarshal` + SavedSampleRates map[string]int `json:"saved_sample_rates"` + MovingAverage map[string]float64 `json:"moving_average"` +} + +// SaveState returns a byte array with a JSON representation of the sampler state +func (e *EMAThroughput) SaveState() ([]byte, error) { + e.lock.Lock() + defer e.lock.Unlock() + + if e.savedSampleRates == nil { + return nil, errors.New("saved sample rate map is nil") + } + if e.movingAverage == nil { + return nil, errors.New("moving average map is nil") + } + s := &emaThroughputState{SavedSampleRates: e.savedSampleRates, MovingAverage: e.movingAverage} + return json.Marshal(s) +} + +// LoadState accepts a byte array with a JSON representation of a previous instance's +// state +func (e *EMAThroughput) LoadState(state []byte) error { + e.lock.Lock() + defer e.lock.Unlock() + + s := emaThroughputState{} + err := json.Unmarshal(state, &s) + if err != nil { + return err + } + + // Load the previously calculated sample rates + e.savedSampleRates = s.SavedSampleRates + e.movingAverage = s.MovingAverage + // Allow GetSampleRate to return calculated sample rates from the loaded map + e.haveData = true + + return nil +} diff --git a/emathroughput_test.go b/emathroughput_test.go new file mode 100644 index 0000000..f8c52dc --- /dev/null +++ b/emathroughput_test.go @@ -0,0 +1,264 @@ +package dynsampler + +import ( + "math" + mrand "math/rand" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestUpdateEMAThroughput(t *testing.T) { + e := &EMAThroughput{ + movingAverage: make(map[string]float64), + Weight: 0.2, + AgeOutValue: 0.2, + } + + tests := []struct { + keyAValue float64 + keyAExpected float64 + keyBValue float64 + keyBExpected float64 + keyCValue float64 + keyCExpected float64 + }{ + {463, 93, 235, 47, 0, 0}, + {176, 109, 458, 129, 0, 0}, + {345, 156, 470, 197, 0, 0}, + {339, 193, 317, 221, 0, 0}, + {197, 194, 165, 210, 0, 0}, + {387, 232, 95, 187, 6960, 1392}, + } + + for _, tt := range tests { + counts := make(map[string]float64) + counts["a"] = tt.keyAValue + counts["b"] = tt.keyBValue + counts["c"] = tt.keyCValue + e.updateEMA(counts) + assert.Equal(t, tt.keyAExpected, math.Round(e.movingAverage["a"])) + assert.Equal(t, tt.keyBExpected, math.Round(e.movingAverage["b"])) + assert.Equal(t, tt.keyCExpected, math.Round(e.movingAverage["c"])) + } +} + +func TestEMAThroughputSampleGetSampleRateStartup(t *testing.T) { + e := &EMAThroughput{ + InitialSampleRate: 10, + currentCounts: map[string]float64{}, + } + rate := e.GetSampleRate("key") + assert.Equal(t, rate, 10) + assert.Equal(t, e.currentCounts["key"], float64(1)) +} + +func TestEMAThroughputSampleUpdateMapsSparseCounts(t *testing.T) { + e := &EMAThroughput{ + GoalThroughputPerSec: 10, + AdjustmentInterval: 1 * time.Second, + Weight: 0.2, + AgeOutValue: 0.2, + } + + e.movingAverage = make(map[string]float64) + e.savedSampleRates = make(map[string]int) + + for i := 0; i <= 100; i++ { + input := make(map[string]float64) + // simulate steady stream of input from one key + input["largest_count"] = 40 + // sporadic keys with single counts that come and go with each interval + for j := 0; j < 5; j++ { + key := randomString(8) + input[key] = 1 + } + e.currentCounts = input + e.updateMaps() + } + assert.Equal(t, 4, e.savedSampleRates["largest_count"]) +} + +func TestEMAThroughputAgesOutSmallValues(t *testing.T) { + e := &EMAThroughput{ + GoalThroughputPerSec: 10, + AdjustmentInterval: 1 * time.Second, + Weight: 0.2, + AgeOutValue: 0.2, + } + e.movingAverage = make(map[string]float64) + for i := 0; i < 100; i++ { + e.currentCounts = map[string]float64{"foo": 500.0} + e.updateMaps() + } + assert.Equal(t, 1, len(e.movingAverage)) + assert.Equal(t, float64(500), math.Round(e.movingAverage["foo"])) + for i := 0; i < 100; i++ { + // "observe" no occurrences of foo for many iterations + e.currentCounts = map[string]float64{"asdf": 1} + e.updateMaps() + } + _, found := e.movingAverage["foo"] + assert.Equal(t, false, found) + _, found = e.movingAverage["asdf"] + assert.Equal(t, true, found) +} + +func TestEMAThroughputBurstDetection(t *testing.T) { + // Set the adjustment interval very high so that we never run the regular interval + e := &EMAThroughput{AdjustmentInterval: 1 * time.Hour} + err := e.Start() + assert.Nil(t, err) + + // set some counts and compute the EMA + e.currentCounts = map[string]float64{"foo": 1000} + e.updateMaps() + // should have a burst threshold computed now from this average + // 1000 = 0.5 (weight) * 1000 * 2 (threshold multiplier) + assert.Equal(t, float64(1000), e.burstThreshold) + + // Let's try and trigger a burst: + for i := 0; i <= 1000; i++ { + e.GetSampleRate("bar") + } + // burst sum isn't reset even though we're above our burst threshold + // This is because we haven't processed enough intervals to do burst detection yet + assert.Equal(t, float64(1001), e.currentBurstSum) + // Now let's cheat and say we have + e.intervalCount = e.BurstDetectionDelay + e.testSignalMapsDone = make(chan struct{}) + e.GetSampleRate("bar") + // wait on updateMaps to complete + <-e.testSignalMapsDone + // currentBurstSum has been reset + assert.Equal(t, float64(0), e.currentBurstSum) + + // ensure EMA is updated + assert.Equal(t, float64(501), e.movingAverage["bar"]) +} + +func TestEMAThroughputUpdateMapsRace(t *testing.T) { + e := &EMAThroughput{AdjustmentInterval: 1 * time.Hour} + e.testSignalMapsDone = make(chan struct{}, 1000) + err := e.Start() + assert.Nil(t, err) + for i := 0; i < 1000; i++ { + e.GetSampleRate("foo") + go e.updateMaps() + } + done := 0 + for done != 1000 { + <-e.testSignalMapsDone + done++ + } +} + +func TestEMAThroughputSampleRateSaveState(t *testing.T) { + var sampler Sampler + esr := &EMAThroughput{} + // ensure the interface is implemented + sampler = esr + err := sampler.Start() + assert.Nil(t, err) + + esr.lock.Lock() + esr.savedSampleRates = map[string]int{"foo": 2, "bar": 4} + esr.movingAverage = map[string]float64{"foo": 500.1234, "bar": 9999.99} + esr.haveData = true + esr.lock.Unlock() + + assert.Equal(t, 2, sampler.GetSampleRate("foo")) + assert.Equal(t, 4, sampler.GetSampleRate("bar")) + + state, err := sampler.SaveState() + assert.Nil(t, err) + + var newSampler Sampler + esr2 := &EMAThroughput{} + newSampler = esr2 + + err = newSampler.LoadState(state) + assert.Nil(t, err) + err = newSampler.Start() + assert.Nil(t, err) + + assert.Equal(t, 2, newSampler.GetSampleRate("foo")) + assert.Equal(t, 4, newSampler.GetSampleRate("bar")) + esr2.lock.Lock() + defer esr2.lock.Unlock() + assert.Equal(t, float64(500.1234), esr2.movingAverage["foo"]) + assert.Equal(t, float64(9999.99), esr2.movingAverage["bar"]) +} + +// This is a long test that generates a lot of random data and run it through the sampler +// The goal is to determine if we actually hit the specified target throughput (within a tolerance) an acceptable +// number of times. Most of the time, the throughput of observations kept should be close +// to the target rate. +func TestEMAThroughputSampleRateHitsTargetRate(t *testing.T) { + mrand.Seed(time.Now().Unix()) + testThroughputs := []int{100, 1000} + testKeyCount := []int{10, 30} + toleranceFraction := float64(0.2) + + for _, throughput := range testThroughputs { + tolerance := float64(throughput) * toleranceFraction + toleranceUpper := float64(throughput) + tolerance + toleranceLower := float64(throughput) - tolerance + + for _, keyCount := range testKeyCount { + sampler := &EMAThroughput{ + AdjustmentInterval: 1 * time.Second, + GoalThroughputPerSec: throughput, + Weight: 0.5, + AgeOutValue: 0.5, + currentCounts: make(map[string]float64), + movingAverage: make(map[string]float64), + } + + // build a consistent set of keys to use + keys := make([]string, keyCount) + for i := 0; i < keyCount; i++ { + keys[i] = randomString(8) + } + + for i, key := range keys { + // generate key counts of different magnitudes + base := math.Pow10(i%3 + 1) + count := float64(((i%10)+1))*base + float64(mrand.Intn(int(base))) + sampler.currentCounts[key] = count + } + + // build an initial set of sample values so we don't just return the target + sampler.updateMaps() + + var success int + + grandTotalKept := 0 + // each tick is 1 second + for i := 0; i < 100; i++ { + totalKeptObservations := 0 + for j, key := range keys { + base := math.Pow10(j%3 + 1) + count := float64(((j%10)+1))*base + float64(mrand.Intn(int(base))) + for k := 0; k < int(count); k++ { + rate := sampler.GetSampleRate(key) + if mrand.Intn(rate) == 0 { + totalKeptObservations++ + } + } + } + grandTotalKept += totalKeptObservations + + if totalKeptObservations <= int(toleranceUpper) && totalKeptObservations >= int(toleranceLower) { + success++ + } + sampler.updateMaps() + } + assert.GreaterOrEqual(t, grandTotalKept, throughput*90, "totalKept too low: %d expected: %d\n", grandTotalKept, throughput*100) + assert.LessOrEqual(t, grandTotalKept, throughput*110, "totalKept too high: %d expected: %d\n", grandTotalKept, throughput*100) + + assert.True(t, success >= 90, "target throughput test %d with key count %d failed with success rate of %d%%", throughput, keyCount, success) + } + } +} diff --git a/genericsampler_test.go b/genericsampler_test.go index 1b99738..e5dc642 100644 --- a/genericsampler_test.go +++ b/genericsampler_test.go @@ -56,6 +56,18 @@ func TestGenericSamplerBehavior(t *testing.T) { LookbackFrequencyDuration: 1 * time.Second, }, []int{1, 1, 1, 2, 6, 19, 58, 174}, }, + {"EMAThroughput", + &dynsampler.EMAThroughput{ + AdjustmentInterval: 1 * time.Second, + GoalThroughputPerSec: 100, + }, []int{1, 1, 2, 3, 6, 13, 31, 77}, + }, + {"EMAThroughputLowTraffic", + &dynsampler.EMAThroughput{ + AdjustmentInterval: 1 * time.Second, + GoalThroughputPerSec: 100000, + }, []int{1, 1, 1, 1, 1, 1, 1, 1}, + }, } const ( @@ -95,7 +107,7 @@ func TestGenericSamplerBehavior(t *testing.T) { for k := 0; k < nkeys; k++ { // if !isCloseTo(tt.want[k], results[k]) { if tt.want[k] != results[k] { - t.Errorf("results not close enough = for key %s (%d) want %d, got %d\n", keys[k], k, tt.want[k], results[k]) + t.Errorf("results not = for key %s (%d) want %d, got %d\n", keys[k], k, tt.want[k], results[k]) } } })