|
|
import os |
|
|
import json |
|
|
from tqdm import tqdm |
|
|
from loguru import logger |
|
|
from dataclasses import dataclass, asdict |
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed |
|
|
from models.Base import BaseModel |
|
|
from dataloaders.ProblemState import ProblemState |
|
|
from memories.Memory import BaseMemory |
|
|
|
|
|
|
|
|
|
|
|
class BaseAgent: |
|
|
|
|
|
def __init__(self, model: BaseModel, dataset): |
|
|
self.model = model |
|
|
self.dataset = dataset |
|
|
self.memories = self.memory_init() |
|
|
|
|
|
def memory_init(self): |
|
|
return [BaseMemory(ps) for ps in self.dataset.problem_states] |
|
|
def run_single_pass(self, mem: BaseMemory, verbose=False, temperature=0): |
|
|
pass |
|
|
|
|
|
def run(self, output_path=None, multi_thread=True, verbose=False, datalen=None, mem_path=None, temperature=0): |
|
|
data_len = datalen if datalen else len(self.dataset) |
|
|
with tqdm(total=data_len) as pbar: |
|
|
if multi_thread: |
|
|
thread_num = 3 |
|
|
|
|
|
with ThreadPoolExecutor(max_workers=thread_num) as executor: |
|
|
futures = {executor.submit(self.run_single_pass, mem, temperature): mem for mem in self.memories[:data_len]} |
|
|
for future in as_completed(futures): |
|
|
pbar.update(1) |
|
|
else: |
|
|
for mem in self.memories[:data_len]: |
|
|
self.run_single_pass(mem, verbose, temperature=temperature) |
|
|
pbar.update(1) |
|
|
|
|
|
|
|
|
if output_path is not None: |
|
|
self.dataset.write_file(output_path) |
|
|
|
|
|
if mem_path is not None: |
|
|
self.write_memories(mem_path) |
|
|
|
|
|
def write_memories(self, file_path): |
|
|
with open(file_path, "w") as f: |
|
|
for mem in self.memories: |
|
|
output = asdict(mem) |
|
|
f.write(json.dumps(output) + "\n") |
|
|
|
|
|
class SequentialBaseAgent(BaseAgent): |
|
|
def run(self, output_path=None, multi_thread=True, verbose=False, datalen=None, iteration_num=0, temperature=0): |
|
|
data_len = datalen if datalen else len(self.dataset) |
|
|
|
|
|
for iter in range(iteration_num): |
|
|
logger.info(f"\n=== Iteration {iter} ===") |
|
|
root, extension = os.path.splitext(output_path) |
|
|
iter_path = f"{root}_{iter}{extension}" |
|
|
with tqdm(total=data_len) as pbar: |
|
|
if multi_thread: |
|
|
thread_num = 3 |
|
|
|
|
|
with ThreadPoolExecutor(max_workers=thread_num) as executor: |
|
|
futures = {executor.submit(self.run_single_pass, mem, temperature): mem for mem in self.memories[:data_len]} |
|
|
for future in as_completed(futures): |
|
|
pbar.update(1) |
|
|
|
|
|
else: |
|
|
for mem in self.memories[:data_len]: |
|
|
self.run_single_pass(mem, verbose, temperature=temperature) |
|
|
pbar.update(1) |
|
|
|
|
|
|
|
|
if output_path is not None: |
|
|
self.dataset.write_file(iter_path) |