diff --git a/rxjava-core/src/main/java/rx/operators/OperationGroupBy.java b/rxjava-core/src/main/java/rx/operators/OperationGroupBy.java index 1c2e6e969c..b95ed09b7c 100644 --- a/rxjava-core/src/main/java/rx/operators/OperationGroupBy.java +++ b/rxjava-core/src/main/java/rx/operators/OperationGroupBy.java @@ -17,12 +17,16 @@ import static org.junit.Assert.*; -import java.util.ArrayList; import java.util.Arrays; -import java.util.HashMap; -import java.util.List; +import java.util.Collection; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; import org.junit.Test; @@ -30,6 +34,9 @@ import rx.Observer; import rx.Subscription; import rx.observables.GroupedObservable; +import rx.subscriptions.BooleanSubscription; +import rx.subscriptions.Subscriptions; +import rx.util.functions.Action1; import rx.util.functions.Func1; import rx.util.functions.Functions; @@ -55,7 +62,12 @@ public static Func1>, Subscription> grou } private static class GroupBy implements Func1>, Subscription> { + private final Observable> source; + private final ConcurrentHashMap> groupedObservables = new ConcurrentHashMap>(); + private final AtomicObservableSubscription actualParentSubscription = new AtomicObservableSubscription(); + private final AtomicInteger numGroupSubscriptions = new AtomicInteger(); + private final AtomicBoolean unsubscribeRequested = new AtomicBoolean(false); private GroupBy(Observable> source) { this.source = source; @@ -63,61 +75,171 @@ private GroupBy(Observable> source) { @Override public Subscription call(final Observer> observer) { - return source.subscribe(new GroupByObserver(observer)); - } + final GroupBy _this = this; + actualParentSubscription.wrap(source.subscribe(new Observer>() { + + @Override + public void onCompleted() { + // we need to propagate to all children I imagine ... we can't just leave all of those Observable/Observers hanging + for (GroupedSubject o : groupedObservables.values()) { + o.onCompleted(); + } + // now the parent + observer.onCompleted(); + } - private class GroupByObserver implements Observer> { - private final Observer> underlying; + @Override + public void onError(Exception e) { + // we need to propagate to all children I imagine ... we can't just leave all of those Observable/Observers hanging + for (GroupedSubject o : groupedObservables.values()) { + o.onError(e); + } + // now the parent + observer.onError(e); + } - private final ConcurrentHashMap keys = new ConcurrentHashMap(); + @Override + public void onNext(KeyValue value) { + GroupedSubject gs = groupedObservables.get(value.key); + if (gs == null) { + if (unsubscribeRequested.get()) { + // unsubscribe has been requested so don't create new groups + // only send data to groups already created + return; + } + /* + * Technically the source should be single-threaded so we shouldn't need to do this but I am + * programming defensively as most operators are so this can work with a concurrent sequence + * if it ends up receiving one. + */ + GroupedSubject newGs = GroupedSubject. create(value.key, _this); + GroupedSubject existing = groupedObservables.putIfAbsent(value.key, newGs); + if (existing == null) { + // we won so use the one we created + gs = newGs; + // since we won the creation we emit this new GroupedObservable + observer.onNext(gs); + } else { + // another thread beat us so use the existing one + gs = existing; + } + } + gs.onNext(value.value); + } + })); - private GroupByObserver(Observer> underlying) { - this.underlying = underlying; - } + return new Subscription() { - @Override - public void onCompleted() { - underlying.onCompleted(); - } + @Override + public void unsubscribe() { + if (numGroupSubscriptions.get() == 0) { + // if we have no group subscriptions we will unsubscribe + actualParentSubscription.unsubscribe(); + // otherwise we mark to not send any more groups (waiting on existing groups to finish) + unsubscribeRequested.set(true); + } + } + }; + } - @Override - public void onError(Exception e) { - underlying.onError(e); - } + /** + * Children notify of being subscribed to. + * + * @param key + */ + private void subscribeKey(K key) { + numGroupSubscriptions.incrementAndGet(); + } - @Override - public void onNext(final KeyValue args) { - K key = args.key; - boolean newGroup = keys.putIfAbsent(key, true) == null; - if (newGroup) { - underlying.onNext(buildObservableFor(source, key)); - } + /** + * Children notify of being unsubscribed from. + * + * @param key + */ + private void unsubscribeKey(K key) { + int c = numGroupSubscriptions.decrementAndGet(); + if (c == 0) { + actualParentSubscription.unsubscribe(); } } } - private static GroupedObservable buildObservableFor(Observable> source, final K key) { - final Observable observable = source.filter(new Func1, Boolean>() { - @Override - public Boolean call(KeyValue pair) { - return key.equals(pair.key); - } - }).map(new Func1, R>() { - @Override - public R call(KeyValue pair) { - return pair.value; - } - }); - return new GroupedObservable(key, new Func1, Subscription>() { + private static class GroupedSubject extends GroupedObservable implements Observer { - @Override - public Subscription call(Observer observer) { - return observable.subscribe(observer); - } + static GroupedSubject create(final K key, final GroupBy parent) { + @SuppressWarnings("unchecked") + final AtomicReference> subscribedObserver = new AtomicReference>(EMPTY_OBSERVER); + + return new GroupedSubject(key, new Func1, Subscription>() { + + private final AtomicObservableSubscription subscription = new AtomicObservableSubscription(); + + @Override + public Subscription call(Observer observer) { + // register Observer + subscribedObserver.set(observer); + + parent.subscribeKey(key); + + return subscription.wrap(new Subscription() { + + @SuppressWarnings("unchecked") + @Override + public void unsubscribe() { + // we remove the Observer so we stop emitting further events (they will be ignored if parent continues to send) + subscribedObserver.set(EMPTY_OBSERVER); + // now we need to notify the parent that we're unsubscribed + parent.unsubscribeKey(key); + } + }); + } + }, subscribedObserver); + } + + private final AtomicReference> subscribedObserver; + + public GroupedSubject(K key, Func1, Subscription> onSubscribe, AtomicReference> subscribedObserver) { + super(key, onSubscribe); + this.subscribedObserver = subscribedObserver; + } + + @Override + public void onCompleted() { + subscribedObserver.get().onCompleted(); + } + + @Override + public void onError(Exception e) { + subscribedObserver.get().onError(e); + } + + @Override + public void onNext(T v) { + subscribedObserver.get().onNext(v); + } - }); } + @SuppressWarnings("rawtypes") + private static Observer EMPTY_OBSERVER = new Observer() { + + @Override + public void onCompleted() { + // do nothing + } + + @Override + public void onError(Exception e) { + // do nothing + } + + @Override + public void onNext(Object args) { + // do nothing + } + + }; + private static class KeyValue { private final K key; private final V value; @@ -141,13 +263,12 @@ public void testGroupBy() { Observable source = Observable.from("one", "two", "three", "four", "five", "six"); Observable> grouped = Observable.create(groupBy(source, length)); - Map> map = toMap(grouped); + Map> map = toMap(grouped); assertEquals(3, map.size()); - assertEquals(Arrays.asList("one", "two", "six"), map.get(3)); - assertEquals(Arrays.asList("four", "five"), map.get(4)); - assertEquals(Arrays.asList("three"), map.get(5)); - + assertArrayEquals(Arrays.asList("one", "two", "six").toArray(), map.get(3).toArray()); + assertArrayEquals(Arrays.asList("four", "five").toArray(), map.get(4).toArray()); + assertArrayEquals(Arrays.asList("three").toArray(), map.get(5).toArray()); } @Test @@ -155,31 +276,286 @@ public void testEmpty() { Observable source = Observable.from(); Observable> grouped = Observable.create(groupBy(source, length)); - Map> map = toMap(grouped); + Map> map = toMap(grouped); assertTrue(map.isEmpty()); } - private static Map> toMap(Observable> observable) { - Map> result = new HashMap>(); - for (GroupedObservable g : observable.toBlockingObservable().toIterable()) { - K key = g.getKey(); + @Test + public void testError() { + Observable sourceStrings = Observable.from("one", "two", "three", "four", "five", "six"); + Observable errorSource = Observable.error(new RuntimeException("forced failure")); + @SuppressWarnings("unchecked") + Observable source = Observable.concat(sourceStrings, errorSource); - for (V value : g.toBlockingObservable().toIterable()) { - List values = result.get(key); - if (values == null) { - values = new ArrayList(); - result.put(key, values); - } + Observable> grouped = Observable.create(groupBy(source, length)); - values.add(value); + final AtomicInteger groupCounter = new AtomicInteger(); + final AtomicInteger eventCounter = new AtomicInteger(); + final AtomicReference error = new AtomicReference(); + + grouped.mapMany(new Func1, Observable>() { + + @Override + public Observable call(final GroupedObservable o) { + groupCounter.incrementAndGet(); + return o.map(new Func1() { + + @Override + public String call(String v) { + return "Event => key: " + o.getKey() + " value: " + v; + } + }); } + }).subscribe(new Observer() { - } + @Override + public void onCompleted() { + + } + + @Override + public void onError(Exception e) { + e.printStackTrace(); + error.set(e); + } + + @Override + public void onNext(String v) { + eventCounter.incrementAndGet(); + System.out.println(v); + + } + }); + + assertEquals(3, groupCounter.get()); + assertEquals(6, eventCounter.get()); + assertNotNull(error.get()); + } + + private static Map> toMap(Observable> observable) { + + final ConcurrentHashMap> result = new ConcurrentHashMap>(); + + observable.forEach(new Action1>() { + + @Override + public void call(final GroupedObservable o) { + result.put(o.getKey(), new ConcurrentLinkedQueue()); + o.subscribe(new Action1() { + + @Override + public void call(V v) { + result.get(o.getKey()).add(v); + } + + }); + } + }); return result; } + /** + * Assert that only a single subscription to a stream occurs and that all events are received. + * + * @throws Exception + */ + @Test + public void testGroupedEventStream() throws Exception { + + final AtomicInteger eventCounter = new AtomicInteger(); + final AtomicInteger subscribeCounter = new AtomicInteger(); + final AtomicInteger groupCounter = new AtomicInteger(); + final CountDownLatch latch = new CountDownLatch(1); + final int count = 100; + final int groupCount = 2; + + Observable es = Observable.create(new Func1, Subscription>() { + + @Override + public Subscription call(final Observer observer) { + System.out.println("*** Subscribing to EventStream ***"); + subscribeCounter.incrementAndGet(); + new Thread(new Runnable() { + + @Override + public void run() { + for (int i = 0; i < count; i++) { + Event e = new Event(); + e.source = i % groupCount; + e.message = "Event-" + i; + observer.onNext(e); + } + observer.onCompleted(); + } + + }).start(); + return Subscriptions.empty(); + } + + }); + + es.groupBy(new Func1() { + + @Override + public Integer call(Event e) { + return e.source; + } + }).mapMany(new Func1, Observable>() { + + @Override + public Observable call(GroupedObservable eventGroupedObservable) { + System.out.println("GroupedObservable Key: " + eventGroupedObservable.getKey()); + groupCounter.incrementAndGet(); + + return eventGroupedObservable.map(new Func1() { + + @Override + public String call(Event event) { + return "Source: " + event.source + " Message: " + event.message; + } + }); + + }; + }).subscribe(new Observer() { + + @Override + public void onCompleted() { + latch.countDown(); + } + + @Override + public void onError(Exception e) { + e.printStackTrace(); + latch.countDown(); + } + + @Override + public void onNext(String outputMessage) { + System.out.println(outputMessage); + eventCounter.incrementAndGet(); + } + }); + + latch.await(5000, TimeUnit.MILLISECONDS); + assertEquals(1, subscribeCounter.get()); + assertEquals(groupCount, groupCounter.get()); + assertEquals(count, eventCounter.get()); + + } + + /* + * We will only take 1 group with 20 events from it and then unsubscribe. + */ + @Test + public void testUnsubscribe() throws InterruptedException { + + final AtomicInteger eventCounter = new AtomicInteger(); + final AtomicInteger subscribeCounter = new AtomicInteger(); + final AtomicInteger groupCounter = new AtomicInteger(); + final AtomicInteger sentEventCounter = new AtomicInteger(); + final CountDownLatch latch = new CountDownLatch(1); + final int count = 100; + final int groupCount = 2; + + Observable es = Observable.create(new Func1, Subscription>() { + + @Override + public Subscription call(final Observer observer) { + final BooleanSubscription s = new BooleanSubscription(); + System.out.println("*** Subscribing to EventStream ***"); + subscribeCounter.incrementAndGet(); + new Thread(new Runnable() { + + @Override + public void run() { + for (int i = 0; i < count; i++) { + if (s.isUnsubscribed()) { + break; + } + Event e = new Event(); + e.source = i % groupCount; + e.message = "Event-" + i; + observer.onNext(e); + sentEventCounter.incrementAndGet(); + } + observer.onCompleted(); + } + + }).start(); + return s; + } + + }); + + es.groupBy(new Func1() { + + @Override + public Integer call(Event e) { + return e.source; + } + }) + .take(1) // we want only the first group + .mapMany(new Func1, Observable>() { + + @Override + public Observable call(GroupedObservable eventGroupedObservable) { + System.out.println("GroupedObservable Key: " + eventGroupedObservable.getKey()); + groupCounter.incrementAndGet(); + + return eventGroupedObservable + .take(20) // limit to only 20 events on this group + .map(new Func1() { + + @Override + public String call(Event event) { + return "Source: " + event.source + " Message: " + event.message; + } + }); + + }; + }).subscribe(new Observer() { + + @Override + public void onCompleted() { + latch.countDown(); + } + + @Override + public void onError(Exception e) { + e.printStackTrace(); + latch.countDown(); + } + + @Override + public void onNext(String outputMessage) { + System.out.println(outputMessage); + eventCounter.incrementAndGet(); + } + }); + + latch.await(5000, TimeUnit.MILLISECONDS); + assertEquals(1, subscribeCounter.get()); + assertEquals(1, groupCounter.get()); + assertEquals(20, eventCounter.get()); + // sentEvents will go until 'eventCounter' hits 20 and then unsubscribes + // which means it will also send (but ignore) the 19 events for the other group + // It will not however send all 100 events. + assertEquals(39, sentEventCounter.get()); + + } + + private static class Event { + int source; + String message; + + @Override + public String toString() { + return "Event => source: " + source + " message: " + message; + } + } + } }