Skip to content

Commit

Permalink
Support comparing byte slice (#1202)
Browse files Browse the repository at this point in the history
* support comparing byte slice

Signed-off-by: Ryan Leung <[email protected]>

* address the comment

Signed-off-by: Ryan Leung <[email protected]>
  • Loading branch information
rleungx authored Jun 21, 2022
1 parent 48391ba commit c31ea03
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 1 deletion.
24 changes: 23 additions & 1 deletion assert/assertion_compare.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package assert

import (
"bytes"
"fmt"
"reflect"
"time"
Expand Down Expand Up @@ -32,7 +33,8 @@ var (

stringType = reflect.TypeOf("")

timeType = reflect.TypeOf(time.Time{})
timeType = reflect.TypeOf(time.Time{})
bytesType = reflect.TypeOf([]byte{})
)

func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) {
Expand Down Expand Up @@ -323,6 +325,26 @@ func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) {

return compare(timeObj1.UnixNano(), timeObj2.UnixNano(), reflect.Int64)
}
case reflect.Slice:
{
// We only care about the []byte type.
if !canConvert(obj1Value, bytesType) {
break
}

// []byte can be compared!
bytesObj1, ok := obj1.([]byte)
if !ok {
bytesObj1 = obj1Value.Convert(bytesType).Interface().([]byte)

}
bytesObj2, ok := obj2.([]byte)
if !ok {
bytesObj2 = obj2Value.Convert(bytesType).Interface().([]byte)
}

return CompareType(bytes.Compare(bytesObj1, bytesObj2)), true
}
}

return compareEqual, false
Expand Down
128 changes: 128 additions & 0 deletions assert/assertion_compare_go1.17_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,24 @@
package assert

import (
"bytes"
"reflect"
"testing"
"time"
)

func TestCompare17(t *testing.T) {
type customTime time.Time
type customBytes []byte
for _, currCase := range []struct {
less interface{}
greater interface{}
cType string
}{
{less: time.Now(), greater: time.Now().Add(time.Hour), cType: "time.Time"},
{less: customTime(time.Now()), greater: customTime(time.Now().Add(time.Hour)), cType: "time.Time"},
{less: []byte{1, 1}, greater: []byte{1, 2}, cType: "[]byte"},
{less: customBytes([]byte{1, 1}), greater: customBytes([]byte{1, 2}), cType: "[]byte"},
} {
resLess, isComparable := compare(currCase.less, currCase.greater, reflect.ValueOf(currCase.less).Kind())
if !isComparable {
Expand Down Expand Up @@ -52,3 +56,127 @@ func TestCompare17(t *testing.T) {
}
}
}

func TestGreater17(t *testing.T) {
mockT := new(testing.T)

if !Greater(mockT, 2, 1) {
t.Error("Greater should return true")
}

if Greater(mockT, 1, 1) {
t.Error("Greater should return false")
}

if Greater(mockT, 1, 2) {
t.Error("Greater should return false")
}

// Check error report
for _, currCase := range []struct {
less interface{}
greater interface{}
msg string
}{
{less: []byte{1, 1}, greater: []byte{1, 2}, msg: `"[1 1]" is not greater than "[1 2]"`},
{less: time.Time{}, greater: time.Time{}.Add(time.Hour), msg: `"0001-01-01 00:00:00 +0000 UTC" is not greater than "0001-01-01 01:00:00 +0000 UTC"`},
} {
out := &outputT{buf: bytes.NewBuffer(nil)}
False(t, Greater(out, currCase.less, currCase.greater))
Contains(t, out.buf.String(), currCase.msg)
Contains(t, out.helpers, "github.com/stretchr/testify/assert.Greater")
}
}

func TestGreaterOrEqual17(t *testing.T) {
mockT := new(testing.T)

if !GreaterOrEqual(mockT, 2, 1) {
t.Error("GreaterOrEqual should return true")
}

if !GreaterOrEqual(mockT, 1, 1) {
t.Error("GreaterOrEqual should return true")
}

if GreaterOrEqual(mockT, 1, 2) {
t.Error("GreaterOrEqual should return false")
}

// Check error report
for _, currCase := range []struct {
less interface{}
greater interface{}
msg string
}{
{less: []byte{1, 1}, greater: []byte{1, 2}, msg: `"[1 1]" is not greater than or equal to "[1 2]"`},
{less: time.Time{}, greater: time.Time{}.Add(time.Hour), msg: `"0001-01-01 00:00:00 +0000 UTC" is not greater than or equal to "0001-01-01 01:00:00 +0000 UTC"`},
} {
out := &outputT{buf: bytes.NewBuffer(nil)}
False(t, GreaterOrEqual(out, currCase.less, currCase.greater))
Contains(t, out.buf.String(), currCase.msg)
Contains(t, out.helpers, "github.com/stretchr/testify/assert.GreaterOrEqual")
}
}

func TestLess17(t *testing.T) {
mockT := new(testing.T)

if !Less(mockT, 1, 2) {
t.Error("Less should return true")
}

if Less(mockT, 1, 1) {
t.Error("Less should return false")
}

if Less(mockT, 2, 1) {
t.Error("Less should return false")
}

// Check error report
for _, currCase := range []struct {
less interface{}
greater interface{}
msg string
}{
{less: []byte{1, 1}, greater: []byte{1, 2}, msg: `"[1 2]" is not less than "[1 1]"`},
{less: time.Time{}, greater: time.Time{}.Add(time.Hour), msg: `"0001-01-01 01:00:00 +0000 UTC" is not less than "0001-01-01 00:00:00 +0000 UTC"`},
} {
out := &outputT{buf: bytes.NewBuffer(nil)}
False(t, Less(out, currCase.greater, currCase.less))
Contains(t, out.buf.String(), currCase.msg)
Contains(t, out.helpers, "github.com/stretchr/testify/assert.Less")
}
}

func TestLessOrEqual17(t *testing.T) {
mockT := new(testing.T)

if !LessOrEqual(mockT, 1, 2) {
t.Error("LessOrEqual should return true")
}

if !LessOrEqual(mockT, 1, 1) {
t.Error("LessOrEqual should return true")
}

if LessOrEqual(mockT, 2, 1) {
t.Error("LessOrEqual should return false")
}

// Check error report
for _, currCase := range []struct {
less interface{}
greater interface{}
msg string
}{
{less: []byte{1, 1}, greater: []byte{1, 2}, msg: `"[1 2]" is not less than or equal to "[1 1]"`},
{less: time.Time{}, greater: time.Time{}.Add(time.Hour), msg: `"0001-01-01 01:00:00 +0000 UTC" is not less than or equal to "0001-01-01 00:00:00 +0000 UTC"`},
} {
out := &outputT{buf: bytes.NewBuffer(nil)}
False(t, LessOrEqual(out, currCase.greater, currCase.less))
Contains(t, out.buf.String(), currCase.msg)
Contains(t, out.helpers, "github.com/stretchr/testify/assert.LessOrEqual")
}
}

0 comments on commit c31ea03

Please sign in to comment.