#!/usr/bin/env python3
"""
CLI命令实现
"""
import argparse
from pathlib import Path
from typing import Dict, Any, List, Optional
import sys
import json
from ..config import (
get_config, load_config_from_file,
DATA_DIR, IMAGES_DIR, REPORTS_DIR, OUTPUTS_DIR,
ModelTypes
)
from ..models import create_evaluator, get_supported_models
from ..utils import ReportGenerator, setup_logging, get_logger, create_progress_logger
[文档]
class BaseCommand:
"""命令基类"""
[文档]
def __init__(self):
self.logger = None
[文档]
def setup_logging(self, args):
"""设置日志系统"""
# 从命令行参数或配置获取日志级别
log_level = getattr(args, 'log_level', None)
if log_level:
config_dict = {'logging': {'level': log_level}}
setup_logging(config_dict)
else:
config = get_config()
setup_logging(config.config)
self.logger = get_logger(self.__class__.__name__)
[文档]
def add_common_arguments(self, parser: argparse.ArgumentParser):
"""添加通用参数"""
parser.add_argument(
'--config', '-c',
type=Path,
help='配置文件路径 (YAML或JSON格式)'
)
parser.add_argument(
'--log-level',
choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
help='日志级别'
)
parser.add_argument(
'--verbose', '-v',
action='store_true',
help='详细输出 (相当于 --log-level DEBUG)'
)
[文档]
def load_config(self, args):
"""加载配置"""
if args.verbose:
args.log_level = 'DEBUG'
if args.config and args.config.exists():
load_config_from_file(args.config)
if hasattr(args, 'log_level') and args.log_level:
config = get_config()
config.set('logging.level', args.log_level)
[文档]
class EvaluateCommand(BaseCommand):
"""评估命令"""
[文档]
@staticmethod
def add_arguments(parser: argparse.ArgumentParser):
"""添加评估命令参数"""
parser.add_argument(
'model',
choices=get_supported_models(),
help='要使用的模型类型'
)
parser.add_argument(
'--images-dir', '-i',
type=Path,
default=IMAGES_DIR,
help=f'图片目录路径 (默认: {IMAGES_DIR})'
)
parser.add_argument(
'--output-dir', '-o',
type=Path,
default=REPORTS_DIR,
help=f'报告输出目录 (默认: {REPORTS_DIR})'
)
parser.add_argument(
'--model-config',
type=str,
help='模型配置 (JSON格式字符串)'
)
parser.add_argument(
'--report-format',
choices=['markdown', 'json', 'both'],
default='both',
help='报告格式 (默认: both)'
)
parser.add_argument(
'--report-name',
type=str,
help='自定义报告文件名前缀'
)
[文档]
def run(self, args):
"""执行评估命令"""
self.load_config(args)
self.setup_logging(args)
self.logger.info("🚀 开始OCR模型评估")
self.logger.info(f"模型类型: {args.model}")
self.logger.info(f"图片目录: {args.images_dir}")
self.logger.info(f"输出目录: {args.output_dir}")
try:
# 验证输入目录
if not args.images_dir.exists():
self.logger.error(f"图片目录不存在: {args.images_dir}")
return 1
# 创建输出目录
args.output_dir.mkdir(parents=True, exist_ok=True)
# 解析模型配置
model_config = None
if args.model_config:
try:
model_config = json.loads(args.model_config)
except json.JSONDecodeError as e:
self.logger.error(f"模型配置JSON格式错误: {e}")
return 1
# 创建评估器
self.logger.info("正在创建评估器...")
evaluator = create_evaluator(args.model, model_config)
# 执行评估
self.logger.info("正在执行数据集评估...")
summary = evaluator.evaluate_dataset(args.images_dir)
if summary is None:
self.logger.error("评估失败")
return 1
# 生成报告
self.logger.info("正在生成评估报告...")
report_generator = ReportGenerator(args.output_dir)
report_files = []
if args.report_format in ['markdown', 'both']:
markdown_file = report_generator.save_markdown_report(
summary,
args.report_name + '.md' if args.report_name else None
)
report_files.append(markdown_file)
if args.report_format in ['json', 'both']:
json_file = report_generator.save_json_results(
summary,
args.report_name + '.json' if args.report_name else None
)
report_files.append(json_file)
# 显示结果摘要
self._show_summary(summary)
# 显示生成的文件
self.logger.info("\n📄 生成的报告文件:")
for file_path in report_files:
self.logger.info(f" - {file_path}")
self.logger.info("✅ 评估完成!")
return 0
except Exception as e:
self.logger.error(f"评估过程中发生错误: {e}", exc_info=True)
return 1
def _show_summary(self, summary):
"""显示评估结果摘要"""
self.logger.info("\n📊 评估结果摘要:")
self.logger.info(f" 模型: {summary.model_name}")
self.logger.info(f" 总图片数: {summary.total_images}")
self.logger.info(f" 总体准确率: {summary.overall_accuracy:.4f} ({summary.overall_accuracy*100:.2f}%)")
self.logger.info(f" 完全匹配率: {summary.overall_exact_match_rate:.4f} ({summary.overall_exact_match_rate*100:.2f}%)")
# 显示各目录结果
self.logger.info("\n📁 分目录结果:")
for dir_result in summary.directory_results:
dir_name = dir_result.directory.name
self.logger.info(f" {dir_name}: {dir_result.total_images}张图片, "
f"准确率 {dir_result.average_accuracy:.4f} ({dir_result.average_accuracy*100:.2f}%)")
[文档]
class CompareCommand(BaseCommand):
"""模型对比命令"""
[文档]
@staticmethod
def add_arguments(parser: argparse.ArgumentParser):
"""添加对比命令参数"""
parser.add_argument(
'models',
nargs='+',
choices=get_supported_models(),
help='要对比的模型类型 (可指定多个)'
)
parser.add_argument(
'--images-dir', '-i',
type=Path,
default=IMAGES_DIR,
help=f'图片目录路径 (默认: {IMAGES_DIR})'
)
parser.add_argument(
'--output-dir', '-o',
type=Path,
default=REPORTS_DIR,
help=f'报告输出目录 (默认: {REPORTS_DIR})'
)
parser.add_argument(
'--comparison-report',
type=str,
default='model_comparison',
help='对比报告文件名前缀 (默认: model_comparison)'
)
[文档]
def run(self, args):
"""执行模型对比命令"""
self.load_config(args)
self.setup_logging(args)
self.logger.info("🔄 开始模型对比评估")
self.logger.info(f"对比模型: {', '.join(args.models)}")
try:
# 验证输入
if len(args.models) < 2:
self.logger.error("至少需要指定两个模型进行对比")
return 1
if not args.images_dir.exists():
self.logger.error(f"图片目录不存在: {args.images_dir}")
return 1
args.output_dir.mkdir(parents=True, exist_ok=True)
# 评估每个模型
summaries = []
for model_type in args.models:
self.logger.info(f"\n🔍 评估模型: {model_type}")
evaluator = create_evaluator(model_type)
summary = evaluator.evaluate_dataset(args.images_dir)
if summary:
summaries.append(summary)
self.logger.info(f" {model_type} 完成: 准确率 {summary.overall_accuracy:.4f}")
else:
self.logger.warning(f" {model_type} 评估失败")
if len(summaries) < 2:
self.logger.error("至少需要两个成功的评估结果才能进行对比")
return 1
# 生成对比报告
self.logger.info("\n📊 生成对比报告...")
comparison_report = self._generate_comparison_report(summaries)
report_file = args.output_dir / f"{args.comparison_report}.md"
with open(report_file, 'w', encoding='utf-8') as f:
f.write(comparison_report)
self.logger.info(f"📄 对比报告已保存至: {report_file}")
# 显示对比摘要
self._show_comparison_summary(summaries)
return 0
except Exception as e:
self.logger.error(f"对比过程中发生错误: {e}", exc_info=True)
return 1
def _generate_comparison_report(self, summaries: List) -> str:
"""生成对比报告"""
from datetime import datetime
report = []
report.append("# OCR模型对比评估报告")
report.append("")
report.append(f"**生成时间**: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
report.append(f"**对比模型**: {', '.join([s.model_name for s in summaries])}")
report.append("")
# 总体对比表格
report.append("## 📊 总体性能对比")
report.append("")
report.append("| 模型 | 总图片数 | 总体准确率 | 完全匹配率 | 平均处理时间 |")
report.append("|------|----------|------------|------------|--------------|")
for summary in summaries:
processing_time = summary.technical_details.get('average_processing_time', 0)
report.append(f"| {summary.model_name} | {summary.total_images} | "
f"{summary.overall_accuracy:.4f} ({summary.overall_accuracy*100:.2f}%) | "
f"{summary.overall_exact_match_rate:.4f} ({summary.overall_exact_match_rate*100:.2f}%) | "
f"{processing_time:.3f}s |")
report.append("")
# 分目录对比
report.append("## 📁 分目录性能对比")
report.append("")
# 获取所有目录
all_directories = set()
for summary in summaries:
for dir_result in summary.directory_results:
all_directories.add(dir_result.directory.name)
for directory in sorted(all_directories):
report.append(f"### {directory}")
report.append("")
report.append("| 模型 | 图片数量 | 准确率 | 完全匹配率 |")
report.append("|------|----------|--------|------------|")
for summary in summaries:
dir_result = next(
(dr for dr in summary.directory_results if dr.directory.name == directory),
None
)
if dir_result:
report.append(f"| {summary.model_name} | {dir_result.total_images} | "
f"{dir_result.average_accuracy:.4f} ({dir_result.average_accuracy*100:.2f}%) | "
f"{dir_result.exact_match_rate:.4f} ({dir_result.exact_match_rate*100:.2f}%) |")
else:
report.append(f"| {summary.model_name} | - | - | - |")
report.append("")
# 结论和建议
report.append("## 💡 结论和建议")
report.append("")
# 找出表现最好的模型
best_model = max(summaries, key=lambda s: s.overall_accuracy)
report.append(f"### 最佳整体性能: {best_model.model_name}")
report.append(f"- 准确率: {best_model.overall_accuracy:.4f} ({best_model.overall_accuracy*100:.2f}%)")
report.append(f"- 完全匹配率: {best_model.overall_exact_match_rate:.4f} ({best_model.overall_exact_match_rate*100:.2f}%)")
report.append("")
# 各模型优势分析
report.append("### 模型特性分析")
report.append("")
for summary in summaries:
model_type = summary.technical_details.get('model_type', 'unknown')
if model_type == 'paddleocr':
report.append(f"**{summary.model_name}** (专业OCR模型):")
report.append("- ✅ 专门针对文字识别优化")
report.append("- ✅ 处理速度快")
report.append("- ❌ 功能相对单一")
elif model_type == 'qwen_vl':
report.append(f"**{summary.model_name}** (多模态LLM):")
report.append("- ✅ 具备图像理解能力")
report.append("- ✅ 可处理复杂场景")
report.append("- ❌ 推理速度相对较慢")
report.append("")
return "\n".join(report)
def _show_comparison_summary(self, summaries: List):
"""显示对比摘要"""
self.logger.info("\n📊 模型对比摘要:")
# 按准确率排序
sorted_summaries = sorted(summaries, key=lambda s: s.overall_accuracy, reverse=True)
for i, summary in enumerate(sorted_summaries, 1):
self.logger.info(f" {i}. {summary.model_name}: "
f"准确率 {summary.overall_accuracy:.4f} ({summary.overall_accuracy*100:.2f}%), "
f"匹配率 {summary.overall_exact_match_rate:.4f} ({summary.overall_exact_match_rate*100:.2f}%)")
[文档]
class ConfigCommand(BaseCommand):
"""配置管理命令"""
[文档]
@staticmethod
def add_arguments(parser: argparse.ArgumentParser):
"""添加配置命令参数"""
subparsers = parser.add_subparsers(dest='config_action', help='配置操作')
# 显示配置
show_parser = subparsers.add_parser('show', help='显示当前配置')
show_parser.add_argument(
'--key', '-k',
type=str,
help='显示特定配置项 (支持点分割路径,如 models.paddleocr.lang)'
)
# 设置配置
set_parser = subparsers.add_parser('set', help='设置配置项')
set_parser.add_argument('key', help='配置项键名 (支持点分割路径)')
set_parser.add_argument('value', help='配置项值')
# 生成默认配置文件
generate_parser = subparsers.add_parser('generate', help='生成默认配置文件')
generate_parser.add_argument(
'--output', '-o',
type=Path,
default='config.yaml',
help='输出文件路径 (默认: config.yaml)'
)
generate_parser.add_argument(
'--format',
choices=['yaml', 'json'],
default='yaml',
help='配置文件格式 (默认: yaml)'
)
[文档]
def run(self, args):
"""执行配置命令"""
self.load_config(args)
self.setup_logging(args)
if args.config_action == 'show':
return self._show_config(args)
elif args.config_action == 'set':
return self._set_config(args)
elif args.config_action == 'generate':
return self._generate_config(args)
else:
self.logger.error("未指定配置操作,使用 --help 查看可用操作")
return 1
def _show_config(self, args) -> int:
"""显示配置"""
config = get_config()
if args.key:
# 显示特定配置项
value = config.get(args.key)
if value is not None:
self.logger.info(f"{args.key}: {value}")
else:
self.logger.error(f"配置项 '{args.key}' 不存在")
return 1
else:
# 显示完整配置
import yaml
config_yaml = yaml.dump(config.config, default_flow_style=False,
allow_unicode=True, indent=2)
print(config_yaml)
return 0
def _set_config(self, args) -> int:
"""设置配置项"""
config = get_config()
# 尝试解析值的类型
value = args.value
if value.lower() in ('true', 'false'):
value = value.lower() == 'true'
elif value.isdigit():
value = int(value)
elif value.replace('.', '').isdigit():
value = float(value)
config.set(args.key, value)
self.logger.info(f"配置项 '{args.key}' 已设置为: {value}")
return 0
def _generate_config(self, args) -> int:
"""生成默认配置文件"""
try:
config = get_config()
config.save_config(args.output)
self.logger.info(f"默认配置文件已生成: {args.output}")
return 0
except Exception as e:
self.logger.error(f"生成配置文件失败: {e}")
return 1