diff --git a/scripts/text_generation/model/gpt.py b/scripts/text_generation/model/gpt.py index a46c936ba6..da7bae0321 100644 --- a/scripts/text_generation/model/gpt.py +++ b/scripts/text_generation/model/gpt.py @@ -22,9 +22,8 @@ import os import mxnet as mx -from mxnet.gluon import Block, HybridBlock, nn +from mxnet.gluon import HybridBlock, nn from mxnet.gluon.model_zoo import model_store -import numpy as np from gluonnlp.base import get_home_dir from gluonnlp.model.attention_cell import DotProductAttentionCell @@ -32,7 +31,7 @@ from gluonnlp.model.utils import _load_pretrained_params, _load_vocab -class GPT2SelfAttentionLayer(Block): +class GPT2SelfAttentionLayer(HybridBlock): """Self-attention layer used in OpenAI GPT-2. Parameters @@ -88,28 +87,33 @@ def __init__(self, units, num_heads, dropout=0.0, bias_initializer=bias_initializer, prefix='out_proj_') - def forward(self, data, states=None): # pylint: disable=arguments-differ - batch_size = data.shape[0] - seq_len = data.shape[1] + def hybrid_forward(self, F, data, states=None): # pylint: disable=arguments-differ # Generate mask if states is not None: prev_key, prev_value = states - prev_len = prev_key.shape[2] + + prev_len_range = F.contrib.arange_like(prev_key, axis=2) + data_len_range = F.contrib.arange_like(data, axis=2) + prev_len = F.broadcast_add(F.slice_axis(prev_len_range, axis=0, begin=-1, end=None), + F.ones((1, ))) + + data_pos = F.broadcast_add(F.contrib.arange_like(data, axis=1), prev_len) + all_pos = F.contrib.arange_like(F.concat(prev_len_range, data_len_range, dim=0)) else: prev_key, prev_value = None, None - prev_len = 0 - data_pos = mx.nd.arange(prev_len, prev_len + seq_len, ctx=data.context, dtype=data.dtype) - all_pos = mx.nd.arange(seq_len + prev_len, ctx=data.context, dtype=data.dtype) - mask = mx.nd.broadcast_lesser_equal(all_pos.reshape((1, -1)), data_pos.reshape((-1, 1))) - mask = mx.nd.broadcast_axes(mx.nd.expand_dims(mask, axis=0), axis=0, - size=batch_size * self._num_heads) + data_pos = F.contrib.arange_like(data, axis=1) + all_pos = data_pos + + mask = F.broadcast_lesser_equal(all_pos.reshape((1, -1)), data_pos.reshape((-1, 1))) + mask = F.broadcast_like(F.expand_dims(mask, axis=0), data, lhs_axes=(0, ), rhs_axes=(0, )) + mask = F.concat(*[mask] * self._num_heads, dim=0) # Multi-head attention qkv = self._multi_head_qkv_proj(data) # Shape (batch_size, seq_len, 3 * units) - qkv = mx.nd.swapaxes(qkv, 1, 2) # Shape (batch_size, 3 * units, seq_len) + qkv = F.swapaxes(qkv, 1, 2) # Shape (batch_size, 3 * units, seq_len) # Each has shape (batch_size, units, seq_len) - query, key, value = mx.nd.split(qkv, num_outputs=3, axis=1) + query, key, value = F.split(qkv, num_outputs=3, axis=1) # Map each to have shape (batch_size * num_head, ele_units, seq_len) query = query.reshape(shape=(0, -4, self._num_heads, -1, 0)).reshape( shape=(-1, 0, 0), reverse=True) @@ -117,20 +121,20 @@ def forward(self, data, states=None): # pylint: disable=arguments-differ shape=(-1, 0, 0), reverse=True) value = value.reshape(shape=(0, -4, self._num_heads, -1, 0)).reshape( shape=(-1, 0, 0), reverse=True) - query = mx.nd.swapaxes(query, 1, 2) - key = mx.nd.swapaxes(key, 1, 2) - value = mx.nd.swapaxes(value, 1, 2) + query = F.swapaxes(query, 1, 2) + key = F.swapaxes(key, 1, 2) + value = F.swapaxes(value, 1, 2) if prev_key is not None: - key = mx.nd.concat(prev_key.reshape((-1, 0, 0), reverse=True), - key, dim=1) # Shape (batch_size * num_heads, all_len, ele_units) + # Shape (batch_size * num_heads, all_len, ele_units) + key = F.concat(prev_key.reshape((-1, 0, 0), reverse=True), key, dim=1) if prev_value is not None: - value = mx.nd.concat(prev_value.reshape((-1, 0, 0), reverse=True), - value, dim=1) + value = F.concat(prev_value.reshape((-1, 0, 0), reverse=True), + value, dim=1) # Shape (batch_size * num_heads, all_len, ele_units) out, _ = self._base_attn_cell(query, key, value, mask) - out = mx.nd.transpose(out.reshape((-1, self._num_heads, 0, 0), reverse=True), - axes=(0, 2, 1, 3)).reshape((0, 0, -1)) + out = F.transpose(out.reshape((-1, self._num_heads, 0, 0), reverse=True), + axes=(0, 2, 1, 3)).reshape((0, 0, -1)) out = self._out_proj(out) return out, [key.reshape((-1, self._num_heads, 0, 0), reverse=True), value.reshape((-1, self._num_heads, 0, 0), reverse=True)] @@ -186,7 +190,7 @@ def hybrid_forward(self, F, data): # pylint: disable=arguments-differ return out -class GPT2Model(Block): +class GPT2Model(HybridBlock): """Generic Model for GPT-2. Parameters @@ -223,7 +227,7 @@ def __init__(self, units, vocab_size, max_length, num_layers, num_heads, dropout weight_initializer=mx.init.Normal(0.02)) self._logits_proj = nn.Dense(units=vocab_size, in_units=units, use_bias=False, flatten=False, params=self._embed.params) - self._self_attention_layers = nn.Sequential() + self._self_attention_layers = nn.HybridSequential() self._ffn_layers = nn.HybridSequential() self._attn_ln = nn.HybridSequential() self._ffn_ln = nn.HybridSequential() @@ -237,7 +241,7 @@ def __init__(self, units, vocab_size, max_length, num_layers, num_heads, dropout self._ffn_ln.add(nn.LayerNorm(prefix='ffn_ln{}_'.format(i))) self._final_ln = nn.LayerNorm(prefix='final_ln{}_'.format(i)) - def forward(self, data, states=None): # pylint: disable=arguments-differ + def hybrid_forward(self, F, data, states=None): # pylint: disable=arguments-differ """ Parameters @@ -253,15 +257,18 @@ def forward(self, data, states=None): # pylint: disable=arguments-differ new_states : list of NDArray """ new_states = [] - batch_size, seq_len = data.shape[0], data.shape[1] if states is not None: - prev_len = states[0].shape[1] + prev_key, _ = states + prev_len_range = F.contrib.arange_like(prev_key, axis=2) + prev_len = F.broadcast_add(F.slice_axis(prev_len_range, axis=0, begin=-1, end=None), + F.ones((1, ))) + data_pos = F.broadcast_add(F.contrib.arange_like(data, axis=1), prev_len) else: - prev_len = 0 - assert seq_len + prev_len <= self._max_length - data_pos = mx.nd.arange(prev_len, prev_len + seq_len, ctx=data.context, dtype=np.float32) - data_pos = mx.nd.broadcast_axes(mx.nd.expand_dims(data_pos, axis=0), - axis=0, size=batch_size) + data_pos = F.contrib.arange_like(data, axis=1) + if F is mx.nd: + assert data.shape[1] + prev_key.shape[2] <= self._max_length + data_pos = F.broadcast_like(F.expand_dims(data_pos, axis=0), data, + lhs_axes=(0, ), rhs_axes=(0, )) out = self._embed(data) + self._pos_embed(data_pos) for i in range(self._num_layers): attn_layer = self._self_attention_layers[i]