diff --git a/src/ring/middleware/oauth2.clj b/src/ring/middleware/oauth2.clj index a46f813..5d1c177 100644 --- a/src/ring/middleware/oauth2.clj +++ b/src/ring/middleware/oauth2.clj @@ -55,14 +55,19 @@ (base64url (random/base64 63))) (defn- make-launch-handler [{:keys [pkce?] :as profile}] - (fn [{:keys [session] :or {session {}} :as request}] - (let [state (random-state) - verifier (when pkce? (random-code-verifier)) - session' (-> session - (assoc ::state state) - (cond-> pkce? (assoc ::code-verifier verifier)))] - (-> (resp/redirect (authorize-uri profile request state verifier)) - (assoc :session session'))))) + (fn handler + ([{:keys [session] :or {session {}} :as request}] + (let [state (random-state) + verifier (when pkce? (random-code-verifier)) + session' (-> session + (assoc ::state state) + (cond-> pkce? (assoc ::code-verifier verifier)))] + (-> (resp/redirect (authorize-uri profile request state verifier)) + (assoc :session session')))) + ([request respond raise] + (when-let [response (try (handler request) + (catch Exception e (raise e) false))] + (respond response))))) (defn- state-matches? [request] (= (get-in request [:session ::state]) @@ -107,41 +112,77 @@ (merge {:client_id id :client_secret secret})))) -(defn- get-access-token +(defn- access-token-http-options [{:keys [access-token-uri client-id client-secret basic-auth?] - :or {basic-auth? false} :as profile} request] - (format-access-token - (http/post access-token-uri - (cond-> {:accept :json, :as :json, - :form-params (request-params profile request)} - basic-auth? (add-header-credentials client-id client-secret) - (not basic-auth?) (add-form-credentials client-id client-secret))))) - -(defn state-mismatch-handler [_] - {:status 400, :headers {}, :body "State mismatch"}) - -(defn no-auth-code-handler [_] - {:status 400, :headers {}, :body "No authorization code"}) - -(defn- make-redirect-handler [{:keys [id landing-uri] :as profile}] - (let [state-mismatch-handler (:state-mismatch-handler - profile state-mismatch-handler) - no-auth-code-handler (:no-auth-code-handler - profile no-auth-code-handler)] - (fn [{:keys [session] :or {session {}} :as request}] - (cond - (not (state-matches? request)) - (state-mismatch-handler request) - - (nil? (get-authorization-code request)) - (no-auth-code-handler request) - - :else - (let [access-token (get-access-token profile request)] - (-> (resp/redirect landing-uri) - (assoc :session (-> session - (assoc-in [::access-tokens id] access-token) - (dissoc ::state ::code-verifier))))))))) + :or {basic-auth? false} :as profile} + request] + (let [opts {:method :post + :url access-token-uri + :accept :json + :as :json + :form-params (request-params profile request)}] + (if basic-auth? + (add-header-credentials opts client-id client-secret) + (add-form-credentials opts client-id client-secret)))) + +(defn- get-access-token + ([profile request] + (-> (http/request (access-token-http-options profile request)) + (format-access-token))) + ([profile request respond raise] + (http/request (-> (access-token-http-options profile request) + (assoc :async? true)) + (comp respond format-access-token) + raise))) + +(defn state-mismatch-handler + ([_] + {:status 400, :headers {}, :body "State mismatch"}) + ([request respond _] + (respond (state-mismatch-handler request)))) + +(defn no-auth-code-handler + ([_] + {:status 400, :headers {}, :body "No authorization code"}) + ([request respond _] + (respond (no-auth-code-handler request)))) + +(defn- redirect-response [{:keys [id landing-uri]} session access-token] + (-> (resp/redirect landing-uri) + (assoc :session (-> session + (assoc-in [::access-tokens id] access-token) + (dissoc ::state ::code-verifier))))) + +(defn- make-redirect-handler + [{:keys [state-mismatch-handler no-auth-code-handler] + :or {state-mismatch-handler state-mismatch-handler + no-auth-code-handler no-auth-code-handler} + :as profile}] + (fn + ([{:keys [session] :or {session {}} :as request}] + (cond + (not (state-matches? request)) + (state-mismatch-handler request) + + (nil? (get-authorization-code request)) + (no-auth-code-handler request) + + :else + (let [access-token (get-access-token profile request)] + (redirect-response profile session access-token)))) + ([{:keys [session] :or {session {}} :as request} respond raise] + (cond + (not (state-matches? request)) + (state-mismatch-handler request respond raise) + + (nil? (get-authorization-code request)) + (no-auth-code-handler request respond raise) + + :else + (get-access-token profile request + (fn [token] + (respond (redirect-response profile session token))) + raise))))) (defn- assoc-access-tokens [request] (if-let [tokens (-> request :session ::access-tokens)] @@ -151,7 +192,7 @@ (defn- parse-redirect-url [{:keys [redirect-uri]}] (.getPath (java.net.URI. redirect-uri))) -(defn- valid-profile? [{:keys [client-id client-secret] :as profile}] +(defn- valid-profile? [{:keys [client-id client-secret]}] (and (some? client-id) (some? client-secret))) (defn wrap-oauth2 [handler profiles] @@ -159,9 +200,17 @@ (let [profiles (for [[k v] profiles] (assoc v :id k)) launches (into {} (map (juxt :launch-uri identity)) profiles) redirects (into {} (map (juxt parse-redirect-url identity)) profiles)] - (fn [{:keys [uri] :as request}] - (if-let [profile (launches uri)] - ((make-launch-handler profile) request) - (if-let [profile (redirects uri)] - ((:redirect-handler profile (make-redirect-handler profile)) request) - (handler (assoc-access-tokens request))))))) + (fn + ([{:keys [uri] :as request}] + (if-let [profile (launches uri)] + ((make-launch-handler profile) request) + (if-let [profile (redirects uri)] + ((:redirect-handler profile (make-redirect-handler profile)) request) + (handler (assoc-access-tokens request))))) + ([{:keys [uri] :as request} respond raise] + (if-let [profile (launches uri)] + ((make-launch-handler profile) request respond raise) + (if-let [profile (redirects uri)] + ((:redirect-handler profile (make-redirect-handler profile)) + request respond raise) + (handler (assoc-access-tokens request) respond raise))))))) diff --git a/test/ring/middleware/oauth2_test.clj b/test/ring/middleware/oauth2_test.clj index 7464122..cc6df10 100644 --- a/test/ring/middleware/oauth2_test.clj +++ b/test/ring/middleware/oauth2_test.clj @@ -23,8 +23,11 @@ (def test-profile-pkce (assoc test-profile :pkce? true)) -(defn- token-handler [{:keys [oauth2/access-tokens]}] - {:status 200, :headers {}, :body access-tokens}) +(defn- token-handler + ([{:keys [oauth2/access-tokens]}] + {:status 200, :headers {}, :body access-tokens}) + ([request respond _raise] + (respond (token-handler request)))) (def test-handler (wrap-oauth2 token-handler {:test test-profile})) @@ -33,20 +36,43 @@ (wrap-oauth2 token-handler {:test test-profile-pkce})) (deftest test-launch-uri - (let [response (test-handler (mock/request :get "/oauth2/test")) - location (get-in response [:headers "Location"]) - [_ query] (str/split location #"\?" 2) - params (codec/form-decode query)] - (is (= 302 (:status response))) - (is (.startsWith ^String location "https://example.com/oauth2/authorize?")) - (is (= {"response_type" "code" - "client_id" "abcdef" - "redirect_uri" "http://localhost/oauth2/test/callback" - "scope" "user project"} - (dissoc params "state"))) - (is (re-matches #"[A-Za-z0-9_-]{12}" (params "state"))) - (is (= {::oauth2/state (params "state")} - (:session response))))) + (testing "sync handlers" + (let [response (test-handler (mock/request :get "/oauth2/test")) + location (get-in response [:headers "Location"]) + [_ query] (str/split location #"\?" 2) + params (codec/form-decode query)] + (is (= 302 (:status response))) + (is (.startsWith ^String location "https://example.com/oauth2/authorize?")) + (is (= {"response_type" "code" + "client_id" "abcdef" + "redirect_uri" "http://localhost/oauth2/test/callback" + "scope" "user project"} + (dissoc params "state"))) + (is (re-matches #"[A-Za-z0-9_-]{12}" (params "state"))) + (is (= {::oauth2/state (params "state")} + (:session response))))) + + (testing "async handlers" + (let [respond (promise) + raise (promise)] + (test-handler (mock/request :get "/oauth2/test") respond raise) + (let [response (deref respond 100 :empty) + error (deref raise 100 :empty)] + (is (not= response :empty)) + (is (= error :empty)) + (let [location (get-in response [:headers "Location"]) + [_ query] (str/split location #"\?" 2) + params (codec/form-decode query)] + (is (= 302 (:status response))) + (is (.startsWith ^String location "https://example.com/oauth2/authorize?")) + (is (= {"response_type" "code" + "client_id" "abcdef" + "redirect_uri" "http://localhost/oauth2/test/callback" + "scope" "user project"} + (dissoc params "state"))) + (is (re-matches #"[A-Za-z0-9_-]{12}" (params "state"))) + (is (= {::oauth2/state (params "state")} + (:session response)))))))) (deftest test-launch-uri-pkce (let [response (test-handler-pkce (mock/request :get "/oauth2/test")) @@ -248,7 +274,35 @@ :session ::oauth2/access-tokens :test :id-token))) (is (approx-eq expires (-> response - :session ::oauth2/access-tokens :test :expires))))))) + :session ::oauth2/access-tokens :test :expires))))) + + (testing "async handler" + (let [request (-> (mock/request :get "/oauth2/test/callback") + (assoc :session {::oauth2/state "xyzxyz"}) + (assoc :query-params {"code" "abcabc" + "state" "xyzxyz"})) + respond (promise) + raise (promise) + expires (seconds-from-now-to-date 3600)] + (test-handler request respond raise) + (let [response (deref respond 100 :empty) + error (deref raise 100 :empty)] + (is (not= response :empty) "timeout getting response") + (is (= error :empty)) + (is (= 302 (:status response))) + (is (= "/" (get-in response [:headers "Location"]))) + (is (map? (-> response :session ::oauth2/access-tokens))) + (is (= "defdef" + (-> response :session ::oauth2/access-tokens :test :token))) + (is (= "ghighi" + (-> response + :session ::oauth2/access-tokens :test :refresh-token))) + (is (= "abc.def.ghi" + (-> response + :session ::oauth2/access-tokens :test :id-token))) + (is (approx-eq expires + (-> response + :session ::oauth2/access-tokens :test :expires)))))))) (def openid-response-with-string-expires {:status 200 @@ -315,3 +369,20 @@ response (handler request) body (:body response)] (is (= "redirect-handler-response-body" body)))) + +(deftest test-handler-passthrough + (let [tokens {:test "tttkkkk"} + request (-> (mock/request :get "/example") + (assoc :session {::oauth2/access-tokens tokens}))] + (testing "sync handler" + (is (= {:status 200, :headers {}, :body tokens} + (test-handler request)))) + + (testing "async handler" + (let [respond (promise) + raise (promise)] + (test-handler request respond raise) + (is (= :empty + (deref raise 100 :empty))) + (is (= {:status 200, :headers {}, :body tokens} + (deref respond 100 :empty)))))))