Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

[BUGFIX] fix NTA implementation #1277

Merged
merged 5 commits into from
Aug 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions scripts/language_model/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ The dataset used for training the models is wikitext-2.
+---------------+----------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------+
| Weight_drop | 0.5 | 0.2 | 0 | 0 | 0 |
+---------------+----------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------+
| Val PPL | 68.71 | 84.89 | 86.51 | 90.96 | 107.59 |
| Val PPL | 71.78 | 80.11 | 86.28 | 91.30 | 108.17 |
+---------------+----------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------+
| Test PPL | 65.62 | 80.67 | 82.29 | 86.91 | 101.64 |
| Test PPL | 68.55 | 76.14 | 81.99 | 85.82 | 102.49 |
+---------------+----------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------+
| Command | [1] | [2] | [3] | [4] | [5] |
+---------------+----------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------+
Expand All @@ -47,31 +47,31 @@ The dataset used for training the models is wikitext-2.

For all the above model settings, we set Tied = True and NTASGD = True .

[1] awd_lstm_lm_1150_wikitext-2 (Val PPL 68.71 Test PPL 65.62 )
[1] awd_lstm_lm_1150_wikitext-2 (Val PPL 71.78 Test PPL 68.55 )

.. code-block:: console

$ python word_language_model.py --gpu 0 --tied --ntasgd --lr_update_interval 30 --lr_update_factor 0.1 --save awd_lstm_lm_1150_wikitext-2

[2] awd_lstm_lm_600_wikitext-2 (Val PPL 84.89 Test PPL 80.67)
[2] awd_lstm_lm_600_wikitext-2 (Val PPL 80.11 Test PPL 76.14)

.. code-block:: console

$ python word_language_model.py --gpu 0 --emsize 200 --nhid 600 --epochs 750 --dropout 0.2 --dropout_h 0.1 --dropout_i 0.3 --dropout_e 0.05 --weight_drop 0.2 --tied --ntasgd --lr_update_interval 30 --lr_update_factor 0.1 --save awd_lstm_lm_600_wikitext-2

[3] standard_lstm_lm_1500_wikitext-2 (Val PPL 86.51 Test PPL 82.29)
[3] standard_lstm_lm_1500_wikitext-2 (Val PPL 86.28 Test PPL 81.99)

.. code-block:: console

$ python word_language_model.py --gpu 0 --emsize 1500 --nhid 1500 --nlayers 2 --lr 20 --epochs 750 --batch_size 20 --bptt 35 --dropout 0.65 --dropout_h 0 --dropout_i 0 --dropout_e 0 --weight_drop 0 --tied --wd 0 --alpha 0 --beta 0 --ntasgd --lr_update_interval 30 --lr_update_factor 0.1 --save standard_lstm_lm_1500_wikitext-2

[4] standard_lstm_lm_650_wikitext-2 (Val PPL 90.96 Test PPL 86.91)
[4] standard_lstm_lm_650_wikitext-2 (Val PPL 91.30 Test PPL 85.82)

.. code-block:: console

$ python word_language_model.py --gpu 0 --emsize 650 --nhid 650 --nlayers 2 --lr 20 --epochs 750 --batch_size 20 --bptt 35 --dropout 0.5 --dropout_h 0 --dropout_i 0 --dropout_e 0 --weight_drop 0 --tied --wd 0 --alpha 0 --beta 0 --ntasgd --lr_update_interval 30 --lr_update_factor 0.1 --save standard_lstm_lm_650_wikitext-2

[5] standard_lstm_lm_200_wikitext-2 (Val PPL 107.59 Test PPL 101.64)
[5] standard_lstm_lm_200_wikitext-2 (Val PPL 108.17 Test PPL 102.49)

.. code-block:: console

Expand All @@ -93,9 +93,9 @@ The dataset used for training the models is wikitext-2.
+=====================+===================================================================================================================================+==================================================================================================================================+========================================================================================================================================+=======================================================================================================================================+=======================================================================================================================================+
| Pre-trained setting | Refer to: awd_lstm_lm_1150_wikitext-2 | Refer to: awd_lstm_lm_600_wikitext-2 | Refer to: standard_lstm_lm_1500_wikitext-2 | Refer to: standard_lstm_lm_650_wikitext-2 | Refer to: standard_lstm_lm_200_wikitext-2 |
+---------------------+-----------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| Val PPL | 53.41 | 64.51 | 65.54 | 68.47 | 77.51 |
| Val PPL | 58.18 | 64.09 | 73.19 | 69.27 | 81.68 |
+---------------------+-----------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| Test PPL | 51.46 | 62.19 | 62.79 | 65.85 | 73.74 |
| Test PPL | 56.08 | 61.62 | 70.91 | 66.39 | 77.83 |
+---------------------+-----------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| Command | [1] | [2] | [3] | [4] | [5] |
+---------------------+-----------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
Expand All @@ -104,31 +104,31 @@ The dataset used for training the models is wikitext-2.

For all the above model settings, we set lambdas = 0.1279, theta = 0.662, window = 2000 and bptt= 2000 .

[1] cache_awd_lstm_lm_1150_wikitext-2 (Val PPL 53.41 Test PPL 51.46)
[1] cache_awd_lstm_lm_1150_wikitext-2 (Val PPL 58.18 Test PPL 56.08)

.. code-block:: console

$ python cache_language_model.py --gpus 0 --model_name awd_lstm_lm_1150

[2] cache_awd_lstm_lm_600_wikitext-2 (Val PPL 64.51 Test PPL 62.19)
[2] cache_awd_lstm_lm_600_wikitext-2 (Val PPL 64.09 Test PPL 61.62)

.. code-block:: console

$ python cache_language_model.py --gpus 0 --model_name awd_lstm_lm_600

[3] cache_standard_lstm_lm_1500_wikitext-2 (Val PPL 65.54 Test PPL 62.79)
[3] cache_standard_lstm_lm_1500_wikitext-2 (Val PPL 73.19 Test PPL 70.91)

.. code-block:: console

$ python cache_language_model.py --gpus 0 --model_name standard_lstm_lm_1500

[4] cache_standard_lstm_lm_650_wikitext-2 (Val PPL 68.47 Test PPL 65.85)
[4] cache_standard_lstm_lm_650_wikitext-2 (Val PPL 69.27 Test PPL 66.39)

.. code-block:: console

$ python cache_language_model.py --gpus 0 --model_name standard_lstm_lm_650

[5] cache_standard_lstm_lm_200_wikitext-2 (Val PPL 77.51 Test PPL 73.74)
[5] cache_standard_lstm_lm_200_wikitext-2 (Val PPL 81.68 Test PPL 77.83)

.. code-block:: console

Expand Down
2 changes: 1 addition & 1 deletion scripts/language_model/word_language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ def train():
trainer.learning_rate))

if args.ntasgd and avg_trigger == 0:
if t > n and val_L > min(valid_losses[-n:]):
if t > n and val_L > min(valid_losses[:-n]):
if param_dict_avg is None:
param_dict_avg = {k.split(model._prefix)[1]: v.data(context[0]).copy()
for k, v in parameters.items()}
Expand Down