-
Notifications
You must be signed in to change notification settings - Fork 13
/
generate.py
90 lines (68 loc) · 2.61 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import fire
import os
from tqdm import tqdm
import pickle
from models.bart import BART
from models.gpt2 import GPT2
def get_model_list(dataset, prog_vocabs, first_model):
model_list = []
for i in range(len(prog_vocabs) - 1):
if i == 0 and 'gpt2' in first_model:
model = GPT2(gpt2_type=first_model)
model.load_model(f'training_logs/{first_model}_{dataset}_'
f'{prog_vocabs[1]}words/best_model.pt')
else:
model = BART()
model.load_model(f'training_logs/bart_{dataset}_{prog_vocabs[i]}-'
f'{prog_vocabs[i+1]}/best_model.pt')
model_list.append(model)
return model_list
def generate(model, cond, top_k, top_p):
while True:
gen_text = model.generate(cond=cond, top_k=top_k, top_p=top_p)
if len(list(filter(str.isalpha, gen_text))) > 0:
return gen_text
def main(dataset,
prog_steps,
first_model,
top_k=-1,
top_p=0.95):
prog_vocabs = prog_steps.split('-')
assert prog_vocabs[0] == 'null' and prog_vocabs[-1] == 'full'
model_list = get_model_list(dataset, prog_vocabs, first_model)
decoding = 'top_'
if top_k > 0:
decoding += f'k{top_k}'
if top_p > 0:
decoding += f'p{top_p}'
test_examples = pickle.load(open(f'data/{dataset}/test.pickle', 'rb'))
gen_dir = f'generated_texts/{dataset}_first-{first_model}_{prog_steps}/' \
f'{decoding}'
os.makedirs(gen_dir, exist_ok=True)
log_file = open(f'{gen_dir}/gen.txt', 'w')
gens = []
for example in tqdm(test_examples, desc='Generating'):
condition, truth = example['condition'], example['text']
prog_gens = [generate(
model=model_list[0], cond=condition, top_k=top_k, top_p=top_p)]
for model in model_list[1:]:
prog_gens.append(generate(
model=model,
cond=condition + ' [SEP] ' + prog_gens[-1],
top_k=top_k, top_p=top_p))
gens.append({
'condition': condition,
'truth': truth,
'prog_gens': prog_gens,
'top_k': top_k,
'top_p': top_p
})
print(f'CONDITION:\n{condition}\n', '-' * 50, '\n\n',
f'TRUTH:\n{truth}\n', '=' * 100, '\n\n', file=log_file)
for step, text in enumerate(prog_gens):
print(f'STEP_{step}:\n{text}\n', '-' * 50, '\n\n', file=log_file)
print('=' * 50, file=log_file)
log_file.flush()
pickle.dump(gens, open(f'{gen_dir}/gen.pickle', 'wb'))
if __name__ == '__main__':
fire.Fire(main)