SalamandraTAV / app.py
federicocosta1989's picture
Update app.py
27d3b45 verified
import gradio as gr
from transformers import pipeline
import torch
import os
import spaces
# Load the model pipeline
pipe = pipeline(
task="multimodal_mt",
model="BSC-LT/salamandra-TAV-7b",
trust_remote_code=True,
token=os.environ.get("HF_TOKEN"),
device_map="auto",
torch_dtype=torch.float16,
)
# Define the languages for the dropdowns
LANGUAGES = {
"autodetect": "Autodetect",
"en": "English",
"es": "Spanish",
"ca": "Catalan",
"pt": "Portuguese",
"gl": "Galician",
"eu": "Basque",
}
# Invert the dictionary for easy lookup
LANG_TO_NAME = {v: k for k, v in LANGUAGES.items()}
@spaces.GPU
def process_audio(audio, source_lang_name, target_lang_name):
"""
Processes the audio input to perform speech-to-text translation or transcription.
"""
if audio is None:
return "Please provide an audio file or record one.", ""
if target_lang_name is None:
return "Please select a target language.", ""
source_lang = LANG_TO_NAME.get(source_lang_name)
target_lang = LANG_TO_NAME.get(target_lang_name)
generation_kwargs = {"beam_size": 5, "max_new_tokens": 100}
asr_kwargs = {"mode": "asr", "return_chat_history": True, **generation_kwargs}
if source_lang != "autodetect":
asr_kwargs["src_lang"] = source_lang_name
history = pipe(audio, **asr_kwargs)
# If source and target languages are the same, we're done (transcription)
if source_lang == target_lang:
text = history.get_assistant_messages()[-1]
else:
# Text-to-text translation step
t2tt_kwargs = {
"mode": "t2tt",
"tgt_lang": target_lang_name,
"return_chat_history": True,
**generation_kwargs
}
if source_lang != "autodetect":
t2tt_kwargs["src_lang"] = source_lang_name
history = pipe(history, **t2tt_kwargs)
text = history.get_assistant_messages()[-1]
detected_language = ""
if source_lang == "autodetect":
# Language identification step
lang_history = pipe(history, mode="lid", return_chat_history=True, **generation_kwargs)
detected_language = lang_history.get_assistant_messages()[-1]
return text, detected_language
# Create the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# SalamandraTAV: Speech-to-Text Translation Demo")
gr.Markdown(
"A multilingual model for Speech-to-Text Translation (S2TT) and Automatic Speech Recognition (ASR) for Iberian languages."
)
with gr.Row():
with gr.Column():
audio_input = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Audio Input")
with gr.Row():
source_lang_dropdown = gr.Dropdown(
choices=list(LANGUAGES.values()),
value=LANGUAGES["autodetect"],
label="Source Language (Optional)",
)
target_lang_dropdown = gr.Dropdown(
choices=[lang for key, lang in LANGUAGES.items() if key != "autodetect"],
label="Target Language (Required)",
)
submit_button = gr.Button("Translate/Transcribe")
with gr.Column():
output_text = gr.Textbox(label="Output", lines=10, interactive=False)
detected_lang_output = gr.Textbox(label="Detected Source Language", interactive=False)
submit_button.click(
fn=process_audio,
inputs=[audio_input, source_lang_dropdown, target_lang_dropdown],
outputs=[output_text, detected_lang_output],
)
gr.Markdown("## Examples")
gr.Examples(
examples=[
[
"https://github.com/voxserv/audio_quality_testing_samples/raw/refs/heads/master/orig/127389__acclivity__thetimehascome.wav",
LANGUAGES["en"],
LANGUAGES["es"],
],
[
"https://github.com/voxserv/audio_quality_testing_samples/raw/refs/heads/master/orig/127389__acclivity__thetimehascome.wav",
LANGUAGES["en"],
LANGUAGES["en"],
],
],
inputs=[audio_input, source_lang_dropdown, target_lang_dropdown],
outputs=[output_text, detected_lang_output],
fn=process_audio,
)
if __name__ == "__main__":
demo.launch()