From fae7fe2c3a06b92fa80b5f4166b818a16625b181 Mon Sep 17 00:00:00 2001 From: Tharsanan1 Date: Fri, 31 Jan 2025 16:10:52 +0530 Subject: [PATCH] Add backend JWt support to go enforcer --- gateway/enforcer/go.mod | 1 - gateway/enforcer/go.sum | 2 - .../internal/authorization/authorization.go | 2 +- gateway/enforcer/internal/config/config.go | 2 + .../internal/config/enforcer_config.go | 10 +- .../enforcer/internal/datastore/api_store.go | 96 ++++--- gateway/enforcer/internal/dto/claim_value.go | 2 +- .../internal/dto/jwt_configuration.go | 30 +-- gateway/enforcer/internal/extproc/ext_proc.go | 20 +- .../internal/jwtbackend/jwtGenerator.go | 239 ----------------- .../internal/jwtbackend/jwtInfoDTO.go | 71 ----- .../internal/jwtbackend/jwtValidationInfo.go | 57 ----- .../internal/jwtbackend/jwt_generator.go | 93 +++++++ .../enforcer/internal/requestconfig/api.go | 58 ++--- gateway/enforcer/internal/util/cert.go | 21 ++ gateway/enforcer/internal/util/conversion.go | 3 +- gateway/enforcer/internal/util/jwt.go | 242 ++++++++++++++++++ .../enforcer/internal/xds/client_manager.go | 2 +- 18 files changed, 484 insertions(+), 467 deletions(-) delete mode 100644 gateway/enforcer/internal/jwtbackend/jwtGenerator.go delete mode 100644 gateway/enforcer/internal/jwtbackend/jwtInfoDTO.go delete mode 100644 gateway/enforcer/internal/jwtbackend/jwtValidationInfo.go create mode 100644 gateway/enforcer/internal/jwtbackend/jwt_generator.go create mode 100644 gateway/enforcer/internal/util/jwt.go diff --git a/gateway/enforcer/go.mod b/gateway/enforcer/go.mod index 06a88b63b..312a89036 100644 --- a/gateway/enforcer/go.mod +++ b/gateway/enforcer/go.mod @@ -7,7 +7,6 @@ require ( github.com/envoyproxy/go-control-plane v0.13.1 github.com/go-logr/logr v1.4.2 github.com/go-logr/zapr v1.3.0 - github.com/golang-jwt/jwt/v4 v4.5.1 github.com/google/uuid v1.6.0 github.com/kelseyhightower/envconfig v1.4.0 github.com/prometheus/client_golang v1.20.5 diff --git a/gateway/enforcer/go.sum b/gateway/enforcer/go.sum index b4ec224eb..ca759eefa 100644 --- a/gateway/enforcer/go.sum +++ b/gateway/enforcer/go.sum @@ -35,8 +35,6 @@ github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1v github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= -github.com/golang-jwt/jwt/v4 v4.5.1 h1:JdqV9zKUdtaa9gdPlywC3aeoEsR681PlKC+4F5gQgeo= -github.com/golang-jwt/jwt/v4 v4.5.1/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/gnostic-models v0.6.8 h1:yo/ABAfM5IMRsS1VnXjTBvUb61tFIHozhlYvRgGre9I= diff --git a/gateway/enforcer/internal/authorization/authorization.go b/gateway/enforcer/internal/authorization/authorization.go index b14361832..5d4748017 100644 --- a/gateway/enforcer/internal/authorization/authorization.go +++ b/gateway/enforcer/internal/authorization/authorization.go @@ -36,6 +36,6 @@ func Validate(rch *requestconfig.Holder, subAppDataStore *datastore.Subscription return immediateResponse } cfg.Logger.Info(fmt.Sprintf("Subscription validation successful for the request: %s", rch.MatchedResource.Path)) - + return nil } diff --git a/gateway/enforcer/internal/config/config.go b/gateway/enforcer/internal/config/config.go index 2691fa1fd..4e3781127 100644 --- a/gateway/enforcer/internal/config/config.go +++ b/gateway/enforcer/internal/config/config.go @@ -69,6 +69,8 @@ type Server struct { ExternalProcessingMaxHeaderLimit int `envconfig:"EXTERNAL_PROCESSING_MAX_HEADER_LIMIT" default:"8192"` Logger logging.Logger Metrics metrics + JWTGeneratorPublicKeyPath string `envconfig:"JWT_GENERATOR_PUBLIC_CERTIFICATE_PATH" default:"/home/wso2/security/keystore/mg.pem"` + JWTGeneratorPrivateKeyPath string `envconfig:"JWT_GENERATOR_PRIVATE_KEY_PATH" default:"/home/wso2/security/keystore/mg.key"` } type metrics struct { diff --git a/gateway/enforcer/internal/config/enforcer_config.go b/gateway/enforcer/internal/config/enforcer_config.go index 7bab257cd..cd0f3d646 100644 --- a/gateway/enforcer/internal/config/enforcer_config.go +++ b/gateway/enforcer/internal/config/enforcer_config.go @@ -14,16 +14,16 @@ * limitations under the License. * */ - + package config import ( - "github.com/wso2/apk/gateway/enforcer/internal/dto" config_from_adapter "github.com/wso2/apk/adapter/pkg/discovery/api/wso2/discovery/config/enforcer" + "github.com/wso2/apk/gateway/enforcer/internal/dto" ) // EnforcerConfig is a struct that holds the enforcer configuration. type EnforcerConfig struct { - JWTConfiguration *dto.JWTConfiguration - Analytics *config_from_adapter.Analytics -} \ No newline at end of file + JWTConfiguration *dto.BackendJWTConfiguration + Analytics *config_from_adapter.Analytics +} diff --git a/gateway/enforcer/internal/datastore/api_store.go b/gateway/enforcer/internal/datastore/api_store.go index a0e795168..2d5a40c68 100644 --- a/gateway/enforcer/internal/datastore/api_store.go +++ b/gateway/enforcer/internal/datastore/api_store.go @@ -18,10 +18,11 @@ package datastore import ( - "log" + "fmt" "sync" api "github.com/wso2/apk/adapter/pkg/discovery/api/wso2/discovery/api" + "github.com/wso2/apk/gateway/enforcer/internal/config" "github.com/wso2/apk/gateway/enforcer/internal/dto" "github.com/wso2/apk/gateway/enforcer/internal/requestconfig" "github.com/wso2/apk/gateway/enforcer/internal/util" @@ -32,13 +33,15 @@ type APIStore struct { apis map[string]*requestconfig.API mu sync.RWMutex configStore *ConfigStore + cfg *config.Server } // NewAPIStore creates a new instance of APIStore. -func NewAPIStore(configStore *ConfigStore) *APIStore { +func NewAPIStore(configStore *ConfigStore, cfg *config.Server) *APIStore { return &APIStore{ configStore: configStore, // apis: make(map[string]*api.Api, 0), + cfg: cfg, } } @@ -50,29 +53,29 @@ func (s *APIStore) AddAPIs(apis []*api.Api) { s.apis = make(map[string]*requestconfig.API, len(apis)) for _, api := range apis { customAPI := requestconfig.API{ - Name: api.Title, - Version: api.Version, - Vhost: api.Vhost, - BasePath: api.BasePath, - APIType: api.ApiType, - EnvType: api.EnvType, - APILifeCycleState: api.ApiLifeCycleState, - AuthorizationHeader: "", // You might want to set this field if applicable - OrganizationID: api.OrganizationId, - UUID: api.Id, - Tier: api.Tier, - DisableAuthentication: api.DisableAuthentications, - DisableScopes: api.DisableScopes, - Resources: make([]requestconfig.Resource, 0), - IsMockedAPI: false, // You can add logic to determine if the API is mocked - MutualSSL: api.MutualSSL, - TransportSecurity: api.TransportSecurity, - ApplicationSecurity: api.ApplicationSecurity, - JwtConfigurationDto: convertBackendJWTTokenInfoToJWTConfig(api.BackendJWTTokenInfo), - SystemAPI: api.SystemAPI, - APIDefinition: api.ApiDefinitionFile, - Environment: api.Environment, - SubscriptionValidation: api.SubscriptionValidation, + Name: api.Title, + Version: api.Version, + Vhost: api.Vhost, + BasePath: api.BasePath, + APIType: api.ApiType, + EnvType: api.EnvType, + APILifeCycleState: api.ApiLifeCycleState, + AuthorizationHeader: "", // You might want to set this field if applicable + OrganizationID: api.OrganizationId, + UUID: api.Id, + Tier: api.Tier, + DisableAuthentication: api.DisableAuthentications, + DisableScopes: api.DisableScopes, + Resources: make([]requestconfig.Resource, 0), + IsMockedAPI: false, // You can add logic to determine if the API is mocked + MutualSSL: api.MutualSSL, + TransportSecurity: api.TransportSecurity, + ApplicationSecurity: api.ApplicationSecurity, + BackendJwtConfiguration: convertBackendJWTTokenInfoToJWTConfig(api.BackendJWTTokenInfo, s.cfg, fmt.Sprintf("%s-%s", api.Title, api.Version)), + SystemAPI: api.SystemAPI, + APIDefinition: api.ApiDefinitionFile, + Environment: api.Environment, + SubscriptionValidation: api.SubscriptionValidation, // Endpoints: api.Endpoints, // EndpointSecurity: convertSecurityInfoToEndpointSecurity(api.EndpointSecurity), AiProvider: convertAIProviderToDTO(api.Aiprovider), @@ -98,7 +101,7 @@ func (s *APIStore) AddAPIs(apis []*api.Api) { customAPI.Resources = append(customAPI.Resources, resource) } } - log.Printf("Adding API: %+v", customAPI.JwtConfigurationDto) + s.cfg.Logger.Info(fmt.Sprintf("Adding API: %+v", customAPI.BackendJwtConfiguration)) s.apis[util.PrepareAPIKey(api.Vhost, api.BasePath, api.Version)] = &customAPI } } @@ -173,33 +176,42 @@ func (s *APIStore) GetMatchedAPI(apiKey string) *requestconfig.API { } // ConvertBackendJWTTokenInfoToJWTConfig converts BackendJWTTokenInfo to JWTConfiguration. -func convertBackendJWTTokenInfoToJWTConfig(info *api.BackendJWTTokenInfo) *dto.JWTConfiguration { +func convertBackendJWTTokenInfoToJWTConfig(info *api.BackendJWTTokenInfo, cfg *config.Server, apiName string) *dto.BackendJWTConfiguration { if info == nil { return nil } // Convert CustomClaims from map[string]*Claim to map[string]ClaimValue - customClaims := make(map[string]dto.ClaimValue) + customClaims := make(map[string]*dto.ClaimValue) for key, claim := range info.CustomClaims { if claim != nil { - customClaims[key] = dto.ClaimValue{ + customClaims[key] = &dto.ClaimValue{ Value: claim.Value, Type: claim.Type, } } } - - return &dto.JWTConfiguration{ - Enabled: info.Enabled, - JWTHeader: info.Header, - ConsumerDialectURI: "", // Add a default value or fetch if needed - SignatureAlgorithm: info.SigningAlgorithm, - Encoding: info.Encoding, - TokenIssuerDtoMap: make(map[string]dto.TokenIssuer), // Populate if required - JwtExcludedClaims: make(map[string]bool), // Populate if required - PublicCert: nil, // Add conversion logic if needed - PrivateKey: nil, // Add conversion logic if needed - TTL: int64(info.TokenTTL), // Convert int32 to int64 - CustomClaims: customClaims, + publicCert, err := util.LoadCertificate(cfg.JWTGeneratorPublicKeyPath) + if err != nil { + cfg.Logger.Error(err, fmt.Sprintf("Error loading public cert. Marking API %s as backend jwt disabled.", apiName)) + info.Enabled = false + } + privateKey, err := util.LoadPrivateKey(cfg.JWTGeneratorPrivateKeyPath) + if err != nil { + cfg.Logger.Error(err, fmt.Sprintf("Error loading private key. Marking API %s as backend jwt disabled. Path: %s", apiName, cfg.JWTGeneratorPrivateKeyPath)) + info.Enabled = false + } + return &dto.BackendJWTConfiguration{ + Enabled: info.Enabled, + JWTHeader: info.Header, + ConsumerDialectURI: "", // Add a default value or fetch if needed + SignatureAlgorithm: info.SigningAlgorithm, + Encoding: info.Encoding, + TokenIssuerDtoMap: make(map[string]dto.TokenIssuer), // Populate if required + JwtExcludedClaims: make(map[string]bool), // Populate if required + PublicCert: publicCert, // Add conversion logic if needed + PrivateKey: privateKey, // Add conversion logic if needed + TTL: int64(info.TokenTTL), // Convert int32 to int64 + CustomClaims: customClaims, } } diff --git a/gateway/enforcer/internal/dto/claim_value.go b/gateway/enforcer/internal/dto/claim_value.go index 0c4df85dd..57bd89655 100644 --- a/gateway/enforcer/internal/dto/claim_value.go +++ b/gateway/enforcer/internal/dto/claim_value.go @@ -19,6 +19,6 @@ package dto // ClaimValue represents the claim value type ClaimValue struct { - Value interface{} `json:"value"` // Value of the claim (can be any type) + Value string `json:"value"` // Value of the claim (can be any type) Type string `json:"type"` // Type of the claim } \ No newline at end of file diff --git a/gateway/enforcer/internal/dto/jwt_configuration.go b/gateway/enforcer/internal/dto/jwt_configuration.go index 5c4bee4b8..4c4d68354 100644 --- a/gateway/enforcer/internal/dto/jwt_configuration.go +++ b/gateway/enforcer/internal/dto/jwt_configuration.go @@ -18,22 +18,22 @@ package dto import ( - "crypto/ecdsa" + "crypto/rsa" "crypto/x509" ) -// JWTConfiguration represents the JWT configuration -type JWTConfiguration struct { - Enabled bool `json:"enabled"` // Whether JWT is enabled - JWTHeader string `json:"jwtHeader"` // JWT header name - ConsumerDialectURI string `json:"consumerDialectUri"` // URI for the consumer dialect - SignatureAlgorithm string `json:"signatureAlgorithm"` // Algorithm for signature - Encoding string `json:"encoding"` // Encoding type - TokenIssuerDtoMap map[string]TokenIssuer `json:"tokenIssuerDtoMap"` // Map of token issuers - JwtExcludedClaims map[string]bool `json:"jwtExcludedClaims"` // Excluded claims in JWT - PublicCert *x509.Certificate `json:"publicCert"` // Public certificate - PrivateKey *ecdsa.PrivateKey `json:"privateKey"` // Private key for signing JWT - TTL int64 `json:"ttl"` // Time to live for the JWT - CustomClaims map[string]ClaimValue `json:"customClaims"` // Custom claims - UseKid bool `json:"useKid"` // Whether to use kid +// BackendJWTConfiguration represents the JWT configuration +type BackendJWTConfiguration struct { + Enabled bool `json:"enabled"` // Whether JWT is enabled + JWTHeader string `json:"jwtHeader"` // JWT header name + ConsumerDialectURI string `json:"consumerDialectUri"` // URI for the consumer dialect + SignatureAlgorithm string `json:"signatureAlgorithm"` // Algorithm for signature + Encoding string `json:"encoding"` // Encoding type + TokenIssuerDtoMap map[string]TokenIssuer `json:"tokenIssuerDtoMap"` // Map of token issuers + JwtExcludedClaims map[string]bool `json:"jwtExcludedClaims"` // Excluded claims in JWT + PublicCert *x509.Certificate `json:"publicCert"` // Public certificate + PrivateKey *rsa.PrivateKey `json:"privateKey"` // Private key for signing JWT + TTL int64 `json:"ttl"` // Time to live for the JWT + CustomClaims map[string]*ClaimValue `json:"customClaims"` // Custom claims + UseKid bool `json:"useKid"` // Whether to use kid } diff --git a/gateway/enforcer/internal/extproc/ext_proc.go b/gateway/enforcer/internal/extproc/ext_proc.go index 43b62a40d..250c8f06b 100644 --- a/gateway/enforcer/internal/extproc/ext_proc.go +++ b/gateway/enforcer/internal/extproc/ext_proc.go @@ -31,6 +31,7 @@ import ( "github.com/wso2/apk/gateway/enforcer/internal/config" "github.com/wso2/apk/gateway/enforcer/internal/datastore" "github.com/wso2/apk/gateway/enforcer/internal/dto" + "github.com/wso2/apk/gateway/enforcer/internal/jwtbackend" "github.com/wso2/apk/gateway/enforcer/internal/logging" "github.com/wso2/apk/gateway/enforcer/internal/ratelimit" "github.com/wso2/apk/gateway/enforcer/internal/requestconfig" @@ -189,6 +190,7 @@ func (s *ExternalProcessingServer) Process(srv envoy_service_proc_v3.ExternalPro } s.requestConfigHolder.ExternalProcessingEnvoyMetadata = metadata s.requestConfigHolder.MatchedResource = httpHandler.GetMatchedResource(s.requestConfigHolder.MatchedAPI, *s.requestConfigHolder.ExternalProcessingEnvoyAttributes) + s.log.Info(fmt.Sprintf("Matched api bjc: %v", s.requestConfigHolder.MatchedAPI.BackendJwtConfiguration)) s.log.Info(fmt.Sprintf("Matched Resource: %v", s.requestConfigHolder.MatchedResource)) s.log.Info(fmt.Sprintf("req holder: %+v\n s: %+v", &s.requestConfigHolder, &s)) if !s.requestConfigHolder.MatchedResource.AuthenticationConfig.Disabled && !s.requestConfigHolder.MatchedAPI.DisableAuthentication { @@ -216,6 +218,11 @@ func (s *ExternalProcessingServer) Process(srv envoy_service_proc_v3.ExternalPro dynamicMetadataKeyValuePairs[orgAndRLPolicyMetadataKey] = fmt.Sprintf("%s-%s", s.requestConfigHolder.MatchedAPI.OrganizationID, s.requestConfigHolder.MatchedSubscription.RatelimitTier) } } + backendJWT := "" + if s.requestConfigHolder.MatchedAPI.BackendJwtConfiguration != nil && s.requestConfigHolder.MatchedAPI.BackendJwtConfiguration.Enabled { + backendJWT = jwtbackend.CreateBackendJWT(s.requestConfigHolder, s.cfg) + s.log.Sugar().Infof("generated backendJWT==%v", backendJWT) + } rhq := &envoy_service_proc_v3.HeadersResponse{ Response: &envoy_service_proc_v3.CommonResponse{ HeaderMutation: &envoy_service_proc_v3.HeaderMutation{ @@ -232,6 +239,17 @@ func (s *ExternalProcessingServer) Process(srv envoy_service_proc_v3.ExternalPro ClearRouteCache: true, }, } + if backendJWT != "" { + rhq.Response.HeaderMutation.SetHeaders = append(rhq.Response.HeaderMutation.SetHeaders, &corev3.HeaderValueOption{ + Header: &corev3.HeaderValue{ + Key: s.requestConfigHolder.MatchedAPI.BackendJwtConfiguration.JWTHeader, + RawValue: []byte(attributes.ClusterName), + }, + + }) + s.cfg.Logger.Info(fmt.Sprintf("Added backend JWT to the header: %s, header name: %s", backendJWT, s.requestConfigHolder.MatchedAPI.BackendJwtConfiguration.JWTHeader)) + } + resp.Response = &envoy_service_proc_v3.ProcessingResponse_RequestHeaders{ RequestHeaders: rhq, } @@ -796,7 +814,7 @@ func buildDynamicMetadata(keyValuePairs *map[string]string) (*structpb.Struct, e } func (s *ExternalProcessingServer) prepareMetadataKeyValuePairAndAddTo(metadataKeyValuePair map[string]string) *map[string]string { - if s.requestConfigHolder.MatchedAPI != nil { + if s.requestConfigHolder != nil && s.requestConfigHolder.MatchedAPI != nil { metadataKeyValuePair[analytics.APIIDKey] = s.requestConfigHolder.MatchedAPI.UUID metadataKeyValuePair[analytics.APIContextKey] = s.requestConfigHolder.MatchedAPI.BasePath metadataKeyValuePair[organizationMetadataKey] = s.requestConfigHolder.MatchedAPI.OrganizationID diff --git a/gateway/enforcer/internal/jwtbackend/jwtGenerator.go b/gateway/enforcer/internal/jwtbackend/jwtGenerator.go deleted file mode 100644 index 284c481fc..000000000 --- a/gateway/enforcer/internal/jwtbackend/jwtGenerator.go +++ /dev/null @@ -1,239 +0,0 @@ -package jwtbackend - -import ( - "crypto/sha1" - "crypto/x509" - "encoding/base64" - "encoding/json" - "errors" - "fmt" - "strings" - "sync" - "time" - - dto "github.com/wso2/apk/gateway/enforcer/internal/dto" -) - -const ( - none = "NONE" - sha256withRSA = "SHA256withRSA" -) - -// AbstractAPIMgtGatewayJWTGenerator is an interface for generating JWT tokens. -type AbstractAPIMgtGatewayJWTGenerator struct { - JwtConfigurationDto *dto.JWTConfiguration - DialectURI string - SignatureAlgorithm string - mutex sync.Mutex -} - -// SetJWTConfigurationDto sets the JWT configuration DTO. -func (g *AbstractAPIMgtGatewayJWTGenerator) SetJWTConfigurationDto(jwtConfigurationDto *dto.JWTConfiguration) { - g.JwtConfigurationDto = jwtConfigurationDto - g.DialectURI = jwtConfigurationDto.ConsumerDialectURI - if g.DialectURI == "" { - g.DialectURI = "http://wso2.org/claims" - } - g.SignatureAlgorithm = jwtConfigurationDto.SignatureAlgorithm - if g.SignatureAlgorithm != none && g.SignatureAlgorithm != sha256withRSA { - g.SignatureAlgorithm = sha256withRSA - } -} - -// GetJWTConfigurationDto gets the JWT configuration DTO. -func (g *AbstractAPIMgtGatewayJWTGenerator) GetJWTConfigurationDto() *dto.JWTConfiguration { - return g.JwtConfigurationDto -} - -// GenerateToken generates a JWT token. -func (g *AbstractAPIMgtGatewayJWTGenerator) GenerateToken(jwtInfoDto *JWTInfoDto, signatureAlgorithm string, signJWT func(string) ([]byte, error)) (string, error) { - jwtHeader, err := g.BuildHeader(nil, signatureAlgorithm) - if err != nil { - return "", fmt.Errorf("error building JWT header: %w", err) - } - - jwtBody, err := g.BuildBody(jwtInfoDto) - if err != nil { - return "", fmt.Errorf("error building JWT body: %w", err) - } - - base64UrlEncodedHeader := Encode([]byte(jwtHeader)) - base64UrlEncodedBody := Encode([]byte(jwtBody)) - - if signatureAlgorithm == "SHA256withRSA" { - assertion := base64UrlEncodedHeader + "." + base64UrlEncodedBody - - signedAssertion, err := signJWT(assertion) - if err != nil { - return "", fmt.Errorf("error signing JWT: %w", err) - } - - base64UrlEncodedAssertion := Encode(signedAssertion) - return base64UrlEncodedHeader + "." + base64UrlEncodedBody + "." + base64UrlEncodedAssertion, nil - } - - return base64UrlEncodedHeader + "." + base64UrlEncodedBody + ".", nil -} - -func (g *AbstractAPIMgtGatewayJWTGenerator) populateStandardClaims(jwtInfoDto *JWTInfoDto) map[string]dto.ClaimValue { - claims := make(map[string]dto.ClaimValue) - for key, value := range jwtInfoDto.Claims { - claims[key] = dto.ClaimValue{ - Value: value.Value, - Type: value.Type, // Ensure `Type` is also assigned if it exists in `ClaimValueDTO`. - } - } - return claims -} - -func (g *AbstractAPIMgtGatewayJWTGenerator) populateCustomClaims(jwtInfoDto *JWTInfoDto) map[string]dto.ClaimValue { - claims := make(map[string]dto.ClaimValue) - for key, value := range jwtInfoDto.Claims { - claims[key] = dto.ClaimValue{ - Value: value.Value, - Type: value.Type, // Ensure `Type` is also assigned if it exists in `ClaimValueDTO`. - } - } - return claims -} - -// Hexify converts a byte slice to a hex string. -func Hexify(bytes []byte) string { - hexDigits := []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'} - var builder strings.Builder - builder.Grow(len(bytes) * 2) - - for _, b := range bytes { - builder.WriteRune(hexDigits[(b&0xf0)>>4]) - builder.WriteRune(hexDigits[b&0x0f]) - } - - return builder.String() -} - -// GenerateThumbprint generates a thumbprint of the public certificate. -func GenerateThumbprint(hashType string, publicCert *x509.Certificate, usePadding bool) (string, error) { - if publicCert == nil { - return "", errors.New("public certificate is nil") - } - - hash := sha1.New() - if hashType != "SHA-1" { - return "", errors.New("unsupported hash type") - } - - hash.Write(publicCert.Raw) - digestInBytes := hash.Sum(nil) - publicCertThumbprint := Hexify(digestInBytes) - - var base64UrlEncodedThumbPrint string - if usePadding { - base64UrlEncodedThumbPrint = base64.URLEncoding.EncodeToString([]byte(publicCertThumbprint)) - } else { - base64UrlEncodedThumbPrint = base64.RawURLEncoding.EncodeToString([]byte(publicCertThumbprint)) - } - - return base64UrlEncodedThumbPrint, nil -} - -// GenerateHeader generates the JWT header. -func GenerateHeader(jwtConfigurationDto *dto.JWTConfiguration, signatureAlgorithm string) (string, error) { - if signatureAlgorithm == "NONE" { - return `{"typ":"JWT","alg":"NONE"}`, nil - } - - header := fmt.Sprintf(`{"typ":"JWT","alg":"RS256"`) - - if jwtConfigurationDto.UseKid { - header += fmt.Sprintf(`,"kid":"%v"`, jwtConfigurationDto.UseKid) - } else { - thumbprint, err := GenerateThumbprint("SHA-1", jwtConfigurationDto.PublicCert, true) - if err != nil { - return "", fmt.Errorf("error in generating public certificate thumbprint: %w", err) - } - header += fmt.Sprintf(`,"x5t":"%s"`, thumbprint) - } - - header += "}" - return header, nil -} - -// AddCertToHeader adds the certificate to the JWT header. -func AddCertToHeader(jwtConfigurationDto *dto.JWTConfiguration, signatureAlgorithm string) (string, error) { - header, err := GenerateHeader(jwtConfigurationDto, signatureAlgorithm) - if err != nil { - return "", fmt.Errorf("error in obtaining keystore: %w", err) - } - return header, nil -} - -// BuildHeader builds the JWT header. -func (g *AbstractAPIMgtGatewayJWTGenerator) BuildHeader(jwtConfigurationDto *dto.JWTConfiguration, signatureAlgorithm string) (string, error) { - var jwtHeader string - if signatureAlgorithm == "NONE" { - jwtHeader = `{"typ":"JWT","alg":"NONE"}` - } else if signatureAlgorithm == "SHA256withRSA" { - header, err := AddCertToHeader(jwtConfigurationDto, signatureAlgorithm) - if err != nil { - return "", err - } - jwtHeader = header - } - return jwtHeader, nil -} - -// BuildBody builds the JWT body. -func (g *AbstractAPIMgtGatewayJWTGenerator) BuildBody(jwtInfoDto *JWTInfoDto) (string, error) { - claims := make(map[string]interface{}) - - // Populate standard claims - for key, value := range g.populateStandardClaims(jwtInfoDto) { - claims[key] = value - } - - // Populate custom claims - for key, claim := range g.populateCustomClaims(jwtInfoDto) { - var finalValue interface{} = claim.Value - if strVal, ok := claim.Value.(string); ok { - switch strings.ToLower(claim.Type) { - case "bool": - finalValue = strVal == "true" - case "int": - finalValue = parseToInt(strVal) - case "long": - finalValue = parseToInt(strVal) - case "float": - finalValue = parseToFloat(strVal) - case "date": - parsedDate, err := time.Parse("2006-01-02", strVal) - if err == nil { - finalValue = parsedDate - } - } - } - claims[key] = finalValue - } - - // Convert claims to JSON - jsonClaims, err := json.Marshal(claims) - if err != nil { - return "", fmt.Errorf("error marshaling claims to JSON: %w", err) - } - - return string(jsonClaims), nil -} - -func parseToInt(value string) int64 { - parsed, _ := time.ParseDuration(value + "s") - return int64(parsed.Seconds()) -} - -func parseToFloat(value string) float64 { - parsed, _ := time.ParseDuration(value + "s") - return float64(parsed.Seconds()) -} - -// Encode encodes a byte slice to a base64 URL encoded string. -func Encode(stringToBeEncoded []byte) string { - return base64.RawURLEncoding.EncodeToString(stringToBeEncoded) -} diff --git a/gateway/enforcer/internal/jwtbackend/jwtInfoDTO.go b/gateway/enforcer/internal/jwtbackend/jwtInfoDTO.go deleted file mode 100644 index 444070a13..000000000 --- a/gateway/enforcer/internal/jwtbackend/jwtInfoDTO.go +++ /dev/null @@ -1,71 +0,0 @@ -package jwtbackend - -import "github.com/wso2/apk/gateway/enforcer/internal/dto" - -// JWTInfoDto holds information related to JWT tokens. -type JWTInfoDto struct { - ApplicationTier string `json:"applicationTier"` - KeyType string `json:"keyType"` - Version string `json:"version"` - ApplicationName string `json:"applicationName"` - EndUser string `json:"endUser"` - EndUserTenantID int `json:"endUserTenantId"` - ApplicationUUID string `json:"applicationUUId"` - Subscriber string `json:"subscriber"` - SubscriptionTier string `json:"subscriptionTier"` - ApplicationID string `json:"applicationId"` - APIContext string `json:"apiContext"` - APIName string `json:"apiName"` - JwtValidationInfo *JWTValidationInfo `json:"jwtValidationInfo"` - AppAttributes map[string]string `json:"appAttributes"` - Sub string `json:"sub"` - Organizations []string `json:"organizations"` - Claims map[string]*dto.ClaimValue `json:"claims"` -} - -// NewJWTInfoDto creates a new JWTInfoDto instance. -func NewJWTInfoDto() *JWTInfoDto { - return &JWTInfoDto{ - AppAttributes: make(map[string]string), - Claims: make(map[string]*dto.ClaimValue), - Organizations: make([]string, 0), - } -} - -// Clone creates a deep copy of the JWTInfoDto. -func (j *JWTInfoDto) Clone() *JWTInfoDto { - clone := *j - clone.AppAttributes = CloneStringMap(j.AppAttributes) - clone.Claims = CloneClaimsMap(j.Claims) - clone.Organizations = copyStringSlice(j.Organizations) - if j.JwtValidationInfo != nil { - clone.JwtValidationInfo = j.JwtValidationInfo.Clone() - } - return &clone -} - -// CloneStringMap creates a copy of a string map. -func CloneStringMap(original map[string]string) map[string]string { - clone := make(map[string]string) - for k, v := range original { - clone[k] = v - } - return clone -} - -// CloneClaimsMap creates a deep copy of a map with ClaimValueDTO pointers. -func CloneClaimsMap(original map[string]*dto.ClaimValue) map[string]*dto.ClaimValue { - clonedMap := make(map[string]*dto.ClaimValue) - for k, v := range original { - if v != nil { - valueCopy := *v // Create a copy of the struct value - clonedMap[k] = &valueCopy - } - } - return clonedMap -} - -// Helper function to copy a string slice. -func copyStringSlice(original []string) []string { - return append([]string{}, original...) -} diff --git a/gateway/enforcer/internal/jwtbackend/jwtValidationInfo.go b/gateway/enforcer/internal/jwtbackend/jwtValidationInfo.go deleted file mode 100644 index b6a99515c..000000000 --- a/gateway/enforcer/internal/jwtbackend/jwtValidationInfo.go +++ /dev/null @@ -1,57 +0,0 @@ -package jwtbackend - -import ( - "github.com/golang-jwt/jwt/v4" -) - -// JWTValidationInfo holds JWT validation related information. -type JWTValidationInfo struct { - User string `json:"user"` - ExpiryTime int64 `json:"expiryTime"` - ConsumerKey string `json:"consumerKey"` - Valid bool `json:"valid"` - Scopes []string `json:"scopes"` - Claims map[string]interface{} `json:"claims"` - ValidationCode int `json:"validationCode"` - KeyManager string `json:"keyManager"` - Identifier string `json:"identifier"` - JWTClaimsSet *jwt.MapClaims `json:"jwtClaimsSet"` - Token string `json:"token"` - Audience []string `json:"audience"` -} - -// NewJWTValidationInfo creates a new instance of JWTValidationInfo. -func NewJWTValidationInfo() *JWTValidationInfo { - return &JWTValidationInfo{ - Scopes: make([]string, 0), - Claims: make(map[string]interface{}), - Audience: make([]string, 0), - } -} - -// Clone creates a copy of an existing JWTValidationInfo. -func (j *JWTValidationInfo) Clone() *JWTValidationInfo { - return &JWTValidationInfo{ - User: j.User, - ExpiryTime: j.ExpiryTime, - ConsumerKey: j.ConsumerKey, - Valid: j.Valid, - Scopes: append([]string{}, j.Scopes...), - Claims: CloneMap(j.Claims), - ValidationCode: j.ValidationCode, - KeyManager: j.KeyManager, - Identifier: j.Identifier, - JWTClaimsSet: j.JWTClaimsSet, - Token: j.Token, - Audience: append([]string{}, j.Audience...), - } -} - -// CloneMap creates a shallow copy of a map[string]interface{}. -func CloneMap(original map[string]interface{}) map[string]interface{} { - clonedMap := make(map[string]interface{}) - for k, v := range original { - clonedMap[k] = v - } - return clonedMap -} diff --git a/gateway/enforcer/internal/jwtbackend/jwt_generator.go b/gateway/enforcer/internal/jwtbackend/jwt_generator.go new file mode 100644 index 000000000..aee07cc10 --- /dev/null +++ b/gateway/enforcer/internal/jwtbackend/jwt_generator.go @@ -0,0 +1,93 @@ +package jwtbackend + +import ( + "fmt" + "time" + + "github.com/wso2/apk/gateway/enforcer/internal/config" + "github.com/wso2/apk/gateway/enforcer/internal/dto" + "github.com/wso2/apk/gateway/enforcer/internal/requestconfig" + "github.com/wso2/apk/gateway/enforcer/internal/util" +) + +const ( + apiGatewayID = "wso2.org/products/am" + dialectURI = "http://wso2.org/claims/" + sha256WithRSA = "SHA256withRSA" +) + +// CreateBackendJWT creates a JWT token for the backend. +func CreateBackendJWT(rch *requestconfig.Holder, cfg *config.Server) string { + api := rch.MatchedAPI + application := rch.MatchedApplication + subscription := rch.MatchedSubscription + + if api != nil && api.BackendJwtConfiguration != nil && api.BackendJwtConfiguration.Enabled { + bjc := api.BackendJwtConfiguration + customClaims := bjc.CustomClaims + if customClaims == nil { + customClaims = make(map[string]*dto.ClaimValue) + } + customClaims["iss"] = &dto.ClaimValue{ + Value: apiGatewayID, + Type: "string", + } + currentTime := time.Now().Unix() + expireIn := currentTime + bjc.TTL + customClaims["exp"] = &dto.ClaimValue{ + Value: fmt.Sprintf("%d", expireIn), + Type: "int", + } + customClaims["iat"] = &dto.ClaimValue{ + Value: fmt.Sprintf("%d", currentTime), + Type: "int", + } + customClaims[dialectURI+"apiname"] = &dto.ClaimValue{ + Value: api.Name, + Type: "string", + } + customClaims[dialectURI+"apicontext"] = &dto.ClaimValue{ + Value: api.BasePath, + Type: "string", + } + customClaims[dialectURI+"version"] = &dto.ClaimValue{ + Value: api.Version, + Type: "string", + } + customClaims[dialectURI+"keytype"] = &dto.ClaimValue{ + Value: api.EnvType, + Type: "string", + } + if application != nil { + customClaims[dialectURI+"subscriber"] = &dto.ClaimValue{ + Value: application.Owner, + Type: "string", + } + customClaims[dialectURI+"applicationid"] = &dto.ClaimValue{ + Value: application.UUID, + Type: "string", + } + customClaims[dialectURI+"applicationname"] = &dto.ClaimValue{ + Value: application.Name, + Type: "string", + } + customClaims[dialectURI+"applicationtier"] = &dto.ClaimValue{ + Value: subscription.RatelimitTier, + Type: "string", + } + } + if subscription != nil { + customClaims[dialectURI+"tier"] = &dto.ClaimValue{ + Value: subscription.RatelimitTier, + Type: "string", + } + } + signatureAlgorithm := bjc.SignatureAlgorithm + if signatureAlgorithm != "NONE" && signatureAlgorithm != sha256WithRSA { + signatureAlgorithm = sha256WithRSA + } + + return util.GenerateJWTToken(signatureAlgorithm, true, bjc.PublicCert, customClaims, bjc.PrivateKey) + } + return "" +} diff --git a/gateway/enforcer/internal/requestconfig/api.go b/gateway/enforcer/internal/requestconfig/api.go index 7ba5a163d..b198440ff 100644 --- a/gateway/enforcer/internal/requestconfig/api.go +++ b/gateway/enforcer/internal/requestconfig/api.go @@ -23,33 +23,33 @@ import ( // API is a struct that represents an API type API struct { - Name string `json:"name"` // Name of the API - Version string `json:"version"` // API version - Vhost string `json:"vhost"` // Virtual host for the API - BasePath string `json:"basePath"` // Base path for the API - APIType string `json:"apiType"` // Type of the API - EnvType string `json:"envType"` // Environment type (e.g., production, sandbox) - APILifeCycleState string `json:"apiLifeCycleState"` // Lifecycle state of the API - AuthorizationHeader string `json:"authorizationHeader"` // Authorization header used by the API - OrganizationID string `json:"organizationId"` // Organization ID for the API - UUID string `json:"uuid"` // Unique identifier for the API - Tier string `json:"tier"` // API tier (e.g., Unlimited) - DisableAuthentication bool `json:"disableAuthentication"` // Whether authentication is disabled - DisableScopes bool `json:"disableScopes"` // Whether scopes are disabled - Resources []Resource `json:"resources"` // List of resources for the API - IsMockedAPI bool `json:"isMockedApi"` // Whether the API is mocked - MutualSSL string `json:"mutualSSL"` // Mutual SSL configuration - TransportSecurity bool `json:"transportSecurity"` // Whether transport security is enabled - ApplicationSecurity map[string]bool `json:"applicationSecurity"` // Application security settings - JwtConfigurationDto *dto.JWTConfiguration `json:"jwtConfigurationDto"` // JWT configuration DTO - SystemAPI bool `json:"systemAPI"` // Whether the API is a system API - APIDefinition []byte `json:"apiDefinition"` // API definition (e.g., Swagger) - Environment string `json:"environment"` // API environment (e.g., development, production) - SubscriptionValidation bool `json:"subscriptionValidation"` // Whether subscription validation is enabled - EndpointSecurity []EndpointSecurity `json:"endpointSecurity"` // Endpoint security configurations - Endpoints EndpointCluster `json:"endpoints"` // Endpoint cluster for the API - AiProvider *dto.AIProvider `json:"aiProvider"` // AI provider configuration - AIModelBasedRoundRobin *dto.AIModelBasedRoundRobin `json:"aiModelBasedRoundRobin"` // AI model-based round robin configuration - DoSubscriptionAIRLInHeaderReponse bool `json:"doSubscriptionAIRLInHeaderReponse"` // Whether to include subscription AIRL in header response - DoSubscriptionAIRLInBodyReponse bool `json:"doSubscriptionAIRLInBodyReponse"` // Whether to include subscription AIRL in body response + Name string `json:"name"` // Name of the API + Version string `json:"version"` // API version + Vhost string `json:"vhost"` // Virtual host for the API + BasePath string `json:"basePath"` // Base path for the API + APIType string `json:"apiType"` // Type of the API + EnvType string `json:"envType"` // Environment type (e.g., production, sandbox) + APILifeCycleState string `json:"apiLifeCycleState"` // Lifecycle state of the API + AuthorizationHeader string `json:"authorizationHeader"` // Authorization header used by the API + OrganizationID string `json:"organizationId"` // Organization ID for the API + UUID string `json:"uuid"` // Unique identifier for the API + Tier string `json:"tier"` // API tier (e.g., Unlimited) + DisableAuthentication bool `json:"disableAuthentication"` // Whether authentication is disabled + DisableScopes bool `json:"disableScopes"` // Whether scopes are disabled + Resources []Resource `json:"resources"` // List of resources for the API + IsMockedAPI bool `json:"isMockedApi"` // Whether the API is mocked + MutualSSL string `json:"mutualSSL"` // Mutual SSL configuration + TransportSecurity bool `json:"transportSecurity"` // Whether transport security is enabled + ApplicationSecurity map[string]bool `json:"applicationSecurity"` // Application security settings + BackendJwtConfiguration *dto.BackendJWTConfiguration `json:"jwtConfigurationDto"` // JWT configuration DTO + SystemAPI bool `json:"systemAPI"` // Whether the API is a system API + APIDefinition []byte `json:"apiDefinition"` // API definition (e.g., Swagger) + Environment string `json:"environment"` // API environment (e.g., development, production) + SubscriptionValidation bool `json:"subscriptionValidation"` // Whether subscription validation is enabled + EndpointSecurity []EndpointSecurity `json:"endpointSecurity"` // Endpoint security configurations + Endpoints EndpointCluster `json:"endpoints"` // Endpoint cluster for the API + AiProvider *dto.AIProvider `json:"aiProvider"` // AI provider configuration + AIModelBasedRoundRobin *dto.AIModelBasedRoundRobin `json:"aiModelBasedRoundRobin"` // AI model-based round robin configuration + DoSubscriptionAIRLInHeaderReponse bool `json:"doSubscriptionAIRLInHeaderReponse"` // Whether to include subscription AIRL in header response + DoSubscriptionAIRLInBodyReponse bool `json:"doSubscriptionAIRLInBodyReponse"` // Whether to include subscription AIRL in body response } diff --git a/gateway/enforcer/internal/util/cert.go b/gateway/enforcer/internal/util/cert.go index 47d3f3a67..396dfdd8a 100644 --- a/gateway/enforcer/internal/util/cert.go +++ b/gateway/enforcer/internal/util/cert.go @@ -20,6 +20,7 @@ package util import ( "crypto/tls" "crypto/x509" + "encoding/pem" "fmt" "io/fs" "io/ioutil" @@ -94,3 +95,23 @@ func CreateTLSConfig(cert tls.Certificate, certPool *x509.CertPool) *tls.Config RootCAs: certPool, } } + +// LoadCertificate loads an x509 certificate from a file path +func LoadCertificate(path string) (*x509.Certificate, error) { + data, err := ioutil.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("failed to read certificate file: %w", err) + } + + block, _ := pem.Decode(data) + if block == nil { + return nil, fmt.Errorf("failed to decode PEM block") + } + + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return nil, fmt.Errorf("failed to parse certificate: %w", err) + } + + return cert, nil +} diff --git a/gateway/enforcer/internal/util/conversion.go b/gateway/enforcer/internal/util/conversion.go index a134d4835..ee45021ed 100644 --- a/gateway/enforcer/internal/util/conversion.go +++ b/gateway/enforcer/internal/util/conversion.go @@ -5,7 +5,6 @@ import ( "strconv" ) - // ConvertBytesToInt converts a []byte to an int. // It assumes the []byte contains a valid numeric string (e.g., "123"). func ConvertBytesToInt(data []byte) (int, error) { @@ -29,4 +28,4 @@ func ConvertStringToInt(input string) (int, error) { return 0, fmt.Errorf("invalid input: %s, error: %w", input, err) } return num, nil -} +} \ No newline at end of file diff --git a/gateway/enforcer/internal/util/jwt.go b/gateway/enforcer/internal/util/jwt.go new file mode 100644 index 000000000..d77d6a900 --- /dev/null +++ b/gateway/enforcer/internal/util/jwt.go @@ -0,0 +1,242 @@ +package util + +import ( + "crypto" + "crypto/rand" + "crypto/rsa" + "crypto/sha1" + "crypto/sha256" + "crypto/x509" + "encoding/base64" + "encoding/json" + "encoding/pem" + "errors" + "fmt" + "os" + "strconv" + "strings" + "time" + + "github.com/wso2/apk/gateway/enforcer/internal/dto" +) + +// GenerateHeader generates the JWT header. +func GenerateHeader(useKid bool, pubCert *x509.Certificate, signatureAlgorithm string) (string, error) { + if signatureAlgorithm == "NONE" { + return `{"typ":"JWT","alg":"NONE"}`, nil + } + + header := fmt.Sprintf(`{"typ":"JWT","alg":"RS256"`) + + if useKid { + header += fmt.Sprintf(`,"kid":"%v"`, useKid) + } else { + thumbprint, err := GenerateThumbprint("SHA-1", pubCert, true) + if err != nil { + return "", fmt.Errorf("error in generating public certificate thumbprint: %w", err) + } + header += fmt.Sprintf(`,"x5t":"%s"`, thumbprint) + } + + header += "}" + return header, nil +} + +// GenerateThumbprint generates a thumbprint of the public certificate. +func GenerateThumbprint(hashType string, publicCert *x509.Certificate, usePadding bool) (string, error) { + if publicCert == nil { + return "", errors.New("public certificate is nil") + } + + hash := sha1.New() + if hashType != "SHA-1" { + return "", errors.New("unsupported hash type") + } + + hash.Write(publicCert.Raw) + digestInBytes := hash.Sum(nil) + publicCertThumbprint := Hexify(digestInBytes) + + var base64UrlEncodedThumbPrint string + if usePadding { + base64UrlEncodedThumbPrint = base64.URLEncoding.EncodeToString([]byte(publicCertThumbprint)) + } else { + base64UrlEncodedThumbPrint = base64.RawURLEncoding.EncodeToString([]byte(publicCertThumbprint)) + } + + return base64UrlEncodedThumbPrint, nil +} + +// Hexify converts a byte slice to a hex string. +func Hexify(bytes []byte) string { + hexDigits := []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'} + var builder strings.Builder + builder.Grow(len(bytes) * 2) + + for _, b := range bytes { + builder.WriteRune(hexDigits[(b&0xf0)>>4]) + builder.WriteRune(hexDigits[b&0x0f]) + } + + return builder.String() +} + +// BuildBody builds the JWT body. +func BuildBody(claims map[string]*dto.ClaimValue) (string, error) { + customClaims := make(map[string]interface{}) + + // Populate custom claims + for key, claim := range claims { + var finalValue interface{} = claim.Value + switch strings.ToLower(claim.Type) { + case "string": + finalValue = claim.Value + case "bool": + finalValue = claim.Value == "true" + case "int": + finalValue = parseToInt(claim.Value) + case "long": + finalValue = parseToInt(claim.Value) + case "float": + finalValue = parseToFloat(claim.Value) + case "date": + parsedDate, err := time.Parse("2006-01-02", claim.Value) + if err == nil { + finalValue = parsedDate + } + } + customClaims[key] = finalValue + } + + // Convert claims to JSON + jsonClaims, err := json.Marshal(customClaims) + if err != nil { + return "", fmt.Errorf("error marshaling claims to JSON: %w", err) + } + + return string(jsonClaims), nil +} + +func parseToInt(value string) int64 { + parsed, err := strconv.ParseInt(value, 10, 64) + if err != nil { + return 0 + } + return parsed +} + +func parseToFloat(value string) float64 { + parsed, err := strconv.ParseFloat(value, 64) + if err != nil { + return 0 + } + return parsed +} + +// BuildHeader builds the JWT header. +func BuildHeader(publicCert *x509.Certificate, useKid bool, signatureAlgorithm string) (string, error) { + if signatureAlgorithm == "NONE" { + return `{"typ":"JWT","alg":"NONE"}`, nil + } + + header := `{"typ":"JWT","alg":"RS256"` + + if useKid { + header += `,"kid":"true"` + } else { + thumbprint, err := GenerateThumbprint("SHA-1", publicCert, true) + if err != nil { + return "", fmt.Errorf("error generating public certificate thumbprint: %w", err) + } + header += fmt.Sprintf(`,"x5t":"%s"`, thumbprint) + } + + header += "}" + return header, nil +} + +// GenerateJWTToken generates a JWT token. +func GenerateJWTToken(signatureAlgo string, useKid bool, publicCert *x509.Certificate, claims map[string]*dto.ClaimValue, privateKey *rsa.PrivateKey) string { + header, err := BuildHeader(publicCert, true, signatureAlgo) + if err != nil { + return "" + } + body, err := BuildBody(claims) + if err != nil { + return "" + } + // Base64 encode header and body + + if signatureAlgo != "SHA256withRSA" { + base64Header := base64.RawURLEncoding.EncodeToString([]byte(header)) + base64Body := base64.RawURLEncoding.EncodeToString([]byte(body)) + + // Concatenate header and body with a period + unsignedToken := fmt.Sprintf("%s.%s", base64Header, base64Body) + return fmt.Sprintf("%s.", unsignedToken) + } + // Sign the token + jwtToken, err := signJWT(header, body, privateKey) + if err != nil { + return "" + } + + return jwtToken + // Generate JWT token +} + +// LoadPrivateKey Read Private Key from a PEM file +func LoadPrivateKey(filename string) (*rsa.PrivateKey, error) { + keyBytes, err := os.ReadFile(filename) + if err != nil { + return nil, err + } + + block, _ := pem.Decode(keyBytes) + if block == nil { + return nil, fmt.Errorf("failed to decode PEM block") + } + + // Try parsing as PKCS#1 + if privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes); err == nil { + return privateKey, nil + } + + // Try parsing as PKCS#8 + key, err := x509.ParsePKCS8PrivateKey(block.Bytes) + if err != nil { + return nil, fmt.Errorf("failed to parse private key: %v", err) + } + + // Ensure it's an RSA key + rsaKey, ok := key.(*rsa.PrivateKey) + if !ok { + return nil, fmt.Errorf("private key is not an RSA key") + } + + return rsaKey, nil +} + +// Sign JWT using SHA256withRSA +func signJWT(header, payload string, privateKey *rsa.PrivateKey) (string, error) { + // Create signing string + signingInput := base64URLEncode([]byte(header)) + "." + base64URLEncode([]byte(payload)) + + // Hash the signing input + hashed := sha256.Sum256([]byte(signingInput)) + + // Sign with RSA private key + signature, err := rsa.SignPKCS1v15(rand.Reader, privateKey, crypto.SHA256, hashed[:]) + if err != nil { + return "", fmt.Errorf("failed to sign JWT: %v", err) + } + + // Encode the signature in Base64URL + signedJWT := signingInput + "." + base64URLEncode(signature) + return signedJWT, nil +} + +// base64URLEncode Base64 URL Encoding (JWT-safe) +func base64URLEncode(input []byte) string { + return base64.RawURLEncoding.EncodeToString(input) +} diff --git a/gateway/enforcer/internal/xds/client_manager.go b/gateway/enforcer/internal/xds/client_manager.go index dfc0c30b1..392eada12 100644 --- a/gateway/enforcer/internal/xds/client_manager.go +++ b/gateway/enforcer/internal/xds/client_manager.go @@ -48,7 +48,7 @@ func CreateXDSClients(cfg *config.Server) (*datastore.APIStore, *datastore.Confi tlsConfig := util.CreateTLSConfig(clientCert, certPool) configDatastore := datastore.NewConfigStore() jwtIssuerDatastore := datastore.NewJWTIssuerStore() - apiDatastore := datastore.NewAPIStore(configDatastore) + apiDatastore := datastore.NewAPIStore(configDatastore, cfg) // Initialize the tracker modelBasedRoundRobinTracker := datastore.NewModelBasedRoundRobinTracker() // Start the reactivation task