From 1ff313fd3964355d49d53ffa403be7ad6a41a603 Mon Sep 17 00:00:00 2001 From: akarnokd Date: Mon, 26 Jan 2015 13:26:50 +0100 Subject: [PATCH 1/2] Merge with max concurrency now supports backpressure. --- .../operators/OperatorMergeMaxConcurrent.java | 288 +++++++++++++++--- .../OperatorMergeMaxConcurrentTest.java | 157 +++++++++- 2 files changed, 393 insertions(+), 52 deletions(-) diff --git a/src/main/java/rx/internal/operators/OperatorMergeMaxConcurrent.java b/src/main/java/rx/internal/operators/OperatorMergeMaxConcurrent.java index d75425bb6d..9f28f3199e 100644 --- a/src/main/java/rx/internal/operators/OperatorMergeMaxConcurrent.java +++ b/src/main/java/rx/internal/operators/OperatorMergeMaxConcurrent.java @@ -15,12 +15,14 @@ */ package rx.internal.operators; -import java.util.LinkedList; -import java.util.Queue; -import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; -import rx.Observable; +import java.util.*; +import java.util.concurrent.atomic.*; + +import rx.*; import rx.Observable.Operator; -import rx.Subscriber; +import rx.Observable; +import rx.exceptions.MissingBackpressureException; +import rx.internal.util.RxRingBuffer; import rx.observers.SerializedSubscriber; import rx.subscriptions.CompositeSubscription; @@ -47,9 +49,24 @@ public Subscriber> call(Subscriber ch final CompositeSubscription csub = new CompositeSubscription(); child.add(csub); - return new SourceSubscriber(maxConcurrency, s, csub); + SourceSubscriber ssub = new SourceSubscriber(maxConcurrency, s, csub); + child.setProducer(new MergeMaxConcurrentProducer(ssub)); + + return ssub; + } + /** Routes the requests from downstream to the sourcesubscriber. */ + static final class MergeMaxConcurrentProducer implements Producer { + final SourceSubscriber ssub; + public MergeMaxConcurrentProducer(SourceSubscriber ssub) { + this.ssub = ssub; + } + @Override + public void request(long n) { + ssub.downstreamRequest(n); + } } static final class SourceSubscriber extends Subscriber> { + final NotificationLite nl = NotificationLite.instance(); final int maxConcurrency; final Subscriber s; final CompositeSubscription csub; @@ -57,24 +74,50 @@ static final class SourceSubscriber extends Subscriber WIP_UPDATER + static final AtomicIntegerFieldUpdater WIP = AtomicIntegerFieldUpdater.newUpdater(SourceSubscriber.class, "wip"); + volatile int sourceIndex; + @SuppressWarnings("rawtypes") + static final AtomicIntegerFieldUpdater SOURCE_INDEX + = AtomicIntegerFieldUpdater.newUpdater(SourceSubscriber.class, "sourceIndex"); /** Guarded by guard. */ int active; /** Guarded by guard. */ final Queue> queue; + /** Indicates the emitting phase. Guarded by this. */ + boolean emitting; + /** Counts the missed emitting calls. Guarded by this. */ + int missedEmitting; + /** The last buffer index in the round-robin drain scheme. Accessed while emitting == true. */ + int lastIndex; + + /** Guarded by itself. */ + final List subscribers; + + volatile long requested; + @SuppressWarnings("rawtypes") + static final AtomicLongFieldUpdater REQUESTED + = AtomicLongFieldUpdater.newUpdater(SourceSubscriber.class, "requested"); + + public SourceSubscriber(int maxConcurrency, Subscriber s, CompositeSubscription csub) { super(s); this.maxConcurrency = maxConcurrency; this.s = s; this.csub = csub; this.guard = new Object(); - this.queue = new LinkedList>(); + this.queue = new ArrayDeque>(maxConcurrency); + this.subscribers = Collections.synchronizedList(new ArrayList()); this.wip = 1; } + @Override + public void onStart() { + request(maxConcurrency); + } + @Override public void onNext(Observable t) { synchronized (guard) { @@ -94,50 +137,213 @@ void subscribeNext() { queue.poll(); } - Subscriber itemSub = new Subscriber() { - boolean once = true; - @Override - public void onNext(T t) { - s.onNext(t); - } - - @Override - public void onError(Throwable e) { - SourceSubscriber.this.onError(e); - } - - @Override - public void onCompleted() { - if (once) { - once = false; - synchronized (guard) { - active--; - } - csub.remove(this); - - subscribeNext(); - - SourceSubscriber.this.onCompleted(); - } - } - - }; + MergeItemSubscriber itemSub = new MergeItemSubscriber(SOURCE_INDEX.getAndIncrement(this)); + subscribers.add(itemSub); + csub.add(itemSub); - WIP_UPDATER.incrementAndGet(this); + + WIP.incrementAndGet(this); t.unsafeSubscribe(itemSub); + + request(1); } @Override public void onError(Throwable e) { - s.onError(e); - unsubscribe(); + Object[] active; + synchronized (subscribers) { + active = subscribers.toArray(); + subscribers.clear(); + } + + try { + s.onError(e); + + unsubscribe(); + } finally { + for (Object o : active) { + @SuppressWarnings("unchecked") + MergeItemSubscriber a = (MergeItemSubscriber)o; + a.release(); + } + } + } @Override public void onCompleted() { - if (WIP_UPDATER.decrementAndGet(this) == 0) { - s.onCompleted(); + WIP.decrementAndGet(this); + drain(); + } + + protected void downstreamRequest(long n) { + for (;;) { + long r = requested; + long u; + if (r != Long.MAX_VALUE && n == Long.MAX_VALUE) { + u = Long.MAX_VALUE; + } else + if (r + n < 0) { + u = Long.MAX_VALUE; + } else { + u = r + n; + } + if (REQUESTED.compareAndSet(this, r, u)) { + break; + } + } + drain(); + } + + protected void drain() { + synchronized (this) { + if (emitting) { + missedEmitting++; + return; + } + emitting = true; + missedEmitting = 0; + } + final List.MergeItemSubscriber> subs = subscribers; + final Subscriber child = s; + Object[] active = new Object[subs.size()]; + do { + long r; + + outer: + while ((r = requested) > 0) { + int idx = lastIndex; + synchronized (subs) { + if (subs.size() == active.length) { + active = subs.toArray(active); + } else { + active = subs.toArray(); + } + } + + int resumeIndex = 0; + int j = 0; + for (Object o : active) { + @SuppressWarnings("unchecked") + MergeItemSubscriber e = (MergeItemSubscriber)o; + if (e.index == idx) { + resumeIndex = j; + break; + } + j++; + } + int sumConsumed = 0; + for (int i = 0; i < active.length; i++) { + j = (i + resumeIndex) % active.length; + + @SuppressWarnings("unchecked") + final MergeItemSubscriber e = (MergeItemSubscriber)active[j]; + final RxRingBuffer b = e.buffer; + lastIndex = e.index; + + if (!e.once && b.peek() == null) { + subs.remove(e); + + synchronized (guard) { + this.active--; + } + csub.remove(e); + + e.release(); + + subscribeNext(); + + WIP.decrementAndGet(this); + + continue outer; + } + + int consumed = 0; + Object v; + while (r > 0 && (v = b.poll()) != null) { + nl.accept(child, v); + if (child.isUnsubscribed()) { + return; + } + r--; + consumed++; + } + if (consumed > 0) { + sumConsumed += consumed; + REQUESTED.addAndGet(this, -consumed); + e.requestMore(consumed); + } + if (r == 0) { + break outer; + } + } + if (sumConsumed == 0) { + break; + } + } + + if (active.length == 0) { + if (wip == 0) { + child.onCompleted(); + return; + } + } + synchronized (this) { + if (missedEmitting == 0) { + emitting = false; + break; + } + missedEmitting = 0; + } + } while (true); + } + final class MergeItemSubscriber extends Subscriber { + volatile boolean once = true; + final int index; + final RxRingBuffer buffer; + + public MergeItemSubscriber(int index) { + buffer = RxRingBuffer.getSpmcInstance(); + this.index = index; + } + + @Override + public void onStart() { + request(RxRingBuffer.SIZE); + } + + @Override + public void onNext(T t) { + try { + buffer.onNext(t); + } catch (MissingBackpressureException ex) { + onError(ex); + return; + } + + drain(); + } + + @Override + public void onError(Throwable e) { + SourceSubscriber.this.onError(e); + } + + @Override + public void onCompleted() { + if (once) { + once = false; + drain(); + } + } + /** Request more from upstream. */ + void requestMore(long n) { + request(n); + } + void release() { + // NO-OP for now + buffer.release(); } } } diff --git a/src/test/java/rx/internal/operators/OperatorMergeMaxConcurrentTest.java b/src/test/java/rx/internal/operators/OperatorMergeMaxConcurrentTest.java index be9a04cb7a..238526a3c3 100644 --- a/src/test/java/rx/internal/operators/OperatorMergeMaxConcurrentTest.java +++ b/src/test/java/rx/internal/operators/OperatorMergeMaxConcurrentTest.java @@ -15,23 +15,19 @@ */ package rx.internal.operators; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; +import static org.junit.Assert.*; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Iterator; -import java.util.List; +import java.util.*; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicInteger; -import org.junit.Before; -import org.junit.Test; -import org.mockito.Mock; -import org.mockito.MockitoAnnotations; +import org.junit.*; +import org.mockito.*; +import rx.*; import rx.Observable; import rx.Observer; -import rx.Subscriber; +import rx.observers.TestSubscriber; import rx.schedulers.Schedulers; public class OperatorMergeMaxConcurrentTest { @@ -157,4 +153,143 @@ public void testMergeALotOfSourcesOneByOneSynchronouslyTakeHalf() { } assertEquals(j, n / 2); } + + @Test + public void testSimple() { + for (int i = 1; i < 100; i++) { + TestSubscriber ts = new TestSubscriber(); + List> sourceList = new ArrayList>(i); + List result = new ArrayList(i); + for (int j = 1; j <= i; j++) { + sourceList.add(Observable.just(j)); + result.add(j); + } + + Observable.merge(sourceList, i).subscribe(ts); + + ts.assertNoErrors(); + ts.assertTerminalEvent(); + ts.assertReceivedOnNext(result); + } + } + @Test + public void testSimpleOneLess() { + for (int i = 2; i < 100; i++) { + TestSubscriber ts = new TestSubscriber(); + List> sourceList = new ArrayList>(i); + List result = new ArrayList(i); + for (int j = 1; j <= i; j++) { + sourceList.add(Observable.just(j)); + result.add(j); + } + + Observable.merge(sourceList, i - 1).subscribe(ts); + + ts.assertNoErrors(); + ts.assertTerminalEvent(); + ts.assertReceivedOnNext(result); + } + } + @Test(timeout = 10000) + public void testSympleAsyncLoop() { + for (int i = 0; i < 200; i++) { + testSimpleAsync(); + } + } + @Test(timeout = 10000) + public void testSimpleAsync() { + for (int i = 1; i < 100; i++) { + TestSubscriber ts = new TestSubscriber(); + List> sourceList = new ArrayList>(i); + Set expected = new HashSet(i); + for (int j = 1; j <= i; j++) { + sourceList.add(Observable.just(j).subscribeOn(Schedulers.io())); + expected.add(j); + } + + Observable.merge(sourceList, i).subscribe(ts); + + ts.awaitTerminalEvent(); + ts.assertNoErrors(); + Set actual = new HashSet(ts.getOnNextEvents()); + + assertEquals(expected, actual); + } + } + @Test(timeout = 10000) + public void testSimpleOneLessAsyncLoop() { + for (int i = 0; i < 200; i++) { + testSimpleOneLessAsync(); + } + } + @Test(timeout = 10000) + public void testSimpleOneLessAsync() { + for (int i = 2; i < 100; i++) { + TestSubscriber ts = new TestSubscriber(); + List> sourceList = new ArrayList>(i); + Set expected = new HashSet(i); + for (int j = 1; j <= i; j++) { + sourceList.add(Observable.just(j).subscribeOn(Schedulers.io())); + expected.add(j); + } + + Observable.merge(sourceList, i - 1).subscribe(ts); + + ts.awaitTerminalEvent(); + ts.assertNoErrors(); + Set actual = new HashSet(ts.getOnNextEvents()); + + assertEquals(expected, actual); + } + } + @Test(timeout = 5000) + public void testBackpressureHonored() throws Exception { + List> sourceList = new ArrayList>(3); + + sourceList.add(Observable.range(0, 100000).subscribeOn(Schedulers.io())); + sourceList.add(Observable.range(0, 100000).subscribeOn(Schedulers.io())); + sourceList.add(Observable.range(0, 100000).subscribeOn(Schedulers.io())); + + final CountDownLatch cdl = new CountDownLatch(5); + + TestSubscriber ts = new TestSubscriber() { + @Override + public void onStart() { + request(0); + } + @Override + public void onNext(Integer t) { + super.onNext(t); + cdl.countDown(); + } + }; + + Observable.merge(sourceList, 2).subscribe(ts); + + ts.requestMore(5); + + cdl.await(); + + ts.assertNoErrors(); + assertEquals(5, ts.getOnNextEvents().size()); + assertEquals(0, ts.getOnCompletedEvents().size()); + + ts.unsubscribe(); + } + @Test(timeout = 5000) + public void testTake() throws Exception { + List> sourceList = new ArrayList>(3); + + sourceList.add(Observable.range(0, 100000).subscribeOn(Schedulers.io())); + sourceList.add(Observable.range(0, 100000).subscribeOn(Schedulers.io())); + sourceList.add(Observable.range(0, 100000).subscribeOn(Schedulers.io())); + + TestSubscriber ts = new TestSubscriber(); + + Observable.merge(sourceList, 2).take(5).subscribe(ts); + + ts.awaitTerminalEvent(); + ts.assertNoErrors(); + assertEquals(5, ts.getOnNextEvents().size()); + } } From a868569c44fd8f44b0be0826f0387d788714c636 Mon Sep 17 00:00:00 2001 From: akarnokd Date: Mon, 26 Jan 2015 13:41:32 +0100 Subject: [PATCH 2/2] Less concurrent threads and more in-line timeout detection. --- .../operators/OperatorMergeMaxConcurrentTest.java | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/test/java/rx/internal/operators/OperatorMergeMaxConcurrentTest.java b/src/test/java/rx/internal/operators/OperatorMergeMaxConcurrentTest.java index 238526a3c3..52c7ee21f2 100644 --- a/src/test/java/rx/internal/operators/OperatorMergeMaxConcurrentTest.java +++ b/src/test/java/rx/internal/operators/OperatorMergeMaxConcurrentTest.java @@ -18,7 +18,7 @@ import static org.junit.Assert.*; import java.util.*; -import java.util.concurrent.CountDownLatch; +import java.util.concurrent.*; import java.util.concurrent.atomic.AtomicInteger; import org.junit.*; @@ -191,14 +191,14 @@ public void testSimpleOneLess() { } } @Test(timeout = 10000) - public void testSympleAsyncLoop() { + public void testSimpleAsyncLoop() { for (int i = 0; i < 200; i++) { testSimpleAsync(); } } @Test(timeout = 10000) public void testSimpleAsync() { - for (int i = 1; i < 100; i++) { + for (int i = 1; i < 50; i++) { TestSubscriber ts = new TestSubscriber(); List> sourceList = new ArrayList>(i); Set expected = new HashSet(i); @@ -209,7 +209,7 @@ public void testSimpleAsync() { Observable.merge(sourceList, i).subscribe(ts); - ts.awaitTerminalEvent(); + ts.awaitTerminalEvent(1, TimeUnit.SECONDS); ts.assertNoErrors(); Set actual = new HashSet(ts.getOnNextEvents()); @@ -224,7 +224,7 @@ public void testSimpleOneLessAsyncLoop() { } @Test(timeout = 10000) public void testSimpleOneLessAsync() { - for (int i = 2; i < 100; i++) { + for (int i = 2; i < 50; i++) { TestSubscriber ts = new TestSubscriber(); List> sourceList = new ArrayList>(i); Set expected = new HashSet(i); @@ -235,7 +235,7 @@ public void testSimpleOneLessAsync() { Observable.merge(sourceList, i - 1).subscribe(ts); - ts.awaitTerminalEvent(); + ts.awaitTerminalEvent(1, TimeUnit.SECONDS); ts.assertNoErrors(); Set actual = new HashSet(ts.getOnNextEvents());