Spaces:
Running
Running
| # test1: MJ17 direct | |
| # test2: "A1YU101" thailand cross-ref | |
| # test3: "EBK109" thailand cross-ref | |
| # test4: "OQ731952"/"BST115" for search query title: "South Asian maternal and paternal lineages in southern Thailand and" | |
| from iterate3 import data_preprocess, model | |
| import mtdna_classifier | |
| import app | |
| import pandas as pd | |
| from pathlib import Path | |
| import subprocess | |
| from NER.html import extractHTML | |
| import os | |
| import google.generativeai as genai | |
| import re | |
| import standardize_location | |
| # Helper functions in for this pipeline | |
| # Track time | |
| import time | |
| import multiprocessing | |
| def run_with_timeout(func, args=(), kwargs={}, timeout=20): | |
| """ | |
| Runs `func` with timeout in seconds. Kills if it exceeds. | |
| Returns: (success, result or None) | |
| """ | |
| def wrapper(q, *args, **kwargs): | |
| try: | |
| q.put(func(*args, **kwargs)) | |
| except Exception as e: | |
| q.put(e) | |
| q = multiprocessing.Queue() | |
| p = multiprocessing.Process(target=wrapper, args=(q, *args), kwargs=kwargs) | |
| p.start() | |
| p.join(timeout) | |
| if p.is_alive(): | |
| p.terminate() | |
| p.join() | |
| print(f"⏱️ Timeout exceeded ({timeout} sec) — function killed.") | |
| return False, None | |
| else: | |
| result = q.get() | |
| if isinstance(result, Exception): | |
| raise result | |
| return True, result | |
| def time_it(func, *args, **kwargs): | |
| """ | |
| Measure how long a function takes to run and return its result + time. | |
| """ | |
| start = time.time() | |
| result = func(*args, **kwargs) | |
| end = time.time() | |
| elapsed = end - start | |
| print(f"⏱️ '{func.__name__}' took {elapsed:.3f} seconds") | |
| return result, elapsed | |
| # --- Define Pricing Constants (for Gemini 1.5 Flash & text-embedding-004) --- | |
| def track_gemini_cost(): | |
| # Prices are per 1,000 tokens | |
| PRICE_PER_1K_INPUT_LLM = 0.000075 # $0.075 per 1M tokens | |
| PRICE_PER_1K_OUTPUT_LLM = 0.0003 # $0.30 per 1M tokens | |
| PRICE_PER_1K_EMBEDDING_INPUT = 0.000025 # $0.025 per 1M tokens | |
| return True | |
| def unique_preserve_order(seq): | |
| seen = set() | |
| return [x for x in seq if not (x in seen or seen.add(x))] | |
| # Main execution | |
| def pipeline_with_gemini(accessions): | |
| # output: country, sample_type, ethnic, location, money_cost, time_cost, explain | |
| # there can be one accession number in the accessions | |
| # Prices are per 1,000 tokens | |
| PRICE_PER_1K_INPUT_LLM = 0.000075 # $0.075 per 1M tokens | |
| PRICE_PER_1K_OUTPUT_LLM = 0.0003 # $0.30 per 1M tokens | |
| PRICE_PER_1K_EMBEDDING_INPUT = 0.000025 # $0.025 per 1M tokens | |
| if not accessions: | |
| print("no input") | |
| return None | |
| else: | |
| accs_output = {} | |
| os.environ["GOOGLE_API_KEY"] = "AIzaSyDi0CNKBgEtnr6YuPaY6YNEuC5wT0cdKhk" | |
| genai.configure(api_key=os.getenv("GOOGLE_API_KEY")) | |
| for acc in accessions: | |
| start = time.time() | |
| total_cost_title = 0 | |
| jsonSM, links, article_text = {},[], "" | |
| acc_score = { "isolate": "", | |
| "country":{}, | |
| "sample_type":{}, | |
| #"specific_location":{}, | |
| #"ethnicity":{}, | |
| "query_cost":total_cost_title, | |
| "time_cost":None, | |
| "source":links} | |
| meta = mtdna_classifier.fetch_ncbi_metadata(acc) | |
| country, spe_loc, ethnic, sample_type, col_date, iso, title, doi, pudID, features = meta["country"], meta["specific_location"], meta["ethnicity"], meta["sample_type"], meta["collection_date"], meta["isolate"], meta["title"], meta["doi"], meta["pubmed_id"], meta["all_features"] | |
| acc_score["isolate"] = iso | |
| # set up step: create the folder to save document | |
| chunk, all_output = "","" | |
| if pudID: | |
| id = pudID | |
| saveTitle = title | |
| else: | |
| saveTitle = title + "_" + col_date | |
| id = "DirectSubmission" | |
| folder_path = Path("/content/drive/MyDrive/CollectData/MVP/mtDNA-Location-Classifier/data/"+str(id)) | |
| if not folder_path.exists(): | |
| cmd = f'mkdir /content/drive/MyDrive/CollectData/MVP/mtDNA-Location-Classifier/data/{id}' | |
| result = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) | |
| print("data/"+str(id) +" created.") | |
| else: | |
| print("data/"+str(id) +" already exists.") | |
| saveLinkFolder = "/content/drive/MyDrive/CollectData/MVP/mtDNA-Location-Classifier/data/"+str(id) | |
| # first way: ncbi method | |
| if country.lower() != "unknown": | |
| stand_country = standardize_location.smart_country_lookup(country.lower()) | |
| if stand_country.lower() != "not found": | |
| acc_score["country"][stand_country.lower()] = ["ncbi"] | |
| else: acc_score["country"][country.lower()] = ["ncbi"] | |
| # if spe_loc.lower() != "unknown": | |
| # acc_score["specific_location"][spe_loc.lower()] = ["ncbi"] | |
| # if ethnic.lower() != "unknown": | |
| # acc_score["ethnicity"][ethnic.lower()] = ["ncbi"] | |
| if sample_type.lower() != "unknown": | |
| acc_score["sample_type"][sample_type.lower()] = ["ncbi"] | |
| # second way: LLM model | |
| # Preprocess the input token | |
| accession, isolate = None, None | |
| if acc != "unknown": accession = acc | |
| if iso != "unknown": isolate = iso | |
| # check doi first | |
| if doi != "unknown": | |
| link = 'https://doi.org/' + doi | |
| # get the file to create listOfFile for each id | |
| html = extractHTML.HTML("",link) | |
| jsonSM = html.getSupMaterial() | |
| article_text = html.getListSection() | |
| if article_text: | |
| if "Just a moment...Enable JavaScript and cookies to continue".lower() not in article_text.lower() or "403 Forbidden Request".lower() not in article_text.lower(): | |
| links.append(link) | |
| if jsonSM: | |
| links += sum((jsonSM[key] for key in jsonSM),[]) | |
| # no doi then google custom search api | |
| if len(article_text) == 0 or "Just a moment...Enable JavaScript and cookies to continue".lower() in article_text.lower() or "403 Forbidden Request".lower() in article_text.lower(): | |
| # might find the article | |
| tem_links = mtdna_classifier.search_google_custom(title, 2) | |
| # get supplementary of that article | |
| for link in tem_links: | |
| html = extractHTML.HTML("",link) | |
| jsonSM = html.getSupMaterial() | |
| article_text_tem = html.getListSection() | |
| if article_text_tem: | |
| if "Just a moment...Enable JavaScript and cookies to continue".lower() not in article_text_tem.lower() or "403 Forbidden Request".lower() not in article_text_tem.lower(): | |
| links.append(link) | |
| if jsonSM: | |
| links += sum((jsonSM[key] for key in jsonSM),[]) | |
| print(links) | |
| links = unique_preserve_order(links) | |
| acc_score["source"] = links | |
| chunk_path = "/"+saveTitle+"_merged_document.docx" | |
| all_path = "/"+saveTitle+"_all_merged_document.docx" | |
| # if chunk and all output not exist yet | |
| file_chunk_path = saveLinkFolder + chunk_path | |
| file_all_path = saveLinkFolder + all_path | |
| if os.path.exists(file_chunk_path): | |
| print("File chunk exists!") | |
| if not chunk: | |
| text, table, document_title = model.read_docx_text(file_chunk_path) | |
| chunk = data_preprocess.normalize_for_overlap(text) + "\n" + data_preprocess.normalize_for_overlap(". ".join(table)) | |
| if os.path.exists(file_all_path): | |
| print("File all output exists!") | |
| if not all_output: | |
| text_all, table_all, document_title_all = model.read_docx_text(file_all_path) | |
| all_output = data_preprocess.normalize_for_overlap(text_all) + "\n" + data_preprocess.normalize_for_overlap(". ".join(table_all)) | |
| if not chunk and not all_output: | |
| # else: check if we can reuse these chunk and all output of existed accession to find another | |
| if links: | |
| for link in links: | |
| print(link) | |
| # if len(all_output) > 1000*1000: | |
| # all_output = data_preprocess.normalize_for_overlap(all_output) | |
| # print("after normalizing all output: ", len(all_output)) | |
| if len(data_preprocess.normalize_for_overlap(all_output)) > 600000: | |
| print("break here") | |
| break | |
| if iso != "unknown": query_kw = iso | |
| else: query_kw = acc | |
| #text_link, tables_link, final_input_link = data_preprocess.preprocess_document(link,saveLinkFolder, isolate=query_kw) | |
| success_process, output_process = run_with_timeout(data_preprocess.preprocess_document,args=(link,saveLinkFolder),kwargs={"isolate":query_kw},timeout=180) | |
| if success_process: | |
| text_link, tables_link, final_input_link = output_process[0], output_process[1], output_process[2] | |
| print("yes succeed for process document") | |
| else: text_link, tables_link, final_input_link = "", "", "" | |
| context = data_preprocess.extract_context(final_input_link, query_kw) | |
| if context != "Sample ID not found.": | |
| if len(data_preprocess.normalize_for_overlap(chunk)) < 1000*1000: | |
| success_chunk, the_output_chunk = run_with_timeout(data_preprocess.merge_texts_skipping_overlap,args=(chunk, context)) | |
| if success_chunk: | |
| chunk = the_output_chunk#data_preprocess.merge_texts_skipping_overlap(all_output, final_input_link) | |
| print("yes succeed for chunk") | |
| else: | |
| chunk += context | |
| print("len context: ", len(context)) | |
| print("basic fall back") | |
| print("len chunk after: ", len(chunk)) | |
| if len(final_input_link) > 1000*1000: | |
| if context != "Sample ID not found.": | |
| final_input_link = context | |
| else: | |
| final_input_link = data_preprocess.normalize_for_overlap(final_input_link) | |
| if len(final_input_link) > 1000 *1000: | |
| final_input_link = final_input_link[:100000] | |
| if len(data_preprocess.normalize_for_overlap(all_output)) < 1000*1000: | |
| success, the_output = run_with_timeout(data_preprocess.merge_texts_skipping_overlap,args=(all_output, final_input_link)) | |
| if success: | |
| all_output = the_output#data_preprocess.merge_texts_skipping_overlap(all_output, final_input_link) | |
| print("yes succeed") | |
| else: | |
| all_output += final_input_link | |
| print("len final input: ", len(final_input_link)) | |
| print("basic fall back") | |
| print("len all output after: ", len(all_output)) | |
| #country_pro, chunk, all_output = data_preprocess.process_inputToken(links, saveLinkFolder, accession=accession, isolate=isolate) | |
| else: | |
| chunk = "Collection_date: " + col_date +". Isolate: " + iso + ". Title: " + title + ". Features: " + features | |
| all_output = "Collection_date: " + col_date +". Isolate: " + iso + ". Title: " + title + ". Features: " + features | |
| if not chunk: chunk = "Collection_date: " + col_date +". Isolate: " + iso + ". Title: " + title + ". Features: " + features | |
| if not all_output: all_output = "Collection_date: " + col_date +". Isolate: " + iso + ". Title: " + title + ". Features: " + features | |
| if len(all_output) > 1*1024*1024: | |
| all_output = data_preprocess.normalize_for_overlap(all_output) | |
| if len(all_output) > 1*1024*1024: | |
| all_output = all_output[:1*1024*1024] | |
| print("chunk len: ", len(chunk)) | |
| print("all output len: ", len(all_output)) | |
| data_preprocess.save_text_to_docx(chunk, file_chunk_path) | |
| data_preprocess.save_text_to_docx(all_output, file_all_path) | |
| # else: | |
| # final_input = "" | |
| # if all_output: | |
| # final_input = all_output | |
| # else: | |
| # if chunk: final_input = chunk | |
| # #data_preprocess.merge_texts_skipping_overlap(final_input, all_output) | |
| # if final_input: | |
| # keywords = [] | |
| # if iso != "unknown": keywords.append(iso) | |
| # if acc != "unknown": keywords.append(acc) | |
| # for keyword in keywords: | |
| # chunkBFS = data_preprocess.get_contextual_sentences_BFS(final_input, keyword) | |
| # countryDFS, chunkDFS = data_preprocess.get_contextual_sentences_DFS(final_input, keyword) | |
| # chunk = data_preprocess.merge_texts_skipping_overlap(chunk, chunkDFS) | |
| # chunk = data_preprocess.merge_texts_skipping_overlap(chunk, chunkBFS) | |
| # Define paths for cached RAG assets | |
| faiss_index_path = saveLinkFolder+"/faiss_index.bin" | |
| document_chunks_path = saveLinkFolder+"/document_chunks.json" | |
| structured_lookup_path = saveLinkFolder+"/structured_lookup.json" | |
| master_structured_lookup, faiss_index, document_chunks = model.load_rag_assets( | |
| faiss_index_path, document_chunks_path, structured_lookup_path | |
| ) | |
| global_llm_model_for_counting_tokens = genai.GenerativeModel('gemini-1.5-flash-latest') | |
| if not all_output: | |
| if chunk: all_output = chunk | |
| else: all_output = "Collection_date: " + col_date +". Isolate: " + iso + ". Title: " + title + ". Features: " + features | |
| if faiss_index is None: | |
| print("\nBuilding RAG assets (structured lookup, FAISS index, chunks)...") | |
| total_doc_embedding_tokens = global_llm_model_for_counting_tokens.count_tokens( | |
| all_output | |
| ).total_tokens | |
| initial_embedding_cost = (total_doc_embedding_tokens / 1000) * PRICE_PER_1K_EMBEDDING_INPUT | |
| total_cost_title += initial_embedding_cost | |
| print(f"Initial one-time embedding cost for '{file_all_path}' ({total_doc_embedding_tokens} tokens): ${initial_embedding_cost:.6f}") | |
| master_structured_lookup, faiss_index, document_chunks, plain_text_content = model.build_vector_index_and_data( | |
| file_all_path, faiss_index_path, document_chunks_path, structured_lookup_path | |
| ) | |
| else: | |
| print("\nRAG assets loaded from file. No re-embedding of entire document will occur.") | |
| plain_text_content_all, table_strings_all, document_title_all = model.read_docx_text(file_all_path) | |
| master_structured_lookup['document_title'] = master_structured_lookup.get('document_title', document_title_all) | |
| primary_word = iso | |
| alternative_word = acc | |
| print(f"\n--- General Query: Primary='{primary_word}' (Alternative='{alternative_word}') ---") | |
| if features.lower() not in all_output.lower(): | |
| all_output += ". NCBI Features: " + features | |
| # country, sample_type, method_used, ethnic, spe_loc, total_query_cost = model.query_document_info( | |
| # primary_word, alternative_word, meta, master_structured_lookup, faiss_index, document_chunks, | |
| # model.call_llm_api, chunk=chunk, all_output=all_output) | |
| country, sample_type, method_used, country_explanation, sample_type_explanation, total_query_cost = model.query_document_info( | |
| primary_word, alternative_word, meta, master_structured_lookup, faiss_index, document_chunks, | |
| model.call_llm_api, chunk=chunk, all_output=all_output) | |
| if len(country) == 0: country = "unknown" | |
| if len(sample_type) == 0: sample_type = "unknown" | |
| if country_explanation: country_explanation = "-"+country_explanation | |
| else: country_explanation = "" | |
| if sample_type_explanation: sample_type_explanation = "-"+sample_type_explanation | |
| else: sample_type_explanation = "" | |
| if method_used == "unknown": method_used = "" | |
| if country.lower() != "unknown": | |
| stand_country = standardize_location.smart_country_lookup(country.lower()) | |
| if stand_country.lower() != "not found": | |
| if stand_country.lower() in acc_score["country"]: | |
| if country_explanation: | |
| acc_score["country"][stand_country.lower()].append(method_used + country_explanation) | |
| else: | |
| acc_score["country"][stand_country.lower()] = [method_used + country_explanation] | |
| else: | |
| if country.lower() in acc_score["country"]: | |
| if country_explanation: | |
| if len(method_used + country_explanation) > 0: | |
| acc_score["country"][country.lower()].append(method_used + country_explanation) | |
| else: | |
| if len(method_used + country_explanation) > 0: | |
| acc_score["country"][country.lower()] = [method_used + country_explanation] | |
| # if spe_loc.lower() != "unknown": | |
| # if spe_loc.lower() in acc_score["specific_location"]: | |
| # acc_score["specific_location"][spe_loc.lower()].append(method_used) | |
| # else: | |
| # acc_score["specific_location"][spe_loc.lower()] = [method_used] | |
| # if ethnic.lower() != "unknown": | |
| # if ethnic.lower() in acc_score["ethnicity"]: | |
| # acc_score["ethnicity"][ethnic.lower()].append(method_used) | |
| # else: | |
| # acc_score["ethnicity"][ethnic.lower()] = [method_used] | |
| if sample_type.lower() != "unknown": | |
| if sample_type.lower() in acc_score["sample_type"]: | |
| if len(method_used + sample_type_explanation) > 0: | |
| acc_score["sample_type"][sample_type.lower()].append(method_used + sample_type_explanation) | |
| else: | |
| if len(method_used + sample_type_explanation)> 0: | |
| acc_score["sample_type"][sample_type.lower()] = [method_used + sample_type_explanation] | |
| end = time.time() | |
| total_cost_title += total_query_cost | |
| acc_score["query_cost"] = total_cost_title | |
| elapsed = end - start | |
| acc_score["time_cost"] = f"{elapsed:.3f} seconds" | |
| accs_output[acc] = acc_score | |
| print(accs_output[acc]) | |
| return accs_output |