Skip to content

Commit

Permalink
#1. Fix RNNSharp crashing bug when dense feature isn't used
Browse files Browse the repository at this point in the history
#2. Improve BiRNN learning process
#3. Support to train model without validated corpus
  • Loading branch information
zhongkaifu committed Feb 15, 2016
1 parent ad9a7cf commit 1dc9759
Show file tree
Hide file tree
Showing 12 changed files with 113 additions and 124 deletions.
2 changes: 1 addition & 1 deletion ConvertCorpus/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ static int ArgPos(string str, string[] args)
{
if (a == args.Length - 1)
{
Logger.WriteLine(Logger.Level.info, "Argument missing for {0}", str);
Logger.WriteLine("Argument missing for {0}", str);
return -1;
}
return a;
Expand Down
33 changes: 12 additions & 21 deletions RNNSharp/BiRNN.cs
Original file line number Diff line number Diff line change
Expand Up @@ -177,11 +177,6 @@ public override int DenseFeatureSize
}
}

public override void SetHiddenLayer(SimpleCell[] cells)
{
throw new NotImplementedException("SetHiddenLayer is not implemented in BiRNN");
}

public override SimpleCell[] GetHiddenLayer()
{
throw new NotImplementedException("GetHiddenLayer is not implemented in BiRNN");
Expand All @@ -204,7 +199,7 @@ public override void initMem()
}
}

