File size: 4,403 Bytes
86bd8c8
 
 
a613fde
3bc70e4
86bd8c8
 
 
 
27d3b45
86bd8c8
a613fde
86bd8c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3bc70e4
86bd8c8
 
 
 
 
31c417c
86bd8c8
 
31c417c
86bd8c8
 
 
 
 
b0fd395
a0c2fe8
440a3a7
a0c2fe8
 
 
31c417c
a0c2fe8
86bd8c8
a0c2fe8
86bd8c8
a0c2fe8
 
 
 
 
 
 
 
 
 
 
 
31c417c
 
 
a0c2fe8
 
 
31c417c
 
86bd8c8
b0fd395
86bd8c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31c417c
86bd8c8
 
 
 
31c417c
86bd8c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31c417c
86bd8c8
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import gradio as gr
from transformers import pipeline
import torch
import os
import spaces

# Load the model pipeline
pipe = pipeline(
    task="multimodal_mt",
    model="BSC-LT/salamandra-TAV-7b",
    trust_remote_code=True,
    token=os.environ.get("HF_TOKEN"),
    device_map="auto",
    torch_dtype=torch.float16,
)

# Define the languages for the dropdowns
LANGUAGES = {
    "autodetect": "Autodetect",
    "en": "English",
    "es": "Spanish",
    "ca": "Catalan",
    "pt": "Portuguese",
    "gl": "Galician",
    "eu": "Basque",
}

# Invert the dictionary for easy lookup
LANG_TO_NAME = {v: k for k, v in LANGUAGES.items()}

@spaces.GPU
def process_audio(audio, source_lang_name, target_lang_name):
    """
    Processes the audio input to perform speech-to-text translation or transcription.
    """
    if audio is None:
        return "Please provide an audio file or record one.", ""

    if target_lang_name is None:
        return "Please select a target language.", ""

    source_lang = LANG_TO_NAME.get(source_lang_name)
    target_lang = LANG_TO_NAME.get(target_lang_name)

    generation_kwargs = {"beam_size": 5, "max_new_tokens": 100}
    
    asr_kwargs = {"mode": "asr", "return_chat_history": True, **generation_kwargs}
    if source_lang != "autodetect":
        asr_kwargs["src_lang"] = source_lang_name

    history = pipe(audio, **asr_kwargs)
    
    # If source and target languages are the same, we're done (transcription)
    if source_lang == target_lang:
        text = history.get_assistant_messages()[-1]
    else:
        # Text-to-text translation step
        t2tt_kwargs = {
            "mode": "t2tt",
            "tgt_lang": target_lang_name,
            "return_chat_history": True,
            **generation_kwargs
        }
        if source_lang != "autodetect":
            t2tt_kwargs["src_lang"] = source_lang_name
        
        history = pipe(history, **t2tt_kwargs)
        text = history.get_assistant_messages()[-1]
    
    detected_language = ""
    if source_lang == "autodetect":
        # Language identification step
        lang_history = pipe(history, mode="lid", return_chat_history=True, **generation_kwargs)
        detected_language = lang_history.get_assistant_messages()[-1]

    return text, detected_language


# Create the Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# SalamandraTAV: Speech-to-Text Translation Demo")
    gr.Markdown(
        "A multilingual model for Speech-to-Text Translation (S2TT) and Automatic Speech Recognition (ASR) for Iberian languages."
    )

    with gr.Row():
        with gr.Column():
            audio_input = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Audio Input")
            with gr.Row():
                source_lang_dropdown = gr.Dropdown(
                    choices=list(LANGUAGES.values()),
                    value=LANGUAGES["autodetect"],
                    label="Source Language (Optional)",
                )
                target_lang_dropdown = gr.Dropdown(
                    choices=[lang for key, lang in LANGUAGES.items() if key != "autodetect"],
                    label="Target Language (Required)",
                )
            submit_button = gr.Button("Translate/Transcribe")

        with gr.Column():
            output_text = gr.Textbox(label="Output", lines=10, interactive=False)
            detected_lang_output = gr.Textbox(label="Detected Source Language", interactive=False)

    submit_button.click(
        fn=process_audio,
        inputs=[audio_input, source_lang_dropdown, target_lang_dropdown],
        outputs=[output_text, detected_lang_output],
    )

    gr.Markdown("## Examples")
    gr.Examples(
        examples=[
            [
                "https://github.com/voxserv/audio_quality_testing_samples/raw/refs/heads/master/orig/127389__acclivity__thetimehascome.wav",
                LANGUAGES["en"],
                LANGUAGES["es"],
            ],
            [
                "https://github.com/voxserv/audio_quality_testing_samples/raw/refs/heads/master/orig/127389__acclivity__thetimehascome.wav",
                LANGUAGES["en"],
                LANGUAGES["en"],
            ],
        ],
        inputs=[audio_input, source_lang_dropdown, target_lang_dropdown],
        outputs=[output_text, detected_lang_output],
        fn=process_audio,
    )


if __name__ == "__main__":
    demo.launch()