import torch from opt_einsum import contract as einsum import esm from data.residue_constants import order2restype_with_mask def get_pre_repr(seqs, model, alphabet, batch_converter, device="cuda:0"): # Prepare data (first 2 sequences from ESMStructuralSplitDataset superfamily / 4) # data = [ # ("protein1", "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"), # ("protein2", "KALTARQQEVFDLIRDHISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"), # ("protein2 with mask","KALTARQQEVFDLIRDISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"), # ("protein3", "K A I S Q"), # ] data = [] for idx, seq in enumerate([seqs]): seq_string = ''.join([order2restype_with_mask[int(i)] for i in seq]) data.append(("protein_"+str(idx), seq_string)) batch_labels, batch_strs, batch_tokens = batch_converter(data) batch_lens = (batch_tokens != alphabet.padding_idx).sum(1) # Extract per-residue representations (on CPU) with torch.no_grad(): results = model(batch_tokens.to(device), repr_layers=[33], return_contacts=True) node_repr = results["representations"][33][:,1:-1,:] pair_repr = results['attentions'][:,33-1,:,1:-1,1:-1].permute(0,2,3,1) # Generate per-sequence representations via averaging # NOTE: token 0 is always a beginning-of-sequence token, so the first residue is token 1. # sequence_representations = [] # for i, tokens_len in enumerate(batch_lens): # sequence_representations.append(token_representations[i, 1 : tokens_len - 1].mean(0)) # Look at the unsupervised self-attention map contact predictions # for (_, seq), tokens_len, attention_contacts in zip(data, batch_lens, results["contacts"]): # plt.matshow(attention_contacts[: tokens_len, : tokens_len]) return node_repr, pair_repr