Skip to content

Commit

Permalink
Merge pull request #5 from dmlc/master
Browse files Browse the repository at this point in the history
[BUGFIX] fix NTA implementation (dmlc#1277)
  • Loading branch information
jamiekang authored Aug 11, 2020
2 parents 67ef280 + 528283d commit 55a1d46
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 15 deletions.
28 changes: 14 additions & 14 deletions scripts/language_model/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ The dataset used for training the models is wikitext-2.
+---------------+----------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------+
| Weight_drop | 0.5 | 0.2 | 0 | 0 | 0 |
+---------------+----------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------+
| Val PPL | 68.71 | 84.89 | 86.51 | 90.96 | 107.59 |
| Val PPL | 71.78 | 80.11 | 86.28 | 91.30 | 108.17 |
+---------------+----------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------+
| Test PPL | 65.62 | 80.67 | 82.29 | 86.91 | 101.64 |
| Test PPL | 68.55 | 76.14 | 81.99 | 85.82 | 102.49 |
+---------------+----------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------+
| Command | [1] | [2] | [3] | [4] | [5] |
+---------------+----------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------+
Expand All @@ -47,31 +47,31 @@ The dataset used for training the models is wikitext-2.

For all the above model settings, we set Tied = True and NTASGD = True .

[1] awd_lstm_lm_1150_wikitext-2 (Val PPL 68.71 Test PPL 65.62 )
[1] awd_lstm_lm_1150_wikitext-2 (Val PPL 71.78 Test PPL 68.55 )

.. code-block:: console
$ python word_language_model.py --gpu 0 --tied --ntasgd --lr_update_interval 30 --lr_update_factor 0.1 --save awd_lstm_lm_1150_wikitext-2
[2] awd_lstm_lm_600_wikitext-2 (Val PPL 84.89 Test PPL 80.67)
[2] awd_lstm_lm_600_wikitext-2 (Val PPL 80.11 Test PPL 76.14)

.. code-block:: console
$ python word_language_model.py --gpu 0 --emsize 200 --nhid 600 --epochs 750 --dropout 0.2 --dropout_h 0.1 --dropout_i 0.3 --dropout_e 0.05 --weight_drop 0.2 --tied --ntasgd --lr_update_interval 30 --lr_update_factor 0.1 --save awd_lstm_lm_600_wikitext-2
[3] standard_lstm_lm_1500_wikitext-2 (Val PPL 86.51 Test PPL 82.29)
[3] standard_lstm_lm_1500_wikitext-2 (Val PPL 86.28 Test PPL 81.99)

.. code-block:: console
$ python word_language_model.py --gpu 0 --emsize 1500 --nhid 1500 --nlayers 2 --lr 20 --epochs 750 --batch_size 20 --bptt 35 --dropout 0.65 --dropout_h 0 --dropout_i 0 --dropout_e 0 --weight_drop 0 --tied --wd 0 --alpha 0 --beta 0 --ntasgd --lr_update_interval 30 --lr_update_factor 0.1 --save standard_lstm_lm_1500_wikitext-2
[4] standard_lstm_lm_650_wikitext-2 (Val PPL 90.96 Test PPL 86.91)
[4] standard_lstm_lm_650_wikitext-2 (Val PPL 91.30 Test PPL 85.82)

.. code-block:: console
$ python word_language_model.py --gpu 0 --emsize 650 --nhid 650 --nlayers 2 --lr 20 --epochs 750 --batch_size 20 --bptt 35 --dropout 0.5 --dropout_h 0 --dropout_i 0 --dropout_e 0 --weight_drop 0 --tied --wd 0 --alpha 0 --beta 0 --ntasgd --lr_update_interval 30 --lr_update_factor 0.1 --save standard_lstm_lm_650_wikitext-2
[5] standard_lstm_lm_200_wikitext-2 (Val PPL 107.59 Test PPL 101.64)
[5] standard_lstm_lm_200_wikitext-2 (Val PPL 108.17 Test PPL 102.49)

.. code-block:: console
Expand All @@ -93,9 +93,9 @@ The dataset used for training the models is wikitext-2.
+=====================+===================================================================================================================================+==================================================================================================================================+========================================================================================================================================+=======================================================================================================================================+=======================================================================================================================================+
| Pre-trained setting | Refer to: awd_lstm_lm_1150_wikitext-2 | Refer to: awd_lstm_lm_600_wikitext-2 | Refer to: standard_lstm_lm_1500_wikitext-2 | Refer to: standard_lstm_lm_650_wikitext-2 | Refer to: standard_lstm_lm_200_wikitext-2 |
+---------------------+-----------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| Val PPL | 53.41 | 64.51 | 65.54 | 68.47 | 77.51 |
| Val PPL | 58.18 | 64.09 | 73.19 | 69.27 | 81.68 |
+---------------------+-----------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| Test PPL | 51.46 | 62.19 | 62.79 | 65.85 | 73.74 |
| Test PPL | 56.08 | 61.62 | 70.91 | 66.39 | 77.83 |
+---------------------+-----------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| Command | [1] | [2] | [3] | [4] | [5] |
+---------------------+-----------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
Expand All @@ -104,31 +104,31 @@ The dataset used for training the models is wikitext-2.

For all the above model settings, we set lambdas = 0.1279, theta = 0.662, window = 2000 and bptt= 2000 .

[1] cache_awd_lstm_lm_1150_wikitext-2 (Val PPL 53.41 Test PPL 51.46)
[1] cache_awd_lstm_lm_1150_wikitext-2 (Val PPL 58.18 Test PPL 56.08)

.. code-block:: console
$ python cache_language_model.py --gpus 0 --model_name awd_lstm_lm_1150
[2] cache_awd_lstm_lm_600_wikitext-2 (Val PPL 64.51 Test PPL 62.19)
[2] cache_awd_lstm_lm_600_wikitext-2 (Val PPL 64.09 Test PPL 61.62)

.. code-block:: console
$ python cache_language_model.py --gpus 0 --model_name awd_lstm_lm_600
[3] cache_standard_lstm_lm_1500_wikitext-2 (Val PPL 65.54 Test PPL 62.79)
[3] cache_standard_lstm_lm_1500_wikitext-2 (Val PPL 73.19 Test PPL 70.91)

.. code-block:: console
$ python cache_language_model.py --gpus 0 --model_name standard_lstm_lm_1500
[4] cache_standard_lstm_lm_650_wikitext-2 (Val PPL 68.47 Test PPL 65.85)
[4] cache_standard_lstm_lm_650_wikitext-2 (Val PPL 69.27 Test PPL 66.39)

.. code-block:: console
$ python cache_language_model.py --gpus 0 --model_name standard_lstm_lm_650
[5] cache_standard_lstm_lm_200_wikitext-2 (Val PPL 77.51 Test PPL 73.74)
[5] cache_standard_lstm_lm_200_wikitext-2 (Val PPL 81.68 Test PPL 77.83)

.. code-block:: console
Expand Down
2 changes: 1 addition & 1 deletion scripts/language_model/word_language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ def train():
trainer.learning_rate))

if args.ntasgd and avg_trigger == 0:
if t > n and val_L > min(valid_losses[-n:]):
if t > n and val_L > min(valid_losses[:-n]):
if param_dict_avg is None:
param_dict_avg = {k.split(model._prefix)[1]: v.data(context[0]).copy()
for k, v in parameters.items()}
Expand Down

0 comments on commit 55a1d46

Please sign in to comment.