Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
initial fix for RNN
Browse files Browse the repository at this point in the history
  • Loading branch information
lanking520 committed Jul 13, 2018
1 parent 5b4d528 commit 62bf495
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 74 deletions.
5 changes: 4 additions & 1 deletion scala-package/core/src/main/scala/org/apache/mxnet/IO.scala
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,9 @@ object DataDesc {
}

implicit def ListMap2Descs(shapes: ListMap[String, Shape]): IndexedSeq[DataDesc] = {
shapes.map { case (k, s) => new DataDesc(k, s) }.toIndexedSeq
if (shapes.toIndexedSeq(0)._2.length == 2) {
shapes.map { case (k, s) => new DataDesc(k, s, layout = "NT") }.toIndexedSeq
}
else shapes.map { case (k, s) => new DataDesc(k, s) }.toIndexedSeq
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@ import scala.io.Source
import scala.util.Random
import scala.collection.mutable

/**
* @author Depeng Liang
*/
object BucketIo {

type Text2Id = (String, Map[String, Int]) => Array[Int]
Expand Down Expand Up @@ -57,7 +54,7 @@ object BucketIo {
val tmp = sentence.split(" ").filter(_.length() > 0)
for (w <- tmp) yield theVocab(w)
}
words.toArray
words
}

def defaultGenBuckets(sentences: Array[String], batchSize: Int,
Expand Down Expand Up @@ -160,8 +157,6 @@ object BucketIo {
labelBuffer.append(NDArray.zeros(_batchSize, buckets(iBucket)))
}

private val initStateArrays = initStates.map(x => NDArray.zeros(x._2._1, x._2._2))

private val _provideData = { val tmp = ListMap("data" -> Shape(_batchSize, _defaultBucketKey))
tmp ++ initStates.map(x => x._1 -> Shape(x._2._1, x._2._2))
}
Expand Down Expand Up @@ -192,12 +187,13 @@ object BucketIo {
tmp ++ initStates.map(x => x._1 -> Shape(x._2._1, x._2._2))
}
val batchProvideLabel = ListMap("softmax_label" -> labelBuf.shape)
new DataBatch(IndexedSeq(dataBuf) ++ initStateArrays,
IndexedSeq(labelBuf),
getIndex(),
getPad(),
this.buckets(bucketIdx).asInstanceOf[AnyRef],
batchProvideData, batchProvideLabel)
val initStateArrays = initStates.map(x => NDArray.zeros(x._2._1, x._2._2))
new DataBatch(IndexedSeq(dataBuf.copy()) ++ initStateArrays,
IndexedSeq(labelBuf.copy()),
getIndex(),
getPad(),
this.buckets(bucketIdx).asInstanceOf[AnyRef],
batchProvideData, batchProvideLabel)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,10 @@

package org.apache.mxnetexamples.rnn

import org.apache.mxnet.Symbol
import org.apache.mxnet.{Shape, Symbol}

import scala.collection.mutable.ArrayBuffer

/**
* @author Depeng Liang
*/
object Lstm {

final case class LSTMState(c: Symbol, h: Symbol)
Expand All @@ -35,27 +32,22 @@ object Lstm {
def lstm(numHidden: Int, inData: Symbol, prevState: LSTMState,
param: LSTMParam, seqIdx: Int, layerIdx: Int, dropout: Float = 0f): LSTMState = {
val inDataa = {
if (dropout > 0f) Symbol.Dropout()()(Map("data" -> inData, "p" -> dropout))
if (dropout > 0f) Symbol.api.Dropout(data = Some(inData), p = Some(dropout))
else inData
}
val i2h = Symbol.FullyConnected(s"t${seqIdx}_l${layerIdx}_i2h")()(Map("data" -> inDataa,
"weight" -> param.i2hWeight,
"bias" -> param.i2hBias,
"num_hidden" -> numHidden * 4))
val h2h = Symbol.FullyConnected(s"t${seqIdx}_l${layerIdx}_h2h")()(Map("data" -> prevState.h,
"weight" -> param.h2hWeight,
"bias" -> param.h2hBias,
"num_hidden" -> numHidden * 4))
val i2h = Symbol.api.FullyConnected(data = Some(inDataa), weight = Some(param.i2hWeight),
bias = Some(param.i2hBias), num_hidden = numHidden * 4, name = s"t${seqIdx}_l${layerIdx}_i2h")
val h2h = Symbol.api.FullyConnected(data = Some(prevState.h), weight = Some(param.h2hWeight),
bias = Some(param.h2hBias), num_hidden = numHidden * 4, name = s"t${seqIdx}_l${layerIdx}_h2h")
val gates = i2h + h2h
val sliceGates = Symbol.SliceChannel(s"t${seqIdx}_l${layerIdx}_slice")(
gates)(Map("num_outputs" -> 4))
val ingate = Symbol.Activation()()(Map("data" -> sliceGates.get(0), "act_type" -> "sigmoid"))
val inTransform = Symbol.Activation()()(Map("data" -> sliceGates.get(1), "act_type" -> "tanh"))
val forgetGate = Symbol.Activation()()(
Map("data" -> sliceGates.get(2), "act_type" -> "sigmoid"))
val outGate = Symbol.Activation()()(Map("data" -> sliceGates.get(3), "act_type" -> "sigmoid"))
val sliceGates = Symbol.api.SliceChannel(data = Some(gates), num_outputs = 4,
name = s"t${seqIdx}_l${layerIdx}_slice")
val ingate = Symbol.api.Activation(data = Some(sliceGates.get(0)), act_type = "sigmoid")
val inTransform = Symbol.api.Activation(data = Some(sliceGates.get(1)), act_type = "tanh")
val forgetGate = Symbol.api.Activation(data = Some(sliceGates.get(2)), act_type = "sigmoid")
val outGate = Symbol.api.Activation(data = Some(sliceGates.get(3)), act_type = "sigmoid")
val nextC = (forgetGate * prevState.c) + (ingate * inTransform)
val nextH = outGate * Symbol.Activation()()(Map("data" -> nextC, "act_type" -> "tanh"))
val nextH = outGate * Symbol.api.Activation(data = Some(nextC), "tanh")
LSTMState(c = nextC, h = nextH)
}

Expand All @@ -74,11 +66,11 @@ object Lstm {
val lastStatesBuf = ArrayBuffer[LSTMState]()
for (i <- 0 until numLstmLayer) {
paramCellsBuf.append(LSTMParam(i2hWeight = Symbol.Variable(s"l${i}_i2h_weight"),
i2hBias = Symbol.Variable(s"l${i}_i2h_bias"),
h2hWeight = Symbol.Variable(s"l${i}_h2h_weight"),
h2hBias = Symbol.Variable(s"l${i}_h2h_bias")))
i2hBias = Symbol.Variable(s"l${i}_i2h_bias"),
h2hWeight = Symbol.Variable(s"l${i}_h2h_weight"),
h2hBias = Symbol.Variable(s"l${i}_h2h_bias")))
lastStatesBuf.append(LSTMState(c = Symbol.Variable(s"l${i}_init_c_beta"),
h = Symbol.Variable(s"l${i}_init_h_beta")))
h = Symbol.Variable(s"l${i}_init_h_beta")))
}
val paramCells = paramCellsBuf.toArray
val lastStates = lastStatesBuf.toArray
Expand All @@ -87,10 +79,9 @@ object Lstm {
// embeding layer
val data = Symbol.Variable("data")
var label = Symbol.Variable("softmax_label")
val embed = Symbol.Embedding("embed")()(Map("data" -> data, "input_dim" -> inputSize,
"weight" -> embedWeight, "output_dim" -> numEmbed))
val wordvec = Symbol.SliceChannel()()(
Map("data" -> embed, "num_outputs" -> seqLen, "squeeze_axis" -> 1))
val embed = Symbol.api.Embedding(data = Some(data), input_dim = inputSize, weight = Some(embedWeight),
output_dim = numEmbed, name = "embed")
val wordvec = Symbol.api.SliceChannel(data = Some(embed), num_outputs = seqLen, squeeze_axis = Some(true))

val hiddenAll = ArrayBuffer[Symbol]()
var dpRatio = 0f
Expand All @@ -101,22 +92,23 @@ object Lstm {
for (i <- 0 until numLstmLayer) {
if (i == 0) dpRatio = 0f else dpRatio = dropout
val nextState = lstm(numHidden, inData = hidden,
prevState = lastStates(i),
param = paramCells(i),
seqIdx = seqIdx, layerIdx = i, dropout = dpRatio)
prevState = lastStates(i),
param = paramCells(i),
seqIdx = seqIdx, layerIdx = i, dropout = dpRatio)
hidden = nextState.h
lastStates(i) = nextState
}
// decoder
if (dropout > 0f) hidden = Symbol.Dropout()()(Map("data" -> hidden, "p" -> dropout))
if (dropout > 0f) hidden = Symbol.api.Dropout(data = Some(hidden), p = Some(dropout))
hiddenAll.append(hidden)
}
val hiddenConcat = Symbol.Concat()(hiddenAll: _*)(Map("dim" -> 0))
val pred = Symbol.FullyConnected("pred")()(Map("data" -> hiddenConcat, "num_hidden" -> numLabel,
"weight" -> clsWeight, "bias" -> clsBias))
label = Symbol.transpose()(label)()
label = Symbol.Reshape()()(Map("data" -> label, "target_shape" -> "(0,)"))
val sm = Symbol.SoftmaxOutput("softmax")()(Map("data" -> pred, "label" -> label))
val hiddenConcat = Symbol.api.Concat(data = hiddenAll.toArray, num_args = hiddenAll.length,
dim = Some(0))
val pred = Symbol.api.FullyConnected(data = Some(hiddenConcat), num_hidden = numLabel,
weight = Some(clsWeight), bias = Some(clsBias))
label = Symbol.api.transpose(data = Some(label))
label = Symbol.api.Reshape(data = Some(label), target_shape = Some(Shape(0)))
val sm = Symbol.api.SoftmaxOutput(data = Some(pred), label = Some(label), name = "softmax")
sm
}

