P2DFlow / models /flow_model.py
Holmes
test
ca7299e
"""Neural network architecture for the flow model."""
import torch
from torch import nn
from models.node_embedder import NodeEmbedder
from models.edge_embedder import EdgeEmbedder
from models.add_module.structure_module import Node_update, Pair_update, E3_transformer, TensorProductConvLayer, E3_GNNlayer, E3_transformer_no_adjacency, E3_transformer_test, Energy_Adapter_Node
from models.add_module.egnn import EGNN
from models.add_module.model_utils import get_sub_graph
from models import ipa_pytorch
from data import utils as du
from data import all_atom
import e3nn
class FlowModel(nn.Module):
def __init__(self, model_conf):
super(FlowModel, self).__init__()
self._model_conf = model_conf
self._ipa_conf = model_conf.ipa
self.rigids_ang_to_nm = lambda x: x.apply_trans_fn(lambda x: x * du.ANG_TO_NM_SCALE)
self.rigids_nm_to_ang = lambda x: x.apply_trans_fn(lambda x: x * du.NM_TO_ANG_SCALE)
self.node_embedder = NodeEmbedder(model_conf.node_features)
self.edge_embedder = EdgeEmbedder(model_conf.edge_features)
# aatype_total = 21
# self.aatype_pred_layer = nn.Sequential(
# nn.Linear(self._ipa_conf.c_s, self._ipa_conf.c_s),
# nn.ReLU(),
# nn.Dropout(self._model_conf.dropout),
# nn.Linear(self._ipa_conf.c_s, aatype_total)
# )
# self.use_np_update = False
# self.use_e3_transformer = False
# self.use_torsions = True
# self.use_mid_bb_update = False
# self.use_mid_bb_update_e3 = False
# self.use_adapter_node = True
self.use_torsions = model_conf.use_torsions
self.use_adapter_node = model_conf.use_adapter_node
self.use_mid_bb_update = model_conf.use_mid_bb_update
self.use_mid_bb_update_e3 = model_conf.use_mid_bb_update_e3
self.use_e3_transformer = model_conf.use_e3_transformer
if self.use_adapter_node:
self.energy_adapter = Energy_Adapter_Node(d_node=model_conf.node_embed_size, n_head=model_conf.ipa.no_heads, p_drop=model_conf.dropout)
if self.use_torsions:
self.num_torsions = 7
self.torsions_pred_layer1 = nn.Sequential(
nn.Linear(self._ipa_conf.c_s, self._ipa_conf.c_s),
nn.ReLU(),
nn.Linear(self._ipa_conf.c_s, self._ipa_conf.c_s),
)
self.torsions_pred_layer2 = nn.Linear(self._ipa_conf.c_s, self.num_torsions * 2)
if self.use_e3_transformer or self.use_mid_bb_update_e3:
self.max_dist = self._model_conf.edge_features.max_dist
self.irreps_sh = e3nn.o3.Irreps.spherical_harmonics(2) #
input_d_l1 = 3
e3_d_node_l1 = 32
irreps_l1_in = e3nn.o3.Irreps(str(input_d_l1)+"x1o")
irreps_l1_out = e3nn.o3.Irreps(str(e3_d_node_l1)+"x1o")
self.proj_l1_init = e3nn.o3.Linear(irreps_l1_in, irreps_l1_out)
# Attention trunk
self.trunk = nn.ModuleDict()
for b in range(self._ipa_conf.num_blocks):
if self.use_e3_transformer:
# self.trunk[f'ipa_{b}'] = E3_GNNlayer(e3_d_node_l0 = 4*e3_d_node_l1,
# e3_d_node_l1 = e3_d_node_l1)
# self.trunk[f'ipa_{b}'] = E3_transformer(e3_d_node_l0 = 4*e3_d_node_l1, e3_d_node_l1 = e3_d_node_l1,
# d_node=model_conf.node_embed_size, d_pair=model_conf.edge_embed_size,
# final = False)
# self.trunk[f'ipa_{b}'] = E3_transformer_no_adjacency(e3_d_node_l0=4*e3_d_node_l1, e3_d_node_l1=e3_d_node_l1,
# d_node=model_conf.node_embed_size, d_pair=model_conf.edge_embed_size,
# no_heads=16, final = False)
self.trunk[f'ipa_{b}'] = E3_transformer_test(self._ipa_conf,e3_d_node_l1 = e3_d_node_l1)
else:
self.trunk[f'ipa_{b}'] = ipa_pytorch.InvariantPointAttention(self._ipa_conf)
self.trunk[f'ipa_ln_{b}'] = nn.LayerNorm(self._ipa_conf.c_s)
if self.use_adapter_node:
self.trunk[f'energy_adapter_{b}'] = Energy_Adapter_Node(d_node=model_conf.node_embed_size, n_head=model_conf.ipa.no_heads, p_drop=model_conf.dropout)
self.trunk[f'egnn_{b}'] = EGNN(hidden_dim = model_conf.node_embed_size)
# if self.use_np_update:
# self.trunk[f'node_transition_{b}'] = Node_update(d_msa=model_conf.node_embed_size, d_pair=model_conf.edge_embed_size)
# else:
# tfmr_in = self._ipa_conf.c_s
# tfmr_layer = torch.nn.TransformerEncoderLayer(
# d_model=tfmr_in,
# nhead=self._ipa_conf.seq_tfmr_num_heads,
# dim_feedforward=tfmr_in,
# batch_first=True,
# dropout=self._model_conf.dropout,
# norm_first=False
# )
# self.trunk[f'seq_tfmr_{b}'] = torch.nn.TransformerEncoder(
# tfmr_layer, self._ipa_conf.seq_tfmr_num_layers, enable_nested_tensor=False)
# self.trunk[f'post_tfmr_{b}'] = ipa_pytorch.Linear(
# tfmr_in, self._ipa_conf.c_s, init="final")
# self.trunk[f'node_transition_{b}'] = ipa_pytorch.StructureModuleTransition(
# c=self._ipa_conf.c_s)
if self.use_mid_bb_update:
if self.use_mid_bb_update_e3:
self.trunk[f'bb_update_{b}'] = E3_GNNlayer(e3_d_node_l0 = 4*e3_d_node_l1,
e3_d_node_l1 = e3_d_node_l1, final = True,
init='final',dropout=0.0)
else:
self.trunk[f'bb_update_{b}'] = ipa_pytorch.BackboneUpdate(
self._ipa_conf.c_s, use_rot_updates=True, dropout=0.0)
# if b < self._ipa_conf.num_blocks-1:
# # No edge update on the last block.
# if self.use_np_update:
# self.trunk[f'edge_transition_{b}'] = Pair_update(d_node=model_conf.node_embed_size, d_pair=model_conf.edge_embed_size)
# else:
# edge_in = self._model_conf.edge_embed_size
# self.trunk[f'edge_transition_{b}'] = ipa_pytorch.EdgeTransition(
# node_embed_size=self._ipa_conf.c_s,
# edge_embed_in=edge_in,
# edge_embed_out=self._model_conf.edge_embed_size,
# )
if self.use_mid_bb_update_e3:
# self.bb_update_layer = E3_transformer(e3_d_node_l0 = 4*e3_d_node_l1, e3_d_node_l1 = e3_d_node_l1,
# d_node=model_conf.node_embed_size, d_pair=model_conf.edge_embed_size,
# final = True)
self.bb_update_layer = E3_GNNlayer(e3_d_node_l0 = 4*e3_d_node_l1,
e3_d_node_l1 = e3_d_node_l1, final = True,
init='final',dropout=0.0)
elif self.use_e3_transformer:
irreps_1 = e3nn.o3.Irreps(str(self._ipa_conf.c_s)+"x0e + "+str(e3_d_node_l1)+"x1o")
irreps_2 = e3nn.o3.Irreps(str(self._ipa_conf.c_s)+"x0e + "+str(e3_d_node_l1)+"x1o")
# irreps_3 = e3nn.o3.Irreps("2x1o + 2x1e")
irreps_3 = e3nn.o3.Irreps("6x0e")
self.bb_tpc = TensorProductConvLayer(irreps_1, irreps_2, irreps_3, self._ipa_conf.c_s, fc_factor=1,
init='final', dropout=0.1)
# self.bb_update_layer = ipa_pytorch.BackboneUpdate(
# self._ipa_conf.c_s+3*e3_d_node_l1, use_rot_updates=True, dropout=0.0)
else:
self.bb_update_layer = ipa_pytorch.BackboneUpdate(
self._ipa_conf.c_s, use_rot_updates=True, dropout=0.0)
def forward(self, input_feats, use_mask_aatype = False):
'''
note: B and L are changing during training
input_feats.keys():
'aatype' (B,L)
'res_mask' (B,L)
't' (B,1)
'trans_1' (B,L,3)
'rotmats_1' (B,L,3,3)
'trans_t' (B,L,3)
'rotmats_t' (B,L,3,3)
'''
node_mask = input_feats['res_mask']
edge_mask = node_mask[:, None] * node_mask[:, :, None]
continuous_t = input_feats['t']
trans_t = input_feats['trans_t']
rotmats_t = input_feats['rotmats_t']
aatype = input_feats['aatype']
if self.use_adapter_node:
energy = input_feats['energy'] # (B,)
if self.use_e3_transformer or self.use_mid_bb_update_e3:
xyz_t = all_atom.to_atom37(trans_t, rotmats_t)[:, :, :3, :] # (B,L,3,3)
edge_src, edge_dst, pair_index = get_sub_graph(xyz_t[:,:,1,:], mask = node_mask,
dist_max=self.max_dist, kmin=32)
xyz_graph = xyz_t[:,:,1,:].reshape(-1,3)
edge_vec = xyz_graph[edge_src] - xyz_graph[edge_dst]
edge_sh = e3nn.o3.spherical_harmonics(self.irreps_sh, edge_vec,
normalize=True, normalization='component') # (num_edges, irreps_sh)
# if use_mask_aatype:
# factor = 0.15
# mask_token = 21
# mask_aatype = torch.ones(aatype.shape)
# for b in range(aatype.shape[0]):
# mask_aa_num = torch.tensor(random.random()*factor*aatype.shape[1]).int()
# indices = random.sample(range(aatype.shape[1]), mask_aa_num)
# mask_aatype[b,indices] = 0
# mask_aatype = mask_aatype.to(aatype.device)
# aatype = (aatype * mask_aatype + (1-mask_aatype) * mask_token).int()
node_repr_pre = input_feats['node_repr_pre']
pair_repr_pre = input_feats['pair_repr_pre']
if 'trans_sc' not in input_feats:
trans_sc = torch.zeros_like(trans_t)
else:
trans_sc = input_feats['trans_sc']
# Initialize node and edge embeddings, there is only seq_pos info in node, and there is bond-len info in edge
init_node_embed = self.node_embedder(continuous_t, aatype, node_repr_pre, node_mask)
init_node_embed = init_node_embed * node_mask[..., None]
if self.use_adapter_node:
init_node_embed = self.energy_adapter(init_node_embed, energy, mask = node_mask)
init_node_embed = init_node_embed * node_mask[..., None]
init_edge_embed = self.edge_embedder(
init_node_embed, trans_t, trans_sc, pair_repr_pre, edge_mask)
init_edge_embed = init_edge_embed * edge_mask[..., None]
if self.use_e3_transformer or self.use_mid_bb_update_e3:
l1_feats = torch.concat([xyz_t[:,:,0,:]-xyz_t[:,:,1,:],
xyz_t[:,:,1,:],
xyz_t[:,:,2,:]-xyz_t[:,:,1,:]], dim=-1) # (B,L,3*3)
l1_feats = self.proj_l1_init(l1_feats) # (B,L,4*3)
l1_feats = l1_feats * node_mask[..., None]
# Initial rigids
curr_rigids = du.create_rigid(rotmats_t, trans_t,)
# Main trunk
curr_rigids = self.rigids_ang_to_nm(curr_rigids)
node_embed = init_node_embed
edge_embed = init_edge_embed
for b in range(self._ipa_conf.num_blocks):
if self.use_e3_transformer:
# ipa_embed, l1_feats = self.trunk[f'ipa_{b}'](node_embed, edge_embed, l1_feats, pair_index,
# edge_src, edge_dst, edge_sh, mask = node_mask)
# l1_feats = l1_feats * node_mask[..., None]
# ipa_embed, l1_feats = self.trunk[f'ipa_{b}'](node_embed,
# edge_embed,
# l1_feats,
# mask = node_mask)
# l1_feats = l1_feats * node_mask[..., None]
ipa_embed, l1_feats = self.trunk[f'ipa_{b}'](node_embed,
edge_embed,
l1_feats,
curr_rigids,
node_mask) # (B,L,d_node)
l1_feats = l1_feats * node_mask[..., None]
else:
ipa_embed = self.trunk[f'ipa_{b}'](node_embed,
edge_embed,
curr_rigids,
node_mask) # (B,L,d_node)
ipa_embed = node_embed + ipa_embed
ipa_embed = ipa_embed * node_mask[..., None]
node_embed = self.trunk[f'ipa_ln_{b}'](ipa_embed)
if self.use_adapter_node:
node_embed = self.trunk[f'energy_adapter_{b}'](node_embed, energy, mask = node_mask)
node_embed = node_embed * node_mask[..., None]
__, node_embed = self.trunk[f'egnn_{b}'](node_embed, trans_t)
node_embed = node_embed * node_mask[..., None]
# trans_t_new, node_embed = self.trunk[f'egnn_{b}'](node_embed, trans_t)
# rotmats_t_new = torch.zeros(trans_t_new.shape, device = trans_t_new.device) # (B,L,3)
# rigid_update = torch.concat([trans_t_new,rotmats_t_new], dim=-1) # (B,L,6)
# node_embed = node_embed * node_mask[..., None]
# curr_rigids = curr_rigids.compose_q_update_vec(rigid_update, node_mask[..., None])
# if self.use_np_update:
# node_embed = self.trunk[f'node_transition_{b}'](node_embed, edge_embed, node_mask)
# else:
# seq_tfmr_out = self.trunk[f'seq_tfmr_{b}'](
# node_embed, src_key_padding_mask=(1 - node_mask).bool())
# #node_embed = init_node_embed + self.trunk[f'post_tfmr_{b}'](seq_tfmr_out)
# node_embed = node_embed + self.trunk[f'post_tfmr_{b}'](seq_tfmr_out)
# node_embed = self.trunk[f'node_transition_{b}'](node_embed)
# node_embed = node_embed * node_mask[..., None]
if self.use_mid_bb_update:
if self.use_mid_bb_update_e3:
rigid_update = self.trunk[f'bb_update_{b}'](node_embed, edge_embed, l1_feats, pair_index,
edge_src, edge_dst, edge_sh, mask = node_mask)
elif self.use_e3_transformer:
rigid_update = self.trunk[f'bb_update_{b}'](torch.concat([node_embed,l1_feats],dim=-1))
else:
rigid_update = self.trunk[f'bb_update_{b}'](node_embed)
curr_rigids = curr_rigids.compose_q_update_vec(
rigid_update, node_mask[..., None])
# if b < self._ipa_conf.num_blocks-1:
# if self.use_np_update:
# edge_embed = self.trunk[f'edge_transition_{b}'](
# node_embed, edge_embed, node_mask)
# else:
# edge_embed = self.trunk[f'edge_transition_{b}'](
# node_embed, edge_embed)
# edge_embed *= edge_mask[..., None]
if self.use_mid_bb_update_e3:
rigid_update = self.bb_update_layer(node_embed, edge_embed, l1_feats, pair_index,
edge_src, edge_dst, edge_sh, mask = node_mask)
# print('\nl1_feats:','equ='+str(torch.mean(l1_feats[:5])),'inv='+str(torch.mean(torch.linalg.norm(l1_feats[:5],dim=-1))))
elif self.use_e3_transformer:
node_feats_total = torch.concat([node_embed,l1_feats],dim=-1)
rigid_update = self.bb_tpc(node_feats_total, node_feats_total, node_embed)
else:
rigid_update = self.bb_update_layer(node_embed)
curr_rigids = curr_rigids.compose_q_update_vec(
rigid_update, node_mask[..., None])
curr_rigids = self.rigids_nm_to_ang(curr_rigids)
pred_trans = curr_rigids.get_trans()
# pred_trans = pred_trans - torch.mean(pred_trans, dim=-2, keepdims=True)
pred_rotmats = curr_rigids.get_rots().get_rot_mats()
# pred_aatype = self.aatype_pred_layer(node_embed)
if self.use_torsions:
pred_torsions = node_embed + self.torsions_pred_layer1(node_embed)
pred_torsions = self.torsions_pred_layer2(pred_torsions).reshape(input_feats['aatype'].shape+(self.num_torsions,2)) # (B,L,self.num_torsions,2)
norm_torsions = torch.sqrt(torch.sum(pred_torsions ** 2, dim=-1, keepdim=True)) # (B,L,self.num_torsions,1)
pred_torsions = pred_torsions / norm_torsions # (B,L,self.num_torsions,2)
add_rot = pred_torsions.new_zeros((1,) * len(pred_torsions.shape[:-2])+(8-self.num_torsions,2)) # (1,1,8-self.num_torsions,2)
add_rot[..., 1] = 1
add_rot = add_rot.expand(*pred_torsions.shape[:-2], -1, -1) # (B,L,8-self.num_torsions,2)
pred_torsions_with_CB = torch.concat([add_rot, pred_torsions],dim=-2) # (B,L,8,2)
# aatype # (B,L)
return {
'pred_trans': pred_trans,
'pred_rotmats': pred_rotmats,
# 'pred_aatype': pred_aatype
'pred_torsions_with_CB': pred_torsions_with_CB,
}