Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Commit

Permalink
Make GPT2Model a HybridBlock (#1010)
Browse files Browse the repository at this point in the history
* Make GPT2Model a HybridBlock

* Enable gpt2 test for sequence_sampling.py

* Fix

* Hybridize

* Workaround apache/mxnet#16851

* Fix #1015

* Enable test

* Ignore warning about package resolution using __spec__ or __package__
  • Loading branch information
leezu authored Nov 20, 2019
1 parent 5e11334 commit ebfc920
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 40 deletions.
5 changes: 4 additions & 1 deletion pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,7 @@ env =
MXNET_HOME=tests/data

filterwarnings =
error
error
# ignore warning about package resolution using __spec__ or __package__
# can't reproduce locally
ignore:.*can't resolve package from __spec__ or __package__.*:ImportWarning
11 changes: 8 additions & 3 deletions scripts/tests/test_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,14 +135,19 @@ def test_sentiment_analysis_textcnn():
'--dropout', '0.5', '--model_mode', 'rand', '--data_name', 'MR'])
time.sleep(5)

@pytest.mark.skip_master
@pytest.mark.remote_required
@pytest.mark.gpu
@pytest.mark.serial
@pytest.mark.integration
@pytest.mark.parametrize('method', ['beam_search', 'sampling'])
def test_sampling(method):
args = ['--bos', 'I love it', '--beam-size', '2', '--print-num', '1', '--gpu', '0']
@pytest.mark.parametrize('lmmodel', ['awd_lstm_lm_1150', 'gpt2_117m'])
def test_sampling(method, lmmodel):
if 'gpt2' in lmmodel and method == 'beam_search':
return # unsupported
args = [
'--bos', 'I love it', '--beam-size', '2', '--print-num', '1', '--gpu', '0', '--lm-model',
lmmodel
]
if method == 'beam_search':
args.insert(0, 'beam-search')
args.extend(['--k', '50'])
Expand Down
86 changes: 51 additions & 35 deletions scripts/text_generation/model/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,16 @@
import os

import mxnet as mx
from mxnet.gluon import Block, HybridBlock, nn
from mxnet.gluon import HybridBlock, nn
from mxnet.gluon.model_zoo import model_store
import numpy as np

from gluonnlp.base import get_home_dir
from gluonnlp.model.attention_cell import DotProductAttentionCell
from gluonnlp.model.block import GELU
from gluonnlp.model.utils import _load_pretrained_params, _load_vocab


