diff --git a/src/main/java/io/reactivex/Observable.java b/src/main/java/io/reactivex/Observable.java index 13bb33ee45..7a33bdbe36 100644 --- a/src/main/java/io/reactivex/Observable.java +++ b/src/main/java/io/reactivex/Observable.java @@ -238,9 +238,7 @@ public final Observable flatMap(Function 0 required but it was " + maxConcurrency); } - if (bufferSize <= 0) { - throw new IllegalArgumentException("bufferSize > 0 required but it was " + bufferSize); - } + validateBufferSize(bufferSize); if (onSubscribe instanceof PublisherScalarSource) { PublisherScalarSource scalar = (PublisherScalarSource) onSubscribe; return create(scalar.flatMap(mapper)); @@ -443,9 +441,7 @@ public final Observable observeOn(Scheduler scheduler, boolean delayError) { public final Observable observeOn(Scheduler scheduler, boolean delayError, int bufferSize) { Objects.requireNonNull(scheduler); - if (bufferSize <= 0) { - throw new IllegalArgumentException("bufferSize > 0 required but it was " + bufferSize); - } + validateBufferSize(bufferSize); return lift(new OperatorObserveOn<>(scheduler, delayError, bufferSize)); } @@ -476,9 +472,7 @@ public final ConnectableObservable publish() { } public final ConnectableObservable publish(int bufferSize) { - if (bufferSize <= 0) { - throw new IllegalArgumentException("bufferSize > 0 required but it was " + bufferSize); - } + validateBufferSize(bufferSize); return OperatorPublish.create(this, bufferSize); } @@ -487,9 +481,7 @@ public final Observable publish(Function, ? extends } public final Observable publish(Function, ? extends Observable> selector, int bufferSize) { - if (bufferSize <= 0) { - throw new IllegalArgumentException("bufferSize > 0 required but it was " + bufferSize); - } + validateBufferSize(bufferSize); Objects.requireNonNull(selector); return OperatorPublish.create(this, selector, bufferSize); } @@ -940,15 +932,14 @@ public final Observable> groupBy(Function 0 required but it was " + bufferSize); - } + validateBufferSize(bufferSize); return lift(new OperatorGroupBy<>(keySelector, valueSelector, bufferSize, delayError)); } @SuppressWarnings("unchecked") private static Function toFunction(BiFunction biFunction) { + Objects.requireNonNull(biFunction); return a -> { if (a.length != 2) { throw new IllegalArgumentException("Array of size 2 expected but got " + a.length); @@ -1036,9 +1027,7 @@ public static Observable zipArray(Function 0 required but it was " + bufferSize); - } + validateBufferSize(bufferSize); return create(new PublisherZip<>(sources, null, zipper, bufferSize, delayError)); } @@ -1047,9 +1036,7 @@ public static Observable zipIterable(Function> sources) { Objects.requireNonNull(zipper); Objects.requireNonNull(sources); - if (bufferSize <= 0) { - throw new IllegalArgumentException("bufferSize > 0 required but it was " + bufferSize); - } + validateBufferSize(bufferSize); return create(new PublisherZip<>(null, sources, zipper, bufferSize, delayError)); } @@ -1462,10 +1449,10 @@ public final Observable skipLast(long time, TimeUnit unit, Scheduler schedule public final Observable skipLast(long time, TimeUnit unit, Scheduler scheduler, boolean delayError, int bufferSize) { Objects.requireNonNull(unit); Objects.requireNonNull(scheduler); - if (bufferSize <= 0) { - throw new IllegalArgumentException("bufferSize > 0 required but it was " + bufferSize); - } - return lift(new OperatorSkipLastTimed<>(time, unit, scheduler, bufferSize, delayError)); + validateBufferSize(bufferSize); + // the internal buffer holds pairs of (timestamp, value) so double the default buffer size + int s = bufferSize << 1; + return lift(new OperatorSkipLastTimed<>(time, unit, scheduler, s, delayError)); } public final Observable takeLast(long time, TimeUnit unit) { @@ -1499,9 +1486,7 @@ public final Observable takeLast(long time, TimeUnit unit, Scheduler schedule public final Observable takeLast(long count, long time, TimeUnit unit, Scheduler scheduler, boolean delayError, int bufferSize) { Objects.requireNonNull(unit); Objects.requireNonNull(scheduler); - if (bufferSize <= 0) { - throw new IllegalArgumentException("bufferSize > 0 required but it was " + bufferSize); - } + validateBufferSize(bufferSize); if (count < 0) { throw new IllegalArgumentException("count >= 0 required but it was " + count); } @@ -1550,9 +1535,7 @@ public final Observable switchMap(Function Observable switchMap(Function> mapper, int bufferSize) { Objects.requireNonNull(mapper); - if (bufferSize <= 0) { - throw new IllegalArgumentException("bufferSize > 0 required but it was " + bufferSize); - } + validateBufferSize(bufferSize); return lift(new OperatorSwitchMap<>(mapper, bufferSize)); } @@ -1653,9 +1636,121 @@ public static Observable sequenceEqual(Publisher p1, P Objects.requireNonNull(p1); Objects.requireNonNull(p2); Objects.requireNonNull(isEqual); + validateBufferSize(bufferSize); + return create(new PublisherSequenceEqual<>(p1, p2, isEqual, bufferSize)); + } + + public static Observable combineLatest(Publisher[] sources, Function combiner) { + return combineLatest(sources, combiner, false, bufferSize()); + } + + public static Observable combineLatest(Publisher[] sources, Function combiner, boolean delayError) { + return combineLatest(sources, combiner, delayError, bufferSize()); + } + + @SafeVarargs + public static Observable combineLatest(Function combiner, boolean delayError, int bufferSize, Publisher... sources) { + return combineLatest(sources, combiner, delayError, bufferSize); + } + + public static Observable combineLatest(Publisher[] sources, Function combiner, boolean delayError, int bufferSize) { + validateBufferSize(bufferSize); + Objects.requireNonNull(combiner); + if (sources.length == 0) { + return empty(); + } + // the queue holds a pair of values so we need to double the capacity + int s = bufferSize << 1; + return create(new PublisherCombineLatest<>(sources, null, combiner, s, delayError)); + } + + public static Observable combineLatest(Iterable> sources, Function combiner) { + return combineLatest(sources, combiner, false, bufferSize()); + } + + public static Observable combineLatest(Iterable> sources, Function combiner, boolean delayError) { + return combineLatest(sources, combiner, delayError, bufferSize()); + } + + public static Observable combineLatest(Iterable> sources, Function combiner, boolean delayError, int bufferSize) { + Objects.requireNonNull(sources); + Objects.requireNonNull(combiner); + validateBufferSize(bufferSize); + + // the queue holds a pair of values so we need to double the capacity + int s = bufferSize << 1; + return create(new PublisherCombineLatest<>(null, sources, combiner, s, delayError)); + } + + private static void validateBufferSize(int bufferSize) { if (bufferSize <= 0) { throw new IllegalArgumentException("bufferSize > 0 required but it was " + bufferSize); } - return create(new PublisherSequenceEqual<>(p1, p2, isEqual, bufferSize)); } + + public static Observable combineLatest( + Publisher p1, Publisher p2, + BiFunction combiner) { + Function f = toFunction(combiner); + return combineLatest(f, false, bufferSize(), p1, p2); + } + + public static Observable combineLatest( + Publisher p1, Publisher p2, + Publisher p3, + Function3 combiner) { + return combineLatest(combiner, false, bufferSize(), p1, p2, p3); + } + + public static Observable combineLatest( + Publisher p1, Publisher p2, + Publisher p3, Publisher p4, + Function4 combiner) { + return combineLatest(combiner, false, bufferSize(), p1, p2, p3, p4); + } + + public static Observable combineLatest( + Publisher p1, Publisher p2, + Publisher p3, Publisher p4, + Publisher p5, + Function5 combiner) { + return combineLatest(combiner, false, bufferSize(), p1, p2, p3, p4, p5); + } + + public static Observable combineLatest( + Publisher p1, Publisher p2, + Publisher p3, Publisher p4, + Publisher p5, Publisher p6, + Function6 combiner) { + return combineLatest(combiner, false, bufferSize(), p1, p2, p3, p4, p5, p6); + } + + public static Observable combineLatest( + Publisher p1, Publisher p2, + Publisher p3, Publisher p4, + Publisher p5, Publisher p6, + Publisher p7, + Function7 combiner) { + return combineLatest(combiner, false, bufferSize(), p1, p2, p3, p4, p5, p6, p7); + } + + public static Observable combineLatest( + Publisher p1, Publisher p2, + Publisher p3, Publisher p4, + Publisher p5, Publisher p6, + Publisher p7, Publisher p8, + Function8 combiner) { + return combineLatest(combiner, false, bufferSize(), p1, p2, p3, p4, p5, p6, p7, p8); + } + + public static Observable combineLatest( + Publisher p1, Publisher p2, + Publisher p3, Publisher p4, + Publisher p5, Publisher p6, + Publisher p7, Publisher p8, + Publisher p9, + Function9 combiner) { + return combineLatest(combiner, false, bufferSize(), p1, p2, p3, p4, p5, p6, p7, p8, p9); + } + } diff --git a/src/main/java/io/reactivex/internal/operators/PublisherCombineLatest.java b/src/main/java/io/reactivex/internal/operators/PublisherCombineLatest.java new file mode 100644 index 0000000000..781fa19709 --- /dev/null +++ b/src/main/java/io/reactivex/internal/operators/PublisherCombineLatest.java @@ -0,0 +1,406 @@ +/** + * Copyright 2015 Netflix, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in + * compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is + * distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See + * the License for the specific language governing permissions and limitations under the License. + */ + +package io.reactivex.internal.operators; + +import java.util.Queue; +import java.util.concurrent.atomic.*; +import java.util.function.Function; + +import org.reactivestreams.*; + +import io.reactivex.internal.queue.SpscLinkedArrayQueue; +import io.reactivex.internal.subscriptions.*; +import io.reactivex.internal.util.BackpressureHelper; +import io.reactivex.plugins.RxJavaPlugins; + +public final class PublisherCombineLatest implements Publisher { + final Publisher[] sources; + final Iterable> sourcesIterable; + final Function combiner; + final int bufferSize; + final boolean delayError; + + public PublisherCombineLatest(Publisher[] sources, + Iterable> sourcesIterable, + Function combiner, int bufferSize, + boolean delayError) { + this.sources = sources; + this.sourcesIterable = sourcesIterable; + this.combiner = combiner; + this.bufferSize = bufferSize; + this.delayError = delayError; + } + + + @Override + @SuppressWarnings("unchecked") + public void subscribe(Subscriber s) { + Publisher[] sources = this.sources; + int count = 0; + if (sources == null) { + sources = new Publisher[8]; + for (Publisher p : sourcesIterable) { + if (count == sources.length) { + Publisher[] b = new Publisher[count + count >> 2]; + System.arraycopy(sources, 0, b, 0, count); + sources = b; + } + sources[count++] = p; + } + } else { + count = sources.length; + } + + if (count == 0) { + s.onSubscribe(EmptySubscription.INSTANCE); + s.onComplete(); + return; + } + + LatestCoordinator lc = new LatestCoordinator<>(s, combiner, count, bufferSize, delayError); + lc.subscribe(sources); + } + + static final class LatestCoordinator extends AtomicInteger implements Subscription { + /** */ + private static final long serialVersionUID = 8567835998786448817L; + final Subscriber actual; + final Function combiner; + final int count; + final CombinerSubscriber[] subscribers; + final int bufferSize; + final Object[] latest; + final SpscLinkedArrayQueue queue; + final boolean delayError; + + volatile boolean cancelled; + + volatile boolean done; + + volatile long requested; + @SuppressWarnings("rawtypes") + static final AtomicLongFieldUpdater REQUESTED = + AtomicLongFieldUpdater.newUpdater(LatestCoordinator.class, "requested"); + + volatile Throwable error; + @SuppressWarnings("rawtypes") + static final AtomicReferenceFieldUpdater ERROR = + AtomicReferenceFieldUpdater.newUpdater(LatestCoordinator.class, Throwable.class, "error"); + + + int active; + int complete; + + @SuppressWarnings("unchecked") + public LatestCoordinator(Subscriber actual, + Function combiner, + int count, int bufferSize, boolean delayError) { + this.actual = actual; + this.combiner = combiner; + this.count = count; + this.bufferSize = bufferSize; + this.delayError = delayError; + this.latest = new Object[count]; + this.subscribers = new CombinerSubscriber[count]; + this.queue = new SpscLinkedArrayQueue<>(bufferSize); + } + + public void subscribe(Publisher[] sources) { + Subscriber[] as = subscribers; + int len = as.length; + for (int i = 0; i < len; i++) { + as[i] = new CombinerSubscriber<>(this, i); + } + lazySet(0); // release array contents + actual.onSubscribe(this); + for (int i = 0; i < len; i++) { + if (cancelled) { + return; + } + sources[i].subscribe(as[i]); + } + } + + @Override + public void request(long n) { + if (SubscriptionHelper.validateRequest(n)) { + return; + } + BackpressureHelper.add(REQUESTED, this, n); + drain(); + } + + @Override + public void cancel() { + if (!cancelled) { + cancelled = true; + + if (getAndIncrement() == 0) { + cancel(queue); + } + } + } + + void cancel(Queue q) { + q.clear(); + for (CombinerSubscriber s : subscribers) { + s.cancel(); + } + } + + void combine(T value, int index) { + + CombinerSubscriber cs = subscribers[index]; + + int a; + int c; + int len; + boolean empty; + boolean f; + synchronized (this) { + len = latest.length; + Object o = latest[index]; + a = active; + if (o == null) { + active = ++a; + } + c = complete; + if (value == null) { + complete = ++c; + } else { + latest[index] = value; + } + f = a == len; + empty = c == len; + if (!empty) { + queue.offer(cs, latest.clone()); + } else { + done = true; + } + } + if (!f && value != null) { + cs.request(1); + return; + } + drain(); + } + + void drain() { + if (getAndIncrement() != 0) { + return; + } + + final Queue q = queue; + final Subscriber a = actual; + final boolean delayError = this.delayError; + + int missed = 1; + for (;;) { + + if (checkTerminated(done, q.isEmpty(), a, q, delayError)) { + return; + } + + long r = requested; + boolean unbounded = r == Long.MAX_VALUE; + long e = 0L; + + while (r != 0L) { + + boolean d = done; + @SuppressWarnings("unchecked") + CombinerSubscriber cs = (CombinerSubscriber)q.peek(); + boolean empty = cs == null; + + if (checkTerminated(d, empty, a, q, delayError)) { + return; + } + + if (empty) { + break; + } + + q.poll(); + Object[] array = (Object[])q.poll(); + + if (array == null) { + cancelled = true; + cancel(q); + a.onError(new IllegalStateException("Broken queue?! Sender received but not the array.")); + return; + } + + R v; + try { + v = combiner.apply(array); + } catch (Throwable ex) { + cancelled = true; + cancel(q); + a.onError(ex); + return; + } + + a.onNext(v); + + cs.request(1); + + r--; + e--; + } + + if (e != 0) { + if (!unbounded) { + REQUESTED.addAndGet(this, e); + } + } + + missed = addAndGet(-missed); + if (missed == 0) { + break; + } + } + } + + + boolean checkTerminated(boolean d, boolean empty, Subscriber a, Queue q, boolean delayError) { + if (cancelled) { + cancel(q); + return true; + } + if (d) { + if (delayError) { + if (empty) { + Throwable e = error; + if (e != null) { + a.onError(e); + } else { + a.onComplete(); + } + return true; + } + } else { + Throwable e = error; + if (e != null) { + cancel(q); + a.onError(e); + return true; + } else + if (empty) { + a.onComplete(); + return true; + } + } + } + return false; + } + + void onError(Throwable e) { + for (;;) { + Throwable curr = error; + if (curr != null) { + e.addSuppressed(curr); + } + Throwable next = e; + if (ERROR.compareAndSet(this, curr, next)) { + return; + } + } + } + } + + static final class CombinerSubscriber implements Subscriber, Subscription { + final LatestCoordinator parent; + final int index; + + boolean done; + + volatile Subscription s; + @SuppressWarnings("rawtypes") + static final AtomicReferenceFieldUpdater S = + AtomicReferenceFieldUpdater.newUpdater(CombinerSubscriber.class, Subscription.class, "s"); + + static final Subscription CANCELLED = new Subscription() { + @Override + public void request(long n) { + + } + + @Override + public void cancel() { + + } + }; + + public CombinerSubscriber(LatestCoordinator parent, int index) { + this.parent = parent; + this.index = index; + } + + @Override + public void onSubscribe(Subscription s) { + if (!S.compareAndSet(this, null, s)) { + s.cancel(); + if (s != CANCELLED) { + SubscriptionHelper.reportSubscriptionSet(); + } + return; + } + s.request(parent.bufferSize); + } + + @Override + public void onNext(T t) { + if (done) { + return; + } + parent.combine(t, index); + } + + @Override + public void onError(Throwable t) { + if (done) { + RxJavaPlugins.onError(t); + return; + } + parent.onError(t); + done = true; + parent.combine(null, index); + } + + @Override + public void onComplete() { + if (done) { + return; + } + done = true; + parent.combine(null, index); + } + + @Override + public void request(long n) { + s.request(n); + } + + @Override + public void cancel() { + Subscription a = s; + if (a != CANCELLED) { + a = S.getAndSet(this, CANCELLED); + if (a != CANCELLED && a != null) { + a.cancel(); + } + } + } + } +}