624 lines
26 KiB
Python
624 lines
26 KiB
Python
import base64
|
||
import json
|
||
import datetime
|
||
import time
|
||
import re
|
||
from typing import Dict, Optional, List
|
||
from openai import OpenAI
|
||
from openai import APIError, RateLimitError, AuthenticationError
|
||
|
||
|
||
class AIService:
|
||
def __init__(self, config):
|
||
"""
|
||
初始化 AI 服务
|
||
:param config: 包含 ai 配置的字典 (来自 config.yaml)
|
||
"""
|
||
self.api_key = config.get('api_key')
|
||
self.base_url = config.get('base_url')
|
||
self.model = config.get('model', 'gpt-4o')
|
||
self.max_retries = config.get('max_retries', 3)
|
||
self.retry_delay = config.get('retry_delay', 1.0)
|
||
|
||
# 初始化 OpenAI 客户端 (兼容所有支持 OpenAI 格式的 API)
|
||
self.client = OpenAI(
|
||
api_key=self.api_key,
|
||
base_url=self.base_url
|
||
)
|
||
|
||
# 支持的图片格式
|
||
self.supported_formats = {'.png', '.jpg', '.jpeg', '.bmp', '.gif'}
|
||
|
||
# Prompt模板缓存
|
||
self.prompt_templates = {}
|
||
|
||
def _encode_image(self, image_path):
|
||
"""将本地图片转换为 Base64 编码"""
|
||
try:
|
||
with open(image_path, "rb") as image_file:
|
||
return base64.b64encode(image_file.read()).decode('utf-8')
|
||
except Exception as e:
|
||
raise Exception(f"图片编码失败: {str(e)}")
|
||
|
||
def _encode_image_from_bytes(self, image_bytes: bytes) -> str:
|
||
"""将图片字节数据转换为 Base64 编码"""
|
||
try:
|
||
return base64.b64encode(image_bytes).decode('utf-8')
|
||
except Exception as e:
|
||
raise Exception(f"图片字节数据编码失败: {str(e)}")
|
||
|
||
def _validate_image_file(self, image_path):
|
||
"""验证图片文件"""
|
||
from pathlib import Path
|
||
|
||
file_path = Path(image_path)
|
||
|
||
# 检查文件是否存在
|
||
if not file_path.exists():
|
||
raise FileNotFoundError(f"图片文件不存在: {image_path}")
|
||
|
||
# 检查文件大小(限制为10MB)
|
||
file_size = file_path.stat().st_size
|
||
if file_size > 10 * 1024 * 1024: # 10MB
|
||
raise ValueError(f"图片文件过大: {file_size / 1024 / 1024:.2f}MB (最大10MB)")
|
||
|
||
# 检查文件格式
|
||
if file_path.suffix.lower() not in self.supported_formats:
|
||
raise ValueError(f"不支持的图片格式: {file_path.suffix}")
|
||
|
||
return True
|
||
|
||
def _build_prompt(self, current_date_str: str) -> str:
|
||
"""构建AI提示词"""
|
||
prompt = f"""
|
||
你是一个专业的项目管理助手。当前系统日期是:{current_date_str}。
|
||
你的任务是从用户上传的图片(可能是聊天记录、邮件、文档截图)中提取任务信息。
|
||
|
||
请严格按照以下 JSON 格式返回结果:
|
||
{{
|
||
"task_description": "任务的具体描述,简练概括",
|
||
"priority": "必须从以下选项中选择一个: ['紧急', '较紧急', '一般', '普通']",
|
||
"status": "必须从以下选项中选择一个: ['已停滞','待开始', '进行中', '已完成']",
|
||
"latest_progress": "图片中提到的最新进展,如果没有则留空字符串",
|
||
"initiator": "任务发起人姓名。请仔细识别图片中的发件人/发送者名字。如果是邮件截图,请识别发件人;如果是聊天记录,请识别发送者。如果没有明确的发起人,则留空字符串。",
|
||
"department": "发起人部门,如果没有则留空字符串",
|
||
"start_date": "YYYY-MM-DD 格式的日期字符串。如果提到'今天'就是当前日期,'下周一'请根据当前日期计算。如果没提到则返回 null",
|
||
"due_date": "YYYY-MM-DD 格式的截止日期字符串。逻辑同上,如果没提到则返回 null"
|
||
}}
|
||
|
||
注意:
|
||
1. 如果图片中包含多个任务,请只提取最核心的一个。
|
||
2. 请特别关注图片中的发件人/发送者信息,准确提取姓名。
|
||
3. 如果识别到的名字可能存在重名,请在任务描述中添加提示信息。
|
||
4. 不要返回 Markdown 代码块标记(如 ```json),直接返回纯 JSON 字符串。
|
||
5. 确保返回的 JSON 格式正确,可以被 Python 的 json.loads() 解析。
|
||
"""
|
||
return prompt
|
||
|
||
def _parse_ai_response(self, content: Optional[str]) -> Optional[Dict]:
|
||
"""解析AI响应"""
|
||
if not content:
|
||
raise ValueError("AI响应内容为空")
|
||
|
||
try:
|
||
# 清理可能的 markdown 标记
|
||
content = content.replace("```json", "").replace("```", "").strip()
|
||
|
||
# 尝试修复常见的JSON格式问题
|
||
# 1. 处理未闭合的引号
|
||
content = self._fix_json_string(content)
|
||
|
||
# 2. 处理转义字符问题
|
||
content = self._fix_json_escapes(content)
|
||
|
||
# 解析JSON
|
||
result_dict = json.loads(content)
|
||
|
||
# 验证必需的字段
|
||
required_fields = ['task_description', 'priority', 'status']
|
||
for field in required_fields:
|
||
if field not in result_dict:
|
||
raise ValueError(f"AI响应缺少必需字段: {field}")
|
||
|
||
# 验证选项值
|
||
valid_priorities = ['紧急', '较紧急', '一般', '普通']
|
||
valid_statuses = ['已停滞','待开始', '进行中', '已完成']
|
||
|
||
if result_dict.get('priority') not in valid_priorities:
|
||
raise ValueError(f"无效的优先级: {result_dict.get('priority')}")
|
||
|
||
if result_dict.get('status') not in valid_statuses:
|
||
raise ValueError(f"无效的状态: {result_dict.get('status')}")
|
||
|
||
return result_dict
|
||
|
||
except json.JSONDecodeError as e:
|
||
# 尝试更详细的错误修复
|
||
try:
|
||
# 如果标准解析失败,尝试使用更宽松的解析
|
||
content = self._aggressive_json_fix(content)
|
||
result_dict = json.loads(content)
|
||
|
||
# 重新验证字段
|
||
required_fields = ['task_description', 'priority', 'status']
|
||
for field in required_fields:
|
||
if field not in result_dict:
|
||
raise ValueError(f"AI响应缺少必需字段: {field}")
|
||
|
||
# 重新验证选项值
|
||
valid_priorities = ['紧急', '较紧急', '一般', '普通']
|
||
valid_statuses = ['已停滞','待开始', '进行中', '已完成']
|
||
|
||
if result_dict.get('priority') not in valid_priorities:
|
||
raise ValueError(f"无效的优先级: {result_dict.get('priority')}")
|
||
|
||
if result_dict.get('status') not in valid_statuses:
|
||
raise ValueError(f"无效的状态: {result_dict.get('status')}")
|
||
|
||
return result_dict
|
||
except Exception as retry_error:
|
||
raise ValueError(f"AI响应不是有效的JSON: {str(e)} (修复后错误: {str(retry_error)})")
|
||
except Exception as e:
|
||
raise ValueError(f"解析AI响应失败: {str(e)}")
|
||
|
||
def _fix_json_string(self, content: str) -> str:
|
||
"""修复JSON字符串中的未闭合引号问题"""
|
||
import re
|
||
|
||
# 查找可能未闭合的字符串
|
||
# 匹配模式:从引号开始,但没有对应的闭合引号
|
||
lines = content.split('\n')
|
||
fixed_lines = []
|
||
|
||
for line in lines:
|
||
# 检查行中是否有未闭合的引号
|
||
in_string = False
|
||
escaped = False
|
||
fixed_line = []
|
||
|
||
for char in line:
|
||
if escaped:
|
||
fixed_line.append(char)
|
||
escaped = False
|
||
continue
|
||
|
||
if char == '\\':
|
||
escaped = True
|
||
fixed_line.append(char)
|
||
continue
|
||
|
||
if char == '"':
|
||
if in_string:
|
||
in_string = False
|
||
else:
|
||
in_string = True
|
||
fixed_line.append(char)
|
||
else:
|
||
fixed_line.append(char)
|
||
|
||
# 如果行结束时仍在字符串中,添加闭合引号
|
||
if in_string:
|
||
fixed_line.append('"')
|
||
|
||
fixed_lines.append(''.join(fixed_line))
|
||
|
||
return '\n'.join(fixed_lines)
|
||
|
||
def _fix_json_escapes(self, content: str) -> str:
|
||
"""修复JSON转义字符问题"""
|
||
# 注意:我们不应该转义JSON结构中的引号,只转义字符串内容中的引号
|
||
# 这个函数暂时不处理换行符,因为JSON中的换行符是有效的
|
||
# 更复杂的转义修复应该在JSON解析后进行
|
||
|
||
return content
|
||
|
||
def _aggressive_json_fix(self, content: str) -> str:
|
||
"""更激进的JSON修复策略"""
|
||
# 1. 移除可能的非JSON内容
|
||
content = re.sub(r'^[^{]*', '', content) # 移除JSON前的非JSON内容
|
||
content = re.sub(r'[^}]*$', '', content) # 移除JSON后的非JSON内容
|
||
|
||
# 2. 确保JSON对象闭合
|
||
if not content.strip().endswith('}'):
|
||
content = content.strip() + '}'
|
||
|
||
# 3. 确保JSON对象开始
|
||
if not content.strip().startswith('{'):
|
||
content = '{' + content.strip()
|
||
|
||
# 4. 处理常见的AI响应格式问题
|
||
# 移除可能的Markdown代码块标记
|
||
content = content.replace('```json', '').replace('```', '')
|
||
|
||
# 5. 处理可能的多余空格和换行
|
||
content = ' '.join(content.split())
|
||
|
||
return content
|
||
|
||
def _call_ai_with_retry(self, image_path: str, prompt: str) -> Optional[str]:
|
||
"""调用AI API,带重试机制"""
|
||
base64_image = self._encode_image(image_path)
|
||
|
||
for attempt in range(self.max_retries):
|
||
try:
|
||
# 尝试使用response_format参数(适用于OpenAI格式的API)
|
||
try:
|
||
response = self.client.chat.completions.create(
|
||
model=self.model,
|
||
messages=[
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{"type": "text", "text": prompt},
|
||
{
|
||
"type": "image_url",
|
||
"image_url": {
|
||
"url": f"data:image/jpeg;base64,{base64_image}"
|
||
}
|
||
}
|
||
]
|
||
}
|
||
],
|
||
response_format={"type": "json_object"},
|
||
max_tokens=2000
|
||
)
|
||
except Exception as e:
|
||
# 如果response_format参数不支持,尝试不使用该参数
|
||
print(f"⚠️ response_format参数不支持,尝试不使用该参数: {str(e)}")
|
||
response = self.client.chat.completions.create(
|
||
model=self.model,
|
||
messages=[
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{"type": "text", "text": prompt},
|
||
{
|
||
"type": "image_url",
|
||
"image_url": {
|
||
"url": f"data:image/jpeg;base64,{base64_image}"
|
||
}
|
||
}
|
||
]
|
||
}
|
||
],
|
||
max_tokens=2000
|
||
)
|
||
|
||
if not response.choices:
|
||
return None
|
||
|
||
content = response.choices[0].message.content
|
||
return content if content else None
|
||
|
||
except RateLimitError as e:
|
||
if attempt < self.max_retries - 1:
|
||
wait_time = self.retry_delay * (2 ** attempt) # 指数退避
|
||
time.sleep(wait_time)
|
||
continue
|
||
raise Exception(f"API速率限制: {str(e)}")
|
||
|
||
except AuthenticationError as e:
|
||
raise Exception(f"API认证失败: {str(e)}")
|
||
|
||
except APIError as e:
|
||
if attempt < self.max_retries - 1:
|
||
wait_time = self.retry_delay * (2 ** attempt)
|
||
time.sleep(wait_time)
|
||
continue
|
||
raise Exception(f"API调用失败: {str(e)}")
|
||
|
||
except Exception as e:
|
||
if attempt < self.max_retries - 1:
|
||
wait_time = self.retry_delay * (2 ** attempt)
|
||
time.sleep(wait_time)
|
||
continue
|
||
raise Exception(f"未知错误: {str(e)}")
|
||
|
||
raise Exception(f"AI调用失败,已重试 {self.max_retries} 次")
|
||
|
||
def _call_ai_with_retry_from_bytes(self, image_bytes: bytes, prompt: str, image_name: str = "memory_image") -> Optional[str]:
|
||
"""调用AI API,带重试机制(从内存字节数据)"""
|
||
base64_image = self._encode_image_from_bytes(image_bytes)
|
||
|
||
for attempt in range(self.max_retries):
|
||
try:
|
||
# 尝试使用response_format参数(适用于OpenAI格式的API)
|
||
try:
|
||
response = self.client.chat.completions.create(
|
||
model=self.model,
|
||
messages=[
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{"type": "text", "text": prompt},
|
||
{
|
||
"type": "image_url",
|
||
"image_url": {
|
||
"url": f"data:image/jpeg;base64,{base64_image}"
|
||
}
|
||
}
|
||
]
|
||
}
|
||
],
|
||
response_format={"type": "json_object"},
|
||
max_tokens=2000
|
||
)
|
||
except Exception as e:
|
||
# 如果response_format参数不支持,尝试不使用该参数
|
||
print(f"⚠️ response_format参数不支持,尝试不使用该参数: {str(e)}")
|
||
response = self.client.chat.completions.create(
|
||
model=self.model,
|
||
messages=[
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{"type": "text", "text": prompt},
|
||
{
|
||
"type": "image_url",
|
||
"image_url": {
|
||
"url": f"data:image/jpeg;base64,{base64_image}"
|
||
}
|
||
}
|
||
]
|
||
}
|
||
],
|
||
max_tokens=2000
|
||
)
|
||
|
||
if not response.choices:
|
||
return None
|
||
|
||
content = response.choices[0].message.content
|
||
return content if content else None
|
||
|
||
except RateLimitError as e:
|
||
if attempt < self.max_retries - 1:
|
||
wait_time = self.retry_delay * (2 ** attempt) # 指数退避
|
||
time.sleep(wait_time)
|
||
continue
|
||
raise Exception(f"API速率限制: {str(e)}")
|
||
|
||
except AuthenticationError as e:
|
||
raise Exception(f"API认证失败: {str(e)}")
|
||
|
||
except APIError as e:
|
||
if attempt < self.max_retries - 1:
|
||
wait_time = self.retry_delay * (2 ** attempt)
|
||
time.sleep(wait_time)
|
||
continue
|
||
raise Exception(f"API调用失败: {str(e)}")
|
||
|
||
except Exception as e:
|
||
if attempt < self.max_retries - 1:
|
||
wait_time = self.retry_delay * (2 ** attempt)
|
||
time.sleep(wait_time)
|
||
continue
|
||
raise Exception(f"未知错误: {str(e)}")
|
||
|
||
raise Exception(f"AI调用失败,已重试 {self.max_retries} 次")
|
||
|
||
for attempt in range(self.max_retries):
|
||
try:
|
||
# 尝试使用response_format参数(适用于OpenAI格式的API)
|
||
try:
|
||
response = self.client.chat.completions.create(
|
||
model=self.model,
|
||
messages=[
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{"type": "text", "text": prompt},
|
||
{
|
||
"type": "image_url",
|
||
"image_url": {
|
||
"url": f"data:image/jpeg;base64,{base64_image}"
|
||
}
|
||
}
|
||
]
|
||
}
|
||
],
|
||
response_format={"type": "json_object"},
|
||
max_tokens=2000
|
||
)
|
||
except Exception as e:
|
||
# 如果response_format参数不支持,尝试不使用该参数
|
||
print(f"⚠️ response_format参数不支持,尝试不使用该参数: {str(e)}")
|
||
response = self.client.chat.completions.create(
|
||
model=self.model,
|
||
messages=[
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{"type": "text", "text": prompt},
|
||
{
|
||
"type": "image_url",
|
||
"image_url": {
|
||
"url": f"data:image/jpeg;base64,{base64_image}"
|
||
}
|
||
}
|
||
]
|
||
}
|
||
],
|
||
max_tokens=2000
|
||
)
|
||
|
||
if not response.choices:
|
||
return None
|
||
|
||
content = response.choices[0].message.content
|
||
return content if content else None
|
||
|
||
except RateLimitError as e:
|
||
if attempt < self.max_retries - 1:
|
||
wait_time = self.retry_delay * (2 ** attempt) # 指数退避
|
||
time.sleep(wait_time)
|
||
continue
|
||
raise Exception(f"API速率限制: {str(e)}")
|
||
|
||
except AuthenticationError as e:
|
||
raise Exception(f"API认证失败: {str(e)}")
|
||
|
||
except APIError as e:
|
||
if attempt < self.max_retries - 1:
|
||
wait_time = self.retry_delay * (2 ** attempt)
|
||
time.sleep(wait_time)
|
||
continue
|
||
raise Exception(f"API调用失败: {str(e)}")
|
||
|
||
except Exception as e:
|
||
if attempt < self.max_retries - 1:
|
||
wait_time = self.retry_delay * (2 ** attempt)
|
||
time.sleep(wait_time)
|
||
continue
|
||
raise Exception(f"未知错误: {str(e)}")
|
||
|
||
raise Exception(f"AI调用失败,已重试 {self.max_retries} 次")
|
||
|
||
def analyze_image(self, image_path):
|
||
"""
|
||
核心方法:发送图片到 AI 并获取结构化数据
|
||
:param image_path: 图片文件的路径
|
||
:return: 解析后的字典 (Dict)
|
||
"""
|
||
from pathlib import Path
|
||
|
||
file_path = Path(image_path)
|
||
# 使用sys.stdout.write替代print,避免编码问题
|
||
import sys
|
||
sys.stdout.write(f" [AI] 正在分析图片: {file_path.name} ...\n")
|
||
sys.stdout.flush()
|
||
|
||
try:
|
||
# 1. 验证图片文件
|
||
self._validate_image_file(image_path)
|
||
|
||
# 2. 获取当前日期 (用于辅助 AI 推断相对时间)
|
||
now = datetime.datetime.now()
|
||
current_date_str = now.strftime("%Y-%m-%d %A") # 例如: 2023-10-27 Sunday
|
||
|
||
# 3. 构建 Prompt
|
||
prompt = self._build_prompt(current_date_str)
|
||
|
||
# 4. 调用AI API(带重试机制)
|
||
content = self._call_ai_with_retry(image_path, prompt)
|
||
|
||
# 5. 解析结果
|
||
if not content:
|
||
import sys
|
||
sys.stdout.write(f" [AI] AI返回空内容\n")
|
||
sys.stdout.flush()
|
||
return None
|
||
|
||
result_dict = self._parse_ai_response(content)
|
||
|
||
# 记录成功日志
|
||
if result_dict:
|
||
task_desc = result_dict.get('task_description', '')
|
||
if task_desc and len(task_desc) > 30:
|
||
task_desc = task_desc[:30] + "..."
|
||
import sys
|
||
sys.stdout.write(f" [AI] 识别成功: {task_desc}\n")
|
||
sys.stdout.flush()
|
||
|
||
return result_dict
|
||
|
||
except Exception as e:
|
||
import sys
|
||
sys.stdout.write(f" [AI] 分析失败: {str(e)}\n")
|
||
sys.stdout.flush()
|
||
return None
|
||
|
||
def analyze_image_from_bytes(self, image_bytes: bytes, image_name: str = "memory_image"):
|
||
"""
|
||
核心方法:从内存中的图片字节数据发送到 AI 并获取结构化数据
|
||
:param image_bytes: 图片的字节数据
|
||
:param image_name: 图片名称(用于日志)
|
||
:return: 解析后的字典 (Dict)
|
||
"""
|
||
# 使用sys.stdout.write替代print,避免编码问题
|
||
import sys
|
||
sys.stdout.write(f" [AI] 正在分析内存图片: {image_name} ...\n")
|
||
sys.stdout.flush()
|
||
|
||
try:
|
||
# 1. 验证图片数据大小
|
||
if len(image_bytes) > 10 * 1024 * 1024: # 10MB
|
||
raise ValueError(f"图片数据过大: {len(image_bytes) / 1024 / 1024:.2f}MB (最大10MB)")
|
||
|
||
# 2. 获取当前日期 (用于辅助 AI 推断相对时间)
|
||
now = datetime.datetime.now()
|
||
current_date_str = now.strftime("%Y-%m-%d %A") # 例如: 2023-10-27 Sunday
|
||
|
||
# 3. 构建 Prompt
|
||
prompt = self._build_prompt(current_date_str)
|
||
|
||
# 4. 调用AI API(带重试机制)
|
||
content = self._call_ai_with_retry_from_bytes(image_bytes, prompt, image_name)
|
||
|
||
# 5. 解析结果
|
||
if not content:
|
||
import sys
|
||
sys.stdout.write(f" [AI] AI返回空内容\n")
|
||
sys.stdout.flush()
|
||
return None
|
||
|
||
result_dict = self._parse_ai_response(content)
|
||
|
||
# 记录成功日志
|
||
if result_dict:
|
||
task_desc = result_dict.get('task_description', '')
|
||
if task_desc and len(task_desc) > 30:
|
||
task_desc = task_desc[:30] + "..."
|
||
import sys
|
||
sys.stdout.write(f" [AI] 识别成功: {task_desc}\n")
|
||
sys.stdout.flush()
|
||
|
||
return result_dict
|
||
|
||
except Exception as e:
|
||
import sys
|
||
sys.stdout.write(f" [AI] 分析失败: {str(e)}\n")
|
||
sys.stdout.flush()
|
||
return None
|
||
|
||
def analyze_image_batch(self, image_paths: List[str]) -> Dict[str, Optional[Dict]]:
|
||
"""
|
||
批量分析图片
|
||
:param image_paths: 图片文件路径列表
|
||
:return: 字典,键为图片路径,值为分析结果
|
||
"""
|
||
results = {}
|
||
|
||
for image_path in image_paths:
|
||
try:
|
||
result = self.analyze_image(image_path)
|
||
results[image_path] = result
|
||
except Exception as e:
|
||
import sys
|
||
sys.stdout.write(f" [AI] 批量处理失败 {image_path}: {str(e)}\n")
|
||
sys.stdout.flush()
|
||
results[image_path] = None
|
||
|
||
return results
|
||
|
||
# ================= 单元测试 =================
|
||
if __name__ == "__main__":
|
||
# 在这里填入你的配置进行测试
|
||
test_config = {
|
||
"api_key": "sk-xxxxxxxxxxxxxxxxxxxxxxxx",
|
||
"base_url": "https://api.openai.com/v1",
|
||
"model": "gpt-4o"
|
||
}
|
||
|
||
# 确保目录下有一张名为 test_img.jpg 的图片
|
||
import os
|
||
if os.path.exists("test_img.jpg"):
|
||
ai = AIService(test_config)
|
||
res = ai.analyze_image("test_img.jpg")
|
||
import sys
|
||
sys.stdout.write(json.dumps(res, indent=2, ensure_ascii=False) + "\n")
|
||
sys.stdout.flush()
|
||
else:
|
||
import sys
|
||
sys.stdout.write("请在同级目录下放一张 test_img.jpg 用于测试\n")
|
||
sys.stdout.flush() |