加速错题 OCR:上传压缩、识别缩放、启动预热与 MKL-DNN。

This commit is contained in:
dekun
2026-06-28 14:12:01 +08:00
parent 6200dbb596
commit 14bf314544
8 changed files with 116 additions and 26 deletions
+4 -1
View File
@@ -19,9 +19,12 @@ class Settings(BaseSettings):
FRONTEND_DIST: str = "../frontend/dist"
ADMIN_DEFAULT_USERNAME: str = "admin"
ADMIN_DEFAULT_PASSWORD: str = "admin123"
OCR_TIMEOUT_SECONDS: int = 300
OCR_TIMEOUT_SECONDS: int = 180
AI_TIMEOUT_SECONDS: int = 600
PROCESS_STALE_MINUTES: int = 20
OCR_MAX_SIDE: int = 1280
UPLOAD_MAX_SIDE: int = 2048
OCR_WARMUP: bool = True
class Config:
env_file = ".env"
+2
View File
@@ -9,6 +9,7 @@ from fastapi.staticfiles import StaticFiles
from app.core.config import settings
from app.core.database import Base, SessionLocal, engine
from app.routers import admin, auth, exams, export, settings as settings_router, students, subjects, wrong_questions
from app.services import ocr as ocr_service
from app.services.migrate import run_migrations
from app.services.seed import seed_admin_and_settings, seed_subjects
@@ -34,6 +35,7 @@ async def lifespan(app: FastAPI):
seed_admin_and_settings(db)
finally:
db.close()
ocr_service.warmup_ocr_engine()
yield
+1 -2
View File
@@ -157,8 +157,7 @@ def _process_wrong_question(question_id: uuid.UUID):
except FuturesTimeout:
wq.status = WrongQuestionStatus.failed
wq.error_message = (
f"OCR 识别超时(>{settings.OCR_TIMEOUT_SECONDS}秒)"
" 首次加载模型较慢,请稍后点「重新识别标注」重试"
f"OCR 识别超时(>{settings.OCR_TIMEOUT_SECONDS}秒),请点「重新识别标注」重试"
)
db.commit()
return
+101 -15
View File
@@ -1,6 +1,9 @@
from pathlib import Path
import logging
import os
import tempfile
import threading
from io import BytesIO
from pathlib import Path
from PIL import Image
@@ -9,7 +12,10 @@ from app.core.config import settings
# 无图形界面服务器:避免 OpenCV/Paddle 依赖 X11
os.environ.setdefault("OPENCV_IO_ENABLE_OPENEXR", "0")
logger = logging.getLogger(__name__)
_ocr_engine = None
_ocr_warmup_started = False
def get_ocr_engine():
@@ -17,20 +23,103 @@ def get_ocr_engine():
if _ocr_engine is None:
from paddleocr import PaddleOCR
_ocr_engine = PaddleOCR(use_angle_cls=True, lang="ch", show_log=False)
_ocr_engine = PaddleOCR(
use_angle_cls=False,
lang="ch",
show_log=False,
use_gpu=False,
enable_mkldnn=True,
det_limit_side_len=min(settings.OCR_MAX_SIDE, 1280),
rec_batch_num=8,
)
return _ocr_engine
def warmup_ocr_engine() -> None:
"""后台预加载 OCR 模型,避免首张图片等待数分钟。"""
global _ocr_warmup_started
if _ocr_warmup_started or not settings.OCR_WARMUP:
return
_ocr_warmup_started = True
def _run() -> None:
try:
engine = get_ocr_engine()
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
Image.new("RGB", (120, 40), color=(255, 255, 255)).save(tmp.name, format="JPEG")
tmp_path = tmp.name
try:
engine.ocr(tmp_path, cls=False)
logger.info("OCR engine warmed up")
finally:
Path(tmp_path).unlink(missing_ok=True)
except Exception as exc:
logger.warning("OCR warmup failed: %s", exc)
threading.Thread(target=_run, daemon=True, name="ocr-warmup").start()
def _bbox_from_box(box: list) -> list[float]:
xs = [float(p[0]) for p in box]
ys = [float(p[1]) for p in box]
return [min(xs), min(ys), max(xs), max(ys)]
def _scale_bbox(bbox: list[float], scale_x: float, scale_y: float) -> list[float]:
return [bbox[0] * scale_x, bbox[1] * scale_y, bbox[2] * scale_x, bbox[3] * scale_y]
def _scale_box(box: list, scale_x: float, scale_y: float) -> list:
return [[float(p[0]) * scale_x, float(p[1]) * scale_y] for p in box]
def _normalize_image_bytes(content: bytes, max_side: int) -> bytes:
with Image.open(BytesIO(content)) as img:
img = img.convert("RGB")
width, height = img.size
longest = max(width, height)
if longest > max_side:
ratio = max_side / longest
img = img.resize((int(width * ratio), int(height * ratio)), Image.Resampling.LANCZOS)
buf = BytesIO()
img.save(buf, format="JPEG", quality=88, optimize=True)
return buf.getvalue()
def _prepare_ocr_image(image_path: str) -> tuple[str, float, float, int, int, Path | None]:
"""若图片过大则生成临时缩小图供 OCR,返回缩放比例与原始尺寸。"""
with Image.open(image_path) as img:
orig_w, orig_h = img.size
max_side = settings.OCR_MAX_SIDE
longest = max(orig_w, orig_h)
if longest <= max_side:
return image_path, 1.0, 1.0, orig_w, orig_h, None
with Image.open(image_path) as img:
img = img.convert("RGB")
ratio = max_side / longest
new_w = max(1, int(orig_w * ratio))
new_h = max(1, int(orig_h * ratio))
resized = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
tmp = Path(tempfile.gettempdir()) / f"ocr_{Path(image_path).stem}_{os.getpid()}.jpg"
resized.save(tmp, format="JPEG", quality=85, optimize=True)
scale_x = orig_w / new_w
scale_y = orig_h / new_h
return str(tmp), scale_x, scale_y, orig_w, orig_h, tmp
def run_ocr_with_regions(image_path: str) -> dict:
"""Return OCR text plus line-level bounding boxes for annotation."""
engine = get_ocr_engine()
result = engine.ocr(image_path, cls=True)
ocr_path, scale_x, scale_y, orig_w, orig_h, tmp_path = _prepare_ocr_image(image_path)
try:
result = engine.ocr(ocr_path, cls=False)
finally:
if tmp_path is not None:
tmp_path.unlink(missing_ok=True)
lines: list[dict] = []
if result and result[0]:
for item in result[0]:
@@ -41,27 +130,23 @@ def run_ocr_with_regions(image_path: str) -> dict:
conf = float(rec[1]) if rec and len(rec) > 1 else 0.0
if not text:
continue
if scale_x != 1.0 or scale_y != 1.0:
box = _scale_box(box, scale_x, scale_y)
bbox = _bbox_from_box(box)
lines.append(
{
"text": text,
"confidence": conf,
"box": box,
"bbox": _bbox_from_box(box),
"bbox": bbox,
}
)
width, height = 0, 0
try:
with Image.open(image_path) as img:
width, height = img.size
except OSError:
pass
return {
"text": "\n".join(line["text"] for line in lines),
"lines": lines,
"width": width,
"height": height,
"width": orig_w,
"height": orig_h,
}
@@ -77,7 +162,8 @@ def save_upload_file(user_id: str, question_id: str, filename: str, content: byt
user_dir.mkdir(parents=True, exist_ok=True)
rel_path = f"{user_id}/{question_id}{ext}"
full_path = Path(settings.UPLOAD_DIR) / rel_path
full_path.write_bytes(content)
normalized = _normalize_image_bytes(content, settings.UPLOAD_MAX_SIDE)
full_path.write_bytes(normalized)
return rel_path
File diff suppressed because one or more lines are too long
+1 -1
View File
@@ -9,7 +9,7 @@
<meta name="author" content="马建军" />
<meta name="copyright" content="Copyright (c) 马建军. All rights reserved." />
<title>中学成绩档案</title>
<script type="module" crossorigin src="/assets/index-_1CtLpiP.js"></script>
<script type="module" crossorigin src="/assets/index-DzzkB1zh.js"></script>
<link rel="stylesheet" crossorigin href="/assets/index-GY2etMYN.css">
</head>
<body>
+3 -3
View File
@@ -37,7 +37,7 @@ export default function WrongQuestionDetail({
try {
const { data } = await wrongQuestionApi.get(questionId)
setWq(data)
setQuestionText(data.question_text || '')
setQuestionText(data.question_text || data.ocr_raw_text || '')
setApproachText(data.solution_approach || '')
setSolutionText(data.solution_text || '')
setImageMode(data.has_annotated_image ? 'annotated' : 'original')
@@ -56,7 +56,7 @@ export default function WrongQuestionDetail({
try {
const { data } = await wrongQuestionApi.get(questionId)
setWq(data)
setQuestionText(data.question_text || '')
setQuestionText(data.question_text || data.ocr_raw_text || '')
setApproachText(data.solution_approach || '')
setSolutionText(data.solution_text || '')
if (data.has_annotated_image) setImageMode('annotated')
@@ -88,7 +88,7 @@ export default function WrongQuestionDetail({
try {
const { data } = await wrongQuestionApi.regenerate(questionId)
setWq(data)
setQuestionText(data.question_text || '')
setQuestionText(data.question_text || data.ocr_raw_text || '')
setApproachText(data.solution_approach || '')
setSolutionText(data.solution_text || '')
message.success('解题思路已重新生成')
+2 -2
View File
@@ -9,10 +9,10 @@ export function isWrongQuestionProcessing(wq: WrongQuestion): boolean {
export function processingHint(wq: WrongQuestion): string {
if (wq.status === 'pending') {
return '正在 OCR 识别(首次约 1–5 分钟,请稍候)…'
return '正在识别文字(约 1030 秒)…'
}
if (wq.status === 'ocr_done') {
return '正在标注错题并生成解题思路…'
return '正在标注错题并生成解题思路(约 3090 秒)…'
}
return '正在识别、标注并生成解题思路…'
}