Skip to content

Commit

Permalink
Added new functions to apoc.coll - shuffle(), randomItem(), randomIte…
Browse files Browse the repository at this point in the history
…ms(). (#296)
  • Loading branch information
InverseFalcon authored and jexp committed Mar 6, 2017
1 parent cd61246 commit 69ebe54
Show file tree
Hide file tree
Showing 3 changed files with 181 additions and 0 deletions.
3 changes: 3 additions & 0 deletions docs/overview.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,9 @@ Sometimes type information gets lost, these functions help you to coerce an "Any
| apoc.coll.unionAll(first, second) | creates the full union with duplicates of the two lists
| apoc.coll.split(list,value) | splits collection on given values rows of lists, value itself will not be part of resulting lists
| apoc.coll.indexOf(coll, value) | position of value in the list
| apoc.coll.shuffle(coll) | returns the shuffled list
| apoc.coll.randomItem(coll) | returns a random item from the list
| apoc.coll.randomItems(coll, itemCount, allowRepick: false) | returns a list of `itemCount` random items from the list, optionally allowing picked elements to be picked again
|===

=== Lookup Functions
Expand Down
43 changes: 43 additions & 0 deletions src/main/java/apoc/coll/Coll.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import org.neo4j.graphdb.Relationship;

import java.util.*;
import java.util.concurrent.ThreadLocalRandom;
import java.util.stream.IntStream;
import java.util.stream.Stream;

Expand Down Expand Up @@ -257,4 +258,46 @@ public List<Object> unionAll(@Name("first") List<Object> first, @Name("second")
return list;
}

@UserFunction
@Description("apoc.coll.shuffle(coll) - returns the shuffled list")
public List<Object> shuffle(@Name("coll") List<Object> coll) {
List<Object> shuffledList = new ArrayList<>(coll);
Collections.shuffle(shuffledList);
return shuffledList;
}

@UserFunction
@Description("apoc.coll.randomItem(coll)- returns a random item from the list, or null on an empty or null list")
public Object randomItem(@Name("coll") List<Object> coll) {
if (coll == null || coll.isEmpty()) {
return null;
}

return coll.get(ThreadLocalRandom.current().nextInt(coll.size()));
}

@UserFunction
@Description("apoc.coll.randomItems(coll, itemCount, allowRepick: false) - returns a list of itemCount random items from the original list, optionally allowing picked elements to be picked again")
public List<Object> randomItems(@Name("coll") List<Object> coll, @Name("itemCount") long itemCount, @Name(value = "allowRepick", defaultValue = "false") boolean allowRepick) {
if (coll == null || coll.isEmpty() || itemCount <= 0) {
return Collections.emptyList();
}

List<Object> pickList = new ArrayList<>(coll);
List<Object> randomItems = new ArrayList<>((int)itemCount);
Random random = ThreadLocalRandom.current();

if (!allowRepick && itemCount >= coll.size()) {
Collections.shuffle(pickList);
return pickList;
}

while (randomItems.size() < itemCount) {
Object item = allowRepick ? pickList.get(random.nextInt(pickList.size()))
: pickList.remove(random.nextInt(pickList.size()));
randomItems.add(item);
}

return randomItems;
}
}
135 changes: 135 additions & 0 deletions src/test/java/apoc/coll/CollTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import static java.util.Arrays.asList;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.neo4j.helpers.collection.Iterables.asSet;

