turing-space / turing /dataset.py
papri-ka's picture
Deploy FastAPI ML service to Hugging Face Spaces
5fc6e5d
import ast
import os
from pathlib import Path
from datasets import DatasetDict, load_dataset
from loguru import logger
import turing.config as config
class DatasetManager:
"""
Manages the loading, transformation, and access of project datasets.
"""
def __init__(self, dataset_path: Path = None):
self.hf_id = config.DATASET_HF_ID
self.raw_data_dir = config.RAW_DATA_DIR
self.interim_data_dir = config.INTERIM_DATA_DIR
self.base_interim_path = self.interim_data_dir / "base"
if dataset_path:
self.dataset_path = dataset_path
else:
self.dataset_path = self.base_interim_path
def _format_labels_for_csv(self, example: dict) -> dict:
"""
Formats the labels list as a string for CSV storage.
(Private class method)
Args:
example (dict): A single example from the dataset.
Returns:
dict: The example with labels converted to string.
"""
labels = example.get("labels")
if isinstance(labels, list):
example["labels"] = str(labels)
return example
def download_dataset(self):
"""
Loads the dataset from Hugging Face and saves it into the "raw" folder.
"""
logger.info(f"Loading dataset: {self.hf_id}")
try:
ds = load_dataset(self.hf_id)
logger.success("Dataset loaded successfully.")
logger.info(f"Dataset splits: {ds}")
self.raw_data_dir.mkdir(parents=True, exist_ok=True)
for split_name, dataset_split in ds.items():
output_path = os.path.join(
self.raw_data_dir, f"{split_name.replace('-', '_')}.parquet"
)
dataset_split.to_parquet(output_path)
logger.success(f"Dataset saved to {self.raw_data_dir}.")
except Exception as e:
logger.warning(f"Error during loading: {e}.")
def parquet_to_csv(self):
"""
Converts all parquet files in the raw data directory
to CSV format in the interim data directory.
"""
logger.info("Starting Parquet to CSV conversion...")
self.base_interim_path.mkdir(parents=True, exist_ok=True)
for file_name in os.listdir(self.raw_data_dir):
if file_name.endswith(".parquet"):
part_name = file_name.replace(".parquet", "").replace("-", "_")
# Load the parquet file
dataset = load_dataset(
"parquet", data_files={part_name: str(self.raw_data_dir / file_name)}
)
# Map and format labels
dataset[part_name] = dataset[part_name].map(self._format_labels_for_csv)
# Save to CSV
csv_output_path = os.path.join(self.base_interim_path, f"{part_name}.csv")
dataset[part_name].to_csv(csv_output_path)
logger.info(f"Converted {file_name} to {csv_output_path}")
logger.success("Parquet -> CSV conversion complete.")
def get_dataset_name(self) -> str:
"""
Returns the name of the current dataset being used.
Returns:
str: The name of the dataset (e.g., 'clean-aug-soft-k5000').
"""
return self.dataset_path.name
def get_dataset(self) -> DatasetDict:
"""
Returns the processed dataset from the interim data directory
as a DatasetDict (loaded from CSVs).
Returns:
DatasetDict: The complete dataset with train and test splits for each language.
"""
dataset_path = self.dataset_path
# Define the base filenames
data_files = {
"java_train": str(dataset_path / "java_train.csv"),
"java_test": str(dataset_path / "java_test.csv"),
"python_train": str(dataset_path / "python_train.csv"),
"python_test": str(dataset_path / "python_test.csv"),
"pharo_train": str(dataset_path / "pharo_train.csv"),
"pharo_test": str(dataset_path / "pharo_test.csv"),
}
# Verify file existence before loading
logger.info("Loading CSV dataset from splits...")
existing_data_files = {}
for key, path in data_files.items():
if not os.path.exists(path):
found = False
if os.path.exists(dataset_path):
for f in os.listdir(dataset_path):
if f.startswith(key) and f.endswith(".csv"):
existing_data_files[key] = str(dataset_path / f)
found = True
break
if not found:
logger.warning(f"File not found for split '{key}': {path}")
else:
existing_data_files[key] = path
if not existing_data_files:
logger.error("No dataset CSV files found. Run 'parquet-to-csv' first.")
raise FileNotFoundError("Dataset CSV files not found.")
logger.info(f"Found files: {list(existing_data_files.keys())}")
full_dataset = load_dataset("csv", data_files=existing_data_files)
logger.info("Formatting labels (from string back to list)...")
for split in full_dataset:
full_dataset[split] = full_dataset[split].map(
lambda x: {
"labels": ast.literal_eval(x["labels"])
if isinstance(x["labels"], str)
else x["labels"]
}
)
logger.success("Dataset is ready for use.")
return full_dataset
def get_raw_dataset_from_hf(self) -> DatasetDict:
"""
Loads the raw dataset directly from Hugging Face without saving.
Returns:
DatasetDict: The raw dataset from Hugging Face.
"""
logger.info(f"Loading raw dataset '{self.hf_id}' from Hugging Face...")
try:
ds = load_dataset(self.hf_id)
logger.success(f"Successfully loaded '{self.hf_id}'.")
return ds
except Exception as e:
logger.error(f"Failed to load dataset from Hugging Face: {e}")
return None
def search_file(self, file_name: str, search_directory: Path = None) -> list:
"""
Recursively searches for a file by name within a specified data directory.
Args:
file_name (str): The name of the file to search for (e.g., "java_train.csv").
search_directory (Path, optional): The directory to search in.
Defaults to self.raw_data_dir.
Returns:
list: A list of Path objects for all found files.
"""
if search_directory is None:
search_directory = self.raw_data_dir
logger.info(f"Defaulting search to raw data directory: {search_directory}")
if not search_directory.is_dir():
logger.error(f"Search directory not found: {search_directory}")
return []
logger.info(f"Searching for '{file_name}' in '{search_directory}'...")
found_files = []
for root, dirs, files in os.walk(search_directory):
for file in files:
if file == file_name:
found_files.append(Path(root) / file)
if not found_files:
logger.warning(f"No files named '{file_name}' found in '{search_directory}'.")
else:
logger.success(f"Found {len(found_files)} matching file(s).")
return found_files