Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
Parallelize detensorizing context over batch (#3729)
Browse files Browse the repository at this point in the history
* Batch context.tolist()

* Cleanup

* Suggestion
  • Loading branch information
EricMichaelSmith authored Jun 18, 2021
1 parent 76b8804 commit e853e7c
Showing 1 changed file with 31 additions and 1 deletion.
32 changes: 31 additions & 1 deletion parlai/core/torch_generator_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,6 +1005,19 @@ def _get_context(self, batch, batch_idx):
ctxt = batch.full_text_vec[batch_idx]
return ctxt

def _get_batch_context(self, batch):
"""
Version of TGA._get_context() that operates on full batches for speed.
"""
if self.beam_context_block_ngram <= 0:
# We aren't context blocking, return empty tensor of the correct size
return torch.zeros(batch.batchsize, 0, dtype=torch.long)

ctxt = batch.text_vec
if self.beam_block_full_context:
ctxt = batch.full_text_vec
return ctxt

def _get_initial_decoder_input(
self, bsz: int, beam_size: int, dev: torch.device
) -> torch.LongTensor:
Expand Down Expand Up @@ -1092,9 +1105,10 @@ def _generate(
bsz = batch.batchsize
if batch.text_vec is not None:
batchsize = batch.batchsize
batch_context_list = self._get_batch_context(batch).tolist()
beams = [
self._treesearch_factory(dev)
.set_context(self._get_context(batch, batch_idx))
.set_batch_context(batch_context_list, batch_idx)
.set_block_list(self.beam_block_list)
for batch_idx in range(batchsize)
]
Expand Down Expand Up @@ -1284,6 +1298,22 @@ def set_context(self: TSType, context: torch.LongTensor) -> TSType:
self.context = context.tolist()
return self

def set_batch_context(
self: TSType, batch_context_list: List[List[int]], batch_idx: int
) -> TSType:
"""
Version of .set_context() that operates on a single element of a batch.
Set the internal context representation and return self.
:param batch_context_list:
a list of lists, each one containing the context for one member of the batch
:param batch_idx:
index of the batch
"""
self.context = batch_context_list[batch_idx]
return self

def set_block_list(self: TSType, block_list: Optional[SearchBlocklist]) -> TSType:
self.block_list = block_list
return self
Expand Down

0 comments on commit e853e7c

Please sign in to comment.