ocr_evaluation.models.base 源代码

#!/usr/bin/env python3
"""
OCR评估器抽象基类
"""

from abc import ABC, abstractmethod
from pathlib import Path
from typing import Dict, Any, List, Optional, Tuple
from dataclasses import dataclass
import logging


[文档] @dataclass class EvaluationResult: """单个样本评估结果""" image_path: Path ground_truth: str predicted: str accuracy: float exact_match: bool metadata: Optional[Dict[str, Any]] = None
[文档] @dataclass class DirectoryResult: """目录评估结果""" directory: Path total_images: int average_accuracy: float exact_match_count: int exact_match_rate: float results: List[EvaluationResult] metadata: Optional[Dict[str, Any]] = None
[文档] @dataclass class TestSummary: """完整测试汇总结果""" model_name: str test_timestamp: str total_images: int overall_accuracy: float overall_exact_match_rate: float directory_results: List[DirectoryResult] technical_details: Dict[str, Any] metadata: Optional[Dict[str, Any]] = None
[文档] class BaseEvaluator(ABC): """OCR评估器基类 定义了所有OCR模型评估器必须实现的接口 """
[文档] def __init__(self, config: Optional[Dict[str, Any]] = None): """初始化评估器 Args: config: 模型配置字典 """ self.config = config or {} self.logger = logging.getLogger(self.__class__.__name__) self.model = None self.is_initialized = False # 技术细节记录 self.technical_details = { 'model_name': self.get_model_name(), 'model_type': self.get_model_type(), 'config': self.config.copy(), 'initialization_time': None, 'total_processing_time': None, 'average_processing_time': None }
[文档] @abstractmethod def get_model_name(self) -> str: """获取模型名称""" pass
[文档] @abstractmethod def get_model_type(self) -> str: """获取模型类型""" pass
[文档] @abstractmethod def initialize(self) -> bool: """初始化模型 Returns: bool: 初始化是否成功 """ pass
[文档] @abstractmethod def recognize_image(self, image_path: Path) -> str: """识别单张图片 Args: image_path: 图片路径 Returns: str: 识别结果文本 """ pass
[文档] @abstractmethod def cleanup(self): """清理资源""" pass
[文档] def validate_image(self, image_path: Path) -> bool: """验证图片文件 Args: image_path: 图片路径 Returns: bool: 图片是否有效 """ if not image_path.exists(): self.logger.warning(f"图片文件不存在: {image_path}") return False if not image_path.is_file(): self.logger.warning(f"路径不是文件: {image_path}") return False # 检查文件扩展名 from ..config import SUPPORTED_IMAGE_FORMATS if image_path.suffix.lower() not in SUPPORTED_IMAGE_FORMATS: self.logger.warning(f"不支持的图片格式: {image_path.suffix}") return False # 检查文件大小 from ..config import PerformanceConstants if image_path.stat().st_size > PerformanceConstants.MAX_IMAGE_SIZE: self.logger.warning(f"图片文件过大: {image_path}") return False return True
[文档] def parse_label_file(self, label_file: Path) -> Dict[str, str]: """解析Label.txt文件 Args: label_file: 标签文件路径 Returns: Dict[str, str]: 图片名到标注文本的映射 """ import json import os labels = {} try: with open(label_file, 'r', encoding='utf-8') as f: for line_num, line in enumerate(f, 1): line = line.strip() if not line: continue # 解析格式: 图片路径\t[{"transcription": "文本", ...}] parts = line.split('\t', 1) if len(parts) != 2: self.logger.warning(f"标签文件第{line_num}行格式错误: {line}") continue image_path = parts[0] image_name = os.path.basename(image_path) # 解析JSON标注 try: annotations = json.loads(parts[1]) if annotations and isinstance(annotations, list) and len(annotations) > 0: transcription = annotations[0].get('transcription', '') if transcription: labels[image_name] = transcription except json.JSONDecodeError as e: self.logger.warning(f"解析JSON标注失败,第{line_num}行: {e}") continue except Exception as e: self.logger.error(f"解析标签文件失败 {label_file}: {e}") return labels
[文档] def calculate_accuracy(self, ground_truth: str, predicted: str) -> float: """计算准确率 Args: ground_truth: 标准答案 predicted: 预测结果 Returns: float: 准确率 (0.0 到 1.0) """ if not ground_truth and not predicted: return 1.0 if not ground_truth or not predicted: return 0.0 # 完全匹配 if ground_truth == predicted: return 1.0 # 基于编辑距离的准确率 return self._levenshtein_accuracy(ground_truth, predicted)
def _levenshtein_accuracy(self, s1: str, s2: str) -> float: """计算基于Levenshtein编辑距离的准确率""" distance = self._levenshtein_distance(s1, s2) max_len = max(len(s1), len(s2)) return max(0.0, 1.0 - distance / max_len) if max_len > 0 else 1.0 def _levenshtein_distance(self, s1: str, s2: str) -> int: """计算Levenshtein编辑距离""" if len(s1) < len(s2): return self._levenshtein_distance(s2, s1) if len(s2) == 0: return len(s1) previous_row = list(range(len(s2) + 1)) for i, c1 in enumerate(s1): current_row = [i + 1] for j, c2 in enumerate(s2): insertions = previous_row[j + 1] + 1 deletions = current_row[j] + 1 substitutions = previous_row[j] + (c1 != c2) current_row.append(min(insertions, deletions, substitutions)) previous_row = current_row return previous_row[-1]
[文档] def evaluate_directory(self, directory: Path) -> Optional[DirectoryResult]: """评估单个目录 Args: directory: 包含图片和Label.txt的目录 Returns: DirectoryResult: 目录评估结果,如果失败返回None """ if not self.is_initialized: self.logger.error("模型未初始化") return None self.logger.info(f"评估目录: {directory}") # 查找标签文件 label_file = directory / "Label.txt" if not label_file.exists(): self.logger.warning(f"未找到标签文件: {label_file}") return None # 解析标签 labels = self.parse_label_file(label_file) if not labels: self.logger.warning(f"标签文件为空: {label_file}") return None self.logger.info(f"找到 {len(labels)} 个标签") # 处理每个图片 results = [] total_accuracy = 0.0 exact_matches = 0 for image_name, ground_truth in labels.items(): image_path = directory / image_name if not self.validate_image(image_path): continue # 识别图片 try: predicted_text = self.recognize_image(image_path) accuracy = self.calculate_accuracy(ground_truth, predicted_text) exact_match = ground_truth == predicted_text result = EvaluationResult( image_path=image_path, ground_truth=ground_truth, predicted=predicted_text, accuracy=accuracy, exact_match=exact_match ) results.append(result) total_accuracy += accuracy if exact_match: exact_matches += 1 self.logger.debug(f"图片: {image_name}, 标准答案: {ground_truth}, " f"识别结果: {predicted_text}, 准确率: {accuracy:.4f}") except Exception as e: self.logger.error(f"处理图片失败 {image_path}: {e}") continue if not results: self.logger.error(f"目录 {directory} 中没有有效的识别结果") return None # 计算统计信息 avg_accuracy = total_accuracy / len(results) exact_match_rate = exact_matches / len(results) return DirectoryResult( directory=directory, total_images=len(results), average_accuracy=avg_accuracy, exact_match_count=exact_matches, exact_match_rate=exact_match_rate, results=results )
[文档] def evaluate_dataset(self, images_dir: Path) -> Optional[TestSummary]: """评估完整数据集 Args: images_dir: 包含所有测试图片目录的根目录 Returns: TestSummary: 完整测试汇总结果 """ from datetime import datetime if not self.initialize(): self.logger.error("模型初始化失败") return None start_time = datetime.now() self.logger.info(f"开始 {self.get_model_name()} 数据集评估") directory_results = [] total_images = 0 total_accuracy = 0.0 total_exact_matches = 0 # 遍历所有子目录 for item in images_dir.iterdir(): if not item.is_dir(): continue # 检查是否直接包含Label.txt if (item / "Label.txt").exists(): result = self.evaluate_directory(item) if result: directory_results.append(result) total_images += result.total_images total_accuracy += result.average_accuracy * result.total_images total_exact_matches += result.exact_match_count else: # 检查子目录 for subitem in item.iterdir(): if subitem.is_dir() and (subitem / "Label.txt").exists(): result = self.evaluate_directory(subitem) if result: directory_results.append(result) total_images += result.total_images total_accuracy += result.average_accuracy * result.total_images total_exact_matches += result.exact_match_count # 计算总体统计 overall_accuracy = total_accuracy / total_images if total_images > 0 else 0.0 overall_exact_match_rate = total_exact_matches / total_images if total_images > 0 else 0.0 # 更新技术细节 end_time = datetime.now() processing_time = (end_time - start_time).total_seconds() self.technical_details.update({ 'total_processing_time': processing_time, 'average_processing_time': processing_time / total_images if total_images > 0 else 0, 'initialization_time': self.technical_details.get('initialization_time', 0) }) self.logger.info(f"评估完成,总体准确率: {overall_accuracy:.4f} ({overall_accuracy*100:.2f}%)") # 清理资源 self.cleanup() return TestSummary( model_name=self.get_model_name(), test_timestamp=end_time.isoformat(), total_images=total_images, overall_accuracy=overall_accuracy, overall_exact_match_rate=overall_exact_match_rate, directory_results=directory_results, technical_details=self.technical_details.copy() )