Skip to content

Commit

Permalink
Merge #26062
Browse files Browse the repository at this point in the history
26062: server, ui: logout r=vilterp a=vilterp

**implement logout endpoint on backend**

Because of grpc-ecosystem/grpc-gateway#470, had to do an end run around gRPC. We'll need to fix grpc-gateway to pass the context through eventually, but I don't want logout to be blocked on it.

**implement logout button on frontend**

Makes an RPC to invalidate the session. If successful, reloads the page.

Fixes #25784 
Produces awesomeness when combined with #26053 

Co-authored-by: Pete Vilter <[email protected]>
  • Loading branch information
craig[bot] and Pete Vilter committed May 31, 2018
2 parents a394a07 + eea0820 commit e6bbb2e
Show file tree
Hide file tree
Showing 11 changed files with 415 additions and 131 deletions.
99 changes: 84 additions & 15 deletions pkg/server/authentication.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"encoding/base64"
"fmt"
"net/http"
"strconv"
"time"

gwruntime "github.com/grpc-ecosystem/grpc-gateway/runtime"
Expand All @@ -44,7 +45,8 @@ import (
const (
// authPrefix is the prefix for RESTful endpoints used to provide
// authentication methods.
authPrefix = "/_auth/v1/"
loginPath = "/login"
logoutPath = "/logout"
// secretLength is the number of random bytes generated for session secrets.
secretLength = 16
sessionCookieName = "session"
Expand Down Expand Up @@ -72,15 +74,19 @@ func newAuthenticationServer(s *Server) *authenticationServer {

// RegisterService registers the GRPC service.
func (s *authenticationServer) RegisterService(g *grpc.Server) {
serverpb.RegisterAuthenticationServer(g, s)
serverpb.RegisterLogInServer(g, s)
serverpb.RegisterLogOutServer(g, s)
}

// RegisterGateway starts the gateway (i.e. reverse proxy) that proxies HTTP requests
// to the appropriate gRPC endpoints.
func (s *authenticationServer) RegisterGateway(
ctx context.Context, mux *gwruntime.ServeMux, conn *grpc.ClientConn,
) error {
return serverpb.RegisterAuthenticationHandler(ctx, mux, conn)
if err := serverpb.RegisterLogInHandler(ctx, mux, conn); err != nil {
return err
}
return serverpb.RegisterLogOutHandler(ctx, mux, conn)
}

// UserLogin verifies an incoming request by a user to create an web
Expand Down Expand Up @@ -150,7 +156,46 @@ func (s *authenticationServer) UserLogin(
func (s *authenticationServer) UserLogout(
ctx context.Context, req *serverpb.UserLogoutRequest,
) (*serverpb.UserLogoutResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "Logout method has not yet been implemented.")
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return nil, apiInternalError(ctx, fmt.Errorf("couldn't get incoming context"))
}
sessionIDs := md.Get(webSessionIDKeyStr)
if len(sessionIDs) != 1 {
return nil, apiInternalError(ctx, fmt.Errorf("couldn't get incoming context"))
}

sessionID, err := strconv.Atoi(sessionIDs[0])
if err != nil {
return nil, fmt.Errorf("invalid session id: %d", sessionID)
}

// Revoke the session.
if n, err := s.server.internalExecutor.Exec(
ctx,
"revoke-auth-session",
nil, /* txn */
`UPDATE system.web_sessions SET "revokedAt" = now() WHERE id = $1`,
sessionID,
); err != nil {
return nil, apiInternalError(ctx, err)
} else if n == 0 {
msg := fmt.Sprintf("session with id %d nonexistent", sessionID)
log.Info(ctx, msg)
return nil, fmt.Errorf(msg)
}

// Send back a header which will cause the browser to destroy the cookie.
// See https://tools.ietf.org/search/rfc6265, page 7.
cookie := makeCookieWithValue("")
cookie.MaxAge = -1

// Set the cookie header on the outgoing response.
if err := grpc.SetHeader(ctx, metadata.Pairs("set-cookie", cookie.String())); err != nil {
return nil, apiInternalError(ctx, err)
}

return &serverpb.UserLogoutResponse{}, nil
}

// verifySession verifies the existence and validity of the session claimed by
Expand Down Expand Up @@ -303,17 +348,24 @@ func newAuthenticationMux(s *authenticationServer, inner http.Handler) *authenti
}
}

