Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Make GPT2Model a HybridBlock #1010

Merged
merged 8 commits into from
Nov 20, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,7 @@ env =
MXNET_HOME=tests/data

filterwarnings =
error
error
# ignore warning about package resolution using __spec__ or __package__
# can't reproduce locally
ignore:.*can't resolve package from __spec__ or __package__.*:ImportWarning
11 changes: 8 additions & 3 deletions scripts/tests/test_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,14 +135,19 @@ def test_sentiment_analysis_textcnn():
'--dropout', '0.5', '--model_mode', 'rand', '--data_name', 'MR'])
time.sleep(5)

@pytest.mark.skip_master
@pytest.mark.remote_required
@pytest.mark.gpu
@pytest.mark.serial
@pytest.mark.integration
@pytest.mark.parametrize('method', ['beam_search', 'sampling'])
def test_sampling(method):
args = ['--bos', 'I love it', '--beam-size', '2', '--print-num', '1', '--gpu', '0']
@pytest.mark.parametrize('lmmodel', ['awd_lstm_lm_1150', 'gpt2_117m'])
def test_sampling(method, lmmodel):
if 'gpt2' in lmmodel and method == 'beam_search':
return # unsupported
args = [
'--bos', 'I love it', '--beam-size', '2', '--print-num', '1', '--gpu', '0', '--lm-model',
lmmodel
]
if method == 'beam_search':
args.insert(0, 'beam-search')
args.extend(['--k', '50'])
Expand Down
86 changes: 51 additions & 35 deletions scripts/text_generation/model/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,16 @@
import os

import mxnet as mx
from mxnet.gluon import Block, HybridBlock, nn
from mxnet.gluon import HybridBlock, nn
from mxnet.gluon.model_zoo import model_store
import numpy as np

from gluonnlp.base import get_home_dir
from gluonnlp.model.attention_cell import DotProductAttentionCell
from gluonnlp.model.block import GELU
from gluonnlp.model.utils import _load_pretrained_params, _load_vocab


class GPT2SelfAttentionLayer(Block):
class GPT2SelfAttentionLayer(HybridBlock):
"""Self-attention layer used in OpenAI GPT-2.

Parameters
Expand Down Expand Up @@ -88,49 +87,54 @@ def __init__(self, units, num_heads, dropout=0.0,
bias_initializer=bias_initializer,
prefix='out_proj_')

def forward(self, data, states=None): # pylint: disable=arguments-differ
batch_size = data.shape[0]
seq_len = data.shape[1]
def hybrid_forward(self, F, data, states=None): # pylint: disable=arguments-differ
# Generate mask
if states is not None:
prev_key, prev_value = states
prev_len = prev_key.shape[2]

prev_len_range = F.contrib.arange_like(prev_key, axis=2)
data_len_range = F.contrib.arange_like(data, axis=1)
prev_len = F.broadcast_add(F.slice_axis(prev_len_range, axis=0, begin=-1, end=None),
F.ones((1, )))

data_pos = F.broadcast_add(F.contrib.arange_like(data, axis=1), prev_len)
all_pos = F.contrib.arange_like(F.concat(prev_len_range, data_len_range, dim=0))
else:
prev_key, prev_value = None, None
prev_len = 0
data_pos = mx.nd.arange(prev_len, prev_len + seq_len, ctx=data.context, dtype=data.dtype)
all_pos = mx.nd.arange(seq_len + prev_len, ctx=data.context, dtype=data.dtype)
mask = mx.nd.broadcast_lesser_equal(all_pos.reshape((1, -1)), data_pos.reshape((-1, 1)))
mask = mx.nd.broadcast_axes(mx.nd.expand_dims(mask, axis=0), axis=0,
size=batch_size * self._num_heads)
data_pos = F.contrib.arange_like(data, axis=1)
all_pos = data_pos

mask = F.broadcast_lesser_equal(all_pos.reshape((1, -1)), data_pos.reshape((-1, 1)))
mask = F.broadcast_like(F.expand_dims(mask, axis=0), data, lhs_axes=(0, ), rhs_axes=(0, ))
mask = F.concat(*[mask] * self._num_heads, dim=0)

# Multi-head attention
qkv = self._multi_head_qkv_proj(data) # Shape (batch_size, seq_len, 3 * units)
qkv = mx.nd.swapaxes(qkv, 1, 2) # Shape (batch_size, 3 * units, seq_len)
qkv = F.swapaxes(qkv, 1, 2) # Shape (batch_size, 3 * units, seq_len)

# Each has shape (batch_size, units, seq_len)
query, key, value = mx.nd.split(qkv, num_outputs=3, axis=1)
query, key, value = F.split(qkv, num_outputs=3, axis=1)
# Map each to have shape (batch_size * num_head, ele_units, seq_len)
query = query.reshape(shape=(0, -4, self._num_heads, -1, 0)).reshape(
shape=(-1, 0, 0), reverse=True)
key = key.reshape(shape=(0, -4, self._num_heads, -1, 0)).reshape(
shape=(-1, 0, 0), reverse=True)
value = value.reshape(shape=(0, -4, self._num_heads, -1, 0)).reshape(
shape=(-1, 0, 0), reverse=True)
query = mx.nd.swapaxes(query, 1, 2)
key = mx.nd.swapaxes(key, 1, 2)
value = mx.nd.swapaxes(value, 1, 2)
query = F.swapaxes(query, 1, 2)
key = F.swapaxes(key, 1, 2)
value = F.swapaxes(value, 1, 2)
if prev_key is not None:
key = mx.nd.concat(prev_key.reshape((-1, 0, 0), reverse=True),
key, dim=1) # Shape (batch_size * num_heads, all_len, ele_units)
# Shape (batch_size * num_heads, all_len, ele_units)
key = F.concat(prev_key.reshape((-1, 0, 0), reverse=True), key, dim=1)
if prev_value is not None:
value = mx.nd.concat(prev_value.reshape((-1, 0, 0), reverse=True),
value, dim=1)
value = F.concat(prev_value.reshape((-1, 0, 0), reverse=True),
value, dim=1)

# Shape (batch_size * num_heads, all_len, ele_units)
out, _ = self._base_attn_cell(query, key, value, mask)
out = mx.nd.transpose(out.reshape((-1, self._num_heads, 0, 0), reverse=True),
axes=(0, 2, 1, 3)).reshape((0, 0, -1))
out = F.transpose(out.reshape((-1, self._num_heads, 0, 0), reverse=True),
axes=(0, 2, 1, 3)).reshape((0, 0, -1))
out = self._out_proj(out)
return out, [key.reshape((-1, self._num_heads, 0, 0), reverse=True),
value.reshape((-1, self._num_heads, 0, 0), reverse=True)]
Expand Down Expand Up @@ -186,7 +190,7 @@ def hybrid_forward(self, F, data): # pylint: disable=arguments-differ
return out


class GPT2Model(Block):
class GPT2Model(HybridBlock):
"""Generic Model for GPT-2.

