787 lines
29 KiB
Python
787 lines
29 KiB
Python
"""
|
||
图片OCR热量识别模块 - 基于基座架构
|
||
支持多种OCR技术识别食物热量信息,包含智能验证和修正机制
|
||
"""
|
||
|
||
import cv2
|
||
import numpy as np
|
||
import re
|
||
import json
|
||
import logging
|
||
from typing import Dict, List, Optional, Any, Tuple
|
||
from dataclasses import dataclass
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
import requests
|
||
import base64
|
||
from PIL import Image, ImageEnhance, ImageFilter
|
||
import pytesseract
|
||
from core.base import BaseModule, ModuleType, UserData, AnalysisResult, BaseConfig
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
@dataclass
|
||
class OCRResult:
|
||
"""OCR识别结果"""
|
||
text: str
|
||
confidence: float
|
||
bounding_boxes: List[Dict[str, Any]]
|
||
processing_time: float
|
||
method: str
|
||
|
||
|
||
@dataclass
|
||
class CalorieInfo:
|
||
"""热量信息"""
|
||
food_name: str
|
||
calories: Optional[float]
|
||
serving_size: Optional[str]
|
||
confidence: float
|
||
source: str # 'ocr', 'database', 'user_confirmed'
|
||
raw_text: str
|
||
validation_status: str # 'pending', 'validated', 'corrected'
|
||
|
||
|
||
@dataclass
|
||
class FoodRecognitionResult:
|
||
"""食物识别结果"""
|
||
image_path: str
|
||
ocr_results: List[OCRResult]
|
||
calorie_infos: List[CalorieInfo]
|
||
overall_confidence: float
|
||
processing_time: float
|
||
suggestions: List[str]
|
||
|
||
|
||
class OCRCalorieRecognitionModule(BaseModule):
|
||
"""OCR热量识别模块"""
|
||
|
||
def __init__(self, config: BaseConfig):
|
||
super().__init__(config, ModuleType.DATA_COLLECTION)
|
||
|
||
# OCR配置
|
||
self.ocr_methods = ['tesseract', 'paddleocr', 'easyocr']
|
||
self.min_confidence = 0.6
|
||
self.max_processing_time = 30.0
|
||
|
||
# 热量识别模式
|
||
self.calorie_patterns = [
|
||
r'(\d+(?:\.\d+)?)\s*[kK]?[cC][aA][lL](?:ories?)?',
|
||
r'(\d+(?:\.\d+)?)\s*[kK][cC][aA][lL]',
|
||
r'(\d+(?:\.\d+)?)\s*卡路里',
|
||
r'(\d+(?:\.\d+)?)\s*千卡',
|
||
r'(\d+(?:\.\d+)?)\s*大卡',
|
||
r'(\d+(?:\.\d+)?)\s*[kK][jJ]', # 千焦
|
||
]
|
||
|
||
# 食物名称模式
|
||
self.food_patterns = [
|
||
r'([a-zA-Z\u4e00-\u9fff]+)\s*(?:\d+(?:\.\d+)?)',
|
||
r'(\d+(?:\.\d+)?)\s*([a-zA-Z\u4e00-\u9fff]+)',
|
||
]
|
||
|
||
# 食物数据库
|
||
self.food_database = self._load_food_database()
|
||
|
||
# 用户学习数据
|
||
self.user_corrections = {}
|
||
|
||
# 初始化OCR引擎
|
||
self.ocr_engines = {}
|
||
self._initialize_ocr_engines()
|
||
|
||
def initialize(self) -> bool:
|
||
"""初始化模块"""
|
||
try:
|
||
self.logger.info("OCR热量识别模块初始化中...")
|
||
|
||
# 创建必要的目录
|
||
self._create_directories()
|
||
|
||
# 加载用户学习数据
|
||
self._load_user_corrections()
|
||
|
||
self.is_initialized = True
|
||
self.logger.info("OCR热量识别模块初始化完成")
|
||
return True
|
||
except Exception as e:
|
||
self.logger.error(f"OCR热量识别模块初始化失败: {e}")
|
||
return False
|
||
|
||
def process(self, input_data: Any, user_data: UserData) -> AnalysisResult:
|
||
"""处理OCR识别请求"""
|
||
try:
|
||
request_type = input_data.get('type', 'unknown')
|
||
|
||
if request_type == 'recognize_image':
|
||
result = self._recognize_image_calories(input_data, user_data)
|
||
elif request_type == 'validate_result':
|
||
result = self._validate_recognition_result(input_data, user_data)
|
||
elif request_type == 'learn_correction':
|
||
result = self._learn_from_correction(input_data, user_data)
|
||
else:
|
||
result = self._create_error_result("未知的请求类型")
|
||
|
||
return AnalysisResult(
|
||
module_type=self.module_type,
|
||
user_id=user_data.user_id,
|
||
input_data=input_data,
|
||
result=result,
|
||
confidence=result.get('confidence', 0.5)
|
||
)
|
||
except Exception as e:
|
||
self.logger.error(f"OCR识别处理失败: {e}")
|
||
return self._create_error_result(f"处理失败: {str(e)}")
|
||
|
||
def _recognize_image_calories(self, input_data: Dict[str, Any], user_data: UserData) -> Dict[str, Any]:
|
||
"""识别图片中的热量信息"""
|
||
start_time = datetime.now()
|
||
|
||
try:
|
||
image_path = input_data.get('image_path')
|
||
if not image_path or not Path(image_path).exists():
|
||
return self._create_error_result("图片文件不存在")
|
||
|
||
# 预处理图片
|
||
processed_image = self._preprocess_image(image_path)
|
||
|
||
# 多OCR引擎识别
|
||
ocr_results = []
|
||
for method in self.ocr_methods:
|
||
try:
|
||
result = self._ocr_recognize(processed_image, method)
|
||
if result:
|
||
ocr_results.append(result)
|
||
except Exception as e:
|
||
self.logger.warning(f"OCR方法 {method} 失败: {e}")
|
||
|
||
# 合并和去重OCR结果
|
||
merged_text = self._merge_ocr_results(ocr_results)
|
||
|
||
# 提取热量信息
|
||
calorie_infos = self._extract_calorie_info(merged_text, user_data)
|
||
|
||
# 数据库匹配和验证
|
||
validated_infos = self._validate_with_database(calorie_infos, user_data)
|
||
|
||
# 生成建议
|
||
suggestions = self._generate_suggestions(validated_infos, user_data)
|
||
|
||
processing_time = (datetime.now() - start_time).total_seconds()
|
||
|
||
result = FoodRecognitionResult(
|
||
image_path=image_path,
|
||
ocr_results=ocr_results,
|
||
calorie_infos=validated_infos,
|
||
overall_confidence=self._calculate_overall_confidence(ocr_results, validated_infos),
|
||
processing_time=processing_time,
|
||
suggestions=suggestions
|
||
)
|
||
|
||
return {
|
||
'success': True,
|
||
'result': result,
|
||
'confidence': result.overall_confidence,
|
||
'message': f"识别完成,处理时间: {processing_time:.2f}秒"
|
||
}
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"图片热量识别失败: {e}")
|
||
return self._create_error_result(f"识别失败: {str(e)}")
|
||
|
||
def _preprocess_image(self, image_path: str) -> np.ndarray:
|
||
"""预处理图片以提高OCR准确性"""
|
||
try:
|
||
# 读取图片
|
||
image = cv2.imread(image_path)
|
||
if image is None:
|
||
raise ValueError("无法读取图片")
|
||
|
||
# 转换为灰度图
|
||
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||
|
||
# 降噪
|
||
denoised = cv2.medianBlur(gray, 3)
|
||
|
||
# 增强对比度
|
||
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
||
enhanced = clahe.apply(denoised)
|
||
|
||
# 二值化
|
||
_, binary = cv2.threshold(enhanced, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
||
|
||
# 形态学操作
|
||
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (2, 2))
|
||
cleaned = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel)
|
||
|
||
return cleaned
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"图片预处理失败: {e}")
|
||
return cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
|
||
|
||
def _ocr_recognize(self, image: np.ndarray, method: str) -> Optional[OCRResult]:
|
||
"""使用指定方法进行OCR识别"""
|
||
start_time = datetime.now()
|
||
|
||
try:
|
||
if method == 'tesseract':
|
||
return self._tesseract_ocr(image)
|
||
elif method == 'paddleocr':
|
||
return self._paddleocr_recognize(image)
|
||
elif method == 'easyocr':
|
||
return self._easyocr_recognize(image)
|
||
else:
|
||
return None
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"OCR方法 {method} 失败: {e}")
|
||
return None
|
||
|
||
def _tesseract_ocr(self, image: np.ndarray) -> OCRResult:
|
||
"""使用Tesseract进行OCR识别"""
|
||
try:
|
||
# 配置Tesseract
|
||
config = '--oem 3 --psm 6 -c tessedit_char_whitelist=0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz\u4e00-\u9fff'
|
||
|
||
# OCR识别
|
||
text = pytesseract.image_to_string(image, config=config, lang='chi_sim+eng')
|
||
|
||
# 获取置信度
|
||
data = pytesseract.image_to_data(image, output_type=pytesseract.Output.DICT, config=config, lang='chi_sim+eng')
|
||
confidences = [int(conf) for conf in data['conf'] if int(conf) > 0]
|
||
avg_confidence = sum(confidences) / len(confidences) / 100.0 if confidences else 0.0
|
||
|
||
# 获取边界框
|
||
bounding_boxes = []
|
||
for i in range(len(data['text'])):
|
||
if int(data['conf'][i]) > 0:
|
||
bounding_boxes.append({
|
||
'text': data['text'][i],
|
||
'confidence': int(data['conf'][i]) / 100.0,
|
||
'bbox': [data['left'][i], data['top'][i], data['width'][i], data['height'][i]]
|
||
})
|
||
|
||
processing_time = (datetime.now() - datetime.now()).total_seconds()
|
||
|
||
return OCRResult(
|
||
text=text.strip(),
|
||
confidence=avg_confidence,
|
||
bounding_boxes=bounding_boxes,
|
||
processing_time=processing_time,
|
||
method='tesseract'
|
||
)
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"Tesseract OCR失败: {e}")
|
||
return None
|
||
|
||
def _paddleocr_recognize(self, image: np.ndarray) -> Optional[OCRResult]:
|
||
"""使用PaddleOCR进行识别"""
|
||
try:
|
||
# 这里需要安装paddleocr: pip install paddleocr
|
||
from paddleocr import PaddleOCR
|
||
|
||
if 'paddleocr' not in self.ocr_engines:
|
||
self.ocr_engines['paddleocr'] = PaddleOCR(use_angle_cls=True, lang='ch')
|
||
|
||
ocr = self.ocr_engines['paddleocr']
|
||
result = ocr.ocr(image, cls=True)
|
||
|
||
if not result or not result[0]:
|
||
return None
|
||
|
||
# 提取文本和置信度
|
||
texts = []
|
||
confidences = []
|
||
bounding_boxes = []
|
||
|
||
for line in result[0]:
|
||
if line:
|
||
bbox, (text, confidence) = line
|
||
texts.append(text)
|
||
confidences.append(confidence)
|
||
bounding_boxes.append({
|
||
'text': text,
|
||
'confidence': confidence,
|
||
'bbox': bbox
|
||
})
|
||
|
||
merged_text = ' '.join(texts)
|
||
avg_confidence = sum(confidences) / len(confidences) if confidences else 0.0
|
||
|
||
return OCRResult(
|
||
text=merged_text,
|
||
confidence=avg_confidence,
|
||
bounding_boxes=bounding_boxes,
|
||
processing_time=0.0,
|
||
method='paddleocr'
|
||
)
|
||
|
||
except ImportError:
|
||
self.logger.warning("PaddleOCR未安装,跳过此方法")
|
||
return None
|
||
except Exception as e:
|
||
self.logger.error(f"PaddleOCR识别失败: {e}")
|
||
return None
|
||
|
||
def _easyocr_recognize(self, image: np.ndarray) -> Optional[OCRResult]:
|
||
"""使用EasyOCR进行识别"""
|
||
try:
|
||
# 这里需要安装easyocr: pip install easyocr
|
||
import easyocr
|
||
|
||
if 'easyocr' not in self.ocr_engines:
|
||
self.ocr_engines['easyocr'] = easyocr.Reader(['ch_sim', 'en'])
|
||
|
||
reader = self.ocr_engines['easyocr']
|
||
result = reader.readtext(image)
|
||
|
||
if not result:
|
||
return None
|
||
|
||
# 提取文本和置信度
|
||
texts = []
|
||
confidences = []
|
||
bounding_boxes = []
|
||
|
||
for bbox, text, confidence in result:
|
||
texts.append(text)
|
||
confidences.append(confidence)
|
||
bounding_boxes.append({
|
||
'text': text,
|
||
'confidence': confidence,
|
||
'bbox': bbox
|
||
})
|
||
|
||
merged_text = ' '.join(texts)
|
||
avg_confidence = sum(confidences) / len(confidences) if confidences else 0.0
|
||
|
||
return OCRResult(
|
||
text=merged_text,
|
||
confidence=avg_confidence,
|
||
bounding_boxes=bounding_boxes,
|
||
processing_time=0.0,
|
||
method='easyocr'
|
||
)
|
||
|
||
except ImportError:
|
||
self.logger.warning("EasyOCR未安装,跳过此方法")
|
||
return None
|
||
except Exception as e:
|
||
self.logger.error(f"EasyOCR识别失败: {e}")
|
||
return None
|
||
|
||
def _merge_ocr_results(self, ocr_results: List[OCRResult]) -> str:
|
||
"""合并多个OCR结果"""
|
||
if not ocr_results:
|
||
return ""
|
||
|
||
# 按置信度排序
|
||
sorted_results = sorted(ocr_results, key=lambda x: x.confidence, reverse=True)
|
||
|
||
# 使用最高置信度的结果作为主要结果
|
||
primary_result = sorted_results[0]
|
||
|
||
# 如果有多个高置信度结果,尝试合并
|
||
if len(sorted_results) > 1 and sorted_results[1].confidence > 0.7:
|
||
# 简单的文本合并策略
|
||
merged_text = self._smart_text_merge([r.text for r in sorted_results[:3]])
|
||
return merged_text
|
||
|
||
return primary_result.text
|
||
|
||
def _smart_text_merge(self, texts: List[str]) -> str:
|
||
"""智能文本合并"""
|
||
if not texts:
|
||
return ""
|
||
|
||
if len(texts) == 1:
|
||
return texts[0]
|
||
|
||
# 简单的合并策略:选择最长的文本
|
||
return max(texts, key=len)
|
||
|
||
def _extract_calorie_info(self, text: str, user_data: UserData) -> List[CalorieInfo]:
|
||
"""从文本中提取热量信息"""
|
||
calorie_infos = []
|
||
|
||
try:
|
||
# 查找热量数值
|
||
for pattern in self.calorie_patterns:
|
||
matches = re.finditer(pattern, text, re.IGNORECASE)
|
||
for match in matches:
|
||
calories = float(match.group(1))
|
||
|
||
# 查找对应的食物名称
|
||
food_name = self._extract_food_name(text, match.start())
|
||
|
||
calorie_info = CalorieInfo(
|
||
food_name=food_name,
|
||
calories=calories,
|
||
serving_size=None,
|
||
confidence=0.8, # OCR基础置信度
|
||
source='ocr',
|
||
raw_text=match.group(0),
|
||
validation_status='pending'
|
||
)
|
||
|
||
calorie_infos.append(calorie_info)
|
||
|
||
# 如果没有找到热量信息,尝试查找食物名称
|
||
if not calorie_infos:
|
||
food_names = self._extract_all_food_names(text)
|
||
for food_name in food_names:
|
||
calorie_info = CalorieInfo(
|
||
food_name=food_name,
|
||
calories=None,
|
||
serving_size=None,
|
||
confidence=0.6,
|
||
source='ocr',
|
||
raw_text=food_name,
|
||
validation_status='pending'
|
||
)
|
||
calorie_infos.append(calorie_info)
|
||
|
||
return calorie_infos
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"热量信息提取失败: {e}")
|
||
return []
|
||
|
||
def _extract_food_name(self, text: str, calorie_position: int) -> str:
|
||
"""提取食物名称"""
|
||
try:
|
||
# 在热量数值前后查找食物名称
|
||
context_start = max(0, calorie_position - 50)
|
||
context_end = min(len(text), calorie_position + 50)
|
||
context = text[context_start:context_end]
|
||
|
||
# 查找中文和英文食物名称
|
||
food_pattern = r'([a-zA-Z\u4e00-\u9fff]{2,20})'
|
||
matches = re.findall(food_pattern, context)
|
||
|
||
if matches:
|
||
# 选择最可能的食物名称
|
||
return matches[0]
|
||
|
||
return "未知食物"
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"食物名称提取失败: {e}")
|
||
return "未知食物"
|
||
|
||
def _extract_all_food_names(self, text: str) -> List[str]:
|
||
"""提取所有可能的食物名称"""
|
||
try:
|
||
food_pattern = r'([a-zA-Z\u4e00-\u9fff]{2,20})'
|
||
matches = re.findall(food_pattern, text)
|
||
|
||
# 去重并过滤
|
||
unique_foods = list(set(matches))
|
||
return unique_foods[:5] # 最多返回5个
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"食物名称提取失败: {e}")
|
||
return []
|
||
|
||
def _validate_with_database(self, calorie_infos: List[CalorieInfo], user_data: UserData) -> List[CalorieInfo]:
|
||
"""使用数据库验证热量信息"""
|
||
validated_infos = []
|
||
|
||
for info in calorie_infos:
|
||
try:
|
||
# 在食物数据库中查找匹配
|
||
db_match = self._find_database_match(info.food_name)
|
||
|
||
if db_match:
|
||
# 使用数据库信息更新
|
||
info.calories = db_match.get('calories', info.calories)
|
||
info.serving_size = db_match.get('serving_size', info.serving_size)
|
||
info.confidence = max(info.confidence, 0.9)
|
||
info.source = 'database'
|
||
|
||
# 应用用户学习数据
|
||
user_correction = self._get_user_correction(user_data.user_id, info.food_name)
|
||
if user_correction:
|
||
info.calories = user_correction.get('calories', info.calories)
|
||
info.confidence = max(info.confidence, 0.95)
|
||
info.source = 'user_confirmed'
|
||
|
||
validated_infos.append(info)
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"数据库验证失败: {e}")
|
||
validated_infos.append(info)
|
||
|
||
return validated_infos
|
||
|
||
def _find_database_match(self, food_name: str) -> Optional[Dict[str, Any]]:
|
||
"""在数据库中查找食物匹配"""
|
||
try:
|
||
# 精确匹配
|
||
if food_name in self.food_database:
|
||
return self.food_database[food_name]
|
||
|
||
# 模糊匹配
|
||
for db_food, info in self.food_database.items():
|
||
if food_name in db_food or db_food in food_name:
|
||
return info
|
||
|
||
return None
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"数据库匹配失败: {e}")
|
||
return None
|
||
|
||
def _get_user_correction(self, user_id: str, food_name: str) -> Optional[Dict[str, Any]]:
|
||
"""获取用户修正数据"""
|
||
try:
|
||
user_data = self.user_corrections.get(user_id, {})
|
||
return user_data.get(food_name)
|
||
except Exception as e:
|
||
self.logger.error(f"获取用户修正数据失败: {e}")
|
||
return None
|
||
|
||
def _generate_suggestions(self, calorie_infos: List[CalorieInfo], user_data: UserData) -> List[str]:
|
||
"""生成建议"""
|
||
suggestions = []
|
||
|
||
try:
|
||
for info in calorie_infos:
|
||
if info.calories is None:
|
||
suggestions.append(f"未识别到 {info.food_name} 的热量信息,请手动输入")
|
||
elif info.confidence < 0.8:
|
||
suggestions.append(f"{info.food_name} 的热量 {info.calories} 可能不准确,请确认")
|
||
else:
|
||
suggestions.append(f"{info.food_name}: {info.calories} 卡路里")
|
||
|
||
if not calorie_infos:
|
||
suggestions.append("未识别到任何食物信息,请检查图片质量或手动输入")
|
||
|
||
return suggestions
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"生成建议失败: {e}")
|
||
return ["识别过程中出现错误"]
|
||
|
||
def _calculate_overall_confidence(self, ocr_results: List[OCRResult], calorie_infos: List[CalorieInfo]) -> float:
|
||
"""计算整体置信度"""
|
||
try:
|
||
if not ocr_results and not calorie_infos:
|
||
return 0.0
|
||
|
||
# OCR置信度
|
||
ocr_confidence = sum(r.confidence for r in ocr_results) / len(ocr_results) if ocr_results else 0.0
|
||
|
||
# 热量信息置信度
|
||
calorie_confidence = sum(info.confidence for info in calorie_infos) / len(calorie_infos) if calorie_infos else 0.0
|
||
|
||
# 综合置信度
|
||
overall_confidence = (ocr_confidence * 0.4 + calorie_confidence * 0.6)
|
||
|
||
return min(overall_confidence, 1.0)
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"计算置信度失败: {e}")
|
||
return 0.0
|
||
|
||
def _validate_recognition_result(self, input_data: Dict[str, Any], user_data: UserData) -> Dict[str, Any]:
|
||
"""验证识别结果"""
|
||
try:
|
||
food_name = input_data.get('food_name')
|
||
calories = input_data.get('calories')
|
||
is_correct = input_data.get('is_correct', True)
|
||
|
||
if not is_correct:
|
||
# 用户修正
|
||
corrected_calories = input_data.get('corrected_calories')
|
||
self._save_user_correction(user_data.user_id, food_name, corrected_calories)
|
||
|
||
return {
|
||
'success': True,
|
||
'message': '验证结果已保存',
|
||
'confidence': 1.0
|
||
}
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"验证识别结果失败: {e}")
|
||
return self._create_error_result(f"验证失败: {str(e)}")
|
||
|
||
def _learn_from_correction(self, input_data: Dict[str, Any], user_data: UserData) -> Dict[str, Any]:
|
||
"""从用户修正中学习"""
|
||
try:
|
||
food_name = input_data.get('food_name')
|
||
corrected_calories = input_data.get('corrected_calories')
|
||
|
||
self._save_user_correction(user_data.user_id, food_name, corrected_calories)
|
||
|
||
return {
|
||
'success': True,
|
||
'message': '学习数据已保存',
|
||
'confidence': 1.0
|
||
}
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"学习修正数据失败: {e}")
|
||
return self._create_error_result(f"学习失败: {str(e)}")
|
||
|
||
def _save_user_correction(self, user_id: str, food_name: str, calories: float):
|
||
"""保存用户修正数据"""
|
||
try:
|
||
if user_id not in self.user_corrections:
|
||
self.user_corrections[user_id] = {}
|
||
|
||
self.user_corrections[user_id][food_name] = {
|
||
'calories': calories,
|
||
'timestamp': datetime.now().isoformat(),
|
||
'correction_count': self.user_corrections[user_id].get(food_name, {}).get('correction_count', 0) + 1
|
||
}
|
||
|
||
# 保存到文件
|
||
self._save_user_corrections_to_file()
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"保存用户修正数据失败: {e}")
|
||
|
||
def _load_food_database(self) -> Dict[str, Dict[str, Any]]:
|
||
"""加载食物数据库"""
|
||
try:
|
||
# 基础食物数据库
|
||
food_db = {
|
||
"米饭": {"calories": 130, "serving_size": "100g"},
|
||
"面条": {"calories": 110, "serving_size": "100g"},
|
||
"馒头": {"calories": 221, "serving_size": "100g"},
|
||
"包子": {"calories": 250, "serving_size": "100g"},
|
||
"饺子": {"calories": 250, "serving_size": "100g"},
|
||
"鸡蛋": {"calories": 155, "serving_size": "100g"},
|
||
"豆腐": {"calories": 76, "serving_size": "100g"},
|
||
"鱼肉": {"calories": 206, "serving_size": "100g"},
|
||
"鸡肉": {"calories": 165, "serving_size": "100g"},
|
||
"瘦肉": {"calories": 250, "serving_size": "100g"},
|
||
"青菜": {"calories": 25, "serving_size": "100g"},
|
||
"西红柿": {"calories": 18, "serving_size": "100g"},
|
||
"胡萝卜": {"calories": 41, "serving_size": "100g"},
|
||
"土豆": {"calories": 77, "serving_size": "100g"},
|
||
"西兰花": {"calories": 34, "serving_size": "100g"},
|
||
"苹果": {"calories": 52, "serving_size": "100g"},
|
||
"香蕉": {"calories": 89, "serving_size": "100g"},
|
||
"橙子": {"calories": 47, "serving_size": "100g"},
|
||
"葡萄": {"calories": 67, "serving_size": "100g"},
|
||
"草莓": {"calories": 32, "serving_size": "100g"},
|
||
"牛奶": {"calories": 42, "serving_size": "100ml"},
|
||
"酸奶": {"calories": 59, "serving_size": "100g"},
|
||
"豆浆": {"calories": 31, "serving_size": "100ml"},
|
||
"坚果": {"calories": 607, "serving_size": "100g"},
|
||
"红枣": {"calories": 264, "serving_size": "100g"},
|
||
}
|
||
|
||
# 尝试从文件加载扩展数据库
|
||
db_file = Path("data/food_database.json")
|
||
if db_file.exists():
|
||
with open(db_file, 'r', encoding='utf-8') as f:
|
||
extended_db = json.load(f)
|
||
food_db.update(extended_db)
|
||
|
||
return food_db
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"加载食物数据库失败: {e}")
|
||
return {}
|
||
|
||
def _create_directories(self):
|
||
"""创建必要的目录"""
|
||
directories = [
|
||
'data/ocr_cache',
|
||
'data/user_corrections',
|
||
'data/food_images'
|
||
]
|
||
|
||
for directory in directories:
|
||
Path(directory).mkdir(parents=True, exist_ok=True)
|
||
|
||
def _load_user_corrections(self):
|
||
"""加载用户修正数据"""
|
||
try:
|
||
corrections_file = Path("data/user_corrections.json")
|
||
if corrections_file.exists():
|
||
with open(corrections_file, 'r', encoding='utf-8') as f:
|
||
self.user_corrections = json.load(f)
|
||
else:
|
||
self.user_corrections = {}
|
||
except Exception as e:
|
||
self.logger.error(f"加载用户修正数据失败: {e}")
|
||
self.user_corrections = {}
|
||
|
||
def _save_user_corrections_to_file(self):
|
||
"""保存用户修正数据到文件"""
|
||
try:
|
||
corrections_file = Path("data/user_corrections.json")
|
||
with open(corrections_file, 'w', encoding='utf-8') as f:
|
||
json.dump(self.user_corrections, f, ensure_ascii=False, indent=2)
|
||
except Exception as e:
|
||
self.logger.error(f"保存用户修正数据失败: {e}")
|
||
|
||
def _initialize_ocr_engines(self):
|
||
"""初始化OCR引擎"""
|
||
try:
|
||
# 检查Tesseract是否可用
|
||
try:
|
||
pytesseract.get_tesseract_version()
|
||
self.logger.info("Tesseract OCR引擎可用")
|
||
except Exception:
|
||
self.logger.warning("Tesseract OCR引擎不可用")
|
||
|
||
# 其他OCR引擎将在需要时初始化
|
||
self.logger.info("OCR引擎初始化完成")
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"OCR引擎初始化失败: {e}")
|
||
|
||
def _create_error_result(self, message: str) -> Dict[str, Any]:
|
||
"""创建错误结果"""
|
||
return {
|
||
'success': False,
|
||
'error': message,
|
||
'confidence': 0.0
|
||
}
|
||
|
||
def cleanup(self) -> bool:
|
||
"""清理资源"""
|
||
try:
|
||
# 保存用户修正数据
|
||
self._save_user_corrections_to_file()
|
||
|
||
# 清理OCR引擎
|
||
self.ocr_engines.clear()
|
||
|
||
self.is_initialized = False
|
||
self.logger.info("OCR热量识别模块清理完成")
|
||
return True
|
||
except Exception as e:
|
||
self.logger.error(f"OCR热量识别模块清理失败: {e}")
|
||
return False
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# 测试OCR模块
|
||
from core.base import BaseConfig
|
||
|
||
config = BaseConfig()
|
||
ocr_module = OCRCalorieRecognitionModule(config)
|
||
|
||
if ocr_module.initialize():
|
||
print("OCR模块初始化成功")
|
||
|
||
# 测试图片识别
|
||
test_data = {
|
||
'type': 'recognize_image',
|
||
'image_path': 'test_image.jpg' # 需要提供测试图片
|
||
}
|
||
|
||
# 这里需要用户数据,暂时跳过实际测试
|
||
print("OCR模块测试完成")
|
||
else:
|
||
print("OCR模块初始化失败")
|