Skip to content

Commit

Permalink
Added token caching functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
delinea-sagar committed Nov 22, 2024
1 parent 00beb33 commit a4fb757
Showing 1 changed file with 73 additions and 44 deletions.
117 changes: 73 additions & 44 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,13 @@ func (s Server) accessResource(method, resource, path string, input interface{})

log.Printf("[DEBUG] calling %s %s", method, req.URL.String())

data, _, err := handleResponse((&http.Client{}).Do(req))
data, statusCode, err := handleResponse((&http.Client{}).Do(req))

// Check for unauthorized or access denied
if statusCode.StatusCode == http.StatusUnauthorized || statusCode.StatusCode == http.StatusForbidden {
s.clearTokenCache()
log.Printf("[ERROR] Token cache cleared due to unauthorized or access denied response.")
}

return data, err
}
Expand Down Expand Up @@ -260,20 +266,20 @@ func (s Server) uploadFile(secretId int, fileField SecretField) error {
return err
}

func (s *Server) setCacheAccessToken(value string, expiresIn int) error {
func (s *Server) setCacheAccessToken(value string, expiresIn int, baseURL string) error {
cache := TokenCache{}
cache.AccessToken = value
cache.ExpiresIn = (int(time.Now().Unix()) + expiresIn) - int(math.Floor(float64(expiresIn)*0.9))

data, _ := json.Marshal(cache)
os.Setenv("SS_AT", string(data))
os.Setenv("SS_AT_"+url.QueryEscape(baseURL), string(data))
return nil
}

func (s *Server) getCacheAccessToken() (string, bool) {
data, ok := os.LookupEnv("SS_AT")
func (s *Server) getCacheAccessToken(baseURL string) (string, bool) {
data, ok := os.LookupEnv("SS_AT_" + url.QueryEscape(baseURL))
if !ok {
os.Setenv("SS_AT", "")
s.clearTokenCache()
return "", ok
}
cache := TokenCache{}
Expand All @@ -286,22 +292,43 @@ func (s *Server) getCacheAccessToken() (string, bool) {
return "", false
}

func (s *Server) clearTokenCache() {
var baseURL string

if s.ServerURL == "" {
baseURL = fmt.Sprintf(cloudBaseURLTemplate, s.Tenant, s.TLD)
} else {
baseURL = s.ServerURL
}

os.Setenv("SS_AT_"+url.QueryEscape(baseURL), "")
}

// getAccessToken gets an OAuth2 Access Grant and returns the token
// endpoint and get an accessGrant.
func (s *Server) getAccessToken() (string, error) {
if s.Credentials.Token != "" {
return s.Credentials.Token, nil
}
accessToken, found := s.getCacheAccessToken()
if found {
return accessToken, nil
var baseURL string

if s.ServerURL == "" {
baseURL = fmt.Sprintf(cloudBaseURLTemplate, s.Tenant, s.TLD)
} else {
baseURL = s.ServerURL
}

response, err := s.checkPlatformDetails()
response, err := s.checkPlatformDetails(baseURL)
if err != nil {
log.Print("Error while checking server details:", err)
return "", err
} else if err == nil && response == "" {

accessToken, found := s.getCacheAccessToken(baseURL)
if found {
return accessToken, nil
}

values := url.Values{
"username": {s.Credentials.Username},
"password": {s.Credentials.Password},
Expand Down Expand Up @@ -331,7 +358,7 @@ func (s *Server) getAccessToken() (string, error) {
log.Print("[ERROR] parsing grant response:", err)
return "", err
}
if err = s.setCacheAccessToken(grant.AccessToken, grant.ExpiresIn); err != nil {
if err = s.setCacheAccessToken(grant.AccessToken, grant.ExpiresIn, baseURL); err != nil {
log.Print("[ERROR] caching access token:", err)
return "", err
}
Expand All @@ -341,15 +368,7 @@ func (s *Server) getAccessToken() (string, error) {
}
}

func (s *Server) checkPlatformDetails() (string, error) {
var baseURL string

if s.ServerURL == "" {
baseURL = fmt.Sprintf(cloudBaseURLTemplate, s.Tenant, s.TLD)
} else {
baseURL = s.ServerURL
}

func (s *Server) checkPlatformDetails(baseURL string) (string, error) {
platformHelthCheckUrl := fmt.Sprintf("%s/%s", strings.Trim(baseURL, "/"), "health")
ssHealthCheckUrl := fmt.Sprintf("%s/%s", strings.Trim(baseURL, "/"), "healthcheck.aspx")

Expand All @@ -359,40 +378,50 @@ func (s *Server) checkPlatformDetails() (string, error) {
} else {
isHealthy := checkJSONResponse(platformHelthCheckUrl)
if isHealthy {
requestData := url.Values{}
requestData.Set("grant_type", "client_credentials")
requestData.Set("client_id", s.Credentials.Username)
requestData.Set("client_secret", s.Credentials.Password)
requestData.Set("scope", "xpmheadless")

req, err := http.NewRequest("POST", fmt.Sprintf("%s/%s", strings.Trim(baseURL, "/"), "identity/api/oauth2/token/xpmplatform"), bytes.NewBufferString(requestData.Encode()))
if err != nil {
log.Print("Error creating HTTP request:", err)
return "", err
}
accessToken, found := s.getCacheAccessToken(baseURL)
if !found {
requestData := url.Values{}
requestData.Set("grant_type", "client_credentials")
requestData.Set("client_id", s.Credentials.Username)
requestData.Set("client_secret", s.Credentials.Password)
requestData.Set("scope", "xpmheadless")

req, err := http.NewRequest("POST", fmt.Sprintf("%s/%s", strings.Trim(baseURL, "/"), "identity/api/oauth2/token/xpmplatform"), bytes.NewBufferString(requestData.Encode()))
if err != nil {
log.Print("Error creating HTTP request:", err)
return "", err
}

req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")

data, _, err := handleResponse((&http.Client{}).Do(req))
if err != nil {
log.Print("[ERROR] get token response error:", err)
return "", err
}
data, _, err := handleResponse((&http.Client{}).Do(req))
if err != nil {
log.Print("[ERROR] get token response error:", err)
return "", err
}

var tokenjsonResponse OAuthTokens
if err = json.Unmarshal(data, &tokenjsonResponse); err != nil {
log.Print("[ERROR] parsing get token response:", err)
return "", err
var tokenjsonResponse OAuthTokens
if err = json.Unmarshal(data, &tokenjsonResponse); err != nil {
log.Print("[ERROR] parsing get token response:", err)
return "", err
}
accessToken = tokenjsonResponse.AccessToken

if err = s.setCacheAccessToken(tokenjsonResponse.AccessToken, tokenjsonResponse.ExpiresIn, baseURL); err != nil {
log.Print("[ERROR] caching access token:", err)
return "", err
}
}

req, err = http.NewRequest("GET", fmt.Sprintf("%s/%s", strings.Trim(baseURL, "/"), "vaultbroker/api/vaults"), bytes.NewBuffer([]byte{}))
req, err := http.NewRequest("GET", fmt.Sprintf("%s/%s", strings.Trim(baseURL, "/"), "vaultbroker/api/vaults"), bytes.NewBuffer([]byte{}))
if err != nil {
log.Print("Error creating HTTP request:", err)
return "", err
}
req.Header.Add("Authorization", "Bearer "+tokenjsonResponse.AccessToken)
req.Header.Add("Authorization", "Bearer "+accessToken)

data, _, err = handleResponse((&http.Client{}).Do(req))
data, _, err := handleResponse((&http.Client{}).Do(req))
if err != nil {
log.Print("[ERROR] get vaults response error:", err)
return "", err
Expand All @@ -417,7 +446,7 @@ func (s *Server) checkPlatformDetails() (string, error) {
return "", fmt.Errorf("no configured vault found")
}

return tokenjsonResponse.AccessToken, nil
return accessToken, nil
}
}
return "", fmt.Errorf("invalid URL")
Expand Down

0 comments on commit a4fb757

Please sign in to comment.