| | import os |
| | import torch |
| | from transformers import logging |
| | from transformers import AutoTokenizer |
| | from wrapper import EvalWrapper |
| | from models_xin import CLAP |
| | from utils import compute_similarity |
| | import librosa |
| |
|
| |
|
| | if __name__ == '__main__': |
| | logging.set_verbosity_error() |
| | ckpt = torch.hub.load_state_dict_from_url( |
| | url="https://huggingface.co/KeiKinn/paraclap/resolve/main/best.pth.tar?download=true", |
| | map_location="cpu", |
| | check_hash=True, |
| | ) |
| | |
| | text_model = 'bert-base-uncased' |
| | audio_model = 'audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim' |
| | |
| | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') |
| | |
| | candidates = ['happy', 'sad', 'surprise', 'angry'] |
| | wavpath = '[Waveform path]' |
| |
|
| | waveform, sample_rate = librosa.load(wavpath, sr=16000) |
| | x = torch.Tensor(waveform) |
| |
|
| | tokenizer = AutoTokenizer.from_pretrained(text_model) |
| |
|
| | candidate_tokens = tokenizer.batch_encode_plus( |
| | candidates, |
| | padding=True, |
| | truncation=True, |
| | return_tensors='pt' |
| | ) |
| |
|
| | model = CLAP( |
| | speech_name=audio_model, |
| | text_name=text_model, |
| | embedding_dim=768, |
| | ) |
| |
|
| | model.load_state_dict(ckpt) |
| | model.to(device) |
| | print(f'Checkpoint is loaded') |
| | model.eval() |
| |
|
| | with torch.no_grad(): |
| | z = model( |
| | x.unsqueeze(0).to(device), |
| | candidate_tokens |
| | ) |
| |
|
| | similarity = compute_similarity(z[2], z[0], z[1]) |
| | prediction = similarity.T.argmax(dim=1) |
| | |
| | result = candidates[prediction] |
| |
|
| | print(result) |