public SimpleCell[][] InnerDecode(Sequence pSequence, out SimpleCell[][] outputHiddenLayer, out Matrix<double> rawOutputLayer, out SimpleCell[][] forwardHidden, out SimpleCell[][] backwardHidden)
public SimpleCell[][] InnerDecode(Sequence pSequence, out SimpleCell[][] outputHiddenLayer, out Matrix<double> rawOutputLayer)
{
int numStates = pSequence.States.Length;
SimpleCell[][] mForward = null;
Expand Down Expand Up @@ -275,8 +270,6 @@ public SimpleCell[][] InnerDecode(Sequence pSequence, out SimpleCell[][] outputH

outputHiddenLayer = mergedHiddenLayer;
rawOutputLayer = tmp_rawOutputLayer;
forwardHidden = mForward;
backwardHidden = mBackward;

return seqOutput;
}
Expand All @@ -288,9 +281,7 @@ public override int[] PredictSentenceCRF(Sequence pSequence, RunningMode running
//Predict output
SimpleCell[][] mergedHiddenLayer = null;
Matrix<double> rawOutputLayer = null;
SimpleCell[][] forwardHidden = null;
SimpleCell[][] backwardHidden = null;
SimpleCell[][] seqOutput = InnerDecode(pSequence, out mergedHiddenLayer, out rawOutputLayer, out forwardHidden, out backwardHidden);
SimpleCell[][] seqOutput = InnerDecode(pSequence, out mergedHiddenLayer, out rawOutputLayer);

ForwardBackward(numStates, rawOutputLayer);

Expand Down Expand Up @@ -324,7 +315,7 @@ public override int[] PredictSentenceCRF(Sequence pSequence, RunningMode running
layer[label].er = 1 - CRFOutputLayer[label];
}

LearnTwoRNN(pSequence, mergedHiddenLayer, seqOutput, forwardHidden, backwardHidden);
LearnTwoRNN(pSequence, mergedHiddenLayer, seqOutput);
}

return predict;
Expand All @@ -338,9 +329,7 @@ public override Matrix<double> PredictSentence(Sequence pSequence, RunningMode r
//Predict output
SimpleCell[][] mergedHiddenLayer = null;
Matrix<double> rawOutputLayer = null;
SimpleCell[][] forwardHidden = null;
SimpleCell[][] backwardHidden = null;
SimpleCell[][] seqOutput = InnerDecode(pSequence, out mergedHiddenLayer, out rawOutputLayer, out forwardHidden, out backwardHidden);
SimpleCell[][] seqOutput = InnerDecode(pSequence, out mergedHiddenLayer, out rawOutputLayer);

if (runningMode != RunningMode.Test)
{
Expand All @@ -367,13 +356,13 @@ public override Matrix<double> PredictSentence(Sequence pSequence, RunningMode r
layer[label].er = 1.0 - layer[label].cellOutput;
}

LearnTwoRNN(pSequence, mergedHiddenLayer, seqOutput, forwardHidden, backwardHidden);
LearnTwoRNN(pSequence, mergedHiddenLayer, seqOutput);
}

return rawOutputLayer;
}

private void LearnTwoRNN(Sequence pSequence, SimpleCell[][] mergedHiddenLayer, SimpleCell[][] seqOutput, SimpleCell[][] forwardHidden, SimpleCell[][] backwardHidden)
private void LearnTwoRNN(Sequence pSequence, SimpleCell[][] mergedHiddenLayer, SimpleCell[][] seqOutput)
{
int numStates = pSequence.States.Length;

Expand Down Expand Up @@ -418,10 +407,11 @@ private void LearnTwoRNN(Sequence pSequence, SimpleCell[][] mergedHiddenLayer, S
State state = pSequence.States[curState];

forwardRNN.setInputLayer(state, curState, numStates, null);
forwardRNN.SetHiddenLayer(forwardHidden[curState]);

forwardRNN.computeHiddenLayer(state, true);

//Copy output result to forward net work's output
forwardRNN.OutputLayer = seqOutput[curState];

forwardRNN.ComputeHiddenLayerErr();

forwardRNN.learnNet(state);
Expand All @@ -439,10 +429,11 @@ private void LearnTwoRNN(Sequence pSequence, SimpleCell[][] mergedHiddenLayer, S
State state2 = pSequence.States[curState2];

backwardRNN.setInputLayer(state2, curState2, numStates, null, false);
backwardRNN.SetHiddenLayer(backwardHidden[curState2]);

backwardRNN.computeHiddenLayer(state2, true);

//Copy output result to forward net work's output
backwardRNN.OutputLayer = seqOutput[curState2];

backwardRNN.ComputeHiddenLayerErr();

backwardRNN.learnNet(state2);
Expand Down
16 changes: 8 additions & 8 deletions RNNSharp/Featurizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -62,25 +62,25 @@ public void LoadFeatureConfigFromFile(string strFileName)
string strValue = kv[1].Trim().ToLower();
if (strKey == WORDEMBEDDING_FILENAME)
{
Logger.WriteLine(Logger.Level.info, "Loading word embedding feature set...");
Logger.WriteLine("Loading word embedding feature set...");
m_WordEmbedding = new WordEMWrapFeaturizer(strValue);
continue;
}
else if (strKey == TFEATURE_FILENAME)
{
Logger.WriteLine(Logger.Level.info, "Loading template feature set...");
Logger.WriteLine("Loading template feature set...");
m_TFeaturizer = new TemplateFeaturizer(strValue);
continue;
}
else if (strKey == WORDEMBEDDING_COLUMN)
{
m_WordEmbeddingCloumn = int.Parse(strValue);
Logger.WriteLine(Logger.Level.info, "Word embedding feature column: {0}", m_WordEmbeddingCloumn);
Logger.WriteLine("Word embedding feature column: {0}", m_WordEmbeddingCloumn);
continue;
}
else if (strKey == TFEATURE_WEIGHT_TYPE)
{
Logger.WriteLine(Logger.Level.info, "TFeature weighting type: {0}", strValue);
Logger.WriteLine("TFeature weighting type: {0}", strValue);
if (strValue == "binary")
{
m_TFeatureWeightType = TFEATURE_WEIGHT_TYPE_ENUM.BINARY;
Expand Down Expand Up @@ -160,16 +160,16 @@ public void ShowFeatureSize()
var fc = m_FeatureConfiguration;

if (m_TFeaturizer != null)
Logger.WriteLine(Logger.Level.info, "Template feature size: {0}", m_TFeaturizer.GetFeatureSize());
Logger.WriteLine("Template feature size: {0}", m_TFeaturizer.GetFeatureSize());

if (fc.ContainsKey(TFEATURE_CONTEXT) == true)
Logger.WriteLine(Logger.Level.info, "Template feature context size: {0}", m_TFeaturizer.GetFeatureSize() * fc[TFEATURE_CONTEXT].Count);
Logger.WriteLine("Template feature context size: {0}", m_TFeaturizer.GetFeatureSize() * fc[TFEATURE_CONTEXT].Count);

if (fc.ContainsKey(RT_FEATURE_CONTEXT) == true)
Logger.WriteLine(Logger.Level.info, "Run time feature size: {0}", TagSet.GetSize() * fc[RT_FEATURE_CONTEXT].Count);
Logger.WriteLine("Run time feature size: {0}", TagSet.GetSize() * fc[RT_FEATURE_CONTEXT].Count);

if (fc.ContainsKey(WORDEMBEDDING_CONTEXT) == true)
Logger.WriteLine(Logger.Level.info, "Word embedding feature size: {0}", m_WordEmbedding.GetDimension() * fc[WORDEMBEDDING_CONTEXT].Count);
Logger.WriteLine("Word embedding feature size: {0}", m_WordEmbedding.GetDimension() * fc[WORDEMBEDDING_CONTEXT].Count);
}

void ExtractSparseFeature(int currentState, int numStates, List<string[]> features, State pState)
Expand Down
16 changes: 5 additions & 11 deletions RNNSharp/LSTMRNN.cs
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,6 @@ public LSTMRNN()
ModelType = MODELTYPE.LSTM;
}

public override void SetHiddenLayer(SimpleCell[] cells)
{
neuHidden = (LSTMCell[])cells;
}

public override SimpleCell[] GetHiddenLayer()
{
LSTMCell[] m = new LSTMCell[L1];
Expand Down Expand Up @@ -244,7 +239,7 @@ private void saveLSTMWeight(LSTMWeight[][] weight, BinaryWriter fo)

public override void loadNetBin(string filename)
{
Logger.WriteLine(Logger.Level.info, "Loading LSTM-RNN model: {0}", filename);
Logger.WriteLine("Loading LSTM-RNN model: {0}", filename);

StreamReader sr = new StreamReader(filename);
BinaryReader br = new BinaryReader(sr.BaseStream);
Expand Down Expand Up @@ -485,7 +480,7 @@ public override void initMem()
}
}

Logger.WriteLine(Logger.Level.info, "[TRACE] Initializing weights, random value is {0}", rand.NextDouble());// yy debug
Logger.WriteLine("[TRACE] Initializing weights, random value is {0}", rand.NextDouble());// yy debug
initWeights();
}

Expand Down Expand Up @@ -549,7 +544,6 @@ public override void ComputeHiddenLayerErr()
{
cell.er += OutputLayer[k].er * Hidden2OutputWeight[k][i];
}
//cell.er = NormalizeErr(cell.er);
}
});
}
Expand All @@ -562,7 +556,7 @@ public override void LearnOutputWeight()
double cellOutput = neuHidden[i].cellOutput;
for (int k = 0; k < L2; k++)
{
Hidden2OutputWeight[k][i] += LearningRate * NormalizeErr(cellOutput * OutputLayer[k].er);
Hidden2OutputWeight[k][i] += LearningRate * NormalizeGradient(cellOutput * OutputLayer[k].er);
}
});
}
Expand All @@ -579,10 +573,10 @@ public override void learnNet(State state)
LSTMCell c = neuHidden[i];

