| import re |
| import json |
| import numpy as np |
| from typing import Dict, List, Optional |
| from dataclasses import dataclass |
| from loguru import logger |
|
|
|
|
| |
| SAMPLE_RATE = 24000 |
| DEFAULT_SPEED = 1.0 |
| DEFAULT_FADE_OUT = 0.05 |
| DEFAULT_PAUSE = 0.05 |
|
|
| ALIASES = { |
| 'en-us': 'a', 'en-gb': 'b', 'es': 'e', 'fr-fr': 'f', |
| 'hi': 'h', 'it': 'i', 'pt-br': 'p', 'ja': 'j', 'zh': 'z', |
| } |
|
|
| LANG_CODES = { |
| 'a': 'American English', 'b': 'British English', 'e': 'es', |
| 'f': 'fr-fr', 'h': 'hi', 'i': 'it', 'p': 'pt-br', |
| 'j': 'Japanese', 'z': 'Mandarin Chinese', |
| } |
|
|
|
|
| @dataclass |
| class G2PContext: |
| """G2P""" |
| g2p: any |
| g2p_type: str |
| vocab: Dict[str, int] |
|
|
|
|
| def clean_text(text: str) -> str: |
| text = re.sub(r'\s+', ' ', text) |
| text = text.strip() |
| text = ''.join(char for char in text if ord(char) >= 32 or char in '\n\r\t') |
| return text |
|
|
|
|
| def split_sentences(text: str, lang_code: str = 'a') -> List[str]: |
| if lang_code in ['z', 'j']: |
| sentences = re.split(r'([。!?;,、:""''()【】《》…\n])', text) |
| else: |
| sentences = re.split(r'([.!?;,:\n])', text) |
| |
| result = [] |
| |
| for i in range(0, len(sentences)-1, 2): |
| if i+1 < len(sentences): |
| sentence = sentences[i] + sentences[i+1] |
| else: |
| sentence = sentences[i] |
| sentence = sentence.strip() |
| if sentence: |
| result.append(sentence) |
| |
| |
| if len(sentences) % 2 == 1 and sentences[-1].strip(): |
| last_text = sentences[-1].strip() |
| end_punctuation = '。' if lang_code in ['z', 'j'] else '.' |
| result.append(last_text + end_punctuation) |
| |
| return result if result else [text] |
|
|
|
|
| def apply_fade_out(audio: np.ndarray, fade_samples: int) -> np.ndarray: |
| """末尾淡出音频""" |
| if len(audio) <= fade_samples or fade_samples <= 0: |
| return audio |
| fade_out = np.linspace(1.0, 0.0, fade_samples).astype(np.float32) |
| audio_faded = audio.copy() |
| audio_faded[-fade_samples:] *= fade_out |
| return audio_faded |
|
|
|
|
| def audio_numpy_concat(segment_data_list: List[np.ndarray], sr: int = SAMPLE_RATE, |
| speed: float = DEFAULT_SPEED, pause_duration: float = DEFAULT_PAUSE) -> np.ndarray: |
| """拼接音频片段""" |
| if not segment_data_list: |
| return np.array([], dtype=np.float32) |
| |
| audio_segments = [] |
| pause_samples = int((sr * pause_duration) / speed) |
| |
| for i, segment_data in enumerate(segment_data_list): |
| audio_segments.append(segment_data.reshape(-1)) |
| if i < len(segment_data_list) - 1 and pause_samples > 0: |
| audio_segments.append(np.zeros(pause_samples, dtype=np.float32)) |
| |
| return np.concatenate(audio_segments).astype(np.float32) |
|
|
|
|
| def load_vocab_from_config(config_path: str) -> Dict[str, int]: |
| with open(config_path, 'r', encoding='utf-8') as f: |
| config = json.load(f) |
| return config.get('vocab', {}) |
|
|
|
|
| def init_g2p(lang_code: str, trf: bool = False, en_callable=None): |
| lang_code = lang_code.lower() |
| lang_code = ALIASES.get(lang_code, lang_code) |
| |
| if lang_code not in LANG_CODES: |
| raise ValueError(f"不支持的语言代码: {lang_code}") |
| |
| if lang_code in 'ab': |
| from misaki import en, espeak |
| try: |
| fallback = espeak.EspeakFallback(british=lang_code=='b') |
| except: |
| logger.warning("EspeakFallback 未启用") |
| fallback = None |
| g2p = en.G2P(trf=trf, british=lang_code=='b', fallback=fallback, unk='') |
| return g2p, 'en' |
| elif lang_code == 'j': |
| from misaki import ja |
| return ja.JAG2P(), 'ja' |
| elif lang_code == 'z': |
| from misaki import zh |
| return zh.ZHG2P(version=None, en_callable=en_callable), 'zh' |
| else: |
| from misaki import espeak |
| language = LANG_CODES[lang_code] |
| return espeak.EspeakG2P(language=language), 'espeak' |
|
|
|
|
| def text_to_phonemes(text: str, g2p, g2p_type: str) -> str: |
| if g2p_type == 'en': |
| _, tokens = g2p(text) |
| phonemes = ''.join(t.phonemes + (' ' if t.whitespace else '') for t in tokens).strip() |
| return phonemes |
| else: |
| phonemes, _ = g2p(text) |
| return phonemes |
|
|
|
|
| def phonemes_to_input_ids(phonemes: str, vocab: Dict[str, int], debug: bool = False) -> np.ndarray: |
| input_ids = [] |
| skipped_phonemes = [] |
| skipped_positions = [] |
| |
| for i, p in enumerate(phonemes): |
| if p in vocab: |
| input_ids.append(vocab[p]) |
| else: |
| skipped_phonemes.append(p) |
| skipped_positions.append(i) |
| if debug: |
| start = max(0, i-5) |
| end = min(len(phonemes), i+6) |
| context = phonemes[start:end] |
| logger.warning(f"未知音素 '{p}' (ord={ord(p)}) 位置={i}, 上下文: ...{context}...") |
| |
| if skipped_phonemes: |
| logger.error(f"总共跳过 {len(skipped_phonemes)} 个音素在位置 {skipped_positions}: {skipped_phonemes}") |
| |
| return np.array(input_ids, dtype=np.int64) |
|
|
|
|
| def load_voice_embedding(voice_path: str, phoneme_len: Optional[int] = None) -> np.ndarray: |
| if "checkpoints/voices_npy" in voice_path: |
| pass |
| else: |
| voice_path = voice_path.replace('checkpoints/voices', 'checkpoints/voices_npy').replace('.pt', '.npy') |
| pack = np.load(voice_path).reshape(510,1,256) |
| if phoneme_len is not None: |
| ref_s = pack[phoneme_len:phoneme_len+1] |
| else: |
| idx = pack.shape[0] // 2 |
| ref_s = pack[idx:idx+1] |
| return ref_s[0] |
|
|
|
|
| def split_input_ids_semantic( |
| input_ids: np.ndarray, |
| fixed_seq_len: int, |
| phonemes: str = None, |
| vocab: Dict[str, int] = None |
| ) -> List[Dict]: |
| """input_ids分割""" |
| content = input_ids[0, 1:-1] |
| chunk_with_special = np.concatenate([[0], content, [0]]) |
| |
| |
| padding_len = fixed_seq_len - len(chunk_with_special) |
| if padding_len > 0: |
| chunk_padded = np.concatenate([chunk_with_special, np.zeros(padding_len, dtype=input_ids.dtype)]) |
| else: |
| chunk_padded = chunk_with_special |
| |
| return [{ |
| 'input_ids': chunk_padded.reshape(1, -1), |
| 'actual_len': len(chunk_with_special), |
| |
| |
| 'trim_end_chars': 0 |
| }] |
|
|
|
|
| def generate_input_ids_from_text(text: str, lang_code: str = None, config_path: str = None, |
| g2p=None, g2p_type: str = None, vocab: Dict[str, int] = None, |
| g2p_context: G2PContext = None): |
| """从文本生成input_ids""" |
| if g2p_context is not None: |
| g2p = g2p_context.g2p |
| g2p_type = g2p_context.g2p_type |
| vocab = g2p_context.vocab |
| else: |
| if vocab is None: |
| vocab = load_vocab_from_config(config_path) |
| if g2p is None: |
| g2p, g2p_type = init_g2p(lang_code) |
| |
| phonemes = text_to_phonemes(text, g2p, g2p_type) |
| content_ids = phonemes_to_input_ids(phonemes, vocab, debug=False) |
| input_ids = np.concatenate([[0], content_ids, [0]]).reshape(1, -1) |
| return input_ids, phonemes |
|
|
|
|
| def split_long_sentence(sentence, lang_code, g2p, g2p_type, vocab, max_merge_len=78, depth=0): |
| try: |
| input_ids, phonemes = generate_input_ids_from_text( |
| sentence, g2p=g2p, g2p_type=g2p_type, vocab=vocab |
| ) |
| content_len = input_ids.shape[1] |
| |
| if content_len <= max_merge_len: |
| return [{ |
| 'sentence': sentence, |
| 'input_ids': input_ids, |
| 'phonemes': phonemes, |
| 'content_len': content_len |
| }] |
| else: |
| |
| if lang_code in ['z', 'j']: |
| mid = len(sentence) // 2 |
| first_half = sentence[:mid] |
| second_half = sentence[mid:] |
| else: |
| |
| words = sentence.split() |
| mid_word = len(words) // 2 |
| first_half = ' '.join(words[:mid_word]) |
| second_half = ' '.join(words[mid_word:]) |
| |
| result_first = split_long_sentence(first_half, lang_code, g2p, g2p_type, vocab, |
| max_merge_len, depth + 1) |
| result_second = split_long_sentence(second_half, lang_code, g2p, g2p_type, vocab, |
| max_merge_len, depth + 1) |
| |
| return result_first + result_second |
| except Exception: |
| return [] |
|
|
|
|
| def concat_audios(sub_audios: List[Dict], sr: int = SAMPLE_RATE) -> np.ndarray: |
| """拼接音频""" |
| if not sub_audios: |
| return np.array([], dtype=np.float32) |
| |
| if len(sub_audios) == 1: |
| return sub_audios[0]['audio'] |
| |
| audio_segments = [sub['audio'] for sub in sub_audios] |
| return np.concatenate(audio_segments).astype(np.float32) |
|
|
|
|
| def process_and_merge_sentences(text: str, lang_code: str, g2p, g2p_type: str, |
| vocab: Dict[str, int], max_merge_len: int = 96) -> List[Dict]: |
| """ |
| 处理文本:清理(待处理)、分句、生成input_ids、长句分割、短句合并 |
| |
| """ |
| cleaned_text = clean_text(text) |
| sentences = split_sentences(cleaned_text, lang_code=lang_code) |
| |
| |
| sentence_data = [] |
| for sentence in sentences: |
| try: |
| input_ids, phonemes = generate_input_ids_from_text( |
| sentence, g2p=g2p, g2p_type=g2p_type, vocab=vocab |
| ) |
| content_len = input_ids.shape[1] |
| |
| if content_len <= max_merge_len: |
| sentence_data.append({ |
| 'sentence': sentence, |
| 'input_ids': input_ids, |
| 'phonemes': phonemes, |
| 'content_len': content_len, |
| 'is_long': False |
| }) |
| else: |
| sub_results = split_long_sentence(sentence, lang_code, g2p, g2p_type, vocab, max_merge_len) |
| sentence_data.append({ |
| 'sentence': sentence, |
| 'sub_results': sub_results, |
| 'is_long': True |
| }) |
| except Exception as e: |
| logger.error(f"错误处理句子 '{sentence}': {e}") |
| |
| if not sentence_data: |
| raise ValueError("没有生成任何 input_ids") |
| |
| |
| merged_groups = [] |
| i = 0 |
| |
| while i < len(sentence_data): |
| if sentence_data[i]['is_long']: |
| sub_results = sentence_data[i]['sub_results'] |
| merged_groups.append({'is_long_split': True, 'sub_results': sub_results}) |
| i += 1 |
| else: |
| merged_sentences = [] |
| total_len = 0 |
| j = i |
| |
| while j < len(sentence_data) and not sentence_data[j]['is_long']: |
| next_len = sentence_data[j]['content_len'] |
| |
| if total_len + next_len < max_merge_len: |
| merged_sentences.append(sentence_data[j]['sentence']) |
| total_len += next_len |
| j += 1 |
| else: |
| break |
| |
| if j == i: |
| merged_sentences.append(sentence_data[i]['sentence']) |
| j = i + 1 |
| |
| |
| merged_text = ' '.join(merged_sentences) |
| merged_input_ids, merged_phonemes = generate_input_ids_from_text( |
| merged_text, g2p=g2p, g2p_type=g2p_type, vocab=vocab |
| ) |
| |
| merged_groups.append({'input_ids': merged_input_ids, 'phonemes': merged_phonemes}) |
| i = j |
| |
| return merged_groups |
|
|
|
|
| def run_batch_inference(engine, merged_groups: List[Dict], voice_path: str, |
| vocab: Dict[str, int], speed: float = DEFAULT_SPEED, |
| fade_out_duration: float = DEFAULT_FADE_OUT, |
| sr: int = SAMPLE_RATE) -> List[np.ndarray]: |
| """推理""" |
| audio_list = [] |
| |
| for group in merged_groups: |
| try: |
| if group.get('is_long_split', False): |
| |
| sub_audios = [] |
| for sub in group['sub_results']: |
| phoneme_len = sub['input_ids'].shape[1] - 2 |
| ref_s = load_voice_embedding(voice_path, phoneme_len=phoneme_len) |
| audio = engine.inference(sub['input_ids'], ref_s, sub['phonemes'], vocab, |
| speed=speed, fade_out_duration=0) |
| sub_audios.append({'audio': audio}) |
| |
| combined_audio = concat_audios(sub_audios, sr=sr) |
| audio_list.append(combined_audio) |
| else: |
| |
| phoneme_len = group['input_ids'].shape[1] - 2 |
| ref_s = load_voice_embedding(voice_path, phoneme_len=phoneme_len) |
| audio = engine.inference(group['input_ids'], ref_s, group['phonemes'], vocab, |
| speed=speed, fade_out_duration=fade_out_duration) |
| audio_list.append(audio) |
| except Exception as e: |
| logger.error(f"推理错误: {e}") |
| |
| if not audio_list: |
| raise ValueError("没有生成任何音频") |
| |
| return audio_list |