Files
recommend/modules/ocr_calorie_recognition.py

787 lines
29 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
图片OCR热量识别模块 - 基于基座架构
支持多种OCR技术识别食物热量信息包含智能验证和修正机制
"""
import cv2
import numpy as np
import re
import json
import logging
from typing import Dict, List, Optional, Any, Tuple
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
import requests
import base64
from PIL import Image, ImageEnhance, ImageFilter
import pytesseract
from core.base import BaseModule, ModuleType, UserData, AnalysisResult, BaseConfig
logger = logging.getLogger(__name__)
@dataclass
class OCRResult:
"""OCR识别结果"""
text: str
confidence: float
bounding_boxes: List[Dict[str, Any]]
processing_time: float
method: str
@dataclass
class CalorieInfo:
"""热量信息"""
food_name: str
calories: Optional[float]
serving_size: Optional[str]
confidence: float
source: str # 'ocr', 'database', 'user_confirmed'
raw_text: str
validation_status: str # 'pending', 'validated', 'corrected'
@dataclass
class FoodRecognitionResult:
"""食物识别结果"""
image_path: str
ocr_results: List[OCRResult]
calorie_infos: List[CalorieInfo]
overall_confidence: float
processing_time: float
suggestions: List[str]
class OCRCalorieRecognitionModule(BaseModule):
"""OCR热量识别模块"""
def __init__(self, config: BaseConfig):
super().__init__(config, ModuleType.DATA_COLLECTION)
# OCR配置
self.ocr_methods = ['tesseract', 'paddleocr', 'easyocr']
self.min_confidence = 0.6
self.max_processing_time = 30.0
# 热量识别模式
self.calorie_patterns = [
r'(\d+(?:\.\d+)?)\s*[kK]?[cC][aA][lL](?:ories?)?',
r'(\d+(?:\.\d+)?)\s*[kK][cC][aA][lL]',
r'(\d+(?:\.\d+)?)\s*卡路里',
r'(\d+(?:\.\d+)?)\s*千卡',
r'(\d+(?:\.\d+)?)\s*大卡',
r'(\d+(?:\.\d+)?)\s*[kK][jJ]', # 千焦
]
# 食物名称模式
self.food_patterns = [
r'([a-zA-Z\u4e00-\u9fff]+)\s*(?:\d+(?:\.\d+)?)',
r'(\d+(?:\.\d+)?)\s*([a-zA-Z\u4e00-\u9fff]+)',
]
# 食物数据库
self.food_database = self._load_food_database()
# 用户学习数据
self.user_corrections = {}
# 初始化OCR引擎
self.ocr_engines = {}
self._initialize_ocr_engines()
def initialize(self) -> bool:
"""初始化模块"""
try:
self.logger.info("OCR热量识别模块初始化中...")
# 创建必要的目录
self._create_directories()
# 加载用户学习数据
self._load_user_corrections()
self.is_initialized = True
self.logger.info("OCR热量识别模块初始化完成")
return True
except Exception as e:
self.logger.error(f"OCR热量识别模块初始化失败: {e}")
return False
def process(self, input_data: Any, user_data: UserData) -> AnalysisResult:
"""处理OCR识别请求"""
try:
request_type = input_data.get('type', 'unknown')
if request_type == 'recognize_image':
result = self._recognize_image_calories(input_data, user_data)
elif request_type == 'validate_result':
result = self._validate_recognition_result(input_data, user_data)
elif request_type == 'learn_correction':
result = self._learn_from_correction(input_data, user_data)
else:
result = self._create_error_result("未知的请求类型")
return AnalysisResult(
module_type=self.module_type,
user_id=user_data.user_id,
input_data=input_data,
result=result,
confidence=result.get('confidence', 0.5)
)
except Exception as e:
self.logger.error(f"OCR识别处理失败: {e}")
return self._create_error_result(f"处理失败: {str(e)}")
def _recognize_image_calories(self, input_data: Dict[str, Any], user_data: UserData) -> Dict[str, Any]:
"""识别图片中的热量信息"""
start_time = datetime.now()
try:
image_path = input_data.get('image_path')
if not image_path or not Path(image_path).exists():
return self._create_error_result("图片文件不存在")
# 预处理图片
processed_image = self._preprocess_image(image_path)
# 多OCR引擎识别
ocr_results = []
for method in self.ocr_methods:
try:
result = self._ocr_recognize(processed_image, method)
if result:
ocr_results.append(result)
except Exception as e:
self.logger.warning(f"OCR方法 {method} 失败: {e}")
# 合并和去重OCR结果
merged_text = self._merge_ocr_results(ocr_results)
# 提取热量信息
calorie_infos = self._extract_calorie_info(merged_text, user_data)
# 数据库匹配和验证
validated_infos = self._validate_with_database(calorie_infos, user_data)
# 生成建议
suggestions = self._generate_suggestions(validated_infos, user_data)
processing_time = (datetime.now() - start_time).total_seconds()
result = FoodRecognitionResult(
image_path=image_path,
ocr_results=ocr_results,
calorie_infos=validated_infos,
overall_confidence=self._calculate_overall_confidence(ocr_results, validated_infos),
processing_time=processing_time,
suggestions=suggestions
)
return {
'success': True,
'result': result,
'confidence': result.overall_confidence,
'message': f"识别完成,处理时间: {processing_time:.2f}"
}
except Exception as e:
self.logger.error(f"图片热量识别失败: {e}")
return self._create_error_result(f"识别失败: {str(e)}")
def _preprocess_image(self, image_path: str) -> np.ndarray:
"""预处理图片以提高OCR准确性"""
try:
# 读取图片
image = cv2.imread(image_path)
if image is None:
raise ValueError("无法读取图片")
# 转换为灰度图
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
# 降噪
denoised = cv2.medianBlur(gray, 3)
# 增强对比度
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
enhanced = clahe.apply(denoised)
# 二值化
_, binary = cv2.threshold(enhanced, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
# 形态学操作
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (2, 2))
cleaned = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel)
return cleaned
except Exception as e:
self.logger.error(f"图片预处理失败: {e}")
return cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
def _ocr_recognize(self, image: np.ndarray, method: str) -> Optional[OCRResult]:
"""使用指定方法进行OCR识别"""
start_time = datetime.now()
try:
if method == 'tesseract':
return self._tesseract_ocr(image)
elif method == 'paddleocr':
return self._paddleocr_recognize(image)
elif method == 'easyocr':
return self._easyocr_recognize(image)
else:
return None
except Exception as e:
self.logger.error(f"OCR方法 {method} 失败: {e}")
return None
def _tesseract_ocr(self, image: np.ndarray) -> OCRResult:
"""使用Tesseract进行OCR识别"""
try:
# 配置Tesseract
config = '--oem 3 --psm 6 -c tessedit_char_whitelist=0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz\u4e00-\u9fff'
# OCR识别
text = pytesseract.image_to_string(image, config=config, lang='chi_sim+eng')
# 获取置信度
data = pytesseract.image_to_data(image, output_type=pytesseract.Output.DICT, config=config, lang='chi_sim+eng')
confidences = [int(conf) for conf in data['conf'] if int(conf) > 0]
avg_confidence = sum(confidences) / len(confidences) / 100.0 if confidences else 0.0
# 获取边界框
bounding_boxes = []
for i in range(len(data['text'])):
if int(data['conf'][i]) > 0:
bounding_boxes.append({
'text': data['text'][i],
'confidence': int(data['conf'][i]) / 100.0,
'bbox': [data['left'][i], data['top'][i], data['width'][i], data['height'][i]]
})
processing_time = (datetime.now() - datetime.now()).total_seconds()
return OCRResult(
text=text.strip(),
confidence=avg_confidence,
bounding_boxes=bounding_boxes,
processing_time=processing_time,
method='tesseract'
)
except Exception as e:
self.logger.error(f"Tesseract OCR失败: {e}")
return None
def _paddleocr_recognize(self, image: np.ndarray) -> Optional[OCRResult]:
"""使用PaddleOCR进行识别"""
try:
# 这里需要安装paddleocr: pip install paddleocr
from paddleocr import PaddleOCR
if 'paddleocr' not in self.ocr_engines:
self.ocr_engines['paddleocr'] = PaddleOCR(use_angle_cls=True, lang='ch')
ocr = self.ocr_engines['paddleocr']
result = ocr.ocr(image, cls=True)
if not result or not result[0]:
return None
# 提取文本和置信度
texts = []
confidences = []
bounding_boxes = []
for line in result[0]:
if line:
bbox, (text, confidence) = line
texts.append(text)
confidences.append(confidence)
bounding_boxes.append({
'text': text,
'confidence': confidence,
'bbox': bbox
})
merged_text = ' '.join(texts)
avg_confidence = sum(confidences) / len(confidences) if confidences else 0.0
return OCRResult(
text=merged_text,
confidence=avg_confidence,
bounding_boxes=bounding_boxes,
processing_time=0.0,
method='paddleocr'
)
except ImportError:
self.logger.warning("PaddleOCR未安装跳过此方法")
return None
except Exception as e:
self.logger.error(f"PaddleOCR识别失败: {e}")
return None
def _easyocr_recognize(self, image: np.ndarray) -> Optional[OCRResult]:
"""使用EasyOCR进行识别"""
try:
# 这里需要安装easyocr: pip install easyocr
import easyocr
if 'easyocr' not in self.ocr_engines:
self.ocr_engines['easyocr'] = easyocr.Reader(['ch_sim', 'en'])
reader = self.ocr_engines['easyocr']
result = reader.readtext(image)
if not result:
return None
# 提取文本和置信度
texts = []
confidences = []
bounding_boxes = []
for bbox, text, confidence in result:
texts.append(text)
confidences.append(confidence)
bounding_boxes.append({
'text': text,
'confidence': confidence,
'bbox': bbox
})
merged_text = ' '.join(texts)
avg_confidence = sum(confidences) / len(confidences) if confidences else 0.0
return OCRResult(
text=merged_text,
confidence=avg_confidence,
bounding_boxes=bounding_boxes,
processing_time=0.0,
method='easyocr'
)
except ImportError:
self.logger.warning("EasyOCR未安装跳过此方法")
return None
except Exception as e:
self.logger.error(f"EasyOCR识别失败: {e}")
return None
def _merge_ocr_results(self, ocr_results: List[OCRResult]) -> str:
"""合并多个OCR结果"""
if not ocr_results:
return ""
# 按置信度排序
sorted_results = sorted(ocr_results, key=lambda x: x.confidence, reverse=True)
# 使用最高置信度的结果作为主要结果
primary_result = sorted_results[0]
# 如果有多个高置信度结果,尝试合并
if len(sorted_results) > 1 and sorted_results[1].confidence > 0.7:
# 简单的文本合并策略
merged_text = self._smart_text_merge([r.text for r in sorted_results[:3]])
return merged_text
return primary_result.text
def _smart_text_merge(self, texts: List[str]) -> str:
"""智能文本合并"""
if not texts:
return ""
if len(texts) == 1:
return texts[0]
# 简单的合并策略:选择最长的文本
return max(texts, key=len)
def _extract_calorie_info(self, text: str, user_data: UserData) -> List[CalorieInfo]:
"""从文本中提取热量信息"""
calorie_infos = []
try:
# 查找热量数值
for pattern in self.calorie_patterns:
matches = re.finditer(pattern, text, re.IGNORECASE)
for match in matches:
calories = float(match.group(1))
# 查找对应的食物名称
food_name = self._extract_food_name(text, match.start())
calorie_info = CalorieInfo(
food_name=food_name,
calories=calories,
serving_size=None,
confidence=0.8, # OCR基础置信度
source='ocr',
raw_text=match.group(0),
validation_status='pending'
)
calorie_infos.append(calorie_info)
# 如果没有找到热量信息,尝试查找食物名称
if not calorie_infos:
food_names = self._extract_all_food_names(text)
for food_name in food_names:
calorie_info = CalorieInfo(
food_name=food_name,
calories=None,
serving_size=None,
confidence=0.6,
source='ocr',
raw_text=food_name,
validation_status='pending'
)
calorie_infos.append(calorie_info)
return calorie_infos
except Exception as e:
self.logger.error(f"热量信息提取失败: {e}")
return []
def _extract_food_name(self, text: str, calorie_position: int) -> str:
"""提取食物名称"""
try:
# 在热量数值前后查找食物名称
context_start = max(0, calorie_position - 50)
context_end = min(len(text), calorie_position + 50)
context = text[context_start:context_end]
# 查找中文和英文食物名称
food_pattern = r'([a-zA-Z\u4e00-\u9fff]{2,20})'
matches = re.findall(food_pattern, context)
if matches:
# 选择最可能的食物名称
return matches[0]
return "未知食物"
except Exception as e:
self.logger.error(f"食物名称提取失败: {e}")
return "未知食物"
def _extract_all_food_names(self, text: str) -> List[str]:
"""提取所有可能的食物名称"""
try:
food_pattern = r'([a-zA-Z\u4e00-\u9fff]{2,20})'
matches = re.findall(food_pattern, text)
# 去重并过滤
unique_foods = list(set(matches))
return unique_foods[:5] # 最多返回5个
except Exception as e:
self.logger.error(f"食物名称提取失败: {e}")
return []
def _validate_with_database(self, calorie_infos: List[CalorieInfo], user_data: UserData) -> List[CalorieInfo]:
"""使用数据库验证热量信息"""
validated_infos = []
for info in calorie_infos:
try:
# 在食物数据库中查找匹配
db_match = self._find_database_match(info.food_name)
if db_match:
# 使用数据库信息更新
info.calories = db_match.get('calories', info.calories)
info.serving_size = db_match.get('serving_size', info.serving_size)
info.confidence = max(info.confidence, 0.9)
info.source = 'database'
# 应用用户学习数据
user_correction = self._get_user_correction(user_data.user_id, info.food_name)
if user_correction:
info.calories = user_correction.get('calories', info.calories)
info.confidence = max(info.confidence, 0.95)
info.source = 'user_confirmed'
validated_infos.append(info)
except Exception as e:
self.logger.error(f"数据库验证失败: {e}")
validated_infos.append(info)
return validated_infos
def _find_database_match(self, food_name: str) -> Optional[Dict[str, Any]]:
"""在数据库中查找食物匹配"""
try:
# 精确匹配
if food_name in self.food_database:
return self.food_database[food_name]
# 模糊匹配
for db_food, info in self.food_database.items():
if food_name in db_food or db_food in food_name:
return info
return None
except Exception as e:
self.logger.error(f"数据库匹配失败: {e}")
return None
def _get_user_correction(self, user_id: str, food_name: str) -> Optional[Dict[str, Any]]:
"""获取用户修正数据"""
try:
user_data = self.user_corrections.get(user_id, {})
return user_data.get(food_name)
except Exception as e:
self.logger.error(f"获取用户修正数据失败: {e}")
return None
def _generate_suggestions(self, calorie_infos: List[CalorieInfo], user_data: UserData) -> List[str]:
"""生成建议"""
suggestions = []
try:
for info in calorie_infos:
if info.calories is None:
suggestions.append(f"未识别到 {info.food_name} 的热量信息,请手动输入")
elif info.confidence < 0.8:
suggestions.append(f"{info.food_name} 的热量 {info.calories} 可能不准确,请确认")
else:
suggestions.append(f"{info.food_name}: {info.calories} 卡路里")
if not calorie_infos:
suggestions.append("未识别到任何食物信息,请检查图片质量或手动输入")
return suggestions
except Exception as e:
self.logger.error(f"生成建议失败: {e}")
return ["识别过程中出现错误"]
def _calculate_overall_confidence(self, ocr_results: List[OCRResult], calorie_infos: List[CalorieInfo]) -> float:
"""计算整体置信度"""
try:
if not ocr_results and not calorie_infos:
return 0.0
# OCR置信度
ocr_confidence = sum(r.confidence for r in ocr_results) / len(ocr_results) if ocr_results else 0.0
# 热量信息置信度
calorie_confidence = sum(info.confidence for info in calorie_infos) / len(calorie_infos) if calorie_infos else 0.0
# 综合置信度
overall_confidence = (ocr_confidence * 0.4 + calorie_confidence * 0.6)
return min(overall_confidence, 1.0)
except Exception as e:
self.logger.error(f"计算置信度失败: {e}")
return 0.0
def _validate_recognition_result(self, input_data: Dict[str, Any], user_data: UserData) -> Dict[str, Any]:
"""验证识别结果"""
try:
food_name = input_data.get('food_name')
calories = input_data.get('calories')
is_correct = input_data.get('is_correct', True)
if not is_correct:
# 用户修正
corrected_calories = input_data.get('corrected_calories')
self._save_user_correction(user_data.user_id, food_name, corrected_calories)
return {
'success': True,
'message': '验证结果已保存',
'confidence': 1.0
}
except Exception as e:
self.logger.error(f"验证识别结果失败: {e}")
return self._create_error_result(f"验证失败: {str(e)}")
def _learn_from_correction(self, input_data: Dict[str, Any], user_data: UserData) -> Dict[str, Any]:
"""从用户修正中学习"""
try:
food_name = input_data.get('food_name')
corrected_calories = input_data.get('corrected_calories')
self._save_user_correction(user_data.user_id, food_name, corrected_calories)
return {
'success': True,
'message': '学习数据已保存',
'confidence': 1.0
}
except Exception as e:
self.logger.error(f"学习修正数据失败: {e}")
return self._create_error_result(f"学习失败: {str(e)}")
def _save_user_correction(self, user_id: str, food_name: str, calories: float):
"""保存用户修正数据"""
try:
if user_id not in self.user_corrections:
self.user_corrections[user_id] = {}
self.user_corrections[user_id][food_name] = {
'calories': calories,
'timestamp': datetime.now().isoformat(),
'correction_count': self.user_corrections[user_id].get(food_name, {}).get('correction_count', 0) + 1
}
# 保存到文件
self._save_user_corrections_to_file()
except Exception as e:
self.logger.error(f"保存用户修正数据失败: {e}")
def _load_food_database(self) -> Dict[str, Dict[str, Any]]:
"""加载食物数据库"""
try:
# 基础食物数据库
food_db = {
"米饭": {"calories": 130, "serving_size": "100g"},
"面条": {"calories": 110, "serving_size": "100g"},
"馒头": {"calories": 221, "serving_size": "100g"},
"包子": {"calories": 250, "serving_size": "100g"},
"饺子": {"calories": 250, "serving_size": "100g"},
"鸡蛋": {"calories": 155, "serving_size": "100g"},
"豆腐": {"calories": 76, "serving_size": "100g"},
"鱼肉": {"calories": 206, "serving_size": "100g"},
"鸡肉": {"calories": 165, "serving_size": "100g"},
"瘦肉": {"calories": 250, "serving_size": "100g"},
"青菜": {"calories": 25, "serving_size": "100g"},
"西红柿": {"calories": 18, "serving_size": "100g"},
"胡萝卜": {"calories": 41, "serving_size": "100g"},
"土豆": {"calories": 77, "serving_size": "100g"},
"西兰花": {"calories": 34, "serving_size": "100g"},
"苹果": {"calories": 52, "serving_size": "100g"},
"香蕉": {"calories": 89, "serving_size": "100g"},
"橙子": {"calories": 47, "serving_size": "100g"},
"葡萄": {"calories": 67, "serving_size": "100g"},
"草莓": {"calories": 32, "serving_size": "100g"},
"牛奶": {"calories": 42, "serving_size": "100ml"},
"酸奶": {"calories": 59, "serving_size": "100g"},
"豆浆": {"calories": 31, "serving_size": "100ml"},
"坚果": {"calories": 607, "serving_size": "100g"},
"红枣": {"calories": 264, "serving_size": "100g"},
}
# 尝试从文件加载扩展数据库
db_file = Path("data/food_database.json")
if db_file.exists():
with open(db_file, 'r', encoding='utf-8') as f:
extended_db = json.load(f)
food_db.update(extended_db)
return food_db
except Exception as e:
self.logger.error(f"加载食物数据库失败: {e}")
return {}
def _create_directories(self):
"""创建必要的目录"""
directories = [
'data/ocr_cache',
'data/user_corrections',
'data/food_images'
]
for directory in directories:
Path(directory).mkdir(parents=True, exist_ok=True)
def _load_user_corrections(self):
"""加载用户修正数据"""
try:
corrections_file = Path("data/user_corrections.json")
if corrections_file.exists():
with open(corrections_file, 'r', encoding='utf-8') as f:
self.user_corrections = json.load(f)
else:
self.user_corrections = {}
except Exception as e:
self.logger.error(f"加载用户修正数据失败: {e}")
self.user_corrections = {}
def _save_user_corrections_to_file(self):
"""保存用户修正数据到文件"""
try:
corrections_file = Path("data/user_corrections.json")
with open(corrections_file, 'w', encoding='utf-8') as f:
json.dump(self.user_corrections, f, ensure_ascii=False, indent=2)
except Exception as e:
self.logger.error(f"保存用户修正数据失败: {e}")
def _initialize_ocr_engines(self):
"""初始化OCR引擎"""
try:
# 检查Tesseract是否可用
try:
pytesseract.get_tesseract_version()
self.logger.info("Tesseract OCR引擎可用")
except Exception:
self.logger.warning("Tesseract OCR引擎不可用")
# 其他OCR引擎将在需要时初始化
self.logger.info("OCR引擎初始化完成")
except Exception as e:
self.logger.error(f"OCR引擎初始化失败: {e}")
def _create_error_result(self, message: str) -> Dict[str, Any]:
"""创建错误结果"""
return {
'success': False,
'error': message,
'confidence': 0.0
}
def cleanup(self) -> bool:
"""清理资源"""
try:
# 保存用户修正数据
self._save_user_corrections_to_file()
# 清理OCR引擎
self.ocr_engines.clear()
self.is_initialized = False
self.logger.info("OCR热量识别模块清理完成")
return True
except Exception as e:
self.logger.error(f"OCR热量识别模块清理失败: {e}")
return False
if __name__ == "__main__":
# 测试OCR模块
from core.base import BaseConfig
config = BaseConfig()
ocr_module = OCRCalorieRecognitionModule(config)
if ocr_module.initialize():
print("OCR模块初始化成功")
# 测试图片识别
test_data = {
'type': 'recognize_image',
'image_path': 'test_image.jpg' # 需要提供测试图片
}
# 这里需要用户数据,暂时跳过实际测试
print("OCR模块测试完成")
else:
print("OCR模块初始化失败")