Skip to content

Commit

Permalink
[Performance #287] Sending messages to session, validation throttling
Browse files Browse the repository at this point in the history
  • Loading branch information
lukaskabc authored and ledsoft committed Sep 13, 2024
1 parent 347f000 commit c81d347
Show file tree
Hide file tree
Showing 13 changed files with 158 additions and 191 deletions.
Binary file modified doc/throttle-debounce.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import cz.cvut.kbss.termit.util.Constants;
import cz.cvut.kbss.termit.websocket.handler.StompExceptionHandler;
import cz.cvut.kbss.termit.websocket.handler.WebSocketExceptionHandler;
import cz.cvut.kbss.termit.websocket.handler.WebSocketMessageWithHeadersValueHandler;
import org.jetbrains.annotations.NotNull;
import org.springframework.context.ApplicationContext;
import org.springframework.context.annotation.Configuration;
Expand All @@ -16,7 +15,6 @@
import org.springframework.messaging.converter.MessageConverter;
import org.springframework.messaging.converter.StringMessageConverter;
import org.springframework.messaging.handler.invocation.HandlerMethodArgumentResolver;
import org.springframework.messaging.handler.invocation.HandlerMethodReturnValueHandler;
import org.springframework.messaging.simp.SimpMessagingTemplate;
import org.springframework.messaging.simp.config.ChannelRegistration;
import org.springframework.messaging.simp.config.MessageBrokerRegistry;
Expand Down Expand Up @@ -94,11 +92,6 @@ public void configureClientInboundChannel(@NotNull ChannelRegistration registrat
registration.interceptors(webSocketJwtAuthorizationInterceptor, new SecurityContextChannelInterceptor(), interceptor);
}

@Override
public void addReturnValueHandlers(List<HandlerMethodReturnValueHandler> returnValueHandlers) {
returnValueHandlers.add(new WebSocketMessageWithHeadersValueHandler(simpMessagingTemplate));
}

@Override
public void registerStompEndpoints(StompEndpointRegistry registry) {
registry.addEndpoint("/ws").setAllowedOrigins(allowedOrigins.split(","));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import cz.cvut.kbss.termit.util.Utils;
import cz.cvut.kbss.termit.util.throttle.CachableFuture;
import cz.cvut.kbss.termit.util.throttle.Throttle;
import cz.cvut.kbss.termit.util.throttle.ThrottledFuture;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
Expand Down Expand Up @@ -361,7 +362,6 @@ public void refreshLastModified(RefreshLastModifiedEvent event) {
refreshLastModified();
}

@Throttle("{#vocabulary}")
@Transactional
public CachableFuture<Collection<ValidationResult>> validateContents(URI vocabulary) {
final VocabularyContentValidator validator = context.getBean(VocabularyContentValidator.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@

import cz.cvut.kbss.termit.event.EvictCacheEvent;
import cz.cvut.kbss.termit.event.VocabularyContentModified;
import cz.cvut.kbss.termit.exception.TermItException;
import cz.cvut.kbss.termit.model.validation.ValidationResult;
import cz.cvut.kbss.termit.util.throttle.Throttle;
import cz.cvut.kbss.termit.util.throttle.ThrottledFuture;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
Expand All @@ -30,15 +32,18 @@
import org.springframework.context.annotation.Profile;
import org.springframework.context.event.EventListener;
import org.springframework.stereotype.Component;
import org.springframework.transaction.annotation.Transactional;

import java.net.URI;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;

@Component("cachingValidator")
@Primary
Expand All @@ -53,7 +58,7 @@ public class ResultCachingValidator implements VocabularyContentValidator {
*/
private final Map<URI, @NotNull Collection<URI>> vocabularyClosure = new ConcurrentHashMap<>();

private final Map<URI, @NotNull List<ValidationResult>> validationCache = new HashMap<>();
private final Map<URI, @NotNull Collection<ValidationResult>> validationCache = new HashMap<>();

/**
* @return true when the cache contents are dirty and should be refreshed; false otherwise.
Expand All @@ -62,12 +67,14 @@ public boolean isNotDirty(@NotNull URI originVocabularyIri) {
return vocabularyClosure.containsKey(originVocabularyIri);
}

private List<ValidationResult> getCached(@NotNull URI originVocabularyIri) {
private Optional<Collection<ValidationResult>> getCached(@NotNull URI originVocabularyIri) {
synchronized (validationCache) {
return validationCache.getOrDefault(originVocabularyIri, List.of());
return Optional.ofNullable(validationCache.get(originVocabularyIri));
}
}

@Throttle("{#originVocabularyIri}")
@Transactional
@Override
public @NotNull ThrottledFuture<Collection<ValidationResult>> validate(@NotNull URI originVocabularyIri, @NotNull Collection<URI> vocabularyIris) {
final Set<URI> iris = Set.copyOf(vocabularyIris);
Expand All @@ -76,25 +83,35 @@ private List<ValidationResult> getCached(@NotNull URI originVocabularyIri) {
return ThrottledFuture.done(List.of());
}

List<ValidationResult> cached = getCached(originVocabularyIri);
if (isNotDirty(originVocabularyIri)) {
return ThrottledFuture.done(cached);
Optional<Collection<ValidationResult>> cached = getCached(originVocabularyIri);
if (isNotDirty(originVocabularyIri) && cached.isPresent()) {
return ThrottledFuture.done(cached.get());
}

return ThrottledFuture.of(() -> runValidation(originVocabularyIri, iris)).setCachedResult(cached.isEmpty() ? null : cached);
return ThrottledFuture.of(() -> runValidation(originVocabularyIri, iris)).setCachedResult(cached.orElse(null));
}


private @NotNull Collection<ValidationResult> runValidation(@NotNull URI originVocabularyIri, @NotNull final Set<URI> iris) {
if (isNotDirty(originVocabularyIri)) {
return getCached(originVocabularyIri);
Optional<Collection<ValidationResult>> cached = getCached(originVocabularyIri);
if (isNotDirty(originVocabularyIri) && cached.isPresent()) {
return cached.get();
}

final List<ValidationResult> results = getValidator().runValidation(iris);
final Collection<ValidationResult> results;
try {
// executes real validation
results = getValidator().validate(originVocabularyIri, iris).get();
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new TermItException(e);
} catch (ExecutionException e) {
throw new TermItException(e.getCause());
}

synchronized (validationCache) {
vocabularyClosure.put(originVocabularyIri, Collections.unmodifiableCollection(iris));
validationCache.put(originVocabularyIri, Collections.unmodifiableList(results));
validationCache.put(originVocabularyIri, Collections.unmodifiableCollection(results));
}

return results;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import cz.cvut.kbss.termit.persistence.context.VocabularyContextMapper;
import cz.cvut.kbss.termit.util.Configuration;
import cz.cvut.kbss.termit.util.Utils;
import cz.cvut.kbss.termit.util.throttle.Throttle;
import cz.cvut.kbss.termit.util.throttle.ThrottledFuture;
import org.apache.jena.rdf.model.Literal;
import org.apache.jena.rdf.model.Model;
Expand Down Expand Up @@ -142,6 +143,7 @@ private void loadOverrideRules(Model validationModel, String language) throws IO
}
}

@Throttle("{#originVocabularyIri}")
@Transactional(readOnly = true)
@Override
public @NotNull ThrottledFuture<Collection<ValidationResult>> validate(final @NotNull URI originVocabularyIri, final @NotNull Collection<URI> vocabularyIris) {
Expand All @@ -159,7 +161,9 @@ private void loadOverrideRules(Model validationModel, String language) throws IO
protected synchronized List<ValidationResult> runValidation(@NotNull Collection<URI> vocabularyIris) {
LOG.debug("Validating {}", vocabularyIris);
try {
LOG.trace("Constructing model from RDF4J repository...");
final Model dataModel = getModelFromRdf4jRepository(vocabularyIris);
LOG.trace("Model constructed, running validation...");
// TODO: would be better to cache the validator, but its not thread safe
org.topbraid.shacl.validation.ValidationReport report = new com.github.sgov.server.Validator()
.validate(dataModel, validationModel);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,21 +144,7 @@ public ResponseEntity<ErrorInfo> authorizationException(HttpServletRequest reque
public ResponseEntity<ErrorInfo> authenticationException(HttpServletRequest request, AuthenticationException e) {
LOG.warn("Authentication failure during HTTP request to {}: {}", request.getRequestURI(), e.getMessage());
LOG.atDebug().setCause(e).log(e.getMessage());
return new ResponseEntity<>(errorInfo(request, e), HttpStatus.FORBIDDEN);
}

/**
* Fired, for example, on method security violation
*/
@ExceptionHandler(AccessDeniedException.class)
public ResponseEntity<ErrorInfo> accessDeniedException(HttpServletRequest request, AccessDeniedException e) {
LOG.atWarn().setMessage("[{}] Unauthorized access: {}").addArgument(() -> {
if (request.getUserPrincipal() != null) {
return request.getUserPrincipal().getName();
}
return "(unknown user)";
}).addArgument(e.getMessage()).log();
return new ResponseEntity<>(errorInfo(request, e), HttpStatus.FORBIDDEN);
return new ResponseEntity<>(errorInfo(request, e), HttpStatus.UNAUTHORIZED);
}

@ExceptionHandler(ValidationException.class)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ public class ThrottleAspect implements LongRunningTaskRegister {

private final Clock clock;

private final Executor transactionExecutor;
private final TransactionExecutor transactionExecutor;

private final @NotNull AtomicReference<Instant> lastClear;

Expand Down Expand Up @@ -317,16 +317,16 @@ private Runnable createRunnableToSchedule(ThrottledFuture<?> throttledFuture, Id
// restore the security context
SecurityContextHolder.setContext(securityContext.get());
try {
// update last run timestamp
synchronized (lastRun) {
lastRun.put(identifier, Instant.now(clock));
}
// fulfill the future
if (withTransaction) {
transactionExecutor.execute(throttledFuture::run);
} else {
throttledFuture.run();
}
// update last run timestamp
synchronized (lastRun) {
lastRun.put(identifier, Instant.now(clock));
}
} finally {
// clear the security context
SecurityContextHolder.clearContext();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,28 +96,19 @@ public boolean isDone() {
}

/**
* Does not execute the task, blocks the current thread until some result is available.
*
* @return cached result when available, otherwise awaits future resolution.
* Does not execute the task, blocks the current thread until the result is available.
*/
@Override
public T get() throws InterruptedException, ExecutionException {
if (!isDone() && this.cachedResult != null) {
return this.cachedResult;
}
return future.get();
}

/**
* Does not execute the task, blocks the current thread until some result is available.
* @return cached result when available, otherwise awaits future resolution.
* Does not execute the task, blocks the current thread until the result is available.
*/
@Override
public T get(long timeout, @NotNull TimeUnit unit)
throws InterruptedException, ExecutionException, TimeoutException {
if (!isDone() && this.cachedResult != null) {
return this.cachedResult;
}
return future.get(timeout, unit);
}
/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package cz.cvut.kbss.termit.websocket;

import cz.cvut.kbss.termit.rest.BaseController;
import cz.cvut.kbss.termit.service.IdentifierResolver;
import cz.cvut.kbss.termit.util.Configuration;
import org.jetbrains.annotations.NotNull;
import org.springframework.messaging.MessageHeaders;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessagingTemplate;
import org.springframework.messaging.simp.stomp.StompCommand;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.messaging.simp.user.DestinationUserNameProvider;

import java.security.Principal;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;

public class BaseWebSocketController extends BaseController {

protected final SimpMessagingTemplate messagingTemplate;

protected BaseWebSocketController(IdentifierResolver idResolver, Configuration config,
SimpMessagingTemplate messagingTemplate) {
super(idResolver, config);
this.messagingTemplate = messagingTemplate;
}

/**
* Resolves session id, when present, and sends to the specific session.
* When session id is not present, sends it to all sessions of specific user.
*
* @param destination the destination (without user prefix)
* @param payload payload to send
* @param replyHeaders native headers for the reply
* @param sourceHeaders original headers containing session id or name of the user
*/
protected void sendToSession(@NotNull String destination, @NotNull Object payload,
@NotNull Map<String, Object> replyHeaders, @NotNull MessageHeaders sourceHeaders) {
getSessionId(sourceHeaders)
.ifPresentOrElse(sessionId -> { // session id present
StompHeaderAccessor headerAccessor = StompHeaderAccessor.create(StompCommand.MESSAGE);
// add reply headers as native headers
replyHeaders.forEach((name, value) -> headerAccessor.addNativeHeader(name, Objects.toString(value)));
headerAccessor.setSessionId(sessionId); // pass session id to new headers
// send to user session
messagingTemplate.convertAndSendToUser(sessionId, destination, payload, headerAccessor.toMessageHeaders());
},
// session id not present, send to all user sessions
() -> getUser(sourceHeaders).ifPresent(user -> messagingTemplate.convertAndSendToUser(user, destination, payload, replyHeaders))
);
}

/**
* Resolves name which can be used to send a message to the user with {@link SimpMessagingTemplate#convertAndSendToUser}.
*
* @return name or session id, or empty when information is not available.
*/
protected @NotNull Optional<String> getUser(@NotNull MessageHeaders messageHeaders) {
return getUserName(messageHeaders).or(() -> getSessionId(messageHeaders));
}

private @NotNull Optional<String> getSessionId(@NotNull MessageHeaders messageHeaders) {
return Optional.ofNullable(SimpMessageHeaderAccessor.getSessionId(messageHeaders));
}

/**
* Resolves the name of the user
*
* @return the name or null
*/
private @NotNull Optional<String> getUserName(MessageHeaders headers) {
Principal principal = SimpMessageHeaderAccessor.getUser(headers);
if (principal != null) {
final String name = (principal instanceof DestinationUserNameProvider provider ?
provider.getDestinationUserName() : principal.getName());
return Optional.ofNullable(name);
}
return Optional.empty();
}
}
Loading

0 comments on commit c81d347

Please sign in to comment.