| | |
| |
|
| | import argparse |
| | import glob |
| | import os |
| | import sys |
| |
|
| | from fireredasr.models.fireredasr import FireRedAsr |
| |
|
| |
|
| | parser = argparse.ArgumentParser() |
| | parser.add_argument('--asr_type', type=str, required=True, choices=["aed", "llm"]) |
| | parser.add_argument('--model_dir', type=str, required=True) |
| |
|
| | |
| | parser.add_argument("--wav_path", type=str) |
| | parser.add_argument("--wav_paths", type=str, nargs="*") |
| | parser.add_argument("--wav_dir", type=str) |
| | parser.add_argument("--wav_scp", type=str) |
| | parser.add_argument("--output", type=str) |
| |
|
| | |
| | parser.add_argument('--use_gpu', type=int, default=1) |
| | parser.add_argument("--batch_size", type=int, default=1) |
| | parser.add_argument("--beam_size", type=int, default=1) |
| | parser.add_argument("--decode_max_len", type=int, default=0) |
| | |
| | parser.add_argument("--nbest", type=int, default=1) |
| | parser.add_argument("--softmax_smoothing", type=float, default=1.0) |
| | parser.add_argument("--aed_length_penalty", type=float, default=0.0) |
| | parser.add_argument("--eos_penalty", type=float, default=1.0) |
| | |
| | parser.add_argument("--decode_min_len", type=int, default=0) |
| | parser.add_argument("--repetition_penalty", type=float, default=1.0) |
| | parser.add_argument("--llm_length_penalty", type=float, default=0.0) |
| | parser.add_argument("--temperature", type=float, default=1.0) |
| |
|
| |
|
| | def main(args): |
| | wavs = get_wav_info(args) |
| | fout = open(args.output, "w") if args.output else None |
| |
|
| | model = FireRedAsr.from_pretrained(args.asr_type, args.model_dir) |
| |
|
| | batch_uttid = [] |
| | batch_wav_path = [] |
| | for i, wav in enumerate(wavs): |
| | uttid, wav_path = wav |
| | batch_uttid.append(uttid) |
| | batch_wav_path.append(wav_path) |
| | if len(batch_wav_path) < args.batch_size and i != len(wavs) - 1: |
| | continue |
| |
|
| | results = model.transcribe( |
| | batch_uttid, |
| | batch_wav_path, |
| | { |
| | "use_gpu": args.use_gpu, |
| | "beam_size": args.beam_size, |
| | "nbest": args.nbest, |
| | "decode_max_len": args.decode_max_len, |
| | "softmax_smoothing": args.softmax_smoothing, |
| | "aed_length_penalty": args.aed_length_penalty, |
| | "eos_penalty": args.eos_penalty, |
| | "decode_min_len": args.decode_min_len, |
| | "repetition_penalty": args.repetition_penalty, |
| | "llm_length_penalty": args.llm_length_penalty, |
| | "temperature": args.temperature |
| | } |
| | ) |
| |
|
| | for result in results: |
| | print(result) |
| | if fout is not None: |
| | fout.write(f"{result['uttid']}\t{result['text']}\n") |
| |
|
| | batch_uttid = [] |
| | batch_wav_path = [] |
| |
|
| |
|
| | def get_wav_info(args): |
| | """ |
| | Returns: |
| | wavs: list of (uttid, wav_path) |
| | """ |
| | base = lambda p: os.path.basename(p).replace(".wav", "") |
| | if args.wav_path: |
| | wavs = [(base(args.wav_path), args.wav_path)] |
| | elif args.wav_paths and len(args.wav_paths) >= 1: |
| | wavs = [(base(p), p) for p in sorted(args.wav_paths)] |
| | elif args.wav_scp: |
| | wavs = [line.strip().split() for line in open(args.wav_scp)] |
| | elif args.wav_dir: |
| | wavs = glob.glob(f"{args.wav_dir}/**/*.wav", recursive=True) |
| | wavs = [(base(p), p) for p in sorted(wavs)] |
| | else: |
| | raise ValueError("Please provide valid wav info") |
| | print(f"#wavs={len(wavs)}") |
| | return wavs |
| |
|
| |
|
| | if __name__ == "__main__": |
| | args = parser.parse_args() |
| | print(args) |
| | main(args) |
| |
|