diff --git a/asn1_test.go b/asn1_test.go index eebb0a9..b54d1b1 100644 --- a/asn1_test.go +++ b/asn1_test.go @@ -4,6 +4,7 @@ import ( "math/big" "reflect" "testing" + "fmt" ) // isBytesEqual compares two byte arrays/slices. @@ -26,9 +27,14 @@ type testCase struct { expected []byte } +func (t testCase) String() string { + return fmt.Sprintf("testCase: value %#v (%T) expects %#v", t.value, t.value, t.expected) +} + // testEncode encodes an object and compares with the expected bytes. func testEncode(t *testing.T, ctx *Context, options string, tests ...testCase) { for _, test := range tests { + t.Logf("Testing case: %v", test) data, err := ctx.EncodeWithOptions(test.value, options) if err != nil { t.Fatal(err) @@ -670,3 +676,30 @@ func TestArraySlice(t *testing.T) { ctx := NewContext() testEncodeDecode(t, ctx, "", testCases...) } + +func TestPointerInterface(t *testing.T) { + type I interface {} + type Type struct { + A int + B string + C bool + } + var obj I + obj = &Type{1, "abc", true} + ctx := NewContext() + // We cannot use testSimple because the type is I + data, err := ctx.Encode(obj) + if err != nil { + t.Fatal(err) + } + value := new(Type) + rest, err := ctx.Decode(data, value) + if err != nil { + t.Fatal(err) + } + if len(rest) > 0 { + t.Fatalf("Unexpected remaining bytes when decoding \"%v\": %#v\n", + obj, rest) + } + checkEqual(t, obj, value) +} diff --git a/encode.go b/encode.go index 964441d..5a88cb6 100644 --- a/encode.go +++ b/encode.go @@ -37,10 +37,7 @@ func (ctx *Context) EncodeWithOptions(obj interface{}, options string) (data []b func (ctx *Context) encode(value reflect.Value, opts *fieldOptions) (*rawValue, error) { // Skip the interface type - switch value.Kind() { - case reflect.Interface: - value = value.Elem() - } + value = getActualType(value) // If a value is missing the default value is used empty := isEmpty(value) diff --git a/types.go b/types.go index 8e54ab5..7aec372 100644 --- a/types.go +++ b/types.go @@ -347,6 +347,21 @@ func (ctx *Context) decodeNull(data []byte, value reflect.Value) error { * Helper functions */ +// getActualType recursively gets the underlying type of Interfaces and Pointers. +func getActualType(value reflect.Value) reflect.Value { + for { + if value.Type() == bigIntType { + return value + } + switch value.Kind() { + case reflect.Interface, reflect.Ptr: + value = value.Elem() + default: + return value + } + } +} + func checkInt(ctx *Context, data []byte) error { if ctx.der.decoding { if len(data) >= 2 {