Skip to content

Commit

Permalink
[djl-bench] Uses Shape.parseShapes() (#1849)
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu authored Apr 30, 2024
1 parent f1f40bc commit 444fcbb
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 75 deletions.
2 changes: 1 addition & 1 deletion benchmark/src/main/java/ai/djl/benchmark/Arguments.java
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ public class Arguments {
}

String shape = cmd.getOptionValue("input-shapes");
inputShapes = NDListGenerator.parseShape(shape);
inputShapes = Shape.parseShapes(shape);
inputData = cmd.getOptionValue("input-data");
}

Expand Down
87 changes: 13 additions & 74 deletions benchmark/src/main/java/ai/djl/benchmark/NDListGenerator.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,13 @@
*/
package ai.djl.benchmark;

import ai.djl.Device;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDList.Encoding;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.util.Pair;
import ai.djl.util.PairList;
import ai.djl.util.passthrough.PassthroughNDManager;

import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.DefaultParser;
Expand All @@ -35,9 +34,6 @@
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

/** A class generates NDList files. */
final class NDListGenerator {
Expand Down Expand Up @@ -68,22 +64,20 @@ static boolean generate(String[] args) {
encoding = Encoding.ND_LIST;
}
Path path = Paths.get(output);

try (NDManager manager = NDManager.newBaseManager(Device.cpu(), "PyTorch")) {
NDList list = new NDList();
for (Pair<DataType, Shape> pair : parseShape(inputShapes)) {
DataType dataType = pair.getKey();
Shape shape = pair.getValue();
if (ones) {
list.add(manager.ones(shape, dataType));
} else {
list.add(manager.zeros(shape, dataType));
}
}
try (OutputStream os = new BufferedOutputStream(Files.newOutputStream(path))) {
list.encode(os, encoding);
NDManager manager = PassthroughNDManager.INSTANCE;
NDList list = new NDList();
for (Pair<DataType, Shape> pair : Shape.parseShapes(inputShapes)) {
DataType dataType = pair.getKey();
Shape shape = pair.getValue();
if (ones) {
list.add(manager.ones(shape, dataType));
} else {
list.add(manager.zeros(shape, dataType));
}
}
try (OutputStream os = new BufferedOutputStream(Files.newOutputStream(path))) {
list.encode(os, encoding);
}
logger.info("NDList file created: {}", path.toAbsolutePath());
return true;
} catch (ParseException e) {
Expand All @@ -94,61 +88,6 @@ static boolean generate(String[] args) {
return false;
}

static PairList<DataType, Shape> parseShape(String shape) {
PairList<DataType, Shape> inputShapes = new PairList<>();
if (shape != null) {
if (shape.contains("(")) {
Pattern pattern =
Pattern.compile("\\((\\s*(\\d+)([,\\s]+\\d+)*\\s*)\\)([sdubilBfS]?)");
Matcher matcher = pattern.matcher(shape);
while (matcher.find()) {
String[] tokens = matcher.group(1).split(",");
long[] array = Arrays.stream(tokens).mapToLong(Long::parseLong).toArray();
DataType dataType;
String dataTypeStr = matcher.group(4);
if (dataTypeStr == null || dataTypeStr.isEmpty()) {
dataType = DataType.FLOAT32;
} else {
switch (dataTypeStr) {
case "s":
dataType = DataType.FLOAT16;
break;
case "d":
dataType = DataType.FLOAT64;
break;
case "u":
dataType = DataType.UINT8;
break;
case "b":
dataType = DataType.INT8;
break;
case "i":
dataType = DataType.INT32;
break;
case "l":
dataType = DataType.INT64;
break;
case "B":
dataType = DataType.BOOLEAN;
break;
case "f":
dataType = DataType.FLOAT32;
break;
default:
throw new IllegalArgumentException("Invalid input-shape: " + shape);
}
}
inputShapes.add(dataType, new Shape(array));
}
} else {
String[] tokens = shape.split(",");
long[] shapes = Arrays.stream(tokens).mapToLong(Long::parseLong).toArray();
inputShapes.add(DataType.FLOAT32, new Shape(shapes));
}
}
return inputShapes;
}

private static Options getOptions() {
Options options = new Options();
options.addOption(
Expand Down

0 comments on commit 444fcbb

Please sign in to comment.