360 lines
13 KiB
Python
360 lines
13 KiB
Python
"""
|
||
TTS Proxy Service - 小米 MiMo TTS 音频转换代理
|
||
核心功能: /api/tts 实时 TTS + 智能分段 + 自动重试
|
||
"""
|
||
|
||
import os
|
||
import json
|
||
import base64
|
||
import subprocess
|
||
import uuid
|
||
import asyncio
|
||
import logging
|
||
import time
|
||
from contextlib import asynccontextmanager
|
||
from pathlib import Path
|
||
|
||
import httpx
|
||
from fastapi import FastAPI, HTTPException, Request, Depends
|
||
from fastapi.responses import FileResponse, HTMLResponse, Response
|
||
from fastapi.staticfiles import StaticFiles
|
||
|
||
import config
|
||
|
||
# ── Logging ───────────────────────────────────────────────────────────────
|
||
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format="%(asctime)s %(message)s",
|
||
datefmt="%H:%M:%S",
|
||
)
|
||
logger = logging.getLogger("tts")
|
||
# 静默 httpx 内部日志
|
||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||
|
||
# ── Text Segmentation ─────────────────────────────────────────────────────
|
||
|
||
MAX_CHUNK_CHARS = 2000
|
||
|
||
_SEGMENT_PATTERNS = [
|
||
"\n\n", # 段落
|
||
"\n", # 换行
|
||
"。", "!", "?", "…",
|
||
".", "!", "?",
|
||
";", ";",
|
||
",", ",",
|
||
]
|
||
|
||
|
||
def split_text(text: str, max_chars: int = MAX_CHUNK_CHARS) -> list[str]:
|
||
"""智能分段:在自然边界切分,每段不超过 max_chars"""
|
||
text = text.strip()
|
||
if len(text) <= max_chars:
|
||
return [text]
|
||
|
||
chunks: list[str] = []
|
||
remaining = text
|
||
|
||
while remaining:
|
||
if len(remaining) <= max_chars:
|
||
chunks.append(remaining)
|
||
break
|
||
|
||
window = remaining[:max_chars]
|
||
cut_pos = -1
|
||
|
||
for sep in _SEGMENT_PATTERNS:
|
||
idx = window.rfind(sep)
|
||
if idx > 0:
|
||
cut_pos = idx + len(sep)
|
||
break
|
||
|
||
if cut_pos <= 0:
|
||
cut_pos = max_chars
|
||
|
||
chunk = remaining[:cut_pos].strip()
|
||
if chunk:
|
||
chunks.append(chunk)
|
||
remaining = remaining[cut_pos:].strip()
|
||
|
||
return chunks
|
||
|
||
|
||
# ── Auth ──────────────────────────────────────────────────────────────────
|
||
|
||
async def verify_token(request: Request):
|
||
"""Bearer Token 验证(API_TOKEN 未配置时跳过)"""
|
||
if not config.API_TOKEN:
|
||
return
|
||
auth = request.headers.get("Authorization", "")
|
||
if not auth.startswith("Bearer "):
|
||
raise HTTPException(401, "缺少 Authorization: Bearer <token>")
|
||
if auth[7:] != config.API_TOKEN:
|
||
raise HTTPException(403, "Token 无效")
|
||
|
||
|
||
# ── Audio Utils ───────────────────────────────────────────────────────────
|
||
|
||
def wav_to_mp3(wav_path: str, mp3_path: str):
|
||
result = subprocess.run(
|
||
["ffmpeg", "-y", "-i", wav_path, "-codec:a", "libmp3lame", "-qscale:a", "2", mp3_path],
|
||
capture_output=True, text=True,
|
||
)
|
||
if result.returncode != 0:
|
||
raise RuntimeError(f"ffmpeg 转换失败: {result.stderr[:300]}")
|
||
|
||
|
||
def concat_mp3_files(mp3_paths: list[str], output_path: str):
|
||
list_path = output_path + ".concat_list.txt"
|
||
with open(list_path, "w") as f:
|
||
for p in mp3_paths:
|
||
f.write(f"file '{p}'\n")
|
||
try:
|
||
result = subprocess.run(
|
||
["ffmpeg", "-y", "-f", "concat", "-safe", "0", "-i", list_path,
|
||
"-codec:a", "libmp3lame", "-qscale:a", "2", output_path],
|
||
capture_output=True, text=True,
|
||
)
|
||
if result.returncode != 0:
|
||
raise RuntimeError(f"ffmpeg 拼接失败: {result.stderr[:300]}")
|
||
finally:
|
||
os.remove(list_path)
|
||
|
||
|
||
# ── TTS Service ───────────────────────────────────────────────────────────
|
||
|
||
MAX_TTS_RETRIES = 3
|
||
|
||
|
||
async def call_mimo_tts(text: str, style: str = "", voice: str = "") -> bytes:
|
||
"""调用 MiMo TTS API,返回 WAV 字节。5xx 自动重试最多 3 次"""
|
||
if not config.MIMO_API_KEY:
|
||
raise HTTPException(500, "MIMO_API_KEY 未配置")
|
||
|
||
content = f"<style>{style}</style>{text}" if style else text
|
||
use_voice = voice or config.MIMO_VOICE
|
||
|
||
payload = {
|
||
"model": config.MIMO_TTS_MODEL,
|
||
"audio": {"format": "wav", "voice": use_voice},
|
||
"messages": [{"role": "assistant", "content": content}],
|
||
}
|
||
headers = {
|
||
"Content-Type": "application/json",
|
||
"api-key": config.MIMO_API_KEY,
|
||
}
|
||
|
||
last_exc = None
|
||
for attempt in range(1, MAX_TTS_RETRIES + 1):
|
||
t0 = time.time()
|
||
try:
|
||
async with httpx.AsyncClient(timeout=120) as client:
|
||
resp = await client.post(config.MIMO_API_ENDPOINT, json=payload, headers=headers)
|
||
|
||
elapsed = round(time.time() - t0, 2)
|
||
|
||
if resp.status_code != 200:
|
||
logger.error(f"TTS FAIL http={resp.status_code} {elapsed}s")
|
||
err = HTTPException(502, f"MiMo TTS API 错误: HTTP {resp.status_code}")
|
||
if resp.status_code >= 500 and attempt < MAX_TTS_RETRIES:
|
||
last_exc = err
|
||
await asyncio.sleep(1.5 * attempt)
|
||
continue
|
||
raise err
|
||
|
||
data = resp.json()
|
||
if data.get("error"):
|
||
raise HTTPException(502, f"MiMo TTS 错误: {data['error']}")
|
||
|
||
audio_b64 = data["choices"][0]["message"]["audio"]["data"]
|
||
wav_bytes = base64.b64decode(audio_b64)
|
||
logger.info(f"TTS OK {len(wav_bytes)//1024}KB {elapsed}s")
|
||
return wav_bytes
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
elapsed = round(time.time() - t0, 2)
|
||
logger.error(f"TTS ERR {e} {elapsed}s")
|
||
last_exc = HTTPException(502, f"MiMo TTS 异常: {e}")
|
||
if attempt < MAX_TTS_RETRIES:
|
||
await asyncio.sleep(1.5 * attempt)
|
||
|
||
raise last_exc
|
||
|
||
|
||
# ── Core: generate MP3 from text ──────────────────────────────────────────
|
||
|
||
async def generate_mp3(text: str, style: str = "", voice: str = "") -> bytes:
|
||
"""文本 → MP3 字节。长文本自动分段拼接"""
|
||
chunks = split_text(text)
|
||
tmp_dir = Path(config.AUDIO_DIR) / "_tmp"
|
||
tmp_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
if len(chunks) == 1:
|
||
wav_bytes = await call_mimo_tts(text, style, voice)
|
||
uid = uuid.uuid4().hex
|
||
wav_path = str(tmp_dir / f"{uid}.wav")
|
||
mp3_path = str(tmp_dir / f"{uid}.mp3")
|
||
with open(wav_path, "wb") as f:
|
||
f.write(wav_bytes)
|
||
loop = asyncio.get_event_loop()
|
||
await loop.run_in_executor(None, wav_to_mp3, wav_path, mp3_path)
|
||
with open(mp3_path, "rb") as f:
|
||
mp3_bytes = f.read()
|
||
os.remove(wav_path)
|
||
os.remove(mp3_path)
|
||
return mp3_bytes
|
||
|
||
# 多段
|
||
logger.info(f"SPLIT {len(text)}字 → {len(chunks)}段")
|
||
mp3_paths = []
|
||
for chunk in chunks:
|
||
wav_bytes = await call_mimo_tts(chunk, style, voice)
|
||
uid = uuid.uuid4().hex
|
||
wav_path = str(tmp_dir / f"{uid}.wav")
|
||
mp3_path = str(tmp_dir / f"{uid}.mp3")
|
||
with open(wav_path, "wb") as f:
|
||
f.write(wav_bytes)
|
||
loop = asyncio.get_event_loop()
|
||
await loop.run_in_executor(None, wav_to_mp3, wav_path, mp3_path)
|
||
os.remove(wav_path)
|
||
mp3_paths.append(mp3_path)
|
||
|
||
merged_id = uuid.uuid4().hex
|
||
merged_path = str(tmp_dir / f"{merged_id}.mp3")
|
||
await loop.run_in_executor(None, concat_mp3_files, mp3_paths, merged_path)
|
||
with open(merged_path, "rb") as f:
|
||
mp3_bytes = f.read()
|
||
for p in mp3_paths:
|
||
os.remove(p)
|
||
os.remove(merged_path)
|
||
return mp3_bytes
|
||
|
||
|
||
# ── App ───────────────────────────────────────────────────────────────────
|
||
|
||
@asynccontextmanager
|
||
async def lifespan(app: FastAPI):
|
||
os.makedirs(config.AUDIO_DIR, exist_ok=True)
|
||
(Path(config.AUDIO_DIR) / "_tmp").mkdir(exist_ok=True)
|
||
(Path(config.AUDIO_DIR) / "_preview").mkdir(exist_ok=True)
|
||
yield
|
||
|
||
|
||
app = FastAPI(title="TTS Proxy Service", lifespan=lifespan)
|
||
|
||
|
||
# ── 健康检查 ───────────────────────────────────────────────────────────────
|
||
|
||
@app.get("/health")
|
||
async def health():
|
||
return {
|
||
"status": "ok",
|
||
"api_key": bool(config.MIMO_API_KEY),
|
||
"token": bool(config.API_TOKEN),
|
||
}
|
||
|
||
|
||
# ── 核心接口: 实时 TTS ────────────────────────────────────────────────────
|
||
|
||
@app.post("/api/tts")
|
||
async def realtime_tts(request: Request):
|
||
"""
|
||
实时 TTS → 返回 MP3 音频流
|
||
|
||
JSON: {"text": "内容", "style": "开心", "voice": ""}
|
||
Form: tex=内容 (百度兼容)
|
||
"""
|
||
text = style = voice = ""
|
||
content_type = request.headers.get("content-type", "")
|
||
|
||
try:
|
||
if "json" in content_type:
|
||
data = await request.json()
|
||
text = (data.get("text") or "").strip()
|
||
style = (data.get("style") or "").strip()
|
||
voice = (data.get("voice") or "").strip()
|
||
else:
|
||
from urllib.parse import parse_qs, unquote
|
||
body = await request.body()
|
||
params = parse_qs(body.decode("utf-8"))
|
||
text = unquote(unquote((params.get("tex", [""])[0]).strip()))
|
||
except Exception:
|
||
pass
|
||
|
||
if not text:
|
||
return Response(
|
||
content=json.dumps({"status": 40000001, "message": "text 不能为空"}, ensure_ascii=False),
|
||
media_type="application/json", status_code=400,
|
||
)
|
||
|
||
try:
|
||
mp3_bytes = await generate_mp3(text, style, voice)
|
||
return Response(content=mp3_bytes, media_type="audio/mpeg")
|
||
except Exception as e:
|
||
return Response(
|
||
content=json.dumps({"status": 500, "message": str(e)[:300]}, ensure_ascii=False),
|
||
media_type="application/json", status_code=500,
|
||
)
|
||
|
||
|
||
# ── 管理接口 ───────────────────────────────────────────────────────────────
|
||
|
||
@app.post("/admin/api/preview")
|
||
async def preview(request: Request, _auth=Depends(verify_token)):
|
||
"""TTS 试听,返回音频 URL"""
|
||
data = await request.json()
|
||
text = (data.get("text") or "").strip()
|
||
style = (data.get("style") or "").strip()
|
||
voice = (data.get("voice") or "").strip()
|
||
if not text:
|
||
raise HTTPException(400, "文本不能为空")
|
||
|
||
mp3_bytes = await generate_mp3(text, style, voice)
|
||
preview_dir = Path(config.AUDIO_DIR) / "_preview"
|
||
filename = f"{uuid.uuid4().hex}.mp3"
|
||
with open(preview_dir / filename, "wb") as f:
|
||
f.write(mp3_bytes)
|
||
return {"ok": True, "url": f"/audio/_preview/{filename}"}
|
||
|
||
|
||
@app.get("/admin/api/config")
|
||
async def config_info(_auth=Depends(verify_token)):
|
||
return {
|
||
"endpoint": config.MIMO_API_ENDPOINT,
|
||
"model": config.MIMO_TTS_MODEL,
|
||
"voice": config.MIMO_VOICE,
|
||
"api_key": config.MIMO_API_KEY[:6] + "****" if config.MIMO_API_KEY else "未配置",
|
||
"max_chunk": MAX_CHUNK_CHARS,
|
||
"token_set": bool(config.API_TOKEN),
|
||
}
|
||
|
||
|
||
# ── 配置文件下载 ───────────────────────────────────────────────────────────
|
||
|
||
@app.get("/httpTts.json")
|
||
async def serve_config():
|
||
path = os.path.join(config.BASE_DIR, "httpTts-mimo.json")
|
||
if os.path.exists(path):
|
||
return FileResponse(path, media_type="application/json")
|
||
raise HTTPException(404)
|
||
|
||
|
||
# ── 静态 & 前端 ───────────────────────────────────────────────────────────
|
||
|
||
app.mount("/audio", StaticFiles(directory=config.AUDIO_DIR), name="audio")
|
||
|
||
|
||
@app.get("/", response_class=HTMLResponse)
|
||
async def frontend():
|
||
with open(os.path.join(config.BASE_DIR, "static", "index.html"), "r", encoding="utf-8") as f:
|
||
return HTMLResponse(f.read())
|
||
|
||
|
||
# ── Main ──────────────────────────────────────────────────────────────────
|
||
|
||
if __name__ == "__main__":
|
||
import uvicorn
|
||
uvicorn.run("main:app", host=config.SERVER_HOST, port=config.SERVER_PORT, reload=True)
|