diff --git a/Cargo.lock b/Cargo.lock index 146608d5339..dccd967b8ba 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1495,6 +1495,7 @@ dependencies = [ "slog-async", "slog-envlogger", "slog-term", + "sqlparser", "stable-hash 0.3.4", "stable-hash 0.4.4", "strum_macros", @@ -4106,6 +4107,15 @@ version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3b9b39299b249ad65f3b7e96443bad61c02ca5cd3589f46cb6d610a0fd6c0d6a" +[[package]] +name = "sqlparser" +version = "0.43.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f95c4bae5aba7cd30bd506f7140026ade63cff5afd778af8854026f9606bf5d4" +dependencies = [ + "log", +] + [[package]] name = "stable-hash" version = "0.3.4" diff --git a/docs/aggregations.md b/docs/aggregations.md index f7819e67dda..166b5fb8733 100644 --- a/docs/aggregations.md +++ b/docs/aggregations.md @@ -87,6 +87,11 @@ _dimensions_, and fields with the `@aggregate` directive are called _aggregates_. A timeseries type really represents many timeseries, one for each combination of values for the dimensions. +The same timeseries can be used for multiple aggregations. For example, the +`Stats` aggregation could also be formed by aggregating over the `TokenData` +timeseries. Since `Stats` doesn't have a `token` dimension, all aggregates +will be formed across all tokens. + Each `@aggregate` by default starts at 0 for each new bucket and therefore just aggregates over the time interval for the bucket. The `@aggregate` directive also accepts a boolean flag `cumulative` that indicates whether @@ -97,10 +102,6 @@ the entire timeseries up to the end of the time interval for the bucket. aggregations, and it doesn't seem like it used in practice, we won't initially support it. (same for variance, stddev etc.) -**TODO** The timeseries type can be simplified for some situations if -aggregations can be done over expressions, for example over `priceUSD * -amount` to track `totalVolumeUSD` - **TODO** It might be necessary to allow `@aggregate` fields that are only used for some intervals. We could allow that with syntax like `@aggregate(fn: .., arg: .., interval: "day")` @@ -132,7 +133,10 @@ annotation. These attributes must be of a numeric type (`Int`, `Int8`, `BigInt`, or `BigDecimal`) The annotation must have two arguments: - `fn`: the name of an aggregation function -- `arg`: the name of an attribute in the timeseries type +- `arg`: the name of an attribute in the timeseries type, or an expression + using only constants and attributes of the timeseries type + +#### Aggregation functions The following aggregation functions are currently supported: @@ -149,16 +153,42 @@ The `first` and `last` aggregation function calculate the first and last value in an interval by sorting the data by `id`; `graph-node` enforces correctness here by automatically setting the `id` for timeseries entities. +#### Aggregation expressions + +The `arg` can be the name of any attribute in the timeseries type, or an +expression using only constants and attributes of the timeseries type such +as `price * amount` or `greatest(amount0, amount1)`. Expressions use SQL +syntax and support a subset of builtin SQL functions, operators, and other +constructs. + +Supported operators are `+`, `-`, `*`, `/`, `%`, `^`, `=`, `!=`, `<`, `<=`, +`>`, `>=`, `<->`, `and`, `or`, and `not`. In addition the operators `is +[not] {null|true|false}`, and `is [not] distinct from` are supported. + +The supported SQL functions are the [math +functions](https://www.postgresql.org/docs/current/functions-math.html) +`abs`, `ceil`, `ceiling`, `div`, `floor`, `gcd`, `lcm`, `mod`, `power`, +`sign`, and the [conditional +functions](https://www.postgresql.org/docs/current/functions-conditional.html) +`coalesce`, `nullif`, `greatest`, and `least`. + +The +[statement](https://www.postgresql.org/docs/current/functions-conditional.html#FUNCTIONS-CASE) +`case when .. else .. end` is also supported. + +Some examples of valid expressions, assuming the underlying timeseries +contains the mentioned fields: + +- Aggregate the value of a token: `@aggregate(fn: "sum", arg: "priceUSD * amount")` +- Aggregate the maximum positive amount of two different amounts: + `@aggregate(fn: "max", arg: "greatest(amount0, amount1, 0)")` +- Conditionally sum an amount: `@aggregate(fn: "sum", arg: "case when amount0 > amount1 then amount0 else 0 end")` + ## Querying _This section is not implemented yet, and will require a bit more thought about details_ -**TODO** As written, timeseries points like `TokenData` can be queried like -any other entity. It would be nice to restrict how these data points can be -queried, maybe even forbid it, as that would give us more latitude in how we -store that data. - We create a toplevel query field for each aggregation. That query field accepts the following arguments: diff --git a/graph/Cargo.toml b/graph/Cargo.toml index 21f522df3b1..b7d28682566 100644 --- a/graph/Cargo.toml +++ b/graph/Cargo.toml @@ -88,6 +88,7 @@ web3 = { git = "https://github.com/graphprotocol/rust-web3", branch = "graph-pat "arbitrary_precision", ] } serde_plain = "1.0.2" +sqlparser = "0.43.1" [dev-dependencies] clap = { version = "3.2.25", features = ["derive", "env"] } diff --git a/graph/src/lib.rs b/graph/src/lib.rs index 024336c047b..8ab0c90dbd7 100644 --- a/graph/src/lib.rs +++ b/graph/src/lib.rs @@ -51,6 +51,7 @@ pub use petgraph; pub use prometheus; pub use semver; pub use slog; +pub use sqlparser; pub use stable_hash; pub use stable_hash_legacy; pub use tokio; diff --git a/graph/src/schema/entity_type.rs b/graph/src/schema/entity_type.rs index e8706a9ee69..7827dd928dd 100644 --- a/graph/src/schema/entity_type.rs +++ b/graph/src/schema/entity_type.rs @@ -11,10 +11,7 @@ use crate::{ util::intern::Atom, }; -use super::{ - input_schema::{ObjectType, POI_OBJECT}, - EntityKey, Field, InputSchema, InterfaceType, -}; +use super::{EntityKey, Field, InputSchema, InterfaceType, ObjectType, POI_OBJECT}; /// A reference to a type in the input schema. It should mostly be the /// reference to a concrete entity type, either one declared with `@entity` diff --git a/graph/src/schema/input_schema.rs b/graph/src/schema/input/mod.rs similarity index 97% rename from graph/src/schema/input_schema.rs rename to graph/src/schema/input/mod.rs index f5e269fd852..be7eb7f37b6 100644 --- a/graph/src/schema/input_schema.rs +++ b/graph/src/schema/input/mod.rs @@ -23,8 +23,10 @@ use crate::prelude::{s, DeploymentHash}; use crate::schema::api::api_schema; use crate::util::intern::{Atom, AtomPool}; -use super::fulltext::FulltextDefinition; -use super::{ApiSchema, AsEntityTypeName, EntityType, Schema}; +use crate::schema::fulltext::FulltextDefinition; +use crate::schema::{ApiSchema, AsEntityTypeName, EntityType, Schema}; + +pub mod sqlexpr; /// The name of the PoI entity type pub(crate) const POI_OBJECT: &str = "Poi$"; @@ -753,42 +755,26 @@ impl AggregationMapping { } } -#[derive(Clone, PartialEq, Debug)] -pub struct Arg { - pub name: Word, - pub value_type: ValueType, -} - -impl Arg { - fn new(name: Word, src_type: &s::ObjectType) -> Self { - let value_type = src_type - .field(&name) - .unwrap() - .field_type - .value_type() - .unwrap(); - Self { name, value_type } - } -} - +/// The `@aggregate` annotation in an aggregation. The annotation controls +/// how values from the source table are aggregated #[derive(PartialEq, Debug)] pub struct Aggregate { + /// The name of the aggregate field in the aggregation pub name: Word, + /// The function used to aggregate the values pub func: AggregateFn, - pub arg: Arg, + /// The field to aggregate in the source table + pub arg: Word, + /// The type of the field `name` in the aggregation pub field_type: s::Type, + /// The `ValueType` corresponding to `field_type` pub value_type: ValueType, + /// Whether the aggregation is cumulative pub cumulative: bool, } impl Aggregate { - fn new( - _schema: &Schema, - src_type: &s::ObjectType, - name: &str, - field_type: &s::Type, - dir: &s::Directive, - ) -> Self { + fn new(_schema: &Schema, name: &str, field_type: &s::Type, dir: &s::Directive) -> Self { let func = dir .argument("fn") .unwrap() @@ -803,8 +789,7 @@ impl Aggregate { let arg = dir .argument("arg") .map(|arg| Word::from(arg.as_str().unwrap())) - .map(|arg| Arg::new(arg, src_type)) - .unwrap_or_else(|| Arg::new(ID.clone(), src_type)); + .unwrap_or_else(|| ID.clone()); let cumulative = dir .argument(kw::CUMULATIVE) .map(|arg| match arg { @@ -861,7 +846,6 @@ impl Aggregation { .unwrap() .as_str() .unwrap(); - let src_type = schema.document.get_object_type_definition(source).unwrap(); let source = pool.lookup(source).unwrap(); let fields: Box<[_]> = agg_type .fields @@ -873,9 +857,7 @@ impl Aggregation { .fields .iter() .filter_map(|field| field.find_directive(kw::AGGREGATE).map(|dir| (field, dir))) - .map(|(field, dir)| { - Aggregate::new(schema, src_type, &field.name, &field.field_type, dir) - }) + .map(|(field, dir)| Aggregate::new(schema, &field.name, &field.field_type, dir)) .collect(); let obj_types = intervals @@ -1675,7 +1657,7 @@ mod validations { }, prelude::s, schema::{ - input_schema::{kw, AggregateFn, AggregationInterval}, + input::{kw, sqlexpr, AggregateFn, AggregationInterval}, FulltextAlgorithm, FulltextLanguage, Schema as BaseSchema, SchemaValidationError, SchemaValidationError as Err, Strings, SCHEMA_TYPE_NAME, }, @@ -2493,46 +2475,56 @@ mod validations { continue; } }; - let arg_type = match source.field(arg) { - Some(arg_field) => match arg_field.field_type.value_type() { - Ok(arg_type) if arg_type.is_numeric() => arg_type, - Ok(_) | Err(_) => { - errors.push(Err::AggregationNonNumericArg( - agg_type.name.to_owned(), - field.name.to_owned(), - source.name.to_owned(), - arg.to_owned(), - )); - continue; - } - }, - None => { - errors.push(Err::AggregationUnknownArg( + let field_type = match field.field_type.value_type() { + Ok(field_type) => field_type, + Err(_) => { + errors.push(Err::NonNumericAggregate( agg_type.name.to_owned(), field.name.to_owned(), - arg.to_owned(), )); continue; } }; - let field_type = match field.field_type.value_type() { - Ok(field_type) => field_type, - Err(_) => { - errors.push(Err::NonNumericAggregate( + // It would be nicer to use a proper struct here + // and have that implement + // `sqlexpr::ExprVisitor` but we need access to + // a bunch of local variables that would make + // setting up that struct a bit awkward, so we + // use a closure instead + let check_ident = |ident: &str| -> Result<(), SchemaValidationError> { + let arg_type = match source.field(ident) { + Some(arg_field) => match arg_field.field_type.value_type() { + Ok(arg_type) if arg_type.is_numeric() => arg_type, + Ok(_) | Err(_) => { + return Err(Err::AggregationNonNumericArg( + agg_type.name.to_owned(), + field.name.to_owned(), + source.name.to_owned(), + arg.to_owned(), + )); + } + }, + None => { + return Err(Err::AggregationUnknownArg( + agg_type.name.to_owned(), + field.name.to_owned(), + arg.to_owned(), + )); + } + }; + if arg_type > field_type { + return Err(Err::AggregationNonMatchingArg( agg_type.name.to_owned(), field.name.to_owned(), + arg.to_owned(), + arg_type.to_str().to_owned(), + field_type.to_str().to_owned(), )); - continue; } + Ok(()) }; - if arg_type > field_type { - errors.push(Err::AggregationNonMatchingArg( - agg_type.name.to_owned(), - field.name.to_owned(), - arg.to_owned(), - arg_type.to_str().to_owned(), - field_type.to_str().to_owned(), - )); + if let Err(mut errs) = sqlexpr::parse(arg, check_ident) { + errors.append(&mut errs); } } None => { @@ -3051,7 +3043,7 @@ type Gravatar @entity { if errs.iter().any(|err| { err.to_string().contains(&msg) || format!("{err:?}").contains(&msg) }) { - println!("{file_name} failed as expected: {errs:?}",) + // println!("{file_name} failed as expected: {errs:?}",) } else { let msgs: Vec<_> = errs.iter().map(|err| err.to_string()).collect(); panic!( @@ -3060,7 +3052,7 @@ type Gravatar @entity { } } (true, Ok(_)) => { - println!("{file_name} validated as expected") + // println!("{file_name} validated as expected") } } } @@ -3074,7 +3066,7 @@ mod tests { data::store::ID, prelude::DeploymentHash, schema::{ - input_schema::{POI_DIGEST, POI_OBJECT}, + input::{POI_DIGEST, POI_OBJECT}, EntityType, }, }; diff --git a/graph/src/schema/input/sqlexpr.rs b/graph/src/schema/input/sqlexpr.rs new file mode 100644 index 00000000000..5e0d8c95f6c --- /dev/null +++ b/graph/src/schema/input/sqlexpr.rs @@ -0,0 +1,333 @@ +//! Tools for parsing SQL expressions +use sqlparser::ast as p; +use sqlparser::dialect::PostgreSqlDialect; +use sqlparser::parser::{Parser as SqlParser, ParserError}; +use sqlparser::tokenizer::Tokenizer; + +use crate::schema::SchemaValidationError; + +pub(crate) trait CheckIdentFn: Fn(&str) -> Result<(), SchemaValidationError> {} + +impl CheckIdentFn for T where T: Fn(&str) -> Result<(), SchemaValidationError> {} + +/// Parse a SQL expression and check that it only uses whitelisted +/// operations and functions. The `check_ident` function is called for each +/// identifier in the expression +pub(crate) fn parse( + sql: &str, + check_ident: F, +) -> Result<(), Vec> { + let mut validator = Validator { + check_ident, + errors: Vec::new(), + }; + VisitExpr::visit(sql, &mut validator) + .map(|_| ()) + .map_err(|()| validator.errors) +} + +/// A visitor for `VistExpr` that gets called for the constructs for which +/// we need different behavior between validation and query generation in +/// `store/postgres/src/relational/rollup.rs`. Note that the visitor can +/// mutate both itself (e.g., to store errors) and the expression it is +/// visiting. +pub trait ExprVisitor { + /// Visit an identifier (column name). Must return `Err` if the + /// identifier is not allowed + fn visit_ident(&mut self, ident: &mut p::Ident) -> Result<(), ()>; + /// Visit a function name. Must return `Err` if the function is not + /// allowed + fn visit_func_name(&mut self, func: &mut p::Ident) -> Result<(), ()>; + /// Called when we encounter a construct that is not supported like a + /// subquery + fn not_supported(&mut self, msg: String); + /// Called if the SQL expression we are visiting has SQL syntax errors + fn parse_error(&mut self, e: sqlparser::parser::ParserError); +} + +pub struct VisitExpr<'a> { + visitor: Box<&'a mut dyn ExprVisitor>, +} + +impl<'a> VisitExpr<'a> { + fn nope(&mut self, construct: &str) -> Result<(), ()> { + self.not_supported(format!("Expressions using {construct} are not supported")) + } + + fn illegal_function(&mut self, msg: String) -> Result<(), ()> { + self.not_supported(format!("Illegal function: {msg}")) + } + + fn not_supported(&mut self, msg: String) -> Result<(), ()> { + self.visitor.not_supported(msg); + Err(()) + } + + /// Parse `sql` into an expression and traverse it, calling back into + /// `visitor` at the appropriate places. Return the parsed expression, + /// which might have been changed by the visitor, on success. On error, + /// return `Err(())`. The visitor will know the details of the error + /// since this can only happen if `visit_ident` or `visit_func_name` + /// returned an error, or `parse_error` or `not_supported` was called. + pub fn visit(sql: &str, visitor: &'a mut dyn ExprVisitor) -> Result { + let dialect = PostgreSqlDialect {}; + + let mut parser = SqlParser::new(&dialect); + let tokens = Tokenizer::new(&dialect, sql) + .with_unescape(true) + .tokenize_with_location() + .unwrap(); + parser = parser.with_tokens_with_locations(tokens); + let mut visit = VisitExpr { + visitor: Box::new(visitor), + }; + let mut expr = match parser.parse_expr() { + Ok(expr) => expr, + Err(e) => { + visitor.parse_error(e); + return Err(()); + } + }; + visit.visit_expr(&mut expr).map(|()| expr) + } + + fn visit_expr(&mut self, expr: &mut p::Expr) -> Result<(), ()> { + use p::Expr::*; + + match expr { + Identifier(ident) => self.visitor.visit_ident(ident), + BinaryOp { left, op, right } => { + self.check_binary_op(op)?; + self.visit_expr(left)?; + self.visit_expr(right)?; + Ok(()) + } + UnaryOp { op, expr } => { + self.check_unary_op(op)?; + self.visit_expr(expr)?; + Ok(()) + } + Function(func) => self.visit_func(func), + Value(_) => Ok(()), + Case { + operand, + conditions, + results, + else_result, + } => { + if let Some(operand) = operand { + self.visit_expr(operand)?; + } + for condition in conditions { + self.visit_expr(condition)?; + } + for result in results { + self.visit_expr(result)?; + } + if let Some(else_result) = else_result { + self.visit_expr(else_result)?; + } + Ok(()) + } + Cast { + expr, + data_type: _, + format: _, + } => self.visit_expr(expr), + Nested(expr) | IsFalse(expr) | IsNotFalse(expr) | IsTrue(expr) | IsNotTrue(expr) + | IsNull(expr) | IsNotNull(expr) => self.visit_expr(expr), + IsDistinctFrom(expr1, expr2) | IsNotDistinctFrom(expr1, expr2) => { + self.visit_expr(expr1)?; + self.visit_expr(expr2)?; + Ok(()) + } + CompoundIdentifier(_) => self.nope("CompoundIdentifier"), + JsonAccess { .. } => self.nope("JsonAccess"), + CompositeAccess { .. } => self.nope("CompositeAccess"), + IsUnknown(_) => self.nope("IsUnknown"), + IsNotUnknown(_) => self.nope("IsNotUnknown"), + InList { .. } => self.nope("InList"), + InSubquery { .. } => self.nope("InSubquery"), + InUnnest { .. } => self.nope("InUnnest"), + Between { .. } => self.nope("Between"), + Like { .. } => self.nope("Like"), + ILike { .. } => self.nope("ILike"), + SimilarTo { .. } => self.nope("SimilarTo"), + RLike { .. } => self.nope("RLike"), + AnyOp { .. } => self.nope("AnyOp"), + AllOp { .. } => self.nope("AllOp"), + Convert { .. } => self.nope("Convert"), + TryCast { .. } => self.nope("TryCast"), + SafeCast { .. } => self.nope("SafeCast"), + AtTimeZone { .. } => self.nope("AtTimeZone"), + Extract { .. } => self.nope("Extract"), + Ceil { .. } => self.nope("Ceil"), + Floor { .. } => self.nope("Floor"), + Position { .. } => self.nope("Position"), + Substring { .. } => self.nope("Substring"), + Trim { .. } => self.nope("Trim"), + Overlay { .. } => self.nope("Overlay"), + Collate { .. } => self.nope("Collate"), + IntroducedString { .. } => self.nope("IntroducedString"), + TypedString { .. } => self.nope("TypedString"), + MapAccess { .. } => self.nope("MapAccess"), + AggregateExpressionWithFilter { .. } => self.nope("AggregateExpressionWithFilter"), + Exists { .. } => self.nope("Exists"), + Subquery(_) => self.nope("Subquery"), + ArraySubquery(_) => self.nope("ArraySubquery"), + ListAgg(_) => self.nope("ListAgg"), + ArrayAgg(_) => self.nope("ArrayAgg"), + GroupingSets(_) => self.nope("GroupingSets"), + Cube(_) => self.nope("Cube"), + Rollup(_) => self.nope("Rollup"), + Tuple(_) => self.nope("Tuple"), + Struct { .. } => self.nope("Struct"), + Named { .. } => self.nope("Named"), + ArrayIndex { .. } => self.nope("ArrayIndex"), + Array(_) => self.nope("Array"), + Interval(_) => self.nope("Interval"), + MatchAgainst { .. } => self.nope("MatchAgainst"), + Wildcard => self.nope("Wildcard"), + QualifiedWildcard(_) => self.nope("QualifiedWildcard"), + } + } + + fn visit_func(&mut self, func: &mut p::Function) -> Result<(), ()> { + let p::Function { + name, + args: pargs, + filter, + null_treatment, + over, + distinct: _, + special: _, + order_by, + } = func; + + if filter.is_some() || null_treatment.is_some() || over.is_some() || !order_by.is_empty() { + return self.illegal_function(format!("call to {name} uses an illegal feature")); + } + + let idents = &mut name.0; + if idents.len() != 1 { + return self.illegal_function(format!( + "function name {name} uses a qualified name with '.'" + )); + } + self.visitor.visit_func_name(&mut idents[0])?; + for arg in pargs { + use p::FunctionArg::*; + match arg { + Named { .. } => { + return self.illegal_function(format!("call to {name} uses a named argument")); + } + Unnamed(arg) => match arg { + p::FunctionArgExpr::Expr(expr) => { + self.visit_expr(expr)?; + } + p::FunctionArgExpr::QualifiedWildcard(_) | p::FunctionArgExpr::Wildcard => { + return self + .illegal_function(format!("call to {name} uses a wildcard argument")); + } + }, + }; + } + Ok(()) + } + + fn check_binary_op(&mut self, op: &p::BinaryOperator) -> Result<(), ()> { + use p::BinaryOperator::*; + match op { + Plus | Minus | Multiply | Divide | Modulo | PGExp | Gt | Lt | GtEq | LtEq + | Spaceship | Eq | NotEq | And | Or => Ok(()), + StringConcat + | Xor + | BitwiseOr + | BitwiseAnd + | BitwiseXor + | DuckIntegerDivide + | MyIntegerDivide + | Custom(_) + | PGBitwiseXor + | PGBitwiseShiftLeft + | PGBitwiseShiftRight + | PGOverlap + | PGRegexMatch + | PGRegexIMatch + | PGRegexNotMatch + | PGRegexNotIMatch + | PGLikeMatch + | PGILikeMatch + | PGNotLikeMatch + | PGNotILikeMatch + | PGStartsWith + | PGCustomBinaryOperator(_) => { + self.not_supported(format!("binary operator {op} is not supported")) + } + } + } + + fn check_unary_op(&mut self, op: &p::UnaryOperator) -> Result<(), ()> { + use p::UnaryOperator::*; + match op { + Plus | Minus | Not => Ok(()), + PGBitwiseNot | PGSquareRoot | PGCubeRoot | PGPostfixFactorial | PGPrefixFactorial + | PGAbs => self.not_supported(format!("unary operator {op} is not supported")), + } + } +} + +/// An `ExprVisitor` that validates an expression +struct Validator { + check_ident: F, + errors: Vec, +} + +const FN_WHITELIST: [&'static str; 14] = [ + // Clearly deterministic functions from + // https://www.postgresql.org/docs/current/functions-math.html, Table + // 9.5. We could also add trig functions (Table 9.7 and 9.8), but under + // no circumstances random functions from Table 9.6 + "abs", "ceil", "ceiling", "div", "floor", "gcd", "lcm", "mod", "power", "sign", + // Conditional functions from + // https://www.postgresql.org/docs/current/functions-conditional.html. + "coalesce", "nullif", "greatest", "least", +]; + +impl ExprVisitor for Validator { + fn visit_ident(&mut self, ident: &mut p::Ident) -> Result<(), ()> { + match (self.check_ident)(&ident.value) { + Ok(()) => Ok(()), + Err(e) => { + self.errors.push(e); + Err(()) + } + } + } + + fn visit_func_name(&mut self, func: &mut p::Ident) -> Result<(), ()> { + let p::Ident { value, quote_style } = &func; + let whitelisted = match quote_style { + Some(_) => FN_WHITELIST.contains(&value.as_str()), + None => FN_WHITELIST + .iter() + .any(|name| name.eq_ignore_ascii_case(value)), + }; + if whitelisted { + Ok(()) + } else { + self.not_supported(format!("Function {func} is not supported")); + Err(()) + } + } + + fn not_supported(&mut self, msg: String) { + self.errors + .push(SchemaValidationError::ExprNotSupported(msg)); + } + + fn parse_error(&mut self, e: ParserError) { + self.errors + .push(SchemaValidationError::ExprParseError(e.to_string())); + } +} diff --git a/graph/src/schema/mod.rs b/graph/src/schema/mod.rs index 345bfda199d..9847ca93643 100644 --- a/graph/src/schema/mod.rs +++ b/graph/src/schema/mod.rs @@ -21,7 +21,7 @@ pub mod ast; mod entity_key; mod entity_type; mod fulltext; -mod input_schema; +mod input; pub use api::{is_introspection_field, APISchemaError, INTROSPECTION_QUERY_TYPE}; @@ -29,7 +29,9 @@ pub use api::{ApiSchema, ErrorPolicy}; pub use entity_key::EntityKey; pub use entity_type::{AsEntityTypeName, EntityType}; pub use fulltext::{FulltextAlgorithm, FulltextConfig, FulltextDefinition, FulltextLanguage}; -pub use input_schema::{ +pub use input::sqlexpr::{ExprVisitor, VisitExpr}; +pub(crate) use input::POI_OBJECT; +pub use input::{ kw, Aggregate, AggregateFn, Aggregation, AggregationInterval, AggregationMapping, Field, InputSchema, InterfaceType, ObjectOrInterface, ObjectType, TypeKind, }; @@ -179,6 +181,12 @@ pub enum SchemaValidationError { AggregationsNotSupported(Version), #[error("Using Int8 as the type for the `id` field is not supported with spec version {0}; please migrate the subgraph to the latest version")] IdTypeInt8NotSupported(Version), + #[error("{0}")] + ExprNotSupported(String), + #[error("Expressions can't us the function {0}")] + ExprIllegalFunction(String), + #[error("Failed to parse expression: {0}")] + ExprParseError(String), } /// A validated and preprocessed GraphQL schema for a subgraph. diff --git a/graph/src/schema/test_schemas/ts_expr_random.graphql b/graph/src/schema/test_schemas/ts_expr_random.graphql new file mode 100644 index 00000000000..4472fc7a498 --- /dev/null +++ b/graph/src/schema/test_schemas/ts_expr_random.graphql @@ -0,0 +1,14 @@ +# fail: ExprNotSupported("Function random is not supported") +# Random must not be allowed as it would introduce nondeterministic behavior +type Data @entity(timeseries: true) { + id: Int8! + timestamp: Int8! + price0: BigDecimal! + price1: BigDecimal! +} + +type Stats @aggregation(intervals: ["hour", "day"], source: "Data") { + id: Int8! + timestamp: Int8! + max_price: BigDecimal! @aggregate(fn: "max", arg: "random()") +} diff --git a/graph/src/schema/test_schemas/ts_expr_simple.graphql b/graph/src/schema/test_schemas/ts_expr_simple.graphql new file mode 100644 index 00000000000..ed15c14ceb3 --- /dev/null +++ b/graph/src/schema/test_schemas/ts_expr_simple.graphql @@ -0,0 +1,25 @@ +# valid: Minimal example +type Data @entity(timeseries: true) { + id: Int8! + timestamp: Int8! + price0: BigDecimal! + price1: BigDecimal! +} + +type Stats @aggregation(intervals: ["hour", "day"], source: "Data") { + id: Int8! + timestamp: Int8! + max_price: BigDecimal! @aggregate(fn: "max", arg: "greatest(price0, price1)") + abs_price: BigDecimal! @aggregate(fn: "sum", arg: "abs(price0) + abs(price1)") + price0_sq: BigDecimal! @aggregate(fn: "sum", arg: "power(price0, 2)") + sum_sq: BigDecimal! @aggregate(fn: "sum", arg: "price0 * price0") + sum_sq_cross: BigDecimal! @aggregate(fn: "sum", arg: "price0 * price1") + + max_some: BigDecimal! + @aggregate( + fn: "max" + arg: "case when price0 > price1 then price0 else 0 end" + ) + + max_cast: BigDecimal! @aggregate(fn: "sum", arg: "(price0/7)::int4") +} diff --git a/graph/src/schema/test_schemas/ts_expr_syntax_err.graphql b/graph/src/schema/test_schemas/ts_expr_syntax_err.graphql new file mode 100644 index 00000000000..fb5def64a12 --- /dev/null +++ b/graph/src/schema/test_schemas/ts_expr_syntax_err.graphql @@ -0,0 +1,13 @@ +# fail: ExprParseError("sql parser error: Expected an expression:, found: EOF") +type Data @entity(timeseries: true) { + id: Int8! + timestamp: Int8! + price0: BigDecimal! + price1: BigDecimal! +} + +type Stats @aggregation(intervals: ["hour", "day"], source: "Data") { + id: Int8! + timestamp: Int8! + max_price: BigDecimal! @aggregate(fn: "max", arg: "greatest(price0,") +} diff --git a/store/postgres/src/relational/rollup.rs b/store/postgres/src/relational/rollup.rs index 4e5bd7674a7..9a997a2bf0f 100644 --- a/store/postgres/src/relational/rollup.rs +++ b/store/postgres/src/relational/rollup.rs @@ -53,6 +53,7 @@ //! group by id, timestamp, ) //! select id, timestamp, , from combined //! ``` +use std::collections::HashSet; use std::fmt; use std::ops::Range; use std::sync::Arc; @@ -62,17 +63,87 @@ use diesel::{sql_query, PgConnection, RunQueryDsl as _}; use diesel::sql_types::{BigInt, Integer}; use graph::blockchain::BlockTime; use graph::components::store::{BlockNumber, StoreError}; +use graph::constraint_violation; use graph::data::store::IdType; -use graph::schema::{Aggregate, AggregateFn, Aggregation, AggregationInterval}; +use graph::schema::{ + Aggregate, AggregateFn, Aggregation, AggregationInterval, ExprVisitor, VisitExpr, +}; +use graph::sqlparser::ast as p; +use graph::sqlparser::parser::ParserError; use crate::relational::Table; use super::{Column, SqlName}; +/// Rewrite `expr` by replacing field names with column names and return the +/// rewritten SQL expression and the columns used in the expression +fn rewrite<'a>(table: &'a Table, expr: &str) -> Result<(String, Vec<&'a str>), StoreError> { + struct Rewriter<'a> { + table: &'a Table, + // All columns used in the expression + columns: HashSet<&'a str>, + // The first error we encounter. Any error here is really an + // oversight in the schema validation; that should have caught all + // possible problems + error: Option, + } + + impl<'a> ExprVisitor for Rewriter<'a> { + fn visit_ident(&mut self, ident: &mut p::Ident) -> Result<(), ()> { + match self.table.column_for_field(&ident.value) { + Ok(column) => { + self.columns.insert(&column.name); + ident.value = column.name.to_string(); + ident.quote_style = Some('"'); + Ok(()) + } + Err(e) => { + self.not_supported(e.to_string()); + Err(()) + } + } + } + + fn visit_func_name(&mut self, _func: &mut p::Ident) -> Result<(), ()> { + Ok(()) + } + + fn not_supported(&mut self, msg: String) { + if self.error.is_none() { + self.error = Some(constraint_violation!( + "Schema validation should have found expression errors: {}", + msg + )); + } + } + + fn parse_error(&mut self, e: ParserError) { + self.not_supported(e.to_string()) + } + } + + let mut visitor = Rewriter { + table, + columns: HashSet::new(), + error: None, + }; + let expr = match VisitExpr::visit(expr, &mut visitor) { + Ok(expr) => expr, + Err(()) => return Err(visitor.error.unwrap()), + }; + if let Some(e) = visitor.error { + return Err(e); + } + let mut columns = visitor.columns.into_iter().collect::>(); + columns.sort(); + Ok((expr.to_string(), columns)) +} + #[derive(Debug, Clone)] pub(crate) struct Agg<'a> { aggregate: &'a Aggregate, - src_column: &'a Column, + src_columns: Vec<&'a str>, + expr: String, agg_column: &'a Column, } @@ -82,41 +153,42 @@ impl<'a> Agg<'a> { src_table: &'a Table, agg_table: &'a Table, ) -> Result { - let src_column = src_table.column_for_field(&aggregate.arg.name)?; + let (expr, src_columns) = rewrite(src_table, &aggregate.arg)?; let agg_column = agg_table.column_for_field(&aggregate.name)?; Ok(Self { aggregate, - src_column, + src_columns, + expr, agg_column, }) } - fn aggregate_over(&self, src: &SqlName, time: &str, w: &mut dyn fmt::Write) -> fmt::Result { + fn aggregate_over(&self, src: &str, time: &str, w: &mut dyn fmt::Write) -> fmt::Result { use AggregateFn::*; match self.aggregate.func { - Sum => write!(w, "sum(\"{}\")", src)?, - Max => write!(w, "max(\"{}\")", src)?, - Min => write!(w, "min(\"{}\")", src)?, + Sum => write!(w, "sum({})", src)?, + Max => write!(w, "max({})", src)?, + Min => write!(w, "min({})", src)?, First => { let sql_type = self.agg_column.column_type.sql_type(); - write!(w, "arg_min_{}((\"{}\", {time}))", sql_type, src)? + write!(w, "arg_min_{}(({}, {time}))", sql_type, src)? } Last => { let sql_type = self.agg_column.column_type.sql_type(); - write!(w, "arg_max_{}((\"{}\", {time}))", sql_type, src)? + write!(w, "arg_max_{}(({}, {time}))", sql_type, src)? } Count => write!(w, "count(*)")?, } write!(w, " as \"{}\"", self.agg_column.name) } - /// Generate a SQL fragment `func(src_column) as agg_column` where + /// Generate a SQL fragment `func(expr) as agg_column` where /// `func` is the aggregation function. The `time` parameter is the name /// of the column with respect to which `first` and `last` should decide /// which values are earlier or later fn aggregate(&self, time: &str, w: &mut dyn fmt::Write) -> fmt::Result { - self.aggregate_over(&self.src_column.name, time, w) + self.aggregate_over(&self.expr, time, w) } /// Generate a SQL fragment `func(src_column) as agg_column` where @@ -130,7 +202,8 @@ impl<'a> Agg<'a> { Sum | Max | Min | First | Last => { // For these, combining and aggregating is done by the same // function - return self.aggregate_over(&self.agg_column.name, time, w); + let name = format!("\"{}\"", self.agg_column.name); + return self.aggregate_over(&name, time, w); } Count => write!(w, "sum(\"{}\")", self.agg_column.name)?, } @@ -273,11 +346,12 @@ impl<'a> RollupSql<'a> { " from (select id, timestamp/{secs}*{secs} as timestamp, " )?; write_dims(self.dimensions, w)?; - let agg_srcs = { + let agg_srcs: Vec<&str> = { let mut agg_srcs: Vec<_> = self .aggregates .iter() - .map(|agg| agg.src_column.name.as_str()) + .flat_map(|agg| &agg.src_columns) + .map(|col| *col) .filter(|&col| col != "id" && col != "timestamp") .collect(); agg_srcs.sort(); @@ -499,6 +573,7 @@ mod tests { id: Int8! timestamp: Int8! max: BigDecimal! @aggregate(fn: "max", arg: "price") + max_value: BigDecimal! @aggregate(fn: "max", arg: "price * amount") } type OpenClose @aggregation(intervals: ["day"], source: "Data") { @@ -538,9 +613,10 @@ mod tests { group by timestamp, "token""#; const TOTAL_SQL: &str = r#"\ - insert into "sgd007"."total_stats_day"(id, timestamp, block$, "max") \ - select max(id) as id, timestamp, $3, max("price") as "max" from (\ - select id, timestamp/86400*86400 as timestamp, "price" from "sgd007"."data" \ + insert into "sgd007"."total_stats_day"(id, timestamp, block$, "max", "max_value") \ + select max(id) as id, timestamp, $3, max("price") as "max", \ + max("price" * "amount") as "max_value" from (\ + select id, timestamp/86400*86400 as timestamp, "amount", "price" from "sgd007"."data" \ where "sgd007"."data".timestamp >= $1 and "sgd007"."data".timestamp < $2 \ order by "sgd007"."data".timestamp) data \ group by timestamp"#; diff --git a/store/test-store/tests/postgres/aggregation.rs b/store/test-store/tests/postgres/aggregation.rs index 984cb13c943..a5dce4ee20b 100644 --- a/store/test-store/tests/postgres/aggregation.rs +++ b/store/test-store/tests/postgres/aggregation.rs @@ -38,9 +38,12 @@ type Data @entity(timeseries: true) { timestamp: Int8! token: Bytes! sum: BigDecimal! @aggregate(fn: "sum", arg: "price") + sum_sq: BigDecimal! @aggregate(fn: "sum", arg: "price * price") max: BigDecimal! @aggregate(fn: "max", arg: "amount") - first: BigDecimal! @aggregate(fn: "first", arg: "amount") + first: BigDecimal @aggregate(fn: "first", arg: "amount") last: BigDecimal! @aggregate(fn: "last", arg: "amount") + value: BigDecimal! @aggregate(fn: "sum", arg: "price * amount") + totalValue: BigDecimal! @aggregate(fn: "sum", arg: "price * amount", cumulative: true) } type TotalStats @aggregation(intervals: ["hour"], source: "Data") { @@ -167,16 +170,24 @@ fn stats_hour(schema: &InputSchema) -> Vec> { // Stats_hour aggregations over BLOCKS[0..=1], i.e., at BLOCKS[2] let block2 = vec![ - entity! { schema => id: 11i64, timestamp: 0i64, token: TOKEN1.clone(), sum: bd(3), max: bd(10), first: bd(10), last: bd(2) }, - entity! { schema => id: 12i64, timestamp: 0i64, token: TOKEN2.clone(), sum: bd(3), max: bd(20), first: bd(1), last: bd(20) }, + entity! { schema => id: 11i64, timestamp: 0i64, token: TOKEN1.clone(), + sum: bd(3), sum_sq: bd(5), max: bd(10), first: bd(10), last: bd(2), + value: bd(14), totalValue: bd(14) }, + entity! { schema => id: 12i64, timestamp: 0i64, token: TOKEN2.clone(), + sum: bd(3), sum_sq: bd(5), max: bd(20), first: bd(1), last: bd(20), + value: bd(41), totalValue: bd(41) }, ]; let block3 = { let mut v1 = block2.clone(); // Stats_hour aggregations over BLOCKS[2], i.e., at BLOCKS[3] let mut v2 = vec![ - entity! { schema => id: 21i64, timestamp: 3600i64, token: TOKEN1.clone(), sum: bd(3), max: bd(30), first: bd(30), last: bd(30) }, - entity! { schema => id: 22i64, timestamp: 3600i64, token: TOKEN2.clone(), sum: bd(3), max: bd(3), first: bd(3), last: bd(3) }, + entity! { schema => id: 21i64, timestamp: 3600i64, token: TOKEN1.clone(), + sum: bd(3), sum_sq: bd(9), max: bd(30), first: bd(30), last: bd(30), + value: bd(90), totalValue: bd(104) }, + entity! { schema => id: 22i64, timestamp: 3600i64, token: TOKEN2.clone(), + sum: bd(3), sum_sq: bd(9), max: bd(3), first: bd(3), last: bd(3), + value: bd(9), totalValue: bd(50)}, ]; v1.append(&mut v2); v1