|
|
"""Neural network for embedding node features.""" |
|
|
import torch |
|
|
from torch import nn |
|
|
from models.utils import get_index_embedding, get_time_embedding, add_RoPE |
|
|
|
|
|
class NodeEmbedder(nn.Module): |
|
|
def __init__(self, module_cfg): |
|
|
super(NodeEmbedder, self).__init__() |
|
|
self._cfg = module_cfg |
|
|
self.c_s = self._cfg.c_s |
|
|
self.c_pos_emb = self._cfg.c_pos_emb |
|
|
self.c_timestep_emb = self._cfg.c_timestep_emb |
|
|
self.c_node_pre = 1280 |
|
|
self.aatype_emb_dim = self._cfg.c_pos_emb |
|
|
|
|
|
self.aatype_emb = nn.Embedding(21, self.aatype_emb_dim) |
|
|
|
|
|
total_node_feats = self.aatype_emb_dim + self._cfg.c_timestep_emb + self.c_node_pre |
|
|
|
|
|
|
|
|
self.linear = nn.Sequential( |
|
|
nn.Linear(total_node_feats, self.c_s), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(self._cfg.dropout), |
|
|
nn.Linear(self.c_s, self.c_s), |
|
|
) |
|
|
|
|
|
def embed_t(self, timesteps, mask): |
|
|
timestep_emb = get_time_embedding( |
|
|
timesteps[:, 0], |
|
|
self.c_timestep_emb, |
|
|
max_positions=2056 |
|
|
)[:, None, :].repeat(1, mask.shape[1], 1) |
|
|
return timestep_emb * mask.unsqueeze(-1) |
|
|
|
|
|
def forward(self, timesteps, aatype, node_repr_pre, mask): |
|
|
''' |
|
|
mask: [B,L] |
|
|
timesteps: [B,1] |
|
|
energy: [B,] |
|
|
''' |
|
|
|
|
|
b, num_res, device = mask.shape[0], mask.shape[1], mask.device |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
aatype_emb = self.aatype_emb(aatype) * mask.unsqueeze(-1) |
|
|
|
|
|
|
|
|
input_feats = [aatype_emb] |
|
|
|
|
|
time_emb = self.embed_t(timesteps, mask) |
|
|
input_feats.append(time_emb) |
|
|
|
|
|
input_feats.append(node_repr_pre) |
|
|
|
|
|
out = self.linear(torch.cat(input_feats, dim=-1)) |
|
|
|
|
|
return add_RoPE(out) |