""" Schema 提取器 —— 只提取表结构,不碰数据 """ import sqlite3 from typing import Any def extract_schema(db_path: str) -> dict[str, Any]: """从数据库提取 Schema,只返回结构信息""" conn = sqlite3.connect(db_path) conn.row_factory = sqlite3.Row cur = conn.cursor() cur.execute("SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'") tables = [row["name"] for row in cur.fetchall()] schema = {"tables": []} for table in tables: cur.execute(f"PRAGMA table_info('{table}')") columns = [] for col in cur.fetchall(): columns.append({ "name": col["name"], "type": col["type"], "nullable": col["notnull"] == 0, "is_primary_key": col["pk"] == 1, }) cur.execute(f"PRAGMA foreign_key_list('{table}')") fks = [ {"column": fk["from"], "references_table": fk["table"], "references_column": fk["to"]} for fk in cur.fetchall() ] cur.execute(f"SELECT COUNT(*) AS cnt FROM '{table}'") row_count = cur.fetchone()["cnt"] data_profile = {} for col in columns: col_name = col["name"] col_type = (col["type"] or "").upper() if any(t in col_type for t in ("VARCHAR", "TEXT", "CHAR")): cur.execute(f'SELECT DISTINCT "{col_name}" FROM "{table}" WHERE "{col_name}" IS NOT NULL LIMIT 20') vals = [row[0] for row in cur.fetchall()] if len(vals) <= 20: data_profile[col_name] = {"type": "enum", "distinct_count": len(vals), "values": vals} elif any(t in col_type for t in ("INT", "REAL", "FLOAT", "DOUBLE", "DECIMAL", "NUMERIC")): cur.execute(f''' SELECT MIN("{col_name}") AS min_val, MAX("{col_name}") AS max_val, AVG("{col_name}") AS avg_val, COUNT(DISTINCT "{col_name}") AS distinct_count FROM "{table}" WHERE "{col_name}" IS NOT NULL ''') row = cur.fetchone() if row and row["min_val"] is not None: data_profile[col_name] = { "type": "numeric", "min": round(row["min_val"], 2), "max": round(row["max_val"], 2), "avg": round(row["avg_val"], 2), "distinct_count": row["distinct_count"], } schema["tables"].append({ "name": table, "columns": columns, "foreign_keys": fks, "row_count": row_count, "data_profile": data_profile, }) conn.close() return schema def schema_to_text(schema: dict) -> str: """将 Schema 转为可读文本,供 LLM 理解""" lines = ["=== 数据库 Schema ===\n"] for table in schema["tables"]: lines.append(f"📋 表: {table['name']} (共 {table['row_count']} 行)") lines.append(" 列:") for col in table["columns"]: pk = " [PK]" if col["is_primary_key"] else "" null = " NULL" if col["nullable"] else " NOT NULL" lines.append(f' - {col["name"]}: {col["type"]}{pk}{null}') if table["foreign_keys"]: lines.append(" 外键:") for fk in table["foreign_keys"]: lines.append(f' - {fk["column"]} → {fk["references_table"]}.{fk["references_column"]}') if table["data_profile"]: lines.append(" 数据画像:") for col_name, profile in table["data_profile"].items(): if profile["type"] == "enum": vals = ", ".join(str(v) for v in profile["values"][:10]) lines.append(f' - {col_name}: 枚举值({profile["distinct_count"]}个) = [{vals}]') elif profile["type"] == "numeric": lines.append( f' - {col_name}: 范围[{profile["min"]} ~ {profile["max"]}], ' f'均值{profile["avg"]}, {profile["distinct_count"]}个不同值' ) lines.append("") return "\n".join(lines)