public class CollTest {
Expand Down Expand Up @@ -210,4 +211,138 @@ public void testSetOperations() throws Exception {
testCall(db,"RETURN apoc.coll.removeAll([1,2],[3,2]) AS value", r -> assertEquals(asList(1L),r.get("value")));

}

@Test
public void testShuffle() throws Exception {
// with 10k elements, very remote chance of randomly getting same order
int elements = 10_000;
ArrayList<Long> original = new ArrayList<>(elements);
for (long i = 0; i< elements; i++) {
original.add(i);
}

Map<String, Object> params = new HashMap<>();
params.put("list", original);

testCall(db, "RETURN apoc.coll.shuffle({list}) as value", params,
(row) -> {
List<Object> result = (List<Object>) row.get("value");
assertEquals(original.size(), result.size());
assertTrue(original.containsAll(result));
assertFalse(original.equals(result));
});
}

@Test
public void testRandomItemOnNullAndEmptyList() throws Exception {
testCall(db, "RETURN apoc.coll.randomItem([]) as value",
(row) -> {
Object result = row.get("value");
assertEquals(null, result);
});

testCall(db, "RETURN apoc.coll.randomItem(null) as value",
(row) -> {
Object result = row.get("value");
assertEquals(null, result);
});
}

@Test
public void testRandomItem() throws Exception {
testCall(db, "RETURN apoc.coll.randomItem([1,2,3,4,5]) as value",
(row) -> {
Long result = (Long) row.get("value");
assertTrue(result >= 1 && result <= 5);
});
}

@Test
public void testRandomItemsOnNullAndEmptyList() throws Exception {
testCall(db, "RETURN apoc.coll.randomItems([], 5) as value",
(row) -> {
List<Object> result = (List<Object>) row.get("value");
assertTrue(result.isEmpty());
});

testCall(db, "RETURN apoc.coll.randomItems(null, 5) as value",
(row) -> {
List<Object> result = (List<Object>) row.get("value");
assertTrue(result.isEmpty());
});

testCall(db, "RETURN apoc.coll.randomItems([], 5, true) as value",
(row) -> {
List<Object> result = (List<Object>) row.get("value");
assertTrue(result.isEmpty());
});

testCall(db, "RETURN apoc.coll.randomItems(null, 5, true) as value",
(row) -> {
List<Object> result = (List<Object>) row.get("value");
assertTrue(result.isEmpty());
});
}

@Test
public void testRandomItems() throws Exception {
// with 100k elements, very remote chance of randomly getting same order
int elements = 100_000;
ArrayList<Long> original = new ArrayList<>(elements);
for (long i = 0; i< elements; i++) {
original.add(i);
}

Map<String, Object> params = new HashMap<>();
params.put("list", original);

testCall(db, "RETURN apoc.coll.randomItems({list}, 5000) as value", params,
(row) -> {
List<Object> result = (List<Object>) row.get("value");
assertEquals(result.size(), 5000);
assertTrue(original.containsAll(result));
assertFalse(result.equals(original.subList(0, 5000)));
});
}

@Test
public void testRandomItemsLargerThanOriginal() throws Exception {
// with 10k elements, very remote chance of randomly getting same order
int elements = 10_000;
ArrayList<Long> original = new ArrayList<>(elements);
for (long i = 0; i< elements; i++) {
original.add(i);
}

Map<String, Object> params = new HashMap<>();
params.put("list", original);

testCall(db, "RETURN apoc.coll.randomItems({list}, 20000) as value", params,
(row) -> {
List<Object> result = (List<Object>) row.get("value");
assertEquals(result.size(), 10000);
assertTrue(original.containsAll(result));
assertFalse(result.equals(original));
});
}

@Test
public void testRandomItemsLargerThanOriginalAllowingRepick() throws Exception {
// with 100k elements, very remote chance of randomly getting same order
int elements = 100_000;
ArrayList<Long> original = new ArrayList<>(elements);
for (long i = 0; i< elements; i++) {
original.add(i);
}

Map<String, Object> params = new HashMap<>();
params.put("list", original);

testCall(db, "RETURN apoc.coll.randomItems({list}, 11000, true) as value", params,
(row) -> {
List<Object> result = (List<Object>) row.get("value");
assertEquals(result.size(), 11000);
assertTrue(original.containsAll(result));
});
}
}

0 comments on commit 69ebe54

Please sign in to comment.