| | import json |
| | import numpy as np |
| | from typing import Any, Dict, List |
| | from rank_bm25 import BM25Okapi |
| |
|
| | class BM25Retriever: |
| | def __init__(self, mode="instruction"): |
| | assert mode in ("instruction", "code") |
| | self.bm25: BM25Okapi = None |
| | self.content_input_path: str = "" |
| | self.mode = mode |
| | |
| | def process(self, content_input_path: str): |
| | self.content_input_path = content_input_path |
| | with open(content_input_path, "r", encoding="utf-8") as f: |
| | content = json.load(f) |
| | |
| | |
| | self.chunks = [] |
| | self.corpus = [] |
| | for c in content: |
| | self.chunks.append(c["code"]) |
| | self.corpus.append(c["description_1"]) |
| |
|
| | if self.mode == "instruction" and self.corpus: |
| | tokenized_corpus = [co.split(" ") for co in self.corpus] |
| | self.bm25 = BM25Okapi(tokenized_corpus) |
| | elif self.mode == "code" and self.chunks: |
| | tokenized_corpus = [co.split(" ") for co in self.chunks] |
| | self.bm25 = BM25Okapi(tokenized_corpus) |
| | else: |
| | self.bm25 = None |
| |
|
| | def query( |
| | self, |
| | query: str, |
| | top_k: int = 1 |
| | ) -> List[Dict[str, Any]]: |
| | |
| | if top_k <= 0: |
| | raise ValueError("top_k must be a positive integer.") |
| | if self.bm25 is None or not self.chunks: |
| | raise ValueError( |
| | "BM25 model is not initialized. Call `process` first." |
| | ) |
| | |
| | |
| | processed_query = query.split(" ") |
| | |
| | scores = self.bm25.get_scores(processed_query) |
| |
|
| | top_k_indices = np.argpartition(scores, -top_k)[-top_k:] |
| |
|
| | formatted_results = [] |
| | for i in top_k_indices: |
| | result_dict = { |
| | "similarity score": scores[i], |
| | "original instruction": self.corpus[i], |
| | "code": self.chunks[i] |
| | } |
| | formatted_results.append(result_dict) |
| | |
| | |
| | formatted_results.sort( |
| | key=lambda x: x['similarity score'], reverse=True |
| | ) |
| |
|
| | return formatted_results |