forked from Yichuan0712/11785-TCR
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
146 lines (124 loc) · 6.3 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import torch.nn as nn
import torch
import torch.nn.functional as F
import esm
import esm_adapterH
from peft import PeftModel, LoraConfig, get_peft_model
import numpy as np # for lora
class ESM2(nn.Module): # embedding table is fixed
def __init__(self, configs):
super(ESM2, self).__init__()
esm2_dict = {}
if configs.adapter_h.enable:
adapter_args = configs.adapter_h
if configs.encoder_name == "esm2_t36_3B_UR50D":
esm2_dict["esm2_t36_3B_UR50D"] = esm_adapterH.pretrained.esm2_t36_3B_UR50D(adapter_args)
elif configs.encoder_name == "esm2_t33_650M_UR50D":
esm2_dict["esm2_t33_650M_UR50D"] = esm_adapterH.pretrained.esm2_t33_650M_UR50D(adapter_args)
elif configs.encoder_name == "esm2_t30_150M_UR50D":
esm2_dict["esm2_t30_150M_UR50D"] = esm_adapterH.pretrained.esm2_t30_150M_UR50D(adapter_args)
elif configs.encoder_name == "esm2_t12_35M_UR50D":
esm2_dict["esm2_t12_35M_UR50D"] = esm_adapterH.pretrained.esm2_t12_35M_UR50D(adapter_args)
elif configs.encoder_name == "esm2_t6_8M_UR50D":
esm2_dict["esm2_t6_8M_UR50D"] = esm_adapterH.pretrained.esm2_t6_8M_UR50D(adapter_args)
else:
raise ValueError(f"Unknown encoder name: {configs.encoder_name}")
else:
if configs.encoder_name == "esm2_t36_3B_UR50D":
esm2_dict["esm2_t36_3B_UR50D"] = esm.pretrained.esm2_t36_3B_UR50D()
elif configs.encoder_name == "esm2_t33_650M_UR50D":
esm2_dict["esm2_t33_650M_UR50D"] = esm.pretrained.esm2_t33_650M_UR50D()
elif configs.encoder_name == "esm2_t30_150M_UR50D":
esm2_dict["esm2_t30_150M_UR50D"] = esm.pretrained.esm2_t30_150M_UR50D()
elif configs.encoder_name == "esm2_t12_35M_UR50D":
esm2_dict["esm2_t12_35M_UR50D"] = esm.pretrained.esm2_t12_35M_UR50D()
elif configs.encoder_name == "esm2_t6_8M_UR50D":
esm2_dict["esm2_t6_8M_UR50D"] = esm.pretrained.esm2_t6_8M_UR50D()
else:
raise ValueError(f"Unknown encoder name: {configs.encoder_name}")
self.esm2, self.alphabet = esm2_dict[configs.encoder_name]
self.num_layers = self.esm2.num_layers
for p in self.esm2.parameters():
p.requires_grad = False
if configs.adapter_h.enable:
for name, param in self.esm2.named_parameters():
if "adapter_layer" in name:
param.requires_grad = True
if configs.lora.enable:
lora_targets = ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj", "self_attn.out_proj"]
target_modules = []
if configs.lora.esm_num_end_lora > 0:
start_layer_idx = np.max([self.num_layers - configs.lora.esm_num_end_lora, 0])
for idx in range(start_layer_idx, self.num_layers):
for layer_name in lora_targets:
target_modules.append(f"layers.{idx}.{layer_name}")
peft_config = LoraConfig(
inference_mode=False,
r=configs.lora.r,
lora_alpha=configs.lora.alpha,
target_modules=target_modules,
lora_dropout=configs.lora.dropout,
bias="none",
)
self.peft_model = get_peft_model(self.esm2, peft_config)
elif configs.fine_tuning.enable:
unfix_last_layer = configs.fine_tuning.unfix_last_layer # unfix_last_layer: the number of layers that can be fine-tuned
fix_layer_num = self.num_layers - unfix_last_layer
fix_layer_index = 0
for layer in self.esm2.layers: # only fine-tune transformer layers, no contact_head and other parameters
if fix_layer_index < fix_layer_num:
fix_layer_index += 1 # keep these layers frozen
continue
for p in layer.parameters():
p.requires_grad = True
if unfix_last_layer != 0: # if you need fine-tune last layer, the emb_layer_norm_after for last representation should be updated
for p in self.esm2.emb_layer_norm_after.parameters():
p.requires_grad = True
if configs.tune_ESM_table:
for p in self.esm2.embed_tokens.parameters():
p.requires_grad = True
def forward(self, x):
outputs = self.esm2(x, repr_layers=[self.num_layers], return_contacts=False)
residue_feature = outputs['representations'][self.num_layers]
return residue_feature
class LayerNormNet(nn.Module):
def __init__(self, embedding_dim, hidden_dim, out_dim, drop_out):
super(LayerNormNet, self).__init__()
self.hidden_dim1 = hidden_dim
self.out_dim = out_dim
self.drop_out = drop_out
self.fc1 = nn.Linear(embedding_dim, hidden_dim)
self.ln1 = nn.LayerNorm(hidden_dim)
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.ln2 = nn.LayerNorm(hidden_dim)
self.fc3 = nn.Linear(hidden_dim, out_dim)
self.dropout = nn.Dropout(p=drop_out)
def forward(self, x):
x = self.dropout(self.ln1(self.fc1(x)))
x = torch.relu(x)
x = self.dropout(self.ln2(self.fc2(x)))
x = torch.relu(x)
x = self.fc3(x)
return x
def prepare_models(configs, log_path):
# Use ESM2 for sequence
encoder = ESM2(configs)
if configs.encoder_name == "esm2_t36_3B_UR50D":
embedding_dim = 2560
elif configs.encoder_name == "esm2_t33_650M_UR50D":
embedding_dim = 1280
elif configs.encoder_name == "esm2_t30_150M_UR50D":
embedding_dim = 640
elif configs.encoder_name == "esm2_t12_35M_UR50D":
embedding_dim = 480
elif configs.encoder_name == "esm2_t6_8M_UR50D":
embedding_dim = 320
else:
raise ValueError(f"Unknown encoder name: {configs.encoder_name}")
if configs.projection_head_name == "LayerNorm":
projection_head = LayerNormNet(embedding_dim=embedding_dim, hidden_dim=configs.hidden_dim, out_dim=configs.out_dim, drop_out=configs.drop_out)
else:
raise ValueError(f"Unknown projection head name: {configs.projection_head_name}")
return encoder, projection_head
if __name__ == '__main__':
print('test')