Skip to content

Commit

Permalink
Add dropouts to GPT-NeoX (#24680)
Browse files Browse the repository at this point in the history
* add attention dropout, post attention dropout, post mlp dropout to gpt-neox

* fix typo

* add documentation

* fix too long line

* ran Checking/fixing src/transformers/models/gpt_neox/configuration_gpt_neox.py src/transformers/models/gpt_neox/modeling_gpt_neox.py
python utils/custom_init_isort.py
python utils/sort_auto_mappings.py
doc-builder style src/transformers docs/source --max_len 119 --path_to_docs docs/source
python utils/check_doc_toc.py --fix_and_overwrite
running deps_table_update
updating src/transformers/dependency_versions_table.py
python utils/check_copies.py
python utils/check_table.py
python utils/check_dummies.py
python utils/check_repo.py
Checking all models are included.
Checking all models are public.
Checking all models are properly tested.
Checking all objects are properly documented.
Checking all models are in at least one auto class.
Checking all names in auto name mappings are defined.
Checking all keys in auto name mappings are defined in `CONFIG_MAPPING_NAMES`.
Checking all auto mappings could be imported.
Checking all objects are equally (across frameworks) in the main __init__.
python utils/check_inits.py
python utils/check_config_docstrings.py
python utils/check_config_attributes.py
python utils/check_doctest_list.py
python utils/update_metadata.py --check-only
python utils/check_task_guides.py
  • Loading branch information
ZHAOTING authored Jul 6, 2023
1 parent fb3b22c commit 3927404
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
9 changes: 9 additions & 0 deletions src/transformers/models/gpt_neox/configuration_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ class GPTNeoXConfig(PretrainedConfig):
percentage of hidden dimensions to allocate to rotary embeddings
rotary_emb_base (`int`, *optional*, defaults to 10000)
base for computing rotary embeddings frequency
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio probability of the attention score.
hidden_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio of (1) the word embeddings, (2) the post-attention hidden states, and (3) the post-mlp
hidden states.
classifier_dropout (`float`, *optional*, defaults to 0.1):
Argument used when doing token classification, used in the model [`GPTNeoXForTokenClassification`].
Expand Down Expand Up @@ -99,6 +104,8 @@ def __init__(
hidden_act="gelu",
rotary_pct=0.25,
rotary_emb_base=10000,
attention_dropout=0.0,
hidden_dropout=0.0,
classifier_dropout=0.1,
max_position_embeddings=2048,
initializer_range=0.02,
Expand All @@ -120,6 +127,8 @@ def __init__(
self.hidden_act = hidden_act
self.rotary_pct = rotary_pct
self.rotary_emb_base = rotary_emb_base
self.attention_dropout = attention_dropout
self.hidden_dropout = hidden_dropout
self.classifier_dropout = classifier_dropout
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
Expand Down
12 changes: 11 additions & 1 deletion src/transformers/models/gpt_neox/modeling_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ def __init__(self, config):
self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size)
self.dense = nn.Linear(config.hidden_size, config.hidden_size)

self.attention_dropout = nn.Dropout(config.attention_dropout)

def forward(
self,
hidden_states: torch.FloatTensor,
Expand Down Expand Up @@ -245,6 +247,8 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
if head_mask is not None:
attn_weights = attn_weights * head_mask

attn_weights = self.attention_dropout(attn_weights)

attn_output = torch.matmul(attn_weights, value)
return attn_output, attn_weights

Expand Down Expand Up @@ -320,6 +324,8 @@ def __init__(self, config):
self.use_parallel_residual = config.use_parallel_residual
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.post_attention_dropout = nn.Dropout(config.hidden_dropout)
self.post_mlp_dropout = nn.Dropout(config.hidden_dropout)
self.attention = GPTNeoXAttention(config)
self.mlp = GPTNeoXMLP(config)

Expand All @@ -343,19 +349,22 @@ def forward(
output_attentions=output_attentions,
)
attn_output = attention_layer_outputs[0] # output_attn: attn_output, present, (attn_weights)
attn_output = self.post_attention_dropout(attn_output)
outputs = attention_layer_outputs[1:]

if self.use_parallel_residual:
# pseudocode:
# x = x + attn(ln1(x)) + mlp(ln2(x))
mlp_output = self.mlp(self.post_attention_layernorm(hidden_states))
mlp_output = self.post_mlp_dropout(mlp_output)
hidden_states = mlp_output + attn_output + hidden_states
else:
# pseudocode:
# x = x + attn(ln1(x))
# x = x + mlp(ln2(x))
attn_output = attn_output + hidden_states
mlp_output = self.mlp(self.post_attention_layernorm(attn_output))
mlp_output = self.post_mlp_dropout(mlp_output)
hidden_states = mlp_output + attn_output

if use_cache:
Expand Down Expand Up @@ -429,6 +438,7 @@ def __init__(self, config):
self.config = config

self.embed_in = nn.Embedding(config.vocab_size, config.hidden_size)
self.emb_dropout = nn.Dropout(config.hidden_dropout)
self.layers = nn.ModuleList([GPTNeoXLayer(config) for _ in range(config.num_hidden_layers)])
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

Expand Down Expand Up @@ -533,7 +543,7 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_in(input_ids)

hidden_states = inputs_embeds
hidden_states = self.emb_dropout(inputs_embeds)

if self.gradient_checkpointing and self.training:
if use_cache:
Expand Down

0 comments on commit 3927404

Please sign in to comment.