From fbac4456d3ff093ffe027f13a323b31b95a6e6e6 Mon Sep 17 00:00:00 2001 From: scbizu Date: Wed, 26 Jul 2023 14:03:55 +0800 Subject: [PATCH] internal/parser,v2/pkg/orm: support mysql functions and adds generic ToResult helper MySQL driver will scan `any` type field into `[]byte` as [it mentioned](https://github.com/go-sql-driver/mysql/issues/441) However, we cannot determine the underlying type of every mysql functions currently. We need a helper function to generic the `[]byte` -> wellknown SQL type(e.g.: `NullInt64`) to avoid spamming the conversion code all over our business codebase. --- e2e/mysqlr/gen_methods.go | 123 +++++++++++++++++++++++++ e2e/mysqlr/mysqlr_test.go | 30 ++++++ e2e/mysqlr/sqls/blog_aggr.sql | 6 ++ e2e/mysqlr/sqls/blog_func.sql | 8 ++ internal/parser/x/query/tidb_parser.go | 23 +++++ v2/pkg/orm/any.go | 56 +++++++++++ 6 files changed, 246 insertions(+) create mode 100644 e2e/mysqlr/sqls/blog_aggr.sql create mode 100644 e2e/mysqlr/sqls/blog_func.sql create mode 100644 v2/pkg/orm/any.go diff --git a/e2e/mysqlr/gen_methods.go b/e2e/mysqlr/gen_methods.go index 862fa0a..83b7c39 100644 --- a/e2e/mysqlr/gen_methods.go +++ b/e2e/mysqlr/gen_methods.go @@ -111,3 +111,126 @@ func (m *sqlMethods) Blog(ctx context.Context, req *BlogReq, opts ...RawQueryOpt } return results, nil } + +type BlogAggrResp struct { + Count any `sql:"count"` +} + +type BlogAggrReq struct { + Id int64 `sql:"id"` +} + +func (req *BlogAggrReq) Params() []any { + var params []any + + if req.Id != 0 { + params = append(params, req.Id) + } + + return params +} + +func (req *BlogAggrReq) Condition() string { + var conditions []string + if req.Id != 0 { + conditions = append(conditions, "id = ?") + } + var query string + if len(conditions) > 0 { + query += " WHERE " + strings.Join(conditions, " AND ") + } + return query +} + +const _BlogAggrSQL = "SELECT COUNT(`id`) AS `count` FROM `blogs` %s" + +// BlogAggr is a raw query handler generated function for `e2e/mysqlr/sqls/blog_aggr.sql`. +func (m *sqlMethods) BlogAggr(ctx context.Context, req *BlogAggrReq, opts ...RawQueryOptionHandler) ([]*BlogAggrResp, error) { + + rawQueryOption := &RawQueryOption{} + + for _, o := range opts { + o(rawQueryOption) + } + + query := fmt.Sprintf(_BlogAggrSQL, req.Condition()) + + rows, err := db.GetMysql(db.WithDB(rawQueryOption.db)).QueryContext(ctx, query, req.Params()...) + if err != nil { + return nil, err + } + defer rows.Close() + + var results []*BlogAggrResp + for rows.Next() { + var o BlogAggrResp + err = rows.Scan(&o.Count) + if err != nil { + return nil, err + } + results = append(results, &o) + } + return results, nil +} + +type BlogFuncResp struct { + UTitle any `sql:"u_title"` + LenTitle any `sql:"len_title"` +} + +type BlogFuncReq struct { + Id int64 `sql:"id"` +} + +func (req *BlogFuncReq) Params() []any { + var params []any + + if req.Id != 0 { + params = append(params, req.Id) + } + + return params +} + +func (req *BlogFuncReq) Condition() string { + var conditions []string + if req.Id != 0 { + conditions = append(conditions, "id = ?") + } + var query string + if len(conditions) > 0 { + query += " WHERE " + strings.Join(conditions, " AND ") + } + return query +} + +const _BlogFuncSQL = "SELECT UPPER(`title`) AS `u_title`,LENGTH(`title`) AS `len_title` FROM `blogs` %s" + +// BlogFunc is a raw query handler generated function for `e2e/mysqlr/sqls/blog_func.sql`. +func (m *sqlMethods) BlogFunc(ctx context.Context, req *BlogFuncReq, opts ...RawQueryOptionHandler) ([]*BlogFuncResp, error) { + + rawQueryOption := &RawQueryOption{} + + for _, o := range opts { + o(rawQueryOption) + } + + query := fmt.Sprintf(_BlogFuncSQL, req.Condition()) + + rows, err := db.GetMysql(db.WithDB(rawQueryOption.db)).QueryContext(ctx, query, req.Params()...) + if err != nil { + return nil, err + } + defer rows.Close() + + var results []*BlogFuncResp + for rows.Next() { + var o BlogFuncResp + err = rows.Scan(&o.UTitle, &o.LenTitle) + if err != nil { + return nil, err + } + results = append(results, &o) + } + return results, nil +} diff --git a/e2e/mysqlr/mysqlr_test.go b/e2e/mysqlr/mysqlr_test.go index 7784a7e..b75a66e 100644 --- a/e2e/mysqlr/mysqlr_test.go +++ b/e2e/mysqlr/mysqlr_test.go @@ -10,6 +10,7 @@ import ( "testing" "time" + "github.com/ezbuy/ezorm/v2/pkg/orm" "github.com/stretchr/testify/assert" ) @@ -160,6 +161,35 @@ func TestBlogsCRUD(t *testing.T) { assert.Equal(t, 0, len(resp)) }) + t.Run("MySQLFunction", func(t *testing.T) { + resps, err := GetRawQuery().BlogAggr(ctx, &BlogAggrReq{ + Id: 0, + }, WithDB(db.DB)) + assert.NoError(t, err) + assert.Equal(t, 1, len(resps)) + resp := resps[0] + i, err := orm.ToResult[sql.NullInt64](resp.Count) + assert.NoError(t, err) + assert.Equal(t, int64(1), i.Int64) + + resp2s, err := GetRawQuery().BlogFunc(ctx, &BlogFuncReq{ + Id: 1, + }, WithDB(db.DB)) + assert.NoError(t, err) + assert.Equal(t, 1, len(resp2s)) + resp2 := resp2s[0] + s, err := orm.ToResult[sql.NullString](resp2.UTitle) + assert.NoError(t, err) + if !assert.Equal(t, "TEST", s.String) { + t.Errorf("resp2.UTitle: %#v\n", resp2.UTitle) + } + i2, err := orm.ToResult[sql.NullInt64](resp2.LenTitle) + assert.NoError(t, err) + if !assert.Equal(t, int64(4), i2.Int64) { + t.Errorf("resp2.LenTitle: %#v\n", resp2.LenTitle) + } + }) + t.Run("Delete", func(t *testing.T) { af, err := BlogDBMgr(db).DeleteByPrimaryKey(ctx, 1, 1) assert.NoError(t, err) diff --git a/e2e/mysqlr/sqls/blog_aggr.sql b/e2e/mysqlr/sqls/blog_aggr.sql new file mode 100644 index 0000000..be71b01 --- /dev/null +++ b/e2e/mysqlr/sqls/blog_aggr.sql @@ -0,0 +1,6 @@ +SELECT + COUNT(id) as count +FROM + blogs +WHERE + id > 1; diff --git a/e2e/mysqlr/sqls/blog_func.sql b/e2e/mysqlr/sqls/blog_func.sql new file mode 100644 index 0000000..7a7d572 --- /dev/null +++ b/e2e/mysqlr/sqls/blog_func.sql @@ -0,0 +1,8 @@ + +SELECT + UPPER(title) as u_title, + LENGTH(title) as len_title +FROM + blogs +WHERE + id = 1; diff --git a/internal/parser/x/query/tidb_parser.go b/internal/parser/x/query/tidb_parser.go index 4457eed..6147292 100644 --- a/internal/parser/x/query/tidb_parser.go +++ b/internal/parser/x/query/tidb_parser.go @@ -118,6 +118,29 @@ func (tp *TiDBParser) parse(node ast.Node, n int) error { } } } + if expr, ok := f.Expr.(*ast.FuncCallExpr); ok { + field := &QueryField{ + Alias: f.AsName.String(), + } + var txt bytes.Buffer + txt.WriteString(expr.FnName.O) + for _, args := range expr.Args { + txt.WriteString("_") + var arg strings.Builder + args.Format(&arg) + txt.WriteString(arg.String()) + } + field.Name = txt.String() + field.Type = T_ANY + if len(expr.Args) > 0 { + for _, arg := range expr.Args { + if col, ok := arg.(*ast.ColumnNameExpr); ok { + tp.meta.AppendResult(col.Name.Table.String(), field) + tp.b.resultFields = append(tp.b.resultFields, field) + } + } + } + } if expr, ok := f.Expr.(*ast.ColumnNameExpr); ok { field := &QueryField{ Alias: f.AsName.String(), diff --git a/v2/pkg/orm/any.go b/v2/pkg/orm/any.go new file mode 100644 index 0000000..ed51f16 --- /dev/null +++ b/v2/pkg/orm/any.go @@ -0,0 +1,56 @@ +package orm + +import ( + "database/sql" + "fmt" +) + +type Result interface { + sql.NullInt64 | sql.NullString +} + +var ErrScanResultTypeUnsupported = fmt.Errorf("scan result type unsupported") + +// ToResult is helper function which sits in the background +// when we cannot determine the underlying type of MySQL scan result(such as generic function result). + +// Inspired by https://github.com/go-sql-driver/mysql/issues/86, +// if the Go SQL driver cannot determine the underlying type of scan result(`interface{}` or `any`), +// it will fallback to TEXT protocol to communicate with MySQL server, +// therefore, the result can only be `[]uint8`(`[]byte`) or raw `string`. +func ToResult[T Result](rawField any) (T, error) { + var t T + switch rawField := rawField.(type) { + default: + return t, fmt.Errorf("rawField type got %T", rawField) + case string: + switch any(t).(type) { + default: + return t, ErrScanResultTypeUnsupported + case sql.NullString: + s := &sql.NullString{} + if err := s.Scan(rawField); err != nil { + return t, err + } + t = any(*s).(T) + } + case []byte: + switch any(t).(type) { + default: + return t, ErrScanResultTypeUnsupported + case sql.NullString: + s := &sql.NullString{} + if err := s.Scan(rawField); err != nil { + return t, err + } + t = any(*s).(T) + case sql.NullInt64: + i := &sql.NullInt64{} + if err := i.Scan(rawField); err != nil { + return t, err + } + t = any(*i).(T) + } + } + return t, nil +}