diff --git a/server/src/main/java/org/opensearch/gateway/AsyncShardBatchFetch.java b/server/src/main/java/org/opensearch/gateway/AsyncShardBatchFetch.java index 4f39a39cea678..9f4cee99465a5 100644 --- a/server/src/main/java/org/opensearch/gateway/AsyncShardBatchFetch.java +++ b/server/src/main/java/org/opensearch/gateway/AsyncShardBatchFetch.java @@ -80,6 +80,16 @@ public synchronized void clearShard(ShardId shardId) { this.cache.deleteShard(shardId); } + /** + * Clear the cache for a given node and shardId. + * + * @param nodeId node id to be removed from the batch. + * @param shardId shard id to be removed from the batch. + */ + public synchronized void clearCache(String nodeId, ShardId shardId) { + this.cache.cleanCacheForNodeForShardId(nodeId, shardId); + } + /** * Cache implementation of transport actions returning batch of shards related data in the response. * Store node level responses of transport actions like {@link TransportNodesListGatewayStartedShardsBatch} or @@ -138,6 +148,14 @@ public void deleteShard(ShardId shardId) { } } + @Override + public void cleanCacheForNodeForShardId(String nodeId, ShardId shardId) { + if (shardIdToArray.containsKey(shardId)) { + Integer shardIdIndex = shardIdToArray.remove(shardId); + cache.get(nodeId).clearShard(shardIdIndex); + } + } + @Override public void initData(DiscoveryNode node) { cache.put(node.getId(), new NodeEntry<>(node.getId(), shardResponseClass, batchSize, emptyShardResponsePredicate)); diff --git a/server/src/main/java/org/opensearch/gateway/AsyncShardFetch.java b/server/src/main/java/org/opensearch/gateway/AsyncShardFetch.java index b664dd573ce67..5b9571a73bba9 100644 --- a/server/src/main/java/org/opensearch/gateway/AsyncShardFetch.java +++ b/server/src/main/java/org/opensearch/gateway/AsyncShardFetch.java @@ -315,6 +315,11 @@ public void deleteShard(ShardId shardId) { cache.clear(); // single shard cache can clear the full map } + @Override + public void cleanCacheForNodeForShardId(String nodeId, ShardId shardId) { + cache.remove(nodeId); // non batch cache only has one entry per node + } + /** * A node entry, holding the state of the fetched data for a specific shard * for a giving node. diff --git a/server/src/main/java/org/opensearch/gateway/AsyncShardFetchCache.java b/server/src/main/java/org/opensearch/gateway/AsyncShardFetchCache.java index 2a4e6181467b0..8e37137d4aa24 100644 --- a/server/src/main/java/org/opensearch/gateway/AsyncShardFetchCache.java +++ b/server/src/main/java/org/opensearch/gateway/AsyncShardFetchCache.java @@ -74,6 +74,8 @@ protected AsyncShardFetchCache(Logger logger, String type) { */ abstract void deleteShard(ShardId shardId); + abstract void cleanCacheForNodeForShardId(String nodeId, ShardId shardId); + /** * Returns the number of fetches that are currently ongoing. */ diff --git a/server/src/main/java/org/opensearch/gateway/ReplicaShardBatchAllocator.java b/server/src/main/java/org/opensearch/gateway/ReplicaShardBatchAllocator.java index 3459f1591b633..be7867b7823f6 100644 --- a/server/src/main/java/org/opensearch/gateway/ReplicaShardBatchAllocator.java +++ b/server/src/main/java/org/opensearch/gateway/ReplicaShardBatchAllocator.java @@ -56,7 +56,7 @@ public void processExistingRecoveries(RoutingAllocation allocation, List routingNode.getInitializingShards().forEach(shardRouting -> { if (currentBatchedShards.containsKey(shardRouting.shardId()) && shardRouting.primary() == primary) { batchedShardsToAssign.add(shardRouting.shardId()); + // Set updated shard routing in batch if it already exists + String batchId = currentBatchedShards.get(shardRouting.shardId()); + currentBatches.get(batchId).batchInfo.get(shardRouting.shardId()).setShardRouting(shardRouting); } })); @@ -410,10 +413,6 @@ private void ensureAsyncFetchStorePrimaryRecency(RoutingAllocation allocation) { Sets.difference(newEphemeralIds, lastSeenEphemeralIds) ) ); - // ToDo : Validate that we don't need below call for batch allocation - // storeShardBatchLookup.values().forEach(batch -> - // clearCacheForBatchPrimary(batchIdToStoreShardBatch.get(batch), allocation) - // ); batchIdToStoreShardBatch.values().forEach(batch -> clearCacheForBatchPrimary(batch, allocation)); // recalc to also (lazily) clear out old nodes. @@ -422,20 +421,16 @@ private void ensureAsyncFetchStorePrimaryRecency(RoutingAllocation allocation) { } private static void clearCacheForBatchPrimary(ShardsBatch batch, RoutingAllocation allocation) { - // We're not running below code because for removing a node from cache we need all replica's primaries - // to be assigned on same node. This was easy in single shard case and we're saving a call for a node - // if primary was already assigned for a replica. But here we don't keep track of per shard data in cache - // so it's not feasible to do any removal of node entry just based on single shard. - // ONLY run if single shard is present in the batch, to maintain backward compatibility - if (batch.getBatchedShards().size() == 1) { - List primaries = batch.getBatchedShards() - .stream() - .map(allocation.routingNodes()::activePrimary) - .filter(Objects::nonNull) - .collect(Collectors.toList()); - AsyncShardFetch fetch = batch.getAsyncFetcher(); - primaries.forEach(node -> fetch.clearCacheForNode(node.currentNodeId())); - } + // We need to clear the cache for the primary shard to ensure we do not cancel recoveries based on excessively + // stale data. We do this by clearing the cache of primary shards on nodes for all the active primaries of + // replicas in the current batch. + List primaries = batch.getBatchedShards() + .stream() + .map(allocation.routingNodes()::activePrimary) + .filter(Objects::nonNull) + .collect(Collectors.toList()); + AsyncShardBatchFetch fetch = batch.getAsyncFetcher(); + primaries.forEach(shardRouting -> fetch.clearCache(shardRouting.currentNodeId(), shardRouting.shardId())); } private boolean hasNewNodes(DiscoveryNodes nodes) {