Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add the support for multipart/form-data in python service and engine #751

Merged
merged 4 commits into from
Aug 7, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions engine/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,10 @@
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-web</artifactId>
</dependency>
<dependency>
<groupId>com.h2database</groupId>
<artifactId>h2</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,16 @@
*******************************************************************************/
package io.seldon.engine.api.rest;

import java.io.IOException;
import java.util.Base64;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicBoolean;

import javax.annotation.PostConstruct;

import com.fasterxml.jackson.databind.ObjectMapper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
Expand All @@ -43,6 +48,8 @@
import io.seldon.engine.tracing.TracingProvider;
import io.seldon.protos.PredictionProtos.Feedback;
import io.seldon.protos.PredictionProtos.SeldonMessage;
import org.springframework.web.multipart.MultipartFile;
import org.springframework.web.multipart.MultipartHttpServletRequest;

@RestController
public class RestClientController {
Expand Down Expand Up @@ -127,55 +134,111 @@ String unpause() {
@Timed
@CrossOrigin(origins = "*")
@RequestMapping(value = "/api/v0.1/predictions", method = RequestMethod.POST, consumes = "application/json; charset=utf-8", produces = "application/json; charset=utf-8")
public ResponseEntity<String> predictions(RequestEntity<String> requestEntity)
public ResponseEntity<String> predictions_json(RequestEntity<String> requestEntity)
{
logger.debug("Received predict request");
Scope tracingScope = null;
if (tracingProvider.isActive())
tracingScope = tracingProvider.getTracer().buildSpan("/api/v0.1/predictions").startActive(true);
try
{
SeldonMessage request;
try
{
SeldonMessage.Builder builder = SeldonMessage.newBuilder();
ProtoBufUtils.updateMessageBuilderFromJson(builder, requestEntity.getBody() );
request = builder.build();
}
catch (InvalidProtocolBufferException e)
{
logger.error("Bad request",e);
throw new APIException(ApiExceptionType.ENGINE_INVALID_JSON,requestEntity.getBody());
return _predictions(requestEntity.getBody());
}

try
finally
{
SeldonMessage response = predictionService.predict(request);
String responseJson = ProtoBufUtils.toJson(response);
return new ResponseEntity<String>(responseJson,HttpStatus.OK);
if (tracingScope != null)
tracingScope.close();
}
catch (InterruptedException e) {
throw new APIException(ApiExceptionType.ENGINE_INTERRUPTED,e.getMessage());
} catch (ExecutionException e) {
if (e.getCause().getClass() == APIException.class){
throw (APIException) e.getCause();

}


@Timed
@CrossOrigin(origins = "*")
@RequestMapping(value = "/api/v0.1/predictions", method = RequestMethod.POST, consumes = "multipart/form-data", produces = "application/json; charset=utf-8")
public ResponseEntity<String> predictions_multiform(MultipartHttpServletRequest requestEntity)
{
logger.debug("Received predict request");
Scope tracingScope = null;
if (tracingProvider.isActive())
tracingScope = tracingProvider.getTracer().buildSpan("/api/v0.1/predictions").startActive(true);
try {
ObjectMapper mapper = new ObjectMapper();
Map<String,Object> mergedParamMap = new HashMap<String,Object>();
if(requestEntity.getParameterMap() != null){
for(Map.Entry<String,String[]> formEntry: requestEntity.getParameterMap().entrySet()){
if(formEntry.getKey().equalsIgnoreCase(SeldonMessage.DataOneofCase.STRDATA.name())){
mergedParamMap.put(formEntry.getKey(),formEntry.getValue()[0]);
}else{
mergedParamMap.put(formEntry.getKey(),mapper.readTree(formEntry.getValue()[0]));
}
}
}
else
{
throw new APIException(ApiExceptionType.ENGINE_EXECUTION_FAILURE,e.getMessage());
if(requestEntity.getFileMap() != null){
for(Map.Entry<String ,MultipartFile> fileEntry: requestEntity.getFileMap().entrySet()){
if(fileEntry.getKey().equalsIgnoreCase(SeldonMessage.DataOneofCase.STRDATA.name())){
mergedParamMap.put(fileEntry.getKey(),new String(fileEntry.getValue().getBytes()));
}else{
mergedParamMap.put(fileEntry.getKey(),fileEntry.getValue().getBytes());
}
}
}
} catch (InvalidProtocolBufferException e) {
throw new APIException(ApiExceptionType.ENGINE_INVALID_JSON,"");
}
}
finally

return _predictions(mapper.writeValueAsString(mergedParamMap));
} catch (IOException e) {
logger.error("Bad request",e);
throw new APIException(ApiExceptionType.REQUEST_IO_EXCEPTION,e.getMessage());

} finally
{
if (tracingScope != null)
tracingScope.close();
}

}


/**
* It calls the prediction service for the input json.
* It is the base function for all forms of request Content-type
* @param json - Input JSON to predict REST api
* @return The response for prediction service
*/
private ResponseEntity<String> _predictions(String json)
{
SeldonMessage request;
try
{
SeldonMessage.Builder builder = SeldonMessage.newBuilder();
ProtoBufUtils.updateMessageBuilderFromJson(builder, json );
request = builder.build();
}
catch (InvalidProtocolBufferException e)
{
logger.error("Bad request",e);
throw new APIException(ApiExceptionType.ENGINE_INVALID_JSON,json);
}

try
{
SeldonMessage response = predictionService.predict(request);
String responseJson = ProtoBufUtils.toJson(response);
return new ResponseEntity<String>(responseJson,HttpStatus.OK);
}
catch (InterruptedException e) {
throw new APIException(ApiExceptionType.ENGINE_INTERRUPTED,e.getMessage());
} catch (ExecutionException e) {
if (e.getCause().getClass() == APIException.class){
throw (APIException) e.getCause();
}
else
{
throw new APIException(ApiExceptionType.ENGINE_EXECUTION_FAILURE,e.getMessage());
}
} catch (InvalidProtocolBufferException e) {
throw new APIException(ApiExceptionType.ENGINE_INVALID_JSON,"");
}
}

@Timed
@CrossOrigin(origins = "*")
@RequestMapping(value= "/api/v0.1/feedback", method = RequestMethod.POST, consumes = "application/json; charset=utf-8", produces = "application/json; charset=utf-8")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ public enum ApiExceptionType {
ENGINE_INVALID_COMBINER_RESPONSE(204,"Invalid number of predictions from combiner",500),
ENGINE_INTERRUPTED(205,"API call interrupted",500),
ENGINE_EXECUTION_FAILURE(206,"Execution failure",500),
ENGINE_INVALID_ROUTING(207,"Invalid Routing",500);
ENGINE_INVALID_ROUTING(207,"Invalid Routing",500),
REQUEST_IO_EXCEPTION(208,"IO Exception",500);

int id;
String message;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

import java.util.Arrays;

import com.google.protobuf.ByteString;
import io.seldon.protos.PredictionProtos;
import org.springframework.stereotype.Component;

import io.seldon.protos.PredictionProtos.DefaultData;
Expand All @@ -37,15 +39,25 @@ public SimpleModelUnit() {}

@Override
public SeldonMessage transformInput(SeldonMessage input, PredictiveUnitState state){
SeldonMessage output = SeldonMessage.newBuilder()
SeldonMessage.Builder builder = SeldonMessage.newBuilder()
.setStatus(Status.newBuilder().setStatus(Status.StatusFlag.SUCCESS).build())
.setMeta(Meta.newBuilder()
.addMetrics(Metric.newBuilder().setKey("mymetric_counter").setType(MetricType.COUNTER).setValue(1))
.addMetrics(Metric.newBuilder().setKey("mymetric_gauge").setType(MetricType.GAUGE).setValue(100))
.addMetrics(Metric.newBuilder().setKey("mymetric_timer").setType(MetricType.TIMER).setValue(22.1F)))
.setData(DefaultData.newBuilder().addAllNames(Arrays.asList(classes))
.addMetrics(Metric.newBuilder().setKey("mymetric_timer").setType(MetricType.TIMER).setValue(22.1F)));

// echo in case of strData and binData
if(input.getDataOneofCase().equals(SeldonMessage.DataOneofCase.BINDATA)){
builder.setBinData(input.getBinData());
} else if (input.getDataOneofCase().equals(SeldonMessage.DataOneofCase.STRDATA)){
builder.setStrData(input.getStrData());
}else{
builder.setData(DefaultData.newBuilder().addAllNames(Arrays.asList(classes))
.setTensor(Tensor.newBuilder().addShape(1).addShape(values.length)
.addAllValues(Arrays.asList(values)))).build();
.addAllValues(Arrays.asList(values))));
}

SeldonMessage output = builder.build();
System.out.println("Model " + state.name + " finishing computations");
return output;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,20 @@
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.test.context.SpringBootTest.WebEnvironment;
import org.springframework.http.MediaType;
import org.springframework.jmx.support.MetricType;
import org.springframework.test.context.junit4.SpringRunner;
import org.springframework.test.web.servlet.MockMvc;
import org.springframework.test.web.servlet.MvcResult;
import org.springframework.test.web.servlet.request.MockMvcRequestBuilders;
import org.springframework.test.web.servlet.setup.MockMvcBuilders;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.context.WebApplicationContext;

import io.seldon.engine.pb.ProtoBufUtils;
import io.seldon.protos.PredictionProtos.SeldonMessage;

import java.util.*;

@RunWith(SpringRunner.class)
@SpringBootTest(webEnvironment = WebEnvironment.RANDOM_PORT)
//@AutoConfigureMockMvc
Expand Down Expand Up @@ -135,4 +138,146 @@ public void testPredict_21dim_tensor() throws Exception
Assert.assertEquals("GAUGE", seldonMessage.getMeta().getMetrics(1).getType().toString());
Assert.assertEquals("TIMER", seldonMessage.getMeta().getMetrics(2).getType().toString());
}

@Test
public void testPredict_multiform_11dim_ndarry() throws Exception
{
final String predictJson = "{" +
"\"request\": {" +
"\"ndarray\": [[1.0]]}" +
"}";
final MultiValueMap<String,String> paramMap = new LinkedMultiValueMap<>();
paramMap.put("data", Arrays.asList(predictJson));
MvcResult res = mvc.perform(MockMvcRequestBuilders.post("/api/v0.1/predictions")
.accept(MediaType.APPLICATION_JSON_UTF8)
.params(paramMap)
.contentType(MediaType.MULTIPART_FORM_DATA)).andReturn();
String response = res.getResponse().getContentAsString();
System.out.println(response);
Assert.assertEquals(200, res.getResponse().getStatus());
}

@Test
public void testPredict_multiform_21dim_ndarry() throws Exception
{
final String predictJson = "{" +
"\"request\": {" +
"\"ndarray\": [[1.0],[2.0]]}" +
"}";
final MultiValueMap<String,String> paramMap = new LinkedMultiValueMap<>();
paramMap.put("data", Arrays.asList(predictJson));
MvcResult res = mvc.perform(MockMvcRequestBuilders.post("/api/v0.1/predictions")
.accept(MediaType.APPLICATION_JSON_UTF8)
.params(paramMap)
.contentType(MediaType.MULTIPART_FORM_DATA)).andReturn();
String response = res.getResponse().getContentAsString();
System.out.println(response);
Assert.assertEquals(200, res.getResponse().getStatus());
SeldonMessage.Builder builder = SeldonMessage.newBuilder();
ProtoBufUtils.updateMessageBuilderFromJson(builder, response );
SeldonMessage seldonMessage = builder.build();
Assert.assertEquals(3, seldonMessage.getMeta().getMetricsCount());
Assert.assertEquals("COUNTER", seldonMessage.getMeta().getMetrics(0).getType().toString());
Assert.assertEquals("GAUGE", seldonMessage.getMeta().getMetrics(1).getType().toString());
Assert.assertEquals("TIMER", seldonMessage.getMeta().getMetrics(2).getType().toString());
}

@Test
public void testPredict_multiform_21dim_tensor() throws Exception
{
final String predictJson = "{" +
"\"request\": {" +
"\"tensor\": {\"shape\":[2,1],\"values\":[1.0,2.0]}}" +
"}";
final MultiValueMap<String,String> paramMap = new LinkedMultiValueMap<>();
paramMap.put("data", Arrays.asList(predictJson));
MvcResult res = mvc.perform(MockMvcRequestBuilders.post("/api/v0.1/predictions")
.accept(MediaType.APPLICATION_JSON_UTF8)
.params(paramMap)
.contentType(MediaType.MULTIPART_FORM_DATA)).andReturn();
String response = res.getResponse().getContentAsString();
System.out.println(response);
Assert.assertEquals(200, res.getResponse().getStatus());
SeldonMessage.Builder builder = SeldonMessage.newBuilder();
ProtoBufUtils.updateMessageBuilderFromJson(builder, response );
SeldonMessage seldonMessage = builder.build();
Assert.assertEquals(3, seldonMessage.getMeta().getMetricsCount());
Assert.assertEquals("COUNTER", seldonMessage.getMeta().getMetrics(0).getType().toString());
Assert.assertEquals("GAUGE", seldonMessage.getMeta().getMetrics(1).getType().toString());
Assert.assertEquals("TIMER", seldonMessage.getMeta().getMetrics(2).getType().toString());
}
@Test
public void testPredict_multiform_binData() throws Exception
{
final String metaJson = "{\"puid\":\"1234\"}" ;
final MultiValueMap<String,String> paramMap = new LinkedMultiValueMap<>();
paramMap.put("meta", Arrays.asList(metaJson));
byte[] fileData = "test data".getBytes();
MvcResult res = mvc.perform(MockMvcRequestBuilders.fileUpload("/api/v0.1/predictions").file("binData",fileData)
.accept(MediaType.APPLICATION_JSON_UTF8)
.params(paramMap)
.contentType(MediaType.MULTIPART_FORM_DATA)).andReturn();
String response = res.getResponse().getContentAsString();
System.out.println(response);
Assert.assertEquals(200, res.getResponse().getStatus());
SeldonMessage.Builder builder = SeldonMessage.newBuilder();
ProtoBufUtils.updateMessageBuilderFromJson(builder, response );
SeldonMessage seldonMessage = builder.build();
Assert.assertEquals(3, seldonMessage.getMeta().getMetricsCount());
Assert.assertEquals("COUNTER", seldonMessage.getMeta().getMetrics(0).getType().toString());
Assert.assertEquals("GAUGE", seldonMessage.getMeta().getMetrics(1).getType().toString());
Assert.assertEquals("TIMER", seldonMessage.getMeta().getMetrics(2).getType().toString());
Assert.assertEquals(new String(fileData), seldonMessage.getBinData().toStringUtf8());
Assert.assertEquals("1234", seldonMessage.getMeta().getPuid());
}
@Test
public void testPredict_multiform_strData_as_file() throws Exception
{
final String metaJson = "{\"puid\":\"1234\"}" ;
final MultiValueMap<String,String> paramMap = new LinkedMultiValueMap<>();
paramMap.put("meta", Arrays.asList(metaJson));
byte[] fileData = "test data".getBytes();
MvcResult res = mvc.perform(MockMvcRequestBuilders.fileUpload("/api/v0.1/predictions").file("strData",fileData)
.accept(MediaType.APPLICATION_JSON_UTF8)
.params(paramMap)
.contentType(MediaType.MULTIPART_FORM_DATA)).andReturn();
String response = res.getResponse().getContentAsString();
System.out.println(response);
Assert.assertEquals(200, res.getResponse().getStatus());
SeldonMessage.Builder builder = SeldonMessage.newBuilder();
ProtoBufUtils.updateMessageBuilderFromJson(builder, response );
SeldonMessage seldonMessage = builder.build();
Assert.assertEquals(3, seldonMessage.getMeta().getMetricsCount());
Assert.assertEquals("COUNTER", seldonMessage.getMeta().getMetrics(0).getType().toString());
Assert.assertEquals("GAUGE", seldonMessage.getMeta().getMetrics(1).getType().toString());
Assert.assertEquals("TIMER", seldonMessage.getMeta().getMetrics(2).getType().toString());
Assert.assertEquals(new String(fileData), seldonMessage.getStrData());
Assert.assertEquals("1234", seldonMessage.getMeta().getPuid());

}
@Test
public void testPredict_multiform_strData_as_text() throws Exception
{
final String metaJson = "{\"puid\":\"1234\"}" ;
final MultiValueMap<String,String> paramMap = new LinkedMultiValueMap<>();
paramMap.put("meta", Arrays.asList(metaJson));
String strdata = "test data";
paramMap.put("strData",Arrays.asList(strdata));
MvcResult res = mvc.perform(MockMvcRequestBuilders.post("/api/v0.1/predictions")
.accept(MediaType.APPLICATION_JSON_UTF8)
.params(paramMap)
.contentType(MediaType.MULTIPART_FORM_DATA)).andReturn();
String response = res.getResponse().getContentAsString();
System.out.println(response);
Assert.assertEquals(200, res.getResponse().getStatus());
SeldonMessage.Builder builder = SeldonMessage.newBuilder();
ProtoBufUtils.updateMessageBuilderFromJson(builder, response );
SeldonMessage seldonMessage = builder.build();
Assert.assertEquals(3, seldonMessage.getMeta().getMetricsCount());
Assert.assertEquals("COUNTER", seldonMessage.getMeta().getMetrics(0).getType().toString());
Assert.assertEquals("GAUGE", seldonMessage.getMeta().getMetrics(1).getType().toString());
Assert.assertEquals("TIMER", seldonMessage.getMeta().getMetrics(2).getType().toString());
Assert.assertEquals(strdata, seldonMessage.getStrData());
Assert.assertEquals("1234", seldonMessage.getMeta().getPuid());
}
}
Loading