Skip to content

Commit

Permalink
feat: add support for NewDecimalFromApdDecimal (#119)
Browse files Browse the repository at this point in the history
  • Loading branch information
scgkiran authored Feb 13, 2025
1 parent 261dc94 commit e64c786
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 14 deletions.
21 changes: 15 additions & 6 deletions expr/decimal_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,7 @@ var decimalPattern = regexp.MustCompile(`^[+-]?\d*(\.\d*)?([eE][+-]?\d*)?$`)
// The precision is the total number of digits in the decimal value. The precision is limited to 38 digits.
// The scale is the number of digits to the right of the decimal point. The scale is limited to the precision.
func DecimalStringToBytes(decimalStr string) ([16]byte, int32, int32, error) {
var (
result [16]byte
precision int32
scale int32
)
var result [16]byte

strings.Trim(decimalStr, " ")
if !decimalPattern.MatchString(decimalStr) {
Expand All @@ -34,6 +30,19 @@ func DecimalStringToBytes(decimalStr string) ([16]byte, int32, int32, error) {
return result, 0, 0, fmt.Errorf("invalid decimal string %s: %v", decimalStr, err)
}

return DecimalToBytes(dec)
}

// DecimalToBytes converts apd.Decimal to a 16-byte byte array.
// 16-byte bytes represents a little-endian 128-bit integer, to be divided by 10^Scale to get the decimal value.
// This function also returns the precision and scale of the decimal value.
func DecimalToBytes(dec *apd.Decimal) ([16]byte, int32, int32, error) {
var (
result [16]byte
precision int32
scale int32
)

if dec.Exponent > 0 {
precision = int32(apd.NumDigits(&dec.Coeff)) + dec.Exponent
scale = 0
Expand All @@ -42,7 +51,7 @@ func DecimalStringToBytes(decimalStr string) ([16]byte, int32, int32, error) {
precision = max(int32(apd.NumDigits(&dec.Coeff)), scale+1)
}
if precision > 38 {
return result, precision, scale, fmt.Errorf("number %s exceeds maximum precision of 38 (%d)", decimalStr, precision)
return result, precision, scale, fmt.Errorf("number %s exceeds maximum precision of 38 (%d)", dec.String(), precision)
}

coefficient := dec.Coeff
Expand Down
3 changes: 3 additions & 0 deletions expr/literals.go
Original file line number Diff line number Diff line change
Expand Up @@ -909,6 +909,9 @@ func NewLiteral[T allLiteralTypes](val T, nullable bool) (Literal, error) {
},
}, nil
case *types.Decimal:
if len(v.Value) != 16 {
return nil, fmt.Errorf("decimal value must be 16 bytes")
}
return &ProtoLiteral{
Value: v.Value,
Type: &types.DecimalType{
Expand Down
27 changes: 27 additions & 0 deletions expr/literals_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,33 @@ func TestNewDecimalWithType(t *testing.T) {
}
}

func TestNewLiteralWithDecimalBytes(t *testing.T) {
tests := []struct {
name string
value []byte
precision int32
scale int32
expectedToFail bool
}{
{"[0]byte", []byte{}, 2, 0, true},
{"[2]byte", []byte{0x1, 0x0}, 2, 0, true},
{"[3]byte", []byte{0x1, 0x2, 0x0}, 3, 0, true},
{"[17]byte", []byte{0x1, 0x0, 0x0, 0x0, 0x5, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, 16, 0, true},

{"[16]byte", []byte{0x1, 0x0, 0x0, 0x0, 0x5, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, 16, 0, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := expr.NewLiteral[*types.Decimal](&types.Decimal{Value: tt.value, Precision: tt.precision, Scale: tt.scale}, false)
if tt.expectedToFail {
require.Error(t, err)
} else {
require.NoError(t, err)
}
})
}
}

func TestNewFixedLenWithType(t *testing.T) {
tests := []struct {
name string
Expand Down
9 changes: 9 additions & 0 deletions literal/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"strconv"
"time"

"github.com/cockroachdb/apd/v3"
"github.com/google/uuid"
"github.com/substrait-io/substrait-go/v3/expr"
"github.com/substrait-io/substrait-go/v3/proto"
Expand Down Expand Up @@ -321,6 +322,14 @@ func NewDecimalFromString(value string) (expr.Literal, error) {
return expr.NewLiteral[*types.Decimal](&types.Decimal{Value: v[:16], Precision: precision, Scale: scale}, false)
}

func NewDecimalFromApdDecimal(value *apd.Decimal, nullable bool) (expr.Literal, error) {
v, precision, scale, err := expr.DecimalToBytes(value)
if err != nil {
return nil, err
}
return expr.NewLiteral[*types.Decimal](&types.Decimal{Value: v[:16], Precision: precision, Scale: scale}, nullable)
}

// NewPrecisionTimestampFromTime creates a new PrecisionTimestamp literal from a time.Time timestamp value with given precision.
func NewPrecisionTimestampFromTime(precision types.TimePrecision, tm time.Time) (expr.Literal, error) {
return NewPrecisionTimestamp(precision, types.GetTimeValueByPrecision(tm, precision))
Expand Down
42 changes: 34 additions & 8 deletions literal/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"testing"
"time"

"github.com/cockroachdb/apd/v3"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -58,22 +59,36 @@ func TestNewDecimalFromString(t *testing.T) {
tests := []struct {
value string
want expr.Literal
wantErr assert.ErrorAssertionFunc
wantErr bool
}{
{"0", createDecimalLiteral([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1, 0, false), assert.NoError},
{"111111.222222", createDecimalLiteral([]byte{0xce, 0xb3, 0xbe, 0xde, 0x19, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 12, 6, false), assert.NoError},
{"-111111.222222", createDecimalLiteral([]byte{0x32, 0x4c, 0x41, 0x21, 0xe6, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, 12, 6, false), assert.NoError},
{"+1", createDecimalLiteral([]byte{0x1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1, 0, false), assert.NoError},
{"-1", createDecimalLiteral([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, 1, 0, false), assert.NoError},
{"not a decimal", nil, assert.Error},
{"0", createDecimalLiteral([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1, 0, false), false},
{"111111.222222", createDecimalLiteral([]byte{0xce, 0xb3, 0xbe, 0xde, 0x19, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 12, 6, false), false},
{"-111111.222222", createDecimalLiteral([]byte{0x32, 0x4c, 0x41, 0x21, 0xe6, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, 12, 6, false), false},
{"+1", createDecimalLiteral([]byte{0x1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1, 0, false), false},
{"-1", createDecimalLiteral([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, 1, 0, false), false},
{"not a decimal", nil, true},
}
for _, tt := range tests {
t.Run(tt.value, func(t *testing.T) {
got, err := NewDecimalFromString(tt.value)
if !tt.wantErr(t, err, fmt.Sprintf("NewDecimalFromString(%v)", tt.value)) {
if tt.wantErr {
require.Error(t, err, fmt.Sprintf("NewDecimalFromString(%v) expected error", tt.value))
return
}
require.NoError(t, err)
assert.Equalf(t, tt.want, got, "NewDecimalFromString(%v)", tt.value)

dec, _, err := apd.NewFromString(tt.value)
require.NoError(t, err)
got, err = NewDecimalFromApdDecimal(dec, false)
require.NoError(t, err)
assert.Equal(t, tt.want, got, "NewDecimalFromApdDecimal(%v)", tt.value)

got, err = NewDecimalFromApdDecimal(dec, true)
require.NoError(t, err)
expected, err := protoLiteralWithNullability(tt.want.(*expr.ProtoLiteral), true)
require.NoError(t, err)
assert.Equal(t, expected, got, "NewDecimalFromApdDecimal(%v)", tt.value)
})
}
}
Expand All @@ -92,6 +107,17 @@ func createDecimalLiteral(value []byte, precision int32, scale int32, isNullable
},
}
}

func protoLiteralWithNullability(lit *expr.ProtoLiteral, nullable bool) (expr.Literal, error) {
nullability := proto.Type_NULLABILITY_REQUIRED
if nullable {
nullability = proto.Type_NULLABILITY_NULLABLE
}
decType := lit.GetType().WithNullability(nullability)

return lit.WithType(decType)
}

func TestNewDecimalFromTwosComplement(t *testing.T) {
type args struct {
twosComplement []byte
Expand Down

0 comments on commit e64c786

Please sign in to comment.