From e8cf97c4829e523b2c8def0d5ff9212c2695784a Mon Sep 17 00:00:00 2001 From: Gary Belvin Date: Wed, 7 Mar 2018 14:42:06 +0000 Subject: [PATCH] Add nil checks to types package --- types/logroot.go | 3 +++ types/logroot_test.go | 4 ++++ types/maproot.go | 3 +++ types/maproot_test.go | 37 +++++++++++++++++++++++++++++++++++++ 4 files changed, 47 insertions(+) diff --git a/types/logroot.go b/types/logroot.go index 69ed40be3b..2648a36166 100644 --- a/types/logroot.go +++ b/types/logroot.go @@ -40,6 +40,9 @@ type LogRoot struct { // ParseLogRoot verifies that b has the LOG_ROOT_FORMAT_V1 tag and returns a *LogRootV1 func ParseLogRoot(b []byte) (*LogRootV1, error) { + if b == nil { + return nil, fmt.Errorf("nil log root") + } // Verify version version := binary.BigEndian.Uint16(b) if version != uint16(trillian.LogRootFormat_LOG_ROOT_FORMAT_V1) { diff --git a/types/logroot_test.go b/types/logroot_test.go index 4a5609924a..c3a4567d07 100644 --- a/types/logroot_test.go +++ b/types/logroot_test.go @@ -63,6 +63,10 @@ func TestParseLogRoot(t *testing.T) { logRoot: []byte("foo"), wantErr: true, }, + { + logRoot: nil, + wantErr: true, + }, } { _, err := ParseLogRoot(tc.logRoot) if got, want := err != nil, tc.wantErr; got != want { diff --git a/types/maproot.go b/types/maproot.go index 9e390aab83..65d976f531 100644 --- a/types/maproot.go +++ b/types/maproot.go @@ -38,6 +38,9 @@ type MapRoot struct { // ParseMapRoot verifies that b has the MAP_ROOT_FORMAT_V1 tag and returns a *MapRootV1 func ParseMapRoot(b []byte) (*MapRootV1, error) { + if b == nil { + return nil, fmt.Errorf("nil map root") + } // Verify version version := binary.BigEndian.Uint16(b) if version != uint16(trillian.MapRootFormat_MAP_ROOT_FORMAT_V1) { diff --git a/types/maproot_test.go b/types/maproot_test.go index f996c06e57..6fe4b67d23 100644 --- a/types/maproot_test.go +++ b/types/maproot_test.go @@ -43,3 +43,40 @@ func TestMapRoot(t *testing.T) { } } } + +func MustSerializeMapRoot(root *MapRootV1) []byte { + b, err := SerializeMapRoot(root) + if err != nil { + panic(err) + } + return b +} + +func TestParseMapRoot(t *testing.T) { + for _, tc := range []struct { + mapRoot []byte + want *MapRootV1 + wantErr bool + }{ + { + want: &MapRootV1{ + RootHash: []byte("foo"), + Metadata: []byte{}, + }, + mapRoot: MustSerializeMapRoot(&MapRootV1{ + RootHash: []byte("foo"), + Metadata: []byte{}, + }), + }, + {mapRoot: []byte("foo"), wantErr: true}, + {mapRoot: nil, wantErr: true}, + } { + r, err := ParseMapRoot(tc.mapRoot) + if got, want := err != nil, tc.wantErr; got != want { + t.Errorf("ParseMapRoot(): %v, wantErr: %v", err, want) + } + if got, want := r, tc.want; !reflect.DeepEqual(got, want) { + t.Errorf("ParseMapRoot(): %v, want: %v", got, want) + } + } +}