-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdata.py
More file actions
103 lines (90 loc) · 3.98 KB
/
data.py
File metadata and controls
103 lines (90 loc) · 3.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from pathlib import Path
import json
import random
from torch.utils.data import Dataset
import torch
import util
snli_keys_no_parses = [
'captionID', 'pairID',
'annotator_labels', 'gold_label',
'sentence1', 'sentence2'
]
snli_labels = ['entailment', 'contradiction', 'neutral']
class SNLIDataset(Dataset):
def __init__(self, path, label_key, exclude_no_gold=True, exclude_ids=None):
self.path = Path(path)
self.label_key = label_key
self.items = util.load_jsonl(self.path)
if label_key:
self.items = [item for item in self.items if not item[label_key] == '-']
if exclude_ids:
self.items = [item for item in self.items if not item['pairID'] in exclude_ids]
def __len__(self):
return len(self.items)
def __getitem__(self, idx):
raw_item = self.items[idx]
item = {k: raw_item[k] for k in ['pairID', 'sentence1', 'sentence2']}
item['label'] = raw_item[self.label_key] if self.label_key else None
return item
def add_roberta_se_fields(batch):
"""
Add batch fields needed by the RoBERTa Self-Explaining model.
Replicates this:
https://github.com/ShannonAI/Self_Explaining_Structures_Improve_NLP_Models/blob/master/datasets/collate_functions.py#L19"
"""
lengths = batch['attention_mask'].sum(dim=1)
max_sentence_length = lengths.max()
device = max_sentence_length.device
start_indexs = []
end_indexs = []
for i in range(1, max_sentence_length - 1):
for j in range(i, max_sentence_length - 1):
# # span大小为10
# if j - i > 10:
# continue
start_indexs.append(i)
end_indexs.append(j)
# generate span mask
span_masks = []
for input_ids, length in zip(batch['input_ids'], lengths):
span_mask = []
middle_index = input_ids.tolist().index(2)
for start_index, end_index in zip(start_indexs, end_indexs):
if 1 <= start_index <= length.item() - 2 and 1 <= end_index <= length.item() - 2 and (
start_index > middle_index or end_index < middle_index):
span_mask.append(0)
else:
span_mask.append(1e6)
span_masks.append(span_mask)
# add to output
batch['start_indexs'] = torch.LongTensor(start_indexs).to(device)
batch['end_indexs'] = torch.LongTensor(end_indexs).to(device)
batch['span_masks'] = torch.LongTensor(span_masks).to(device)
return batch # (input_ids, labels, length, start_indexs, end_indexs, span_masks)
def create_collate_fn(tokenizer, label_stoi, device,
roberta_se=False, hypothesis_only=False):
def collate_fn(batch):
item_ids = [item['pairID'] for item in batch]
if hypothesis_only:
inputs = [item['sentence2'] for item in batch]
else:
inputs = [(item['sentence1'], item['sentence2']) for item in batch]
sentence_pair_tokenized = tokenizer(inputs, padding=True, return_tensors='pt').to(device)
if roberta_se:
sentence_pair_tokenized = add_roberta_se_fields(sentence_pair_tokenized)
label_idxs = torch.LongTensor([label_stoi[item['label']] if item['label'] else -1 for item in batch]).to(device)
return item_ids, sentence_pair_tokenized, label_idxs
return collate_fn
def sample_and_supplement(dataset_a, dataset_b, sample_n, replace=False):
"""
Samples `sample_n` items from dataset_b and adds them to dataset_a.
If `replace` is True then `sample_n` items are removed from dataset_a
(keeping the original size of the dataset).
"""
b_ids = random.sample(range(len(dataset_b)), sample_n)
if replace:
a_ids = random.sample(range(len(dataset_a)), len(dataset_a) - sample_n)
assert len(b_ids) + len(a_ids) == len(dataset_a)
dataset_a = torch.utils.data.Subset(dataset_a, a_ids)
dataset_b = torch.utils.data.Subset(dataset_b, b_ids)
return torch.utils.data.ConcatDataset([dataset_a, dataset_b])