kokoro.axera / inference_utils.py
HY-2012's picture
Update inference demo
6dde92c verified
import re
import json
import numpy as np
from typing import Dict, List, Optional
from dataclasses import dataclass
from loguru import logger
# 常量
SAMPLE_RATE = 24000
DEFAULT_SPEED = 1.0
DEFAULT_FADE_OUT = 0.05
DEFAULT_PAUSE = 0.05
ALIASES = {
'en-us': 'a', 'en-gb': 'b', 'es': 'e', 'fr-fr': 'f',
'hi': 'h', 'it': 'i', 'pt-br': 'p', 'ja': 'j', 'zh': 'z',
}
LANG_CODES = {
'a': 'American English', 'b': 'British English', 'e': 'es',
'f': 'fr-fr', 'h': 'hi', 'i': 'it', 'p': 'pt-br',
'j': 'Japanese', 'z': 'Mandarin Chinese',
}
@dataclass
class G2PContext:
"""G2P"""
g2p: any
g2p_type: str
vocab: Dict[str, int]
def clean_text(text: str) -> str:
text = re.sub(r'\s+', ' ', text)
text = text.strip()
text = ''.join(char for char in text if ord(char) >= 32 or char in '\n\r\t')
return text
def split_sentences(text: str, lang_code: str = 'a') -> List[str]:
if lang_code in ['z', 'j']:
sentences = re.split(r'([。!?;,、:""''()【】《》…\n])', text)
else:
sentences = re.split(r'([.!?;,:\n])', text)
result = []
#一句话带一个标点
for i in range(0, len(sentences)-1, 2):
if i+1 < len(sentences):
sentence = sentences[i] + sentences[i+1]
else:
sentence = sentences[i]
sentence = sentence.strip()
if sentence:
result.append(sentence)
# 处理最后没有标点的文本片段,添加默认结束标点
if len(sentences) % 2 == 1 and sentences[-1].strip():
last_text = sentences[-1].strip()
end_punctuation = '。' if lang_code in ['z', 'j'] else '.'
result.append(last_text + end_punctuation)
return result if result else [text]
def apply_fade_out(audio: np.ndarray, fade_samples: int) -> np.ndarray:
"""末尾淡出音频"""
if len(audio) <= fade_samples or fade_samples <= 0:
return audio
fade_out = np.linspace(1.0, 0.0, fade_samples).astype(np.float32)
audio_faded = audio.copy()
audio_faded[-fade_samples:] *= fade_out
return audio_faded
def audio_numpy_concat(segment_data_list: List[np.ndarray], sr: int = SAMPLE_RATE,
speed: float = DEFAULT_SPEED, pause_duration: float = DEFAULT_PAUSE) -> np.ndarray:
"""拼接音频片段"""
if not segment_data_list:
return np.array([], dtype=np.float32)
audio_segments = []
pause_samples = int((sr * pause_duration) / speed)
for i, segment_data in enumerate(segment_data_list):
audio_segments.append(segment_data.reshape(-1))
if i < len(segment_data_list) - 1 and pause_samples > 0:
audio_segments.append(np.zeros(pause_samples, dtype=np.float32))
return np.concatenate(audio_segments).astype(np.float32)
def load_vocab_from_config(config_path: str) -> Dict[str, int]:
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
return config.get('vocab', {})
def init_g2p(lang_code: str, trf: bool = False, en_callable=None):
lang_code = lang_code.lower()
lang_code = ALIASES.get(lang_code, lang_code)
if lang_code not in LANG_CODES:
raise ValueError(f"不支持的语言代码: {lang_code}")
if lang_code in 'ab':
from misaki import en, espeak
try:
fallback = espeak.EspeakFallback(british=lang_code=='b')
except:
logger.warning("EspeakFallback 未启用")
fallback = None
g2p = en.G2P(trf=trf, british=lang_code=='b', fallback=fallback, unk='')
return g2p, 'en'
elif lang_code == 'j':
from misaki import ja
return ja.JAG2P(), 'ja'
elif lang_code == 'z':
from misaki import zh
return zh.ZHG2P(version=None, en_callable=en_callable), 'zh'
else:
from misaki import espeak
language = LANG_CODES[lang_code]
return espeak.EspeakG2P(language=language), 'espeak'
def text_to_phonemes(text: str, g2p, g2p_type: str) -> str:
if g2p_type == 'en':
_, tokens = g2p(text)
phonemes = ''.join(t.phonemes + (' ' if t.whitespace else '') for t in tokens).strip()
return phonemes
else:
phonemes, _ = g2p(text)
return phonemes
def phonemes_to_input_ids(phonemes: str, vocab: Dict[str, int], debug: bool = False) -> np.ndarray:
input_ids = []
skipped_phonemes = []
skipped_positions = []
for i, p in enumerate(phonemes):
if p in vocab:
input_ids.append(vocab[p])
else:
skipped_phonemes.append(p)
skipped_positions.append(i)
if debug:
start = max(0, i-5)
end = min(len(phonemes), i+6)
context = phonemes[start:end]
logger.warning(f"未知音素 '{p}' (ord={ord(p)}) 位置={i}, 上下文: ...{context}...")
if skipped_phonemes:
logger.error(f"总共跳过 {len(skipped_phonemes)} 个音素在位置 {skipped_positions}: {skipped_phonemes}")
return np.array(input_ids, dtype=np.int64)
def load_voice_embedding(voice_path: str, phoneme_len: Optional[int] = None) -> np.ndarray:
if "checkpoints/voices_npy" in voice_path:
pass
else:
voice_path = voice_path.replace('checkpoints/voices', 'checkpoints/voices_npy').replace('.pt', '.npy')
pack = np.load(voice_path).reshape(510,1,256)
if phoneme_len is not None:
ref_s = pack[phoneme_len:phoneme_len+1]
else:
idx = pack.shape[0] // 2
ref_s = pack[idx:idx+1]
return ref_s[0]
def split_input_ids_semantic(
input_ids: np.ndarray,
fixed_seq_len: int,
phonemes: str = None,
vocab: Dict[str, int] = None
) -> List[Dict]:
"""input_ids分割"""
content = input_ids[0, 1:-1]
chunk_with_special = np.concatenate([[0], content, [0]])
# 填充到固定长度
padding_len = fixed_seq_len - len(chunk_with_special)
if padding_len > 0:
chunk_padded = np.concatenate([chunk_with_special, np.zeros(padding_len, dtype=input_ids.dtype)])
else:
chunk_padded = chunk_with_special
return [{
'input_ids': chunk_padded.reshape(1, -1),
'actual_len': len(chunk_with_special),
# 'is_last': True,
# 'is_first': True,
'trim_end_chars': 0
}]
def generate_input_ids_from_text(text: str, lang_code: str = None, config_path: str = None,
g2p=None, g2p_type: str = None, vocab: Dict[str, int] = None,
g2p_context: G2PContext = None):
"""从文本生成input_ids"""
if g2p_context is not None:
g2p = g2p_context.g2p
g2p_type = g2p_context.g2p_type
vocab = g2p_context.vocab
else:
if vocab is None:
vocab = load_vocab_from_config(config_path)
if g2p is None:
g2p, g2p_type = init_g2p(lang_code)
phonemes = text_to_phonemes(text, g2p, g2p_type)
content_ids = phonemes_to_input_ids(phonemes, vocab, debug=False)
input_ids = np.concatenate([[0], content_ids, [0]]).reshape(1, -1)
return input_ids, phonemes
def split_long_sentence(sentence, lang_code, g2p, g2p_type, vocab, max_merge_len=78, depth=0):
try:
input_ids, phonemes = generate_input_ids_from_text(
sentence, g2p=g2p, g2p_type=g2p_type, vocab=vocab
)
content_len = input_ids.shape[1]
if content_len <= max_merge_len:
return [{
'sentence': sentence,
'input_ids': input_ids,
'phonemes': phonemes,
'content_len': content_len
}]
else:
# 中文/日文按字符长度一半分割
if lang_code in ['z', 'j']:
mid = len(sentence) // 2
first_half = sentence[:mid]
second_half = sentence[mid:]
else:
# 英文按单词个数一半分割
words = sentence.split()
mid_word = len(words) // 2
first_half = ' '.join(words[:mid_word])
second_half = ' '.join(words[mid_word:])
result_first = split_long_sentence(first_half, lang_code, g2p, g2p_type, vocab,
max_merge_len, depth + 1)
result_second = split_long_sentence(second_half, lang_code, g2p, g2p_type, vocab,
max_merge_len, depth + 1)
return result_first + result_second
except Exception:
return []
def concat_audios(sub_audios: List[Dict], sr: int = SAMPLE_RATE) -> np.ndarray:
"""拼接音频"""
if not sub_audios:
return np.array([], dtype=np.float32)
if len(sub_audios) == 1:
return sub_audios[0]['audio']
audio_segments = [sub['audio'] for sub in sub_audios]
return np.concatenate(audio_segments).astype(np.float32)
def process_and_merge_sentences(text: str, lang_code: str, g2p, g2p_type: str,
vocab: Dict[str, int], max_merge_len: int = 96) -> List[Dict]:
"""
处理文本:清理(待处理)、分句、生成input_ids、长句分割、短句合并
"""
cleaned_text = clean_text(text)
sentences = split_sentences(cleaned_text, lang_code=lang_code)
# 为每个句子生成 input_ids
sentence_data = []
for sentence in sentences:
try:
input_ids, phonemes = generate_input_ids_from_text(
sentence, g2p=g2p, g2p_type=g2p_type, vocab=vocab
)
content_len = input_ids.shape[1]
if content_len <= max_merge_len:
sentence_data.append({
'sentence': sentence,
'input_ids': input_ids,
'phonemes': phonemes,
'content_len': content_len,
'is_long': False
})
else:
sub_results = split_long_sentence(sentence, lang_code, g2p, g2p_type, vocab, max_merge_len)
sentence_data.append({
'sentence': sentence,
'sub_results': sub_results,
'is_long': True
})
except Exception as e:
logger.error(f"错误处理句子 '{sentence}': {e}")
if not sentence_data:
raise ValueError("没有生成任何 input_ids")
# 长句保持分割,短句合并
merged_groups = []
i = 0
while i < len(sentence_data):
if sentence_data[i]['is_long']:
sub_results = sentence_data[i]['sub_results']
merged_groups.append({'is_long_split': True, 'sub_results': sub_results})
i += 1
else:
merged_sentences = []
total_len = 0
j = i
while j < len(sentence_data) and not sentence_data[j]['is_long']:
next_len = sentence_data[j]['content_len']
if total_len + next_len < max_merge_len:
merged_sentences.append(sentence_data[j]['sentence'])
total_len += next_len
j += 1
else:
break
if j == i:
merged_sentences.append(sentence_data[i]['sentence'])
j = i + 1
# 重新生成合并后的 input_ids
merged_text = ' '.join(merged_sentences)
merged_input_ids, merged_phonemes = generate_input_ids_from_text(
merged_text, g2p=g2p, g2p_type=g2p_type, vocab=vocab
)
merged_groups.append({'input_ids': merged_input_ids, 'phonemes': merged_phonemes})
i = j
return merged_groups
def run_batch_inference(engine, merged_groups: List[Dict], voice_path: str,
vocab: Dict[str, int], speed: float = DEFAULT_SPEED,
fade_out_duration: float = DEFAULT_FADE_OUT,
sr: int = SAMPLE_RATE) -> List[np.ndarray]:
"""推理"""
audio_list = []
for group in merged_groups:
try:
if group.get('is_long_split', False):
# 长句分割:对每个子片段推理后拼接
sub_audios = []
for sub in group['sub_results']:
phoneme_len = sub['input_ids'].shape[1] - 2
ref_s = load_voice_embedding(voice_path, phoneme_len=phoneme_len)
audio = engine.inference(sub['input_ids'], ref_s, sub['phonemes'], vocab,
speed=speed, fade_out_duration=0)
sub_audios.append({'audio': audio})
combined_audio = concat_audios(sub_audios, sr=sr)
audio_list.append(combined_audio)
else:
# 短句或合并句:直接推理
phoneme_len = group['input_ids'].shape[1] - 2
ref_s = load_voice_embedding(voice_path, phoneme_len=phoneme_len)
audio = engine.inference(group['input_ids'], ref_s, group['phonemes'], vocab,
speed=speed, fade_out_duration=fade_out_duration)
audio_list.append(audio)
except Exception as e:
logger.error(f"推理错误: {e}")
if not audio_list:
raise ValueError("没有生成任何音频")
return audio_list