Skip to content

Commit

Permalink
From string validation (#75)
Browse files Browse the repository at this point in the history
* chore: address linter issues

Perform changes suggested by linter; no functional changes.

* error.go: make ErrInvalidID a const

Make ErrInvalidID a constant instead of a variable. This prevent it from
being changed by external packages; a behavior that although allowed by 
the compiler, should probably be considered an invalid operation.

* add benchmark and new failing test for FromString

* fix: let decode look for additional base32 padding

Update FromString and XID.TextUnmarshal so that it looks for discarded
bits in the final source character. This ensures that XIDs that have
been manually tampered with in a way that's ignored by base32 decode,
will not pass as valid.
  • Loading branch information
smyrman authored Mar 10, 2022
1 parent 1ac68e2 commit 66f8c42
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 11 deletions.
11 changes: 11 additions & 0 deletions error.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package xid

const (
// ErrInvalidID is returned when trying to unmarshal an invalid ID.
ErrInvalidID strErr = "xid: invalid ID"
)

// strErr allows declaring errors as constants.
type strErr string

func (err strErr) Error() string { return string(err) }
22 changes: 15 additions & 7 deletions id.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ import (
"crypto/rand"
"database/sql/driver"
"encoding/binary"
"errors"
"fmt"
"hash/crc32"
"io/ioutil"
Expand All @@ -73,9 +72,6 @@ const (
)

var (
// ErrInvalidID is returned when trying to unmarshal an invalid ID
ErrInvalidID = errors.New("xid: invalid ID")

// objectIDCounter is atomically incremented when generating a new ObjectId
// using NewObjectId() function. It's used as a counter part of an id.
// This id is initialized with a random value.
Expand Down Expand Up @@ -242,7 +238,9 @@ func (id *ID) UnmarshalText(text []byte) error {
return ErrInvalidID
}
}
decode(id, text)
if !decode(id, text) {
return ErrInvalidID
}
return nil
}

Expand All @@ -260,8 +258,8 @@ func (id *ID) UnmarshalJSON(b []byte) error {
return id.UnmarshalText(b[1 : len(b)-1])
}

// decode by unrolling the stdlib base32 algorithm + removing all safe checks
func decode(id *ID, src []byte) {
// decode by unrolling the stdlib base32 algorithm + customized safe check.
func decode(id *ID, src []byte) bool {
_ = src[19]
_ = id[11]

Expand All @@ -277,6 +275,16 @@ func decode(id *ID, src []byte) {
id[2] = dec[src[3]]<<4 | dec[src[4]]>>1
id[1] = dec[src[1]]<<6 | dec[src[2]]<<1 | dec[src[3]]>>4
id[0] = dec[src[0]]<<3 | dec[src[1]]>>2

// Validate that there are no discarer bits (padding) in src that would
// cause the string-encoded id not to equal src.
var check [4]byte

check[3] = encoding[(id[11]<<4)&0x1F]
check[2] = encoding[(id[11]>>1)&0x1F]
check[1] = encoding[(id[11]>>6)&0x1F|(id[10]<<2)&0x1F]
check[0] = encoding[id[10]>>3]
return bytes.Equal([]byte(src[16:20]), check[:])
}

// Time returns the timestamp part of the id.
Expand Down
64 changes: 60 additions & 4 deletions id_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ import (
"encoding/json"
"errors"
"fmt"
"math/rand"
"reflect"
"testing"
"testing/quick"
"time"
)

Expand All @@ -19,21 +21,21 @@ type IDParts struct {
}

var IDs = []IDParts{
IDParts{
{
ID{0x4d, 0x88, 0xe1, 0x5b, 0x60, 0xf4, 0x86, 0xe4, 0x28, 0x41, 0x2d, 0xc9},
1300816219,
[]byte{0x60, 0xf4, 0x86},
0xe428,
4271561,
},
IDParts{
{
ID{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
0,
[]byte{0x00, 0x00, 0x00},
0x0000,
0,
},
IDParts{
{
ID{0x00, 0x00, 0x00, 0x00, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0x00, 0x00, 0x01},
0,
[]byte{0xaa, 0xbb, 0xcc},
Expand Down Expand Up @@ -252,6 +254,60 @@ func BenchmarkFromString(b *testing.B) {
})
}

func TestFromStringQuick(t *testing.T) {
f := func(id1 ID, c byte) bool {
s1 := id1.String()
for i := range s1 {
s2 := []byte(s1)
s2[i] = c
id2, err := FromString(string(s2))
if id1 == id2 && err == nil && c != s1[i] {
t.Logf("comparing XIDs:\na: %q\nb: %q (index %d changed to %c)", s1, s2, i, c)
return false
}
}
return true
}
err := quick.Check(f, &quick.Config{
Values: func(args []reflect.Value, r *rand.Rand) {
i := r.Intn(len(encoding))
args[0] = reflect.ValueOf(New())
args[1] = reflect.ValueOf(byte(encoding[i]))
},
MaxCount: 1000,
})
if err != nil {
t.Error(err)
}
}

func TestFromStringQuickInvalidChars(t *testing.T) {
f := func(id1 ID, c byte) bool {
s1 := id1.String()
for i := range s1 {
s2 := []byte(s1)
s2[i] = c
id2, err := FromString(string(s2))
if id1 == id2 && err == nil && c != s1[i] {
t.Logf("comparing XIDs:\na: %q\nb: %q (index %d changed to %c)", s1, s2, i, c)
return false
}
}
return true
}
err := quick.Check(f, &quick.Config{
Values: func(args []reflect.Value, r *rand.Rand) {
i := r.Intn(0xFF)
args[0] = reflect.ValueOf(New())
args[1] = reflect.ValueOf(byte(i))
},
MaxCount: 2000,
})
if err != nil {
t.Error(err)
}
}

// func BenchmarkUUIDv1(b *testing.B) {
// b.RunParallel(func(pb *testing.PB) {
// for pb.Next() {
Expand Down Expand Up @@ -329,7 +385,7 @@ func TestFromBytes_InvalidBytes(t *testing.T) {
{13, true},
}
for _, c := range cases {
b := make([]byte, c.length, c.length)
b := make([]byte, c.length)
_, err := FromBytes(b)
if got, want := err != nil, c.shouldFail; got != want {
t.Errorf("FromBytes() error got %v, want %v", got, want)
Expand Down

0 comments on commit 66f8c42

Please sign in to comment.