//using the error find the gradient of the output gate
var gradientOutputGate = LearningRate * NormalizeErr(SigmoidDerivative(c.netOut) * c.cellState * c.er);
var gradientOutputGate = LearningRate * NormalizeGradient(SigmoidDerivative(c.netOut) * c.cellState * c.er);

//internal cell state error
var cellStateError = LearningRate * NormalizeErr(c.yOut * c.er);
var cellStateError = LearningRate * NormalizeGradient(c.yOut * c.er);

LSTMWeight[] w_i = input2hidden[i];
LSTMWeightDerivative[] wd_i = input2hiddenDeri[i];
Expand Down
24 changes: 12 additions & 12 deletions RNNSharp/ModelSetting.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,34 +20,34 @@ public class ModelSetting

public void DumpSetting()
{
Logger.WriteLine(Logger.Level.info, "Model File: {0}", ModelFile);
Logger.WriteLine("Model File: {0}", ModelFile);
if (ModelType == 0)
{
Logger.WriteLine(Logger.Level.info, "Model Structure: Simple RNN");
Logger.WriteLine(Logger.Level.info, "BPTT: {0}", Bptt);
Logger.WriteLine("Model Structure: Simple RNN");
Logger.WriteLine("BPTT: {0}", Bptt);
}
else if (ModelType == 1)
{
Logger.WriteLine(Logger.Level.info, "Model Structure: LSTM-RNN");
Logger.WriteLine("Model Structure: LSTM-RNN");
}

if (ModelDirection == 0)
{
Logger.WriteLine(Logger.Level.info, "RNN Direction: Forward");
Logger.WriteLine("RNN Direction: Forward");
}
else
{
Logger.WriteLine(Logger.Level.info, "RNN Direction: Bi-directional");
Logger.WriteLine("RNN Direction: Bi-directional");
}

Logger.WriteLine(Logger.Level.info, "Learning rate: {0}", LearningRate);
Logger.WriteLine(Logger.Level.info, "Dropout: {0}", Dropout);
Logger.WriteLine(Logger.Level.info, "Max Iteration: {0}", MaxIteration);
Logger.WriteLine(Logger.Level.info, "Hidden layer size: {0}", NumHidden);
Logger.WriteLine(Logger.Level.info, "RNN-CRF: {0}", IsCRFTraining);
Logger.WriteLine("Learning rate: {0}", LearningRate);
Logger.WriteLine("Dropout: {0}", Dropout);
Logger.WriteLine("Max Iteration: {0}", MaxIteration);
Logger.WriteLine("Hidden layer size: {0}", NumHidden);
Logger.WriteLine("RNN-CRF: {0}", IsCRFTraining);
if (SaveStep > 0)
{
Logger.WriteLine(Logger.Level.info, "Save temporary model after every {0} sentences", SaveStep);
Logger.WriteLine("Save temporary model after every {0} sentences", SaveStep);
}
}

