Skip to content

Commit

Permalink
finish adapting to persistent memory
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Sep 24, 2024
1 parent ee36ef0 commit 4eb89ff
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 99 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import io.bioimage.modelrunner.system.PlatformDetection;
import io.bioimage.modelrunner.tensor.Tensor;
import io.bioimage.modelrunner.tensor.shm.SharedMemoryArray;
import io.bioimage.modelrunner.tensorflow.v1.shm.ShmBuilder;
import io.bioimage.modelrunner.tensorflow.v1.tensor.ImgLib2Builder;
import io.bioimage.modelrunner.tensorflow.v1.tensor.TensorBuilder;
import io.bioimage.modelrunner.utils.CommonUtils;
Expand All @@ -45,28 +46,17 @@
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.util.Cast;
import net.imglib2.util.Util;

import java.io.BufferedReader;
import java.io.File;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.RandomAccessFile;
import java.io.UnsupportedEncodingException;
import java.net.URISyntaxException;
import java.net.URL;
import java.net.URLDecoder;
import java.nio.ByteBuffer;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.security.ProtectionDomain;
import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
Expand Down Expand Up @@ -332,7 +322,7 @@ protected void runFromShmas(List<String> inputs, List<String> outputs) throws IO
for (String ee : inputs) {
Map<String, Object> decoded = Types.decode(ee);
SharedMemoryArray shma = SharedMemoryArray.read((String) decoded.get(MEM_NAME_KEY));
org.tensorflow.Tensor<?> inT = io.bioimage.modelrunner.tensorflow.v2.api030.shm.TensorBuilder.build(shma);
org.tensorflow.Tensor<?> inT = io.bioimage.modelrunner.tensorflow.v1.shm.TensorBuilder.build(shma);
if (PlatformDetection.isWindows()) shma.close();
inTensors.add(inT);
String inputName = getModelInputName((String) decoded.get(NAME_KEY), c ++);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,7 @@
import java.util.Arrays;

import org.tensorflow.Tensor;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.TFloat64;
import org.tensorflow.types.TInt32;
import org.tensorflow.types.TInt64;
import org.tensorflow.types.TUint8;
import org.tensorflow.types.family.TType;
import org.tensorflow.types.UInt8;

import net.imglib2.RandomAccessibleInterval;
import net.imglib2.type.numeric.integer.IntType;
Expand Down Expand Up @@ -70,20 +65,20 @@ private ShmBuilder()
* @throws IOException
*/
@SuppressWarnings("unchecked")
public static void build(Tensor<? extends TType> tensor, String memoryName) throws IllegalArgumentException, IOException
public static void build(Tensor<?> tensor, String memoryName) throws IllegalArgumentException, IOException
{
switch (tensor.dataType().name())
switch (tensor.dataType())
{
case TUint8.NAME:
buildFromTensorUByte((Tensor<TUint8>) tensor, memoryName);
case TInt32.NAME:
buildFromTensorInt((Tensor<TInt32>) tensor, memoryName);
case TFloat32.NAME:
buildFromTensorFloat((Tensor<TFloat32>) tensor, memoryName);
case TFloat64.NAME:
buildFromTensorDouble((Tensor<TFloat64>) tensor, memoryName);
case TInt64.NAME:
buildFromTensorLong((Tensor<TInt64>) tensor, memoryName);
case UINT8:
buildFromTensorUByte((Tensor<UInt8>) tensor, memoryName);
case INT32:
buildFromTensorInt((Tensor<Integer>) tensor, memoryName);
case FLOAT:
buildFromTensorFloat((Tensor<Float>) tensor, memoryName);
case DOUBLE:
buildFromTensorDouble((Tensor<Double>) tensor, memoryName);
case INT64:
buildFromTensorLong((Tensor<Long>) tensor, memoryName);
default:
throw new IllegalArgumentException("Unsupported tensor type: " + tensor.dataType().name());
}
Expand All @@ -97,20 +92,15 @@ public static void build(Tensor<? extends TType> tensor, String memoryName) thr
* @return The {@link RandomAccessibleInterval} built from the tensor, of type {@link UnsignedByteType}.
* @throws IOException
*/
private static void buildFromTensorUByte(Tensor<TUint8> tensor, String memoryName) throws IOException
private static void buildFromTensorUByte(Tensor<UInt8> tensor, String memoryName) throws IOException
{
long[] arrayShape = tensor.shape().asArray();
long[] arrayShape = tensor.shape();
if (CommonUtils.int32Overflows(arrayShape, 1))
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
+ " is too big. Max number of elements per ubyte output tensor supported: " + Integer.MAX_VALUE / 1);
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new UnsignedByteType(), false, true);
ByteBuffer buff = shma.getDataBuffer();
int totalSize = 1;
for (long i : arrayShape) {totalSize *= i;}
byte[] flatArr = new byte[buff.capacity()];
buff.get(flatArr);
tensor.rawData().read(flatArr, flatArr.length - totalSize, totalSize);
shma.setBuffer(ByteBuffer.wrap(flatArr));
tensor.writeTo(buff);
if (PlatformDetection.isWindows()) shma.close();
}

