From a08722814f6713f827e963f8d9d2e62c361250af Mon Sep 17 00:00:00 2001 From: Qing Date: Tue, 13 Nov 2018 11:03:43 -0800 Subject: [PATCH] update readme and script to run --- .../predictor/run_predictor_java_example.sh | 44 ++++++++++++++ .../infer/predictor/PredictorExample.java | 33 ++++++---- .../javaapi/infer/predictor/README.md | 60 +++++++++++++++++++ 3 files changed, 127 insertions(+), 10 deletions(-) create mode 100755 scala-package/examples/scripts/infer/predictor/run_predictor_java_example.sh create mode 100644 scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/predictor/README.md diff --git a/scala-package/examples/scripts/infer/predictor/run_predictor_java_example.sh b/scala-package/examples/scripts/infer/predictor/run_predictor_java_example.sh new file mode 100755 index 000000000000..4ebcc3076a78 --- /dev/null +++ b/scala-package/examples/scripts/infer/predictor/run_predictor_java_example.sh @@ -0,0 +1,44 @@ +#!/bin/bash + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +hw_type=cpu +if [[ $3 = gpu ]] +then + hw_type=gpu +fi + +platform=linux-x86_64 + +if [[ $OSTYPE = [darwin]* ]] +then + platform=osx-x86_64 +fi + +MXNET_ROOT=$(cd "$(dirname $0)/../../../../../"; pwd) +CLASS_PATH=$MXNET_ROOT/scala-package/assembly/$platform-$hw_type/target/*:$MXNET_ROOT/scala-package/examples/target/*:$MXNET_ROOT/scala-package/examples/target/classes/lib/*:$MXNET_ROOT/scala-package/infer/target/* + +# model dir and prefix +MODEL_DIR=$1 +# input image +INPUT_IMG=$2 + +java -Xmx8G -cp $CLASS_PATH \ + org.apache.mxnetexamples.javaapi.infer.predictor.PredictorExample \ + --model-path-prefix $MODEL_DIR \ + --input-image $INPUT_IMG diff --git a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/predictor/PredictorExample.java b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/predictor/PredictorExample.java index 3ab82de7f95c..01766e61e315 100644 --- a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/predictor/PredictorExample.java +++ b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/predictor/PredictorExample.java @@ -18,10 +18,7 @@ package org.apache.mxnetexamples.javaapi.infer.predictor; import org.apache.mxnet.infer.javaapi.Predictor; -import org.apache.mxnet.javaapi.Context; -import org.apache.mxnet.javaapi.DType; -import org.apache.mxnet.javaapi.DataDesc; -import org.apache.mxnet.javaapi.Shape; +import org.apache.mxnet.javaapi.*; import org.kohsuke.args4j.CmdLineParser; import org.kohsuke.args4j.Option; import org.slf4j.Logger; @@ -42,8 +39,6 @@ public class PredictorExample { private String modelPathPrefix = "/model/ssd_resnet50_512"; @Option(name = "--input-image", usage = "the input image") private String inputImagePath = "/images/dog.jpg"; - @Option(name = "--input-dir", usage = "the input batch of images directory") - private String inputImageDir = "/images/"; final static Logger logger = LoggerFactory.getLogger(PredictorExample.class); @@ -97,9 +92,12 @@ private static float[] imagePreprocess(BufferedImage buf) { return result; } - private static String printMaximumClass(float[] probabilities) throws IOException { - BufferedReader reader = new BufferedReader(new FileReader("/tmp/resnet18/synset.txt")); - ArrayList list = new ArrayList(); + private static String printMaximumClass(float[] probabilities, + String modelPathPrefix) throws IOException { + String synsetFilePath = modelPathPrefix.substring(0, + 1 + modelPathPrefix.lastIndexOf(File.separator)) + "/synset.txt"; + BufferedReader reader = new BufferedReader(new FileReader(synsetFilePath)); + ArrayList list = new ArrayList<>(); String line = reader.readLine(); while (line != null){ @@ -147,7 +145,22 @@ public static void main(String[] args) { // predict float[][] result = predictor.predict(new float[][]{imagePreprocess(img)}); try { - System.out.println(printMaximumClass(result[0])); + System.out.println("Predict with Float input"); + System.out.println(printMaximumClass(result[0], inst.modelPathPrefix)); + } catch (Exception e) { + System.err.println(e); + } + // predict with NDArray + NDArray nd = new NDArray( + imagePreprocess(img), + new Shape(new int[]{1, 3, 224, 224}), + Context.cpu()); + List ndList = new ArrayList<>(); + ndList.add(nd); + List ndResult = predictor.predictWithNDArray(ndList); + try { + System.out.println("Predict with NDArray"); + System.out.println(printMaximumClass(ndResult.get(0).toArray(), inst.modelPathPrefix)); } catch (Exception e) { System.err.println(e); } diff --git a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/predictor/README.md b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/predictor/README.md new file mode 100644 index 000000000000..70d84a16e7ad --- /dev/null +++ b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/predictor/README.md @@ -0,0 +1,60 @@ +# Image Classification using Java Predictor + +In this example, you will learn how to use Java Inference API to +build and run pre-trained Resnet 18 model. + +## Contents + +1. [Prerequisites](#prerequisites) +2. [Download artifacts](#download-artifacts) +3. [Setup datapath and parameters](#setup-datapath-and-parameters) +4. [Run the image inference example](#run-the-image-inference-example) + +## Prerequisites + +1. MXNet +2. MXNet Scala Package +3. [IntelliJ IDE (or alternative IDE) project setup](https://github.com/apache/incubator-mxnet/blob/master/docs/tutorials/java/mxnet_java_on_intellij.md) with the MXNet Java Package +4. wget + +## Download Artifacts + +For this tutorial, you can get the model and sample input image by running following bash file. This script will use `wget` to download these artifacts from AWS S3. + +From the `scala-package/examples/scripts/infer/imageclassifier/` folder run: + +```bash +./get_resnet_18_data.sh +``` + +**Note**: You may need to run `chmod +x get_resnet_18_data.sh` before running this script. + +### Setup Datapath and Parameters + +The available arguments are as follows: + +| Argument | Comments | +| ----------------------------- | ---------------------------------------- | +| `model-dir`                   | Folder path with prefix to the model (including json, params, and any synset file). | +| `input-image` | The image to run inference on. | + +## Run the image inference example + +After the previous steps, you should be able to run the code using the following script that will pass all of the required parameters to the Infer API. + +From the `scala-package/examples/scripts/infer/predictor/` folder run: + +```bash +bash run_predictor_java_example.sh ../models/resnet-18/resnet-18 ../images/kitten.jpg +``` + +**Notes**: +* These are relative paths to this script. +* You may need to run `chmod +x run_predictor_java_example.sh` before running this script. + +The example should give expected output as shown below: +``` +Probability : 0.30337515 Class : n02123159 tiger cat +Probability : 0.30337515 Class : n02123159 tiger cat +``` +the outputs come from the the input image, with top1 predictions picked. \ No newline at end of file