Expand Down
34 changes: 16 additions & 18 deletions RNNSharp/RNN.cs
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ public float RandInitWeight()
public virtual double TrainNet(DataSet trainingSet, int iter)
{
DateTime start = DateTime.Now;
Logger.WriteLine(Logger.Level.info, "[TRACE] Iter " + iter + " begins with learning rate alpha = " + LearningRate + " ...");
Logger.WriteLine("[TRACE] Iter " + iter + " begins with learning rate alpha = " + LearningRate + " ...");

//Initialize varibles
logp = 0;
Expand All @@ -501,7 +501,7 @@ public virtual double TrainNet(DataSet trainingSet, int iter)
int wordCnt = 0;
int tknErrCnt = 0;
int sentErrCnt = 0;
Logger.WriteLine(Logger.Level.info, "[TRACE] Progress = 0/" + numSequence / 1000.0 + "K\r");
Logger.WriteLine("[TRACE] Progress = 0/" + numSequence / 1000.0 + "K\r");
for (int curSequence = 0; curSequence < numSequence; curSequence++)
{
Sequence pSequence = trainingSet.SequenceList[curSequence];
Expand All @@ -528,16 +528,16 @@ public virtual double TrainNet(DataSet trainingSet, int iter)

if ((curSequence + 1) % 1000 == 0)
{
Logger.WriteLine(Logger.Level.info, "[TRACE] Progress = {0} ", (curSequence + 1) / 1000 + "K/" + numSequence / 1000.0 + "K");
Logger.WriteLine(Logger.Level.info, " train cross-entropy = {0} ", -logp / Math.Log10(2.0) / wordCnt);
Logger.WriteLine(Logger.Level.info, " Error token ratio = {0}%", (double)tknErrCnt / (double)wordCnt * 100.0);
Logger.WriteLine(Logger.Level.info, " Error sentence ratio = {0}%", (double)sentErrCnt / (double)curSequence * 100.0);
Logger.WriteLine("[TRACE] Progress = {0} ", (curSequence + 1) / 1000 + "K/" + numSequence / 1000.0 + "K");
Logger.WriteLine(" train cross-entropy = {0} ", -logp / Math.Log10(2.0) / wordCnt);
Logger.WriteLine(" Error token ratio = {0}%", (double)tknErrCnt / (double)wordCnt * 100.0);
Logger.WriteLine(" Error sentence ratio = {0}%", (double)sentErrCnt / (double)curSequence * 100.0);
}

if (SaveStep > 0 && (curSequence + 1) % SaveStep == 0)
{
//After processed every m_SaveStep sentences, save current model into a temporary file
Logger.WriteLine(Logger.Level.info, "Saving temporary model into file...");
Logger.WriteLine("Saving temporary model into file...");
saveNetBin(ModelTempFile);
}
}
Expand All @@ -547,9 +547,9 @@ public virtual double TrainNet(DataSet trainingSet, int iter)

double entropy = -logp / Math.Log10(2.0) / wordCnt;
double ppl = exp_10(-logp / wordCnt);
Logger.WriteLine(Logger.Level.info, "[TRACE] Iter " + iter + " completed");
Logger.WriteLine(Logger.Level.info, "[TRACE] Sentences = " + numSequence + ", time escape = " + duration + "s, speed = " + numSequence / duration.TotalSeconds);
Logger.WriteLine(Logger.Level.info, "[TRACE] In training: log probability = " + logp + ", cross-entropy = " + entropy + ", perplexity = " + ppl);
Logger.WriteLine("[TRACE] Iter " + iter + " completed");
Logger.WriteLine("[TRACE] Sentences = " + numSequence + ", time escape = " + duration + "s, speed = " + numSequence / duration.TotalSeconds);
Logger.WriteLine("[TRACE] In training: log probability = " + logp + ", cross-entropy = " + entropy + ", perplexity = " + ppl);

return ppl;
}
Expand All @@ -561,8 +561,6 @@ public virtual double TrainNet(DataSet trainingSet, int iter)

