From ffe8856dba23128e190c9a841eba616c80f0ba8e Mon Sep 17 00:00:00 2001 From: Ping-Lin Chang Date: Mon, 20 Jun 2022 02:45:35 +0100 Subject: [PATCH] feat: add filter for list pipeline --- go.mod | 5 +- go.sum | 8 +- integration-test/rest-pipeline.js | 37 ++++ integration-test/rest.js | 29 ++- pkg/datamodel/datamodel.go | 16 +- pkg/handler/handler.go | 27 ++- pkg/repository/repository.go | 19 +- pkg/repository/transpiler.go | 323 ++++++++++++++++++++++++++++ pkg/service/mock_repository_test.go | 9 +- pkg/service/service.go | 7 +- pkg/usage/usage.go | 3 +- 11 files changed, 452 insertions(+), 31 deletions(-) create mode 100644 pkg/repository/transpiler.go diff --git a/go.mod b/go.mod index 9c8709268..f5426b810 100644 --- a/go.mod +++ b/go.mod @@ -11,13 +11,14 @@ require ( github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 github.com/grpc-ecosystem/grpc-gateway/v2 v2.10.3 github.com/iancoleman/strcase v0.2.0 - github.com/instill-ai/protogen-go v0.1.5-alpha.0.20220615161406-12d8be2b3938 + github.com/instill-ai/protogen-go v0.1.5-alpha.0.20220620010454-ca08bcdfc4dd github.com/instill-ai/usage-client v0.0.0-20220607201439-d646c37f5b02 github.com/instill-ai/x v0.1.0-alpha.0.20220604235252-39fcffc82edb github.com/knadh/koanf v1.4.1 github.com/mennanov/fieldmask-utils v0.5.0 github.com/rs/cors v1.8.2 github.com/stretchr/testify v1.7.2 + go.einride.tech/aip v0.54.1 go.temporal.io/sdk v1.15.0 go.uber.org/zap v1.21.0 golang.org/x/net v0.0.0-20220614195744-fb05da6f9022 @@ -67,7 +68,7 @@ require ( golang.org/x/sys v0.0.0-20220614162138-6c1b26c55098 // indirect golang.org/x/text v0.3.7 // indirect golang.org/x/time v0.0.0-20220609170525-579cf78fd858 // indirect - google.golang.org/genproto v0.0.0-20220614165028-45ed7f3ff16e // indirect + google.golang.org/genproto v0.0.0-20220614165028-45ed7f3ff16e gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index b220b6efe..6d0889c98 100644 --- a/go.sum +++ b/go.sum @@ -694,8 +694,8 @@ github.com/imdario/mergo v0.3.10/go.mod h1:jmQim1M+e3UYxmgPu/WyfjB3N3VflVyUjjjwH github.com/imdario/mergo v0.3.11/go.mod h1:jmQim1M+e3UYxmgPu/WyfjB3N3VflVyUjjjwH0dnCYA= github.com/imdario/mergo v0.3.12/go.mod h1:jmQim1M+e3UYxmgPu/WyfjB3N3VflVyUjjjwH0dnCYA= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= -github.com/instill-ai/protogen-go v0.1.5-alpha.0.20220615161406-12d8be2b3938 h1:LR/Zk/EyqXID8BKm+AnrPvd4L4qqCRD/9sLIWvL3KF0= -github.com/instill-ai/protogen-go v0.1.5-alpha.0.20220615161406-12d8be2b3938/go.mod h1:d9ebEdwMX2Las4OScym45qbQM+xcBQITqvq/8anTVas= +github.com/instill-ai/protogen-go v0.1.5-alpha.0.20220620010454-ca08bcdfc4dd h1:m6XdTfydrZwaQHqLIDMn0/hGp2C6e8HxRoO6YclrL+U= +github.com/instill-ai/protogen-go v0.1.5-alpha.0.20220620010454-ca08bcdfc4dd/go.mod h1:d9ebEdwMX2Las4OScym45qbQM+xcBQITqvq/8anTVas= github.com/instill-ai/usage-client v0.0.0-20220607201439-d646c37f5b02 h1:7dhRYHERy+NbvESpaQ0NPOo3CiiDpvLsARR90Ftkiqw= github.com/instill-ai/usage-client v0.0.0-20220607201439-d646c37f5b02/go.mod h1:saH0H46iHHMxBx+znN3CoE4IOylbTlpQUPj0Do06yKo= github.com/instill-ai/x v0.1.0-alpha.0.20220604235252-39fcffc82edb h1:70AJVfr463jWkgPQ1w281zsQ1LK/tOW5INTNc+yOBsI= @@ -1188,6 +1188,8 @@ github.com/yvasiyarov/gorelic v0.0.0-20141212073537-a9bba5b9ab50/go.mod h1:NUSPS github.com/yvasiyarov/newrelic_platform_go v0.0.0-20140908184405-b21fdbd4370f/go.mod h1:GlGEuHIJweS1mbCqG+7vt2nvWLzLLnRHbXz5JKd/Qbg= github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= gitlab.com/nyarla/go-crypt v0.0.0-20160106005555-d9a5dc2b789b/go.mod h1:T3BPAOm2cqquPa0MKWeNkmOM5RQsRhkrwMWonFMN7fE= +go.einride.tech/aip v0.54.1 h1:srys7sFWPixEqyOu0gWuZAC86p4UAnWJIQcA01Ys3R4= +go.einride.tech/aip v0.54.1/go.mod h1:tUzBlpbLzt0LbL2GcO7RHQyHdnVFK25KvfZ638MTbgk= go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= go.etcd.io/bbolt v1.3.3/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= go.etcd.io/bbolt v1.3.5/go.mod h1:G5EMThwa9y8QZGBClrRx5EY+Yw9kAhnjy3bSjsnlVTQ= @@ -1901,9 +1903,11 @@ gorm.io/gorm v1.21.4/go.mod h1:0HFTzE/SqkGTzK6TlDPPQbAYCluiVvhzoA1+aVyzenw= gorm.io/gorm v1.23.4/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk= gorm.io/gorm v1.23.6 h1:KFLdNgri4ExFFGTRGGFWON2P1ZN28+9SJRN8voOoYe0= gorm.io/gorm v1.23.6/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk= +gotest.tools v2.2.0+incompatible h1:VsBPFP1AI068pPrMxtb/S8Zkgf9xEmTLJjfM+P5UIEo= gotest.tools v2.2.0+incompatible/go.mod h1:DsYFclhRJ6vuDpmuTbkuFWG+y2sxOXAzmJt81HFBacw= gotest.tools/v3 v3.0.2/go.mod h1:3SzNCllyD9/Y+b5r9JIKQ474KzkZyqLqEfYqMsX94Bk= gotest.tools/v3 v3.0.3/go.mod h1:Z7Lb0S5l+klDB31fvDQX8ss/FlKDxtlFlw3Oa8Ymbl8= +gotest.tools/v3 v3.1.0 h1:rVV8Tcg/8jHUkPUorwjaMTtemIMVXfIPKiOqnhEhakk= gotest.tools/v3 v3.1.0/go.mod h1:fHy7eyTmJFO5bQbUsEGQ1v4m2J3Jz9eWL54TP2/ZuYQ= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/integration-test/rest-pipeline.js b/integration-test/rest-pipeline.js index 39a9bcccb..ad325ac08 100644 --- a/integration-test/rest-pipeline.js +++ b/integration-test/rest-pipeline.js @@ -218,6 +218,43 @@ export function CheckList() { [`GET /v1alpha/pipelines?page_size=100&page_token=${resFirst100.json().next_page_token} response next_page_token is empty`]: (r) => r.json().next_page_token == "", }); + // Filtering + check(http.request("GET", `${pipelineHost}/v1alpha/pipelines?filter=mode=MODE_SYNC`, null, {headers: {"Content-Type": "application/json",}}), { + [`GET /v1alpha/pipelines?filter=mode=MODE_SYNC response 200`]: (r) => r.status == 200, + [`GET /v1alpha/pipelines?filter=mode=MODE_SYNC response pipelines.length > 0`]: (r) => r.json().pipelines.length > 0, + }); + + check(http.request("GET", `${pipelineHost}/v1alpha/pipelines?filter=mode=MODE_SYNC%20AND%20state=STATE_ACTIVE`, null, {headers: {"Content-Type": "application/json",}}), { + [`GET /v1alpha/pipelines?filter=mode=MODE_SYNC%20AND%20state=STATE_ACTIVE response 200`]: (r) => r.status == 200, + [`GET /v1alpha/pipelines?filter=mode=MODE_SYNC%20AND%20state=STATE_ACTIVE response pipelines.length > 0`]: (r) => r.json().pipelines.length > 0, + }); + + check(http.request("GET", `${pipelineHost}/v1alpha/pipelines?filter=state=STATE_ACTIVE%20AND%20create_time>timestamp%28%222000-06-19T23:31:08.657Z%22%29`, null, {headers: {"Content-Type": "application/json",}}), { + [`GET /v1alpha/pipelines?filter=state=STATE_ACTIVE%20AND%20create_time%20>%20timestamp%28%222000-06-19T23:31:08.657Z%22%29 response 200`]: (r) => r.status == 200, + [`GET /v1alpha/pipelines?filter=state=STATE_ACTIVE%20AND%20create_time%20>%20timestamp%28%222000-06-19T23:31:08.657Z%22%29 response pipelines.length > 0`]: (r) => r.json().pipelines.length > 0, + }); + + // Get UUID for foreign resources + var srcConnUid = http.get(`${connectorHost}/v1alpha/source-connectors/source-http`, {}, {headers: {"Content-Type": "application/json"},}).json().source_connector.uid + var srcConnPermalink = `source-connectors/${srcConnUid}` + + var dstConnUid = http.get(`${connectorHost}/v1alpha/destination-connectors/destination-http`, {}, {headers: {"Content-Type": "application/json"},}).json().destination_connector.uid + var dstConnPermalink = `destination-connectors/${dstConnUid}` + + var modelUid = http.get(`${modelHost}/v1alpha/models/${constant.model_id}`, {}, {headers: {"Content-Type": "application/json"},}).json().model.uid + var modelInstUid = http.get(`${modelHost}/v1alpha/models/${constant.model_id}/instances/latest`, {}, {headers: {"Content-Type": "application/json"},}).json().instance.uid + var modelInstPermalink = `models/${modelUid}/instances/${modelInstUid}` + + check(http.request("GET", `${pipelineHost}/v1alpha/pipelines?filter=mode=MODE_SYNC%20AND%20recipe.source=%22${srcConnPermalink}%22`, null, {headers: {"Content-Type": "application/json",}}), { + [`GET /v1alpha/pipelines?filter=mode=MODE_SYNC%20AND%20recipe.source=%22${srcConnPermalink}%22 response 200`]: (r) => r.status == 200, + [`GET /v1alpha/pipelines?filter=mode=MODE_SYNC%20AND%20recipe.source=%22${srcConnPermalink}%22 response pipelines.length > 0`]: (r) => r.json().pipelines.length > 0, + }); + + check(http.request("GET", `${pipelineHost}/v1alpha/pipelines?filter=mode=MODE_SYNC%20AND%20recipe.source=%22${srcConnPermalink}%22%20AND%20recipe.model_instances:%22${modelInstPermalink}%22`, null, {headers: {"Content-Type": "application/json",}}), { + [`GET /v1alpha/pipelines?filter=mode=MODE_SYNC%20AND%20recipe.source=%22${srcConnPermalink}%22%20AND%20recipe.model_instances:%22${modelInstPermalink}%22 response 200`]: (r) => r.status == 200, + [`GET /v1alpha/pipelines?filter=mode=MODE_SYNC%20AND%20recipe.source=%22${srcConnPermalink}%22%20AND%20recipe.model_instances:%22${modelInstPermalink}%22 response pipelines.length > 0`]: (r) => r.json().pipelines.length > 0, + }); + // Delete the pipelines for (const reqBody of reqBodies) { check(http.request( diff --git a/integration-test/rest.js b/integration-test/rest.js index 2b04c28fc..5d626a16c 100644 --- a/integration-test/rest.js +++ b/integration-test/rest.js @@ -25,7 +25,8 @@ export let options = { export function setup() { group("Connector Backend API: Create a http source connector", function () { - check(http.request("POST", `${connectorHost}/v1alpha/source-connectors`, + + var res = http.request("POST", `${connectorHost}/v1alpha/source-connectors`, JSON.stringify({ "id": "source-http", "source_connector_definition": "source-connector-definitions/source-http", @@ -34,13 +35,16 @@ export function setup() { } }), { headers: { "Content-Type": "application/json" }, - }), { + }) + check(res, { "POST /v1alpha/source-connectors response status for creating directness HTTP source connector 201": (r) => r.status === 201, }) + }); group("Connector Backend API: Create a http destination connector", function () { - check(http.request("POST", `${connectorHost}/v1alpha/destination-connectors`, + + var res = http.request("POST", `${connectorHost}/v1alpha/destination-connectors`, JSON.stringify({ "id": "destination-http", "destination_connector_definition": "destination-connector-definitions/destination-http", @@ -49,13 +53,17 @@ export function setup() { } }), { headers: { "Content-Type": "application/json" }, - }), { + }) + + check(res, { "POST /v1alpha/destination-connectors response status for creating directness HTTP destination connector 201": (r) => r.status === 201, }) + }); group("Connector Backend API: Create a CSV destination connector", function () { - check(http.request("POST", `${connectorHost}/v1alpha/destination-connectors`, + + var res = http.request("POST", `${connectorHost}/v1alpha/destination-connectors`, JSON.stringify({ "id": constant.dstCSVConnID, "destination_connector_definition": "destination-connector-definitions/destination-csv", @@ -66,9 +74,12 @@ export function setup() { } }), { headers: { "Content-Type": "application/json" }, - }), { + }) + + check(res, { "POST /v1alpha/destination-connectors response status for creating CSV destination connector 201": (r) => r.status === 201, }) + }); group("Model Backend API: Deploy a detection model", function () { @@ -86,11 +97,13 @@ export function setup() { "POST /v1alpha/models:multipart task det response status": (r) => r.status === 201 }); - check(http.post(`${modelHost}/v1alpha/models/${constant.model_id}/instances/latest:deploy`, {}, { + var res = http.post(`${modelHost}/v1alpha/models/${constant.model_id}/instances/latest:deploy`, {}, { headers: { "Content-Type": "application/json" }, - }), { + }) + + check(res, { [`POST /v1alpha/models/${constant.model_id}/instances/latest:deploy online task det response status`]: (r) => r.status === 200 }); diff --git a/pkg/datamodel/datamodel.go b/pkg/datamodel/datamodel.go index ca9594de4..4b33244ad 100644 --- a/pkg/datamodel/datamodel.go +++ b/pkg/datamodel/datamodel.go @@ -47,28 +47,28 @@ type Pipeline struct { type PipelineMode pipelinePB.Pipeline_Mode // Scan function for custom GORM type PipelineMode -func (c *PipelineMode) Scan(value interface{}) error { - *c = PipelineMode(pipelinePB.Pipeline_Mode_value[value.(string)]) +func (p *PipelineMode) Scan(value interface{}) error { + *p = PipelineMode(pipelinePB.Pipeline_Mode_value[value.(string)]) return nil } // Value function for custom GORM type PipelineMode -func (c PipelineMode) Value() (driver.Value, error) { - return pipelinePB.Pipeline_Mode(c).String(), nil +func (p PipelineMode) Value() (driver.Value, error) { + return pipelinePB.Pipeline_Mode(p).String(), nil } // PipelineState is an alias type for Protobuf enum Pipeline.State type PipelineState pipelinePB.Pipeline_State // Scan function for custom GORM type PipelineState -func (c *PipelineState) Scan(value interface{}) error { - *c = PipelineState(pipelinePB.Pipeline_State_value[value.(string)]) +func (p *PipelineState) Scan(value interface{}) error { + *p = PipelineState(pipelinePB.Pipeline_State_value[value.(string)]) return nil } // Value function for custom GORM type PipelineState -func (c PipelineState) Value() (driver.Value, error) { - return pipelinePB.Pipeline_State(c).String(), nil +func (p PipelineState) Value() (driver.Value, error) { + return pipelinePB.Pipeline_State(p).String(), nil } // Recipe is the data model of the pipeline recipe diff --git a/pkg/handler/handler.go b/pkg/handler/handler.go index 34c31e8ba..9e772f3b6 100644 --- a/pkg/handler/handler.go +++ b/pkg/handler/handler.go @@ -10,6 +10,7 @@ import ( "github.com/gofrs/uuid" "github.com/gogo/status" "github.com/iancoleman/strcase" + "go.einride.tech/aip/filtering" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" @@ -113,7 +114,31 @@ func (h *handler) ListPipeline(ctx context.Context, req *pipelinePB.ListPipeline return &pipelinePB.ListPipelineResponse{}, err } - dbPipelines, totalSize, nextPageToken, err := h.service.ListPipeline(owner, req.GetPageSize(), req.GetPageToken(), isBasicView) + var mode pipelinePB.Pipeline_Mode + var state pipelinePB.Pipeline_State + declarations, err := filtering.NewDeclarations([]filtering.DeclarationOption{ + filtering.DeclareStandardFunctions(), + filtering.DeclareFunction("time.now", filtering.NewFunctionOverload("time.now", filtering.TypeTimestamp)), + filtering.DeclareIdent("uid", filtering.TypeString), + filtering.DeclareIdent("id", filtering.TypeString), + filtering.DeclareIdent("description", filtering.TypeString), + filtering.DeclareIdent("recipe", filtering.TypeMap(filtering.TypeString, filtering.TypeString)), + filtering.DeclareEnumIdent("mode", mode.Type()), + filtering.DeclareEnumIdent("state", state.Type()), + filtering.DeclareIdent("owner", filtering.TypeString), + filtering.DeclareIdent("create_time", filtering.TypeTimestamp), + filtering.DeclareIdent("update_time", filtering.TypeTimestamp), + }...) + if err != nil { + return &pipelinePB.ListPipelineResponse{}, err + } + + filter, err := filtering.ParseFilter(req, declarations) + if err != nil { + return &pipelinePB.ListPipelineResponse{}, err + } + + dbPipelines, totalSize, nextPageToken, err := h.service.ListPipeline(owner, req.GetPageSize(), req.GetPageToken(), isBasicView, filter) if err != nil { return &pipelinePB.ListPipelineResponse{}, err } diff --git a/pkg/repository/repository.go b/pkg/repository/repository.go index fea3e7b7e..31bbeaead 100644 --- a/pkg/repository/repository.go +++ b/pkg/repository/repository.go @@ -6,9 +6,11 @@ import ( "github.com/gofrs/uuid" "github.com/jackc/pgconn" + "go.einride.tech/aip/filtering" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "gorm.io/gorm" + "gorm.io/gorm/clause" "github.com/instill-ai/pipeline-backend/pkg/datamodel" "github.com/instill-ai/x/paginate" @@ -23,7 +25,7 @@ const MaxPageSize = 100 // Repository interface type Repository interface { CreatePipeline(pipeline *datamodel.Pipeline) error - ListPipeline(owner string, pageSize int64, pageToken string, isBasicView bool) ([]datamodel.Pipeline, int64, string, error) + ListPipeline(owner string, pageSize int64, pageToken string, isBasicView bool, filter filtering.Filter) ([]datamodel.Pipeline, int64, string, error) GetPipelineByID(id string, owner string, isBasicView bool) (*datamodel.Pipeline, error) GetPipelineByUID(uid uuid.UUID, owner string, isBasicView bool) (*datamodel.Pipeline, error) UpdatePipeline(id string, owner string, pipeline *datamodel.Pipeline) error @@ -55,7 +57,7 @@ func (r *repository) CreatePipeline(pipeline *datamodel.Pipeline) error { return nil } -func (r *repository) ListPipeline(owner string, pageSize int64, pageToken string, isBasicView bool) (pipelines []datamodel.Pipeline, totalSize int64, nextPageToken string, err error) { +func (r *repository) ListPipeline(owner string, pageSize int64, pageToken string, isBasicView bool, filter filtering.Filter) (pipelines []datamodel.Pipeline, totalSize int64, nextPageToken string, err error) { if result := r.db.Model(&datamodel.Pipeline{}).Where("owner = ?", owner).Count(&totalSize); result.Error != nil { return nil, 0, "", status.Errorf(codes.Internal, result.Error.Error()) @@ -83,6 +85,12 @@ func (r *repository) ListPipeline(owner string, pageSize int64, pageToken string queryBuilder.Omit("pipeline.recipe") } + if expr, err := r.transpileFilter(filter); err != nil { + return nil, 0, "", status.Errorf(codes.Internal, err.Error()) + } else if expr != nil { + queryBuilder.Clauses(expr) + } + var createTime time.Time rows, err := queryBuilder.Rows() if err != nil { @@ -189,3 +197,10 @@ func (r *repository) UpdatePipelineState(id string, owner string, state datamode } return nil } + +// TranspileFilter transpiles a parsed AIP filter expression to GORM DB clauses +func (r *repository) transpileFilter(filter filtering.Filter) (*clause.Expr, error) { + return (&Transpiler{ + filter: filter, + }).Transpile() +} diff --git a/pkg/repository/transpiler.go b/pkg/repository/transpiler.go new file mode 100644 index 000000000..7bb0b624b --- /dev/null +++ b/pkg/repository/transpiler.go @@ -0,0 +1,323 @@ +package repository + +import ( + "fmt" + "time" + + "go.einride.tech/aip/filtering" + "gorm.io/gorm/clause" + + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/reflect/protoregistry" + + expr "google.golang.org/genproto/googleapis/api/expr/v1alpha1" +) + +// Transpiler data +type Transpiler struct { + filter filtering.Filter +} + +// Transpile executes the transpilation on the filter +func (t *Transpiler) Transpile() (*clause.Expr, error) { + if t.filter.CheckedExpr == nil { + return nil, nil + } + expr, err := t.transpileExpr(t.filter.CheckedExpr.Expr) + if err != nil { + return nil, err + } + return expr, nil +} + +func (t *Transpiler) transpileExpr(e *expr.Expr) (*clause.Expr, error) { + switch e.ExprKind.(type) { + case *expr.Expr_CallExpr: + return t.transpileCallExpr(e) + case *expr.Expr_IdentExpr: + return t.transpileIdentExpr(e) + case *expr.Expr_ConstExpr: + return t.transpileConstExpr(e) + case *expr.Expr_SelectExpr: + return t.transpileSelectExpr(e) + default: + return nil, fmt.Errorf("unsupported expr: %v", e) + } +} + +func (t *Transpiler) transpileConstExpr(e *expr.Expr) (*clause.Expr, error) { + switch kind := e.GetConstExpr().ConstantKind.(type) { + case *expr.Constant_BoolValue: + return &clause.Expr{Vars: []interface{}{kind.BoolValue}}, nil + case *expr.Constant_DoubleValue: + return &clause.Expr{Vars: []interface{}{kind.DoubleValue}}, nil + case *expr.Constant_Int64Value: + return &clause.Expr{Vars: []interface{}{kind.Int64Value}}, nil + case *expr.Constant_StringValue: + return &clause.Expr{Vars: []interface{}{kind.StringValue}}, nil + case *expr.Constant_Uint64Value: + return &clause.Expr{Vars: []interface{}{kind.Uint64Value}}, nil + + default: + return nil, fmt.Errorf("unsupported const expr: %v", kind) + } +} + +func (t *Transpiler) transpileCallExpr(e *expr.Expr) (*clause.Expr, error) { + switch e.GetCallExpr().Function { + case filtering.FunctionHas: + return t.transpileHasCallExpr(e) + case filtering.FunctionEquals: + return t.transpileComparisonCallExpr(e, clause.Eq{}) + case filtering.FunctionNotEquals: + return t.transpileComparisonCallExpr(e, clause.Neq{}) + case filtering.FunctionLessThan: + return t.transpileComparisonCallExpr(e, clause.Lt{}) + case filtering.FunctionLessEquals: + return t.transpileComparisonCallExpr(e, clause.Lte{}) + case filtering.FunctionGreaterThan: + return t.transpileComparisonCallExpr(e, clause.Gt{}) + case filtering.FunctionGreaterEquals: + return t.transpileComparisonCallExpr(e, clause.Gte{}) + case filtering.FunctionAnd: + return t.transpileBinaryLogicalCallExpr(e, clause.AndConditions{}) + case filtering.FunctionOr: + return t.transpileBinaryLogicalCallExpr(e, clause.OrConditions{}) + case filtering.FunctionNot: + return t.transpileNotCallExpr(e) + case filtering.FunctionTimestamp: + return t.transpileTimestampCallExpr(e) + default: + return nil, fmt.Errorf("unsupported function call: %s", e.GetCallExpr().Function) + } +} + +func (t *Transpiler) transpileIdentExpr(e *expr.Expr) (*clause.Expr, error) { + + identExpr := e.GetIdentExpr() + identType, ok := t.filter.CheckedExpr.TypeMap[e.Id] + if !ok { + return nil, fmt.Errorf("unknown type of ident expr %d", e.Id) + } + if messageType := identType.GetMessageType(); messageType != "" { + if enumType, err := protoregistry.GlobalTypes.FindEnumByName(protoreflect.FullName(messageType)); err == nil { + if enumValue := enumType.Descriptor().Values().ByName(protoreflect.Name(identExpr.Name)); enumValue != nil { + // TODO: Configurable support for string literals. + return &clause.Expr{ + Vars: []interface{}{enumValue.Name()}, + WithoutParentheses: true, + }, nil + } + } + } + return &clause.Expr{ + SQL: identExpr.Name, + Vars: nil, + WithoutParentheses: true, + }, nil +} + +func (t *Transpiler) transpileSelectExpr(e *expr.Expr) (*clause.Expr, error) { + selectExpr := e.GetSelectExpr() + operand, err := t.transpileExpr(selectExpr.Operand) + if err != nil { + return nil, err + } + return &clause.Expr{ + SQL: fmt.Sprintf("%s ->> '%s'", operand.SQL, selectExpr.Field), + Vars: nil, + WithoutParentheses: true, + }, nil +} + +func (t *Transpiler) transpileNotCallExpr(e *expr.Expr) (*clause.Expr, error) { + callExpr := e.GetCallExpr() + if len(callExpr.Args) != 1 { + return nil, fmt.Errorf( + "unexpected number of arguments to `%s` expression: %d", + filtering.FunctionNot, + len(callExpr.Args), + ) + } + rhsExpr, err := t.transpileExpr(callExpr.Args[0]) + if err != nil { + return nil, err + } + return &clause.Expr{ + SQL: fmt.Sprintf("NOT %s", rhsExpr.SQL), + WithoutParentheses: true, + }, nil +} + +func (t *Transpiler) transpileComparisonCallExpr(e *expr.Expr, op interface{}) (*clause.Expr, error) { + callExpr := e.GetCallExpr() + if len(callExpr.Args) != 2 { + return nil, fmt.Errorf( + "unexpected number of arguments to `%s`: %d", + callExpr.GetFunction(), + len(callExpr.Args), + ) + } + + ident, err := t.transpileExpr(callExpr.Args[0]) + if err != nil { + return nil, err + } + + con, err := t.transpileExpr(callExpr.Args[1]) + if err != nil { + return nil, err + } + + var sql string + var vars []interface{} + switch op.(type) { + case clause.Eq: + sql = fmt.Sprintf("%s = ?", ident.SQL) + vars = append(vars, con.Vars...) + case clause.Neq: + sql = fmt.Sprintf("%s <> ?", ident.SQL) + vars = append(vars, con.Vars...) + case clause.Lt: + sql = fmt.Sprintf("%s < ?", ident.SQL) + vars = append(vars, con.Vars...) + case clause.Lte: + sql = fmt.Sprintf("%s <= ?", ident.SQL) + vars = append(vars, con.Vars...) + case clause.Gt: + sql = fmt.Sprintf("%s > ?", ident.SQL) + vars = append(vars, con.Vars...) + case clause.Gte: + sql = fmt.Sprintf("%s >= ?", ident.SQL) + vars = append(vars, con.Vars...) + } + + return &clause.Expr{ + SQL: sql, + Vars: vars, + WithoutParentheses: true, + }, nil +} + +func (t *Transpiler) transpileBinaryLogicalCallExpr(e *expr.Expr, op clause.Expression) (*clause.Expr, error) { + callExpr := e.GetCallExpr() + if len(callExpr.Args) != 2 { + return nil, fmt.Errorf( + "unexpected number of arguments to `%s`: %d", + callExpr.GetFunction(), + len(callExpr.Args), + ) + } + lhsExpr, err := t.transpileExpr(callExpr.Args[0]) + if err != nil { + return nil, err + } + rhsExpr, err := t.transpileExpr(callExpr.Args[1]) + if err != nil { + return nil, err + } + + var sql string + switch op.(type) { + case clause.AndConditions: + sql = fmt.Sprintf("%s AND %s", lhsExpr.SQL, rhsExpr.SQL) + case clause.OrConditions: + sql = fmt.Sprintf("%s OR %s", lhsExpr.SQL, rhsExpr.SQL) + } + + return &clause.Expr{ + SQL: sql, + Vars: append(lhsExpr.Vars, rhsExpr.Vars), + WithoutParentheses: true, + }, nil +} + +func (t *Transpiler) transpileHasCallExpr(e *expr.Expr) (*clause.Expr, error) { + callExpr := e.GetCallExpr() + if len(callExpr.Args) != 2 { + return nil, fmt.Errorf("unexpected number of arguments to `in` expression: %d", len(callExpr.Args)) + } + + if callExpr.Args[1].GetConstExpr() == nil { + return nil, fmt.Errorf("TODO: add support for transpiling `:` where RHS is other than Const") + } + + switch callExpr.Args[0].ExprKind.(type) { + case *expr.Expr_IdentExpr: + identExpr := callExpr.Args[0] + constExpr := callExpr.Args[1] + identType, ok := t.filter.CheckedExpr.TypeMap[callExpr.Args[0].Id] + if !ok { + return nil, fmt.Errorf("unknown type of ident expr %d", e.Id) + } + switch { + // Repeated primitives: + // > Repeated fields query to see if the repeated structure contains a matching element. + case identType.GetListType().GetElemType().GetPrimitive() != expr.Type_PRIMITIVE_TYPE_UNSPECIFIED: + iden, err := t.transpileIdentExpr(identExpr) + if err != nil { + return nil, err + } + con, err := t.transpileConstExpr(constExpr) + if err != nil { + return nil, err + } + return &clause.Expr{ + SQL: fmt.Sprintf("? = ANY(%s)", iden.SQL), + Vars: con.Vars, + WithoutParentheses: false, + }, nil + default: + return nil, fmt.Errorf("TODO: add support for transpiling `:` on other types than repeated primitives") + } + case *expr.Expr_SelectExpr: + operand := callExpr.Args[0].GetSelectExpr().Operand + field := callExpr.Args[0].GetSelectExpr().Field + constExpr := callExpr.Args[1] + + iden, err := t.transpileIdentExpr(operand) + if err != nil { + return nil, err + } + + con, err := t.transpileConstExpr(constExpr) + if err != nil { + return nil, err + } + con.Vars[0] = "%" + con.Vars[0].(string) + "%" + return &clause.Expr{ + SQL: fmt.Sprintf("%s ->> '%s' LIKE ?", iden.SQL, field), + Vars: con.Vars, + WithoutParentheses: false, + }, nil + default: + return nil, fmt.Errorf("TODO: add support for transpiling `:` where LHS is other than Ident and Select") + } + +} + +func (t *Transpiler) transpileTimestampCallExpr(e *expr.Expr) (*clause.Expr, error) { + + callExpr := e.GetCallExpr() + if len(callExpr.Args) != 1 { + return nil, fmt.Errorf( + "unexpected number of arguments to `%s`: %d", callExpr.Function, len(callExpr.Args), + ) + } + constArg, ok := callExpr.Args[0].ExprKind.(*expr.Expr_ConstExpr) + if !ok { + return nil, fmt.Errorf("expected constant string arg to %s", callExpr.Function) + } + stringArg, ok := constArg.ConstExpr.ConstantKind.(*expr.Constant_StringValue) + if !ok { + return nil, fmt.Errorf("expected constant string arg to %s", callExpr.Function) + } + timeArg, err := time.Parse(time.RFC3339, stringArg.StringValue) + if err != nil { + return nil, fmt.Errorf("invalid string arg to %s: %w", callExpr.Function, err) + } + return &clause.Expr{ + Vars: []interface{}{timeArg}, + WithoutParentheses: true, + }, nil +} diff --git a/pkg/service/mock_repository_test.go b/pkg/service/mock_repository_test.go index 37b7ea7dd..126b23ac9 100644 --- a/pkg/service/mock_repository_test.go +++ b/pkg/service/mock_repository_test.go @@ -10,6 +10,7 @@ import ( uuid "github.com/gofrs/uuid" gomock "github.com/golang/mock/gomock" datamodel "github.com/instill-ai/pipeline-backend/pkg/datamodel" + filtering "go.einride.tech/aip/filtering" ) // MockRepository is a mock of Repository interface. @@ -94,9 +95,9 @@ func (mr *MockRepositoryMockRecorder) GetPipelineByUID(arg0, arg1, arg2 interfac } // ListPipeline mocks base method. -func (m *MockRepository) ListPipeline(arg0 string, arg1 int64, arg2 string, arg3 bool) ([]datamodel.Pipeline, int64, string, error) { +func (m *MockRepository) ListPipeline(arg0 string, arg1 int64, arg2 string, arg3 bool, arg4 filtering.Filter) ([]datamodel.Pipeline, int64, string, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ListPipeline", arg0, arg1, arg2, arg3) + ret := m.ctrl.Call(m, "ListPipeline", arg0, arg1, arg2, arg3, arg4) ret0, _ := ret[0].([]datamodel.Pipeline) ret1, _ := ret[1].(int64) ret2, _ := ret[2].(string) @@ -105,9 +106,9 @@ func (m *MockRepository) ListPipeline(arg0 string, arg1 int64, arg2 string, arg3 } // ListPipeline indicates an expected call of ListPipeline. -func (mr *MockRepositoryMockRecorder) ListPipeline(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { +func (mr *MockRepositoryMockRecorder) ListPipeline(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListPipeline", reflect.TypeOf((*MockRepository)(nil).ListPipeline), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListPipeline", reflect.TypeOf((*MockRepository)(nil).ListPipeline), arg0, arg1, arg2, arg3, arg4) } // UpdatePipeline mocks base method. diff --git a/pkg/service/service.go b/pkg/service/service.go index fb17b3895..c378fe68c 100644 --- a/pkg/service/service.go +++ b/pkg/service/service.go @@ -11,6 +11,7 @@ import ( "github.com/go-redis/redis/v9" "github.com/gofrs/uuid" "github.com/gogo/status" + "go.einride.tech/aip/filtering" "google.golang.org/grpc/codes" "github.com/instill-ai/pipeline-backend/internal/resource" @@ -26,7 +27,7 @@ import ( // Service interface type Service interface { CreatePipeline(pipeline *datamodel.Pipeline) (*datamodel.Pipeline, error) - ListPipeline(ownerRscName string, pageSize int64, pageToken string, isBasicView bool) ([]datamodel.Pipeline, int64, string, error) + ListPipeline(ownerRscName string, pageSize int64, pageToken string, isBasicView bool, filter filtering.Filter) ([]datamodel.Pipeline, int64, string, error) GetPipelineByID(id string, ownerRscName string, isBasicView bool) (*datamodel.Pipeline, error) GetPipelineByUID(uid uuid.UUID, ownerRscName string, isBasicView bool) (*datamodel.Pipeline, error) UpdatePipeline(id string, ownerRscName string, updatedPipeline *datamodel.Pipeline) (*datamodel.Pipeline, error) @@ -103,14 +104,14 @@ func (s *service) CreatePipeline(dbPipeline *datamodel.Pipeline) (*datamodel.Pip return dbCreatedPipeline, nil } -func (s *service) ListPipeline(ownerRscName string, pageSize int64, pageToken string, isBasicView bool) ([]datamodel.Pipeline, int64, string, error) { +func (s *service) ListPipeline(ownerRscName string, pageSize int64, pageToken string, isBasicView bool, filter filtering.Filter) ([]datamodel.Pipeline, int64, string, error) { ownerPermalink, err := s.ownerRscNameToPermalink(ownerRscName) if err != nil { return nil, 0, "", status.Errorf(codes.InvalidArgument, err.Error()) } - dbPipelines, ps, pt, err := s.repository.ListPipeline(ownerPermalink, pageSize, pageToken, isBasicView) + dbPipelines, ps, pt, err := s.repository.ListPipeline(ownerPermalink, pageSize, pageToken, isBasicView, filter) if err != nil { return nil, 0, "", err } diff --git a/pkg/usage/usage.go b/pkg/usage/usage.go index e8075e184..a04c6e7c7 100644 --- a/pkg/usage/usage.go +++ b/pkg/usage/usage.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/go-redis/redis/v9" + "go.einride.tech/aip/filtering" "github.com/instill-ai/pipeline-backend/internal/logger" "github.com/instill-ai/pipeline-backend/pkg/datamodel" @@ -64,7 +65,7 @@ func (u *usage) RetrieveUsageData() interface{} { pipeSyncModeNum := int64(0) pipeAsyncModeNum := int64(0) for { - dbPipelines, _, pipeNextPageToken, err := u.repository.ListPipeline(fmt.Sprintf("users/%s", user.GetUid()), int64(repository.MaxPageSize), pipePageToken, true) + dbPipelines, _, pipeNextPageToken, err := u.repository.ListPipeline(fmt.Sprintf("users/%s", user.GetUid()), int64(repository.MaxPageSize), pipePageToken, true, filtering.Filter{}) if err != nil { logger.Error(fmt.Sprintf("%s", err)) }