Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] HeapBasedRateTracker uses time provider to allow simluating of time in unit tests #3941

Merged
merged 1 commit into from
Jan 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
package org.opensearch.security.util.ratetracking;

import java.util.Arrays;
import java.util.Optional;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.function.LongSupplier;

import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
Expand All @@ -33,16 +35,22 @@ public class HeapBasedRateTracker<ClientIdType> implements RateTracker<ClientIdT
private final Logger log = LogManager.getLogger(this.getClass());

private final Cache<ClientIdType, ClientRecord> cache;
private final LongSupplier timeProvider;
private final long timeWindowMs;
private final int maxTimeOffsets;

public HeapBasedRateTracker(long timeWindowMs, int allowedTries, int maxEntries) {
this(timeWindowMs, allowedTries, maxEntries, null);
}

public HeapBasedRateTracker(long timeWindowMs, int allowedTries, int maxEntries, LongSupplier timeProvider) {
if (allowedTries < 2) {
throw new IllegalArgumentException("allowedTries must be >= 2");
}

this.timeWindowMs = timeWindowMs;
this.maxTimeOffsets = allowedTries > 2 ? allowedTries - 2 : 0;
this.timeProvider = Optional.ofNullable(timeProvider).orElse(System::currentTimeMillis);
this.cache = CacheBuilder.newBuilder()
.expireAfterAccess(this.timeWindowMs, TimeUnit.MILLISECONDS)
.maximumSize(maxEntries)
Expand Down Expand Up @@ -89,7 +97,7 @@ private class ClientRecord {
private short timeOffsetEnd = -1;

synchronized boolean track() {
long timestamp = System.currentTimeMillis();
long timestamp = timeProvider.getAsLong();

if (this.startTime == -1 || timestamp - getMostRecent() >= timeWindowMs) {
this.startTime = timestamp;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@

package org.opensearch.security.auth.limiting;

import org.junit.Ignore;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.LongSupplier;

import org.junit.Test;

import org.opensearch.security.util.ratetracking.HeapBasedRateTracker;
Expand All @@ -27,9 +29,12 @@

public class HeapBasedRateTrackerTest {

private final AtomicLong currentTime = new AtomicLong(1);
private LongSupplier timeProvider = () -> currentTime.getAndAdd(1);

@Test
public void simpleTest() throws Exception {
HeapBasedRateTracker<String> tracker = new HeapBasedRateTracker<>(100, 5, 100_000);
HeapBasedRateTracker<String> tracker = new HeapBasedRateTracker<>(100, 5, 100_000, timeProvider);

assertFalse(tracker.track("a"));
assertFalse(tracker.track("a"));
Expand All @@ -40,9 +45,8 @@ public void simpleTest() throws Exception {
}

@Test
@Ignore // https://github.com/opensearch-project/security/issues/2193
public void expiryTest() throws Exception {
HeapBasedRateTracker<String> tracker = new HeapBasedRateTracker<>(100, 5, 100_000);
HeapBasedRateTracker<String> tracker = new HeapBasedRateTracker<>(100, 5, 100_000, timeProvider);

assertFalse(tracker.track("a"));
assertFalse(tracker.track("a"));
Expand All @@ -58,42 +62,41 @@ public void expiryTest() throws Exception {

assertFalse(tracker.track("c"));

Thread.sleep(50);
currentTime.addAndGet(50);

assertFalse(tracker.track("c"));
assertFalse(tracker.track("c"));
assertFalse(tracker.track("c"));

Thread.sleep(55);
currentTime.addAndGet(55);

assertFalse(tracker.track("c"));
assertTrue(tracker.track("c"));

assertFalse(tracker.track("a"));

Thread.sleep(55);
currentTime.addAndGet(55);
assertFalse(tracker.track("c"));
assertFalse(tracker.track("c"));
assertTrue(tracker.track("c"));

}

@Test
@Ignore // https://github.com/opensearch-project/security/issues/2193
public void maxTwoTriesTest() throws Exception {
HeapBasedRateTracker<String> tracker = new HeapBasedRateTracker<>(100, 2, 100_000);
HeapBasedRateTracker<String> tracker = new HeapBasedRateTracker<>(100, 2, 100_000, timeProvider);

assertFalse(tracker.track("a"));
assertTrue(tracker.track("a"));

assertFalse(tracker.track("b"));
Thread.sleep(50);
currentTime.addAndGet(50);
assertTrue(tracker.track("b"));

Thread.sleep(55);
currentTime.addAndGet(55);
assertTrue(tracker.track("b"));

Thread.sleep(105);
currentTime.addAndGet(105);
assertFalse(tracker.track("b"));
assertTrue(tracker.track("b"));

Expand Down
Loading