package jwt

import (
	"context"
	"crypto/rsa"
	"io/ioutil"
	"net/http"
	"strings"
	"time"

	"github.com/dgrijalva/jwt-go"
	"github.com/gogf/gf/v2/crypto/gmd5"
	"github.com/gogf/gf/v2/frame/g"
	"github.com/gogf/gf/v2/net/ghttp"
	"github.com/gogf/gf/v2/os/gcache"
)

// MapClaims type that uses the map[string]interface{} for JSON decoding
// This is the default claims type if you don't supply one
type MapClaims map[string]interface{}

// GfJWTMiddleware provides a Json-Web-Token authentication implementation. On failure, a 401 HTTP response
// is returned. On success, the wrapped middleware is called, and the userID is made available as
// c.Get("userID").(string).
// Users can get a token by posting a json request to LoginHandler. The token then needs to be passed in
// the Authentication header. Example: Authorization:Bearer XXX_TOKEN_XXX
type GfJWTMiddleware struct {
	// Realm name to display to the user. Required.
	Realm string

	// signing algorithm - possible values are HS256, HS384, HS512
	// Optional, default is HS256.
	SigningAlgorithm string

	// Secret key used for signing. Required.
	Key []byte

	// Duration that a jwt token is valid. Optional, defaults to one hour.
	Timeout time.Duration

	// This field allows clients to refresh their token until MaxRefresh has passed.
	// Note that clients can refresh their token in the last moment of MaxRefresh.
	// This means that the maximum validity timespan for a token is TokenTime + MaxRefresh.
	// Optional, defaults to 0 meaning not refreshable.
	MaxRefresh time.Duration

	// Callback function that should perform the authentication of the user based on login info.
	// Must return user data as user identifier, it will be stored in Claim Array. Required.
	// Check error (e) to determine the appropriate error message.
	Authenticator func(r *ghttp.Request) (interface{}, error)

	// Callback function that should perform the authorization of the authenticated user. Called
	// only after an authentication success. Must return true on success, false on failure.
	// Optional, default to success.
	Authorizator func(data interface{}, r *ghttp.Request) bool

	// Callback function that will be called during login.
	// Using this function it is possible to add additional payload data to the webtoken.
	// The data is then made available during requests via c.Get(jwt.PayloadKey).
	// Note that the payload is not encrypted.
	// The attributes mentioned on jwt.io can't be used as keys for the map.
	// Optional, by default no additional data will be set.
	PayloadFunc func(data interface{}) MapClaims

	// User can define own Unauthorized func.
	Unauthorized func(*ghttp.Request, int, string)

	// User can define own LoginResponse func.
	LoginResponse func(*ghttp.Request, int, string, time.Time)

	// User can define own RefreshResponse func.
	RefreshResponse func(*ghttp.Request, int, string, time.Time)

	// User can define own LogoutResponse func.
	LogoutResponse func(*ghttp.Request, int)

	// Set the identity handler function
	IdentityHandler func(*ghttp.Request) interface{}

	// Set the identity key
	IdentityKey string

	// TokenLookup is a string in the form of "<source>:<name>" that is used
	// to extract token from the request.
	// Optional. Default value "header:Authorization".
	// Possible values:
	// - "header:<name>"
	// - "query:<name>"
	// - "cookie:<name>"
	TokenLookup string

	// TokenHeadName is a string in the header. Default value is "Bearer"
	TokenHeadName string

	// TimeFunc provides the current time. You can override it to use another time value. This is useful for testing or if your server uses a different time zone than your tokens.
	TimeFunc func() time.Time

	// HTTP Status messages for when something in the JWT middleware fails.
	// Check error (e) to determine the appropriate error message.
	HTTPStatusMessageFunc func(e error, r *ghttp.Request) string

	// Private key file for asymmetric algorithms
	PrivKeyFile string

	// Public key file for asymmetric algorithms
	PubKeyFile string

	// Private key
	privKey *rsa.PrivateKey

	// Public key
	pubKey *rsa.PublicKey

	// Optionally return the token as a cookie
	SendCookie bool

	// Allow insecure cookies for development over http
	SecureCookie bool

	// Allow cookies to be accessed client side for development
	CookieHTTPOnly bool

	// Allow cookie domain change for development
	CookieDomain string

	// SendAuthorization allow return authorization header for every request
	SendAuthorization bool

	// Disable abort() of context.
	DisabledAbort bool

	// CookieName allow cookie name change for development
	CookieName string

	// CacheAdapter
	CacheAdapter gcache.Adapter
}

var (
	// TokenKey default jwt token key in params
	TokenKey = "JWT_TOKEN"
	// PayloadKey default jwt payload key in params
	PayloadKey = "JWT_PAYLOAD"
	// IdentityKey default identity key
	IdentityKey = "identity"
	// The blacklist stores tokens that have not expired but have been deactivated.
	blacklist = gcache.New()
)

// New for check error with GfJWTMiddleware
func New(m *GfJWTMiddleware) (*GfJWTMiddleware, error) {
	if err := m.MiddlewareInit(); err != nil {
		return nil, err
	}

	return m, nil
}

func (mw *GfJWTMiddleware) readKeys() error {
	err := mw.privateKey()
	if err != nil {
		return err
	}
	err = mw.publicKey()
	if err != nil {
		return err
	}
	return nil
}

func (mw *GfJWTMiddleware) privateKey() error {
	keyData, err := ioutil.ReadFile(mw.PrivKeyFile)
	if err != nil {
		return ErrNoPrivKeyFile
	}
	key, err := jwt.ParseRSAPrivateKeyFromPEM(keyData)
	if err != nil {
		return ErrInvalidPrivKey
	}
	mw.privKey = key
	return nil
}

func (mw *GfJWTMiddleware) publicKey() error {
	keyData, err := ioutil.ReadFile(mw.PubKeyFile)
	if err != nil {
		return ErrNoPubKeyFile
	}
	key, err := jwt.ParseRSAPublicKeyFromPEM(keyData)
	if err != nil {
		return ErrInvalidPubKey
	}
	mw.pubKey = key
	return nil
}

func (mw *GfJWTMiddleware) usingPublicKeyAlgo() bool {
	switch mw.SigningAlgorithm {
	case "RS256", "RS512", "RS384":
		return true
	}
	return false
}

// MiddlewareInit initialize jwt configs.
func (mw *GfJWTMiddleware) MiddlewareInit() error {

	if mw.TokenLookup == "" {
		mw.TokenLookup = "header:Authorization"
	}

	if mw.SigningAlgorithm == "" {
		mw.SigningAlgorithm = "HS256"
	}

	if mw.Timeout == 0 {
		mw.Timeout = time.Hour
	}

	if mw.TimeFunc == nil {
		mw.TimeFunc = time.Now
	}

	mw.TokenHeadName = strings.TrimSpace(mw.TokenHeadName)
	if len(mw.TokenHeadName) == 0 {
		mw.TokenHeadName = "Bearer"
	}

	if mw.Authorizator == nil {
		mw.Authorizator = func(data interface{}, r *ghttp.Request) bool {
			return true
		}
	}

	if mw.Unauthorized == nil {
		mw.Unauthorized = func(r *ghttp.Request, code int, message string) {
			r.Response.WriteJson(g.Map{
				"code":    code,
				"message": message,
			})
		}
	}

	if mw.LoginResponse == nil {
		mw.LoginResponse = func(r *ghttp.Request, code int, token string, expire time.Time) {
			r.Response.WriteJson(g.Map{
				"code":   http.StatusOK,
				"token":  token,
				"expire": expire.Format(time.RFC3339),
			})
		}
	}

	if mw.RefreshResponse == nil {
		mw.RefreshResponse = func(r *ghttp.Request, code int, token string, expire time.Time) {
			r.Response.WriteJson(g.Map{
				"code":   http.StatusOK,
				"token":  token,
				"expire": expire.Format(time.RFC3339),
			})
		}
	}

	if mw.LogoutResponse == nil {
		mw.LogoutResponse = func(r *ghttp.Request, code int) {
			r.Response.WriteJson(g.Map{
				"code":    http.StatusOK,
				"message": "success",
			})
		}
	}

	if mw.IdentityKey == "" {
		mw.IdentityKey = IdentityKey
	}

	if mw.IdentityHandler == nil {
		mw.IdentityHandler = func(r *ghttp.Request) interface{} {
			claims := ExtractClaims(r)
			return claims[mw.IdentityKey]
		}
	}

	if mw.HTTPStatusMessageFunc == nil {
		mw.HTTPStatusMessageFunc = func(e error, r *ghttp.Request) string {
			return e.Error()
		}
	}

	if mw.Realm == "" {
		mw.Realm = "gf jwt"
	}

	if mw.CookieName == "" {
		mw.CookieName = "jwt"
	}

	if mw.usingPublicKeyAlgo() {
		return mw.readKeys()
	}

	if mw.Key == nil {
		return ErrMissingSecretKey
	}

	if mw.CacheAdapter != nil {
		blacklist.SetAdapter(mw.CacheAdapter)
	}

	return nil
}

