diff --git a/e2e/mysqlr/gen_methods.go b/e2e/mysqlr/gen_methods.go index 862fa0a..83b7c39 100644 --- a/e2e/mysqlr/gen_methods.go +++ b/e2e/mysqlr/gen_methods.go @@ -111,3 +111,126 @@ func (m *sqlMethods) Blog(ctx context.Context, req *BlogReq, opts ...RawQueryOpt } return results, nil } + +type BlogAggrResp struct { + Count any `sql:"count"` +} + +type BlogAggrReq struct { + Id int64 `sql:"id"` +} + +func (req *BlogAggrReq) Params() []any { + var params []any + + if req.Id != 0 { + params = append(params, req.Id) + } + + return params +} + +func (req *BlogAggrReq) Condition() string { + var conditions []string + if req.Id != 0 { + conditions = append(conditions, "id = ?") + } + var query string + if len(conditions) > 0 { + query += " WHERE " + strings.Join(conditions, " AND ") + } + return query +} + +const _BlogAggrSQL = "SELECT COUNT(`id`) AS `count` FROM `blogs` %s" + +// BlogAggr is a raw query handler generated function for `e2e/mysqlr/sqls/blog_aggr.sql`. +func (m *sqlMethods) BlogAggr(ctx context.Context, req *BlogAggrReq, opts ...RawQueryOptionHandler) ([]*BlogAggrResp, error) { + + rawQueryOption := &RawQueryOption{} + + for _, o := range opts { + o(rawQueryOption) + } + + query := fmt.Sprintf(_BlogAggrSQL, req.Condition()) + + rows, err := db.GetMysql(db.WithDB(rawQueryOption.db)).QueryContext(ctx, query, req.Params()...) + if err != nil { + return nil, err + } + defer rows.Close() + + var results []*BlogAggrResp + for rows.Next() { + var o BlogAggrResp + err = rows.Scan(&o.Count) + if err != nil { + return nil, err + } + results = append(results, &o) + } + return results, nil +} + +type BlogFuncResp struct { + UTitle any `sql:"u_title"` + LenTitle any `sql:"len_title"` +} + +type BlogFuncReq struct { + Id int64 `sql:"id"` +} + +func (req *BlogFuncReq) Params() []any { + var params []any + + if req.Id != 0 { + params = append(params, req.Id) + } + + return params +} + +func (req *BlogFuncReq) Condition() string { + var conditions []string + if req.Id != 0 { + conditions = append(conditions, "id = ?") + } + var query string + if len(conditions) > 0 { + query += " WHERE " + strings.Join(conditions, " AND ") + } + return query +} + +const _BlogFuncSQL = "SELECT UPPER(`title`) AS `u_title`,LENGTH(`title`) AS `len_title` FROM `blogs` %s" + +// BlogFunc is a raw query handler generated function for `e2e/mysqlr/sqls/blog_func.sql`. +func (m *sqlMethods) BlogFunc(ctx context.Context, req *BlogFuncReq, opts ...RawQueryOptionHandler) ([]*BlogFuncResp, error) { + + rawQueryOption := &RawQueryOption{} + + for _, o := range opts { + o(rawQueryOption) + } + + query := fmt.Sprintf(_BlogFuncSQL, req.Condition()) + + rows, err := db.GetMysql(db.WithDB(rawQueryOption.db)).QueryContext(ctx, query, req.Params()...) + if err != nil { + return nil, err + } + defer rows.Close() + + var results []*BlogFuncResp + for rows.Next() { + var o BlogFuncResp + err = rows.Scan(&o.UTitle, &o.LenTitle) + if err != nil { + return nil, err + } + results = append(results, &o) + } + return results, nil +} diff --git a/e2e/mysqlr/mysqlr_test.go b/e2e/mysqlr/mysqlr_test.go index 7784a7e..b75a66e 100644 --- a/e2e/mysqlr/mysqlr_test.go +++ b/e2e/mysqlr/mysqlr_test.go @@ -10,6 +10,7 @@ import ( "testing" "time" + "github.com/ezbuy/ezorm/v2/pkg/orm" "github.com/stretchr/testify/assert" ) @@ -160,6 +161,35 @@ func TestBlogsCRUD(t *testing.T) { assert.Equal(t, 0, len(resp)) }) + t.Run("MySQLFunction", func(t *testing.T) { + resps, err := GetRawQuery().BlogAggr(ctx, &BlogAggrReq{ + Id: 0, + }, WithDB(db.DB)) + assert.NoError(t, err) + assert.Equal(t, 1, len(resps)) + resp := resps[0] + i, err := orm.ToResult[sql.NullInt64](resp.Count) + assert.NoError(t, err) + assert.Equal(t, int64(1), i.Int64) + + resp2s, err := GetRawQuery().BlogFunc(ctx, &BlogFuncReq{ + Id: 1, + }, WithDB(db.DB)) + assert.NoError(t, err) + assert.Equal(t, 1, len(resp2s)) + resp2 := resp2s[0] + s, err := orm.ToResult[sql.NullString](resp2.UTitle) + assert.NoError(t, err) + if !assert.Equal(t, "TEST", s.String) { + t.Errorf("resp2.UTitle: %#v\n", resp2.UTitle) + } + i2, err := orm.ToResult[sql.NullInt64](resp2.LenTitle) + assert.NoError(t, err) + if !assert.Equal(t, int64(4), i2.Int64) { + t.Errorf("resp2.LenTitle: %#v\n", resp2.LenTitle) + } + }) + t.Run("Delete", func(t *testing.T) { af, err := BlogDBMgr(db).DeleteByPrimaryKey(ctx, 1, 1) assert.NoError(t, err) diff --git a/e2e/mysqlr/sqls/blog_aggr.sql b/e2e/mysqlr/sqls/blog_aggr.sql new file mode 100644 index 0000000..be71b01 --- /dev/null +++ b/e2e/mysqlr/sqls/blog_aggr.sql @@ -0,0 +1,6 @@ +SELECT + COUNT(id) as count +FROM + blogs +WHERE + id > 1; diff --git a/e2e/mysqlr/sqls/blog_func.sql b/e2e/mysqlr/sqls/blog_func.sql new file mode 100644 index 0000000..7a7d572 --- /dev/null +++ b/e2e/mysqlr/sqls/blog_func.sql @@ -0,0 +1,8 @@ + +SELECT + UPPER(title) as u_title, + LENGTH(title) as len_title +FROM + blogs +WHERE + id = 1; diff --git a/internal/parser/x/query/tidb_parser.go b/internal/parser/x/query/tidb_parser.go index 4457eed..6147292 100644 --- a/internal/parser/x/query/tidb_parser.go +++ b/internal/parser/x/query/tidb_parser.go @@ -118,6 +118,29 @@ func (tp *TiDBParser) parse(node ast.Node, n int) error { } } } + if expr, ok := f.Expr.(*ast.FuncCallExpr); ok { + field := &QueryField{ + Alias: f.AsName.String(), + } + var txt bytes.Buffer + txt.WriteString(expr.FnName.O) + for _, args := range expr.Args { + txt.WriteString("_") + var arg strings.Builder + args.Format(&arg) + txt.WriteString(arg.String()) + } + field.Name = txt.String() + field.Type = T_ANY + if len(expr.Args) > 0 { + for _, arg := range expr.Args { + if col, ok := arg.(*ast.ColumnNameExpr); ok { + tp.meta.AppendResult(col.Name.Table.String(), field) + tp.b.resultFields = append(tp.b.resultFields, field) + } + } + } + } if expr, ok := f.Expr.(*ast.ColumnNameExpr); ok { field := &QueryField{ Alias: f.AsName.String(), diff --git a/v2/pkg/orm/any.go b/v2/pkg/orm/any.go new file mode 100644 index 0000000..ed51f16 --- /dev/null +++ b/v2/pkg/orm/any.go @@ -0,0 +1,56 @@ +package orm + +import ( + "database/sql" + "fmt" +) + +type Result interface { + sql.NullInt64 | sql.NullString +} + +var ErrScanResultTypeUnsupported = fmt.Errorf("scan result type unsupported") + +// ToResult is helper function which sits in the background +// when we cannot determine the underlying type of MySQL scan result(such as generic function result). + +// Inspired by https://github.com/go-sql-driver/mysql/issues/86, +// if the Go SQL driver cannot determine the underlying type of scan result(`interface{}` or `any`), +// it will fallback to TEXT protocol to communicate with MySQL server, +// therefore, the result can only be `[]uint8`(`[]byte`) or raw `string`. +func ToResult[T Result](rawField any) (T, error) { + var t T + switch rawField := rawField.(type) { + default: + return t, fmt.Errorf("rawField type got %T", rawField) + case string: + switch any(t).(type) { + default: + return t, ErrScanResultTypeUnsupported + case sql.NullString: + s := &sql.NullString{} + if err := s.Scan(rawField); err != nil { + return t, err + } + t = any(*s).(T) + } + case []byte: + switch any(t).(type) { + default: + return t, ErrScanResultTypeUnsupported + case sql.NullString: + s := &sql.NullString{} + if err := s.Scan(rawField); err != nil { + return t, err + } + t = any(*s).(T) + case sql.NullInt64: + i := &sql.NullInt64{} + if err := i.Scan(rawField); err != nil { + return t, err + } + t = any(*i).(T) + } + } + return t, nil +}