Skip to content

Commit

Permalink
Add user otp tests
Browse files Browse the repository at this point in the history
  • Loading branch information
xorkevin committed Oct 16, 2023
1 parent 7b7d306 commit 3fd8db1
Show file tree
Hide file tree
Showing 14 changed files with 193 additions and 34 deletions.
7 changes: 3 additions & 4 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -353,8 +353,7 @@ func TestClient(t *testing.T) {

err = clientC.fail(nil)
assert.Error(err)
assert.ErrorIs(err, ErrServerRes)
var kerr *kerrors.Error
assert.ErrorAs(err, &kerr)
assert.Equal("Test fail", kerr.Message)
var errres *ErrorServerRes
assert.ErrorAs(err, &errres)
assert.Equal("Test fail", errres.Message)
}
39 changes: 34 additions & 5 deletions clienthttp.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"errors"
"io"
"net/http"
"strconv"
"strings"
"time"

"xorkevin.dev/governor/util/kjson"
Expand Down Expand Up @@ -41,8 +43,6 @@ var (
ErrSendClientReq errSendClientReq
// ErrInvalidServerRes is returned on an invalid server response
ErrInvalidServerRes errInvalidServerRes
// ErrServerRes is a returned server error
ErrServerRes errServerRes
)

type (
Expand All @@ -64,8 +64,33 @@ func (e errInvalidServerRes) Error() string {
return "Invalid server response"
}

func (e errServerRes) Error() string {
return "Error server response"
type (
// ErrorServerRes is a returned server error
ErrorServerRes struct {
Status int
Code string
Message string
}
)

// WriteError implements [xorkevin.dev/kerrors.ErrorWriter]
func (e *ErrorServerRes) WriteError(b io.Writer) {
io.WriteString(b, "(")
io.WriteString(b, strconv.Itoa(e.Status))
io.WriteString(b, ") ")
io.WriteString(b, e.Message)
if e.Code != "" {
io.WriteString(b, " [")
io.WriteString(b, e.Code)
io.WriteString(b, "]")
}
}

// Error implements error
func (e *ErrorServerRes) Error() string {
var b strings.Builder
e.WriteError(&b)
return b.String()
}

func newHTTPClient(c configHTTPClient) *httpClient {
Expand Down Expand Up @@ -120,7 +145,11 @@ func (c *httpClient) Do(ctx context.Context, r *http.Request) (_ *http.Response,
if err := kjson.Unmarshal(b, &errres); err != nil {
return res, kerrors.WithKind(err, ErrInvalidServerRes, "Failed reading response")
}
return res, kerrors.WithKind(nil, ErrServerRes, errres.Message)
return res, kerrors.WithKind(nil, &ErrorServerRes{
Status: res.StatusCode,
Code: errres.Code,
Message: errres.Message,
}, errres.Message)
}
return res, nil
}
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ require (
gopkg.in/square/go-jose.v2 v2.6.0
nhooyr.io/websocket v1.8.7
xorkevin.dev/forge v0.5.2
xorkevin.dev/hunter2 v0.2.10
xorkevin.dev/hunter2 v0.2.11
xorkevin.dev/kerrors v0.1.5
xorkevin.dev/kfs v0.1.2
xorkevin.dev/klog v0.1.0
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -713,8 +713,8 @@ rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0=
rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA=
xorkevin.dev/forge v0.5.2 h1:bUNNV5wo5JJHxMmLeTY4sO2pdb5wOCY6SyZKEQA10hU=
xorkevin.dev/forge v0.5.2/go.mod h1:RzKyWpvft8khuCAljvwB9Q1k5jjZMtMRZesXOCNseyo=
xorkevin.dev/hunter2 v0.2.10 h1:0Gl2hqve7FIMiug1vmSGy3EKLaUskcm/JYu6kUS5Qc0=
xorkevin.dev/hunter2 v0.2.10/go.mod h1:RLMY1vygxaWYUzMrfVE06aWimgaBqc7w2nBjtXvkFK4=
xorkevin.dev/hunter2 v0.2.11 h1:djdZ1/BfHxXtugwT/CscDXlPX5muuy/Prk7YeEQfbnM=
xorkevin.dev/hunter2 v0.2.11/go.mod h1:RLMY1vygxaWYUzMrfVE06aWimgaBqc7w2nBjtXvkFK4=
xorkevin.dev/kerrors v0.1.5 h1:vszMcLbQjGn2DuqaQZNZNTbGw17qKvfH3X6E1s+1Iac=
xorkevin.dev/kerrors v0.1.5/go.mod h1:HnfCCdUvvu5aXhHxO8TPa8KU3r5X2c2QZRbfE+HBLvc=
xorkevin.dev/kfs v0.1.2 h1:118Y6/QP/jVSEjJdGtDqv/YLTkrcnWWfMP4yjUwqtw0=
Expand Down
4 changes: 0 additions & 4 deletions goverror_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,6 @@ func TestError(t *testing.T) {
Err: ErrInvalidServerRes,
String: "Invalid server response",
},
{
Err: ErrServerRes,
String: "Error server response",
},
} {
tc := tc
assert.Equal(tc.String, tc.Err.Error())
Expand Down
3 changes: 2 additions & 1 deletion service/conduit/service_presence.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"strings"
"time"

"xorkevin.dev/governor/service/kvstore"
"xorkevin.dev/governor/service/ws"
Expand All @@ -24,7 +25,7 @@ func (s *Service) presenceHandler(ctx context.Context, props ws.PresenceEventPro
default:
return nil
}
if err := s.kvpresence.Set(ctx, props.Userid, subloc, 60); err != nil {
if err := s.kvpresence.Set(ctx, props.Userid, subloc, time.Minute); err != nil {
return kerrors.WithMsg(err, "Failed to set presence")
}
return nil
Expand Down
17 changes: 17 additions & 0 deletions service/kvstore/map.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,23 @@ func NewMap() *Map {
}
}

func (s *Map) Dump() string {
s.mu.Lock()
defer s.mu.Unlock()

var b strings.Builder
for k, v := range s.store {
if v.expire.Before(time.Now()) {
continue
}
b.WriteString(k)
b.WriteString(":")
b.WriteString(v.val)
b.WriteString("\n")
}
return b.String()
}

func (s *Map) Ping(ctx context.Context) error {
return nil
}
Expand Down
5 changes: 3 additions & 2 deletions service/user/route_editsecure.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ type (
Userid string `valid:"userid,has" json:"-"`
Alg string `valid:"OTPAlg" json:"alg"`
Digits int `valid:"OTPDigits" json:"digits"`
Period int `valid:"OTPPeriod" json:"period"`
}
)

Expand All @@ -180,7 +181,7 @@ func (s *router) addOTP(c *governor.Context) {
return
}

res, err := s.s.addOTP(c.Ctx(), req.Userid, req.Alg, req.Digits)
res, err := s.s.addOTP(c.Ctx(), req.Userid, req.Alg, req.Digits, req.Period)
if err != nil {
c.WriteError(err)
return
Expand All @@ -192,7 +193,7 @@ type (
//forge:valid
reqOTPCode struct {
Userid string `valid:"userid,has" json:"-"`
Code string `valid:"OTPCode,opt" json:"code"`
Code string `valid:"OTPCode,has" json:"code"`
}
)

Expand Down
6 changes: 2 additions & 4 deletions service/user/service_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ func (s *Service) checkOTPCode(ctx context.Context, m *usermodel.Model, code str
}

func (s *Service) markOTPCode(ctx context.Context, userid string, code string) {
if err := s.kvotpcodes.Set(ctx, s.kvotpcodes.Subkey(userid, code), "-", 120); err != nil {
if err := s.kvotpcodes.Set(ctx, s.kvotpcodes.Subkey(userid, code), "-", 4*time.Minute); err != nil {
s.log.Err(ctx, kerrors.WithMsg(err, "Failed to mark otp code as used"))
}
}
Expand Down Expand Up @@ -188,8 +188,6 @@ func (s *Service) login(ctx context.Context, userid, password, code, backup, ses

if m.OTPEnabled {
if code == "" && backup == "" {
// must make a best effort to increment login failures
s.incrLoginFailCount(klog.ExtendCtx(context.Background(), ctx), m, ipaddr, useragent)
return nil, governor.ErrWithRes(nil, http.StatusBadRequest, "otp_required", "OTP code required")
}

Expand Down Expand Up @@ -277,7 +275,7 @@ func (s *Service) login(ctx context.Context, userid, password, code, backup, ses
}
}

if m.OTPEnabled {
if m.OTPEnabled && code != "" {
// must make a best effort to mark otp code as used
s.markOTPCode(klog.ExtendCtx(context.Background(), ctx), userid, code)
}
Expand Down
4 changes: 2 additions & 2 deletions service/user/service_editsecure.go
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ type (
}
)

func (s *Service) addOTP(ctx context.Context, userid string, alg string, digits int) (*resAddOTP, error) {
func (s *Service) addOTP(ctx context.Context, userid string, alg string, digits int, period int) (*resAddOTP, error) {
m, err := s.users.GetByID(ctx, userid)
if err != nil {
if errors.Is(err, dbsql.ErrNotFound) {
Expand All @@ -453,7 +453,7 @@ func (s *Service) addOTP(ctx context.Context, userid string, alg string, digits
if err != nil {
return nil, err
}
uri, backup, err := s.users.GenerateOTPSecret(ctx, cipher.cipher, m, s.authSettings.otpIssuer, alg, digits)
uri, backup, err := s.users.GenerateOTPSecret(ctx, cipher.cipher, m, s.authSettings.otpIssuer, alg, digits, period)
if err != nil {
return nil, kerrors.WithMsg(err, "Failed to generate otp secret")
}
Expand Down
107 changes: 103 additions & 4 deletions service/user/user_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ package user
import (
"bytes"
"context"
"encoding/base32"
"fmt"
"net/http"
"net/http/cookiejar"
"net/url"
"strconv"
"strings"
"testing"
"time"
Expand All @@ -30,9 +32,13 @@ import (
"xorkevin.dev/governor/service/user/sessionmodel"
"xorkevin.dev/governor/service/user/usermodel"
"xorkevin.dev/governor/util/kjson"
"xorkevin.dev/hunter2/h2cipher/xchacha20poly1305"
"xorkevin.dev/hunter2/h2otp"
"xorkevin.dev/klog"
)

var base32RawEncoding = base32.StdEncoding.WithPadding(base32.NoPadding)

func TestUsers(t *testing.T) {
if testing.Short() {
t.Skip("relies on db")
Expand All @@ -45,6 +51,9 @@ func TestUsers(t *testing.T) {
gateClient, err := gatetest.NewClient()
assert.NoError(err)

otpcipherconfig, err := xchacha20poly1305.NewConfig()
assert.NoError(err)

{
systoken, err := gateClient.GenToken(gate.KeySubSystem, time.Hour, "")
assert.NoError(err)
Expand All @@ -68,7 +77,7 @@ func TestUsers(t *testing.T) {
"extkeys": []string{gateClient.ExtKeyStr},
},
"otpkey": map[string]any{
"secrets": []string{},
"secrets": []string{otpcipherconfig.String()},
},
},
}, nil)
Expand Down Expand Up @@ -547,8 +556,10 @@ func TestUsers(t *testing.T) {
assert.Equal(template.KindLocal, maillog.Records[0].Tpl.Kind)
assert.Equal("passreset", maillog.Records[0].Tpl.Name)
maillog.Reset()
}

r, err = httpc.ReqJSON(http.MethodPost, "/u/auth/login", reqUserAuth{
{
r, err := httpc.ReqJSON(http.MethodPost, "/u/auth/login", reqUserAuth{
Username: "xorkevin2",
Password: "password3",
})
Expand All @@ -560,8 +571,96 @@ func TestUsers(t *testing.T) {

assert.True(authbody.Valid)
gateClient.Token = authbody.AccessToken
}

r, err = httpc.ReqJSON(http.MethodPost, fmt.Sprintf("/u/auth/id/%s/logout", regularUserid), nil)
{
r, err := httpc.ReqJSON(http.MethodPut, "/u/user/otp", reqAddOTP{
Alg: h2otp.AlgSHA512,
Digits: h2otp.OTPDigitsDefault,
Period: int(h2otp.TOTPPeriodDefault),
})
assert.NoError(err)

var body resAddOTP
_, err = httpc.DoJSON(context.Background(), r, &body)
assert.NoError(err)

totpuri, err := url.Parse(body.URI)
assert.NoError(err)
assert.Equal("otpauth", totpuri.Scheme)
assert.Equal("totp", totpuri.Host)
assert.Equal(h2otp.AlgSHA512, totpuri.Query().Get("algorithm"))
assert.Equal(strconv.Itoa(h2otp.OTPDigitsDefault), totpuri.Query().Get("digits"))
assert.Equal(strconv.FormatUint(h2otp.TOTPPeriodDefault, 10), totpuri.Query().Get("period"))
secret, err := base32RawEncoding.DecodeString(totpuri.Query().Get("secret"))
assert.NoError(err)
sha512, ok := h2otp.DefaultHashes.Get(h2otp.AlgSHA512)
assert.True(ok)
code, err := h2otp.TOTPNow(secret, h2otp.TOTPOpts{
Alg: sha512,
Digits: h2otp.OTPDigitsDefault,
Period: h2otp.TOTPPeriodDefault,
})
assert.NoError(err)

r, err = httpc.ReqJSON(http.MethodPut, "/u/user/otp/verify", reqOTPCode{
Code: code,
})
assert.NoError(err)

_, err = httpc.DoNoContent(context.Background(), r)
assert.NoError(err)

{
r, err := httpc.ReqJSON(http.MethodPost, "/u/auth/login", reqUserAuth{
Username: "xorkevin2",
Password: "password3",
})
assert.NoError(err)

res, err := httpc.DoJSON(context.Background(), r, nil)
assert.Error(err)
assert.Equal(http.StatusBadRequest, res.StatusCode)
var errres *governor.ErrorServerRes
assert.ErrorAs(err, &errres)
assert.Equal("otp_required", errres.Code)
}

{
r, err := httpc.ReqJSON(http.MethodPost, "/u/auth/login", reqUserAuth{
Username: "xorkevin2",
Password: "password3",
Code: code,
})
assert.NoError(err)

var authbody resUserAuth
_, err = httpc.DoJSON(context.Background(), r, &authbody)
assert.NoError(err)

assert.True(authbody.Valid)
gateClient.Token = authbody.AccessToken
}

{
r, err := httpc.ReqJSON(http.MethodPost, "/u/auth/login", reqUserAuth{
Username: "xorkevin2",
Password: "password3",
Code: code,
})
assert.NoError(err)

res, err := httpc.DoJSON(context.Background(), r, nil)
assert.Error(err)
assert.Equal(http.StatusBadRequest, res.StatusCode)
var errres *governor.ErrorServerRes
assert.ErrorAs(err, &errres)
assert.Equal("OTP code already used", errres.Message)
}
}

{
r, err := httpc.ReqJSON(http.MethodPost, fmt.Sprintf("/u/auth/id/%s/logout", regularUserid), nil)
assert.NoError(err)

_, err = httpc.DoNoContent(context.Background(), r)
Expand Down Expand Up @@ -593,7 +692,7 @@ func TestUsers(t *testing.T) {
CreationTime: body.CreationTime,
},
Email: "[email protected]",
OTPEnabled: false,
OTPEnabled: true,
}, body)
}
}
Loading

0 comments on commit 3fd8db1

Please sign in to comment.