// MiddlewareFunc makes GfJWTMiddleware implement the Middleware interface.
func (mw *GfJWTMiddleware) MiddlewareFunc() ghttp.HandlerFunc {
	return func(r *ghttp.Request) {
		mw.middlewareImpl(r)
	}
}

func (mw *GfJWTMiddleware) middlewareImpl(r *ghttp.Request) {
	claims, token, err := mw.GetClaimsFromJWT(r)
	if err != nil {
		mw.unauthorized(r, http.StatusUnauthorized, mw.HTTPStatusMessageFunc(err, r))
		return
	}

	if claims["exp"] == nil {
		mw.unauthorized(r, http.StatusBadRequest, mw.HTTPStatusMessageFunc(ErrMissingExpField, r))
		return
	}

	if _, ok := claims["exp"].(float64); !ok {
		mw.unauthorized(r, http.StatusBadRequest, mw.HTTPStatusMessageFunc(ErrWrongFormatOfExp, r))
		return
	}

	if int64(claims["exp"].(float64)) < mw.TimeFunc().Unix() {
		mw.unauthorized(r, http.StatusUnauthorized, mw.HTTPStatusMessageFunc(ErrExpiredToken, r))
		return
	}

	in, err := mw.inBlacklist(token)
	if err != nil {
		mw.unauthorized(r, http.StatusUnauthorized, mw.HTTPStatusMessageFunc(err, r))
		return
	}

	if in {
		mw.unauthorized(r, http.StatusUnauthorized, mw.HTTPStatusMessageFunc(ErrInvalidToken, r))
		return
	}

	r.SetParam(PayloadKey, claims)
	identity := mw.IdentityHandler(r)

	if identity != nil {
		r.SetParam(mw.IdentityKey, identity)
	}

	if !mw.Authorizator(identity, r) {
		mw.unauthorized(r, http.StatusForbidden, mw.HTTPStatusMessageFunc(ErrForbidden, r))
		return
	}

	//c.Next() todo
}

// GetClaimsFromJWT get claims from JWT token
func (mw *GfJWTMiddleware) GetClaimsFromJWT(r *ghttp.Request) (MapClaims, string, error) {
	token, err := mw.ParseToken(r)

	if err != nil {
		return nil, "", err
	}

	if mw.SendAuthorization {
		token := r.Get(TokenKey).String()
		if len(token) > 0 {
			r.Header.Set("Authorization", mw.TokenHeadName+" "+token)
		}
	}

	claims := MapClaims{}
	for key, value := range token.Claims.(jwt.MapClaims) {
		claims[key] = value
	}

	return claims, token.Raw, nil
}

// LoginHandler can be used by clients to get a jwt token.
// Payload needs to be json in the form of {"username": "USERNAME", "password": "PASSWORD"}.
// Reply will be of the form {"token": "TOKEN"}.
func (mw *GfJWTMiddleware) LoginHandler(r *ghttp.Request) {
	if mw.Authenticator == nil {
		mw.unauthorized(r, http.StatusInternalServerError, mw.HTTPStatusMessageFunc(ErrMissingAuthenticatorFunc, r))
		return
	}

	data, err := mw.Authenticator(r)

	if err != nil {
		mw.unauthorized(r, http.StatusUnauthorized, mw.HTTPStatusMessageFunc(err, r))
		return
	}

	// Create the token
	token := jwt.New(jwt.GetSigningMethod(mw.SigningAlgorithm))
	claims := token.Claims.(jwt.MapClaims)

	if mw.PayloadFunc != nil {
		for key, value := range mw.PayloadFunc(data) {
			claims[key] = value
		}
	}

	if _, ok := claims[mw.IdentityKey]; !ok {
		mw.unauthorized(r, http.StatusInternalServerError, mw.HTTPStatusMessageFunc(ErrMissingIdentity, r))
		return
	}

	expire := mw.TimeFunc().Add(mw.Timeout)
	claims["exp"] = expire.Unix()
	claims["iat"] = mw.TimeFunc().Unix()
	tokenString, err := mw.signedString(token)

	if err != nil {
		mw.unauthorized(r, http.StatusUnauthorized, mw.HTTPStatusMessageFunc(ErrFailedTokenCreation, r))
		return
	}

	// set cookie
	if mw.SendCookie {
		maxage := int64(expire.Unix() - time.Now().Unix())
		r.Cookie.SetCookie(mw.CookieName, tokenString, mw.CookieDomain, "/", time.Duration(maxage)*time.Second)
	}

	mw.LoginResponse(r, http.StatusOK, tokenString, expire)
}

