diff --git a/common/environment.go b/common/environment.go index 1d8ffc84a..5a4b7f3d0 100644 --- a/common/environment.go +++ b/common/environment.go @@ -21,10 +21,13 @@ package common import ( - "github.com/JeffreyRichter/enum/enum" + "encoding/json" + "fmt" "reflect" "runtime" "strings" + + "github.com/JeffreyRichter/enum/enum" ) type EnvironmentVariable struct { @@ -115,6 +118,32 @@ func (d *AutoLoginType) Parse(s string) error { return err } +// MarshalJSON customizes the JSON encoding for AutoLoginType +func (d AutoLoginType) MarshalJSON() ([]byte, error) { + return json.Marshal(d.String()) +} + +// UnmarshalJSON customizes the JSON decoding for AutoLoginType +func (d *AutoLoginType) UnmarshalJSON(data []byte) error { + var v interface{} + if err := json.Unmarshal(data, &v); err != nil { + return err + } + if strValue, ok := v.(string); ok { + return d.Parse(strValue) + } + // Handle numeric values + if numValue, ok := v.(float64); ok { + if numValue < 0 || numValue > 255 { + return fmt.Errorf("value out of range for _token_source_refresh: %v", numValue) + } + *d = AutoLoginType(uint8(numValue)) + return nil + } + + return fmt.Errorf("unsupported type for AutoLoginType: %T", v) +} + func ValidAutoLoginTypes() []string { return []string{ EAutoLoginType.Device().String() + " (Device code workflow)", diff --git a/common/oauthTokenManager_test.go b/common/oauthTokenManager_test.go new file mode 100644 index 000000000..8f00fbffd --- /dev/null +++ b/common/oauthTokenManager_test.go @@ -0,0 +1,185 @@ +// Copyright © 2017 Microsoft +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package common + +import ( + "context" + "fmt" + "os" + "reflect" + "strconv" + "testing" +) + +const tokenInfoJson = `{ + "access_token": "dummy_access_token", + "refresh_token": "dummy_refresh_token", + "expires_in": 0, + "expires_on": 0, + "not_before": 0, + "resource": "dummy_resource", + "token_type": "dummy_token_type", + "_tenant": "dummy_tenant", + "_ad_endpoint": "dummy_ad_endpoint", + "_token_refresh_source": %v, + "_application_id": "dummy_application_id", + "IdentityInfo": { + "_identity_client_id": "dummy_identity_client_id", + "_identity_object_id": "dummy_identity_object_id", + "_identity_msi_res_id": "dummy_identity_msi_res_id" + }, + "SPNInfo": { + "_spn_secret": "dummy_spn_secret", + "_spn_cert_path": "dummy_spn_cert_path" + } + }` + +var oauthTokenInfo = OAuthTokenInfo{ + Tenant: "dummy_tenant", + ActiveDirectoryEndpoint: "dummy_ad_endpoint", + LoginType: 255, + ApplicationID: "dummy_application_id", + IdentityInfo: IdentityInfo{ + ClientID: "dummy_identity_client_id", + ObjectID: "dummy_identity_object_id", + MSIResID: "dummy_identity_msi_res_id", + }, + SPNInfo: SPNInfo{ + Secret: "dummy_spn_secret", + CertPath: "dummy_spn_cert_path", + }, +} + +func formatTokenInfo(value interface{}) string { + var formattedValue string + switch v := value.(type) { + case string: + formattedValue = fmt.Sprintf("\"%s\"", v) + case int: + formattedValue = strconv.Itoa(v) + default: + formattedValue = fmt.Sprintf("%v", v) + } + return fmt.Sprintf(tokenInfoJson, formattedValue) +} + +func TestUserOAuthTokenManager_GetTokenInfo(t *testing.T) { + type args struct { + ctx context.Context + } + tests := []struct { + name string + uotm *UserOAuthTokenManager + args args + setup func(t *testing.T) + want *OAuthTokenInfo + wantErr bool + }{ + { + name: "This UT tests if AutoLoginType filled is parsed properly from string to uint8 data type", + uotm: &UserOAuthTokenManager{}, + args: args{ + ctx: context.Background(), + }, + setup: func(t *testing.T) { + tokenInfo := formatTokenInfo("TokenStore") + fmt.Println(tokenInfo) + + // Set the environment variable AZCOPY_OAUTH_TOKEN_INFO + err := os.Setenv("AZCOPY_OAUTH_TOKEN_INFO", tokenInfo) + if err != nil { + t.Fatalf("Failed to set environment variable: %v", err) + } + }, + want: &oauthTokenInfo, + wantErr: false, + }, + { + name: "This UT tests if AutoLoginType filled is assigned properly to uint8 data type", + uotm: &UserOAuthTokenManager{}, + args: args{ + ctx: context.Background(), + }, + setup: func(t *testing.T) { + tokenInfo := formatTokenInfo(255) + fmt.Println(tokenInfo) + + // Set the environment variable AZCOPY_OAUTH_TOKEN_INFO + err := os.Setenv("AZCOPY_OAUTH_TOKEN_INFO", tokenInfo) + if err != nil { + t.Fatalf("Failed to set environment variable: %v", err) + } + }, + want: &oauthTokenInfo, + wantErr: false, + }, + { + name: "This UT tests if _token_refresh_source fails to parse due to invalid string type", + uotm: &UserOAuthTokenManager{}, + args: args{ + ctx: context.Background(), + }, + setup: func(t *testing.T) { + tokenInfo := formatTokenInfo("2gt5") + + // Set the environment variable AZCOPY_OAUTH_TOKEN_INFO + err := os.Setenv("AZCOPY_OAUTH_TOKEN_INFO", tokenInfo) + if err != nil { + t.Fatalf("Failed to set environment variable: %v", err) + } + }, + want: nil, + wantErr: true, + }, + { + name: "This UT tests if _token_refresh_source fails to parse due to value out of uint8 range", + uotm: &UserOAuthTokenManager{}, + args: args{ + ctx: context.Background(), + }, + setup: func(t *testing.T) { + tokenInfo := formatTokenInfo(847) + + // Set the environment variable AZCOPY_OAUTH_TOKEN_INFO + err := os.Setenv("AZCOPY_OAUTH_TOKEN_INFO", tokenInfo) + if err != nil { + t.Fatalf("Failed to set environment variable: %v", err) + } + }, + want: nil, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + tt.setup(t) + got, err := tt.uotm.GetTokenInfo(tt.args.ctx) + if (err != nil) != tt.wantErr { + t.Errorf("UserOAuthTokenManager.GetTokenInfo() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != nil && reflect.DeepEqual(got, tt.want) { + t.Errorf("UserOAuthTokenManager.GetTokenInfo() = %v, want %v", got, tt.want) + } + }) + } +}