Skip to content

Commit

Permalink
Consolidate DelayableWriteable (#55932)
Browse files Browse the repository at this point in the history
This commit includes a number of minor improvements around `DelayableWriteable`: javadocs were expanded and reworded, `get` was renamed to `expand` and `DelayableWriteable` no longer implements `Supplier`. Also a couple of methods are now private instead of package private.
  • Loading branch information
javanna committed Apr 30, 2020
1 parent c36bcb4 commit fc6422f
Show file tree
Hide file tree
Showing 8 changed files with 63 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@
import java.util.Map;
import java.util.function.Function;
import java.util.function.IntFunction;
import java.util.function.Supplier;
import java.util.stream.Collectors;

public final class SearchPhaseController {
Expand Down Expand Up @@ -437,7 +436,7 @@ public ReducedQueryPhase reducedQueryPhase(Collection<? extends SearchPhaseResul
* @see QuerySearchResult#consumeProfileResult()
*/
private ReducedQueryPhase reducedQueryPhase(Collection<? extends SearchPhaseResult> queryResults,
List<Supplier<InternalAggregations>> bufferedAggs,
List<DelayableWriteable<InternalAggregations>> bufferedAggs,
List<TopDocs> bufferedTopDocs,
TopDocsStats topDocsStats, int numReducePhases, boolean isScrollRequest,
InternalAggregation.ReduceContextBuilder aggReduceContextBuilder,
Expand All @@ -462,7 +461,7 @@ private ReducedQueryPhase reducedQueryPhase(Collection<? extends SearchPhaseResu
final boolean hasSuggest = firstResult.suggest() != null;
final boolean hasProfileResults = firstResult.hasProfileResults();
final boolean consumeAggs;
final List<Supplier<InternalAggregations>> aggregationsList;
final List<DelayableWriteable<InternalAggregations>> aggregationsList;
if (bufferedAggs != null) {
consumeAggs = false;
// we already have results from intermediate reduces and just need to perform the final reduce
Expand Down Expand Up @@ -527,18 +526,18 @@ private ReducedQueryPhase reducedQueryPhase(Collection<? extends SearchPhaseResu
firstResult.sortValueFormats(), numReducePhases, size, from, false);
}

private InternalAggregations reduceAggs(
private static InternalAggregations reduceAggs(
InternalAggregation.ReduceContextBuilder aggReduceContextBuilder,
boolean performFinalReduce,
List<? extends Supplier<InternalAggregations>> aggregationsList
List<DelayableWriteable<InternalAggregations>> aggregationsList
) {
/*
* Parse the aggregations, clearing the list as we go so bits backing
* the DelayedWriteable can be collected immediately.
*/
List<InternalAggregations> toReduce = new ArrayList<>(aggregationsList.size());
for (int i = 0; i < aggregationsList.size(); i++) {
toReduce.add(aggregationsList.get(i).get());
toReduce.add(aggregationsList.get(i).expand());
aggregationsList.set(i, null);
}
return aggregationsList.isEmpty() ? null : InternalAggregations.topLevelReduce(toReduce,
Expand Down Expand Up @@ -701,7 +700,7 @@ private synchronized void consumeInternal(QuerySearchResult querySearchResult) {
if (hasAggs) {
List<InternalAggregations> aggs = new ArrayList<>(aggsBuffer.length);
for (int i = 0; i < aggsBuffer.length; i++) {
aggs.add(aggsBuffer[i].get());
aggs.add(aggsBuffer[i].expand());
aggsBuffer[i] = null; // null the buffer so it can be GCed now.
}
InternalAggregations reduced =
Expand Down Expand Up @@ -743,8 +742,8 @@ private synchronized void consumeInternal(QuerySearchResult querySearchResult) {
processedShards[querySearchResult.getShardIndex()] = querySearchResult.getSearchShardTarget();
}

private synchronized List<Supplier<InternalAggregations>> getRemainingAggs() {
return hasAggs ? Arrays.asList((Supplier<InternalAggregations>[]) aggsBuffer).subList(0, index) : null;
private synchronized List<DelayableWriteable<InternalAggregations>> getRemainingAggs() {
return hasAggs ? Arrays.asList((DelayableWriteable<InternalAggregations>[]) aggsBuffer).subList(0, index) : null;
}

private synchronized List<TopDocs> getRemainingTopDocs() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,26 @@
import org.elasticsearch.common.bytes.BytesReference;

import java.io.IOException;
import java.util.function.Supplier;

/**
* A holder for {@link Writeable}s that can delays reading the underlying
* {@linkplain Writeable} when it is read from a remote node.
* A holder for {@link Writeable}s that delays reading the underlying object
* on the receiving end. To be used for objects whose deserialized
* representation is inefficient to keep in memory compared to their
* corresponding serialized representation.
* The node that produces the {@link Writeable} calls {@link #referencing(Writeable)}
* to create a {@link DelayableWriteable} that serializes the inner object
* first to a buffer and writes the content of the buffer to the {@link StreamOutput}.
* The receiver node calls {@link #delayed(Reader, StreamInput)} to create a
* {@link DelayableWriteable} that reads the buffer from the @link {@link StreamInput}
* but delays creating the actual object by calling {@link #expand()} when needed.
* Multiple {@link DelayableWriteable}s coming from different nodes may be buffered
* on the receiver end, which may hold a mix of {@link DelayableWriteable}s that were
* produced locally (hence expanded) as well as received form another node (hence subject
* to delayed expansion). When such objects are buffered for some time it may be desirable
* to force their buffering in serialized format by calling
* {@link #asSerialized(Reader, NamedWriteableRegistry)}.
*/
public abstract class DelayableWriteable<T extends Writeable> implements Supplier<T>, Writeable {
public abstract class DelayableWriteable<T extends Writeable> implements Writeable {
/**
* Build a {@linkplain DelayableWriteable} that wraps an existing object
* but is serialized so that deserializing it can be delayed.
Expand All @@ -42,7 +55,7 @@ public static <T extends Writeable> DelayableWriteable<T> referencing(T referenc
/**
* Build a {@linkplain DelayableWriteable} that copies a buffer from
* the provided {@linkplain StreamInput} and deserializes the buffer
* when {@link Supplier#get()} is called.
* when {@link #expand()} is called.
*/
public static <T extends Writeable> DelayableWriteable<T> delayed(Writeable.Reader<T> reader, StreamInput in) throws IOException {
return new Serialized<>(reader, in.getVersion(), in.namedWriteableRegistry(), in.readBytesReference());
Expand All @@ -56,16 +69,21 @@ private DelayableWriteable() {}
*/
public abstract Serialized<T> asSerialized(Writeable.Reader<T> reader, NamedWriteableRegistry registry);

/**
* Expands the inner {@link Writeable} to its original representation and returns it
*/
public abstract T expand();

/**
* {@code true} if the {@linkplain Writeable} is being stored in
* serialized form, {@code false} otherwise.
*/
abstract boolean isSerialized();

private static class Referencing<T extends Writeable> extends DelayableWriteable<T> {
private T reference;
private final T reference;

Referencing(T reference) {
private Referencing(T reference) {
this.reference = reference;
}

Expand All @@ -75,17 +93,19 @@ public void writeTo(StreamOutput out) throws IOException {
}

@Override
public T get() {
public T expand() {
return reference;
}

@Override
public Serialized<T> asSerialized(Reader<T> reader, NamedWriteableRegistry registry) {
BytesStreamOutput buffer;
try {
return new Serialized<T>(reader, Version.CURRENT, registry, writeToBuffer(Version.CURRENT).bytes());
buffer = writeToBuffer(Version.CURRENT);
} catch (IOException e) {
throw new RuntimeException("unexpected error expanding aggregations", e);
throw new RuntimeException("unexpected error writing writeable to buffer", e);
}
return new Serialized<>(reader, Version.CURRENT, registry, buffer.bytes());
}

@Override
Expand All @@ -111,8 +131,8 @@ public static class Serialized<T extends Writeable> extends DelayableWriteable<T
private final NamedWriteableRegistry registry;
private final BytesReference serialized;

Serialized(Writeable.Reader<T> reader, Version serializedAtVersion,
NamedWriteableRegistry registry, BytesReference serialized) throws IOException {
private Serialized(Writeable.Reader<T> reader, Version serializedAtVersion,
NamedWriteableRegistry registry, BytesReference serialized) {
this.reader = reader;
this.serializedAtVersion = serializedAtVersion;
this.registry = registry;
Expand All @@ -136,20 +156,20 @@ public void writeTo(StreamOutput out) throws IOException {
* differences in the wire protocol. This ain't efficient but
* it should be quite rare.
*/
referencing(get()).writeTo(out);
referencing(expand()).writeTo(out);
}
}

@Override
public T get() {
public T expand() {
try {
try (StreamInput in = registry == null ?
serialized.streamInput() : new NamedWriteableAwareStreamInput(serialized.streamInput(), registry)) {
in.setVersion(serializedAtVersion);
return reader.read(in);
}
} catch (IOException e) {
throw new RuntimeException("unexpected error expanding aggregations", e);
throw new RuntimeException("unexpected error expanding serialized delayed writeable", e);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1103,7 +1103,7 @@ public <T extends Exception> T readException() throws IOException {
}

/**
* Get the registry of named writeables is his stream has one,
* Get the registry of named writeables if this stream has one,
* {@code null} otherwise.
*/
public NamedWriteableRegistry namedWriteableRegistry() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ public void writeToNoId(StreamOutput out) throws IOException {
} else {
out.writeBoolean(true);
if (out.getVersion().before(Version.V_7_7_0)) {
InternalAggregations aggs = aggregations.get();
InternalAggregations aggs = aggregations.expand();
aggs.writeTo(out);
if (out.getVersion().before(Version.V_7_2_0)) {
/*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ public void testRoundTripFromDelayedFromOldVersionWithNamedWriteable() throws IO
public void testSerializesWithRemoteVersion() throws IOException {
Version remoteVersion = VersionUtils.randomCompatibleVersion(random(), Version.CURRENT);
DelayableWriteable<SneakOtherSideVersionOnWire> original = DelayableWriteable.referencing(new SneakOtherSideVersionOnWire());
assertThat(roundTrip(original, SneakOtherSideVersionOnWire::new, remoteVersion).get().version, equalTo(remoteVersion));
assertThat(roundTrip(original, SneakOtherSideVersionOnWire::new, remoteVersion).expand().version, equalTo(remoteVersion));
}

public void testAsSerializedIsNoopOnSerialized() throws IOException {
Expand All @@ -172,7 +172,7 @@ public void testAsSerializedIsNoopOnSerialized() throws IOException {
private <T extends Writeable> void roundTripTestCase(DelayableWriteable<T> original, Writeable.Reader<T> reader) throws IOException {
DelayableWriteable<T> roundTripped = roundTrip(original, reader, Version.CURRENT);
assertTrue(roundTripped.isSerialized());
assertThat(roundTripped.get(), equalTo(original.get()));
assertThat(roundTripped.expand(), equalTo(original.expand()));
}

private <T extends Writeable> DelayableWriteable<T> roundTrip(DelayableWriteable<T> original,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ public void testSerialization() throws Exception {
assertEquals(querySearchResult.size(), deserialized.size());
assertEquals(querySearchResult.hasAggs(), deserialized.hasAggs());
if (deserialized.hasAggs()) {
Aggregations aggs = querySearchResult.consumeAggs().get();
Aggregations deserializedAggs = deserialized.consumeAggs().get();
Aggregations aggs = querySearchResult.consumeAggs().expand();
Aggregations deserializedAggs = deserialized.consumeAggs().expand();
assertEquals(aggs.asList(), deserializedAggs.asList());
}
assertEquals(querySearchResult.terminatedEarly(), deserialized.terminatedEarly());
Expand All @@ -114,7 +114,7 @@ public void testReadFromPre_7_1_0() throws IOException {
QuerySearchResult querySearchResult = new QuerySearchResult(in);
assertEquals(100, querySearchResult.getContextId().getId());
assertTrue(querySearchResult.hasAggs());
InternalAggregations aggs = querySearchResult.consumeAggs().get();
InternalAggregations aggs = querySearchResult.consumeAggs().expand();
assertEquals(1, aggs.asList().size());
// We deserialize and throw away top level pipeline aggs
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -379,18 +379,16 @@ public void onPartialReduce(List<SearchShard> shards, TotalHits totalHits,
reducedAggs = () -> null;
} else {
/*
* Keep a reference to the serialiazed form of the partially
* Keep a reference to the serialized form of the partially
* reduced aggs and reduce it on the fly when someone asks
* for it. This will produce right-ish aggs. Much more right
* than if you don't do the final reduce. Its important that
* we wait until someone needs the result so we don't perform
* the final reduce only to throw it away. And it is important
* that we kep the reference to the serialized aggrgations
* because the SearchPhaseController *already* has that
* reference so we're not creating more garbage.
* for it. It's important that we wait until someone needs
* the result so we don't perform the final reduce only to
* throw it away. And it is important that we keep the reference
* to the serialized aggregations because SearchPhaseController
* *already* has that reference so we're not creating more garbage.
*/
reducedAggs = () ->
InternalAggregations.topLevelReduce(singletonList(aggregations.get()), aggReduceContextSupplier.get());
InternalAggregations.topLevelReduce(singletonList(aggregations.expand()), aggReduceContextSupplier.get());
}
searchResponse.get().updatePartialResponse(shards.size(), totalHits, reducedAggs, reducePhase);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,9 @@ class MutableSearchResponse {
private int reducePhase;
/**
* The response produced by the search API. Once we receive it we stop
* building our own {@linkplain SearchResponse}s when you get the status
* and instead return this.
* building our own {@linkplain SearchResponse}s when get async search
* is called, and instead return this.
* @see #findOrBuildResponse(AsyncSearchTask)
*/
private SearchResponse finalResponse;
private ElasticsearchException failure;
Expand Down Expand Up @@ -157,10 +158,9 @@ private SearchResponse findOrBuildResponse(AsyncSearchTask task) {
/*
* Build the response, reducing aggs if we haven't already and
* storing the result of the reduction so we won't have to reduce
* a second time if you get the response again and nothing has
* changed. This does cost memory because we have a reference
* to the reduced aggs sitting around so it can't be GCed until
* we get an update.
* the same aggregation results a second time if nothing has changed.
* This does cost memory because we have a reference to the finally
* reduced aggs sitting around which can't be GCed until we get an update.
*/
InternalAggregations reducedAggs = reducedAggsSource.get();
reducedAggsSource = () -> reducedAggs;
Expand All @@ -183,8 +183,6 @@ synchronized AsyncSearchResponse toAsyncSearchResponseWithHeaders(AsyncSearchTas
return resp;
}



private void failIfFrozen() {
if (frozen) {
throw new IllegalStateException("invalid update received after the completion of the request");
Expand Down

0 comments on commit fc6422f

Please sign in to comment.