From 60c2d375fc663f1fa5cc7c48dc09f3f74a2008dd Mon Sep 17 00:00:00 2001 From: Marcin Januszkiewicz Date: Thu, 29 Aug 2024 17:18:42 +0200 Subject: [PATCH] Fixes --- .../elasticsearch/query/TraveltimeScorer.java | 3 +- .../elasticsearch/query/TraveltimeWeight.java | 51 ++++++++++++++----- 2 files changed, 40 insertions(+), 14 deletions(-) diff --git a/7.12/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java b/7.12/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java index 3958877..c55b3dc 100644 --- a/7.12/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java +++ b/7.12/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java @@ -2,7 +2,6 @@ import it.unimi.dsi.fastutil.longs.Long2IntMap; import java.io.IOException; -import lombok.Getter; import lombok.RequiredArgsConstructor; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.Scorer; @@ -17,7 +16,7 @@ public class TraveltimeScorer extends Scorer { private class TraveltimeFilteredDocs extends DocIdSetIterator { private final TraveltimeWeight.FilteredIterator backing; - @Getter private long currentValue = 0; + private long currentValue = 0; private boolean currentValueDirty = true; private void invalidateCurrentValue() { diff --git a/7.12/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java b/7.12/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java index b7e8541..231698c 100644 --- a/7.12/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java +++ b/7.12/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java @@ -121,18 +121,45 @@ public Scorer scorer(LeafReaderContext context) throws IOException { val pointToTime = new Long2IntOpenHashMap(valueArray.size()); - val results = - protoFetcher.getTimes( - ttQuery.getParams().getOrigin(), - decodedArray, - ttQuery.getParams().getLimit(), - ttQuery.getParams().getMode(), - ttQuery.getParams().getCountry(), - ttQuery.getParams().getRequestType()); - - for (int index = 0; index < results.size(); index++) { - if (results.get(index) >= 0) { - pointToTime.put(valueArray.getLong(index), results.get(index).intValue()); + if (ttQuery.getParams().isIncludeDistance()) { + val pointToDistance = new Long2IntOpenHashMap(valueArray.size()); + + val mode = Util.unsafeCastToDistanceTransportation(ttQuery.getParams().getMode()); + + val timeDistance = + protoFetcher.getTimesAndDistances( + ttQuery.getParams().getOrigin(), + decodedArray, + ttQuery.getParams().getLimit(), + mode, + ttQuery.getParams().getCountry(), + ttQuery.getParams().getRequestType()); + + val times = timeDistance.getLeft(); + val distances = timeDistance.getRight(); + + for (int index = 0; index < times.size(); index++) { + if (times.get(index) >= 0) { + pointToTime.put(valueArray.getLong(index), times.get(index).intValue()); + pointToDistance.put(valueArray.getLong(index), distances.get(index).intValue()); + } + } + + TraveltimeCache.DISTANCE.add(ttQuery.getParams(), pointToDistance); + } else { + val results = + protoFetcher.getTimes( + ttQuery.getParams().getOrigin(), + decodedArray, + ttQuery.getParams().getLimit(), + ttQuery.getParams().getMode(), + ttQuery.getParams().getCountry(), + ttQuery.getParams().getRequestType()); + + for (int index = 0; index < results.size(); index++) { + if (results.get(index) >= 0) { + pointToTime.put(valueArray.getLong(index), results.get(index).intValue()); + } } }