| | import torch |
| | import torchaudio |
| | import torch.nn as nn |
| | from transformers import PreTrainedModel, PretrainedConfig |
| | import torch |
| | from BigVGAN import bigvgan |
| | from BigVGAN.meldataset import get_mel_spectrogram |
| | from voice_restore import VoiceRestore |
| | import argparse |
| | from model import OptimizedAudioRestorationModel |
| | import librosa |
| | from inference_long import apply_overlap_windowing_waveform, reconstruct_waveform_from_windows |
| |
|
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
|
| | |
| | class VoiceRestoreConfig(PretrainedConfig): |
| | model_type = "voice_restore" |
| |
|
| | def __init__(self, **kwargs): |
| | super().__init__(**kwargs) |
| | self.steps = kwargs.get("steps", 16) |
| | self.cfg_strength = kwargs.get("cfg_strength", 0.5) |
| | self.window_size_sec = kwargs.get("window_size_sec", 5.0) |
| | self.overlap = kwargs.get("overlap", 0.5) |
| |
|
| | |
| | class VoiceRestore(PreTrainedModel): |
| | config_class = VoiceRestoreConfig |
| | |
| | def __init__(self, config: VoiceRestoreConfig): |
| | super().__init__(config) |
| | self.steps = config.steps |
| | self.cfg_strength = config.cfg_strength |
| | self.window_size_sec = config.window_size_sec |
| | self.overlap = config.overlap |
| |
|
| | |
| | self.bigvgan_model = bigvgan.BigVGAN.from_pretrained( |
| | 'nvidia/bigvgan_v2_24khz_100band_256x', |
| | use_cuda_kernel=False, |
| | force_download=False |
| | ).to(device) |
| | self.bigvgan_model.remove_weight_norm() |
| |
|
| | |
| | self.optimized_model = OptimizedAudioRestorationModel(device=device, bigvgan_model=self.bigvgan_model) |
| | save_path = "./pytorch_model.bin" |
| | state_dict = torch.load(save_path, map_location=torch.device(device)) |
| | if 'model_state_dict' in state_dict: |
| | state_dict = state_dict['model_state_dict'] |
| | |
| | self.optimized_model.voice_restore.load_state_dict(state_dict, strict=True) |
| | self.optimized_model.eval() |
| |
|
| | def forward(self, input_path, output_path, short=True): |
| | |
| | if short: |
| | self.restore_audio_short(self.optimized_model, input_path, output_path, self.steps, self.cfg_strength) |
| | else: |
| | self.restore_audio_long(self.optimized_model, input_path, output_path, self.steps, self.cfg_strength, self.window_size_sec, self.overlap) |
| |
|
| | def restore_audio_short(self, model, input_path, output_path, steps, cfg_strength): |
| | """ |
| | Short inference for audio restoration. |
| | """ |
| | |
| | device_type = device.type |
| | audio, sr = torchaudio.load(input_path) |
| | if sr != model.target_sample_rate: |
| | audio = torchaudio.functional.resample(audio, sr, model.target_sample_rate) |
| |
|
| | audio = audio.mean(dim=0, keepdim=True) if audio.dim() > 1 else audio |
| |
|
| | with torch.inference_mode(): |
| | with torch.autocast(device_type): |
| | restored_wav = model(audio, steps=steps, cfg_strength=cfg_strength) |
| | restored_wav = restored_wav.squeeze(0).float().cpu() |
| |
|
| | |
| | torchaudio.save(output_path, restored_wav, model.target_sample_rate) |
| |
|
| | def restore_audio_long(self, model, input_path, output_path, steps, cfg_strength, window_size_sec, overlap): |
| | """ |
| | Long inference for audio restoration using overlapping windows. |
| | """ |
| | |
| | wav, sr = librosa.load(input_path, sr=24000, mono=True) |
| | wav = torch.FloatTensor(wav).unsqueeze(0) |
| |
|
| | window_size_samples = int(window_size_sec * sr) |
| | wav_windows = apply_overlap_windowing_waveform(wav, window_size_samples, overlap) |
| |
|
| | restored_wav_windows = [] |
| | for wav_window in wav_windows: |
| | wav_window = wav_window.to(device) |
| | processed_mel = get_mel_spectrogram(wav_window, self.bigvgan_model.h).to(device) |
| |
|
| | |
| | with torch.no_grad(): |
| | with torch.autocast(device.type): |
| | restored_mel = model.voice_restore.sample(processed_mel.transpose(1, 2), steps=steps, cfg_strength=cfg_strength) |
| | restored_mel = restored_mel.squeeze(0).transpose(0, 1) |
| |
|
| | restored_wav = self.bigvgan_model(restored_mel.unsqueeze(0)).squeeze(0).float().cpu() |
| | restored_wav_windows.append(restored_wav) |
| |
|
| | torch.cuda.empty_cache() |
| |
|
| | restored_wav_windows = torch.stack(restored_wav_windows) |
| | restored_wav = reconstruct_waveform_from_windows(restored_wav_windows, window_size_samples, overlap) |
| |
|
| | |
| | torchaudio.save(output_path, restored_wav.unsqueeze(0), 24000) |
| |
|
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|