Skip to content

Commit

Permalink
GH-1018 Ensure AWS adapter can pass raw InputStream
Browse files Browse the repository at this point in the history
Resolves #1018
  • Loading branch information
olegz committed Mar 30, 2023
1 parent e76cca0 commit d8a03d7
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package org.springframework.cloud.function.adapter.aws;

import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.Type;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
Expand All @@ -31,7 +33,9 @@
import org.springframework.http.HttpStatus;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHeaders;
import org.springframework.messaging.support.GenericMessage;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.util.StreamUtils;

/**
*
Expand Down Expand Up @@ -77,6 +81,23 @@ static boolean isSupportedAWSType(Type inputType) {
|| typeName.equals("com.amazonaws.services.lambda.runtime.events.KinesisEvent");
}

@SuppressWarnings("rawtypes")
public static Message generateMessage(InputStream payload, Type inputType, boolean isSupplier, JsonMapper jsonMapper, Context context) throws IOException {
if (inputType != null && FunctionTypeUtils.isMessage(inputType)) {
inputType = FunctionTypeUtils.getImmediateGenericType(inputType, 0);
}
if (inputType != null && InputStream.class.isAssignableFrom(FunctionTypeUtils.getRawType(inputType))) {
MessageBuilder msgBuilder = MessageBuilder.withPayload(payload);
if (context != null) {
msgBuilder.setHeader(AWSLambdaUtils.AWS_CONTEXT, context);
}
return msgBuilder.build();
}
else {
return generateMessage(StreamUtils.copyToByteArray(payload), inputType, isSupplier, jsonMapper, context);
}
}

public static Message<byte[]> generateMessage(byte[] payload, Type inputType, boolean isSupplier, JsonMapper jsonMapper) {
return generateMessage(payload, inputType, isSupplier, jsonMapper, null);
}
Expand All @@ -87,6 +108,7 @@ public static Message<byte[]> generateMessage(byte[] payload, Type inputType, bo
logger.info("Received: " + new String(payload, StandardCharsets.UTF_8));
}


Object structMessage = jsonMapper.fromJson(payload, Object.class);
boolean isApiGateway = structMessage instanceof Map
&& (((Map<String, Object>) structMessage).containsKey("httpMethod") ||
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ public FunctionInvoker() {
@Override
public void handleRequest(InputStream input, OutputStream output, Context context) throws IOException {
Message requestMessage = AWSLambdaUtils
.generateMessage(StreamUtils.copyToByteArray(input), this.function.getInputType(), this.function.isSupplier(), jsonMapper, context);
.generateMessage(input, this.function.getInputType(), this.function.isSupplier(), jsonMapper, context);

Object response = this.function.apply(requestMessage);
byte[] responseBytes = this.buildResult(requestMessage, response);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
import org.springframework.messaging.converter.AbstractMessageConverter;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.util.MimeType;
import org.springframework.util.StreamUtils;

import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.fail;
Expand Down Expand Up @@ -989,6 +990,40 @@ public void testApiGatewayAsSupplier() throws Exception {
assertThat(result.get("body")).isEqualTo("\"boom\"");
}

@SuppressWarnings({ "rawtypes", "unchecked" })
@Test
public void testApiGatewayInAndOutInputStream() throws Exception {
System.setProperty("MAIN_CLASS", ApiGatewayConfiguration.class.getName());
System.setProperty("spring.cloud.function.definition", "echoInputStreamToString");
FunctionInvoker invoker = new FunctionInvoker();

InputStream targetStream = new ByteArrayInputStream(this.apiGatewayEvent.getBytes());
ByteArrayOutputStream output = new ByteArrayOutputStream();
invoker.handleRequest(targetStream, output, null);

Map result = mapper.readValue(output.toByteArray(), Map.class);
assertThat(result.get("body")).isEqualTo("hello");
Map headers = (Map) result.get("headers");
assertThat(headers).isNotEmpty();
}

@SuppressWarnings({ "rawtypes", "unchecked" })
@Test
public void testApiGatewayInAndOutInputStreamMsg() throws Exception {
System.setProperty("MAIN_CLASS", ApiGatewayConfiguration.class.getName());
System.setProperty("spring.cloud.function.definition", "echoInputStreamMsgToString");
FunctionInvoker invoker = new FunctionInvoker();

InputStream targetStream = new ByteArrayInputStream(this.apiGatewayEvent.getBytes());
ByteArrayOutputStream output = new ByteArrayOutputStream();
invoker.handleRequest(targetStream, output, null);

Map result = mapper.readValue(output.toByteArray(), Map.class);
assertThat(result.get("body")).isEqualTo("hello");
Map headers = (Map) result.get("headers");
assertThat(headers).isNotEmpty();
}

@SuppressWarnings("rawtypes")
@Test
public void testApiGatewayInAndOut() throws Exception {
Expand Down Expand Up @@ -1426,6 +1461,34 @@ public Function<APIGatewayProxyRequestEvent, String> inputApiEvent() {
};
}

@Bean

public Function<InputStream, String> echoInputStreamToString() {
return is -> {
try {
String result = StreamUtils.copyToString(is, StandardCharsets.UTF_8);
return result;
}
catch (Exception e) {
throw new RuntimeException(e);
}
};
}

@Bean

public Function<Message<InputStream>, String> echoInputStreamMsgToString() {
return msg -> {
try {
String result = StreamUtils.copyToString(msg.getPayload(), StandardCharsets.UTF_8);
return result;
}
catch (Exception e) {
throw new RuntimeException(e);
}
};
}

@Bean
public Function<APIGatewayProxyRequestEvent, APIGatewayProxyResponseEvent> inputOutputApiEvent() {
return v -> {
Expand Down

0 comments on commit d8a03d7

Please sign in to comment.