-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset.py
64 lines (52 loc) · 1.9 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader
from pytorch_lightning import LightningDataModule
class GECDataset(Dataset):
def __init__(self, input_ids, attention_mask, labels=None):
self.input_ids = input_ids
self.attention_mask = attention_mask
self.labels = labels
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
input_ids = torch.tensor(self.input_ids[idx])
attention_mask = torch.tensor(self.attention_mask[idx])
if self.labels is not None:
labels = torch.tensor(self.labels[idx], dtype=torch.long)
return input_ids, attention_mask, labels
else:
return input_ids, attention_mask
class GECDataModule(LightningDataModule):
def __init__(self, train_df, val_df, config):
super().__init__()
self.train_df = train_df
self.val_df = val_df
self.config = config
def setup(self, stage=None):
self.train_dataset = GECDataset(
self.train_df.input_ids.values,
self.train_df.attention_mask.values,
self.train_df.label.values,
)
self.val_dataset = GECDataset(
self.val_df.input_ids.values,
self.val_df.attention_mask.values,
self.val_df.label.values,
)
def train_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=self.config["train_batch_size"],
num_workers=self.config["num_workers"],
shuffle=True,
pin_memory=False,
)
def val_dataloader(self):
return DataLoader(
self.val_dataset,
batch_size=self.config["val_batch_size"],
num_workers=self.config["num_workers"],
shuffle=True,
pin_memory=False,
)