Spaces:
Sleeping
Sleeping
| # Import the necessary components from whisper_online.py | |
| import logging | |
| import os | |
| from typing import Optional | |
| import librosa | |
| import soundfile | |
| import uvicorn | |
| from fastapi import FastAPI, WebSocket | |
| from pydantic import BaseModel, ConfigDict | |
| from starlette.websockets import WebSocketDisconnect | |
| from libs.whisper_streaming.whisper_online import ( | |
| ASRBase, | |
| OnlineASRProcessor, | |
| VACOnlineASRProcessor, | |
| add_shared_args, | |
| asr_factory, | |
| set_logging, | |
| create_tokenizer, | |
| load_audio, | |
| load_audio_chunk, OpenaiApiASR, | |
| set_logging | |
| ) | |
| import argparse | |
| import sys | |
| import numpy as np | |
| import io | |
| import soundfile | |
| import wave | |
| import requests | |
| import argparse | |
| # from libs.whisper_streaming.whisper_online_server import online | |
| logger = logging.getLogger(__name__) | |
| SAMPLING_RATE = 16000 | |
| WARMUP_FILE = "mono16k.test_hebrew.wav" | |
| AUDIO_FILE_URL = "https://raw.githubusercontent.com/AshDavid12/runpod-serverless-forked/main/test_hebrew.wav" | |
| app = FastAPI() | |
| args = argparse.ArgumentParser() | |
| add_shared_args(args) | |
| def drop_option_from_parser(parser, option_name): | |
| """ | |
| Reinitializes the parser and copies all options except the specified option. | |
| Args: | |
| parser (argparse.ArgumentParser): The original argument parser. | |
| option_name (str): The option string to drop (e.g., '--model'). | |
| Returns: | |
| argparse.ArgumentParser: A new parser without the specified option. | |
| """ | |
| # Create a new parser with the same description and other attributes | |
| new_parser = argparse.ArgumentParser( | |
| description=parser.description, | |
| epilog=parser.epilog, | |
| formatter_class=parser.formatter_class | |
| ) | |
| # Iterate through all the arguments of the original parser | |
| for action in parser._actions: | |
| if "-h" in action.option_strings: | |
| continue | |
| # Check if the option is not the one to drop | |
| if option_name not in action.option_strings : | |
| new_parser._add_action(action) | |
| return new_parser | |
| def convert_to_mono_16k(input_wav: str, output_wav: str) -> None: | |
| """ | |
| Converts any .wav file to mono 16 kHz. | |
| Args: | |
| input_wav (str): Path to the input .wav file. | |
| output_wav (str): Path to save the output .wav file with mono 16 kHz. | |
| """ | |
| # Step 1: Load the audio file with librosa | |
| audio_data, original_sr = librosa.load(input_wav, sr=None, mono=False) # Load at original sampling rate | |
| logger.info("Loaded audio with shape: %s, original sampling rate: %d" % (audio_data.shape, original_sr)) | |
| # Step 2: If the audio has multiple channels, average them to make it mono | |
| if audio_data.ndim > 1: | |
| audio_data = librosa.to_mono(audio_data) | |
| # Step 3: Resample the audio to 16 kHz | |
| resampled_audio = librosa.resample(audio_data, orig_sr=original_sr, target_sr=SAMPLING_RATE) | |
| # Step 4: Save the resampled audio as a .wav file in mono at 16 kHz | |
| sf.write(output_wav, resampled_audio, SAMPLING_RATE) | |
| logger.info(f"Converted audio saved to {output_wav}") | |
| def download_warmup_file(): | |
| # Download the audio file if not already present | |
| audio_file_path = "test_hebrew.wav" | |
| if not os.path.exists(WARMUP_FILE): | |
| if not os.path.exists(audio_file_path): | |
| response = requests.get(AUDIO_FILE_URL) | |
| with open(audio_file_path, 'wb') as f: | |
| f.write(response.content) | |
| convert_to_mono_16k(audio_file_path, WARMUP_FILE) | |
| class State(BaseModel): | |
| model_config = ConfigDict(arbitrary_types_allowed=True) | |
| websocket: WebSocket | |
| asr: ASRBase | |
| online: OnlineASRProcessor | |
| min_limit: int | |
| is_first: bool = True | |
| last_end: Optional[float] = None | |
| async def receive_audio_chunk(state: State) -> Optional[np.ndarray]: | |
| # receive all audio that is available by this time | |
| # blocks operation if less than self.min_chunk seconds is available | |
| # unblocks if connection is closed or a chunk is available | |
| out = [] | |
| while sum(len(x) for x in out) < state.min_limit: | |
| raw_bytes = await state.websocket.receive_bytes() | |
| if not raw_bytes: | |
| break | |
| # print("received audio:",len(raw_bytes), "bytes", raw_bytes[:10]) | |
| sf = soundfile.SoundFile(io.BytesIO(raw_bytes), channels=1,endian="LITTLE",samplerate=SAMPLING_RATE, subtype="PCM_16",format="RAW") | |
| audio, _ = librosa.load(sf,sr=SAMPLING_RATE,dtype=np.float32) | |
| out.append(audio) | |
| if not out: | |
| return None | |
| flat_out = np.concatenate(out) | |
| if state.is_first and len(flat_out) < state.min_limit: | |
| return None | |
| state.is_first = False | |
| return flat_out | |
| def format_output_transcript(state, o) -> dict: | |
| # output format in stdout is like: | |
| # 0 1720 Takhle to je | |
| # - the first two words are: | |
| # - beg and end timestamp of the text segment, as estimated by Whisper model. The timestamps are not accurate, but they're useful anyway | |
| # - the next words: segment transcript | |
| # This function differs from whisper_online.output_transcript in the following: | |
| # succeeding [beg,end] intervals are not overlapping because ELITR protocol (implemented in online-text-flow events) requires it. | |
| # Therefore, beg, is max of previous end and current beg outputed by Whisper. | |
| # Usually it differs negligibly, by appx 20 ms. | |
| if o[0] is not None: | |
| beg, end = o[0]*1000,o[1]*1000 | |
| if state.last_end is not None: | |
| beg = max(beg, state.last_end) | |
| state.last_end = end | |
| print("%1.0f %1.0f %s" % (beg,end,o[2]),flush=True,file=sys.stderr) | |
| return { | |
| "start": "%1.0f" % beg, | |
| "end": "%1.0f" % end, | |
| "text": "%s" % o[2], | |
| } | |
| else: | |
| logger.debug("No text in this segment") | |
| return None | |
| # Define WebSocket endpoint | |
| async def websocket_transcribe(websocket: WebSocket): | |
| logger.info("New WebSocket connection request received.") | |
| await websocket.accept() | |
| logger.info("WebSocket connection established successfully.") | |
| # initialize the ASR model | |
| logger.info("Loading whisper model...") | |
| asr, online = asr_factory(args) | |
| state = State( | |
| websocket=websocket, | |
| asr=asr, | |
| online=online, | |
| min_limit=args.min_chunk_size * SAMPLING_RATE, | |
| ) | |
| # warm up the ASR because the very first transcribe takes more time than the others. | |
| # Test results in https://github.com/ufal/whisper_streaming/pull/81 | |
| logger.info("Warming up the whisper model...") | |
| a = load_audio_chunk(WARMUP_FILE, 0, 1) | |
| asr.transcribe(a) | |
| logger.info("Whisper is warmed up.") | |
| try: | |
| while True: | |
| a = await receive_audio_chunk(state) | |
| if a is None: | |
| break | |
| state.online.insert_audio_chunk(a) | |
| o = online.process_iter() | |
| try: | |
| if result := format_output_transcript(state, o): | |
| await websocket.send_json(result) | |
| except BrokenPipeError: | |
| logger.info("broken pipe -- connection closed?") | |
| break | |
| except WebSocketDisconnect: | |
| logger.info("WebSocket connection closed by the client.") | |
| break | |
| except Exception as e: | |
| logger.error(f"Unexpected error during WebSocket transcription: {e}") | |
| await websocket.send_json({"error": str(e)}) | |
| finally: | |
| logger.info("Cleaning up and closing WebSocket connection.") | |
| def main(): | |
| global args | |
| args = drop_option_from_parser(args, '--model') | |
| args.add_argument('--model', type=str, | |
| help="Name size of the Whisper model to use. The model is automatically downloaded from the model hub if not present in model cache dir.") | |
| args.parse_args([ | |
| '--lan', 'he', | |
| '--model', 'ivrit-ai/faster-whisper-v2-d4', | |
| '--backend', 'faster-whisper', | |
| '--vad', | |
| # '--vac', '--buffer_trimming', 'segment', '--buffer_trimming_sec', '15', '--min_chunk_size', '1.0', '--vac_chunk_size', '0.04', '--start_at', '0.0', '--offline', '--comp_unaware', '--log_level', 'DEBUG' | |
| ]) | |
| uvicorn.run(app) | |
| if __name__ == "__main__": | |
| main() |