diff --git a/pkg/runtime/function/cast_decimal.go b/pkg/runtime/function/cast_decimal.go new file mode 100644 index 00000000..8296c23f --- /dev/null +++ b/pkg/runtime/function/cast_decimal.go @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package function + +import ( + "context" + "fmt" +) + +import ( + "github.com/pkg/errors" +) + +import ( + "github.com/arana-db/arana/pkg/proto" +) + +// FuncCastDecimal is https://dev.mysql.com/doc/refman/5.6/en/cast-functions.html#function_cast +const FuncCastDecimal = "CAST_DECIMAL" + +var ( + _defaultDecimalPrecision proto.Value = proto.NewValueInt64(10) + _defaultDecimalScale proto.Value = proto.NewValueInt64(0) +) + +var _ proto.Func = (*castDecimalFunc)(nil) + +func init() { + proto.RegisterFunc(FuncCastDecimal, castDecimalFunc{}) +} + +type castDecimalFunc struct{} + +func (a castDecimalFunc) Apply(ctx context.Context, inputs ...proto.Valuer) (proto.Value, error) { + if len(inputs) != 3 { + return nil, errors.New("The Decimal function must accept three parameters\n") + } + + val, err := inputs[0].Value(ctx) + if err != nil { + return nil, errors.WithStack(err) + } + d, err := val.Decimal() + if err != nil { + return proto.NewValueFloat64(0), nil + } + + precision, err := inputs[1].Value(ctx) + if err != nil { + return nil, errors.WithStack(err) + } + + if precision == nil { + precision = _defaultDecimalPrecision + } + + p, err := precision.Int64() + if err != nil { + return nil, errors.WithStack(err) + } + + scale, err := inputs[2].Value(ctx) + if err != nil { + return nil, errors.WithStack(err) + } + + if scale == nil { + scale = _defaultDecimalScale + } + s, err := scale.Int64() + if err != nil { + return nil, errors.WithStack(err) + } + + // M must be >= D + if p < s { + return nil, errors.WithStack(fmt.Errorf("for float(M,D), double(M,D) or decimal(M,D), M must be >= D")) + } + + return proto.NewValueString(d.StringFixed(int32(s))), nil + +} + +func (a castDecimalFunc) NumInput() int { + return 3 +} diff --git a/pkg/runtime/function/cast_decimal_test.go b/pkg/runtime/function/cast_decimal_test.go new file mode 100644 index 00000000..28ef42b6 --- /dev/null +++ b/pkg/runtime/function/cast_decimal_test.go @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package function + +import ( + "context" + "fmt" + "testing" +) + +import ( + "github.com/stretchr/testify/assert" +) + +import ( + "github.com/arana-db/arana/pkg/proto" +) + +func TestCastDecimal(t *testing.T) { + fn := proto.MustGetFunc(FuncCastDecimal) + assert.Equal(t, 3, fn.NumInput()) + + type tt struct { + inFirst proto.Value + inSecond proto.Value + intThird proto.Value + out string + } + + for _, it := range []tt{ + {proto.NewValueInt64(15), proto.NewValueInt64(4), proto.NewValueInt64(2), "15.00"}, + {proto.NewValueInt64(15), proto.NewValueInt64(4), proto.NewValueInt64(0), "15"}, + {proto.NewValueInt64(15), proto.NewValueInt64(4), (nil), ("15")}, + {proto.NewValueInt64(15), nil, nil, ("15")}, + {proto.NewValueInt64(15), proto.NewValueInt64(0), nil, ("15")}, + {proto.NewValueFloat64(8 / 5.0), proto.NewValueInt64(11), proto.NewValueInt64(4), "1.6000"}, + {proto.NewValueString(".885"), proto.NewValueInt64(11), proto.NewValueInt64(3), "0.885"}, + {proto.NewValueString(".885"), proto.NewValueInt64(11), proto.NewValueInt64(4), "0.8850"}, + {proto.NewValueString(".885"), proto.NewValueInt64(2), proto.NewValueInt64(1), "0.9"}, + {proto.NewValueString(".885"), proto.NewValueInt64(2), nil, "1"}, + {proto.NewValueString(".885"), proto.NewValueInt64(2), proto.NewValueInt64(0), "1"}, + {proto.NewValueString(".885"), proto.NewValueInt64(20), proto.NewValueInt64(0), "1"}, + } { + t.Run(it.out, func(t *testing.T) { + out, err := fn.Apply(context.Background(), proto.ToValuer(it.inFirst), proto.ToValuer(it.inSecond), proto.ToValuer(it.intThird)) + assert.NoError(t, err) + assert.Equal(t, it.out, fmt.Sprint(out)) + }) + } +}