Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GroupBy GroupedObservables should not re-subscribe to parent sequence #283

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
500 changes: 438 additions & 62 deletions rxjava-core/src/main/java/rx/operators/OperationGroupBy.java
Original file line number Diff line number Diff line change
@@ -17,19 +17,26 @@

import static org.junit.Assert.*;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Collection;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;

import org.junit.Test;

import rx.Observable;
import rx.Observer;
import rx.Subscription;
import rx.observables.GroupedObservable;
import rx.subscriptions.BooleanSubscription;
import rx.subscriptions.Subscriptions;
import rx.util.functions.Action1;
import rx.util.functions.Func1;
import rx.util.functions.Functions;

@@ -55,69 +62,184 @@ public static <K, T> Func1<Observer<GroupedObservable<K, T>>, Subscription> grou
}

private static class GroupBy<K, V> implements Func1<Observer<GroupedObservable<K, V>>, Subscription> {

private final Observable<KeyValue<K, V>> source;
private final ConcurrentHashMap<K, GroupedSubject<K, V>> groupedObservables = new ConcurrentHashMap<K, GroupedSubject<K, V>>();
private final AtomicObservableSubscription actualParentSubscription = new AtomicObservableSubscription();
private final AtomicInteger numGroupSubscriptions = new AtomicInteger();
private final AtomicBoolean unsubscribeRequested = new AtomicBoolean(false);

private GroupBy(Observable<KeyValue<K, V>> source) {
this.source = source;
}

@Override
public Subscription call(final Observer<GroupedObservable<K, V>> observer) {
return source.subscribe(new GroupByObserver(observer));
}
final GroupBy<K, V> _this = this;
actualParentSubscription.wrap(source.subscribe(new Observer<KeyValue<K, V>>() {

@Override
public void onCompleted() {
// we need to propagate to all children I imagine ... we can't just leave all of those Observable/Observers hanging
for (GroupedSubject<K, V> o : groupedObservables.values()) {
o.onCompleted();
}
// now the parent
observer.onCompleted();
}

private class GroupByObserver implements Observer<KeyValue<K, V>> {
private final Observer<GroupedObservable<K, V>> underlying;
@Override
public void onError(Exception e) {
// we need to propagate to all children I imagine ... we can't just leave all of those Observable/Observers hanging
for (GroupedSubject<K, V> o : groupedObservables.values()) {
o.onError(e);
}
// now the parent
observer.onError(e);
}

private final ConcurrentHashMap<K, Boolean> keys = new ConcurrentHashMap<K, Boolean>();
@Override
public void onNext(KeyValue<K, V> value) {
GroupedSubject<K, V> gs = groupedObservables.get(value.key);
if (gs == null) {
if (unsubscribeRequested.get()) {
// unsubscribe has been requested so don't create new groups
// only send data to groups already created
return;
}
/*
* Technically the source should be single-threaded so we shouldn't need to do this but I am
* programming defensively as most operators are so this can work with a concurrent sequence
* if it ends up receiving one.
*/
GroupedSubject<K, V> newGs = GroupedSubject.<K, V> create(value.key, _this);
GroupedSubject<K, V> existing = groupedObservables.putIfAbsent(value.key, newGs);
if (existing == null) {
// we won so use the one we created
gs = newGs;
// since we won the creation we emit this new GroupedObservable
observer.onNext(gs);
} else {
// another thread beat us so use the existing one
gs = existing;
}
}
gs.onNext(value.value);
}
}));

private GroupByObserver(Observer<GroupedObservable<K, V>> underlying) {
this.underlying = underlying;
}
return new Subscription() {

@Override
public void onCompleted() {
underlying.onCompleted();
}
@Override
public void unsubscribe() {
if (numGroupSubscriptions.get() == 0) {
// if we have no group subscriptions we will unsubscribe
actualParentSubscription.unsubscribe();
// otherwise we mark to not send any more groups (waiting on existing groups to finish)
unsubscribeRequested.set(true);
}
}
};
}

@Override
public void onError(Exception e) {
underlying.onError(e);
}
/**
* Children notify of being subscribed to.
*
* @param key
*/
private void subscribeKey(K key) {
numGroupSubscriptions.incrementAndGet();
}

@Override
public void onNext(final KeyValue<K, V> args) {
K key = args.key;
boolean newGroup = keys.putIfAbsent(key, true) == null;
if (newGroup) {
underlying.onNext(buildObservableFor(source, key));
}
/**
* Children notify of being unsubscribed from.
*
* @param key
*/
private void unsubscribeKey(K key) {
int c = numGroupSubscriptions.decrementAndGet();
if (c == 0) {
actualParentSubscription.unsubscribe();
}
}
}

private static <K, R> GroupedObservable<K, R> buildObservableFor(Observable<KeyValue<K, R>> source, final K key) {
final Observable<R> observable = source.filter(new Func1<KeyValue<K, R>, Boolean>() {
@Override
public Boolean call(KeyValue<K, R> pair) {
return key.equals(pair.key);
}
}).map(new Func1<KeyValue<K, R>, R>() {
@Override
public R call(KeyValue<K, R> pair) {
return pair.value;
}
});
return new GroupedObservable<K, R>(key, new Func1<Observer<R>, Subscription>() {
private static class GroupedSubject<K, T> extends GroupedObservable<K, T> implements Observer<T> {

@Override
public Subscription call(Observer<R> observer) {
return observable.subscribe(observer);
}
static <K, T> GroupedSubject<K, T> create(final K key, final GroupBy<K, T> parent) {
@SuppressWarnings("unchecked")
final AtomicReference<Observer<T>> subscribedObserver = new AtomicReference<Observer<T>>(EMPTY_OBSERVER);

return new GroupedSubject<K, T>(key, new Func1<Observer<T>, Subscription>() {

private final AtomicObservableSubscription subscription = new AtomicObservableSubscription();

@Override
public Subscription call(Observer<T> observer) {
// register Observer
subscribedObserver.set(observer);

parent.subscribeKey(key);

return subscription.wrap(new Subscription() {

@SuppressWarnings("unchecked")
@Override
public void unsubscribe() {
// we remove the Observer so we stop emitting further events (they will be ignored if parent continues to send)
subscribedObserver.set(EMPTY_OBSERVER);
// now we need to notify the parent that we're unsubscribed
parent.unsubscribeKey(key);
}
});
}
}, subscribedObserver);
}

private final AtomicReference<Observer<T>> subscribedObserver;

public GroupedSubject(K key, Func1<Observer<T>, Subscription> onSubscribe, AtomicReference<Observer<T>> subscribedObserver) {
super(key, onSubscribe);
this.subscribedObserver = subscribedObserver;
}

@Override
public void onCompleted() {
subscribedObserver.get().onCompleted();
}

@Override
public void onError(Exception e) {
subscribedObserver.get().onError(e);
}

@Override
public void onNext(T v) {
subscribedObserver.get().onNext(v);
}

});
}

@SuppressWarnings("rawtypes")
private static Observer EMPTY_OBSERVER = new Observer() {

@Override
public void onCompleted() {
// do nothing
}

@Override
public void onError(Exception e) {
// do nothing
}

@Override
public void onNext(Object args) {
// do nothing
}

};

private static class KeyValue<K, V> {
private final K key;
private final V value;
@@ -141,45 +263,299 @@ public void testGroupBy() {
Observable<String> source = Observable.from("one", "two", "three", "four", "five", "six");
Observable<GroupedObservable<Integer, String>> grouped = Observable.create(groupBy(source, length));

Map<Integer, List<String>> map = toMap(grouped);
Map<Integer, Collection<String>> map = toMap(grouped);

assertEquals(3, map.size());
assertEquals(Arrays.asList("one", "two", "six"), map.get(3));
assertEquals(Arrays.asList("four", "five"), map.get(4));
assertEquals(Arrays.asList("three"), map.get(5));

assertArrayEquals(Arrays.asList("one", "two", "six").toArray(), map.get(3).toArray());
assertArrayEquals(Arrays.asList("four", "five").toArray(), map.get(4).toArray());
assertArrayEquals(Arrays.asList("three").toArray(), map.get(5).toArray());
}

@Test
public void testEmpty() {
Observable<String> source = Observable.from();
Observable<GroupedObservable<Integer, String>> grouped = Observable.create(groupBy(source, length));

Map<Integer, List<String>> map = toMap(grouped);
Map<Integer, Collection<String>> map = toMap(grouped);

assertTrue(map.isEmpty());
}

private static <K, V> Map<K, List<V>> toMap(Observable<GroupedObservable<K, V>> observable) {
Map<K, List<V>> result = new HashMap<K, List<V>>();
for (GroupedObservable<K, V> g : observable.toBlockingObservable().toIterable()) {
K key = g.getKey();
@Test
public void testError() {
Observable<String> sourceStrings = Observable.from("one", "two", "three", "four", "five", "six");
Observable<String> errorSource = Observable.error(new RuntimeException("forced failure"));
@SuppressWarnings("unchecked")
Observable<String> source = Observable.concat(sourceStrings, errorSource);

for (V value : g.toBlockingObservable().toIterable()) {
List<V> values = result.get(key);
if (values == null) {
values = new ArrayList<V>();
result.put(key, values);
}
Observable<GroupedObservable<Integer, String>> grouped = Observable.create(groupBy(source, length));

values.add(value);
final AtomicInteger groupCounter = new AtomicInteger();
final AtomicInteger eventCounter = new AtomicInteger();
final AtomicReference<Exception> error = new AtomicReference<Exception>();

grouped.mapMany(new Func1<GroupedObservable<Integer, String>, Observable<String>>() {

@Override
public Observable<String> call(final GroupedObservable<Integer, String> o) {
groupCounter.incrementAndGet();
return o.map(new Func1<String, String>() {

@Override
public String call(String v) {
return "Event => key: " + o.getKey() + " value: " + v;
}
});
}
}).subscribe(new Observer<String>() {

}
@Override
public void onCompleted() {

}

@Override
public void onError(Exception e) {
e.printStackTrace();
error.set(e);
}

@Override
public void onNext(String v) {
eventCounter.incrementAndGet();
System.out.println(v);

}
});

assertEquals(3, groupCounter.get());
assertEquals(6, eventCounter.get());
assertNotNull(error.get());
}

private static <K, V> Map<K, Collection<V>> toMap(Observable<GroupedObservable<K, V>> observable) {

final ConcurrentHashMap<K, Collection<V>> result = new ConcurrentHashMap<K, Collection<V>>();

observable.forEach(new Action1<GroupedObservable<K, V>>() {

@Override
public void call(final GroupedObservable<K, V> o) {
result.put(o.getKey(), new ConcurrentLinkedQueue<V>());
o.subscribe(new Action1<V>() {

@Override
public void call(V v) {
result.get(o.getKey()).add(v);
}

});
}
});

return result;
}

/**
* Assert that only a single subscription to a stream occurs and that all events are received.
*
* @throws Exception
*/
@Test
public void testGroupedEventStream() throws Exception {

final AtomicInteger eventCounter = new AtomicInteger();
final AtomicInteger subscribeCounter = new AtomicInteger();
final AtomicInteger groupCounter = new AtomicInteger();
final CountDownLatch latch = new CountDownLatch(1);
final int count = 100;
final int groupCount = 2;

Observable<Event> es = Observable.create(new Func1<Observer<Event>, Subscription>() {

@Override
public Subscription call(final Observer<Event> observer) {
System.out.println("*** Subscribing to EventStream ***");
subscribeCounter.incrementAndGet();
new Thread(new Runnable() {

@Override
public void run() {
for (int i = 0; i < count; i++) {
Event e = new Event();
e.source = i % groupCount;
e.message = "Event-" + i;
observer.onNext(e);
}
observer.onCompleted();
}

}).start();
return Subscriptions.empty();
}

});

es.groupBy(new Func1<Event, Integer>() {

@Override
public Integer call(Event e) {
return e.source;
}
}).mapMany(new Func1<GroupedObservable<Integer, Event>, Observable<String>>() {

@Override
public Observable<String> call(GroupedObservable<Integer, Event> eventGroupedObservable) {
System.out.println("GroupedObservable Key: " + eventGroupedObservable.getKey());
groupCounter.incrementAndGet();

return eventGroupedObservable.map(new Func1<Event, String>() {

@Override
public String call(Event event) {
return "Source: " + event.source + " Message: " + event.message;
}
});

};
}).subscribe(new Observer<String>() {

@Override
public void onCompleted() {
latch.countDown();
}

@Override
public void onError(Exception e) {
e.printStackTrace();
latch.countDown();
}

@Override
public void onNext(String outputMessage) {
System.out.println(outputMessage);
eventCounter.incrementAndGet();
}
});

latch.await(5000, TimeUnit.MILLISECONDS);
assertEquals(1, subscribeCounter.get());
assertEquals(groupCount, groupCounter.get());
assertEquals(count, eventCounter.get());

}

/*
* We will only take 1 group with 20 events from it and then unsubscribe.
*/
@Test
public void testUnsubscribe() throws InterruptedException {

final AtomicInteger eventCounter = new AtomicInteger();
final AtomicInteger subscribeCounter = new AtomicInteger();
final AtomicInteger groupCounter = new AtomicInteger();
final AtomicInteger sentEventCounter = new AtomicInteger();
final CountDownLatch latch = new CountDownLatch(1);
final int count = 100;
final int groupCount = 2;

Observable<Event> es = Observable.create(new Func1<Observer<Event>, Subscription>() {

@Override
public Subscription call(final Observer<Event> observer) {
final BooleanSubscription s = new BooleanSubscription();
System.out.println("*** Subscribing to EventStream ***");
subscribeCounter.incrementAndGet();
new Thread(new Runnable() {

@Override
public void run() {
for (int i = 0; i < count; i++) {
if (s.isUnsubscribed()) {
break;
}
Event e = new Event();
e.source = i % groupCount;
e.message = "Event-" + i;
observer.onNext(e);
sentEventCounter.incrementAndGet();
}
observer.onCompleted();
}

}).start();
return s;
}

});

es.groupBy(new Func1<Event, Integer>() {

@Override
public Integer call(Event e) {
return e.source;
}
})
.take(1) // we want only the first group
.mapMany(new Func1<GroupedObservable<Integer, Event>, Observable<String>>() {

@Override
public Observable<String> call(GroupedObservable<Integer, Event> eventGroupedObservable) {
System.out.println("GroupedObservable Key: " + eventGroupedObservable.getKey());
groupCounter.incrementAndGet();

return eventGroupedObservable
.take(20) // limit to only 20 events on this group
.map(new Func1<Event, String>() {

@Override
public String call(Event event) {
return "Source: " + event.source + " Message: " + event.message;
}
});

};
}).subscribe(new Observer<String>() {

@Override
public void onCompleted() {
latch.countDown();
}

@Override
public void onError(Exception e) {
e.printStackTrace();
latch.countDown();
}

@Override
public void onNext(String outputMessage) {
System.out.println(outputMessage);
eventCounter.incrementAndGet();
}
});

latch.await(5000, TimeUnit.MILLISECONDS);
assertEquals(1, subscribeCounter.get());
assertEquals(1, groupCounter.get());
assertEquals(20, eventCounter.get());
// sentEvents will go until 'eventCounter' hits 20 and then unsubscribes
// which means it will also send (but ignore) the 19 events for the other group
// It will not however send all 100 events.
assertEquals(39, sentEventCounter.get());

}

private static class Event {
int source;
String message;

@Override
public String toString() {
return "Event => source: " + source + " message: " + message;
}
}

}

}