from df.enhance import enhance, init_df, save_audio
from df.utils import download_file
import os
import pathlib
import librosa
import numpy as np
import soundfile as sf
import torch
import argparse

def process_audio(input_file=None, output_dir="./output_denoised"):
    # Load DeepFilterNet2 model (latest version)
    model, df_state, suffix, epoch = init_df(model_base_dir="DeepFilterNet2") # production ready
    # model, df_state, suffix, epoch = init_df(model_base_dir="DeepFilterNet") # useless
    
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    
    # Use input file if provided, otherwise download a sample
    if input_file and os.path.exists(input_file):
        audio_path = input_file
        print(f"Using provided audio file: {audio_path}")
    else:
        print("No valid input file provided. Downloading sample file...")
        audio_url = "https://github.com/Rikorose/DeepFilterNet/raw/e031053/assets/noisy_snr0.wav"
        audio_path = download_file(audio_url, download_dir=output_dir)
        print(f"Downloaded audio file: {audio_path}")
    
    print(f"File exists: {os.path.exists(audio_path)}")
    
    # Load the audio file using librosa
    print(f"Loading audio file with librosa...")
    audio_data, sr = librosa.load(audio_path, sr=None, mono=True)
    # Convert to float32 for DeepFilterNet
    audio_np = np.float32(audio_data)
    print(f"Audio shape: {audio_np.shape}, Sample rate: {sr}")
    
    # Make sure the sample rate matches what the model expects
    if sr != df_state.sr():
        print(f"Resampling from {sr} to {df_state.sr()}")
        audio_np = librosa.resample(audio_np, orig_sr=sr, target_sr=df_state.sr())
    
    # Convert to torch tensor and reshape to [1, samples] (batch dimension)
    audio = torch.tensor(audio_np, dtype=torch.float32).unsqueeze(0)
    print(f"Tensor shape after reshaping: {audio.shape}")
    
    # Denoise the audio
    print("Enhancing audio...")
    enhanced = enhance(model, df_state, audio)
    
    # Generate output filename
    input_filename = os.path.basename(audio_path)
    output_filename = f"enhanced_{input_filename}"
    output_path = os.path.join(output_dir, output_filename)
    
    print(f"Saving enhanced audio to: {output_path}")
    
    # Convert enhanced tensor back to numpy if needed
    if isinstance(enhanced, torch.Tensor):
        enhanced = enhanced.squeeze(0).cpu().numpy()  # Remove batch dimension
        
    sf.write(output_path, enhanced, df_state.sr())
    print(f"Enhanced audio saved to: {output_path}")
    
    return output_path

if __name__ == "__main__":
    # Parse command line arguments
    parser = argparse.ArgumentParser(description="Audio denoising with DeepFilterNet")
    parser.add_argument("--input", "-i", type=str, help="Path to input audio file")
    parser.add_argument("--output-dir", "-o", type=str, default="./output_denoised",
                        help="Directory to save enhanced audio (default: ./output_denoised)")
    args = parser.parse_args()
    
    # Process the audio
    enhanced_file = process_audio(input_file=args.input, output_dir=args.output_dir) 