Parameters
Expand Down Expand Up @@ -223,7 +227,7 @@ def __init__(self, units, vocab_size, max_length, num_layers, num_heads, dropout
weight_initializer=mx.init.Normal(0.02))
self._logits_proj = nn.Dense(units=vocab_size, in_units=units, use_bias=False,
flatten=False, params=self._embed.params)
self._self_attention_layers = nn.Sequential()
self._self_attention_layers = nn.HybridSequential()
self._ffn_layers = nn.HybridSequential()
self._attn_ln = nn.HybridSequential()
self._ffn_ln = nn.HybridSequential()
Expand All @@ -237,8 +241,15 @@ def __init__(self, units, vocab_size, max_length, num_layers, num_heads, dropout
self._ffn_ln.add(nn.LayerNorm(prefix='ffn_ln{}_'.format(i)))
self._final_ln = nn.LayerNorm(prefix='final_ln{}_'.format(i))

def forward(self, data, states=None): # pylint: disable=arguments-differ
"""
def hybrid_forward(self, F, data, states=None): # pylint: disable=arguments-differ
"""Compute

Notes
-----
If you hybridized the GPT2Model by calling net.hybridize(), you cannot
switch between states=None, and states=list_of_NDArray between calls to
the net. The hybridized model will only support the type of states used
during the first call after hybridization.

Parameters
----------
Expand All @@ -253,15 +264,20 @@ def forward(self, data, states=None): # pylint: disable=arguments-differ
new_states : list of NDArray
"""
new_states = []
batch_size, seq_len = data.shape[0], data.shape[1]
if states is not None:
prev_len = states[0].shape[1]
prev_len_range = F.contrib.arange_like(states[0], axis=2).astype('int32')
prev_len = F.broadcast_add(F.slice_axis(prev_len_range, axis=0, begin=-1, end=None),
F.ones((1, ), dtype='int32'))
data_pos = F.broadcast_add(
F.contrib.arange_like(data, axis=1).astype('int32'), prev_len)
else:
prev_len = 0
assert seq_len + prev_len <= self._max_length
data_pos = mx.nd.arange(prev_len, prev_len + seq_len, ctx=data.context, dtype=np.float32)
data_pos = mx.nd.broadcast_axes(mx.nd.expand_dims(data_pos, axis=0),
axis=0, size=batch_size)
data_pos = F.contrib.arange_like(data, axis=1).astype('int32')
if F is mx.nd:
length = data.shape[1] + (states[0].shape[2] if states is not None else 0)
assert length <= self._max_length
# astype cast to workaround https://github.com/apache/incubator-mxnet/issues/16851
data_pos = F.broadcast_like(F.expand_dims(data_pos, axis=0), data.astype('int32'),
lhs_axes=(0, ), rhs_axes=(0, ))
out = self._embed(data) + self._pos_embed(data_pos)
for i in range(self._num_layers):
attn_layer = self._self_attention_layers[i]
Expand Down
5 changes: 4 additions & 1 deletion scripts/text_generation/sequence_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
p.add_argument('--lm-model', type=str, default='awd_lstm_lm_1150',
help='type of the pre-trained model to load, can be "standard_lstm_lm_200", '
'"standard_lstm_lm_650", "standard_lstm_lm_1500", '
'"awd_lstm_lm_1150", etc.')
'"awd_lstm_lm_1150", "gpt2_117m", "gpt2_345m", etc.')
p.add_argument('--max-length', type=int, default=20, help='Maximum sentence length.')
p.add_argument('--print-num', type=int, default=3, help='Number of sentences to display.')
p.add_argument('--bos', type=str, default='I think this works')
Expand Down Expand Up @@ -170,6 +170,9 @@ def generate():
scorer=scorer,
max_length=args.max_length - len(bos_tokens))
inputs, begin_states = get_initial_input_state(decoder, bos_ids)

sampler._decoder.net.hybridize() # Hybridize after we obtained the initial states

# samples have shape (1, beam_size, length), scores have shape (1, beam_size)
samples, scores, valid_lengths = sampler(inputs, begin_states)
samples = samples[0].asnumpy()
Expand Down