Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Commit

Permalink
Hybridize
Browse files Browse the repository at this point in the history
  • Loading branch information
leezu committed Nov 18, 2019
1 parent d22ca95 commit 3c929b3
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
9 changes: 8 additions & 1 deletion scripts/text_generation/model/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,14 @@ def __init__(self, units, vocab_size, max_length, num_layers, num_heads, dropout
self._final_ln = nn.LayerNorm(prefix='final_ln{}_'.format(i))

def hybrid_forward(self, F, data, states=None): # pylint: disable=arguments-differ
"""
"""Compute
Notes
-----
If you hybridized the GPT2Model by calling net.hybridize(), you cannot
switch between states=None, and states=list_of_NDArray between calls to
the net. The hybridized model will only support the type of states used
during the first call after hybridization.
Parameters
----------
Expand Down
3 changes: 3 additions & 0 deletions scripts/text_generation/sequence_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,9 @@ def generate():
scorer=scorer,
max_length=args.max_length - len(bos_tokens))
inputs, begin_states = get_initial_input_state(decoder, bos_ids)

sampler._decoder.net.hybridize() # Hybridize after we obtained the initial states

# samples have shape (1, beam_size, length), scores have shape (1, beam_size)
samples, scores, valid_lengths = sampler(inputs, begin_states)
samples = samples[0].asnumpy()
Expand Down

0 comments on commit 3c929b3

Please sign in to comment.