Expand All @@ -122,21 +112,16 @@ private static void buildFromTensorUByte(Tensor<TUint8> tensor, String memoryNam
* @return The {@link RandomAccessibleInterval} built from the tensor, of type {@link IntType}.
* @throws IOException
*/
private static void buildFromTensorInt(Tensor<TInt32> tensor, String memoryName) throws IOException
private static void buildFromTensorInt(Tensor<Integer> tensor, String memoryName) throws IOException
{
long[] arrayShape = tensor.shape().asArray();
long[] arrayShape = tensor.shape();
if (CommonUtils.int32Overflows(arrayShape, 4))
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
+ " is too big. Max number of elements per int output tensor supported: " + Integer.MAX_VALUE / 4);

SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new IntType(), false, true);
ByteBuffer buff = shma.getDataBuffer();
int totalSize = 4;
for (long i : arrayShape) {totalSize *= i;}
byte[] flatArr = new byte[buff.capacity()];
buff.get(flatArr);
tensor.rawData().read(flatArr, flatArr.length - totalSize, totalSize);
shma.setBuffer(ByteBuffer.wrap(flatArr));
tensor.writeTo(buff);
if (PlatformDetection.isWindows()) shma.close();
}

Expand All @@ -148,21 +133,16 @@ private static void buildFromTensorInt(Tensor<TInt32> tensor, String memoryName)
* @return The {@link RandomAccessibleInterval} built from the tensor, of type {@link FloatType}.
* @throws IOException
*/
private static void buildFromTensorFloat(Tensor<TFloat32> tensor, String memoryName) throws IOException
private static void buildFromTensorFloat(Tensor<Float> tensor, String memoryName) throws IOException
{
long[] arrayShape = tensor.shape().asArray();
long[] arrayShape = tensor.shape();
if (CommonUtils.int32Overflows(arrayShape, 4))
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
+ " is too big. Max number of elements per float output tensor supported: " + Integer.MAX_VALUE / 4);

SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new FloatType(), false, true);
ByteBuffer buff = shma.getDataBuffer();
int totalSize = 4;
for (long i : arrayShape) {totalSize *= i;}
byte[] flatArr = new byte[buff.capacity()];
buff.get(flatArr);
tensor.rawData().read(flatArr, flatArr.length - totalSize, totalSize);
shma.setBuffer(ByteBuffer.wrap(flatArr));
tensor.writeTo(buff);
if (PlatformDetection.isWindows()) shma.close();
}

Expand All @@ -174,21 +154,16 @@ private static void buildFromTensorFloat(Tensor<TFloat32> tensor, String memoryN
* @return The {@link RandomAccessibleInterval} built from the tensor, of type {@link DoubleType}.
* @throws IOException
*/
private static void buildFromTensorDouble(Tensor<TFloat64> tensor, String memoryName) throws IOException
private static void buildFromTensorDouble(Tensor<Double> tensor, String memoryName) throws IOException
{
long[] arrayShape = tensor.shape().asArray();
long[] arrayShape = tensor.shape();
if (CommonUtils.int32Overflows(arrayShape, 8))
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
+ " is too big. Max number of elements per double output tensor supported: " + Integer.MAX_VALUE / 8);

SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new DoubleType(), false, true);
ByteBuffer buff = shma.getDataBuffer();
int totalSize = 8;
for (long i : arrayShape) {totalSize *= i;}
byte[] flatArr = new byte[buff.capacity()];
buff.get(flatArr);
tensor.rawData().read(flatArr, flatArr.length - totalSize, totalSize);
shma.setBuffer(ByteBuffer.wrap(flatArr));
tensor.writeTo(buff);
if (PlatformDetection.isWindows()) shma.close();
}