type loggedInUserKey struct{}
type webSessionUserKey struct{}
type webSessionIDKey struct{}

const webSessionUserKeyStr = "webSessionUser"
const webSessionIDKeyStr = "webSessionID"

func (am *authenticationMux) ServeHTTP(w http.ResponseWriter, req *http.Request) {
username, err := am.getSession(w, req)
username, cookie, err := am.getSession(w, req)
if err != nil && !am.allowAnonymous {
log.Infof(req.Context(), "Web session error: %s", err)
http.Error(w, "a valid authentication cookie is required", http.StatusUnauthorized)
return
}

newCtx := context.WithValue(req.Context(), loggedInUserKey{}, username)
newCtx := context.WithValue(req.Context(), webSessionUserKey{}, username)
if cookie != nil {
newCtx = context.WithValue(newCtx, webSessionIDKey{}, cookie.ID)
}
newReq := req.WithContext(newCtx)

am.inner.ServeHTTP(w, newReq)
Expand All @@ -324,43 +376,49 @@ func encodeSessionCookie(sessionCookie *serverpb.SessionCookie) (*http.Cookie, e
if err != nil {
return nil, errors.Wrap(err, "session cookie could not be encoded")
}
value := base64.StdEncoding.EncodeToString(cookieValueBytes)
return makeCookieWithValue(value), nil
}

func makeCookieWithValue(value string) *http.Cookie {
return &http.Cookie{
Name: sessionCookieName,
Value: base64.StdEncoding.EncodeToString(cookieValueBytes),
Value: value,
Path: "/",
HttpOnly: true,
Secure: true,
}, nil
}
}

// getSession decodes the cookie from the request, looks up the corresponding session, and
// returns the logged in user name. If there's an error, it returns an error value and the
// HTTP error code.
func (am *authenticationMux) getSession(w http.ResponseWriter, req *http.Request) (string, error) {
func (am *authenticationMux) getSession(
w http.ResponseWriter, req *http.Request,
) (string, *serverpb.SessionCookie, error) {
// Validate the returned cookie.
rawCookie, err := req.Cookie(sessionCookieName)
if err != nil {
return "", err
return "", nil, err
}

cookie, err := decodeSessionCookie(rawCookie)
if err != nil {
err = errors.Wrap(err, "a valid authentication cookie is required")
return "", err
return "", nil, err
}

valid, username, err := am.server.verifySession(req.Context(), cookie)
if err != nil {
err := apiInternalError(req.Context(), err)
return "", err
return "", nil, err
}
if !valid {
err := errors.New("the provided authentication session could not be validated")
return "", err
return "", nil, err
}

return username, nil
return username, cookie, nil
}

func decodeSessionCookie(encodedCookie *http.Cookie) (*serverpb.SessionCookie, error) {
Expand Down Expand Up @@ -394,3 +452,14 @@ func authenticationHeaderMatcher(key string) (string, bool) {
// duplicated here.
return fmt.Sprintf("%s%s", gwruntime.MetadataHeaderPrefix, key), true
}

func forwardAuthenticationMetadata(ctx context.Context, _ *http.Request) metadata.MD {
md := metadata.MD{}
if user := ctx.Value(webSessionUserKey{}); user != nil {
md.Set(webSessionUserKeyStr, user.(string))
}
if sessionID := ctx.Value(webSessionIDKey{}); sessionID != nil {
md.Set(webSessionIDKeyStr, fmt.Sprintf("%v", sessionID))
}
return md
}
78 changes: 77 additions & 1 deletion pkg/server/authentication_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"fmt"
"io/ioutil"
"net/http"
"net/http/cookiejar"
"net/url"
"testing"
"time"
Expand Down Expand Up @@ -436,7 +437,7 @@ func TestAuthenticationAPIUserLogin(t *testing.T) {
}
var resp serverpb.UserLoginResponse
return httputil.PostJSONWithRequest(
httpClient, ts.AdminURL()+authPrefix+"login", &req, &resp,
httpClient, ts.AdminURL()+loginPath, &req, &resp,
)
}

Expand Down Expand Up @@ -493,6 +494,81 @@ func TestAuthenticationAPIUserLogin(t *testing.T) {
}
}

