-
-
Notifications
You must be signed in to change notification settings - Fork 709
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
511 additions
and
104 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,129 +1,318 @@ | ||
// # Aggregation example | ||
// | ||
// This example shows how you can use built-in aggregations. | ||
// We will use range buckets and compute the average in each bucket. | ||
// | ||
// We will use nested aggregations with buckets and metrics: | ||
// - Range buckets and compute the average in each bucket. | ||
// - Term aggregation and compute the min price in each bucket | ||
// --- | ||
|
||
use serde_json::Value; | ||
use serde_json::{Deserializer, Value}; | ||
use tantivy::aggregation::agg_req::{ | ||
Aggregation, Aggregations, BucketAggregation, BucketAggregationType, MetricAggregation, | ||
RangeAggregation, | ||
}; | ||
use tantivy::aggregation::agg_result::AggregationResults; | ||
use tantivy::aggregation::bucket::RangeAggregationRange; | ||
use tantivy::aggregation::metric::AverageAggregation; | ||
use tantivy::aggregation::AggregationCollector; | ||
use tantivy::query::TermQuery; | ||
use tantivy::schema::{self, IndexRecordOption, Schema, TextFieldIndexing}; | ||
use tantivy::{doc, Index, Term}; | ||
use tantivy::query::AllQuery; | ||
use tantivy::schema::{self, IndexRecordOption, Schema, TextFieldIndexing, FAST}; | ||
use tantivy::Index; | ||
|
||
fn main() -> tantivy::Result<()> { | ||
// # Create Schema | ||
// | ||
// Lets create a schema for a footwear shop, with 4 fields: name, category, stock and price. | ||
// category, stock and price will be fast fields as that's the requirement | ||
// for aggregation queries. | ||
// | ||
|
||
let mut schema_builder = Schema::builder(); | ||
// In preparation of the `TermsAggregation`, the category field is configured with: | ||
// - `set_fast` | ||
// - `raw` tokenizer | ||
// | ||
// The tokenizer is set to "raw", because the fast field uses the same dictionary as the | ||
// inverted index. (This behaviour will change in tantivy 0.20, where the fast field will | ||
// always be raw tokenized independent from the regular tokenizing) | ||
// | ||
let text_fieldtype = schema::TextOptions::default() | ||
.set_indexing_options( | ||
TextFieldIndexing::default().set_index_option(IndexRecordOption::WithFreqs), | ||
TextFieldIndexing::default() | ||
.set_index_option(IndexRecordOption::WithFreqs) | ||
.set_tokenizer("raw"), | ||
) | ||
.set_fast() | ||
.set_stored(); | ||
let text_field = schema_builder.add_text_field("text", text_fieldtype); | ||
let score_fieldtype = crate::schema::NumericOptions::default().set_fast(); | ||
let highscore_field = schema_builder.add_f64_field("highscore", score_fieldtype.clone()); | ||
let price_field = schema_builder.add_f64_field("price", score_fieldtype); | ||
schema_builder.add_text_field("category", text_fieldtype); | ||
schema_builder.add_f64_field("stock", FAST); | ||
schema_builder.add_f64_field("price", FAST); | ||
|
||
let schema = schema_builder.build(); | ||
|
||
// # Indexing documents | ||
// | ||
// Lets index a bunch of documents for this example. | ||
let index = Index::create_in_ram(schema); | ||
let index = Index::create_in_ram(schema.clone()); | ||
|
||
let data = r#"{ | ||
"name": "Almond Toe Court Shoes, Patent Black", | ||
"category": "Womens Footwear", | ||
"price": 99.00, | ||
"stock": 5 | ||
} | ||
{ | ||
"name": "Suede Shoes, Blue", | ||
"category": "Womens Footwear", | ||
"price": 42.00, | ||
"stock": 4 | ||
} | ||
{ | ||
"name": "Leather Driver Saddle Loafers, Tan", | ||
"category": "Mens Footwear", | ||
"price": 34.00, | ||
"stock": 12 | ||
} | ||
{ | ||
"name": "Flip Flops, Red", | ||
"category": "Mens Footwear", | ||
"price": 19.00, | ||
"stock": 6 | ||
} | ||
{ | ||
"name": "Flip Flops, Blue", | ||
"category": "Mens Footwear", | ||
"price": 19.00, | ||
"stock": 0 | ||
} | ||
{ | ||
"name": "Gold Button Cardigan, Black", | ||
"category": "Womens Casualwear", | ||
"price": 167.00, | ||
"stock": 6 | ||
} | ||
{ | ||
"name": "Cotton Shorts, Medium Red", | ||
"category": "Womens Casualwear", | ||
"price": 30.00, | ||
"stock": 5 | ||
} | ||
{ | ||
"name": "Fine Stripe Short SleeveShirt, Grey", | ||
"category": "Mens Casualwear", | ||
"price": 49.99, | ||
"stock": 9 | ||
} | ||
{ | ||
"name": "Fine Stripe Short SleeveShirt, Green", | ||
"category": "Mens Casualwear", | ||
"price": 49.99, | ||
"offer": 39.99, | ||
"stock": 9 | ||
} | ||
{ | ||
"name": "Sharkskin Waistcoat, Charcoal", | ||
"category": "Mens Formalwear", | ||
"price": 75.00, | ||
"stock": 2 | ||
} | ||
{ | ||
"name": "Lightweight Patch PocketBlazer, Deer", | ||
"category": "Mens Formalwear", | ||
"price": 175.50, | ||
"stock": 1 | ||
} | ||
{ | ||
"name": "Bird Print Dress, Black", | ||
"category": "Womens Formalwear", | ||
"price": 270.00, | ||
"stock": 10 | ||
} | ||
{ | ||
"name": "Mid Twist Cut-Out Dress, Pink", | ||
"category": "Womens Formalwear", | ||
"price": 540.00, | ||
"stock": 5 | ||
}"#; | ||
|
||
let stream = Deserializer::from_str(data).into_iter::<Value>(); | ||
|
||
let mut index_writer = index.writer(50_000_000)?; | ||
// writing the segment | ||
index_writer.add_document(doc!( | ||
text_field => "cool", | ||
highscore_field => 1f64, | ||
price_field => 0f64, | ||
))?; | ||
index_writer.add_document(doc!( | ||
text_field => "cool", | ||
highscore_field => 3f64, | ||
price_field => 1f64, | ||
))?; | ||
index_writer.add_document(doc!( | ||
text_field => "cool", | ||
highscore_field => 5f64, | ||
price_field => 1f64, | ||
))?; | ||
index_writer.add_document(doc!( | ||
text_field => "nohit", | ||
highscore_field => 6f64, | ||
price_field => 2f64, | ||
))?; | ||
index_writer.add_document(doc!( | ||
text_field => "cool", | ||
highscore_field => 7f64, | ||
price_field => 2f64, | ||
))?; | ||
index_writer.commit()?; | ||
index_writer.add_document(doc!( | ||
text_field => "cool", | ||
highscore_field => 11f64, | ||
price_field => 10f64, | ||
))?; | ||
index_writer.add_document(doc!( | ||
text_field => "cool", | ||
highscore_field => 14f64, | ||
price_field => 15f64, | ||
))?; | ||
|
||
index_writer.add_document(doc!( | ||
text_field => "cool", | ||
highscore_field => 15f64, | ||
price_field => 20f64, | ||
))?; | ||
let mut num_indexed = 0; | ||
for value in stream { | ||
let doc = schema.parse_document(&serde_json::to_string(&value.unwrap())?)?; | ||
index_writer.add_document(doc)?; | ||
num_indexed += 1; | ||
if num_indexed > 4 { | ||
// Writing the first segment | ||
index_writer.commit()?; | ||
} | ||
} | ||
|
||
// Writing the second segment | ||
index_writer.commit()?; | ||
|
||
// We have two segments now. The `AggregationCollector` will run the aggregation on each | ||
// segment and then merge the results into an `IntermediateAggregationResult`. | ||
|
||
let reader = index.reader()?; | ||
let text_field = reader.searcher().schema().get_field("text").unwrap(); | ||
|
||
let term_query = TermQuery::new( | ||
Term::from_field_text(text_field, "cool"), | ||
IndexRecordOption::Basic, | ||
); | ||
|
||
let sub_agg_req_1: Aggregations = vec![( | ||
"average_price".to_string(), | ||
Aggregation::Metric(MetricAggregation::Average( | ||
AverageAggregation::from_field_name("price".to_string()), | ||
)), | ||
)] | ||
.into_iter() | ||
.collect(); | ||
let searcher = reader.searcher(); | ||
// --- | ||
// # Aggregation Query | ||
// | ||
// | ||
// We can construct the query by building the request structure or by deserializing from JSON. | ||
// The JSON API is more stable and therefore recommended. | ||
// | ||
// ## Request 1 | ||
|
||
let agg_req_str = r#" | ||
{ | ||
"group_by_stock": { | ||
"aggs": { | ||
"average_price": { "avg": { "field": "price" } } | ||
}, | ||
"range": { | ||
"field": "stock", | ||
"ranges": [ | ||
{ "key": "few", "to": 1.0 }, | ||
{ "key": "some", "from": 1.0, "to": 10.0 }, | ||
{ "key": "many", "from": 10.0 } | ||
] | ||
} | ||
} | ||
} "#; | ||
|
||
// In this Aggregation we want to get the average price for different groups, depending on how | ||
// many items are in stock. We define custom ranges `few`, `some`, `many` via the | ||
// range aggregation. | ||
// For every bucket we want the average price, so we create a nested metric aggregation on the range bucket | ||
// aggregation. Only buckets support nested aggregations. | ||
// ### Request JSON API | ||
// | ||
|
||
let agg_req: Aggregations = serde_json::from_str(agg_req_str)?; | ||
let collector = AggregationCollector::from_aggs(agg_req, None, index.schema()); | ||
|
||
let agg_res: AggregationResults = searcher.search(&AllQuery, &collector).unwrap(); | ||
let res2: Value = serde_json::to_value(agg_res)?; | ||
|
||
// ### Request Rust API | ||
// | ||
// This is exactly the same request as above, but via the rust structures. | ||
// | ||
|
||
let agg_req_1: Aggregations = vec![( | ||
"score_ranges".to_string(), | ||
let agg_req: Aggregations = vec![( | ||
"group_by_stock".to_string(), | ||
Aggregation::Bucket(BucketAggregation { | ||
bucket_agg: BucketAggregationType::Range(RangeAggregation { | ||
field: "highscore".to_string(), | ||
field: "stock".to_string(), | ||
ranges: vec![ | ||
(-1f64..9f64).into(), | ||
(9f64..14f64).into(), | ||
(14f64..20f64).into(), | ||
RangeAggregationRange { | ||
key: Some("few".into()), | ||
from: None, | ||
to: Some(1f64), | ||
}, | ||
RangeAggregationRange { | ||
key: Some("some".into()), | ||
from: Some(1f64), | ||
to: Some(10f64), | ||
}, | ||
RangeAggregationRange { | ||
key: Some("many".into()), | ||
from: Some(10f64), | ||
to: None, | ||
}, | ||
], | ||
..Default::default() | ||
}), | ||
sub_aggregation: sub_agg_req_1, | ||
sub_aggregation: vec![( | ||
"average_price".to_string(), | ||
Aggregation::Metric(MetricAggregation::Average( | ||
AverageAggregation::from_field_name("price".to_string()), | ||
)), | ||
)] | ||
.into_iter() | ||
.collect(), | ||
}), | ||
)] | ||
.into_iter() | ||
.collect(); | ||
|
||
let collector = AggregationCollector::from_aggs(agg_req_1, None, index.schema()); | ||
let collector = AggregationCollector::from_aggs(agg_req, None, index.schema()); | ||
// We use the `AllQuery` which will pass all documents to the AggregationCollector. | ||
let agg_res: AggregationResults = searcher.search(&AllQuery, &collector).unwrap(); | ||
|
||
let searcher = reader.searcher(); | ||
let agg_res: AggregationResults = searcher.search(&term_query, &collector).unwrap(); | ||
let res1: Value = serde_json::to_value(agg_res)?; | ||
|
||
// ### Aggregation Result | ||
// | ||
// The resulting structure deserializes in the same JSON format as elastic search. | ||
// | ||
let expected_res = r#" | ||
{ | ||
"group_by_stock":{ | ||
"buckets":[ | ||
{"average_price":{"value":19.0},"doc_count":1,"key":"few","to":1.0}, | ||
{"average_price":{"value":124.748},"doc_count":10,"from":1.0,"key":"some","to":10.0}, | ||
{"average_price":{"value":152.0},"doc_count":2,"from":10.0,"key":"many"} | ||
] | ||
} | ||
} | ||
"#; | ||
let expected_json: Value = serde_json::from_str(expected_res)?; | ||
assert_eq!(expected_json, res1); | ||
assert_eq!(expected_json, res2); | ||
|
||
// ### Request 2 | ||
// | ||
// Now we are interested in the minimum price per category, so we create a bucket per | ||
// category via `TermsAggregation`. We are interested in the highest minimum prices, and set the order of | ||
// the buckets `"order": { "min_price": "desc" }` to be sorted by the the metric of the sub aggregation. (awesome) | ||
// | ||
let agg_req_str = r#" | ||
{ | ||
"min_price_per_category": { | ||
"aggs": { | ||
"min_price": { "min": { "field": "price" } } | ||
}, | ||
"terms": { | ||
"field": "category", | ||
"min_doc_count": 1, | ||
"order": { "min_price": "desc" } | ||
} | ||
} | ||
} "#; | ||
|
||
let agg_req: Aggregations = serde_json::from_str(agg_req_str)?; | ||
|
||
let collector = AggregationCollector::from_aggs(agg_req, None, index.schema()); | ||
|
||
let agg_res: AggregationResults = searcher.search(&AllQuery, &collector).unwrap(); | ||
let res: Value = serde_json::to_value(agg_res)?; | ||
println!("{}", serde_json::to_string_pretty(&res)?); | ||
|
||
// Minimum price per category, sorted by minimum price descending | ||
// | ||
// As you can see, the starting prices for `Formalwear` are higher than `Casualwear`. | ||
// | ||
let expected_res = r#" | ||
{ | ||
"min_price_per_category": { | ||
"buckets": [ | ||
{ "doc_count": 2, "key": "Womens Formalwear", "min_price": { "value": 270.0 } }, | ||
{ "doc_count": 2, "key": "Mens Formalwear", "min_price": { "value": 75.0 } }, | ||
{ "doc_count": 2, "key": "Mens Casualwear", "min_price": { "value": 49.99 } }, | ||
{ "doc_count": 2, "key": "Womens Footwear", "min_price": { "value": 42.0 } }, | ||
{ "doc_count": 2, "key": "Womens Casualwear", "min_price": { "value": 30.0 } }, | ||
{ "doc_count": 3, "key": "Mens Footwear", "min_price": { "value": 19.0 } } | ||
], | ||
"sum_other_doc_count": 0 | ||
} | ||
} | ||
"#; | ||
let expected_json: Value = serde_json::from_str(expected_res)?; | ||
|
||
assert_eq!(expected_json, res); | ||
|
||
Ok(()) | ||
} |
Oops, something went wrong.