feat: API文档、文本自动分段、音色配置、批量并发
- 新增 API.md 完整接口文档 - 智能文本分段:长文本按段落/句子/标点边界自动切分(≤2000字/段),逐段TTS后ffmpeg拼接 - /api/tts 支持 voice 参数指定音色 - httpTts JSON 配置增加 style 和 voice 字段 - 批量生成改用并发(Semaphore 3路) - 新增 /health 健康检查端点 - TTS 试听前端增加音色输入 - 清理 import,修复端口不一致
This commit is contained in:
236
app/main.py
236
app/main.py
@@ -16,11 +16,12 @@ from pathlib import Path
|
||||
|
||||
import httpx
|
||||
from fastapi import FastAPI, HTTPException, Query, Request
|
||||
from fastapi.responses import FileResponse, HTMLResponse, JSONResponse
|
||||
from fastapi.responses import FileResponse, HTMLResponse, JSONResponse, Response
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from sqlalchemy import Column, Integer, String, Text, DateTime, func, select, delete
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
from urllib.parse import parse_qs, unquote
|
||||
|
||||
import config
|
||||
|
||||
@@ -33,6 +34,83 @@ logging.basicConfig(
|
||||
)
|
||||
logger = logging.getLogger("tts-service")
|
||||
|
||||
# ── Text Segmentation ─────────────────────────────────────────────────────
|
||||
|
||||
# MiMo TTS 单次请求文本上限(保守值,实际约 5000)
|
||||
MAX_CHUNK_CHARS = 2000
|
||||
|
||||
# 分割优先级: 段落 > 句子 > 逗号/分号 > 按长度硬切
|
||||
_SEGMENT_PATTERNS = [
|
||||
"\n\n", # 段落
|
||||
"\n", # 换行
|
||||
"。", "!", "?", "…", # 中文句末
|
||||
".", "!", "?", # 英文句末
|
||||
";", ";", # 分号
|
||||
",", ",", # 逗号
|
||||
]
|
||||
|
||||
|
||||
def split_text(text: str, max_chars: int = MAX_CHUNK_CHARS) -> list[str]:
|
||||
"""
|
||||
智能分段:尽量在自然边界(段落/句子/标点)处切分,
|
||||
保证每段不超过 max_chars 字符。
|
||||
"""
|
||||
text = text.strip()
|
||||
if len(text) <= max_chars:
|
||||
return [text]
|
||||
|
||||
chunks: list[str] = []
|
||||
remaining = text
|
||||
|
||||
while remaining:
|
||||
if len(remaining) <= max_chars:
|
||||
chunks.append(remaining)
|
||||
break
|
||||
|
||||
# 在 max_chars 范围内找最佳切割点
|
||||
window = remaining[:max_chars]
|
||||
cut_pos = -1
|
||||
|
||||
for sep in _SEGMENT_PATTERNS:
|
||||
idx = window.rfind(sep)
|
||||
if idx > 0:
|
||||
cut_pos = idx + len(sep)
|
||||
break
|
||||
|
||||
if cut_pos <= 0:
|
||||
# 没找到任何分隔符,硬切
|
||||
cut_pos = max_chars
|
||||
|
||||
chunk = remaining[:cut_pos].strip()
|
||||
if chunk:
|
||||
chunks.append(chunk)
|
||||
remaining = remaining[cut_pos:].strip()
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
# ── Audio Concatenation ───────────────────────────────────────────────────
|
||||
|
||||
def concat_mp3_files(mp3_paths: list[str], output_path: str):
|
||||
"""用 ffmpeg 将多个 MP3 文件拼接为一个"""
|
||||
# 创建 ffmpeg concat 文件列表
|
||||
list_path = output_path + ".concat_list.txt"
|
||||
with open(list_path, "w") as f:
|
||||
for p in mp3_paths:
|
||||
f.write(f"file '{p}'\n")
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["ffmpeg", "-y", "-f", "concat", "-safe", "0", "-i", list_path,
|
||||
"-codec:a", "libmp3lame", "-qscale:a", "2", output_path],
|
||||
capture_output=True, text=True,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"ffmpeg 拼接失败: {result.stderr[:300]}")
|
||||
finally:
|
||||
os.remove(list_path)
|
||||
|
||||
|
||||
# ── Database ──────────────────────────────────────────────────────────────
|
||||
|
||||
engine = create_async_engine(config.DATABASE_URL, echo=False)
|
||||
@@ -69,16 +147,17 @@ class Chapter(Base):
|
||||
|
||||
# ── TTS Service ───────────────────────────────────────────────────────────
|
||||
|
||||
async def call_mimo_tts(text: str, style: str = "") -> bytes:
|
||||
async def call_mimo_tts(text: str, style: str = "", voice: str = "") -> bytes:
|
||||
"""调用小米 MiMo TTS API,返回 WAV 音频字节"""
|
||||
if not config.MIMO_API_KEY:
|
||||
raise HTTPException(500, "MIMO_API_KEY 未配置,请设置环境变量")
|
||||
|
||||
content = f"<style>{style}</style>{text}" if style else text
|
||||
use_voice = voice or config.MIMO_VOICE
|
||||
|
||||
payload = {
|
||||
"model": config.MIMO_TTS_MODEL,
|
||||
"audio": {"format": "wav", "voice": config.MIMO_VOICE},
|
||||
"audio": {"format": "wav", "voice": use_voice},
|
||||
"messages": [{"role": "assistant", "content": content}],
|
||||
}
|
||||
|
||||
@@ -125,7 +204,7 @@ def wav_to_mp3(wav_path: str, mp3_path: str):
|
||||
|
||||
|
||||
async def generate_chapter_audio(chapter_id_str: str):
|
||||
"""为指定章节生成音频(WAV → MP3)"""
|
||||
"""为指定章节生成音频(支持长文本自动分段拼接)"""
|
||||
async with async_session() as db:
|
||||
result = await db.execute(select(Chapter).where(Chapter.chapter_id == chapter_id_str))
|
||||
chapter = result.scalar_one_or_none()
|
||||
@@ -145,20 +224,41 @@ async def generate_chapter_audio(chapter_id_str: str):
|
||||
audio_dir = Path(config.AUDIO_DIR) / chapter.book_id
|
||||
audio_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
wav_path = str(audio_dir / f"{chapter.chapter_id}.wav")
|
||||
mp3_path = str(audio_dir / f"{chapter.chapter_id}.mp3")
|
||||
chunks = split_text(chapter.text_content)
|
||||
|
||||
# MiMo TTS 生成 WAV
|
||||
wav_bytes = await call_mimo_tts(chapter.text_content)
|
||||
with open(wav_path, "wb") as f:
|
||||
f.write(wav_bytes)
|
||||
if len(chunks) == 1:
|
||||
# 单段:直接生成
|
||||
wav_bytes = await call_mimo_tts(chapter.text_content)
|
||||
wav_path = str(audio_dir / f"{chapter.chapter_id}.wav")
|
||||
with open(wav_path, "wb") as f:
|
||||
f.write(wav_bytes)
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(None, wav_to_mp3, wav_path, mp3_path)
|
||||
os.remove(wav_path)
|
||||
else:
|
||||
# 多段:逐段生成 → 拼接
|
||||
logger.info(f"章节 {chapter_id_str}: 文本 {len(chapter.text_content)} 字, 分 {len(chunks)} 段生成")
|
||||
tmp_mp3_paths = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
wav_bytes = await call_mimo_tts(chunk)
|
||||
tmp_id = f"{chapter.chapter_id}_part{i}"
|
||||
wav_path = str(audio_dir / f"{tmp_id}.wav")
|
||||
tmp_mp3 = str(audio_dir / f"{tmp_id}.mp3")
|
||||
|
||||
# WAV → MP3
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(None, wav_to_mp3, wav_path, mp3_path)
|
||||
with open(wav_path, "wb") as f:
|
||||
f.write(wav_bytes)
|
||||
|
||||
# 删除 WAV 源文件,只保留 MP3
|
||||
os.remove(wav_path)
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(None, wav_to_mp3, wav_path, tmp_mp3)
|
||||
os.remove(wav_path)
|
||||
tmp_mp3_paths.append(tmp_mp3)
|
||||
|
||||
# 拼接
|
||||
await loop.run_in_executor(None, concat_mp3_files, tmp_mp3_paths, mp3_path)
|
||||
for p in tmp_mp3_paths:
|
||||
os.remove(p)
|
||||
logger.info(f"章节 {chapter_id_str}: {len(chunks)} 段拼接完成")
|
||||
|
||||
chapter.audio_file = mp3_path
|
||||
chapter.status = "ready"
|
||||
@@ -166,6 +266,7 @@ async def generate_chapter_audio(chapter_id_str: str):
|
||||
except Exception as e:
|
||||
chapter.status = "error"
|
||||
chapter.error_msg = str(e)[:500]
|
||||
logger.error(f"章节 {chapter_id_str} 生成失败: {e}")
|
||||
|
||||
await db.commit()
|
||||
|
||||
@@ -184,6 +285,12 @@ async def lifespan(app: FastAPI):
|
||||
app = FastAPI(title="TTS Book Service", lifespan=lifespan)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""健康检查"""
|
||||
return {"status": "ok", "service": "TTS Book Service", "api_key_configured": bool(config.MIMO_API_KEY)}
|
||||
|
||||
|
||||
# ── 听书 App 音频接入接口 ─────────────────────────────────────────────────
|
||||
|
||||
@app.get("/api/book/{book_id}")
|
||||
@@ -242,9 +349,6 @@ async def get_chapter_audio(book_id: str, chapter_id: str):
|
||||
|
||||
# ── 实时 TTS 接口(兼容听书 App 格式)─────────────────────────────────────
|
||||
|
||||
from fastapi.responses import Response
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
@app.post("/api/tts")
|
||||
async def realtime_tts(request: Request):
|
||||
"""
|
||||
@@ -256,6 +360,7 @@ async def realtime_tts(request: Request):
|
||||
"""
|
||||
text = ""
|
||||
style = ""
|
||||
voice = ""
|
||||
content_type = request.headers.get("content-type", "")
|
||||
|
||||
try:
|
||||
@@ -263,13 +368,13 @@ async def realtime_tts(request: Request):
|
||||
data = await request.json()
|
||||
text = data.get("text", "").strip()
|
||||
style = data.get("style", "").strip()
|
||||
voice = data.get("voice", "").strip()
|
||||
else:
|
||||
# form-urlencoded (百度风格)
|
||||
body_bytes = await request.body()
|
||||
params = parse_qs(body_bytes.decode("utf-8"))
|
||||
text = (params.get("tex", [""])[0]).strip()
|
||||
# URL 解码(百度会 double-encode)
|
||||
from urllib.parse import unquote
|
||||
text = unquote(unquote(text))
|
||||
except Exception:
|
||||
pass
|
||||
@@ -282,28 +387,60 @@ async def realtime_tts(request: Request):
|
||||
)
|
||||
|
||||
try:
|
||||
# MiMo TTS 生成 WAV
|
||||
wav_bytes = await call_mimo_tts(text, style)
|
||||
# 文本分段
|
||||
chunks = split_text(text)
|
||||
logger.info(f"实时 TTS: text_len={len(text)}, chunks={len(chunks)}, style={style or '(默认)'}, voice={voice or '(默认)'}")
|
||||
|
||||
# WAV → MP3(临时文件)
|
||||
tmp_dir = Path(config.AUDIO_DIR) / "_tmp"
|
||||
tmp_dir.mkdir(parents=True, exist_ok=True)
|
||||
tmp_id = uuid.uuid4().hex
|
||||
wav_path = str(tmp_dir / f"{tmp_id}.wav")
|
||||
mp3_path = str(tmp_dir / f"{tmp_id}.mp3")
|
||||
|
||||
with open(wav_path, "wb") as f:
|
||||
f.write(wav_bytes)
|
||||
if len(chunks) == 1:
|
||||
# 单段:直接生成
|
||||
wav_bytes = await call_mimo_tts(text, style, voice)
|
||||
tmp_id = uuid.uuid4().hex
|
||||
wav_path = str(tmp_dir / f"{tmp_id}.wav")
|
||||
mp3_path = str(tmp_dir / f"{tmp_id}.mp3")
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(None, wav_to_mp3, wav_path, mp3_path)
|
||||
with open(wav_path, "wb") as f:
|
||||
f.write(wav_bytes)
|
||||
|
||||
with open(mp3_path, "rb") as f:
|
||||
mp3_bytes = f.read()
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(None, wav_to_mp3, wav_path, mp3_path)
|
||||
|
||||
# 清理临时文件
|
||||
os.remove(wav_path)
|
||||
os.remove(mp3_path)
|
||||
with open(mp3_path, "rb") as f:
|
||||
mp3_bytes = f.read()
|
||||
|
||||
os.remove(wav_path)
|
||||
os.remove(mp3_path)
|
||||
else:
|
||||
# 多段:逐段生成 → 拼接
|
||||
mp3_paths = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
wav_bytes = await call_mimo_tts(chunk, style, voice)
|
||||
chunk_id = uuid.uuid4().hex
|
||||
wav_path = str(tmp_dir / f"{chunk_id}.wav")
|
||||
mp3_path = str(tmp_dir / f"{chunk_id}.mp3")
|
||||
|
||||
with open(wav_path, "wb") as f:
|
||||
f.write(wav_bytes)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(None, wav_to_mp3, wav_path, mp3_path)
|
||||
os.remove(wav_path)
|
||||
mp3_paths.append(mp3_path)
|
||||
|
||||
# 拼接所有 MP3
|
||||
merged_id = uuid.uuid4().hex
|
||||
merged_path = str(tmp_dir / f"{merged_id}.mp3")
|
||||
await loop.run_in_executor(None, concat_mp3_files, mp3_paths, merged_path)
|
||||
|
||||
with open(merged_path, "rb") as f:
|
||||
mp3_bytes = f.read()
|
||||
|
||||
# 清理
|
||||
for p in mp3_paths:
|
||||
os.remove(p)
|
||||
os.remove(merged_path)
|
||||
|
||||
return Response(content=mp3_bytes, media_type="audio/mpeg")
|
||||
|
||||
@@ -459,7 +596,7 @@ async def generate_audio(book_id: str, chapter_id: str):
|
||||
|
||||
@app.post("/admin/api/books/{book_id}/generate-all")
|
||||
async def generate_all_chapters(book_id: str):
|
||||
"""批量生成书籍所有章节音频"""
|
||||
"""批量生成书籍所有章节音频(并发,限制 3 路)"""
|
||||
async with async_session() as db:
|
||||
result = await db.execute(
|
||||
select(Chapter).where(Chapter.book_id == book_id, Chapter.status != "ready")
|
||||
@@ -467,10 +604,31 @@ async def generate_all_chapters(book_id: str):
|
||||
chapters = result.scalars().all()
|
||||
|
||||
chapter_ids = [ch.chapter_id for ch in chapters]
|
||||
for cid in chapter_ids:
|
||||
await generate_chapter_audio(cid)
|
||||
if not chapter_ids:
|
||||
return {"ok": True, "total": 0, "chapter_ids": [], "message": "没有需要生成的章节"}
|
||||
|
||||
return {"ok": True, "total": len(chapter_ids), "chapter_ids": chapter_ids}
|
||||
# 并发生成,限制同时 3 个请求避免过载
|
||||
sem = asyncio.Semaphore(3)
|
||||
|
||||
async def _gen(cid: str):
|
||||
async with sem:
|
||||
await generate_chapter_audio(cid)
|
||||
|
||||
tasks = [_gen(cid) for cid in chapter_ids]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# 统计结果
|
||||
errors = [str(r) for r in results if isinstance(r, Exception)]
|
||||
success_count = len(chapter_ids) - len(errors)
|
||||
|
||||
return {
|
||||
"ok": True,
|
||||
"total": len(chapter_ids),
|
||||
"success": success_count,
|
||||
"failed": len(errors),
|
||||
"errors": errors[:10] if errors else [],
|
||||
"chapter_ids": chapter_ids,
|
||||
}
|
||||
|
||||
|
||||
# --- TTS 试听 ---
|
||||
@@ -481,11 +639,12 @@ async def tts_preview(request: Request):
|
||||
data = await request.json()
|
||||
text = data.get("text", "").strip()
|
||||
style = data.get("style", "").strip()
|
||||
voice = data.get("voice", "").strip()
|
||||
|
||||
if not text:
|
||||
raise HTTPException(400, "文本不能为空")
|
||||
|
||||
wav_bytes = await call_mimo_tts(text, style)
|
||||
wav_bytes = await call_mimo_tts(text, style, voice)
|
||||
audio_dir = Path(config.AUDIO_DIR) / "_preview"
|
||||
audio_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@@ -510,6 +669,7 @@ async def get_config():
|
||||
"model": config.MIMO_TTS_MODEL,
|
||||
"voice": config.MIMO_VOICE,
|
||||
"api_key_masked": config.MIMO_API_KEY[:6] + "****" if config.MIMO_API_KEY else "未配置",
|
||||
"max_chunk_chars": MAX_CHUNK_CHARS,
|
||||
}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user