func TestLogout(t *testing.T) {
defer leaktest.AfterTest(t)()
s, db, _ := serverutils.StartServer(t, base.TestServerArgs{})
defer s.Stopper().Stop(context.TODO())
ts := s.(*TestServer)

// Log in.
authHTTPClient, cookie, err := ts.getAuthenticatedHTTPClientAndCookie()
if err != nil {
t.Fatal("error opening HTTP client", err)
}

// Log out.
var resp serverpb.UserLogoutResponse
if err := httputil.GetJSON(authHTTPClient, ts.AdminURL()+logoutPath, &resp); err != nil {
t.Fatal("logout request failed:", err)
}

// Verify that revokedAt has been set in the DB.
query := `SELECT "revokedAt" FROM system.web_sessions WHERE id = $1`
result := db.QueryRow(query, cookie.ID)
var revokedAt string
if err := result.Scan(&revokedAt); err != nil {
t.Fatalf("error querying auth session: %s", err)
}

if revokedAt == "" {
t.Fatal("expected revoked at to not be empty; was empty")
}

databasesURL := ts.AdminURL() + "/_admin/v1/databases"

// Verify that we're unauthorized after logout.
response, err := authHTTPClient.Get(databasesURL)
if err != nil {
t.Fatal(err)
}
defer response.Body.Close()

if response.StatusCode != http.StatusUnauthorized {
t.Fatal("expected unauthorized response after logout; got", response.StatusCode)
}

// Try to use the revoked cookie; verify that it doesn't work.
parsedURL, err := url.Parse(s.AdminURL())
if err != nil {
t.Fatal(err)
}
encodedCookie, err := encodeSessionCookie(cookie)
if err != nil {
t.Fatal(err)
}

invalidAuthClient, err := s.GetHTTPClient()
if err != nil {
t.Fatal(err)
}
jar, err := cookiejar.New(nil)
if err != nil {
t.Fatal(err)
}
invalidAuthClient.Jar = jar
invalidAuthClient.Jar.SetCookies(parsedURL, []*http.Cookie{encodedCookie})

invalidAuthResp, err := invalidAuthClient.Get(databasesURL)
if err != nil {
t.Fatal(err)
}
defer invalidAuthResp.Body.Close()

if invalidAuthResp.StatusCode != 401 {
t.Fatal("expected unauthorized error; got", invalidAuthResp.StatusCode)
}
}

// TestAuthenticationMux verifies that the authentication handler is used by all
// of the APIs it should be protecting. Authentication is enabled by default for
// the test server, and every test which accesses APIs uses an authenticated
Expand Down
6 changes: 4 additions & 2 deletions pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -1203,6 +1203,7 @@ func (s *Server) Start(ctx context.Context) error {
gwruntime.WithMarshalerOption(httputil.ProtoContentType, protopb),
gwruntime.WithMarshalerOption(httputil.AltProtoContentType, protopb),
gwruntime.WithOutgoingHeaderMatcher(authenticationHeaderMatcher),
gwruntime.WithMetadata(forwardAuthenticationMetadata),
)
gwCtx, gwCancel := context.WithCancel(s.AnnotateCtx(context.Background()))
s.stopper.AddCloser(stop.CloserFn(gwCancel))
Expand Down Expand Up @@ -1478,7 +1479,8 @@ If problems persist, please see ` + base.DocsURL("cluster-setup-troubleshooting.
s.mux.Handle("/_admin/v1/health", gwMux)
s.mux.Handle(ts.URLPrefix, authHandler)
s.mux.Handle(statusPrefix, authHandler)
s.mux.Handle(authPrefix, gwMux)
s.mux.Handle(loginPath, gwMux)
s.mux.Handle(logoutPath, authHandler)
s.mux.Handle(statusVars, http.HandlerFunc(s.status.handleVars))
log.Event(ctx, "added http endpoints")

Expand Down Expand Up @@ -1928,7 +1930,7 @@ func serveUIAssets(fileServer http.Handler, cfg Config) http.Handler {
LoginEnabled: cfg.RequireWebSession(),
Version: build.VersionPrefix(),
}
loggedInUser, ok := request.Context().Value(loggedInUserKey{}).(string)
loggedInUser, ok := request.Context().Value(webSessionUserKey{}).(string)
if ok && loggedInUser != "" {
tmplArgs.LoggedInUser = &loggedInUser
}
Expand Down
Loading

0 comments on commit e6bbb2e

Please sign in to comment.