diff --git a/builtin/credential/approle/path_role_test.go b/builtin/credential/approle/path_role_test.go index 24ce69211e51..d2c35c21afc6 100644 --- a/builtin/credential/approle/path_role_test.go +++ b/builtin/credential/approle/path_role_test.go @@ -1864,7 +1864,7 @@ func TestAppRole_TokenutilUpgrade(t *testing.T) { // Hand craft JSON because there is overlap between fields if err := s.Put(ctx, &logical.StorageEntry{ Key: "role/foo", - Value: []byte(`{"policies": ["foo"], "period": 300000000000, "token_bound_cidrs": ["127.0.0.1", "10.10.10.10/24"]}`), + Value: []byte(`{"policies": ["foo"], "period": 300000000000, "token_bound_cidrs": ["127.0.0.1", "10.10.10.10/24"], "token_type": "service"}`), }); err != nil { t.Fatal(err) } @@ -1882,6 +1882,7 @@ func TestAppRole_TokenutilUpgrade(t *testing.T) { TokenPolicies: []string{"foo"}, TokenPeriod: 300 * time.Second, TokenBoundCIDRs: []*sockaddr.SockAddrMarshaler{&sockaddr.SockAddrMarshaler{SockAddr: sockaddr.MustIPAddr("127.0.0.1")}, &sockaddr.SockAddrMarshaler{SockAddr: sockaddr.MustIPAddr("10.10.10.10/24")}}, + TokenType: logical.TokenTypeService, }, } if diff := deep.Equal(fooEntry, exp); diff != nil { diff --git a/sdk/logical/token.go b/sdk/logical/token.go index 242885e1ffc2..3646fee5388a 100644 --- a/sdk/logical/token.go +++ b/sdk/logical/token.go @@ -1,6 +1,7 @@ package logical import ( + "fmt" "time" sockaddr "github.com/hashicorp/go-sockaddr" @@ -28,6 +29,31 @@ const ( TokenTypeDefaultBatch ) +func (t *TokenType) UnmarshalJSON(b []byte) error { + if len(b) == 1 { + *t = TokenType(b[0] - '0') + return nil + } + + // Handle upgrade from pre-1.2 where we were serialized as string: + s := string(b) + switch s { + case `"default"`: + *t = TokenTypeDefault + case `"service"`: + *t = TokenTypeService + case `"batch"`: + *t = TokenTypeBatch + case `"default-service"`: + *t = TokenTypeDefaultService + case `"default-batch"`: + *t = TokenTypeDefaultBatch + default: + return fmt.Errorf("unknown token type %q", s) + } + return nil +} + func (t TokenType) String() string { switch t { case TokenTypeDefault: diff --git a/sdk/logical/token_test.go b/sdk/logical/token_test.go new file mode 100644 index 000000000000..412a7d4abfd4 --- /dev/null +++ b/sdk/logical/token_test.go @@ -0,0 +1,33 @@ +package logical + +import ( + "encoding/json" + "testing" +) + +func TestJSONSerialization(t *testing.T) { + tt := TokenTypeDefaultBatch + s, err := json.Marshal(tt) + if err != nil { + t.Fatal(err) + } + + var utt TokenType + err = json.Unmarshal(s, &utt) + if err != nil { + t.Fatal(err) + } + + if tt != utt { + t.Fatalf("expected %v, got %v", tt, utt) + } + + utt = TokenTypeDefault + err = json.Unmarshal([]byte(`"default-batch"`), &utt) + if err != nil { + t.Fatal(err) + } + if tt != utt { + t.Fatalf("expected %v, got %v", tt, utt) + } +}