Expand All @@ -200,22 +175,17 @@ private static void buildFromTensorDouble(Tensor<TFloat64> tensor, String memory
* @return The {@link RandomAccessibleInterval} built from the tensor, of type {@link LongType}.
* @throws IOException
*/
private static void buildFromTensorLong(Tensor<TInt64> tensor, String memoryName) throws IOException
private static void buildFromTensorLong(Tensor<Long> tensor, String memoryName) throws IOException
{
long[] arrayShape = tensor.shape().asArray();
long[] arrayShape = tensor.shape();
if (CommonUtils.int32Overflows(arrayShape, 8))
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
+ " is too big. Max number of elements per long output tensor supported: " + Integer.MAX_VALUE / 8);


SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new LongType(), false, true);
ByteBuffer buff = shma.getDataBuffer();
int totalSize = 8;
for (long i : arrayShape) {totalSize *= i;}
byte[] flatArr = new byte[buff.capacity()];
buff.get(flatArr);
tensor.rawData().read(flatArr, flatArr.length - totalSize, totalSize);
shma.setBuffer(ByteBuffer.wrap(flatArr));
tensor.writeTo(buff);
if (PlatformDetection.isWindows()) shma.close();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,7 @@
import java.util.Arrays;

import org.tensorflow.Tensor;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.ndarray.buffer.ByteDataBuffer;
import org.tensorflow.ndarray.buffer.DoubleDataBuffer;
import org.tensorflow.ndarray.buffer.FloatDataBuffer;
import org.tensorflow.ndarray.buffer.IntDataBuffer;
import org.tensorflow.ndarray.buffer.LongDataBuffer;
import org.tensorflow.ndarray.impl.buffer.raw.RawDataBufferFactory;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.TFloat64;
import org.tensorflow.types.TInt32;
import org.tensorflow.types.TInt64;
import org.tensorflow.types.TUint8;
import org.tensorflow.types.family.TType;
import org.tensorflow.types.UInt8;