Expand All @@ -131,35 +123,35 @@ object Lstm {
var lastStates = Array[LSTMState]()
for (i <- 0 until numLstmLayer) {
paramCells = paramCells :+ LSTMParam(i2hWeight = Symbol.Variable(s"l${i}_i2h_weight"),
i2hBias = Symbol.Variable(s"l${i}_i2h_bias"),
h2hWeight = Symbol.Variable(s"l${i}_h2h_weight"),
h2hBias = Symbol.Variable(s"l${i}_h2h_bias"))
i2hBias = Symbol.Variable(s"l${i}_i2h_bias"),
h2hWeight = Symbol.Variable(s"l${i}_h2h_weight"),
h2hBias = Symbol.Variable(s"l${i}_h2h_bias"))
lastStates = lastStates :+ LSTMState(c = Symbol.Variable(s"l${i}_init_c_beta"),
h = Symbol.Variable(s"l${i}_init_h_beta"))
h = Symbol.Variable(s"l${i}_init_h_beta"))
}
assert(lastStates.length == numLstmLayer)

val data = Symbol.Variable("data")

var hidden = Symbol.Embedding("embed")()(Map("data" -> data, "input_dim" -> inputSize,
"weight" -> embedWeight, "output_dim" -> numEmbed))
var hidden = Symbol.api.Embedding(data = Some(data), input_dim = inputSize, weight = Some(embedWeight),
output_dim = numEmbed, name = "embed")

