Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for NewDecimalFromApdDecimal #119

Merged
merged 1 commit into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions expr/decimal_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,7 @@ var decimalPattern = regexp.MustCompile(`^[+-]?\d*(\.\d*)?([eE][+-]?\d*)?$`)
// The precision is the total number of digits in the decimal value. The precision is limited to 38 digits.
// The scale is the number of digits to the right of the decimal point. The scale is limited to the precision.
func DecimalStringToBytes(decimalStr string) ([16]byte, int32, int32, error) {
var (
result [16]byte
precision int32
scale int32
)
var result [16]byte

strings.Trim(decimalStr, " ")
if !decimalPattern.MatchString(decimalStr) {
Expand All @@ -34,6 +30,19 @@ func DecimalStringToBytes(decimalStr string) ([16]byte, int32, int32, error) {
return result, 0, 0, fmt.Errorf("invalid decimal string %s: %v", decimalStr, err)
}

return DecimalToBytes(dec)
}

// DecimalToBytes converts apd.Decimal to a 16-byte byte array.
// 16-byte bytes represents a little-endian 128-bit integer, to be divided by 10^Scale to get the decimal value.
// This function also returns the precision and scale of the decimal value.
func DecimalToBytes(dec *apd.Decimal) ([16]byte, int32, int32, error) {
var (
result [16]byte
precision int32
scale int32
)

if dec.Exponent > 0 {
precision = int32(apd.NumDigits(&dec.Coeff)) + dec.Exponent
scale = 0
Expand All @@ -42,7 +51,7 @@ func DecimalStringToBytes(decimalStr string) ([16]byte, int32, int32, error) {
precision = max(int32(apd.NumDigits(&dec.Coeff)), scale+1)
}
if precision > 38 {
return result, precision, scale, fmt.Errorf("number %s exceeds maximum precision of 38 (%d)", decimalStr, precision)
return result, precision, scale, fmt.Errorf("number %s exceeds maximum precision of 38 (%d)", dec.String(), precision)
}

coefficient := dec.Coeff
Expand Down
3 changes: 3 additions & 0 deletions expr/literals.go
Original file line number Diff line number Diff line change
Expand Up @@ -909,6 +909,9 @@ func NewLiteral[T allLiteralTypes](val T, nullable bool) (Literal, error) {
},
}, nil
case *types.Decimal:
if len(v.Value) != 16 {
return nil, fmt.Errorf("decimal value must be 16 bytes")
}
return &ProtoLiteral{
Value: v.Value,
Type: &types.DecimalType{
Expand Down
27 changes: 27 additions & 0 deletions expr/literals_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,33 @@ func TestNewDecimalWithType(t *testing.T) {
}
}

func TestNewLiteralWithDecimalBytes(t *testing.T) {
tests := []struct {
name string
value []byte
precision int32
scale int32
expectedToFail bool
}{
{"[0]byte", []byte{}, 2, 0, true},
{"[2]byte", []byte{0x1, 0x0}, 2, 0, true},
{"[3]byte", []byte{0x1, 0x2, 0x0}, 3, 0, true},
{"[17]byte", []byte{0x1, 0x0, 0x0, 0x0, 0x5, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, 16, 0, true},

{"[16]byte", []byte{0x1, 0x0, 0x0, 0x0, 0x5, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, 16, 0, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := expr.NewLiteral[*types.Decimal](&types.Decimal{Value: tt.value, Precision: tt.precision, Scale: tt.scale}, false)
if tt.expectedToFail {
require.Error(t, err)
} else {
require.NoError(t, err)
}
})
}
}

func TestNewFixedLenWithType(t *testing.T) {
tests := []struct {
name string
Expand Down
9 changes: 9 additions & 0 deletions literal/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"strconv"
"time"

"github.com/cockroachdb/apd/v3"
"github.com/google/uuid"
"github.com/substrait-io/substrait-go/v3/expr"
"github.com/substrait-io/substrait-go/v3/proto"
Expand Down Expand Up @@ -321,6 +322,14 @@
return expr.NewLiteral[*types.Decimal](&types.Decimal{Value: v[:16], Precision: precision, Scale: scale}, false)
}

func NewDecimalFromApdDecimal(value *apd.Decimal, nullable bool) (expr.Literal, error) {
v, precision, scale, err := expr.DecimalToBytes(value)
if err != nil {
return nil, err
}

Check warning on line 329 in literal/utils.go

View check run for this annotation

Codecov / codecov/patch

literal/utils.go#L328-L329

Added lines #L328 - L329 were not covered by tests
return expr.NewLiteral[*types.Decimal](&types.Decimal{Value: v[:16], Precision: precision, Scale: scale}, nullable)
}

// NewPrecisionTimestampFromTime creates a new PrecisionTimestamp literal from a time.Time timestamp value with given precision.
func NewPrecisionTimestampFromTime(precision types.TimePrecision, tm time.Time) (expr.Literal, error) {
return NewPrecisionTimestamp(precision, types.GetTimeValueByPrecision(tm, precision))
Expand Down
42 changes: 34 additions & 8 deletions literal/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"testing"
"time"

"github.com/cockroachdb/apd/v3"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -58,22 +59,36 @@ func TestNewDecimalFromString(t *testing.T) {
tests := []struct {
value string
want expr.Literal
wantErr assert.ErrorAssertionFunc
wantErr bool
}{
{"0", createDecimalLiteral([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1, 0, false), assert.NoError},
{"111111.222222", createDecimalLiteral([]byte{0xce, 0xb3, 0xbe, 0xde, 0x19, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 12, 6, false), assert.NoError},
{"-111111.222222", createDecimalLiteral([]byte{0x32, 0x4c, 0x41, 0x21, 0xe6, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, 12, 6, false), assert.NoError},
{"+1", createDecimalLiteral([]byte{0x1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1, 0, false), assert.NoError},
{"-1", createDecimalLiteral([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, 1, 0, false), assert.NoError},
{"not a decimal", nil, assert.Error},
{"0", createDecimalLiteral([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1, 0, false), false},
{"111111.222222", createDecimalLiteral([]byte{0xce, 0xb3, 0xbe, 0xde, 0x19, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 12, 6, false), false},
{"-111111.222222", createDecimalLiteral([]byte{0x32, 0x4c, 0x41, 0x21, 0xe6, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, 12, 6, false), false},
{"+1", createDecimalLiteral([]byte{0x1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1, 0, false), false},
{"-1", createDecimalLiteral([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, 1, 0, false), false},
{"not a decimal", nil, true},
}
for _, tt := range tests {
t.Run(tt.value, func(t *testing.T) {
got, err := NewDecimalFromString(tt.value)
if !tt.wantErr(t, err, fmt.Sprintf("NewDecimalFromString(%v)", tt.value)) {
if tt.wantErr {
require.Error(t, err, fmt.Sprintf("NewDecimalFromString(%v) expected error", tt.value))
return
}
require.NoError(t, err)
assert.Equalf(t, tt.want, got, "NewDecimalFromString(%v)", tt.value)

dec, _, err := apd.NewFromString(tt.value)
require.NoError(t, err)
got, err = NewDecimalFromApdDecimal(dec, false)
require.NoError(t, err)
assert.Equal(t, tt.want, got, "NewDecimalFromApdDecimal(%v)", tt.value)

got, err = NewDecimalFromApdDecimal(dec, true)
require.NoError(t, err)
expected, err := protoLiteralWithNullability(tt.want.(*expr.ProtoLiteral), true)
require.NoError(t, err)
assert.Equal(t, expected, got, "NewDecimalFromApdDecimal(%v)", tt.value)
})
}
}
Expand All @@ -92,6 +107,17 @@ func createDecimalLiteral(value []byte, precision int32, scale int32, isNullable
},
}
}

func protoLiteralWithNullability(lit *expr.ProtoLiteral, nullable bool) (expr.Literal, error) {
nullability := proto.Type_NULLABILITY_REQUIRED
if nullable {
nullability = proto.Type_NULLABILITY_NULLABLE
}
decType := lit.GetType().WithNullability(nullability)

return lit.WithType(decType)
}

func TestNewDecimalFromTwosComplement(t *testing.T) {
type args struct {
twosComplement []byte
Expand Down
Loading