diff --git a/backend/application/auth/refresh/request.go b/backend/application/auth/refresh/request.go new file mode 100644 index 00000000..1ba77367 --- /dev/null +++ b/backend/application/auth/refresh/request.go @@ -0,0 +1,11 @@ +package refresh + +type validationErrors map[string]string + +type Request struct { + Token string `json:"token"` +} + +func (r *Request) Validate() (bool, validationErrors) { + return true, nil +} diff --git a/backend/application/auth/refresh/response.go b/backend/application/auth/refresh/response.go new file mode 100644 index 00000000..f1f70a45 --- /dev/null +++ b/backend/application/auth/refresh/response.go @@ -0,0 +1,7 @@ +package refresh + +type RefreshResponse struct { + ValidationErrors validationErrors `json:"errors,omitempty"` + AccessToken string `json:"access_token,omitempty"` + RefreshToken string `json:"refresh_token,omitempty"` +} diff --git a/backend/application/auth/refresh/usecase.go b/backend/application/auth/refresh/usecase.go new file mode 100644 index 00000000..86e8c5f8 --- /dev/null +++ b/backend/application/auth/refresh/usecase.go @@ -0,0 +1,97 @@ +package refresh + +import ( + "time" + + "github.com/khanzadimahdi/testproject/application/auth" + "github.com/khanzadimahdi/testproject/domain/user" + "github.com/khanzadimahdi/testproject/infrastructure/jwt" +) + +type UseCase struct { + userRepository user.Repository + JWT *jwt.JWT +} + +func NewUseCase(userRepository user.Repository, JWT *jwt.JWT) *UseCase { + return &UseCase{ + userRepository: userRepository, + JWT: JWT, + } +} + +func (uc *UseCase) Login(request Request) (*RefreshResponse, error) { + if ok, validation := request.Validate(); !ok { + return &RefreshResponse{ + ValidationErrors: validation, + }, nil + } + + claims, err := uc.JWT.Verify(request.Token) + if err != nil { + return &RefreshResponse{ + ValidationErrors: validationErrors{ + "token": err.Error(), + }, + }, nil + } + + if audiences, err := claims.GetAudience(); err != nil || len(audiences) == 0 || audiences[0] != auth.RefreshToken { + return &RefreshResponse{ + ValidationErrors: validationErrors{ + "token": "refresh token is not valid", + }, + }, nil + } + + userUUID, err := claims.GetSubject() + if err != nil { + return &RefreshResponse{ + ValidationErrors: validationErrors{ + "token": err.Error(), + }, + }, nil + } + + u, err := uc.userRepository.GetOne(userUUID) + if err != nil { + return nil, err + } + + accessToken, err := uc.generateAccessToken(u) + if err != nil { + return nil, err + } + + refreshToken, err := uc.generateRefreshToken(u) + if err != nil { + return nil, err + } + + return &RefreshResponse{ + AccessToken: accessToken, + RefreshToken: refreshToken, + }, nil +} + +func (uc *UseCase) generateAccessToken(u user.User) (string, error) { + b := jwt.NewClaimsBuilder() + b.SetSubject(u.UUID) + b.SetNotBefore(time.Now()) + b.SetExpirationTime(time.Now().Add(15 * time.Minute)) + b.SetIssuedAt(time.Now()) + b.SetAudience([]string{auth.AccessToken}) + + return uc.JWT.Generate(b.Build()) +} + +func (uc *UseCase) generateRefreshToken(u user.User) (string, error) { + b := jwt.NewClaimsBuilder() + b.SetSubject(u.UUID) + b.SetNotBefore(time.Now()) + b.SetExpirationTime(time.Now().Add(2 * 24 * time.Hour)) + b.SetIssuedAt(time.Now()) + b.SetAudience([]string{auth.RefreshToken}) + + return uc.JWT.Generate(b.Build()) +} diff --git a/backend/application/auth/refresh/usecase_test.go b/backend/application/auth/refresh/usecase_test.go new file mode 100644 index 00000000..d8c53677 --- /dev/null +++ b/backend/application/auth/refresh/usecase_test.go @@ -0,0 +1 @@ +package refresh diff --git a/backend/main.go b/backend/main.go index dafb4911..d5f55891 100644 --- a/backend/main.go +++ b/backend/main.go @@ -17,6 +17,7 @@ import ( getArticles "github.com/khanzadimahdi/testproject/application/article/getArticles" "github.com/khanzadimahdi/testproject/application/article/getArticlesByHashtag" "github.com/khanzadimahdi/testproject/application/auth/login" + "github.com/khanzadimahdi/testproject/application/auth/refresh" dashboardCreateArticle "github.com/khanzadimahdi/testproject/application/dashboard/article/createArticle" dashboardDeleteArticle "github.com/khanzadimahdi/testproject/application/dashboard/article/deleteArticle" dashboardGetArticle "github.com/khanzadimahdi/testproject/application/dashboard/article/getArticle" @@ -118,6 +119,7 @@ func httpHandler() http.Handler { router := httprouter.New() log.SetFlags(log.LstdFlags | log.Llongfile) loginUseCase := login.NewUseCase(userRepository, j) + refreshUseCase := refresh.NewUseCase(userRepository, j) getArticleUsecase := getArticle.NewUseCase(articlesRepository, elementsRepository) getArticlesUsecase := getArticles.NewUseCase(articlesRepository) getArticlesByHashtagUseCase := getArticlesByHashtag.NewUseCase(articlesRepository) @@ -128,6 +130,7 @@ func httpHandler() http.Handler { // auth router.Handler(http.MethodPost, "/api/auth/login", auth.NewLoginHandler(loginUseCase)) + router.Handler(http.MethodPost, "/api/auth/token/refresh", auth.NewRefreshHandler(refreshUseCase)) // articles router.Handler(http.MethodGet, "/api/articles", articleAPI.NewIndexHandler(getArticlesUsecase)) diff --git a/backend/presentation/http/api/auth/refresh.go b/backend/presentation/http/api/auth/refresh.go new file mode 100644 index 00000000..ca4e8676 --- /dev/null +++ b/backend/presentation/http/api/auth/refresh.go @@ -0,0 +1,43 @@ +package auth + +import ( + "encoding/json" + "errors" + "net/http" + + "github.com/khanzadimahdi/testproject/application/auth/refresh" + "github.com/khanzadimahdi/testproject/domain" +) + +type refreshHandler struct { + refreshUseCase *refresh.UseCase +} + +func NewRefreshHandler(refreshUseCase *refresh.UseCase) *refreshHandler { + return &refreshHandler{ + refreshUseCase: refreshUseCase, + } +} + +func (h *refreshHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { + var request refresh.Request + if err := json.NewDecoder(r.Body).Decode(&request); err != nil { + rw.WriteHeader(http.StatusBadRequest) + return + } + + response, err := h.refreshUseCase.Login(request) + switch true { + case errors.Is(err, domain.ErrNotExists): + rw.WriteHeader(http.StatusNotFound) + case err != nil: + rw.WriteHeader(http.StatusInternalServerError) + case len(response.ValidationErrors) > 0: + rw.WriteHeader(http.StatusBadRequest) + json.NewEncoder(rw).Encode(response) + default: + rw.Header().Add("Content-Type", "application/json") + rw.WriteHeader(http.StatusOK) + json.NewEncoder(rw).Encode(response) + } +}