| |
|
| | import gc |
| | import os |
| | import sys |
| | import torch |
| | import pickle |
| | import numpy as np |
| | import pandas as pd |
| | import streamlit as st |
| | from torch.utils.data import DataLoader |
| |
|
| | from rdkit import Chem |
| | from rdkit.Chem import Draw |
| |
|
| | sys.path.insert(0, os.path.abspath("src/")) |
| | from src.dataset import DrugRetrieval, collate_target |
| | from hyper_dti.models.hyper_pcm import HyperPCM |
| |
|
| | base_path = os.path.dirname(__file__) |
| | data_path = os.path.join(base_path, 'data') |
| | checkpoint_path = os.path.join(base_path, 'checkpoints/lpo/cv2_test_fold6_1402/model_updated.t7') |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
|
| | st.set_page_config( |
| | page_title='HyperPCM', |
| | layout='centered', |
| | menu_items={ |
| | 'About': |
| | ''' |
| | # HyperPCM |
| | |
| | HyperNetworks have been established as an effective technique to achieve fast adaptation of parameters for |
| | neural networks. Recently, HyperNetwork predictions conditioned on descriptors of tasks have improved |
| | multi-task generalization in various domains, such as personalized federated learning and neural architecture |
| | search. Especially powerful results were achieved in few- and zero-shot settings, attributed to the increased |
| | information sharing by the HyperNetwork. With the rise of new diseases fast discovery of drugs is needed which |
| | requires models that can generalize drug-target interaction predictions in low-data scenarios. |
| | |
| | In this work, we propose the HyperPCM model, a task-conditioned HyperNetwork approach for the problem of |
| | predicting drug-target interactions in drug discovery. Our model learns to generate a QSAR model specialized on |
| | a given protein target. We demonstrate state-of-the-art performance over previous methods on multiple |
| | well-known benchmarks, particularly in zero-shot settings for unseen protein targets. This app demonstrates the |
| | model as a retrieval task of the top-k most active drug compounds predicted for a given query target. |
| | ''' |
| | } |
| | |
| | ) |
| |
|
| | st.title('HyperPCM: Robust Task-Conditioned Modeling of Drug-Target Interactions\n') |
| | st.markdown('') |
| | st.markdown( |
| | """ |
| | 🧬 Github: [ml-jku/hyper-dti](https://https://github.com/ml-jku/hyper-dti) 📝 Paper: [JCIM 2024](https://pubs.acs.org/doi/10.1021/acs.jcim.3c01417); [NeurIPS 2022 AI4Science workshop](https://openreview.net/forum?id=dIX34JWnIAL) \n |
| | """ |
| | ) |
| |
|
| | def about_page(): |
| | st.markdown( |
| | """ |
| | ### About |
| | |
| | HyperNetworks have been established as an effective technique to achieve fast adaptation of parameters for |
| | neural networks. Recently, HyperNetwork predictions conditioned on descriptors of tasks have improved |
| | multi-task generalization in various domains, such as personalized federated learning and neural architecture |
| | search. Especially powerful results were achieved in few- and zero-shot settings, attributed to the increased |
| | information sharing by the HyperNetwork. With the rise of new diseases fast discovery of drugs is needed which |
| | requires models that can generalize drug-target interaction predictions in low-data scenarios. |
| | |
| | In this work, we propose the HyperPCM model, a task-conditioned HyperNetwork approach for the problem of |
| | predicting drug-target interactions in drug discovery. Our model learns to generate a QSAR model specialized on |
| | a given protein target. We demonstrate state-of-the-art performance over previous methods on multiple |
| | well-known benchmarks, particularly in zero-shot settings for unseen protein targets. This app demonstrates the |
| | model as a retrieval task of the top-k most active drug compounds predicted for a given query target. |
| | """ |
| | ) |
| |
|
| | st.image('figures/hyper-dti.png', caption='Overview of the HyperPCM architecture.', use_column_width='always') |
| |
|
| | st.markdown( |
| | """ |
| | ### Citation |
| | |
| | Please cite our work using the following reference. |
| | ```bibtex |
| | @article{svensson2024hyperpcm, |
| | title={{HyperPCM: Robust Task-Conditioned Modeling of Drug--Target Interactions}}, |
| | author={Svensson, Emma and Hoedt, Pieter-Jan and Hochreiter, Sepp and Klambauer, G{\"u}nter}, |
| | journal={Journal of Chemical Information and Modeling}, |
| | volume = {64}, |
| | number = {7}, |
| | pages = {2539-2553}, |
| | year = {2024}, |
| | doi = {10.1021/acs.jcim.3c01417}, |
| | publisher={ACS Publications} |
| | } |
| | ``` |
| | """ |
| | ) |
| | |
| |
|
| | def retrieval(): |
| | st.markdown('## Retrieval of most active drug compounds') |
| |
|
| | st.write('Use HyperPCM to generate a QSAR model for a selected query protein target and retrieve the top-k drug compounds predicted to have the highest activity toward the given protein target from the Lenselink datasets.') |
| |
|
| | col1, col2 = st.columns(2) |
| | with col1: |
| | st.markdown('### Query Target') |
| | with col2: |
| | st.markdown('### Drug Database') |
| | |
| | col1, col2, col3, col4 = st.columns(4) |
| | with col1: |
| | ex_target = 'YTKMKTATNIYIFNLALADALATSTLPFQSVNYLMGTWPFGTILCKIVISIDYYNMFTSIFTLCTMSVDRYIAVCHPVKALDFRTPRNAKTVNVCNWI' |
| | sequence = st.text_input('Enter amino-acid sequence', value=ex_target, placeholder=ex_target) |
| | if sequence == ex_target: |
| | st.image('figures/lenselink_ex_target.jpeg', use_column_width='always') |
| | elif sequence == 'HXHVWPVQDAKARFSEFLDACITEGPQIVSRRGAEEAVLVPIGEWRRLQAAA': |
| | st.image('figures/ex_protein.jpeg', use_column_width='always') |
| | elif sequence: |
| | st.error('Visualization coming soon...') |
| | |
| | with col2: |
| | selected_encoder = st.selectbox( |
| | 'Select target encoder',(['SeqVec']) |
| | ) |
| | if sequence: |
| | st.image('figures/target_encoder_done.png', use_column_width='always') |
| | with st.spinner('Encoding in progress...'): |
| |
|
| | with open(os.path.join(data_path, f'Lenselink/processed/SeqVec_encoding_test.pickle'), 'rb') as handle: |
| | test_set = pickle.load(handle) |
| |
|
| | if sequence in list(test_set.keys()): |
| | query_embedding = test_set[sequence] |
| | else: |
| | from bio_embeddings.embed import SeqVecEmbedder |
| | encoder = SeqVecEmbedder() |
| | embeddings = encoder.embed_batch([sequence]) |
| | for emb in embeddings: |
| | query_embedding = encoder.reduce_per_protein(emb) |
| | break |
| | |
| | st.success('Encoding complete.') |
| | |
| | with col3: |
| | selected_database = st.selectbox( |
| | 'Select database',('Lenselink', 'Davis', 'DUD-E', 'DrugBank') |
| | ) |
| | l = { |
| | 'Lenselink': 314707, |
| | 'Davis': 30056, |
| | 'DUDE': 1434019, |
| | 'DrugBank': 10681, |
| | } |
| | if selected_database == 'DUD-E': |
| | selected_database = 'DUDE' |
| | st.image('figures/multi_drugs.png', use_column_width='always') |
| | with st.spinner(f'Loading {l[selected_database]} drugs...'): |
| | batch_size = 2048 |
| | dataset = DrugRetrieval(os.path.join(data_path, selected_database), sequence, query_embedding) |
| | dataloader = DataLoader(dataset, num_workers=2, batch_size=batch_size, shuffle=False, collate_fn=collate_target) |
| | st.success(f'{l[selected_database]} drugs loaded.') |
| | |
| | with col4: |
| | selected_encoder = st.selectbox( |
| | 'Select drug encoder',(['CDDD']) |
| | ) |
| | st.image('figures/drug_encoder_done.png', use_column_width='always') |
| | st.success('Encoding complete.') |
| | |
| | if sequence == ex_target and selected_database == 'Lenselink': |
| | st.markdown('### Inference') |
| |
|
| | progress_text = "HyperPCM is predicting the QSAR model for the query protein target. Please wait." |
| | my_bar = st.progress(0, text=progress_text) |
| | my_bar.progress(100, text="HyperPCM is predicting the QSAR model for the query protein target. Done.") |
| | |
| | st.markdown('### Retrieval') |
| | |
| | selected_k = st.slider(f'Top-k most active drug compounds {selected_database} predicted by HyperPCM are, for k = ', 5, 20, 5, 5) |
| |
|
| | results = pd.read_csv('data/Lenselink/processed/ex_results.csv') |
| | |
| | cols = st.columns(5) |
| | for j, col in enumerate(cols): |
| | with col: |
| | for i in range(int(selected_k/5)): |
| | mol = Chem.MolFromSmiles(results.loc[j + 5*i, 'SMILES']) |
| | mol_img = Chem.Draw.MolToImage(mol) |
| | st.image(mol_img, caption=f"{results.loc[j + 5*i, 'Prediction']:.2f}") |
| |
|
| | st.download_button(f'Download retrieved drug compounds from the {selected_database} database.', results.head(selected_k).to_csv(index=False).encode('utf-8'), file_name='retrieved_drugs.csv') |
| | |
| | elif query_embedding is not None: |
| | st.markdown('### Inference') |
| | |
| | progress_text = "HyperPCM is predicting the QSAR model for the query protein target. Please wait." |
| | my_bar = st.progress(0, text=progress_text) |
| | |
| | gc.collect() |
| | torch.cuda.empty_cache() |
| | memory = dataset |
| | model = HyperPCM(memory=memory).to(device) |
| | model = torch.nn.DataParallel(model) |
| | model.load_state_dict(torch.load(checkpoint_path, map_location=lambda storage, loc: storage)) |
| | model.eval() |
| |
|
| | with torch.set_grad_enabled(False): |
| |
|
| | smiles = [] |
| | preds = [] |
| | i = 0 |
| | for batch, labels in dataloader: |
| | pids, proteins, mids, molecules = batch['pids'], batch['targets'], batch['mids'], batch['drugs'] |
| |
|
| | logits = model(batch) |
| | logits = logits.detach().cpu().numpy() |
| |
|
| | smiles.append(mids) |
| | preds.append(logits) |
| | my_bar.progress((batch_size*i)/len(dataset), text=progress_text) |
| | i += 1 |
| | my_bar.progress(100, text="HyperPCM is predicting the QSAR model for the query protein target. Done.") |
| | |
| | |
| | st.markdown('### Retrieval') |
| | |
| | selected_k = st.slider(f'Top-k most active drug compounds {selected_database} predicted by HyperPCM are, for k = ', 5, 20, 5, 5) |
| |
|
| | if selected_database != 'DrugBank': |
| | results = pd.DataFrame({'SMILES': np.concatenate(smiles), 'Prediction': np.concatenate(preds)}) |
| | else: |
| | with open(os.path.join(data_path, f'{selected_database}/processed/drugbank.pickle'), 'rb') as handle: |
| | lookup = pickle.load(handle) |
| | drug_id = np.concatenate(smiles) |
| | structure = [lookup[i] for i in drug_id] |
| | results = pd.DataFrame({'SMILES': structure, 'DrugBank ID': drug_id, 'Prediction': np.concatenate(preds)}) |
| | results = results.sort_values(by='Prediction', ascending=False) |
| | results = results.reset_index() |
| | |
| | cols = st.columns(5) |
| | for j, col in enumerate(cols): |
| | with col: |
| | for i in range(int(selected_k/5)): |
| | mol = Chem.MolFromSmiles(results.loc[j + 5*i, 'SMILES']) |
| | mol_img = Chem.Draw.MolToImage(mol) |
| | if selected_database != 'DrugBank': |
| | caption = f"{results.loc[j + 5*i, 'Prediction']:.2f}" |
| | else: |
| | caption = f"{results.loc[j + 5*i, 'DrugBank ID']}:\n{results.loc[j + 5*i, 'Prediction']:.2f}" |
| | st.image(mol_img, caption=caption) |
| |
|
| | st.download_button(f'Download retrieved drug compounds from the {selected_database} database.', results.head(selected_k).to_csv(index=False).encode('utf-8'), file_name='retrieved_drugs.csv') |
| |
|
| |
|
| |
|
| | page_names_to_func = { |
| | 'Retrieval': retrieval, |
| | 'About': about_page |
| | } |
| |
|
| | |
| | |
| | |
| |
|
| | tab1, tab2 = st.tabs(page_names_to_func.keys()) |
| |
|
| | with tab1: |
| | page_names_to_func['Retrieval']() |
| |
|
| | with tab2: |
| | page_names_to_func['About']() |
| |
|