Skip to content

Commit

Permalink
Improve execution of Roundable unit-tests
Browse files Browse the repository at this point in the history
Signed-off-by: Ketan Verma <[email protected]>
  • Loading branch information
ketanv3 committed Dec 18, 2023
1 parent b7cb42a commit 2e82d0a
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 27 deletions.
8 changes: 0 additions & 8 deletions libs/common/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,6 @@ if (BuildParams.runtimeJavaVersion >= JavaVersion.VERSION_20) {
classpath -= sourceSets.main.output
}

tasks.register('roundableSimdTest', Test) {
group 'verification'
include '**/RoundableTests.class'
systemProperty 'opensearch.experimental.feature.simd.rounding.enabled', 'true'
}

check.dependsOn(roundableSimdTest)

forbiddenApisJava20 {
failOnMissingClasses = false
ignoreSignaturesOfMissingClasses = true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,36 +8,61 @@

package org.opensearch.common.round;

import org.opensearch.common.SuppressForbidden;
import org.opensearch.test.OpenSearchTestCase;

import java.util.List;
import java.lang.invoke.MethodHandles;
import java.lang.reflect.InvocationTargetException;

public class RoundableTests extends OpenSearchTestCase {

public void testRoundingEmptyArray() {
Throwable throwable = assertThrows(IllegalArgumentException.class, () -> RoundableFactory.create(new long[0], 0));
assertEquals("at least one value must be present", throwable.getMessage());
public void testBidirectionalLinearSearcher() {
assertRounding(BidirectionalLinearSearcher::new);
}

public void testRoundingSmallArray() {
int size = randomIntBetween(1, 64);
long[] values = randomArrayOfSortedValues(size);
Roundable roundable = RoundableFactory.create(values, size);

assertEquals("BidirectionalLinearSearcher", roundable.getClass().getSimpleName());
assertRounding(roundable, values, size);
public void testBinarySearcher() {
assertRounding(BinarySearcher::new);
}

public void testRoundingLargeArray() {
int size = randomIntBetween(65, 256);
long[] values = randomArrayOfSortedValues(size);
Roundable roundable = RoundableFactory.create(values, size);
@SuppressForbidden(reason = "Reflective construction of BtreeSearcher since it's not supported below Java 20")
public void testBtreeSearcher() {
RoundableSupplier supplier;

try {
Class<?> clz = MethodHandles.lookup().findClass("org.opensearch.common.round.BtreeSearcher");
supplier = (values, size) -> {
try {
return (Roundable) clz.getDeclaredConstructor(long[].class, int.class).newInstance(values, size);
} catch (InvocationTargetException e) {
// Failed to instantiate the class. Unwrap if the nested exception is already a runtime exception,
// say due to an IllegalArgumentException due to bad constructor arguments.
if (e.getCause() instanceof RuntimeException) {
throw (RuntimeException) e.getCause();
} else {
throw new RuntimeException(e);
}
} catch (Exception e) {
throw new RuntimeException(e);
}
};
} catch (ClassNotFoundException e) {
assumeTrue("BtreeSearcher is not supported below Java 20", false);
return;
} catch (IllegalAccessException e) {
throw new RuntimeException(e);
}

assertTrue(List.of("BtreeSearcher", "BinarySearcher").contains(roundable.getClass().getSimpleName()));
assertRounding(roundable, values, size);
assertRounding(supplier);
}

private void assertRounding(Roundable roundable, long[] values, int size) {
private void assertRounding(RoundableSupplier supplier) {
Throwable throwable = assertThrows(IllegalArgumentException.class, () -> supplier.get(new long[0], 0));
assertEquals("at least one value must be present", throwable.getMessage());

int size = randomIntBetween(1, 256);
long[] values = randomArrayOfSortedValues(size);
Roundable roundable = supplier.get(values, size);

for (int i = 0; i < 100000; i++) {
// Index of the expected round-down point.
int idx = randomIntBetween(0, size - 1);
Expand All @@ -55,7 +80,7 @@ private void assertRounding(Roundable roundable, long[] values, int size) {
assertEquals(expected, roundable.floor(key));
}

Throwable throwable = assertThrows(AssertionError.class, () -> roundable.floor(values[0] - 1));
throwable = assertThrows(AssertionError.class, () -> roundable.floor(values[0] - 1));
assertEquals("key must be greater than or equal to " + values[0], throwable.getMessage());
}

Expand All @@ -69,4 +94,9 @@ private static long[] randomArrayOfSortedValues(int size) {

return values;
}

@FunctionalInterface
private interface RoundableSupplier {
Roundable get(long[] values, int size);
}
}

0 comments on commit 2e82d0a

Please sign in to comment.