Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(controller): support receiving MultipartFiles #3165

Merged
merged 15 commits into from
Jan 31, 2025
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,17 @@

import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.TreeMap;
import java.util.stream.Collectors;

import com.fasterxml.jackson.core.JsonPointer;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ObjectNode;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.context.ApplicationContext;
import org.springframework.context.annotation.Import;
import org.springframework.http.HttpStatus;
Expand All @@ -33,6 +37,7 @@
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.multipart.MultipartHttpServletRequest;

import com.vaadin.flow.internal.CurrentInstance;
import com.vaadin.flow.server.VaadinRequest;
Expand Down Expand Up @@ -70,6 +75,8 @@ public class EndpointController {
private static final Logger LOGGER = LoggerFactory
.getLogger(EndpointController.class);

public static final String BODY_PART_NAME = "hilla_body_part";

static final String ENDPOINT_METHODS = "/{endpoint}/{method}";

/**
Expand All @@ -87,6 +94,8 @@ public class EndpointController {

private final EndpointInvoker endpointInvoker;

private final ObjectMapper objectMapper;

VaadinService vaadinService;

/**
Expand All @@ -101,13 +110,16 @@ public class EndpointController {
* @param csrfChecker
* the csrf checker to use
*/

public EndpointController(ApplicationContext context,
EndpointRegistry endpointRegistry, EndpointInvoker endpointInvoker,
CsrfChecker csrfChecker) {
CsrfChecker csrfChecker,
@Qualifier("hillaEndpointObjectMapper") ObjectMapper objectMapper) {
this.context = context;
this.endpointInvoker = endpointInvoker;
this.csrfChecker = csrfChecker;
this.endpointRegistry = endpointRegistry;
this.objectMapper = objectMapper;
}

/**
Expand Down Expand Up @@ -169,7 +181,7 @@ public void registerEndpoints() {
* the current response
* @return execution result as a JSON string or an error message string
*/
@PostMapping(path = ENDPOINT_METHODS, produces = MediaType.APPLICATION_JSON_UTF8_VALUE)
@PostMapping(path = ENDPOINT_METHODS, consumes = MediaType.APPLICATION_JSON_VALUE, produces = MediaType.APPLICATION_JSON_UTF8_VALUE)
public ResponseEntity<String> serveEndpoint(
@PathVariable("endpoint") String endpointName,
@PathVariable("method") String methodName,
Expand All @@ -179,6 +191,35 @@ public ResponseEntity<String> serveEndpoint(
response);
}

/**
* Captures and processes the Vaadin multipart endpoint requests. They are
* used when there are uploaded files.
* <p>
* This method works as
* {@link #serveEndpoint(String, String, ObjectNode, HttpServletRequest, HttpServletResponse)},
* but it also captures the files uploaded in the request.
*
* @param endpointName
* the name of an endpoint to address the calls to, not case
* sensitive
* @param methodName
* the method name to execute on an endpoint, not case sensitive
* @param request
* the current multipart request which triggers the endpoint call
* @param response
* the current response
* @return execution result as a JSON string or an error message string
*/
@PostMapping(path = ENDPOINT_METHODS, consumes = MediaType.MULTIPART_FORM_DATA_VALUE, produces = MediaType.APPLICATION_JSON_VALUE)
public ResponseEntity<String> serveMultipartEndpoint(
@PathVariable("endpoint") String endpointName,
@PathVariable("method") String methodName,
HttpServletRequest request, HttpServletResponse response)
throws IOException {
return doServeEndpoint(endpointName, methodName, null, request,
response);
}

/**
* Captures and processes the Vaadin endpoint requests.
* <p>
Expand Down Expand Up @@ -227,6 +268,45 @@ private ResponseEntity<String> doServeEndpoint(String endpointName,
if (enforcementResult.isEnforcementNeeded()) {
return buildEnforcementResponseEntity(enforcementResult);
}

if (isMultipartRequest(request)) {
var multipartRequest = (MultipartHttpServletRequest) request;

// retrieve the body from a part having the correct name
var bodyPart = multipartRequest.getParameter(BODY_PART_NAME);
if (bodyPart == null) {
return ResponseEntity.badRequest()
.body(endpointInvoker.createResponseErrorObject(
"Missing body part in multipart request"));
}

try {
body = objectMapper.readValue(bodyPart, ObjectNode.class);
} catch (IOException e) {
LOGGER.error("Request body does not contain valid JSON", e);
return ResponseEntity
.status(HttpStatus.INTERNAL_SERVER_ERROR)
.body(endpointInvoker.createResponseErrorObject(
"Request body does not contain valid JSON"));
}

// parse uploaded files and add them to the body
var fileMap = multipartRequest.getFileMap();
for (var entry : fileMap.entrySet()) {
var partName = entry.getKey();
var file = entry.getValue();

// parse the part name as JSON pointer, e.g.
// "/orders/1/invoice"
var pointer = JsonPointer.valueOf(partName);
// split the path in parent and property name
var parent = pointer.head();
var property = pointer.last().getMatchingProperty();
var parentObject = body.withObject(parent);
parentObject.putPOJO(property, file);
}
}

Object returnValue = endpointInvoker.invoke(endpointName,
methodName, body, request.getUserPrincipal(),
request::isUserInRole);
Expand Down Expand Up @@ -273,6 +353,12 @@ private ResponseEntity<String> doServeEndpoint(String endpointName,
}
}

private boolean isMultipartRequest(HttpServletRequest request) {
String contentType = request.getContentType();
return contentType != null
&& contentType.startsWith(MediaType.MULTIPART_FORM_DATA_VALUE);
}

private ResponseEntity<String> buildEnforcementResponseEntity(
DAUUtils.EnforcementResult enforcementResult) {
EnforcementNotificationMessages messages = enforcementResult.messages();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import com.vaadin.hilla.exception.EndpointException;
import com.vaadin.hilla.exception.EndpointValidationException;
import com.vaadin.hilla.exception.EndpointValidationException.ValidationErrorData;
import com.vaadin.hilla.parser.jackson.JacksonObjectMapperFactory;
import jakarta.servlet.ServletContext;
import jakarta.validation.ConstraintViolation;
import jakarta.validation.Validation;
Expand All @@ -46,10 +45,12 @@
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.lang.reflect.Type;
import java.security.Principal;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.LinkedHashMap;
Expand Down Expand Up @@ -303,13 +304,36 @@ private Method getMethod(String endpointName, String methodName) {
return endpointData.getMethod(methodName).orElse(null);
}

private Map<String, JsonNode> getRequestParameters(ObjectNode body) {
private Map<String, JsonNode> getRequestParameters(ObjectNode body,
List<String> parameterNames) {
// Respect the order of parameters in the request body
Map<String, JsonNode> parametersData = new LinkedHashMap<>();
if (body != null) {
body.fields().forEachRemaining(entry -> parametersData
.put(entry.getKey(), entry.getValue()));
}
return parametersData;

// Try to adapt to the order of parameters in the method
var orderedData = new LinkedHashMap<String, JsonNode>();
for (String parameterName : parameterNames) {
JsonNode parameterData = parametersData.get(parameterName);
if (parameterData != null) {
parametersData.remove(parameterName);
orderedData.put(parameterName, parameterData);
}
}
orderedData.putAll(parametersData);

if (getLogger().isDebugEnabled()) {
var returnedParameterNames = List.copyOf(orderedData.keySet());
if (!parameterNames.equals(returnedParameterNames)) {
getLogger().debug(
"The parameter names in the request body do not match the method parameters. Expected: {}, but got: {}",
parameterNames, returnedParameterNames);
}
}

return orderedData;
}

private Object[] getVaadinEndpointParameters(
Expand Down Expand Up @@ -404,7 +428,10 @@ private Object invokeVaadinEndpointMethod(String endpointName,
endpointName, methodName, checkError));
}

Map<String, JsonNode> requestParameters = getRequestParameters(body);
var parameterNames = Arrays.stream(methodToInvoke.getParameters())
.map(Parameter::getName).toList();
Map<String, JsonNode> requestParameters = getRequestParameters(body,
parameterNames);
Type[] javaParameters = getJavaParameters(methodToInvoke, ClassUtils
.getUserClass(vaadinEndpointData.getEndpointObject()));
if (javaParameters.length != requestParameters.size()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,26 @@
*/
package com.vaadin.hilla.endpointransfermapper;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Map.Entry;

import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.databind.DeserializationContext;
import com.fasterxml.jackson.databind.JavaType;
import com.fasterxml.jackson.databind.JsonDeserializer;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.Module;
import com.fasterxml.jackson.databind.deser.std.StdDelegatingDeserializer;
import com.fasterxml.jackson.databind.module.SimpleModule;
import com.fasterxml.jackson.databind.node.POJONode;
import com.fasterxml.jackson.databind.ser.std.StdDelegatingSerializer;
import com.fasterxml.jackson.databind.type.TypeFactory;
import com.fasterxml.jackson.databind.util.StdConverter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.web.multipart.MultipartFile;

/**
* Defines mappings for certain endpoint types to corresponding transfer types.
Expand Down Expand Up @@ -174,6 +181,8 @@ public JavaType getOutputType(TypeFactory typeFactory) {
});

jacksonModule.addDeserializer(endpointType, deserializer);
jacksonModule.addDeserializer(MultipartFile.class,
new MultipartFileDeserializer());
}

/**
Expand Down Expand Up @@ -318,4 +327,28 @@ private Logger getLogger() {
return LoggerFactory.getLogger(getClass());
}

/**
* A deserializer for MultipartFile. It is needed because otherwise Jackson
* tries to deserialize the object which is already a POJO.
*/
public static class MultipartFileDeserializer
extends JsonDeserializer<MultipartFile> {

@Override
public MultipartFile deserialize(JsonParser p,
DeserializationContext ctxt) throws IOException {
JsonNode node = p.getCodec().readTree(p);

if (node instanceof POJONode) {
Object pojo = ((POJONode) node).getPojo();

if (pojo instanceof MultipartFile) {
return (MultipartFile) pojo;
}
}

throw new IOException(
"Expected a POJONode wrapping a MultipartFile");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,13 @@ public void setUp() {
EndpointRegistry endpointRegistry = new EndpointRegistry(
new EndpointNameChecker());
ApplicationContext appCtx = Mockito.mock(ApplicationContext.class);
ObjectMapper objectMapper = new JacksonObjectMapperFactory.Json()
.build();
EndpointInvoker endpointInvoker = new EndpointInvoker(appCtx,
new JacksonObjectMapperFactory.Json().build(),
new ExplicitNullableTypeChecker(), servletContext,
objectMapper, new ExplicitNullableTypeChecker(), servletContext,
endpointRegistry);
controller = new EndpointController(appCtx, endpointRegistry,
endpointInvoker, csrfChecker);
endpointInvoker, csrfChecker, objectMapper);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,9 @@ public EndpointController build() {
EndpointInvoker invoker = Mockito.spy(
new EndpointInvoker(applicationContext, endpointObjectMapper,
explicitNullableTypeChecker, servletContext, registry));
EndpointController controller = Mockito.spy(new EndpointController(
applicationContext, registry, invoker, csrfChecker));
EndpointController controller = Mockito
.spy(new EndpointController(applicationContext, registry,
invoker, csrfChecker, endpointObjectMapper));
Mockito.doReturn(mock(EndpointAccessChecker.class)).when(invoker)
.getAccessChecker();
return controller;
Expand Down
Loading