diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala index 607a0c782619..57eca2140664 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala @@ -106,7 +106,10 @@ object IO { checkCall(_LIB.mxDataIterCreateIter(handle, keys, vals, out)) val dataName = params.getOrElse("data_name", "data") val labelName = params.getOrElse("label_name", "label") - new MXDataIter(out.value, dataName, labelName) + val dataLayout = params.getOrElse("dataLayout", "NCHW") + val labelLayout = params.getOrElse("labelLayout", "N") + new MXDataIter(out.value, dataName, labelName, + dataLayout = dataLayout, labelLayout = labelLayout) } // Convert data into canonical form. @@ -142,7 +145,8 @@ class DataBatch(val data: IndexedSeq[NDArray], private val providedData: ListMap[String, Shape] = null, private val providedLabel: ListMap[String, Shape] = null, val dtype: DType = Base.MX_REAL_TYPE, - val layout: String = "NCHW") { + val dataLayout: String = "NCHW", + val labelLayout: String = "N") { /** * Dispose its data and labels * The object shall never be used after it is disposed. @@ -172,7 +176,8 @@ object DataBatch { private var label: IndexedSeq[NDArray] = null private var index: IndexedSeq[Long] = null private var pad: Int = 0 - private var layout: String = "NCHW" + private var dataLayout: String = "NCHW" + private var labelLayout: String = "N" private var dtype: DType = Base.MX_REAL_TYPE private var bucketKey: AnyRef = null private var datatShapes: ListMap[String, Shape] = null @@ -232,11 +237,13 @@ object DataBatch { /** * Set the layout. - * @param layout The layout of the label, default is NCHW + * @param dataLayout The layout of the data, default is NCHW + * @param labelLayout The layout of the label, default is N * @return this */ - def setLayout(layout: String): Builder = { - this.layout = layout + def setLayout(dataLayout: String, labelLayout: String): Builder = { + this.dataLayout = dataLayout + this.labelLayout = labelLayout this } @@ -282,7 +289,8 @@ object DataBatch { def build(): DataBatch = { require(data != null, "data is required.") - new DataBatch(data, label, index, pad, bucketKey, datatShapes, labelShapes, dtype, layout) + new DataBatch(data, label, index, pad, bucketKey, datatShapes, labelShapes, + dtype, dataLayout, labelLayout) } } } @@ -305,7 +313,7 @@ abstract class DataIter extends Iterator[DataBatch] { @throws(classOf[NoSuchElementException]) def next(): DataBatch = { new DataBatch(getData(), getLabel(), getIndex(), getPad(), - dtype = getDType(), layout = getLayout()) + dtype = getDType(), dataLayout = getLayout()._1, labelLayout = getLayout()._2) } /** @@ -335,9 +343,9 @@ abstract class DataIter extends Iterator[DataBatch] { /** * Get the layout - * @return layout of the DataIter + * @return data and label layout of the DataIter */ - def getLayout(): String + def getLayout(): (String, String) /** * Get the index of current batch diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala index 4c6a64d99034..5d8fd6d07512 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala @@ -32,7 +32,10 @@ import scala.collection.mutable.ListBuffer */ private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle, dataName: String = "data", - labelName: String = "label") + labelName: String = "label", + dtype: DType = DType.Float32, + dataLayout: String = "NCHW", + labelLayout: String = "N") extends DataIter with WarnIfNotDisposed { private val logger = LoggerFactory.getLogger(classOf[MXDataIter]) @@ -65,10 +68,11 @@ private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle, val data = currentBatch.data(0) val label = currentBatch.label(0) val dType = currentBatch.dtype - val layout = currentBatch.layout + val dataLayout = currentBatch.dataLayout + val labelLayout = currentBatch.labelLayout // properties - val res = (IndexedSeq(new DataDesc(dataName, data.shape, dType, layout)), - IndexedSeq(new DataDesc(labelName, label.shape, dType, layout)), + val res = (IndexedSeq(new DataDesc(dataName, data.shape, dType, dataLayout)), + IndexedSeq(new DataDesc(labelName, label.shape, dType, labelLayout)), data.shape(0)) currentBatch.dispose() reset() @@ -126,7 +130,8 @@ private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle, if (next.value > 0) { currentBatch = new DataBatch(data = getData(), label = getLabel(), index = getIndex(), pad = getPad(), - dtype = currentBatch.dtype, layout = currentBatch.layout) + dtype = getDType(), dataLayout = getLayout()._1, + labelLayout = getLayout()._2) } else { currentBatch = null } @@ -179,17 +184,13 @@ private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle, * Get the DType * @return DType */ - def getDType(): DType = { - currentBatch.dtype - } + def getDType(): DType = dtype /** * Get the layout * @return layout */ - def getLayout(): String = { - currentBatch.layout - } + def getLayout(): (String, String) = (dataLayout, labelLayout) // The name and shape of data provided by this iterator override def provideData: ListMap[String, Shape] = _provideData diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala index d34390bacdca..53f7c352c636 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala @@ -43,7 +43,8 @@ import scala.collection.immutable.ListMap class NDArrayIter(data: IndexedSeq[(String, NDArray)], label: IndexedSeq[(String, NDArray)], private val dataBatchSize: Int, shuffle: Boolean, - lastBatchHandle: String, dtype: DType, layout: String) extends DataIter { + lastBatchHandle: String, + dtype: DType, dataLayout: String, labelLayout: String) extends DataIter { /** * @param data Specify the data. Data names will be data_0, data_1, ..., etc. @@ -61,10 +62,11 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)], dataBatchSize: Int = 1, shuffle: Boolean = false, lastBatchHandle: String = "pad", dataName: String = "data", labelName: String = "label", - dType: DType = MX_REAL_TYPE, layout: String = "NCHW") { + dType: DType = MX_REAL_TYPE, dataLayout: String = "NCHW", + labelLayout: String = "N") { this(IO.initData(data, allowEmpty = false, dataName), IO.initData(label, allowEmpty = true, labelName), - dataBatchSize, shuffle, lastBatchHandle, dType, layout) + dataBatchSize, shuffle, lastBatchHandle, dType, dataLayout, labelLayout) } private val logger = LoggerFactory.getLogger(classOf[NDArrayIter]) @@ -111,8 +113,8 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)], private val (_provideDataDesc: IndexedSeq[DataDesc], _provideLabelDesc: IndexedSeq[DataDesc]) = { - val pData = initData.map(ele => new DataDesc(ele._1, getShape(ele)._2, dtype, layout)) - val pLabel = initLabel.map(ele => new DataDesc(ele._1, getShape(ele)._2, dtype, layout)) + val pData = initData.map(ele => new DataDesc(ele._1, getShape(ele)._2, dtype, dataLayout)) + val pLabel = initLabel.map(ele => new DataDesc(ele._1, getShape(ele)._2, dtype, labelLayout)) (pData, pLabel) } @@ -158,7 +160,7 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)], if (hasNext) { cursor += dataBatchSize new DataBatch(getData(), getLabel(), getIndex(), getPad(), - dtype = getDType(), layout = getLayout()) + dtype = getDType(), dataLayout = getLayout()._1, labelLayout = getLayout()._2) } else { throw new NoSuchElementException } @@ -245,8 +247,8 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)], * Get the layout * @return layout */ - def getLayout(): String = { - layout + def getLayout(): (String, String) = { + (dataLayout, labelLayout) } // The name and shape of data provided by this iterator @@ -274,7 +276,8 @@ object NDArrayIter { private var label: IndexedSeq[(String, NDArray)] = IndexedSeq.empty private var dataBatchSize: Int = 1 private var lastBatchHandle: String = "pad" - private var layout: String = "NCHW" + private var dataLayout: String = "NCHW" + private var labelLayout: String = "N" private var dtype: DType = Base.MX_REAL_TYPE /** @@ -331,11 +334,13 @@ object NDArrayIter { /** * Set the layout. - * @param layout The layout of the label, default is NCHW + * @param dataLayout The layout of the data, default is NCHW + * @param labelLayout The layout of the label, default is N * @return this */ - def setLayout(layout: String): Builder = { - this.layout = layout + def setLayout(dataLayout: String, labelLayout: String): Builder = { + this.dataLayout = dataLayout + this.labelLayout = labelLayout this } @@ -344,7 +349,8 @@ object NDArrayIter { * @return the built object. */ def build(): NDArrayIter = { - new NDArrayIter(data, label, dataBatchSize, false, lastBatchHandle, dtype, layout) + new NDArrayIter(data, label, dataBatchSize, false, lastBatchHandle, + dtype, dataLayout, labelLayout) } } } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala index f8f589f5faa5..3097948b5de9 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala @@ -189,8 +189,8 @@ class PrefetchingIter( * Get the layout * @return layout */ - def getLayout(): String = { - currentBatch.layout + def getLayout(): (String, String) = { + (currentBatch.dataLayout, currentBatch.labelLayout) } // The name and shape of label provided by this iterator @@ -224,7 +224,8 @@ class PrefetchingIter( labels.toIndexedSeq.flatten, nextBatch(0).index, nextBatch(0).pad, - layout = nextBatch(0).layout, + dataLayout = nextBatch(0).dataLayout, + labelLayout = nextBatch(0).labelLayout, dtype = nextBatch(0).dtype) for (e <- dataTaken) e.release() true diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/ResizeIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/ResizeIter.scala index 228ba72c97ed..6e521c219128 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/io/ResizeIter.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/ResizeIter.scala @@ -141,8 +141,8 @@ class ResizeIter( * Get the layout * @return layout */ - def getLayout(): String = { - currentBatch.layout + def getLayout(): (String, String) = { + (currentBatch.dataLayout, currentBatch.labelLayout) } override def batchSize: Int = { diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala index 1b922b3c05b6..478f834d210b 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala @@ -38,7 +38,9 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { "shuffle" -> "1", "flat" -> "1", "silent" -> "0", - "seed" -> "10" + "seed" -> "10", + "dataLayout" -> "NT", + "labelLayout" -> "N" ) val mnistPack = IO.MNISTPack(params) @@ -99,7 +101,9 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { "data_shape" -> "(3,28,28)", "batch_size" -> "100", "preprocess_threads" -> "4", - "prefetch_buffer" -> "1" + "prefetch_buffer" -> "1", + "dataLayout" -> "NCHW", + "labelLayout" -> "N" ) val imgRecIter = IO.ImageRecordIter(params) val nBatch = 500 @@ -145,7 +149,9 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { "shuffle" -> "1", "flat" -> "1", "silent" -> "0", - "seed" -> "10" + "seed" -> "10", + "dataLayout" -> "NT", + "labelLayout" -> "N" ) val mnistIter = IO.MNISTIter(params) @@ -182,7 +188,9 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { "shuffle" -> "1", "flat" -> "1", "silent" -> "0", - "seed" -> "10" + "seed" -> "10", + "dataLayout" -> "NT", + "labelLayout" -> "N" ) val mnistPack1 = IO.MNISTPack(params) @@ -243,7 +251,8 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { val batchLabel = NDArray.ones(Shape(Array(128, 1))) // test pad - val dataIter0 = new NDArrayIter(data, label, 128, false, "pad") + val dataIter0 = new NDArrayIter(data, label, 128, false, "pad", + dataLayout = "NTC", labelLayout = "NT") var batchCount = 0 val nBatch0 = 8 while(dataIter0.hasNext) { @@ -262,6 +271,7 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { .addData("data0", data(0)).addData("data1", data(1)) .addLabel("label", label(0)) .setBatchSize(128) + .setLayout("NTC", "NT") .setLastBatchHandle("discard").build() val nBatch1 = 7 batchCount = 0 @@ -277,7 +287,8 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { assert(batchCount === nBatch1) // test empty label (for prediction) - val dataIter2 = new NDArrayIter(data = data, dataBatchSize = 128, lastBatchHandle = "discard") + val dataIter2 = new NDArrayIter(data = data, dataBatchSize = 128, lastBatchHandle = "discard", + dataLayout = "NTC") batchCount = 0 while(dataIter2.hasNext) { val tBatch = dataIter2.next() diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala index 8234568d7d9f..a73195f77207 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala @@ -184,7 +184,8 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll { val data = NDArray.array(Array(0.05f, 0.1f), Shape(1, 1, 1, 2)) val label = NDArray.array(Array(0.01f, 0.99f), Shape(1, 1, 1, 2)) val trainData = new NDArrayIter( - IndexedSeq(data), IndexedSeq(label), labelName = "softmax_label") + IndexedSeq(data), IndexedSeq(label), labelName = "softmax_label", + dataLayout = "NCHW", labelLayout = "NCHW") // symbols var x = Symbol.Variable("data") @@ -234,7 +235,8 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll { val data = NDArray.array(Array(0.05f, 0.1f), Shape(1, 1, 1, 2)) val label = NDArray.array(Array(0.01f, 0.99f), Shape(1, 1, 1, 2)) val trainData = new NDArrayIter( - IndexedSeq(data), IndexedSeq(label), labelName = "softmax_label") + IndexedSeq(data), IndexedSeq(label), labelName = "softmax_label", + dataLayout = "NCHW", labelLayout = "NCHW") // symbols var x = Symbol.Variable("data") diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala index 51de9164c328..1a10d91ee44b 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala @@ -68,7 +68,8 @@ object ExampleMultiTask { new DataBatch(batch.data, IndexedSeq(label, label), batch.index, - batch.pad, dtype = batch.dtype, layout = batch.layout) + batch.pad, dtype = batch.dtype, dataLayout = batch.dataLayout, + labelLayout = batch.labelLayout) } else { throw new NoSuchElementException } @@ -129,7 +130,7 @@ object ExampleMultiTask { override def getDType(): DType = this.dataIter.getDType() - override def getLayout(): String = this.dataIter.getLayout() + override def getLayout(): (String, String) = this.dataIter.getLayout() // The name and shape of data provided by this iterator override def provideData: ListMap[String, Shape] = this.dataIter.provideData diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala index e37a3265d322..21cc146a1b32 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala @@ -97,7 +97,9 @@ object BucketIo { path: String, vocab: Map[String, Int], var buckets: IndexedSeq[Int], _batchSize: Int, private val initStates: IndexedSeq[(String, (Int, Int))], seperateChar: String = " ", text2Id: Text2Id = defaultText2Id, - readContent: ReadContent = defaultReadContent, layout: String = "NT", + readContent: ReadContent = defaultReadContent, + dataLayout: String = "NT", + labelLayout: String = "N", dtype : DType = DType.Float32) extends DataIter { private val logger = LoggerFactory.getLogger(classOf[BucketSentenceIter]) @@ -173,12 +175,12 @@ object BucketIo { private val _provideDataDesc = { val tmp = IndexedSeq(new DataDesc("data", - Shape(_batchSize, _defaultBucketKey), dtype, layout)) - tmp ++ initStates.map(x => new DataDesc(x._1, Shape(x._2._1, x._2._2), dtype, layout)) + Shape(_batchSize, _defaultBucketKey), dtype, dataLayout)) + tmp ++ initStates.map(x => new DataDesc(x._1, Shape(x._2._1, x._2._2), dtype, dataLayout)) } private val _provideLabelDesc = IndexedSeq(new DataDesc("softmax_label", - Shape(_batchSize, _defaultBucketKey), dtype, layout)) + Shape(_batchSize, _defaultBucketKey), dtype, labelLayout)) private var iBucket = 0 @@ -210,7 +212,8 @@ object BucketIo { getIndex(), getPad(), this.buckets(bucketIdx).asInstanceOf[AnyRef], - batchProvideData, batchProvideLabel, dtype, layout) + batchProvideData, batchProvideLabel, getDType(), + getLayout()._1, getLayout()._2) } /** @@ -250,7 +253,7 @@ object BucketIo { override def getDType(): DType = dtype - override def getLayout(): String = layout + override def getLayout(): (String, String) = (dataLayout, labelLayout) // The name and shape of label provided by this iterator override def provideLabel: ListMap[String, Shape] = this._provideLabel diff --git a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala index db491b497b93..7fda57a4e3e7 100644 --- a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala +++ b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala @@ -35,7 +35,8 @@ class LabeledPointIter private[mxnet]( private val dataName: String = "data", private val labelName: String = "label", private val dtype: DType = DType.Float32, - private val layout: String = "NCHW") extends DataIter { + private val dataLayout: String = "NCHW", + private val labelLayout: String = "N") extends DataIter { private val cache: ArrayBuffer[DataBatch] = ArrayBuffer.empty[DataBatch] private var index: Int = -1 @@ -76,7 +77,7 @@ class LabeledPointIter private[mxnet]( val pad = batchSize - instNum val dataBatch = new LongLivingDataBatch( IndexedSeq(dataBuilder), IndexedSeq(labelBuilder), null, pad, - layout, dtype) + dataLayout, labelLayout, dtype) cache += dataBatch dataBatch } @@ -129,11 +130,11 @@ class LabeledPointIter private[mxnet]( } override def provideDataDesc: IndexedSeq[DataDesc] = { - IndexedSeq(new DataDesc(dataName, dataShape, dtype, layout)) + IndexedSeq(new DataDesc(dataName, dataShape, dtype, dataLayout)) } override def provideLabelDesc: IndexedSeq[DataDesc] = { - IndexedSeq(new DataDesc(labelName, Shape(_batchSize), dtype, layout)) + IndexedSeq(new DataDesc(labelName, Shape(_batchSize), dtype, labelLayout)) } /** @@ -145,7 +146,7 @@ class LabeledPointIter private[mxnet]( override def getDType(): DType = dtype - override def getLayout(): String = layout + override def getLayout(): (String, String) = (dataLayout, labelLayout) override def batchSize: Int = _batchSize diff --git a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala index d062a81b9bcc..acbcbc7c2b84 100644 --- a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala +++ b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala @@ -29,9 +29,10 @@ class LongLivingDataBatch( override val label: IndexedSeq[NDArray], override val index: IndexedSeq[Long], override val pad: Int, - override val layout: String, + override val dataLayout: String, + override val labelLayout: String, override val dtype: DType) extends DataBatch(data, label, index, pad, - layout = layout, dtype = dtype) { + dataLayout = dataLayout, labelLayout = labelLayout, dtype = dtype) { override def dispose(): Unit = {} def disposeForce(): Unit = super.dispose() } diff --git a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala index c0d898c72fc7..d239e5c641fa 100644 --- a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala +++ b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala @@ -35,7 +35,8 @@ class PointIter private[mxnet]( private val dataName: String = "data", private val labelName: String = "label", private val dtype: DType = DType.Float32, - private val layout: String = "NCHW") extends DataIter { + private val dataLayout: String = "NCHW", + private val labelLayout: String = "N") extends DataIter { private val cache: ArrayBuffer[DataBatch] = ArrayBuffer.empty[DataBatch] private var index: Int = -1 @@ -75,7 +76,7 @@ class PointIter private[mxnet]( val pad = batchSize - instNum val dataBatch = new LongLivingDataBatch( IndexedSeq(dataBuilder), IndexedSeq(labelBuilder), null, pad, - layout, dtype) + dataLayout, labelLayout, dtype) cache += dataBatch dataBatch } @@ -128,11 +129,11 @@ class PointIter private[mxnet]( } override def provideDataDesc: IndexedSeq[DataDesc] = { - IndexedSeq(new DataDesc(dataName, dataShape, dtype, layout)) + IndexedSeq(new DataDesc(dataName, dataShape, dtype, dataLayout)) } override def provideLabelDesc: IndexedSeq[DataDesc] = { - IndexedSeq(new DataDesc(labelName, Shape(_batchSize), dtype, layout)) + IndexedSeq(new DataDesc(labelName, Shape(_batchSize), dtype, labelLayout)) } /** @@ -144,7 +145,7 @@ class PointIter private[mxnet]( override def getDType(): DType = dtype - override def getLayout(): String = layout + override def getLayout(): (String, String) = (dataLayout, labelLayout) override def batchSize: Int = _batchSize