/**
* A TensorFlow 2 {@link Tensor} builder from {@link Img} and
Expand Down Expand Up @@ -80,7 +68,7 @@ private TensorBuilder() {}
* @throws IllegalArgumentException if the type of the {@link RandomAccessibleInterval}
* is not supported
*/
public static Tensor<? extends TType> build(SharedMemoryArray array) throws IllegalArgumentException
public static Tensor<?> build(SharedMemoryArray array) throws IllegalArgumentException
{
// Create an Icy sequence of the same type of the tensor
if (array.getOriginalDataType().equals("uint8")) {
Expand Down Expand Up @@ -113,7 +101,7 @@ else if (array.getOriginalDataType().equals("int64")) {
* @throws IllegalArgumentException if the input {@link RandomAccessibleInterval} type is
* not compatible
*/
public static Tensor<TUint8> buildUByte(SharedMemoryArray tensor)
public static Tensor<UInt8> buildUByte(SharedMemoryArray tensor)
throws IllegalArgumentException
{
long[] ogShape = tensor.getOriginalShape();
Expand All @@ -123,8 +111,7 @@ public static Tensor<TUint8> buildUByte(SharedMemoryArray tensor)
if (!tensor.isNumpyFormat())
throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format.");
ByteBuffer buff = tensor.getDataBufferNoHeader();
ByteDataBuffer dataBuffer = RawDataBufferFactory.create(buff.array(), false);
Tensor<TUint8> ndarray = Tensor.of(TUint8.DTYPE, Shape.of(ogShape), dataBuffer);
Tensor<UInt8> ndarray = Tensor.create(UInt8.class, ogShape, buff);
return ndarray;
}

Expand All @@ -138,7 +125,7 @@ public static Tensor<TUint8> buildUByte(SharedMemoryArray tensor)
* @throws IllegalArgumentException if the input {@link RandomAccessibleInterval} type is
* not compatible
*/
public static Tensor<TInt32> buildInt(SharedMemoryArray tensor)
public static Tensor<Integer> buildInt(SharedMemoryArray tensor)
throws IllegalArgumentException
{
long[] ogShape = tensor.getOriginalShape();
Expand All @@ -151,8 +138,7 @@ public static Tensor<TInt32> buildInt(SharedMemoryArray tensor)
IntBuffer intBuff = buff.asIntBuffer();
int[] intArray = new int[intBuff.capacity()];
intBuff.get(intArray);
IntDataBuffer dataBuffer = RawDataBufferFactory.create(intArray, false);
Tensor<TInt32> ndarray = TInt32.tensorOf(Shape.of(ogShape), dataBuffer);
Tensor<Integer> ndarray = Tensor.create(ogShape, intBuff);
return ndarray;
}

Expand All @@ -166,7 +152,7 @@ public static Tensor<TInt32> buildInt(SharedMemoryArray tensor)
* @throws IllegalArgumentException if the input {@link RandomAccessibleInterval} type is
* not compatible
*/
private static Tensor<TInt64> buildLong(SharedMemoryArray tensor)
private static Tensor<Long> buildLong(SharedMemoryArray tensor)
throws IllegalArgumentException
{
long[] ogShape = tensor.getOriginalShape();
Expand All @@ -179,8 +165,7 @@ private static Tensor<TInt64> buildLong(SharedMemoryArray tensor)
LongBuffer longBuff = buff.asLongBuffer();
long[] longArray = new long[longBuff.capacity()];
longBuff.get(longArray);
LongDataBuffer dataBuffer = RawDataBufferFactory.create(longArray, false);
Tensor<TInt64> ndarray = TInt64.tensorOf(Shape.of(ogShape), dataBuffer);
Tensor<Long> ndarray = Tensor.create(ogShape, longBuff);
return ndarray;
}

Expand All @@ -194,7 +179,7 @@ private static Tensor<TInt64> buildLong(SharedMemoryArray tensor)
* @throws IllegalArgumentException if the input {@link RandomAccessibleInterval} type is
* not compatible
*/
public static Tensor<TFloat32> buildFloat(SharedMemoryArray tensor)
public static Tensor<Float> buildFloat(SharedMemoryArray tensor)
throws IllegalArgumentException
{
long[] ogShape = tensor.getOriginalShape();
Expand All @@ -207,8 +192,7 @@ public static Tensor<TFloat32> buildFloat(SharedMemoryArray tensor)
FloatBuffer floatBuff = buff.asFloatBuffer();
float[] floatArray = new float[floatBuff.capacity()];
floatBuff.get(floatArray);
FloatDataBuffer dataBuffer = RawDataBufferFactory.create(floatArray, false);
Tensor<TFloat32> ndarray = TFloat32.tensorOf(Shape.of(ogShape), dataBuffer);
Tensor<Float> ndarray = Tensor.create(ogShape, floatBuff);
return ndarray;
}

Expand All @@ -222,7 +206,7 @@ public static Tensor<TFloat32> buildFloat(SharedMemoryArray tensor)
* @throws IllegalArgumentException if the input {@link RandomAccessibleInterval} type is
* not compatible
*/
private static Tensor<TFloat64> buildDouble(SharedMemoryArray tensor)
private static Tensor<Double> buildDouble(SharedMemoryArray tensor)
throws IllegalArgumentException
{
long[] ogShape = tensor.getOriginalShape();
Expand All @@ -235,8 +219,7 @@ private static Tensor<TFloat64> buildDouble(SharedMemoryArray tensor)
DoubleBuffer doubleBuff = buff.asDoubleBuffer();
double[] doubleArray = new double[doubleBuff.capacity()];
doubleBuff.get(doubleArray);
DoubleDataBuffer dataBuffer = RawDataBufferFactory.create(doubleArray, false);
Tensor<TFloat64> ndarray = TFloat64.tensorOf(Shape.of(ogShape), dataBuffer);
Tensor<Double> ndarray = Tensor.create(ogShape, doubleBuff);
return ndarray;
}
}

0 comments on commit 4eb89ff

Please sign in to comment.