-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathdataset.py
28 lines (19 loc) · 801 Bytes
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch
class SquadDataset(torch.utils.data.Dataset):
def __init__(self, encodings):
self.encodings = encodings
def __getitem__(self, idx):
# idx can be a slice, e.g. 1:100
return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
def __len__(self):
return len(self.encodings.input_ids)
if __name__ == '__main__':
# Minimum sanity check
from preprocess import SquadPreprocessor
from torch.utils.data import DataLoader
sp = SquadPreprocessor()
train_enc, val_enc = sp.get_encodings(random_sample_train=0.001, random_sample_val=0.1, return_tensors="pt")
train_ds = SquadDataset(train_enc)
train_dl = DataLoader(train_ds, batch_size=64)
for train_data in train_dl:
print(len(train_data))