class GPT2SelfAttentionLayer(Block):
class GPT2SelfAttentionLayer(HybridBlock):
"""Self-attention layer used in OpenAI GPT-2.
Parameters
Expand Down Expand Up @@ -88,49 +87,54 @@ def __init__(self, units, num_heads, dropout=0.0,
bias_initializer=bias_initializer,
prefix='out_proj_')

def forward(self, data, states=None): # pylint: disable=arguments-differ
batch_size = data.shape[0]
seq_len = data.shape[1]
def hybrid_forward(self, F, data, states=None): # pylint: disable=arguments-differ
# Generate mask
if states is not None:
prev_key, prev_value = states
prev_len = prev_key.shape[2]

prev_len_range = F.contrib.arange_like(prev_key, axis=2)
data_len_range = F.contrib.arange_like(data, axis=1)
prev_len = F.broadcast_add(F.slice_axis(prev_len_range, axis=0, begin=-1, end=None),
F.ones((1, )))

data_pos = F.broadcast_add(F.contrib.arange_like(data, axis=1), prev_len)
all_pos = F.contrib.arange_like(F.concat(prev_len_range, data_len_range, dim=0))
else:
prev_key, prev_value = None, None
prev_len = 0
data_pos = mx.nd.arange(prev_len, prev_len + seq_len, ctx=data.context, dtype=data.dtype)
all_pos = mx.nd.arange(seq_len + prev_len, ctx=data.context, dtype=data.dtype)
mask = mx.nd.broadcast_lesser_equal(all_pos.reshape((1, -1)), data_pos.reshape((-1, 1)))
mask = mx.nd.broadcast_axes(mx.nd.expand_dims(mask, axis=0), axis=0,
size=batch_size * self._num_heads)
data_pos = F.contrib.arange_like(data, axis=1)
all_pos = data_pos

mask = F.broadcast_lesser_equal(all_pos.reshape((1, -1)), data_pos.reshape((-1, 1)))
mask = F.broadcast_like(F.expand_dims(mask, axis=0), data, lhs_axes=(0, ), rhs_axes=(0, ))
mask = F.concat(*[mask] * self._num_heads, dim=0)

# Multi-head attention
qkv = self._multi_head_qkv_proj(data) # Shape (batch_size, seq_len, 3 * units)
qkv = mx.nd.swapaxes(qkv, 1, 2) # Shape (batch_size, 3 * units, seq_len)
qkv = F.swapaxes(qkv, 1, 2) # Shape (batch_size, 3 * units, seq_len)

# Each has shape (batch_size, units, seq_len)
query, key, value = mx.nd.split(qkv, num_outputs=3, axis=1)
query, key, value = F.split(qkv, num_outputs=3, axis=1)
# Map each to have shape (batch_size * num_head, ele_units, seq_len)
query = query.reshape(shape=(0, -4, self._num_heads, -1, 0)).reshape(
shape=(-1, 0, 0), reverse=True)
key = key.reshape(shape=(0, -4, self._num_heads, -1, 0)).reshape(
shape=(-1, 0, 0), reverse=True)
value = value.reshape(shape=(0, -4, self._num_heads, -1, 0)).reshape(
shape=(-1, 0, 0), reverse=True)
query = mx.nd.swapaxes(query, 1, 2)
key = mx.nd.swapaxes(key, 1, 2)
value = mx.nd.swapaxes(value, 1, 2)
query = F.swapaxes(query, 1, 2)
key = F.swapaxes(key, 1, 2)
value = F.swapaxes(value, 1, 2)
if prev_key is not None:
key = mx.nd.concat(prev_key.reshape((-1, 0, 0), reverse=True),
key, dim=1) # Shape (batch_size * num_heads, all_len, ele_units)
# Shape (batch_size * num_heads, all_len, ele_units)
key = F.concat(prev_key.reshape((-1, 0, 0), reverse=True), key, dim=1)
if prev_value is not None:
value = mx.nd.concat(prev_value.reshape((-1, 0, 0), reverse=True),
value, dim=1)
value = F.concat(prev_value.reshape((-1, 0, 0), reverse=True),
value, dim=1)

# Shape (batch_size * num_heads, all_len, ele_units)
out, _ = self._base_attn_cell(query, key, value, mask)
out = mx.nd.transpose(out.reshape((-1, self._num_heads, 0, 0), reverse=True),
axes=(0, 2, 1, 3)).reshape((0, 0, -1))
out = F.transpose(out.reshape((-1, self._num_heads, 0, 0), reverse=True),
axes=(0, 2, 1, 3)).reshape((0, 0, -1))
out = self._out_proj(out)
return out, [key.reshape((-1, self._num_heads, 0, 0), reverse=True),
value.reshape((-1, self._num_heads, 0, 0), reverse=True)]
Expand Down Expand Up @@ -186,7 +190,7 @@ def hybrid_forward(self, F, data): # pylint: disable=arguments-differ
return out


class GPT2Model(Block):
class GPT2Model(HybridBlock):
"""Generic Model for GPT-2.
Parameters
Expand Down Expand Up @@ -223,7 +227,7 @@ def __init__(self, units, vocab_size, max_length, num_layers, num_heads, dropout
weight_initializer=mx.init.Normal(0.02))
self._logits_proj = nn.Dense(units=vocab_size, in_units=units, use_bias=False,
flatten=False, params=self._embed.params)
self._self_attention_layers = nn.Sequential()
self._self_attention_layers = nn.HybridSequential()
self._ffn_layers = nn.HybridSequential()
self._attn_ln = nn.HybridSequential()
self._ffn_ln = nn.HybridSequential()
Expand All @@ -237,8 +241,15 @@ def __init__(self, units, vocab_size, max_length, num_layers, num_heads, dropout
self._ffn_ln.add(nn.LayerNorm(prefix='ffn_ln{}_'.format(i)))
self._final_ln = nn.LayerNorm(prefix='final_ln{}_'.format(i))

def forward(self, data, states=None): # pylint: disable=arguments-differ
"""
def hybrid_forward(self, F, data, states=None): # pylint: disable=arguments-differ
"""Compute
Notes
-----
If you hybridized the GPT2Model by calling net.hybridize(), you cannot
switch between states=None, and states=list_of_NDArray between calls to
the net. The hybridized model will only support the type of states used
during the first call after hybridization.
Parameters
----------
Expand All @@ -253,15 +264,20 @@ def forward(self, data, states=None): # pylint: disable=arguments-differ
new_states : list of NDArray
"""
new_states = []
batch_size, seq_len = data.shape[0], data.shape[1]
if states is not None:
prev_len = states[0].shape[1]
prev_len_range = F.contrib.arange_like(states[0], axis=2).astype('int32')
prev_len = F.broadcast_add(F.slice_axis(prev_len_range, axis=0, begin=-1, end=None),
F.ones((1, ), dtype='int32'))
data_pos = F.broadcast_add(
F.contrib.arange_like(data, axis=1).astype('int32'), prev_len)
else:
prev_len = 0
assert seq_len + prev_len <= self._max_length
data_pos = mx.nd.arange(prev_len, prev_len + seq_len, ctx=data.context, dtype=np.float32)
data_pos = mx.nd.broadcast_axes(mx.nd.expand_dims(data_pos, axis=0),
axis=0, size=batch_size)
data_pos = F.contrib.arange_like(data, axis=1).astype('int32')
if F is mx.nd:
length = data.shape[1] + (states[0].shape[2] if states is not None else 0)
assert length <= self._max_length
# astype cast to workaround https://github.com/apache/incubator-mxnet/issues/16851
data_pos = F.broadcast_like(F.expand_dims(data_pos, axis=0), data.astype('int32'),
lhs_axes=(0, ), rhs_axes=(0, ))
out = self._embed(data) + self._pos_embed(data_pos)
for i in range(self._num_layers):
attn_layer = self._self_attention_layers[i]
Expand Down
5 changes: 4 additions & 1 deletion scripts/text_generation/sequence_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
p.add_argument('--lm-model', type=str, default='awd_lstm_lm_1150',
help='type of the pre-trained model to load, can be "standard_lstm_lm_200", '
'"standard_lstm_lm_650", "standard_lstm_lm_1500", '
'"awd_lstm_lm_1150", etc.')
'"awd_lstm_lm_1150", "gpt2_117m", "gpt2_345m", etc.')
p.add_argument('--max-length', type=int, default=20, help='Maximum sentence length.')
p.add_argument('--print-num', type=int, default=3, help='Number of sentences to display.')
p.add_argument('--bos', type=str, default='I think this works')
Expand Down Expand Up @@ -170,6 +170,9 @@ def generate():
scorer=scorer,
max_length=args.max_length - len(bos_tokens))
inputs, begin_states = get_initial_input_state(decoder, bos_ids)

sampler._decoder.net.hybridize() # Hybridize after we obtained the initial states

# samples have shape (1, beam_size, length), scores have shape (1, beam_size)
samples, scores, valid_lengths = sampler(inputs, begin_states)
samples = samples[0].asnumpy()
Expand Down

0 comments on commit ebfc920

Please sign in to comment.