-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathgenerate_proteins.py
71 lines (49 loc) · 1.96 KB
/
generate_proteins.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from collections import Counter
import numpy as np
import sys
import torch
from tqdm import tqdm
from nano_transformer import PLNanoTransformer
from pytorch_lightning import seed_everything
n_prots = int(sys.argv[1]) # number of proteins to generate
checkpoint_file = sys.argv[2] # checkpoint file to load
PROT_FNAME = "data/prot_seqs.txt" # needed for the vocab
seed_everything(1337)
######## build the vocab #########
# TODO: factor this out of train/generate files
with open(PROT_FNAME, "r") as f:
lines = f.readlines()
pad = "!" # padding character
flat_text = [c for line in lines for c in line] # all proteins concatenated
chars = sorted(set(flat_text)) + [pad]
vocab_size = len(chars)
stoi = {c: i for i, c in enumerate(chars)}
itos = {i: c for i, c in enumerate(chars)}
def decode(i):
return "".join([itos[ii] for ii in i])
# In order to generate, we compute the probability
# of each amino acid to appear in the first position
nr_times_first = Counter([line[0] for line in lines])
aa_to_proba_first = {aa: nr_times_first[aa] / len(lines) for aa in nr_times_first}
######## load the model (on CPU) #########
pl_model = PLNanoTransformer.load_from_checkpoint(checkpoint_file)
pl_model.eval()
######## generate proteins #########
def generate_protein_string():
# start_char_proba = {k: v / sum(counter.values()) for k, v in counter.items()}
start_char = np.random.choice(
list(aa_to_proba_first.keys()), p=list(aa_to_proba_first.values())
)
initial_context = torch.tensor([[stoi[start_char]]], dtype=torch.long) # (1, 1)
return decode(
pl_model.model.generate_line(
idx=initial_context,
termination_token_idx=stoi["\n"],
pad_token_idx=stoi["!"],
).tolist()[:-1]
) # remove the last \n
generated_seqs = []
for _ in tqdm(range(n_prots)):
generated_seqs.append(generate_protein_string())
with open("generated_proteins.txt", "w") as f:
f.write("\n".join(generated_seqs))