加速错题 OCR:上传压缩、识别缩放、启动预热与 MKL-DNN。

This commit is contained in:
dekun
2026-06-28 14:12:01 +08:00
parent 6200dbb596
commit 14bf314544
8 changed files with 116 additions and 26 deletions
+4 -1
View File
@@ -19,9 +19,12 @@ class Settings(BaseSettings):
FRONTEND_DIST: str = "../frontend/dist" FRONTEND_DIST: str = "../frontend/dist"
ADMIN_DEFAULT_USERNAME: str = "admin" ADMIN_DEFAULT_USERNAME: str = "admin"
ADMIN_DEFAULT_PASSWORD: str = "admin123" ADMIN_DEFAULT_PASSWORD: str = "admin123"
OCR_TIMEOUT_SECONDS: int = 300 OCR_TIMEOUT_SECONDS: int = 180
AI_TIMEOUT_SECONDS: int = 600 AI_TIMEOUT_SECONDS: int = 600
PROCESS_STALE_MINUTES: int = 20 PROCESS_STALE_MINUTES: int = 20
OCR_MAX_SIDE: int = 1280
UPLOAD_MAX_SIDE: int = 2048
OCR_WARMUP: bool = True
class Config: class Config:
env_file = ".env" env_file = ".env"
+2
View File
@@ -9,6 +9,7 @@ from fastapi.staticfiles import StaticFiles
from app.core.config import settings from app.core.config import settings
from app.core.database import Base, SessionLocal, engine from app.core.database import Base, SessionLocal, engine
from app.routers import admin, auth, exams, export, settings as settings_router, students, subjects, wrong_questions from app.routers import admin, auth, exams, export, settings as settings_router, students, subjects, wrong_questions
from app.services import ocr as ocr_service
from app.services.migrate import run_migrations from app.services.migrate import run_migrations
from app.services.seed import seed_admin_and_settings, seed_subjects from app.services.seed import seed_admin_and_settings, seed_subjects
@@ -34,6 +35,7 @@ async def lifespan(app: FastAPI):
seed_admin_and_settings(db) seed_admin_and_settings(db)
finally: finally:
db.close() db.close()
ocr_service.warmup_ocr_engine()
yield yield
+1 -2
View File
@@ -157,8 +157,7 @@ def _process_wrong_question(question_id: uuid.UUID):
except FuturesTimeout: except FuturesTimeout:
wq.status = WrongQuestionStatus.failed wq.status = WrongQuestionStatus.failed
wq.error_message = ( wq.error_message = (
f"OCR 识别超时(>{settings.OCR_TIMEOUT_SECONDS}秒)" f"OCR 识别超时(>{settings.OCR_TIMEOUT_SECONDS}秒),请点「重新识别标注」重试"
" 首次加载模型较慢,请稍后点「重新识别标注」重试"
) )
db.commit() db.commit()
return return
+101 -15
View File
@@ -1,6 +1,9 @@
from pathlib import Path import logging
import os import os
import tempfile
import threading
from io import BytesIO
from pathlib import Path
from PIL import Image from PIL import Image
@@ -9,7 +12,10 @@ from app.core.config import settings
# 无图形界面服务器:避免 OpenCV/Paddle 依赖 X11 # 无图形界面服务器:避免 OpenCV/Paddle 依赖 X11
os.environ.setdefault("OPENCV_IO_ENABLE_OPENEXR", "0") os.environ.setdefault("OPENCV_IO_ENABLE_OPENEXR", "0")
logger = logging.getLogger(__name__)
_ocr_engine = None _ocr_engine = None
_ocr_warmup_started = False
def get_ocr_engine(): def get_ocr_engine():
@@ -17,20 +23,103 @@ def get_ocr_engine():
if _ocr_engine is None: if _ocr_engine is None:
from paddleocr import PaddleOCR from paddleocr import PaddleOCR
_ocr_engine = PaddleOCR(use_angle_cls=True, lang="ch", show_log=False) _ocr_engine = PaddleOCR(
use_angle_cls=False,
lang="ch",
show_log=False,
use_gpu=False,
enable_mkldnn=True,
det_limit_side_len=min(settings.OCR_MAX_SIDE, 1280),
rec_batch_num=8,
)
return _ocr_engine return _ocr_engine
def warmup_ocr_engine() -> None:
"""后台预加载 OCR 模型,避免首张图片等待数分钟。"""
global _ocr_warmup_started
if _ocr_warmup_started or not settings.OCR_WARMUP:
return
_ocr_warmup_started = True
def _run() -> None:
try:
engine = get_ocr_engine()
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
Image.new("RGB", (120, 40), color=(255, 255, 255)).save(tmp.name, format="JPEG")
tmp_path = tmp.name
try:
engine.ocr(tmp_path, cls=False)
logger.info("OCR engine warmed up")
finally:
Path(tmp_path).unlink(missing_ok=True)
except Exception as exc:
logger.warning("OCR warmup failed: %s", exc)
threading.Thread(target=_run, daemon=True, name="ocr-warmup").start()
def _bbox_from_box(box: list) -> list[float]: def _bbox_from_box(box: list) -> list[float]:
xs = [float(p[0]) for p in box] xs = [float(p[0]) for p in box]
ys = [float(p[1]) for p in box] ys = [float(p[1]) for p in box]
return [min(xs), min(ys), max(xs), max(ys)] return [min(xs), min(ys), max(xs), max(ys)]
def _scale_bbox(bbox: list[float], scale_x: float, scale_y: float) -> list[float]:
return [bbox[0] * scale_x, bbox[1] * scale_y, bbox[2] * scale_x, bbox[3] * scale_y]
def _scale_box(box: list, scale_x: float, scale_y: float) -> list:
return [[float(p[0]) * scale_x, float(p[1]) * scale_y] for p in box]
def _normalize_image_bytes(content: bytes, max_side: int) -> bytes:
with Image.open(BytesIO(content)) as img:
img = img.convert("RGB")
width, height = img.size
longest = max(width, height)
if longest > max_side:
ratio = max_side / longest
img = img.resize((int(width * ratio), int(height * ratio)), Image.Resampling.LANCZOS)
buf = BytesIO()
img.save(buf, format="JPEG", quality=88, optimize=True)
return buf.getvalue()
def _prepare_ocr_image(image_path: str) -> tuple[str, float, float, int, int, Path | None]:
"""若图片过大则生成临时缩小图供 OCR,返回缩放比例与原始尺寸。"""
with Image.open(image_path) as img:
orig_w, orig_h = img.size
max_side = settings.OCR_MAX_SIDE
longest = max(orig_w, orig_h)
if longest <= max_side:
return image_path, 1.0, 1.0, orig_w, orig_h, None
with Image.open(image_path) as img:
img = img.convert("RGB")
ratio = max_side / longest
new_w = max(1, int(orig_w * ratio))
new_h = max(1, int(orig_h * ratio))
resized = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
tmp = Path(tempfile.gettempdir()) / f"ocr_{Path(image_path).stem}_{os.getpid()}.jpg"
resized.save(tmp, format="JPEG", quality=85, optimize=True)
scale_x = orig_w / new_w
scale_y = orig_h / new_h
return str(tmp), scale_x, scale_y, orig_w, orig_h, tmp
def run_ocr_with_regions(image_path: str) -> dict: def run_ocr_with_regions(image_path: str) -> dict:
"""Return OCR text plus line-level bounding boxes for annotation.""" """Return OCR text plus line-level bounding boxes for annotation."""
engine = get_ocr_engine() engine = get_ocr_engine()
result = engine.ocr(image_path, cls=True) ocr_path, scale_x, scale_y, orig_w, orig_h, tmp_path = _prepare_ocr_image(image_path)
try:
result = engine.ocr(ocr_path, cls=False)
finally:
if tmp_path is not None:
tmp_path.unlink(missing_ok=True)
lines: list[dict] = [] lines: list[dict] = []
if result and result[0]: if result and result[0]:
for item in result[0]: for item in result[0]:
@@ -41,27 +130,23 @@ def run_ocr_with_regions(image_path: str) -> dict:
conf = float(rec[1]) if rec and len(rec) > 1 else 0.0 conf = float(rec[1]) if rec and len(rec) > 1 else 0.0
if not text: if not text:
continue continue
if scale_x != 1.0 or scale_y != 1.0:
box = _scale_box(box, scale_x, scale_y)
bbox = _bbox_from_box(box)
lines.append( lines.append(
{ {
"text": text, "text": text,
"confidence": conf, "confidence": conf,
"box": box, "box": box,
"bbox": _bbox_from_box(box), "bbox": bbox,
} }
) )
width, height = 0, 0
try:
with Image.open(image_path) as img:
width, height = img.size
except OSError:
pass
return { return {
"text": "\n".join(line["text"] for line in lines), "text": "\n".join(line["text"] for line in lines),
"lines": lines, "lines": lines,
"width": width, "width": orig_w,
"height": height, "height": orig_h,
} }
@@ -77,7 +162,8 @@ def save_upload_file(user_id: str, question_id: str, filename: str, content: byt
user_dir.mkdir(parents=True, exist_ok=True) user_dir.mkdir(parents=True, exist_ok=True)
rel_path = f"{user_id}/{question_id}{ext}" rel_path = f"{user_id}/{question_id}{ext}"
full_path = Path(settings.UPLOAD_DIR) / rel_path full_path = Path(settings.UPLOAD_DIR) / rel_path
full_path.write_bytes(content) normalized = _normalize_image_bytes(content, settings.UPLOAD_MAX_SIDE)
full_path.write_bytes(normalized)
return rel_path return rel_path
File diff suppressed because one or more lines are too long
+1 -1
View File
@@ -9,7 +9,7 @@
<meta name="author" content="马建军" /> <meta name="author" content="马建军" />
<meta name="copyright" content="Copyright (c) 马建军. All rights reserved." /> <meta name="copyright" content="Copyright (c) 马建军. All rights reserved." />
<title>中学成绩档案</title> <title>中学成绩档案</title>
<script type="module" crossorigin src="/assets/index-_1CtLpiP.js"></script> <script type="module" crossorigin src="/assets/index-DzzkB1zh.js"></script>
<link rel="stylesheet" crossorigin href="/assets/index-GY2etMYN.css"> <link rel="stylesheet" crossorigin href="/assets/index-GY2etMYN.css">
</head> </head>
<body> <body>
+3 -3
View File
@@ -37,7 +37,7 @@ export default function WrongQuestionDetail({
try { try {
const { data } = await wrongQuestionApi.get(questionId) const { data } = await wrongQuestionApi.get(questionId)
setWq(data) setWq(data)
setQuestionText(data.question_text || '') setQuestionText(data.question_text || data.ocr_raw_text || '')
setApproachText(data.solution_approach || '') setApproachText(data.solution_approach || '')
setSolutionText(data.solution_text || '') setSolutionText(data.solution_text || '')
setImageMode(data.has_annotated_image ? 'annotated' : 'original') setImageMode(data.has_annotated_image ? 'annotated' : 'original')
@@ -56,7 +56,7 @@ export default function WrongQuestionDetail({
try { try {
const { data } = await wrongQuestionApi.get(questionId) const { data } = await wrongQuestionApi.get(questionId)
setWq(data) setWq(data)
setQuestionText(data.question_text || '') setQuestionText(data.question_text || data.ocr_raw_text || '')
setApproachText(data.solution_approach || '') setApproachText(data.solution_approach || '')
setSolutionText(data.solution_text || '') setSolutionText(data.solution_text || '')
if (data.has_annotated_image) setImageMode('annotated') if (data.has_annotated_image) setImageMode('annotated')
@@ -88,7 +88,7 @@ export default function WrongQuestionDetail({
try { try {
const { data } = await wrongQuestionApi.regenerate(questionId) const { data } = await wrongQuestionApi.regenerate(questionId)
setWq(data) setWq(data)
setQuestionText(data.question_text || '') setQuestionText(data.question_text || data.ocr_raw_text || '')
setApproachText(data.solution_approach || '') setApproachText(data.solution_approach || '')
setSolutionText(data.solution_text || '') setSolutionText(data.solution_text || '')
message.success('解题思路已重新生成') message.success('解题思路已重新生成')
+2 -2
View File
@@ -9,10 +9,10 @@ export function isWrongQuestionProcessing(wq: WrongQuestion): boolean {
export function processingHint(wq: WrongQuestion): string { export function processingHint(wq: WrongQuestion): string {
if (wq.status === 'pending') { if (wq.status === 'pending') {
return '正在 OCR 识别(首次约 1–5 分钟,请稍候)…' return '正在识别文字(约 1030 秒)…'
} }
if (wq.status === 'ocr_done') { if (wq.status === 'ocr_done') {
return '正在标注错题并生成解题思路…' return '正在标注错题并生成解题思路(约 3090 秒)…'
} }
return '正在识别、标注并生成解题思路…' return '正在识别、标注并生成解题思路…'
} }