func (mw *GfJWTMiddleware) signedString(token *jwt.Token) (string, error) {
	var tokenString string
	var err error
	if mw.usingPublicKeyAlgo() {
		tokenString, err = token.SignedString(mw.privKey)
	} else {
		tokenString, err = token.SignedString(mw.Key)
	}
	return tokenString, err
}

// LogoutHandler can be used to logout a token. The token still needs to be valid on logout.
// Logout the token puts the unexpired token on a blacklist.
func (mw *GfJWTMiddleware) LogoutHandler(r *ghttp.Request) {
	claims, token, err := mw.CheckIfTokenExpire(r)
	if err != nil {
		mw.unauthorized(r, http.StatusUnauthorized, mw.HTTPStatusMessageFunc(err, r))
		return
	}

	err = mw.setBlacklist(token, claims)

	if err != nil {
		mw.unauthorized(r, http.StatusUnauthorized, mw.HTTPStatusMessageFunc(err, r))
		return
	}

	mw.LogoutResponse(r, http.StatusOK)
}

// RefreshHandler can be used to refresh a token. The token still needs to be valid on refresh.
// Shall be put under an endpoint that is using the GfJWTMiddleware.
// Reply will be of the form {"token": "TOKEN"}.
func (mw *GfJWTMiddleware) RefreshHandler(r *ghttp.Request) {
	tokenString, expire, err := mw.RefreshToken(r)
	if err != nil {
		mw.unauthorized(r, http.StatusUnauthorized, mw.HTTPStatusMessageFunc(err, r))
		return
	}

	mw.RefreshResponse(r, http.StatusOK, tokenString, expire)
}

// RefreshToken refresh token and check if token is expired
func (mw *GfJWTMiddleware) RefreshToken(r *ghttp.Request) (string, time.Time, error) {
	claims, token, err := mw.CheckIfTokenExpire(r)
	if err != nil {
		return "", time.Now(), err
	}

	// Create the token
	newToken := jwt.New(jwt.GetSigningMethod(mw.SigningAlgorithm))
	newClaims := newToken.Claims.(jwt.MapClaims)

	for key := range claims {
		newClaims[key] = claims[key]
	}

	expire := mw.TimeFunc().Add(mw.Timeout)
	newClaims["exp"] = expire.Unix()
	newClaims["iat"] = mw.TimeFunc().Unix()
	tokenString, err := mw.signedString(newToken)

	if err != nil {
		return "", time.Now(), err
	}

	// set cookie
	if mw.SendCookie {
		maxage := int64(expire.Unix() - time.Now().Unix())
		r.Cookie.SetCookie(mw.CookieName, tokenString, mw.CookieDomain, "/", time.Duration(maxage)*time.Second)
	}

	// set old token in blacklist
	err = mw.setBlacklist(token, claims)
	if err != nil {
		return "", time.Now(), err
	}

	return tokenString, expire, nil
}

// CheckIfTokenExpire check if token expire
func (mw *GfJWTMiddleware) CheckIfTokenExpire(r *ghttp.Request) (jwt.MapClaims, string, error) {
	token, err := mw.ParseToken(r)

	if err != nil {
		// If we receive an error, and the error is anything other than a single
		// ValidationErrorExpired, we want to return the error.
		// If the error is just ValidationErrorExpired, we want to continue, as we can still
		// refresh the token if it's within the MaxRefresh time.
		// (see https://github.com/appleboy/gin-jwt/issues/176)
		validationErr, ok := err.(*jwt.ValidationError)
		if !ok || validationErr.Errors != jwt.ValidationErrorExpired {
			return nil, "", err
		}
	}

	in, err := mw.inBlacklist(token.Raw)

	if err != nil {
		return nil, "", err
	}

	if in {
		return nil, "", ErrInvalidToken
	}

	claims := token.Claims.(jwt.MapClaims)

	origIat := int64(claims["iat"].(float64))

	if origIat < mw.TimeFunc().Add(-mw.MaxRefresh).Unix() {
		return nil, "", ErrExpiredToken
	}

	return claims, token.Raw, nil
}

// TokenGenerator method that clients can use to get a jwt token.
func (mw *GfJWTMiddleware) TokenGenerator(data interface{}) (string, time.Time, error) {
	token := jwt.New(jwt.GetSigningMethod(mw.SigningAlgorithm))
	claims := token.Claims.(jwt.MapClaims)

	if mw.PayloadFunc != nil {
		for key, value := range mw.PayloadFunc(data) {
			claims[key] = value
		}
	}

	expire := mw.TimeFunc().UTC().Add(mw.Timeout)
	claims["exp"] = expire.Unix()
	claims["iat"] = mw.TimeFunc().Unix()
	tokenString, err := mw.signedString(token)
	if err != nil {
		return "", time.Time{}, err
	}

	return tokenString, expire, nil
}

