Files

360 lines
13 KiB
Python
Raw Permalink Normal View History

"""
TTS Proxy Service - 小米 MiMo TTS 音频转换代理
核心功能: /api/tts 实时 TTS + 智能分段 + 自动重试
"""
import os
import json
import base64
import subprocess
import uuid
import asyncio
import logging
import time
from contextlib import asynccontextmanager
from pathlib import Path
import httpx
from fastapi import FastAPI, HTTPException, Request, Depends
from fastapi.responses import FileResponse, HTMLResponse, Response
from fastapi.staticfiles import StaticFiles
import config
# ── Logging ───────────────────────────────────────────────────────────────
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(message)s",
datefmt="%H:%M:%S",
)
logger = logging.getLogger("tts")
# 静默 httpx 内部日志
logging.getLogger("httpx").setLevel(logging.WARNING)
# ── Text Segmentation ─────────────────────────────────────────────────────
MAX_CHUNK_CHARS = 2000
_SEGMENT_PATTERNS = [
"\n\n", # 段落
"\n", # 换行
"", "", "", "",
".", "!", "?",
"", ";",
"", ",",
]
def split_text(text: str, max_chars: int = MAX_CHUNK_CHARS) -> list[str]:
"""智能分段:在自然边界切分,每段不超过 max_chars"""
text = text.strip()
if len(text) <= max_chars:
return [text]
chunks: list[str] = []
remaining = text
while remaining:
if len(remaining) <= max_chars:
chunks.append(remaining)
break
window = remaining[:max_chars]
cut_pos = -1
for sep in _SEGMENT_PATTERNS:
idx = window.rfind(sep)
if idx > 0:
cut_pos = idx + len(sep)
break
if cut_pos <= 0:
cut_pos = max_chars
chunk = remaining[:cut_pos].strip()
if chunk:
chunks.append(chunk)
remaining = remaining[cut_pos:].strip()
return chunks
# ── Auth ──────────────────────────────────────────────────────────────────
async def verify_token(request: Request):
"""Bearer Token 验证API_TOKEN 未配置时跳过)"""
if not config.API_TOKEN:
return
auth = request.headers.get("Authorization", "")
if not auth.startswith("Bearer "):
raise HTTPException(401, "缺少 Authorization: Bearer <token>")
if auth[7:] != config.API_TOKEN:
raise HTTPException(403, "Token 无效")
# ── Audio Utils ───────────────────────────────────────────────────────────
def wav_to_mp3(wav_path: str, mp3_path: str):
result = subprocess.run(
["ffmpeg", "-y", "-i", wav_path, "-codec:a", "libmp3lame", "-qscale:a", "2", mp3_path],
capture_output=True, text=True,
)
if result.returncode != 0:
raise RuntimeError(f"ffmpeg 转换失败: {result.stderr[:300]}")
def concat_mp3_files(mp3_paths: list[str], output_path: str):
list_path = output_path + ".concat_list.txt"
with open(list_path, "w") as f:
for p in mp3_paths:
f.write(f"file '{p}'\n")
try:
result = subprocess.run(
["ffmpeg", "-y", "-f", "concat", "-safe", "0", "-i", list_path,
"-codec:a", "libmp3lame", "-qscale:a", "2", output_path],
capture_output=True, text=True,
)
if result.returncode != 0:
raise RuntimeError(f"ffmpeg 拼接失败: {result.stderr[:300]}")
finally:
os.remove(list_path)
# ── TTS Service ───────────────────────────────────────────────────────────
MAX_TTS_RETRIES = 3
async def call_mimo_tts(text: str, style: str = "", voice: str = "") -> bytes:
"""调用 MiMo TTS API返回 WAV 字节。5xx 自动重试最多 3 次"""
if not config.MIMO_API_KEY:
raise HTTPException(500, "MIMO_API_KEY 未配置")
content = f"<style>{style}</style>{text}" if style else text
use_voice = voice or config.MIMO_VOICE
payload = {
"model": config.MIMO_TTS_MODEL,
"audio": {"format": "wav", "voice": use_voice},
"messages": [{"role": "assistant", "content": content}],
}
headers = {
"Content-Type": "application/json",
"api-key": config.MIMO_API_KEY,
}
last_exc = None
for attempt in range(1, MAX_TTS_RETRIES + 1):
t0 = time.time()
try:
async with httpx.AsyncClient(timeout=120) as client:
resp = await client.post(config.MIMO_API_ENDPOINT, json=payload, headers=headers)
elapsed = round(time.time() - t0, 2)
if resp.status_code != 200:
logger.error(f"TTS FAIL http={resp.status_code} {elapsed}s")
err = HTTPException(502, f"MiMo TTS API 错误: HTTP {resp.status_code}")
if resp.status_code >= 500 and attempt < MAX_TTS_RETRIES:
last_exc = err
await asyncio.sleep(1.5 * attempt)
continue
raise err
data = resp.json()
if data.get("error"):
raise HTTPException(502, f"MiMo TTS 错误: {data['error']}")
audio_b64 = data["choices"][0]["message"]["audio"]["data"]
wav_bytes = base64.b64decode(audio_b64)
logger.info(f"TTS OK {len(wav_bytes)//1024}KB {elapsed}s")
return wav_bytes
except HTTPException:
raise
except Exception as e:
elapsed = round(time.time() - t0, 2)
logger.error(f"TTS ERR {e} {elapsed}s")
last_exc = HTTPException(502, f"MiMo TTS 异常: {e}")
if attempt < MAX_TTS_RETRIES:
await asyncio.sleep(1.5 * attempt)
raise last_exc
# ── Core: generate MP3 from text ──────────────────────────────────────────
async def generate_mp3(text: str, style: str = "", voice: str = "") -> bytes:
"""文本 → MP3 字节。长文本自动分段拼接"""
chunks = split_text(text)
tmp_dir = Path(config.AUDIO_DIR) / "_tmp"
tmp_dir.mkdir(parents=True, exist_ok=True)
if len(chunks) == 1:
wav_bytes = await call_mimo_tts(text, style, voice)
uid = uuid.uuid4().hex
wav_path = str(tmp_dir / f"{uid}.wav")
mp3_path = str(tmp_dir / f"{uid}.mp3")
with open(wav_path, "wb") as f:
f.write(wav_bytes)
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, wav_to_mp3, wav_path, mp3_path)
with open(mp3_path, "rb") as f:
mp3_bytes = f.read()
os.remove(wav_path)
os.remove(mp3_path)
return mp3_bytes
# 多段
logger.info(f"SPLIT {len(text)}字 → {len(chunks)}")
mp3_paths = []
for chunk in chunks:
wav_bytes = await call_mimo_tts(chunk, style, voice)
uid = uuid.uuid4().hex
wav_path = str(tmp_dir / f"{uid}.wav")
mp3_path = str(tmp_dir / f"{uid}.mp3")
with open(wav_path, "wb") as f:
f.write(wav_bytes)
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, wav_to_mp3, wav_path, mp3_path)
os.remove(wav_path)
mp3_paths.append(mp3_path)
merged_id = uuid.uuid4().hex
merged_path = str(tmp_dir / f"{merged_id}.mp3")
await loop.run_in_executor(None, concat_mp3_files, mp3_paths, merged_path)
with open(merged_path, "rb") as f:
mp3_bytes = f.read()
for p in mp3_paths:
os.remove(p)
os.remove(merged_path)
return mp3_bytes
# ── App ───────────────────────────────────────────────────────────────────
@asynccontextmanager
async def lifespan(app: FastAPI):
os.makedirs(config.AUDIO_DIR, exist_ok=True)
(Path(config.AUDIO_DIR) / "_tmp").mkdir(exist_ok=True)
(Path(config.AUDIO_DIR) / "_preview").mkdir(exist_ok=True)
yield
app = FastAPI(title="TTS Proxy Service", lifespan=lifespan)
# ── 健康检查 ───────────────────────────────────────────────────────────────
@app.get("/health")
async def health():
return {
"status": "ok",
"api_key": bool(config.MIMO_API_KEY),
"token": bool(config.API_TOKEN),
}
# ── 核心接口: 实时 TTS ────────────────────────────────────────────────────
@app.post("/api/tts")
async def realtime_tts(request: Request):
"""
实时 TTS 返回 MP3 音频流
JSON: {"text": "内容", "style": "开心", "voice": ""}
Form: tex=内容 (百度兼容)
"""
text = style = voice = ""
content_type = request.headers.get("content-type", "")
try:
if "json" in content_type:
data = await request.json()
text = (data.get("text") or "").strip()
style = (data.get("style") or "").strip()
voice = (data.get("voice") or "").strip()
else:
from urllib.parse import parse_qs, unquote
body = await request.body()
params = parse_qs(body.decode("utf-8"))
text = unquote(unquote((params.get("tex", [""])[0]).strip()))
except Exception:
pass
if not text:
return Response(
content=json.dumps({"status": 40000001, "message": "text 不能为空"}, ensure_ascii=False),
media_type="application/json", status_code=400,
)
try:
mp3_bytes = await generate_mp3(text, style, voice)
return Response(content=mp3_bytes, media_type="audio/mpeg")
except Exception as e:
return Response(
content=json.dumps({"status": 500, "message": str(e)[:300]}, ensure_ascii=False),
media_type="application/json", status_code=500,
)
# ── 管理接口 ───────────────────────────────────────────────────────────────
@app.post("/admin/api/preview")
async def preview(request: Request, _auth=Depends(verify_token)):
"""TTS 试听,返回音频 URL"""
data = await request.json()
text = (data.get("text") or "").strip()
style = (data.get("style") or "").strip()
voice = (data.get("voice") or "").strip()
if not text:
raise HTTPException(400, "文本不能为空")
mp3_bytes = await generate_mp3(text, style, voice)
preview_dir = Path(config.AUDIO_DIR) / "_preview"
filename = f"{uuid.uuid4().hex}.mp3"
with open(preview_dir / filename, "wb") as f:
f.write(mp3_bytes)
return {"ok": True, "url": f"/audio/_preview/{filename}"}
@app.get("/admin/api/config")
async def config_info(_auth=Depends(verify_token)):
return {
"endpoint": config.MIMO_API_ENDPOINT,
"model": config.MIMO_TTS_MODEL,
"voice": config.MIMO_VOICE,
"api_key": config.MIMO_API_KEY[:6] + "****" if config.MIMO_API_KEY else "未配置",
"max_chunk": MAX_CHUNK_CHARS,
"token_set": bool(config.API_TOKEN),
}
# ── 配置文件下载 ───────────────────────────────────────────────────────────
@app.get("/httpTts.json")
async def serve_config():
path = os.path.join(config.BASE_DIR, "httpTts-mimo.json")
if os.path.exists(path):
return FileResponse(path, media_type="application/json")
raise HTTPException(404)
# ── 静态 & 前端 ───────────────────────────────────────────────────────────
app.mount("/audio", StaticFiles(directory=config.AUDIO_DIR), name="audio")
@app.get("/", response_class=HTMLResponse)
async def frontend():
with open(os.path.join(config.BASE_DIR, "static", "index.html"), "r", encoding="utf-8") as f:
return HTMLResponse(f.read())
# ── Main ──────────────────────────────────────────────────────────────────
if __name__ == "__main__":
import uvicorn
uvicorn.run("main:app", host=config.SERVER_HOST, port=config.SERVER_PORT, reload=True)