diff --git a/src/main/java/io/anserini/ltr/BaseFeatureExtractor.java b/src/main/java/io/anserini/ltr/BaseFeatureExtractor.java deleted file mode 100644 index cd47a5b35b..0000000000 --- a/src/main/java/io/anserini/ltr/BaseFeatureExtractor.java +++ /dev/null @@ -1,257 +0,0 @@ -/* - * Anserini: A Lucene toolkit for replicable information retrieval research - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.anserini.ltr; - -import io.anserini.analysis.AnalyzerUtils; -import io.anserini.ltr.feature.FeatureExtractors; -import io.anserini.rerank.RerankerContext; -import io.anserini.util.Qrels; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.apache.lucene.analysis.Analyzer; -import org.apache.lucene.document.Document; -import org.apache.lucene.index.IndexReader; -import org.apache.lucene.index.MultiBits; -import org.apache.lucene.index.MultiTerms; -import org.apache.lucene.index.Terms; -import org.apache.lucene.search.IndexSearcher; -import org.apache.lucene.search.Query; -import org.apache.lucene.search.ScoreDoc; -import org.apache.lucene.search.TopDocs; -import org.apache.lucene.util.Bits; - -import java.io.IOException; -import java.io.PrintStream; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Set; - -/** - * Feature extractor class that forms the base for other feature extractors - */ -abstract public class BaseFeatureExtractor { - private static final Logger LOG = LogManager.getLogger(BaseFeatureExtractor.class); - private IndexReader reader; - private Qrels qrels; - private Map> topics; - private Analyzer queryAnalyzer; - private final FeatureExtractors customFeatureExtractors; - - abstract protected String getIdField(); - - abstract protected String getTermVectorField(); - - protected FeatureExtractors getExtractors() { - return this.customFeatureExtractors; - } - - abstract protected Analyzer getAnalyzer(); - - abstract protected Set getFieldsToLoad(); - - abstract protected Query parseQuery(String queryText); - - abstract protected Query docIdQuery(String docId); - - public static String constructOutputString(K qid, int qrel, String docId, float[] features) { - StringBuilder sb = new StringBuilder(); - sb.append(qrel); - sb.append(" "); - sb.append("qid:"); - sb.append(qid); - for (int i = 0 ; i < features.length; i++) { - sb.append(" "); - sb.append(i+1); - sb.append(":"); - sb.append(features[i]); - } - sb.append(" # "); - sb.append(docId); - return sb.toString(); - } - /** - * Method used to print a line of feature vector to the output file - * @param out The output stream - * @param qid Qrel Id - * @param qrel The Qrel relevance value - * @param docId The stored Doc Id - * @param features The feature vector in featureNum:value form - */ - public static void writeFeatureVector(PrintStream out, K qid, int qrel, String docId, float[] features) { - out.print(constructOutputString(qid, qrel, docId,features)); - out.print("\n"); - } - - /** - * Factory method that will take the usual parameters for making a Web or Twitter feature extractor - * and a definition file. Will parse the definition file to build the FeatureExtractor chain - * @param reader - * @param qrels - * @param topics - * @param definitionFile - * @return - */ - static BaseFeatureExtractor parseExtractorsFromFile(IndexReader reader, Qrels qrels, - Map> topics, String definitionFile) { - - return null; - } - - /** - * Constructor that requires a reader to the index, the qrels and the topics - * @param reader - * @param qrels - * @param topics - */ - public BaseFeatureExtractor(IndexReader reader, Qrels qrels, Map> topics, - FeatureExtractors extractors) { - this.reader = reader; - this.qrels = qrels; - this.topics = topics; - this.queryAnalyzer = getAnalyzer(); - this.customFeatureExtractors = extractors; - } - - // Build all the reranker contexts because they will be reused once per query - @SuppressWarnings("unchecked") - private Map> buildRerankerContextMap() throws IOException { - Map> queryContextMap = new HashMap<>(); - IndexSearcher searcher = new IndexSearcher(reader); - - for (String qid : qrels.getQids()) { - // Construct the reranker context - LOG.debug(String.format("Constructing context for QID: %s", qid)); - String queryText = topics.get(Integer.parseInt(qid)).get("title"); - Query q = null; - - // We will not be checking for nulls here because the input should be correct, - // and if not it signals other issues - q = parseQuery(queryText); - List queryTokens = AnalyzerUtils.analyze(queryAnalyzer, queryText); - // Construct the reranker context - RerankerContext context = new RerankerContext<>(searcher, (K)qid, - q, null, queryText, - queryTokens, - null, null); - - queryContextMap.put(qid, context); - - } - LOG.debug("Completed constructing context for all qrels"); - return queryContextMap; - } - - private void printHeader(PrintStream out, FeatureExtractors extractors) { - out.println("#Extracting features with the following feature vector:"); - for (int i = 0; i < extractors.extractors.size(); i++) { - out.println(String.format("#%d:%s", i +1, extractors.extractors.get(i).getName())); - } - } - - /** - * Iterates through all the documents and print the features for each of the queries - * This way we are not iterating over the entire index for each query to save disk access - * @param out - * @throws IOException - */ - public void printFeatureForAllDocs(PrintStream out) throws IOException { - Map> queryContextMap = buildRerankerContextMap(); - FeatureExtractors extractors = getExtractors(); - Bits liveDocs = MultiBits.getLiveDocs(reader); - Set fieldsToLoad = getFieldsToLoad(); - - this.printHeader(out, extractors); - - for (int docId = 0; docId < reader.maxDoc(); docId ++) { - // Only check live docs if we have some - if (reader.hasDeletions() && (liveDocs == null || !liveDocs.get(docId))) { - LOG.warn(String.format("Document %d not in live docs", docId)); - continue; - } - Document doc = reader.document(docId, fieldsToLoad); - String docIdString = doc.get(getIdField()); - // NOTE doc frequencies should not be retrieved from here, term vector returned is as if on single document - // index - Terms terms = MultiTerms.getTerms(reader, getTermVectorField());//reader.getTermVector(docId, getTermVectorField()); - - if (terms == null) { - continue; - } - - for (Map.Entry> entry : queryContextMap.entrySet()) { - float[] featureValues = extractors.extractAll(doc, terms, entry.getValue()); - writeFeatureVector(out, entry.getKey(),qrels.getRelevanceGrade(entry.getKey(),docIdString), - docIdString, featureValues); - } - out.flush(); - LOG.debug(String.format("Completed computing feature vectors for doc %d", docId)); - } - } - - /** - * Prints feature vectors wrt to the qrels, one vector per qrel - * @param out - * @throws IOException - */ - public void printFeatures(PrintStream out) throws IOException { - Map> queryContextMap = buildRerankerContextMap(); - FeatureExtractors extractors = getExtractors(); - Bits liveDocs = MultiBits.getLiveDocs(reader); - Set fieldsToLoad = getFieldsToLoad(); - - // We need to open a searcher - IndexSearcher searcher = new IndexSearcher(reader); - - this.printHeader(out, extractors); - // Iterate through all the qrels and for each document id we have for them - LOG.debug("Processing queries"); - - for (String qid : this.qrels.getQids()) { - LOG.debug(String.format("Processing qid: %s", qid)); - // Get the map of documents - RerankerContext context = queryContextMap.get(qid); - - for (Map.Entry entry : this.qrels.getDocMap(qid).entrySet()) { - String docId = entry.getKey(); - int qrelScore = entry.getValue(); - // We issue a specific query - TopDocs topDocs = searcher.search(docIdQuery(docId), 1); - if (topDocs.totalHits.value == 0) { - LOG.warn(String.format("Document Id %s expected but not found in index, skipping...", docId)); - continue; - } - - ScoreDoc hit = topDocs.scoreDocs[0]; - Document doc = reader.document(hit.doc, fieldsToLoad); - - //TODO factor for test - Terms terms = reader.getTermVector(hit.doc, getTermVectorField()); - - if (terms == null) { - LOG.debug(String.format("No term vectors found for doc %s, qid %s", docId, qid)); - continue; - } - float[] featureValues = extractors.extractAll(doc, terms, context); - writeFeatureVector(out, qid ,qrelScore, - docId, featureValues); - } - LOG.debug(String.format("Finished processing for qid: %s", qid)); - out.flush(); - } - } -} diff --git a/src/main/java/io/anserini/ltr/DumpTweetsLtrData.java b/src/main/java/io/anserini/ltr/DumpTweetsLtrData.java deleted file mode 100644 index ef00bf8c41..0000000000 --- a/src/main/java/io/anserini/ltr/DumpTweetsLtrData.java +++ /dev/null @@ -1,152 +0,0 @@ -/* - * Anserini: A Lucene toolkit for replicable information retrieval research - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.anserini.ltr; - -import io.anserini.analysis.AnalyzerUtils; -import io.anserini.analysis.TweetAnalyzer; -import io.anserini.index.IndexArgs; -import io.anserini.ltr.feature.FeatureExtractors; -import io.anserini.rerank.RerankerCascade; -import io.anserini.rerank.RerankerContext; -import io.anserini.rerank.ScoredDocuments; -import io.anserini.search.SearchArgs; -import io.anserini.search.query.BagOfWordsQueryGenerator; -import io.anserini.search.topicreader.MicroblogTopicReader; -import io.anserini.search.topicreader.TopicReader; -import io.anserini.util.Qrels; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.apache.lucene.document.LongPoint; -import org.apache.lucene.index.DirectoryReader; -import org.apache.lucene.index.IndexReader; -import org.apache.lucene.search.BooleanClause; -import org.apache.lucene.search.BooleanQuery; -import org.apache.lucene.search.IndexSearcher; -import org.apache.lucene.search.Query; -import org.apache.lucene.search.TopDocs; -import org.apache.lucene.search.similarities.BM25Similarity; -import org.apache.lucene.store.Directory; -import org.apache.lucene.store.FSDirectory; -import org.kohsuke.args4j.CmdLineException; -import org.kohsuke.args4j.CmdLineParser; -import org.kohsuke.args4j.Option; -import org.kohsuke.args4j.OptionHandlerFilter; -import org.kohsuke.args4j.ParserProperties; - -import java.io.File; -import java.io.FileOutputStream; -import java.io.PrintStream; -import java.nio.file.Files; -import java.nio.file.Path; -import java.nio.file.Paths; -import java.util.List; -import java.util.Map; -import java.util.SortedMap; -import java.util.TreeMap; - -@SuppressWarnings("deprecation") -public class DumpTweetsLtrData { - private static final Logger LOG = LogManager.getLogger(DumpTweetsLtrData.class); - - private DumpTweetsLtrData() {} - - private static class LtrArgs extends SearchArgs { - @Option(name = "-qrels", metaVar = "[file]", required = true, usage = "qrels file") - public String qrels; - - @Option(name = "-extractors", metaVar = "[file]", required = true, usage = "FeatureExtractors Definition File") - public String extractors = null; - } - - public static void main(String[] argv) throws Exception { - long curTime = System.nanoTime(); - LtrArgs args = new LtrArgs(); - CmdLineParser parser = new CmdLineParser(args, ParserProperties.defaults().withUsageWidth(90)); - - try { - parser.parseArgument(argv); - } catch (CmdLineException e) { - System.err.println(e.getMessage()); - parser.printUsage(System.err); - System.err.println("Example: DumpTweetsLtrData" + parser.printExample(OptionHandlerFilter.REQUIRED)); - return; - } - - LOG.info("Reading index at " + args.index); - Directory dir = FSDirectory.open(Paths.get(args.index)); - IndexReader reader = DirectoryReader.open(dir); - IndexSearcher searcher = new IndexSearcher(reader); - - searcher.setSimilarity(new BM25Similarity()); - Qrels qrels = new Qrels(args.qrels); - - FeatureExtractors extractors = null; - if (args.extractors != null) { - extractors = FeatureExtractors.loadExtractor(args.extractors); - } - - PrintStream out = new PrintStream(new FileOutputStream(new File(args.output))); - RerankerCascade cascade = new RerankerCascade(); - cascade.add(new TweetsLtrDataGenerator(out, qrels, extractors)); - - SortedMap> topics = new TreeMap<>(); - for (String singleTopicFile : args.topics) { - Path topicsFilePath = Paths.get(singleTopicFile); - if (!Files.exists(topicsFilePath) || !Files.isRegularFile(topicsFilePath) || !Files.isReadable(topicsFilePath)) { - throw new IllegalArgumentException("Topics file : " + topicsFilePath + " does not exist or is not a (readable) file."); - } - TopicReader tr = new MicroblogTopicReader(topicsFilePath); - topics.putAll(tr.read()); - } - - LOG.info("Initialized complete! (elapsed time = " + (System.nanoTime()-curTime)/1000000 + "ms)"); - long totalTime = 0; - int cnt = 0; - for (Map.Entry> entry : topics.entrySet()) { - long curQueryTime = System.nanoTime(); - Integer qID = entry.getKey(); - String queryString = entry.getValue().get("title"); - Long queryTime = Long.parseLong(entry.getValue().get("time")); - Query filter = LongPoint.newRangeQuery(IndexArgs.ID, 0L, queryTime); - Query query = new BagOfWordsQueryGenerator().buildQuery(IndexArgs.ID, - new TweetAnalyzer(), queryString); - BooleanQuery.Builder builder = new BooleanQuery.Builder(); - builder.add(filter, BooleanClause.Occur.FILTER); - builder.add(query, BooleanClause.Occur.MUST); - Query q = builder.build(); - - TopDocs rs = searcher.search(q, args.hits); - List queryTokens = AnalyzerUtils.analyze(new TweetAnalyzer(), queryString); - - RerankerContext context = new RerankerContext<>(searcher, Integer.parseInt(queryString), query, null, - queryString, queryTokens, filter, null); - - cascade.run(ScoredDocuments.fromTopDocs(rs, searcher), context); - long qtime = (System.nanoTime()-curQueryTime)/1000000; - LOG.info("Query " + qID + " (elapsed time = " + qtime + "ms)"); - totalTime += qtime; - cnt++; - } - - LOG.info("All queries completed!"); - LOG.info("Total elapsed time = " + totalTime + "ms"); - LOG.info("Average query latency = " + (totalTime/cnt) + "ms"); - - reader.close(); - out.close(); - } -} diff --git a/src/main/java/io/anserini/ltr/FeatureExtractorCli.java b/src/main/java/io/anserini/ltr/FeatureExtractorCli.java deleted file mode 100644 index cf0c823ff6..0000000000 --- a/src/main/java/io/anserini/ltr/FeatureExtractorCli.java +++ /dev/null @@ -1,129 +0,0 @@ -/* - * Anserini: A Lucene toolkit for replicable information retrieval research - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.anserini.ltr; - -import io.anserini.ltr.feature.FeatureExtractors; -import io.anserini.search.topicreader.MicroblogTopicReader; -import io.anserini.search.topicreader.TopicReader; -import io.anserini.search.topicreader.TrecTopicReader; -import io.anserini.search.topicreader.WebxmlTopicReader; -import io.anserini.util.Qrels; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.apache.lucene.index.DirectoryReader; -import org.apache.lucene.index.IndexReader; -import org.apache.lucene.store.Directory; -import org.apache.lucene.store.FSDirectory; -import org.kohsuke.args4j.CmdLineException; -import org.kohsuke.args4j.CmdLineParser; -import org.kohsuke.args4j.Option; -import org.kohsuke.args4j.ParserProperties; - -import java.io.File; -import java.io.FileOutputStream; -import java.io.PrintStream; -import java.nio.file.Paths; -import java.util.Map; -import java.util.SortedMap; - -/** - * Main class for feature extractors feed in command line arguments to dump features - */ -public class FeatureExtractorCli { - private static final Logger LOG = LogManager.getLogger(FeatureExtractorCli.class); - - static class FeatureExtractionArgs { - @Option(name = "-index", metaVar = "[path]", required = true, usage = "Lucene index directory") - public String indexDir; - - @Option(name = "-qrel", metaVar = "[path]", required = true, usage = "Qrel File") - public String qrelFile; - - @Option(name = "-topic", metaVar = "[path]", required = true, usage = "Topic File") - public String topicsFile; - - @Option(name = "-out", metaVar = "[path]", required = true, usage = "Output File") - public String outputFile; - - @Option(name = "-collection", metaVar = "[path]", required = true, usage = "[clueweb|gov2|twitter]") - public String collection; - - @Option(name = "-extractors", metaVar = "[path]", required = false, usage = "FeatureExtractors File") - public String extractors = null; - - @SuppressWarnings("unchecked") - public TopicReader buildTopicReaderForCollection() throws Exception { - if ("clueweb".equals(collection)) { - return (TopicReader) new WebxmlTopicReader(Paths.get(topicsFile)); - } else if ("gov2".equals(collection)){ - return (TopicReader) new TrecTopicReader(Paths.get(topicsFile)); - } else if ("twitter".equals(collection)) { - return (TopicReader) new MicroblogTopicReader(Paths.get(topicsFile)); - } - - throw new RuntimeException("Unrecognized collection " + collection); - } - - @SuppressWarnings("unchecked") - public BaseFeatureExtractor buildBaseFeatureExtractor(IndexReader reader, Qrels qrels, Map> topics, FeatureExtractors extractors) { - if ("clueweb".equals(collection) || "gov2".equals(collection)) { - return new WebFeatureExtractor(reader, qrels, topics, extractors); - } else if ("twitter".equals(collection)) { - return (BaseFeatureExtractor) new TwitterFeatureExtractor(reader, qrels, (Map>) topics, extractors); - } - - throw new RuntimeException("Unrecognized collection " + collection); - } - } - - /** - * requires the user to supply the index directory and also the directory containing the qrels and topics - * @param args indexDir, qrelFile, topicFile, outputFile - */ - public static void main(String args[]) throws Exception { - - FeatureExtractionArgs parsedArgs = new FeatureExtractionArgs(); - CmdLineParser parser= new CmdLineParser(parsedArgs, ParserProperties.defaults().withUsageWidth(90)); - - try { - parser.parseArgument(args); - } catch (CmdLineException e) { - System.err.println(e.getMessage()); - parser.printUsage(System.err); - return; - } - - Directory indexDirectory = FSDirectory.open(Paths.get(parsedArgs.indexDir)); - IndexReader reader = DirectoryReader.open(indexDirectory); - Qrels qrels = new Qrels(parsedArgs.qrelFile); - - FeatureExtractors extractors = null; - if (parsedArgs.extractors != null) { - extractors = FeatureExtractors.loadExtractor(parsedArgs.extractors); - } - - // Query parser needed to construct the query object for feature extraction in the loop - PrintStream out = new PrintStream (new FileOutputStream(new File(parsedArgs.outputFile))); - - TopicReader tr = parsedArgs.buildTopicReaderForCollection(); - SortedMap> topics = tr.read(); - LOG.debug(String.format("%d topics found", topics.size())); - - BaseFeatureExtractor extractor = parsedArgs.buildBaseFeatureExtractor(reader, qrels, topics, extractors); - extractor.printFeatures(out); - } -} diff --git a/src/main/java/io/anserini/ltr/FeatureExtractorUtils.java b/src/main/java/io/anserini/ltr/FeatureExtractorUtils.java new file mode 100644 index 0000000000..28d0172848 --- /dev/null +++ b/src/main/java/io/anserini/ltr/FeatureExtractorUtils.java @@ -0,0 +1,263 @@ +/* + * Anserini: A Lucene toolkit for replicable information retrieval research + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.anserini.ltr; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.anserini.index.IndexArgs; +import io.anserini.ltr.feature.FeatureExtractor; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.lucene.document.Document; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.Term; +import org.apache.lucene.index.Terms; +import org.apache.lucene.search.*; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.FSDirectory; + +import java.io.IOException; +import java.nio.file.Paths; +import java.util.*; +import java.util.concurrent.*; + +/** + * Feature extractor class that exposed in Pyserini + */ +public class FeatureExtractorUtils { + private static final Logger LOG = LogManager.getLogger(FeatureExtractorUtils.class); + private IndexReader reader; + private IndexSearcher searcher; + private List extractors = new ArrayList<>(); + private Set fieldsToLoad = new HashSet<>(); + private ExecutorService pool; + private Map> tasks = new HashMap<>(); + + /** + * set up the feature we wish to extract + * @param extractor initialized FeatureExtractor instance + * @return + */ + public FeatureExtractorUtils add(FeatureExtractor extractor) { + extractors.add(extractor); + if((extractor.getField()!=null)&&(!fieldsToLoad.contains(extractor.getField()))) + fieldsToLoad.add(extractor.getField()); + return this; + } + + /** + * mainly used for testing + * @param queryTokens tokenized query text + * @param docIds external document ids that you wish to collect; users need to make sure it is present + * @return + * @throws ExecutionException + * @throws InterruptedException + * @throws JsonProcessingException + */ + public ArrayList extract(List queryTokens, List docIds) throws ExecutionException, InterruptedException, JsonProcessingException { + ObjectMapper mapper = new ObjectMapper(); + input root = new input(); + root.qid = "-1"; + root.queryTokens = queryTokens; + root.docIds = docIds; + this.lazyExtract(mapper.writeValueAsString(root)); + String res = this.getResult(root.qid); + TypeReference> typeref = new TypeReference>() {}; + return mapper.readValue(res, typeref); + } + + /** + * submit tasks to workers + * @param qid unique query id; users need to make sure it is not duplicated + * @param queryTokens tokenized query text + * @param docIds external document ids that you wish to collect; users need to make sure it is present + */ + public void addTask(String qid, List queryTokens, List docIds) { + if(tasks.containsKey(qid)) + throw new IllegalArgumentException("existed qid"); + tasks.put(qid, pool.submit(() -> { + List localExtractors = new ArrayList<>(); + for(FeatureExtractor e: extractors){ + localExtractors.add(e.clone()); + } + ObjectMapper mapper = new ObjectMapper(); + List result = new ArrayList<>(); + for(String docId: docIds) { + Query q = new TermQuery(new Term(IndexArgs.ID, docId)); + TopDocs topDocs = searcher.search(q, 1); + if (topDocs.totalHits.value == 0) { + LOG.warn(String.format("Document Id %s expected but not found in index, skipping...", docId)); + continue; + } + + ScoreDoc hit = topDocs.scoreDocs[0]; + Document doc = reader.document(hit.doc, fieldsToLoad); + + Terms terms = reader.getTermVector(hit.doc, IndexArgs.CONTENTS); + List features = new ArrayList<>(); + for (int i = 0; i < localExtractors.size(); i++) { + features.add(localExtractors.get(i).extract(doc, terms, String.join(",", queryTokens), queryTokens, reader)); + } + result.add(new output(docId,features)); + } + return mapper.writeValueAsString(result); + })); + } + + /** + * submit tasks to workers, exposed in Pyserini + * @param jsonString + * @throws JsonProcessingException + */ + public void lazyExtract(String jsonString) throws JsonProcessingException { + ObjectMapper mapper = new ObjectMapper(); + input root = mapper.readValue(jsonString, input.class); + this.addTask(root.qid, root.queryTokens, root.docIds); + } + + /** + * blocked until the result is ready + * @param qid the query id you wise to fetch the result + * @return + * @throws ExecutionException + * @throws InterruptedException + */ + public String getResult(String qid) throws ExecutionException, InterruptedException { + return tasks.remove(qid).get(); + } + + /** + * @param indexDir index path to work on + * @throws IOException + */ + public FeatureExtractorUtils(String indexDir) throws IOException { + Directory indexDirectory = FSDirectory.open(Paths.get(indexDir)); + reader = DirectoryReader.open(indexDirectory); + searcher = new IndexSearcher(reader); + fieldsToLoad.add(IndexArgs.ID); + pool = Executors.newFixedThreadPool(1); + } + + /** + * @param indexDir index path to work on + * @param workNum worker threads number + * @throws IOException + */ + public FeatureExtractorUtils(String indexDir, int workNum) throws IOException { + Directory indexDirectory = FSDirectory.open(Paths.get(indexDir)); + reader = DirectoryReader.open(indexDirectory); + searcher = new IndexSearcher(reader); + fieldsToLoad.add(IndexArgs.ID); + pool = Executors.newFixedThreadPool(workNum); + } + + /** + * for testing purpose + * @param reader initialized indexreader + * @throws IOException + */ + public FeatureExtractorUtils(IndexReader reader) throws IOException { + this.reader = reader; + searcher = new IndexSearcher(reader); + fieldsToLoad.add(IndexArgs.ID); + pool = Executors.newFixedThreadPool(1); + } + + /** + * @param reader + * @param workNum + * @throws IOException + */ + public FeatureExtractorUtils(IndexReader reader, int workNum) throws IOException { + this.reader = reader; + searcher = new IndexSearcher(reader); + fieldsToLoad.add(IndexArgs.ID); + pool = Executors.newFixedThreadPool(workNum); + } + + /** + * close to avoid theadleaking warning during test + * @throws IOException + */ + public void close() throws IOException { + pool.shutdown(); + reader.close(); + } + +} + +class input{ + String qid; + List queryTokens; + List docIds; + + input(){} + + public String getQid() { + return qid; + } + + public List getDocIds() { + return docIds; + } + + public List getQueryTokens() { + return queryTokens; + } + + public void setQid(String qid) { + this.qid = qid; + } + + public void setDocIds(List docIds) { + this.docIds = docIds; + } + + public void setQueryTokens(List queryTokens) { + this.queryTokens = queryTokens; + } +} + +class output{ + String pid; + List features; + + output(){} + + output(String pid, List features){ + this.pid = pid; + this.features = features; + } + + public String getPid() { + return pid; + } + + public List getFeatures() { + return features; + } + + public void setPid(String pid) { + this.pid = pid; + } + + public void setFeatures(List features) { + this.features = features; + } +} \ No newline at end of file diff --git a/src/main/java/io/anserini/ltr/TweetsLtrDataGenerator.java b/src/main/java/io/anserini/ltr/TweetsLtrDataGenerator.java deleted file mode 100644 index a2286b41fb..0000000000 --- a/src/main/java/io/anserini/ltr/TweetsLtrDataGenerator.java +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Anserini: A Lucene toolkit for replicable information retrieval research - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.anserini.ltr; - -import io.anserini.index.IndexArgs; -import io.anserini.ltr.feature.FeatureExtractors; -import io.anserini.rerank.Reranker; -import io.anserini.rerank.RerankerContext; -import io.anserini.rerank.ScoredDocuments; -import io.anserini.util.Qrels; -import org.apache.lucene.index.IndexReader; -import org.apache.lucene.index.Terms; - -import java.io.FileNotFoundException; -import java.io.IOException; -import java.io.PrintStream; - -public class TweetsLtrDataGenerator implements Reranker { - private final PrintStream out; - private final Qrels qrels; - private final FeatureExtractors extractorChain; - - - public TweetsLtrDataGenerator(PrintStream out, Qrels qrels, FeatureExtractors extractors) throws FileNotFoundException { - this.out = out; - this.qrels = qrels; - this.extractorChain = extractors == null ? WebFeatureExtractor.getDefaultExtractors() : extractors; - - } - - @Override - public ScoredDocuments rerank(ScoredDocuments docs, RerankerContext context) { - IndexReader reader = context.getIndexSearcher().getIndexReader(); - - for (int i = 0; i < docs.documents.length; i++) { - Terms terms = null; - try { - terms = reader.getTermVector(docs.ids[i], IndexArgs.CONTENTS); - } catch (IOException e) { - continue; - } - - String qid = ((String)context.getQueryId()).replaceFirst("^MB0*", ""); - String docid = docs.documents[i].getField(IndexArgs.ID).stringValue(); - - out.print(qrels.getRelevanceGrade(qid, docid)); - out.print(" qid:" + qid); - out.print(" 1:" + docs.scores[i]); - - float[] intFeatures = this.extractorChain.extractAll(docs.documents[i], terms, context); - - for (int j=0; j { - private static final Logger LOG = LogManager.getLogger(TwitterFeatureExtractor.class); - private static final FeatureExtractors DEFAULT_EXTRACTOR_CHAIN = FeatureExtractors. - createFeatureExtractorChain(new UnigramFeatureExtractor(), - new UnorderedSequentialPairsFeatureExtractor(6), - new UnorderedSequentialPairsFeatureExtractor(8), - new UnorderedSequentialPairsFeatureExtractor(10), - new OrderedSequentialPairsFeatureExtractor(6), - new OrderedSequentialPairsFeatureExtractor(8), - new OrderedSequentialPairsFeatureExtractor(10), - new MatchingTermCount(), - new QueryLength(), - new SumMatchingTf(), - new TermFrequencyFeatureExtractor(), - new BM25FeatureExtractor(), - new TFIDFFeatureExtractor(), - new UniqueTermCount(), - new DocSizeFeatureExtractor(), - new AvgICTFFeatureExtractor(), - new AvgIDFFeatureExtractor(), - new SimplifiedClarityFeatureExtractor(), - new PMIFeatureExtractor(), - new SCQFeatureExtractor(), - new LinkCount(), - new TwitterFollowerCount(), - new TwitterFriendCount(), - new IsTweetReply(), - new HashtagCount() - ); - /** - * Constructor that requires a reader to the index, the qrels and the topics - * - * @param reader - * @param qrels - * @param topics - */ - public TwitterFeatureExtractor(IndexReader reader, Qrels qrels, Map> topics) { - super(reader, qrels, topics, getDefaultExtractors()); - LOG.debug("Twitter Feature Extractor initialized."); - } - - /** - * Constructor that requires a reader to the index, the qrels and the topics - * - * @param reader - * @param qrels - * @param topics - */ - public TwitterFeatureExtractor(IndexReader reader, Qrels qrels, - Map> topics, FeatureExtractors featureExtractors) { - super(reader, qrels, topics, featureExtractors == null ? getDefaultExtractors() : featureExtractors); - LOG.debug("Twitter Feature Extractor initialized with custom feature extractors."); - } - - - - @Override - protected String getIdField() { - return IndexArgs.ID; - } - - @Override - protected String getTermVectorField() { - return IndexArgs.CONTENTS; - } - - public static FeatureExtractors getDefaultExtractors() { - return DEFAULT_EXTRACTOR_CHAIN; - } - - @Override - protected Analyzer getAnalyzer() { - return new TweetAnalyzer(); - } - - @Override - protected Set getFieldsToLoad() { - return new HashSet<>(Arrays.asList( - getIdField(), - getTermVectorField(), - TweetField.FOLLOWERS_COUNT.name, - TweetField.FRIENDS_COUNT.name, - TweetField.IN_REPLY_TO_STATUS_ID.name) - ); - } - - @Override - protected Query parseQuery(String queryText) { - LOG.debug(String.format("Parsing query: %s", queryText) ); - return new BagOfWordsQueryGenerator().buildQuery(IndexArgs.CONTENTS, new TweetAnalyzer(), queryText); - } - - @Override - protected Query docIdQuery(String docId) { - long docIdLong = Long.parseLong(docId); - return LongPoint.newRangeQuery(getIdField(), docIdLong, docIdLong); - } -} diff --git a/src/main/java/io/anserini/ltr/TwitterFeatureReranker.java b/src/main/java/io/anserini/ltr/TwitterFeatureReranker.java deleted file mode 100644 index 40dc3856ea..0000000000 --- a/src/main/java/io/anserini/ltr/TwitterFeatureReranker.java +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Anserini: A Lucene toolkit for replicable information retrieval research - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.anserini.ltr; - -import io.anserini.index.IndexArgs; -import io.anserini.ltr.feature.FeatureExtractors; -import io.anserini.rerank.Reranker; -import io.anserini.rerank.RerankerContext; -import io.anserini.rerank.ScoredDocuments; -import io.anserini.util.Qrels; -import org.apache.lucene.index.IndexReader; -import org.apache.lucene.index.Terms; - -import java.io.IOException; -import java.io.PrintStream; - -/** - * Used to rerank according to features - * - */ -public class TwitterFeatureReranker implements Reranker { - private final PrintStream out; - private final Qrels qrels; - private final FeatureExtractors extractors; - - public TwitterFeatureReranker(PrintStream out, Qrels qrels, FeatureExtractors extractors) { - this.out = out; - this.qrels = qrels; - this.extractors = extractors == null ? TwitterFeatureExtractor.getDefaultExtractors() : extractors; - } - - @Override - public ScoredDocuments rerank(ScoredDocuments docs, RerankerContext context) { - IndexReader reader = context.getIndexSearcher().getIndexReader(); - - for (int i = 0; i < docs.documents.length; i++) { - Terms terms = null; - try { - terms = reader.getTermVector(docs.ids[i], IndexArgs.CONTENTS); - } catch (IOException e) { - continue; - } - - int qid = context.getQueryId(); - String docid = docs.documents[i].getField(IndexArgs.ID).stringValue(); - - out.print(qrels.getRelevanceGrade(qid, docid)); - out.print(" qid:" + qid); - - float[] intFeatures = this.extractors.extractAll(docs.documents[i], terms, context); - - // TODO use model to rerank - } - - return docs; - } - - @Override - public String tag() { return ""; } -} diff --git a/src/main/java/io/anserini/ltr/WebCollectionLtrDataGenerator.java b/src/main/java/io/anserini/ltr/WebCollectionLtrDataGenerator.java deleted file mode 100644 index 27623cf15a..0000000000 --- a/src/main/java/io/anserini/ltr/WebCollectionLtrDataGenerator.java +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Anserini: A Lucene toolkit for replicable information retrieval research - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.anserini.ltr; - -import io.anserini.index.IndexArgs; -import io.anserini.ltr.feature.FeatureExtractors; -import io.anserini.rerank.Reranker; -import io.anserini.rerank.RerankerContext; -import io.anserini.rerank.ScoredDocuments; -import io.anserini.util.Qrels; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.apache.lucene.document.Document; -import org.apache.lucene.index.IndexReader; -import org.apache.lucene.index.Terms; - -import java.io.IOException; -import java.io.PrintStream; - -/** - * A reranker that will be used to dump feature vectors - * for documents retrieved from a search - */ -public class WebCollectionLtrDataGenerator implements Reranker { - private static final Logger LOG = LogManager.getLogger(WebCollectionLtrDataGenerator.class); - - private PrintStream out; - private Qrels qrels; - private final FeatureExtractors extractorChain; - - /** - * Constructor - * @param out The output stream to actually print it - */ - public WebCollectionLtrDataGenerator(PrintStream out, Qrels qrels, FeatureExtractors extractors) { - this.out = out; - this.qrels = qrels; - this.extractorChain = extractors == null ? WebFeatureExtractor.getDefaultExtractors() : extractors; - } - - @Override - public ScoredDocuments rerank(ScoredDocuments docs, RerankerContext context) { - Document[] documents = docs.documents; - IndexReader reader = context.getIndexSearcher().getIndexReader(); - int qid = context.getQueryId(); - LOG.info("Beginning rerank"); - for (int i =0; i < docs.documents.length; i++ ) { - try { - Terms terms = reader.getTermVector(docs.ids[i], IndexArgs.CONTENTS); - float[] features = this.extractorChain.extractAll(documents[i], terms, context); - String docId = documents[i].get(IndexArgs.ID); - // QREL 0 in this case, will be assigned if needed later - //qid - BaseFeatureExtractor.writeFeatureVector(out, qid, this.qrels.getRelevanceGrade(qid, docId), docId, features); - LOG.info("Finished writing vectors"); - } catch (IOException e) { - LOG.error(String.format("IOExecption trying to retrieve feature vector for %d doc", docs.ids[i])); - continue; - } - } - // Does nothing to the actual docs, we just need to extract the feature vector - return docs; - } - - @Override - public String tag() { return ""; } -} diff --git a/src/main/java/io/anserini/ltr/WebFeatureExtractor.java b/src/main/java/io/anserini/ltr/WebFeatureExtractor.java deleted file mode 100644 index d11fbfa99e..0000000000 --- a/src/main/java/io/anserini/ltr/WebFeatureExtractor.java +++ /dev/null @@ -1,151 +0,0 @@ -/* - * Anserini: A Lucene toolkit for replicable information retrieval research - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.anserini.ltr; - -import io.anserini.index.IndexArgs; -import io.anserini.ltr.feature.FeatureExtractors; -import io.anserini.ltr.feature.OrderedSequentialPairsFeatureExtractor; -import io.anserini.ltr.feature.UnigramFeatureExtractor; -import io.anserini.ltr.feature.UnorderedSequentialPairsFeatureExtractor; -import io.anserini.ltr.feature.base.AvgICTFFeatureExtractor; -import io.anserini.ltr.feature.base.AvgIDFFeatureExtractor; -import io.anserini.ltr.feature.base.BM25FeatureExtractor; -import io.anserini.ltr.feature.base.DocSizeFeatureExtractor; -import io.anserini.ltr.feature.base.MatchingTermCount; -import io.anserini.ltr.feature.base.PMIFeatureExtractor; -import io.anserini.ltr.feature.base.QueryLength; -import io.anserini.ltr.feature.base.SCQFeatureExtractor; -import io.anserini.ltr.feature.base.SimplifiedClarityFeatureExtractor; -import io.anserini.ltr.feature.base.SumMatchingTf; -import io.anserini.ltr.feature.base.TFIDFFeatureExtractor; -import io.anserini.ltr.feature.base.TermFrequencyFeatureExtractor; -import io.anserini.ltr.feature.base.UniqueTermCount; -import io.anserini.util.Qrels; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.apache.lucene.analysis.Analyzer; -import org.apache.lucene.analysis.en.EnglishAnalyzer; -import org.apache.lucene.index.IndexReader; -import org.apache.lucene.index.Term; -import org.apache.lucene.queryparser.classic.ParseException; -import org.apache.lucene.queryparser.classic.QueryParser; -import org.apache.lucene.search.Query; -import org.apache.lucene.search.TermQuery; - -import java.util.Arrays; -import java.util.HashSet; -import java.util.Map; -import java.util.Set; - -/** - * Feature extractor for the gov two collection - */ -public class WebFeatureExtractor extends BaseFeatureExtractor { - private static final Logger LOG = LogManager.getLogger(WebFeatureExtractor.class); - - //************************************************** - //************************************************** - private static final FeatureExtractors DEFAULT_EXTRACTOR_CHAIN = FeatureExtractors. - createFeatureExtractorChain(new UnigramFeatureExtractor(), - new UnorderedSequentialPairsFeatureExtractor(6), - new UnorderedSequentialPairsFeatureExtractor(8), - new UnorderedSequentialPairsFeatureExtractor(10), - new OrderedSequentialPairsFeatureExtractor(6), - new OrderedSequentialPairsFeatureExtractor(8), - new OrderedSequentialPairsFeatureExtractor(10), - new MatchingTermCount(), - new QueryLength(), - new SumMatchingTf(), - new TermFrequencyFeatureExtractor(), - new BM25FeatureExtractor(), - new TFIDFFeatureExtractor(), - new UniqueTermCount(), - new DocSizeFeatureExtractor(), - new AvgICTFFeatureExtractor(), - new AvgIDFFeatureExtractor(), - new SimplifiedClarityFeatureExtractor(), - new PMIFeatureExtractor(), - new SCQFeatureExtractor() - ); - - //************************************************** - //************************************************** - - private QueryParser parser; - - public WebFeatureExtractor(IndexReader reader, Qrels qrels, Map> topics) { - this(reader, qrels, topics, getDefaultExtractors()); - LOG.debug("Web Feature extractor initialized."); - } - - /** - * FeatureExtractor constructor requires an index reader, qrels, and topics - * also takes in optional customExtractors, if null, the default will be used - * @param reader - * @param qrels - * @param topics - * @param customExtractors - */ - @SuppressWarnings("unchecked") - public WebFeatureExtractor(IndexReader reader, Qrels qrels, Map> topics, - FeatureExtractors customExtractors) { - super(reader, qrels, topics, customExtractors == null ? getDefaultExtractors() : customExtractors); - this.parser = new QueryParser(getTermVectorField(), getAnalyzer()); - LOG.debug("Web Feature extractor initialized."); - } - - @Override - protected String getIdField() { - return IndexArgs.ID; - } - - @Override - protected String getTermVectorField() { - return IndexArgs.CONTENTS; - } - - public static FeatureExtractors getDefaultExtractors() { - return DEFAULT_EXTRACTOR_CHAIN; - } - - @Override - protected Analyzer getAnalyzer() { - return new EnglishAnalyzer(); - } - - @Override - protected Set getFieldsToLoad() { - return new HashSet<>(Arrays.asList(getIdField(), getTermVectorField())); - } - - @Override - protected Query parseQuery(String queryText) { - try { - return this.parser.parse(queryText); - } catch (ParseException e) { - LOG.error(String.format("Unable to parse query for query text %s, error %s", - queryText, e)); - return null; - } - } - - @Override - protected Query docIdQuery(String docId) { - return new TermQuery(new Term(getIdField(), docId)); - } - -} diff --git a/src/main/java/io/anserini/ltr/feature/FeatureExtractor.java b/src/main/java/io/anserini/ltr/feature/FeatureExtractor.java index e08f2f45ad..8b00ce53e9 100644 --- a/src/main/java/io/anserini/ltr/feature/FeatureExtractor.java +++ b/src/main/java/io/anserini/ltr/feature/FeatureExtractor.java @@ -16,63 +16,46 @@ package io.anserini.ltr.feature; -import io.anserini.ltr.feature.base.AvgICTFFeatureExtractor; -import io.anserini.ltr.feature.base.AvgIDFFeatureExtractor; -import io.anserini.ltr.feature.base.BM25FeatureExtractor; -import io.anserini.ltr.feature.base.DocSizeFeatureExtractor; -import io.anserini.ltr.feature.base.MatchingTermCount; -import io.anserini.ltr.feature.base.PMIFeatureExtractor; -import io.anserini.ltr.feature.base.QueryLength; -import io.anserini.ltr.feature.base.SCQFeatureExtractor; -import io.anserini.ltr.feature.base.SimplifiedClarityFeatureExtractor; -import io.anserini.ltr.feature.base.SumMatchingTf; -import io.anserini.ltr.feature.base.TFIDFFeatureExtractor; -import io.anserini.ltr.feature.base.TermFrequencyFeatureExtractor; -import io.anserini.ltr.feature.base.UniqueTermCount; -import io.anserini.ltr.feature.twitter.HashtagCount; -import io.anserini.ltr.feature.twitter.IsTweetReply; -import io.anserini.ltr.feature.twitter.LinkCount; -import io.anserini.ltr.feature.twitter.TwitterFollowerCount; -import io.anserini.ltr.feature.twitter.TwitterFriendCount; import io.anserini.rerank.RerankerContext; import org.apache.lucene.document.Document; +import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.Terms; import java.util.HashMap; +import java.util.List; import java.util.Map; /** * A feature extractor. */ -public interface FeatureExtractor { - //******************************************************** - // TODO normalize names - final Map> EXTRACTOR_MAP = new HashMap>() {{ - put("AvgICTF", AvgICTFFeatureExtractor.class); - put("SimplifiedClarityScore", SimplifiedClarityFeatureExtractor.class); - put("PMIFeature", PMIFeatureExtractor.class); - put("AvgSCQ", SCQFeatureExtractor.class); - put("SumMatchingTf", SumMatchingTf.class); - put("UnigramsFeatureExtractor", UnigramFeatureExtractor.class); - put("AvgIDF", AvgIDFFeatureExtractor.class); - put("BM25Feature", BM25FeatureExtractor.class); - put("DocSize", DocSizeFeatureExtractor.class); - put("MatchingTermCount", MatchingTermCount.class); - put("QueryLength", QueryLength.class); - put("SumTermFrequency", TermFrequencyFeatureExtractor.class); - put("TFIDF", TFIDFFeatureExtractor.class); - put("UniqueQueryTerms", UniqueTermCount.class); - put("UnorderedSequentialPairs", UnorderedSequentialPairsFeatureExtractor.class); - put("OrderedSequentialPairs", OrderedSequentialPairsFeatureExtractor.class); - put("TwitterHashtagCount", HashtagCount.class); - put("IsTweetReply", IsTweetReply.class); - put("TwitterLinkCount", LinkCount.class); - put("TwitterFollowerCount", TwitterFollowerCount.class); - put("TwitterFriendCount", TwitterFriendCount.class); - }}; +public interface FeatureExtractor { - float extract(Document doc, Terms terms, RerankerContext context); + /** + * @param doc the document we work on + * @param terms a iterator to the term vector of the content field + * @param queryText original query text + * @param queryTokens tokenized query text + * @param reader in case the extractor need some global information + * @return feature value + */ + float extract(Document doc, Terms terms, String queryText, List queryTokens, IndexReader reader); + /** + * we need to make sure each thread has a thread-local copy of extractors + * otherwise we will have concurrency problems + * @return a copy with the same set up + */ + FeatureExtractor clone(); + + /** + * used for tell the corresponding feature name for each column in the feature vector + * @return feature name + */ String getName(); + /** + * @return the field this feature extractor needs to load + */ + String getField(); + } diff --git a/src/main/java/io/anserini/ltr/feature/FeatureExtractors.java b/src/main/java/io/anserini/ltr/feature/FeatureExtractors.java deleted file mode 100644 index 73f5f0dbfc..0000000000 --- a/src/main/java/io/anserini/ltr/feature/FeatureExtractors.java +++ /dev/null @@ -1,124 +0,0 @@ -/* - * Anserini: A Lucene toolkit for replicable information retrieval research - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.anserini.ltr.feature; - -import com.fasterxml.jackson.core.JsonFactory; -import com.fasterxml.jackson.core.JsonParser; -import com.fasterxml.jackson.core.JsonParser.Feature; -import com.fasterxml.jackson.databind.DeserializationFeature; -import com.fasterxml.jackson.databind.JsonNode; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.databind.module.SimpleModule; -import io.anserini.rerank.RerankerContext; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.apache.lucene.document.Document; -import org.apache.lucene.index.Terms; - -import java.io.FileReader; -import java.util.ArrayList; -import java.util.List; - -/** - * A collection of {@link FeatureExtractor}s. - */ -public class FeatureExtractors { - private static final Logger LOG = LogManager.getLogger(FeatureExtractors.class); - - //******************************************************** - //******************************************************** - private static final String JSON_KEY = "extractors"; - private static final String NAME_KEY = "name"; - private static final String CONFIG_KEY = "params"; - - public static FeatureExtractors loadExtractor(String filePath) throws Exception { - JsonParser extractorJson = new JsonFactory().createParser(new FileReader(filePath)); - return FeatureExtractors.fromJson(extractorJson); - } - - public static FeatureExtractors fromJson(JsonParser jsonParser) throws Exception { - FeatureExtractors extractors = new FeatureExtractors(); - - ObjectMapper objectMapper = new ObjectMapper(); - SimpleModule module = new SimpleModule(); - module.addDeserializer(UnorderedSequentialPairsFeatureExtractor.class, - new UnorderedSequentialPairsFeatureExtractor.Deserializer()); - module.addDeserializer(OrderedSequentialPairsFeatureExtractor.class, - new OrderedSequentialPairsFeatureExtractor.Deserializer()); - module.addDeserializer(OrderedQueryPairsFeatureExtractor.class, - new OrderedQueryPairsFeatureExtractor.Deserializer()); - module.addDeserializer(UnorderedQueryPairsFeatureExtractor.class, - new UnorderedQueryPairsFeatureExtractor.Deserializer()); - objectMapper.registerModule(module); - objectMapper.disable(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES); - - jsonParser.configure(Feature.ALLOW_UNQUOTED_FIELD_NAMES, true); - JsonNode node = objectMapper.readTree(jsonParser); - for (JsonNode extractor : node.get(JSON_KEY)) { - String extractorName = extractor.get(NAME_KEY).asText(); - if (!FeatureExtractor.EXTRACTOR_MAP.containsKey(extractorName)) { - LOG.warn(String.format("Unknown extractor %s encountered, skipping", extractorName)); - continue; - } - - if (extractor.has(CONFIG_KEY)) { - JsonNode config = extractor.get(CONFIG_KEY); - JsonParser configJsonParser = objectMapper.treeAsTokens(config); - FeatureExtractor parsedExtractor = (FeatureExtractor) objectMapper - .readValue(configJsonParser, FeatureExtractor.EXTRACTOR_MAP.get(extractorName)); - extractors.add(parsedExtractor); - } else { - FeatureExtractor parsedExtractor = (FeatureExtractor) FeatureExtractor.EXTRACTOR_MAP.get(extractorName) - .getConstructor().newInstance(); - extractors.add(parsedExtractor); - } - } - - return extractors; - } - - public static FeatureExtractors createFeatureExtractorChain(FeatureExtractor... extractors) { - FeatureExtractors chain = new FeatureExtractors(); - for (FeatureExtractor extractor : extractors) { - chain.add(extractor); - } - - return chain; - } - - //******************************************************** - - public List extractors = new ArrayList<>(); - - public FeatureExtractors() {} - - public FeatureExtractors add(FeatureExtractor extractor) { - extractors.add(extractor); - return this; - } - - @SuppressWarnings("unchecked") - public float[] extractAll(Document doc, Terms terms, RerankerContext context) { - float[] features = new float[extractors.size()]; - - for (int i=0; i implements FeatureExtractor { +public class OrderedQueryPairsFeatureExtractor implements FeatureExtractor { private static final Logger LOG = LogManager.getLogger(OrderedQueryPairsFeatureExtractor.class); protected static ArrayList gapSizes = new ArrayList<>(); @@ -45,29 +47,8 @@ public class OrderedQueryPairsFeatureExtractor implements FeatureExtractor protected static Map singleCountMap = new HashMap<>(); protected static Map> queryPairMap = new HashMap<>(); - protected static String lastProcessedId = ""; protected static Document lastProcessedDoc = null; - public static class Deserializer extends StdDeserializer - { - public Deserializer() { - this(null); - } - - public Deserializer(Class vc) { - super(vc); - } - - @Override - public OrderedQueryPairsFeatureExtractor - deserialize(JsonParser jsonParser, DeserializationContext context) throws IOException - { - JsonNode node = jsonParser.getCodec().readTree(jsonParser); - int gapSize = node.get("gapSize").asInt(); - return new OrderedQueryPairsFeatureExtractor(gapSize); - } - } - public OrderedQueryPairsFeatureExtractor(int gapSize) { this.gapSize = gapSize; // Add a window to the counters @@ -75,13 +56,12 @@ public OrderedQueryPairsFeatureExtractor(int gapSize) { gapSizes.add(gapSize); } - private static void resetCounters(String newestQuery, Document newestDoc) { + private static void resetCounters(Document newestDoc) { singleCountMap.clear(); queryPairMap.clear(); for (int i : counters.keySet()) { counters.get(i).phraseCountMap.clear(); } - lastProcessedId = newestQuery; lastProcessedDoc = newestDoc; } @@ -99,13 +79,12 @@ protected void populateQueryPairMap(List queryTokens) { } } - protected float computeOrderedFrequencyScore(Document doc, Terms terms, RerankerContext context) throws IOException { + protected float computeOrderedFrequencyScore(Document doc, Terms terms, List queryTokens) throws IOException { // Only compute the score once for all window sizes on the same document - if (!context.getQueryId().equals(lastProcessedId) || lastProcessedDoc != doc) { - resetCounters((String)context.getQueryId(), doc); + if (lastProcessedDoc != doc) { + resetCounters(doc); - List queryTokens = context.getQueryTokens(); populateQueryPairMap(queryTokens); // Now make the call to the static method @@ -124,9 +103,9 @@ protected float computeOrderedFrequencyScore(Document doc, Terms terms, Reranker } @Override - public float extract(Document doc, Terms terms, RerankerContext context) { + public float extract(Document doc, Terms terms, String queryText, List queryTokens, IndexReader reader) { try { - return computeOrderedFrequencyScore(doc, terms, context); + return computeOrderedFrequencyScore(doc, terms, queryTokens); } catch (IOException e) { LOG.error("IOException, returning 0.0f"); return 0.0f; @@ -137,4 +116,14 @@ public float extract(Document doc, Terms terms, RerankerContext context) { public String getName() { return "OrderedAllPairs" + this.gapSize; } + + @Override + public String getField() { + return null; + } + + @Override + public FeatureExtractor clone() { + return new OrderedQueryPairsFeatureExtractor(this.gapSize); + } } diff --git a/src/main/java/io/anserini/ltr/feature/OrderedSequentialPairsFeatureExtractor.java b/src/main/java/io/anserini/ltr/feature/OrderedSequentialPairsFeatureExtractor.java index c318d74ea4..2612736567 100644 --- a/src/main/java/io/anserini/ltr/feature/OrderedSequentialPairsFeatureExtractor.java +++ b/src/main/java/io/anserini/ltr/feature/OrderedSequentialPairsFeatureExtractor.java @@ -20,10 +20,12 @@ import com.fasterxml.jackson.databind.DeserializationContext; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.deser.std.StdDeserializer; +import io.anserini.index.IndexArgs; import io.anserini.rerank.RerankerContext; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.lucene.document.Document; +import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.Terms; import java.io.IOException; @@ -39,7 +41,7 @@ * This feature extractor will return the number of phrases * in a specified gap size */ -public class OrderedSequentialPairsFeatureExtractor implements FeatureExtractor { +public class OrderedSequentialPairsFeatureExtractor implements FeatureExtractor { private static final Logger LOG = LogManager.getLogger(OrderedSequentialPairsFeatureExtractor.class); protected static ArrayList gapSizes = new ArrayList<>(); @@ -47,36 +49,14 @@ public class OrderedSequentialPairsFeatureExtractor implements FeatureExtract protected static Map singleCountMap = new HashMap<>(); protected static Map> queryPairMap = new HashMap<>(); - protected static String lastProcessedId = ""; protected static Document lastProcessedDoc = null; - public static class Deserializer extends StdDeserializer - { - public Deserializer() { - this(null); - } - - public Deserializer(Class vc) { - super(vc); - } - - @Override - public OrderedSequentialPairsFeatureExtractor - deserialize(JsonParser jsonParser, DeserializationContext context) throws IOException - { - JsonNode node = jsonParser.getCodec().readTree(jsonParser); - int gapSize = node.get("gapSize").asInt(); - return new OrderedSequentialPairsFeatureExtractor(gapSize); - } - } - - private static void resetCounters(String newestQuery, Document newestDoc) { + private static void resetCounters(Document newestDoc) { singleCountMap.clear(); queryPairMap.clear(); for (int i : counters.keySet()) { counters.get(i).phraseCountMap.clear(); } - lastProcessedId = newestQuery; lastProcessedDoc = newestDoc; } @@ -91,9 +71,9 @@ public OrderedSequentialPairsFeatureExtractor(int gapSize) { } @Override - public float extract(Document doc, Terms terms, RerankerContext context) { + public float extract(Document doc, Terms terms, String queryText, List queryTokens, IndexReader reader) { try { - return computeOrderedFrequencyScore(doc, terms, context); + return computeOrderedFrequencyScore(doc, terms, queryTokens); } catch (IOException e) { LOG.error("IOException, returning 0.0f"); return 0.0f; @@ -117,13 +97,12 @@ protected void populateQueryPairMap(List queryTokens) { } } - protected float computeOrderedFrequencyScore(Document doc, Terms terms, RerankerContext context) throws IOException { + protected float computeOrderedFrequencyScore(Document doc, Terms terms,List queryTokens) throws IOException { // Only compute the score once for all window sizes on the same document - if (!context.getQueryId().equals(lastProcessedId) || lastProcessedDoc != doc) { - resetCounters(context.getQueryId().toString(), doc); + if (lastProcessedDoc != doc) { + resetCounters(doc); - List queryTokens = context.getQueryTokens(); populateQueryPairMap(queryTokens); // Now make the call to the static method @@ -145,4 +124,14 @@ protected float computeOrderedFrequencyScore(Document doc, Terms terms, Reranker public String getName() { return "OrderedSequentialPairs" + this.gapSize; } + + @Override + public String getField() { + return null; + } + + @Override + public FeatureExtractor clone() { + return new OrderedSequentialPairsFeatureExtractor(this.gapSize); + } } diff --git a/src/main/java/io/anserini/ltr/feature/SequentialDependenceModel.java b/src/main/java/io/anserini/ltr/feature/SequentialDependenceModel.java index aa2e59276e..0bb54f2aec 100644 --- a/src/main/java/io/anserini/ltr/feature/SequentialDependenceModel.java +++ b/src/main/java/io/anserini/ltr/feature/SequentialDependenceModel.java @@ -16,6 +16,7 @@ package io.anserini.ltr.feature; +import io.anserini.index.IndexArgs; import io.anserini.rerank.RerankerContext; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -35,10 +36,9 @@ /** * Implementation of the Sequential Dependence term dependence model */ -public class SequentialDependenceModel implements FeatureExtractor { +public class SequentialDependenceModel implements FeatureExtractor { private static final Logger LOG = LogManager.getLogger(SequentialDependenceModel.class); - private static final String NAME = "SDM"; private static final int WINDOW_SIZE = 8; private float lambdaT = 0.5f; @@ -51,9 +51,7 @@ public SequentialDependenceModel(float lambdaT, float lambdaO, float lambdaU) { this.lambdaU = lambdaU; } - private float computeUnorderedFrequencyScore(Document doc, Terms terms, RerankerContext context) throws IOException { - List queryTokens = context.getQueryTokens(); - + private float computeUnorderedFrequencyScore(Document doc, Terms terms, List queryTokens) throws IOException { // Construct token stream with offset 0 TokenStream stream = new TokenStreamFromTermVector(terms, 0); CharTermAttribute termAttribute = stream.addAttribute(CharTermAttribute.class); @@ -120,8 +118,7 @@ private float computeUnorderedFrequencyScore(Document doc, Terms terms, Reranker return score; } - private float computeOrderedFrequencyScore(Document doc, Terms terms, RerankerContext context) throws IOException { - List queryTokens = context.getQueryTokens(); + private float computeOrderedFrequencyScore(Document doc, Terms terms, List queryTokens) throws IOException { Map queryPairMap = new HashMap<>(); Map phraseCountMap = new HashMap<>(); Map singleCountMap = new HashMap<>(); @@ -171,16 +168,12 @@ private float computeOrderedFrequencyScore(Document doc, Terms terms, RerankerCo /** * The single term scoring function: lambda* log( (1-alpha) tf/ |D|) - * @param doc * @param terms - * @param context * @return */ - private float computeFullIndependenceScore(Document doc, Terms terms, RerankerContext context) throws IOException { + private float computeFullIndependenceScore(Terms terms) throws IOException { // tf can be calculated by iterating over terms, number of times a term occurs in doc // |D| total number of terms can be calculated by iterating over stream - IndexReader reader = context.getIndexSearcher().getIndexReader(); - List queryTokenList = context.getQueryTokens(); Map termCount = new HashMap<>(); TokenStream stream = new TokenStreamFromTermVector(terms, 0); @@ -208,15 +201,15 @@ private float computeFullIndependenceScore(Document doc, Terms terms, RerankerCo } @Override - public float extract(Document doc, Terms terms, RerankerContext context) { + public float extract(Document doc, Terms terms, String queryText, List queryTokens, IndexReader reader) { float orderedWindowScore = 0.0f; float unorderedDependenceScore = 0.0f; float independentScore = 0.0f; try { - independentScore = computeFullIndependenceScore(doc, terms, context); - orderedWindowScore = computeOrderedFrequencyScore(doc, terms, context); - unorderedDependenceScore = computeUnorderedFrequencyScore(doc, terms, context); + independentScore = computeFullIndependenceScore(terms); + orderedWindowScore = computeOrderedFrequencyScore(doc, terms, queryTokens); + unorderedDependenceScore = computeUnorderedFrequencyScore(doc, terms, queryTokens); LOG.debug(String.format("independent: %f, ordered: %f, unordered: %f", independentScore, orderedWindowScore, unorderedDependenceScore)); } catch (IOException e) { e.printStackTrace(); @@ -226,6 +219,16 @@ public float extract(Document doc, Terms terms, RerankerContext context) { @Override public String getName() { - return NAME; + return "SDM"; + } + + @Override + public String getField() { + return null; + } + + @Override + public FeatureExtractor clone() { + return new SequentialDependenceModel(this.lambdaT, this.lambdaO, this.lambdaU); } } diff --git a/src/main/java/io/anserini/ltr/feature/UnigramFeatureExtractor.java b/src/main/java/io/anserini/ltr/feature/UnigramFeatureExtractor.java deleted file mode 100644 index fbe7f5bb55..0000000000 --- a/src/main/java/io/anserini/ltr/feature/UnigramFeatureExtractor.java +++ /dev/null @@ -1,97 +0,0 @@ -/* - * Anserini: A Lucene toolkit for replicable information retrieval research - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.anserini.ltr.feature; - -import io.anserini.rerank.RerankerContext; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.apache.lucene.analysis.TokenStream; -import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; -import org.apache.lucene.document.Document; -import org.apache.lucene.index.IndexReader; -import org.apache.lucene.index.Terms; -import org.apache.lucene.search.highlight.TokenStreamFromTermVector; - -import java.io.IOException; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -/** - * Counts unigrams - */ -public class UnigramFeatureExtractor implements FeatureExtractor { - private static final Logger LOG = LogManager.getLogger(UnigramFeatureExtractor.class); - - @Override - public float extract(Document doc, Terms terms, RerankerContext context) { - try { - return computeFullIndependenceScore(doc, terms, context); - } catch (IOException e) { - LOG.error("IOException while counting unigrams, returning 0"); - return 0.0f; - } - } - - /** - * The single term scoring function: lambda* log( (1-alpha) tf/ |D|) - * @param doc - * @param terms - * @param context - * @return - */ - private float computeFullIndependenceScore(Document doc, Terms terms, RerankerContext context) throws IOException { - // tf can be calculated by iterating over terms, number of times a term occurs in doc - // |D| total number of terms can be calculated by iterating over stream - IndexReader reader = context.getIndexSearcher().getIndexReader(); - List queryTokenList = context.getQueryTokens(); - Map termCount = new HashMap<>(); - - for (String queryToken : queryTokenList) { - termCount.put(queryToken, 0); - } - TokenStream stream = new TokenStreamFromTermVector(terms, -1); - CharTermAttribute termAttribute = stream.addAttribute(CharTermAttribute.class); - - stream.reset(); - float docSize =0; - // Count all the tokens - while (stream.incrementToken()) { - docSize++; - String token = termAttribute.toString(); - if (termCount.containsKey(token)) { - termCount.put(token, termCount.get(token) + 1); - } - } - float score = 0.0f; - // Smoothing count of 1 - docSize++; - // Only compute the score for what's in term count all else 0 - for (String queryToken : termCount.keySet()) { - score += termCount.get(queryToken); - } - - stream.end(); - stream.close(); - return score; - } - - @Override - public String getName() { - return "UnigramsFeatureExtractor"; - } -} diff --git a/src/main/java/io/anserini/ltr/feature/UnorderedQueryPairsFeatureExtractor.java b/src/main/java/io/anserini/ltr/feature/UnorderedQueryPairsFeatureExtractor.java index 6b68914db6..4de74a2525 100644 --- a/src/main/java/io/anserini/ltr/feature/UnorderedQueryPairsFeatureExtractor.java +++ b/src/main/java/io/anserini/ltr/feature/UnorderedQueryPairsFeatureExtractor.java @@ -20,8 +20,10 @@ import com.fasterxml.jackson.databind.DeserializationContext; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.deser.std.StdDeserializer; +import io.anserini.index.IndexArgs; import io.anserini.rerank.RerankerContext; import org.apache.lucene.document.Document; +import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.Terms; import java.io.IOException; @@ -36,37 +38,16 @@ /** * Counts all unordered pairs of query tokens */ -public class UnorderedQueryPairsFeatureExtractor implements FeatureExtractor { +public class UnorderedQueryPairsFeatureExtractor implements FeatureExtractor { protected static ArrayList gapSizes = new ArrayList<>(); protected static Map counters = new HashMap<>(); protected static Map singleCountMap = new HashMap<>(); protected static Map> queryPairMap = new HashMap<>(); protected static Map> backQueryPairMap = new HashMap<>(); - protected static String lastProcessedId = ""; protected static Document lastProcessedDoc = null; - public static class Deserializer extends StdDeserializer - { - public Deserializer() { - this(null); - } - - public Deserializer(Class vc) { - super(vc); - } - - @Override - public UnorderedQueryPairsFeatureExtractor - deserialize(JsonParser jsonParser, DeserializationContext context) throws IOException - { - JsonNode node = jsonParser.getCodec().readTree(jsonParser); - int gapSize = node.get("gapSize").asInt(); - return new UnorderedQueryPairsFeatureExtractor(gapSize); - } - } - - private static void resetCounters(String newestQuery, Document newestDoc) { + private static void resetCounters(Document newestDoc) { singleCountMap.clear(); backQueryPairMap.clear(); @@ -74,7 +55,6 @@ private static void resetCounters(String newestQuery, Document newestDoc) { for (int i : counters.keySet()) { counters.get(i).phraseCountMap.clear(); } - lastProcessedId = newestQuery; lastProcessedDoc = newestDoc; } @@ -109,11 +89,10 @@ protected void populateQueryMaps(List queryTokens) { singleCountMap.put(queryTokens.get(queryTokens.size() - 1), 0); } - protected float computeUnorderedFrequencyScore(Document doc, Terms terms, RerankerContext context) throws IOException { + protected float computeUnorderedFrequencyScore(Document doc, Terms terms, List queryTokens) throws IOException { - if (!context.getQueryId().equals(lastProcessedId) || doc != lastProcessedDoc) { - resetCounters(context.getQueryId().toString(), doc); - List queryTokens = context.getQueryTokens(); + if (doc != lastProcessedDoc) { + resetCounters(doc); populateQueryMaps(queryTokens); @@ -131,9 +110,9 @@ protected float computeUnorderedFrequencyScore(Document doc, Terms terms, Rerank return score; } @Override - public float extract(Document doc, Terms terms, RerankerContext context) { + public float extract(Document doc, Terms terms, String queryText, List queryTokens, IndexReader reader) { try { - return computeUnorderedFrequencyScore(doc, terms, context); + return computeUnorderedFrequencyScore(doc, terms, queryTokens); } catch (IOException e) { e.printStackTrace(); return 0.0f; @@ -144,4 +123,14 @@ public float extract(Document doc, Terms terms, RerankerContext context) { public String getName() { return "UnorderedQueryTokenPairs" + this.gapSize; } + + @Override + public String getField() { + return null; + } + + @Override + public FeatureExtractor clone() { + return new UnorderedQueryPairsFeatureExtractor(this.gapSize); + } } diff --git a/src/main/java/io/anserini/ltr/feature/UnorderedSequentialPairsFeatureExtractor.java b/src/main/java/io/anserini/ltr/feature/UnorderedSequentialPairsFeatureExtractor.java index e64a4855e1..2c117c2c04 100644 --- a/src/main/java/io/anserini/ltr/feature/UnorderedSequentialPairsFeatureExtractor.java +++ b/src/main/java/io/anserini/ltr/feature/UnorderedSequentialPairsFeatureExtractor.java @@ -20,8 +20,10 @@ import com.fasterxml.jackson.databind.DeserializationContext; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.deser.std.StdDeserializer; +import io.anserini.index.IndexArgs; import io.anserini.rerank.RerankerContext; import org.apache.lucene.document.Document; +import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.Terms; import java.io.IOException; @@ -37,7 +39,7 @@ * This is a feature extractor that will calculate the * unordered count of phrases in the window specified */ -public class UnorderedSequentialPairsFeatureExtractor implements FeatureExtractor { +public class UnorderedSequentialPairsFeatureExtractor implements FeatureExtractor { protected static ArrayList gapSizes = new ArrayList<>(); protected static Map counters = new HashMap<>(); @@ -45,30 +47,9 @@ public class UnorderedSequentialPairsFeatureExtractor implements FeatureExtra protected static Map singleCountMap = new HashMap<>(); protected static Map> queryPairMap = new HashMap<>(); protected static Map> backQueryPairMap = new HashMap<>(); - protected static String lastProcessedId = ""; protected static Document lastProcessedDoc = null; - public static class Deserializer extends StdDeserializer - { - public Deserializer() { - this(null); - } - - public Deserializer(Class vc) { - super(vc); - } - - @Override - public UnorderedSequentialPairsFeatureExtractor - deserialize(JsonParser jsonParser, DeserializationContext context) throws IOException - { - JsonNode node = jsonParser.getCodec().readTree(jsonParser); - int gapSize = node.get("gapSize").asInt(); - return new UnorderedSequentialPairsFeatureExtractor(gapSize); - } - } - - private static void resetCounters(String newestQuery, Document newestDoc) { + private static void resetCounters(Document newestDoc) { singleCountMap.clear(); backQueryPairMap.clear(); @@ -76,7 +57,6 @@ private static void resetCounters(String newestQuery, Document newestDoc) { for (int i : counters.keySet()) { counters.get(i).phraseCountMap.clear(); } - lastProcessedId = newestQuery; lastProcessedDoc = newestDoc; } @@ -116,11 +96,10 @@ protected void populateQueryMaps(List queryTokens) { singleCountMap.put(queryTokens.get(queryTokens.size() -1), 0); } - protected float computeUnorderedFrequencyScore(Document doc, Terms terms, RerankerContext context) throws IOException { + protected float computeUnorderedFrequencyScore(Document doc, Terms terms, List queryTokens) throws IOException { - if (!context.getQueryId().equals(lastProcessedId) || doc != lastProcessedDoc) { - resetCounters(context.getQueryId().toString(), doc); - List queryTokens = context.getQueryTokens(); + if (doc != lastProcessedDoc) { + resetCounters(doc); populateQueryMaps(queryTokens); @@ -139,9 +118,9 @@ protected float computeUnorderedFrequencyScore(Document doc, Terms terms, Rerank } @Override - public float extract(Document doc, Terms terms, RerankerContext context) { + public float extract(Document doc, Terms terms, String queryText, List queryTokens, IndexReader reader) { try { - return computeUnorderedFrequencyScore(doc, terms, context); + return computeUnorderedFrequencyScore(doc, terms, queryTokens); } catch (IOException e) { e.printStackTrace(); return 0.0f; @@ -152,4 +131,14 @@ public float extract(Document doc, Terms terms, RerankerContext context) { public String getName() { return "UnorderedSequentialPairs" + this.gapSize; } + + @Override + public String getField() { + return null; + } + + @Override + public FeatureExtractor clone() { + return new UnorderedSequentialPairsFeatureExtractor(this.gapSize); + } } diff --git a/src/main/java/io/anserini/ltr/feature/base/AvgICTFFeatureExtractor.java b/src/main/java/io/anserini/ltr/feature/base/AvgICTFFeatureExtractor.java index 2a7b66eb4c..faf9294007 100644 --- a/src/main/java/io/anserini/ltr/feature/base/AvgICTFFeatureExtractor.java +++ b/src/main/java/io/anserini/ltr/feature/base/AvgICTFFeatureExtractor.java @@ -16,13 +16,16 @@ package io.anserini.ltr.feature.base; +import io.anserini.index.IndexArgs; import io.anserini.ltr.feature.FeatureExtractor; import io.anserini.rerank.RerankerContext; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.lucene.document.Document; +import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.Terms; import org.apache.lucene.index.TermsEnum; +import org.mockito.internal.matchers.Null; import java.io.IOException; import java.util.ArrayList; @@ -33,7 +36,7 @@ * Carmel, Yom-Tov Estimating query difficulty for Information Retrieval * log(|D| / tf) */ -public class AvgICTFFeatureExtractor implements FeatureExtractor { +public class AvgICTFFeatureExtractor implements FeatureExtractor { private static final Logger LOG = LogManager.getLogger(AvgICTFFeatureExtractor.class); // Calculate term frequencies, if error returns an empty map, couting all tf = 0 @@ -61,15 +64,25 @@ private float getSumICTF(Terms terms, List queryTokens) { return sumICTF; } @Override - public float extract(Document doc, Terms terms, RerankerContext context) { + public float extract(Document doc, Terms terms, String queryText, List queryTokens, IndexReader reader) { // We need docSize, and tf for each term - float sumIctf = getSumICTF(terms, context.getQueryTokens()); + float sumIctf = getSumICTF(terms, queryTokens); // Compute the average by dividing - return sumIctf / context.getQueryTokens().size(); + return sumIctf / queryTokens.size(); } @Override public String getName() { return "AvgICTF"; } + + @Override + public String getField() { + return null; + } + + @Override + public FeatureExtractor clone() { + return new AvgICTFFeatureExtractor(); + } } diff --git a/src/main/java/io/anserini/ltr/feature/base/AvgIDFFeatureExtractor.java b/src/main/java/io/anserini/ltr/feature/base/AvgIDFFeatureExtractor.java index e2bd541321..09083b3eb8 100644 --- a/src/main/java/io/anserini/ltr/feature/base/AvgIDFFeatureExtractor.java +++ b/src/main/java/io/anserini/ltr/feature/base/AvgIDFFeatureExtractor.java @@ -33,7 +33,7 @@ * Average IDF, idf calculated using log( 1+ (N - N_t + 0.5)/(N_t + 0.5)) * where N is the total number of docs, calculated like in BM25 */ -public class AvgIDFFeatureExtractor implements FeatureExtractor { +public class AvgIDFFeatureExtractor implements FeatureExtractor { private static final Logger LOG = LogManager.getLogger(AvgIDFFeatureExtractor.class); private float sumIdf(IndexReader reader, List queryTokens, @@ -47,13 +47,11 @@ private float sumIdf(IndexReader reader, List queryTokens, } @Override - public float extract(Document doc, Terms terms, RerankerContext context) { - IndexReader reader = context.getIndexSearcher().getIndexReader(); - + public float extract(Document doc, Terms terms, String queryText, List queryTokens, IndexReader reader) { long numDocs = reader.numDocs() - reader.numDeletedDocs(); try { - float sumIdf = sumIdf(reader, context.getQueryTokens(), numDocs, IndexArgs.CONTENTS); - return sumIdf / (float) context.getQueryTokens().size(); + float sumIdf = sumIdf(reader, queryTokens, numDocs, IndexArgs.CONTENTS); + return sumIdf / (float) queryTokens.size(); } catch (IOException e) { LOG.warn("Error computing AvgIdf, returning 0"); return 0.0f; @@ -64,4 +62,14 @@ public float extract(Document doc, Terms terms, RerankerContext context) { public String getName() { return "AvgIDF"; } + + @Override + public String getField() { + return null; + } + + @Override + public FeatureExtractor clone() { + return new AvgIDFFeatureExtractor(); + } } diff --git a/src/main/java/io/anserini/ltr/feature/base/BM25FeatureExtractor.java b/src/main/java/io/anserini/ltr/feature/base/BM25FeatureExtractor.java index e4026d92c1..156483ab49 100644 --- a/src/main/java/io/anserini/ltr/feature/base/BM25FeatureExtractor.java +++ b/src/main/java/io/anserini/ltr/feature/base/BM25FeatureExtractor.java @@ -41,7 +41,7 @@ * Lucene uses the norm value encoded in the index, we are calculating it as is * also we do not have any boosting, the field norm is also not available */ -public class BM25FeatureExtractor implements FeatureExtractor { +public class BM25FeatureExtractor implements FeatureExtractor { private static final Logger LOG = LogManager.getLogger(BM25FeatureExtractor.class); public static Map getDocFreqs(IndexReader reader, List queryTokens, String field) throws IOException { @@ -98,14 +98,10 @@ private long getSumTermFrequency(IndexReader reader, String fieldName) { * the formula used: * sum ( IDF(qi) * (df(qi,D) * (k+1)) / (df(qi,D) + k * (1-b + b*|D| / avgFL)) * IDF and avgFL computation are described above. - * @param doc document - * @param terms terms - * @param context reranker context - * @return BM25 score */ @Override - public float extract(Document doc, Terms terms, RerankerContext context) { - Set queryTokens = new HashSet<>(context.getQueryTokens()); + public float extract(Document doc, Terms terms, String queryText, List queryTokens, IndexReader reader) { + Set queryTokenSet = new HashSet<>(queryTokens); TermsEnum termsEnum = null; try { @@ -115,7 +111,6 @@ public float extract(Document doc, Terms terms, RerankerContext context) { return 0.0f; } - IndexReader reader = context.getIndexSearcher().getIndexReader(); long maxDocs = reader.numDocs(); long sumTotalTermFreq = getSumTermFrequency(reader, IndexArgs.CONTENTS); // Compute by iterating @@ -125,7 +120,7 @@ public float extract(Document doc, Terms terms, RerankerContext context) { // the term vector here is only a partial term vector that treats this as if we only have 1 document in the index Map docFreqMap = null; try { - docFreqMap = getDocFreqs(reader, context.getQueryTokens(), IndexArgs.CONTENTS); + docFreqMap = getDocFreqs(reader, queryTokens, IndexArgs.CONTENTS); } catch (IOException e) { LOG.warn("Unable to retrieve document frequencies."); docFreqMap = new HashMap<>(); @@ -136,7 +131,7 @@ public float extract(Document doc, Terms terms, RerankerContext context) { while (termsEnum.next() != null) { String termString = termsEnum.term().utf8ToString(); docSize += termsEnum.totalTermFreq(); - if (queryTokens.contains(termString)) { + if (queryTokenSet.contains(termString)) { termFreqMap.put(termString, termsEnum.totalTermFreq()); } } @@ -147,7 +142,7 @@ public float extract(Document doc, Terms terms, RerankerContext context) { float score = 0.0f; // Iterate over the query tokens double avgFL = computeAvgFL(sumTotalTermFreq, maxDocs); - for (String token : queryTokens) { + for (String token : queryTokenSet) { long docFreq = docFreqMap.getOrDefault(token, 0); double termFreq = termFreqMap.containsKey(token) ? termFreqMap.get(token) : 0; double numerator = (this.k1 + 1) * termFreq; @@ -161,7 +156,12 @@ public float extract(Document doc, Terms terms, RerankerContext context) { @Override public String getName() { - return "BM25Feature"; + return "BM25"; + } + + @Override + public String getField() { + return null; } public double getK1() { @@ -171,4 +171,9 @@ public double getK1() { public double getB() { return b; } + + @Override + public FeatureExtractor clone() { + return new BM25FeatureExtractor(this.k1, this.b); + } } diff --git a/src/main/java/io/anserini/ltr/feature/base/DocSizeFeatureExtractor.java b/src/main/java/io/anserini/ltr/feature/base/DocSizeFeatureExtractor.java index 0bdd088857..647f415735 100644 --- a/src/main/java/io/anserini/ltr/feature/base/DocSizeFeatureExtractor.java +++ b/src/main/java/io/anserini/ltr/feature/base/DocSizeFeatureExtractor.java @@ -16,24 +16,27 @@ package io.anserini.ltr.feature.base; +import io.anserini.index.IndexArgs; import io.anserini.ltr.feature.FeatureExtractor; import io.anserini.rerank.RerankerContext; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.lucene.document.Document; +import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.Terms; import org.apache.lucene.index.TermsEnum; import java.io.IOException; +import java.util.List; /** * Returns the size of the document */ -public class DocSizeFeatureExtractor implements FeatureExtractor { +public class DocSizeFeatureExtractor implements FeatureExtractor { private static final Logger LOG = LogManager.getLogger(DocSizeFeatureExtractor.class); @Override - public float extract(Document doc, Terms terms, RerankerContext context) { + public float extract(Document doc, Terms terms, String queryText, List queryTokens, IndexReader reader) { float score; try { score = (float)terms.getSumTotalTermFreq(); @@ -55,4 +58,14 @@ public float extract(Document doc, Terms terms, RerankerContext context) { public String getName() { return "DocSize"; } + + @Override + public String getField() { + return null; + } + + @Override + public FeatureExtractor clone() { + return new DocSizeFeatureExtractor(); + } } diff --git a/src/main/java/io/anserini/ltr/feature/base/MatchingTermCount.java b/src/main/java/io/anserini/ltr/feature/base/MatchingTermCount.java index 85891878c9..eefc09d2ce 100644 --- a/src/main/java/io/anserini/ltr/feature/base/MatchingTermCount.java +++ b/src/main/java/io/anserini/ltr/feature/base/MatchingTermCount.java @@ -16,9 +16,11 @@ package io.anserini.ltr.feature.base; +import io.anserini.index.IndexArgs; import io.anserini.ltr.feature.FeatureExtractor; import io.anserini.rerank.RerankerContext; import org.apache.lucene.document.Document; +import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.Terms; import org.apache.lucene.index.TermsEnum; import org.apache.lucene.util.BytesRef; @@ -30,12 +32,11 @@ * Computes the number of query terms that are found in the document. If there are three terms in * the query and all three terms are found in the document, the feature value is three. */ -public class MatchingTermCount implements FeatureExtractor { +public class MatchingTermCount implements FeatureExtractor { @Override - public float extract(Document doc, Terms terms, RerankerContext context) { + public float extract(Document doc, Terms terms, String queryText, List queryTokens, IndexReader reader) { try { - List queryTokens = context.getQueryTokens(); TermsEnum termsEnum = terms.iterator(); int matching = 0; @@ -57,4 +58,14 @@ public float extract(Document doc, Terms terms, RerankerContext context) { public String getName() { return "MatchingTermCount"; } + + @Override + public String getField() { + return null; + } + + @Override + public FeatureExtractor clone() { + return new MatchingTermCount(); + } } diff --git a/src/main/java/io/anserini/ltr/feature/base/PMIFeatureExtractor.java b/src/main/java/io/anserini/ltr/feature/base/PMIFeatureExtractor.java index 74f0599f9a..f0a49b6ec6 100644 --- a/src/main/java/io/anserini/ltr/feature/base/PMIFeatureExtractor.java +++ b/src/main/java/io/anserini/ltr/feature/base/PMIFeatureExtractor.java @@ -41,7 +41,7 @@ * where pr are the MLE * described on page 22 of Carmel, Yom-Tov 2010 */ -public class PMIFeatureExtractor implements FeatureExtractor { +public class PMIFeatureExtractor implements FeatureExtractor { private String lastQueryProcessed = ""; private float lastComputedValue = 0f; @@ -74,17 +74,15 @@ private int countPostingIntersect(PostingsEnum firstEnum, PostingsEnum secondEnu } @Override - public float extract(Document doc, Terms terms, RerankerContext context) { + public float extract(Document doc, Terms terms, String queryText, List queryTokens, IndexReader reader) { // We need docfreqs of each token // and also doc freqs of each pair - if (!this.lastQueryProcessed.equals(context.getQueryText())) { - this.lastQueryProcessed = context.getQueryText(); + if (!this.lastQueryProcessed.equals(queryText)) { + this.lastQueryProcessed = queryText; this.lastComputedValue = 0.0f; - Set querySet = new HashSet<>(context.getQueryTokens()); - IndexReader reader = context.getIndexSearcher().getIndexReader(); + Set querySet = new HashSet<>(queryTokens); Map docFreqs = new HashMap<>(); - List queryTokens = new ArrayList<>(querySet); try { @@ -135,6 +133,16 @@ public float extract(Document doc, Terms terms, RerankerContext context) { @Override public String getName() { - return "PMIFeature"; + return "PMI"; + } + + @Override + public String getField() { + return null; + } + + @Override + public FeatureExtractor clone() { + return new PMIFeatureExtractor(); } } diff --git a/src/main/java/io/anserini/ltr/feature/base/QueryLength.java b/src/main/java/io/anserini/ltr/feature/base/QueryLength.java index 70f5e2756c..f2a2d53b71 100644 --- a/src/main/java/io/anserini/ltr/feature/base/QueryLength.java +++ b/src/main/java/io/anserini/ltr/feature/base/QueryLength.java @@ -16,9 +16,11 @@ package io.anserini.ltr.feature.base; +import io.anserini.index.IndexArgs; import io.anserini.ltr.feature.FeatureExtractor; import io.anserini.rerank.RerankerContext; import org.apache.lucene.document.Document; +import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.Terms; import java.util.List; @@ -27,11 +29,10 @@ * QueryCount * Compute the query length (number of terms in the query). */ -public class QueryLength implements FeatureExtractor { +public class QueryLength implements FeatureExtractor { @Override - public float extract(Document doc, Terms terms, RerankerContext context) { - List queryTokens = context.getQueryTokens(); + public float extract(Document doc, Terms terms, String queryText, List queryTokens, IndexReader reader) { return queryTokens.size(); } @@ -39,4 +40,14 @@ public float extract(Document doc, Terms terms, RerankerContext context) { public String getName() { return "QueryLength"; } + + @Override + public String getField() { + return null; + } + + @Override + public FeatureExtractor clone() { + return new QueryLength(); + } } diff --git a/src/main/java/io/anserini/ltr/feature/base/SCQFeatureExtractor.java b/src/main/java/io/anserini/ltr/feature/base/SCQFeatureExtractor.java index 17ad18d9a4..40165e93af 100644 --- a/src/main/java/io/anserini/ltr/feature/base/SCQFeatureExtractor.java +++ b/src/main/java/io/anserini/ltr/feature/base/SCQFeatureExtractor.java @@ -34,7 +34,7 @@ * Avg( (1 + log(tf(t,D))) * idf(t)) found on page 33 of Carmel, Yom-Tov 2010 * D is the collection term frequency */ -public class SCQFeatureExtractor implements FeatureExtractor { +public class SCQFeatureExtractor implements FeatureExtractor { private static final Logger LOG = LogManager.getLogger(SCQFeatureExtractor.class); private String lastQueryProcessed = ""; @@ -63,16 +63,14 @@ private float sumSCQ(IndexReader reader, List queryTokens, } @Override - public float extract(Document doc, Terms terms, RerankerContext context) { - IndexReader reader = context.getIndexSearcher().getIndexReader(); - - if (!lastQueryProcessed.equals(context.getQueryText())) { - this.lastQueryProcessed = context.getQueryText(); + public float extract(Document doc, Terms terms, String queryText, List queryTokens, IndexReader reader) { + if (!lastQueryProcessed.equals(queryText)) { + this.lastQueryProcessed = queryText; this.lastComputedScore = 0.0f; try { - float sumScq = sumSCQ(reader, context.getQueryTokens(), IndexArgs.CONTENTS); - this.lastComputedScore = sumScq / context.getQueryTokens().size(); + float sumScq = sumSCQ(reader, queryTokens, IndexArgs.CONTENTS); + this.lastComputedScore = sumScq / queryTokens.size(); } catch (IOException e) { this.lastComputedScore = 0.0f; } @@ -85,4 +83,14 @@ public float extract(Document doc, Terms terms, RerankerContext context) { public String getName() { return "AvgSCQ"; } + + @Override + public String getField() { + return null; + } + + @Override + public FeatureExtractor clone() { + return new SCQFeatureExtractor(); + } } diff --git a/src/main/java/io/anserini/ltr/feature/base/SimplifiedClarityFeatureExtractor.java b/src/main/java/io/anserini/ltr/feature/base/SCSFeatureExtractor.java similarity index 79% rename from src/main/java/io/anserini/ltr/feature/base/SimplifiedClarityFeatureExtractor.java rename to src/main/java/io/anserini/ltr/feature/base/SCSFeatureExtractor.java index 67071cedb9..9439b7df9f 100644 --- a/src/main/java/io/anserini/ltr/feature/base/SimplifiedClarityFeatureExtractor.java +++ b/src/main/java/io/anserini/ltr/feature/base/SCSFeatureExtractor.java @@ -33,7 +33,7 @@ * SCS = sum (P[t|q]) * log(P[t|q] / P[t|D]) * page 20 of Carmel, Yom-Tov 2010 */ -public class SimplifiedClarityFeatureExtractor implements FeatureExtractor { +public class SCSFeatureExtractor implements FeatureExtractor { private String lastQueryProcessed = ""; private float lastComputedScore = 0.0f; @@ -67,17 +67,15 @@ private float sumSC(IndexReader reader, Map queryTokenMap, } @Override - public float extract(Document doc, Terms terms, RerankerContext context) { + public float extract(Document doc, Terms terms, String queryText, List queryTokens, IndexReader reader) { - if (!this.lastQueryProcessed.equals(context.getQueryText())) { - this.lastQueryProcessed = context.getQueryText(); + if (!this.lastQueryProcessed.equals(queryText)) { + this.lastQueryProcessed = queryText; this.lastComputedScore = 0.0f; - Map queryCountMap = queryTermMap(context.getQueryTokens()); + Map queryCountMap = queryTermMap(queryTokens); try { - this.lastComputedScore = sumSC(context.getIndexSearcher().getIndexReader(), - queryCountMap, context.getQueryTokens().size(), - IndexArgs.CONTENTS); + this.lastComputedScore = sumSC(reader, queryCountMap, queryTokens.size(), IndexArgs.CONTENTS); } catch (IOException e) { this.lastComputedScore = 0.0f; } @@ -88,6 +86,16 @@ public float extract(Document doc, Terms terms, RerankerContext context) { @Override public String getName() { - return "SimplifiedClarityScore"; + return "SCS"; + } + + @Override + public String getField() { + return null; + } + + @Override + public FeatureExtractor clone() { + return new SCSFeatureExtractor(); } } diff --git a/src/main/java/io/anserini/ltr/feature/base/TermFrequencyFeatureExtractor.java b/src/main/java/io/anserini/ltr/feature/base/SumMatchingTF.java similarity index 76% rename from src/main/java/io/anserini/ltr/feature/base/TermFrequencyFeatureExtractor.java rename to src/main/java/io/anserini/ltr/feature/base/SumMatchingTF.java index 53f32a9eb4..55ebdca79e 100644 --- a/src/main/java/io/anserini/ltr/feature/base/TermFrequencyFeatureExtractor.java +++ b/src/main/java/io/anserini/ltr/feature/base/SumMatchingTF.java @@ -16,28 +16,27 @@ package io.anserini.ltr.feature.base; +import io.anserini.index.IndexArgs; import io.anserini.ltr.feature.FeatureExtractor; import io.anserini.rerank.RerankerContext; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.lucene.document.Document; +import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.Terms; import org.apache.lucene.index.TermsEnum; import java.io.IOException; -import java.util.HashMap; -import java.util.HashSet; -import java.util.Map; -import java.util.Set; +import java.util.*; /** * Computes the sum of term frequencies for each query token. */ -public class TermFrequencyFeatureExtractor implements FeatureExtractor { - private static final Logger LOG = LogManager.getLogger(TermFrequencyFeatureExtractor.class); +public class SumMatchingTF implements FeatureExtractor { + private static final Logger LOG = LogManager.getLogger(SumMatchingTF.class); @Override - public float extract(Document doc, Terms terms, RerankerContext context) { + public float extract(Document doc, Terms terms, String queryText, List queryTokenList, IndexReader reader) { TermsEnum termsEnum = null; try { @@ -48,7 +47,7 @@ public float extract(Document doc, Terms terms, RerankerContext context) { } Map termFreqMap = new HashMap<>(); - Set queryTokens = new HashSet<>(context.getQueryTokens()); + Set queryTokens = new HashSet<>(queryTokenList); try { while (termsEnum.next() != null) { String termString = termsEnum.term().utf8ToString(); @@ -73,6 +72,16 @@ public float extract(Document doc, Terms terms, RerankerContext context) { @Override public String getName() { - return "SumTermFrequency"; + return "SumMatchingTF"; + } + + @Override + public String getField() { + return null; + } + + @Override + public FeatureExtractor clone() { + return new SumMatchingTF(); } } diff --git a/src/main/java/io/anserini/ltr/feature/base/SumMatchingTf.java b/src/main/java/io/anserini/ltr/feature/base/SumMatchingTf.java deleted file mode 100644 index 60be872a8c..0000000000 --- a/src/main/java/io/anserini/ltr/feature/base/SumMatchingTf.java +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Anserini: A Lucene toolkit for replicable information retrieval research - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.anserini.ltr.feature.base; - -import io.anserini.ltr.feature.FeatureExtractor; -import io.anserini.rerank.RerankerContext; -import org.apache.lucene.document.Document; -import org.apache.lucene.index.Terms; -import org.apache.lucene.index.TermsEnum; -import org.apache.lucene.util.BytesRef; - -import java.io.IOException; -import java.util.List; - -/** - * Computes the sum of the term frequencies of the matching terms. That is, if there are two query - * terms and the first occurs twice in the document and the second occurs once in the document, the - * sum of the matching term frequencies is three. - */ -public class SumMatchingTf implements FeatureExtractor { - - @Override - public float extract(Document doc, Terms terms, RerankerContext context) { - try { - List queryTokens = context.getQueryTokens(); - TermsEnum termsEnum = terms.iterator(); - int sum = 0; - - BytesRef text = null; - while ((text = termsEnum.next()) != null) { - String term = text.utf8ToString(); - if (queryTokens.contains(term)) { - sum += (int) termsEnum.totalTermFreq(); - } - } - return sum; - - } catch (IOException e) { - return 0; - } - } - - @Override - public String getName() { - return "SumMatchingTf"; - } -} diff --git a/src/main/java/io/anserini/ltr/feature/base/TFIDFFeatureExtractor.java b/src/main/java/io/anserini/ltr/feature/base/TFIDFFeatureExtractor.java index c749bbe874..93f7acb0a9 100644 --- a/src/main/java/io/anserini/ltr/feature/base/TFIDFFeatureExtractor.java +++ b/src/main/java/io/anserini/ltr/feature/base/TFIDFFeatureExtractor.java @@ -31,28 +31,28 @@ import java.io.IOException; import java.util.HashMap; +import java.util.List; import java.util.Map; /** * Computes the TFIDF feature according to Lucene's formula, * Not the same because we don't compute length norm or query norm, with boost 1 */ -public class TFIDFFeatureExtractor implements FeatureExtractor { +public class TFIDFFeatureExtractor implements FeatureExtractor { private static final Logger LOG = LogManager.getLogger(TFIDFFeatureExtractor.class); @Override - public float extract(Document doc, Terms terms, RerankerContext context) { + public float extract(Document doc, Terms terms, String queryText, List queryTokens, IndexReader reader) { float score = 0.0f; Map countMap = new HashMap<>(); Map docFreqs = new HashMap<>(); - IndexReader reader = context.getIndexSearcher().getIndexReader(); long numDocs = reader.numDocs(); - for (Object queryToken : context.getQueryTokens()) { + for (String queryToken : queryTokens) { try { - docFreqs.put((String)queryToken, reader.docFreq(new Term(IndexArgs.CONTENTS, (String)queryToken))); + docFreqs.put(queryToken, reader.docFreq(new Term(IndexArgs.CONTENTS, queryToken))); } catch (IOException e) { LOG.error("Error trying to read document frequency"); - docFreqs.put((String)queryToken, 0); + docFreqs.put(queryToken, 0); } } @@ -60,7 +60,7 @@ public float extract(Document doc, Terms terms, RerankerContext context) { TermsEnum termsEnum = terms.iterator(); while (termsEnum.next() != null) { String termString = termsEnum.term().utf8ToString(); - if (context.getQueryTokens().contains(termString)) { + if (queryTokens.contains(termString)) { countMap.put(termString, termsEnum.totalTermFreq()); } } @@ -75,9 +75,9 @@ public float extract(Document doc, Terms terms, RerankerContext context) { //float coord = similarity.coord(countMap.size(), context.getQueryTokens().size()); // coord removed in Lucene 7 - for (Object token : context.getQueryTokens()) { - long termFreq = countMap.getOrDefault(token.toString(), 0L); - long docFreq = docFreqs.getOrDefault(token.toString(), 0); + for (String token : queryTokens) { + long termFreq = countMap.getOrDefault(token, 0L); + long docFreq = docFreqs.getOrDefault(token, 0); float tf = similarity.tf(termFreq); float idf = similarity.idf(docFreq, numDocs); score += tf * idf*idf; @@ -92,4 +92,14 @@ public float extract(Document doc, Terms terms, RerankerContext context) { public String getName() { return "TFIDF"; } + + @Override + public String getField() { + return null; + } + + @Override + public FeatureExtractor clone() { + return new TFIDFFeatureExtractor(); + } } diff --git a/src/main/java/io/anserini/ltr/feature/base/UniqueTermCount.java b/src/main/java/io/anserini/ltr/feature/base/UniqueTermCount.java index 35a43f312e..d899cf8e94 100644 --- a/src/main/java/io/anserini/ltr/feature/base/UniqueTermCount.java +++ b/src/main/java/io/anserini/ltr/feature/base/UniqueTermCount.java @@ -16,26 +16,39 @@ package io.anserini.ltr.feature.base; +import io.anserini.index.IndexArgs; import io.anserini.ltr.feature.FeatureExtractor; import io.anserini.rerank.RerankerContext; import org.apache.lucene.document.Document; +import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.Terms; import java.util.HashSet; +import java.util.List; import java.util.Set; /** * Count of unique query terms */ -public class UniqueTermCount implements FeatureExtractor { +public class UniqueTermCount implements FeatureExtractor { @Override - public float extract(Document doc, Terms terms, RerankerContext context) { - Set queryTokens = new HashSet<>(context.getQueryTokens()); - return queryTokens.size(); + public float extract(Document doc, Terms terms, String queryText, List queryTokens, IndexReader reader) { + Set queryTokenSet = new HashSet<>(queryTokens); + return queryTokenSet.size(); } @Override public String getName() { return "UniqueQueryTerms"; } + + @Override + public String getField() { + return null; + } + + @Override + public FeatureExtractor clone() { + return new UniqueTermCount(); + } } diff --git a/src/main/java/io/anserini/ltr/feature/twitter/HashtagCount.java b/src/main/java/io/anserini/ltr/feature/twitter/HashtagCount.java index c20ab7f422..27a135f961 100644 --- a/src/main/java/io/anserini/ltr/feature/twitter/HashtagCount.java +++ b/src/main/java/io/anserini/ltr/feature/twitter/HashtagCount.java @@ -20,12 +20,15 @@ import io.anserini.ltr.feature.FeatureExtractor; import io.anserini.rerank.RerankerContext; import org.apache.lucene.document.Document; +import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.Terms; +import java.util.List; + public class HashtagCount implements FeatureExtractor { @Override - public float extract(Document doc, Terms terms, RerankerContext context) { + public float extract(Document doc, Terms terms, String queryText, List queryTokens, IndexReader reader) { String str = doc.getField(IndexArgs.CONTENTS).stringValue(); final String matchStr = "#"; @@ -47,4 +50,14 @@ public float extract(Document doc, Terms terms, RerankerContext context) { public String getName() { return "TwitterHashtagCount"; } + + @Override + public String getField() { + return IndexArgs.CONTENTS; + } + + @Override + public FeatureExtractor clone() { + return new HashtagCount(); + } } diff --git a/src/main/java/io/anserini/ltr/feature/twitter/IsTweetReply.java b/src/main/java/io/anserini/ltr/feature/twitter/IsTweetReply.java index aceddd2e27..324ea58bcf 100644 --- a/src/main/java/io/anserini/ltr/feature/twitter/IsTweetReply.java +++ b/src/main/java/io/anserini/ltr/feature/twitter/IsTweetReply.java @@ -16,15 +16,19 @@ package io.anserini.ltr.feature.twitter; +import io.anserini.index.IndexArgs; import io.anserini.index.generator.TweetGenerator.TweetField; import io.anserini.ltr.feature.FeatureExtractor; import io.anserini.rerank.RerankerContext; import org.apache.lucene.document.Document; +import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.Terms; +import java.util.List; + public class IsTweetReply implements FeatureExtractor { @Override - public float extract(Document doc, Terms terms, RerankerContext context) { + public float extract(Document doc, Terms terms, String queryText, List queryTokens, IndexReader reader) { return doc.getField(TweetField.IN_REPLY_TO_STATUS_ID.name) == null ? 0.0f : 1.0f; } @@ -32,4 +36,14 @@ public float extract(Document doc, Terms terms, RerankerContext context) { public String getName() { return "IsTweetReply"; } + + @Override + public String getField() { + return TweetField.IN_REPLY_TO_STATUS_ID.name; + } + + @Override + public FeatureExtractor clone() { + return new IsTweetReply(); + } } diff --git a/src/main/java/io/anserini/ltr/feature/twitter/LinkCount.java b/src/main/java/io/anserini/ltr/feature/twitter/LinkCount.java index 58a07032a0..501b88fc24 100644 --- a/src/main/java/io/anserini/ltr/feature/twitter/LinkCount.java +++ b/src/main/java/io/anserini/ltr/feature/twitter/LinkCount.java @@ -20,11 +20,14 @@ import io.anserini.ltr.feature.FeatureExtractor; import io.anserini.rerank.RerankerContext; import org.apache.lucene.document.Document; +import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.Terms; +import java.util.List; + public class LinkCount implements FeatureExtractor { @Override - public float extract(Document doc, Terms terms, RerankerContext context) { + public float extract(Document doc, Terms terms, String queryText, List queryTokens, IndexReader reader) { final String str = doc.getField(IndexArgs.CONTENTS).stringValue(); final String matchStr = "http://"; @@ -46,4 +49,14 @@ public float extract(Document doc, Terms terms, RerankerContext context) { public String getName() { return "TwitterLinkCount"; } + + @Override + public String getField() { + return IndexArgs.CONTENTS; + } + + @Override + public FeatureExtractor clone() { + return new LinkCount(); + } } diff --git a/src/main/java/io/anserini/ltr/feature/twitter/TwitterFollowerCount.java b/src/main/java/io/anserini/ltr/feature/twitter/TwitterFollowerCount.java index 01d5e40dac..43f846aebe 100644 --- a/src/main/java/io/anserini/ltr/feature/twitter/TwitterFollowerCount.java +++ b/src/main/java/io/anserini/ltr/feature/twitter/TwitterFollowerCount.java @@ -16,15 +16,19 @@ package io.anserini.ltr.feature.twitter; +import io.anserini.index.IndexArgs; import io.anserini.index.generator.TweetGenerator.TweetField; import io.anserini.ltr.feature.FeatureExtractor; import io.anserini.rerank.RerankerContext; import org.apache.lucene.document.Document; +import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.Terms; +import java.util.List; + public class TwitterFollowerCount implements FeatureExtractor { @Override - public float extract(Document doc, Terms terms, RerankerContext context) { + public float extract(Document doc, Terms terms, String queryText, List queryTokens, IndexReader reader) { return (float) (int) doc.getField(TweetField.FOLLOWERS_COUNT.name).numericValue(); } @@ -32,4 +36,14 @@ public float extract(Document doc, Terms terms, RerankerContext context) { public String getName() { return "TwitterFollowerCount"; } + + @Override + public String getField() { + return TweetField.FOLLOWERS_COUNT.name; + } + + @Override + public FeatureExtractor clone() { + return new TwitterFollowerCount(); + } } diff --git a/src/main/java/io/anserini/ltr/feature/twitter/TwitterFriendCount.java b/src/main/java/io/anserini/ltr/feature/twitter/TwitterFriendCount.java index 83b5de8378..3fe89b9ff9 100644 --- a/src/main/java/io/anserini/ltr/feature/twitter/TwitterFriendCount.java +++ b/src/main/java/io/anserini/ltr/feature/twitter/TwitterFriendCount.java @@ -16,15 +16,19 @@ package io.anserini.ltr.feature.twitter; +import io.anserini.index.IndexArgs; import io.anserini.index.generator.TweetGenerator.TweetField; import io.anserini.ltr.feature.FeatureExtractor; import io.anserini.rerank.RerankerContext; import org.apache.lucene.document.Document; +import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.Terms; +import java.util.List; + public class TwitterFriendCount implements FeatureExtractor { @Override - public float extract(Document doc, Terms terms, RerankerContext context) { + public float extract(Document doc, Terms terms, String queryText, List queryTokens, IndexReader reader) { return (float) (int) doc.getField(TweetField.FRIENDS_COUNT.name).numericValue(); } @@ -32,4 +36,14 @@ public float extract(Document doc, Terms terms, RerankerContext context) { public String getName() { return "TwitterFriendCount"; } + + @Override + public String getField() { + return TweetField.FRIENDS_COUNT.name; + } + + @Override + public FeatureExtractor clone() { + return new TwitterFriendCount(); + } } diff --git a/src/test/java/io/anserini/ltr/ICTFFeatureExtractorTest.java b/src/test/java/io/anserini/ltr/AvgICTFFeatureExtractorTest.java similarity index 69% rename from src/test/java/io/anserini/ltr/ICTFFeatureExtractorTest.java rename to src/test/java/io/anserini/ltr/AvgICTFFeatureExtractorTest.java index 0451de1271..c1a4a03cf4 100644 --- a/src/test/java/io/anserini/ltr/ICTFFeatureExtractorTest.java +++ b/src/test/java/io/anserini/ltr/AvgICTFFeatureExtractorTest.java @@ -16,46 +16,48 @@ package io.anserini.ltr; -import io.anserini.ltr.feature.FeatureExtractors; +import io.anserini.ltr.feature.FeatureExtractor; import io.anserini.ltr.feature.base.AvgICTFFeatureExtractor; import org.junit.Test; import java.io.IOException; +import java.util.concurrent.ExecutionException; /** * Test ICTF feature extractor is implemented according to * the Carmel, Yom-Tov synthesis series book */ -public class ICTFFeatureExtractorTest extends BaseFeatureExtractorTest { +public class AvgICTFFeatureExtractorTest extends BaseFeatureExtractorTest { + + private static FeatureExtractor EXTRACTOR = new AvgICTFFeatureExtractor(); - private static FeatureExtractors EXTRACTOR = getChain(new AvgICTFFeatureExtractor()); @Test - public void testSingleQueryPhrase() throws IOException { + public void testSingleQueryPhrase() throws IOException, ExecutionException, InterruptedException { float[] expected = {0}; assertFeatureValues(expected, "document", "document", EXTRACTOR); } @Test - public void testSingleQuery2() throws IOException { + public void testSingleQuery2() throws IOException, ExecutionException, InterruptedException { float[] expected = {1.38629f}; assertFeatureValues(expected, "document", "document multiple tokens more", EXTRACTOR); } @Test - public void testSingleQuery3() throws IOException { + public void testSingleQuery3() throws IOException, ExecutionException, InterruptedException { float[] expected = {0.693147f}; assertFeatureValues(expected, "document", "document document test more tokens document", EXTRACTOR); } @Test - public void testMultiQuery() throws IOException { + public void testMultiQuery() throws IOException, ExecutionException, InterruptedException { float[] expected = {0.20273f}; assertFeatureValues(expected, "document test", "document document missing", EXTRACTOR); } @Test - public void testMultiQuery2() throws IOException { + public void testMultiQuery2() throws IOException, ExecutionException, InterruptedException { // log(8/3)*0.5 + log(8/2) * 0.5 float[] expected = {1.18356f}; assertFeatureValues(expected, "document test", "document document test test more tokens document tokens", EXTRACTOR); diff --git a/src/test/java/io/anserini/ltr/AvgIDFFeatureExtractorTest.java b/src/test/java/io/anserini/ltr/AvgIDFFeatureExtractorTest.java index adf7068045..3fb578dcd6 100644 --- a/src/test/java/io/anserini/ltr/AvgIDFFeatureExtractorTest.java +++ b/src/test/java/io/anserini/ltr/AvgIDFFeatureExtractorTest.java @@ -16,31 +16,33 @@ package io.anserini.ltr; -import io.anserini.ltr.feature.FeatureExtractors; +import io.anserini.ltr.feature.FeatureExtractor; import io.anserini.ltr.feature.base.AvgIDFFeatureExtractor; import org.junit.Test; import java.io.IOException; import java.util.Arrays; +import java.util.List; +import java.util.concurrent.ExecutionException; public class AvgIDFFeatureExtractorTest extends BaseFeatureExtractorTest { - private static FeatureExtractors EXTRACTOR = getChain(new AvgIDFFeatureExtractor()); + private static FeatureExtractor EXTRACTOR = new AvgIDFFeatureExtractor(); @Test - public void testSingleDoc() throws IOException { + public void testSingleDoc() throws IOException, ExecutionException, InterruptedException { float[] expected = {0.2876f}; assertFeatureValues(expected, "document", "test document", EXTRACTOR); } @Test - public void testSingleDocMissingToken() throws IOException { + public void testSingleDocMissingToken() throws IOException, ExecutionException, InterruptedException { float[] expected = {0.836985f}; assertFeatureValues(expected, "document test", "document missing token", EXTRACTOR); } @Test - public void testMultipleDocMultipleTokens() throws IOException { + public void testMultipleDocMultipleTokens() throws IOException, ExecutionException, InterruptedException { // N = 7 // N_document = 4 0.57536 // N_token = 0 2.77258 diff --git a/src/test/java/io/anserini/ltr/BM25FeatureExtractorTest.java b/src/test/java/io/anserini/ltr/BM25FeatureExtractorTest.java index 51da345bae..b69d8f6902 100644 --- a/src/test/java/io/anserini/ltr/BM25FeatureExtractorTest.java +++ b/src/test/java/io/anserini/ltr/BM25FeatureExtractorTest.java @@ -22,6 +22,8 @@ import java.io.IOException; import java.util.Arrays; +import java.util.List; +import java.util.concurrent.ExecutionException; /** * Tests that BM25 score is computed according to our formula @@ -31,9 +33,12 @@ public class BM25FeatureExtractorTest extends BaseFeatureExtractorTest private static final FeatureExtractor EXTRACTOR = new BM25FeatureExtractor(0.9,0.4); // 1.25,0.75 private static final FeatureExtractor EXTRACTOR2 = new BM25FeatureExtractor(); + private static List EXTRACTORS = getChain(EXTRACTOR, EXTRACTOR2); + + @Test - public void testSingleDocSingleQuery() throws IOException { + public void testSingleDocSingleQuery() throws IOException, ExecutionException, InterruptedException { String docText = "single document test case"; String queryText = "test"; //df, tf =1, avgFL = 4, numDocs = 1 @@ -43,12 +48,12 @@ public void testSingleDocSingleQuery() throws IOException { // 0.287682 * 2.25 / (1 + 1.25 *(0.25 + 0.75)) = 0.287682 float[] expected = {0.287682f,0.287682f}; - assertFeatureValues(expected, queryText, docText, getChain(EXTRACTOR, EXTRACTOR2)); + assertFeatureValues(expected, queryText, docText, EXTRACTORS); } @Test - public void testSingleDocMultiQuery() throws IOException { + public void testSingleDocMultiQuery() throws IOException, ExecutionException, InterruptedException { String docText = "single document test case"; String queryText = "test document"; //df, tf =1, avgFL = 4, numDocs = 1 @@ -58,12 +63,12 @@ public void testSingleDocMultiQuery() throws IOException { // 0.287682 * 2.25 / (1 + 1.25 *(0.25 + 0.75)) = 0.287682 float[] expected = {0.575364f,0.575364f}; - assertFeatureValues(expected, queryText, docText, getChain(EXTRACTOR, EXTRACTOR2)); + assertFeatureValues(expected, queryText, docText, EXTRACTORS); } @Test - public void testMultiDocSingleQuery() throws IOException { + public void testMultiDocSingleQuery() throws IOException, ExecutionException, InterruptedException { String queryText = "test"; //df , tf =1, avgFL = 3, numDocs = 3 //idf = log(1 + (3- 1 + 0.5 / 1 + 0.5)) = 0.98082 @@ -73,12 +78,12 @@ public void testMultiDocSingleQuery() throws IOException { float[] expected = {0.92255f,0.8612f}; assertFeatureValues(expected, queryText, Arrays.asList("single document test case", - "another document", "yet another document"), getChain(EXTRACTOR, EXTRACTOR2),0); + "another document", "yet another document"), EXTRACTORS,0); } @Test - public void testMultiDocMultiQuery() throws IOException { + public void testMultiDocMultiQuery() throws IOException, ExecutionException, InterruptedException { String queryText = "test document"; //df , tf =1, avgFL = 3, numDocs = 3 //idf = log(1 + (3- 1 + 0.5 / 1 + 0.5)) = 0.98082 @@ -92,11 +97,11 @@ public void testMultiDocMultiQuery() throws IOException { float[] expected = {1.04814f,0.97844f}; assertFeatureValues(expected, queryText, Arrays.asList("single document test case", - "another document", "yet another document"), getChain(EXTRACTOR, EXTRACTOR2),0); + "another document", "yet another document"), EXTRACTORS,0); } @Test - public void testMultiDocMultiQuery2() throws IOException { + public void testMultiDocMultiQuery2() throws IOException, ExecutionException, InterruptedException { String queryText = "test document"; //df , tf =1, avgFL = 3, numDocs = 3 //idf = log(1 + (3- 1 + 0.5 / 1 + 0.5)) = 0.98082 @@ -110,7 +115,7 @@ public void testMultiDocMultiQuery2() throws IOException { float[] expected = {1.30555f,1.2435f}; assertFeatureValues(expected, queryText, Arrays.asList("single document test case test", - "another document", "more document"), getChain(EXTRACTOR, EXTRACTOR2),0); + "another document", "more document"), EXTRACTORS,0); } diff --git a/src/test/java/io/anserini/ltr/BaseFeatureExtractorTest.java b/src/test/java/io/anserini/ltr/BaseFeatureExtractorTest.java index 74c5e1b038..35711cc247 100644 --- a/src/test/java/io/anserini/ltr/BaseFeatureExtractorTest.java +++ b/src/test/java/io/anserini/ltr/BaseFeatureExtractorTest.java @@ -19,34 +19,22 @@ import io.anserini.analysis.AnalyzerUtils; import io.anserini.index.IndexArgs; import io.anserini.ltr.feature.FeatureExtractor; -import io.anserini.ltr.feature.FeatureExtractors; -import io.anserini.rerank.RerankerContext; import org.apache.lucene.analysis.Analyzer; import org.apache.lucene.analysis.en.EnglishAnalyzer; import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; import org.apache.lucene.document.FieldType; -import org.apache.lucene.index.DirectoryReader; -import org.apache.lucene.index.IndexOptions; -import org.apache.lucene.index.IndexReader; -import org.apache.lucene.index.IndexWriter; -import org.apache.lucene.index.IndexWriterConfig; -import org.apache.lucene.index.Terms; -import org.apache.lucene.queryparser.classic.ParseException; -import org.apache.lucene.queryparser.classic.QueryParser; -import org.apache.lucene.search.IndexSearcher; -import org.apache.lucene.store.ByteBuffersDirectory; +import org.apache.lucene.document.StringField; +import org.apache.lucene.index.*; import org.apache.lucene.store.Directory; -import org.apache.lucene.store.MockDirectoryWrapper; +import org.apache.lucene.store.FSDirectory; import org.apache.lucene.util.LuceneTestCase; import org.junit.After; import org.junit.Before; import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.Random; +import java.util.*; +import java.util.concurrent.ExecutionException; /** * This class will contain setup and teardown code for testing feature extractors @@ -54,65 +42,29 @@ abstract public class BaseFeatureExtractorTest extends LuceneTestCase { protected static final String TEST_FIELD_NAME = IndexArgs.CONTENTS; protected static final Analyzer TEST_ANALYZER = new EnglishAnalyzer(); - protected static final QueryParser TEST_PARSER = new QueryParser(TEST_FIELD_NAME, TEST_ANALYZER); - protected static final String DEFAULT_QID = "1"; // Acceptable delta for float assert protected static final float DELTA = 0.01f; - protected Directory DIRECTORY; + protected Directory DIRECTORY; protected IndexWriter testWriter; - /** - * A lot of feature extractors are tested individually, easy way to wrap the chain - * @param extractors The extractors - * @return - */ - protected static FeatureExtractors getChain(FeatureExtractor... extractors ) { - FeatureExtractors chain = new FeatureExtractors(); - for (FeatureExtractor extractor : extractors) { - chain.add(extractor); - } - return chain; + protected static List getChain(FeatureExtractor... extractors ) { + return Arrays.asList(extractors); } - protected Document addTestDocument(String testText) throws IOException { + protected void addTestDocument(String testText, String docId) throws IOException { FieldType fieldType = new FieldType(); fieldType.setStored(true); fieldType.setStoreTermVectors(true); - fieldType.setStoreTermVectorOffsets(true); fieldType.setStoreTermVectorPositions(true); - fieldType.setIndexOptions(IndexOptions.DOCS_AND_FREQS_AND_POSITIONS_AND_OFFSETS); + fieldType.setIndexOptions(IndexOptions.DOCS_AND_FREQS_AND_POSITIONS); Field field = new Field(TEST_FIELD_NAME, testText, fieldType); Document doc = new Document(); doc.add(field); + doc.add(new StringField(IndexArgs.ID, docId, Field.Store.YES)); testWriter.addDocument(doc); testWriter.commit(); - return doc; - } - - /** - * The reranker context constructed will return with a searcher - * and the query we want with dummy query ids and null filter - * @return - */ - @SuppressWarnings("unchecked") - protected RerankerContext makeTestContext(String queryText) { - try { - RerankerContext context = new RerankerContext( - new IndexSearcher(DirectoryReader.open(DIRECTORY)), - (T) DEFAULT_QID, - TEST_PARSER.parse(queryText), - null, - queryText, - AnalyzerUtils.analyze(TEST_ANALYZER, queryText), - null, null); - return context; - } catch (ParseException e) { - return null; - } catch (IOException e) { - return null; - } } /** @@ -124,8 +76,7 @@ protected RerankerContext makeTestContext(String queryText) { @Before public void setUp() throws Exception { super.setUp(); - // Use a RAMDirectory instead of MemoryIndex because we might test with multiple documents - DIRECTORY = new MockDirectoryWrapper(new Random(), new ByteBuffersDirectory()); + DIRECTORY = FSDirectory.open(createTempDir()); testWriter = new IndexWriter(DIRECTORY, new IndexWriterConfig(TEST_ANALYZER)); } @@ -147,14 +98,20 @@ public void tearDown() throws Exception { * @param queryText */ protected void assertFeatureValues(float[] expected, String queryText, String docText, - FeatureExtractors extractors) throws IOException { + List extractors) throws IOException, ExecutionException, InterruptedException { assertFeatureValues(expected, queryText, Arrays.asList(docText), extractors,0); } // just add a signature for single extractor protected void assertFeatureValues(float[] expected, String queryText, String docText, - FeatureExtractor extractor) throws IOException { - assertFeatureValues(expected, queryText, Arrays.asList(docText), getChain(extractor),0); + FeatureExtractor extractor) throws IOException, ExecutionException, InterruptedException { + assertFeatureValues(expected, queryText, Arrays.asList(docText), Arrays.asList(extractor),0); + } + + // just add a signature for single extractor + protected void assertFeatureValues(float[] expected, String queryText, List docTexts, + FeatureExtractor extractor, int docToExtract) throws IOException, ExecutionException, InterruptedException { + assertFeatureValues(expected, queryText, docTexts, Arrays.asList(extractor),docToExtract); } /** @@ -166,21 +123,33 @@ protected void assertFeatureValues(float[] expected, String queryText, String do * @param docToExtract Index of the document we want to compute features for */ protected void assertFeatureValues(float[] expected, String queryText, List docTexts, - FeatureExtractors extractors, int docToExtract) throws IOException { - List addedDocs = new ArrayList<>(); + List extractors, int docToExtract) throws IOException, ExecutionException, InterruptedException { + int id = 0; for (String docText : docTexts) { - Document testDoc = addTestDocument(docText); - addedDocs.add(testDoc); + addTestDocument(docText, String.format("doc%s", id)); + id += 1; } testWriter.forceMerge(1); - Document testDoc = addedDocs.get(docToExtract); - RerankerContext context = makeTestContext(queryText); - IndexReader reader = context.getIndexSearcher().getIndexReader(); - Terms terms = reader.getTermVector(docToExtract, TEST_FIELD_NAME); - float[] extractedFeatureValues = extractors.extractAll(testDoc, terms, context); - - assertArrayEquals(expected, extractedFeatureValues, DELTA); + FeatureExtractorUtils utils = new FeatureExtractorUtils(DirectoryReader.open(DIRECTORY)); + for(FeatureExtractor extractor: extractors){ + utils.add(extractor); + } + String docIdToExtract = String.format("doc%s", docToExtract); + ArrayList extractedFeatureValues = utils.extract(AnalyzerUtils.analyze(TEST_ANALYZER,queryText), Arrays.asList(docIdToExtract)); + List extractFeatures = null; + for(output doc: extractedFeatureValues) { + if(doc.pid.equals(docIdToExtract)) + if(extractFeatures == null) + extractFeatures = doc.features; + } + float[] extractFeaturesArray = new float[extractFeatures.size()]; + for (int i=0; i < extractFeaturesArray.length; i++) + { + extractFeaturesArray[i] = extractFeatures.get(i).floatValue(); + } + assertArrayEquals(expected, extractFeaturesArray, DELTA); + utils.close(); } } diff --git a/src/test/java/io/anserini/ltr/BigramFeaturesTest.java b/src/test/java/io/anserini/ltr/BigramFeaturesTest.java index b3e18037d7..092264db5d 100644 --- a/src/test/java/io/anserini/ltr/BigramFeaturesTest.java +++ b/src/test/java/io/anserini/ltr/BigramFeaturesTest.java @@ -16,14 +16,13 @@ package io.anserini.ltr; -import io.anserini.ltr.feature.FeatureExtractors; -import io.anserini.ltr.feature.OrderedQueryPairsFeatureExtractor; -import io.anserini.ltr.feature.OrderedSequentialPairsFeatureExtractor; -import io.anserini.ltr.feature.UnorderedQueryPairsFeatureExtractor; -import io.anserini.ltr.feature.UnorderedSequentialPairsFeatureExtractor; +import io.anserini.ltr.feature.*; import org.junit.Test; import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ExecutionException; /** * Create some temporary documents and test the correctness of ordered and unordered @@ -31,51 +30,62 @@ */ public class BigramFeaturesTest extends BaseFeatureExtractorTest { - private FeatureExtractors getUnorderedChain() { - FeatureExtractors chain = new FeatureExtractors(); - chain.add(new UnorderedSequentialPairsFeatureExtractor(2)); - chain.add(new UnorderedSequentialPairsFeatureExtractor(4)); - chain.add(new UnorderedSequentialPairsFeatureExtractor(6)); - return chain; + private List getUnorderedChain() { + return getChain( + new UnorderedSequentialPairsFeatureExtractor(2), + new UnorderedSequentialPairsFeatureExtractor(4), + new UnorderedSequentialPairsFeatureExtractor(6) + ); } - private FeatureExtractors getOrderedChain() { - FeatureExtractors chain = new FeatureExtractors(); - chain.add(new OrderedSequentialPairsFeatureExtractor(2)); - chain.add(new OrderedSequentialPairsFeatureExtractor(4)); - chain.add(new OrderedSequentialPairsFeatureExtractor(6)); - return chain; + private List getOrderedChain() { + return getChain( + new OrderedSequentialPairsFeatureExtractor(2), + new OrderedSequentialPairsFeatureExtractor(4), + new OrderedSequentialPairsFeatureExtractor(6) + ); } - private FeatureExtractors getMixedChain() { - FeatureExtractors chain = new FeatureExtractors(); - chain.add(new OrderedSequentialPairsFeatureExtractor(2)); - chain.add(new OrderedSequentialPairsFeatureExtractor(4)); - chain.add(new OrderedSequentialPairsFeatureExtractor(6)); - chain.add(new UnorderedSequentialPairsFeatureExtractor(2)); - chain.add(new UnorderedSequentialPairsFeatureExtractor(4)); - chain.add(new UnorderedSequentialPairsFeatureExtractor(6)); - return chain; + private List getMixedChain() { + return getChain( + new OrderedSequentialPairsFeatureExtractor(2), + new OrderedSequentialPairsFeatureExtractor(4), + new OrderedSequentialPairsFeatureExtractor(6), + new UnorderedSequentialPairsFeatureExtractor(2), + new UnorderedSequentialPairsFeatureExtractor(4), + new UnorderedSequentialPairsFeatureExtractor(6) + ); } - private FeatureExtractors getAllPairsOrdered() { - FeatureExtractors chain = new FeatureExtractors(); - chain.add(new OrderedQueryPairsFeatureExtractor(2)); - chain.add(new OrderedQueryPairsFeatureExtractor(4)); - chain.add(new OrderedQueryPairsFeatureExtractor(6)); - return chain; + private List getAllPairsOrdered() { + return getChain( + new OrderedQueryPairsFeatureExtractor(2), + new OrderedQueryPairsFeatureExtractor(4), + new OrderedQueryPairsFeatureExtractor(6) + ); } - private FeatureExtractors getAllPairsUnOrdered() { - FeatureExtractors chain = new FeatureExtractors(); - chain.add(new UnorderedQueryPairsFeatureExtractor(2)); - chain.add(new UnorderedQueryPairsFeatureExtractor(4)); - chain.add(new UnorderedQueryPairsFeatureExtractor(6)); - return chain; + private List getAllPairsUnOrdered() { + return getChain( + new UnorderedQueryPairsFeatureExtractor(2), + new UnorderedQueryPairsFeatureExtractor(4), + new UnorderedQueryPairsFeatureExtractor(6) + ); } + private List getMixedSequentialAllPairs() { + return getChain( + new OrderedSequentialPairsFeatureExtractor(2), + new UnorderedSequentialPairsFeatureExtractor(2), + new OrderedQueryPairsFeatureExtractor(2), + new UnorderedQueryPairsFeatureExtractor(2) + ); + } + + private static FeatureExtractor bigram = new OrderedSequentialPairsFeatureExtractor(1); + @Test - public void testSimpleQuery () throws IOException { + public void testSimpleQuery () throws IOException, ExecutionException, InterruptedException { String testText = "a simple document"; String testQuery = "simple document"; float[] expected = {1,1,1}; @@ -84,7 +94,7 @@ public void testSimpleQuery () throws IOException { } @Test - public void testMultipleUnorderedQuery() throws IOException { + public void testMultipleUnorderedQuery() throws IOException, ExecutionException, InterruptedException { String testText = "document more token simple test case"; String testQuery = "simple document test case"; @@ -99,7 +109,7 @@ public void testMultipleUnorderedQuery() throws IOException { } @Test - public void testMixedMultipleQuery() throws IOException { + public void testMixedMultipleQuery() throws IOException, ExecutionException, InterruptedException { String testText = "bunch words document test simple case document test case simple, test document"; String testQuery = "document test"; @@ -114,7 +124,7 @@ public void testMixedMultipleQuery() throws IOException { } @Test - public void testSimpleCountOrderedAllPairs() throws IOException { + public void testSimpleCountOrderedAllPairs() throws IOException, ExecutionException, InterruptedException { String testText = "bunch words document test simple case large text length size"; String testQuery = "bunch words test"; @@ -124,7 +134,7 @@ public void testSimpleCountOrderedAllPairs() throws IOException { } @Test - public void testSimpleCountUnorderedAllPairs() throws IOException { + public void testSimpleCountUnorderedAllPairs() throws IOException, ExecutionException, InterruptedException { String testText = "bunch words document test simple case large text length size"; String testQuery = "test document text"; @@ -134,7 +144,7 @@ public void testSimpleCountUnorderedAllPairs() throws IOException { } @Test - public void testDuplicateStartingTokens() throws IOException { + public void testDuplicateStartingTokens() throws IOException, ExecutionException, InterruptedException { String testText = "document test document bunch"; String testQuery = "document test document bunch"; @@ -148,7 +158,7 @@ public void testDuplicateStartingTokens() throws IOException { } @Test - public void testDuplicateAllPairs() throws IOException { + public void testDuplicateAllPairs() throws IOException, ExecutionException, InterruptedException { String testText = "document case document test bunch"; String testQuery = "document case test"; @@ -163,45 +173,38 @@ public void testDuplicateAllPairs() throws IOException { } @Test - public void testMixedSequentialAllPairs() throws IOException { + public void testMixedSequentialAllPairs() throws IOException, ExecutionException, InterruptedException { String testText = "document test word word word test case word document"; String testQuery = "document test case"; // document test, test bunch, bunch document float[] expected = {2,2,2,3}; - assertFeatureValues(expected, testQuery, testText, - getChain(new OrderedSequentialPairsFeatureExtractor(2), - new UnorderedSequentialPairsFeatureExtractor(2), - new OrderedQueryPairsFeatureExtractor(2), - new UnorderedQueryPairsFeatureExtractor(2))); + assertFeatureValues(expected, testQuery, testText, getMixedSequentialAllPairs()); } @Test - public void testSimpleBigramCount() throws IOException { + public void testSimpleBigramCount() throws IOException, ExecutionException, InterruptedException { String testText = "document test document test"; String testQuery = "missing phrase"; float[] expected = {0.0f}; - assertFeatureValues(expected, testQuery, testText, - getChain(new OrderedSequentialPairsFeatureExtractor(1))); + assertFeatureValues(expected, testQuery, testText, bigram); } @Test - public void testSimpleBigramCount2() throws IOException { + public void testSimpleBigramCount2() throws IOException, ExecutionException, InterruptedException { String testText = "document test document test"; String testQuery = "document tests"; float[] expected = {2f}; - assertFeatureValues(expected, testQuery, testText, - getChain(new OrderedSequentialPairsFeatureExtractor(1))); + assertFeatureValues(expected, testQuery, testText, bigram); } @Test - public void testBigramCountMultiple() throws IOException { + public void testBigramCountMultiple() throws IOException, ExecutionException, InterruptedException { String testText = "test document test document multiple tokens multiple phrase"; String testQuery = "test document multiple"; //test document x 2 + document multiple float[] expected = {3f}; - assertFeatureValues(expected, testQuery, testText, - getChain(new OrderedSequentialPairsFeatureExtractor(1))); + assertFeatureValues(expected, testQuery, testText, bigram); } } diff --git a/src/test/java/io/anserini/ltr/DocSizeFeatureExtractorTest.java b/src/test/java/io/anserini/ltr/DocSizeFeatureExtractorTest.java index ac4a841a7b..a79ecbea0f 100644 --- a/src/test/java/io/anserini/ltr/DocSizeFeatureExtractorTest.java +++ b/src/test/java/io/anserini/ltr/DocSizeFeatureExtractorTest.java @@ -16,30 +16,34 @@ package io.anserini.ltr; +import io.anserini.ltr.feature.FeatureExtractor; import io.anserini.ltr.feature.base.DocSizeFeatureExtractor; import org.junit.Test; import java.io.IOException; import java.util.Arrays; +import java.util.concurrent.ExecutionException; /** * Test we get the doc size correctly */ public class DocSizeFeatureExtractorTest extends BaseFeatureExtractorTest { + private static FeatureExtractor EXTRACTOR = new DocSizeFeatureExtractor(); + @Test - public void testSingleDoc() throws IOException { + public void testSingleDoc() throws IOException, ExecutionException, InterruptedException { float[] expected = {5}; assertFeatureValues(expected, "query text can't be empty", "document size independent of query document", - new DocSizeFeatureExtractor()); + EXTRACTOR); } @Test - public void testMultipleDocs() throws IOException { + public void testMultipleDocs() throws IOException, ExecutionException, InterruptedException { float[] expected = {5}; assertFeatureValues(expected, "query text", Arrays.asList("first document", "second document", "test document document document test"), - getChain(new DocSizeFeatureExtractor()), 2); + EXTRACTOR, 2); } } diff --git a/src/test/java/io/anserini/ltr/FeatureExtractionArgsTest.java b/src/test/java/io/anserini/ltr/FeatureExtractionArgsTest.java deleted file mode 100644 index 92e513e98b..0000000000 --- a/src/test/java/io/anserini/ltr/FeatureExtractionArgsTest.java +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Anserini: A Lucene toolkit for replicable information retrieval research - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.anserini.ltr; - -import io.anserini.ltr.FeatureExtractorCli.FeatureExtractionArgs; -import io.anserini.search.topicreader.MicroblogTopicReader; -import io.anserini.search.topicreader.TopicReader; -import io.anserini.search.topicreader.TrecTopicReader; -import io.anserini.search.topicreader.WebxmlTopicReader; -import org.junit.Assert; -import org.junit.Test; -import org.kohsuke.args4j.CmdLineException; -import org.kohsuke.args4j.CmdLineParser; -import org.kohsuke.args4j.ParserProperties; - -/** - * Tests that Arguments for the {@link FeatureExtractorCli} can be parsed. - */ -public class FeatureExtractionArgsTest { - - @Test - public void checkThatTopicReaderForCluewebCollectionCanBeCreated() throws Exception { - FeatureExtractionArgs args = createFeatureExtractionArgsWithCollection("clueweb"); - TopicReader topicReaderForCollection = args.buildTopicReaderForCollection(); - - Assert.assertEquals(WebxmlTopicReader.class, topicReaderForCollection.getClass()); - } - - @Test - public void checkThatTopicReaderForGov2CollectionCanBeCreated() throws Exception { - FeatureExtractionArgs args = createFeatureExtractionArgsWithCollection("gov2"); - TopicReader topicReaderForCollection = args.buildTopicReaderForCollection(); - - Assert.assertEquals(TrecTopicReader.class, topicReaderForCollection.getClass()); - } - - @Test - public void checkThatTopicReaderForTwitterCollectionCanBeCreated() throws Exception { - FeatureExtractionArgs args = createFeatureExtractionArgsWithCollection("twitter"); - TopicReader topicReaderForCollection = args.buildTopicReaderForCollection(); - - Assert.assertEquals(MicroblogTopicReader.class, topicReaderForCollection.getClass()); - } - - private static FeatureExtractionArgs createFeatureExtractionArgsWithCollection(String collection) throws CmdLineException { - String[] args = createProgramArgsWithCollection(collection); - return parseFeatureExtractionArgs(args); - } - - private static String[] createProgramArgsWithCollection(String collection) { - return new String[] { "-index", "example-index-arg", "-qrel", "example-qrel-arg", "-topic", "example-topic-arg", - "-out", "example-out-arg", "-collection", collection }; - } - - private static FeatureExtractionArgs parseFeatureExtractionArgs(String[] args) throws CmdLineException { - FeatureExtractionArgs parsedArgs = new FeatureExtractionArgs(); - CmdLineParser parser = new CmdLineParser(parsedArgs, ParserProperties.defaults().withUsageWidth(90)); - - parser.parseArgument(args); - - return parsedArgs; - } -} diff --git a/src/test/java/io/anserini/ltr/FeatureExtractorChainFromJsonTest.java b/src/test/java/io/anserini/ltr/FeatureExtractorChainFromJsonTest.java deleted file mode 100644 index e7ee664fc2..0000000000 --- a/src/test/java/io/anserini/ltr/FeatureExtractorChainFromJsonTest.java +++ /dev/null @@ -1,98 +0,0 @@ -/* - * Anserini: A Lucene toolkit for replicable information retrieval research - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.anserini.ltr; - -import com.fasterxml.jackson.core.JsonFactory; -import com.fasterxml.jackson.core.JsonParser; -import io.anserini.ltr.feature.FeatureExtractors; -import org.junit.Test; - -/** - * Will test that constructing a feature extractor chain works correctly - */ -public class FeatureExtractorChainFromJsonTest extends BaseFeatureExtractorTest{ - - private JsonFactory jsonFactory = new JsonFactory(); - - @Test - public void testEmptyChain() throws Exception { - String jsonString = "{extractors: []}"; - JsonParser jsonParser = jsonFactory.createParser(jsonString); - - FeatureExtractors emptyChain = FeatureExtractors.fromJson(jsonParser); - assertNotNull(emptyChain); - } - - @Test - public void testChainSingleExtractorNoParam() throws Exception { - String jsonString = "{extractors: [ {name: \"AvgSCQ\"} ]}"; - JsonParser jsonParser = jsonFactory.createParser(jsonString); - String testText = "test document"; - String testQuery = "document"; - //idf = 0.28768 - //tf =1 - float [] expected = {-0.24590f}; - FeatureExtractors chain = FeatureExtractors.fromJson(jsonParser); - assertFeatureValues(expected, testQuery, testText, chain); - } - - @Test - public void testChainSingleExtractorParam() throws Exception { - String jsonString = "{extractors: [ {name: \"BM25Feature\", params: {k1:0.9, b:0.4}} ]}"; - JsonParser jsonParser = jsonFactory.createParser(jsonString); - String docText = "single document test case"; - String queryText = "test"; - //df, tf =1, avgFL = 4, numDocs = 1 - //idf = log(1 + (0.5 / 1 + 0.5)) = 0.287682 - - // 0.287682* 1.9 / (1 + 0.9 * (0.6 + 0.4 * (4/4))) = 1 * 0.287682 - // 0.287682 * 2.25 / (1 + 1.25 *(0.25 + 0.75)) = 0.287682 - float[] expected = {0.287682f}; - FeatureExtractors chain = FeatureExtractors.fromJson(jsonParser); - assertFeatureValues(expected, queryText, docText, chain); - } - - @Test - public void testMultipleExtractorNoParam() throws Exception { - String jsonString = "{extractors: [ {name: \"AvgIDF\"}, {name: \"SumTermFrequency\"} ]}"; - JsonParser jsonParser = jsonFactory.createParser(jsonString); - String docText = "document missing token"; - String queryText = "document test"; - float[] expected = {0.836985f, 1f}; - - FeatureExtractors chain = FeatureExtractors.fromJson(jsonParser); - - assertFeatureValues(expected, queryText, docText, chain); - } - - @Test - public void testMultipleExtractorMixed() throws Exception { - String jsonString = "{extractors: [ {name: \"DocSize\"}, {name: \"QueryLength\"}," + - "{name: \"OrderedSequentialPairs\", params:{gapSize: 2}}, {name: \"UnorderedSequentialPairs\", params:{gapSize : 2}}" + - ", {name: \"OrderedSequentialPairs\", params:{gapSize: 5}} ]}"; - JsonParser jsonParser = jsonFactory.createParser(jsonString); - String testText = "document test word word word test bunch word document"; - String testQuery = "document test bunch"; - FeatureExtractors chain = FeatureExtractors.fromJson(jsonParser); - - // document test, test bunch, bunch document - float[] expected = {9f, 3f, 2f, 2f, 4f}; - assertFeatureValues(expected, testQuery, testText,chain); - - - } -} diff --git a/src/test/java/io/anserini/ltr/LoadFeatureExtractorFromFileTest.java b/src/test/java/io/anserini/ltr/LoadFeatureExtractorFromFileTest.java deleted file mode 100644 index c7926ab5b6..0000000000 --- a/src/test/java/io/anserini/ltr/LoadFeatureExtractorFromFileTest.java +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Anserini: A Lucene toolkit for replicable information retrieval research - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.anserini.ltr; - -import io.anserini.ltr.feature.FeatureExtractors; -import org.junit.Test; - -/** - * Test loading feature extractors from files - */ -public class LoadFeatureExtractorFromFileTest extends BaseFeatureExtractorTest{ - - @Test - public void testMultipleExtractorNoParam() throws Exception { - String jsonFile = "src/test/resources/MixedFeatureExtractor.txt"; - String docText = "document missing token"; - String queryText = "document test"; - float[] expected = {0.836985f, 1f}; - - FeatureExtractors chain = FeatureExtractors.loadExtractor(jsonFile); - - assertFeatureValues(expected, queryText, docText, chain); - } -} diff --git a/src/test/java/io/anserini/ltr/MatchingTermCountTest.java b/src/test/java/io/anserini/ltr/MatchingTermCountTest.java new file mode 100644 index 0000000000..bdffa3f504 --- /dev/null +++ b/src/test/java/io/anserini/ltr/MatchingTermCountTest.java @@ -0,0 +1,60 @@ +package io.anserini.ltr; + +import io.anserini.ltr.feature.FeatureExtractor; +import io.anserini.ltr.feature.base.MatchingTermCount; +import org.junit.Test; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.ExecutionException; + +public class MatchingTermCountTest extends BaseFeatureExtractorTest { + + private FeatureExtractor EXTRACTOR = new MatchingTermCount(); + + @Test + public void testAllMissing() throws IOException, ExecutionException, InterruptedException { + float[] expected = {0}; + assertFeatureValues(expected, "nothing", "document test missing all", EXTRACTOR); + } + + @Test + public void testSingleTermDoc() throws IOException, ExecutionException, InterruptedException { + String testText = "document document document another"; + String testQuery = "document"; + float[] expected = {1}; + + assertFeatureValues(expected, testQuery, testText, EXTRACTOR); + } + + @Test + public void testMissingTermDoc() throws IOException, ExecutionException, InterruptedException { + String testText = "document test simple tokens"; + String testQuery = "simple missing"; + float[] expected = {1}; + + assertFeatureValues(expected, testQuery, testText, EXTRACTOR); + } + + @Test + public void testMultipleTermsDoc() throws IOException, ExecutionException, InterruptedException { + String testText = "document with multiple document term document multiple some missing"; + String testQuery = "document multiple missing"; + float[] expected = {3}; + + assertFeatureValues(expected, testQuery, testText, EXTRACTOR); + } + + @Test + public void testTermFrequencyWithMultipleDocs() throws IOException, ExecutionException, InterruptedException { + List docs = Arrays.asList("document document", "document with multiple terms", + "document to test", "test terms tokens", "another test document"); + // We want to test that the expected value of count 1 is found for document + // at index 2 + String queryText = "document"; + float[] expected = {1}; + + assertFeatureValues(expected, queryText, docs, EXTRACTOR, 2); + } +} diff --git a/src/test/java/io/anserini/ltr/PMIFeatureExtractorTest.java b/src/test/java/io/anserini/ltr/PMIFeatureExtractorTest.java index f16a8bf366..cc92cd4c73 100644 --- a/src/test/java/io/anserini/ltr/PMIFeatureExtractorTest.java +++ b/src/test/java/io/anserini/ltr/PMIFeatureExtractorTest.java @@ -16,22 +16,24 @@ package io.anserini.ltr; -import io.anserini.ltr.feature.FeatureExtractors; +import io.anserini.ltr.feature.FeatureExtractor; import io.anserini.ltr.feature.base.PMIFeatureExtractor; import org.junit.Test; import java.io.IOException; import java.util.Arrays; +import java.util.List; +import java.util.concurrent.ExecutionException; /** * Test implementation of PMI */ public class PMIFeatureExtractorTest extends BaseFeatureExtractorTest { - private static FeatureExtractors EXTRACTOR = getChain(new PMIFeatureExtractor()); + private static FeatureExtractor EXTRACTOR = new PMIFeatureExtractor(); @Test - public void testSingleDocSimpleQuery() throws IOException { + public void testSingleDocSimpleQuery() throws IOException, ExecutionException, InterruptedException { String testText = "test document multiple tokens"; String testQuery = "test document"; float[] expected = {0f}; @@ -40,7 +42,7 @@ public void testSingleDocSimpleQuery() throws IOException { } @Test - public void testMultipleDocSimpleQuery() throws IOException { + public void testMultipleDocSimpleQuery() throws IOException, ExecutionException, InterruptedException { float[] expected = {-1.43916f}; String testQuery = "test document token"; // 3 query pairs: test document, document token, test token @@ -62,7 +64,7 @@ public void testMultipleDocSimpleQuery() throws IOException { } @Test - public void testBadQueries() throws IOException { + public void testBadQueries() throws IOException, ExecutionException, InterruptedException { float[] expected = {0.0f}; String testQuery = "missing tokens"; assertFeatureValues(expected, testQuery, @@ -74,7 +76,7 @@ public void testBadQueries() throws IOException { } @Test - public void testNoIntersect() throws IOException { + public void testNoIntersect() throws IOException, ExecutionException, InterruptedException { float[] expected = {0.0f}; String testQuery = "test document"; assertFeatureValues(expected, testQuery, diff --git a/src/test/java/io/anserini/ltr/QueryLengthTest.java b/src/test/java/io/anserini/ltr/QueryLengthTest.java new file mode 100644 index 0000000000..5fa1192673 --- /dev/null +++ b/src/test/java/io/anserini/ltr/QueryLengthTest.java @@ -0,0 +1,27 @@ +package io.anserini.ltr; + +import io.anserini.ltr.feature.FeatureExtractor; +import io.anserini.ltr.feature.base.QueryLength; +import org.junit.Test; + +import java.io.IOException; +import java.util.Arrays; +import java.util.concurrent.ExecutionException; + +public class QueryLengthTest extends BaseFeatureExtractorTest{ + private static FeatureExtractor EXTRACTOR = new QueryLength(); + + @Test + public void testSingleDoc() throws IOException, ExecutionException, InterruptedException { + float[] expected = {3}; + assertFeatureValues(expected, "simple test query", "document size independent of query document", + EXTRACTOR); + } + + @Test + public void testMultipleDocs() throws IOException, ExecutionException, InterruptedException { + float[] expected = {2}; + assertFeatureValues(expected, "just test", Arrays.asList("first document", + "second document", "test document document document test"), EXTRACTOR, 2); + } +} diff --git a/src/test/java/io/anserini/ltr/SCQFeatureExtractorTest.java b/src/test/java/io/anserini/ltr/SCQFeatureExtractorTest.java index 418811b6bd..6f9644a65f 100644 --- a/src/test/java/io/anserini/ltr/SCQFeatureExtractorTest.java +++ b/src/test/java/io/anserini/ltr/SCQFeatureExtractorTest.java @@ -15,19 +15,21 @@ */ package io.anserini.ltr; -import io.anserini.ltr.feature.FeatureExtractors; +import io.anserini.ltr.feature.FeatureExtractor; import io.anserini.ltr.feature.base.SCQFeatureExtractor; import org.junit.Test; import java.io.IOException; import java.util.Arrays; +import java.util.List; +import java.util.concurrent.ExecutionException; public class SCQFeatureExtractorTest extends BaseFeatureExtractorTest { - private static FeatureExtractors EXTRACTOR = getChain(new SCQFeatureExtractor()); + private static FeatureExtractor EXTRACTOR = new SCQFeatureExtractor(); @Test - public void testSimpleSingleDocument() throws IOException { + public void testSimpleSingleDocument() throws IOException, ExecutionException, InterruptedException { String testText = "test document"; String testQuery = "document"; //idf = 0.28768 @@ -37,7 +39,7 @@ public void testSimpleSingleDocument() throws IOException { } @Test - public void testSingleDocumentMultipleQueryToken() throws IOException { + public void testSingleDocumentMultipleQueryToken() throws IOException, ExecutionException, InterruptedException { String testText = "test document more tokens than just two document "; String testQuery = "document missing"; @@ -46,7 +48,7 @@ public void testSingleDocumentMultipleQueryToken() throws IOException { } @Test - public void testSimpleMultiDocument() throws IOException { + public void testSimpleMultiDocument() throws IOException, ExecutionException, InterruptedException { String testQuery = "test document"; // idf = 0.47 // tf = 3 diff --git a/src/test/java/io/anserini/ltr/SCSFeatureExtractorTest.java b/src/test/java/io/anserini/ltr/SCSFeatureExtractorTest.java index ac03b5f99f..661fdb815c 100644 --- a/src/test/java/io/anserini/ltr/SCSFeatureExtractorTest.java +++ b/src/test/java/io/anserini/ltr/SCSFeatureExtractorTest.java @@ -16,21 +16,23 @@ package io.anserini.ltr; -import io.anserini.ltr.feature.FeatureExtractors; -import io.anserini.ltr.feature.base.SimplifiedClarityFeatureExtractor; +import io.anserini.ltr.feature.FeatureExtractor; +import io.anserini.ltr.feature.base.SCSFeatureExtractor; import org.junit.Test; import java.io.IOException; import java.util.Arrays; +import java.util.concurrent.ExecutionException; /** * Tests the simplified clarity feature */ public class SCSFeatureExtractorTest extends BaseFeatureExtractorTest { - private FeatureExtractors EXTRACTOR = getChain(new SimplifiedClarityFeatureExtractor()); + + private FeatureExtractor EXTRACTOR = new SCSFeatureExtractor(); @Test - public void testBadQuery() throws IOException { + public void testBadQuery() throws IOException, ExecutionException, InterruptedException { String testQuery = "test"; // P[t|q] = 1 // P[t|D] = 0 @@ -41,7 +43,7 @@ public void testBadQuery() throws IOException { } @Test - public void testSimpleQuery() throws IOException { + public void testSimpleQuery() throws IOException, ExecutionException, InterruptedException { String testQuery = "test"; // P[t|q] = 1 @@ -59,7 +61,7 @@ public void testSimpleQuery() throws IOException { } @Test - public void testMultipleTokensQuery() throws IOException { + public void testMultipleTokensQuery() throws IOException, ExecutionException, InterruptedException { String testQuery = "test document"; // P[t|q] = 1/2 diff --git a/src/test/java/io/anserini/ltr/TermFrequencyFeatureExtractorTest.java b/src/test/java/io/anserini/ltr/SumMatchingTFTest.java similarity index 63% rename from src/test/java/io/anserini/ltr/TermFrequencyFeatureExtractorTest.java rename to src/test/java/io/anserini/ltr/SumMatchingTFTest.java index e651b1366e..7ec86e86d6 100644 --- a/src/test/java/io/anserini/ltr/TermFrequencyFeatureExtractorTest.java +++ b/src/test/java/io/anserini/ltr/SumMatchingTFTest.java @@ -16,60 +16,58 @@ package io.anserini.ltr; -import io.anserini.ltr.feature.FeatureExtractors; -import io.anserini.ltr.feature.base.TermFrequencyFeatureExtractor; +import io.anserini.ltr.feature.FeatureExtractor; +import io.anserini.ltr.feature.base.SumMatchingTF; import org.junit.Test; import java.io.IOException; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.concurrent.ExecutionException; /** * Test the term frequency feature extractor is correct */ -public class TermFrequencyFeatureExtractorTest extends BaseFeatureExtractorTest { +public class SumMatchingTFTest extends BaseFeatureExtractorTest { - private FeatureExtractors getChain() { - FeatureExtractors chain = new FeatureExtractors(); - chain.add(new TermFrequencyFeatureExtractor()); - return chain; - } + private FeatureExtractor EXTRACTOR = new SumMatchingTF(); @Test - public void testAllMissing() throws IOException { + public void testAllMissing() throws IOException, ExecutionException, InterruptedException { float[] expected = {0}; - assertFeatureValues(expected, "nothing", "document test missing all", getChain()); + assertFeatureValues(expected, "nothing", "document test missing all", EXTRACTOR); } @Test - public void testSingleTermDoc() throws IOException { + public void testSingleTermDoc() throws IOException, ExecutionException, InterruptedException { String testText = "document document document another"; String testQuery = "document"; float[] expected = {3}; - assertFeatureValues(expected, testQuery, testText, getChain()); + assertFeatureValues(expected, testQuery, testText, EXTRACTOR); } @Test - public void testMissingTermDoc() throws IOException { + public void testMissingTermDoc() throws IOException, ExecutionException, InterruptedException { String testText = "document test simple tokens"; String testQuery = "simple missing"; float[] expected = {1}; - assertFeatureValues(expected, testQuery, testText, getChain()); + assertFeatureValues(expected, testQuery, testText, EXTRACTOR); } @Test - public void testMultipleTermsDoc() throws IOException { + public void testMultipleTermsDoc() throws IOException, ExecutionException, InterruptedException { String testText = "document with multiple document term document multiple some missing"; String testQuery = "document multiple missing"; float[] expected = {6}; - assertFeatureValues(expected, testQuery, testText, getChain()); + assertFeatureValues(expected, testQuery, testText, EXTRACTOR); } @Test - public void testTermFrequencyWithMultipleDocs() throws IOException { + public void testTermFrequencyWithMultipleDocs() throws IOException, ExecutionException, InterruptedException { List docs = Arrays.asList("document document", "document with multiple terms", "document to test", "test terms tokens", "another test document"); // We want to test that the expected value of count 1 is found for document @@ -77,6 +75,6 @@ public void testTermFrequencyWithMultipleDocs() throws IOException { String queryText = "document"; float[] expected = {1}; - assertFeatureValues(expected, queryText, docs, getChain(), 2); + assertFeatureValues(expected, queryText, docs, EXTRACTOR, 2); } } diff --git a/src/test/java/io/anserini/ltr/TFIDFFeatureExtractorTest.java b/src/test/java/io/anserini/ltr/TFIDFFeatureExtractorTest.java index 6db4aec3f6..c61a047991 100644 --- a/src/test/java/io/anserini/ltr/TFIDFFeatureExtractorTest.java +++ b/src/test/java/io/anserini/ltr/TFIDFFeatureExtractorTest.java @@ -16,11 +16,13 @@ package io.anserini.ltr; +import io.anserini.ltr.feature.FeatureExtractor; import io.anserini.ltr.feature.base.TFIDFFeatureExtractor; import org.junit.Test; import java.io.IOException; import java.util.Arrays; +import java.util.concurrent.ExecutionException; /** * Make sure that TFIDF feature extractor gives the scores as caculated by the formula @@ -28,54 +30,54 @@ */ public class TFIDFFeatureExtractorTest extends BaseFeatureExtractorTest { + private FeatureExtractor EXTRACTOR = new TFIDFFeatureExtractor(); + @Test - public void testTFIDFOnSingleDocSingleQuery() throws IOException { + public void testTFIDFOnSingleDocSingleQuery() throws IOException, ExecutionException, InterruptedException { float[] expected = {1f}; - assertFeatureValues(expected, "document", "single document test case", - new TFIDFFeatureExtractor() ); + assertFeatureValues(expected, "document", "single document test case", EXTRACTOR); } @Test - public void testTFIDFOnSingleDocMultiQuery() throws IOException { + public void testTFIDFOnSingleDocMultiQuery() throws IOException, ExecutionException, InterruptedException { float[] expected = {2f}; - assertFeatureValues(expected, "document test", "single document test case", - new TFIDFFeatureExtractor() ); + assertFeatureValues(expected, "document test", "single document test case", EXTRACTOR); } @Test - public void testTFIDFOnMultiDocSingleQuery() throws IOException { + public void testTFIDFOnMultiDocSingleQuery() throws IOException, ExecutionException, InterruptedException { String queryText = "document"; float[] expected = {1f}; assertFeatureValues(expected, queryText, Arrays.asList("single document test case", - "another document test"),getChain(new TFIDFFeatureExtractor()), 0 ); + "another document test"), EXTRACTOR, 0 ); } @Test - public void testTFIDFOnMultiDocMultiQuery() throws IOException { + public void testTFIDFOnMultiDocMultiQuery() throws IOException, ExecutionException, InterruptedException { String queryText = "document test"; float[] expected = {2f}; assertFeatureValues(expected, queryText, Arrays.asList("single document test case", - "another document test"),getChain(new TFIDFFeatureExtractor()), 0 ); + "another document test"), EXTRACTOR, 0 ); } @Test - public void testTFIDFOnMultiDocMultiQuery2() throws IOException { + public void testTFIDFOnMultiDocMultiQuery2() throws IOException, ExecutionException, InterruptedException { String queryText = "document test"; float[] expected = {2.9753323f}; assertFeatureValues(expected, queryText, Arrays.asList("single document test case", - "another document"),getChain(new TFIDFFeatureExtractor()), 0 ); + "another document"), EXTRACTOR, 0 ); } @Test - public void testTFIDFOnMultiDocMultiQuery3() throws IOException { + public void testTFIDFOnMultiDocMultiQuery3() throws IOException, ExecutionException, InterruptedException { String queryText = "document test"; float[] expected = {3.8667474f}; assertFeatureValues(expected, queryText, Arrays.asList("single document test case", - "new document", "another document"),getChain(new TFIDFFeatureExtractor()), 0 ); + "new document", "another document"), EXTRACTOR, 0 ); } } diff --git a/src/test/java/io/anserini/ltr/UnigramFeaturesTest.java b/src/test/java/io/anserini/ltr/UnigramFeaturesTest.java deleted file mode 100644 index ecb0690dde..0000000000 --- a/src/test/java/io/anserini/ltr/UnigramFeaturesTest.java +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Anserini: A Lucene toolkit for replicable information retrieval research - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.anserini.ltr; - -import io.anserini.ltr.feature.UnigramFeatureExtractor; -import org.junit.Test; - -import java.io.IOException; - -/** - * Tests the unigram count feature - */ -public class UnigramFeaturesTest extends BaseFeatureExtractorTest { - - @Test - public void testSingleQueryTermCounts() throws IOException { - String testText = "document document simple test case"; - String testQuery = "document"; - float [] expected = {2}; - assertFeatureValues(expected, testQuery, testText, new UnigramFeatureExtractor()); - } - - @Test - public void testNonMatchQuery() throws IOException { - String testText = "document document simple"; - String testQuery = "case"; - float[] expected = {0}; - - assertFeatureValues(expected, testQuery, testText, new UnigramFeatureExtractor()); - } - - @Test - public void testPartialMatches() throws IOException { - String testText = "simple test case document"; - String testQuery = "simple document unigram"; - float[] expected = {2}; - - assertFeatureValues(expected, testQuery, testText, new UnigramFeatureExtractor()); - } - - @Test - public void testMultipleMatches() throws IOException { - String testText = "simple simple document test case document"; - String testQuery = "document simple case nonexistent query"; - float[] expected = {5}; - - assertFeatureValues(expected, testQuery, testText, new UnigramFeatureExtractor()); - } - - -} diff --git a/src/test/java/io/anserini/ltr/UniqueTermCountFeatureExtractorTest.java b/src/test/java/io/anserini/ltr/UniqueTermCountFeatureExtractorTest.java new file mode 100644 index 0000000000..89d0a83d64 --- /dev/null +++ b/src/test/java/io/anserini/ltr/UniqueTermCountFeatureExtractorTest.java @@ -0,0 +1,28 @@ +package io.anserini.ltr; + +import io.anserini.ltr.feature.FeatureExtractor; +import io.anserini.ltr.feature.base.UniqueTermCount; +import org.junit.Test; + +import java.io.IOException; +import java.util.Arrays; +import java.util.concurrent.ExecutionException; + +public class UniqueTermCountFeatureExtractorTest extends BaseFeatureExtractorTest { + private static FeatureExtractor EXTRACTOR = new UniqueTermCount(); + + @Test + public void testSingleDoc() throws IOException, ExecutionException, InterruptedException { + float[] expected = {3}; + assertFeatureValues(expected, "simple test query", "document size independent of query document", + EXTRACTOR); + } + + @Test + public void testMultipleDocs() throws IOException, ExecutionException, InterruptedException { + float[] expected = {2}; + assertFeatureValues(expected, "just test just test", Arrays.asList("first document", + "second document", "test document document document test"), EXTRACTOR, 2); + } +} +