Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
[torchscript] use device type when concatenating tensors (#4554)
Browse files Browse the repository at this point in the history
Co-authored-by: Maryam Daneshi <[email protected]>
  • Loading branch information
2 people authored and kushalarora committed Jun 15, 2022
1 parent 147a954 commit 0d2e906
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions parlai/torchscript/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.

from collections import defaultdict
from typing import List, Dict, Optional, Tuple
from typing import Dict, List, Optional, Tuple

import torch.jit
from parlai.agents.bart.bart import BartAgent
Expand Down Expand Up @@ -218,9 +218,13 @@ def forward(self, context: str, max_len: int = 128) -> str:
if self.is_bart:
flattened_text_vec = torch.cat(
[
torch.tensor([self.start_idx], dtype=torch.long),
flattened_text_vec,
torch.tensor([self.end_idx], dtype=torch.long),
torch.tensor([self.start_idx], dtype=torch.long).to(
self.get_device()
),
flattened_text_vec.to(self.get_device()),
torch.tensor([self.end_idx], dtype=torch.long).to(
self.get_device()
),
],
dim=0,
)
Expand Down

0 comments on commit 0d2e906

Please sign in to comment.