diff --git a/modules/io/src/main/java/deepboof/io/torch7/BaseParserTorch7.java b/modules/io/src/main/java/deepboof/io/torch7/BaseParserTorch7.java index 40aa790..15f98ac 100644 --- a/modules/io/src/main/java/deepboof/io/torch7/BaseParserTorch7.java +++ b/modules/io/src/main/java/deepboof/io/torch7/BaseParserTorch7.java @@ -217,9 +217,7 @@ private TorchStorage parseStorage( String name ) throws IOException { case "torch.DoubleStorage":{ TorchDoubleStorage t = new TorchDoubleStorage(size); - for (int i = 0; i < size; i++) { - t.data[i] = readDouble(); - } + readArrayDouble(size,t.data); out = t; }break; diff --git a/modules/io/src/main/java/deepboof/io/torch7/ParseAsciiTorch7.java b/modules/io/src/main/java/deepboof/io/torch7/ParseAsciiTorch7.java index 6ce3418..511e3f6 100644 --- a/modules/io/src/main/java/deepboof/io/torch7/ParseAsciiTorch7.java +++ b/modules/io/src/main/java/deepboof/io/torch7/ParseAsciiTorch7.java @@ -89,10 +89,14 @@ public int readU8() throws IOException { @Override public void readArrayDouble(int size, double[] storage) throws IOException { + String line = readInnerString(); + String words[] = line.split(" "); + if( words.length != size ) + throw new IOException("Unexpected number of words "+size+" found "+words.length); for (int i = 0; i < size; i++) { - storage[i] = readDouble(); + storage[i] = Double.parseDouble(words[i]); } - input.readByte(); +// int foo = input.readByte(); } @Override diff --git a/modules/io/src/main/java/deepboof/io/torch7/struct/TorchGeneric.java b/modules/io/src/main/java/deepboof/io/torch7/struct/TorchGeneric.java index bcb13e8..0bb7c8b 100644 --- a/modules/io/src/main/java/deepboof/io/torch7/struct/TorchGeneric.java +++ b/modules/io/src/main/java/deepboof/io/torch7/struct/TorchGeneric.java @@ -20,6 +20,7 @@ import deepboof.io.torch7.ConvertTorchToBoofForward; import deepboof.tensors.Tensor_F32; +import deepboof.tensors.Tensor_F64; import deepboof.tensors.Tensor_U8; import java.util.HashMap; @@ -50,4 +51,8 @@ public Tensor_U8 getTensorU8(String key ) { public Tensor_F32 getTensorF32(String key ) { return (Tensor_F32)ConvertTorchToBoofForward.convert(map.get(key)); } + + public Tensor_F64 getTensorF64(String key ) { + return (Tensor_F64)ConvertTorchToBoofForward.convert(map.get(key)); + } }