diff --git a/servlet/src/main/java/io/undertow/servlet/handlers/security/ServletFormAuthenticationMechanism.java b/servlet/src/main/java/io/undertow/servlet/handlers/security/ServletFormAuthenticationMechanism.java index 50bb70a49a..7435539a8e 100644 --- a/servlet/src/main/java/io/undertow/servlet/handlers/security/ServletFormAuthenticationMechanism.java +++ b/servlet/src/main/java/io/undertow/servlet/handlers/security/ServletFormAuthenticationMechanism.java @@ -18,10 +18,14 @@ package io.undertow.servlet.handlers.security; +import static io.undertow.security.api.SecurityNotification.EventType.AUTHENTICATED; import static io.undertow.util.StatusCodes.OK; import io.undertow.security.api.AuthenticationMechanism; import io.undertow.security.api.AuthenticationMechanismFactory; +import io.undertow.security.api.SecurityContext; +import io.undertow.security.api.NotificationReceiver; +import io.undertow.security.api.SecurityNotification; import io.undertow.security.idm.IdentityManager; import io.undertow.security.impl.FormAuthenticationMechanism; import io.undertow.server.HttpServerExchange; @@ -153,6 +157,19 @@ public ServletFormAuthenticationMechanism(FormParserFactory formParserFactory, S this.overrideInitial = overrideInitial; } + @Override + public AuthenticationMechanismOutcome authenticate(final HttpServerExchange exchange, final SecurityContext securityContext) { + securityContext.registerNotificationReceiver(new NotificationReceiver() { + @Override + public void handleNotification(final SecurityNotification notification) { + if (notification.getEventType() == AUTHENTICATED) { + getAndInitializeSession(exchange, false); + } + } + }); + return super.authenticate(exchange, securityContext); + } + @Override protected Integer servePage(final HttpServerExchange exchange, final String location) { final ServletRequestContext servletRequestContext = exchange.getAttachment(ServletRequestContext.ATTACHMENT_KEY); @@ -195,27 +212,7 @@ protected void storeInitialLocation(final HttpServerExchange exchange, byte[] by if(!saveOriginalRequest) { return; } - final ServletRequestContext servletRequestContext = exchange.getAttachment(ServletRequestContext.ATTACHMENT_KEY); - final ServletContextImpl servletContextImpl = servletRequestContext.getCurrentServletContext(); - HttpSessionImpl httpSession = servletContextImpl.getSession(exchange, false); - boolean newSession = false; - if (httpSession == null) { - httpSession = servletContextImpl.getSession(exchange, true); - newSession = true; - } - Session session; - if (System.getSecurityManager() == null) { - session = httpSession.getSession(); - } else { - session = AccessController.doPrivileged(new HttpSessionImpl.UnwrapSessionAction(httpSession)); - } - if (newSession) { - int originalMaxInactiveInterval = session.getMaxInactiveInterval(); - if (originalMaxInactiveInterval > authenticationSessionTimeout) { - session.setAttribute(ORIGINAL_SESSION_TIMEOUT, session.getMaxInactiveInterval()); - session.setMaxInactiveInterval(authenticationSessionTimeout); - } - } + Session session = getAndInitializeSession(exchange, true); SessionManager manager = session.getSessionManager(); if (seenSessionManagers.add(manager)) { manager.registerSessionListener(LISTENER); @@ -230,25 +227,15 @@ protected void storeInitialLocation(final HttpServerExchange exchange, byte[] by @Override protected void handleRedirectBack(final HttpServerExchange exchange) { - final ServletRequestContext servletRequestContext = exchange.getAttachment(ServletRequestContext.ATTACHMENT_KEY); - HttpServletResponse resp = (HttpServletResponse) servletRequestContext.getServletResponse(); - HttpSessionImpl httpSession = servletRequestContext.getCurrentServletContext().getSession(exchange, false); - if (httpSession != null) { - Session session; - if (System.getSecurityManager() == null) { - session = httpSession.getSession(); - } else { - session = AccessController.doPrivileged(new HttpSessionImpl.UnwrapSessionAction(httpSession)); - } - Integer originalSessionTimeout = (Integer) session.removeAttribute(ORIGINAL_SESSION_TIMEOUT); - if (originalSessionTimeout != null) { - session.setMaxInactiveInterval(originalSessionTimeout); - } + final Session session = getAndInitializeSession(exchange, false); + if (session != null) { String path = (String) session.getAttribute(SESSION_KEY); if ((path == null || overrideInitial) && defaultPage != null) { path = defaultPage; } if (path != null) { + final ServletRequestContext servletRequestContext = exchange.getAttachment(ServletRequestContext.ATTACHMENT_KEY); + final HttpServletResponse resp = (HttpServletResponse) servletRequestContext.getServletResponse(); try { resp.sendRedirect(path); } catch (IOException e) { @@ -256,7 +243,41 @@ protected void handleRedirectBack(final HttpServerExchange exchange) { } } } + } + + private Session getAndInitializeSession(final HttpServerExchange exchange, final boolean createNewSession) { + final ServletRequestContext servletRequestContext = exchange.getAttachment(ServletRequestContext.ATTACHMENT_KEY); + final ServletContextImpl servletContextImpl = servletRequestContext.getCurrentServletContext(); + HttpSessionImpl httpSession = servletContextImpl.getSession(exchange, false); + if (httpSession == null && !createNewSession) return null; + + boolean newSession = false; + if (httpSession == null) { + httpSession = servletContextImpl.getSession(exchange, true); + newSession = true; + } + + Session session; + if (System.getSecurityManager() == null) { + session = httpSession.getSession(); + } else { + session = AccessController.doPrivileged(new HttpSessionImpl.UnwrapSessionAction(httpSession)); + } + + if (newSession) { + final int originalMaxInactiveInterval = session.getMaxInactiveInterval(); + if (originalMaxInactiveInterval > authenticationSessionTimeout) { + session.setAttribute(ORIGINAL_SESSION_TIMEOUT, session.getMaxInactiveInterval()); + session.setMaxInactiveInterval(authenticationSessionTimeout); + } + } else { + final Integer originalSessionTimeout = (Integer) session.removeAttribute(ORIGINAL_SESSION_TIMEOUT); + if (originalSessionTimeout != null) { + session.setMaxInactiveInterval(originalSessionTimeout); + } + } + return session; } private static class FormResponseWrapper extends HttpServletResponseWrapper {