Skip to content
This repository has been archived by the owner on Sep 26, 2019. It is now read-only.

Commit

Permalink
Fix thread safety in SubscriptionManager. (#1540)
Browse files Browse the repository at this point in the history
Connection ID is now tracked as a property of the subscription rather than as a separate map, so we can use a single concurrent map to track all details of subscriptions.
  • Loading branch information
ajsutton authored and lucassaldanha committed Jun 10, 2019
1 parent 4753529 commit 0f1162d
Show file tree
Hide file tree
Showing 20 changed files with 182 additions and 222 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,18 @@

public class Subscription {

private final Long id;
private final Long subscriptionId;
private final String connectionId;
private final SubscriptionType subscriptionType;
private final Boolean includeTransaction;

public Subscription(
final Long id, final SubscriptionType subscriptionType, final Boolean includeTransaction) {
this.id = id;
final Long subscriptionId,
final String connectionId,
final SubscriptionType subscriptionType,
final Boolean includeTransaction) {
this.subscriptionId = subscriptionId;
this.connectionId = connectionId;
this.subscriptionType = subscriptionType;
this.includeTransaction = includeTransaction;
}
Expand All @@ -35,8 +40,12 @@ public SubscriptionType getSubscriptionType() {
return subscriptionType;
}

public Long getId() {
return id;
public Long getSubscriptionId() {
return subscriptionId;
}

public String getConnectionId() {
return connectionId;
}

public Boolean getIncludeTransaction() {
Expand All @@ -46,8 +55,10 @@ public Boolean getIncludeTransaction() {
@Override
public String toString() {
return MoreObjects.toStringHelper(this)
.add("id", id)
.add("subscriptionId", subscriptionId)
.add("connectionId", connectionId)
.add("subscriptionType", subscriptionType)
.add("includeTransaction", includeTransaction)
.toString();
}

Expand All @@ -64,11 +75,12 @@ public boolean equals(final Object o) {
return false;
}
final Subscription that = (Subscription) o;
return Objects.equals(id, that.id) && subscriptionType == that.subscriptionType;
return Objects.equals(subscriptionId, that.subscriptionId)
&& subscriptionType == that.subscriptionType;
}

@Override
public int hashCode() {
return Objects.hash(id, subscriptionType);
return Objects.hash(subscriptionId, subscriptionType);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,27 +23,31 @@

public class SubscriptionBuilder {

public Subscription build(final long id, final SubscribeRequest request) {
public Subscription build(
final long subscriptionId, final String connectionId, final SubscribeRequest request) {
final SubscriptionType subscriptionType = request.getSubscriptionType();
switch (subscriptionType) {
case NEW_BLOCK_HEADERS:
{
return new NewBlockHeadersSubscription(id, request.getIncludeTransaction());
return new NewBlockHeadersSubscription(
subscriptionId, connectionId, request.getIncludeTransaction());
}
case LOGS:
{
return new LogsSubscription(
id,
subscriptionId,
connectionId,
Optional.ofNullable(request.getFilterParameter())
.orElseThrow(IllegalArgumentException::new));
}
case SYNCING:
{
return new SyncingSubscription(id, subscriptionType);
return new SyncingSubscription(subscriptionId, connectionId, subscriptionType);
}
case NEW_PENDING_TRANSACTIONS:
default:
return new Subscription(id, subscriptionType, request.getIncludeTransaction());
return new Subscription(
subscriptionId, connectionId, subscriptionType, request.getIncludeTransaction());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,13 @@
import tech.pegasys.pantheon.ethereum.jsonrpc.websocket.subscription.request.UnsubscribeRequest;
import tech.pegasys.pantheon.ethereum.jsonrpc.websocket.subscription.response.SubscriptionResponse;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Consumer;
import java.util.stream.Collectors;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import io.vertx.core.AbstractVerticle;
import io.vertx.core.eventbus.Message;
import io.vertx.core.json.Json;
Expand All @@ -47,12 +43,9 @@ public class SubscriptionManager extends AbstractVerticle {
"SubscriptionManager::removeSubscriptions";

private final AtomicLong subscriptionCounter = new AtomicLong(0);
private final Map<Long, Subscription> subscriptions = new HashMap<>();
private final Map<String, List<Long>> connectionSubscriptionsMap = new HashMap<>();
private final Map<Long, Subscription> subscriptions = new ConcurrentHashMap<>();
private final SubscriptionBuilder subscriptionBuilder = new SubscriptionBuilder();

public SubscriptionManager() {}

@Override
public void start() {
vertx.eventBus().consumer(EVENTBUS_REMOVE_SUBSCRIPTIONS_ADDRESS, this::removeSubscriptions);
Expand All @@ -62,23 +55,11 @@ public Long subscribe(final SubscribeRequest request) {
LOG.debug("Subscribe request {}", request);

final long subscriptionId = subscriptionCounter.incrementAndGet();
final Subscription subscription = subscriptionBuilder.build(subscriptionId, request);
addSubscription(subscription, request.getConnectionId());

return subscription.getId();
}

private void addSubscription(final Subscription subscription, final String connectionId) {
subscriptions.put(subscription.getId(), subscription);
mapSubscriptionToConnection(connectionId, subscription.getId());
}
final Subscription subscription =
subscriptionBuilder.build(subscriptionId, request.getConnectionId(), request);
subscriptions.put(subscription.getSubscriptionId(), subscription);

private void mapSubscriptionToConnection(final String connectionId, final Long subscriptionId) {
if (connectionSubscriptionsMap.containsKey(connectionId)) {
connectionSubscriptionsMap.get(connectionId).add(subscriptionId);
} else {
connectionSubscriptionsMap.put(connectionId, Lists.newArrayList(subscriptionId));
}
return subscription.getSubscriptionId();
}

public boolean unsubscribe(final UnsubscribeRequest request) {
Expand All @@ -87,66 +68,39 @@ public boolean unsubscribe(final UnsubscribeRequest request) {

LOG.debug("Unsubscribe request subscriptionId = {}", subscriptionId);

if (!subscriptions.containsKey(subscriptionId)
|| !connectionOwnsSubscription(subscriptionId, connectionId)) {
final Subscription subscription = subscriptions.get(subscriptionId);
if (subscription == null || !subscription.getConnectionId().equals(connectionId)) {
throw new SubscriptionNotFoundException(subscriptionId);
}

destroySubscription(subscriptionId, connectionId);
destroySubscription(subscriptionId);

return true;
}

private boolean connectionOwnsSubscription(final Long subscriptionId, final String connectionId) {
return connectionSubscriptionsMap.get(connectionId) != null
&& connectionSubscriptionsMap.get(connectionId).contains(subscriptionId);
}

private void destroySubscription(final long subscriptionId, final String connectionId) {
private void destroySubscription(final long subscriptionId) {
subscriptions.remove(subscriptionId);

if (connectionSubscriptionsMap.containsKey(connectionId)) {
removeSubscriptionToConnectionMapping(connectionId, subscriptionId);
}
}

private void removeSubscriptionToConnectionMapping(
final String connectionId, final Long subscriptionId) {
if (connectionSubscriptionsMap.get(connectionId).size() > 1) {
connectionSubscriptionsMap.get(connectionId).remove(subscriptionId);
} else {
connectionSubscriptionsMap.remove(connectionId);
}
}

@VisibleForTesting
void removeSubscriptions(final Message<String> message) {
private void removeSubscriptions(final Message<String> message) {
final String connectionId = message.body();
if (connectionId == null || "".equals(connectionId)) {
LOG.warn("Received invalid connectionId ({}). No subscriptions removed.");
LOG.warn("Received invalid connectionId ({}). No subscriptions removed.", connectionId);
}

LOG.debug("Removing subscription for connectionId = {}", connectionId);

final List<Long> subscriptionIds =
Lists.newArrayList(
connectionSubscriptionsMap.getOrDefault(connectionId, Lists.newArrayList()));
subscriptionIds.forEach(subscriptionId -> destroySubscription(subscriptionId, connectionId));
}
LOG.debug("Removing subscription for connectionId {}", connectionId);

@VisibleForTesting
Map<Long, Subscription> subscriptions() {
return Maps.newHashMap(subscriptions);
subscriptions.values().stream()
.filter(subscription -> subscription.getConnectionId().equals(connectionId))
.forEach(subscription -> destroySubscription(subscription.getSubscriptionId()));
}

@VisibleForTesting
public Map<String, List<Long>> getConnectionSubscriptionsMap() {
return Maps.newHashMap(connectionSubscriptionsMap);
public Subscription getSubscriptionById(final Long subscriptionId) {
return subscriptions.get(subscriptionId);
}

public <T> List<T> subscriptionsOfType(final SubscriptionType type, final Class<T> clazz) {
return subscriptions.entrySet().stream()
.map(Entry::getValue)
return subscriptions.values().stream()
.filter(subscription -> subscription.isType(type))
.map(subscriptionBuilder.mapToSubscriptionClass(clazz))
.collect(Collectors.toList());
Expand All @@ -155,11 +109,10 @@ public <T> List<T> subscriptionsOfType(final SubscriptionType type, final Class<
public void sendMessage(final Long subscriptionId, final JsonRpcResult msg) {
final SubscriptionResponse response = new SubscriptionResponse(subscriptionId, msg);

connectionSubscriptionsMap.entrySet().stream()
.filter(e -> e.getValue().contains(subscriptionId))
.map(Entry::getKey)
.findFirst()
.ifPresent(connectionId -> vertx.eventBus().send(connectionId, Json.encode(response)));
final Subscription subscription = subscriptions.get(subscriptionId);
if (subscription != null) {
vertx.eventBus().send(subscription.getConnectionId(), Json.encode(response));
}
}

public <T> void notifySubscribersOnWorkerThread(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ public class NewBlockHeadersSubscription extends Subscription {

private final boolean includeTransactions;

public NewBlockHeadersSubscription(final Long subscriptionId, final boolean includeTransactions) {
super(subscriptionId, SubscriptionType.NEW_BLOCK_HEADERS, Boolean.FALSE);
public NewBlockHeadersSubscription(
final Long subscriptionId, final String connectionId, final boolean includeTransactions) {
super(subscriptionId, connectionId, SubscriptionType.NEW_BLOCK_HEADERS, Boolean.FALSE);
this.includeTransactions = includeTransactions;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public void onBlockAdded(final BlockAddedEvent event, final Blockchain blockchai
? blockWithCompleteTransaction(newBlockHash)
: blockWithTransactionHash(newBlockHash);

subscriptionManager.sendMessage(subscription.getId(), newBlock);
subscriptionManager.sendMessage(subscription.getSubscriptionId(), newBlock);
}
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ public class LogsSubscription extends Subscription {

private final FilterParameter filterParameter;

public LogsSubscription(final Long subscriptionId, final FilterParameter filterParameter) {
super(subscriptionId, SubscriptionType.LOGS, Boolean.FALSE);
public LogsSubscription(
final Long subscriptionId, final String connectionId, final FilterParameter filterParameter) {
super(subscriptionId, connectionId, SubscriptionType.LOGS, Boolean.FALSE);
this.filterParameter = filterParameter;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ private void sendLogToSubscription(
final int logIndex,
final LogsSubscription subscription) {
final LogWithMetadata logWithMetaData = logWithMetadata(logIndex, receiptWithMetadata, removed);
subscriptionManager.sendMessage(subscription.getId(), new LogResult(logWithMetaData));
subscriptionManager.sendMessage(
subscription.getSubscriptionId(), new LogResult(logWithMetaData));
}

// @formatter:off
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ private void notifySubscribers(final Hash pendingTransaction) {

final PendingTransactionResult msg = new PendingTransactionResult(pendingTransaction);
for (final Subscription subscription : subscriptions) {
subscriptionManager.sendMessage(subscription.getId(), msg);
subscriptionManager.sendMessage(subscription.getSubscriptionId(), msg);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ private void notifySubscribers(final Transaction pendingTransaction) {
new PendingTransactionDetailResult(pendingTransaction);
for (final Subscription subscription : subscriptions) {
if (Boolean.TRUE.equals(subscription.getIncludeTransaction())) {
subscriptionManager.sendMessage(subscription.getId(), detailResult);
subscriptionManager.sendMessage(subscription.getSubscriptionId(), detailResult);
} else {
subscriptionManager.sendMessage(subscription.getId(), hashResult);
subscriptionManager.sendMessage(subscription.getSubscriptionId(), hashResult);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@
public class SyncingSubscription extends Subscription {
private boolean firstMessageHasBeenSent = false;

public SyncingSubscription(final Long id, final SubscriptionType subscriptionType) {
super(id, subscriptionType, Boolean.FALSE);
public SyncingSubscription(
final Long subscriptionId,
final String connectionId,
final SubscriptionType subscriptionType) {
super(subscriptionId, connectionId, subscriptionType, Boolean.FALSE);
}

public void setFirstMessageHasBeenSent(final boolean firstMessageHasBeenSent) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,14 @@ private void sendSyncingToMatchingSubscriptions(final SyncStatus syncStatus) {
syncingSubscriptions -> {
if (syncStatus.inSync()) {
syncingSubscriptions.forEach(
s -> subscriptionManager.sendMessage(s.getId(), new NotSynchronisingResult()));
s ->
subscriptionManager.sendMessage(
s.getSubscriptionId(), new NotSynchronisingResult()));
} else {
syncingSubscriptions.forEach(
s -> subscriptionManager.sendMessage(s.getId(), new SyncingResult(syncStatus)));
s ->
subscriptionManager.sendMessage(
s.getSubscriptionId(), new SyncingResult(syncStatus)));
}
});
}
Expand Down
Loading

0 comments on commit 0f1162d

Please sign in to comment.