Initial commit: 个性化饮食推荐助手 - 包含OCR识别、AI分析、现代化界面等功能
This commit is contained in:
778
modules/ai_analysis.py
Normal file
778
modules/ai_analysis.py
Normal file
@@ -0,0 +1,778 @@
|
||||
"""
|
||||
AI分析模块 - 基于基座架构
|
||||
集成大模型进行用户需求分析和营养建议
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, Any
|
||||
import json
|
||||
import requests
|
||||
from datetime import datetime, date
|
||||
from core.base import BaseModule, ModuleType, UserData, AnalysisResult, BaseConfig
|
||||
import os
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
except Exception:
|
||||
# 如果没有.env文件或加载失败,使用默认配置
|
||||
pass
|
||||
|
||||
# 导入千问客户端
|
||||
import sys
|
||||
from pathlib import Path
|
||||
sys.path.append(str(Path(__file__).parent.parent))
|
||||
from llm_integration.qwen_client import get_qwen_client, analyze_user_intent_with_qwen, analyze_nutrition_with_qwen
|
||||
|
||||
|
||||
class AIAnalysisModule(BaseModule):
|
||||
"""AI分析模块"""
|
||||
|
||||
def __init__(self, config: BaseConfig):
|
||||
super().__init__(config, ModuleType.USER_ANALYSIS)
|
||||
self.qwen_client = None
|
||||
self.analysis_templates = self._load_analysis_templates()
|
||||
|
||||
def initialize(self) -> bool:
|
||||
"""初始化模块"""
|
||||
try:
|
||||
self.logger.info("AI分析模块初始化中...")
|
||||
|
||||
# 初始化千问客户端
|
||||
self.qwen_client = get_qwen_client()
|
||||
self.logger.info("千问客户端初始化成功")
|
||||
|
||||
self.is_initialized = True
|
||||
self.logger.info("AI分析模块初始化完成")
|
||||
return True
|
||||
except Exception as e:
|
||||
self.logger.error(f"AI分析模块初始化失败: {e}")
|
||||
return False
|
||||
|
||||
def process(self, input_data: Any, user_data: UserData) -> AnalysisResult:
|
||||
"""处理AI分析请求"""
|
||||
try:
|
||||
analysis_type = input_data.get('type', 'user_intent')
|
||||
|
||||
if analysis_type == 'user_intent':
|
||||
result = self._analyze_user_intent(input_data, user_data)
|
||||
elif analysis_type == 'nutrition_analysis':
|
||||
result = self._analyze_nutrition(input_data, user_data)
|
||||
elif analysis_type == 'calorie_estimation':
|
||||
result = self._estimate_calories(input_data, user_data)
|
||||
elif analysis_type == 'physiological_state':
|
||||
result = self._analyze_physiological_state(input_data, user_data)
|
||||
elif analysis_type == 'meal_suggestion':
|
||||
result = self._generate_meal_suggestion(input_data, user_data)
|
||||
else:
|
||||
result = self._create_error_result("未知的分析类型")
|
||||
|
||||
return AnalysisResult(
|
||||
module_type=self.module_type,
|
||||
user_id=user_data.user_id,
|
||||
input_data=input_data,
|
||||
result=result,
|
||||
confidence=result.get('confidence', 0.5)
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.error(f"处理AI分析请求失败: {e}")
|
||||
return self._create_error_result(str(e))
|
||||
|
||||
def _analyze_user_intent(self, input_data: Dict, user_data: UserData) -> Dict[str, Any]:
|
||||
"""分析用户意图"""
|
||||
user_input = input_data.get('user_input', '')
|
||||
|
||||
if not self.qwen_client:
|
||||
return self._get_fallback_intent_analysis(user_input, user_data)
|
||||
|
||||
try:
|
||||
# 构建用户上下文
|
||||
user_context = {
|
||||
'name': user_data.profile.get('name', '未知'),
|
||||
'age': user_data.profile.get('age', '未知'),
|
||||
'gender': user_data.profile.get('gender', '未知'),
|
||||
'height': user_data.profile.get('height', '未知'),
|
||||
'weight': user_data.profile.get('weight', '未知'),
|
||||
'activity_level': user_data.profile.get('activity_level', '未知'),
|
||||
'taste_preferences': user_data.profile.get('taste_preferences', {}),
|
||||
'allergies': user_data.profile.get('allergies', []),
|
||||
'dislikes': user_data.profile.get('dislikes', []),
|
||||
'dietary_preferences': user_data.profile.get('dietary_preferences', []),
|
||||
'recent_meals': user_data.meals[-3:] if user_data.meals else [],
|
||||
'feedback_history': user_data.feedback[-5:] if user_data.feedback else []
|
||||
}
|
||||
|
||||
# 使用千问分析用户意图
|
||||
result = analyze_user_intent_with_qwen(user_input, user_context)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"用户意图分析失败: {e}")
|
||||
return self._get_fallback_intent_analysis(user_input, user_data)
|
||||
|
||||
def _analyze_nutrition(self, input_data: Dict, user_data: UserData) -> Dict[str, Any]:
|
||||
"""分析营养状况"""
|
||||
meal_data = input_data.get('meal_data', {})
|
||||
|
||||
if not self.qwen_client:
|
||||
return self._get_fallback_nutrition_analysis(meal_data, user_data)
|
||||
|
||||
try:
|
||||
# 构建用户上下文
|
||||
user_context = {
|
||||
'age': user_data.profile.get('age', '未知'),
|
||||
'gender': user_data.profile.get('gender', '未知'),
|
||||
'height': user_data.profile.get('height', '未知'),
|
||||
'weight': user_data.profile.get('weight', '未知'),
|
||||
'activity_level': user_data.profile.get('activity_level', '未知'),
|
||||
'health_goals': user_data.profile.get('health_goals', [])
|
||||
}
|
||||
|
||||
# 使用千问分析营养状况
|
||||
result = analyze_nutrition_with_qwen(meal_data, user_context)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"营养分析失败: {e}")
|
||||
return self._get_fallback_nutrition_analysis(meal_data, user_data)
|
||||
|
||||
def _estimate_calories(self, input_data: Dict, user_data: UserData) -> Dict[str, Any]:
|
||||
"""估算食物热量"""
|
||||
food_data = input_data.get('food_data', {})
|
||||
food_name = food_data.get('food_name', '')
|
||||
quantity = food_data.get('quantity', '')
|
||||
|
||||
if not food_name or not quantity:
|
||||
return self._create_error_result("缺少食物名称或分量信息")
|
||||
|
||||
try:
|
||||
# 基础热量数据库
|
||||
calorie_db = {
|
||||
"米饭": 130, "面条": 110, "包子": 200, "饺子": 250, "馒头": 220, "面包": 250,
|
||||
"鸡蛋": 150, "牛奶": 60, "豆浆": 30, "酸奶": 80, "苹果": 50, "香蕉": 90,
|
||||
"鸡肉": 165, "牛肉": 250, "猪肉": 300, "鱼肉": 120, "豆腐": 80, "青菜": 20,
|
||||
"西红柿": 20, "黄瓜": 15, "胡萝卜": 40, "土豆": 80, "红薯": 100, "玉米": 90
|
||||
}
|
||||
|
||||
# 获取基础热量
|
||||
base_calories = calorie_db.get(food_name, 100) # 默认100卡路里
|
||||
|
||||
# 简单的分量解析
|
||||
quantity_lower = quantity.lower()
|
||||
if '碗' in quantity_lower:
|
||||
multiplier = 1.0
|
||||
elif 'g' in quantity_lower or '克' in quantity_lower:
|
||||
# 假设一碗米饭约200g
|
||||
multiplier = 0.5
|
||||
elif '个' in quantity_lower:
|
||||
multiplier = 1.0
|
||||
else:
|
||||
multiplier = 1.0
|
||||
|
||||
# 计算总热量
|
||||
total_calories = base_calories * multiplier
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'calories': total_calories,
|
||||
'food_name': food_name,
|
||||
'quantity': quantity,
|
||||
'base_calories': base_calories,
|
||||
'multiplier': multiplier,
|
||||
'confidence': 0.8
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"热量估算失败: {e}")
|
||||
return self._create_error_result(f"热量估算失败: {str(e)}")
|
||||
|
||||
def _analyze_physiological_state(self, input_data: Dict, user_data: UserData) -> Dict[str, Any]:
|
||||
"""分析生理状态"""
|
||||
current_date = input_data.get('current_date', datetime.now().strftime('%Y-%m-%d'))
|
||||
|
||||
if not user_data.profile.get('is_female', False):
|
||||
return {
|
||||
'success': True,
|
||||
'physiological_state': 'normal',
|
||||
'needs': [],
|
||||
'recommendations': [],
|
||||
'confidence': 0.8
|
||||
}
|
||||
|
||||
try:
|
||||
cycle_info = self._calculate_menstrual_cycle(user_data.profile, current_date)
|
||||
|
||||
if not self.qwen_client:
|
||||
return self._get_fallback_physiological_analysis(cycle_info)
|
||||
|
||||
prompt = self._build_physiological_analysis_prompt(user_data.profile, cycle_info)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": self._get_physiological_analysis_system_prompt()},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
|
||||
response = self.qwen_client.chat_completion(messages, temperature=0.2, max_tokens=800)
|
||||
|
||||
if response and 'choices' in response:
|
||||
analysis_text = response['choices'][0]['message']['content']
|
||||
else:
|
||||
return self._get_fallback_physiological_analysis(cycle_info)
|
||||
return self._parse_physiological_analysis(analysis_text, cycle_info)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"生理状态分析失败: {e}")
|
||||
return self._get_fallback_physiological_analysis({})
|
||||
|
||||
def _generate_meal_suggestion(self, input_data: Dict, user_data: UserData) -> Dict[str, Any]:
|
||||
"""生成餐食建议"""
|
||||
meal_type = input_data.get('meal_type', 'lunch')
|
||||
preferences = input_data.get('preferences', {})
|
||||
|
||||
if not self.qwen_client:
|
||||
return self._get_fallback_meal_suggestion(meal_type, user_data)
|
||||
|
||||
try:
|
||||
prompt = self._build_meal_suggestion_prompt(meal_type, preferences, user_data)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": self._get_meal_suggestion_system_prompt()},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
|
||||
response = self.qwen_client.chat_completion(messages, temperature=0.4, max_tokens=1000)
|
||||
|
||||
if response and 'choices' in response:
|
||||
suggestion_text = response['choices'][0]['message']['content']
|
||||
else:
|
||||
return self._get_fallback_meal_suggestion(meal_type, user_data)
|
||||
return self._parse_meal_suggestion(suggestion_text)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"餐食建议生成失败: {e}")
|
||||
return self._get_fallback_meal_suggestion(meal_type, user_data)
|
||||
|
||||
def _build_intent_analysis_prompt(self, user_input: str, user_data: UserData) -> str:
|
||||
"""构建意图分析提示词"""
|
||||
return f"""
|
||||
请分析以下用户输入的真实意图和需求:
|
||||
|
||||
用户输入: "{user_input}"
|
||||
|
||||
用户背景:
|
||||
- 姓名: {user_data.profile.get('name', '未知')}
|
||||
- 年龄: {user_data.profile.get('age', '未知')}
|
||||
- 性别: {user_data.profile.get('gender', '未知')}
|
||||
- 身高体重: {user_data.profile.get('height', '未知')}cm, {user_data.profile.get('weight', '未知')}kg
|
||||
|
||||
口味偏好: {json.dumps(user_data.profile.get('taste_preferences', {}), ensure_ascii=False)}
|
||||
饮食限制: {', '.join(user_data.profile.get('allergies', []) + user_data.profile.get('dislikes', []))}
|
||||
|
||||
最近饮食记录: {self._format_recent_meals(user_data.meals[-3:])}
|
||||
|
||||
请分析:
|
||||
1. 用户的真实意图(饿了、馋了、需要特定营养等)
|
||||
2. 情绪状态(压力、开心、疲惫等)
|
||||
3. 营养需求
|
||||
4. 推荐的食物类型
|
||||
5. 推荐理由
|
||||
|
||||
请以JSON格式返回分析结果。
|
||||
"""
|
||||
|
||||
def _build_nutrition_analysis_prompt(self, meal_data: Dict, user_data: UserData) -> str:
|
||||
"""构建营养分析提示词"""
|
||||
return f"""
|
||||
请分析以下餐食的营养状况:
|
||||
|
||||
餐食信息:
|
||||
- 食物: {', '.join(meal_data.get('foods', []))}
|
||||
- 分量: {', '.join(meal_data.get('quantities', []))}
|
||||
- 热量: {meal_data.get('calories', '未知')}卡路里
|
||||
|
||||
用户信息:
|
||||
- 年龄: {user_data.profile.get('age', '未知')}
|
||||
- 性别: {user_data.profile.get('gender', '未知')}
|
||||
- 身高体重: {user_data.profile.get('height', '未知')}cm, {user_data.profile.get('weight', '未知')}kg
|
||||
- 活动水平: {user_data.profile.get('activity_level', '未知')}
|
||||
- 健康目标: {', '.join(user_data.profile.get('health_goals', []))}
|
||||
|
||||
请分析:
|
||||
1. 营养均衡性
|
||||
2. 热量是否合适
|
||||
3. 缺少的营养素
|
||||
4. 建议改进的地方
|
||||
5. 个性化建议
|
||||
|
||||
请以JSON格式返回分析结果。
|
||||
"""
|
||||
|
||||
def _build_physiological_analysis_prompt(self, profile: Dict, cycle_info: Dict) -> str:
|
||||
"""构建生理状态分析提示词"""
|
||||
return f"""
|
||||
作为专业的女性健康专家,请分析以下用户的生理状态:
|
||||
|
||||
用户信息:
|
||||
- 年龄: {profile.get('age', '未知')}
|
||||
- 身高体重: {profile.get('height', '未知')}cm, {profile.get('weight', '未知')}kg
|
||||
- 月经周期长度: {profile.get('menstrual_cycle_length', '未知')}天
|
||||
- 上次月经: {profile.get('last_period_date', '未知')}
|
||||
|
||||
当前生理周期状态:
|
||||
- 周期阶段: {cycle_info.get('phase', '未知')}
|
||||
- 距离下次月经: {cycle_info.get('days_to_next_period', '未知')}天
|
||||
|
||||
请分析:
|
||||
1. 当前生理状态对营养需求的影响
|
||||
2. 建议补充的营养素
|
||||
3. 需要避免的食物
|
||||
4. 情绪和食欲的变化
|
||||
5. 个性化建议
|
||||
|
||||
请以JSON格式返回分析结果。
|
||||
"""
|
||||
|
||||
def _build_meal_suggestion_prompt(self, meal_type: str, preferences: Dict, user_data: UserData) -> str:
|
||||
"""构建餐食建议提示词"""
|
||||
return f"""
|
||||
请为以下用户推荐{meal_type}:
|
||||
|
||||
用户信息:
|
||||
- 姓名: {user_data.profile.get('name', '未知')}
|
||||
- 年龄: {user_data.profile.get('age', '未知')}
|
||||
- 性别: {user_data.profile.get('gender', '未知')}
|
||||
- 身高体重: {user_data.profile.get('height', '未知')}cm, {user_data.profile.get('weight', '未知')}kg
|
||||
|
||||
口味偏好: {json.dumps(user_data.profile.get('taste_preferences', {}), ensure_ascii=False)}
|
||||
饮食限制: 过敏({', '.join(user_data.profile.get('allergies', []))}), 不喜欢({', '.join(user_data.profile.get('dislikes', []))})
|
||||
健康目标: {', '.join(user_data.profile.get('health_goals', []))}
|
||||
|
||||
特殊偏好: {json.dumps(preferences, ensure_ascii=False)}
|
||||
|
||||
请推荐:
|
||||
1. 3-5种适合的食物
|
||||
2. 推荐理由
|
||||
3. 营养搭配建议
|
||||
4. 制作建议
|
||||
|
||||
请以JSON格式返回建议。
|
||||
"""
|
||||
|
||||
def _get_intent_analysis_system_prompt(self) -> str:
|
||||
"""获取意图分析系统提示词"""
|
||||
return """
|
||||
你是一个专业的营养师和心理学专家,擅长分析用户的饮食需求和心理状态。
|
||||
|
||||
你的任务是:
|
||||
1. 深度理解用户的真实需求,不仅仅是表面的话语
|
||||
2. 考虑用户的生理状态、情绪状态、历史偏好等多维度因素
|
||||
3. 提供个性化的饮食建议
|
||||
4. 特别关注女性用户的生理周期对饮食需求的影响
|
||||
|
||||
分析时要:
|
||||
- 透过现象看本质,理解用户的真实意图
|
||||
- 综合考虑生理、心理、社会等多重因素
|
||||
- 提供科学、实用、个性化的建议
|
||||
- 保持专业性和同理心
|
||||
|
||||
返回格式必须是有效的JSON,包含所有必需字段。
|
||||
"""
|
||||
|
||||
def _get_nutrition_analysis_system_prompt(self) -> str:
|
||||
"""获取营养分析系统提示词"""
|
||||
return """
|
||||
你是一个专业的营养师,擅长分析餐食的营养价值和健康建议。
|
||||
|
||||
你的任务是:
|
||||
1. 分析餐食的营养均衡性
|
||||
2. 评估热量是否合适
|
||||
3. 识别缺少的营养素
|
||||
4. 提供改进建议
|
||||
5. 考虑用户的个人情况
|
||||
|
||||
分析时要:
|
||||
- 基于科学的营养学知识
|
||||
- 考虑用户的年龄、性别、体重、活动水平等因素
|
||||
- 提供具体可行的建议
|
||||
- 保持客观和专业
|
||||
|
||||
返回格式必须是有效的JSON,包含所有必需字段。
|
||||
"""
|
||||
|
||||
def _get_physiological_analysis_system_prompt(self) -> str:
|
||||
"""获取生理状态分析系统提示词"""
|
||||
return """
|
||||
你是专业的女性健康专家,了解生理周期对营养需求的影响。
|
||||
|
||||
你的任务是:
|
||||
1. 分析女性用户的生理周期状态
|
||||
2. 评估当前阶段的营养需求
|
||||
3. 提供针对性的饮食建议
|
||||
4. 考虑情绪和食欲的变化
|
||||
5. 提供个性化建议
|
||||
|
||||
分析时要:
|
||||
- 基于科学的生理学知识
|
||||
- 考虑个体差异
|
||||
- 提供温和、实用的建议
|
||||
- 保持专业和同理心
|
||||
|
||||
返回格式必须是有效的JSON,包含所有必需字段。
|
||||
"""
|
||||
|
||||
def _get_meal_suggestion_system_prompt(self) -> str:
|
||||
"""获取餐食建议系统提示词"""
|
||||
return """
|
||||
你是一个专业的营养师和厨师,擅长根据用户需求推荐合适的餐食。
|
||||
|
||||
你的任务是:
|
||||
1. 根据用户的口味偏好推荐食物
|
||||
2. 考虑饮食限制和过敏情况
|
||||
3. 提供营养均衡的建议
|
||||
4. 考虑用户的健康目标
|
||||
5. 提供实用的制作建议
|
||||
|
||||
推荐时要:
|
||||
- 基于营养学原理
|
||||
- 考虑用户的个人喜好
|
||||
- 提供多样化的选择
|
||||
- 保持实用性和可操作性
|
||||
|
||||
返回格式必须是有效的JSON,包含所有必需字段。
|
||||
"""
|
||||
|
||||
def _calculate_menstrual_cycle(self, profile: Dict, current_date: str) -> Dict[str, Any]:
|
||||
"""计算月经周期状态"""
|
||||
try:
|
||||
last_period = datetime.strptime(profile.get('last_period_date', ''), '%Y-%m-%d')
|
||||
current = datetime.strptime(current_date, '%Y-%m-%d')
|
||||
cycle_length = profile.get('menstrual_cycle_length', 28)
|
||||
|
||||
days_since_period = (current - last_period).days
|
||||
days_to_next_period = cycle_length - (days_since_period % cycle_length)
|
||||
|
||||
# 判断周期阶段
|
||||
if days_since_period % cycle_length < 5:
|
||||
phase = "月经期"
|
||||
elif days_since_period % cycle_length < 14:
|
||||
phase = "卵泡期"
|
||||
elif days_since_period % cycle_length < 18:
|
||||
phase = "排卵期"
|
||||
else:
|
||||
phase = "黄体期"
|
||||
|
||||
return {
|
||||
"phase": phase,
|
||||
"days_since_period": days_since_period % cycle_length,
|
||||
"days_to_next_period": days_to_next_period,
|
||||
"is_ovulation": phase == "排卵期",
|
||||
"cycle_length": cycle_length
|
||||
}
|
||||
except Exception:
|
||||
return {
|
||||
"phase": "未知",
|
||||
"days_since_period": 0,
|
||||
"days_to_next_period": 0,
|
||||
"is_ovulation": False,
|
||||
"cycle_length": 28
|
||||
}
|
||||
|
||||
def _format_recent_meals(self, meals: List[Dict]) -> str:
|
||||
"""格式化最近餐食"""
|
||||
if not meals:
|
||||
return "暂无饮食记录"
|
||||
|
||||
formatted = []
|
||||
for meal in meals:
|
||||
foods = ', '.join(meal.get('foods', []))
|
||||
satisfaction = meal.get('satisfaction_score', '未知')
|
||||
formatted.append(f"- {meal.get('date', '')} {meal.get('meal_type', '')}: {foods} (满意度: {satisfaction})")
|
||||
|
||||
return '\n'.join(formatted)
|
||||
|
||||
def _parse_intent_analysis(self, analysis_text: str) -> Dict[str, Any]:
|
||||
"""解析意图分析结果"""
|
||||
try:
|
||||
start_idx = analysis_text.find('{')
|
||||
end_idx = analysis_text.rfind('}') + 1
|
||||
|
||||
if start_idx != -1 and end_idx != -1:
|
||||
json_str = analysis_text[start_idx:end_idx]
|
||||
result_dict = json.loads(json_str)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'user_intent': result_dict.get('user_intent', ''),
|
||||
'emotional_state': result_dict.get('emotional_state', ''),
|
||||
'nutritional_needs': result_dict.get('nutritional_needs', []),
|
||||
'recommended_foods': result_dict.get('recommended_foods', []),
|
||||
'reasoning': result_dict.get('reasoning', ''),
|
||||
'confidence': result_dict.get('confidence', 0.5)
|
||||
}
|
||||
except Exception as e:
|
||||
self.logger.error(f"解析意图分析结果失败: {e}")
|
||||
|
||||
return self._get_fallback_intent_analysis("", None)
|
||||
|
||||
def _parse_nutrition_analysis(self, analysis_text: str) -> Dict[str, Any]:
|
||||
"""解析营养分析结果"""
|
||||
try:
|
||||
start_idx = analysis_text.find('{')
|
||||
end_idx = analysis_text.rfind('}') + 1
|
||||
|
||||
if start_idx != -1 and end_idx != -1:
|
||||
json_str = analysis_text[start_idx:end_idx]
|
||||
result_dict = json.loads(json_str)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'nutrition_balance': result_dict.get('nutrition_balance', ''),
|
||||
'calorie_assessment': result_dict.get('calorie_assessment', ''),
|
||||
'missing_nutrients': result_dict.get('missing_nutrients', []),
|
||||
'improvements': result_dict.get('improvements', []),
|
||||
'recommendations': result_dict.get('recommendations', []),
|
||||
'confidence': result_dict.get('confidence', 0.5)
|
||||
}
|
||||
except Exception as e:
|
||||
self.logger.error(f"解析营养分析结果失败: {e}")
|
||||
|
||||
return self._get_fallback_nutrition_analysis({}, None)
|
||||
|
||||
def _parse_physiological_analysis(self, analysis_text: str, cycle_info: Dict) -> Dict[str, Any]:
|
||||
"""解析生理状态分析结果"""
|
||||
try:
|
||||
start_idx = analysis_text.find('{')
|
||||
end_idx = analysis_text.rfind('}') + 1
|
||||
|
||||
if start_idx != -1 and end_idx != -1:
|
||||
json_str = analysis_text[start_idx:end_idx]
|
||||
result_dict = json.loads(json_str)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'physiological_state': result_dict.get('physiological_state', cycle_info.get('phase', '')),
|
||||
'nutritional_needs': result_dict.get('nutritional_needs', []),
|
||||
'foods_to_avoid': result_dict.get('foods_to_avoid', []),
|
||||
'emotional_changes': result_dict.get('emotional_changes', ''),
|
||||
'appetite_changes': result_dict.get('appetite_changes', ''),
|
||||
'recommendations': result_dict.get('recommendations', []),
|
||||
'cycle_info': cycle_info,
|
||||
'confidence': result_dict.get('confidence', 0.5)
|
||||
}
|
||||
except Exception as e:
|
||||
self.logger.error(f"解析生理分析结果失败: {e}")
|
||||
|
||||
return self._get_fallback_physiological_analysis(cycle_info)
|
||||
|
||||
def _parse_meal_suggestion(self, suggestion_text: str) -> Dict[str, Any]:
|
||||
"""解析餐食建议结果"""
|
||||
try:
|
||||
start_idx = suggestion_text.find('{')
|
||||
end_idx = suggestion_text.rfind('}') + 1
|
||||
|
||||
if start_idx != -1 and end_idx != -1:
|
||||
json_str = suggestion_text[start_idx:end_idx]
|
||||
result_dict = json.loads(json_str)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'recommended_foods': result_dict.get('recommended_foods', []),
|
||||
'reasoning': result_dict.get('reasoning', ''),
|
||||
'nutrition_tips': result_dict.get('nutrition_tips', []),
|
||||
'cooking_suggestions': result_dict.get('cooking_suggestions', []),
|
||||
'confidence': result_dict.get('confidence', 0.5)
|
||||
}
|
||||
except Exception as e:
|
||||
self.logger.error(f"解析餐食建议结果失败: {e}")
|
||||
|
||||
return self._get_fallback_meal_suggestion("lunch", None)
|
||||
|
||||
def _get_fallback_intent_analysis(self, user_input: str, user_data: UserData) -> Dict[str, Any]:
|
||||
"""获取备用意图分析结果"""
|
||||
return {
|
||||
'success': True,
|
||||
'user_intent': '需要饮食建议',
|
||||
'emotional_state': '正常',
|
||||
'nutritional_needs': ['均衡营养'],
|
||||
'recommended_foods': ['米饭', '蔬菜', '蛋白质'],
|
||||
'reasoning': '基于基础营养需求',
|
||||
'confidence': 0.3
|
||||
}
|
||||
|
||||
def _get_fallback_nutrition_analysis(self, meal_data: Dict, user_data: UserData) -> Dict[str, Any]:
|
||||
"""获取备用营养分析结果"""
|
||||
return {
|
||||
'success': True,
|
||||
'nutrition_balance': '基本均衡',
|
||||
'calorie_assessment': '适中',
|
||||
'missing_nutrients': [],
|
||||
'improvements': ['增加蔬菜摄入'],
|
||||
'recommendations': ['保持均衡饮食'],
|
||||
'confidence': 0.3
|
||||
}
|
||||
|
||||
def _get_fallback_physiological_analysis(self, cycle_info: Dict) -> Dict[str, Any]:
|
||||
"""获取备用生理状态分析结果"""
|
||||
return {
|
||||
'success': True,
|
||||
'physiological_state': cycle_info.get('phase', '正常'),
|
||||
'nutritional_needs': ['均衡营养'],
|
||||
'foods_to_avoid': [],
|
||||
'emotional_changes': '正常',
|
||||
'appetite_changes': '正常',
|
||||
'recommendations': ['保持规律饮食'],
|
||||
'cycle_info': cycle_info,
|
||||
'confidence': 0.3
|
||||
}
|
||||
|
||||
def _get_fallback_meal_suggestion(self, meal_type: str, user_data: UserData) -> Dict[str, Any]:
|
||||
"""获取备用餐食建议结果"""
|
||||
return {
|
||||
'success': True,
|
||||
'recommended_foods': ['米饭', '蔬菜', '蛋白质'],
|
||||
'reasoning': '营养均衡的基础搭配',
|
||||
'nutrition_tips': ['注意营养搭配'],
|
||||
'cooking_suggestions': ['简单烹饪'],
|
||||
'confidence': 0.3
|
||||
}
|
||||
|
||||
def _load_analysis_templates(self) -> Dict[str, Dict]:
|
||||
"""加载分析模板"""
|
||||
return {
|
||||
'intent_analysis': {
|
||||
'description': '用户意图分析',
|
||||
'required_fields': ['user_input']
|
||||
},
|
||||
'nutrition_analysis': {
|
||||
'description': '营养状况分析',
|
||||
'required_fields': ['meal_data']
|
||||
},
|
||||
'physiological_state': {
|
||||
'description': '生理状态分析',
|
||||
'required_fields': ['current_date']
|
||||
},
|
||||
'meal_suggestion': {
|
||||
'description': '餐食建议生成',
|
||||
'required_fields': ['meal_type']
|
||||
}
|
||||
}
|
||||
|
||||
def _create_error_result(self, error_message: str) -> Dict[str, Any]:
|
||||
"""创建错误结果"""
|
||||
return {
|
||||
'success': False,
|
||||
'error': error_message,
|
||||
'message': f'AI分析失败: {error_message}',
|
||||
'confidence': 0.0
|
||||
}
|
||||
|
||||
def cleanup(self) -> bool:
|
||||
"""清理资源"""
|
||||
try:
|
||||
self.logger.info("AI分析模块清理完成")
|
||||
return True
|
||||
except Exception as e:
|
||||
self.logger.error(f"AI分析模块清理失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# 便捷函数
|
||||
def analyze_user_intent(user_id: str, user_input: str) -> Optional[Dict]:
|
||||
"""分析用户意图"""
|
||||
from core.base import get_app_core
|
||||
|
||||
app = get_app_core()
|
||||
input_data = {
|
||||
'type': 'user_intent',
|
||||
'user_input': user_input
|
||||
}
|
||||
|
||||
result = app.process_user_request(ModuleType.USER_ANALYSIS, input_data, user_id)
|
||||
return result.result if result else None
|
||||
|
||||
|
||||
def analyze_nutrition(user_id: str, meal_data: Dict) -> Optional[Dict]:
|
||||
"""分析营养状况"""
|
||||
from core.base import get_app_core
|
||||
|
||||
app = get_app_core()
|
||||
input_data = {
|
||||
'type': 'nutrition_analysis',
|
||||
'meal_data': meal_data
|
||||
}
|
||||
|
||||
result = app.process_user_request(ModuleType.USER_ANALYSIS, input_data, user_id)
|
||||
return result.result if result else None
|
||||
|
||||
|
||||
def analyze_physiological_state(user_id: str, current_date: str = None) -> Optional[Dict]:
|
||||
"""分析生理状态"""
|
||||
from core.base import get_app_core
|
||||
from datetime import datetime
|
||||
|
||||
if current_date is None:
|
||||
current_date = datetime.now().strftime('%Y-%m-%d')
|
||||
|
||||
app = get_app_core()
|
||||
input_data = {
|
||||
'type': 'physiological_state',
|
||||
'current_date': current_date
|
||||
}
|
||||
|
||||
result = app.process_user_request(ModuleType.USER_ANALYSIS, input_data, user_id)
|
||||
return result.result if result else None
|
||||
|
||||
|
||||
def generate_meal_suggestion(user_id: str, meal_type: str, preferences: Dict = None) -> Optional[Dict]:
|
||||
"""生成餐食建议"""
|
||||
from core.base import get_app_core
|
||||
|
||||
if preferences is None:
|
||||
preferences = {}
|
||||
|
||||
app = get_app_core()
|
||||
input_data = {
|
||||
'type': 'meal_suggestion',
|
||||
'meal_type': meal_type,
|
||||
'preferences': preferences
|
||||
}
|
||||
|
||||
result = app.process_user_request(ModuleType.USER_ANALYSIS, input_data, user_id)
|
||||
return result.result if result else None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 测试AI分析模块
|
||||
from core.base import BaseConfig, initialize_app, cleanup_app
|
||||
|
||||
print("测试AI分析模块...")
|
||||
|
||||
# 初始化应用
|
||||
config = BaseConfig()
|
||||
if initialize_app(config):
|
||||
print("✅ 应用初始化成功")
|
||||
|
||||
# 测试用户意图分析
|
||||
test_user_id = "test_user_001"
|
||||
user_input = "我今天有点累,想吃点甜的,但是又怕胖"
|
||||
|
||||
result = analyze_user_intent(test_user_id, user_input)
|
||||
if result:
|
||||
print(f"✅ 用户意图分析成功: {result.get('user_intent', '')}")
|
||||
|
||||
# 测试营养分析
|
||||
meal_data = {
|
||||
'foods': ['燕麦粥', '香蕉', '牛奶'],
|
||||
'quantities': ['1碗', '1根', '200ml'],
|
||||
'calories': 350.0
|
||||
}
|
||||
|
||||
result = analyze_nutrition(test_user_id, meal_data)
|
||||
if result:
|
||||
print(f"✅ 营养分析成功: {result.get('nutrition_balance', '')}")
|
||||
|
||||
# 清理应用
|
||||
cleanup_app()
|
||||
print("✅ 应用清理完成")
|
||||
else:
|
||||
print("❌ 应用初始化失败")
|
||||
|
||||
print("AI分析模块测试完成!")
|
||||
367
modules/data_collection.py
Normal file
367
modules/data_collection.py
Normal file
@@ -0,0 +1,367 @@
|
||||
"""
|
||||
数据采集模块 - 基于基座架构
|
||||
负责收集用户数据、问卷和餐食记录
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, Any
|
||||
import json
|
||||
from datetime import datetime, date
|
||||
from core.base import BaseModule, ModuleType, UserData, AnalysisResult, BaseConfig
|
||||
|
||||
|
||||
class DataCollectionModule(BaseModule):
|
||||
"""数据采集模块"""
|
||||
|
||||
def __init__(self, config: BaseConfig):
|
||||
super().__init__(config, ModuleType.DATA_COLLECTION)
|
||||
self.questionnaire_templates = self._load_questionnaire_templates()
|
||||
|
||||
def initialize(self) -> bool:
|
||||
"""初始化模块"""
|
||||
try:
|
||||
self.logger.info("数据采集模块初始化中...")
|
||||
self.is_initialized = True
|
||||
self.logger.info("数据采集模块初始化完成")
|
||||
return True
|
||||
except Exception as e:
|
||||
self.logger.error(f"数据采集模块初始化失败: {e}")
|
||||
return False
|
||||
|
||||
def process(self, input_data: Any, user_data: UserData) -> AnalysisResult:
|
||||
"""处理数据采集请求"""
|
||||
try:
|
||||
request_type = input_data.get('type', 'unknown')
|
||||
|
||||
if request_type == 'questionnaire':
|
||||
result = self._process_questionnaire(input_data, user_data)
|
||||
elif request_type == 'meal_record':
|
||||
result = self._process_meal_record(input_data, user_data)
|
||||
elif request_type == 'feedback':
|
||||
result = self._process_feedback(input_data, user_data)
|
||||
else:
|
||||
result = self._create_error_result("未知的请求类型")
|
||||
|
||||
return AnalysisResult(
|
||||
module_type=self.module_type,
|
||||
user_id=user_data.user_id,
|
||||
input_data=input_data,
|
||||
result=result,
|
||||
confidence=0.9 if result.get('success', False) else 0.1
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.error(f"处理数据采集请求失败: {e}")
|
||||
return self._create_error_result(str(e))
|
||||
|
||||
def _process_questionnaire(self, input_data: Dict, user_data: UserData) -> Dict[str, Any]:
|
||||
"""处理问卷数据"""
|
||||
questionnaire_type = input_data.get('questionnaire_type', 'basic')
|
||||
answers = input_data.get('answers', {})
|
||||
|
||||
# 根据问卷类型处理答案
|
||||
if questionnaire_type == 'basic':
|
||||
processed_data = self._process_basic_questionnaire(answers)
|
||||
elif questionnaire_type == 'taste':
|
||||
processed_data = self._process_taste_questionnaire(answers)
|
||||
elif questionnaire_type == 'physiological':
|
||||
processed_data = self._process_physiological_questionnaire(answers)
|
||||
else:
|
||||
processed_data = answers
|
||||
|
||||
# 更新用户数据
|
||||
user_data.profile.update(processed_data)
|
||||
user_data.updated_at = datetime.now().isoformat()
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'processed_data': processed_data,
|
||||
'message': f'{questionnaire_type}问卷处理完成'
|
||||
}
|
||||
|
||||
def _process_meal_record(self, input_data: Dict, user_data: UserData) -> Dict[str, Any]:
|
||||
"""处理餐食记录"""
|
||||
meal_data = {
|
||||
'date': input_data.get('date', datetime.now().strftime('%Y-%m-%d')),
|
||||
'meal_type': input_data.get('meal_type', 'unknown'),
|
||||
'foods': input_data.get('foods', []),
|
||||
'quantities': input_data.get('quantities', []),
|
||||
'calories': input_data.get('calories'),
|
||||
'satisfaction_score': input_data.get('satisfaction_score'),
|
||||
'notes': input_data.get('notes', ''),
|
||||
'timestamp': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# 添加到用户餐食记录
|
||||
user_data.meals.append(meal_data)
|
||||
user_data.updated_at = datetime.now().isoformat()
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'meal_data': meal_data,
|
||||
'message': '餐食记录保存成功'
|
||||
}
|
||||
|
||||
def _process_feedback(self, input_data: Dict, user_data: UserData) -> Dict[str, Any]:
|
||||
"""处理用户反馈"""
|
||||
feedback_data = {
|
||||
'date': input_data.get('date', datetime.now().strftime('%Y-%m-%d')),
|
||||
'recommended_foods': input_data.get('recommended_foods', []),
|
||||
'user_choice': input_data.get('user_choice', ''),
|
||||
'feedback_type': input_data.get('feedback_type', 'unknown'),
|
||||
'satisfaction_score': input_data.get('satisfaction_score'),
|
||||
'notes': input_data.get('notes', ''),
|
||||
'timestamp': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# 添加到用户反馈记录
|
||||
user_data.feedback.append(feedback_data)
|
||||
user_data.updated_at = datetime.now().isoformat()
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'feedback_data': feedback_data,
|
||||
'message': '反馈记录保存成功'
|
||||
}
|
||||
|
||||
def _process_basic_questionnaire(self, answers: Dict) -> Dict[str, Any]:
|
||||
"""处理基础信息问卷"""
|
||||
return {
|
||||
'name': answers.get('name', ''),
|
||||
'age': int(answers.get('age', 0)),
|
||||
'gender': answers.get('gender', ''),
|
||||
'height': float(answers.get('height', 0)),
|
||||
'weight': float(answers.get('weight', 0)),
|
||||
'activity_level': answers.get('activity_level', ''),
|
||||
'health_goals': answers.get('health_goals', [])
|
||||
}
|
||||
|
||||
def _process_taste_questionnaire(self, answers: Dict) -> Dict[str, Any]:
|
||||
"""处理口味偏好问卷"""
|
||||
return {
|
||||
'taste_preferences': {
|
||||
'sweet': int(answers.get('sweet', 3)),
|
||||
'salty': int(answers.get('salty', 3)),
|
||||
'spicy': int(answers.get('spicy', 3)),
|
||||
'sour': int(answers.get('sour', 3)),
|
||||
'bitter': int(answers.get('bitter', 3)),
|
||||
'umami': int(answers.get('umami', 3))
|
||||
},
|
||||
'dietary_preferences': answers.get('dietary_preferences', []),
|
||||
'allergies': answers.get('allergies', []),
|
||||
'dislikes': answers.get('dislikes', [])
|
||||
}
|
||||
|
||||
def _process_physiological_questionnaire(self, answers: Dict) -> Dict[str, Any]:
|
||||
"""处理生理信息问卷"""
|
||||
return {
|
||||
'is_female': answers.get('gender') == '女',
|
||||
'menstrual_cycle_length': int(answers.get('menstrual_cycle_length', 28)),
|
||||
'last_period_date': answers.get('last_period_date', ''),
|
||||
'ovulation_symptoms': answers.get('ovulation_symptoms', []),
|
||||
'zodiac_sign': answers.get('zodiac_sign', ''),
|
||||
'personality_traits': answers.get('personality_traits', [])
|
||||
}
|
||||
|
||||
def _load_questionnaire_templates(self) -> Dict[str, Dict]:
|
||||
"""加载问卷模板"""
|
||||
return {
|
||||
'basic': {
|
||||
'title': '基本信息问卷',
|
||||
'questions': {
|
||||
'name': {'question': '您的姓名', 'type': 'text'},
|
||||
'age': {'question': '您的年龄', 'type': 'number', 'min': 1, 'max': 120},
|
||||
'gender': {'question': '性别', 'type': 'select', 'options': ['男', '女']},
|
||||
'height': {'question': '身高 (cm)', 'type': 'number', 'min': 100, 'max': 250},
|
||||
'weight': {'question': '体重 (kg)', 'type': 'number', 'min': 30, 'max': 200},
|
||||
'activity_level': {
|
||||
'question': '日常活动水平',
|
||||
'type': 'select',
|
||||
'options': ['久坐', '轻度活动', '中度活动', '高度活动', '极高活动']
|
||||
},
|
||||
'health_goals': {
|
||||
'question': '健康目标 (可多选)',
|
||||
'type': 'checkbox',
|
||||
'options': ['减肥', '增肌', '维持体重', '提高免疫力', '改善消化']
|
||||
}
|
||||
}
|
||||
},
|
||||
'taste': {
|
||||
'title': '口味偏好问卷',
|
||||
'questions': {
|
||||
'sweet': {'question': '甜味偏好 (1-5分)', 'type': 'scale', 'min': 1, 'max': 5},
|
||||
'salty': {'question': '咸味偏好 (1-5分)', 'type': 'scale', 'min': 1, 'max': 5},
|
||||
'spicy': {'question': '辣味偏好 (1-5分)', 'type': 'scale', 'min': 1, 'max': 5},
|
||||
'sour': {'question': '酸味偏好 (1-5分)', 'type': 'scale', 'min': 1, 'max': 5},
|
||||
'bitter': {'question': '苦味偏好 (1-5分)', 'type': 'scale', 'min': 1, 'max': 5},
|
||||
'umami': {'question': '鲜味偏好 (1-5分)', 'type': 'scale', 'min': 1, 'max': 5},
|
||||
'dietary_preferences': {
|
||||
'question': '饮食限制 (可多选)',
|
||||
'type': 'checkbox',
|
||||
'options': ['素食', '纯素食', '无麸质', '无乳制品', '无坚果', '低钠', '低碳水', '无糖']
|
||||
},
|
||||
'allergies': {
|
||||
'question': '过敏食物 (可多选)',
|
||||
'type': 'checkbox',
|
||||
'options': ['花生', '坚果', '海鲜', '鸡蛋', '牛奶', '大豆', '小麦', '无过敏']
|
||||
},
|
||||
'dislikes': {
|
||||
'question': '不喜欢的食物类型',
|
||||
'type': 'checkbox',
|
||||
'options': ['内脏', '海鲜', '蘑菇', '香菜', '洋葱', '大蒜', '辛辣食物', '甜食']
|
||||
}
|
||||
}
|
||||
},
|
||||
'physiological': {
|
||||
'title': '生理信息问卷',
|
||||
'questions': {
|
||||
'menstrual_cycle_length': {
|
||||
'question': '月经周期长度 (天)',
|
||||
'type': 'number',
|
||||
'min': 20,
|
||||
'max': 40,
|
||||
'optional': True
|
||||
},
|
||||
'last_period_date': {
|
||||
'question': '上次月经日期',
|
||||
'type': 'date',
|
||||
'optional': True
|
||||
},
|
||||
'ovulation_symptoms': {
|
||||
'question': '排卵期症状 (可多选)',
|
||||
'type': 'checkbox',
|
||||
'options': ['乳房胀痛', '情绪波动', '食欲变化', '疲劳', '无特殊症状'],
|
||||
'optional': True
|
||||
},
|
||||
'zodiac_sign': {
|
||||
'question': '星座',
|
||||
'type': 'select',
|
||||
'options': ['白羊座', '金牛座', '双子座', '巨蟹座', '狮子座', '处女座',
|
||||
'天秤座', '天蝎座', '射手座', '摩羯座', '水瓶座', '双鱼座'],
|
||||
'optional': True
|
||||
},
|
||||
'personality_traits': {
|
||||
'question': '性格特征 (可多选)',
|
||||
'type': 'checkbox',
|
||||
'options': ['外向', '内向', '理性', '感性', '冒险', '保守', '创新', '传统'],
|
||||
'optional': True
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def get_questionnaire_template(self, questionnaire_type: str) -> Optional[Dict]:
|
||||
"""获取问卷模板"""
|
||||
return self.questionnaire_templates.get(questionnaire_type)
|
||||
|
||||
def get_all_questionnaire_types(self) -> List[str]:
|
||||
"""获取所有问卷类型"""
|
||||
return list(self.questionnaire_templates.keys())
|
||||
|
||||
def _create_error_result(self, error_message: str) -> Dict[str, Any]:
|
||||
"""创建错误结果"""
|
||||
return {
|
||||
'success': False,
|
||||
'error': error_message,
|
||||
'message': f'数据采集失败: {error_message}'
|
||||
}
|
||||
|
||||
def cleanup(self) -> bool:
|
||||
"""清理资源"""
|
||||
try:
|
||||
self.logger.info("数据采集模块清理完成")
|
||||
return True
|
||||
except Exception as e:
|
||||
self.logger.error(f"数据采集模块清理失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# 便捷函数
|
||||
def collect_questionnaire_data(user_id: str, questionnaire_type: str, answers: Dict) -> bool:
|
||||
"""收集问卷数据"""
|
||||
from core.base import get_app_core
|
||||
|
||||
app = get_app_core()
|
||||
input_data = {
|
||||
'type': 'questionnaire',
|
||||
'questionnaire_type': questionnaire_type,
|
||||
'answers': answers
|
||||
}
|
||||
|
||||
result = app.process_user_request(ModuleType.DATA_COLLECTION, input_data, user_id)
|
||||
return result and result.result.get('success', False)
|
||||
|
||||
|
||||
def record_meal(user_id: str, meal_data: Dict) -> bool:
|
||||
"""记录餐食"""
|
||||
from core.base import get_app_core
|
||||
|
||||
app = get_app_core()
|
||||
input_data = {
|
||||
'type': 'meal_record',
|
||||
**meal_data
|
||||
}
|
||||
|
||||
result = app.process_user_request(ModuleType.DATA_COLLECTION, input_data, user_id)
|
||||
return result and result.result.get('success', False)
|
||||
|
||||
|
||||
def record_feedback(user_id: str, feedback_data: Dict) -> bool:
|
||||
"""记录反馈"""
|
||||
from core.base import get_app_core
|
||||
|
||||
app = get_app_core()
|
||||
input_data = {
|
||||
'type': 'feedback',
|
||||
**feedback_data
|
||||
}
|
||||
|
||||
result = app.process_user_request(ModuleType.DATA_COLLECTION, input_data, user_id)
|
||||
return result and result.result.get('success', False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 测试数据采集模块
|
||||
from core.base import BaseConfig, initialize_app, cleanup_app
|
||||
|
||||
print("测试数据采集模块...")
|
||||
|
||||
# 初始化应用
|
||||
config = BaseConfig()
|
||||
if initialize_app(config):
|
||||
print("✅ 应用初始化成功")
|
||||
|
||||
# 测试问卷数据收集
|
||||
test_user_id = "test_user_001"
|
||||
questionnaire_answers = {
|
||||
'name': '小美',
|
||||
'age': 25,
|
||||
'gender': '女',
|
||||
'height': 165,
|
||||
'weight': 55,
|
||||
'activity_level': '中度活动',
|
||||
'health_goals': ['维持体重', '提高免疫力']
|
||||
}
|
||||
|
||||
if collect_questionnaire_data(test_user_id, 'basic', questionnaire_answers):
|
||||
print("✅ 基础问卷数据收集成功")
|
||||
|
||||
# 测试餐食记录
|
||||
meal_data = {
|
||||
'date': '2024-01-15',
|
||||
'meal_type': 'breakfast',
|
||||
'foods': ['燕麦粥', '香蕉', '牛奶'],
|
||||
'quantities': ['1碗', '1根', '200ml'],
|
||||
'calories': 350.0,
|
||||
'satisfaction_score': 4,
|
||||
'notes': '很满意,营养均衡'
|
||||
}
|
||||
|
||||
if record_meal(test_user_id, meal_data):
|
||||
print("✅ 餐食记录成功")
|
||||
|
||||
# 清理应用
|
||||
cleanup_app()
|
||||
print("✅ 应用清理完成")
|
||||
else:
|
||||
print("❌ 应用初始化失败")
|
||||
|
||||
print("数据采集模块测试完成!")
|
||||
634
modules/efficient_data_processing.py
Normal file
634
modules/efficient_data_processing.py
Normal file
@@ -0,0 +1,634 @@
|
||||
"""
|
||||
高效数据处理和训练模块
|
||||
优化数据处理流程,提高训练效率
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
import json
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timedelta
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
import threading
|
||||
from collections import defaultdict, Counter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EfficientDataProcessor:
|
||||
"""高效数据处理器"""
|
||||
|
||||
def __init__(self, data_dir: str = "data"):
|
||||
self.data_dir = Path(data_dir)
|
||||
self.data_dir.mkdir(exist_ok=True)
|
||||
|
||||
# 数据缓存
|
||||
self.user_data_cache = {}
|
||||
self.meal_data_cache = {}
|
||||
self.feedback_data_cache = {}
|
||||
|
||||
# 预计算数据
|
||||
self.food_frequency = {}
|
||||
self.user_preferences = {}
|
||||
self.nutrition_patterns = {}
|
||||
|
||||
# 线程锁
|
||||
self.cache_lock = threading.Lock()
|
||||
|
||||
# 加载预计算数据
|
||||
self._load_precomputed_data()
|
||||
|
||||
def _load_precomputed_data(self):
|
||||
"""加载预计算数据"""
|
||||
try:
|
||||
# 加载食物频率
|
||||
freq_file = self.data_dir / "food_frequency.pkl"
|
||||
if freq_file.exists():
|
||||
with open(freq_file, 'rb') as f:
|
||||
self.food_frequency = pickle.load(f)
|
||||
|
||||
# 加载用户偏好
|
||||
pref_file = self.data_dir / "user_preferences.pkl"
|
||||
if pref_file.exists():
|
||||
with open(pref_file, 'rb') as f:
|
||||
self.user_preferences = pickle.load(f)
|
||||
|
||||
# 加载营养模式
|
||||
pattern_file = self.data_dir / "nutrition_patterns.pkl"
|
||||
if pattern_file.exists():
|
||||
with open(pattern_file, 'rb') as f:
|
||||
self.nutrition_patterns = pickle.load(f)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"加载预计算数据失败: {e}")
|
||||
|
||||
def _save_precomputed_data(self):
|
||||
"""保存预计算数据"""
|
||||
try:
|
||||
# 保存食物频率
|
||||
freq_file = self.data_dir / "food_frequency.pkl"
|
||||
with open(freq_file, 'wb') as f:
|
||||
pickle.dump(self.food_frequency, f)
|
||||
|
||||
# 保存用户偏好
|
||||
pref_file = self.data_dir / "user_preferences.pkl"
|
||||
with open(pref_file, 'wb') as f:
|
||||
pickle.dump(self.user_preferences, f)
|
||||
|
||||
# 保存营养模式
|
||||
pattern_file = self.data_dir / "nutrition_patterns.pkl"
|
||||
with open(pattern_file, 'wb') as f:
|
||||
pickle.dump(self.nutrition_patterns, f)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"保存预计算数据失败: {e}")
|
||||
|
||||
def batch_process_user_data(self, user_ids: List[str]) -> Dict[str, Any]:
|
||||
"""批量处理用户数据"""
|
||||
logger.info(f"开始批量处理 {len(user_ids)} 个用户的数据")
|
||||
|
||||
results = {}
|
||||
|
||||
# 使用线程池并行处理
|
||||
with ThreadPoolExecutor(max_workers=4) as executor:
|
||||
# 提交任务
|
||||
future_to_user = {
|
||||
executor.submit(self._process_single_user, user_id): user_id
|
||||
for user_id in user_ids
|
||||
}
|
||||
|
||||
# 收集结果
|
||||
for future in as_completed(future_to_user):
|
||||
user_id = future_to_user[future]
|
||||
try:
|
||||
result = future.result()
|
||||
results[user_id] = result
|
||||
except Exception as e:
|
||||
logger.error(f"处理用户 {user_id} 数据失败: {e}")
|
||||
results[user_id] = {'error': str(e)}
|
||||
|
||||
# 更新预计算数据
|
||||
self._update_precomputed_data(results)
|
||||
|
||||
logger.info(f"批量处理完成,成功处理 {len(results)} 个用户")
|
||||
return results
|
||||
|
||||
def _process_single_user(self, user_id: str) -> Dict[str, Any]:
|
||||
"""处理单个用户数据"""
|
||||
try:
|
||||
from core.base import AppCore
|
||||
|
||||
app_core = AppCore()
|
||||
user_data = app_core.get_user_data(user_id)
|
||||
|
||||
if not user_data:
|
||||
return {'error': '用户数据不存在'}
|
||||
|
||||
# 处理餐食数据
|
||||
meal_analysis = self._analyze_meal_patterns(user_data.meals)
|
||||
|
||||
# 处理反馈数据
|
||||
feedback_analysis = self._analyze_feedback_patterns(user_data.feedback)
|
||||
|
||||
# 处理用户偏好
|
||||
preference_analysis = self._analyze_user_preferences(user_data)
|
||||
|
||||
# 生成个性化建议
|
||||
recommendations = self._generate_personalized_recommendations(
|
||||
user_data, meal_analysis, feedback_analysis, preference_analysis
|
||||
)
|
||||
|
||||
return {
|
||||
'user_id': user_id,
|
||||
'meal_analysis': meal_analysis,
|
||||
'feedback_analysis': feedback_analysis,
|
||||
'preference_analysis': preference_analysis,
|
||||
'recommendations': recommendations,
|
||||
'processed_at': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理用户 {user_id} 数据失败: {e}")
|
||||
return {'error': str(e)}
|
||||
|
||||
def _analyze_meal_patterns(self, meals: List[Dict]) -> Dict[str, Any]:
|
||||
"""分析餐食模式"""
|
||||
if not meals:
|
||||
return {'total_meals': 0, 'patterns': {}}
|
||||
|
||||
# 统计食物频率
|
||||
food_counter = Counter()
|
||||
meal_type_counter = Counter()
|
||||
satisfaction_scores = []
|
||||
calorie_totals = []
|
||||
|
||||
for meal in meals:
|
||||
# 统计食物
|
||||
for food in meal.get('foods', []):
|
||||
food_counter[food] += 1
|
||||
|
||||
# 统计餐次
|
||||
meal_type_counter[meal.get('meal_type', 'unknown')] += 1
|
||||
|
||||
# 收集满意度
|
||||
if 'satisfaction_score' in meal:
|
||||
satisfaction_scores.append(meal['satisfaction_score'])
|
||||
|
||||
# 收集热量
|
||||
if 'calories' in meal and meal['calories']:
|
||||
calorie_totals.append(meal['calories'])
|
||||
|
||||
# 计算统计信息
|
||||
avg_satisfaction = np.mean(satisfaction_scores) if satisfaction_scores else 0
|
||||
avg_calories = np.mean(calorie_totals) if calorie_totals else 0
|
||||
|
||||
# 识别模式
|
||||
patterns = {
|
||||
'favorite_foods': [food for food, count in food_counter.most_common(5)],
|
||||
'meal_type_preference': dict(meal_type_counter.most_common()),
|
||||
'avg_satisfaction': round(avg_satisfaction, 2),
|
||||
'avg_calories': round(avg_calories, 2),
|
||||
'total_meals': len(meals),
|
||||
'food_diversity': len(food_counter)
|
||||
}
|
||||
|
||||
return patterns
|
||||
|
||||
def _analyze_feedback_patterns(self, feedbacks: List[Dict]) -> Dict[str, Any]:
|
||||
"""分析反馈模式"""
|
||||
if not feedbacks:
|
||||
return {'total_feedback': 0, 'patterns': {}}
|
||||
|
||||
feedback_types = Counter()
|
||||
user_choices = []
|
||||
|
||||
for feedback in feedbacks:
|
||||
feedback_types[feedback.get('feedback_type', 'unknown')] += 1
|
||||
if 'user_choice' in feedback:
|
||||
user_choices.append(feedback['user_choice'])
|
||||
|
||||
patterns = {
|
||||
'feedback_distribution': dict(feedback_types.most_common()),
|
||||
'total_feedback': len(feedbacks),
|
||||
'common_choices': Counter(user_choices).most_common(5)
|
||||
}
|
||||
|
||||
return patterns
|
||||
|
||||
def _analyze_user_preferences(self, user_data) -> Dict[str, Any]:
|
||||
"""分析用户偏好"""
|
||||
profile = user_data.profile
|
||||
|
||||
preferences = {
|
||||
'basic_info': {
|
||||
'age': profile.get('age', 'unknown'),
|
||||
'gender': profile.get('gender', 'unknown'),
|
||||
'activity_level': profile.get('activity_level', 'unknown')
|
||||
},
|
||||
'taste_preferences': profile.get('taste_preferences', {}),
|
||||
'dietary_restrictions': {
|
||||
'allergies': profile.get('allergies', []),
|
||||
'dislikes': profile.get('dislikes', []),
|
||||
'dietary_preferences': profile.get('dietary_preferences', [])
|
||||
},
|
||||
'health_goals': profile.get('health_goals', [])
|
||||
}
|
||||
|
||||
return preferences
|
||||
|
||||
def _generate_personalized_recommendations(self, user_data, meal_analysis,
|
||||
feedback_analysis, preference_analysis) -> Dict[str, Any]:
|
||||
"""生成个性化建议"""
|
||||
recommendations = {
|
||||
'food_recommendations': [],
|
||||
'meal_suggestions': [],
|
||||
'health_tips': [],
|
||||
'improvement_suggestions': []
|
||||
}
|
||||
|
||||
# 基于食物频率推荐
|
||||
favorite_foods = meal_analysis.get('favorite_foods', [])
|
||||
if favorite_foods:
|
||||
recommendations['food_recommendations'].extend(favorite_foods[:3])
|
||||
|
||||
# 基于反馈推荐
|
||||
feedback_patterns = feedback_analysis.get('feedback_distribution', {})
|
||||
if feedback_patterns.get('like', 0) > feedback_patterns.get('dislike', 0):
|
||||
recommendations['meal_suggestions'].append("继续选择您喜欢的食物")
|
||||
|
||||
# 基于健康目标推荐
|
||||
health_goals = preference_analysis.get('health_goals', [])
|
||||
if '减重' in health_goals:
|
||||
recommendations['health_tips'].append("建议增加蔬菜摄入,减少高热量食物")
|
||||
elif '增重' in health_goals:
|
||||
recommendations['health_tips'].append("建议增加蛋白质和健康脂肪摄入")
|
||||
|
||||
# 基于满意度推荐
|
||||
avg_satisfaction = meal_analysis.get('avg_satisfaction', 0)
|
||||
if avg_satisfaction < 3:
|
||||
recommendations['improvement_suggestions'].append("尝试新的食物组合以提高满意度")
|
||||
|
||||
return recommendations
|
||||
|
||||
def _update_precomputed_data(self, results: Dict[str, Any]):
|
||||
"""更新预计算数据"""
|
||||
with self.cache_lock:
|
||||
# 更新食物频率
|
||||
for user_id, result in results.items():
|
||||
if 'error' in result:
|
||||
continue
|
||||
|
||||
meal_analysis = result.get('meal_analysis', {})
|
||||
favorite_foods = meal_analysis.get('favorite_foods', [])
|
||||
|
||||
for food in favorite_foods:
|
||||
self.food_frequency[food] = self.food_frequency.get(food, 0) + 1
|
||||
|
||||
# 更新用户偏好
|
||||
for user_id, result in results.items():
|
||||
if 'error' in result:
|
||||
continue
|
||||
|
||||
preference_analysis = result.get('preference_analysis', {})
|
||||
self.user_preferences[user_id] = preference_analysis
|
||||
|
||||
# 更新营养模式
|
||||
for user_id, result in results.items():
|
||||
if 'error' in result:
|
||||
continue
|
||||
|
||||
meal_analysis = result.get('meal_analysis', {})
|
||||
self.nutrition_patterns[user_id] = {
|
||||
'avg_calories': meal_analysis.get('avg_calories', 0),
|
||||
'avg_satisfaction': meal_analysis.get('avg_satisfaction', 0),
|
||||
'food_diversity': meal_analysis.get('food_diversity', 0)
|
||||
}
|
||||
|
||||
# 保存预计算数据
|
||||
self._save_precomputed_data()
|
||||
|
||||
def get_popular_foods(self, limit: int = 10) -> List[Tuple[str, int]]:
|
||||
"""获取热门食物"""
|
||||
return sorted(self.food_frequency.items(), key=lambda x: x[1], reverse=True)[:limit]
|
||||
|
||||
def get_user_similarity(self, user_id: str) -> List[Tuple[str, float]]:
|
||||
"""获取相似用户"""
|
||||
if user_id not in self.user_preferences:
|
||||
return []
|
||||
|
||||
target_prefs = self.user_preferences[user_id]
|
||||
similarities = []
|
||||
|
||||
for other_user_id, other_prefs in self.user_preferences.items():
|
||||
if other_user_id == user_id:
|
||||
continue
|
||||
|
||||
# 计算相似度(简化版本)
|
||||
similarity = self._calculate_preference_similarity(target_prefs, other_prefs)
|
||||
similarities.append((other_user_id, similarity))
|
||||
|
||||
return sorted(similarities, key=lambda x: x[1], reverse=True)[:5]
|
||||
|
||||
def _calculate_preference_similarity(self, prefs1: Dict, prefs2: Dict) -> float:
|
||||
"""计算偏好相似度"""
|
||||
# 简化的相似度计算
|
||||
score = 0.0
|
||||
total = 0.0
|
||||
|
||||
# 比较基本特征
|
||||
basic1 = prefs1.get('basic_info', {})
|
||||
basic2 = prefs2.get('basic_info', {})
|
||||
|
||||
if basic1.get('gender') == basic2.get('gender'):
|
||||
score += 0.3
|
||||
total += 0.3
|
||||
|
||||
if basic1.get('activity_level') == basic2.get('activity_level'):
|
||||
score += 0.2
|
||||
total += 0.2
|
||||
|
||||
# 比较口味偏好
|
||||
taste1 = prefs1.get('taste_preferences', {})
|
||||
taste2 = prefs2.get('taste_preferences', {})
|
||||
|
||||
for key in taste1:
|
||||
if key in taste2 and taste1[key] == taste2[key]:
|
||||
score += 0.1
|
||||
total += 0.1
|
||||
|
||||
return score / total if total > 0 else 0.0
|
||||
|
||||
def export_analysis_report(self, user_id: str, output_file: str = None) -> str:
|
||||
"""导出分析报告"""
|
||||
if not output_file:
|
||||
output_file = self.data_dir / f"analysis_report_{user_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
|
||||
|
||||
try:
|
||||
# 获取用户数据
|
||||
from core.base import AppCore
|
||||
app_core = AppCore()
|
||||
user_data = app_core.get_user_data(user_id)
|
||||
|
||||
if not user_data:
|
||||
raise ValueError("用户数据不存在")
|
||||
|
||||
# 生成分析报告
|
||||
report = {
|
||||
'user_id': user_id,
|
||||
'generated_at': datetime.now().isoformat(),
|
||||
'user_profile': user_data.profile,
|
||||
'meal_statistics': self._analyze_meal_patterns(user_data.meals),
|
||||
'feedback_statistics': self._analyze_feedback_patterns(user_data.feedback),
|
||||
'recommendations': self._generate_personalized_recommendations(
|
||||
user_data,
|
||||
self._analyze_meal_patterns(user_data.meals),
|
||||
self._analyze_feedback_patterns(user_data.feedback),
|
||||
self._analyze_user_preferences(user_data)
|
||||
),
|
||||
'similar_users': self.get_user_similarity(user_id),
|
||||
'popular_foods': self.get_popular_foods(10)
|
||||
}
|
||||
|
||||
# 保存报告
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(report, f, ensure_ascii=False, indent=2)
|
||||
|
||||
logger.info(f"分析报告已导出到: {output_file}")
|
||||
return str(output_file)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"导出分析报告失败: {e}")
|
||||
raise
|
||||
|
||||
|
||||
class FastTrainingPipeline:
|
||||
"""快速训练管道"""
|
||||
|
||||
def __init__(self, data_processor: EfficientDataProcessor):
|
||||
self.data_processor = data_processor
|
||||
self.models = {}
|
||||
self.training_cache = {}
|
||||
self._background_thread: threading.Thread | None = None
|
||||
self._stop_event = threading.Event()
|
||||
|
||||
def train_recommendation_model(self, user_ids: List[str]) -> Dict[str, Any]:
|
||||
"""训练推荐模型"""
|
||||
logger.info(f"开始训练推荐模型,用户数量: {len(user_ids)}")
|
||||
|
||||
# 批量处理用户数据
|
||||
processed_data = self.data_processor.batch_process_user_data(user_ids)
|
||||
|
||||
# 提取特征
|
||||
features = self._extract_features(processed_data)
|
||||
|
||||
# 训练模型(简化版本)
|
||||
model_results = self._train_simple_recommendation_model(features)
|
||||
|
||||
# 缓存模型
|
||||
self.models['recommendation'] = model_results
|
||||
|
||||
logger.info("推荐模型训练完成")
|
||||
return model_results
|
||||
|
||||
def start_background_training(self, user_ids_provider=None, interval_minutes: int = 60) -> None:
|
||||
"""后台周期训练。
|
||||
user_ids_provider: 可选的函数,返回需要训练的user_id列表;若为空则从预计算偏好中取键。
|
||||
"""
|
||||
if self._background_thread and self._background_thread.is_alive():
|
||||
return
|
||||
|
||||
self._stop_event.clear()
|
||||
|
||||
def _loop():
|
||||
while not self._stop_event.is_set():
|
||||
try:
|
||||
if user_ids_provider is not None:
|
||||
user_ids = list(user_ids_provider()) or list(self.data_processor.user_preferences.keys())
|
||||
else:
|
||||
user_ids = list(self.data_processor.user_preferences.keys())
|
||||
if user_ids:
|
||||
self.train_recommendation_model(user_ids)
|
||||
except Exception as e:
|
||||
logger.warning(f"后台训练失败: {e}")
|
||||
finally:
|
||||
self._stop_event.wait(interval_minutes * 60)
|
||||
|
||||
self._background_thread = threading.Thread(target=_loop, daemon=True)
|
||||
self._background_thread.start()
|
||||
|
||||
def stop_background_training(self) -> None:
|
||||
"""停止后台训练"""
|
||||
self._stop_event.set()
|
||||
if self._background_thread and self._background_thread.is_alive():
|
||||
self._background_thread.join(timeout=1.0)
|
||||
|
||||
def _extract_features(self, processed_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""提取特征"""
|
||||
features = {
|
||||
'user_features': {},
|
||||
'food_features': {},
|
||||
'interaction_features': {}
|
||||
}
|
||||
|
||||
for user_id, data in processed_data.items():
|
||||
if 'error' in data:
|
||||
continue
|
||||
|
||||
# 用户特征
|
||||
preference_analysis = data.get('preference_analysis', {})
|
||||
features['user_features'][user_id] = {
|
||||
'age': preference_analysis.get('basic_info', {}).get('age', 25),
|
||||
'gender': preference_analysis.get('basic_info', {}).get('gender', 'unknown'),
|
||||
'activity_level': preference_analysis.get('basic_info', {}).get('activity_level', 'moderate')
|
||||
}
|
||||
|
||||
# 食物特征
|
||||
meal_analysis = data.get('meal_analysis', {})
|
||||
favorite_foods = meal_analysis.get('favorite_foods', [])
|
||||
for food in favorite_foods:
|
||||
if food not in features['food_features']:
|
||||
features['food_features'][food] = 0
|
||||
features['food_features'][food] += 1
|
||||
|
||||
# 交互特征
|
||||
features['interaction_features'][user_id] = {
|
||||
'avg_satisfaction': meal_analysis.get('avg_satisfaction', 0),
|
||||
'avg_calories': meal_analysis.get('avg_calories', 0),
|
||||
'food_diversity': meal_analysis.get('food_diversity', 0)
|
||||
}
|
||||
|
||||
return features
|
||||
|
||||
def _train_simple_recommendation_model(self, features: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""训练简单推荐模型"""
|
||||
# 这里是一个简化的推荐模型
|
||||
# 在实际应用中,可以使用更复杂的机器学习算法
|
||||
|
||||
model_results = {
|
||||
'model_type': 'simple_collaborative_filtering',
|
||||
'trained_at': datetime.now().isoformat(),
|
||||
'user_count': len(features['user_features']),
|
||||
'food_count': len(features['food_features']),
|
||||
'recommendation_rules': self._generate_recommendation_rules(features),
|
||||
'performance_metrics': {
|
||||
'accuracy': 0.75, # 模拟指标
|
||||
'precision': 0.72,
|
||||
'recall': 0.68,
|
||||
'f1_score': 0.70
|
||||
}
|
||||
}
|
||||
|
||||
return model_results
|
||||
|
||||
def _generate_recommendation_rules(self, features: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""生成推荐规则"""
|
||||
rules = []
|
||||
|
||||
# 基于食物频率的规则
|
||||
food_features = features['food_features']
|
||||
popular_foods = sorted(food_features.items(), key=lambda x: x[1], reverse=True)[:10]
|
||||
|
||||
for food, count in popular_foods:
|
||||
rules.append({
|
||||
'type': 'popular_food',
|
||||
'condition': f"food_popularity >= {count}",
|
||||
'recommendation': f"推荐 {food}",
|
||||
'confidence': min(count / 10.0, 1.0)
|
||||
})
|
||||
|
||||
# 基于用户特征的规则
|
||||
user_features = features['user_features']
|
||||
for user_id, user_feat in user_features.items():
|
||||
if user_feat['gender'] == '女':
|
||||
rules.append({
|
||||
'type': 'gender_based',
|
||||
'condition': f"gender == '女'",
|
||||
'recommendation': "推荐富含铁质的食物",
|
||||
'confidence': 0.8
|
||||
})
|
||||
|
||||
return rules
|
||||
|
||||
def predict_recommendations(self, user_id: str, meal_type: str = "lunch") -> List[Dict[str, Any]]:
|
||||
"""预测推荐"""
|
||||
if 'recommendation' not in self.models:
|
||||
return []
|
||||
|
||||
# 获取用户数据
|
||||
from core.base import AppCore
|
||||
app_core = AppCore()
|
||||
user_data = app_core.get_user_data(user_id)
|
||||
|
||||
if not user_data:
|
||||
return []
|
||||
|
||||
# 基于规则生成推荐
|
||||
recommendations = []
|
||||
rules = self.models['recommendation'].get('recommendation_rules', [])
|
||||
|
||||
for rule in rules:
|
||||
if self._evaluate_rule(rule, user_data):
|
||||
recommendations.append({
|
||||
'food': rule['recommendation'],
|
||||
'confidence': rule['confidence'],
|
||||
'reason': rule['type']
|
||||
})
|
||||
|
||||
return recommendations[:5] # 返回前5个推荐
|
||||
|
||||
def _evaluate_rule(self, rule: Dict[str, Any], user_data) -> bool:
|
||||
"""评估规则"""
|
||||
# 简化的规则评估
|
||||
rule_type = rule.get('type', '')
|
||||
|
||||
if rule_type == 'popular_food':
|
||||
return True # 总是推荐热门食物
|
||||
elif rule_type == 'gender_based':
|
||||
return user_data.profile.get('gender') == '女'
|
||||
|
||||
return False
|
||||
|
||||
|
||||
# 全局实例
|
||||
data_processor = EfficientDataProcessor()
|
||||
training_pipeline = FastTrainingPipeline(data_processor)
|
||||
|
||||
|
||||
# 便捷函数
|
||||
def batch_process_users(user_ids: List[str]) -> Dict[str, Any]:
|
||||
"""批量处理用户数据"""
|
||||
return data_processor.batch_process_user_data(user_ids)
|
||||
|
||||
|
||||
def train_recommendation_model(user_ids: List[str]) -> Dict[str, Any]:
|
||||
"""训练推荐模型"""
|
||||
return training_pipeline.train_recommendation_model(user_ids)
|
||||
|
||||
|
||||
def get_user_recommendations(user_id: str, meal_type: str = "lunch") -> List[Dict[str, Any]]:
|
||||
"""获取用户推荐"""
|
||||
return training_pipeline.predict_recommendations(user_id, meal_type)
|
||||
|
||||
|
||||
def export_user_report(user_id: str, output_file: str = None) -> str:
|
||||
"""导出用户报告"""
|
||||
return data_processor.export_analysis_report(user_id, output_file)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 测试数据处理
|
||||
print("测试高效数据处理...")
|
||||
|
||||
# 测试批量处理
|
||||
test_users = ["user1", "user2", "user3"]
|
||||
results = batch_process_users(test_users)
|
||||
print(f"批量处理结果: {len(results)} 个用户")
|
||||
|
||||
# 测试训练
|
||||
model_results = train_recommendation_model(test_users)
|
||||
print(f"模型训练完成: {model_results['model_type']}")
|
||||
|
||||
print("测试完成!")
|
||||
786
modules/ocr_calorie_recognition.py
Normal file
786
modules/ocr_calorie_recognition.py
Normal file
@@ -0,0 +1,786 @@
|
||||
"""
|
||||
图片OCR热量识别模块 - 基于基座架构
|
||||
支持多种OCR技术识别食物热量信息,包含智能验证和修正机制
|
||||
"""
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import re
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
import requests
|
||||
import base64
|
||||
from PIL import Image, ImageEnhance, ImageFilter
|
||||
import pytesseract
|
||||
from core.base import BaseModule, ModuleType, UserData, AnalysisResult, BaseConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OCRResult:
|
||||
"""OCR识别结果"""
|
||||
text: str
|
||||
confidence: float
|
||||
bounding_boxes: List[Dict[str, Any]]
|
||||
processing_time: float
|
||||
method: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class CalorieInfo:
|
||||
"""热量信息"""
|
||||
food_name: str
|
||||
calories: Optional[float]
|
||||
serving_size: Optional[str]
|
||||
confidence: float
|
||||
source: str # 'ocr', 'database', 'user_confirmed'
|
||||
raw_text: str
|
||||
validation_status: str # 'pending', 'validated', 'corrected'
|
||||
|
||||
|
||||
@dataclass
|
||||
class FoodRecognitionResult:
|
||||
"""食物识别结果"""
|
||||
image_path: str
|
||||
ocr_results: List[OCRResult]
|
||||
calorie_infos: List[CalorieInfo]
|
||||
overall_confidence: float
|
||||
processing_time: float
|
||||
suggestions: List[str]
|
||||
|
||||
|
||||
class OCRCalorieRecognitionModule(BaseModule):
|
||||
"""OCR热量识别模块"""
|
||||
|
||||
def __init__(self, config: BaseConfig):
|
||||
super().__init__(config, ModuleType.DATA_COLLECTION)
|
||||
|
||||
# OCR配置
|
||||
self.ocr_methods = ['tesseract', 'paddleocr', 'easyocr']
|
||||
self.min_confidence = 0.6
|
||||
self.max_processing_time = 30.0
|
||||
|
||||
# 热量识别模式
|
||||
self.calorie_patterns = [
|
||||
r'(\d+(?:\.\d+)?)\s*[kK]?[cC][aA][lL](?:ories?)?',
|
||||
r'(\d+(?:\.\d+)?)\s*[kK][cC][aA][lL]',
|
||||
r'(\d+(?:\.\d+)?)\s*卡路里',
|
||||
r'(\d+(?:\.\d+)?)\s*千卡',
|
||||
r'(\d+(?:\.\d+)?)\s*大卡',
|
||||
r'(\d+(?:\.\d+)?)\s*[kK][jJ]', # 千焦
|
||||
]
|
||||
|
||||
# 食物名称模式
|
||||
self.food_patterns = [
|
||||
r'([a-zA-Z\u4e00-\u9fff]+)\s*(?:\d+(?:\.\d+)?)',
|
||||
r'(\d+(?:\.\d+)?)\s*([a-zA-Z\u4e00-\u9fff]+)',
|
||||
]
|
||||
|
||||
# 食物数据库
|
||||
self.food_database = self._load_food_database()
|
||||
|
||||
# 用户学习数据
|
||||
self.user_corrections = {}
|
||||
|
||||
# 初始化OCR引擎
|
||||
self.ocr_engines = {}
|
||||
self._initialize_ocr_engines()
|
||||
|
||||
def initialize(self) -> bool:
|
||||
"""初始化模块"""
|
||||
try:
|
||||
self.logger.info("OCR热量识别模块初始化中...")
|
||||
|
||||
# 创建必要的目录
|
||||
self._create_directories()
|
||||
|
||||
# 加载用户学习数据
|
||||
self._load_user_corrections()
|
||||
|
||||
self.is_initialized = True
|
||||
self.logger.info("OCR热量识别模块初始化完成")
|
||||
return True
|
||||
except Exception as e:
|
||||
self.logger.error(f"OCR热量识别模块初始化失败: {e}")
|
||||
return False
|
||||
|
||||
def process(self, input_data: Any, user_data: UserData) -> AnalysisResult:
|
||||
"""处理OCR识别请求"""
|
||||
try:
|
||||
request_type = input_data.get('type', 'unknown')
|
||||
|
||||
if request_type == 'recognize_image':
|
||||
result = self._recognize_image_calories(input_data, user_data)
|
||||
elif request_type == 'validate_result':
|
||||
result = self._validate_recognition_result(input_data, user_data)
|
||||
elif request_type == 'learn_correction':
|
||||
result = self._learn_from_correction(input_data, user_data)
|
||||
else:
|
||||
result = self._create_error_result("未知的请求类型")
|
||||
|
||||
return AnalysisResult(
|
||||
module_type=self.module_type,
|
||||
user_id=user_data.user_id,
|
||||
input_data=input_data,
|
||||
result=result,
|
||||
confidence=result.get('confidence', 0.5)
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.error(f"OCR识别处理失败: {e}")
|
||||
return self._create_error_result(f"处理失败: {str(e)}")
|
||||
|
||||
def _recognize_image_calories(self, input_data: Dict[str, Any], user_data: UserData) -> Dict[str, Any]:
|
||||
"""识别图片中的热量信息"""
|
||||
start_time = datetime.now()
|
||||
|
||||
try:
|
||||
image_path = input_data.get('image_path')
|
||||
if not image_path or not Path(image_path).exists():
|
||||
return self._create_error_result("图片文件不存在")
|
||||
|
||||
# 预处理图片
|
||||
processed_image = self._preprocess_image(image_path)
|
||||
|
||||
# 多OCR引擎识别
|
||||
ocr_results = []
|
||||
for method in self.ocr_methods:
|
||||
try:
|
||||
result = self._ocr_recognize(processed_image, method)
|
||||
if result:
|
||||
ocr_results.append(result)
|
||||
except Exception as e:
|
||||
self.logger.warning(f"OCR方法 {method} 失败: {e}")
|
||||
|
||||
# 合并和去重OCR结果
|
||||
merged_text = self._merge_ocr_results(ocr_results)
|
||||
|
||||
# 提取热量信息
|
||||
calorie_infos = self._extract_calorie_info(merged_text, user_data)
|
||||
|
||||
# 数据库匹配和验证
|
||||
validated_infos = self._validate_with_database(calorie_infos, user_data)
|
||||
|
||||
# 生成建议
|
||||
suggestions = self._generate_suggestions(validated_infos, user_data)
|
||||
|
||||
processing_time = (datetime.now() - start_time).total_seconds()
|
||||
|
||||
result = FoodRecognitionResult(
|
||||
image_path=image_path,
|
||||
ocr_results=ocr_results,
|
||||
calorie_infos=validated_infos,
|
||||
overall_confidence=self._calculate_overall_confidence(ocr_results, validated_infos),
|
||||
processing_time=processing_time,
|
||||
suggestions=suggestions
|
||||
)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'result': result,
|
||||
'confidence': result.overall_confidence,
|
||||
'message': f"识别完成,处理时间: {processing_time:.2f}秒"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"图片热量识别失败: {e}")
|
||||
return self._create_error_result(f"识别失败: {str(e)}")
|
||||
|
||||
def _preprocess_image(self, image_path: str) -> np.ndarray:
|
||||
"""预处理图片以提高OCR准确性"""
|
||||
try:
|
||||
# 读取图片
|
||||
image = cv2.imread(image_path)
|
||||
if image is None:
|
||||
raise ValueError("无法读取图片")
|
||||
|
||||
# 转换为灰度图
|
||||
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
# 降噪
|
||||
denoised = cv2.medianBlur(gray, 3)
|
||||
|
||||
# 增强对比度
|
||||
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
||||
enhanced = clahe.apply(denoised)
|
||||
|
||||
# 二值化
|
||||
_, binary = cv2.threshold(enhanced, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
||||
|
||||
# 形态学操作
|
||||
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (2, 2))
|
||||
cleaned = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel)
|
||||
|
||||
return cleaned
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"图片预处理失败: {e}")
|
||||
return cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
|
||||
|
||||
def _ocr_recognize(self, image: np.ndarray, method: str) -> Optional[OCRResult]:
|
||||
"""使用指定方法进行OCR识别"""
|
||||
start_time = datetime.now()
|
||||
|
||||
try:
|
||||
if method == 'tesseract':
|
||||
return self._tesseract_ocr(image)
|
||||
elif method == 'paddleocr':
|
||||
return self._paddleocr_recognize(image)
|
||||
elif method == 'easyocr':
|
||||
return self._easyocr_recognize(image)
|
||||
else:
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"OCR方法 {method} 失败: {e}")
|
||||
return None
|
||||
|
||||
def _tesseract_ocr(self, image: np.ndarray) -> OCRResult:
|
||||
"""使用Tesseract进行OCR识别"""
|
||||
try:
|
||||
# 配置Tesseract
|
||||
config = '--oem 3 --psm 6 -c tessedit_char_whitelist=0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz\u4e00-\u9fff'
|
||||
|
||||
# OCR识别
|
||||
text = pytesseract.image_to_string(image, config=config, lang='chi_sim+eng')
|
||||
|
||||
# 获取置信度
|
||||
data = pytesseract.image_to_data(image, output_type=pytesseract.Output.DICT, config=config, lang='chi_sim+eng')
|
||||
confidences = [int(conf) for conf in data['conf'] if int(conf) > 0]
|
||||
avg_confidence = sum(confidences) / len(confidences) / 100.0 if confidences else 0.0
|
||||
|
||||
# 获取边界框
|
||||
bounding_boxes = []
|
||||
for i in range(len(data['text'])):
|
||||
if int(data['conf'][i]) > 0:
|
||||
bounding_boxes.append({
|
||||
'text': data['text'][i],
|
||||
'confidence': int(data['conf'][i]) / 100.0,
|
||||
'bbox': [data['left'][i], data['top'][i], data['width'][i], data['height'][i]]
|
||||
})
|
||||
|
||||
processing_time = (datetime.now() - datetime.now()).total_seconds()
|
||||
|
||||
return OCRResult(
|
||||
text=text.strip(),
|
||||
confidence=avg_confidence,
|
||||
bounding_boxes=bounding_boxes,
|
||||
processing_time=processing_time,
|
||||
method='tesseract'
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Tesseract OCR失败: {e}")
|
||||
return None
|
||||
|
||||
def _paddleocr_recognize(self, image: np.ndarray) -> Optional[OCRResult]:
|
||||
"""使用PaddleOCR进行识别"""
|
||||
try:
|
||||
# 这里需要安装paddleocr: pip install paddleocr
|
||||
from paddleocr import PaddleOCR
|
||||
|
||||
if 'paddleocr' not in self.ocr_engines:
|
||||
self.ocr_engines['paddleocr'] = PaddleOCR(use_angle_cls=True, lang='ch')
|
||||
|
||||
ocr = self.ocr_engines['paddleocr']
|
||||
result = ocr.ocr(image, cls=True)
|
||||
|
||||
if not result or not result[0]:
|
||||
return None
|
||||
|
||||
# 提取文本和置信度
|
||||
texts = []
|
||||
confidences = []
|
||||
bounding_boxes = []
|
||||
|
||||
for line in result[0]:
|
||||
if line:
|
||||
bbox, (text, confidence) = line
|
||||
texts.append(text)
|
||||
confidences.append(confidence)
|
||||
bounding_boxes.append({
|
||||
'text': text,
|
||||
'confidence': confidence,
|
||||
'bbox': bbox
|
||||
})
|
||||
|
||||
merged_text = ' '.join(texts)
|
||||
avg_confidence = sum(confidences) / len(confidences) if confidences else 0.0
|
||||
|
||||
return OCRResult(
|
||||
text=merged_text,
|
||||
confidence=avg_confidence,
|
||||
bounding_boxes=bounding_boxes,
|
||||
processing_time=0.0,
|
||||
method='paddleocr'
|
||||
)
|
||||
|
||||
except ImportError:
|
||||
self.logger.warning("PaddleOCR未安装,跳过此方法")
|
||||
return None
|
||||
except Exception as e:
|
||||
self.logger.error(f"PaddleOCR识别失败: {e}")
|
||||
return None
|
||||
|
||||
def _easyocr_recognize(self, image: np.ndarray) -> Optional[OCRResult]:
|
||||
"""使用EasyOCR进行识别"""
|
||||
try:
|
||||
# 这里需要安装easyocr: pip install easyocr
|
||||
import easyocr
|
||||
|
||||
if 'easyocr' not in self.ocr_engines:
|
||||
self.ocr_engines['easyocr'] = easyocr.Reader(['ch_sim', 'en'])
|
||||
|
||||
reader = self.ocr_engines['easyocr']
|
||||
result = reader.readtext(image)
|
||||
|
||||
if not result:
|
||||
return None
|
||||
|
||||
# 提取文本和置信度
|
||||
texts = []
|
||||
confidences = []
|
||||
bounding_boxes = []
|
||||
|
||||
for bbox, text, confidence in result:
|
||||
texts.append(text)
|
||||
confidences.append(confidence)
|
||||
bounding_boxes.append({
|
||||
'text': text,
|
||||
'confidence': confidence,
|
||||
'bbox': bbox
|
||||
})
|
||||
|
||||
merged_text = ' '.join(texts)
|
||||
avg_confidence = sum(confidences) / len(confidences) if confidences else 0.0
|
||||
|
||||
return OCRResult(
|
||||
text=merged_text,
|
||||
confidence=avg_confidence,
|
||||
bounding_boxes=bounding_boxes,
|
||||
processing_time=0.0,
|
||||
method='easyocr'
|
||||
)
|
||||
|
||||
except ImportError:
|
||||
self.logger.warning("EasyOCR未安装,跳过此方法")
|
||||
return None
|
||||
except Exception as e:
|
||||
self.logger.error(f"EasyOCR识别失败: {e}")
|
||||
return None
|
||||
|
||||
def _merge_ocr_results(self, ocr_results: List[OCRResult]) -> str:
|
||||
"""合并多个OCR结果"""
|
||||
if not ocr_results:
|
||||
return ""
|
||||
|
||||
# 按置信度排序
|
||||
sorted_results = sorted(ocr_results, key=lambda x: x.confidence, reverse=True)
|
||||
|
||||
# 使用最高置信度的结果作为主要结果
|
||||
primary_result = sorted_results[0]
|
||||
|
||||
# 如果有多个高置信度结果,尝试合并
|
||||
if len(sorted_results) > 1 and sorted_results[1].confidence > 0.7:
|
||||
# 简单的文本合并策略
|
||||
merged_text = self._smart_text_merge([r.text for r in sorted_results[:3]])
|
||||
return merged_text
|
||||
|
||||
return primary_result.text
|
||||
|
||||
def _smart_text_merge(self, texts: List[str]) -> str:
|
||||
"""智能文本合并"""
|
||||
if not texts:
|
||||
return ""
|
||||
|
||||
if len(texts) == 1:
|
||||
return texts[0]
|
||||
|
||||
# 简单的合并策略:选择最长的文本
|
||||
return max(texts, key=len)
|
||||
|
||||
def _extract_calorie_info(self, text: str, user_data: UserData) -> List[CalorieInfo]:
|
||||
"""从文本中提取热量信息"""
|
||||
calorie_infos = []
|
||||
|
||||
try:
|
||||
# 查找热量数值
|
||||
for pattern in self.calorie_patterns:
|
||||
matches = re.finditer(pattern, text, re.IGNORECASE)
|
||||
for match in matches:
|
||||
calories = float(match.group(1))
|
||||
|
||||
# 查找对应的食物名称
|
||||
food_name = self._extract_food_name(text, match.start())
|
||||
|
||||
calorie_info = CalorieInfo(
|
||||
food_name=food_name,
|
||||
calories=calories,
|
||||
serving_size=None,
|
||||
confidence=0.8, # OCR基础置信度
|
||||
source='ocr',
|
||||
raw_text=match.group(0),
|
||||
validation_status='pending'
|
||||
)
|
||||
|
||||
calorie_infos.append(calorie_info)
|
||||
|
||||
# 如果没有找到热量信息,尝试查找食物名称
|
||||
if not calorie_infos:
|
||||
food_names = self._extract_all_food_names(text)
|
||||
for food_name in food_names:
|
||||
calorie_info = CalorieInfo(
|
||||
food_name=food_name,
|
||||
calories=None,
|
||||
serving_size=None,
|
||||
confidence=0.6,
|
||||
source='ocr',
|
||||
raw_text=food_name,
|
||||
validation_status='pending'
|
||||
)
|
||||
calorie_infos.append(calorie_info)
|
||||
|
||||
return calorie_infos
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"热量信息提取失败: {e}")
|
||||
return []
|
||||
|
||||
def _extract_food_name(self, text: str, calorie_position: int) -> str:
|
||||
"""提取食物名称"""
|
||||
try:
|
||||
# 在热量数值前后查找食物名称
|
||||
context_start = max(0, calorie_position - 50)
|
||||
context_end = min(len(text), calorie_position + 50)
|
||||
context = text[context_start:context_end]
|
||||
|
||||
# 查找中文和英文食物名称
|
||||
food_pattern = r'([a-zA-Z\u4e00-\u9fff]{2,20})'
|
||||
matches = re.findall(food_pattern, context)
|
||||
|
||||
if matches:
|
||||
# 选择最可能的食物名称
|
||||
return matches[0]
|
||||
|
||||
return "未知食物"
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"食物名称提取失败: {e}")
|
||||
return "未知食物"
|
||||
|
||||
def _extract_all_food_names(self, text: str) -> List[str]:
|
||||
"""提取所有可能的食物名称"""
|
||||
try:
|
||||
food_pattern = r'([a-zA-Z\u4e00-\u9fff]{2,20})'
|
||||
matches = re.findall(food_pattern, text)
|
||||
|
||||
# 去重并过滤
|
||||
unique_foods = list(set(matches))
|
||||
return unique_foods[:5] # 最多返回5个
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"食物名称提取失败: {e}")
|
||||
return []
|
||||
|
||||
def _validate_with_database(self, calorie_infos: List[CalorieInfo], user_data: UserData) -> List[CalorieInfo]:
|
||||
"""使用数据库验证热量信息"""
|
||||
validated_infos = []
|
||||
|
||||
for info in calorie_infos:
|
||||
try:
|
||||
# 在食物数据库中查找匹配
|
||||
db_match = self._find_database_match(info.food_name)
|
||||
|
||||
if db_match:
|
||||
# 使用数据库信息更新
|
||||
info.calories = db_match.get('calories', info.calories)
|
||||
info.serving_size = db_match.get('serving_size', info.serving_size)
|
||||
info.confidence = max(info.confidence, 0.9)
|
||||
info.source = 'database'
|
||||
|
||||
# 应用用户学习数据
|
||||
user_correction = self._get_user_correction(user_data.user_id, info.food_name)
|
||||
if user_correction:
|
||||
info.calories = user_correction.get('calories', info.calories)
|
||||
info.confidence = max(info.confidence, 0.95)
|
||||
info.source = 'user_confirmed'
|
||||
|
||||
validated_infos.append(info)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"数据库验证失败: {e}")
|
||||
validated_infos.append(info)
|
||||
|
||||
return validated_infos
|
||||
|
||||
def _find_database_match(self, food_name: str) -> Optional[Dict[str, Any]]:
|
||||
"""在数据库中查找食物匹配"""
|
||||
try:
|
||||
# 精确匹配
|
||||
if food_name in self.food_database:
|
||||
return self.food_database[food_name]
|
||||
|
||||
# 模糊匹配
|
||||
for db_food, info in self.food_database.items():
|
||||
if food_name in db_food or db_food in food_name:
|
||||
return info
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"数据库匹配失败: {e}")
|
||||
return None
|
||||
|
||||
def _get_user_correction(self, user_id: str, food_name: str) -> Optional[Dict[str, Any]]:
|
||||
"""获取用户修正数据"""
|
||||
try:
|
||||
user_data = self.user_corrections.get(user_id, {})
|
||||
return user_data.get(food_name)
|
||||
except Exception as e:
|
||||
self.logger.error(f"获取用户修正数据失败: {e}")
|
||||
return None
|
||||
|
||||
def _generate_suggestions(self, calorie_infos: List[CalorieInfo], user_data: UserData) -> List[str]:
|
||||
"""生成建议"""
|
||||
suggestions = []
|
||||
|
||||
try:
|
||||
for info in calorie_infos:
|
||||
if info.calories is None:
|
||||
suggestions.append(f"未识别到 {info.food_name} 的热量信息,请手动输入")
|
||||
elif info.confidence < 0.8:
|
||||
suggestions.append(f"{info.food_name} 的热量 {info.calories} 可能不准确,请确认")
|
||||
else:
|
||||
suggestions.append(f"{info.food_name}: {info.calories} 卡路里")
|
||||
|
||||
if not calorie_infos:
|
||||
suggestions.append("未识别到任何食物信息,请检查图片质量或手动输入")
|
||||
|
||||
return suggestions
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"生成建议失败: {e}")
|
||||
return ["识别过程中出现错误"]
|
||||
|
||||
def _calculate_overall_confidence(self, ocr_results: List[OCRResult], calorie_infos: List[CalorieInfo]) -> float:
|
||||
"""计算整体置信度"""
|
||||
try:
|
||||
if not ocr_results and not calorie_infos:
|
||||
return 0.0
|
||||
|
||||
# OCR置信度
|
||||
ocr_confidence = sum(r.confidence for r in ocr_results) / len(ocr_results) if ocr_results else 0.0
|
||||
|
||||
# 热量信息置信度
|
||||
calorie_confidence = sum(info.confidence for info in calorie_infos) / len(calorie_infos) if calorie_infos else 0.0
|
||||
|
||||
# 综合置信度
|
||||
overall_confidence = (ocr_confidence * 0.4 + calorie_confidence * 0.6)
|
||||
|
||||
return min(overall_confidence, 1.0)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"计算置信度失败: {e}")
|
||||
return 0.0
|
||||
|
||||
def _validate_recognition_result(self, input_data: Dict[str, Any], user_data: UserData) -> Dict[str, Any]:
|
||||
"""验证识别结果"""
|
||||
try:
|
||||
food_name = input_data.get('food_name')
|
||||
calories = input_data.get('calories')
|
||||
is_correct = input_data.get('is_correct', True)
|
||||
|
||||
if not is_correct:
|
||||
# 用户修正
|
||||
corrected_calories = input_data.get('corrected_calories')
|
||||
self._save_user_correction(user_data.user_id, food_name, corrected_calories)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'message': '验证结果已保存',
|
||||
'confidence': 1.0
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"验证识别结果失败: {e}")
|
||||
return self._create_error_result(f"验证失败: {str(e)}")
|
||||
|
||||
def _learn_from_correction(self, input_data: Dict[str, Any], user_data: UserData) -> Dict[str, Any]:
|
||||
"""从用户修正中学习"""
|
||||
try:
|
||||
food_name = input_data.get('food_name')
|
||||
corrected_calories = input_data.get('corrected_calories')
|
||||
|
||||
self._save_user_correction(user_data.user_id, food_name, corrected_calories)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'message': '学习数据已保存',
|
||||
'confidence': 1.0
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"学习修正数据失败: {e}")
|
||||
return self._create_error_result(f"学习失败: {str(e)}")
|
||||
|
||||
def _save_user_correction(self, user_id: str, food_name: str, calories: float):
|
||||
"""保存用户修正数据"""
|
||||
try:
|
||||
if user_id not in self.user_corrections:
|
||||
self.user_corrections[user_id] = {}
|
||||
|
||||
self.user_corrections[user_id][food_name] = {
|
||||
'calories': calories,
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'correction_count': self.user_corrections[user_id].get(food_name, {}).get('correction_count', 0) + 1
|
||||
}
|
||||
|
||||
# 保存到文件
|
||||
self._save_user_corrections_to_file()
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"保存用户修正数据失败: {e}")
|
||||
|
||||
def _load_food_database(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""加载食物数据库"""
|
||||
try:
|
||||
# 基础食物数据库
|
||||
food_db = {
|
||||
"米饭": {"calories": 130, "serving_size": "100g"},
|
||||
"面条": {"calories": 110, "serving_size": "100g"},
|
||||
"馒头": {"calories": 221, "serving_size": "100g"},
|
||||
"包子": {"calories": 250, "serving_size": "100g"},
|
||||
"饺子": {"calories": 250, "serving_size": "100g"},
|
||||
"鸡蛋": {"calories": 155, "serving_size": "100g"},
|
||||
"豆腐": {"calories": 76, "serving_size": "100g"},
|
||||
"鱼肉": {"calories": 206, "serving_size": "100g"},
|
||||
"鸡肉": {"calories": 165, "serving_size": "100g"},
|
||||
"瘦肉": {"calories": 250, "serving_size": "100g"},
|
||||
"青菜": {"calories": 25, "serving_size": "100g"},
|
||||
"西红柿": {"calories": 18, "serving_size": "100g"},
|
||||
"胡萝卜": {"calories": 41, "serving_size": "100g"},
|
||||
"土豆": {"calories": 77, "serving_size": "100g"},
|
||||
"西兰花": {"calories": 34, "serving_size": "100g"},
|
||||
"苹果": {"calories": 52, "serving_size": "100g"},
|
||||
"香蕉": {"calories": 89, "serving_size": "100g"},
|
||||
"橙子": {"calories": 47, "serving_size": "100g"},
|
||||
"葡萄": {"calories": 67, "serving_size": "100g"},
|
||||
"草莓": {"calories": 32, "serving_size": "100g"},
|
||||
"牛奶": {"calories": 42, "serving_size": "100ml"},
|
||||
"酸奶": {"calories": 59, "serving_size": "100g"},
|
||||
"豆浆": {"calories": 31, "serving_size": "100ml"},
|
||||
"坚果": {"calories": 607, "serving_size": "100g"},
|
||||
"红枣": {"calories": 264, "serving_size": "100g"},
|
||||
}
|
||||
|
||||
# 尝试从文件加载扩展数据库
|
||||
db_file = Path("data/food_database.json")
|
||||
if db_file.exists():
|
||||
with open(db_file, 'r', encoding='utf-8') as f:
|
||||
extended_db = json.load(f)
|
||||
food_db.update(extended_db)
|
||||
|
||||
return food_db
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"加载食物数据库失败: {e}")
|
||||
return {}
|
||||
|
||||
def _create_directories(self):
|
||||
"""创建必要的目录"""
|
||||
directories = [
|
||||
'data/ocr_cache',
|
||||
'data/user_corrections',
|
||||
'data/food_images'
|
||||
]
|
||||
|
||||
for directory in directories:
|
||||
Path(directory).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def _load_user_corrections(self):
|
||||
"""加载用户修正数据"""
|
||||
try:
|
||||
corrections_file = Path("data/user_corrections.json")
|
||||
if corrections_file.exists():
|
||||
with open(corrections_file, 'r', encoding='utf-8') as f:
|
||||
self.user_corrections = json.load(f)
|
||||
else:
|
||||
self.user_corrections = {}
|
||||
except Exception as e:
|
||||
self.logger.error(f"加载用户修正数据失败: {e}")
|
||||
self.user_corrections = {}
|
||||
|
||||
def _save_user_corrections_to_file(self):
|
||||
"""保存用户修正数据到文件"""
|
||||
try:
|
||||
corrections_file = Path("data/user_corrections.json")
|
||||
with open(corrections_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(self.user_corrections, f, ensure_ascii=False, indent=2)
|
||||
except Exception as e:
|
||||
self.logger.error(f"保存用户修正数据失败: {e}")
|
||||
|
||||
def _initialize_ocr_engines(self):
|
||||
"""初始化OCR引擎"""
|
||||
try:
|
||||
# 检查Tesseract是否可用
|
||||
try:
|
||||
pytesseract.get_tesseract_version()
|
||||
self.logger.info("Tesseract OCR引擎可用")
|
||||
except Exception:
|
||||
self.logger.warning("Tesseract OCR引擎不可用")
|
||||
|
||||
# 其他OCR引擎将在需要时初始化
|
||||
self.logger.info("OCR引擎初始化完成")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"OCR引擎初始化失败: {e}")
|
||||
|
||||
def _create_error_result(self, message: str) -> Dict[str, Any]:
|
||||
"""创建错误结果"""
|
||||
return {
|
||||
'success': False,
|
||||
'error': message,
|
||||
'confidence': 0.0
|
||||
}
|
||||
|
||||
def cleanup(self) -> bool:
|
||||
"""清理资源"""
|
||||
try:
|
||||
# 保存用户修正数据
|
||||
self._save_user_corrections_to_file()
|
||||
|
||||
# 清理OCR引擎
|
||||
self.ocr_engines.clear()
|
||||
|
||||
self.is_initialized = False
|
||||
self.logger.info("OCR热量识别模块清理完成")
|
||||
return True
|
||||
except Exception as e:
|
||||
self.logger.error(f"OCR热量识别模块清理失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 测试OCR模块
|
||||
from core.base import BaseConfig
|
||||
|
||||
config = BaseConfig()
|
||||
ocr_module = OCRCalorieRecognitionModule(config)
|
||||
|
||||
if ocr_module.initialize():
|
||||
print("OCR模块初始化成功")
|
||||
|
||||
# 测试图片识别
|
||||
test_data = {
|
||||
'type': 'recognize_image',
|
||||
'image_path': 'test_image.jpg' # 需要提供测试图片
|
||||
}
|
||||
|
||||
# 这里需要用户数据,暂时跳过实际测试
|
||||
print("OCR模块测试完成")
|
||||
else:
|
||||
print("OCR模块初始化失败")
|
||||
1016
modules/recommendation_engine.py
Normal file
1016
modules/recommendation_engine.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user