func (mw *GfJWTMiddleware) jwtFromHeader(r *ghttp.Request, key string) (string, error) {
	authHeader := r.Header.Get(key)

	if authHeader == "" {
		return "", ErrEmptyAuthHeader
	}

	parts := strings.SplitN(authHeader, " ", 2)
	if !(len(parts) == 2 && parts[0] == mw.TokenHeadName) {
		return "", ErrInvalidAuthHeader
	}

	return parts[1], nil
}

func (mw *GfJWTMiddleware) jwtFromQuery(r *ghttp.Request, key string) (string, error) {
	token := r.Get(key).String()

	if token == "" {
		return "", ErrEmptyQueryToken
	}

	return token, nil
}

func (mw *GfJWTMiddleware) jwtFromCookie(r *ghttp.Request, key string) (string, error) {
	cookie := r.Cookie.Get(key).String()

	if cookie == "" {
		return "", ErrEmptyCookieToken
	}

	return cookie, nil
}

func (mw *GfJWTMiddleware) jwtFromParam(r *ghttp.Request, key string) (string, error) {
	token := r.Get(key).String()
	if token == "" {
		return "", ErrEmptyParamToken
	}

	return token, nil
}

// ParseToken parse jwt token
func (mw *GfJWTMiddleware) ParseToken(r *ghttp.Request) (*jwt.Token, error) {
	var token string
	var err error

	methods := strings.Split(mw.TokenLookup, ",")
	for _, method := range methods {
		if len(token) > 0 {
			break
		}
		parts := strings.Split(strings.TrimSpace(method), ":")
		k := strings.TrimSpace(parts[0])
		v := strings.TrimSpace(parts[1])
		switch k {
		case "header":
			token, err = mw.jwtFromHeader(r, v)
		case "query":
			token, err = mw.jwtFromQuery(r, v)
		case "cookie":
			token, err = mw.jwtFromCookie(r, v)
		case "param":
			token, err = mw.jwtFromParam(r, v)
		}
	}

	if err != nil {
		return nil, err
	}

	return jwt.Parse(token, func(t *jwt.Token) (interface{}, error) {
		if jwt.GetSigningMethod(mw.SigningAlgorithm) != t.Method {
			return nil, ErrInvalidSigningAlgorithm
		}
		if mw.usingPublicKeyAlgo() {
			return mw.pubKey, nil
		}

		// save token string if vaild
		r.SetParam(TokenKey, token)

		return mw.Key, nil
	})
}

func (mw *GfJWTMiddleware) unauthorized(r *ghttp.Request, code int, message string) {
	r.Header.Set("WWW-Authenticate", "JWT realm="+mw.Realm)
	mw.Unauthorized(r, code, message)
	if !mw.DisabledAbort {
		r.ExitAll()
	}

}

func (mw *GfJWTMiddleware) setBlacklist(token string, claims jwt.MapClaims) error {
	// The goal of MD5 is to reduce the key length.
	token, err := gmd5.EncryptString(token)

	if err != nil {
		return err
	}

	exp := int64(claims["exp"].(float64))

	// save duration time = (exp + max_refresh) - now
	duration := time.Unix(exp, 0).Add(mw.MaxRefresh).Sub(mw.TimeFunc()).Truncate(time.Second)

	// global gcache
	err = blacklist.Set(context.Background(), token, true, duration)

	if err != nil {
		return err
	}

	return nil
}

func (mw *GfJWTMiddleware) inBlacklist(token string) (bool, error) {
	// The goal of MD5 is to reduce the key length.
	tokenRaw, err := gmd5.EncryptString(token)

	if err != nil {
		return false, nil
	}

	// Global gcache
	if in, err := blacklist.Contains(context.Background(), tokenRaw); err != nil {
		return false, nil
	} else {
		return in, nil
	}
}

// ExtractClaims help to extract the JWT claims
func ExtractClaims(r *ghttp.Request) MapClaims {
	claims := r.GetParam(PayloadKey).Map()
	return claims
}

// GetToken help to get the JWT token string
func GetToken(r *ghttp.Request) string {
	token := r.Get(TokenKey).String()
	if len(token) == 0 {
		return ""
	}

	return token
}