P2DFlow / analysis /src /eval.py
Holmes
test
ca7299e
raw
history blame
7.95 kB
from typing import Any, Dict, List, Tuple
import os
from time import strftime
import numpy as np
import pandas as pd
import torch
# import hydra
# import rootutils
# from lightning import LightningDataModule, LightningModule, Trainer
# from lightning.pytorch.loggers import Logger
from omegaconf import DictConfig
# rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
# ------------------------------------------------------------------------------------ #
# the setup_root above is equivalent to:
# - adding project root dir to PYTHONPATH
# (so you don't need to force user to install project as a package)
# (necessary before importing any local modules e.g. `from src import utils`)
# - setting up PROJECT_ROOT environment variable
# (which is used as a base for paths in "configs/paths/default.yaml")
# (this way all filepaths are the same no matter where you run the code)
# - loading environment variables from ".env" in root dir
#
# you can remove it if you:
# 1. either install project as a package or move entry files to project root dir
# 2. set `root_dir` to "." in "configs/paths/default.yaml"
#
# more info: https://github.com/ashleve/rootutils
# ------------------------------------------------------------------------------------ #
from src.utils import (
RankedLogger,
extras,
instantiate_loggers,
log_hyperparameters,
task_wrapper,
checkpoint_utils,
plot_utils,
)
from src.common.pdb_utils import extract_backbone_coords
from src.metrics import metrics
from src.common.geo_utils import _find_rigid_alignment
log = RankedLogger(__name__, rank_zero_only=True)
def evaluate_prediction(pred_dir: str, target_dir: str = None, crystal_dir: str = None, tag: str = None):
"""Evaluate prediction results based on pdb files.
"""
if target_dir is None or not os.path.isdir(target_dir):
log.warning(f"target_dir {target_dir} does not exist. Skip evaluation.")
return {}
assert os.path.isdir(pred_dir), f"pred_dir {pred_dir} is not a directory."
targets = [
d.replace(".pdb", "") for d in os.listdir(target_dir)
]
# pred_bases = os.listdir(pred_dir)
output_dir = pred_dir
tag = tag if tag is not None else "dev"
timestamp = strftime("%m%d-%H-%M")
fns = {
'val_clash': metrics.validity,
'val_bond': metrics.bonding_validity,
'js_pwd': metrics.js_pwd,
'js_rg': metrics.js_rg,
# 'js_tica_pos': metrics.js_tica_pos,
'w2_rmwd': metrics.w2_rmwd,
# 'div_rmsd': metrics.div_rmsd,
'div_rmsf': metrics.div_rmsf,
'pro_w_contacks': metrics.pro_w_contacts,
'pro_t_contacks': metrics.pro_t_contacts,
# 'pro_c_contacks': metrics.pro_c_contacts,
}
eval_res = {k: {} for k in fns}
print(f"total_md_num = {len(targets)}")
count = 0
for target in targets:
count += 1
print("")
print(count, target)
pred_file = os.path.join(pred_dir, f"{target}.pdb")
# assert os.path.isfile(pred_file), f"pred_file {pred_file} does not exist."
if not os.path.isfile(pred_file):
continue
target_file = os.path.join(target_dir, f"{target}.pdb")
ca_coords = {
'target': extract_backbone_coords(target_file),
'pred': extract_backbone_coords(pred_file),
}
cry_target_file = os.path.join(crystal_dir, f"{target}.pdb")
cry_ca_coords = extract_backbone_coords(cry_target_file)[0]
for f_name, func in fns.items():
print(f_name)
if f_name == 'w2_rmwd':
v_ref = torch.as_tensor(ca_coords['target'][0])
for k, v in ca_coords.items():
v = torch.as_tensor(v) # (250,356,3)
for idx in range(v.shape[0]):
R, t = _find_rigid_alignment(v[idx], v_ref)
v[idx] = (torch.matmul(R, v[idx].transpose(-2, -1))).transpose(-2, -1) + t.unsqueeze(0)
ca_coords[k] = v.numpy()
if f_name.startswith('js_'):
res = func(ca_coords, ref_key='target')
elif f_name == 'pro_c_contacks':
res = func(target_file, pred_file, cry_target_file)
elif f_name.startswith('pro_'):
res = func(ca_coords, cry_ca_coords)
else:
res = func(ca_coords)
if f_name == 'js_tica' or f_name == 'js_tica_pos':
pass
# eval_res[f_name][target] = res[0]['pred']
# save_to = os.path.join(output_dir, f"tica_{target}_{tag}_{timestamp}.png")
# plot_utils.scatterplot_2d(res[1], save_to=save_to, ref_key='target')
else:
eval_res[f_name][target] = res['pred']
csv_save_to = os.path.join(output_dir, f"metrics_{tag}_{timestamp}.csv")
df = pd.DataFrame.from_dict(eval_res) # row = target, col = metric name
df.to_csv(csv_save_to)
print(f"metrics saved to {csv_save_to}")
mean_metrics = np.around(df.mean(), decimals=4)
return mean_metrics
# @task_wrapper
# def evaluate(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
# """Sample on a test set and report evaluation metrics.
# This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
# failure. Useful for multiruns, saving info about the crash, etc.
# :param cfg: DictConfig configuration composed by Hydra.
# :return: Tuple[dict, dict] with metrics and dict with all instantiated objects.
# """
# # assert cfg.ckpt_path
# pred_dir = cfg.get("pred_dir")
# if pred_dir and os.path.isdir(pred_dir):
# log.info(f"Found pre-computed prediction directory {pred_dir}.")
# metric_dict = evaluate_prediction(pred_dir, target_dir=cfg.target_dir)
# return metric_dict, None
# log.info(f"Instantiating datamodule <{cfg.data._target_}>")
# datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)
# log.info(f"Instantiating model <{cfg.model._target_}>")
# model: LightningModule = hydra.utils.instantiate(cfg.model)
# log.info("Instantiating loggers...")
# logger: List[Logger] = instantiate_loggers(cfg.get("logger"))
# log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
# trainer: Trainer = hydra.utils.instantiate(cfg.trainer, logger=logger)
# object_dict = {
# "cfg": cfg,
# "datamodule": datamodule,
# "model": model,
# "logger": logger,
# "trainer": trainer,
# }
# if logger:
# log.info("Logging hyperparameters!")
# log_hyperparameters(object_dict)
# # Load checkpoint manually.
# model, ckpt_path = checkpoint_utils.load_model_checkpoint(model, cfg.ckpt_path)
# # log.info("Starting testing!")
# # trainer.test(model=model, datamodule=datamodule, ckpt_path=cfg.ckpt_path)
# # Get dataloader for prediction.
# datamodule.setup(stage="predict")
# dataloaders = datamodule.test_dataloader()
# log.info("Starting predictions.")
# pred_dir = trainer.predict(model=model, dataloaders=dataloaders, ckpt_path=ckpt_path)[-1]
# # metric_dict = trainer.callback_metrics
# log.info("Starting evaluations.")
# metric_dict = evaluate_prediction(pred_dir, target_dir=cfg.target_dir)
# return metric_dict, object_dict
# @hydra.main(version_base="1.3", config_path="../configs", config_name="eval.yaml")
# def main(cfg: DictConfig) -> None:
# """Main entry point for evaluation.
# :param cfg: DictConfig configuration composed by Hydra.
# """
# # apply extra utilities
# # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
# extras(cfg)
# evaluate(cfg)
# if __name__ == "__main__":
# main()