From ee36ef0a43bc30550f0bb0cccc617feb815e3fcb Mon Sep 17 00:00:00 2001 From: carlosuc3m <100329787@alumnos.uc3m.es> Date: Tue, 24 Sep 2024 16:45:13 +0200 Subject: [PATCH] start adapting tf1 to persistent multiprocessing --- .../modelrunner/tensorflow/v1/JavaWorker.java | 150 ++++ .../tensorflow/v1/Tensorflow1Interface.java | 722 +++++++----------- .../tensorflow/v1/shm/ShmBuilder.java | 221 ++++++ .../tensorflow/v1/shm/TensorBuilder.java | 242 ++++++ .../mappedbuffer/ImgLib2ToMappedBuffer.java | 283 ------- .../mappedbuffer/MappedBufferToImgLib2.java | 328 -------- 6 files changed, 880 insertions(+), 1066 deletions(-) create mode 100644 src/main/java/io/bioimage/modelrunner/tensorflow/v1/JavaWorker.java create mode 100644 src/main/java/io/bioimage/modelrunner/tensorflow/v1/shm/ShmBuilder.java create mode 100644 src/main/java/io/bioimage/modelrunner/tensorflow/v1/shm/TensorBuilder.java delete mode 100644 src/main/java/io/bioimage/modelrunner/tensorflow/v1/tensor/mappedbuffer/ImgLib2ToMappedBuffer.java delete mode 100644 src/main/java/io/bioimage/modelrunner/tensorflow/v1/tensor/mappedbuffer/MappedBufferToImgLib2.java diff --git a/src/main/java/io/bioimage/modelrunner/tensorflow/v1/JavaWorker.java b/src/main/java/io/bioimage/modelrunner/tensorflow/v1/JavaWorker.java new file mode 100644 index 0000000..ce83609 --- /dev/null +++ b/src/main/java/io/bioimage/modelrunner/tensorflow/v1/JavaWorker.java @@ -0,0 +1,150 @@ +package io.bioimage.modelrunner.tensorflow.v1; + +import java.io.IOException; +import java.net.URISyntaxException; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Scanner; + +import io.bioimage.modelrunner.apposed.appose.Types; +import io.bioimage.modelrunner.apposed.appose.Service.RequestType; +import io.bioimage.modelrunner.apposed.appose.Service.ResponseType; + +public class JavaWorker { + + private static LinkedHashMap tasks = new LinkedHashMap(); + + private final String uuid; + + private final Tensorflow1Interface ti; + + private boolean cancelRequested = false; + + public static void main(String[] args) { + + try(Scanner scanner = new Scanner(System.in)){ + Tensorflow1Interface ti; + try { + ti = new Tensorflow1Interface(false); + } catch (IOException | URISyntaxException e) { + return; + } + + while (true) { + String line; + try { + if (!scanner.hasNextLine()) break; + line = scanner.nextLine().trim(); + } catch (Exception e) { + break; + } + + if (line.isEmpty()) break; + Map request = Types.decode(line); + String uuid = (String) request.get("task"); + String requestType = (String) request.get("requestType"); + + if (requestType.equals(RequestType.EXECUTE.toString())) { + String script = (String) request.get("script"); + Map inputs = (Map) request.get("inputs"); + JavaWorker task = new JavaWorker(uuid, ti); + tasks.put(uuid, task); + task.start(script, inputs); + } else if (requestType.equals(RequestType.CANCEL.toString())) { + JavaWorker task = (JavaWorker) tasks.get(uuid); + if (task == null) { + System.err.println("No such task: " + uuid); + continue; + } + task.cancelRequested = true; + } else { + break; + } + } + } + + } + + private JavaWorker(String uuid, Tensorflow1Interface ti) { + this.uuid = uuid; + this.ti = ti; + } + + private void executeScript(String script, Map inputs) { + Map binding = new LinkedHashMap(); + binding.put("task", this); + if (inputs != null) + binding.putAll(binding); + + this.reportLaunch(); + try { + if (script.equals("loadModel")) { + ti.loadModel((String) inputs.get("modelFolder"), null); + } else if (script.equals("inference")) { + ti.runFromShmas((List) inputs.get("inputs"), (List) inputs.get("outputs")); + } else if (script.equals("close")) { + ti.closeModel(); + } + } catch(Exception ex) { + this.fail(Types.stackTrace(ex)); + return; + } + this.reportCompletion(); + } + + private void start(String script, Map inputs) { + new Thread(() -> executeScript(script, inputs), "Appose-" + this.uuid).start(); + } + + private void reportLaunch() { + respond(ResponseType.LAUNCH, null); + } + + private void reportCompletion() { + respond(ResponseType.COMPLETION, null); + } + + private void update(String message, Integer current, Integer maximum) { + LinkedHashMap args = new LinkedHashMap(); + + if (message != null) + args.put("message", message); + + if (current != null) + args.put("current", current); + + if (maximum != null) + args.put("maximum", maximum); + this.respond(ResponseType.UPDATE, args); + } + + private void respond(ResponseType responseType, Map args) { + Map response = new HashMap(); + response.put("task", uuid); + response.put("responseType", responseType); + if (args != null) + response.putAll(args); + try { + System.out.println(Types.encode(response)); + System.out.flush(); + } catch(Exception ex) { + this.fail(Types.stackTrace(ex.getCause())); + } + } + + private void cancel() { + this.respond(ResponseType.CANCELATION, null); + } + + private void fail(String error) { + Map args = null; + if (error != null) { + args = new HashMap(); + args.put("error", error); + } + respond(ResponseType.FAILURE, args); + } + +} diff --git a/src/main/java/io/bioimage/modelrunner/tensorflow/v1/Tensorflow1Interface.java b/src/main/java/io/bioimage/modelrunner/tensorflow/v1/Tensorflow1Interface.java index 388869d..806339a 100644 --- a/src/main/java/io/bioimage/modelrunner/tensorflow/v1/Tensorflow1Interface.java +++ b/src/main/java/io/bioimage/modelrunner/tensorflow/v1/Tensorflow1Interface.java @@ -20,9 +20,15 @@ */ package io.bioimage.modelrunner.tensorflow.v1; +import com.google.gson.Gson; import com.google.protobuf.InvalidProtocolBufferException; +import io.bioimage.modelrunner.apposed.appose.Service; +import io.bioimage.modelrunner.apposed.appose.Types; +import io.bioimage.modelrunner.apposed.appose.Service.Task; +import io.bioimage.modelrunner.apposed.appose.Service.TaskStatus; import io.bioimage.modelrunner.bioimageio.description.ModelDescriptor; +import io.bioimage.modelrunner.bioimageio.description.ModelDescriptorFactory; import io.bioimage.modelrunner.bioimageio.download.DownloadModel; import io.bioimage.modelrunner.engine.DeepLearningEngineInterface; import io.bioimage.modelrunner.engine.EngineInfo; @@ -30,15 +36,18 @@ import io.bioimage.modelrunner.exceptions.RunModelException; import io.bioimage.modelrunner.system.PlatformDetection; import io.bioimage.modelrunner.tensor.Tensor; +import io.bioimage.modelrunner.tensor.shm.SharedMemoryArray; import io.bioimage.modelrunner.tensorflow.v1.tensor.ImgLib2Builder; import io.bioimage.modelrunner.tensorflow.v1.tensor.TensorBuilder; -import io.bioimage.modelrunner.tensorflow.v1.tensor.mappedbuffer.ImgLib2ToMappedBuffer; -import io.bioimage.modelrunner.tensorflow.v1.tensor.mappedbuffer.MappedBufferToImgLib2; +import io.bioimage.modelrunner.utils.CommonUtils; import io.bioimage.modelrunner.utils.Constants; import io.bioimage.modelrunner.utils.ZipUtils; +import net.imglib2.RandomAccessibleInterval; import net.imglib2.type.NativeType; import net.imglib2.type.numeric.RealType; import net.imglib2.type.numeric.real.FloatType; +import net.imglib2.util.Cast; +import net.imglib2.util.Util; import java.io.BufferedReader; import java.io.File; @@ -60,8 +69,10 @@ import java.time.format.DateTimeFormatter; import java.util.ArrayList; import java.util.HashMap; +import java.util.LinkedHashMap; import java.util.LinkedList; import java.util.List; +import java.util.Map; import java.util.stream.Collectors; import org.tensorflow.SavedModelBundle; @@ -95,68 +106,25 @@ * @author Carlos Garcia Lopez de Haro and Daniel Felipe Gonzalez Obando */ public class Tensorflow1Interface implements DeepLearningEngineInterface { - - private static final String[] MODEL_TAGS = { "serve", "inference", "train", - "eval", "gpu", "tpu" }; - - private static final String[] TF_MODEL_TAGS = { - "tf.saved_model.tag_constants.SERVING", - "tf.saved_model.tag_constants.INFERENCE", - "tf.saved_model.tag_constants.TRAINING", - "tf.saved_model.tag_constants.EVAL", "tf.saved_model.tag_constants.GPU", - "tf.saved_model.tag_constants.TPU" }; - - private static final String[] SIGNATURE_CONSTANTS = { "serving_default", - "inputs", "tensorflow/serving/classify", "classes", "scores", "inputs", - "tensorflow/serving/predict", "outputs", "inputs", - "tensorflow/serving/regress", "outputs", "train", "eval", - "tensorflow/supervised/training", "tensorflow/supervised/eval" }; - - private static final String[] TF_SIGNATURE_CONSTANTS = { - "tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY", - "tf.saved_model.signature_constants.CLASSIFY_INPUTS", - "tf.saved_model.signature_constants.CLASSIFY_METHOD_NAME", - "tf.saved_model.signature_constants.CLASSIFY_OUTPUT_CLASSES", - "tf.saved_model.signature_constants.CLASSIFY_OUTPUT_SCORES", - "tf.saved_model.signature_constants.PREDICT_INPUTS", - "tf.saved_model.signature_constants.PREDICT_METHOD_NAME", - "tf.saved_model.signature_constants.PREDICT_OUTPUTS", - "tf.saved_model.signature_constants.REGRESS_INPUTS", - "tf.saved_model.signature_constants.REGRESS_METHOD_NAME", - "tf.saved_model.signature_constants.REGRESS_OUTPUTS", - "tf.saved_model.signature_constants.DEFAULT_TRAIN_SIGNATURE_DEF_KEY", - "tf.saved_model.signature_constants.DEFAULT_EVAL_SIGNATURE_DEF_KEY", - "tf.saved_model.signature_constants.SUPERVISED_TRAIN_METHOD_NAME", - "tf.saved_model.signature_constants.SUPERVISED_EVAL_METHOD_NAME" }; - - /** - * Idetifier for the files that contain the data of the inputs - */ - final private static String INPUT_FILE_TERMINATION = "_model_input"; - - /** - * Idetifier for the files that contain the data of the outputs - */ - final private static String OUTPUT_FILE_TERMINATION = "_model_output"; - /** - * Key for the inputs in the map that retrieves the file names for interprocess communication - */ - final private static String INPUTS_MAP_KEY = "inputs"; - /** - * Key for the outputs in the map that retrieves the file names for interprocess communication - */ - final private static String OUTPUTS_MAP_KEY = "outputs"; - /** - * File extension for the temporal files used for interprocessing - */ - final private static String FILE_EXTENSION = ".dat"; /** * Name without vesion of the JAR created for this library */ private static final String JAR_FILE_NAME = "dl-modelrunner-tensorflow-"; + private static final String NAME_KEY = "name"; + private static final String SHAPE_KEY = "shape"; + private static final String DTYPE_KEY = "dtype"; + private static final String IS_INPUT_KEY = "isInput"; + private static final String MEM_NAME_KEY = "memoryName"; + + private List shmaInputList = new ArrayList(); + + private List shmaOutputList = new ArrayList(); + + private List shmaNamesList = new ArrayList(); + /** - * The loaded Tensorflow 1 model + * The loaded Tensorflow 2 model */ private static SavedModelBundle model; /** @@ -167,61 +135,52 @@ public class Tensorflow1Interface implements DeepLearningEngineInterface { * Whether the execution needs interprocessing (MacOS Interl) or not */ private boolean interprocessing = false; - /** - * TEmporary dir where to store temporary files - */ - private String tmpDir; /** * Folde containing the model that is being executed */ private String modelFolder; /** - * List of temporary files used for interprocessing communication - */ - private List listTempFiles; - /** - * HashMap that maps tensor to the temporal file name for interprocessing + * Process where the model is being loaded and executed */ - private HashMap tensorFilenameMap; + Service runner; /** + * TODO the interprocessing is executed for every OS * Constructor that detects whether the operating system where it is being - * executed is MacOS Intel or not to know if it is going to need interprocessing + * executed is Windows or Mac or not to know if it is going to need interprocessing * or not * @throws IOException if the temporary dir is not found + * @throws URISyntaxException */ - public Tensorflow1Interface() throws IOException + public Tensorflow1Interface() throws IOException, URISyntaxException { - boolean isMac = PlatformDetection.isMacOS(); - boolean isIntel = PlatformDetection.getArch().equals(PlatformDetection.ARCH_X86_64); - if (false && isMac && isIntel) { - interprocessing = true; - tmpDir = getTemporaryDir(); - - } + this(true); } /** * Private constructor that can only be launched from the class to create a separate - * process to avoid the conflicts that occur in the same process between TF1 and TF2 - * in MacOS Intel + * process to avoid the conflicts that occur in the same process between TF2 and TF1/Pytorch + * in Windows and Mac * @param doInterprocessing * whether to do interprocessing or not * @throws IOException if the temp dir is not found + * @throws URISyntaxException */ - private Tensorflow1Interface(boolean doInterprocessing) throws IOException + protected Tensorflow1Interface(boolean doInterprocessing) throws IOException, URISyntaxException { - if (!doInterprocessing) { - interprocessing = false; - } else { - boolean isMac = PlatformDetection.isMacOS(); - boolean isIntel = new PlatformDetection().getArch().equals(PlatformDetection.ARCH_X86_64); - if (isMac && isIntel) { - interprocessing = true; - tmpDir = getTemporaryDir(); - - } - } + interprocessing = doInterprocessing; + if (this.interprocessing) { + runner = getRunner(); + runner.debug((text) -> System.err.println(text)); + } + } + + private Service getRunner() throws IOException, URISyntaxException { + List args = getProcessCommandsWithoutArgs(); + String[] argArr = new String[args.size()]; + args.toArray(argArr); + + return new Service(new File("."), argArr); } /** @@ -238,6 +197,11 @@ public void loadModel(String modelFolder, String modelSource) { this.modelFolder = modelFolder; if (interprocessing) { + try { + launchModelLoadOnProcess(); + } catch (IOException | InterruptedException e) { + throw new LoadModelException(Types.stackTrace(e)); + } return; } try { @@ -245,7 +209,7 @@ public void loadModel(String modelFolder, String modelSource) } catch (Exception e) { throw new LoadModelException(e.toString()); } - model = SavedModelBundle.load(modelFolder, "serve"); + model = SavedModelBundle.load(this.modelFolder, "serve"); byte[] byteGraph = model.metaGraphDef(); try { sig = MetaGraphDef.parseFrom(byteGraph).getSignatureDefOrThrow( @@ -253,10 +217,23 @@ public void loadModel(String modelFolder, String modelSource) } catch (InvalidProtocolBufferException e) { closeModel(); - throw new LoadModelException(e.toString()); + throw new LoadModelException(Types.stackTrace(e)); } } + private void launchModelLoadOnProcess() throws IOException, InterruptedException { + HashMap args = new HashMap(); + args.put("modelFolder", modelFolder); + Task task = runner.task("loadModel", args); + task.waitFor(); + if (task.status == TaskStatus.CANCELED) + throw new RuntimeException(); + else if (task.status == TaskStatus.FAILED) + throw new RuntimeException(); + else if (task.status == TaskStatus.CRASHED) + throw new RuntimeException(); + } + /** * Check if an unzipped tensorflow model exists in the model folder, * and if not look for it and unzip it @@ -268,7 +245,7 @@ private void checkModelUnzipped() throws LoadModelException, IOException, Except if (new File(modelFolder, "variables").isDirectory() && new File(modelFolder, "saved_model.pb").isFile()) return; - unzipTfWeights(ModelDescriptor.readFromLocalFile(modelFolder + File.separator + Constants.RDF_FNAME)); + unzipTfWeights(ModelDescriptorFactory.readFromLocalFile(modelFolder + File.separator + Constants.RDF_FNAME)); } /** @@ -286,9 +263,17 @@ private void unzipTfWeights(ModelDescriptor descriptor) throws LoadModelExceptio String source = descriptor.getWeights().gettAllSupportedWeightObjects().stream() .filter(ww -> ww.getFramework().equals(EngineInfo.getBioimageioTfKey())) .findFirst().get().getSource(); - source = DownloadModel.getFileNameFromURLString(source); - System.out.println("Unzipping model..."); - ZipUtils.unzipFolder(modelFolder + File.separator + source, modelFolder); + if (new File(source).isFile()) { + System.out.println("Unzipping model..."); + ZipUtils.unzipFolder(new File(source).getAbsolutePath(), modelFolder); + } else if (new File(modelFolder, source).isFile()) { + System.out.println("Unzipping model..."); + ZipUtils.unzipFolder(new File(modelFolder, source).getAbsolutePath(), modelFolder); + } else { + source = DownloadModel.getFileNameFromURLString(source); + System.out.println("Unzipping model..."); + ZipUtils.unzipFolder(modelFolder + File.separator + source, modelFolder); + } } else { throw new LoadModelException("No model file was found in the model folder"); } @@ -301,7 +286,8 @@ private void unzipTfWeights(ModelDescriptor descriptor) throws LoadModelExceptio * and modifies the output list with the results obtained */ @Override - public void run(List> inputTensors, List> outputTensors) + public & NativeType, R extends RealType & NativeType> + void run(List> inputTensors, List> outputTensors) throws RunModelException { if (interprocessing) { @@ -314,7 +300,7 @@ public void run(List> inputTensors, List> outputTensors) List> inTensors = new ArrayList>(); int c = 0; - for (Tensor tt : inputTensors) { + for (Tensor tt : inputTensors) { inputListNames.add(tt.getName()); org.tensorflow.Tensor inT = TensorBuilder.build(tt); inTensors.add(inT); @@ -322,7 +308,7 @@ public void run(List> inputTensors, List> outputTensors) runner.feed(inputName, inT); } c = 0; - for (Tensor tt : outputTensors) + for (Tensor tt : outputTensors) runner = runner.fetch(getModelOutputName(tt.getName(), c ++)); // Run runner List> resultPatchTensors = runner.run(); @@ -337,6 +323,43 @@ public void run(List> inputTensors, List> outputTensors) } } + protected void runFromShmas(List inputs, List outputs) throws IOException { + Session session = model.session(); + Session.Runner runner = session.runner(); + + List> inTensors = new ArrayList>(); + int c = 0; + for (String ee : inputs) { + Map decoded = Types.decode(ee); + SharedMemoryArray shma = SharedMemoryArray.read((String) decoded.get(MEM_NAME_KEY)); + org.tensorflow.Tensor inT = io.bioimage.modelrunner.tensorflow.v2.api030.shm.TensorBuilder.build(shma); + if (PlatformDetection.isWindows()) shma.close(); + inTensors.add(inT); + String inputName = getModelInputName((String) decoded.get(NAME_KEY), c ++); + runner.feed(inputName, inT); + } + + c = 0; + for (String ee : outputs) + runner = runner.fetch(getModelOutputName((String) Types.decode(ee).get(NAME_KEY), c ++)); + // Run runner + List> resultPatchTensors = runner.run(); + + // Fill the agnostic output tensors list with data from the inference result + c = 0; + for (String ee : outputs) { + Map decoded = Types.decode(ee); + ShmBuilder.build((org.tensorflow.Tensor) resultPatchTensors.get(c ++), (String) decoded.get(MEM_NAME_KEY)); + } + // Close the remaining resources + for (org.tensorflow.Tensor tt : inTensors) { + tt.close(); + } + for (org.tensorflow.Tensor tt : resultPatchTensors) { + tt.close(); + } + } + /** * MEthod only used in MacOS Intel systems that makes all the arangements * to create another process, communicate the model info and tensors to the other @@ -347,32 +370,110 @@ public void run(List> inputTensors, List> outputTensors) * expected results of the model * @throws RunModelException if there is any issue running the model */ - public void runInterprocessing(List> inputTensors, List> outputTensors) throws RunModelException { - createTensorsForInterprocessing(inputTensors); - createTensorsForInterprocessing(outputTensors); + public & NativeType, R extends RealType & NativeType> + void runInterprocessing(List> inputTensors, List> outputTensors) throws RunModelException { + shmaInputList = new ArrayList(); + shmaOutputList = new ArrayList(); + List encIns = modifyForWinCmd(encodeInputs(inputTensors)); + List encOuts = modifyForWinCmd(encodeOutputs(outputTensors)); + LinkedHashMap args = new LinkedHashMap(); + args.put("inputs", encIns); + args.put("outputs", encOuts); + try { - List args = getProcessCommandsWithoutArgs(); - for (Tensor tensor : inputTensors) {args.add(getFilename4Tensor(tensor.getName()) + INPUT_FILE_TERMINATION);} - for (Tensor tensor : outputTensors) {args.add(getFilename4Tensor(tensor.getName()) + OUTPUT_FILE_TERMINATION);} - ProcessBuilder builder = new ProcessBuilder(args); - builder.redirectOutput(ProcessBuilder.Redirect.INHERIT); - builder.redirectError(ProcessBuilder.Redirect.INHERIT); - Process process = builder.start(); - int result = process.waitFor(); - process.destroy(); - if (result != 0) - throw new RunModelException("Error executing the Tensorflow 1 model in" - + " a separate process. The process was not terminated correctly." - + System.lineSeparator() + readProcessStringOutput(process)); - } catch (RunModelException e) { - closeModel(); - throw e; + Task task = runner.task("inference", args); + task.waitFor(); + if (task.status == TaskStatus.CANCELED) + throw new RuntimeException(); + else if (task.status == TaskStatus.FAILED) + throw new RuntimeException(); + else if (task.status == TaskStatus.CRASHED) + throw new RuntimeException(); + for (int i = 0; i < outputTensors.size(); i ++) { + String name = (String) Types.decode(encOuts.get(i)).get(MEM_NAME_KEY); + SharedMemoryArray shm = shmaOutputList.stream() + .filter(ss -> ss.getName().equals(name)).findFirst().orElse(null); + if (shm == null) { + shm = SharedMemoryArray.read(name); + shmaOutputList.add(shm); + } + RandomAccessibleInterval rai = shm.getSharedRAI(); + outputTensors.get(i).setData(Tensor.createCopyOfRaiInWantedDataType(Cast.unchecked(rai), Util.getTypeFromInterval(Cast.unchecked(rai)))); + } } catch (Exception e) { - closeModel(); - throw new RunModelException(e.getCause().toString()); + closeShmas(); + if (e instanceof RunModelException) + throw (RunModelException) e; + throw new RunModelException(Types.stackTrace(e)); } - - retrieveInterprocessingTensors(outputTensors); + closeShmas(); + } + + private void closeShmas() { + shmaInputList.forEach(shm -> { + try { shm.close(); } catch (IOException e1) { e1.printStackTrace();} + }); + shmaInputList = null; + shmaOutputList.forEach(shm -> { + try { shm.close(); } catch (IOException e1) { e1.printStackTrace();} + }); + shmaOutputList = null; + } + + private static List modifyForWinCmd(List ins){ + if (!PlatformDetection.isWindows()) + return ins; + List newIns = new ArrayList(); + for (String ii : ins) + newIns.add("\"" + ii.replace("\"", "\\\"") + "\""); + return newIns; + } + + + private & NativeType> List encodeInputs(List> inputTensors) { + List encodedInputTensors = new ArrayList(); + Gson gson = new Gson(); + for (Tensor tt : inputTensors) { + SharedMemoryArray shma = SharedMemoryArray.createSHMAFromRAI(tt.getData(), false, true); + shmaInputList.add(shma); + HashMap map = new HashMap(); + map.put(NAME_KEY, tt.getName()); + map.put(SHAPE_KEY, tt.getShape()); + map.put(DTYPE_KEY, CommonUtils.getDataTypeFromRAI(tt.getData())); + map.put(IS_INPUT_KEY, true); + map.put(MEM_NAME_KEY, shma.getName()); + encodedInputTensors.add(gson.toJson(map)); + } + return encodedInputTensors; + } + + + private & NativeType> + List encodeOutputs(List> outputTensors) { + Gson gson = new Gson(); + List encodedOutputTensors = new ArrayList(); + for (Tensor tt : outputTensors) { + HashMap map = new HashMap(); + map.put(NAME_KEY, tt.getName()); + map.put(IS_INPUT_KEY, false); + if (!tt.isEmpty()) { + map.put(SHAPE_KEY, tt.getShape()); + map.put(DTYPE_KEY, CommonUtils.getDataTypeFromRAI(tt.getData())); + SharedMemoryArray shma = SharedMemoryArray.createSHMAFromRAI(tt.getData(), false, true); + shmaOutputList.add(shma); + map.put(MEM_NAME_KEY, shma.getName()); + } else if (PlatformDetection.isWindows()){ + SharedMemoryArray shma = SharedMemoryArray.create(0); + shmaOutputList.add(shma); + map.put(MEM_NAME_KEY, shma.getName()); + } else { + String memName = SharedMemoryArray.createShmName(); + map.put(MEM_NAME_KEY, memName); + shmaNamesList.add(memName); + } + encodedOutputTensors.add(gson.toJson(map)); + } + return encodedOutputTensors; } /** @@ -384,9 +485,9 @@ public void runInterprocessing(List> inputTensors, List> out * @throws RunModelException If the number of tensors expected is not the same * as the number of Tensors outputed by the model */ - public static void fillOutputTensors( - List> outputNDArrays, - List> outputTensors) throws RunModelException + public static & NativeType> void fillOutputTensors( + List> outputNDArrays, + List> outputTensors) throws RunModelException { if (outputNDArrays.size() != outputTensors.size()) throw new RunModelException(outputNDArrays.size(), outputTensors.size()); @@ -408,19 +509,31 @@ public static void fillOutputTensors( */ @Override public void closeModel() { + if (this.interprocessing && runner != null) { + Task task; + try { + task = runner.task("close"); + task.waitFor(); + } catch (IOException | InterruptedException e) { + throw new RuntimeException(Types.stackTrace(e)); + } + if (task.status == TaskStatus.CANCELED) + throw new RuntimeException(); + else if (task.status == TaskStatus.FAILED) + throw new RuntimeException(); + else if (task.status == TaskStatus.CRASHED) + throw new RuntimeException(); + this.runner.close(); + return; + } else if (this.interprocessing) { + return; + } sig = null; if (model != null) { model.session().close(); model.close(); } model = null; - if (listTempFiles == null) - return; - for (File ff : listTempFiles) { - if (ff.exists()) - ff.delete(); - } - listTempFiles = null; } // TODO make only one @@ -481,193 +594,6 @@ public static String getModelOutputName(String outputName, int i) { } } - - /** - * Methods to run interprocessing and bypass the errors that occur in MacOS intel - * with the compatibility between TF1 and TF2 - * This method checks that the arguments are correct, retrieves the input and output - * tensors, loads the model, makes inference with it and finally sends the tensors - * to the original process - * - * @param args - * arguments of the program: - * - Path to the model folder - * - Path to a temporary dir - * - Name of the input 0 - * - Name of the input 1 - * - ... - * - Name of the output n - * - Name of the output 0 - * - Name of the output 1 - * - ... - * - Name of the output n - * @throws LoadModelException if there is any error loading the model - * @throws IOException if there is any error reading or writing any file or with the paths - * @throws RunModelException if there is any error running the model - */ - public static void main(String[] args) throws LoadModelException, IOException, RunModelException { - if (args.length == 0) { - String modelFolder = "/home/carlos/git/deep-icy/models/stardist_1channel"; - Tensorflow1Interface ti = new Tensorflow1Interface(false); - ti.loadModel(modelFolder, modelFolder); - Tensor inp = Tensor.buildBlankTensor("in", "byxc", new long[] {1, 208, 208, 1}, new FloatType()); - Tensor out = Tensor.buildEmptyTensor("out", "byxc"); - List> inps = new ArrayList>(); - inps.add(inp); - List> outs = new ArrayList>(); - outs.add(out); - ti.run(inps, outs); - System.out.println(false); - return; - } - // Unpack the args needed - if (args.length < 4) - throw new IllegalArgumentException("Error exectuting Tensorflow 1, " - + "at least 5 arguments are required:" + System.lineSeparator() - + " - Folder where the model is located" + System.lineSeparator() - + " - Temporary dir where the memory mapped files are located" + System.lineSeparator() - + " - Name of the model input followed by the String + '_model_input'" + System.lineSeparator() - + " - Name of the second model input (if it exists) followed by the String + '_model_input'" + System.lineSeparator() - + " - ...." + System.lineSeparator() - + " - Name of the nth model input (if it exists) followed by the String + '_model_input'" + System.lineSeparator() - + " - Name of the model output followed by the String + '_model_output'" + System.lineSeparator() - + " - Name of the second model output (if it exists) followed by the String + '_model_output'" + System.lineSeparator() - + " - ...." + System.lineSeparator() - + " - Name of the nth model output (if it exists) followed by the String + '_model_output'" + System.lineSeparator() - ); - String modelFolder = args[0]; - if (!(new File(modelFolder).isDirectory())) { - throw new IllegalArgumentException("Argument 0 of the main method, '" + modelFolder + "' " - + "should be an existing directory containing a Tensorflow 1 model."); - } - - Tensorflow1Interface tfInterface = new Tensorflow1Interface(false); - tfInterface.tmpDir = args[1]; - if (!(new File(args[1]).isDirectory())) { - throw new IllegalArgumentException("Argument 1 of the main method, '" + args[1] + "' " - + "should be an existing directory."); - } - - tfInterface.loadModel(modelFolder, modelFolder); - - HashMap> map = tfInterface.getInputTensorsFileNames(args); - List inputNames = map.get(INPUTS_MAP_KEY); - List> inputList = inputNames.stream().map(n -> { - try { - return tfInterface.retrieveInterprocessingTensorsByName(n); - } catch (RunModelException e) { - return null; - } - }).collect(Collectors.toList()); - List outputNames = map.get(OUTPUTS_MAP_KEY); - List> outputList = outputNames.stream().map(n -> { - try { - return tfInterface.retrieveInterprocessingTensorsByName(n); - } catch (RunModelException e) { - return null; - } - }).collect(Collectors.toList()); - tfInterface.run(inputList, outputList); - tfInterface.createTensorsForInterprocessing(outputList); - } - - /** - * Get the name of teh temporary file associated to the tensor name - * @param name - * name of the tensor - * @return file name associated to the tensor - */ - private String getFilename4Tensor(String name) { - if (tensorFilenameMap == null) - tensorFilenameMap = new HashMap(); - if (tensorFilenameMap.get(name) != null) - return tensorFilenameMap.get(name); - LocalDateTime now = LocalDateTime.now(); - DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyyMMddHHmmssSSS"); - String newName = name + "_" + now.format(formatter); - tensorFilenameMap.put(name, newName); - return tensorFilenameMap.get(name); - } - - /** - * Create a temporary file for each of the tensors in the list to communicate with - * the separate process in MacOS Intel systems - * @param tensors - * list of tensors to be sent - * @throws RunModelException if there is any error converting the tensors - */ - private void createTensorsForInterprocessing(List> tensors) throws RunModelException{ - if (this.listTempFiles == null) - this.listTempFiles = new ArrayList(); - for (Tensor tensor : tensors) { - long lenFile = ImgLib2ToMappedBuffer.findTotalLengthFile(tensor); - File ff = new File(tmpDir + File.separator + getFilename4Tensor(tensor.getName()) + FILE_EXTENSION); - if (!ff.exists()) { - ff.deleteOnExit(); - this.listTempFiles.add(ff); - } - try (RandomAccessFile rd = - new RandomAccessFile(ff, "rw"); - FileChannel fc = rd.getChannel();) { - MappedByteBuffer mem = fc.map(FileChannel.MapMode.READ_WRITE, 0, lenFile); - ByteBuffer byteBuffer = mem.duplicate(); - ImgLib2ToMappedBuffer.build(tensor, byteBuffer); - } catch (IOException e) { - closeModel(); - throw new RunModelException(e.getCause().toString()); - } - } - } - - /** - * Retrieves the data of the tensors contained in the input list from the output - * generated by the independent process - * @param tensors - * list of tensors that are going to be filled - * @throws RunModelException if there is any issue retrieving the data from the other process - */ - private void retrieveInterprocessingTensors(List> tensors) throws RunModelException{ - for (Tensor tensor : tensors) { - try (RandomAccessFile rd = - new RandomAccessFile(tmpDir + File.separator - + this.getFilename4Tensor(tensor.getName()) + FILE_EXTENSION, "r"); - FileChannel fc = rd.getChannel();) { - MappedByteBuffer mem = fc.map(FileChannel.MapMode.READ_ONLY, 0, fc.size()); - ByteBuffer byteBuffer = mem.duplicate(); - tensor.setData(MappedBufferToImgLib2.build(byteBuffer)); - } catch (IOException e) { - closeModel(); - throw new RunModelException(e.getCause().toString()); - } - } - } - - /** - * Create a tensor from the data contained in a file named as the parameter - * provided as an input + the file extension {@link #FILE_EXTENSION}. - * This file is produced by another process to communicate with the current process - * @param - * generic type of the tensor - * @param name - * name of the file without the extension ({@link #FILE_EXTENSION}). - * @return a tensor created with the data in the file - * @throws RunModelException if there is any problem retrieving the data and cerating the tensor - */ - private < T extends RealType< T > & NativeType< T > > Tensor - retrieveInterprocessingTensorsByName(String name) throws RunModelException { - try (RandomAccessFile rd = - new RandomAccessFile(tmpDir + File.separator - + this.getFilename4Tensor(name) + FILE_EXTENSION, "r"); - FileChannel fc = rd.getChannel();) { - MappedByteBuffer mem = fc.map(FileChannel.MapMode.READ_ONLY, 0, fc.size()); - ByteBuffer byteBuffer = mem.duplicate(); - return MappedBufferToImgLib2.buildTensor(byteBuffer); - } catch (IOException e) { - closeModel(); - throw new RunModelException(e.getCause().toString()); - } - } - /** * if java bin dir contains any special char, surround it by double quotes * @param javaBin @@ -695,12 +621,7 @@ private List getProcessCommandsWithoutArgs() throws IOException, URISynt String javaHome = System.getProperty("java.home"); String javaBin = javaHome + File.separator + "bin" + File.separator + "java"; - String modelrunnerPath = getPathFromClass(DeepLearningEngineInterface.class); - String imglib2Path = getPathFromClass(NativeType.class); - if (modelrunnerPath == null || (modelrunnerPath.endsWith("DeepLearningEngineInterface.class") - && !modelrunnerPath.contains(File.pathSeparator))) - modelrunnerPath = System.getProperty("java.class.path"); - String classpath = modelrunnerPath + File.pathSeparator + imglib2Path + File.pathSeparator; + String classpath = getCurrentClasspath(); ProtectionDomain protectionDomain = Tensorflow1Interface.class.getProtectionDomain(); String codeSource = protectionDomain.getCodeSource().getLocation().getPath(); String f_name = URLDecoder.decode(codeSource, StandardCharsets.UTF_8.toString()); @@ -710,17 +631,34 @@ private List getProcessCommandsWithoutArgs() throws IOException, URISynt continue; classpath += ff.getAbsolutePath() + File.pathSeparator; } - String className = Tensorflow1Interface.class.getName(); + String className = JavaWorker.class.getName(); List command = new LinkedList(); command.add(padSpecialJavaBin(javaBin)); command.add("-cp"); command.add(classpath); command.add(className); - command.add(modelFolder); - command.add(this.tmpDir); return command; } + private static String getCurrentClasspath() throws UnsupportedEncodingException { + + String modelrunnerPath = getPathFromClass(DeepLearningEngineInterface.class); + String imglib2Path = getPathFromClass(NativeType.class); + String gsonPath = getPathFromClass(Gson.class); + String jnaPath = getPathFromClass(com.sun.jna.Library.class); + String jnaPlatformPath = getPathFromClass(com.sun.jna.platform.FileUtils.class); + if (modelrunnerPath == null || (modelrunnerPath.endsWith("DeepLearningEngineInterface.class") + && !modelrunnerPath.contains(File.pathSeparator))) + modelrunnerPath = System.getProperty("java.class.path"); + modelrunnerPath = System.getProperty("java.class.path"); + String classpath = modelrunnerPath + File.pathSeparator + imglib2Path + File.pathSeparator; + classpath = classpath + gsonPath + File.pathSeparator; + classpath = classpath + jnaPath + File.pathSeparator; + classpath = classpath + jnaPlatformPath + File.pathSeparator; + + return classpath; + } + /** * Method that gets the path to the JAR from where a specific class is being loaded * @param clazz @@ -750,130 +688,4 @@ private static String getPathFromClass(Class clazz) throws UnsupportedEncodin path = path.substring(0, path.lastIndexOf(".jar!")) + ".jar"; return path; } - - /** - * Get temporary directory to perform the interprocessing communication in MacOSX intel - * @return the tmp dir - * @throws IOException if the files cannot be written in any of the temp dirs - */ - private static String getTemporaryDir() throws IOException { - String tmpDir; - String enginesDir = getEnginesDir(); - if (enginesDir != null && Files.isWritable(Paths.get(enginesDir))) { - tmpDir = enginesDir + File.separator + "temp"; - if (!(new File(tmpDir).isDirectory()) && !(new File(tmpDir).mkdirs())) - tmpDir = enginesDir; - } else if (System.getenv("temp") != null - && Files.isWritable(Paths.get(System.getenv("temp")))) { - return System.getenv("temp"); - } else if (System.getenv("TEMP") != null - && Files.isWritable(Paths.get(System.getenv("TEMP")))) { - return System.getenv("TEMP"); - } else if (System.getenv("tmp") != null - && Files.isWritable(Paths.get(System.getenv("tmp")))) { - return System.getenv("tmp"); - } else if (System.getenv("TMP") != null - && Files.isWritable(Paths.get(System.getenv("TMP")))) { - return System.getenv("TMP"); - } else if (System.getProperty("java.io.tmpdir") != null - && Files.isWritable(Paths.get(System.getProperty("java.io.tmpdir")))) { - return System.getProperty("java.io.tmpdir"); - } else { - throw new IOException("Unable to find temporal directory with writting rights. " - + "Please either allow writting on the system temporal folder or on '" + enginesDir + "'."); - } - return tmpDir; - } - - /** - * GEt the directory where the TF2 engine is located if a temporary dir is not found - * @return directory of the engines - */ - private static String getEnginesDir() { - String dir; - try { - dir = getPathFromClass(Tensorflow1Interface.class); - } catch (UnsupportedEncodingException e) { - String classResource = Tensorflow1Interface.class.getName().replace('.', '/') + ".class"; - URL resourceUrl = Tensorflow1Interface.class.getClassLoader().getResource(classResource); - if (resourceUrl == null) { - return null; - } - String urlString = resourceUrl.toString(); - if (urlString.startsWith("jar:")) { - urlString = urlString.substring(4); - } - if (urlString.startsWith("file:/") && PlatformDetection.isWindows()) { - urlString = urlString.substring(6); - } else if (urlString.startsWith("file:/") && !PlatformDetection.isWindows()) { - urlString = urlString.substring(5); - } - File file = new File(urlString); - String path = file.getAbsolutePath(); - if (path.lastIndexOf(".jar!") != -1) - path = path.substring(0, path.lastIndexOf(".jar!")) + ".jar"; - dir = path; - } - return new File(dir).getParent(); - } - - /** - * Retrieve the file names used for interprocess communication - * @param args - * args provided to the main method - * @return a map with a list of input and output names - */ - private HashMap> getInputTensorsFileNames(String[] args) { - List inputNames = new ArrayList(); - List outputNames = new ArrayList(); - if (this.tensorFilenameMap == null) - this.tensorFilenameMap = new HashMap(); - for (int i = 2; i < args.length; i ++) { - if (args[i].endsWith(INPUT_FILE_TERMINATION)) { - String nameWTimestamp = args[i].substring(0, args[i].length() - INPUT_FILE_TERMINATION.length()); - String onlyName = nameWTimestamp.substring(0, nameWTimestamp.lastIndexOf("_")); - inputNames.add(onlyName); - tensorFilenameMap.put(onlyName, nameWTimestamp); - } else if (args[i].endsWith(OUTPUT_FILE_TERMINATION)) { - String nameWTimestamp = args[i].substring(0, args[i].length() - OUTPUT_FILE_TERMINATION.length()); - String onlyName = nameWTimestamp.substring(0, nameWTimestamp.lastIndexOf("_")); - outputNames.add(onlyName); - tensorFilenameMap.put(onlyName, nameWTimestamp); - - } - } - if (inputNames.size() == 0) - throw new IllegalArgumentException("The args to the main method of '" - + Tensorflow1Interface.class.toString() + "' should contain at " - + "least one input, defined as ' + '" + INPUT_FILE_TERMINATION + "'."); - if (outputNames.size() == 0) - throw new IllegalArgumentException("The args to the main method of '" - + Tensorflow1Interface.class.toString() + "' should contain at " - + "least one output, defined as ' + '" + OUTPUT_FILE_TERMINATION + "'."); - HashMap> map = new HashMap>(); - map.put(INPUTS_MAP_KEY, inputNames); - map.put(OUTPUTS_MAP_KEY, outputNames); - return map; - } - - /** - * MEthod to obtain the String output of the process in case something goes wrong - * @param process - * the process that executed the TF1 model - * @return the String output that we would have seen on the terminal - * @throws IOException if the output of the terminal cannot be seen - */ - private static String readProcessStringOutput(Process process) throws IOException { - BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(process.getInputStream())); - BufferedReader bufferedErrReader = new BufferedReader(new InputStreamReader(process.getErrorStream())); - String text = ""; - String line; - while ((line = bufferedErrReader.readLine()) != null) { - text += line + System.lineSeparator(); - } - while ((line = bufferedReader.readLine()) != null) { - text += line + System.lineSeparator(); - } - return text; - } } diff --git a/src/main/java/io/bioimage/modelrunner/tensorflow/v1/shm/ShmBuilder.java b/src/main/java/io/bioimage/modelrunner/tensorflow/v1/shm/ShmBuilder.java new file mode 100644 index 0000000..eb41299 --- /dev/null +++ b/src/main/java/io/bioimage/modelrunner/tensorflow/v1/shm/ShmBuilder.java @@ -0,0 +1,221 @@ +/*- + * #%L + * This project complements the DL-model runner acting as the engine that works loading models + * and making inference with Java 0.3.0 and newer API for Tensorflow 2. + * %% + * Copyright (C) 2022 - 2023 Institut Pasteur and BioImage.IO developers. + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ +package io.bioimage.modelrunner.tensorflow.v1.shm; + +import io.bioimage.modelrunner.system.PlatformDetection; +import io.bioimage.modelrunner.tensor.shm.SharedMemoryArray; +import io.bioimage.modelrunner.utils.CommonUtils; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Arrays; + +import org.tensorflow.Tensor; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.TUint8; +import org.tensorflow.types.family.TType; + +import net.imglib2.RandomAccessibleInterval; +import net.imglib2.type.numeric.integer.IntType; +import net.imglib2.type.numeric.integer.LongType; +import net.imglib2.type.numeric.integer.UnsignedByteType; +import net.imglib2.type.numeric.real.DoubleType; +import net.imglib2.type.numeric.real.FloatType; + +/** + * A {@link RandomAccessibleInterval} builder for TensorFlow {@link Tensor} objects. + * Build ImgLib2 objects (backend of {@link io.bioimage.modelrunner.tensor.Tensor}) + * from Tensorflow 2 {@link Tensor} + * + * @author Carlos Garcia Lopez de Haro and Daniel Felipe Gonzalez Obando + */ +public final class ShmBuilder +{ + /** + * Utility class. + */ + private ShmBuilder() + { + } + + /** + * Creates a {@link RandomAccessibleInterval} from a given {@link TType} tensor + * + * @param + * the possible ImgLib2 datatypes of the image + * @param tensor + * The {@link TType} tensor data is read from. + * @throws IllegalArgumentException If the {@link TType} tensor type is not supported. + * @throws IOException + */ + @SuppressWarnings("unchecked") + public static void build(Tensor tensor, String memoryName) throws IllegalArgumentException, IOException + { + switch (tensor.dataType().name()) + { + case TUint8.NAME: + buildFromTensorUByte((Tensor) tensor, memoryName); + case TInt32.NAME: + buildFromTensorInt((Tensor) tensor, memoryName); + case TFloat32.NAME: + buildFromTensorFloat((Tensor) tensor, memoryName); + case TFloat64.NAME: + buildFromTensorDouble((Tensor) tensor, memoryName); + case TInt64.NAME: + buildFromTensorLong((Tensor) tensor, memoryName); + default: + throw new IllegalArgumentException("Unsupported tensor type: " + tensor.dataType().name()); + } + } + + /** + * Builds a {@link RandomAccessibleInterval} from a unsigned byte-typed {@link TUint8} tensor. + * + * @param tensor + * The {@link TUint8} tensor data is read from. + * @return The {@link RandomAccessibleInterval} built from the tensor, of type {@link UnsignedByteType}. + * @throws IOException + */ + private static void buildFromTensorUByte(Tensor tensor, String memoryName) throws IOException + { + long[] arrayShape = tensor.shape().asArray(); + if (CommonUtils.int32Overflows(arrayShape, 1)) + throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape) + + " is too big. Max number of elements per ubyte output tensor supported: " + Integer.MAX_VALUE / 1); + SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new UnsignedByteType(), false, true); + ByteBuffer buff = shma.getDataBuffer(); + int totalSize = 1; + for (long i : arrayShape) {totalSize *= i;} + byte[] flatArr = new byte[buff.capacity()]; + buff.get(flatArr); + tensor.rawData().read(flatArr, flatArr.length - totalSize, totalSize); + shma.setBuffer(ByteBuffer.wrap(flatArr)); + if (PlatformDetection.isWindows()) shma.close(); + } + + /** + * Builds a {@link RandomAccessibleInterval} from a unsigned int32-typed {@link TInt32} tensor. + * + * @param tensor + * The {@link TInt32} tensor data is read from. + * @return The {@link RandomAccessibleInterval} built from the tensor, of type {@link IntType}. + * @throws IOException + */ + private static void buildFromTensorInt(Tensor tensor, String memoryName) throws IOException + { + long[] arrayShape = tensor.shape().asArray(); + if (CommonUtils.int32Overflows(arrayShape, 4)) + throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape) + + " is too big. Max number of elements per int output tensor supported: " + Integer.MAX_VALUE / 4); + + SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new IntType(), false, true); + ByteBuffer buff = shma.getDataBuffer(); + int totalSize = 4; + for (long i : arrayShape) {totalSize *= i;} + byte[] flatArr = new byte[buff.capacity()]; + buff.get(flatArr); + tensor.rawData().read(flatArr, flatArr.length - totalSize, totalSize); + shma.setBuffer(ByteBuffer.wrap(flatArr)); + if (PlatformDetection.isWindows()) shma.close(); + } + + /** + * Builds a {@link RandomAccessibleInterval} from a unsigned float32-typed {@link TFloat32} tensor. + * + * @param tensor + * The {@link TFloat32} tensor data is read from. + * @return The {@link RandomAccessibleInterval} built from the tensor, of type {@link FloatType}. + * @throws IOException + */ + private static void buildFromTensorFloat(Tensor tensor, String memoryName) throws IOException + { + long[] arrayShape = tensor.shape().asArray(); + if (CommonUtils.int32Overflows(arrayShape, 4)) + throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape) + + " is too big. Max number of elements per float output tensor supported: " + Integer.MAX_VALUE / 4); + + SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new FloatType(), false, true); + ByteBuffer buff = shma.getDataBuffer(); + int totalSize = 4; + for (long i : arrayShape) {totalSize *= i;} + byte[] flatArr = new byte[buff.capacity()]; + buff.get(flatArr); + tensor.rawData().read(flatArr, flatArr.length - totalSize, totalSize); + shma.setBuffer(ByteBuffer.wrap(flatArr)); + if (PlatformDetection.isWindows()) shma.close(); + } + + /** + * Builds a {@link RandomAccessibleInterval} from a unsigned float64-typed {@link TFloat64} tensor. + * + * @param tensor + * The {@link TFloat64} tensor data is read from. + * @return The {@link RandomAccessibleInterval} built from the tensor, of type {@link DoubleType}. + * @throws IOException + */ + private static void buildFromTensorDouble(Tensor tensor, String memoryName) throws IOException + { + long[] arrayShape = tensor.shape().asArray(); + if (CommonUtils.int32Overflows(arrayShape, 8)) + throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape) + + " is too big. Max number of elements per double output tensor supported: " + Integer.MAX_VALUE / 8); + + SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new DoubleType(), false, true); + ByteBuffer buff = shma.getDataBuffer(); + int totalSize = 8; + for (long i : arrayShape) {totalSize *= i;} + byte[] flatArr = new byte[buff.capacity()]; + buff.get(flatArr); + tensor.rawData().read(flatArr, flatArr.length - totalSize, totalSize); + shma.setBuffer(ByteBuffer.wrap(flatArr)); + if (PlatformDetection.isWindows()) shma.close(); + } + + /** + * Builds a {@link RandomAccessibleInterval} from a unsigned int64-typed {@link TInt64} tensor. + * + * @param tensor + * The {@link TInt64} tensor data is read from. + * @return The {@link RandomAccessibleInterval} built from the tensor, of type {@link LongType}. + * @throws IOException + */ + private static void buildFromTensorLong(Tensor tensor, String memoryName) throws IOException + { + long[] arrayShape = tensor.shape().asArray(); + if (CommonUtils.int32Overflows(arrayShape, 8)) + throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape) + + " is too big. Max number of elements per long output tensor supported: " + Integer.MAX_VALUE / 8); + + + SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new LongType(), false, true); + ByteBuffer buff = shma.getDataBuffer(); + int totalSize = 8; + for (long i : arrayShape) {totalSize *= i;} + byte[] flatArr = new byte[buff.capacity()]; + buff.get(flatArr); + tensor.rawData().read(flatArr, flatArr.length - totalSize, totalSize); + shma.setBuffer(ByteBuffer.wrap(flatArr)); + if (PlatformDetection.isWindows()) shma.close(); + } +} diff --git a/src/main/java/io/bioimage/modelrunner/tensorflow/v1/shm/TensorBuilder.java b/src/main/java/io/bioimage/modelrunner/tensorflow/v1/shm/TensorBuilder.java new file mode 100644 index 0000000..3bb3410 --- /dev/null +++ b/src/main/java/io/bioimage/modelrunner/tensorflow/v1/shm/TensorBuilder.java @@ -0,0 +1,242 @@ +/*- + * #%L + * This project complements the DL-model runner acting as the engine that works loading models + * and making inference with Java 0.3.0 and newer API for Tensorflow 2. + * %% + * Copyright (C) 2022 - 2023 Institut Pasteur and BioImage.IO developers. + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +package io.bioimage.modelrunner.tensorflow.v1.shm; + +import io.bioimage.modelrunner.tensor.shm.SharedMemoryArray; +import io.bioimage.modelrunner.utils.CommonUtils; +import net.imglib2.RandomAccessibleInterval; +import net.imglib2.img.Img; +import net.imglib2.type.numeric.integer.IntType; +import net.imglib2.type.numeric.integer.LongType; +import net.imglib2.type.numeric.integer.UnsignedByteType; +import net.imglib2.type.numeric.real.DoubleType; +import net.imglib2.type.numeric.real.FloatType; +import net.imglib2.util.Cast; + +import java.nio.ByteBuffer; +import java.nio.DoubleBuffer; +import java.nio.FloatBuffer; +import java.nio.IntBuffer; +import java.nio.LongBuffer; +import java.util.Arrays; + +import org.tensorflow.Tensor; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.buffer.ByteDataBuffer; +import org.tensorflow.ndarray.buffer.DoubleDataBuffer; +import org.tensorflow.ndarray.buffer.FloatDataBuffer; +import org.tensorflow.ndarray.buffer.IntDataBuffer; +import org.tensorflow.ndarray.buffer.LongDataBuffer; +import org.tensorflow.ndarray.impl.buffer.raw.RawDataBufferFactory; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.TUint8; +import org.tensorflow.types.family.TType; + +/** + * A TensorFlow 2 {@link Tensor} builder from {@link Img} and + * {@link io.bioimage.modelrunner.tensor.Tensor} objects. + * + * @author Carlos Garcia Lopez de Haro and Daniel Felipe Gonzalez Obando + */ +public final class TensorBuilder { + + /** + * Utility class. + */ + private TensorBuilder() {} + + /** + * Creates {@link TType} instance with the same size and information as the + * given {@link RandomAccessibleInterval}. + * + * @param + * the ImgLib2 data types the {@link RandomAccessibleInterval} can be + * @param array + * the {@link RandomAccessibleInterval} that is going to be converted into + * a {@link TType} tensor + * @return a {@link TType} tensor + * @throws IllegalArgumentException if the type of the {@link RandomAccessibleInterval} + * is not supported + */ + public static Tensor build(SharedMemoryArray array) throws IllegalArgumentException + { + // Create an Icy sequence of the same type of the tensor + if (array.getOriginalDataType().equals("uint8")) { + return buildUByte(Cast.unchecked(array)); + } + else if (array.getOriginalDataType().equals("int32")) { + return buildInt(Cast.unchecked(array)); + } + else if (array.getOriginalDataType().equals("float32")) { + return buildFloat(Cast.unchecked(array)); + } + else if (array.getOriginalDataType().equals("float64")) { + return buildDouble(Cast.unchecked(array)); + } + else if (array.getOriginalDataType().equals("int64")) { + return buildLong(Cast.unchecked(array)); + } + else { + throw new IllegalArgumentException("Unsupported tensor type: " + array.getOriginalDataType()); + } + } + + /** + * Creates a {@link TType} tensor of type {@link TUint8} from an + * {@link RandomAccessibleInterval} of type {@link UnsignedByteType} + * + * @param tensor + * The {@link RandomAccessibleInterval} to fill the tensor with. + * @return The {@link TType} tensor filled with the {@link RandomAccessibleInterval} data. + * @throws IllegalArgumentException if the input {@link RandomAccessibleInterval} type is + * not compatible + */ + public static Tensor buildUByte(SharedMemoryArray tensor) + throws IllegalArgumentException + { + long[] ogShape = tensor.getOriginalShape(); + if (CommonUtils.int32Overflows(ogShape, 1)) + throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape) + + " is too big. Max number of elements per ubyte tensor supported: " + Integer.MAX_VALUE); + if (!tensor.isNumpyFormat()) + throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format."); + ByteBuffer buff = tensor.getDataBufferNoHeader(); + ByteDataBuffer dataBuffer = RawDataBufferFactory.create(buff.array(), false); + Tensor ndarray = Tensor.of(TUint8.DTYPE, Shape.of(ogShape), dataBuffer); + return ndarray; + } + + /** + * Creates a {@link TInt32} tensor of type {@link TInt32} from an + * {@link RandomAccessibleInterval} of type {@link IntType} + * + * @param tensor + * The {@link RandomAccessibleInterval} to fill the tensor with. + * @return The {@link TInt32} tensor filled with the {@link RandomAccessibleInterval} data. + * @throws IllegalArgumentException if the input {@link RandomAccessibleInterval} type is + * not compatible + */ + public static Tensor buildInt(SharedMemoryArray tensor) + throws IllegalArgumentException + { + long[] ogShape = tensor.getOriginalShape(); + if (CommonUtils.int32Overflows(ogShape, 1)) + throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape) + + " is too big. Max number of elements per ubyte tensor supported: " + Integer.MAX_VALUE); + if (!tensor.isNumpyFormat()) + throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format."); + ByteBuffer buff = tensor.getDataBufferNoHeader(); + IntBuffer intBuff = buff.asIntBuffer(); + int[] intArray = new int[intBuff.capacity()]; + intBuff.get(intArray); + IntDataBuffer dataBuffer = RawDataBufferFactory.create(intArray, false); + Tensor ndarray = TInt32.tensorOf(Shape.of(ogShape), dataBuffer); + return ndarray; + } + + /** + * Creates a {@link TInt64} tensor of type {@link TInt64} from an + * {@link RandomAccessibleInterval} of type {@link LongType} + * + * @param tensor + * The {@link RandomAccessibleInterval} to fill the tensor with. + * @return The {@link TInt64} tensor filled with the {@link RandomAccessibleInterval} data. + * @throws IllegalArgumentException if the input {@link RandomAccessibleInterval} type is + * not compatible + */ + private static Tensor buildLong(SharedMemoryArray tensor) + throws IllegalArgumentException + { + long[] ogShape = tensor.getOriginalShape(); + if (CommonUtils.int32Overflows(ogShape, 1)) + throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape) + + " is too big. Max number of elements per ubyte tensor supported: " + Integer.MAX_VALUE); + if (!tensor.isNumpyFormat()) + throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format."); + ByteBuffer buff = tensor.getDataBufferNoHeader(); + LongBuffer longBuff = buff.asLongBuffer(); + long[] longArray = new long[longBuff.capacity()]; + longBuff.get(longArray); + LongDataBuffer dataBuffer = RawDataBufferFactory.create(longArray, false); + Tensor ndarray = TInt64.tensorOf(Shape.of(ogShape), dataBuffer); + return ndarray; + } + + /** + * Creates a {@link TFloat32} tensor of type {@link TFloat32} from an + * {@link RandomAccessibleInterval} of type {@link FloatType} + * + * @param tensor + * The {@link RandomAccessibleInterval} to fill the tensor with. + * @return The {@link TFloat32} tensor filled with the {@link RandomAccessibleInterval} data. + * @throws IllegalArgumentException if the input {@link RandomAccessibleInterval} type is + * not compatible + */ + public static Tensor buildFloat(SharedMemoryArray tensor) + throws IllegalArgumentException + { + long[] ogShape = tensor.getOriginalShape(); + if (CommonUtils.int32Overflows(ogShape, 1)) + throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape) + + " is too big. Max number of elements per ubyte tensor supported: " + Integer.MAX_VALUE); + if (!tensor.isNumpyFormat()) + throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format."); + ByteBuffer buff = tensor.getDataBufferNoHeader(); + FloatBuffer floatBuff = buff.asFloatBuffer(); + float[] floatArray = new float[floatBuff.capacity()]; + floatBuff.get(floatArray); + FloatDataBuffer dataBuffer = RawDataBufferFactory.create(floatArray, false); + Tensor ndarray = TFloat32.tensorOf(Shape.of(ogShape), dataBuffer); + return ndarray; + } + + /** + * Creates a {@link TFloat64} tensor of type {@link TFloat64} from an + * {@link RandomAccessibleInterval} of type {@link DoubleType} + * + * @param tensor + * The {@link RandomAccessibleInterval} to fill the tensor with. + * @return The {@link TFloat64} tensor filled with the {@link RandomAccessibleInterval} data. + * @throws IllegalArgumentException if the input {@link RandomAccessibleInterval} type is + * not compatible + */ + private static Tensor buildDouble(SharedMemoryArray tensor) + throws IllegalArgumentException + { + long[] ogShape = tensor.getOriginalShape(); + if (CommonUtils.int32Overflows(ogShape, 1)) + throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape) + + " is too big. Max number of elements per ubyte tensor supported: " + Integer.MAX_VALUE); + if (!tensor.isNumpyFormat()) + throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format."); + ByteBuffer buff = tensor.getDataBufferNoHeader(); + DoubleBuffer doubleBuff = buff.asDoubleBuffer(); + double[] doubleArray = new double[doubleBuff.capacity()]; + doubleBuff.get(doubleArray); + DoubleDataBuffer dataBuffer = RawDataBufferFactory.create(doubleArray, false); + Tensor ndarray = TFloat64.tensorOf(Shape.of(ogShape), dataBuffer); + return ndarray; + } +} diff --git a/src/main/java/io/bioimage/modelrunner/tensorflow/v1/tensor/mappedbuffer/ImgLib2ToMappedBuffer.java b/src/main/java/io/bioimage/modelrunner/tensorflow/v1/tensor/mappedbuffer/ImgLib2ToMappedBuffer.java deleted file mode 100644 index ff44674..0000000 --- a/src/main/java/io/bioimage/modelrunner/tensorflow/v1/tensor/mappedbuffer/ImgLib2ToMappedBuffer.java +++ /dev/null @@ -1,283 +0,0 @@ -/*- - * #%L - * This project complements the DL-model runner acting as the engine that works loading models - * and making inference with Java API for Tensorflow 1. - * %% - * Copyright (C) 2022 - 2023 Institut Pasteur and BioImage.IO developers. - * %% - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * #L% - */ -package io.bioimage.modelrunner.tensorflow.v1.tensor.mappedbuffer; - -import java.nio.ByteBuffer; -import java.nio.charset.StandardCharsets; -import java.util.Arrays; - -import io.bioimage.modelrunner.tensor.Tensor; -import net.imglib2.Cursor; -import net.imglib2.RandomAccessibleInterval; -import net.imglib2.img.Img; -import net.imglib2.type.NativeType; -import net.imglib2.type.Type; -import net.imglib2.type.numeric.RealType; -import net.imglib2.type.numeric.integer.ByteType; -import net.imglib2.type.numeric.integer.IntType; -import net.imglib2.type.numeric.integer.LongType; -import net.imglib2.type.numeric.integer.UnsignedByteType; -import net.imglib2.type.numeric.real.DoubleType; -import net.imglib2.type.numeric.real.FloatType; -import net.imglib2.util.Util; -import net.imglib2.view.IntervalView; -import net.imglib2.view.Views; - -/** - * Class that maps {@link Tensor} objects to {@link ByteBuffer} objects. - * This is done to modify the files that are used to communicate between process - * in MacOS Intel to avoid the TF1-TF2 incompatibiity that happens in these systems - * - * @author Carlos Garcia Lopez de Haro - */ -public final class ImgLib2ToMappedBuffer -{ - /** - * Header used to identify files for interprocessing communication - */ - final public static byte[] MODEL_RUNNER_HEADER = - {(byte) 0x93, 'M', 'O', 'D', 'E', 'L', '-', 'R', 'U', 'N', 'N', 'E', 'R'}; - - /** - * Not used (Utility class). - */ - private ImgLib2ToMappedBuffer() - { - } - - /** - * Maps a {@link Tensor} to the provided {@link ByteBuffer} with all the information - * needed to reconstruct the tensor again - * - * @param - * the type of the tensor - * @param tensor - * tensor to be mapped into byte buffer - * @param byteBuffer - * target byte bufer - * @throws IllegalArgumentException - * If the {@link Tensor} ImgLib2 type is not supported. - */ - public static < T extends RealType< T > & NativeType< T > > void build(Tensor tensor, ByteBuffer byteBuffer) - { - byteBuffer.put(ImgLib2ToMappedBuffer.createFileHeader(tensor)); - if (tensor.isEmpty()) - return; - build(tensor.getData(), byteBuffer); - } - - /** - * Adds the {@link RandomAccessibleInterval} data to the {@link ByteBuffer} provided. - * The position of the ByteBuffer is kept in the same place as it was received. - * - * @param - * the type of the {@link RandomAccessibleInterval} - * @param rai - * {@link RandomAccessibleInterval} to be mapped into byte buffer - * @param byteBuffer - * target bytebuffer - * @throws IllegalArgumentException If the {@link RandomAccessibleInterval} type is not supported. - */ - private static > void build(RandomAccessibleInterval rai, ByteBuffer byteBuffer) - { - if (Util.getTypeFromInterval(rai) instanceof ByteType) { - buildByte((RandomAccessibleInterval) rai, byteBuffer); - } else if (Util.getTypeFromInterval(rai) instanceof IntType) { - buildInt((RandomAccessibleInterval) rai, byteBuffer); - } else if (Util.getTypeFromInterval(rai) instanceof FloatType) { - buildFloat((RandomAccessibleInterval) rai, byteBuffer); - } else if (Util.getTypeFromInterval(rai) instanceof DoubleType) { - buildDouble((RandomAccessibleInterval) rai, byteBuffer); - } else { - throw new IllegalArgumentException("The image has an unsupported type: " + Util.getTypeFromInterval(rai).getClass().toString()); - } - } - - /** - * Adds the ByteType {@link RandomAccessibleInterval} data to the {@link ByteBuffer} provided. - * The position of the ByteBuffer is kept in the same place as it was received. - * - * @param imgTensor - * {@link RandomAccessibleInterval} to be mapped into byte buffer - * @param byteBuffer - * target bytebuffer - */ - private static void buildByte(RandomAccessibleInterval imgTensor, ByteBuffer byteBuffer) - { - Cursor tensorCursor = Views.flatIterable(imgTensor).cursor(); - while (tensorCursor.hasNext()) { - tensorCursor.fwd(); - byteBuffer.put(tensorCursor.get().getByte()); - } - } - - /** - * Adds the IntType {@link RandomAccessibleInterval} data to the {@link ByteBuffer} provided. - * The position of the ByteBuffer is kept in the same place as it was received. - * - * @param imgTensor - * {@link RandomAccessibleInterval} to be mapped into byte buffer - * @param byteBuffer - * target bytebuffer - */ - private static void buildInt(RandomAccessibleInterval imgTensor, ByteBuffer byteBuffer) - { - Cursor tensorCursor = Views.flatIterable(imgTensor).cursor(); - while (tensorCursor.hasNext()) { - tensorCursor.fwd(); - byteBuffer.putInt(tensorCursor.get().getInt()); - } - } - - /** - * Adds the FloatType {@link RandomAccessibleInterval} data to the {@link ByteBuffer} provided. - * The position of the ByteBuffer is kept in the same place as it was received. - * - * @param imgTensor - * {@link RandomAccessibleInterval} to be mapped into byte buffer - * @param byteBuffer - * target bytebuffer - */ - private static void buildFloat(RandomAccessibleInterval imgTensor, ByteBuffer byteBuffer) - { - Cursor tensorCursor = Views.flatIterable(imgTensor).cursor(); - while (tensorCursor.hasNext()) { - tensorCursor.fwd(); - byteBuffer.putFloat(tensorCursor.get().getRealFloat()); - } - } - - /** - * Adds the DoubleType {@link RandomAccessibleInterval} data to the {@link ByteBuffer} provided. - * The position of the ByteBuffer is kept in the same place as it was received. - * - * @param imgTensor - * {@link RandomAccessibleInterval} to be mapped into byte buffer - * @param byteBuffer - * target bytebuffer - */ - private static void buildDouble(RandomAccessibleInterval imgTensor, ByteBuffer byteBuffer) - { - Cursor tensorCursor = Views.flatIterable(imgTensor).cursor(); - while (tensorCursor.hasNext()) { - tensorCursor.fwd(); - byteBuffer.putDouble(tensorCursor.get().getRealDouble()); - } - } - - /** - * Create header for the temp file that is used for interprocess communication. - * The header should contain the first key word as an array of bytes (MODEl-RUNNER) - * @param - * type of the tensor - * @param tensor - * tensor whose info is recorded - * @return byte array containing the header info for the file - */ - public static < T extends RealType< T > & NativeType< T > > byte[] - createFileHeader(io.bioimage.modelrunner.tensor.Tensor tensor) { - String dimsStr = - !tensor.isEmpty() ? Arrays.toString(tensor.getData().dimensionsAsLongArray()) : "[]"; - T dtype = !tensor.isEmpty() ? Util.getTypeFromInterval(tensor.getData()): (T) new FloatType(); - String descriptionStr = "{'dtype':'" - + getDataTypeString(dtype) + "','axes':'" - + tensor.getAxesOrderString() + "','name':'" + tensor.getName() + "','shape':'" - + dimsStr + "'}"; - - byte[] descriptionBytes = descriptionStr.getBytes(StandardCharsets.UTF_8); - int lenDescriptionBytes = descriptionBytes.length; - byte[] intAsBytes = ByteBuffer.allocate(4).putInt(lenDescriptionBytes).array(); - int totalHeaderLen = MODEL_RUNNER_HEADER.length + intAsBytes.length + lenDescriptionBytes; - byte[] byteHeader = new byte[totalHeaderLen]; - for (int i = 0; i < MODEL_RUNNER_HEADER.length; i ++) - byteHeader[i] = MODEL_RUNNER_HEADER[i]; - for (int i = MODEL_RUNNER_HEADER.length; i < MODEL_RUNNER_HEADER.length + intAsBytes.length; i ++) - byteHeader[i] = intAsBytes[i - MODEL_RUNNER_HEADER.length]; - for (int i = MODEL_RUNNER_HEADER.length + intAsBytes.length; i < totalHeaderLen; i ++) - byteHeader[i] = descriptionBytes[i - MODEL_RUNNER_HEADER.length - intAsBytes.length]; - - return byteHeader; - } - - /** - * Method that returns a Sting representing the datatype of T - * @param - * type of the tensor - * @param type - * pixel of an imglib2 object to get the info of teh data type - * @return String representation of the datatype - */ - public static< T extends RealType< T > & NativeType< T > > String getDataTypeString(T type) { - if (type instanceof ByteType) { - return "byte"; - } else if (type instanceof IntType) { - return "int32"; - } else if (type instanceof FloatType) { - return "float32"; - } else if (type instanceof DoubleType) { - return "float64"; - } else if (type instanceof LongType) { - return "int64"; - } else if (type instanceof UnsignedByteType) { - return "ubyte"; - } else { - throw new IllegalArgumentException("Unsupported data type. At the moment the only " - + "supported dtypes are: " + IntType.class + ", " + FloatType.class + ", " - + DoubleType.class + ", " + LongType.class + " and " + UnsignedByteType.class); - } - } - - /** - * Get the total byte size of the temp file that is going to be created to be - * able to reconstruct a {@link Tensor} to in the separate process in MacOS Intel - * systems - * - * @param - * type of the imglib2 object - * @param tensor - * tensor of interest - * @return number of bytes needed to create a file with the info of the tensor - */ - public static < T extends RealType< T > & NativeType< T > > long - findTotalLengthFile(io.bioimage.modelrunner.tensor.Tensor tensor) { - long startLen = createFileHeader(tensor).length; - long[] dimsArr = !tensor.isEmpty() ? tensor.getData().dimensionsAsLongArray() : null; - if (dimsArr == null) - return startLen; - long totSizeFlat = 1; - for (long i : dimsArr) {totSizeFlat *= i;} - long nBytesDt = 1; - Type dtype = !tensor.isEmpty() ? - Util.getTypeFromInterval(tensor.getData()) : (Type) new FloatType(); - if (dtype instanceof IntType) { - nBytesDt = 4; - } else if (dtype instanceof ByteType) { - nBytesDt = 1; - } else if (dtype instanceof FloatType) { - nBytesDt = 4; - } else if (dtype instanceof DoubleType) { - nBytesDt = 8; - } else { - throw new IllegalArgumentException("Unsupported tensor type: " + dtype.getClass()); - } - return startLen + nBytesDt * totSizeFlat; - } -} diff --git a/src/main/java/io/bioimage/modelrunner/tensorflow/v1/tensor/mappedbuffer/MappedBufferToImgLib2.java b/src/main/java/io/bioimage/modelrunner/tensorflow/v1/tensor/mappedbuffer/MappedBufferToImgLib2.java deleted file mode 100644 index 37acfc9..0000000 --- a/src/main/java/io/bioimage/modelrunner/tensorflow/v1/tensor/mappedbuffer/MappedBufferToImgLib2.java +++ /dev/null @@ -1,328 +0,0 @@ -/*- - * #%L - * This project complements the DL-model runner acting as the engine that works loading models - * and making inference with Java API for Tensorflow 1. - * %% - * Copyright (C) 2022 - 2023 Institut Pasteur and BioImage.IO developers. - * %% - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * #L% - */ -package io.bioimage.modelrunner.tensorflow.v1.tensor.mappedbuffer; - -import java.nio.ByteBuffer; -import java.nio.charset.StandardCharsets; -import java.util.Arrays; -import java.util.HashMap; -import java.util.regex.Matcher; -import java.util.regex.Pattern; - -import io.bioimage.modelrunner.tensor.Tensor; -import net.imglib2.Cursor; -import net.imglib2.RandomAccessibleInterval; -import net.imglib2.img.Img; -import net.imglib2.img.array.ArrayImgFactory; -import net.imglib2.type.NativeType; -import net.imglib2.type.Type; -import net.imglib2.type.numeric.RealType; -import net.imglib2.type.numeric.integer.ByteType; -import net.imglib2.type.numeric.integer.IntType; -import net.imglib2.type.numeric.real.DoubleType; -import net.imglib2.type.numeric.real.FloatType; - -/** - * A {@link Img} builder from {@link ByteBuffer} objects - * - * @author Carlos Garcia Lopez de Haro - */ -public final class MappedBufferToImgLib2 -{ - /** - * Pattern that matches the header of the temporal file for interprocess communication - * and retrieves data type, shape, name and axes - */ - private static final Pattern HEADER_PATTERN = Pattern.compile("'dtype':'([a-zA-Z0-9]+)'" - + ",'axes':'([a-zA-Z]+)'" - + ",'name':'([^']*)'" - + ",'shape':'(\\[\\s*(?:(?:[1-9]\\d*|0)\\s*,\\s*)*(?:[1-9]\\d*|0)?\\s*\\])'"); - /** - * Key for data type info - */ - private static final String DATA_TYPE_KEY = "dtype"; - /** - * Key for shape info - */ - private static final String SHAPE_KEY = "shape"; - /** - * Key for axes info - */ - private static final String AXES_KEY = "axes"; - /** - * Key for axes info - */ - private static final String NAME_KEY = "name"; - - /** - * Not used (Utility class). - */ - private MappedBufferToImgLib2() - { - } - - /** - * Creates a {@link Tensor} from the information stored in a {@link ByteBuffer} - * - * @param - * the type of the generated tensor - * @param buff - * byte buffer to get the tensor info from - * @return the tensor generated from the bytebuffer - * @throws IllegalArgumentException if the data type of the tensor saved in the bytebuffer is - * not supported - */ - @SuppressWarnings("unchecked") - public static < T extends RealType< T > & NativeType< T > > Tensor buildTensor(ByteBuffer buff) throws IllegalArgumentException - { - String infoStr = getTensorInfoFromBuffer(buff); - HashMap map = getInfoFromHeaderString(infoStr); - String dtype = (String) map.get(DATA_TYPE_KEY); - String axes = (String) map.get(AXES_KEY); - String name = (String) map.get(NAME_KEY); - long[] shape = (long[]) map.get(SHAPE_KEY); - if (shape.length == 0) - return Tensor.buildEmptyTensor(name, axes); - - Img data; - switch (dtype) - { - case "byte": - data = (Img) buildFromTensorByte(buff, shape); - break; - case "int32": - data = (Img) buildFromTensorInt(buff, shape); - break; - case "float32": - data = (Img) buildFromTensorFloat(buff, shape); - break; - case "float64": - data = (Img) buildFromTensorDouble(buff, shape); - break; - default: - throw new IllegalArgumentException("Unsupported tensor type: " + dtype); - } - return Tensor.build(name, axes, (RandomAccessibleInterval) data); - } - - /** - * Creates a {@link Img} from the information stored in a {@link ByteBuffer} - * - * @param - * data type of the image - * @param byteBuff - * The bytebyuffer that contains info to create a tenosr or a {@link Img} - * @return The imglib2 image {@link Img} built from the bytebuffer info. - * @throws IllegalArgumentException if the data type of the tensor saved in the bytebuffer is - * not supported - */ - @SuppressWarnings("unchecked") - public static > Img build(ByteBuffer byteBuff) throws IllegalArgumentException - { - String infoStr = getTensorInfoFromBuffer(byteBuff); - HashMap map = getInfoFromHeaderString(infoStr); - String dtype = (String) map.get(DATA_TYPE_KEY); - long[] shape = (long[]) map.get(SHAPE_KEY); - if (shape.length == 0) - return null; - - // Create an INDArray of the same type of the tensor - switch (dtype) - { - case "byte": - return (Img) buildFromTensorByte(byteBuff, shape); - case "int32": - return (Img) buildFromTensorInt(byteBuff, shape); - case "float32": - return (Img) buildFromTensorFloat(byteBuff, shape); - case "float64": - return (Img) buildFromTensorDouble(byteBuff, shape); - default: - throw new IllegalArgumentException("Unsupported tensor type: " + dtype); - } - } - - /** - * Builds a ByteType {@link Img} from the information stored in a byte buffer. - * The shape of the image that was previously retrieved from the buffer - * @param tensor - * byte buffer containing the information of the a tenosr, the position in the buffer - * should not be at zero but right after the header. - * @param tensorShape - * shape of the image to generate, it has been retrieved from the byte buffer - * @return image specified in the bytebuffer - */ - private static Img buildFromTensorByte(ByteBuffer tensor, long[] tensorShape) - { - final ArrayImgFactory< ByteType > factory = new ArrayImgFactory<>( new ByteType() ); - final Img< ByteType > outputImg = (Img) factory.create(tensorShape); - Cursor tensorCursor= outputImg.cursor(); - while (tensorCursor.hasNext()) { - tensorCursor.fwd(); - tensorCursor.get().set(tensor.get()); - } - return outputImg; - } - - /** - * Builds a IntType {@link Img} from the information stored in a byte buffer. - * The shape of the image that was previously retrieved from the buffer - * @param tensor - * byte buffer containing the information of the a tenosr, the position in the buffer - * should not be at zero but right after the header. - * @param tensorShape - * shape of the image to generate, it has been retrieved from the byte buffer - * @return image specified in the bytebuffer - */ - private static Img buildFromTensorInt(ByteBuffer tensor, long[] tensorShape) - { - final ArrayImgFactory< IntType > factory = new ArrayImgFactory<>( new IntType() ); - final Img< IntType > outputImg = (Img) factory.create(tensorShape); - Cursor tensorCursor= outputImg.cursor(); - byte[] bytes = new byte[4]; - while (tensorCursor.hasNext()) { - tensorCursor.fwd(); - tensor.get(bytes); - int val = ((int) (bytes[0] << 24)) + ((int) (bytes[1] << 16)) - + ((int) (bytes[2] << 8)) + ((int) (bytes[3])); - tensorCursor.get().set(val); - } - return outputImg; - } - - /** - * Builds a FloatType {@link Img} from the information stored in a byte buffer. - * The shape of the image that was previously retrieved from the buffer - * @param tensor - * byte buffer containing the information of the a tenosr, the position in the buffer - * should not be at zero but right after the header. - * @param tensorShape - * shape of the image to generate, it has been retrieved from the byte buffer - * @return image specified in the bytebuffer - */ - private static Img buildFromTensorFloat(ByteBuffer tensor, long[] tensorShape) - { - final ArrayImgFactory< FloatType > factory = new ArrayImgFactory<>( new FloatType() ); - final Img< FloatType > outputImg = (Img) factory.create(tensorShape); - Cursor tensorCursor= outputImg.cursor(); - byte[] bytes = new byte[4]; - while (tensorCursor.hasNext()) { - tensorCursor.fwd(); - tensor.get(bytes); - float val = ByteBuffer.wrap(bytes).getFloat(); - tensorCursor.get().set(val); - } - return outputImg; - } - - /** - * Builds a DoubleType {@link Img} from the information stored in a byte buffer. - * The shape of the image that was previously retrieved from the buffer - * @param tensor - * byte buffer containing the information of the a tenosr, the position in the buffer - * should not be at zero but right after the header. - * @param tensorShape - * shape of the image to generate, it has been retrieved from the byte buffer - * @return image specified in the bytebuffer - */ - private static Img buildFromTensorDouble(ByteBuffer tensor, long[] tensorShape) - { - final ArrayImgFactory< DoubleType > factory = new ArrayImgFactory<>( new DoubleType() ); - final Img< DoubleType > outputImg = (Img) factory.create(tensorShape); - Cursor tensorCursor= outputImg.cursor(); - byte[] bytes = new byte[8]; - while (tensorCursor.hasNext()) { - tensorCursor.fwd(); - tensor.get(bytes); - double val = ByteBuffer.wrap(bytes).getDouble(); - tensorCursor.get().set(val); - } - return outputImg; - } - - /** - * Method that returns the information about the tensor specified at the - * beginning of the {@link ByteBuffer} object created - * with {@link ImgLib2ToMappedBuffer#build(Tensor,ByteBuffer)}. - * This method reads the buffer from the beginning - * @param buff - * ByteBuffer containing the information about the tensor - * @return map containing the name, axes order, datatype and shape of the tensor - * stored in teh buffer - */ - public static HashMap readHeaderAndGetInfo(ByteBuffer buff) { - buff.clear(); - return getInfoFromHeaderString(getTensorInfoFromBuffer(buff)); - } - - /** - * GEt the String info stored at the beginning of the buffer that contains - * the data type, name of tensor, axes and shape info. - * @param buff - * buffer containing all the data to generate a tensor - * @return the String header of teh bytebuffer that contains the data about - * the tensor (name, data type, shape and axes) - */ - private static String getTensorInfoFromBuffer(ByteBuffer buff) { - byte[] arr = new byte[ImgLib2ToMappedBuffer.MODEL_RUNNER_HEADER.length]; - buff.get(arr); - if (!Arrays.equals(arr, ImgLib2ToMappedBuffer.MODEL_RUNNER_HEADER)) - throw new IllegalArgumentException("Error sending tensors between processes."); - byte[] lenInfoInBytes = new byte[4]; - buff.get(lenInfoInBytes); - int lenInfo = ByteBuffer.wrap(lenInfoInBytes).getInt(); - byte[] stringInfoBytes = new byte[lenInfo]; - buff.get(stringInfoBytes); - return new String(stringInfoBytes, StandardCharsets.UTF_8); - } - - /** - * MEthod that retrieves the data type string and shape long array representing - * the data type and dimensions of the tensor saved in the temp file - * @param infoStr - * string header of the file that contains the info about the tensor - * @return dictionary containins the name, dtype, shape and axes of the tensor - */ - private static HashMap getInfoFromHeaderString(String infoStr) { - Matcher matcher = HEADER_PATTERN.matcher(infoStr); - if (!matcher.find()) { - throw new IllegalArgumentException("Cannot find datatype, name, axes and dimensions " - + "info in file header: " + infoStr); - } - String typeStr = matcher.group(1); - String axesStr = matcher.group(2); - String nameStr = matcher.group(3); - String shapeStr = matcher.group(4); - - long[] shape = new long[0]; - if (!shapeStr.isEmpty() && !shapeStr.equals("[]")) { - shapeStr = shapeStr.substring(1, shapeStr.length() - 1); - String[] tokens = shapeStr.split(", ?"); - shape = Arrays.stream(tokens).mapToLong(Long::parseLong).toArray(); - } - HashMap map = new HashMap(); - map.put(DATA_TYPE_KEY, typeStr); - map.put(AXES_KEY, axesStr); - map.put(SHAPE_KEY, shape); - map.put(NAME_KEY, nameStr); - return map; - } -}