From 8ea27dfb394e9a11b6dc6b05fecbbeea28e6cda6 Mon Sep 17 00:00:00 2001 From: csynineyang Date: Wed, 2 Nov 2022 22:05:20 +0800 Subject: [PATCH 01/20] add node config support (#464) --- pkg/admin/router/nodes.go | 44 ++++++++++++++++++++++++---- pkg/boot/discovery.go | 50 +++++++++++++++++++++++++++++--- pkg/boot/proto.go | 7 +++++ pkg/config/api.go | 6 ++++ pkg/config/tenant.go | 60 +++++++++++++++++++++++++++++++++++++++ 5 files changed, 157 insertions(+), 10 deletions(-) diff --git a/pkg/admin/router/nodes.go b/pkg/admin/router/nodes.go index 3723801f..3e5bd074 100644 --- a/pkg/admin/router/nodes.go +++ b/pkg/admin/router/nodes.go @@ -34,9 +34,9 @@ import ( func init() { admin.Register(func(router admin.Router) { - router.GET("/tenants/:tenant/nodes", ListNodes) + router.GET("/tenants/:tenant/nodes", ListNodesByAdmin) router.POST("/tenants/:tenant/nodes", CreateNode) - router.GET("/tenants/:tenant/nodes/:node", GetNode) + router.GET("/tenants/:tenant/nodes/:node", GetNodeByAdmin) router.PUT("/tenants/:tenant/nodes/:node", UpdateNode) router.DELETE("/tenants/:tenant/nodes/:node", RemoveNode) }) @@ -73,6 +73,26 @@ func ListNodes(c *gin.Context) error { return nil } +func ListNodesByAdmin(c *gin.Context) error { + var results []config.Node + service := admin.GetService(c) + tenantName := c.Param("tenant") + nodesArray, err := service.ListNodesByAdmin(c, tenantName) + if err != nil { + return err + } + for _, node := range nodesArray { + result, err := service.GetNodeByAdmin(c, tenantName, node) + if err != nil { + return err + } + results = append(results, *result) + } + + c.JSON(http.StatusOK, results) + return nil +} + func GetNode(c *gin.Context) error { service := admin.GetService(c) tenant := c.Param("tenant") @@ -98,15 +118,27 @@ func GetNode(c *gin.Context) error { return nil } +func GetNodeByAdmin(c *gin.Context) error { + service := admin.GetService(c) + tenant := c.Param("tenant") + node := c.Param("node") + data, err := service.GetNodeByAdmin(c, tenant, node) + if err != nil { + return err + } + c.JSON(http.StatusOK, data) + return nil +} + func CreateNode(c *gin.Context) error { service := admin.GetService(c) tenant := c.Param("tenant") - var node *boot.NodeBody - if err := c.ShouldBindJSON(&node); err != nil { + var nodeBody *boot.NodeBody + if err := c.ShouldBindJSON(&nodeBody); err != nil { return exception.Wrap(exception.CodeInvalidParams, err) } - err := service.UpsertNode(c, tenant, "", node) + err := service.UpsertNode(c, tenant, nodeBody.Name, nodeBody) if err != nil { return err } @@ -119,7 +151,7 @@ func UpdateNode(c *gin.Context) error { tenant := c.Param("tenant") node := c.Param("node") var nodeBody *boot.NodeBody - if err := c.ShouldBindJSON(&nodeBody); err == nil { + if err := c.ShouldBindJSON(&nodeBody); err != nil { return exception.Wrap(exception.CodeInvalidParams, err) } diff --git a/pkg/boot/discovery.go b/pkg/boot/discovery.go index 41d7ed32..348510d9 100644 --- a/pkg/boot/discovery.go +++ b/pkg/boot/discovery.go @@ -183,13 +183,19 @@ func (fp *discovery) RemoveCluster(ctx context.Context, tenant, cluster string) } func (fp *discovery) UpsertNode(ctx context.Context, tenant, node string, body *NodeBody) error { - // TODO implement me - panic("implement me") + if err := fp.tenantOp.UpsertNode(tenant, node, body.Name, body.Host, body.Port, body.Username, body.Password, body.Database, body.Weight); err != nil { + return errors.Wrapf(err, "failed to upsert node '%s' for tenant '%s'", node, tenant) + } + + return nil } func (fp *discovery) RemoveNode(ctx context.Context, tenant, node string) error { - // TODO implement me - panic("implement me") + if err := fp.tenantOp.RemoveNode(tenant, node); err != nil { + return errors.Wrapf(err, "failed to remove node '%s' for tenant '%s'", node, tenant) + } + + return nil } func (fp *discovery) UpsertGroup(ctx context.Context, tenant, cluster, group string, body *GroupBody) error { @@ -501,6 +507,28 @@ func (fp *discovery) ListNodes(ctx context.Context, tenant, cluster, group strin return nodes, nil } +func (fp *discovery) ListNodesByAdmin(ctx context.Context, tenant string) ([]string, error) { + op, ok := fp.centers[tenant] + if !ok { + return nil, ErrorNoTenant + } + + cfg, err := op.LoadAll(context.Background()) + if err != nil { + return nil, err + } + + if cfg == nil || len(cfg.Nodes) == 0 { + return nil, nil + } + + ret := make([]string, 0, len(cfg.Nodes)) + for _, it := range cfg.Nodes { + ret = append(ret, it.Name) + } + return ret, nil +} + func (fp *discovery) ListTables(ctx context.Context, tenant, cluster string) ([]string, error) { op, ok := fp.centers[tenant] if !ok { @@ -562,6 +590,20 @@ func (fp *discovery) GetNode(ctx context.Context, tenant, cluster, group, node s return nodes[nodeId], nil } +func (fp *discovery) GetNodeByAdmin(ctx context.Context, tenant, node string) (*config.Node, error) { + op, ok := fp.centers[tenant] + if !ok { + return nil, ErrorNoTenant + } + + nodes, err := fp.loadNodes(op) + if err != nil { + return nil, err + } + + return nodes[node], nil +} + func (fp *discovery) GetTable(ctx context.Context, tenant, cluster, tableName string) (*rule.VTable, error) { op, ok := fp.centers[tenant] if !ok { diff --git a/pkg/boot/proto.go b/pkg/boot/proto.go index ebb3902a..f8dbf53d 100644 --- a/pkg/boot/proto.go +++ b/pkg/boot/proto.go @@ -43,6 +43,7 @@ type ClusterBody struct { } type NodeBody struct { + Name string `yaml:"name" json:"name"` Host string `yaml:"host" json:"host"` Port int `yaml:"port" json:"port"` Username string `yaml:"username" json:"username"` @@ -97,9 +98,15 @@ type ConfigProvider interface { // ListNodes lists the node names. ListNodes(ctx context.Context, tenant, cluster, group string) ([]string, error) + // ListNodesByAdmin lists the node names by admin. + ListNodesByAdmin(ctx context.Context, tenant string) ([]string, error) + // GetNode returns the node info. GetNode(ctx context.Context, tenant, cluster, group, node string) (*config.Node, error) + // GetNodeByAdmin returns the node info by admin. + GetNodeByAdmin(ctx context.Context, tenant, node string) (*config.Node, error) + // ListTables lists the table names. ListTables(ctx context.Context, tenant, cluster string) ([]string, error) diff --git a/pkg/config/api.go b/pkg/config/api.go index 3fddf2da..825de2a5 100644 --- a/pkg/config/api.go +++ b/pkg/config/api.go @@ -129,6 +129,12 @@ type ( // CreateTenantUser creates a user. CreateTenantUser(tenant, username, password string) error + // UpsertNode creates a node, or updates a node. + UpsertNode(tenant, node, name, host string, port int, username, password, database, weight string) error + + // RemoveNode removes a node. + RemoveNode(tenant, name string) error + // Subscribe subscribes tenants change Subscribe(ctx context.Context, c EventCallback) context.CancelFunc } diff --git a/pkg/config/tenant.go b/pkg/config/tenant.go index 29c4ece4..7b3c83a9 100644 --- a/pkg/config/tenant.go +++ b/pkg/config/tenant.go @@ -241,6 +241,66 @@ func (tp *tenantOperate) RemoveTenant(name string) error { return tp.op.Save(DefaultTenantsPath, data) } +func (tp *tenantOperate) UpsertNode(tenant, node, name, host string, port int, username, password, database, weight string) error { + p := NewPathInfo(tenant) + + prev, err := tp.op.Get(p.DefaultConfigDataNodesPath) + if err != nil { + return errors.WithStack(err) + } + + var nodes Nodes + if err := yaml.Unmarshal(prev, &nodes); err != nil { + return errors.WithStack(err) + } + + nodes[node] = &Node{ + Name: name, + Host: host, + Port: port, + Username: username, + Password: password, + Database: database, + Weight: weight, + } + + b, err := yaml.Marshal(nodes) + if err != nil { + return errors.WithStack(err) + } + + if err := tp.op.Save(p.DefaultConfigDataNodesPath, b); err != nil { + return errors.WithStack(err) + } + return nil +} + +func (tp *tenantOperate) RemoveNode(tenant, name string) error { + p := NewPathInfo(tenant) + + prev, err := tp.op.Get(p.DefaultConfigDataNodesPath) + if err != nil { + return errors.WithStack(err) + } + + var nodes Nodes + if err := yaml.Unmarshal(prev, &nodes); err != nil { + return errors.WithStack(err) + } + + delete(nodes, name) + + b, err := yaml.Marshal(nodes) + if err != nil { + return errors.WithStack(err) + } + + if err := tp.op.Save(p.DefaultConfigDataNodesPath, b); err != nil { + return errors.WithStack(err) + } + return nil +} + func (tp *tenantOperate) Close() error { for i := range tp.cancels { tp.cancels[i]() From 9c56b9841f41bd74ad74e4ab447b2f13e65449ab Mon Sep 17 00:00:00 2001 From: csynineyang Date: Wed, 23 Nov 2022 20:08:42 +0800 Subject: [PATCH 02/20] Support MySQL CAST_CHAR function. --- pkg/runtime/function/cast_char.go | 214 +++++++++++++++++++++++++ pkg/runtime/function/cast_char_test.go | 56 +++++++ pkg/runtime/function/cast_nchar.go | 6 +- 3 files changed, 273 insertions(+), 3 deletions(-) create mode 100644 pkg/runtime/function/cast_char.go create mode 100644 pkg/runtime/function/cast_char_test.go diff --git a/pkg/runtime/function/cast_char.go b/pkg/runtime/function/cast_char.go new file mode 100644 index 00000000..595ed43e --- /dev/null +++ b/pkg/runtime/function/cast_char.go @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package function + +import ( + "bytes" + "context" + "fmt" + "io/ioutil" + "strings" + "unicode/utf8" +) + +import ( + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/util/runes" +) + +import ( + gxbig "github.com/dubbogo/gost/math/big" + "github.com/pkg/errors" + "golang.org/x/text/encoding/charmap" + "golang.org/x/text/encoding/simplifiedchinese" + "golang.org/x/text/encoding/unicode" + "golang.org/x/text/transform" +) + +// FuncCastChar is https://dev.mysql.com/doc/refman/5.6/en/cast-functions.html#function_cast +const FuncCastChar = "CAST_CHAR" + +var _ proto.Func = (*castcharFunc)(nil) + +func init() { + proto.RegisterFunc(FuncCastChar, castcharFunc{}) +} + +type castcharFunc struct{} + +func (a castcharFunc) Apply(ctx context.Context, inputs ...proto.Valuer) (proto.Value, error) { + // expr + val1, err := inputs[0].Value(ctx) + if err != nil { + return nil, errors.WithStack(err) + } + + if len(inputs) < 3 { + return val1, nil + } + + // N + var num int64 + val2, err := inputs[1].Value(ctx) + if err != nil { + return nil, errors.WithStack(err) + } + d2, _ := gxbig.NewDecFromString(fmt.Sprint(val2)) + if d2.IsNegative() { + ret, err := d2.ToInt() + if err != nil { + return "", err + } + num = ret + } else if !strings.Contains(fmt.Sprint(val2), ".") { + ret, err := d2.ToInt() + if err != nil { + return "", err + } + num = ret + } else { + ret, err := d2.ToFloat64() + if err != nil { + return "", err + } + num = int64(ret) + } + + // charset_info + val3, err := inputs[2].Value(ctx) + if err != nil { + return "", err + } + + return a.getResult(runes.ConvertToRune(val1), num, fmt.Sprint(val3)) +} + +func (a castcharFunc) NumInput() int { + return 3 +} + +func (a castcharFunc) getResult(runes []rune, num int64, charEncode string) (string, error) { + var srcString string + + // N + if num > int64(len(runes)) { + srcString = string(runes) + } else if num >= 0 { + srcString = string(runes[:num]) + } else { + srcString = string(runes) + } + + // charset_info + if !utf8.ValidString(srcString) { + // source string only support utf8 + return srcString, nil + } + charInfo := strings.Split(charEncode, " ") + if len(charInfo) >= 3 && charInfo[0] == "CHARACTER" && charInfo[1] == "SET" { + if strings.ToLower(charInfo[2]) == "gbk" { + // CHARACTER SET gbk + srcEncode := simplifiedchinese.GBK.NewEncoder() + dstString, err := srcEncode.String(srcString) + if err == nil { + return dstString, nil + } else { + return "", errors.New("CHAR[(N)][charset_info] gbk encode error") + } + } else if strings.ToLower(charInfo[2]) == "gb18030" { + // CHARACTER SET gb18030 + srcEncode := simplifiedchinese.GB18030.NewEncoder() + dstString, err := srcEncode.String(srcString) + if err == nil { + return dstString, nil + } else { + return "", errors.New("CHAR[(N)][charset_info] gb18030 encode error") + } + } else if strings.ToLower(charInfo[2]) == "latin2" { + // CHARACTER SET latin2 + srcEncode := charmap.ISO8859_2.NewEncoder() + dstString, err := srcEncode.String(srcString) + if err == nil { + return dstString, nil + } else { + return "", errors.New("CHAR[(N)][charset_info] latin2 encode error") + } + } else if strings.ToLower(charInfo[2]) == "latin5" { + // CHARACTER SET latin5 + srcEncode := charmap.ISO8859_9.NewEncoder() + dstString, err := srcEncode.String(srcString) + if err == nil { + return dstString, nil + } else { + return "", errors.New("CHAR[(N)][charset_info] latin5 encode error") + } + } else if strings.ToLower(charInfo[2]) == "greek" { + // CHARACTER SET greek + srcEncode := charmap.ISO8859_7.NewEncoder() + dstString, err := srcEncode.String(srcString) + if err == nil { + return dstString, nil + } else { + return "", errors.New("CHAR[(N)][charset_info] greek encode error") + } + } else if strings.ToLower(charInfo[2]) == "hebrew" { + // CHARACTER SET hebrew + srcEncode := charmap.ISO8859_8.NewEncoder() + dstString, err := srcEncode.String(srcString) + if err == nil { + return dstString, nil + } else { + return "", errors.New("CHAR[(N)][charset_info] hebrew encode error") + } + } else if strings.ToLower(charInfo[2]) == "latin7" { + // CHARACTER SET latin7 + srcEncode := charmap.ISO8859_13.NewEncoder() + dstString, err := srcEncode.String(srcString) + if err == nil { + return dstString, nil + } else { + return "", errors.New("CHAR[(N)][charset_info] latin7 encode error") + } + } else { + return "", errors.New("CHAR[(N)][charset_info] Variable charset_info is not supported") + } + } else if len(charInfo) >= 1 && charInfo[0] == "ASCII" { + // ASCII: CHARACTER SET latin1 + srcEncode := charmap.ISO8859_1.NewEncoder() + dstString, err := srcEncode.String(srcString) + if err == nil { + return dstString, nil + } else { + return "", errors.New("CHAR[(N)][charset_info] latin1 encode error") + } + } else if len(charInfo) >= 1 && charInfo[0] == "UNICODE" { + // UNICODE: CHARACTER SET ucs2 + srcReader := bytes.NewReader([]byte(srcString)) + //UTF-16 bigendian, no-bom + trans := transform.NewReader(srcReader, + unicode.UTF16(unicode.BigEndian, unicode.IgnoreBOM).NewEncoder()) + dstString, err := ioutil.ReadAll(trans) + if err == nil { + return string(dstString), nil + } else { + return "", errors.New("CHAR[(N)][charset_info] ucs2 encode error") + } + } else { + return "", errors.New("CHAR[(N)][charset_info] Variable charset_info is invalid") + } +} diff --git a/pkg/runtime/function/cast_char_test.go b/pkg/runtime/function/cast_char_test.go new file mode 100644 index 00000000..ec12396a --- /dev/null +++ b/pkg/runtime/function/cast_char_test.go @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package function + +import ( + "context" + "fmt" + "testing" +) + +import ( + "github.com/stretchr/testify/assert" +) + +import ( + "github.com/arana-db/arana/pkg/proto" +) + +func TestFuncCastChar(t *testing.T) { + fn := proto.MustGetFunc(FuncCastChar) + assert.Equal(t, 3, fn.NumInput()) + type tt struct { + inFirst proto.Value + inSecond proto.Value + intThird proto.Value + want string + } + for _, v := range []tt{ + {"Hello", -1, "ASCII", "Hello"}, + {"Hello", -1, "UNICODE", "\x00H\x00e\x00l\x00l\x00o"}, + {"Hello世界", -1, "CHARACTER SET gbk", "Hello\xca\xc0\xbd\xe7"}, + {"Hello世界", -1, "CHARACTER SET gb18030", "Hello\xca\xc0\xbd\xe7"}, + {"Hello世界", 5, "CHARACTER SET latin2", "Hello"}, + } { + t.Run(v.want, func(t *testing.T) { + out, err := fn.Apply(context.Background(), proto.ToValuer(v.inFirst), proto.ToValuer(v.inSecond), proto.ToValuer(v.intThird)) + assert.NoError(t, err) + assert.Equal(t, v.want, fmt.Sprint(out)) + }) + } +} diff --git a/pkg/runtime/function/cast_nchar.go b/pkg/runtime/function/cast_nchar.go index 3309f971..68295bc9 100644 --- a/pkg/runtime/function/cast_nchar.go +++ b/pkg/runtime/function/cast_nchar.go @@ -66,20 +66,20 @@ func (a castncharFunc) Apply(ctx context.Context, inputs ...proto.Valuer) (proto if err != nil { return "", err } - return getResult(runes.ConvertToRune(val1), num) + return a.getResult(runes.ConvertToRune(val1), num) } num, err := d2.ToFloat64() if err != nil { return "", err } - return getResult(runes.ConvertToRune(val1), int64(num)) + return a.getResult(runes.ConvertToRune(val1), int64(num)) } func (a castncharFunc) NumInput() int { return 2 } -func getResult(runes []rune, num int64) (string, error) { +func (a castncharFunc) getResult(runes []rune, num int64) (string, error) { if num > int64(len(runes)) { return string(runes), nil } else if num >= 0 { From fac50282d19dd805618397918d88469a36eac753 Mon Sep 17 00:00:00 2001 From: csynineyang Date: Mon, 28 Nov 2022 21:37:40 +0800 Subject: [PATCH 03/20] format style --- pkg/runtime/function/cast_char.go | 43 +++++++++++++------------- pkg/runtime/function/cast_char_test.go | 9 +++--- 2 files changed, 26 insertions(+), 26 deletions(-) diff --git a/pkg/runtime/function/cast_char.go b/pkg/runtime/function/cast_char.go index 595ed43e..27fc0a8d 100644 --- a/pkg/runtime/function/cast_char.go +++ b/pkg/runtime/function/cast_char.go @@ -26,20 +26,22 @@ import ( "unicode/utf8" ) -import ( - "github.com/arana-db/arana/pkg/proto" - "github.com/arana-db/arana/pkg/util/runes" -) - import ( gxbig "github.com/dubbogo/gost/math/big" + "github.com/pkg/errors" + "golang.org/x/text/encoding/charmap" "golang.org/x/text/encoding/simplifiedchinese" "golang.org/x/text/encoding/unicode" "golang.org/x/text/transform" ) +import ( + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/util/runes" +) + // FuncCastChar is https://dev.mysql.com/doc/refman/5.6/en/cast-functions.html#function_cast const FuncCastChar = "CAST_CHAR" @@ -70,12 +72,9 @@ func (a castcharFunc) Apply(ctx context.Context, inputs ...proto.Valuer) (proto. } d2, _ := gxbig.NewDecFromString(fmt.Sprint(val2)) if d2.IsNegative() { - ret, err := d2.ToInt() - if err != nil { - return "", err - } - num = ret - } else if !strings.Contains(fmt.Sprint(val2), ".") { + return "", errors.New("CHAR[(N) Variable N is not allowed to be negative") + } + if !strings.Contains(fmt.Sprint(val2), ".") { ret, err := d2.ToInt() if err != nil { return "", err @@ -115,13 +114,13 @@ func (a castcharFunc) getResult(runes []rune, num int64, charEncode string) (str } // charset_info - if !utf8.ValidString(srcString) { + if !utf8.ValidString(srcString) || strings.EqualFold(charEncode, "") { // source string only support utf8 return srcString, nil } charInfo := strings.Split(charEncode, " ") - if len(charInfo) >= 3 && charInfo[0] == "CHARACTER" && charInfo[1] == "SET" { - if strings.ToLower(charInfo[2]) == "gbk" { + if len(charInfo) >= 3 && strings.EqualFold(charInfo[0], "CHARACTER") && strings.EqualFold(charInfo[1], "SET") { + if strings.EqualFold(charInfo[2], "gbk") { // CHARACTER SET gbk srcEncode := simplifiedchinese.GBK.NewEncoder() dstString, err := srcEncode.String(srcString) @@ -130,7 +129,7 @@ func (a castcharFunc) getResult(runes []rune, num int64, charEncode string) (str } else { return "", errors.New("CHAR[(N)][charset_info] gbk encode error") } - } else if strings.ToLower(charInfo[2]) == "gb18030" { + } else if strings.EqualFold(charInfo[2], "gb18030") { // CHARACTER SET gb18030 srcEncode := simplifiedchinese.GB18030.NewEncoder() dstString, err := srcEncode.String(srcString) @@ -139,7 +138,7 @@ func (a castcharFunc) getResult(runes []rune, num int64, charEncode string) (str } else { return "", errors.New("CHAR[(N)][charset_info] gb18030 encode error") } - } else if strings.ToLower(charInfo[2]) == "latin2" { + } else if strings.EqualFold(charInfo[2], "latin2") { // CHARACTER SET latin2 srcEncode := charmap.ISO8859_2.NewEncoder() dstString, err := srcEncode.String(srcString) @@ -148,7 +147,7 @@ func (a castcharFunc) getResult(runes []rune, num int64, charEncode string) (str } else { return "", errors.New("CHAR[(N)][charset_info] latin2 encode error") } - } else if strings.ToLower(charInfo[2]) == "latin5" { + } else if strings.EqualFold(charInfo[2], "latin5") { // CHARACTER SET latin5 srcEncode := charmap.ISO8859_9.NewEncoder() dstString, err := srcEncode.String(srcString) @@ -157,7 +156,7 @@ func (a castcharFunc) getResult(runes []rune, num int64, charEncode string) (str } else { return "", errors.New("CHAR[(N)][charset_info] latin5 encode error") } - } else if strings.ToLower(charInfo[2]) == "greek" { + } else if strings.EqualFold(charInfo[2], "greek") { // CHARACTER SET greek srcEncode := charmap.ISO8859_7.NewEncoder() dstString, err := srcEncode.String(srcString) @@ -166,7 +165,7 @@ func (a castcharFunc) getResult(runes []rune, num int64, charEncode string) (str } else { return "", errors.New("CHAR[(N)][charset_info] greek encode error") } - } else if strings.ToLower(charInfo[2]) == "hebrew" { + } else if strings.EqualFold(charInfo[2], "hebrew") { // CHARACTER SET hebrew srcEncode := charmap.ISO8859_8.NewEncoder() dstString, err := srcEncode.String(srcString) @@ -175,7 +174,7 @@ func (a castcharFunc) getResult(runes []rune, num int64, charEncode string) (str } else { return "", errors.New("CHAR[(N)][charset_info] hebrew encode error") } - } else if strings.ToLower(charInfo[2]) == "latin7" { + } else if strings.EqualFold(charInfo[2], "latin7") { // CHARACTER SET latin7 srcEncode := charmap.ISO8859_13.NewEncoder() dstString, err := srcEncode.String(srcString) @@ -187,7 +186,7 @@ func (a castcharFunc) getResult(runes []rune, num int64, charEncode string) (str } else { return "", errors.New("CHAR[(N)][charset_info] Variable charset_info is not supported") } - } else if len(charInfo) >= 1 && charInfo[0] == "ASCII" { + } else if len(charInfo) >= 1 && strings.EqualFold(charInfo[0], "ASCII") { // ASCII: CHARACTER SET latin1 srcEncode := charmap.ISO8859_1.NewEncoder() dstString, err := srcEncode.String(srcString) @@ -196,7 +195,7 @@ func (a castcharFunc) getResult(runes []rune, num int64, charEncode string) (str } else { return "", errors.New("CHAR[(N)][charset_info] latin1 encode error") } - } else if len(charInfo) >= 1 && charInfo[0] == "UNICODE" { + } else if len(charInfo) >= 1 && strings.EqualFold(charInfo[0], "UNICODE") { // UNICODE: CHARACTER SET ucs2 srcReader := bytes.NewReader([]byte(srcString)) //UTF-16 bigendian, no-bom diff --git a/pkg/runtime/function/cast_char_test.go b/pkg/runtime/function/cast_char_test.go index ec12396a..f1dd6aa5 100644 --- a/pkg/runtime/function/cast_char_test.go +++ b/pkg/runtime/function/cast_char_test.go @@ -41,10 +41,11 @@ func TestFuncCastChar(t *testing.T) { want string } for _, v := range []tt{ - {"Hello", -1, "ASCII", "Hello"}, - {"Hello", -1, "UNICODE", "\x00H\x00e\x00l\x00l\x00o"}, - {"Hello世界", -1, "CHARACTER SET gbk", "Hello\xca\xc0\xbd\xe7"}, - {"Hello世界", -1, "CHARACTER SET gb18030", "Hello\xca\xc0\xbd\xe7"}, + {"Hello", len("Hello"), "ASCII", "Hello"}, + {"Hello", len("Hello"), "unicode", "\x00H\x00e\x00l\x00l\x00o"}, + {"Hello世界", len("Hello世界"), "CHARACTER SET GBK", "Hello\xca\xc0\xbd\xe7"}, + {"Hello世界", len("Hello世界"), "CHARACTER SET gb18030", "Hello\xca\xc0\xbd\xe7"}, + {"Hello世界", len("Hello世界"), "", "Hello世界"}, {"Hello世界", 5, "CHARACTER SET latin2", "Hello"}, } { t.Run(v.want, func(t *testing.T) { From c05d948f073ce4a96e96c83911a4e4711530c70b Mon Sep 17 00:00:00 2001 From: csynineyang Date: Mon, 19 Dec 2022 16:46:00 +0800 Subject: [PATCH 04/20] Support MySQL CAST_TIME function. (#570) --- pkg/runtime/function/cast_time.go | 249 +++++++++++++++++++++++++ pkg/runtime/function/cast_time_test.go | 69 +++++++ 2 files changed, 318 insertions(+) create mode 100644 pkg/runtime/function/cast_time.go create mode 100644 pkg/runtime/function/cast_time_test.go diff --git a/pkg/runtime/function/cast_time.go b/pkg/runtime/function/cast_time.go new file mode 100644 index 00000000..2d069e37 --- /dev/null +++ b/pkg/runtime/function/cast_time.go @@ -0,0 +1,249 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package function + +import ( + "context" + "fmt" + "regexp" + "strconv" + "strings" +) + +import ( + "github.com/pkg/errors" +) + +import ( + "github.com/arana-db/arana/pkg/proto" +) + +// FuncCastTime is https://dev.mysql.com/doc/refman/5.6/en/cast-functions.html#function_cast +const FuncCastTime = "CAST_TIME" + +var _ proto.Func = (*castTimeFunc)(nil) + +func init() { + proto.RegisterFunc(FuncCastTime, castTimeFunc{}) +} + +type castTimeFunc struct{} + +func (a castTimeFunc) Apply(ctx context.Context, inputs ...proto.Valuer) (proto.Value, error) { + // expr + val, err := inputs[0].Value(ctx) + if err != nil { + return nil, errors.WithStack(err) + } + + timeArgs := fmt.Sprint(val) + if strings.Compare(timeArgs, "") == 0 { + return a.DefaultTimeValue(), nil + } + + // negative flag + nega := false + if strings.Compare(string(timeArgs[0]), "-") == 0 { + nega = true + timeArgs = string(timeArgs[1:]) + } + + // fractional seconds + frac := false + if strings.Contains(timeArgs, ".") { + timeArr := strings.Split(timeArgs, ".") + if len(timeArr) >= 2 && len(timeArr[1]) >= 1 { + timeFrac, _ := strconv.Atoi(string(timeArr[1][0])) + if timeFrac >= 5 { + frac = true + } + } + timeArgs = timeArr[0] + } + + if strings.Contains(timeArgs, " ") { + // format - D hhh:mm:ss,D hhh:mm,D hhh + pat := "^\\d{1,2} \\d{1,3}(:\\d{1,2}){0,2}$" + match, err := regexp.MatchString(pat, timeArgs) + if !match || err != nil { + return a.DefaultTimeValue(), nil + } + + timeArrLeft := strings.Split(timeArgs, " ") + timeDay, _ := strconv.Atoi(timeArrLeft[0]) + timeArrRight := strings.Split(timeArrLeft[1], ":") + timeHour := 0 + if len(timeArrRight) >= 1 { + timeHour, _ = strconv.Atoi(timeArrRight[0]) + } + timeHour += timeDay * 24 + timeMinutes := 0 + if len(timeArrRight) >= 2 { + timeMinutes, _ = strconv.Atoi(timeArrRight[1]) + } + timeSecond := 0 + if len(timeArrRight) >= 3 { + timeSecond, _ = strconv.Atoi(timeArrRight[2]) + } + + if !a.IsHourValid(timeHour) { + return a.MaxTimeValue(), nil + } + if a.IsMinutesValid(timeMinutes) && a.IsSecondValid(timeSecond) { + timeStr := a.TimeOutput(timeHour, timeMinutes, timeSecond, nega, frac) + return timeStr, nil + } + } else if strings.Contains(timeArgs, ":") { + // format - hhh:mm:ss,hhh:mm + pat := "^\\d{1,3}(:\\d{1,2}){1,2}$" + match, err := regexp.MatchString(pat, timeArgs) + if !match || err != nil { + return a.DefaultTimeValue(), nil + } + + timeArr := strings.Split(timeArgs, ":") + timeHour := 0 + if len(timeArr) >= 1 { + timeHour, _ = strconv.Atoi(timeArr[0]) + } + timeMinutes := 0 + if len(timeArr) >= 2 { + timeMinutes, _ = strconv.Atoi(timeArr[1]) + } + timeSecond := 0 + if len(timeArr) >= 3 { + timeSecond, _ = strconv.Atoi(timeArr[2]) + } + + if !a.IsHourValid(timeHour) { + return a.MaxTimeValue(), nil + } + if a.IsMinutesValid(timeMinutes) && a.IsSecondValid(timeSecond) { + timeStr := a.TimeOutput(timeHour, timeMinutes, timeSecond, nega, frac) + return timeStr, nil + } + } else { + // format - hhhmmss,mmss,ss + pat := "^\\d{1,7}$" + match, err := regexp.MatchString(pat, timeArgs) + if !match || err != nil { + return a.DefaultTimeValue(), nil + } + + timeInt, _ := strconv.Atoi(timeArgs) + timeSecond := timeInt % 100 + timeLeft := timeInt / 100 + timeMinutes := timeLeft % 100 + timeHour := timeLeft / 100 + + if !a.IsHourValid(timeHour) { + return a.MaxTimeValue(), nil + } + if a.IsMinutesValid(timeMinutes) && a.IsSecondValid(timeSecond) { + timeStr := a.TimeOutput(timeHour, timeMinutes, timeSecond, nega, frac) + return timeStr, nil + } + } + + return a.DefaultTimeValue(), nil +} + +func (a castTimeFunc) NumInput() int { + return 1 +} + +func (a castTimeFunc) TimeOutput(hour, minutes, second int, nega, frac bool) proto.Value { + if hour == 838 && minutes == 59 && second == 59 { + timeStr := a.MaxTimeValue() + if nega { + timeStr = a.MinTimeValue() + } + return timeStr + } + + if frac { + second += 1 + if second >= 60 { + minutes += 1 + second = 0 + } + if minutes >= 60 { + hour += 1 + minutes = 0 + } + } + + secondStr := fmt.Sprint(second) + if second < 10 { + secondStr = "0" + secondStr + } + minutesStr := fmt.Sprint(minutes) + if minutes < 10 { + minutesStr = "0" + minutesStr + } + hourStr := fmt.Sprint(hour) + if hour < 10 { + hourStr = "0" + hourStr + } + + timeStr := hourStr + ":" + minutesStr + ":" + secondStr + if nega { + timeStr = "-" + timeStr + } + return proto.NewValueString(timeStr) +} + +func (a castTimeFunc) IsDayValid(day int) bool { + if day >= 0 && day <= 34 { + return true + } + return false +} + +func (a castTimeFunc) IsHourValid(hour int) bool { + if hour >= 0 && hour <= 838 { + return true + } + return false +} + +func (a castTimeFunc) IsMinutesValid(minutes int) bool { + if minutes >= 0 && minutes <= 59 { + return true + } + return false +} + +func (a castTimeFunc) IsSecondValid(second int) bool { + if second >= 0 && second <= 59 { + return true + } + return false +} + +func (a castTimeFunc) MaxTimeValue() proto.Value { + return proto.NewValueString("838:59:59") +} + +func (a castTimeFunc) MinTimeValue() proto.Value { + return proto.NewValueString("-838:59:59") +} + +func (a castTimeFunc) DefaultTimeValue() proto.Value { + return proto.NewValueString("00:00:00") +} diff --git a/pkg/runtime/function/cast_time_test.go b/pkg/runtime/function/cast_time_test.go new file mode 100644 index 00000000..efffdaef --- /dev/null +++ b/pkg/runtime/function/cast_time_test.go @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package function + +import ( + "context" + "fmt" + "testing" +) + +import ( + "github.com/stretchr/testify/assert" +) + +import ( + "github.com/arana-db/arana/pkg/proto" +) + +func TestFuncCastTime(t *testing.T) { + fn := proto.MustGetFunc(FuncCastTime) + assert.Equal(t, 1, fn.NumInput()) + type tt struct { + inFirst string + want string + } + for _, v := range []tt{ + {"1 1:2:3.111111", "25:02:03"}, + {"2 1:2.599999", "49:02:01"}, + {"3 1", "73:00:00"}, + {"-34 22", "-838:00:00"}, + {"35 1", "838:59:59"}, + {"-838:59:59", "-838:59:59"}, + {"1:2:3", "01:02:03"}, + {"1:1", "01:01:00"}, + {"1:100", "00:00:00"}, + {"1:b", "00:00:00"}, + {"838:12:11", "838:12:11"}, + {"839:12:11", "838:59:59"}, + {"1", "00:00:01"}, + {"51219", "05:12:19"}, + {"173429", "17:34:29"}, + {"173470", "00:00:00"}, + {"176429", "00:00:00"}, + {"17ab10", "00:00:00"}, + {"8381211", "838:12:11"}, + {"8391211", "838:59:59"}, + } { + t.Run(v.want, func(t *testing.T) { + out, err := fn.Apply(context.Background(), proto.ToValuer(proto.NewValueString(v.inFirst))) + assert.NoError(t, err) + assert.Equal(t, v.want, fmt.Sprint(out)) + }) + } +} From 97ac206f35f7576bce8c25d3102d528462fcd614 Mon Sep 17 00:00:00 2001 From: csynineyang Date: Tue, 20 Dec 2022 19:39:25 +0800 Subject: [PATCH 05/20] Support MySQL CAST_DATE function. (#569) --- pkg/runtime/function/cast_date.go | 217 +++++++++++++++++++++++++ pkg/runtime/function/cast_date_test.go | 65 ++++++++ 2 files changed, 282 insertions(+) create mode 100644 pkg/runtime/function/cast_date.go create mode 100644 pkg/runtime/function/cast_date_test.go diff --git a/pkg/runtime/function/cast_date.go b/pkg/runtime/function/cast_date.go new file mode 100644 index 00000000..0ff32c1c --- /dev/null +++ b/pkg/runtime/function/cast_date.go @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package function + +import ( + "context" + "fmt" + "regexp" + "strconv" + "strings" +) + +import ( + "github.com/pkg/errors" +) + +import ( + "github.com/arana-db/arana/pkg/proto" +) + +// FuncCastDate is https://dev.mysql.com/doc/refman/5.6/en/cast-functions.html#function_cast +const FuncCastDate = "CAST_DATE" + +var _ proto.Func = (*castDateFunc)(nil) + +func init() { + proto.RegisterFunc(FuncCastDate, castDateFunc{}) +} + +type castDateFunc struct{} + +func (a castDateFunc) Apply(ctx context.Context, inputs ...proto.Valuer) (proto.Value, error) { + // expr + val, err := inputs[0].Value(ctx) + if err != nil { + return nil, errors.WithStack(err) + } + + dateArgs := fmt.Sprint(val) + if strings.Compare(dateArgs, "") == 0 { + return a.DefaultDateValue(), nil + } + + // format - YY-MM-DD, YYYY-MM-DD + pat := "^\\d{1,4}[~!@#$%^&*_+\\-=:;,.|?/]{1}\\d{1,2}[~!@#$%^&*_+\\-=:;,.|?/]{1}\\d{1,2}$" + match, err := regexp.MatchString(pat, dateArgs) + if match && err == nil { + dateLen := len(dateArgs) + dateDayStr := string(dateArgs[dateLen-1 : dateLen]) + dateLeft := string(dateArgs[0 : dateLen-2]) + if a.IsDigitalValid(string(dateArgs[dateLen-2])) { + dateDayStr = string(dateArgs[dateLen-2 : dateLen]) + dateLeft = string(dateArgs[0 : dateLen-3]) + } + dateArgs = dateLeft + + dateLen = len(dateArgs) + dateMonthStr := string(dateArgs[dateLen-1 : dateLen]) + dateLeft = string(dateArgs[0 : dateLen-2]) + if a.IsDigitalValid(string(dateArgs[dateLen-2])) { + dateMonthStr = string(dateArgs[dateLen-2 : dateLen]) + dateLeft = string(dateArgs[0 : dateLen-3]) + } + dateYearStr := dateLeft + + dateYear, _ := strconv.Atoi(dateYearStr) + dateYear = a.amend4DigtalYear(dateYear) + dateMonth, _ := strconv.Atoi(dateMonthStr) + dateDay, _ := strconv.Atoi(dateDayStr) + + if a.IsYearValid(dateYear) && a.IsMonthValid(dateMonth) && a.IsDayValid(dateYear, dateMonth, dateDay) { + dateStr := a.DateOutput(dateYear, dateMonth, dateDay) + return dateStr, nil + } else { + return a.DefaultDateValue(), nil + } + } + // format - YYYYMMDD, YYMMDD + pat = "^\\d{5,8}$" + match, err = regexp.MatchString(pat, dateArgs) + if match && err == nil { + dateLen := len(dateArgs) + dateDayStr := string(dateArgs[dateLen-2 : dateLen]) + dateLeft := string(dateArgs[0 : dateLen-2]) + dateArgs = dateLeft + + dateLen = len(dateArgs) + dateMonthStr := string(dateArgs[dateLen-2 : dateLen]) + dateLeft = string(dateArgs[0 : dateLen-2]) + dateYearStr := dateLeft + + dateYear, _ := strconv.Atoi(dateYearStr) + dateYear = a.amend4DigtalYear(dateYear) + dateMonth, _ := strconv.Atoi(dateMonthStr) + dateDay, _ := strconv.Atoi(dateDayStr) + + if a.IsYearValid(dateYear) && a.IsMonthValid(dateMonth) && a.IsDayValid(dateYear, dateMonth, dateDay) { + dateStr := a.DateOutput(dateYear, dateMonth, dateDay) + return dateStr, nil + } else { + return a.DefaultDateValue(), nil + } + } + + return a.DefaultDateValue(), nil +} + +func (a castDateFunc) NumInput() int { + return 1 +} + +func (a castDateFunc) DateOutput(year, month, day int) proto.Value { + dayStr := fmt.Sprint(day) + if day < 10 { + dayStr = "0" + dayStr + } + monthStr := fmt.Sprint(month) + if month < 10 { + monthStr = "0" + monthStr + } + yearStr := fmt.Sprint(year) + if year >= 100 && year <= 999 { + yearStr = "0" + yearStr + } + + dateStr := yearStr + "-" + monthStr + "-" + dayStr + return proto.NewValueString(dateStr) +} + +func (a castDateFunc) amend4DigtalYear(year int) int { + if year >= 0 && year <= 69 { + year += 2000 + } + if year >= 70 && year <= 99 { + year += 1900 + } + return year +} + +func (a castDateFunc) IsDigitalValid(data string) bool { + if len(data) == 1 && data >= "0" && data <= "9" { + return true + } + return false +} + +func (a castDateFunc) IsYearValid(year int) bool { + if year >= 100 && year <= 9999 { + return true + } + return false +} + +func (a castDateFunc) IsMonthValid(month int) bool { + if month >= 1 && month <= 12 { + return true + } + return false +} + +func (a castDateFunc) IsDayValid(year, month, day int) bool { + if month == 1 || month == 3 || month == 5 || month == 7 || + month == 8 || month == 10 || month == 12 { + if day >= 1 && day <= 31 { + return true + } + return false + } + if month == 4 || month == 6 || month == 9 || month == 11 { + if day >= 1 && day <= 30 { + return true + } + return false + } + if month == 2 { + if (year%100 == 0 && year%400 == 0) || + (year%100 > 0 && year%4 == 0) { + if day >= 1 && day <= 29 { + return true + } + return false + } else { + if day >= 1 && day <= 28 { + return true + } + return false + } + } + return false +} + +func (a castDateFunc) MaxDateValue() proto.Value { + return proto.NewValueString("9999-12-31") +} + +func (a castDateFunc) MinDateValue() proto.Value { + return proto.NewValueString("0000-01-01") +} + +func (a castDateFunc) DefaultDateValue() proto.Value { + return proto.NewValueString("0000-00-00") +} diff --git a/pkg/runtime/function/cast_date_test.go b/pkg/runtime/function/cast_date_test.go new file mode 100644 index 00000000..9e68bd7c --- /dev/null +++ b/pkg/runtime/function/cast_date_test.go @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package function + +import ( + "context" + "fmt" + "testing" +) + +import ( + "github.com/stretchr/testify/assert" +) + +import ( + "github.com/arana-db/arana/pkg/proto" +) + +func TestFuncCastDate(t *testing.T) { + fn := proto.MustGetFunc(FuncCastDate) + assert.Equal(t, 1, fn.NumInput()) + type tt struct { + inFirst string + want string + } + for _, v := range []tt{ + {"99-12-2", "1999-12-02"}, + {"99-12-20", "1999-12-20"}, + {"5#2?2", "2005-02-02"}, + {"199#2?2", "0199-02-02"}, + {"12.2+29", "2012-02-29"}, + {"22.2+29", "0000-00-00"}, + {"2.15+20", "0000-00-00"}, + {"2002.5+20", "2002-05-20"}, + {"2002.-5+20", "0000-00-00"}, + {"991202", "1999-12-02"}, + {"19991202", "1999-12-02"}, + {"51202", "2005-12-02"}, + {"051202", "2005-12-02"}, + {"1991202", "0199-12-02"}, + {"20051202", "2005-12-02"}, + {"20051234", "0000-00-00"}, + } { + t.Run(v.want, func(t *testing.T) { + out, err := fn.Apply(context.Background(), proto.ToValuer(proto.NewValueString(v.inFirst))) + assert.NoError(t, err) + assert.Equal(t, v.want, fmt.Sprint(out)) + }) + } +} From e55e00fc63766253c027cf028a0d11bc3e8c711d Mon Sep 17 00:00:00 2001 From: csynineyang Date: Mon, 26 Dec 2022 17:56:26 +0800 Subject: [PATCH 06/20] Support MySQL CAST_DATETIME function. (#568) --- pkg/runtime/function/cast_date.go | 84 ++++--- pkg/runtime/function/cast_datetime.go | 272 +++++++++++++++++++++ pkg/runtime/function/cast_datetime_test.go | 67 +++++ pkg/runtime/function/cast_time.go | 71 +++--- pkg/runtime/function/cast_time_test.go | 1 + 5 files changed, 425 insertions(+), 70 deletions(-) create mode 100644 pkg/runtime/function/cast_datetime.go create mode 100644 pkg/runtime/function/cast_datetime_test.go diff --git a/pkg/runtime/function/cast_date.go b/pkg/runtime/function/cast_date.go index 0ff32c1c..2eb98a8d 100644 --- a/pkg/runtime/function/cast_date.go +++ b/pkg/runtime/function/cast_date.go @@ -60,28 +60,7 @@ func (a castDateFunc) Apply(ctx context.Context, inputs ...proto.Valuer) (proto. pat := "^\\d{1,4}[~!@#$%^&*_+\\-=:;,.|?/]{1}\\d{1,2}[~!@#$%^&*_+\\-=:;,.|?/]{1}\\d{1,2}$" match, err := regexp.MatchString(pat, dateArgs) if match && err == nil { - dateLen := len(dateArgs) - dateDayStr := string(dateArgs[dateLen-1 : dateLen]) - dateLeft := string(dateArgs[0 : dateLen-2]) - if a.IsDigitalValid(string(dateArgs[dateLen-2])) { - dateDayStr = string(dateArgs[dateLen-2 : dateLen]) - dateLeft = string(dateArgs[0 : dateLen-3]) - } - dateArgs = dateLeft - - dateLen = len(dateArgs) - dateMonthStr := string(dateArgs[dateLen-1 : dateLen]) - dateLeft = string(dateArgs[0 : dateLen-2]) - if a.IsDigitalValid(string(dateArgs[dateLen-2])) { - dateMonthStr = string(dateArgs[dateLen-2 : dateLen]) - dateLeft = string(dateArgs[0 : dateLen-3]) - } - dateYearStr := dateLeft - - dateYear, _ := strconv.Atoi(dateYearStr) - dateYear = a.amend4DigtalYear(dateYear) - dateMonth, _ := strconv.Atoi(dateMonthStr) - dateDay, _ := strconv.Atoi(dateDayStr) + dateYear, dateMonth, dateDay := a.splitDateWithSep(dateArgs) if a.IsYearValid(dateYear) && a.IsMonthValid(dateMonth) && a.IsDayValid(dateYear, dateMonth, dateDay) { dateStr := a.DateOutput(dateYear, dateMonth, dateDay) @@ -94,20 +73,7 @@ func (a castDateFunc) Apply(ctx context.Context, inputs ...proto.Valuer) (proto. pat = "^\\d{5,8}$" match, err = regexp.MatchString(pat, dateArgs) if match && err == nil { - dateLen := len(dateArgs) - dateDayStr := string(dateArgs[dateLen-2 : dateLen]) - dateLeft := string(dateArgs[0 : dateLen-2]) - dateArgs = dateLeft - - dateLen = len(dateArgs) - dateMonthStr := string(dateArgs[dateLen-2 : dateLen]) - dateLeft = string(dateArgs[0 : dateLen-2]) - dateYearStr := dateLeft - - dateYear, _ := strconv.Atoi(dateYearStr) - dateYear = a.amend4DigtalYear(dateYear) - dateMonth, _ := strconv.Atoi(dateMonthStr) - dateDay, _ := strconv.Atoi(dateDayStr) + dateYear, dateMonth, dateDay := a.splitDateWithoutSep(dateArgs) if a.IsYearValid(dateYear) && a.IsMonthValid(dateMonth) && a.IsDayValid(dateYear, dateMonth, dateDay) { dateStr := a.DateOutput(dateYear, dateMonth, dateDay) @@ -142,6 +108,52 @@ func (a castDateFunc) DateOutput(year, month, day int) proto.Value { return proto.NewValueString(dateStr) } +func (a castDateFunc) splitDateWithSep(dateArgs string) (year, month, day int) { + dateLen := len(dateArgs) + dateDayStr := string(dateArgs[dateLen-1 : dateLen]) + dateLeft := string(dateArgs[0 : dateLen-2]) + if a.IsDigitalValid(string(dateArgs[dateLen-2])) { + dateDayStr = string(dateArgs[dateLen-2 : dateLen]) + dateLeft = string(dateArgs[0 : dateLen-3]) + } + dateArgs = dateLeft + + dateLen = len(dateArgs) + dateMonthStr := string(dateArgs[dateLen-1 : dateLen]) + dateLeft = string(dateArgs[0 : dateLen-2]) + if a.IsDigitalValid(string(dateArgs[dateLen-2])) { + dateMonthStr = string(dateArgs[dateLen-2 : dateLen]) + dateLeft = string(dateArgs[0 : dateLen-3]) + } + dateYearStr := dateLeft + + dateYear, _ := strconv.Atoi(dateYearStr) + dateYear = a.amend4DigtalYear(dateYear) + dateMonth, _ := strconv.Atoi(dateMonthStr) + dateDay, _ := strconv.Atoi(dateDayStr) + + return dateYear, dateMonth, dateDay +} + +func (a castDateFunc) splitDateWithoutSep(dateArgs string) (year, month, day int) { + dateLen := len(dateArgs) + dateDayStr := string(dateArgs[dateLen-2 : dateLen]) + dateLeft := string(dateArgs[0 : dateLen-2]) + dateArgs = dateLeft + + dateLen = len(dateArgs) + dateMonthStr := string(dateArgs[dateLen-2 : dateLen]) + dateLeft = string(dateArgs[0 : dateLen-2]) + dateYearStr := dateLeft + + dateYear, _ := strconv.Atoi(dateYearStr) + dateYear = a.amend4DigtalYear(dateYear) + dateMonth, _ := strconv.Atoi(dateMonthStr) + dateDay, _ := strconv.Atoi(dateDayStr) + + return dateYear, dateMonth, dateDay +} + func (a castDateFunc) amend4DigtalYear(year int) int { if year >= 0 && year <= 69 { year += 2000 diff --git a/pkg/runtime/function/cast_datetime.go b/pkg/runtime/function/cast_datetime.go new file mode 100644 index 00000000..db7673f7 --- /dev/null +++ b/pkg/runtime/function/cast_datetime.go @@ -0,0 +1,272 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package function + +import ( + "context" + "fmt" + "regexp" + "strconv" + "strings" +) + +import ( + "github.com/pkg/errors" +) + +import ( + "github.com/arana-db/arana/pkg/proto" +) + +// FuncCastDatetime is https://dev.mysql.com/doc/refman/5.6/en/cast-functions.html#function_cast +const FuncCastDatetime = "CAST_DATETIME" + +var _ proto.Func = (*castDatetimeFunc)(nil) + +func init() { + proto.RegisterFunc(FuncCastDatetime, castDatetimeFunc{}) +} + +type castDatetimeFunc struct{} + +var castDate castDateFunc +var castTime castTimeFunc + +func (a castDatetimeFunc) Apply(ctx context.Context, inputs ...proto.Valuer) (proto.Value, error) { + // expr + val, err := inputs[0].Value(ctx) + if err != nil { + return nil, errors.WithStack(err) + } + + datetimeArgs := fmt.Sprint(val) + if strings.Compare(datetimeArgs, "") == 0 { + return a.DefaultDatetimeValue(), nil + } + + // fractional seconds + frac := false + if strings.Contains(datetimeArgs, ".") { + datetimeArr := strings.Split(datetimeArgs, ".") + if len(datetimeArr) >= 2 && len(datetimeArr[1]) >= 1 { + datetimeFrac, _ := strconv.Atoi(string(datetimeArr[1][0])) + if datetimeFrac >= 5 { + frac = true + } + } + datetimeArgs = datetimeArr[0] + } + + if strings.Contains(datetimeArgs, " ") || strings.Contains(datetimeArgs, "T") { + // format - YYYY-MM-DD hh:mm:ss.ms or YYYY-MM-DDThh:mm:ss.ms + var datetimeArr []string + if strings.Contains(datetimeArgs, " ") { + datetimeArr = strings.Split(datetimeArgs, " ") + } else if strings.Contains(datetimeArgs, "T") { + datetimeArr = strings.Split(datetimeArgs, "T") + } else { + return a.DefaultDatetimeValue(), nil + } + + pat := "^\\d{1,4}[~!@#$%^&*_+\\-=:;,|?/]{1}\\d{1,2}[~!@#$%^&*_+\\-=:;,|?/]{1}\\d{1,2}$" + match, err := regexp.MatchString(pat, datetimeArr[0]) + if !match || err != nil { + return a.DefaultDatetimeValue(), nil + } + year, month, day := castDate.splitDateWithSep(datetimeArr[0]) + year = castDate.amend4DigtalYear(year) + + pat = "^\\d{1,2}[~!@#$%^&*_+\\-=:;,|?/]{1}\\d{1,2}[~!@#$%^&*_+\\-=:;,|?/]{1}\\d{1,2}$" + match, err = regexp.MatchString(pat, datetimeArr[1]) + if !match || err != nil { + return a.DefaultDatetimeValue(), nil + } + hour, minutes, second := a.splitDatetimeWithSep(datetimeArr[1]) + + if castDate.IsYearValid(year) && castDate.IsMonthValid(month) && castDate.IsDayValid(year, month, day) && + a.IsHourValid(hour) && castTime.IsMinutesValid(minutes) && castTime.IsSecondValid(second) { + return a.DatetimeOutput(year, month, day, hour, minutes, second, frac), nil + } else { + return a.DefaultDatetimeValue(), nil + } + } else { + // format - YYYYMMDDhhmmss.ms + pat := "^\\d{11,14}$" + match, err := regexp.MatchString(pat, datetimeArgs) + if !match || err != nil { + return a.DefaultDatetimeValue(), nil + } + + datetimeLen := len(datetimeArgs) + dateArgs := string(datetimeArgs[0 : datetimeLen-6]) + year, month, day := castDate.splitDateWithoutSep(dateArgs) + year = castDate.amend4DigtalYear(year) + + timeArgs := string(datetimeArgs[datetimeLen-6 : datetimeLen]) + hour, minutes, second := a.splitDatetimeWithoutSep(timeArgs) + + if castDate.IsYearValid(year) && castDate.IsMonthValid(month) && castDate.IsDayValid(year, month, day) && + a.IsHourValid(hour) && castTime.IsMinutesValid(minutes) && castTime.IsSecondValid(second) { + return a.DatetimeOutput(year, month, day, hour, minutes, second, frac), nil + } else { + return a.DefaultDatetimeValue(), nil + } + } +} + +func (a castDatetimeFunc) NumInput() int { + return 1 +} + +func (a castDatetimeFunc) DatetimeOutput(year, month, day, hour, minutes, second int, frac bool) proto.Value { + if frac { + second += 1 + if second >= 60 { + minutes += 1 + second = 0 + } + if minutes >= 60 { + hour += 1 + minutes = 0 + } + if hour >= 24 { + day += 1 + hour = 0 + } + + if month == 1 || month == 3 || month == 5 || month == 7 || + month == 8 || month == 10 || month == 12 { + if day >= 32 { + month += 1 + day = 1 + } + } + if month == 4 || month == 6 || month == 9 || month == 11 { + if day >= 31 { + month += 1 + day = 1 + } + } + if month == 2 { + if (year%100 == 0 && year%400 == 0) || + (year%100 > 0 && year%4 == 0) { + if day >= 30 { + month += 1 + day = 1 + } + } else { + if day >= 29 { + month += 1 + day = 1 + } + } + } + if month >= 13 { + year += 1 + month = 1 + } + if year >= 10000 { + return a.DefaultDatetimeValue() + } + } + + secondStr := fmt.Sprint(second) + if second < 10 { + secondStr = "0" + secondStr + } + minutesStr := fmt.Sprint(minutes) + if minutes < 10 { + minutesStr = "0" + minutesStr + } + hourStr := fmt.Sprint(hour) + if hour < 10 { + hourStr = "0" + hourStr + } + timeStr := hourStr + ":" + minutesStr + ":" + secondStr + + dayStr := fmt.Sprint(day) + if day < 10 { + dayStr = "0" + dayStr + } + monthStr := fmt.Sprint(month) + if month < 10 { + monthStr = "0" + monthStr + } + yearStr := fmt.Sprint(year) + if year >= 100 && year <= 999 { + yearStr = "0" + yearStr + } + dateStr := yearStr + "-" + monthStr + "-" + dayStr + + return proto.NewValueString(dateStr + " " + timeStr) +} + +func (a castDatetimeFunc) splitDatetimeWithSep(timeArgs string) (hour, minutes, second int) { + timeLen := len(timeArgs) + timeSecondStr := string(timeArgs[timeLen-1 : timeLen]) + timeLeft := string(timeArgs[0 : timeLen-2]) + if castDate.IsDigitalValid(string(timeArgs[timeLen-2])) { + timeSecondStr = string(timeArgs[timeLen-2 : timeLen]) + timeLeft = string(timeArgs[0 : timeLen-3]) + } + timeArgs = timeLeft + + timeLen = len(timeArgs) + timeMinutesStr := string(timeArgs[timeLen-1 : timeLen]) + timeLeft = string(timeArgs[0 : timeLen-2]) + if castDate.IsDigitalValid(string(timeArgs[timeLen-2])) { + timeMinutesStr = string(timeArgs[timeLen-2 : timeLen]) + timeLeft = string(timeArgs[0 : timeLen-3]) + } + timeHourStr := timeLeft + + timeHour, _ := strconv.Atoi(timeHourStr) + timeMinutes, _ := strconv.Atoi(timeMinutesStr) + timeSecond, _ := strconv.Atoi(timeSecondStr) + + return timeHour, timeMinutes, timeSecond +} + +func (a castDatetimeFunc) splitDatetimeWithoutSep(timeArgs string) (hour, minutes, second int) { + timeLen := len(timeArgs) + timeSecondStr := string(timeArgs[timeLen-2 : timeLen]) + timeLeft := string(timeArgs[0 : timeLen-2]) + timeArgs = timeLeft + + timeLen = len(timeArgs) + timeMinutesStr := string(timeArgs[timeLen-2 : timeLen]) + timeLeft = string(timeArgs[0 : timeLen-2]) + timeHourStr := timeLeft + + timeHour, _ := strconv.Atoi(timeHourStr) + timeMinutes, _ := strconv.Atoi(timeMinutesStr) + timeSecond, _ := strconv.Atoi(timeSecondStr) + + return timeHour, timeMinutes, timeSecond +} + +func (a castDatetimeFunc) IsHourValid(hour int) bool { + if hour >= 0 && hour <= 23 { + return true + } + return false +} + +func (a castDatetimeFunc) DefaultDatetimeValue() proto.Value { + return proto.NewValueString("0000-00-00 00:00:00") +} diff --git a/pkg/runtime/function/cast_datetime_test.go b/pkg/runtime/function/cast_datetime_test.go new file mode 100644 index 00000000..50c2f767 --- /dev/null +++ b/pkg/runtime/function/cast_datetime_test.go @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package function + +import ( + "context" + "fmt" + "testing" +) + +import ( + "github.com/stretchr/testify/assert" +) + +import ( + "github.com/arana-db/arana/pkg/proto" +) + +func TestFuncCastDatetime(t *testing.T) { + fn := proto.MustGetFunc(FuncCastDatetime) + assert.Equal(t, 1, fn.NumInput()) + type tt struct { + inFirst string + want string + } + for _, v := range []tt{ + {"99-12-2 1:2:3", "1999-12-02 01:02:03"}, + {"99-12-20T11:2:33", "1999-12-20 11:02:33"}, + {"5#2?2 17%33!24.486762", "2005-02-02 17:33:24"}, + {"199#2?2T23#16+44", "0199-02-02 23:16:44"}, + {"12=2+29 23=59+59.587425", "2012-03-01 00:00:00"}, + {"22=2+29 8=42+11", "0000-00-00 00:00:00"}, + {"2=15+20 11=29+56 ", "0000-00-00 00:00:00"}, + {"2002=5+20 12+34=59.986345", "2002-05-20 12:35:00"}, + {"2002=-5+20 7=32+11", "0000-00-00 00:00:00"}, + {"991202052317.342167", "1999-12-02 05:23:17"}, + {"19991202225959.734128", "1999-12-02 23:00:00"}, + {"51202000000", "2005-12-02 00:00:00"}, + {"051202122458", "2005-12-02 12:24:58"}, + {"1991202091245", "0199-12-02 09:12:45"}, + {"20051202123459.172124", "2005-12-02 12:34:59"}, + {"20051234193247", "0000-00-00 00:00:00"}, + {"20051234561324", "0000-00-00 00:00:00"}, + {"00000000000000", "0000-00-00 00:00:00"}, + } { + t.Run(v.want, func(t *testing.T) { + out, err := fn.Apply(context.Background(), proto.ToValuer(proto.NewValueString(v.inFirst))) + assert.NoError(t, err) + assert.Equal(t, v.want, fmt.Sprint(out)) + }) + } +} diff --git a/pkg/runtime/function/cast_time.go b/pkg/runtime/function/cast_time.go index 2d069e37..b6487b6d 100644 --- a/pkg/runtime/function/cast_time.go +++ b/pkg/runtime/function/cast_time.go @@ -62,6 +62,9 @@ func (a castTimeFunc) Apply(ctx context.Context, inputs ...proto.Valuer) (proto. nega = true timeArgs = string(timeArgs[1:]) } + if strings.Compare(string(timeArgs[0]), "+") == 0 { + timeArgs = string(timeArgs[1:]) + } // fractional seconds frac := false @@ -77,7 +80,7 @@ func (a castTimeFunc) Apply(ctx context.Context, inputs ...proto.Valuer) (proto. } if strings.Contains(timeArgs, " ") { - // format - D hhh:mm:ss,D hhh:mm,D hhh + // format - D hhh:mm:ss.ms,D hhh:mm.ms,D hhh.ms pat := "^\\d{1,2} \\d{1,3}(:\\d{1,2}){0,2}$" match, err := regexp.MatchString(pat, timeArgs) if !match || err != nil { @@ -86,20 +89,8 @@ func (a castTimeFunc) Apply(ctx context.Context, inputs ...proto.Valuer) (proto. timeArrLeft := strings.Split(timeArgs, " ") timeDay, _ := strconv.Atoi(timeArrLeft[0]) - timeArrRight := strings.Split(timeArrLeft[1], ":") - timeHour := 0 - if len(timeArrRight) >= 1 { - timeHour, _ = strconv.Atoi(timeArrRight[0]) - } + timeHour, timeMinutes, timeSecond := a.splitTimeWithSep(timeArrLeft[1]) timeHour += timeDay * 24 - timeMinutes := 0 - if len(timeArrRight) >= 2 { - timeMinutes, _ = strconv.Atoi(timeArrRight[1]) - } - timeSecond := 0 - if len(timeArrRight) >= 3 { - timeSecond, _ = strconv.Atoi(timeArrRight[2]) - } if !a.IsHourValid(timeHour) { return a.MaxTimeValue(), nil @@ -109,26 +100,14 @@ func (a castTimeFunc) Apply(ctx context.Context, inputs ...proto.Valuer) (proto. return timeStr, nil } } else if strings.Contains(timeArgs, ":") { - // format - hhh:mm:ss,hhh:mm + // format - hhh:mm:ss.ms,hhh:mm.ms pat := "^\\d{1,3}(:\\d{1,2}){1,2}$" match, err := regexp.MatchString(pat, timeArgs) if !match || err != nil { return a.DefaultTimeValue(), nil } - timeArr := strings.Split(timeArgs, ":") - timeHour := 0 - if len(timeArr) >= 1 { - timeHour, _ = strconv.Atoi(timeArr[0]) - } - timeMinutes := 0 - if len(timeArr) >= 2 { - timeMinutes, _ = strconv.Atoi(timeArr[1]) - } - timeSecond := 0 - if len(timeArr) >= 3 { - timeSecond, _ = strconv.Atoi(timeArr[2]) - } + timeHour, timeMinutes, timeSecond := a.splitTimeWithSep(timeArgs) if !a.IsHourValid(timeHour) { return a.MaxTimeValue(), nil @@ -138,18 +117,14 @@ func (a castTimeFunc) Apply(ctx context.Context, inputs ...proto.Valuer) (proto. return timeStr, nil } } else { - // format - hhhmmss,mmss,ss + // format - hhhmmss.ms,mmss.ms,ss.ms pat := "^\\d{1,7}$" match, err := regexp.MatchString(pat, timeArgs) if !match || err != nil { return a.DefaultTimeValue(), nil } - timeInt, _ := strconv.Atoi(timeArgs) - timeSecond := timeInt % 100 - timeLeft := timeInt / 100 - timeMinutes := timeLeft % 100 - timeHour := timeLeft / 100 + timeHour, timeMinutes, timeSecond := a.splitTimeWithoutSep(timeArgs) if !a.IsHourValid(timeHour) { return a.MaxTimeValue(), nil @@ -208,6 +183,34 @@ func (a castTimeFunc) TimeOutput(hour, minutes, second int, nega, frac bool) pro return proto.NewValueString(timeStr) } +func (a castTimeFunc) splitTimeWithSep(timeArgs string) (hour, minutes, second int) { + timeArr := strings.Split(timeArgs, ":") + timeHour := 0 + if len(timeArr) >= 1 { + timeHour, _ = strconv.Atoi(timeArr[0]) + } + timeMinutes := 0 + if len(timeArr) >= 2 { + timeMinutes, _ = strconv.Atoi(timeArr[1]) + } + timeSecond := 0 + if len(timeArr) >= 3 { + timeSecond, _ = strconv.Atoi(timeArr[2]) + } + + return timeHour, timeMinutes, timeSecond +} + +func (a castTimeFunc) splitTimeWithoutSep(timeArgs string) (hour, minutes, second int) { + timeInt, _ := strconv.Atoi(timeArgs) + timeSecond := timeInt % 100 + timeLeft := timeInt / 100 + timeMinutes := timeLeft % 100 + timeHour := timeLeft / 100 + + return timeHour, timeMinutes, timeSecond +} + func (a castTimeFunc) IsDayValid(day int) bool { if day >= 0 && day <= 34 { return true diff --git a/pkg/runtime/function/cast_time_test.go b/pkg/runtime/function/cast_time_test.go index efffdaef..735cbd11 100644 --- a/pkg/runtime/function/cast_time_test.go +++ b/pkg/runtime/function/cast_time_test.go @@ -52,6 +52,7 @@ func TestFuncCastTime(t *testing.T) { {"838:12:11", "838:12:11"}, {"839:12:11", "838:59:59"}, {"1", "00:00:01"}, + {"102", "00:01:02"}, {"51219", "05:12:19"}, {"173429", "17:34:29"}, {"173470", "00:00:00"}, From 9ca3805e4d17105c6f7435e184b02ca999bcc5e8 Mon Sep 17 00:00:00 2001 From: csynineyang Date: Tue, 17 Jan 2023 15:56:15 +0800 Subject: [PATCH 07/20] Support MySQL CAST_TIME/CAST_DATE/CAST_DATETIME function --- pkg/runtime/function/cast_date.go | 7 ++- pkg/runtime/function/cast_date_test.go | 2 +- pkg/runtime/function/cast_datetime.go | 72 +++++----------------- pkg/runtime/function/cast_datetime_test.go | 2 +- 4 files changed, 23 insertions(+), 60 deletions(-) diff --git a/pkg/runtime/function/cast_date.go b/pkg/runtime/function/cast_date.go index 2eb98a8d..696601c6 100644 --- a/pkg/runtime/function/cast_date.go +++ b/pkg/runtime/function/cast_date.go @@ -36,6 +36,7 @@ import ( // FuncCastDate is https://dev.mysql.com/doc/refman/5.6/en/cast-functions.html#function_cast const FuncCastDate = "CAST_DATE" +var DateSep = "~!@#$%^&*_+=:;,.|/?\\(\\)\\[\\]\\{\\}\\-\\\\" var _ proto.Func = (*castDateFunc)(nil) func init() { @@ -57,10 +58,12 @@ func (a castDateFunc) Apply(ctx context.Context, inputs ...proto.Valuer) (proto. } // format - YY-MM-DD, YYYY-MM-DD - pat := "^\\d{1,4}[~!@#$%^&*_+\\-=:;,.|?/]{1}\\d{1,2}[~!@#$%^&*_+\\-=:;,.|?/]{1}\\d{1,2}$" + pat := "^\\d{1,4}[" + DateSep + "]+\\d{1,2}[" + DateSep + "]+\\d{1,2}$" match, err := regexp.MatchString(pat, dateArgs) if match && err == nil { - dateYear, dateMonth, dateDay := a.splitDateWithSep(dateArgs) + rep := regexp.MustCompile(`[` + DateSep + `]+`) + dateArgsReplace := rep.ReplaceAllStringFunc(dateArgs, func(s string) string { return "-" }) + dateYear, dateMonth, dateDay := a.splitDateWithSep(dateArgsReplace) if a.IsYearValid(dateYear) && a.IsMonthValid(dateMonth) && a.IsDayValid(dateYear, dateMonth, dateDay) { dateStr := a.DateOutput(dateYear, dateMonth, dateDay) diff --git a/pkg/runtime/function/cast_date_test.go b/pkg/runtime/function/cast_date_test.go index 9e68bd7c..15501b04 100644 --- a/pkg/runtime/function/cast_date_test.go +++ b/pkg/runtime/function/cast_date_test.go @@ -47,7 +47,7 @@ func TestFuncCastDate(t *testing.T) { {"22.2+29", "0000-00-00"}, {"2.15+20", "0000-00-00"}, {"2002.5+20", "2002-05-20"}, - {"2002.-5+20", "0000-00-00"}, + {"2002.-5+20", "2002-05-20"}, {"991202", "1999-12-02"}, {"19991202", "1999-12-02"}, {"51202", "2005-12-02"}, diff --git a/pkg/runtime/function/cast_datetime.go b/pkg/runtime/function/cast_datetime.go index db7673f7..4113373d 100644 --- a/pkg/runtime/function/cast_datetime.go +++ b/pkg/runtime/function/cast_datetime.go @@ -23,6 +23,7 @@ import ( "regexp" "strconv" "strings" + "time" ) import ( @@ -36,6 +37,7 @@ import ( // FuncCastDatetime is https://dev.mysql.com/doc/refman/5.6/en/cast-functions.html#function_cast const FuncCastDatetime = "CAST_DATETIME" +var DatetimeSep = "~!@#$%^&*_+=:;,|/?\\(\\)\\[\\]\\{\\}\\-\\\\" var _ proto.Func = (*castDatetimeFunc)(nil) func init() { @@ -83,20 +85,23 @@ func (a castDatetimeFunc) Apply(ctx context.Context, inputs ...proto.Valuer) (pr return a.DefaultDatetimeValue(), nil } - pat := "^\\d{1,4}[~!@#$%^&*_+\\-=:;,|?/]{1}\\d{1,2}[~!@#$%^&*_+\\-=:;,|?/]{1}\\d{1,2}$" + pat := "^\\d{1,4}[" + DatetimeSep + "]+\\d{1,2}[" + DatetimeSep + "]+\\d{1,2}$" match, err := regexp.MatchString(pat, datetimeArr[0]) if !match || err != nil { return a.DefaultDatetimeValue(), nil } - year, month, day := castDate.splitDateWithSep(datetimeArr[0]) + rep := regexp.MustCompile(`[` + DateSep + `]+`) + datetimeArrReplace := rep.ReplaceAllStringFunc(datetimeArr[0], func(s string) string { return "-" }) + year, month, day := castDate.splitDateWithSep(datetimeArrReplace) year = castDate.amend4DigtalYear(year) - pat = "^\\d{1,2}[~!@#$%^&*_+\\-=:;,|?/]{1}\\d{1,2}[~!@#$%^&*_+\\-=:;,|?/]{1}\\d{1,2}$" + pat = "^\\d{1,2}[" + DatetimeSep + "]+\\d{1,2}[" + DatetimeSep + "]+\\d{1,2}$" match, err = regexp.MatchString(pat, datetimeArr[1]) if !match || err != nil { return a.DefaultDatetimeValue(), nil } - hour, minutes, second := a.splitDatetimeWithSep(datetimeArr[1]) + datetimeArrReplace = rep.ReplaceAllStringFunc(datetimeArr[1], func(s string) string { return "-" }) + hour, minutes, second := a.splitDatetimeWithSep(datetimeArrReplace) if castDate.IsYearValid(year) && castDate.IsMonthValid(month) && castDate.IsDayValid(year, month, day) && a.IsHourValid(hour) && castTime.IsMinutesValid(minutes) && castTime.IsSecondValid(second) { @@ -134,57 +139,6 @@ func (a castDatetimeFunc) NumInput() int { } func (a castDatetimeFunc) DatetimeOutput(year, month, day, hour, minutes, second int, frac bool) proto.Value { - if frac { - second += 1 - if second >= 60 { - minutes += 1 - second = 0 - } - if minutes >= 60 { - hour += 1 - minutes = 0 - } - if hour >= 24 { - day += 1 - hour = 0 - } - - if month == 1 || month == 3 || month == 5 || month == 7 || - month == 8 || month == 10 || month == 12 { - if day >= 32 { - month += 1 - day = 1 - } - } - if month == 4 || month == 6 || month == 9 || month == 11 { - if day >= 31 { - month += 1 - day = 1 - } - } - if month == 2 { - if (year%100 == 0 && year%400 == 0) || - (year%100 > 0 && year%4 == 0) { - if day >= 30 { - month += 1 - day = 1 - } - } else { - if day >= 29 { - month += 1 - day = 1 - } - } - } - if month >= 13 { - year += 1 - month = 1 - } - if year >= 10000 { - return a.DefaultDatetimeValue() - } - } - secondStr := fmt.Sprint(second) if second < 10 { secondStr = "0" + secondStr @@ -213,7 +167,13 @@ func (a castDatetimeFunc) DatetimeOutput(year, month, day, hour, minutes, second } dateStr := yearStr + "-" + monthStr + "-" + dayStr - return proto.NewValueString(dateStr + " " + timeStr) + datetimeRet, _ := time.Parse("2006-01-02 15:04:05", dateStr+" "+timeStr) + + if frac { + datetimeRet = datetimeRet.Add(1 * time.Second) + } + + return proto.NewValueString(datetimeRet.Format("2006-01-02 15:04:05")) } func (a castDatetimeFunc) splitDatetimeWithSep(timeArgs string) (hour, minutes, second int) { diff --git a/pkg/runtime/function/cast_datetime_test.go b/pkg/runtime/function/cast_datetime_test.go index 50c2f767..a87d8c34 100644 --- a/pkg/runtime/function/cast_datetime_test.go +++ b/pkg/runtime/function/cast_datetime_test.go @@ -47,7 +47,7 @@ func TestFuncCastDatetime(t *testing.T) { {"22=2+29 8=42+11", "0000-00-00 00:00:00"}, {"2=15+20 11=29+56 ", "0000-00-00 00:00:00"}, {"2002=5+20 12+34=59.986345", "2002-05-20 12:35:00"}, - {"2002=-5+20 7=32+11", "0000-00-00 00:00:00"}, + {"2002=-5+20 7=32+11", "2002-05-20 07:32:11"}, {"991202052317.342167", "1999-12-02 05:23:17"}, {"19991202225959.734128", "1999-12-02 23:00:00"}, {"51202000000", "2005-12-02 00:00:00"}, From 9c083caccc4ba9a44dfb48bfc664098b235a8de0 Mon Sep 17 00:00:00 2001 From: csynineyang Date: Thu, 2 Feb 2023 14:24:02 +0800 Subject: [PATCH 08/20] Resolve Conversation --- pkg/runtime/function/cast_date.go | 73 ++++++++--------- pkg/runtime/function/cast_datetime.go | 108 ++++++++++++++------------ pkg/runtime/function/cast_time.go | 54 +++++++------ 3 files changed, 128 insertions(+), 107 deletions(-) diff --git a/pkg/runtime/function/cast_date.go b/pkg/runtime/function/cast_date.go index 696601c6..9ac960b5 100644 --- a/pkg/runtime/function/cast_date.go +++ b/pkg/runtime/function/cast_date.go @@ -36,7 +36,11 @@ import ( // FuncCastDate is https://dev.mysql.com/doc/refman/5.6/en/cast-functions.html#function_cast const FuncCastDate = "CAST_DATE" -var DateSep = "~!@#$%^&*_+=:;,.|/?\\(\\)\\[\\]\\{\\}\\-\\\\" +var DateSep = `[~!@#$%^&*_+=:;,.|/?\(\)\[\]\{\}\-\\]+` +var _dateReplace = regexp.MustCompile(DateSep) +var _dateMatchString = regexp.MustCompile(fmt.Sprintf(`^\d{1,4}%s\d{1,2}%s\d{1,2}$`, DateSep, DateSep)) +var _dateMatchInt = regexp.MustCompile(`^\d{5,8}$`) + var _ proto.Func = (*castDateFunc)(nil) func init() { @@ -58,31 +62,24 @@ func (a castDateFunc) Apply(ctx context.Context, inputs ...proto.Valuer) (proto. } // format - YY-MM-DD, YYYY-MM-DD - pat := "^\\d{1,4}[" + DateSep + "]+\\d{1,2}[" + DateSep + "]+\\d{1,2}$" - match, err := regexp.MatchString(pat, dateArgs) - if match && err == nil { - rep := regexp.MustCompile(`[` + DateSep + `]+`) - dateArgsReplace := rep.ReplaceAllStringFunc(dateArgs, func(s string) string { return "-" }) + match := _dateMatchString.MatchString(dateArgs) + if match { + dateArgsReplace := _dateReplace.ReplaceAllStringFunc(dateArgs, func(s string) string { return "-" }) dateYear, dateMonth, dateDay := a.splitDateWithSep(dateArgsReplace) if a.IsYearValid(dateYear) && a.IsMonthValid(dateMonth) && a.IsDayValid(dateYear, dateMonth, dateDay) { dateStr := a.DateOutput(dateYear, dateMonth, dateDay) return dateStr, nil - } else { - return a.DefaultDateValue(), nil } } // format - YYYYMMDD, YYMMDD - pat = "^\\d{5,8}$" - match, err = regexp.MatchString(pat, dateArgs) - if match && err == nil { + match = _dateMatchInt.MatchString(dateArgs) + if match { dateYear, dateMonth, dateDay := a.splitDateWithoutSep(dateArgs) if a.IsYearValid(dateYear) && a.IsMonthValid(dateMonth) && a.IsDayValid(dateYear, dateMonth, dateDay) { dateStr := a.DateOutput(dateYear, dateMonth, dateDay) return dateStr, nil - } else { - return a.DefaultDateValue(), nil } } @@ -94,39 +91,45 @@ func (a castDateFunc) NumInput() int { } func (a castDateFunc) DateOutput(year, month, day int) proto.Value { - dayStr := fmt.Sprint(day) - if day < 10 { - dayStr = "0" + dayStr + var sb strings.Builder + + yearStr := strconv.FormatInt(int64(year), 10) + if year >= 100 && year <= 999 { + sb.WriteString("0") } - monthStr := fmt.Sprint(month) + sb.WriteString(yearStr) + sb.WriteString("-") + monthStr := strconv.FormatInt(int64(month), 10) if month < 10 { - monthStr = "0" + monthStr + sb.WriteString("0") } - yearStr := fmt.Sprint(year) - if year >= 100 && year <= 999 { - yearStr = "0" + yearStr + sb.WriteString(monthStr) + sb.WriteString("-") + dayStr := strconv.FormatInt(int64(day), 10) + if day < 10 { + sb.WriteString("0") } + sb.WriteString(dayStr) - dateStr := yearStr + "-" + monthStr + "-" + dayStr - return proto.NewValueString(dateStr) + return proto.NewValueString(sb.String()) } func (a castDateFunc) splitDateWithSep(dateArgs string) (year, month, day int) { dateLen := len(dateArgs) - dateDayStr := string(dateArgs[dateLen-1 : dateLen]) - dateLeft := string(dateArgs[0 : dateLen-2]) + dateDayStr := dateArgs[dateLen-1 : dateLen] + dateLeft := dateArgs[0 : dateLen-2] if a.IsDigitalValid(string(dateArgs[dateLen-2])) { - dateDayStr = string(dateArgs[dateLen-2 : dateLen]) - dateLeft = string(dateArgs[0 : dateLen-3]) + dateDayStr = dateArgs[dateLen-2 : dateLen] + dateLeft = dateArgs[0 : dateLen-3] } dateArgs = dateLeft dateLen = len(dateArgs) - dateMonthStr := string(dateArgs[dateLen-1 : dateLen]) - dateLeft = string(dateArgs[0 : dateLen-2]) + dateMonthStr := dateArgs[dateLen-1 : dateLen] + dateLeft = dateArgs[0 : dateLen-2] if a.IsDigitalValid(string(dateArgs[dateLen-2])) { - dateMonthStr = string(dateArgs[dateLen-2 : dateLen]) - dateLeft = string(dateArgs[0 : dateLen-3]) + dateMonthStr = dateArgs[dateLen-2 : dateLen] + dateLeft = dateArgs[0 : dateLen-3] } dateYearStr := dateLeft @@ -140,13 +143,13 @@ func (a castDateFunc) splitDateWithSep(dateArgs string) (year, month, day int) { func (a castDateFunc) splitDateWithoutSep(dateArgs string) (year, month, day int) { dateLen := len(dateArgs) - dateDayStr := string(dateArgs[dateLen-2 : dateLen]) - dateLeft := string(dateArgs[0 : dateLen-2]) + dateDayStr := dateArgs[dateLen-2 : dateLen] + dateLeft := dateArgs[0 : dateLen-2] dateArgs = dateLeft dateLen = len(dateArgs) - dateMonthStr := string(dateArgs[dateLen-2 : dateLen]) - dateLeft = string(dateArgs[0 : dateLen-2]) + dateMonthStr := dateArgs[dateLen-2 : dateLen] + dateLeft = dateArgs[0 : dateLen-2] dateYearStr := dateLeft dateYear, _ := strconv.Atoi(dateYearStr) diff --git a/pkg/runtime/function/cast_datetime.go b/pkg/runtime/function/cast_datetime.go index 4113373d..a0ac6c8a 100644 --- a/pkg/runtime/function/cast_datetime.go +++ b/pkg/runtime/function/cast_datetime.go @@ -37,7 +37,12 @@ import ( // FuncCastDatetime is https://dev.mysql.com/doc/refman/5.6/en/cast-functions.html#function_cast const FuncCastDatetime = "CAST_DATETIME" -var DatetimeSep = "~!@#$%^&*_+=:;,|/?\\(\\)\\[\\]\\{\\}\\-\\\\" +var DatetimeSep = `[~!@#$%^&*_+=:;,|/?\(\)\[\]\{\}\-\\]+` +var _datetimeReplace = regexp.MustCompile(DatetimeSep) +var _datetimeMatchUpperString = regexp.MustCompile(fmt.Sprintf(`^\d{1,4}%s\d{1,2}%s\d{1,2}$`, DatetimeSep, DatetimeSep)) +var _datetimeMatchLowerString = regexp.MustCompile(fmt.Sprintf(`^\d{1,2}%s\d{1,2}%s\d{1,2}$`, DatetimeSep, DatetimeSep)) +var _datetimeMatchInt = regexp.MustCompile(`^\d{11,14}$`) + var _ proto.Func = (*castDatetimeFunc)(nil) func init() { @@ -85,22 +90,19 @@ func (a castDatetimeFunc) Apply(ctx context.Context, inputs ...proto.Valuer) (pr return a.DefaultDatetimeValue(), nil } - pat := "^\\d{1,4}[" + DatetimeSep + "]+\\d{1,2}[" + DatetimeSep + "]+\\d{1,2}$" - match, err := regexp.MatchString(pat, datetimeArr[0]) - if !match || err != nil { + match := _datetimeMatchUpperString.MatchString(datetimeArr[0]) + if !match { return a.DefaultDatetimeValue(), nil } - rep := regexp.MustCompile(`[` + DateSep + `]+`) - datetimeArrReplace := rep.ReplaceAllStringFunc(datetimeArr[0], func(s string) string { return "-" }) + datetimeArrReplace := _datetimeReplace.ReplaceAllStringFunc(datetimeArr[0], func(s string) string { return "-" }) year, month, day := castDate.splitDateWithSep(datetimeArrReplace) year = castDate.amend4DigtalYear(year) - pat = "^\\d{1,2}[" + DatetimeSep + "]+\\d{1,2}[" + DatetimeSep + "]+\\d{1,2}$" - match, err = regexp.MatchString(pat, datetimeArr[1]) - if !match || err != nil { + match = _datetimeMatchLowerString.MatchString(datetimeArr[1]) + if !match { return a.DefaultDatetimeValue(), nil } - datetimeArrReplace = rep.ReplaceAllStringFunc(datetimeArr[1], func(s string) string { return "-" }) + datetimeArrReplace = _datetimeReplace.ReplaceAllStringFunc(datetimeArr[1], func(s string) string { return "-" }) hour, minutes, second := a.splitDatetimeWithSep(datetimeArrReplace) if castDate.IsYearValid(year) && castDate.IsMonthValid(month) && castDate.IsDayValid(year, month, day) && @@ -111,18 +113,17 @@ func (a castDatetimeFunc) Apply(ctx context.Context, inputs ...proto.Valuer) (pr } } else { // format - YYYYMMDDhhmmss.ms - pat := "^\\d{11,14}$" - match, err := regexp.MatchString(pat, datetimeArgs) - if !match || err != nil { + match := _datetimeMatchInt.MatchString(datetimeArgs) + if !match { return a.DefaultDatetimeValue(), nil } datetimeLen := len(datetimeArgs) - dateArgs := string(datetimeArgs[0 : datetimeLen-6]) + dateArgs := datetimeArgs[0 : datetimeLen-6] year, month, day := castDate.splitDateWithoutSep(dateArgs) year = castDate.amend4DigtalYear(year) - timeArgs := string(datetimeArgs[datetimeLen-6 : datetimeLen]) + timeArgs := datetimeArgs[datetimeLen-6 : datetimeLen] hour, minutes, second := a.splitDatetimeWithoutSep(timeArgs) if castDate.IsYearValid(year) && castDate.IsMonthValid(month) && castDate.IsDayValid(year, month, day) && @@ -139,35 +140,46 @@ func (a castDatetimeFunc) NumInput() int { } func (a castDatetimeFunc) DatetimeOutput(year, month, day, hour, minutes, second int, frac bool) proto.Value { - secondStr := fmt.Sprint(second) - if second < 10 { - secondStr = "0" + secondStr + var sb strings.Builder + + yearStr := strconv.FormatInt(int64(year), 10) + if year >= 100 && year <= 999 { + sb.WriteString("0") } - minutesStr := fmt.Sprint(minutes) - if minutes < 10 { - minutesStr = "0" + minutesStr + sb.WriteString(yearStr) + sb.WriteString("-") + monthStr := strconv.FormatInt(int64(month), 10) + if month < 10 { + sb.WriteString("0") } - hourStr := fmt.Sprint(hour) - if hour < 10 { - hourStr = "0" + hourStr + sb.WriteString(monthStr) + sb.WriteString("-") + dayStr := strconv.FormatInt(int64(day), 10) + if day < 10 { + sb.WriteString("0") } - timeStr := hourStr + ":" + minutesStr + ":" + secondStr + sb.WriteString(dayStr) + sb.WriteString(" ") - dayStr := fmt.Sprint(day) - if day < 10 { - dayStr = "0" + dayStr + hourStr := strconv.FormatInt(int64(hour), 10) + if hour < 10 { + sb.WriteString("0") } - monthStr := fmt.Sprint(month) - if month < 10 { - monthStr = "0" + monthStr + sb.WriteString(hourStr) + sb.WriteString(":") + minutesStr := strconv.FormatInt(int64(minutes), 10) + if minutes < 10 { + sb.WriteString("0") } - yearStr := fmt.Sprint(year) - if year >= 100 && year <= 999 { - yearStr = "0" + yearStr + sb.WriteString(minutesStr) + sb.WriteString(":") + secondStr := strconv.FormatInt(int64(second), 10) + if second < 10 { + sb.WriteString("0") } - dateStr := yearStr + "-" + monthStr + "-" + dayStr + sb.WriteString(secondStr) - datetimeRet, _ := time.Parse("2006-01-02 15:04:05", dateStr+" "+timeStr) + datetimeRet, _ := time.Parse("2006-01-02 15:04:05", sb.String()) if frac { datetimeRet = datetimeRet.Add(1 * time.Second) @@ -178,20 +190,20 @@ func (a castDatetimeFunc) DatetimeOutput(year, month, day, hour, minutes, second func (a castDatetimeFunc) splitDatetimeWithSep(timeArgs string) (hour, minutes, second int) { timeLen := len(timeArgs) - timeSecondStr := string(timeArgs[timeLen-1 : timeLen]) - timeLeft := string(timeArgs[0 : timeLen-2]) + timeSecondStr := timeArgs[timeLen-1 : timeLen] + timeLeft := timeArgs[0 : timeLen-2] if castDate.IsDigitalValid(string(timeArgs[timeLen-2])) { - timeSecondStr = string(timeArgs[timeLen-2 : timeLen]) - timeLeft = string(timeArgs[0 : timeLen-3]) + timeSecondStr = timeArgs[timeLen-2 : timeLen] + timeLeft = timeArgs[0 : timeLen-3] } timeArgs = timeLeft timeLen = len(timeArgs) - timeMinutesStr := string(timeArgs[timeLen-1 : timeLen]) - timeLeft = string(timeArgs[0 : timeLen-2]) + timeMinutesStr := timeArgs[timeLen-1 : timeLen] + timeLeft = timeArgs[0 : timeLen-2] if castDate.IsDigitalValid(string(timeArgs[timeLen-2])) { - timeMinutesStr = string(timeArgs[timeLen-2 : timeLen]) - timeLeft = string(timeArgs[0 : timeLen-3]) + timeMinutesStr = timeArgs[timeLen-2 : timeLen] + timeLeft = timeArgs[0 : timeLen-3] } timeHourStr := timeLeft @@ -204,13 +216,13 @@ func (a castDatetimeFunc) splitDatetimeWithSep(timeArgs string) (hour, minutes, func (a castDatetimeFunc) splitDatetimeWithoutSep(timeArgs string) (hour, minutes, second int) { timeLen := len(timeArgs) - timeSecondStr := string(timeArgs[timeLen-2 : timeLen]) - timeLeft := string(timeArgs[0 : timeLen-2]) + timeSecondStr := timeArgs[timeLen-2 : timeLen] + timeLeft := timeArgs[0 : timeLen-2] timeArgs = timeLeft timeLen = len(timeArgs) - timeMinutesStr := string(timeArgs[timeLen-2 : timeLen]) - timeLeft = string(timeArgs[0 : timeLen-2]) + timeMinutesStr := timeArgs[timeLen-2 : timeLen] + timeLeft = timeArgs[0 : timeLen-2] timeHourStr := timeLeft timeHour, _ := strconv.Atoi(timeHourStr) diff --git a/pkg/runtime/function/cast_time.go b/pkg/runtime/function/cast_time.go index b6487b6d..6553bf28 100644 --- a/pkg/runtime/function/cast_time.go +++ b/pkg/runtime/function/cast_time.go @@ -36,6 +36,10 @@ import ( // FuncCastTime is https://dev.mysql.com/doc/refman/5.6/en/cast-functions.html#function_cast const FuncCastTime = "CAST_TIME" +var _timeMatchDay = regexp.MustCompile(`^\d{1,2} \d{1,3}(:\d{1,2}){0,2}$`) +var _timeMatchString = regexp.MustCompile(`^\d{1,3}(:\d{1,2}){1,2}$`) +var _timeMatchInt = regexp.MustCompile(`^\d{1,7}$`) + var _ proto.Func = (*castTimeFunc)(nil) func init() { @@ -60,10 +64,10 @@ func (a castTimeFunc) Apply(ctx context.Context, inputs ...proto.Valuer) (proto. nega := false if strings.Compare(string(timeArgs[0]), "-") == 0 { nega = true - timeArgs = string(timeArgs[1:]) + timeArgs = timeArgs[1:] } if strings.Compare(string(timeArgs[0]), "+") == 0 { - timeArgs = string(timeArgs[1:]) + timeArgs = timeArgs[1:] } // fractional seconds @@ -81,9 +85,8 @@ func (a castTimeFunc) Apply(ctx context.Context, inputs ...proto.Valuer) (proto. if strings.Contains(timeArgs, " ") { // format - D hhh:mm:ss.ms,D hhh:mm.ms,D hhh.ms - pat := "^\\d{1,2} \\d{1,3}(:\\d{1,2}){0,2}$" - match, err := regexp.MatchString(pat, timeArgs) - if !match || err != nil { + match := _timeMatchDay.MatchString(timeArgs) + if !match { return a.DefaultTimeValue(), nil } @@ -101,9 +104,8 @@ func (a castTimeFunc) Apply(ctx context.Context, inputs ...proto.Valuer) (proto. } } else if strings.Contains(timeArgs, ":") { // format - hhh:mm:ss.ms,hhh:mm.ms - pat := "^\\d{1,3}(:\\d{1,2}){1,2}$" - match, err := regexp.MatchString(pat, timeArgs) - if !match || err != nil { + match := _timeMatchString.MatchString(timeArgs) + if !match { return a.DefaultTimeValue(), nil } @@ -118,9 +120,8 @@ func (a castTimeFunc) Apply(ctx context.Context, inputs ...proto.Valuer) (proto. } } else { // format - hhhmmss.ms,mmss.ms,ss.ms - pat := "^\\d{1,7}$" - match, err := regexp.MatchString(pat, timeArgs) - if !match || err != nil { + match := _timeMatchInt.MatchString(timeArgs) + if !match { return a.DefaultTimeValue(), nil } @@ -163,24 +164,29 @@ func (a castTimeFunc) TimeOutput(hour, minutes, second int, nega, frac bool) pro } } - secondStr := fmt.Sprint(second) - if second < 10 { - secondStr = "0" + secondStr + var sb strings.Builder + if nega { + sb.WriteString("-") + } + hourStr := strconv.FormatInt(int64(hour), 10) + if hour < 10 { + sb.WriteString("0") } - minutesStr := fmt.Sprint(minutes) + sb.WriteString(hourStr) + sb.WriteString(":") + minutesStr := strconv.FormatInt(int64(minutes), 10) if minutes < 10 { - minutesStr = "0" + minutesStr + sb.WriteString("0") } - hourStr := fmt.Sprint(hour) - if hour < 10 { - hourStr = "0" + hourStr + sb.WriteString(minutesStr) + sb.WriteString(":") + secondStr := strconv.FormatInt(int64(second), 10) + if second < 10 { + sb.WriteString("0") } + sb.WriteString(secondStr) - timeStr := hourStr + ":" + minutesStr + ":" + secondStr - if nega { - timeStr = "-" + timeStr - } - return proto.NewValueString(timeStr) + return proto.NewValueString(sb.String()) } func (a castTimeFunc) splitTimeWithSep(timeArgs string) (hour, minutes, second int) { From 7a6489019c923ce646140d1f7a12e8b2ee880710 Mon Sep 17 00:00:00 2001 From: csynineyang Date: Sat, 18 Mar 2023 13:38:14 +0800 Subject: [PATCH 09/20] Support CREATE TABLE --- pkg/executor/redirect.go | 2 +- pkg/runtime/ast/ast.go | 22 +++++ pkg/runtime/ast/create_table.go | 108 ++++++++++++++++++++++ pkg/runtime/ast/proto.go | 2 + pkg/runtime/optimize/ddl/create_table.go | 94 +++++++++++++++++++ pkg/runtime/plan/ddl/create_table.go | 113 +++++++++++++++++++++++ 6 files changed, 340 insertions(+), 1 deletion(-) create mode 100644 pkg/runtime/ast/create_table.go create mode 100644 pkg/runtime/optimize/ddl/create_table.go create mode 100644 pkg/runtime/plan/ddl/create_table.go diff --git a/pkg/executor/redirect.go b/pkg/executor/redirect.go index a7269adb..65258eaf 100644 --- a/pkg/executor/redirect.go +++ b/pkg/executor/redirect.go @@ -281,7 +281,7 @@ func (executor *RedirectExecutor) doExecutorComQuery(ctx *proto.Context, act ast err = errNoDatabaseSelected } case *ast.TruncateTableStmt, *ast.DropTableStmt, *ast.ExplainStmt, *ast.DropIndexStmt, *ast.CreateIndexStmt, - *ast.AnalyzeTableStmt, *ast.OptimizeTableStmt, *ast.CheckTableStmt, *ast.RenameTableStmt: + *ast.AnalyzeTableStmt, *ast.OptimizeTableStmt, *ast.CheckTableStmt, *ast.RenameTableStmt, *ast.CreateTableStmt: res, warn, err = executeStmt(ctx, schemaless, rt) case *ast.DropTriggerStmt, *ast.SetStmt, *ast.KillStmt: res, warn, err = rt.Execute(ctx) diff --git a/pkg/runtime/ast/ast.go b/pkg/runtime/ast/ast.go index e7e67665..6f703494 100644 --- a/pkg/runtime/ast/ast.go +++ b/pkg/runtime/ast/ast.go @@ -131,6 +131,8 @@ func FromStmtNode(node ast.StmtNode) (Statement, error) { return cc.convOptimizeTable(stmt), nil case *ast.CheckTableStmt: return cc.convCheckTableStmt(stmt), nil + case *ast.CreateTableStmt: + return cc.convCreateTableStmt(stmt), nil case *ast.RenameTableStmt: return cc.convRenameTableStmt(stmt), nil case *ast.KillStmt: @@ -1679,6 +1681,26 @@ func (cc *convCtx) convCheckTableStmt(stmt *ast.CheckTableStmt) Statement { return &CheckTableStmt{Tables: tables} } +func (cc *convCtx) convCreateTableStmt(stmt *ast.CreateTableStmt) Statement { + table := &TableName{ + stmt.Table.Name.String(), + } + var refTable *TableName + if stmt.ReferTable != nil { + refTable = &TableName{ + stmt.ReferTable.Name.String(), + } + } + + return &CreateTableStmt{ + Table: table, + ReferTable: refTable, + Cols: stmt.Cols, + Constraints: stmt.Constraints, + Options: stmt.Options, + } +} + func (cc *convCtx) convRenameTableStmt(stmt *ast.RenameTableStmt) Statement { tableToTables := make([]*TableToTable, len(stmt.TableToTables)) for i, tableToTable := range stmt.TableToTables { diff --git a/pkg/runtime/ast/create_table.go b/pkg/runtime/ast/create_table.go new file mode 100644 index 00000000..800ae841 --- /dev/null +++ b/pkg/runtime/ast/create_table.go @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ast + +import ( + "strings" +) + +import ( + "github.com/arana-db/parser/ast" + "github.com/arana-db/parser/format" + + "github.com/pkg/errors" +) + +var _ Statement = (*CreateTableStmt)(nil) + +type CreateTableStmt struct { + //IfNotExists bool + //TemporaryKeyword + // Meanless when TemporaryKeyword is not TemporaryGlobal. + // ON COMMIT DELETE ROWS => true + // ON COMMIT PRESERVE ROW => false + //OnCommitDelete bool + Table *TableName + ReferTable *TableName + Cols []*ast.ColumnDef + Constraints []*ast.Constraint + Options []*ast.TableOption + //Partition *PartitionOptions + //OnDuplicate OnDuplicateKeyHandlingType + //Select ResultSetNode +} + +func NewCreateTableStmt() *CreateTableStmt { + return &CreateTableStmt{} +} + +func (c *CreateTableStmt) CntParams() int { + return 1 +} + +func (c *CreateTableStmt) Restore(flag RestoreFlag, sb *strings.Builder, args *[]int) error { + sb.WriteString("CREATE TABLE ") + if err := c.Table.Restore(flag, sb, args); err != nil { + return errors.Wrapf(err, "an error occurred while restore AnalyzeTableStatement.Tables[%s]", c.Table) + } + + if c.ReferTable != nil { + sb.WriteString(" LIKE ") + if err := c.ReferTable.Restore(flag, sb, args); err != nil { + return errors.Wrapf(err, "An error occurred while splicing CreateTableStmt ReferTable") + } + } + + rsCtx := format.NewRestoreCtx(format.RestoreFlags(flag), sb) + + lenCols := len(c.Cols) + lenConstraints := len(c.Constraints) + if lenCols+lenConstraints > 0 { + sb.WriteString(" (") + for i, col := range c.Cols { + if i > 0 { + sb.WriteString(",") + } + if err := col.Restore(rsCtx); err != nil { + return errors.Wrapf(err, "An error occurred while splicing CreateTableStmt ColumnDef: [%v]", i) + } + } + for i, constraint := range c.Constraints { + if i > 0 || lenCols >= 1 { + sb.WriteString(",") + } + if err := constraint.Restore(rsCtx); err != nil { + return errors.Wrapf(err, "An error occurred while splicing CreateTableStmt Constraints: [%v]", i) + } + } + sb.WriteString(")") + } + + for i, option := range c.Options { + sb.WriteString(" ") + if err := option.Restore(rsCtx); err != nil { + return errors.Wrapf(err, "An error occurred while splicing CreateTableStmt TableOption: [%v]", i) + } + } + + return nil +} + +func (c *CreateTableStmt) Mode() SQLType { + return SQLTypeCreateTable +} diff --git a/pkg/runtime/ast/proto.go b/pkg/runtime/ast/proto.go index 68bed1ae..f26e9c5a 100644 --- a/pkg/runtime/ast/proto.go +++ b/pkg/runtime/ast/proto.go @@ -60,6 +60,7 @@ const ( SQLTypeKill // KILL SQLTypeCheckTable // CHECK TABLE SQLTypeRenameTable // RENAME TABLE + SQLTypeCreateTable // CREATE TABLE ) var _sqlTypeNames = [...]string{ @@ -95,6 +96,7 @@ var _sqlTypeNames = [...]string{ SQLTypeKill: "KILL", SQLTypeCheckTable: "CHECK TABLE", SQLTypeRenameTable: "RENAME TABLE", + SQLTypeCreateTable: "CREATE TABLE", } // SQLType represents the type of SQL. diff --git a/pkg/runtime/optimize/ddl/create_table.go b/pkg/runtime/optimize/ddl/create_table.go new file mode 100644 index 00000000..b47c28e2 --- /dev/null +++ b/pkg/runtime/optimize/ddl/create_table.go @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ddl + +import ( + "context" +) + +import ( + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/proto/rule" + "github.com/arana-db/arana/pkg/runtime/ast" + "github.com/arana-db/arana/pkg/runtime/optimize" + "github.com/arana-db/arana/pkg/runtime/plan/ddl" + "github.com/arana-db/arana/pkg/runtime/plan/dml" + "github.com/arana-db/arana/pkg/util/log" +) + +func init() { + optimize.Register(ast.SQLTypeCreateTable, optimizeCreateTable) +} + +func optimizeCreateTable(ctx context.Context, o *optimize.Optimizer) (proto.Plan, error) { + stmt := o.Stmt.(*ast.CreateTableStmt) + + var ( + shards rule.DatabaseTables + fullScan bool + ) + //if len(o.Hints) > 0 { + // if shards, err = optimize.Hints(*stmt.Table, o.Hints, o.Rule); err != nil { + // return nil, errors.Wrap(err, "calculate hints failed") + // } + //} + vt, ok := o.Rule.VTable(stmt.Table.Suffix()) //TODO + fullScan = ok + + log.Debugf("compute shards: result=%s, isFullScan=%v", shards, fullScan) + // return error if full-scan is disabled + ////if fullScan && (!vt.AllowFullScan() && !hint.Contains(hint.TypeFullScan, o.Hints)) { + // return nil, errors.WithStack(optimize.ErrDenyFullScan) + //} + + toSingle := func(db, tbl string) (proto.Plan, error) { + ret := &ddl.CreateTablePlan{ + Stmt: stmt, + Database: db, + Tables: []string{tbl}, + } + ret.BindArgs(o.Args) + + return ret, nil + } + + // Go through first table if not full scan. + if !fullScan { + return toSingle("", stmt.Table.Suffix()) + } + + // expand all shards if all shards matched + shards = vt.Topology().Enumerate() + + plans := make([]proto.Plan, 0, len(shards)) + for k, v := range shards { + next := &ddl.CreateTablePlan{ + Database: k, + Tables: v, + Stmt: stmt, + } + next.BindArgs(o.Args) + plans = append(plans, next) + } + + tmpPlan := &dml.CompositePlan{ + Plans: plans, + } + + return tmpPlan, nil +} diff --git a/pkg/runtime/plan/ddl/create_table.go b/pkg/runtime/plan/ddl/create_table.go new file mode 100644 index 00000000..b563eb7e --- /dev/null +++ b/pkg/runtime/plan/ddl/create_table.go @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ddl + +import ( + "context" + "strings" +) + +import ( + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/resultx" + "github.com/arana-db/arana/pkg/runtime/ast" + "github.com/arana-db/arana/pkg/runtime/plan" +) + +type CreateTablePlan struct { + plan.BasePlan + Stmt *ast.CreateTableStmt + Database string + Tables []string +} + +func NewCreateTablePlan( + stmt *ast.CreateTableStmt, + db string, + tb []string, +) *CreateTablePlan { + return &CreateTablePlan{ + Stmt: stmt, + Database: db, + Tables: tb, + } +} + +// Type get plan type +func (c *CreateTablePlan) Type() proto.PlanType { + return proto.PlanTypeExec +} + +func (c *CreateTablePlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Result, error) { + var ( + sb strings.Builder + args []int + err error + ) + + ctx, span := plan.Tracer.Start(ctx, "CreateTable.ExecIn") + defer span.End() + + switch len(c.Tables) { + case 0: + // no table reset + return resultx.New(), nil + case 1: + // single shard table + if err := c.Stmt.Restore(ast.RestoreDefault, &sb, &args); err != nil { + return nil, err + } + if _, err = conn.Query(ctx, c.Database, sb.String(), c.ToArgs(args)...); err != nil { + return nil, err + } + default: + // multiple shard tables + stmt := new(ast.CreateTableStmt) + *stmt = *c.Stmt // do copy + + restore := func(table string) error { + sb.Reset() + if err = c.resetTable(stmt, table); err != nil { + return err + } + if err = stmt.Restore(ast.RestoreDefault, &sb, &args); err != nil { + return err + } + if _, err = conn.Query(ctx, c.Database, sb.String(), c.ToArgs(args)...); err != nil { + return err + } + return nil + } + + for i := 0; i < len(c.Tables); i++ { + if err := restore(c.Tables[i]); err != nil { + return nil, err + } + } + } + + return resultx.New(), nil +} + +func (c *CreateTablePlan) resetTable(stmt *ast.CreateTableStmt, table string) error { + stmt.Table = &ast.TableName{ + table, + } + + return nil +} From d02656f252a04c59d60b282955c3b42d8e1b8ccf Mon Sep 17 00:00:00 2001 From: csynineyang Date: Wed, 22 Mar 2023 17:18:15 +0800 Subject: [PATCH 10/20] add: IfNotExists --- pkg/runtime/ast/ast.go | 1 + pkg/runtime/ast/create_table.go | 9 ++++----- pkg/runtime/optimize/ddl/create_table.go | 9 --------- 3 files changed, 5 insertions(+), 14 deletions(-) diff --git a/pkg/runtime/ast/ast.go b/pkg/runtime/ast/ast.go index 72b6d367..92eb6b11 100644 --- a/pkg/runtime/ast/ast.go +++ b/pkg/runtime/ast/ast.go @@ -1695,6 +1695,7 @@ func (cc *convCtx) convCreateTableStmt(stmt *ast.CreateTableStmt) Statement { } return &CreateTableStmt{ + IfNotExists: stmt.IfNotExists, Table: table, ReferTable: refTable, Cols: stmt.Cols, diff --git a/pkg/runtime/ast/create_table.go b/pkg/runtime/ast/create_table.go index 800ae841..a8572dab 100644 --- a/pkg/runtime/ast/create_table.go +++ b/pkg/runtime/ast/create_table.go @@ -31,7 +31,7 @@ import ( var _ Statement = (*CreateTableStmt)(nil) type CreateTableStmt struct { - //IfNotExists bool + IfNotExists bool //TemporaryKeyword // Meanless when TemporaryKeyword is not TemporaryGlobal. // ON COMMIT DELETE ROWS => true @@ -51,12 +51,11 @@ func NewCreateTableStmt() *CreateTableStmt { return &CreateTableStmt{} } -func (c *CreateTableStmt) CntParams() int { - return 1 -} - func (c *CreateTableStmt) Restore(flag RestoreFlag, sb *strings.Builder, args *[]int) error { sb.WriteString("CREATE TABLE ") + if c.IfNotExists { + sb.WriteString(" IF NOT EXISTS ") + } if err := c.Table.Restore(flag, sb, args); err != nil { return errors.Wrapf(err, "an error occurred while restore AnalyzeTableStatement.Tables[%s]", c.Table) } diff --git a/pkg/runtime/optimize/ddl/create_table.go b/pkg/runtime/optimize/ddl/create_table.go index b47c28e2..793b6edb 100644 --- a/pkg/runtime/optimize/ddl/create_table.go +++ b/pkg/runtime/optimize/ddl/create_table.go @@ -42,19 +42,10 @@ func optimizeCreateTable(ctx context.Context, o *optimize.Optimizer) (proto.Plan shards rule.DatabaseTables fullScan bool ) - //if len(o.Hints) > 0 { - // if shards, err = optimize.Hints(*stmt.Table, o.Hints, o.Rule); err != nil { - // return nil, errors.Wrap(err, "calculate hints failed") - // } - //} vt, ok := o.Rule.VTable(stmt.Table.Suffix()) //TODO fullScan = ok log.Debugf("compute shards: result=%s, isFullScan=%v", shards, fullScan) - // return error if full-scan is disabled - ////if fullScan && (!vt.AllowFullScan() && !hint.Contains(hint.TypeFullScan, o.Hints)) { - // return nil, errors.WithStack(optimize.ErrDenyFullScan) - //} toSingle := func(db, tbl string) (proto.Plan, error) { ret := &ddl.CreateTablePlan{ From 002330ae54425061bb4798122160c85487509c6c Mon Sep 17 00:00:00 2001 From: csynineyang Date: Wed, 22 Mar 2023 17:24:59 +0800 Subject: [PATCH 11/20] fix: reformat imports --- pkg/runtime/ast/create_table.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pkg/runtime/ast/create_table.go b/pkg/runtime/ast/create_table.go index a8572dab..2e438799 100644 --- a/pkg/runtime/ast/create_table.go +++ b/pkg/runtime/ast/create_table.go @@ -21,11 +21,13 @@ import ( "strings" ) +import ( + "github.com/pkg/errors" +) + import ( "github.com/arana-db/parser/ast" "github.com/arana-db/parser/format" - - "github.com/pkg/errors" ) var _ Statement = (*CreateTableStmt)(nil) From 4d9bf3fd5c64d0215f5b17817b4072d7874dd50f Mon Sep 17 00:00:00 2001 From: csynineyang Date: Mon, 17 Apr 2023 14:54:59 +0800 Subject: [PATCH 12/20] Resolve Conversation --- pkg/runtime/ast/create_table.go | 2 +- pkg/runtime/optimize/ddl/create_table.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg/runtime/ast/create_table.go b/pkg/runtime/ast/create_table.go index 2e438799..e1e131e7 100644 --- a/pkg/runtime/ast/create_table.go +++ b/pkg/runtime/ast/create_table.go @@ -59,7 +59,7 @@ func (c *CreateTableStmt) Restore(flag RestoreFlag, sb *strings.Builder, args *[ sb.WriteString(" IF NOT EXISTS ") } if err := c.Table.Restore(flag, sb, args); err != nil { - return errors.Wrapf(err, "an error occurred while restore AnalyzeTableStatement.Tables[%s]", c.Table) + return errors.Wrapf(err, "An error occurred while restore AnalyzeTableStatement.Tables[%s]", c.Table) } if c.ReferTable != nil { diff --git a/pkg/runtime/optimize/ddl/create_table.go b/pkg/runtime/optimize/ddl/create_table.go index 793b6edb..81e1ef97 100644 --- a/pkg/runtime/optimize/ddl/create_table.go +++ b/pkg/runtime/optimize/ddl/create_table.go @@ -42,7 +42,7 @@ func optimizeCreateTable(ctx context.Context, o *optimize.Optimizer) (proto.Plan shards rule.DatabaseTables fullScan bool ) - vt, ok := o.Rule.VTable(stmt.Table.Suffix()) //TODO + vt, ok := o.Rule.VTable(stmt.Table.Suffix()) fullScan = ok log.Debugf("compute shards: result=%s, isFullScan=%v", shards, fullScan) From 742659fbb64bd4628165d0d461ffc21db257747e Mon Sep 17 00:00:00 2001 From: csynineyang Date: Thu, 18 May 2023 09:12:53 +0800 Subject: [PATCH 13/20] Support window function: CUME_DIST --- pkg/runtime/function/cume_dist.go | 69 ++++++++++++++ pkg/runtime/function/cume_dist_test.go | 127 +++++++++++++++++++++++++ 2 files changed, 196 insertions(+) create mode 100644 pkg/runtime/function/cume_dist.go create mode 100644 pkg/runtime/function/cume_dist_test.go diff --git a/pkg/runtime/function/cume_dist.go b/pkg/runtime/function/cume_dist.go new file mode 100644 index 00000000..604777d8 --- /dev/null +++ b/pkg/runtime/function/cume_dist.go @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package function + +import ( + "context" +) + +import ( + "github.com/pkg/errors" +) + +import ( + "github.com/arana-db/arana/pkg/proto" +) + +// FuncCumeDist is https://dev.mysql.com/doc/refman/8.0/en/window-function-descriptions.html +const FuncCumeDist = "CUME_DIST" + +var _ proto.Func = (*cumedistFunc)(nil) + +func init() { + proto.RegisterFunc(FuncCumeDist, cumedistFunc{}) +} + +type cumedistFunc struct{} + +func (a cumedistFunc) Apply(ctx context.Context, inputs ...proto.Valuer) (proto.Value, error) { + first, err := inputs[0].Value(ctx) + if first == nil || err != nil { + return nil, errors.Wrapf(err, "cannot eval %s", FuncCumeDist) + } + firstDec, _ := first.Float64() + firstNum := 0 + + for _, it := range inputs[1:] { + val, err := it.Value(ctx) + if val == nil || err != nil { + return nil, errors.Wrapf(err, "cannot eval %s", FuncCumeDist) + } + + valDec, _ := val.Float64() + if valDec <= firstDec { + firstNum++ + } + } + + r := float64(firstNum) / float64(len(inputs)-1) + return proto.NewValueFloat64(r), nil +} + +func (a cumedistFunc) NumInput() int { + return 0 +} diff --git a/pkg/runtime/function/cume_dist_test.go b/pkg/runtime/function/cume_dist_test.go new file mode 100644 index 00000000..9de2cfe8 --- /dev/null +++ b/pkg/runtime/function/cume_dist_test.go @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package function + +import ( + "context" + "fmt" + "testing" +) + +import ( + "github.com/stretchr/testify/assert" +) + +import ( + "github.com/arana-db/arana/pkg/proto" +) + +func TestFuncCumeDist(t *testing.T) { + fn := proto.MustGetFunc(FuncCumeDist) + type tt struct { + inputs []proto.Value + want string + } + for _, v := range []tt{ + { + []proto.Value{ + proto.NewValueInt64(1), + proto.NewValueInt64(1), + proto.NewValueInt64(1), + proto.NewValueInt64(2), + proto.NewValueInt64(3), + proto.NewValueInt64(3), + proto.NewValueInt64(3), + proto.NewValueInt64(4), + proto.NewValueInt64(4), + proto.NewValueInt64(5), + }, + "0.2222222222222222", + }, + { + []proto.Value{ + proto.NewValueInt64(2), + proto.NewValueInt64(1), + proto.NewValueInt64(1), + proto.NewValueInt64(2), + proto.NewValueInt64(3), + proto.NewValueInt64(3), + proto.NewValueInt64(3), + proto.NewValueInt64(4), + proto.NewValueInt64(4), + proto.NewValueInt64(5), + }, + "0.3333333333333333", + }, + { + []proto.Value{ + proto.NewValueInt64(3), + proto.NewValueInt64(1), + proto.NewValueInt64(1), + proto.NewValueInt64(2), + proto.NewValueInt64(3), + proto.NewValueInt64(3), + proto.NewValueInt64(3), + proto.NewValueInt64(4), + proto.NewValueInt64(4), + proto.NewValueInt64(5), + }, + "0.6666666666666666", + }, + { + []proto.Value{ + proto.NewValueInt64(4), + proto.NewValueInt64(1), + proto.NewValueInt64(1), + proto.NewValueInt64(2), + proto.NewValueInt64(3), + proto.NewValueInt64(3), + proto.NewValueInt64(3), + proto.NewValueInt64(4), + proto.NewValueInt64(4), + proto.NewValueInt64(5), + }, + "0.8888888888888888", + }, + { + []proto.Value{ + proto.NewValueInt64(5), + proto.NewValueInt64(1), + proto.NewValueInt64(1), + proto.NewValueInt64(2), + proto.NewValueInt64(3), + proto.NewValueInt64(3), + proto.NewValueInt64(3), + proto.NewValueInt64(4), + proto.NewValueInt64(4), + proto.NewValueInt64(5), + }, + "1", + }, + } { + t.Run(v.want, func(t *testing.T) { + var inputs []proto.Valuer + for i := range v.inputs { + inputs = append(inputs, proto.ToValuer(v.inputs[i])) + } + out, err := fn.Apply(context.Background(), inputs...) + assert.NoError(t, err) + assert.Equal(t, v.want, fmt.Sprint(out)) + }) + } +} From c090c4f9a3acf613541ec01dcf7a58572967f9cc Mon Sep 17 00:00:00 2001 From: csynineyang Date: Sun, 21 May 2023 11:27:54 +0800 Subject: [PATCH 14/20] Support window function: PERCENT_RANK --- pkg/runtime/function/cume_dist.go | 7 +- pkg/runtime/function/percent_rank.go | 72 ++++++++++++ pkg/runtime/function/percent_rank_test.go | 127 ++++++++++++++++++++++ 3 files changed, 204 insertions(+), 2 deletions(-) create mode 100644 pkg/runtime/function/percent_rank.go create mode 100644 pkg/runtime/function/percent_rank_test.go diff --git a/pkg/runtime/function/cume_dist.go b/pkg/runtime/function/cume_dist.go index 604777d8..b07920a4 100644 --- a/pkg/runtime/function/cume_dist.go +++ b/pkg/runtime/function/cume_dist.go @@ -53,14 +53,17 @@ func (a cumedistFunc) Apply(ctx context.Context, inputs ...proto.Valuer) (proto. if val == nil || err != nil { return nil, errors.Wrapf(err, "cannot eval %s", FuncCumeDist) } - valDec, _ := val.Float64() + if valDec <= firstDec { firstNum++ } } - r := float64(firstNum) / float64(len(inputs)-1) + r := 0.0 + if len(inputs) > 1 { + r = float64(firstNum) / float64(len(inputs)-1) + } return proto.NewValueFloat64(r), nil } diff --git a/pkg/runtime/function/percent_rank.go b/pkg/runtime/function/percent_rank.go new file mode 100644 index 00000000..9b5973ed --- /dev/null +++ b/pkg/runtime/function/percent_rank.go @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package function + +import ( + "context" +) + +import ( + "github.com/pkg/errors" +) + +import ( + "github.com/arana-db/arana/pkg/proto" +) + +// FuncPercentRank is https://dev.mysql.com/doc/refman/8.0/en/window-function-descriptions.html +const FuncPercentRank = "PERCENT_RANK" + +var _ proto.Func = (*percentrankFunc)(nil) + +func init() { + proto.RegisterFunc(FuncPercentRank, percentrankFunc{}) +} + +type percentrankFunc struct{} + +func (a percentrankFunc) Apply(ctx context.Context, inputs ...proto.Valuer) (proto.Value, error) { + first, err := inputs[0].Value(ctx) + if first == nil || err != nil { + return nil, errors.Wrapf(err, "cannot eval %s", FuncCumeDist) + } + firstDec, _ := first.Float64() + firstNum := 0 + + for _, it := range inputs[1:] { + val, err := it.Value(ctx) + if val == nil || err != nil { + return nil, errors.Wrapf(err, "cannot eval %s", FuncPercentRank) + } + valDec, _ := val.Float64() + + if valDec < firstDec { + firstNum++ + } + } + + r := 0.0 + if len(inputs) > 2 { + r = float64(firstNum) / float64(len(inputs)-2) + } + return proto.NewValueFloat64(r), nil +} + +func (a percentrankFunc) NumInput() int { + return 0 +} diff --git a/pkg/runtime/function/percent_rank_test.go b/pkg/runtime/function/percent_rank_test.go new file mode 100644 index 00000000..64a36767 --- /dev/null +++ b/pkg/runtime/function/percent_rank_test.go @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package function + +import ( + "context" + "fmt" + "testing" +) + +import ( + "github.com/stretchr/testify/assert" +) + +import ( + "github.com/arana-db/arana/pkg/proto" +) + +func TestPercentRankDist(t *testing.T) { + fn := proto.MustGetFunc(FuncPercentRank) + type tt struct { + inputs []proto.Value + want string + } + for _, v := range []tt{ + { + []proto.Value{ + proto.NewValueInt64(1), + proto.NewValueInt64(1), + proto.NewValueInt64(1), + proto.NewValueInt64(2), + proto.NewValueInt64(3), + proto.NewValueInt64(3), + proto.NewValueInt64(3), + proto.NewValueInt64(4), + proto.NewValueInt64(4), + proto.NewValueInt64(5), + }, + "0", + }, + { + []proto.Value{ + proto.NewValueInt64(2), + proto.NewValueInt64(1), + proto.NewValueInt64(1), + proto.NewValueInt64(2), + proto.NewValueInt64(3), + proto.NewValueInt64(3), + proto.NewValueInt64(3), + proto.NewValueInt64(4), + proto.NewValueInt64(4), + proto.NewValueInt64(5), + }, + "0.25", + }, + { + []proto.Value{ + proto.NewValueInt64(3), + proto.NewValueInt64(1), + proto.NewValueInt64(1), + proto.NewValueInt64(2), + proto.NewValueInt64(3), + proto.NewValueInt64(3), + proto.NewValueInt64(3), + proto.NewValueInt64(4), + proto.NewValueInt64(4), + proto.NewValueInt64(5), + }, + "0.375", + }, + { + []proto.Value{ + proto.NewValueInt64(4), + proto.NewValueInt64(1), + proto.NewValueInt64(1), + proto.NewValueInt64(2), + proto.NewValueInt64(3), + proto.NewValueInt64(3), + proto.NewValueInt64(3), + proto.NewValueInt64(4), + proto.NewValueInt64(4), + proto.NewValueInt64(5), + }, + "0.75", + }, + { + []proto.Value{ + proto.NewValueInt64(5), + proto.NewValueInt64(1), + proto.NewValueInt64(1), + proto.NewValueInt64(2), + proto.NewValueInt64(3), + proto.NewValueInt64(3), + proto.NewValueInt64(3), + proto.NewValueInt64(4), + proto.NewValueInt64(4), + proto.NewValueInt64(5), + }, + "1", + }, + } { + t.Run(v.want, func(t *testing.T) { + var inputs []proto.Valuer + for i := range v.inputs { + inputs = append(inputs, proto.ToValuer(v.inputs[i])) + } + out, err := fn.Apply(context.Background(), inputs...) + assert.NoError(t, err) + assert.Equal(t, v.want, fmt.Sprint(out)) + }) + } +} From 80124ffd5486546c75aa3cacd933ed6c7a7eace9 Mon Sep 17 00:00:00 2001 From: csynineyang Date: Sun, 21 May 2023 11:45:07 +0800 Subject: [PATCH 15/20] Support window function: RANK --- pkg/runtime/function/rank.go | 68 ++++++++++++++++ pkg/runtime/function/rank_test.go | 127 ++++++++++++++++++++++++++++++ 2 files changed, 195 insertions(+) create mode 100644 pkg/runtime/function/rank.go create mode 100644 pkg/runtime/function/rank_test.go diff --git a/pkg/runtime/function/rank.go b/pkg/runtime/function/rank.go new file mode 100644 index 00000000..ecb303d0 --- /dev/null +++ b/pkg/runtime/function/rank.go @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package function + +import ( + "context" +) + +import ( + "github.com/pkg/errors" +) + +import ( + "github.com/arana-db/arana/pkg/proto" +) + +// FuncRank is https://dev.mysql.com/doc/refman/8.0/en/window-function-descriptions.html +const FuncRank = "RANK" + +var _ proto.Func = (*rankFunc)(nil) + +func init() { + proto.RegisterFunc(FuncRank, rankFunc{}) +} + +type rankFunc struct{} + +func (a rankFunc) Apply(ctx context.Context, inputs ...proto.Valuer) (proto.Value, error) { + first, err := inputs[0].Value(ctx) + if first == nil || err != nil { + return nil, errors.Wrapf(err, "cannot eval %s", FuncCumeDist) + } + firstDec, _ := first.Float64() + firstNum := 0 + + for _, it := range inputs[1:] { + val, err := it.Value(ctx) + if val == nil || err != nil { + return nil, errors.Wrapf(err, "cannot eval %s", FuncCumeDist) + } + valDec, _ := val.Float64() + + if valDec < firstDec { + firstNum++ + } + } + + return proto.NewValueInt64(int64(firstNum) + 1), nil +} + +func (a rankFunc) NumInput() int { + return 0 +} diff --git a/pkg/runtime/function/rank_test.go b/pkg/runtime/function/rank_test.go new file mode 100644 index 00000000..da7607a6 --- /dev/null +++ b/pkg/runtime/function/rank_test.go @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package function + +import ( + "context" + "fmt" + "testing" +) + +import ( + "github.com/stretchr/testify/assert" +) + +import ( + "github.com/arana-db/arana/pkg/proto" +) + +func TestFuncRank(t *testing.T) { + fn := proto.MustGetFunc(FuncRank) + type tt struct { + inputs []proto.Value + want string + } + for _, v := range []tt{ + { + []proto.Value{ + proto.NewValueInt64(1), + proto.NewValueInt64(1), + proto.NewValueInt64(1), + proto.NewValueInt64(2), + proto.NewValueInt64(3), + proto.NewValueInt64(3), + proto.NewValueInt64(3), + proto.NewValueInt64(4), + proto.NewValueInt64(4), + proto.NewValueInt64(5), + }, + "1", + }, + { + []proto.Value{ + proto.NewValueInt64(2), + proto.NewValueInt64(1), + proto.NewValueInt64(1), + proto.NewValueInt64(2), + proto.NewValueInt64(3), + proto.NewValueInt64(3), + proto.NewValueInt64(3), + proto.NewValueInt64(4), + proto.NewValueInt64(4), + proto.NewValueInt64(5), + }, + "3", + }, + { + []proto.Value{ + proto.NewValueInt64(3), + proto.NewValueInt64(1), + proto.NewValueInt64(1), + proto.NewValueInt64(2), + proto.NewValueInt64(3), + proto.NewValueInt64(3), + proto.NewValueInt64(3), + proto.NewValueInt64(4), + proto.NewValueInt64(4), + proto.NewValueInt64(5), + }, + "4", + }, + { + []proto.Value{ + proto.NewValueInt64(4), + proto.NewValueInt64(1), + proto.NewValueInt64(1), + proto.NewValueInt64(2), + proto.NewValueInt64(3), + proto.NewValueInt64(3), + proto.NewValueInt64(3), + proto.NewValueInt64(4), + proto.NewValueInt64(4), + proto.NewValueInt64(5), + }, + "7", + }, + { + []proto.Value{ + proto.NewValueInt64(5), + proto.NewValueInt64(1), + proto.NewValueInt64(1), + proto.NewValueInt64(2), + proto.NewValueInt64(3), + proto.NewValueInt64(3), + proto.NewValueInt64(3), + proto.NewValueInt64(4), + proto.NewValueInt64(4), + proto.NewValueInt64(5), + }, + "9", + }, + } { + t.Run(v.want, func(t *testing.T) { + var inputs []proto.Valuer + for i := range v.inputs { + inputs = append(inputs, proto.ToValuer(v.inputs[i])) + } + out, err := fn.Apply(context.Background(), inputs...) + assert.NoError(t, err) + assert.Equal(t, v.want, fmt.Sprint(out)) + }) + } +} From 3dd886412fde19b1b836c390bd69689193523e57 Mon Sep 17 00:00:00 2001 From: csynineyang Date: Sun, 21 May 2023 12:09:25 +0800 Subject: [PATCH 16/20] Support window function: DENSE_RANK --- pkg/runtime/function/dense_rank.go | 70 +++++++++++++ pkg/runtime/function/dense_rank_test.go | 127 ++++++++++++++++++++++++ 2 files changed, 197 insertions(+) create mode 100644 pkg/runtime/function/dense_rank.go create mode 100644 pkg/runtime/function/dense_rank_test.go diff --git a/pkg/runtime/function/dense_rank.go b/pkg/runtime/function/dense_rank.go new file mode 100644 index 00000000..e441968a --- /dev/null +++ b/pkg/runtime/function/dense_rank.go @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package function + +import ( + "context" +) + +import ( + "github.com/pkg/errors" +) + +import ( + "github.com/arana-db/arana/pkg/proto" +) + +// FuncDenseRank is https://dev.mysql.com/doc/refman/8.0/en/window-function-descriptions.html +const FuncDenseRank = "DENSE_RANK" + +var _ proto.Func = (*denserankFunc)(nil) + +func init() { + proto.RegisterFunc(FuncDenseRank, denserankFunc{}) +} + +type denserankFunc struct{} + +func (a denserankFunc) Apply(ctx context.Context, inputs ...proto.Valuer) (proto.Value, error) { + first, err := inputs[0].Value(ctx) + if first == nil || err != nil { + return nil, errors.Wrapf(err, "cannot eval %s", FuncCumeDist) + } + firstDec, _ := first.Float64() + secondDec := firstDec + firstNum := 0 + + for _, it := range inputs[1:] { + val, err := it.Value(ctx) + if val == nil || err != nil { + return nil, errors.Wrapf(err, "cannot eval %s", FuncCumeDist) + } + valDec, _ := val.Float64() + + if valDec < firstDec && valDec != secondDec { + firstNum++ + secondDec = valDec + } + } + + return proto.NewValueInt64(int64(firstNum) + 1), nil +} + +func (a denserankFunc) NumInput() int { + return 0 +} diff --git a/pkg/runtime/function/dense_rank_test.go b/pkg/runtime/function/dense_rank_test.go new file mode 100644 index 00000000..1daa818e --- /dev/null +++ b/pkg/runtime/function/dense_rank_test.go @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package function + +import ( + "context" + "fmt" + "testing" +) + +import ( + "github.com/stretchr/testify/assert" +) + +import ( + "github.com/arana-db/arana/pkg/proto" +) + +func TestFuncDenseRankt(t *testing.T) { + fn := proto.MustGetFunc(FuncDenseRank) + type tt struct { + inputs []proto.Value + want string + } + for _, v := range []tt{ + { + []proto.Value{ + proto.NewValueInt64(1), + proto.NewValueInt64(1), + proto.NewValueInt64(1), + proto.NewValueInt64(2), + proto.NewValueInt64(3), + proto.NewValueInt64(3), + proto.NewValueInt64(3), + proto.NewValueInt64(4), + proto.NewValueInt64(4), + proto.NewValueInt64(5), + }, + "1", + }, + { + []proto.Value{ + proto.NewValueInt64(2), + proto.NewValueInt64(1), + proto.NewValueInt64(1), + proto.NewValueInt64(2), + proto.NewValueInt64(3), + proto.NewValueInt64(3), + proto.NewValueInt64(3), + proto.NewValueInt64(4), + proto.NewValueInt64(4), + proto.NewValueInt64(5), + }, + "2", + }, + { + []proto.Value{ + proto.NewValueInt64(3), + proto.NewValueInt64(1), + proto.NewValueInt64(1), + proto.NewValueInt64(2), + proto.NewValueInt64(3), + proto.NewValueInt64(3), + proto.NewValueInt64(3), + proto.NewValueInt64(4), + proto.NewValueInt64(4), + proto.NewValueInt64(5), + }, + "3", + }, + { + []proto.Value{ + proto.NewValueInt64(4), + proto.NewValueInt64(1), + proto.NewValueInt64(1), + proto.NewValueInt64(2), + proto.NewValueInt64(3), + proto.NewValueInt64(3), + proto.NewValueInt64(3), + proto.NewValueInt64(4), + proto.NewValueInt64(4), + proto.NewValueInt64(5), + }, + "4", + }, + { + []proto.Value{ + proto.NewValueInt64(5), + proto.NewValueInt64(1), + proto.NewValueInt64(1), + proto.NewValueInt64(2), + proto.NewValueInt64(3), + proto.NewValueInt64(3), + proto.NewValueInt64(3), + proto.NewValueInt64(4), + proto.NewValueInt64(4), + proto.NewValueInt64(5), + }, + "5", + }, + } { + t.Run(v.want, func(t *testing.T) { + var inputs []proto.Valuer + for i := range v.inputs { + inputs = append(inputs, proto.ToValuer(v.inputs[i])) + } + out, err := fn.Apply(context.Background(), inputs...) + assert.NoError(t, err) + assert.Equal(t, v.want, fmt.Sprint(out)) + }) + } +} From 74c5d3094d9f5e07e7f2aaa9d09a756b86819633 Mon Sep 17 00:00:00 2001 From: csynineyang Date: Sat, 10 Jun 2023 16:06:53 +0800 Subject: [PATCH 17/20] Support window function: FIRST_VALUE/LAST_VALUE/LAG/LEAD --- pkg/runtime/function/first_value.go | 95 +++++++++++ pkg/runtime/function/first_value_test.go | 198 +++++++++++++++++++++++ pkg/runtime/function/lag.go | 117 ++++++++++++++ pkg/runtime/function/lag_test.go | 152 +++++++++++++++++ pkg/runtime/function/last_value.go | 108 +++++++++++++ pkg/runtime/function/last_value_test.go | 198 +++++++++++++++++++++++ pkg/runtime/function/lead.go | 117 ++++++++++++++ pkg/runtime/function/lead_test.go | 152 +++++++++++++++++ 8 files changed, 1137 insertions(+) create mode 100644 pkg/runtime/function/first_value.go create mode 100644 pkg/runtime/function/first_value_test.go create mode 100644 pkg/runtime/function/lag.go create mode 100644 pkg/runtime/function/lag_test.go create mode 100644 pkg/runtime/function/last_value.go create mode 100644 pkg/runtime/function/last_value_test.go create mode 100644 pkg/runtime/function/lead.go create mode 100644 pkg/runtime/function/lead_test.go diff --git a/pkg/runtime/function/first_value.go b/pkg/runtime/function/first_value.go new file mode 100644 index 00000000..be011d32 --- /dev/null +++ b/pkg/runtime/function/first_value.go @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package function + +import ( + "context" + "strings" +) + +import ( + "github.com/pkg/errors" +) + +import ( + "github.com/arana-db/arana/pkg/proto" +) + +// FuncFirstValue is https://dev.mysql.com/doc/refman/8.0/en/window-function-descriptions.html +const FuncFirstValue = "FIRST_VALUE" + +var _ proto.Func = (*firstvalueFunc)(nil) + +func init() { + proto.RegisterFunc(FuncFirstValue, firstvalueFunc{}) +} + +type firstvalueFunc struct{} + +func (a firstvalueFunc) Apply(ctx context.Context, inputs ...proto.Valuer) (proto.Value, error) { + if len(inputs) < 3 { + return proto.NewValueString(""), nil + } + + // partition by this column + firstPartitionColumn, err := inputs[1].Value(ctx) + if firstPartitionColumn == nil || err != nil { + return nil, errors.Wrapf(err, "cannot eval %s", FuncFirstValue) + } + firstPartitionColumnStr := firstPartitionColumn.String() + // output by this volumn + firstValueColumn, err := inputs[2].Value(ctx) + if firstValueColumn == nil || err != nil { + return nil, errors.Wrapf(err, "cannot eval %s", FuncFirstValue) + } + firstValueColumnDec, _ := firstValueColumn.Float64() + firstValue := 0.0 + startOffset := 3 + + if len(inputs) < 6 { + return proto.NewValueFloat64(firstValueColumnDec), nil + } + + for { + partitionColumn, err := inputs[startOffset+1].Value(ctx) + if partitionColumn == nil || err != nil { + return nil, errors.Wrapf(err, "cannot eval %s", FuncFirstValue) + } + partitionColumnStr := partitionColumn.String() + valueColumn, err := inputs[startOffset+2].Value(ctx) + if valueColumn == nil || err != nil { + return nil, errors.Wrapf(err, "cannot eval %s", FuncFirstValue) + } + valueColumnDec, _ := valueColumn.Float64() + if strings.Compare(firstPartitionColumnStr, partitionColumnStr) == 0 { + firstValue = valueColumnDec + break + } + + startOffset += 3 + if startOffset >= len(inputs) { + break + } + } + + return proto.NewValueFloat64(firstValue), nil +} + +func (a firstvalueFunc) NumInput() int { + return 1 +} diff --git a/pkg/runtime/function/first_value_test.go b/pkg/runtime/function/first_value_test.go new file mode 100644 index 00000000..5473e7ab --- /dev/null +++ b/pkg/runtime/function/first_value_test.go @@ -0,0 +1,198 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package function + +import ( + "context" + "fmt" + "testing" +) + +import ( + "github.com/stretchr/testify/assert" +) + +import ( + "github.com/arana-db/arana/pkg/proto" +) + +func TestFuncFirstValue(t *testing.T) { + fn := proto.MustGetFunc(FuncFirstValue) + type tt struct { + inputs [][]proto.Value + want string + } + for _, v := range []tt{ + { + [][]proto.Value{ + //order column, partition column, value column + {proto.NewValueString("07:00:00"), proto.NewValueString("st113"), proto.NewValueFloat64(10)}, + {proto.NewValueString("07:00:00"), proto.NewValueString("st113"), proto.NewValueFloat64(10)}, + {proto.NewValueString("07:15:00"), proto.NewValueString("st113"), proto.NewValueFloat64(9)}, + {proto.NewValueString("07:30:00"), proto.NewValueString("st113"), proto.NewValueFloat64(25)}, + {proto.NewValueString("07:45:00"), proto.NewValueString("st113"), proto.NewValueFloat64(0)}, + {proto.NewValueString("07:00:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(0)}, + {proto.NewValueString("07:15:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(10)}, + {proto.NewValueString("07:30:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(5)}, + {proto.NewValueString("07:45:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(30)}, + {proto.NewValueString("08:00:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(25)}, + }, + "10", + }, + { + [][]proto.Value{ + //order column, partition column, value column + {proto.NewValueString("07:15:00"), proto.NewValueString("st113"), proto.NewValueFloat64(9)}, + {proto.NewValueString("07:00:00"), proto.NewValueString("st113"), proto.NewValueFloat64(10)}, + {proto.NewValueString("07:15:00"), proto.NewValueString("st113"), proto.NewValueFloat64(9)}, + {proto.NewValueString("07:30:00"), proto.NewValueString("st113"), proto.NewValueFloat64(25)}, + {proto.NewValueString("07:45:00"), proto.NewValueString("st113"), proto.NewValueFloat64(0)}, + {proto.NewValueString("07:00:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(0)}, + {proto.NewValueString("07:15:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(10)}, + {proto.NewValueString("07:30:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(5)}, + {proto.NewValueString("07:45:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(30)}, + {proto.NewValueString("08:00:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(25)}, + }, + "10", + }, + { + [][]proto.Value{ + //order column, partition column, value column + {proto.NewValueString("07:30:00"), proto.NewValueString("st113"), proto.NewValueFloat64(25)}, + {proto.NewValueString("07:00:00"), proto.NewValueString("st113"), proto.NewValueFloat64(10)}, + {proto.NewValueString("07:15:00"), proto.NewValueString("st113"), proto.NewValueFloat64(9)}, + {proto.NewValueString("07:30:00"), proto.NewValueString("st113"), proto.NewValueFloat64(25)}, + {proto.NewValueString("07:45:00"), proto.NewValueString("st113"), proto.NewValueFloat64(0)}, + {proto.NewValueString("07:00:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(0)}, + {proto.NewValueString("07:15:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(10)}, + {proto.NewValueString("07:30:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(5)}, + {proto.NewValueString("07:45:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(30)}, + {proto.NewValueString("08:00:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(25)}, + }, + "10", + }, + { + [][]proto.Value{ + //order column, partition column, value column + {proto.NewValueString("07:45:00"), proto.NewValueString("st113"), proto.NewValueFloat64(0)}, + {proto.NewValueString("07:00:00"), proto.NewValueString("st113"), proto.NewValueFloat64(10)}, + {proto.NewValueString("07:15:00"), proto.NewValueString("st113"), proto.NewValueFloat64(9)}, + {proto.NewValueString("07:30:00"), proto.NewValueString("st113"), proto.NewValueFloat64(25)}, + {proto.NewValueString("07:45:00"), proto.NewValueString("st113"), proto.NewValueFloat64(0)}, + {proto.NewValueString("07:00:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(0)}, + {proto.NewValueString("07:15:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(10)}, + {proto.NewValueString("07:30:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(5)}, + {proto.NewValueString("07:45:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(30)}, + {proto.NewValueString("08:00:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(25)}, + }, + "10", + }, + { + [][]proto.Value{ + //order column, partition column, value column + {proto.NewValueString("07:00:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(0)}, + {proto.NewValueString("07:00:00"), proto.NewValueString("st113"), proto.NewValueFloat64(10)}, + {proto.NewValueString("07:15:00"), proto.NewValueString("st113"), proto.NewValueFloat64(9)}, + {proto.NewValueString("07:30:00"), proto.NewValueString("st113"), proto.NewValueFloat64(25)}, + {proto.NewValueString("07:45:00"), proto.NewValueString("st113"), proto.NewValueFloat64(0)}, + {proto.NewValueString("07:00:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(0)}, + {proto.NewValueString("07:15:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(10)}, + {proto.NewValueString("07:30:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(5)}, + {proto.NewValueString("07:45:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(30)}, + {proto.NewValueString("08:00:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(25)}, + }, + "0", + }, + { + [][]proto.Value{ + //order column, partition column, value column + {proto.NewValueString("07:15:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(10)}, + {proto.NewValueString("07:00:00"), proto.NewValueString("st113"), proto.NewValueFloat64(10)}, + {proto.NewValueString("07:15:00"), proto.NewValueString("st113"), proto.NewValueFloat64(9)}, + {proto.NewValueString("07:30:00"), proto.NewValueString("st113"), proto.NewValueFloat64(25)}, + {proto.NewValueString("07:45:00"), proto.NewValueString("st113"), proto.NewValueFloat64(0)}, + {proto.NewValueString("07:00:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(0)}, + {proto.NewValueString("07:15:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(10)}, + {proto.NewValueString("07:30:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(5)}, + {proto.NewValueString("07:45:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(30)}, + {proto.NewValueString("08:00:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(25)}, + }, + "0", + }, + { + [][]proto.Value{ + //order column, partition column, value column + {proto.NewValueString("07:30:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(5)}, + {proto.NewValueString("07:00:00"), proto.NewValueString("st113"), proto.NewValueFloat64(10)}, + {proto.NewValueString("07:15:00"), proto.NewValueString("st113"), proto.NewValueFloat64(9)}, + {proto.NewValueString("07:30:00"), proto.NewValueString("st113"), proto.NewValueFloat64(25)}, + {proto.NewValueString("07:45:00"), proto.NewValueString("st113"), proto.NewValueFloat64(0)}, + {proto.NewValueString("07:00:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(0)}, + {proto.NewValueString("07:15:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(10)}, + {proto.NewValueString("07:30:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(5)}, + {proto.NewValueString("07:45:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(30)}, + {proto.NewValueString("08:00:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(25)}, + }, + "0", + }, + { + [][]proto.Value{ + //order column, partition column, value column + {proto.NewValueString("07:45:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(30)}, + {proto.NewValueString("07:00:00"), proto.NewValueString("st113"), proto.NewValueFloat64(10)}, + {proto.NewValueString("07:15:00"), proto.NewValueString("st113"), proto.NewValueFloat64(9)}, + {proto.NewValueString("07:30:00"), proto.NewValueString("st113"), proto.NewValueFloat64(25)}, + {proto.NewValueString("07:45:00"), proto.NewValueString("st113"), proto.NewValueFloat64(0)}, + {proto.NewValueString("07:00:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(0)}, + {proto.NewValueString("07:15:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(10)}, + {proto.NewValueString("07:30:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(5)}, + {proto.NewValueString("07:45:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(30)}, + {proto.NewValueString("08:00:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(25)}, + }, + "0", + }, + { + [][]proto.Value{ + //order column, partition column, value column + {proto.NewValueString("08:00:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(25)}, + {proto.NewValueString("07:00:00"), proto.NewValueString("st113"), proto.NewValueFloat64(10)}, + {proto.NewValueString("07:15:00"), proto.NewValueString("st113"), proto.NewValueFloat64(9)}, + {proto.NewValueString("07:30:00"), proto.NewValueString("st113"), proto.NewValueFloat64(25)}, + {proto.NewValueString("07:45:00"), proto.NewValueString("st113"), proto.NewValueFloat64(0)}, + {proto.NewValueString("07:00:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(0)}, + {proto.NewValueString("07:15:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(10)}, + {proto.NewValueString("07:30:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(5)}, + {proto.NewValueString("07:45:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(30)}, + {proto.NewValueString("08:00:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(25)}, + }, + "0", + }, + } { + t.Run(v.want, func(t *testing.T) { + var inputs []proto.Valuer + for i := range v.inputs { + for j := range v.inputs[i] { + inputs = append(inputs, proto.ToValuer(v.inputs[i][j])) + } + } + out, err := fn.Apply(context.Background(), inputs...) + assert.NoError(t, err) + assert.Equal(t, v.want, fmt.Sprint(out)) + }) + } +} diff --git a/pkg/runtime/function/lag.go b/pkg/runtime/function/lag.go new file mode 100644 index 00000000..f882c170 --- /dev/null +++ b/pkg/runtime/function/lag.go @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package function + +import ( + "context" + "strings" +) + +import ( + "github.com/pkg/errors" +) + +import ( + "github.com/arana-db/arana/pkg/proto" +) + +// FuncLag is https://dev.mysql.com/doc/refman/8.0/en/window-function-descriptions.html +const FuncLag = "LAG" + +var _ proto.Func = (*lagFunc)(nil) + +func init() { + proto.RegisterFunc(FuncLag, lagFunc{}) +} + +type lagFunc struct{} + +func (a lagFunc) Apply(ctx context.Context, inputs ...proto.Valuer) (proto.Value, error) { + if len(inputs) < 6 { + return proto.NewValueString(""), nil + } + + // order by this column + firstOrderColumn, err := inputs[0].Value(ctx) + if firstOrderColumn == nil || err != nil { + return nil, errors.Wrapf(err, "cannot eval %s", FuncLag) + } + firstOrderColumnStr := firstOrderColumn.String() + // partition by this column + firstPartitionColumn, err := inputs[1].Value(ctx) + if firstPartitionColumn == nil || err != nil { + return nil, errors.Wrapf(err, "cannot eval %s", FuncLag) + } + firstPartitionColumnStr := firstPartitionColumn.String() + // output by this volumn + firstValueColumn, err := inputs[2].Value(ctx) + if firstValueColumn == nil || err != nil { + return nil, errors.Wrapf(err, "cannot eval %s", FuncLag) + } + firstValueColumnDec, _ := firstValueColumn.Float64() + lagValue := 0.0 + lagIndex := -1 + startOffset := 3 + + for { + orderColumn, err := inputs[startOffset].Value(ctx) + if orderColumn == nil || err != nil { + return nil, errors.Wrapf(err, "cannot eval %s", FuncLag) + } + orderColumnStr := orderColumn.String() + partitionColumn, err := inputs[startOffset+1].Value(ctx) + if partitionColumn == nil || err != nil { + return nil, errors.Wrapf(err, "cannot eval %s", FuncLag) + } + partitionColumnStr := partitionColumn.String() + valueColumn, err := inputs[startOffset+2].Value(ctx) + if valueColumn == nil || err != nil { + return nil, errors.Wrapf(err, "cannot eval %s", FuncLag) + } + valueColumnDec, _ := valueColumn.Float64() + if strings.Compare(firstOrderColumnStr, orderColumnStr) == 0 && + strings.Compare(firstPartitionColumnStr, partitionColumnStr) == 0 && + firstValueColumnDec == valueColumnDec { + if startOffset > 3 { + lagValueColumn, err := inputs[startOffset-1].Value(ctx) + if lagValueColumn == nil || err != nil { + return nil, errors.Wrapf(err, "cannot eval %s", FuncLag) + } + lagValueColumnDec, _ := lagValueColumn.Float64() + lagValue = lagValueColumnDec + lagIndex = startOffset - 1 + } + break + } + + startOffset += 3 + if startOffset >= len(inputs) { + break + } + } + + if lagIndex < 0 { + return proto.NewValueString(""), nil + } else { + return proto.NewValueFloat64(lagValue), nil + } +} + +func (a lagFunc) NumInput() int { + return 1 +} diff --git a/pkg/runtime/function/lag_test.go b/pkg/runtime/function/lag_test.go new file mode 100644 index 00000000..1bd006f8 --- /dev/null +++ b/pkg/runtime/function/lag_test.go @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package function + +import ( + "context" + "fmt" + "testing" +) + +import ( + "github.com/stretchr/testify/assert" +) + +import ( + "github.com/arana-db/arana/pkg/proto" +) + +func TestFuncLag(t *testing.T) { + fn := proto.MustGetFunc(FuncLag) + type tt struct { + inputs [][]proto.Value + want string + } + for _, v := range []tt{ + { + [][]proto.Value{ + //order column, partition column, value column + {proto.NewValueString("12:00:00"), proto.NewValueString(""), proto.NewValueFloat64(100)}, + {proto.NewValueString("12:00:00"), proto.NewValueString(""), proto.NewValueFloat64(100)}, + {proto.NewValueString("13:00:00"), proto.NewValueString(""), proto.NewValueFloat64(125)}, + {proto.NewValueString("14:00:00"), proto.NewValueString(""), proto.NewValueFloat64(132)}, + {proto.NewValueString("15:00:00"), proto.NewValueString(""), proto.NewValueFloat64(145)}, + {proto.NewValueString("16:00:00"), proto.NewValueString(""), proto.NewValueFloat64(140)}, + {proto.NewValueString("17:00:00"), proto.NewValueString(""), proto.NewValueFloat64(150)}, + {proto.NewValueString("18:00:00"), proto.NewValueString(""), proto.NewValueFloat64(200)}, + }, + "", + }, + { + [][]proto.Value{ + //order column, partition column, value column + {proto.NewValueString("13:00:00"), proto.NewValueString(""), proto.NewValueFloat64(125)}, + {proto.NewValueString("12:00:00"), proto.NewValueString(""), proto.NewValueFloat64(100)}, + {proto.NewValueString("13:00:00"), proto.NewValueString(""), proto.NewValueFloat64(125)}, + {proto.NewValueString("14:00:00"), proto.NewValueString(""), proto.NewValueFloat64(132)}, + {proto.NewValueString("15:00:00"), proto.NewValueString(""), proto.NewValueFloat64(145)}, + {proto.NewValueString("16:00:00"), proto.NewValueString(""), proto.NewValueFloat64(140)}, + {proto.NewValueString("17:00:00"), proto.NewValueString(""), proto.NewValueFloat64(150)}, + {proto.NewValueString("18:00:00"), proto.NewValueString(""), proto.NewValueFloat64(200)}, + }, + "100", + }, + { + [][]proto.Value{ + //order column, partition column, value column + {proto.NewValueString("14:00:00"), proto.NewValueString(""), proto.NewValueFloat64(132)}, + {proto.NewValueString("12:00:00"), proto.NewValueString(""), proto.NewValueFloat64(100)}, + {proto.NewValueString("13:00:00"), proto.NewValueString(""), proto.NewValueFloat64(125)}, + {proto.NewValueString("14:00:00"), proto.NewValueString(""), proto.NewValueFloat64(132)}, + {proto.NewValueString("15:00:00"), proto.NewValueString(""), proto.NewValueFloat64(145)}, + {proto.NewValueString("16:00:00"), proto.NewValueString(""), proto.NewValueFloat64(140)}, + {proto.NewValueString("17:00:00"), proto.NewValueString(""), proto.NewValueFloat64(150)}, + {proto.NewValueString("18:00:00"), proto.NewValueString(""), proto.NewValueFloat64(200)}, + }, + "125", + }, + { + [][]proto.Value{ + //order column, partition column, value column + {proto.NewValueString("15:00:00"), proto.NewValueString(""), proto.NewValueFloat64(145)}, + {proto.NewValueString("12:00:00"), proto.NewValueString(""), proto.NewValueFloat64(100)}, + {proto.NewValueString("13:00:00"), proto.NewValueString(""), proto.NewValueFloat64(125)}, + {proto.NewValueString("14:00:00"), proto.NewValueString(""), proto.NewValueFloat64(132)}, + {proto.NewValueString("15:00:00"), proto.NewValueString(""), proto.NewValueFloat64(145)}, + {proto.NewValueString("16:00:00"), proto.NewValueString(""), proto.NewValueFloat64(140)}, + {proto.NewValueString("17:00:00"), proto.NewValueString(""), proto.NewValueFloat64(150)}, + {proto.NewValueString("18:00:00"), proto.NewValueString(""), proto.NewValueFloat64(200)}, + }, + "132", + }, + { + [][]proto.Value{ + //order column, partition column, value column + {proto.NewValueString("16:00:00"), proto.NewValueString(""), proto.NewValueFloat64(140)}, + {proto.NewValueString("12:00:00"), proto.NewValueString(""), proto.NewValueFloat64(100)}, + {proto.NewValueString("13:00:00"), proto.NewValueString(""), proto.NewValueFloat64(125)}, + {proto.NewValueString("14:00:00"), proto.NewValueString(""), proto.NewValueFloat64(132)}, + {proto.NewValueString("15:00:00"), proto.NewValueString(""), proto.NewValueFloat64(145)}, + {proto.NewValueString("16:00:00"), proto.NewValueString(""), proto.NewValueFloat64(140)}, + {proto.NewValueString("17:00:00"), proto.NewValueString(""), proto.NewValueFloat64(150)}, + {proto.NewValueString("18:00:00"), proto.NewValueString(""), proto.NewValueFloat64(200)}, + }, + "145", + }, + { + [][]proto.Value{ + //order column, partition column, value column + {proto.NewValueString("17:00:00"), proto.NewValueString(""), proto.NewValueFloat64(150)}, + {proto.NewValueString("12:00:00"), proto.NewValueString(""), proto.NewValueFloat64(100)}, + {proto.NewValueString("13:00:00"), proto.NewValueString(""), proto.NewValueFloat64(125)}, + {proto.NewValueString("14:00:00"), proto.NewValueString(""), proto.NewValueFloat64(132)}, + {proto.NewValueString("15:00:00"), proto.NewValueString(""), proto.NewValueFloat64(145)}, + {proto.NewValueString("16:00:00"), proto.NewValueString(""), proto.NewValueFloat64(140)}, + {proto.NewValueString("17:00:00"), proto.NewValueString(""), proto.NewValueFloat64(150)}, + {proto.NewValueString("18:00:00"), proto.NewValueString(""), proto.NewValueFloat64(200)}, + }, + "140", + }, + { + [][]proto.Value{ + //order column, partition column, value column + {proto.NewValueString("18:00:00"), proto.NewValueString(""), proto.NewValueFloat64(200)}, + {proto.NewValueString("12:00:00"), proto.NewValueString(""), proto.NewValueFloat64(100)}, + {proto.NewValueString("13:00:00"), proto.NewValueString(""), proto.NewValueFloat64(125)}, + {proto.NewValueString("14:00:00"), proto.NewValueString(""), proto.NewValueFloat64(132)}, + {proto.NewValueString("15:00:00"), proto.NewValueString(""), proto.NewValueFloat64(145)}, + {proto.NewValueString("16:00:00"), proto.NewValueString(""), proto.NewValueFloat64(140)}, + {proto.NewValueString("17:00:00"), proto.NewValueString(""), proto.NewValueFloat64(150)}, + {proto.NewValueString("18:00:00"), proto.NewValueString(""), proto.NewValueFloat64(200)}, + }, + "150", + }, + } { + t.Run(v.want, func(t *testing.T) { + var inputs []proto.Valuer + for i := range v.inputs { + for j := range v.inputs[i] { + inputs = append(inputs, proto.ToValuer(v.inputs[i][j])) + } + } + out, err := fn.Apply(context.Background(), inputs...) + assert.NoError(t, err) + assert.Equal(t, v.want, fmt.Sprint(out)) + }) + } +} diff --git a/pkg/runtime/function/last_value.go b/pkg/runtime/function/last_value.go new file mode 100644 index 00000000..30a26a22 --- /dev/null +++ b/pkg/runtime/function/last_value.go @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package function + +import ( + "context" + "strings" +) + +import ( + "github.com/pkg/errors" +) + +import ( + "github.com/arana-db/arana/pkg/proto" +) + +// FuncLastValue is https://dev.mysql.com/doc/refman/8.0/en/window-function-descriptions.html +const FuncLastValue = "LAST_VALUE" + +var _ proto.Func = (*lastvalueFunc)(nil) + +func init() { + proto.RegisterFunc(FuncLastValue, lastvalueFunc{}) +} + +type lastvalueFunc struct{} + +func (a lastvalueFunc) Apply(ctx context.Context, inputs ...proto.Valuer) (proto.Value, error) { + if len(inputs) < 3 { + return proto.NewValueString(""), nil + } + + // order by this column + firstOrderColumn, err := inputs[0].Value(ctx) + if firstOrderColumn == nil || err != nil { + return nil, errors.Wrapf(err, "cannot eval %s", FuncLastValue) + } + firstOrderColumnStr := firstOrderColumn.String() + // partition by this column + firstPartitionColumn, err := inputs[1].Value(ctx) + if firstPartitionColumn == nil || err != nil { + return nil, errors.Wrapf(err, "cannot eval %s", FuncLastValue) + } + firstPartitionColumnStr := firstPartitionColumn.String() + // output by this volumn + firstValueColumn, err := inputs[2].Value(ctx) + if firstValueColumn == nil || err != nil { + return nil, errors.Wrapf(err, "cannot eval %s", FuncLastValue) + } + firstValueColumnDec, _ := firstValueColumn.Float64() + lastValue := 0.0 + startOffset := 3 + + if len(inputs) < 6 { + return proto.NewValueFloat64(firstValueColumnDec), nil + } + + for { + orderColumn, err := inputs[startOffset].Value(ctx) + if orderColumn == nil || err != nil { + return nil, errors.Wrapf(err, "cannot eval %s", FuncLastValue) + } + orderColumnStr := orderColumn.String() + partitionColumn, err := inputs[startOffset+1].Value(ctx) + if partitionColumn == nil || err != nil { + return nil, errors.Wrapf(err, "cannot eval %s", FuncLastValue) + } + partitionColumnStr := partitionColumn.String() + valueColumn, err := inputs[startOffset+2].Value(ctx) + if valueColumn == nil || err != nil { + return nil, errors.Wrapf(err, "cannot eval %s", FuncLastValue) + } + valueColumnDec, _ := valueColumn.Float64() + if strings.Compare(firstOrderColumnStr, orderColumnStr) == 0 && + strings.Compare(firstPartitionColumnStr, partitionColumnStr) == 0 && + firstValueColumnDec == valueColumnDec { + lastValue = valueColumnDec + break + } + + startOffset += 3 + if startOffset >= len(inputs) { + break + } + } + + return proto.NewValueFloat64(lastValue), nil +} + +func (a lastvalueFunc) NumInput() int { + return 1 +} diff --git a/pkg/runtime/function/last_value_test.go b/pkg/runtime/function/last_value_test.go new file mode 100644 index 00000000..4ba6c567 --- /dev/null +++ b/pkg/runtime/function/last_value_test.go @@ -0,0 +1,198 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package function + +import ( + "context" + "fmt" + "testing" +) + +import ( + "github.com/stretchr/testify/assert" +) + +import ( + "github.com/arana-db/arana/pkg/proto" +) + +func TestFuncLastValue(t *testing.T) { + fn := proto.MustGetFunc(FuncLastValue) + type tt struct { + inputs [][]proto.Value + want string + } + for _, v := range []tt{ + { + [][]proto.Value{ + //order column, partition column, value column + {proto.NewValueString("07:00:00"), proto.NewValueString("st113"), proto.NewValueFloat64(10)}, + {proto.NewValueString("07:00:00"), proto.NewValueString("st113"), proto.NewValueFloat64(10)}, + {proto.NewValueString("07:15:00"), proto.NewValueString("st113"), proto.NewValueFloat64(9)}, + {proto.NewValueString("07:30:00"), proto.NewValueString("st113"), proto.NewValueFloat64(25)}, + {proto.NewValueString("07:45:00"), proto.NewValueString("st113"), proto.NewValueFloat64(0)}, + {proto.NewValueString("07:00:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(0)}, + {proto.NewValueString("07:15:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(10)}, + {proto.NewValueString("07:30:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(5)}, + {proto.NewValueString("07:45:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(30)}, + {proto.NewValueString("08:00:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(25)}, + }, + "10", + }, + { + [][]proto.Value{ + //order column, partition column, value column + {proto.NewValueString("07:15:00"), proto.NewValueString("st113"), proto.NewValueFloat64(9)}, + {proto.NewValueString("07:00:00"), proto.NewValueString("st113"), proto.NewValueFloat64(10)}, + {proto.NewValueString("07:15:00"), proto.NewValueString("st113"), proto.NewValueFloat64(9)}, + {proto.NewValueString("07:30:00"), proto.NewValueString("st113"), proto.NewValueFloat64(25)}, + {proto.NewValueString("07:45:00"), proto.NewValueString("st113"), proto.NewValueFloat64(0)}, + {proto.NewValueString("07:00:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(0)}, + {proto.NewValueString("07:15:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(10)}, + {proto.NewValueString("07:30:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(5)}, + {proto.NewValueString("07:45:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(30)}, + {proto.NewValueString("08:00:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(25)}, + }, + "9", + }, + { + [][]proto.Value{ + //order column, partition column, value column + {proto.NewValueString("07:30:00"), proto.NewValueString("st113"), proto.NewValueFloat64(25)}, + {proto.NewValueString("07:00:00"), proto.NewValueString("st113"), proto.NewValueFloat64(10)}, + {proto.NewValueString("07:15:00"), proto.NewValueString("st113"), proto.NewValueFloat64(9)}, + {proto.NewValueString("07:30:00"), proto.NewValueString("st113"), proto.NewValueFloat64(25)}, + {proto.NewValueString("07:45:00"), proto.NewValueString("st113"), proto.NewValueFloat64(0)}, + {proto.NewValueString("07:00:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(0)}, + {proto.NewValueString("07:15:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(10)}, + {proto.NewValueString("07:30:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(5)}, + {proto.NewValueString("07:45:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(30)}, + {proto.NewValueString("08:00:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(25)}, + }, + "25", + }, + { + [][]proto.Value{ + //order column, partition column, value column + {proto.NewValueString("07:45:00"), proto.NewValueString("st113"), proto.NewValueFloat64(0)}, + {proto.NewValueString("07:00:00"), proto.NewValueString("st113"), proto.NewValueFloat64(10)}, + {proto.NewValueString("07:15:00"), proto.NewValueString("st113"), proto.NewValueFloat64(9)}, + {proto.NewValueString("07:30:00"), proto.NewValueString("st113"), proto.NewValueFloat64(25)}, + {proto.NewValueString("07:45:00"), proto.NewValueString("st113"), proto.NewValueFloat64(0)}, + {proto.NewValueString("07:00:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(0)}, + {proto.NewValueString("07:15:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(10)}, + {proto.NewValueString("07:30:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(5)}, + {proto.NewValueString("07:45:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(30)}, + {proto.NewValueString("08:00:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(25)}, + }, + "0", + }, + { + [][]proto.Value{ + //order column, partition column, value column + {proto.NewValueString("07:00:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(0)}, + {proto.NewValueString("07:00:00"), proto.NewValueString("st113"), proto.NewValueFloat64(10)}, + {proto.NewValueString("07:15:00"), proto.NewValueString("st113"), proto.NewValueFloat64(9)}, + {proto.NewValueString("07:30:00"), proto.NewValueString("st113"), proto.NewValueFloat64(25)}, + {proto.NewValueString("07:45:00"), proto.NewValueString("st113"), proto.NewValueFloat64(0)}, + {proto.NewValueString("07:00:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(0)}, + {proto.NewValueString("07:15:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(10)}, + {proto.NewValueString("07:30:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(5)}, + {proto.NewValueString("07:45:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(30)}, + {proto.NewValueString("08:00:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(25)}, + }, + "0", + }, + { + [][]proto.Value{ + //order column, partition column, value column + {proto.NewValueString("07:15:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(10)}, + {proto.NewValueString("07:00:00"), proto.NewValueString("st113"), proto.NewValueFloat64(10)}, + {proto.NewValueString("07:15:00"), proto.NewValueString("st113"), proto.NewValueFloat64(9)}, + {proto.NewValueString("07:30:00"), proto.NewValueString("st113"), proto.NewValueFloat64(25)}, + {proto.NewValueString("07:45:00"), proto.NewValueString("st113"), proto.NewValueFloat64(0)}, + {proto.NewValueString("07:00:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(0)}, + {proto.NewValueString("07:15:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(10)}, + {proto.NewValueString("07:30:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(5)}, + {proto.NewValueString("07:45:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(30)}, + {proto.NewValueString("08:00:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(25)}, + }, + "10", + }, + { + [][]proto.Value{ + //order column, partition column, value column + {proto.NewValueString("07:30:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(5)}, + {proto.NewValueString("07:00:00"), proto.NewValueString("st113"), proto.NewValueFloat64(10)}, + {proto.NewValueString("07:15:00"), proto.NewValueString("st113"), proto.NewValueFloat64(9)}, + {proto.NewValueString("07:30:00"), proto.NewValueString("st113"), proto.NewValueFloat64(25)}, + {proto.NewValueString("07:45:00"), proto.NewValueString("st113"), proto.NewValueFloat64(0)}, + {proto.NewValueString("07:00:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(0)}, + {proto.NewValueString("07:15:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(10)}, + {proto.NewValueString("07:30:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(5)}, + {proto.NewValueString("07:45:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(30)}, + {proto.NewValueString("08:00:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(25)}, + }, + "5", + }, + { + [][]proto.Value{ + //order column, partition column, value column + {proto.NewValueString("07:45:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(30)}, + {proto.NewValueString("07:00:00"), proto.NewValueString("st113"), proto.NewValueFloat64(10)}, + {proto.NewValueString("07:15:00"), proto.NewValueString("st113"), proto.NewValueFloat64(9)}, + {proto.NewValueString("07:30:00"), proto.NewValueString("st113"), proto.NewValueFloat64(25)}, + {proto.NewValueString("07:45:00"), proto.NewValueString("st113"), proto.NewValueFloat64(0)}, + {proto.NewValueString("07:00:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(0)}, + {proto.NewValueString("07:15:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(10)}, + {proto.NewValueString("07:30:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(5)}, + {proto.NewValueString("07:45:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(30)}, + {proto.NewValueString("08:00:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(25)}, + }, + "30", + }, + { + [][]proto.Value{ + //order column, partition column, value column + {proto.NewValueString("08:00:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(25)}, + {proto.NewValueString("07:00:00"), proto.NewValueString("st113"), proto.NewValueFloat64(10)}, + {proto.NewValueString("07:15:00"), proto.NewValueString("st113"), proto.NewValueFloat64(9)}, + {proto.NewValueString("07:30:00"), proto.NewValueString("st113"), proto.NewValueFloat64(25)}, + {proto.NewValueString("07:45:00"), proto.NewValueString("st113"), proto.NewValueFloat64(0)}, + {proto.NewValueString("07:00:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(0)}, + {proto.NewValueString("07:15:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(10)}, + {proto.NewValueString("07:30:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(5)}, + {proto.NewValueString("07:45:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(30)}, + {proto.NewValueString("08:00:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(25)}, + }, + "25", + }, + } { + t.Run(v.want, func(t *testing.T) { + var inputs []proto.Valuer + for i := range v.inputs { + for j := range v.inputs[i] { + inputs = append(inputs, proto.ToValuer(v.inputs[i][j])) + } + } + out, err := fn.Apply(context.Background(), inputs...) + assert.NoError(t, err) + assert.Equal(t, v.want, fmt.Sprint(out)) + }) + } +} diff --git a/pkg/runtime/function/lead.go b/pkg/runtime/function/lead.go new file mode 100644 index 00000000..14da35f4 --- /dev/null +++ b/pkg/runtime/function/lead.go @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package function + +import ( + "context" + "strings" +) + +import ( + "github.com/pkg/errors" +) + +import ( + "github.com/arana-db/arana/pkg/proto" +) + +// FuncLead is https://dev.mysql.com/doc/refman/8.0/en/window-function-descriptions.html +const FuncLead = "LEAD" + +var _ proto.Func = (*leadFunc)(nil) + +func init() { + proto.RegisterFunc(FuncLead, leadFunc{}) +} + +type leadFunc struct{} + +func (a leadFunc) Apply(ctx context.Context, inputs ...proto.Valuer) (proto.Value, error) { + if len(inputs) < 6 { + return proto.NewValueString(""), nil + } + + // order by this column + firstOrderColumn, err := inputs[0].Value(ctx) + if firstOrderColumn == nil || err != nil { + return nil, errors.Wrapf(err, "cannot eval %s", FuncLead) + } + firstOrderColumnStr := firstOrderColumn.String() + // partition by this column + firstPartitionColumn, err := inputs[1].Value(ctx) + if firstPartitionColumn == nil || err != nil { + return nil, errors.Wrapf(err, "cannot eval %s", FuncLead) + } + firstPartitionColumnStr := firstPartitionColumn.String() + // output by this volumn + firstValueColumn, err := inputs[2].Value(ctx) + if firstValueColumn == nil || err != nil { + return nil, errors.Wrapf(err, "cannot eval %s", FuncLead) + } + firstValueColumnDec, _ := firstValueColumn.Float64() + lagValue := 0.0 + lagIndex := -1 + startOffset := 3 + + for { + orderColumn, err := inputs[startOffset].Value(ctx) + if orderColumn == nil || err != nil { + return nil, errors.Wrapf(err, "cannot eval %s", FuncLead) + } + orderColumnStr := orderColumn.String() + partitionColumn, err := inputs[startOffset+1].Value(ctx) + if partitionColumn == nil || err != nil { + return nil, errors.Wrapf(err, "cannot eval %s", FuncLead) + } + partitionColumnStr := partitionColumn.String() + valueColumn, err := inputs[startOffset+2].Value(ctx) + if valueColumn == nil || err != nil { + return nil, errors.Wrapf(err, "cannot eval %s", FuncLead) + } + valueColumnDec, _ := valueColumn.Float64() + if strings.Compare(firstOrderColumnStr, orderColumnStr) == 0 && + strings.Compare(firstPartitionColumnStr, partitionColumnStr) == 0 && + firstValueColumnDec == valueColumnDec { + if startOffset+6 <= len(inputs) { + lagValueColumn, err := inputs[startOffset+5].Value(ctx) + if lagValueColumn == nil || err != nil { + return nil, errors.Wrapf(err, "cannot eval %s", FuncLead) + } + lagValueColumnDec, _ := lagValueColumn.Float64() + lagValue = lagValueColumnDec + lagIndex = startOffset - 1 + } + break + } + + startOffset += 3 + if startOffset >= len(inputs) { + break + } + } + + if lagIndex < 0 { + return proto.NewValueString(""), nil + } else { + return proto.NewValueFloat64(lagValue), nil + } +} + +func (a leadFunc) NumInput() int { + return 1 +} diff --git a/pkg/runtime/function/lead_test.go b/pkg/runtime/function/lead_test.go new file mode 100644 index 00000000..53655189 --- /dev/null +++ b/pkg/runtime/function/lead_test.go @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package function + +import ( + "context" + "fmt" + "testing" +) + +import ( + "github.com/stretchr/testify/assert" +) + +import ( + "github.com/arana-db/arana/pkg/proto" +) + +func TestFuncLead(t *testing.T) { + fn := proto.MustGetFunc(FuncLead) + type tt struct { + inputs [][]proto.Value + want string + } + for _, v := range []tt{ + { + [][]proto.Value{ + //order column, partition column, value column + {proto.NewValueString("12:00:00"), proto.NewValueString(""), proto.NewValueFloat64(100)}, + {proto.NewValueString("12:00:00"), proto.NewValueString(""), proto.NewValueFloat64(100)}, + {proto.NewValueString("13:00:00"), proto.NewValueString(""), proto.NewValueFloat64(125)}, + {proto.NewValueString("14:00:00"), proto.NewValueString(""), proto.NewValueFloat64(132)}, + {proto.NewValueString("15:00:00"), proto.NewValueString(""), proto.NewValueFloat64(145)}, + {proto.NewValueString("16:00:00"), proto.NewValueString(""), proto.NewValueFloat64(140)}, + {proto.NewValueString("17:00:00"), proto.NewValueString(""), proto.NewValueFloat64(150)}, + {proto.NewValueString("18:00:00"), proto.NewValueString(""), proto.NewValueFloat64(200)}, + }, + "125", + }, + { + [][]proto.Value{ + //order column, partition column, value column + {proto.NewValueString("13:00:00"), proto.NewValueString(""), proto.NewValueFloat64(125)}, + {proto.NewValueString("12:00:00"), proto.NewValueString(""), proto.NewValueFloat64(100)}, + {proto.NewValueString("13:00:00"), proto.NewValueString(""), proto.NewValueFloat64(125)}, + {proto.NewValueString("14:00:00"), proto.NewValueString(""), proto.NewValueFloat64(132)}, + {proto.NewValueString("15:00:00"), proto.NewValueString(""), proto.NewValueFloat64(145)}, + {proto.NewValueString("16:00:00"), proto.NewValueString(""), proto.NewValueFloat64(140)}, + {proto.NewValueString("17:00:00"), proto.NewValueString(""), proto.NewValueFloat64(150)}, + {proto.NewValueString("18:00:00"), proto.NewValueString(""), proto.NewValueFloat64(200)}, + }, + "132", + }, + { + [][]proto.Value{ + //order column, partition column, value column + {proto.NewValueString("14:00:00"), proto.NewValueString(""), proto.NewValueFloat64(132)}, + {proto.NewValueString("12:00:00"), proto.NewValueString(""), proto.NewValueFloat64(100)}, + {proto.NewValueString("13:00:00"), proto.NewValueString(""), proto.NewValueFloat64(125)}, + {proto.NewValueString("14:00:00"), proto.NewValueString(""), proto.NewValueFloat64(132)}, + {proto.NewValueString("15:00:00"), proto.NewValueString(""), proto.NewValueFloat64(145)}, + {proto.NewValueString("16:00:00"), proto.NewValueString(""), proto.NewValueFloat64(140)}, + {proto.NewValueString("17:00:00"), proto.NewValueString(""), proto.NewValueFloat64(150)}, + {proto.NewValueString("18:00:00"), proto.NewValueString(""), proto.NewValueFloat64(200)}, + }, + "145", + }, + { + [][]proto.Value{ + //order column, partition column, value column + {proto.NewValueString("15:00:00"), proto.NewValueString(""), proto.NewValueFloat64(145)}, + {proto.NewValueString("12:00:00"), proto.NewValueString(""), proto.NewValueFloat64(100)}, + {proto.NewValueString("13:00:00"), proto.NewValueString(""), proto.NewValueFloat64(125)}, + {proto.NewValueString("14:00:00"), proto.NewValueString(""), proto.NewValueFloat64(132)}, + {proto.NewValueString("15:00:00"), proto.NewValueString(""), proto.NewValueFloat64(145)}, + {proto.NewValueString("16:00:00"), proto.NewValueString(""), proto.NewValueFloat64(140)}, + {proto.NewValueString("17:00:00"), proto.NewValueString(""), proto.NewValueFloat64(150)}, + {proto.NewValueString("18:00:00"), proto.NewValueString(""), proto.NewValueFloat64(200)}, + }, + "140", + }, + { + [][]proto.Value{ + //order column, partition column, value column + {proto.NewValueString("16:00:00"), proto.NewValueString(""), proto.NewValueFloat64(140)}, + {proto.NewValueString("12:00:00"), proto.NewValueString(""), proto.NewValueFloat64(100)}, + {proto.NewValueString("13:00:00"), proto.NewValueString(""), proto.NewValueFloat64(125)}, + {proto.NewValueString("14:00:00"), proto.NewValueString(""), proto.NewValueFloat64(132)}, + {proto.NewValueString("15:00:00"), proto.NewValueString(""), proto.NewValueFloat64(145)}, + {proto.NewValueString("16:00:00"), proto.NewValueString(""), proto.NewValueFloat64(140)}, + {proto.NewValueString("17:00:00"), proto.NewValueString(""), proto.NewValueFloat64(150)}, + {proto.NewValueString("18:00:00"), proto.NewValueString(""), proto.NewValueFloat64(200)}, + }, + "150", + }, + { + [][]proto.Value{ + //order column, partition column, value column + {proto.NewValueString("17:00:00"), proto.NewValueString(""), proto.NewValueFloat64(150)}, + {proto.NewValueString("12:00:00"), proto.NewValueString(""), proto.NewValueFloat64(100)}, + {proto.NewValueString("13:00:00"), proto.NewValueString(""), proto.NewValueFloat64(125)}, + {proto.NewValueString("14:00:00"), proto.NewValueString(""), proto.NewValueFloat64(132)}, + {proto.NewValueString("15:00:00"), proto.NewValueString(""), proto.NewValueFloat64(145)}, + {proto.NewValueString("16:00:00"), proto.NewValueString(""), proto.NewValueFloat64(140)}, + {proto.NewValueString("17:00:00"), proto.NewValueString(""), proto.NewValueFloat64(150)}, + {proto.NewValueString("18:00:00"), proto.NewValueString(""), proto.NewValueFloat64(200)}, + }, + "200", + }, + { + [][]proto.Value{ + //order column, partition column, value column + {proto.NewValueString("18:00:00"), proto.NewValueString(""), proto.NewValueFloat64(200)}, + {proto.NewValueString("12:00:00"), proto.NewValueString(""), proto.NewValueFloat64(100)}, + {proto.NewValueString("13:00:00"), proto.NewValueString(""), proto.NewValueFloat64(125)}, + {proto.NewValueString("14:00:00"), proto.NewValueString(""), proto.NewValueFloat64(132)}, + {proto.NewValueString("15:00:00"), proto.NewValueString(""), proto.NewValueFloat64(145)}, + {proto.NewValueString("16:00:00"), proto.NewValueString(""), proto.NewValueFloat64(140)}, + {proto.NewValueString("17:00:00"), proto.NewValueString(""), proto.NewValueFloat64(150)}, + {proto.NewValueString("18:00:00"), proto.NewValueString(""), proto.NewValueFloat64(200)}, + }, + "", + }, + } { + t.Run(v.want, func(t *testing.T) { + var inputs []proto.Valuer + for i := range v.inputs { + for j := range v.inputs[i] { + inputs = append(inputs, proto.ToValuer(v.inputs[i][j])) + } + } + out, err := fn.Apply(context.Background(), inputs...) + assert.NoError(t, err) + assert.Equal(t, v.want, fmt.Sprint(out)) + }) + } +} From 9eb15eeb2234d423cf7b96c6e86a9f1add69daec Mon Sep 17 00:00:00 2001 From: csynineyang Date: Wed, 14 Jun 2023 11:22:41 +0800 Subject: [PATCH 18/20] Support window function: NTH_VALUE/NTILE/ROW_NUMBER --- pkg/runtime/function/nth_value.go | 123 ++++++++++++++ pkg/runtime/function/nth_value_test.go | 207 ++++++++++++++++++++++++ pkg/runtime/function/ntile.go | 124 ++++++++++++++ pkg/runtime/function/ntile_test.go | 207 ++++++++++++++++++++++++ pkg/runtime/function/row_number.go | 109 +++++++++++++ pkg/runtime/function/row_number_test.go | 152 +++++++++++++++++ 6 files changed, 922 insertions(+) create mode 100644 pkg/runtime/function/nth_value.go create mode 100644 pkg/runtime/function/nth_value_test.go create mode 100644 pkg/runtime/function/ntile.go create mode 100644 pkg/runtime/function/ntile_test.go create mode 100644 pkg/runtime/function/row_number.go create mode 100644 pkg/runtime/function/row_number_test.go diff --git a/pkg/runtime/function/nth_value.go b/pkg/runtime/function/nth_value.go new file mode 100644 index 00000000..f9f2a8f4 --- /dev/null +++ b/pkg/runtime/function/nth_value.go @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package function + +import ( + "context" + "strings" +) + +import ( + "github.com/pkg/errors" +) + +import ( + "github.com/arana-db/arana/pkg/proto" +) + +// FuncNthValue is https://dev.mysql.com/doc/refman/8.0/en/window-function-descriptions.html +const FuncNthValue = "NTH_VALUE" + +var _ proto.Func = (*nthvalueFunc)(nil) + +func init() { + proto.RegisterFunc(FuncNthValue, nthvalueFunc{}) +} + +type nthvalueFunc struct{} + +func (a nthvalueFunc) Apply(ctx context.Context, inputs ...proto.Valuer) (proto.Value, error) { + if len(inputs) < 7 { + return proto.NewValueString(""), nil + } + + // nth + nth, err := inputs[0].Value(ctx) + if nth == nil || err != nil { + return nil, errors.Wrapf(err, "cannot eval %s", FuncNtile) + } + nthInt, _ := nth.Int64() + // order by this column + firstOrderColumn, err := inputs[1].Value(ctx) + if firstOrderColumn == nil || err != nil { + return nil, errors.Wrapf(err, "cannot eval %s", FuncNtile) + } + firstOrderColumnStr := firstOrderColumn.String() + // partition by this column + firstPartitionColumn, err := inputs[2].Value(ctx) + if firstPartitionColumn == nil || err != nil { + return nil, errors.Wrapf(err, "cannot eval %s", FuncFirstValue) + } + firstPartitionColumnStr := firstPartitionColumn.String() + // output by this volumn + firstValueColumn, err := inputs[3].Value(ctx) + if firstValueColumn == nil || err != nil { + return nil, errors.Wrapf(err, "cannot eval %s", FuncNtile) + } + firstValueColumnDec, _ := firstValueColumn.Float64() + nthIndex := int64(0) + nthValue := 0.0 + startOffset := 4 + + for { + orderColumn, err := inputs[startOffset].Value(ctx) + if orderColumn == nil || err != nil { + return nil, errors.Wrapf(err, "cannot eval %s", FuncNtile) + } + orderColumnStr := orderColumn.String() + partitionColumn, err := inputs[startOffset+1].Value(ctx) + if partitionColumn == nil || err != nil { + return nil, errors.Wrapf(err, "cannot eval %s", FuncFirstValue) + } + partitionColumnStr := partitionColumn.String() + valueColumn, err := inputs[startOffset+2].Value(ctx) + if valueColumn == nil || err != nil { + return nil, errors.Wrapf(err, "cannot eval %s", FuncFirstValue) + } + valueColumnDec, _ := valueColumn.Float64() + + if strings.Compare(firstPartitionColumnStr, partitionColumnStr) == 0 { + nthIndex += 1 + if nthIndex == nthInt { + nthValue = valueColumnDec + break + } + } + + if strings.Compare(firstOrderColumnStr, orderColumnStr) == 0 && + strings.Compare(firstPartitionColumnStr, partitionColumnStr) == 0 && + firstValueColumnDec == valueColumnDec { + break + } + + startOffset += 3 + if startOffset >= len(inputs) { + break + } + } + + if nthIndex < nthInt { + return proto.NewValueString(""), nil + } else { + return proto.NewValueFloat64(nthValue), nil + } +} + +func (a nthvalueFunc) NumInput() int { + return 1 +} diff --git a/pkg/runtime/function/nth_value_test.go b/pkg/runtime/function/nth_value_test.go new file mode 100644 index 00000000..60a895fd --- /dev/null +++ b/pkg/runtime/function/nth_value_test.go @@ -0,0 +1,207 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package function + +import ( + "context" + "fmt" + "testing" +) + +import ( + "github.com/stretchr/testify/assert" +) + +import ( + "github.com/arana-db/arana/pkg/proto" +) + +func TestFuncNthValue(t *testing.T) { + fn := proto.MustGetFunc(FuncNthValue) + type tt struct { + inputs [][]proto.Value + want string + } + for _, v := range []tt{ + { + [][]proto.Value{ + //order column, partition column, value column + {proto.NewValueInt64(4)}, + {proto.NewValueString("07:00:00"), proto.NewValueString("st113"), proto.NewValueFloat64(10)}, + {proto.NewValueString("07:00:00"), proto.NewValueString("st113"), proto.NewValueFloat64(10)}, + {proto.NewValueString("07:15:00"), proto.NewValueString("st113"), proto.NewValueFloat64(9)}, + {proto.NewValueString("07:30:00"), proto.NewValueString("st113"), proto.NewValueFloat64(25)}, + {proto.NewValueString("07:45:00"), proto.NewValueString("st113"), proto.NewValueFloat64(0)}, + {proto.NewValueString("07:00:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(0)}, + {proto.NewValueString("07:15:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(10)}, + {proto.NewValueString("07:30:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(5)}, + {proto.NewValueString("07:45:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(30)}, + {proto.NewValueString("08:00:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(25)}, + }, + "", + }, + { + [][]proto.Value{ + //order column, partition column, value column + {proto.NewValueInt64(4)}, + {proto.NewValueString("07:15:00"), proto.NewValueString("st113"), proto.NewValueFloat64(9)}, + {proto.NewValueString("07:00:00"), proto.NewValueString("st113"), proto.NewValueFloat64(10)}, + {proto.NewValueString("07:15:00"), proto.NewValueString("st113"), proto.NewValueFloat64(9)}, + {proto.NewValueString("07:30:00"), proto.NewValueString("st113"), proto.NewValueFloat64(25)}, + {proto.NewValueString("07:45:00"), proto.NewValueString("st113"), proto.NewValueFloat64(0)}, + {proto.NewValueString("07:00:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(0)}, + {proto.NewValueString("07:15:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(10)}, + {proto.NewValueString("07:30:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(5)}, + {proto.NewValueString("07:45:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(30)}, + {proto.NewValueString("08:00:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(25)}, + }, + "", + }, + { + [][]proto.Value{ + //order column, partition column, value column + {proto.NewValueInt64(4)}, + {proto.NewValueString("07:30:00"), proto.NewValueString("st113"), proto.NewValueFloat64(25)}, + {proto.NewValueString("07:00:00"), proto.NewValueString("st113"), proto.NewValueFloat64(10)}, + {proto.NewValueString("07:15:00"), proto.NewValueString("st113"), proto.NewValueFloat64(9)}, + {proto.NewValueString("07:30:00"), proto.NewValueString("st113"), proto.NewValueFloat64(25)}, + {proto.NewValueString("07:45:00"), proto.NewValueString("st113"), proto.NewValueFloat64(0)}, + {proto.NewValueString("07:00:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(0)}, + {proto.NewValueString("07:15:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(10)}, + {proto.NewValueString("07:30:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(5)}, + {proto.NewValueString("07:45:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(30)}, + {proto.NewValueString("08:00:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(25)}, + }, + "", + }, + { + [][]proto.Value{ + //order column, partition column, value column + {proto.NewValueInt64(4)}, + {proto.NewValueString("07:45:00"), proto.NewValueString("st113"), proto.NewValueFloat64(0)}, + {proto.NewValueString("07:00:00"), proto.NewValueString("st113"), proto.NewValueFloat64(10)}, + {proto.NewValueString("07:15:00"), proto.NewValueString("st113"), proto.NewValueFloat64(9)}, + {proto.NewValueString("07:30:00"), proto.NewValueString("st113"), proto.NewValueFloat64(25)}, + {proto.NewValueString("07:45:00"), proto.NewValueString("st113"), proto.NewValueFloat64(0)}, + {proto.NewValueString("07:00:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(0)}, + {proto.NewValueString("07:15:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(10)}, + {proto.NewValueString("07:30:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(5)}, + {proto.NewValueString("07:45:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(30)}, + {proto.NewValueString("08:00:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(25)}, + }, + "0", + }, + { + [][]proto.Value{ + //order column, partition column, value column + {proto.NewValueInt64(4)}, + {proto.NewValueString("07:00:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(0)}, + {proto.NewValueString("07:00:00"), proto.NewValueString("st113"), proto.NewValueFloat64(10)}, + {proto.NewValueString("07:15:00"), proto.NewValueString("st113"), proto.NewValueFloat64(9)}, + {proto.NewValueString("07:30:00"), proto.NewValueString("st113"), proto.NewValueFloat64(25)}, + {proto.NewValueString("07:45:00"), proto.NewValueString("st113"), proto.NewValueFloat64(0)}, + {proto.NewValueString("07:00:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(0)}, + {proto.NewValueString("07:15:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(10)}, + {proto.NewValueString("07:30:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(5)}, + {proto.NewValueString("07:45:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(30)}, + {proto.NewValueString("08:00:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(25)}, + }, + "", + }, + { + [][]proto.Value{ + //order column, partition column, value column + {proto.NewValueInt64(4)}, + {proto.NewValueString("07:15:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(10)}, + {proto.NewValueString("07:00:00"), proto.NewValueString("st113"), proto.NewValueFloat64(10)}, + {proto.NewValueString("07:15:00"), proto.NewValueString("st113"), proto.NewValueFloat64(9)}, + {proto.NewValueString("07:30:00"), proto.NewValueString("st113"), proto.NewValueFloat64(25)}, + {proto.NewValueString("07:45:00"), proto.NewValueString("st113"), proto.NewValueFloat64(0)}, + {proto.NewValueString("07:00:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(0)}, + {proto.NewValueString("07:15:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(10)}, + {proto.NewValueString("07:30:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(5)}, + {proto.NewValueString("07:45:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(30)}, + {proto.NewValueString("08:00:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(25)}, + }, + "", + }, + { + [][]proto.Value{ + //order column, partition column, value column + {proto.NewValueInt64(4)}, + {proto.NewValueString("07:30:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(5)}, + {proto.NewValueString("07:00:00"), proto.NewValueString("st113"), proto.NewValueFloat64(10)}, + {proto.NewValueString("07:15:00"), proto.NewValueString("st113"), proto.NewValueFloat64(9)}, + {proto.NewValueString("07:30:00"), proto.NewValueString("st113"), proto.NewValueFloat64(25)}, + {proto.NewValueString("07:45:00"), proto.NewValueString("st113"), proto.NewValueFloat64(0)}, + {proto.NewValueString("07:00:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(0)}, + {proto.NewValueString("07:15:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(10)}, + {proto.NewValueString("07:30:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(5)}, + {proto.NewValueString("07:45:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(30)}, + {proto.NewValueString("08:00:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(25)}, + }, + "", + }, + { + [][]proto.Value{ + //order column, partition column, value column + {proto.NewValueInt64(4)}, + {proto.NewValueString("07:45:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(30)}, + {proto.NewValueString("07:00:00"), proto.NewValueString("st113"), proto.NewValueFloat64(10)}, + {proto.NewValueString("07:15:00"), proto.NewValueString("st113"), proto.NewValueFloat64(9)}, + {proto.NewValueString("07:30:00"), proto.NewValueString("st113"), proto.NewValueFloat64(25)}, + {proto.NewValueString("07:45:00"), proto.NewValueString("st113"), proto.NewValueFloat64(0)}, + {proto.NewValueString("07:00:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(0)}, + {proto.NewValueString("07:15:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(10)}, + {proto.NewValueString("07:30:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(5)}, + {proto.NewValueString("07:45:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(30)}, + {proto.NewValueString("08:00:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(25)}, + }, + "30", + }, + { + [][]proto.Value{ + //order column, partition column, value column + {proto.NewValueInt64(4)}, + {proto.NewValueString("08:00:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(25)}, + {proto.NewValueString("07:00:00"), proto.NewValueString("st113"), proto.NewValueFloat64(10)}, + {proto.NewValueString("07:15:00"), proto.NewValueString("st113"), proto.NewValueFloat64(9)}, + {proto.NewValueString("07:30:00"), proto.NewValueString("st113"), proto.NewValueFloat64(25)}, + {proto.NewValueString("07:45:00"), proto.NewValueString("st113"), proto.NewValueFloat64(0)}, + {proto.NewValueString("07:00:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(0)}, + {proto.NewValueString("07:15:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(10)}, + {proto.NewValueString("07:30:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(5)}, + {proto.NewValueString("07:45:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(30)}, + {proto.NewValueString("08:00:00"), proto.NewValueString("xh458"), proto.NewValueFloat64(25)}, + }, + "30", + }, + } { + t.Run(v.want, func(t *testing.T) { + var inputs []proto.Valuer + for i := range v.inputs { + for j := range v.inputs[i] { + inputs = append(inputs, proto.ToValuer(v.inputs[i][j])) + } + } + out, err := fn.Apply(context.Background(), inputs...) + assert.NoError(t, err) + assert.Equal(t, v.want, fmt.Sprint(out)) + }) + } +} diff --git a/pkg/runtime/function/ntile.go b/pkg/runtime/function/ntile.go new file mode 100644 index 00000000..e13ae0ce --- /dev/null +++ b/pkg/runtime/function/ntile.go @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package function + +import ( + "context" + "strings" +) + +import ( + "github.com/pkg/errors" +) + +import ( + "github.com/arana-db/arana/pkg/proto" +) + +// FuncNtile is https://dev.mysql.com/doc/refman/8.0/en/window-function-descriptions.html +const FuncNtile = "NTILE" + +var _ proto.Func = (*ntileFunc)(nil) + +func init() { + proto.RegisterFunc(FuncNtile, ntileFunc{}) +} + +type ntileFunc struct{} + +func (a ntileFunc) Apply(ctx context.Context, inputs ...proto.Valuer) (proto.Value, error) { + if len(inputs) < 7 { + return proto.NewValueString(""), nil + } + + // bucket number + bucketNum, err := inputs[0].Value(ctx) + if bucketNum == nil || err != nil { + return nil, errors.Wrapf(err, "cannot eval %s", FuncNtile) + } + bucketNumInt, _ := bucketNum.Int64() + // order by this column + firstOrderColumn, err := inputs[1].Value(ctx) + if firstOrderColumn == nil || err != nil { + return nil, errors.Wrapf(err, "cannot eval %s", FuncNtile) + } + firstOrderColumnStr := firstOrderColumn.String() + // partition by this column + firstPartitionColumn, err := inputs[2].Value(ctx) + if firstPartitionColumn == nil || err != nil { + return nil, errors.Wrapf(err, "cannot eval %s", FuncNtile) + } + firstPartitionColumnStr := firstPartitionColumn.String() + // output by this volumn + firstValueColumn, err := inputs[3].Value(ctx) + if firstValueColumn == nil || err != nil { + return nil, errors.Wrapf(err, "cannot eval %s", FuncNtile) + } + firstValueColumnDec, _ := firstValueColumn.Float64() + startOffset := 4 + bucketSeq := int64(1) + bucketIndex := int64(0) + bucketLeft := int64(0) + bucketDiv := int64((len(inputs)-4)/3) / bucketNumInt + bucketMod := int64((len(inputs)-4)/3) % bucketNumInt + + for { + orderColumn, err := inputs[startOffset].Value(ctx) + if orderColumn == nil || err != nil { + return nil, errors.Wrapf(err, "cannot eval %s", FuncNtile) + } + orderColumnStr := orderColumn.String() + partitionColumn, err := inputs[startOffset+1].Value(ctx) + if partitionColumn == nil || err != nil { + return nil, errors.Wrapf(err, "cannot eval %s", FuncNtile) + } + partitionColumnStr := partitionColumn.String() + valueColumn, err := inputs[startOffset+2].Value(ctx) + if valueColumn == nil || err != nil { + return nil, errors.Wrapf(err, "cannot eval %s", FuncNtile) + } + valueColumnDec, _ := valueColumn.Float64() + + bucketIndex += 1 + if bucketIndex > bucketDiv { + if bucketIndex == bucketDiv+1 && bucketLeft < bucketMod { + bucketLeft += 1 + } else { + bucketIndex = int64(1) + bucketSeq += 1 + } + } + + if strings.Compare(firstOrderColumnStr, orderColumnStr) == 0 && + strings.Compare(firstPartitionColumnStr, partitionColumnStr) == 0 && + firstValueColumnDec == valueColumnDec { + break + } + + startOffset += 3 + if startOffset >= len(inputs) { + break + } + } + + return proto.NewValueInt64(int64(bucketSeq)), nil +} + +func (a ntileFunc) NumInput() int { + return 1 +} diff --git a/pkg/runtime/function/ntile_test.go b/pkg/runtime/function/ntile_test.go new file mode 100644 index 00000000..94b28964 --- /dev/null +++ b/pkg/runtime/function/ntile_test.go @@ -0,0 +1,207 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package function + +import ( + "context" + "fmt" + "testing" +) + +import ( + "github.com/stretchr/testify/assert" +) + +import ( + "github.com/arana-db/arana/pkg/proto" +) + +func TestFuncNtile(t *testing.T) { + fn := proto.MustGetFunc(FuncNtile) + type tt struct { + inputs [][]proto.Value + want string + } + for _, v := range []tt{ + { + [][]proto.Value{ + //order column, partition column, value column + {proto.NewValueInt64(2)}, + {proto.NewValueString("12:00:00"), proto.NewValueString(""), proto.NewValueFloat64(100)}, + {proto.NewValueString("12:00:00"), proto.NewValueString(""), proto.NewValueFloat64(100)}, + {proto.NewValueString("13:00:00"), proto.NewValueString(""), proto.NewValueFloat64(125)}, + {proto.NewValueString("14:00:00"), proto.NewValueString(""), proto.NewValueFloat64(132)}, + {proto.NewValueString("15:00:00"), proto.NewValueString(""), proto.NewValueFloat64(145)}, + {proto.NewValueString("16:00:00"), proto.NewValueString(""), proto.NewValueFloat64(140)}, + {proto.NewValueString("17:00:00"), proto.NewValueString(""), proto.NewValueFloat64(150)}, + {proto.NewValueString("18:00:00"), proto.NewValueString(""), proto.NewValueFloat64(200)}, + {proto.NewValueString("19:00:00"), proto.NewValueString(""), proto.NewValueFloat64(220)}, + {proto.NewValueString("20:00:00"), proto.NewValueString(""), proto.NewValueFloat64(260)}, + }, + "1", + }, + { + [][]proto.Value{ + //order column, partition column, value column + {proto.NewValueInt64(2)}, + {proto.NewValueString("13:00:00"), proto.NewValueString(""), proto.NewValueFloat64(125)}, + {proto.NewValueString("12:00:00"), proto.NewValueString(""), proto.NewValueFloat64(100)}, + {proto.NewValueString("13:00:00"), proto.NewValueString(""), proto.NewValueFloat64(125)}, + {proto.NewValueString("14:00:00"), proto.NewValueString(""), proto.NewValueFloat64(132)}, + {proto.NewValueString("15:00:00"), proto.NewValueString(""), proto.NewValueFloat64(145)}, + {proto.NewValueString("16:00:00"), proto.NewValueString(""), proto.NewValueFloat64(140)}, + {proto.NewValueString("17:00:00"), proto.NewValueString(""), proto.NewValueFloat64(150)}, + {proto.NewValueString("18:00:00"), proto.NewValueString(""), proto.NewValueFloat64(200)}, + {proto.NewValueString("19:00:00"), proto.NewValueString(""), proto.NewValueFloat64(220)}, + {proto.NewValueString("20:00:00"), proto.NewValueString(""), proto.NewValueFloat64(260)}, + }, + "1", + }, + { + [][]proto.Value{ + //order column, partition column, value column + {proto.NewValueInt64(2)}, + {proto.NewValueString("14:00:00"), proto.NewValueString(""), proto.NewValueFloat64(132)}, + {proto.NewValueString("12:00:00"), proto.NewValueString(""), proto.NewValueFloat64(100)}, + {proto.NewValueString("13:00:00"), proto.NewValueString(""), proto.NewValueFloat64(125)}, + {proto.NewValueString("14:00:00"), proto.NewValueString(""), proto.NewValueFloat64(132)}, + {proto.NewValueString("15:00:00"), proto.NewValueString(""), proto.NewValueFloat64(145)}, + {proto.NewValueString("16:00:00"), proto.NewValueString(""), proto.NewValueFloat64(140)}, + {proto.NewValueString("17:00:00"), proto.NewValueString(""), proto.NewValueFloat64(150)}, + {proto.NewValueString("18:00:00"), proto.NewValueString(""), proto.NewValueFloat64(200)}, + {proto.NewValueString("19:00:00"), proto.NewValueString(""), proto.NewValueFloat64(220)}, + {proto.NewValueString("20:00:00"), proto.NewValueString(""), proto.NewValueFloat64(260)}, + }, + "1", + }, + { + [][]proto.Value{ + //order column, partition column, value column + {proto.NewValueInt64(2)}, + {proto.NewValueString("15:00:00"), proto.NewValueString(""), proto.NewValueFloat64(145)}, + {proto.NewValueString("12:00:00"), proto.NewValueString(""), proto.NewValueFloat64(100)}, + {proto.NewValueString("13:00:00"), proto.NewValueString(""), proto.NewValueFloat64(125)}, + {proto.NewValueString("14:00:00"), proto.NewValueString(""), proto.NewValueFloat64(132)}, + {proto.NewValueString("15:00:00"), proto.NewValueString(""), proto.NewValueFloat64(145)}, + {proto.NewValueString("16:00:00"), proto.NewValueString(""), proto.NewValueFloat64(140)}, + {proto.NewValueString("17:00:00"), proto.NewValueString(""), proto.NewValueFloat64(150)}, + {proto.NewValueString("18:00:00"), proto.NewValueString(""), proto.NewValueFloat64(200)}, + {proto.NewValueString("19:00:00"), proto.NewValueString(""), proto.NewValueFloat64(220)}, + {proto.NewValueString("20:00:00"), proto.NewValueString(""), proto.NewValueFloat64(260)}, + }, + "1", + }, + { + [][]proto.Value{ + //order column, partition column, value column + {proto.NewValueInt64(2)}, + {proto.NewValueString("16:00:00"), proto.NewValueString(""), proto.NewValueFloat64(140)}, + {proto.NewValueString("12:00:00"), proto.NewValueString(""), proto.NewValueFloat64(100)}, + {proto.NewValueString("13:00:00"), proto.NewValueString(""), proto.NewValueFloat64(125)}, + {proto.NewValueString("14:00:00"), proto.NewValueString(""), proto.NewValueFloat64(132)}, + {proto.NewValueString("15:00:00"), proto.NewValueString(""), proto.NewValueFloat64(145)}, + {proto.NewValueString("16:00:00"), proto.NewValueString(""), proto.NewValueFloat64(140)}, + {proto.NewValueString("17:00:00"), proto.NewValueString(""), proto.NewValueFloat64(150)}, + {proto.NewValueString("18:00:00"), proto.NewValueString(""), proto.NewValueFloat64(200)}, + {proto.NewValueString("19:00:00"), proto.NewValueString(""), proto.NewValueFloat64(220)}, + {proto.NewValueString("20:00:00"), proto.NewValueString(""), proto.NewValueFloat64(260)}, + }, + "1", + }, + { + [][]proto.Value{ + //order column, partition column, value column + {proto.NewValueInt64(2)}, + {proto.NewValueString("17:00:00"), proto.NewValueString(""), proto.NewValueFloat64(150)}, + {proto.NewValueString("12:00:00"), proto.NewValueString(""), proto.NewValueFloat64(100)}, + {proto.NewValueString("13:00:00"), proto.NewValueString(""), proto.NewValueFloat64(125)}, + {proto.NewValueString("14:00:00"), proto.NewValueString(""), proto.NewValueFloat64(132)}, + {proto.NewValueString("15:00:00"), proto.NewValueString(""), proto.NewValueFloat64(145)}, + {proto.NewValueString("16:00:00"), proto.NewValueString(""), proto.NewValueFloat64(140)}, + {proto.NewValueString("17:00:00"), proto.NewValueString(""), proto.NewValueFloat64(150)}, + {proto.NewValueString("18:00:00"), proto.NewValueString(""), proto.NewValueFloat64(200)}, + {proto.NewValueString("19:00:00"), proto.NewValueString(""), proto.NewValueFloat64(220)}, + {proto.NewValueString("20:00:00"), proto.NewValueString(""), proto.NewValueFloat64(260)}, + }, + "2", + }, + { + [][]proto.Value{ + //order column, partition column, value column + {proto.NewValueInt64(2)}, + {proto.NewValueString("18:00:00"), proto.NewValueString(""), proto.NewValueFloat64(200)}, + {proto.NewValueString("12:00:00"), proto.NewValueString(""), proto.NewValueFloat64(100)}, + {proto.NewValueString("13:00:00"), proto.NewValueString(""), proto.NewValueFloat64(125)}, + {proto.NewValueString("14:00:00"), proto.NewValueString(""), proto.NewValueFloat64(132)}, + {proto.NewValueString("15:00:00"), proto.NewValueString(""), proto.NewValueFloat64(145)}, + {proto.NewValueString("16:00:00"), proto.NewValueString(""), proto.NewValueFloat64(140)}, + {proto.NewValueString("17:00:00"), proto.NewValueString(""), proto.NewValueFloat64(150)}, + {proto.NewValueString("18:00:00"), proto.NewValueString(""), proto.NewValueFloat64(200)}, + {proto.NewValueString("19:00:00"), proto.NewValueString(""), proto.NewValueFloat64(220)}, + {proto.NewValueString("20:00:00"), proto.NewValueString(""), proto.NewValueFloat64(260)}, + }, + "2", + }, + { + [][]proto.Value{ + //order column, partition column, value column + {proto.NewValueInt64(2)}, + {proto.NewValueString("19:00:00"), proto.NewValueString(""), proto.NewValueFloat64(220)}, + {proto.NewValueString("12:00:00"), proto.NewValueString(""), proto.NewValueFloat64(100)}, + {proto.NewValueString("13:00:00"), proto.NewValueString(""), proto.NewValueFloat64(125)}, + {proto.NewValueString("14:00:00"), proto.NewValueString(""), proto.NewValueFloat64(132)}, + {proto.NewValueString("15:00:00"), proto.NewValueString(""), proto.NewValueFloat64(145)}, + {proto.NewValueString("16:00:00"), proto.NewValueString(""), proto.NewValueFloat64(140)}, + {proto.NewValueString("17:00:00"), proto.NewValueString(""), proto.NewValueFloat64(150)}, + {proto.NewValueString("18:00:00"), proto.NewValueString(""), proto.NewValueFloat64(200)}, + {proto.NewValueString("19:00:00"), proto.NewValueString(""), proto.NewValueFloat64(220)}, + {proto.NewValueString("20:00:00"), proto.NewValueString(""), proto.NewValueFloat64(260)}, + }, + "2", + }, + { + [][]proto.Value{ + //order column, partition column, value column + {proto.NewValueInt64(2)}, + {proto.NewValueString("20:00:00"), proto.NewValueString(""), proto.NewValueFloat64(260)}, + {proto.NewValueString("12:00:00"), proto.NewValueString(""), proto.NewValueFloat64(100)}, + {proto.NewValueString("13:00:00"), proto.NewValueString(""), proto.NewValueFloat64(125)}, + {proto.NewValueString("14:00:00"), proto.NewValueString(""), proto.NewValueFloat64(132)}, + {proto.NewValueString("15:00:00"), proto.NewValueString(""), proto.NewValueFloat64(145)}, + {proto.NewValueString("16:00:00"), proto.NewValueString(""), proto.NewValueFloat64(140)}, + {proto.NewValueString("17:00:00"), proto.NewValueString(""), proto.NewValueFloat64(150)}, + {proto.NewValueString("18:00:00"), proto.NewValueString(""), proto.NewValueFloat64(200)}, + {proto.NewValueString("19:00:00"), proto.NewValueString(""), proto.NewValueFloat64(220)}, + {proto.NewValueString("20:00:00"), proto.NewValueString(""), proto.NewValueFloat64(260)}, + }, + "2", + }, + } { + t.Run(v.want, func(t *testing.T) { + var inputs []proto.Valuer + for i := range v.inputs { + for j := range v.inputs[i] { + inputs = append(inputs, proto.ToValuer(v.inputs[i][j])) + } + } + out, err := fn.Apply(context.Background(), inputs...) + assert.NoError(t, err) + assert.Equal(t, v.want, fmt.Sprint(out)) + }) + } +} diff --git a/pkg/runtime/function/row_number.go b/pkg/runtime/function/row_number.go new file mode 100644 index 00000000..1436133b --- /dev/null +++ b/pkg/runtime/function/row_number.go @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package function + +import ( + "context" + "strings" +) + +import ( + "github.com/pkg/errors" +) + +import ( + "github.com/arana-db/arana/pkg/proto" +) + +// FuncRowNumber is https://dev.mysql.com/doc/refman/8.0/en/window-function-descriptions.html +const FuncRowNumber = "ROW_NUMBER" + +var _ proto.Func = (*rownumberFunc)(nil) + +func init() { + proto.RegisterFunc(FuncRowNumber, rownumberFunc{}) +} + +type rownumberFunc struct{} + +func (a rownumberFunc) Apply(ctx context.Context, inputs ...proto.Valuer) (proto.Value, error) { + if len(inputs) < 6 { + return proto.NewValueString(""), nil + } + + // order by this column + firstOrderColumn, err := inputs[0].Value(ctx) + if firstOrderColumn == nil || err != nil { + return nil, errors.Wrapf(err, "cannot eval %s", FuncRowNumber) + } + firstOrderColumnStr := firstOrderColumn.String() + // partition by this column + firstPartitionColumn, err := inputs[1].Value(ctx) + if firstPartitionColumn == nil || err != nil { + return nil, errors.Wrapf(err, "cannot eval %s", FuncRowNumber) + } + firstPartitionColumnStr := firstPartitionColumn.String() + // output by this volumn + firstValueColumn, err := inputs[2].Value(ctx) + if firstValueColumn == nil || err != nil { + return nil, errors.Wrapf(err, "cannot eval %s", FuncRowNumber) + } + firstValueColumnDec, _ := firstValueColumn.Float64() + rowNumber := 0 + startOffset := 3 + + for { + orderColumn, err := inputs[startOffset].Value(ctx) + if orderColumn == nil || err != nil { + return nil, errors.Wrapf(err, "cannot eval %s", FuncRowNumber) + } + orderColumnStr := orderColumn.String() + partitionColumn, err := inputs[startOffset+1].Value(ctx) + if partitionColumn == nil || err != nil { + return nil, errors.Wrapf(err, "cannot eval %s", FuncRowNumber) + } + partitionColumnStr := partitionColumn.String() + valueColumn, err := inputs[startOffset+2].Value(ctx) + if valueColumn == nil || err != nil { + return nil, errors.Wrapf(err, "cannot eval %s", FuncRowNumber) + } + valueColumnDec, _ := valueColumn.Float64() + + rowNumber += 1 + if strings.Compare(firstOrderColumnStr, orderColumnStr) == 0 && + strings.Compare(firstPartitionColumnStr, partitionColumnStr) == 0 && + firstValueColumnDec == valueColumnDec { + break + } + + startOffset += 3 + if startOffset >= len(inputs) { + break + } + } + + if rowNumber <= 0 { + return proto.NewValueString(""), nil + } else { + return proto.NewValueInt64(int64(rowNumber)), nil + } +} + +func (a rownumberFunc) NumInput() int { + return 1 +} diff --git a/pkg/runtime/function/row_number_test.go b/pkg/runtime/function/row_number_test.go new file mode 100644 index 00000000..9b4a38ec --- /dev/null +++ b/pkg/runtime/function/row_number_test.go @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package function + +import ( + "context" + "fmt" + "testing" +) + +import ( + "github.com/stretchr/testify/assert" +) + +import ( + "github.com/arana-db/arana/pkg/proto" +) + +func TestFuncRowNumber(t *testing.T) { + fn := proto.MustGetFunc(FuncRowNumber) + type tt struct { + inputs [][]proto.Value + want string + } + for _, v := range []tt{ + { + [][]proto.Value{ + //order column, partition column, value column + {proto.NewValueString("12:00:00"), proto.NewValueString(""), proto.NewValueFloat64(100)}, + {proto.NewValueString("12:00:00"), proto.NewValueString(""), proto.NewValueFloat64(100)}, + {proto.NewValueString("13:00:00"), proto.NewValueString(""), proto.NewValueFloat64(125)}, + {proto.NewValueString("14:00:00"), proto.NewValueString(""), proto.NewValueFloat64(132)}, + {proto.NewValueString("15:00:00"), proto.NewValueString(""), proto.NewValueFloat64(145)}, + {proto.NewValueString("16:00:00"), proto.NewValueString(""), proto.NewValueFloat64(140)}, + {proto.NewValueString("17:00:00"), proto.NewValueString(""), proto.NewValueFloat64(150)}, + {proto.NewValueString("18:00:00"), proto.NewValueString(""), proto.NewValueFloat64(200)}, + }, + "1", + }, + { + [][]proto.Value{ + //order column, partition column, value column + {proto.NewValueString("13:00:00"), proto.NewValueString(""), proto.NewValueFloat64(125)}, + {proto.NewValueString("12:00:00"), proto.NewValueString(""), proto.NewValueFloat64(100)}, + {proto.NewValueString("13:00:00"), proto.NewValueString(""), proto.NewValueFloat64(125)}, + {proto.NewValueString("14:00:00"), proto.NewValueString(""), proto.NewValueFloat64(132)}, + {proto.NewValueString("15:00:00"), proto.NewValueString(""), proto.NewValueFloat64(145)}, + {proto.NewValueString("16:00:00"), proto.NewValueString(""), proto.NewValueFloat64(140)}, + {proto.NewValueString("17:00:00"), proto.NewValueString(""), proto.NewValueFloat64(150)}, + {proto.NewValueString("18:00:00"), proto.NewValueString(""), proto.NewValueFloat64(200)}, + }, + "2", + }, + { + [][]proto.Value{ + //order column, partition column, value column + {proto.NewValueString("14:00:00"), proto.NewValueString(""), proto.NewValueFloat64(132)}, + {proto.NewValueString("12:00:00"), proto.NewValueString(""), proto.NewValueFloat64(100)}, + {proto.NewValueString("13:00:00"), proto.NewValueString(""), proto.NewValueFloat64(125)}, + {proto.NewValueString("14:00:00"), proto.NewValueString(""), proto.NewValueFloat64(132)}, + {proto.NewValueString("15:00:00"), proto.NewValueString(""), proto.NewValueFloat64(145)}, + {proto.NewValueString("16:00:00"), proto.NewValueString(""), proto.NewValueFloat64(140)}, + {proto.NewValueString("17:00:00"), proto.NewValueString(""), proto.NewValueFloat64(150)}, + {proto.NewValueString("18:00:00"), proto.NewValueString(""), proto.NewValueFloat64(200)}, + }, + "3", + }, + { + [][]proto.Value{ + //order column, partition column, value column + {proto.NewValueString("15:00:00"), proto.NewValueString(""), proto.NewValueFloat64(145)}, + {proto.NewValueString("12:00:00"), proto.NewValueString(""), proto.NewValueFloat64(100)}, + {proto.NewValueString("13:00:00"), proto.NewValueString(""), proto.NewValueFloat64(125)}, + {proto.NewValueString("14:00:00"), proto.NewValueString(""), proto.NewValueFloat64(132)}, + {proto.NewValueString("15:00:00"), proto.NewValueString(""), proto.NewValueFloat64(145)}, + {proto.NewValueString("16:00:00"), proto.NewValueString(""), proto.NewValueFloat64(140)}, + {proto.NewValueString("17:00:00"), proto.NewValueString(""), proto.NewValueFloat64(150)}, + {proto.NewValueString("18:00:00"), proto.NewValueString(""), proto.NewValueFloat64(200)}, + }, + "4", + }, + { + [][]proto.Value{ + //order column, partition column, value column + {proto.NewValueString("16:00:00"), proto.NewValueString(""), proto.NewValueFloat64(140)}, + {proto.NewValueString("12:00:00"), proto.NewValueString(""), proto.NewValueFloat64(100)}, + {proto.NewValueString("13:00:00"), proto.NewValueString(""), proto.NewValueFloat64(125)}, + {proto.NewValueString("14:00:00"), proto.NewValueString(""), proto.NewValueFloat64(132)}, + {proto.NewValueString("15:00:00"), proto.NewValueString(""), proto.NewValueFloat64(145)}, + {proto.NewValueString("16:00:00"), proto.NewValueString(""), proto.NewValueFloat64(140)}, + {proto.NewValueString("17:00:00"), proto.NewValueString(""), proto.NewValueFloat64(150)}, + {proto.NewValueString("18:00:00"), proto.NewValueString(""), proto.NewValueFloat64(200)}, + }, + "5", + }, + { + [][]proto.Value{ + //order column, partition column, value column + {proto.NewValueString("17:00:00"), proto.NewValueString(""), proto.NewValueFloat64(150)}, + {proto.NewValueString("12:00:00"), proto.NewValueString(""), proto.NewValueFloat64(100)}, + {proto.NewValueString("13:00:00"), proto.NewValueString(""), proto.NewValueFloat64(125)}, + {proto.NewValueString("14:00:00"), proto.NewValueString(""), proto.NewValueFloat64(132)}, + {proto.NewValueString("15:00:00"), proto.NewValueString(""), proto.NewValueFloat64(145)}, + {proto.NewValueString("16:00:00"), proto.NewValueString(""), proto.NewValueFloat64(140)}, + {proto.NewValueString("17:00:00"), proto.NewValueString(""), proto.NewValueFloat64(150)}, + {proto.NewValueString("18:00:00"), proto.NewValueString(""), proto.NewValueFloat64(200)}, + }, + "6", + }, + { + [][]proto.Value{ + //order column, partition column, value column + {proto.NewValueString("18:00:00"), proto.NewValueString(""), proto.NewValueFloat64(200)}, + {proto.NewValueString("12:00:00"), proto.NewValueString(""), proto.NewValueFloat64(100)}, + {proto.NewValueString("13:00:00"), proto.NewValueString(""), proto.NewValueFloat64(125)}, + {proto.NewValueString("14:00:00"), proto.NewValueString(""), proto.NewValueFloat64(132)}, + {proto.NewValueString("15:00:00"), proto.NewValueString(""), proto.NewValueFloat64(145)}, + {proto.NewValueString("16:00:00"), proto.NewValueString(""), proto.NewValueFloat64(140)}, + {proto.NewValueString("17:00:00"), proto.NewValueString(""), proto.NewValueFloat64(150)}, + {proto.NewValueString("18:00:00"), proto.NewValueString(""), proto.NewValueFloat64(200)}, + }, + "7", + }, + } { + t.Run(v.want, func(t *testing.T) { + var inputs []proto.Valuer + for i := range v.inputs { + for j := range v.inputs[i] { + inputs = append(inputs, proto.ToValuer(v.inputs[i][j])) + } + } + out, err := fn.Apply(context.Background(), inputs...) + assert.NoError(t, err) + assert.Equal(t, v.want, fmt.Sprint(out)) + }) + } +} From a6c7ea076192003353c1a311613c2b160b4d378f Mon Sep 17 00:00:00 2001 From: csynineyang Date: Wed, 14 Jun 2023 15:18:11 +0800 Subject: [PATCH 19/20] support argument(n) in LAG/LEAD --- pkg/runtime/function/first_value.go | 2 +- pkg/runtime/function/lag.go | 24 +++++++++++++++--------- pkg/runtime/function/lag_test.go | 7 +++++++ pkg/runtime/function/last_value.go | 2 +- pkg/runtime/function/lead.go | 24 +++++++++++++++--------- pkg/runtime/function/lead_test.go | 7 +++++++ pkg/runtime/function/nth_value.go | 14 +++++++------- pkg/runtime/function/ntile.go | 3 +++ pkg/runtime/function/row_number.go | 2 +- 9 files changed, 57 insertions(+), 28 deletions(-) diff --git a/pkg/runtime/function/first_value.go b/pkg/runtime/function/first_value.go index be011d32..2883ce60 100644 --- a/pkg/runtime/function/first_value.go +++ b/pkg/runtime/function/first_value.go @@ -91,5 +91,5 @@ func (a firstvalueFunc) Apply(ctx context.Context, inputs ...proto.Valuer) (prot } func (a firstvalueFunc) NumInput() int { - return 1 + return 0 } diff --git a/pkg/runtime/function/lag.go b/pkg/runtime/function/lag.go index f882c170..39ec052d 100644 --- a/pkg/runtime/function/lag.go +++ b/pkg/runtime/function/lag.go @@ -42,31 +42,37 @@ func init() { type lagFunc struct{} func (a lagFunc) Apply(ctx context.Context, inputs ...proto.Valuer) (proto.Value, error) { - if len(inputs) < 6 { + if len(inputs) < 7 { return proto.NewValueString(""), nil } + // lag number + lagNum, err := inputs[0].Value(ctx) + if lagNum == nil || err != nil { + return nil, errors.Wrapf(err, "cannot eval %s", FuncLag) + } + lagNumInt, _ := lagNum.Int64() // order by this column - firstOrderColumn, err := inputs[0].Value(ctx) + firstOrderColumn, err := inputs[1].Value(ctx) if firstOrderColumn == nil || err != nil { return nil, errors.Wrapf(err, "cannot eval %s", FuncLag) } firstOrderColumnStr := firstOrderColumn.String() // partition by this column - firstPartitionColumn, err := inputs[1].Value(ctx) + firstPartitionColumn, err := inputs[2].Value(ctx) if firstPartitionColumn == nil || err != nil { return nil, errors.Wrapf(err, "cannot eval %s", FuncLag) } firstPartitionColumnStr := firstPartitionColumn.String() // output by this volumn - firstValueColumn, err := inputs[2].Value(ctx) + firstValueColumn, err := inputs[3].Value(ctx) if firstValueColumn == nil || err != nil { return nil, errors.Wrapf(err, "cannot eval %s", FuncLag) } firstValueColumnDec, _ := firstValueColumn.Float64() lagValue := 0.0 - lagIndex := -1 - startOffset := 3 + lagIndex := int64(-1) + startOffset := int64(4) for { orderColumn, err := inputs[startOffset].Value(ctx) @@ -87,8 +93,8 @@ func (a lagFunc) Apply(ctx context.Context, inputs ...proto.Valuer) (proto.Value if strings.Compare(firstOrderColumnStr, orderColumnStr) == 0 && strings.Compare(firstPartitionColumnStr, partitionColumnStr) == 0 && firstValueColumnDec == valueColumnDec { - if startOffset > 3 { - lagValueColumn, err := inputs[startOffset-1].Value(ctx) + if startOffset >= 4+3*lagNumInt { + lagValueColumn, err := inputs[startOffset+2-3*lagNumInt].Value(ctx) if lagValueColumn == nil || err != nil { return nil, errors.Wrapf(err, "cannot eval %s", FuncLag) } @@ -100,7 +106,7 @@ func (a lagFunc) Apply(ctx context.Context, inputs ...proto.Valuer) (proto.Value } startOffset += 3 - if startOffset >= len(inputs) { + if startOffset >= int64(len(inputs)) { break } } diff --git a/pkg/runtime/function/lag_test.go b/pkg/runtime/function/lag_test.go index 1bd006f8..1315e9e4 100644 --- a/pkg/runtime/function/lag_test.go +++ b/pkg/runtime/function/lag_test.go @@ -41,6 +41,7 @@ func TestFuncLag(t *testing.T) { { [][]proto.Value{ //order column, partition column, value column + {proto.NewValueInt64(1)}, {proto.NewValueString("12:00:00"), proto.NewValueString(""), proto.NewValueFloat64(100)}, {proto.NewValueString("12:00:00"), proto.NewValueString(""), proto.NewValueFloat64(100)}, {proto.NewValueString("13:00:00"), proto.NewValueString(""), proto.NewValueFloat64(125)}, @@ -55,6 +56,7 @@ func TestFuncLag(t *testing.T) { { [][]proto.Value{ //order column, partition column, value column + {proto.NewValueInt64(1)}, {proto.NewValueString("13:00:00"), proto.NewValueString(""), proto.NewValueFloat64(125)}, {proto.NewValueString("12:00:00"), proto.NewValueString(""), proto.NewValueFloat64(100)}, {proto.NewValueString("13:00:00"), proto.NewValueString(""), proto.NewValueFloat64(125)}, @@ -69,6 +71,7 @@ func TestFuncLag(t *testing.T) { { [][]proto.Value{ //order column, partition column, value column + {proto.NewValueInt64(1)}, {proto.NewValueString("14:00:00"), proto.NewValueString(""), proto.NewValueFloat64(132)}, {proto.NewValueString("12:00:00"), proto.NewValueString(""), proto.NewValueFloat64(100)}, {proto.NewValueString("13:00:00"), proto.NewValueString(""), proto.NewValueFloat64(125)}, @@ -83,6 +86,7 @@ func TestFuncLag(t *testing.T) { { [][]proto.Value{ //order column, partition column, value column + {proto.NewValueInt64(1)}, {proto.NewValueString("15:00:00"), proto.NewValueString(""), proto.NewValueFloat64(145)}, {proto.NewValueString("12:00:00"), proto.NewValueString(""), proto.NewValueFloat64(100)}, {proto.NewValueString("13:00:00"), proto.NewValueString(""), proto.NewValueFloat64(125)}, @@ -97,6 +101,7 @@ func TestFuncLag(t *testing.T) { { [][]proto.Value{ //order column, partition column, value column + {proto.NewValueInt64(1)}, {proto.NewValueString("16:00:00"), proto.NewValueString(""), proto.NewValueFloat64(140)}, {proto.NewValueString("12:00:00"), proto.NewValueString(""), proto.NewValueFloat64(100)}, {proto.NewValueString("13:00:00"), proto.NewValueString(""), proto.NewValueFloat64(125)}, @@ -111,6 +116,7 @@ func TestFuncLag(t *testing.T) { { [][]proto.Value{ //order column, partition column, value column + {proto.NewValueInt64(1)}, {proto.NewValueString("17:00:00"), proto.NewValueString(""), proto.NewValueFloat64(150)}, {proto.NewValueString("12:00:00"), proto.NewValueString(""), proto.NewValueFloat64(100)}, {proto.NewValueString("13:00:00"), proto.NewValueString(""), proto.NewValueFloat64(125)}, @@ -125,6 +131,7 @@ func TestFuncLag(t *testing.T) { { [][]proto.Value{ //order column, partition column, value column + {proto.NewValueInt64(1)}, {proto.NewValueString("18:00:00"), proto.NewValueString(""), proto.NewValueFloat64(200)}, {proto.NewValueString("12:00:00"), proto.NewValueString(""), proto.NewValueFloat64(100)}, {proto.NewValueString("13:00:00"), proto.NewValueString(""), proto.NewValueFloat64(125)}, diff --git a/pkg/runtime/function/last_value.go b/pkg/runtime/function/last_value.go index 30a26a22..bbbb27fa 100644 --- a/pkg/runtime/function/last_value.go +++ b/pkg/runtime/function/last_value.go @@ -104,5 +104,5 @@ func (a lastvalueFunc) Apply(ctx context.Context, inputs ...proto.Valuer) (proto } func (a lastvalueFunc) NumInput() int { - return 1 + return 0 } diff --git a/pkg/runtime/function/lead.go b/pkg/runtime/function/lead.go index 14da35f4..dd5d7ac4 100644 --- a/pkg/runtime/function/lead.go +++ b/pkg/runtime/function/lead.go @@ -42,31 +42,37 @@ func init() { type leadFunc struct{} func (a leadFunc) Apply(ctx context.Context, inputs ...proto.Valuer) (proto.Value, error) { - if len(inputs) < 6 { + if len(inputs) < 7 { return proto.NewValueString(""), nil } + // lag number + leadNum, err := inputs[0].Value(ctx) + if leadNum == nil || err != nil { + return nil, errors.Wrapf(err, "cannot eval %s", FuncLead) + } + leadNumInt, _ := leadNum.Int64() // order by this column - firstOrderColumn, err := inputs[0].Value(ctx) + firstOrderColumn, err := inputs[1].Value(ctx) if firstOrderColumn == nil || err != nil { return nil, errors.Wrapf(err, "cannot eval %s", FuncLead) } firstOrderColumnStr := firstOrderColumn.String() // partition by this column - firstPartitionColumn, err := inputs[1].Value(ctx) + firstPartitionColumn, err := inputs[2].Value(ctx) if firstPartitionColumn == nil || err != nil { return nil, errors.Wrapf(err, "cannot eval %s", FuncLead) } firstPartitionColumnStr := firstPartitionColumn.String() // output by this volumn - firstValueColumn, err := inputs[2].Value(ctx) + firstValueColumn, err := inputs[3].Value(ctx) if firstValueColumn == nil || err != nil { return nil, errors.Wrapf(err, "cannot eval %s", FuncLead) } firstValueColumnDec, _ := firstValueColumn.Float64() lagValue := 0.0 - lagIndex := -1 - startOffset := 3 + lagIndex := int64(-1) + startOffset := int64(4) for { orderColumn, err := inputs[startOffset].Value(ctx) @@ -87,8 +93,8 @@ func (a leadFunc) Apply(ctx context.Context, inputs ...proto.Valuer) (proto.Valu if strings.Compare(firstOrderColumnStr, orderColumnStr) == 0 && strings.Compare(firstPartitionColumnStr, partitionColumnStr) == 0 && firstValueColumnDec == valueColumnDec { - if startOffset+6 <= len(inputs) { - lagValueColumn, err := inputs[startOffset+5].Value(ctx) + if startOffset+2+3*leadNumInt < int64(len(inputs)) { + lagValueColumn, err := inputs[startOffset+2+3*leadNumInt].Value(ctx) if lagValueColumn == nil || err != nil { return nil, errors.Wrapf(err, "cannot eval %s", FuncLead) } @@ -100,7 +106,7 @@ func (a leadFunc) Apply(ctx context.Context, inputs ...proto.Valuer) (proto.Valu } startOffset += 3 - if startOffset >= len(inputs) { + if startOffset >= int64(len(inputs)) { break } } diff --git a/pkg/runtime/function/lead_test.go b/pkg/runtime/function/lead_test.go index 53655189..16f68104 100644 --- a/pkg/runtime/function/lead_test.go +++ b/pkg/runtime/function/lead_test.go @@ -41,6 +41,7 @@ func TestFuncLead(t *testing.T) { { [][]proto.Value{ //order column, partition column, value column + {proto.NewValueInt64(1)}, {proto.NewValueString("12:00:00"), proto.NewValueString(""), proto.NewValueFloat64(100)}, {proto.NewValueString("12:00:00"), proto.NewValueString(""), proto.NewValueFloat64(100)}, {proto.NewValueString("13:00:00"), proto.NewValueString(""), proto.NewValueFloat64(125)}, @@ -55,6 +56,7 @@ func TestFuncLead(t *testing.T) { { [][]proto.Value{ //order column, partition column, value column + {proto.NewValueInt64(1)}, {proto.NewValueString("13:00:00"), proto.NewValueString(""), proto.NewValueFloat64(125)}, {proto.NewValueString("12:00:00"), proto.NewValueString(""), proto.NewValueFloat64(100)}, {proto.NewValueString("13:00:00"), proto.NewValueString(""), proto.NewValueFloat64(125)}, @@ -69,6 +71,7 @@ func TestFuncLead(t *testing.T) { { [][]proto.Value{ //order column, partition column, value column + {proto.NewValueInt64(1)}, {proto.NewValueString("14:00:00"), proto.NewValueString(""), proto.NewValueFloat64(132)}, {proto.NewValueString("12:00:00"), proto.NewValueString(""), proto.NewValueFloat64(100)}, {proto.NewValueString("13:00:00"), proto.NewValueString(""), proto.NewValueFloat64(125)}, @@ -83,6 +86,7 @@ func TestFuncLead(t *testing.T) { { [][]proto.Value{ //order column, partition column, value column + {proto.NewValueInt64(1)}, {proto.NewValueString("15:00:00"), proto.NewValueString(""), proto.NewValueFloat64(145)}, {proto.NewValueString("12:00:00"), proto.NewValueString(""), proto.NewValueFloat64(100)}, {proto.NewValueString("13:00:00"), proto.NewValueString(""), proto.NewValueFloat64(125)}, @@ -97,6 +101,7 @@ func TestFuncLead(t *testing.T) { { [][]proto.Value{ //order column, partition column, value column + {proto.NewValueInt64(1)}, {proto.NewValueString("16:00:00"), proto.NewValueString(""), proto.NewValueFloat64(140)}, {proto.NewValueString("12:00:00"), proto.NewValueString(""), proto.NewValueFloat64(100)}, {proto.NewValueString("13:00:00"), proto.NewValueString(""), proto.NewValueFloat64(125)}, @@ -111,6 +116,7 @@ func TestFuncLead(t *testing.T) { { [][]proto.Value{ //order column, partition column, value column + {proto.NewValueInt64(1)}, {proto.NewValueString("17:00:00"), proto.NewValueString(""), proto.NewValueFloat64(150)}, {proto.NewValueString("12:00:00"), proto.NewValueString(""), proto.NewValueFloat64(100)}, {proto.NewValueString("13:00:00"), proto.NewValueString(""), proto.NewValueFloat64(125)}, @@ -125,6 +131,7 @@ func TestFuncLead(t *testing.T) { { [][]proto.Value{ //order column, partition column, value column + {proto.NewValueInt64(1)}, {proto.NewValueString("18:00:00"), proto.NewValueString(""), proto.NewValueFloat64(200)}, {proto.NewValueString("12:00:00"), proto.NewValueString(""), proto.NewValueFloat64(100)}, {proto.NewValueString("13:00:00"), proto.NewValueString(""), proto.NewValueFloat64(125)}, diff --git a/pkg/runtime/function/nth_value.go b/pkg/runtime/function/nth_value.go index f9f2a8f4..574f9c75 100644 --- a/pkg/runtime/function/nth_value.go +++ b/pkg/runtime/function/nth_value.go @@ -49,25 +49,25 @@ func (a nthvalueFunc) Apply(ctx context.Context, inputs ...proto.Valuer) (proto. // nth nth, err := inputs[0].Value(ctx) if nth == nil || err != nil { - return nil, errors.Wrapf(err, "cannot eval %s", FuncNtile) + return nil, errors.Wrapf(err, "cannot eval %s", FuncNthValue) } nthInt, _ := nth.Int64() // order by this column firstOrderColumn, err := inputs[1].Value(ctx) if firstOrderColumn == nil || err != nil { - return nil, errors.Wrapf(err, "cannot eval %s", FuncNtile) + return nil, errors.Wrapf(err, "cannot eval %s", FuncNthValue) } firstOrderColumnStr := firstOrderColumn.String() // partition by this column firstPartitionColumn, err := inputs[2].Value(ctx) if firstPartitionColumn == nil || err != nil { - return nil, errors.Wrapf(err, "cannot eval %s", FuncFirstValue) + return nil, errors.Wrapf(err, "cannot eval %s", FuncNthValue) } firstPartitionColumnStr := firstPartitionColumn.String() // output by this volumn firstValueColumn, err := inputs[3].Value(ctx) if firstValueColumn == nil || err != nil { - return nil, errors.Wrapf(err, "cannot eval %s", FuncNtile) + return nil, errors.Wrapf(err, "cannot eval %s", FuncNthValue) } firstValueColumnDec, _ := firstValueColumn.Float64() nthIndex := int64(0) @@ -77,17 +77,17 @@ func (a nthvalueFunc) Apply(ctx context.Context, inputs ...proto.Valuer) (proto. for { orderColumn, err := inputs[startOffset].Value(ctx) if orderColumn == nil || err != nil { - return nil, errors.Wrapf(err, "cannot eval %s", FuncNtile) + return nil, errors.Wrapf(err, "cannot eval %s", FuncNthValue) } orderColumnStr := orderColumn.String() partitionColumn, err := inputs[startOffset+1].Value(ctx) if partitionColumn == nil || err != nil { - return nil, errors.Wrapf(err, "cannot eval %s", FuncFirstValue) + return nil, errors.Wrapf(err, "cannot eval %s", FuncNthValue) } partitionColumnStr := partitionColumn.String() valueColumn, err := inputs[startOffset+2].Value(ctx) if valueColumn == nil || err != nil { - return nil, errors.Wrapf(err, "cannot eval %s", FuncFirstValue) + return nil, errors.Wrapf(err, "cannot eval %s", FuncNthValue) } valueColumnDec, _ := valueColumn.Float64() diff --git a/pkg/runtime/function/ntile.go b/pkg/runtime/function/ntile.go index e13ae0ce..f75a3d19 100644 --- a/pkg/runtime/function/ntile.go +++ b/pkg/runtime/function/ntile.go @@ -52,6 +52,9 @@ func (a ntileFunc) Apply(ctx context.Context, inputs ...proto.Valuer) (proto.Val return nil, errors.Wrapf(err, "cannot eval %s", FuncNtile) } bucketNumInt, _ := bucketNum.Int64() + if bucketNumInt <= 0 { + return proto.NewValueString(""), nil + } // order by this column firstOrderColumn, err := inputs[1].Value(ctx) if firstOrderColumn == nil || err != nil { diff --git a/pkg/runtime/function/row_number.go b/pkg/runtime/function/row_number.go index 1436133b..db45a3e0 100644 --- a/pkg/runtime/function/row_number.go +++ b/pkg/runtime/function/row_number.go @@ -105,5 +105,5 @@ func (a rownumberFunc) Apply(ctx context.Context, inputs ...proto.Valuer) (proto } func (a rownumberFunc) NumInput() int { - return 1 + return 0 } From c0321c75351ef6e7689264ee85d8a79e887d0787 Mon Sep 17 00:00:00 2001 From: csynineyang Date: Wed, 14 Jun 2023 15:35:56 +0800 Subject: [PATCH 20/20] convert Int64 to Float64 in test case --- pkg/runtime/function/cume_dist_test.go | 100 +++++++++++----------- pkg/runtime/function/dense_rank.go | 4 +- pkg/runtime/function/dense_rank_test.go | 100 +++++++++++----------- pkg/runtime/function/percent_rank.go | 2 +- pkg/runtime/function/percent_rank_test.go | 100 +++++++++++----------- pkg/runtime/function/rank.go | 4 +- pkg/runtime/function/rank_test.go | 100 +++++++++++----------- 7 files changed, 205 insertions(+), 205 deletions(-) diff --git a/pkg/runtime/function/cume_dist_test.go b/pkg/runtime/function/cume_dist_test.go index 9de2cfe8..23eed956 100644 --- a/pkg/runtime/function/cume_dist_test.go +++ b/pkg/runtime/function/cume_dist_test.go @@ -40,76 +40,76 @@ func TestFuncCumeDist(t *testing.T) { for _, v := range []tt{ { []proto.Value{ - proto.NewValueInt64(1), - proto.NewValueInt64(1), - proto.NewValueInt64(1), - proto.NewValueInt64(2), - proto.NewValueInt64(3), - proto.NewValueInt64(3), - proto.NewValueInt64(3), - proto.NewValueInt64(4), - proto.NewValueInt64(4), - proto.NewValueInt64(5), + proto.NewValueFloat64(1), + proto.NewValueFloat64(1), + proto.NewValueFloat64(1), + proto.NewValueFloat64(2), + proto.NewValueFloat64(3), + proto.NewValueFloat64(3), + proto.NewValueFloat64(3), + proto.NewValueFloat64(4), + proto.NewValueFloat64(4), + proto.NewValueFloat64(5), }, "0.2222222222222222", }, { []proto.Value{ - proto.NewValueInt64(2), - proto.NewValueInt64(1), - proto.NewValueInt64(1), - proto.NewValueInt64(2), - proto.NewValueInt64(3), - proto.NewValueInt64(3), - proto.NewValueInt64(3), - proto.NewValueInt64(4), - proto.NewValueInt64(4), - proto.NewValueInt64(5), + proto.NewValueFloat64(2), + proto.NewValueFloat64(1), + proto.NewValueFloat64(1), + proto.NewValueFloat64(2), + proto.NewValueFloat64(3), + proto.NewValueFloat64(3), + proto.NewValueFloat64(3), + proto.NewValueFloat64(4), + proto.NewValueFloat64(4), + proto.NewValueFloat64(5), }, "0.3333333333333333", }, { []proto.Value{ - proto.NewValueInt64(3), - proto.NewValueInt64(1), - proto.NewValueInt64(1), - proto.NewValueInt64(2), - proto.NewValueInt64(3), - proto.NewValueInt64(3), - proto.NewValueInt64(3), - proto.NewValueInt64(4), - proto.NewValueInt64(4), - proto.NewValueInt64(5), + proto.NewValueFloat64(3), + proto.NewValueFloat64(1), + proto.NewValueFloat64(1), + proto.NewValueFloat64(2), + proto.NewValueFloat64(3), + proto.NewValueFloat64(3), + proto.NewValueFloat64(3), + proto.NewValueFloat64(4), + proto.NewValueFloat64(4), + proto.NewValueFloat64(5), }, "0.6666666666666666", }, { []proto.Value{ - proto.NewValueInt64(4), - proto.NewValueInt64(1), - proto.NewValueInt64(1), - proto.NewValueInt64(2), - proto.NewValueInt64(3), - proto.NewValueInt64(3), - proto.NewValueInt64(3), - proto.NewValueInt64(4), - proto.NewValueInt64(4), - proto.NewValueInt64(5), + proto.NewValueFloat64(4), + proto.NewValueFloat64(1), + proto.NewValueFloat64(1), + proto.NewValueFloat64(2), + proto.NewValueFloat64(3), + proto.NewValueFloat64(3), + proto.NewValueFloat64(3), + proto.NewValueFloat64(4), + proto.NewValueFloat64(4), + proto.NewValueFloat64(5), }, "0.8888888888888888", }, { []proto.Value{ - proto.NewValueInt64(5), - proto.NewValueInt64(1), - proto.NewValueInt64(1), - proto.NewValueInt64(2), - proto.NewValueInt64(3), - proto.NewValueInt64(3), - proto.NewValueInt64(3), - proto.NewValueInt64(4), - proto.NewValueInt64(4), - proto.NewValueInt64(5), + proto.NewValueFloat64(5), + proto.NewValueFloat64(1), + proto.NewValueFloat64(1), + proto.NewValueFloat64(2), + proto.NewValueFloat64(3), + proto.NewValueFloat64(3), + proto.NewValueFloat64(3), + proto.NewValueFloat64(4), + proto.NewValueFloat64(4), + proto.NewValueFloat64(5), }, "1", }, diff --git a/pkg/runtime/function/dense_rank.go b/pkg/runtime/function/dense_rank.go index e441968a..f125fdb7 100644 --- a/pkg/runtime/function/dense_rank.go +++ b/pkg/runtime/function/dense_rank.go @@ -43,7 +43,7 @@ type denserankFunc struct{} func (a denserankFunc) Apply(ctx context.Context, inputs ...proto.Valuer) (proto.Value, error) { first, err := inputs[0].Value(ctx) if first == nil || err != nil { - return nil, errors.Wrapf(err, "cannot eval %s", FuncCumeDist) + return nil, errors.Wrapf(err, "cannot eval %s", FuncDenseRank) } firstDec, _ := first.Float64() secondDec := firstDec @@ -52,7 +52,7 @@ func (a denserankFunc) Apply(ctx context.Context, inputs ...proto.Valuer) (proto for _, it := range inputs[1:] { val, err := it.Value(ctx) if val == nil || err != nil { - return nil, errors.Wrapf(err, "cannot eval %s", FuncCumeDist) + return nil, errors.Wrapf(err, "cannot eval %s", FuncDenseRank) } valDec, _ := val.Float64() diff --git a/pkg/runtime/function/dense_rank_test.go b/pkg/runtime/function/dense_rank_test.go index 1daa818e..646819a6 100644 --- a/pkg/runtime/function/dense_rank_test.go +++ b/pkg/runtime/function/dense_rank_test.go @@ -40,76 +40,76 @@ func TestFuncDenseRankt(t *testing.T) { for _, v := range []tt{ { []proto.Value{ - proto.NewValueInt64(1), - proto.NewValueInt64(1), - proto.NewValueInt64(1), - proto.NewValueInt64(2), - proto.NewValueInt64(3), - proto.NewValueInt64(3), - proto.NewValueInt64(3), - proto.NewValueInt64(4), - proto.NewValueInt64(4), - proto.NewValueInt64(5), + proto.NewValueFloat64(1), + proto.NewValueFloat64(1), + proto.NewValueFloat64(1), + proto.NewValueFloat64(2), + proto.NewValueFloat64(3), + proto.NewValueFloat64(3), + proto.NewValueFloat64(3), + proto.NewValueFloat64(4), + proto.NewValueFloat64(4), + proto.NewValueFloat64(5), }, "1", }, { []proto.Value{ - proto.NewValueInt64(2), - proto.NewValueInt64(1), - proto.NewValueInt64(1), - proto.NewValueInt64(2), - proto.NewValueInt64(3), - proto.NewValueInt64(3), - proto.NewValueInt64(3), - proto.NewValueInt64(4), - proto.NewValueInt64(4), - proto.NewValueInt64(5), + proto.NewValueFloat64(2), + proto.NewValueFloat64(1), + proto.NewValueFloat64(1), + proto.NewValueFloat64(2), + proto.NewValueFloat64(3), + proto.NewValueFloat64(3), + proto.NewValueFloat64(3), + proto.NewValueFloat64(4), + proto.NewValueFloat64(4), + proto.NewValueFloat64(5), }, "2", }, { []proto.Value{ - proto.NewValueInt64(3), - proto.NewValueInt64(1), - proto.NewValueInt64(1), - proto.NewValueInt64(2), - proto.NewValueInt64(3), - proto.NewValueInt64(3), - proto.NewValueInt64(3), - proto.NewValueInt64(4), - proto.NewValueInt64(4), - proto.NewValueInt64(5), + proto.NewValueFloat64(3), + proto.NewValueFloat64(1), + proto.NewValueFloat64(1), + proto.NewValueFloat64(2), + proto.NewValueFloat64(3), + proto.NewValueFloat64(3), + proto.NewValueFloat64(3), + proto.NewValueFloat64(4), + proto.NewValueFloat64(4), + proto.NewValueFloat64(5), }, "3", }, { []proto.Value{ - proto.NewValueInt64(4), - proto.NewValueInt64(1), - proto.NewValueInt64(1), - proto.NewValueInt64(2), - proto.NewValueInt64(3), - proto.NewValueInt64(3), - proto.NewValueInt64(3), - proto.NewValueInt64(4), - proto.NewValueInt64(4), - proto.NewValueInt64(5), + proto.NewValueFloat64(4), + proto.NewValueFloat64(1), + proto.NewValueFloat64(1), + proto.NewValueFloat64(2), + proto.NewValueFloat64(3), + proto.NewValueFloat64(3), + proto.NewValueFloat64(3), + proto.NewValueFloat64(4), + proto.NewValueFloat64(4), + proto.NewValueFloat64(5), }, "4", }, { []proto.Value{ - proto.NewValueInt64(5), - proto.NewValueInt64(1), - proto.NewValueInt64(1), - proto.NewValueInt64(2), - proto.NewValueInt64(3), - proto.NewValueInt64(3), - proto.NewValueInt64(3), - proto.NewValueInt64(4), - proto.NewValueInt64(4), - proto.NewValueInt64(5), + proto.NewValueFloat64(5), + proto.NewValueFloat64(1), + proto.NewValueFloat64(1), + proto.NewValueFloat64(2), + proto.NewValueFloat64(3), + proto.NewValueFloat64(3), + proto.NewValueFloat64(3), + proto.NewValueFloat64(4), + proto.NewValueFloat64(4), + proto.NewValueFloat64(5), }, "5", }, diff --git a/pkg/runtime/function/percent_rank.go b/pkg/runtime/function/percent_rank.go index 9b5973ed..eeed8411 100644 --- a/pkg/runtime/function/percent_rank.go +++ b/pkg/runtime/function/percent_rank.go @@ -43,7 +43,7 @@ type percentrankFunc struct{} func (a percentrankFunc) Apply(ctx context.Context, inputs ...proto.Valuer) (proto.Value, error) { first, err := inputs[0].Value(ctx) if first == nil || err != nil { - return nil, errors.Wrapf(err, "cannot eval %s", FuncCumeDist) + return nil, errors.Wrapf(err, "cannot eval %s", FuncPercentRank) } firstDec, _ := first.Float64() firstNum := 0 diff --git a/pkg/runtime/function/percent_rank_test.go b/pkg/runtime/function/percent_rank_test.go index 64a36767..0621540c 100644 --- a/pkg/runtime/function/percent_rank_test.go +++ b/pkg/runtime/function/percent_rank_test.go @@ -40,76 +40,76 @@ func TestPercentRankDist(t *testing.T) { for _, v := range []tt{ { []proto.Value{ - proto.NewValueInt64(1), - proto.NewValueInt64(1), - proto.NewValueInt64(1), - proto.NewValueInt64(2), - proto.NewValueInt64(3), - proto.NewValueInt64(3), - proto.NewValueInt64(3), - proto.NewValueInt64(4), - proto.NewValueInt64(4), - proto.NewValueInt64(5), + proto.NewValueFloat64(1), + proto.NewValueFloat64(1), + proto.NewValueFloat64(1), + proto.NewValueFloat64(2), + proto.NewValueFloat64(3), + proto.NewValueFloat64(3), + proto.NewValueFloat64(3), + proto.NewValueFloat64(4), + proto.NewValueFloat64(4), + proto.NewValueFloat64(5), }, "0", }, { []proto.Value{ - proto.NewValueInt64(2), - proto.NewValueInt64(1), - proto.NewValueInt64(1), - proto.NewValueInt64(2), - proto.NewValueInt64(3), - proto.NewValueInt64(3), - proto.NewValueInt64(3), - proto.NewValueInt64(4), - proto.NewValueInt64(4), - proto.NewValueInt64(5), + proto.NewValueFloat64(2), + proto.NewValueFloat64(1), + proto.NewValueFloat64(1), + proto.NewValueFloat64(2), + proto.NewValueFloat64(3), + proto.NewValueFloat64(3), + proto.NewValueFloat64(3), + proto.NewValueFloat64(4), + proto.NewValueFloat64(4), + proto.NewValueFloat64(5), }, "0.25", }, { []proto.Value{ - proto.NewValueInt64(3), - proto.NewValueInt64(1), - proto.NewValueInt64(1), - proto.NewValueInt64(2), - proto.NewValueInt64(3), - proto.NewValueInt64(3), - proto.NewValueInt64(3), - proto.NewValueInt64(4), - proto.NewValueInt64(4), - proto.NewValueInt64(5), + proto.NewValueFloat64(3), + proto.NewValueFloat64(1), + proto.NewValueFloat64(1), + proto.NewValueFloat64(2), + proto.NewValueFloat64(3), + proto.NewValueFloat64(3), + proto.NewValueFloat64(3), + proto.NewValueFloat64(4), + proto.NewValueFloat64(4), + proto.NewValueFloat64(5), }, "0.375", }, { []proto.Value{ - proto.NewValueInt64(4), - proto.NewValueInt64(1), - proto.NewValueInt64(1), - proto.NewValueInt64(2), - proto.NewValueInt64(3), - proto.NewValueInt64(3), - proto.NewValueInt64(3), - proto.NewValueInt64(4), - proto.NewValueInt64(4), - proto.NewValueInt64(5), + proto.NewValueFloat64(4), + proto.NewValueFloat64(1), + proto.NewValueFloat64(1), + proto.NewValueFloat64(2), + proto.NewValueFloat64(3), + proto.NewValueFloat64(3), + proto.NewValueFloat64(3), + proto.NewValueFloat64(4), + proto.NewValueFloat64(4), + proto.NewValueFloat64(5), }, "0.75", }, { []proto.Value{ - proto.NewValueInt64(5), - proto.NewValueInt64(1), - proto.NewValueInt64(1), - proto.NewValueInt64(2), - proto.NewValueInt64(3), - proto.NewValueInt64(3), - proto.NewValueInt64(3), - proto.NewValueInt64(4), - proto.NewValueInt64(4), - proto.NewValueInt64(5), + proto.NewValueFloat64(5), + proto.NewValueFloat64(1), + proto.NewValueFloat64(1), + proto.NewValueFloat64(2), + proto.NewValueFloat64(3), + proto.NewValueFloat64(3), + proto.NewValueFloat64(3), + proto.NewValueFloat64(4), + proto.NewValueFloat64(4), + proto.NewValueFloat64(5), }, "1", }, diff --git a/pkg/runtime/function/rank.go b/pkg/runtime/function/rank.go index ecb303d0..9ff7e104 100644 --- a/pkg/runtime/function/rank.go +++ b/pkg/runtime/function/rank.go @@ -43,7 +43,7 @@ type rankFunc struct{} func (a rankFunc) Apply(ctx context.Context, inputs ...proto.Valuer) (proto.Value, error) { first, err := inputs[0].Value(ctx) if first == nil || err != nil { - return nil, errors.Wrapf(err, "cannot eval %s", FuncCumeDist) + return nil, errors.Wrapf(err, "cannot eval %s", FuncRank) } firstDec, _ := first.Float64() firstNum := 0 @@ -51,7 +51,7 @@ func (a rankFunc) Apply(ctx context.Context, inputs ...proto.Valuer) (proto.Valu for _, it := range inputs[1:] { val, err := it.Value(ctx) if val == nil || err != nil { - return nil, errors.Wrapf(err, "cannot eval %s", FuncCumeDist) + return nil, errors.Wrapf(err, "cannot eval %s", FuncRank) } valDec, _ := val.Float64() diff --git a/pkg/runtime/function/rank_test.go b/pkg/runtime/function/rank_test.go index da7607a6..b4befc7f 100644 --- a/pkg/runtime/function/rank_test.go +++ b/pkg/runtime/function/rank_test.go @@ -40,76 +40,76 @@ func TestFuncRank(t *testing.T) { for _, v := range []tt{ { []proto.Value{ - proto.NewValueInt64(1), - proto.NewValueInt64(1), - proto.NewValueInt64(1), - proto.NewValueInt64(2), - proto.NewValueInt64(3), - proto.NewValueInt64(3), - proto.NewValueInt64(3), - proto.NewValueInt64(4), - proto.NewValueInt64(4), - proto.NewValueInt64(5), + proto.NewValueFloat64(1), + proto.NewValueFloat64(1), + proto.NewValueFloat64(1), + proto.NewValueFloat64(2), + proto.NewValueFloat64(3), + proto.NewValueFloat64(3), + proto.NewValueFloat64(3), + proto.NewValueFloat64(4), + proto.NewValueFloat64(4), + proto.NewValueFloat64(5), }, "1", }, { []proto.Value{ - proto.NewValueInt64(2), - proto.NewValueInt64(1), - proto.NewValueInt64(1), - proto.NewValueInt64(2), - proto.NewValueInt64(3), - proto.NewValueInt64(3), - proto.NewValueInt64(3), - proto.NewValueInt64(4), - proto.NewValueInt64(4), - proto.NewValueInt64(5), + proto.NewValueFloat64(2), + proto.NewValueFloat64(1), + proto.NewValueFloat64(1), + proto.NewValueFloat64(2), + proto.NewValueFloat64(3), + proto.NewValueFloat64(3), + proto.NewValueFloat64(3), + proto.NewValueFloat64(4), + proto.NewValueFloat64(4), + proto.NewValueFloat64(5), }, "3", }, { []proto.Value{ - proto.NewValueInt64(3), - proto.NewValueInt64(1), - proto.NewValueInt64(1), - proto.NewValueInt64(2), - proto.NewValueInt64(3), - proto.NewValueInt64(3), - proto.NewValueInt64(3), - proto.NewValueInt64(4), - proto.NewValueInt64(4), - proto.NewValueInt64(5), + proto.NewValueFloat64(3), + proto.NewValueFloat64(1), + proto.NewValueFloat64(1), + proto.NewValueFloat64(2), + proto.NewValueFloat64(3), + proto.NewValueFloat64(3), + proto.NewValueFloat64(3), + proto.NewValueFloat64(4), + proto.NewValueFloat64(4), + proto.NewValueFloat64(5), }, "4", }, { []proto.Value{ - proto.NewValueInt64(4), - proto.NewValueInt64(1), - proto.NewValueInt64(1), - proto.NewValueInt64(2), - proto.NewValueInt64(3), - proto.NewValueInt64(3), - proto.NewValueInt64(3), - proto.NewValueInt64(4), - proto.NewValueInt64(4), - proto.NewValueInt64(5), + proto.NewValueFloat64(4), + proto.NewValueFloat64(1), + proto.NewValueFloat64(1), + proto.NewValueFloat64(2), + proto.NewValueFloat64(3), + proto.NewValueFloat64(3), + proto.NewValueFloat64(3), + proto.NewValueFloat64(4), + proto.NewValueFloat64(4), + proto.NewValueFloat64(5), }, "7", }, { []proto.Value{ - proto.NewValueInt64(5), - proto.NewValueInt64(1), - proto.NewValueInt64(1), - proto.NewValueInt64(2), - proto.NewValueInt64(3), - proto.NewValueInt64(3), - proto.NewValueInt64(3), - proto.NewValueInt64(4), - proto.NewValueInt64(4), - proto.NewValueInt64(5), + proto.NewValueFloat64(5), + proto.NewValueFloat64(1), + proto.NewValueFloat64(1), + proto.NewValueFloat64(2), + proto.NewValueFloat64(3), + proto.NewValueFloat64(3), + proto.NewValueFloat64(3), + proto.NewValueFloat64(4), + proto.NewValueFloat64(4), + proto.NewValueFloat64(5), }, "9", },