var dpRatio = 0f
// stack LSTM
for (i <- 0 until numLstmLayer) {
if (i == 0) dpRatio = 0f else dpRatio = dropout
val nextState = lstm(numHidden, inData = hidden,
prevState = lastStates(i),
param = paramCells(i),
seqIdx = seqIdx, layerIdx = i, dropout = dpRatio)
prevState = lastStates(i),
param = paramCells(i),
seqIdx = seqIdx, layerIdx = i, dropout = dpRatio)
hidden = nextState.h
lastStates(i) = nextState
}
// decoder
if (dropout > 0f) hidden = Symbol.Dropout()()(Map("data" -> hidden, "p" -> dropout))
val fc = Symbol.FullyConnected("pred")()(Map("data" -> hidden, "num_hidden" -> numLabel,
"weight" -> clsWeight, "bias" -> clsBias))
val sm = Symbol.SoftmaxOutput("softmax")()(Map("data" -> fc))
if (dropout > 0f) hidden = Symbol.api.Dropout(data = Some(hidden), p = Some(dropout))
val fc = Symbol.api.FullyConnected(data = Some(hidden), num_hidden = numLabel, weight = Some(clsWeight),
bias = Some(clsBias))
val sm = Symbol.api.SoftmaxOutput(data = Some(fc), name = "softmax")
var output = Array(sm)
for (state <- lastStates) {
output = output :+ state.c
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,17 @@ package org.apache.mxnetexamples.rnn

import org.apache.mxnet.Callback.Speedometer
import org.apache.mxnet._
import BucketIo.BucketSentenceIter
import org.apache.mxnet.module.{BucketingModule, FitParams}
import org.apache.mxnet.optimizer.SGD
import org.kohsuke.args4j.{CmdLineParser, Option}
import org.slf4j.{Logger, LoggerFactory}
import BucketIo.BucketSentenceIter

import scala.collection.JavaConverters._
import org.apache.mxnet.module.BucketingModule
import org.apache.mxnet.module.FitParams

/**
* Bucketing LSTM examples
* @author Yizhi Liu
*/
* Bucketing LSTM examples
*/
class LstmBucketing {
@Option(name = "--data-train", usage = "training set")
private val dataTrain: String = "example/rnn/sherlockholmes.train.txt"
Expand All @@ -55,9 +53,11 @@ object LstmBucketing {
pred.waitToRead()
val labelArr = label.T.toArray.map(_.toInt)
var loss = .0
(0 until pred.shape(0)).foreach(i =>
loss -= Math.log(Math.max(1e-10f, pred.slice(i).toArray(labelArr(i))))
)
(0 until pred.shape(0)).foreach(i => {
val temp = pred.slice(i)
loss -= Math.log(Math.max(1e-10f, temp.toArray(labelArr(i))))
temp.dispose()
})
Math.exp(loss / labelArr.length).toFloat
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# RNN Example for MXNet Scala
This folder contains the following examples writing in new Scala type-safe API:
- [x] LSTM Bucketing
- [ ] CharRNN Inference (still fixing issues)
- [x] CharRNN Training

These example is only for Illustration and not modeled to achieve the best accuracy.

## Setup
### Download the source File
`obama.zip` contains the required files for CharCNN examples and `sherlockholmes` contains the data for LSTM Bucketing
```bash
https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/RNN/obama.zip
https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/RNN/sherlockholmes.train.txt
https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/RNN/sherlockholmes.valid.txt
```
### Unzip the file
```bash
unzip obama.zip
```
### Arguement Configuration
Then you need to define the arguments that you would like to pass in the model:

#### LSTM Bucketing
```bash
--data-train
<path>/sherlockholmes.train.txt
--data-val
<path>/sherlockholmes.valid.txt
--cpus
<num_cpus>
--gpus
<num_gpu>
```
#### TrainCharRnn
```bash
--data-path
<path>/obama.txt
--save-model-path
<path>/
```
#### TestCharRnn
This model currently does not working, still fixing the issues
```bash
--data-path
<path>/obama.txt
--model-prefix
<path>/obama
```

0 comments on commit 62bf495

Please sign in to comment.