Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Commit

Permalink
Make GPT2Model a HybridBlock
Browse files Browse the repository at this point in the history
  • Loading branch information
leezu committed Nov 15, 2019
1 parent e09281c commit 9c3c32e
Showing 1 changed file with 41 additions and 34 deletions.
75 changes: 41 additions & 34 deletions scripts/text_generation/model/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,16 @@
import os

import mxnet as mx
from mxnet.gluon import Block, HybridBlock, nn
from mxnet.gluon import HybridBlock, nn
from mxnet.gluon.model_zoo import model_store
import numpy as np

from gluonnlp.base import get_home_dir
from gluonnlp.model.attention_cell import DotProductAttentionCell
from gluonnlp.model.block import GELU
from gluonnlp.model.utils import _load_pretrained_params, _load_vocab


class GPT2SelfAttentionLayer(Block):
class GPT2SelfAttentionLayer(HybridBlock):
"""Self-attention layer used in OpenAI GPT-2.
Parameters
Expand Down Expand Up @@ -88,49 +87,54 @@ def __init__(self, units, num_heads, dropout=0.0,
bias_initializer=bias_initializer,
prefix='out_proj_')

def forward(self, data, states=None): # pylint: disable=arguments-differ
batch_size = data.shape[0]
seq_len = data.shape[1]
def hybrid_forward(self, F, data, states=None): # pylint: disable=arguments-differ
# Generate mask
if states is not None:
prev_key, prev_value = states
prev_len = prev_key.shape[2]

prev_len_range = F.contrib.arange_like(prev_key, axis=2)
data_len_range = F.contrib.arange_like(data, axis=2)
prev_len = F.broadcast_add(F.slice_axis(prev_len_range, axis=0, begin=-1, end=None),
F.ones((1, )))

data_pos = F.broadcast_add(F.contrib.arange_like(data, axis=1), prev_len)
all_pos = F.contrib.arange_like(F.concat(prev_len_range, data_len_range, dim=0))
else:
prev_key, prev_value = None, None
prev_len = 0
data_pos = mx.nd.arange(prev_len, prev_len + seq_len, ctx=data.context, dtype=data.dtype)
all_pos = mx.nd.arange(seq_len + prev_len, ctx=data.context, dtype=data.dtype)
mask = mx.nd.broadcast_lesser_equal(all_pos.reshape((1, -1)), data_pos.reshape((-1, 1)))
mask = mx.nd.broadcast_axes(mx.nd.expand_dims(mask, axis=0), axis=0,
size=batch_size * self._num_heads)
data_pos = F.contrib.arange_like(data, axis=1)
all_pos = data_pos

mask = F.broadcast_lesser_equal(all_pos.reshape((1, -1)), data_pos.reshape((-1, 1)))
mask = F.broadcast_like(F.expand_dims(mask, axis=0), data, lhs_axes=(0, ), rhs_axes=(0, ))
mask = F.concat(*[mask] * self._num_heads, dim=0)

# Multi-head attention
qkv = self._multi_head_qkv_proj(data) # Shape (batch_size, seq_len, 3 * units)
qkv = mx.nd.swapaxes(qkv, 1, 2) # Shape (batch_size, 3 * units, seq_len)
qkv = F.swapaxes(qkv, 1, 2) # Shape (batch_size, 3 * units, seq_len)

# Each has shape (batch_size, units, seq_len)
query, key, value = mx.nd.split(qkv, num_outputs=3, axis=1)
query, key, value = F.split(qkv, num_outputs=3, axis=1)
# Map each to have shape (batch_size * num_head, ele_units, seq_len)
query = query.reshape(shape=(0, -4, self._num_heads, -1, 0)).reshape(
shape=(-1, 0, 0), reverse=True)
key = key.reshape(shape=(0, -4, self._num_heads, -1, 0)).reshape(
shape=(-1, 0, 0), reverse=True)
value = value.reshape(shape=(0, -4, self._num_heads, -1, 0)).reshape(
shape=(-1, 0, 0), reverse=True)
query = mx.nd.swapaxes(query, 1, 2)
key = mx.nd.swapaxes(key, 1, 2)
value = mx.nd.swapaxes(value, 1, 2)
query = F.swapaxes(query, 1, 2)
key = F.swapaxes(key, 1, 2)
value = F.swapaxes(value, 1, 2)
if prev_key is not None:
key = mx.nd.concat(prev_key.reshape((-1, 0, 0), reverse=True),
key, dim=1) # Shape (batch_size * num_heads, all_len, ele_units)
# Shape (batch_size * num_heads, all_len, ele_units)
key = F.concat(prev_key.reshape((-1, 0, 0), reverse=True), key, dim=1)
if prev_value is not None:
value = mx.nd.concat(prev_value.reshape((-1, 0, 0), reverse=True),
value, dim=1)
value = F.concat(prev_value.reshape((-1, 0, 0), reverse=True),
value, dim=1)

# Shape (batch_size * num_heads, all_len, ele_units)
out, _ = self._base_attn_cell(query, key, value, mask)
out = mx.nd.transpose(out.reshape((-1, self._num_heads, 0, 0), reverse=True),
axes=(0, 2, 1, 3)).reshape((0, 0, -1))
out = F.transpose(out.reshape((-1, self._num_heads, 0, 0), reverse=True),
axes=(0, 2, 1, 3)).reshape((0, 0, -1))
out = self._out_proj(out)
return out, [key.reshape((-1, self._num_heads, 0, 0), reverse=True),
value.reshape((-1, self._num_heads, 0, 0), reverse=True)]
Expand Down Expand Up @@ -186,7 +190,7 @@ def hybrid_forward(self, F, data): # pylint: disable=arguments-differ
return out


class GPT2Model(Block):
class GPT2Model(HybridBlock):
"""Generic Model for GPT-2.
Parameters
Expand Down Expand Up @@ -223,7 +227,7 @@ def __init__(self, units, vocab_size, max_length, num_layers, num_heads, dropout
weight_initializer=mx.init.Normal(0.02))
self._logits_proj = nn.Dense(units=vocab_size, in_units=units, use_bias=False,
flatten=False, params=self._embed.params)
self._self_attention_layers = nn.Sequential()
self._self_attention_layers = nn.HybridSequential()
self._ffn_layers = nn.HybridSequential()
self._attn_ln = nn.HybridSequential()
self._ffn_ln = nn.HybridSequential()
Expand All @@ -237,7 +241,7 @@ def __init__(self, units, vocab_size, max_length, num_layers, num_heads, dropout
self._ffn_ln.add(nn.LayerNorm(prefix='ffn_ln{}_'.format(i)))
self._final_ln = nn.LayerNorm(prefix='final_ln{}_'.format(i))

def forward(self, data, states=None): # pylint: disable=arguments-differ
def hybrid_forward(self, F, data, states=None): # pylint: disable=arguments-differ
"""
Parameters
Expand All @@ -253,15 +257,18 @@ def forward(self, data, states=None): # pylint: disable=arguments-differ
new_states : list of NDArray
"""
new_states = []
batch_size, seq_len = data.shape[0], data.shape[1]
if states is not None:
prev_len = states[0].shape[1]
prev_key, _ = states
prev_len_range = F.contrib.arange_like(prev_key, axis=2)
prev_len = F.broadcast_add(F.slice_axis(prev_len_range, axis=0, begin=-1, end=None),
F.ones((1, )))
data_pos = F.broadcast_add(F.contrib.arange_like(data, axis=1), prev_len)
else:
prev_len = 0
assert seq_len + prev_len <= self._max_length
data_pos = mx.nd.arange(prev_len, prev_len + seq_len, ctx=data.context, dtype=np.float32)
data_pos = mx.nd.broadcast_axes(mx.nd.expand_dims(data_pos, axis=0),
axis=0, size=batch_size)
data_pos = F.contrib.arange_like(data, axis=1)
if F is mx.nd:
assert data.shape[1] + prev_key.shape[2] <= self._max_length
data_pos = F.broadcast_like(F.expand_dims(data_pos, axis=0), data,
lhs_axes=(0, ), rhs_axes=(0, ))
out = self._embed(data) + self._pos_embed(data_pos)
for i in range(self._num_layers):
attn_layer = self._self_attention_layers[i]
Expand Down

0 comments on commit 9c3c32e

Please sign in to comment.