Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Parallelize detensorizing context over batch #3729

Merged
merged 3 commits into from
Jun 18, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion parlai/core/torch_generator_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,6 +1005,19 @@ def _get_context(self, batch, batch_idx):
ctxt = batch.full_text_vec[batch_idx]
return ctxt

def _get_batch_context(self, batch):
"""
Version of TGA._get_context() that operates on full batches for speed.
"""
if self.beam_context_block_ngram <= 0:
# We aren't context blocking, return empty tensor of the correct size
return torch.zeros(batch.batchsize, 0, dtype=torch.long)

ctxt = batch.text_vec
if self.beam_block_full_context:
ctxt = batch.full_text_vec
return ctxt

def _get_initial_decoder_input(
self, bsz: int, beam_size: int, dev: torch.device
) -> torch.LongTensor:
Expand Down Expand Up @@ -1092,9 +1105,10 @@ def _generate(
bsz = batch.batchsize
if batch.text_vec is not None:
batchsize = batch.batchsize
batch_context_list = self._get_batch_context(batch).tolist()
beams = [
self._treesearch_factory(dev)
.set_context(self._get_context(batch, batch_idx))
.set_batch_context(batch_context_list, batch_idx)
.set_block_list(self.beam_block_list)
for batch_idx in range(batchsize)
]
Expand Down Expand Up @@ -1284,6 +1298,22 @@ def set_context(self: TSType, context: torch.LongTensor) -> TSType:
self.context = context.tolist()
return self

def set_batch_context(
self: TSType, batch_context_list: List[List[int]], batch_idx: int
) -> TSType:
"""
Version of .set_context() that operates on a single element of a batch.

Set the internal context representation and return self.

:param batch_context_list:
a list of lists, each one containing the context for one member of the batch
:param batch_idx:
index of the batch
"""
self.context = batch_context_list[batch_idx]
return self

def set_block_list(self: TSType, block_list: Optional[SearchBlocklist]) -> TSType:
self.block_list = block_list
return self
Expand Down