| import torch | |
| from opt_einsum import contract as einsum | |
| import esm | |
| from data.residue_constants import order2restype_with_mask | |
| def get_pre_repr(seqs, model, alphabet, batch_converter, device="cuda:0"): | |
| # Prepare data (first 2 sequences from ESMStructuralSplitDataset superfamily / 4) | |
| # data = [ | |
| # ("protein1", "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"), | |
| # ("protein2", "KALTARQQEVFDLIRDHISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"), | |
| # ("protein2 with mask","KALTARQQEVFDLIRD<mask>ISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"), | |
| # ("protein3", "K A <mask> I S Q"), | |
| # ] | |
| data = [] | |
| for idx, seq in enumerate([seqs]): | |
| seq_string = ''.join([order2restype_with_mask[int(i)] for i in seq]) | |
| data.append(("protein_"+str(idx), seq_string)) | |
| batch_labels, batch_strs, batch_tokens = batch_converter(data) | |
| batch_lens = (batch_tokens != alphabet.padding_idx).sum(1) | |
| # Extract per-residue representations (on CPU) | |
| with torch.no_grad(): | |
| results = model(batch_tokens.to(device), repr_layers=[33], return_contacts=True) | |
| node_repr = results["representations"][33][:,1:-1,:] | |
| pair_repr = results['attentions'][:,33-1,:,1:-1,1:-1].permute(0,2,3,1) | |
| # Generate per-sequence representations via averaging | |
| # NOTE: token 0 is always a beginning-of-sequence token, so the first residue is token 1. | |
| # sequence_representations = [] | |
| # for i, tokens_len in enumerate(batch_lens): | |
| # sequence_representations.append(token_representations[i, 1 : tokens_len - 1].mean(0)) | |
| # Look at the unsupervised self-attention map contact predictions | |
| # for (_, seq), tokens_len, attention_contacts in zip(data, batch_lens, results["contacts"]): | |
| # plt.matshow(attention_contacts[: tokens_len, : tokens_len]) | |
| return node_repr, pair_repr |