GraphGen / graphgen /operators /judge /judge_service.py
github-actions[bot]
Auto-sync from demo at Tue Dec 16 08:21:05 UTC 2025
31086ae
import math
import pandas as pd
from graphgen.bases import BaseGraphStorage, BaseLLMWrapper, BaseOperator
from graphgen.common import init_llm, init_storage
from graphgen.templates import STATEMENT_JUDGEMENT_PROMPT
from graphgen.utils import logger, run_concurrent, yes_no_loss_entropy
class JudgeService(BaseOperator):
"""Service for judging graph edges and nodes using a trainee LLM."""
def __init__(self, working_dir: str = "cache", graph_backend: str = "kuzu"):
super().__init__(working_dir=working_dir, op_name="judge_service")
self.llm_client: BaseLLMWrapper = init_llm("trainee")
self.graph_storage: BaseGraphStorage = init_storage(
backend=graph_backend,
working_dir=working_dir,
namespace="graph",
)
def process(self, batch: pd.DataFrame) -> pd.DataFrame:
items = batch.to_dict(orient="records")
self.graph_storage.reload()
self.judge(items)
return pd.DataFrame([{"status": "judging_completed"}])
async def _process_single_judge(self, item: dict) -> dict:
description = item["description"]
try:
judgement = await self.llm_client.generate_topk_per_token(
STATEMENT_JUDGEMENT_PROMPT["TEMPLATE"].format(statement=description)
)
top_candidates = judgement[0].top_candidates
gt = item.get("ground_truth", "yes")
loss = yes_no_loss_entropy([top_candidates], [gt])
logger.debug("Description: %s Loss: %s", description, loss)
item["loss"] = loss
except Exception as e: # pylint: disable=broad-except
logger.error("Error in judging description: %s", e)
logger.info("Use default loss 0.1")
item["loss"] = -math.log(0.1)
return item
def judge(self, items: list[dict]) -> None:
"""
Judge the description in the item and compute the loss.
"""
results = run_concurrent(
self._process_single_judge,
items,
desc="Judging descriptions",
unit="description",
)
# Update the graph storage with the computed losses
for item in results:
index = item["index"]
loss = item["loss"]
if isinstance(index, str):
node_id = index
node_data = self.graph_storage.get_node(node_id)
node_data["loss"] = loss
self.graph_storage.update_node(node_id, node_data)
elif isinstance(index, tuple):
edge_source, edge_target = index
edge_data = self.graph_storage.get_edge(edge_source, edge_target)
edge_data["loss"] = loss
self.graph_storage.update_edge(edge_source, edge_target, edge_data)
self.graph_storage.index_done_callback()