public abstract SimpleCell[] GetHiddenLayer();

public abstract void SetHiddenLayer(SimpleCell[] cells);

public static void CheckModelFileType(string filename, out MODELTYPE modelType, out MODELDIRECTION modelDir)
{
using (StreamReader sr = new StreamReader(filename))
Expand All @@ -576,7 +574,7 @@ public static void CheckModelFileType(string filename, out MODELTYPE modelType,
}


protected double NormalizeErr(double err)
protected double NormalizeGradient(double err)
{
if (err > GradientCutoff)
err = GradientCutoff;
Expand Down Expand Up @@ -614,7 +612,7 @@ public void matrixXvectorADD(SimpleCell[] dest, SimpleCell[] srcvec, Matrix<doub
cell.er += srcvec[j].er * srcmatrix[j][i];
}

cell.er = NormalizeErr(cell.er);
cell.er = NormalizeGradient(cell.er);
});
}
}
Expand Down Expand Up @@ -829,7 +827,7 @@ public void ComputeOutputLayerErr(State state, int timeat)

public virtual bool ValidateNet(DataSet validationSet, int iter)
{
Logger.WriteLine(Logger.Level.info, "[TRACE] Start validation ...");
Logger.WriteLine("[TRACE] Start validation ...");
int wordcn = 0;
int tknErrCnt = 0;
int sentErrCnt = 0;
Expand Down Expand Up @@ -867,9 +865,9 @@ public virtual bool ValidateNet(DataSet validationSet, int iter)
double tknErrRatio = (double)tknErrCnt / (double)wordcn * 100.0;
double sentErrRatio = (double)sentErrCnt / (double)numSequence * 100.0;

Logger.WriteLine(Logger.Level.info, "[TRACE] In validation: error token ratio = {0}% error sentence ratio = {1}%", tknErrRatio, sentErrRatio);
Logger.WriteLine(Logger.Level.info, "[TRACE] In training: log probability = " + logp + ", cross-entropy = " + entropy + ", perplexity = " + ppl);
Logger.WriteLine(Logger.Level.info, "");
Logger.WriteLine("[TRACE] In validation: error token ratio = {0}% error sentence ratio = {1}%", tknErrRatio, sentErrRatio);
Logger.WriteLine("[TRACE] In training: log probability = " + logp + ", cross-entropy = " + entropy + ", perplexity = " + ppl);
Logger.WriteLine("");

bool bUpdate = false;
if (tknErrRatio < minTknErrRatio)
Expand Down
Loading

0 comments on commit 1dc9759

Please sign in to comment.