Skip to content

Commit

Permalink
Add awaitSubscribed() to test async sources (#1661)
Browse files Browse the repository at this point in the history
Motivation:
For some tests it is useful to be able to wait for the source being
subscribed. The `subscribe()` may offloaded using the 
`subscribeOn()` operator and the thread which initiated the 
`subscribe()` regains control before the subscribe completes on
the offloaded thread. `awaitSubscribed()` allows waiting until the
subscribe is asynchronously completed.
Modifications:
Add `awaitSubscribed()` method to `TestCompletable`, `TestSingle`, and
`TestPublisher`. Shared code is moved to new `AwaitUtils` class
Result:
Additional testing capability.
  • Loading branch information
bondolo authored and idelpivnitskiy committed Feb 9, 2022
1 parent 580a5a2 commit 917791d
Show file tree
Hide file tree
Showing 6 changed files with 153 additions and 85 deletions.
1 change: 1 addition & 0 deletions servicetalk-concurrent-api/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ dependencies {

testFixturesImplementation testFixtures(project(":servicetalk-concurrent-internal"))
testFixturesImplementation project(":servicetalk-utils-internal")
testFixturesImplementation project(":servicetalk-concurrent-test-internal")
testFixturesImplementation "com.google.code.findbugs:jsr305:$jsr305Version"
testFixturesImplementation "org.junit.jupiter:junit-jupiter-api"
testFixturesImplementation "junit:junit:$junitVersion"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@

import io.servicetalk.concurrent.Cancellable;
import io.servicetalk.concurrent.CompletableSource;
import io.servicetalk.concurrent.test.internal.AwaitUtils;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicReferenceFieldUpdater;
import java.util.function.Function;
Expand Down Expand Up @@ -53,6 +55,7 @@ public final class TestCompletable extends Completable implements CompletableSou
private final Function<Subscriber, Subscriber> subscriberFunction;
private final List<Throwable> exceptions = new CopyOnWriteArrayList<>();
private volatile Subscriber subscriber = new WaitingSubscriber();
private final CountDownLatch subscriberLatch = new CountDownLatch(1);

/**
* Create a {@code TestCompletable} with the defaults. See <b>Defaults</b> section of class javadoc.
Expand All @@ -74,6 +77,14 @@ public boolean isSubscribed() {
return !(subscriber instanceof WaitingSubscriber);
}

/**
* Awaits until this {@link TestCompletable} is subscribed, even if interrupted. If interrupted the
* {@link Thread#isInterrupted()} will be set upon return.
*/
public void awaitSubscribed() {
AwaitUtils.awaitUninterruptibly(subscriberLatch);
}

@Override
protected void handleSubscribe(final Subscriber subscriber) {
try {
Expand All @@ -85,6 +96,7 @@ protected void handleSubscribe(final Subscriber subscriber) {
final WaitingSubscriber waiter = (WaitingSubscriber) currSubscriber;
waiter.realSubscriber(newSubscriber);
}
subscriberLatch.countDown();
break;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@
package io.servicetalk.concurrent.api;

import io.servicetalk.concurrent.PublisherSource;
import io.servicetalk.concurrent.test.internal.AwaitUtils;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicReferenceFieldUpdater;
import java.util.function.Function;
Expand Down Expand Up @@ -54,6 +56,7 @@ public final class TestPublisher<T> extends Publisher<T> implements PublisherSou
private final Function<Subscriber<? super T>, Subscriber<? super T>> subscriberFunction;
private final List<Throwable> exceptions = new CopyOnWriteArrayList<>();
private volatile Subscriber<? super T> subscriber = new WaitingSubscriber<>();
private final CountDownLatch subscriberLatch = new CountDownLatch(1);

/**
* Create a {@code TestPublisher} with the defaults. See <b>Defaults</b> section of class javadoc.
Expand All @@ -75,6 +78,14 @@ public boolean isSubscribed() {
return !(subscriber instanceof WaitingSubscriber);
}

/**
* Awaits until this {@link TestPublisher} is subscribed, even if interrupted. If interrupted the
* {@link Thread#isInterrupted()} will be set upon return.
*/
public void awaitSubscribed() {
AwaitUtils.awaitUninterruptibly(subscriberLatch);
}

@Override
protected void handleSubscribe(final Subscriber<? super T> subscriber) {
try {
Expand All @@ -87,6 +98,7 @@ protected void handleSubscribe(final Subscriber<? super T> subscriber) {
final WaitingSubscriber<T> waiter = (WaitingSubscriber<T>) currSubscriber;
waiter.realSubscriber(newSubscriber);
}
subscriberLatch.countDown();
break;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@

import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicReferenceFieldUpdater;
import java.util.function.Function;
import javax.annotation.Nullable;

import static io.servicetalk.concurrent.test.internal.AwaitUtils.awaitUninterruptibly;
import static java.util.Objects.requireNonNull;

/**
Expand All @@ -54,6 +56,7 @@ public final class TestSingle<T> extends Single<T> implements SingleSource<T> {
private final Function<Subscriber<? super T>, Subscriber<? super T>> subscriberFunction;
private final List<Throwable> exceptions = new CopyOnWriteArrayList<>();
private volatile Subscriber<? super T> subscriber = new WaitingSubscriber<>();
private final CountDownLatch subscriberLatch = new CountDownLatch(1);

/**
* Create a {@code TestSingle} with the defaults. See <b>Defaults</b> section of class javadoc.
Expand All @@ -75,6 +78,14 @@ public boolean isSubscribed() {
return !(subscriber instanceof WaitingSubscriber);
}

/**
* Awaits until this Single is subscribed, even if interrupted. If interrupted the {@link Thread#isInterrupted()}
* will be set upon return.
*/
public void awaitSubscribed() {
awaitUninterruptibly(subscriberLatch);
}

@Override
protected void handleSubscribe(final Subscriber<? super T> subscriber) {
try {
Expand All @@ -87,6 +98,7 @@ protected void handleSubscribe(final Subscriber<? super T> subscriber) {
final WaitingSubscriber<T> waiter = (WaitingSubscriber<T>) currSubscriber;
waiter.realSubscriber(newSubscriber);
}
subscriberLatch.countDown();
break;
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/*
* Copyright © 2021 Apple Inc. and the ServiceTalk project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.servicetalk.concurrent.test.internal;

import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import javax.annotation.Nullable;

import static java.util.concurrent.TimeUnit.NANOSECONDS;

public final class AwaitUtils {
private AwaitUtils() {
// no instances
}

public static void awaitUninterruptibly(CountDownLatch latch) {
boolean interrupted = false;
try {
do {
try {
latch.await();
return;
} catch (InterruptedException e) {
interrupted = true;
}
} while (true);
} finally {
if (interrupted) {
Thread.currentThread().interrupt();
}
}
}

public static boolean awaitUninterruptibly(CountDownLatch latch, long timeout, TimeUnit unit) {
final long startTime = System.nanoTime();
final long timeoutNanos = NANOSECONDS.convert(timeout, unit);
long waitTime = timeoutNanos;
boolean interrupted = false;
try {
do {
try {
return latch.await(waitTime, NANOSECONDS);
} catch (InterruptedException e) {
interrupted = true;
}
waitTime = timeoutNanos - (System.nanoTime() - startTime);
if (waitTime <= 0) {
return true;
}
} while (true);
} finally {
if (interrupted) {
Thread.currentThread().interrupt();
}
}
}

public static <T> T takeUninterruptibly(BlockingQueue<T> queue) {
boolean interrupted = false;
try {
do {
try {
return queue.take();
} catch (InterruptedException e) {
interrupted = true;
}
} while (true);
} finally {
if (interrupted) {
Thread.currentThread().interrupt();
}
}
}

@Nullable
public static <T> T pollUninterruptibly(BlockingQueue<T> queue, long timeout, TimeUnit unit) {
final long startTime = System.nanoTime();
final long timeoutNanos = NANOSECONDS.convert(timeout, unit);
long waitTime = timeout;
boolean interrupted = false;
try {
do {
try {
return queue.poll(waitTime, NANOSECONDS);
} catch (InterruptedException e) {
interrupted = true;
}
waitTime = timeoutNanos - (System.nanoTime() - startTime);
if (waitTime <= 0) {
return null;
}
} while (true);
} finally {
if (interrupted) {
Thread.currentThread().interrupt();
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@

import static io.servicetalk.concurrent.internal.TerminalNotification.complete;
import static io.servicetalk.concurrent.internal.TerminalNotification.error;
import static io.servicetalk.concurrent.test.internal.AwaitUtils.awaitUninterruptibly;
import static io.servicetalk.concurrent.test.internal.AwaitUtils.pollUninterruptibly;
import static io.servicetalk.concurrent.test.internal.AwaitUtils.takeUninterruptibly;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.TimeUnit.NANOSECONDS;
import static java.util.stream.Collectors.toList;

/**
Expand Down Expand Up @@ -325,88 +327,4 @@ private static Object wrapNull(@Nullable Object item) {
private static <T> T unwrapNull(Object item) {
return item == NULL_ON_NEXT ? null : (T) item;
}

private static void awaitUninterruptibly(CountDownLatch latch) {
boolean interrupted = false;
try {
do {
try {
latch.await();
return;
} catch (InterruptedException e) {
interrupted = true;
}
} while (true);
} finally {
if (interrupted) {
Thread.currentThread().interrupt();
}
}
}

private static boolean awaitUninterruptibly(CountDownLatch latch, long timeout, TimeUnit unit) {
final long startTime = System.nanoTime();
final long timeoutNanos = NANOSECONDS.convert(timeout, unit);
long waitTime = timeoutNanos;
boolean interrupted = false;
try {
do {
try {
return latch.await(waitTime, NANOSECONDS);
} catch (InterruptedException e) {
interrupted = true;
}
waitTime = timeoutNanos - (System.nanoTime() - startTime);
if (waitTime <= 0) {
return true;
}
} while (true);
} finally {
if (interrupted) {
Thread.currentThread().interrupt();
}
}
}

private static <T> T takeUninterruptibly(BlockingQueue<T> queue) {
boolean interrupted = false;
try {
do {
try {
return queue.take();
} catch (InterruptedException e) {
interrupted = true;
}
} while (true);
} finally {
if (interrupted) {
Thread.currentThread().interrupt();
}
}
}

@Nullable
private static <T> T pollUninterruptibly(BlockingQueue<T> queue, long timeout, TimeUnit unit) {
final long startTime = System.nanoTime();
final long timeoutNanos = NANOSECONDS.convert(timeout, unit);
long waitTime = timeout;
boolean interrupted = false;
try {
do {
try {
return queue.poll(waitTime, NANOSECONDS);
} catch (InterruptedException e) {
interrupted = true;
}
waitTime = timeoutNanos - (System.nanoTime() - startTime);
if (waitTime <= 0) {
return null;
}
} while (true);
} finally {
if (interrupted) {
Thread.currentThread().interrupt();
}
}
}
}

0 comments on commit 917791d

Please sign in to comment.