diff --git a/src/main/java/org/elasticsearch/common/lucene/search/function/FiltersFunctionScoreQuery.java b/src/main/java/org/elasticsearch/common/lucene/search/function/FiltersFunctionScoreQuery.java index f6910fdf2c1e1..12bc7be809da2 100644 --- a/src/main/java/org/elasticsearch/common/lucene/search/function/FiltersFunctionScoreQuery.java +++ b/src/main/java/org/elasticsearch/common/lucene/search/function/FiltersFunctionScoreQuery.java @@ -177,7 +177,15 @@ public Explanation explain(AtomicReaderContext context, int doc) throws IOExcept } // First: Gather explanations for all filters List filterExplanations = new ArrayList<>(); + float weightSum = 0; for (FilterFunction filterFunction : filterFunctions) { + + if (filterFunction.function instanceof WeightFactorFunction) { + weightSum += ((WeightFactorFunction) filterFunction.function).getWeight(); + } else { + weightSum++; + } + Bits docSet = DocIdSets.toSafeBits(context.reader(), filterFunction.filter.getDocIdSet(context, context.reader().getLiveDocs())); if (docSet.get(doc)) { @@ -226,15 +234,13 @@ public Explanation explain(AtomicReaderContext context, int doc) throws IOExcept break; default: // Avg / Total double totalFactor = 0.0f; - int count = 0; for (int i = 0; i < filterExplanations.size(); i++) { totalFactor += filterExplanations.get(i).getValue(); - count++; } - if (count != 0) { + if (weightSum != 0) { factor = totalFactor; if (scoreMode == ScoreMode.Avg) { - factor /= count; + factor /= weightSum; } } } @@ -300,17 +306,21 @@ public float innerScore() throws IOException { } } else { // Avg / Total double totalFactor = 0.0f; - int count = 0; + float weightSum = 0; for (int i = 0; i < filterFunctions.length; i++) { if (docSets[i].get(docId)) { totalFactor += filterFunctions[i].function.score(docId, subQueryScore); - count++; + if (filterFunctions[i].function instanceof WeightFactorFunction) { + weightSum+= ((WeightFactorFunction)filterFunctions[i].function).getWeight(); + } else { + weightSum++; + } } } - if (count != 0) { + if (weightSum != 0) { factor = totalFactor; if (scoreMode == ScoreMode.Avg) { - factor /= count; + factor /= weightSum; } } } diff --git a/src/test/java/org/elasticsearch/search/functionscore/FunctionScoreTests.java b/src/test/java/org/elasticsearch/search/functionscore/FunctionScoreTests.java index 1ad5022cbbbe9..93980d6345827 100644 --- a/src/test/java/org/elasticsearch/search/functionscore/FunctionScoreTests.java +++ b/src/test/java/org/elasticsearch/search/functionscore/FunctionScoreTests.java @@ -254,8 +254,11 @@ protected double computeExpectedScore(float[] weights, float[] scores, String sc expectedScore = Float.MAX_VALUE; } + float weightSum = 0; + for (int i = 0; i < weights.length; i++) { double functionScore = (double) weights[i] * scores[i]; + weightSum += weights[i]; if ("avg".equals(scoreMode)) { expectedScore += functionScore; @@ -271,7 +274,7 @@ protected double computeExpectedScore(float[] weights, float[] scores, String sc } if ("avg".equals(scoreMode)) { - expectedScore /= weights.length; + expectedScore /= weightSum; } return expectedScore; }