From c73789ff4e3240bdb26d97ba160de48f861de3ec Mon Sep 17 00:00:00 2001 From: carlosuc3m <100329787@alumnos.uc3m.es> Date: Fri, 15 Dec 2023 20:32:17 +0100 Subject: [PATCH] fallback if the tf names are not perfect --- .../tensorflow/v1/Tensorflow1Interface.java | 35 ++++++++++++++++--- 1 file changed, 30 insertions(+), 5 deletions(-) diff --git a/src/main/java/io/bioimage/modelrunner/tensorflow/v1/Tensorflow1Interface.java b/src/main/java/io/bioimage/modelrunner/tensorflow/v1/Tensorflow1Interface.java index 643c758..bc0c4cd 100644 --- a/src/main/java/io/bioimage/modelrunner/tensorflow/v1/Tensorflow1Interface.java +++ b/src/main/java/io/bioimage/modelrunner/tensorflow/v1/Tensorflow1Interface.java @@ -38,6 +38,7 @@ import io.bioimage.modelrunner.utils.ZipUtils; import net.imglib2.type.NativeType; import net.imglib2.type.numeric.RealType; +import net.imglib2.type.numeric.real.FloatType; import java.io.BufferedReader; import java.io.File; @@ -312,15 +313,17 @@ public void run(List> inputTensors, List> outputTensors) List inputListNames = new ArrayList(); List> inTensors = new ArrayList>(); + int c = 0; for (Tensor tt : inputTensors) { inputListNames.add(tt.getName()); org.tensorflow.Tensor inT = TensorBuilder.build(tt); inTensors.add(inT); - runner.feed(getModelInputName(tt.getName()), inT); + String inputName = getModelInputName(tt.getName(), c ++); + runner.feed(inputName, inT); } - + c = 0; for (Tensor tt : outputTensors) - runner = runner.fetch(getModelOutputName(tt.getName())); + runner = runner.fetch(getModelOutputName(tt.getName(), c ++)); // Run runner List> resultPatchTensors = runner.run(); @@ -426,10 +429,14 @@ public void closeModel() { * the signature input name. * * @param inputName Signature input name. + * @param i position of the input of interest in the list of inputs * @return The readable input name. */ - public static String getModelInputName(String inputName) { + public static String getModelInputName(String inputName, int i) { TensorInfo inputInfo = sig.getInputsMap().getOrDefault(inputName, null); + if (inputInfo == null) { + inputInfo = sig.getInputsMap().values().stream().collect(Collectors.toList()).get(i); + } if (inputInfo != null) { String modelInputName = inputInfo.getName(); if (modelInputName != null) { @@ -452,10 +459,14 @@ public static String getModelInputName(String inputName) { * given the signature output name. * * @param outputName Signature output name. + * @param i position of the input of interest in the list of inputs * @return The readable output name. */ - public static String getModelOutputName(String outputName) { + public static String getModelOutputName(String outputName, int i) { TensorInfo outputInfo = sig.getOutputsMap().getOrDefault(outputName, null); + if (outputInfo == null) { + outputInfo = sig.getOutputsMap().values().stream().collect(Collectors.toList()).get(i); + } if (outputInfo != null) { String modelOutputName = outputInfo.getName(); if (modelOutputName.endsWith(":0")) { @@ -495,6 +506,20 @@ public static String getModelOutputName(String outputName) { * @throws RunModelException if there is any error running the model */ public static void main(String[] args) throws LoadModelException, IOException, RunModelException { + if (args.length == 0) { + String modelFolder = "/home/carlos/git/deep-icy/models/stardist_1channel"; + Tensorflow1Interface ti = new Tensorflow1Interface(false); + ti.loadModel(modelFolder, modelFolder); + Tensor inp = Tensor.buildBlankTensor("in", "byxc", new long[] {1, 208, 208, 1}, new FloatType()); + Tensor out = Tensor.buildEmptyTensor("out", "byxc"); + List> inps = new ArrayList>(); + inps.add(inp); + List> outs = new ArrayList>(); + outs.add(out); + ti.run(inps, outs); + System.out.println(false); + return; + } // Unpack the args needed if (args.length < 4) throw new IllegalArgumentException("Error exectuting Tensorflow 1, "