-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathfind_prompt_examples.py
More file actions
74 lines (65 loc) · 2.81 KB
/
find_prompt_examples.py
File metadata and controls
74 lines (65 loc) · 2.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from pathlib import Path
import json
import random
import util
from data import snli_labels, snli_keys_no_parses
def find_prompt_examples(data):
# the example was validated by at least 5 annotators
criteria1 = lambda x: len(x['annotator_labels']) >= 5
# all the annotators agreed
criteria2 = lambda x: all(l == x['gold_label'] for l in x['annotator_labels'])
fully_annotated = [item for item in data if criteria1(item)]
perfect_items = [item for item in fully_annotated if criteria2(item)]
print(f"There are {len(fully_annotated)} fully annotated and "
f"{len(perfect_items)} perfect items out of {len(data)}.")
grouped = util.group_records_by(perfect_items, 'captionID')
examples = []
for group in grouped:
if set(item['gold_label'] for item in group) == set(snli_labels):
examples.append(group)
from IPython import embed; embed(); raise;
return examples
def sample_examples(grouped_examples, n):
groups = random.sample(grouped_examples, n)
examples = []
for group in groups :
ex = {
'captionID': group[0]['captionID'],
'sentence1': group[0]['sentence1']
}
for label in snli_labels:
label_items = [item for item in group if item['gold_label'] == label]
item = random.choice(label_items)
ex[label] = {
'pairID': item['pairID'],
'sentence2': item['sentence2']
}
examples.append(ex)
return examples
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('command', type=str, choices=['perfect-examples', 'random-examples'])
args = parser.parse_args()
random.seed(42)
match args.command:
case 'perfect-examples':
data_dir = Path('./data/snli_1.0')
train_data_path = data_dir/'snli_1.0_train.jsonl'
example_path = Path('./prompts/perfect-snli-examples.json')
data = util.load_jsonl(train_data_path, keys=snli_keys_no_parses)
grouped_examples = find_prompt_examples(data)
examples = sample_examples(grouped_examples, 10)
if not example_path.exists():
with example_path.open('w') as f:
json.dump(examples, f)
case 'random-examples':
data_dir = Path('./data/snli_1.0')
train_data_path = data_dir/'snli_1.0_train.jsonl'
example_path = Path('./prompts/random-snli-examples.json')
data = util.load_jsonl(train_data_path, keys=snli_keys_no_parses)
examples = random.sample(data, 10)
if not example_path.exists():
with example_path.open('w') as f:
for item in examples:
util.write_jsonl(item, f)