forked from Versent/saml2aws
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request Versent#720 from logingood/logingood/add-jumpcloud…
…-protect Add JumpCloud Protect (PUSH) MFA support
- Loading branch information
Showing
5 changed files
with
284 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
package jumpcloud | ||
|
||
import ( | ||
"encoding/json" | ||
"fmt" | ||
"io" | ||
"io/ioutil" | ||
"net/http" | ||
"net/url" | ||
"path" | ||
"time" | ||
|
||
"github.com/pkg/errors" | ||
) | ||
|
||
type JumpCloudPushResponse struct { | ||
ID string `json:"id"` | ||
ExpiresAt time.Time `json:"expiresAt"` | ||
InitiatedAt time.Time `json:"initiatedAt"` | ||
Status string `json:"status"` | ||
UserId string `json:"userId"` | ||
} | ||
|
||
func (jc *Client) jumpCloudProtectAuth(submitUrl string, xsrfToken string) (*http.Response, error) { | ||
jumpCloudParsedURL, err := url.Parse(submitUrl) | ||
if err != nil { | ||
return nil, errors.Wrap(err, fmt.Sprintf("unable to parse submit url, url=%s", jumpCloudProtectSubmitURL)) | ||
} | ||
|
||
req, err := http.NewRequest("POST", jumpCloudParsedURL.String(), emptyJSONIOReader()) | ||
if err != nil { | ||
return nil, errors.Wrap(err, "error building jumpcloud protect auth request") | ||
} | ||
ensureHeaders(xsrfToken, req) | ||
|
||
res, err := jc.client.Do(req) | ||
if err != nil { | ||
return nil, errors.Wrap(err, "error retrieving JumpCloud PUSH payload") | ||
} | ||
defer res.Body.Close() | ||
|
||
if res.StatusCode != 200 { | ||
return nil, errors.New("error retrieving JumpCloud PUSH payload, non 200 status returned") | ||
} | ||
|
||
jpResp, err := ioutil.ReadAll(res.Body) | ||
if err != nil { | ||
return nil, errors.Wrap(err, "error retrieving JumpCloud PUSH payload") | ||
} | ||
|
||
jp := JumpCloudPushResponse{} | ||
if err := json.Unmarshal(jpResp, &jp); err != nil { | ||
return nil, errors.Wrap(err, "failed to unmarshal JumpCloud PUSH payload to struct") | ||
} | ||
|
||
jumpCloudParsedURL.Path = path.Join(jumpCloudParsedURL.Path, jp.ID) | ||
req, err = http.NewRequest("GET", jumpCloudParsedURL.String(), nil) | ||
ensureHeaders(xsrfToken, req) | ||
|
||
if err != nil { | ||
return nil, errors.Wrap(err, "failed to build JumpCoud PUSH polling request") | ||
} | ||
|
||
// Stay in the loop until we get something else other than "pending". | ||
// jp.Status can be: | ||
// * accepted | ||
// * expired | ||
// * denied | ||
|
||
for jp.Status == "pending" { | ||
if time.Now().UTC().After(jp.ExpiresAt) { | ||
return nil, errors.New("the session is expired try again") | ||
} | ||
|
||
resp, err := jc.client.Do(req) | ||
if err != nil { | ||
return nil, errors.Wrap(err, "error retrieving verify response") | ||
} | ||
defer resp.Body.Close() | ||
if resp.StatusCode != http.StatusOK { | ||
return nil, errors.New(fmt.Sprintf("received non 200 http code, http code = %d", resp.StatusCode)) | ||
} | ||
|
||
bytes, err := io.ReadAll(resp.Body) | ||
if err != nil { | ||
return nil, errors.Wrap(err, "failed to unmarshal JumpCloud PUSH body") | ||
} | ||
|
||
if err := json.Unmarshal(bytes, &jp); err != nil { | ||
return nil, errors.Wrap(err, "failed to unmarshal poll result json into struct") | ||
} | ||
|
||
// sleep for 500ms before next request | ||
time.Sleep(500 * time.Millisecond) | ||
} | ||
|
||
if jp.Status != "accepted" { | ||
return nil, errors.New(fmt.Sprintf("didn't receive accepted, status=%s", jp.Status)) | ||
} | ||
|
||
jumpCloudParsedURL.Path = path.Join(jumpCloudParsedURL.Path, "login") | ||
req, err = http.NewRequest("POST", jumpCloudParsedURL.String(), emptyJSONIOReader()) | ||
if err != nil { | ||
return nil, errors.Wrap(err, "failed to build JumpCoud login request") | ||
} | ||
|
||
ensureHeaders(xsrfToken, req) | ||
return jc.client.Do(req) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
package jumpcloud | ||
|
||
import ( | ||
"encoding/json" | ||
"fmt" | ||
"net/http" | ||
"net/http/httptest" | ||
"testing" | ||
"time" | ||
|
||
"github.com/stretchr/testify/require" | ||
"github.com/versent/saml2aws/v2/pkg/cfg" | ||
) | ||
|
||
type test struct { | ||
code int | ||
err string | ||
testCase string | ||
} | ||
|
||
func Test_jumpCloudProtectAuth(t *testing.T) { | ||
jumpCloudPushResp := JumpCloudPushResponse{ | ||
ExpiresAt: time.Now().Add(1 * time.Minute).UTC(), | ||
ID: "foo", | ||
} | ||
|
||
pendingCnt := 1 | ||
maxPending := 2 | ||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
// using token here as a clue to mock responses | ||
switch token := r.Header.Get("X-Xsrftoken"); { | ||
case token == "happy": | ||
switch r.URL.Path { | ||
case "/": | ||
returnResp(t, "pending", 200, &jumpCloudPushResp, w) | ||
case fmt.Sprintf("/%s", jumpCloudPushResp.ID): | ||
returnResp(t, "accepted", 200, &jumpCloudPushResp, w) | ||
case fmt.Sprintf("/%s/login", jumpCloudPushResp.ID): | ||
_, err := w.Write([]byte(`{}`)) | ||
require.Nil(t, err) | ||
} | ||
|
||
case token == "loop twice until accepted": | ||
switch r.URL.Path { | ||
case "/": | ||
returnResp(t, "pending", 200, &jumpCloudPushResp, w) | ||
case fmt.Sprintf("/%s", jumpCloudPushResp.ID): | ||
pendingCnt += 1 | ||
if pendingCnt == maxPending { | ||
returnResp(t, "accepted", 200, &jumpCloudPushResp, w) | ||
} else { | ||
returnResp(t, "pending", 200, &jumpCloudPushResp, w) | ||
} | ||
case fmt.Sprintf("/%s/login", jumpCloudPushResp.ID): | ||
_, err := w.Write([]byte(`{}`)) | ||
require.Nil(t, err) | ||
} | ||
|
||
case token == "payload error": | ||
w.WriteHeader(http.StatusBadRequest) | ||
_, err := w.Write([]byte(`{}`)) | ||
require.Nil(t, err) | ||
|
||
case token == "received expired": | ||
switch r.URL.Path { | ||
case "/": | ||
jumpCloudPushResp.Status = "pending" | ||
bytes, err := json.Marshal(&jumpCloudPushResp) | ||
require.Nil(t, err) | ||
_, _ = w.Write(bytes) | ||
case fmt.Sprintf("/%s", jumpCloudPushResp.ID): | ||
returnResp(t, "expired", http.StatusOK, &jumpCloudPushResp, w) | ||
} | ||
|
||
case token == "received denied": | ||
switch r.URL.Path { | ||
case "/": | ||
jumpCloudPushResp.Status = "pending" | ||
bytes, err := json.Marshal(&jumpCloudPushResp) | ||
require.Nil(t, err) | ||
_, _ = w.Write(bytes) | ||
case fmt.Sprintf("/%s", jumpCloudPushResp.ID): | ||
returnResp(t, "denied", http.StatusUnauthorized, &jumpCloudPushResp, w) | ||
} | ||
|
||
case token == "login error": | ||
switch r.URL.Path { | ||
case "/": | ||
jumpCloudPushResp.Status = "pending" | ||
bytes, err := json.Marshal(&jumpCloudPushResp) | ||
require.Nil(t, err) | ||
_, _ = w.Write(bytes) | ||
case fmt.Sprintf("/%s", jumpCloudPushResp.ID): | ||
jumpCloudPushResp.Status = "accepted" | ||
bytes, err := json.Marshal(&jumpCloudPushResp) | ||
require.Nil(t, err) | ||
_, _ = w.Write(bytes) | ||
case fmt.Sprintf("/%s/login", jumpCloudPushResp.ID): | ||
w.WriteHeader(http.StatusInternalServerError) | ||
} | ||
} | ||
})) | ||
defer ts.Close() | ||
|
||
client, err := New(&cfg.IDPAccount{Provider: "JumpCloud", MFA: "PUSH"}) | ||
require.Nil(t, err) | ||
|
||
tests := []test{ | ||
{testCase: "happy", code: http.StatusOK}, | ||
{testCase: "loop twice until accepted", code: http.StatusOK}, | ||
{testCase: "login error", code: http.StatusInternalServerError}, | ||
{testCase: "payload error", code: http.StatusInternalServerError, err: "error retrieving JumpCloud PUSH payload, non 200 status returned"}, | ||
{testCase: "received expired", err: "didn't receive accepted, status=expired"}, | ||
{testCase: "received denied", err: "received non 200 http code, http code = 401"}, | ||
} | ||
|
||
for _, test := range tests { | ||
t.Run(test.testCase, func(t *testing.T) { | ||
resp, err := client.jumpCloudProtectAuth(ts.URL, test.testCase) | ||
if test.err == "" { | ||
require.Nil(t, err) | ||
require.Equal(t, test.code, resp.StatusCode) | ||
} else { | ||
require.EqualError(t, err, test.err) | ||
} | ||
}) | ||
} | ||
|
||
require.Equal(t, pendingCnt, maxPending) | ||
} | ||
|
||
func returnResp(t *testing.T, status string, statusCode int, j *JumpCloudPushResponse, w http.ResponseWriter) { | ||
j.Status = status | ||
bytes, err := json.Marshal(j) | ||
require.Nil(t, err) | ||
w.WriteHeader(statusCode) | ||
_, err = w.Write(bytes) | ||
require.Nil(t, err) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
package jumpcloud | ||
|
||
import ( | ||
"bytes" | ||
"io" | ||
"net/http" | ||
) | ||
|
||
func ensureHeaders(xsrfToken string, req *http.Request) { | ||
req.Header.Add("X-Xsrftoken", xsrfToken) | ||
req.Header.Add("Accept", "application/json") | ||
req.Header.Add("Content-Type", "application/json") | ||
} | ||
|
||
func emptyJSONIOReader() io.Reader { | ||
return bytes.NewReader([]byte(`{}`)) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters