|
|
"""Neural network architecture for the flow model.""" |
|
|
import torch |
|
|
from torch import nn |
|
|
from models.node_embedder import NodeEmbedder |
|
|
from models.edge_embedder import EdgeEmbedder |
|
|
from models.add_module.structure_module import Node_update, Pair_update, E3_transformer, TensorProductConvLayer, E3_GNNlayer, E3_transformer_no_adjacency, E3_transformer_test, Energy_Adapter_Node |
|
|
from models.add_module.egnn import EGNN |
|
|
from models.add_module.model_utils import get_sub_graph |
|
|
from models import ipa_pytorch |
|
|
from data import utils as du |
|
|
from data import all_atom |
|
|
import e3nn |
|
|
|
|
|
class FlowModel(nn.Module): |
|
|
|
|
|
def __init__(self, model_conf): |
|
|
super(FlowModel, self).__init__() |
|
|
|
|
|
self._model_conf = model_conf |
|
|
self._ipa_conf = model_conf.ipa |
|
|
self.rigids_ang_to_nm = lambda x: x.apply_trans_fn(lambda x: x * du.ANG_TO_NM_SCALE) |
|
|
self.rigids_nm_to_ang = lambda x: x.apply_trans_fn(lambda x: x * du.NM_TO_ANG_SCALE) |
|
|
self.node_embedder = NodeEmbedder(model_conf.node_features) |
|
|
self.edge_embedder = EdgeEmbedder(model_conf.edge_features) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.use_torsions = model_conf.use_torsions |
|
|
self.use_adapter_node = model_conf.use_adapter_node |
|
|
self.use_mid_bb_update = model_conf.use_mid_bb_update |
|
|
self.use_mid_bb_update_e3 = model_conf.use_mid_bb_update_e3 |
|
|
self.use_e3_transformer = model_conf.use_e3_transformer |
|
|
|
|
|
|
|
|
if self.use_adapter_node: |
|
|
self.energy_adapter = Energy_Adapter_Node(d_node=model_conf.node_embed_size, n_head=model_conf.ipa.no_heads, p_drop=model_conf.dropout) |
|
|
|
|
|
if self.use_torsions: |
|
|
self.num_torsions = 7 |
|
|
self.torsions_pred_layer1 = nn.Sequential( |
|
|
nn.Linear(self._ipa_conf.c_s, self._ipa_conf.c_s), |
|
|
nn.ReLU(), |
|
|
nn.Linear(self._ipa_conf.c_s, self._ipa_conf.c_s), |
|
|
) |
|
|
self.torsions_pred_layer2 = nn.Linear(self._ipa_conf.c_s, self.num_torsions * 2) |
|
|
|
|
|
if self.use_e3_transformer or self.use_mid_bb_update_e3: |
|
|
self.max_dist = self._model_conf.edge_features.max_dist |
|
|
self.irreps_sh = e3nn.o3.Irreps.spherical_harmonics(2) |
|
|
input_d_l1 = 3 |
|
|
e3_d_node_l1 = 32 |
|
|
irreps_l1_in = e3nn.o3.Irreps(str(input_d_l1)+"x1o") |
|
|
irreps_l1_out = e3nn.o3.Irreps(str(e3_d_node_l1)+"x1o") |
|
|
|
|
|
self.proj_l1_init = e3nn.o3.Linear(irreps_l1_in, irreps_l1_out) |
|
|
|
|
|
|
|
|
self.trunk = nn.ModuleDict() |
|
|
for b in range(self._ipa_conf.num_blocks): |
|
|
if self.use_e3_transformer: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.trunk[f'ipa_{b}'] = E3_transformer_test(self._ipa_conf,e3_d_node_l1 = e3_d_node_l1) |
|
|
else: |
|
|
self.trunk[f'ipa_{b}'] = ipa_pytorch.InvariantPointAttention(self._ipa_conf) |
|
|
self.trunk[f'ipa_ln_{b}'] = nn.LayerNorm(self._ipa_conf.c_s) |
|
|
|
|
|
if self.use_adapter_node: |
|
|
self.trunk[f'energy_adapter_{b}'] = Energy_Adapter_Node(d_node=model_conf.node_embed_size, n_head=model_conf.ipa.no_heads, p_drop=model_conf.dropout) |
|
|
|
|
|
|
|
|
self.trunk[f'egnn_{b}'] = EGNN(hidden_dim = model_conf.node_embed_size) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.use_mid_bb_update: |
|
|
if self.use_mid_bb_update_e3: |
|
|
self.trunk[f'bb_update_{b}'] = E3_GNNlayer(e3_d_node_l0 = 4*e3_d_node_l1, |
|
|
e3_d_node_l1 = e3_d_node_l1, final = True, |
|
|
init='final',dropout=0.0) |
|
|
else: |
|
|
self.trunk[f'bb_update_{b}'] = ipa_pytorch.BackboneUpdate( |
|
|
self._ipa_conf.c_s, use_rot_updates=True, dropout=0.0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.use_mid_bb_update_e3: |
|
|
|
|
|
|
|
|
|
|
|
self.bb_update_layer = E3_GNNlayer(e3_d_node_l0 = 4*e3_d_node_l1, |
|
|
e3_d_node_l1 = e3_d_node_l1, final = True, |
|
|
init='final',dropout=0.0) |
|
|
elif self.use_e3_transformer: |
|
|
|
|
|
irreps_1 = e3nn.o3.Irreps(str(self._ipa_conf.c_s)+"x0e + "+str(e3_d_node_l1)+"x1o") |
|
|
irreps_2 = e3nn.o3.Irreps(str(self._ipa_conf.c_s)+"x0e + "+str(e3_d_node_l1)+"x1o") |
|
|
|
|
|
irreps_3 = e3nn.o3.Irreps("6x0e") |
|
|
self.bb_tpc = TensorProductConvLayer(irreps_1, irreps_2, irreps_3, self._ipa_conf.c_s, fc_factor=1, |
|
|
init='final', dropout=0.1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else: |
|
|
self.bb_update_layer = ipa_pytorch.BackboneUpdate( |
|
|
self._ipa_conf.c_s, use_rot_updates=True, dropout=0.0) |
|
|
|
|
|
def forward(self, input_feats, use_mask_aatype = False): |
|
|
''' |
|
|
note: B and L are changing during training |
|
|
input_feats.keys(): |
|
|
'aatype' (B,L) |
|
|
'res_mask' (B,L) |
|
|
't' (B,1) |
|
|
'trans_1' (B,L,3) |
|
|
'rotmats_1' (B,L,3,3) |
|
|
'trans_t' (B,L,3) |
|
|
'rotmats_t' (B,L,3,3) |
|
|
''' |
|
|
|
|
|
node_mask = input_feats['res_mask'] |
|
|
edge_mask = node_mask[:, None] * node_mask[:, :, None] |
|
|
continuous_t = input_feats['t'] |
|
|
trans_t = input_feats['trans_t'] |
|
|
rotmats_t = input_feats['rotmats_t'] |
|
|
aatype = input_feats['aatype'] |
|
|
|
|
|
if self.use_adapter_node: |
|
|
energy = input_feats['energy'] |
|
|
|
|
|
|
|
|
if self.use_e3_transformer or self.use_mid_bb_update_e3: |
|
|
xyz_t = all_atom.to_atom37(trans_t, rotmats_t)[:, :, :3, :] |
|
|
edge_src, edge_dst, pair_index = get_sub_graph(xyz_t[:,:,1,:], mask = node_mask, |
|
|
dist_max=self.max_dist, kmin=32) |
|
|
xyz_graph = xyz_t[:,:,1,:].reshape(-1,3) |
|
|
edge_vec = xyz_graph[edge_src] - xyz_graph[edge_dst] |
|
|
edge_sh = e3nn.o3.spherical_harmonics(self.irreps_sh, edge_vec, |
|
|
normalize=True, normalization='component') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
node_repr_pre = input_feats['node_repr_pre'] |
|
|
pair_repr_pre = input_feats['pair_repr_pre'] |
|
|
|
|
|
if 'trans_sc' not in input_feats: |
|
|
trans_sc = torch.zeros_like(trans_t) |
|
|
else: |
|
|
trans_sc = input_feats['trans_sc'] |
|
|
|
|
|
|
|
|
init_node_embed = self.node_embedder(continuous_t, aatype, node_repr_pre, node_mask) |
|
|
init_node_embed = init_node_embed * node_mask[..., None] |
|
|
|
|
|
if self.use_adapter_node: |
|
|
init_node_embed = self.energy_adapter(init_node_embed, energy, mask = node_mask) |
|
|
init_node_embed = init_node_embed * node_mask[..., None] |
|
|
|
|
|
init_edge_embed = self.edge_embedder( |
|
|
init_node_embed, trans_t, trans_sc, pair_repr_pre, edge_mask) |
|
|
init_edge_embed = init_edge_embed * edge_mask[..., None] |
|
|
|
|
|
if self.use_e3_transformer or self.use_mid_bb_update_e3: |
|
|
l1_feats = torch.concat([xyz_t[:,:,0,:]-xyz_t[:,:,1,:], |
|
|
xyz_t[:,:,1,:], |
|
|
xyz_t[:,:,2,:]-xyz_t[:,:,1,:]], dim=-1) |
|
|
|
|
|
l1_feats = self.proj_l1_init(l1_feats) |
|
|
|
|
|
l1_feats = l1_feats * node_mask[..., None] |
|
|
|
|
|
|
|
|
curr_rigids = du.create_rigid(rotmats_t, trans_t,) |
|
|
|
|
|
|
|
|
curr_rigids = self.rigids_ang_to_nm(curr_rigids) |
|
|
|
|
|
node_embed = init_node_embed |
|
|
edge_embed = init_edge_embed |
|
|
for b in range(self._ipa_conf.num_blocks): |
|
|
|
|
|
if self.use_e3_transformer: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ipa_embed, l1_feats = self.trunk[f'ipa_{b}'](node_embed, |
|
|
edge_embed, |
|
|
l1_feats, |
|
|
curr_rigids, |
|
|
node_mask) |
|
|
l1_feats = l1_feats * node_mask[..., None] |
|
|
else: |
|
|
ipa_embed = self.trunk[f'ipa_{b}'](node_embed, |
|
|
edge_embed, |
|
|
curr_rigids, |
|
|
node_mask) |
|
|
ipa_embed = node_embed + ipa_embed |
|
|
|
|
|
ipa_embed = ipa_embed * node_mask[..., None] |
|
|
node_embed = self.trunk[f'ipa_ln_{b}'](ipa_embed) |
|
|
|
|
|
if self.use_adapter_node: |
|
|
node_embed = self.trunk[f'energy_adapter_{b}'](node_embed, energy, mask = node_mask) |
|
|
node_embed = node_embed * node_mask[..., None] |
|
|
|
|
|
|
|
|
__, node_embed = self.trunk[f'egnn_{b}'](node_embed, trans_t) |
|
|
node_embed = node_embed * node_mask[..., None] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.use_mid_bb_update: |
|
|
if self.use_mid_bb_update_e3: |
|
|
rigid_update = self.trunk[f'bb_update_{b}'](node_embed, edge_embed, l1_feats, pair_index, |
|
|
edge_src, edge_dst, edge_sh, mask = node_mask) |
|
|
elif self.use_e3_transformer: |
|
|
rigid_update = self.trunk[f'bb_update_{b}'](torch.concat([node_embed,l1_feats],dim=-1)) |
|
|
else: |
|
|
rigid_update = self.trunk[f'bb_update_{b}'](node_embed) |
|
|
curr_rigids = curr_rigids.compose_q_update_vec( |
|
|
rigid_update, node_mask[..., None]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.use_mid_bb_update_e3: |
|
|
rigid_update = self.bb_update_layer(node_embed, edge_embed, l1_feats, pair_index, |
|
|
edge_src, edge_dst, edge_sh, mask = node_mask) |
|
|
|
|
|
|
|
|
elif self.use_e3_transformer: |
|
|
node_feats_total = torch.concat([node_embed,l1_feats],dim=-1) |
|
|
rigid_update = self.bb_tpc(node_feats_total, node_feats_total, node_embed) |
|
|
|
|
|
else: |
|
|
rigid_update = self.bb_update_layer(node_embed) |
|
|
|
|
|
curr_rigids = curr_rigids.compose_q_update_vec( |
|
|
rigid_update, node_mask[..., None]) |
|
|
|
|
|
curr_rigids = self.rigids_nm_to_ang(curr_rigids) |
|
|
pred_trans = curr_rigids.get_trans() |
|
|
|
|
|
pred_rotmats = curr_rigids.get_rots().get_rot_mats() |
|
|
|
|
|
|
|
|
|
|
|
if self.use_torsions: |
|
|
pred_torsions = node_embed + self.torsions_pred_layer1(node_embed) |
|
|
pred_torsions = self.torsions_pred_layer2(pred_torsions).reshape(input_feats['aatype'].shape+(self.num_torsions,2)) |
|
|
|
|
|
norm_torsions = torch.sqrt(torch.sum(pred_torsions ** 2, dim=-1, keepdim=True)) |
|
|
pred_torsions = pred_torsions / norm_torsions |
|
|
|
|
|
add_rot = pred_torsions.new_zeros((1,) * len(pred_torsions.shape[:-2])+(8-self.num_torsions,2)) |
|
|
add_rot[..., 1] = 1 |
|
|
add_rot = add_rot.expand(*pred_torsions.shape[:-2], -1, -1) |
|
|
pred_torsions_with_CB = torch.concat([add_rot, pred_torsions],dim=-2) |
|
|
|
|
|
|
|
|
|
|
|
return { |
|
|
'pred_trans': pred_trans, |
|
|
'pred_rotmats': pred_rotmats, |
|
|
|
|
|
'pred_torsions_with_CB': pred_torsions_with_CB, |
|
|
} |
|
|
|