Spaces:
Running
Running
| import math | |
| import pandas as pd | |
| from graphgen.bases import BaseGraphStorage, BaseLLMWrapper, BaseOperator | |
| from graphgen.common import init_llm, init_storage | |
| from graphgen.templates import STATEMENT_JUDGEMENT_PROMPT | |
| from graphgen.utils import logger, run_concurrent, yes_no_loss_entropy | |
| class JudgeService(BaseOperator): | |
| """Service for judging graph edges and nodes using a trainee LLM.""" | |
| def __init__(self, working_dir: str = "cache", graph_backend: str = "kuzu"): | |
| super().__init__(working_dir=working_dir, op_name="judge_service") | |
| self.llm_client: BaseLLMWrapper = init_llm("trainee") | |
| self.graph_storage: BaseGraphStorage = init_storage( | |
| backend=graph_backend, | |
| working_dir=working_dir, | |
| namespace="graph", | |
| ) | |
| def process(self, batch: pd.DataFrame) -> pd.DataFrame: | |
| items = batch.to_dict(orient="records") | |
| self.graph_storage.reload() | |
| self.judge(items) | |
| return pd.DataFrame([{"status": "judging_completed"}]) | |
| async def _process_single_judge(self, item: dict) -> dict: | |
| description = item["description"] | |
| try: | |
| judgement = await self.llm_client.generate_topk_per_token( | |
| STATEMENT_JUDGEMENT_PROMPT["TEMPLATE"].format(statement=description) | |
| ) | |
| top_candidates = judgement[0].top_candidates | |
| gt = item.get("ground_truth", "yes") | |
| loss = yes_no_loss_entropy([top_candidates], [gt]) | |
| logger.debug("Description: %s Loss: %s", description, loss) | |
| item["loss"] = loss | |
| except Exception as e: # pylint: disable=broad-except | |
| logger.error("Error in judging description: %s", e) | |
| logger.info("Use default loss 0.1") | |
| item["loss"] = -math.log(0.1) | |
| return item | |
| def judge(self, items: list[dict]) -> None: | |
| """ | |
| Judge the description in the item and compute the loss. | |
| """ | |
| results = run_concurrent( | |
| self._process_single_judge, | |
| items, | |
| desc="Judging descriptions", | |
| unit="description", | |
| ) | |
| # Update the graph storage with the computed losses | |
| for item in results: | |
| index = item["index"] | |
| loss = item["loss"] | |
| if isinstance(index, str): | |
| node_id = index | |
| node_data = self.graph_storage.get_node(node_id) | |
| node_data["loss"] = loss | |
| self.graph_storage.update_node(node_id, node_data) | |
| elif isinstance(index, tuple): | |
| edge_source, edge_target = index | |
| edge_data = self.graph_storage.get_edge(edge_source, edge_target) | |
| edge_data["loss"] = loss | |
| self.graph_storage.update_edge(edge_source, edge_target, edge_data) | |
| self.graph_storage.index_done_callback() | |