From 06fef3aaaf3b3a42224543c69a4642e884f351ee Mon Sep 17 00:00:00 2001 From: Christopher DEBOVE Date: Thu, 16 Jan 2025 13:04:54 +0100 Subject: [PATCH] feat: add is empty, is not empty, starts with, ends with filters operators --- models/ast/ast_filter.go | 10 ++ repositories/ingested_data_read_repository.go | 41 ++++- .../ingested_data_read_repository_test.go | 28 ++-- .../ast_eval/evaluate/evaluate_aggregator.go | 26 +++- usecases/ast_eval/evaluate/evaluate_filter.go | 4 + .../ast_eval/evaluate/evaluate_filter_test.go | 146 ++++++++++++++++-- 6 files changed, 218 insertions(+), 37 deletions(-) diff --git a/models/ast/ast_filter.go b/models/ast/ast_filter.go index f69ceaf3e..4febc72ec 100644 --- a/models/ast/ast_filter.go +++ b/models/ast/ast_filter.go @@ -1,5 +1,7 @@ package ast +import "slices" + type FilterOperator string const ( @@ -11,9 +13,17 @@ const ( FILTER_LESSER_OR_EQUAL FilterOperator = "<=" FILTER_IS_IN_LIST FilterOperator = "IsInList" FILTER_IS_NOT_IN_LIST FilterOperator = "IsNotInList" + FILTER_IS_EMPTY FilterOperator = "IsEmpty" + FILTER_IS_NOT_EMPTY FilterOperator = "IsNotEmpty" + FILTER_STARTS_WITH FilterOperator = "StringStartsWith" + FILTER_ENDS_WITH FilterOperator = "StringEndsWith" FILTER_UNKNOWN_OPERATION FilterOperator = "FILTER_UNKNOWN_OPERATION" ) +func (op FilterOperator) IsUnary() bool { + return slices.Contains([]FilterOperator{FILTER_IS_EMPTY, FILTER_IS_NOT_EMPTY}, op) +} + type Filter struct { TableName string FieldName string diff --git a/repositories/ingested_data_read_repository.go b/repositories/ingested_data_read_repository.go index efb4a003b..ff39a1f4f 100644 --- a/repositories/ingested_data_read_repository.go +++ b/repositories/ingested_data_read_repository.go @@ -12,6 +12,12 @@ import ( "github.com/checkmarble/marble-backend/models/ast" ) +// Define where to put that +type FilterWithType struct { + Filter ast.Filter + FieldType models.DataType +} + type IngestedDataReadRepository interface { GetDbField(ctx context.Context, exec Executor, readParams models.DbFieldReadParams) (any, error) ListAllObjectIdsFromTable( @@ -33,7 +39,7 @@ type IngestedDataReadRepository interface { fieldName string, fieldType models.DataType, aggregator ast.Aggregator, - filters []ast.Filter, + filters []FilterWithType, ) (any, error) } @@ -314,7 +320,7 @@ func createQueryAggregated( fieldName string, fieldType models.DataType, aggregator ast.Aggregator, - filters []ast.Filter, + filters []FilterWithType, ) (squirrel.SelectBuilder, error) { var selectExpression string if aggregator == ast.AGGREGATOR_COUNT_DISTINCT { @@ -339,7 +345,8 @@ func createQueryAggregated( var err error for _, filter := range filters { - query, err = addConditionForOperator(query, qualifiedTableName, filter.FieldName, filter.Operator, filter.Value) + query, err = addConditionForOperator(query, qualifiedTableName, + filter.Filter.FieldName, filter.FieldType, filter.Filter.Operator, filter.Filter.Value) if err != nil { return squirrel.SelectBuilder{}, err } @@ -354,7 +361,7 @@ func (repo *IngestedDataReadRepositoryImpl) QueryAggregatedValue( fieldName string, fieldType models.DataType, aggregator ast.Aggregator, - filters []ast.Filter, + filters []FilterWithType, ) (any, error) { if err := validateClientDbExecutor(exec); err != nil { return nil, err @@ -377,7 +384,7 @@ func (repo *IngestedDataReadRepositoryImpl) QueryAggregatedValue( return result, nil } -func addConditionForOperator(query squirrel.SelectBuilder, tableName string, fieldName string, +func addConditionForOperator(query squirrel.SelectBuilder, tableName string, fieldName string, fieldType models.DataType, operator ast.FilterOperator, value any, ) (squirrel.SelectBuilder, error) { switch operator { @@ -393,6 +400,30 @@ func addConditionForOperator(query squirrel.SelectBuilder, tableName string, fie return query.Where(squirrel.Lt{fmt.Sprintf("%s.%s", tableName, fieldName): value}), nil case ast.FILTER_LESSER_OR_EQUAL: return query.Where(squirrel.LtOrEq{fmt.Sprintf("%s.%s", tableName, fieldName): value}), nil + case ast.FILTER_IS_EMPTY: + orCondition := squirrel.Or{ + squirrel.Eq{fmt.Sprintf("%s.%s", tableName, fieldName): nil}, + } + if fieldType == models.String { + orCondition = append(orCondition, squirrel.Eq{ + fmt.Sprintf("%s.%s", tableName, fieldName): "", + }) + } + return query.Where(orCondition), nil + case ast.FILTER_IS_NOT_EMPTY: + andCondition := squirrel.And{ + squirrel.NotEq{fmt.Sprintf("%s.%s", tableName, fieldName): nil}, + } + if fieldType == models.String { + andCondition = append(andCondition, + squirrel.NotEq{fmt.Sprintf("%s.%s", tableName, fieldName): ""}, + ) + } + return query.Where(andCondition), nil + case ast.FILTER_STARTS_WITH: + return query.Where(squirrel.Like{fmt.Sprintf("%s.%s", tableName, fieldName): fmt.Sprintf("%s%%", value)}), nil + case ast.FILTER_ENDS_WITH: + return query.Where(squirrel.Like{fmt.Sprintf("%s.%s", tableName, fieldName): fmt.Sprintf("%%%s", value)}), nil default: return query, fmt.Errorf("unknown operator %s: %w", operator, models.BadParameterError) } diff --git a/repositories/ingested_data_read_repository_test.go b/repositories/ingested_data_read_repository_test.go index 325ea28d8..0f6a23e58 100644 --- a/repositories/ingested_data_read_repository_test.go +++ b/repositories/ingested_data_read_repository_test.go @@ -105,7 +105,7 @@ func TestIngestedDataQueryAggregatedValueWithoutFilter(t *testing.T) { utils.DummyFieldNameForInt, models.Int, ast.AGGREGATOR_AVG, - []ast.Filter{}, + []FilterWithType{}, ) assert.Empty(t, err) sql, args, err := query.ToSql() @@ -123,7 +123,7 @@ func TestIngestedDataQueryCountWithoutFilter(t *testing.T) { utils.DummyFieldNameForInt, models.Int, ast.AGGREGATOR_COUNT, - []ast.Filter{}) + []FilterWithType{}) assert.Empty(t, err) sql, args, err := query.ToSql() assert.Empty(t, err) @@ -134,18 +134,24 @@ func TestIngestedDataQueryCountWithoutFilter(t *testing.T) { } func TestIngestedDataQueryAggregatedValueWithFilter(t *testing.T) { - filters := []ast.Filter{ + filters := []FilterWithType{ { - TableName: utils.DummyTableNameFirst, - FieldName: utils.DummyFieldNameForInt, - Operator: ast.FILTER_EQUAL, - Value: 1, + Filter: ast.Filter{ + TableName: utils.DummyTableNameFirst, + FieldName: utils.DummyFieldNameForInt, + Operator: ast.FILTER_EQUAL, + Value: 1, + }, + FieldType: models.Int, }, { - TableName: utils.DummyTableNameFirst, - FieldName: utils.DummyFieldNameForBool, - Operator: ast.FILTER_NOT_EQUAL, - Value: true, + Filter: ast.Filter{ + TableName: utils.DummyTableNameFirst, + FieldName: utils.DummyFieldNameForBool, + Operator: ast.FILTER_NOT_EQUAL, + Value: true, + }, + FieldType: models.Bool, }, } diff --git a/usecases/ast_eval/evaluate/evaluate_aggregator.go b/usecases/ast_eval/evaluate/evaluate_aggregator.go index 17d9b8037..8e02f38d9 100644 --- a/usecases/ast_eval/evaluate/evaluate_aggregator.go +++ b/usecases/ast_eval/evaluate/evaluate_aggregator.go @@ -73,6 +73,7 @@ func (a AggregatorEvaluator) Evaluate(ctx context.Context, arguments ast.Argumen } // Filters validation + var filtersWithType []repositories.FilterWithType if len(filters) > 0 { for _, filter := range filters { if filter.TableName != tableName { @@ -82,14 +83,28 @@ func (a AggregatorEvaluator) Evaluate(ctx context.Context, arguments ast.Argumen ast.NewNamedArgumentError("filters"), )) } - // At the first nil filter value found, stop and just return the default value for the aggregator - if filter.Value == nil { + + // At the first nil filter value found if we're not on an unary operator, stop and just return the default value for the aggregator + if filter.Value == nil && !filter.Operator.IsUnary() { return a.defaultValueForAggregator(aggregator) } + + filterFieldType, err := getFieldType(a.DataModel, filter.TableName, filter.FieldName) + if err != nil { + return MakeEvaluateError(errors.Join( + errors.Wrap(err, fmt.Sprintf("field type for %s.%s not found in data model in Evaluate aggregator", filter.TableName, filter.FieldName)), + ast.NewNamedArgumentError("fieldName"), + )) + } + + filtersWithType = append(filtersWithType, repositories.FilterWithType{ + Filter: filter, + FieldType: filterFieldType, + }) } } - result, err := a.runQueryInRepository(ctx, tableName, fieldName, fieldType, aggregator, filters) + result, err := a.runQueryInRepository(ctx, tableName, fieldName, fieldType, aggregator, filtersWithType) if err != nil { return MakeEvaluateError(errors.Wrap(err, "Error running aggregation query in repository")) } @@ -107,7 +122,7 @@ func (a AggregatorEvaluator) runQueryInRepository( fieldName string, fieldType models.DataType, aggregator ast.Aggregator, - filters []ast.Filter, + filters []repositories.FilterWithType, ) (any, error) { if a.ReturnFakeValue { return DryRunQueryAggregatedValue(a.DataModel, tableName, fieldName, aggregator) @@ -117,7 +132,8 @@ func (a AggregatorEvaluator) runQueryInRepository( if err != nil { return nil, err } - return a.IngestedDataReadRepository.QueryAggregatedValue(ctx, db, tableName, fieldName, fieldType, aggregator, filters) + return a.IngestedDataReadRepository.QueryAggregatedValue(ctx, db, tableName, + fieldName, fieldType, aggregator, filters) } func (a AggregatorEvaluator) defaultValueForAggregator(aggregator ast.Aggregator) (any, []error) { diff --git a/usecases/ast_eval/evaluate/evaluate_filter.go b/usecases/ast_eval/evaluate/evaluate_filter.go index 8be4c72a8..94a0035f5 100644 --- a/usecases/ast_eval/evaluate/evaluate_filter.go +++ b/usecases/ast_eval/evaluate/evaluate_filter.go @@ -25,6 +25,10 @@ var validTypeForFilterOperators = map[ast.FilterOperator][]models.DataType{ ast.FILTER_LESSER_OR_EQUAL: {models.Int, models.Float, models.String, models.Timestamp}, ast.FILTER_IS_IN_LIST: {models.String}, ast.FILTER_IS_NOT_IN_LIST: {models.String}, + ast.FILTER_IS_EMPTY: {models.Int, models.Float, models.String, models.Timestamp}, + ast.FILTER_IS_NOT_EMPTY: {models.Int, models.Float, models.String, models.Timestamp}, + ast.FILTER_STARTS_WITH: {models.String}, + ast.FILTER_ENDS_WITH: {models.String}, } func (f FilterEvaluator) Evaluate(ctx context.Context, arguments ast.Arguments) (any, []error) { diff --git a/usecases/ast_eval/evaluate/evaluate_filter_test.go b/usecases/ast_eval/evaluate/evaluate_filter_test.go index 54214450e..ff45d8b56 100644 --- a/usecases/ast_eval/evaluate/evaluate_filter_test.go +++ b/usecases/ast_eval/evaluate/evaluate_filter_test.go @@ -23,7 +23,7 @@ var dataModel = models.DataModel{ }, }, } -var filter = FilterEvaluator{DataModel: dataModel} +var filterWithBool = FilterEvaluator{DataModel: dataModel} func TestFilter(t *testing.T) { arguments := ast.Arguments{ @@ -38,12 +38,12 @@ func TestFilter(t *testing.T) { TableName: "table1", FieldName: "field1", Operator: ast.FILTER_EQUAL, - Value: 1, + Value: true, } - result, errs := filter.Evaluate(context.TODO(), arguments) + result, errs := filterWithBool.Evaluate(context.TODO(), arguments) assert.Empty(t, errs) - assert.ObjectsAreEqualValues(expectedResult, result) + assert.EqualValues(t, expectedResult, result) } func TestFilter_tableName_not_string(t *testing.T) { @@ -55,7 +55,7 @@ func TestFilter_tableName_not_string(t *testing.T) { "value": 1, }, } - _, errs := filter.Evaluate(context.TODO(), arguments) + _, errs := filterWithBool.Evaluate(context.TODO(), arguments) assert.NotEmpty(t, errs) } @@ -68,7 +68,7 @@ func TestFilter_fieldName_not_string(t *testing.T) { "value": 1, }, } - _, errs := filter.Evaluate(context.TODO(), arguments) + _, errs := filterWithBool.Evaluate(context.TODO(), arguments) assert.NotEmpty(t, errs) } @@ -81,7 +81,7 @@ func TestFilter_field_unknown(t *testing.T) { "value": 1, }, } - _, errs := filter.Evaluate(context.TODO(), arguments) + _, errs := filterWithBool.Evaluate(context.TODO(), arguments) assert.NotEmpty(t, errs) } @@ -94,7 +94,7 @@ func TestFilter_operator_invalid(t *testing.T) { "value": 1, }, } - _, errs := filter.Evaluate(context.TODO(), arguments) + _, errs := filterWithBool.Evaluate(context.TODO(), arguments) assert.NotEmpty(t, errs) } @@ -107,7 +107,7 @@ func TestFilter_operator_unknown(t *testing.T) { "value": 1, }, } - _, errs := filter.Evaluate(context.TODO(), arguments) + _, errs := filterWithBool.Evaluate(context.TODO(), arguments) assert.NotEmpty(t, errs) } @@ -120,7 +120,7 @@ func TestFilter_fieldType_incompatible(t *testing.T) { "value": 1, }, } - _, errs := filter.Evaluate(context.TODO(), arguments) + _, errs := filterWithBool.Evaluate(context.TODO(), arguments) assert.NotEmpty(t, errs) } @@ -133,7 +133,7 @@ func TestFilter_value_incompatible(t *testing.T) { "value": "incompatible_value", }, } - _, errs := filter.Evaluate(context.TODO(), arguments) + _, errs := filterWithBool.Evaluate(context.TODO(), arguments) assert.NotEmpty(t, errs) } @@ -171,7 +171,7 @@ func TestFilter_value_float(t *testing.T) { result, errs := filterWithInt.Evaluate(context.TODO(), arguments) assert.Empty(t, errs) - assert.ObjectsAreEqualValues(expectedResult, result) + assert.EqualValues(t, expectedResult, result) } var dataModelWithString = models.DataModel{ @@ -208,7 +208,7 @@ func TestFilter_is_in_list(t *testing.T) { result, errs := filterWithString.Evaluate(context.TODO(), arguments) assert.Empty(t, errs) - assert.ObjectsAreEqualValues(expectedResult, result) + assert.EqualValues(t, expectedResult, result) } func TestFilter_is_not_in_list(t *testing.T) { @@ -224,13 +224,13 @@ func TestFilter_is_not_in_list(t *testing.T) { expectedResult := ast.Filter{ TableName: "table1", FieldName: "field1", - // Operator: ast.FILTER_IS_NOT_IN_LIST, - Value: []string{"a", "b"}, + Operator: ast.FILTER_IS_NOT_IN_LIST, + Value: []string{"a", "b"}, } result, errs := filterWithString.Evaluate(context.TODO(), arguments) assert.Empty(t, errs) - assert.ObjectsAreEqualValues(expectedResult, result) + assert.EqualValues(t, expectedResult, result) } func TestFilter_is_in_list_invalid_value_type(t *testing.T) { @@ -260,3 +260,117 @@ func TestFilter_is_in_list_invalid_field_type(t *testing.T) { _, errs := filterWithInt.Evaluate(context.TODO(), arguments) assert.NotEmpty(t, errs) } + +func TestFilter_is_empty(t *testing.T) { + arguments := ast.Arguments{ + NamedArgs: map[string]any{ + "tableName": "table1", + "fieldName": "field1", + "operator": "IsEmpty", + }, + } + + expectedResult := ast.Filter{ + TableName: "table1", + FieldName: "field1", + Operator: ast.FILTER_IS_EMPTY, + Value: nil, + } + result, errs := filterWithString.Evaluate(context.TODO(), arguments) + assert.Empty(t, errs) + + assert.EqualValues(t, expectedResult, result) +} + +func TestFilter_is_not_empty(t *testing.T) { + arguments := ast.Arguments{ + NamedArgs: map[string]any{ + "tableName": "table1", + "fieldName": "field1", + "operator": "IsNotEmpty", + }, + } + + expectedResult := ast.Filter{ + TableName: "table1", + FieldName: "field1", + Operator: ast.FILTER_IS_NOT_EMPTY, + Value: nil, + } + result, errs := filterWithString.Evaluate(context.TODO(), arguments) + assert.Empty(t, errs) + + assert.EqualValues(t, expectedResult, result) +} + +func TestFilter_starts_with(t *testing.T) { + arguments := ast.Arguments{ + NamedArgs: map[string]any{ + "tableName": "table1", + "fieldName": "field1", + "operator": "StringStartsWith", + "value": "some_value", + }, + } + + expectedResult := ast.Filter{ + TableName: "table1", + FieldName: "field1", + Operator: ast.FILTER_STARTS_WITH, + Value: "some_value", + } + result, errs := filterWithString.Evaluate(context.TODO(), arguments) + assert.Empty(t, errs) + + assert.EqualValues(t, expectedResult, result) +} + +func TestFilter_starts_with_wrong_value_type(t *testing.T) { + arguments := ast.Arguments{ + NamedArgs: map[string]any{ + "tableName": "table1", + "fieldName": "field1", + "operator": "StringStartsWith", + "value": 1, + }, + } + + _, errs := filterWithString.Evaluate(context.TODO(), arguments) + assert.NotEmpty(t, errs) +} + +func TestFilter_ends_with(t *testing.T) { + arguments := ast.Arguments{ + NamedArgs: map[string]any{ + "tableName": "table1", + "fieldName": "field1", + "operator": "StringEndsWith", + "value": "some_value", + }, + } + + expectedResult := ast.Filter{ + TableName: "table1", + FieldName: "field1", + Operator: ast.FILTER_ENDS_WITH, + Value: "some_value", + } + result, errs := filterWithString.Evaluate(context.TODO(), arguments) + assert.Empty(t, errs) + + assert.EqualValues(t, expectedResult, result) +} + +func TestFilter_ends_with_wrong_value_type(t *testing.T) { + arguments := ast.Arguments{ + NamedArgs: map[string]any{ + "tableName": "table1", + "fieldName": "field1", + "operator": "StringEndsWith", + "value": 1, + }, + } + + _, errs := filterWithString.Evaluate(context.TODO(), arguments) + assert.NotEmpty(t, errs) +}