From e64c78608a0cae9552e35bd4fea3efbe75e83049 Mon Sep 17 00:00:00 2001 From: Chandra Sanapala Date: Thu, 13 Feb 2025 09:56:22 +0530 Subject: [PATCH] feat: add support for NewDecimalFromApdDecimal (#119) --- expr/decimal_util.go | 21 +++++++++++++++------ expr/literals.go | 3 +++ expr/literals_test.go | 27 +++++++++++++++++++++++++++ literal/utils.go | 9 +++++++++ literal/utils_test.go | 42 ++++++++++++++++++++++++++++++++++-------- 5 files changed, 88 insertions(+), 14 deletions(-) diff --git a/expr/decimal_util.go b/expr/decimal_util.go index 524d056..53886a1 100644 --- a/expr/decimal_util.go +++ b/expr/decimal_util.go @@ -17,11 +17,7 @@ var decimalPattern = regexp.MustCompile(`^[+-]?\d*(\.\d*)?([eE][+-]?\d*)?$`) // The precision is the total number of digits in the decimal value. The precision is limited to 38 digits. // The scale is the number of digits to the right of the decimal point. The scale is limited to the precision. func DecimalStringToBytes(decimalStr string) ([16]byte, int32, int32, error) { - var ( - result [16]byte - precision int32 - scale int32 - ) + var result [16]byte strings.Trim(decimalStr, " ") if !decimalPattern.MatchString(decimalStr) { @@ -34,6 +30,19 @@ func DecimalStringToBytes(decimalStr string) ([16]byte, int32, int32, error) { return result, 0, 0, fmt.Errorf("invalid decimal string %s: %v", decimalStr, err) } + return DecimalToBytes(dec) +} + +// DecimalToBytes converts apd.Decimal to a 16-byte byte array. +// 16-byte bytes represents a little-endian 128-bit integer, to be divided by 10^Scale to get the decimal value. +// This function also returns the precision and scale of the decimal value. +func DecimalToBytes(dec *apd.Decimal) ([16]byte, int32, int32, error) { + var ( + result [16]byte + precision int32 + scale int32 + ) + if dec.Exponent > 0 { precision = int32(apd.NumDigits(&dec.Coeff)) + dec.Exponent scale = 0 @@ -42,7 +51,7 @@ func DecimalStringToBytes(decimalStr string) ([16]byte, int32, int32, error) { precision = max(int32(apd.NumDigits(&dec.Coeff)), scale+1) } if precision > 38 { - return result, precision, scale, fmt.Errorf("number %s exceeds maximum precision of 38 (%d)", decimalStr, precision) + return result, precision, scale, fmt.Errorf("number %s exceeds maximum precision of 38 (%d)", dec.String(), precision) } coefficient := dec.Coeff diff --git a/expr/literals.go b/expr/literals.go index 9212782..6d886b8 100644 --- a/expr/literals.go +++ b/expr/literals.go @@ -909,6 +909,9 @@ func NewLiteral[T allLiteralTypes](val T, nullable bool) (Literal, error) { }, }, nil case *types.Decimal: + if len(v.Value) != 16 { + return nil, fmt.Errorf("decimal value must be 16 bytes") + } return &ProtoLiteral{ Value: v.Value, Type: &types.DecimalType{ diff --git a/expr/literals_test.go b/expr/literals_test.go index 49ec7ad..59de399 100644 --- a/expr/literals_test.go +++ b/expr/literals_test.go @@ -41,6 +41,33 @@ func TestNewDecimalWithType(t *testing.T) { } } +func TestNewLiteralWithDecimalBytes(t *testing.T) { + tests := []struct { + name string + value []byte + precision int32 + scale int32 + expectedToFail bool + }{ + {"[0]byte", []byte{}, 2, 0, true}, + {"[2]byte", []byte{0x1, 0x0}, 2, 0, true}, + {"[3]byte", []byte{0x1, 0x2, 0x0}, 3, 0, true}, + {"[17]byte", []byte{0x1, 0x0, 0x0, 0x0, 0x5, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, 16, 0, true}, + + {"[16]byte", []byte{0x1, 0x0, 0x0, 0x0, 0x5, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, 16, 0, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := expr.NewLiteral[*types.Decimal](&types.Decimal{Value: tt.value, Precision: tt.precision, Scale: tt.scale}, false) + if tt.expectedToFail { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +} + func TestNewFixedLenWithType(t *testing.T) { tests := []struct { name string diff --git a/literal/utils.go b/literal/utils.go index fb602a5..eff2bef 100644 --- a/literal/utils.go +++ b/literal/utils.go @@ -7,6 +7,7 @@ import ( "strconv" "time" + "github.com/cockroachdb/apd/v3" "github.com/google/uuid" "github.com/substrait-io/substrait-go/v3/expr" "github.com/substrait-io/substrait-go/v3/proto" @@ -321,6 +322,14 @@ func NewDecimalFromString(value string) (expr.Literal, error) { return expr.NewLiteral[*types.Decimal](&types.Decimal{Value: v[:16], Precision: precision, Scale: scale}, false) } +func NewDecimalFromApdDecimal(value *apd.Decimal, nullable bool) (expr.Literal, error) { + v, precision, scale, err := expr.DecimalToBytes(value) + if err != nil { + return nil, err + } + return expr.NewLiteral[*types.Decimal](&types.Decimal{Value: v[:16], Precision: precision, Scale: scale}, nullable) +} + // NewPrecisionTimestampFromTime creates a new PrecisionTimestamp literal from a time.Time timestamp value with given precision. func NewPrecisionTimestampFromTime(precision types.TimePrecision, tm time.Time) (expr.Literal, error) { return NewPrecisionTimestamp(precision, types.GetTimeValueByPrecision(tm, precision)) diff --git a/literal/utils_test.go b/literal/utils_test.go index 65d774a..8e03e78 100644 --- a/literal/utils_test.go +++ b/literal/utils_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + "github.com/cockroachdb/apd/v3" "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -58,22 +59,36 @@ func TestNewDecimalFromString(t *testing.T) { tests := []struct { value string want expr.Literal - wantErr assert.ErrorAssertionFunc + wantErr bool }{ - {"0", createDecimalLiteral([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1, 0, false), assert.NoError}, - {"111111.222222", createDecimalLiteral([]byte{0xce, 0xb3, 0xbe, 0xde, 0x19, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 12, 6, false), assert.NoError}, - {"-111111.222222", createDecimalLiteral([]byte{0x32, 0x4c, 0x41, 0x21, 0xe6, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, 12, 6, false), assert.NoError}, - {"+1", createDecimalLiteral([]byte{0x1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1, 0, false), assert.NoError}, - {"-1", createDecimalLiteral([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, 1, 0, false), assert.NoError}, - {"not a decimal", nil, assert.Error}, + {"0", createDecimalLiteral([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1, 0, false), false}, + {"111111.222222", createDecimalLiteral([]byte{0xce, 0xb3, 0xbe, 0xde, 0x19, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 12, 6, false), false}, + {"-111111.222222", createDecimalLiteral([]byte{0x32, 0x4c, 0x41, 0x21, 0xe6, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, 12, 6, false), false}, + {"+1", createDecimalLiteral([]byte{0x1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1, 0, false), false}, + {"-1", createDecimalLiteral([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, 1, 0, false), false}, + {"not a decimal", nil, true}, } for _, tt := range tests { t.Run(tt.value, func(t *testing.T) { got, err := NewDecimalFromString(tt.value) - if !tt.wantErr(t, err, fmt.Sprintf("NewDecimalFromString(%v)", tt.value)) { + if tt.wantErr { + require.Error(t, err, fmt.Sprintf("NewDecimalFromString(%v) expected error", tt.value)) return } + require.NoError(t, err) assert.Equalf(t, tt.want, got, "NewDecimalFromString(%v)", tt.value) + + dec, _, err := apd.NewFromString(tt.value) + require.NoError(t, err) + got, err = NewDecimalFromApdDecimal(dec, false) + require.NoError(t, err) + assert.Equal(t, tt.want, got, "NewDecimalFromApdDecimal(%v)", tt.value) + + got, err = NewDecimalFromApdDecimal(dec, true) + require.NoError(t, err) + expected, err := protoLiteralWithNullability(tt.want.(*expr.ProtoLiteral), true) + require.NoError(t, err) + assert.Equal(t, expected, got, "NewDecimalFromApdDecimal(%v)", tt.value) }) } } @@ -92,6 +107,17 @@ func createDecimalLiteral(value []byte, precision int32, scale int32, isNullable }, } } + +func protoLiteralWithNullability(lit *expr.ProtoLiteral, nullable bool) (expr.Literal, error) { + nullability := proto.Type_NULLABILITY_REQUIRED + if nullable { + nullability = proto.Type_NULLABILITY_NULLABLE + } + decType := lit.GetType().WithNullability(nullability) + + return lit.WithType(decType) +} + func TestNewDecimalFromTwosComplement(t *testing.T) { type args struct { twosComplement []byte