加速错题 OCR:上传压缩、识别缩放、启动预热与 MKL-DNN。
This commit is contained in:
@@ -19,9 +19,12 @@ class Settings(BaseSettings):
|
||||
FRONTEND_DIST: str = "../frontend/dist"
|
||||
ADMIN_DEFAULT_USERNAME: str = "admin"
|
||||
ADMIN_DEFAULT_PASSWORD: str = "admin123"
|
||||
OCR_TIMEOUT_SECONDS: int = 300
|
||||
OCR_TIMEOUT_SECONDS: int = 180
|
||||
AI_TIMEOUT_SECONDS: int = 600
|
||||
PROCESS_STALE_MINUTES: int = 20
|
||||
OCR_MAX_SIDE: int = 1280
|
||||
UPLOAD_MAX_SIDE: int = 2048
|
||||
OCR_WARMUP: bool = True
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
|
||||
@@ -9,6 +9,7 @@ from fastapi.staticfiles import StaticFiles
|
||||
from app.core.config import settings
|
||||
from app.core.database import Base, SessionLocal, engine
|
||||
from app.routers import admin, auth, exams, export, settings as settings_router, students, subjects, wrong_questions
|
||||
from app.services import ocr as ocr_service
|
||||
from app.services.migrate import run_migrations
|
||||
from app.services.seed import seed_admin_and_settings, seed_subjects
|
||||
|
||||
@@ -34,6 +35,7 @@ async def lifespan(app: FastAPI):
|
||||
seed_admin_and_settings(db)
|
||||
finally:
|
||||
db.close()
|
||||
ocr_service.warmup_ocr_engine()
|
||||
yield
|
||||
|
||||
|
||||
|
||||
@@ -157,8 +157,7 @@ def _process_wrong_question(question_id: uuid.UUID):
|
||||
except FuturesTimeout:
|
||||
wq.status = WrongQuestionStatus.failed
|
||||
wq.error_message = (
|
||||
f"OCR 识别超时(>{settings.OCR_TIMEOUT_SECONDS}秒)。"
|
||||
" 首次加载模型较慢,请稍后点「重新识别标注」重试"
|
||||
f"OCR 识别超时(>{settings.OCR_TIMEOUT_SECONDS}秒),请点「重新识别标注」重试"
|
||||
)
|
||||
db.commit()
|
||||
return
|
||||
|
||||
+101
-15
@@ -1,6 +1,9 @@
|
||||
from pathlib import Path
|
||||
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import threading
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
from PIL import Image
|
||||
|
||||
@@ -9,7 +12,10 @@ from app.core.config import settings
|
||||
# 无图形界面服务器:避免 OpenCV/Paddle 依赖 X11
|
||||
os.environ.setdefault("OPENCV_IO_ENABLE_OPENEXR", "0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ocr_engine = None
|
||||
_ocr_warmup_started = False
|
||||
|
||||
|
||||
def get_ocr_engine():
|
||||
@@ -17,20 +23,103 @@ def get_ocr_engine():
|
||||
if _ocr_engine is None:
|
||||
from paddleocr import PaddleOCR
|
||||
|
||||
_ocr_engine = PaddleOCR(use_angle_cls=True, lang="ch", show_log=False)
|
||||
_ocr_engine = PaddleOCR(
|
||||
use_angle_cls=False,
|
||||
lang="ch",
|
||||
show_log=False,
|
||||
use_gpu=False,
|
||||
enable_mkldnn=True,
|
||||
det_limit_side_len=min(settings.OCR_MAX_SIDE, 1280),
|
||||
rec_batch_num=8,
|
||||
)
|
||||
return _ocr_engine
|
||||
|
||||
|
||||
def warmup_ocr_engine() -> None:
|
||||
"""后台预加载 OCR 模型,避免首张图片等待数分钟。"""
|
||||
global _ocr_warmup_started
|
||||
if _ocr_warmup_started or not settings.OCR_WARMUP:
|
||||
return
|
||||
_ocr_warmup_started = True
|
||||
|
||||
def _run() -> None:
|
||||
try:
|
||||
engine = get_ocr_engine()
|
||||
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
|
||||
Image.new("RGB", (120, 40), color=(255, 255, 255)).save(tmp.name, format="JPEG")
|
||||
tmp_path = tmp.name
|
||||
try:
|
||||
engine.ocr(tmp_path, cls=False)
|
||||
logger.info("OCR engine warmed up")
|
||||
finally:
|
||||
Path(tmp_path).unlink(missing_ok=True)
|
||||
except Exception as exc:
|
||||
logger.warning("OCR warmup failed: %s", exc)
|
||||
|
||||
threading.Thread(target=_run, daemon=True, name="ocr-warmup").start()
|
||||
|
||||
|
||||
def _bbox_from_box(box: list) -> list[float]:
|
||||
xs = [float(p[0]) for p in box]
|
||||
ys = [float(p[1]) for p in box]
|
||||
return [min(xs), min(ys), max(xs), max(ys)]
|
||||
|
||||
|
||||
def _scale_bbox(bbox: list[float], scale_x: float, scale_y: float) -> list[float]:
|
||||
return [bbox[0] * scale_x, bbox[1] * scale_y, bbox[2] * scale_x, bbox[3] * scale_y]
|
||||
|
||||
|
||||
def _scale_box(box: list, scale_x: float, scale_y: float) -> list:
|
||||
return [[float(p[0]) * scale_x, float(p[1]) * scale_y] for p in box]
|
||||
|
||||
|
||||
def _normalize_image_bytes(content: bytes, max_side: int) -> bytes:
|
||||
with Image.open(BytesIO(content)) as img:
|
||||
img = img.convert("RGB")
|
||||
width, height = img.size
|
||||
longest = max(width, height)
|
||||
if longest > max_side:
|
||||
ratio = max_side / longest
|
||||
img = img.resize((int(width * ratio), int(height * ratio)), Image.Resampling.LANCZOS)
|
||||
buf = BytesIO()
|
||||
img.save(buf, format="JPEG", quality=88, optimize=True)
|
||||
return buf.getvalue()
|
||||
|
||||
|
||||
def _prepare_ocr_image(image_path: str) -> tuple[str, float, float, int, int, Path | None]:
|
||||
"""若图片过大则生成临时缩小图供 OCR,返回缩放比例与原始尺寸。"""
|
||||
with Image.open(image_path) as img:
|
||||
orig_w, orig_h = img.size
|
||||
|
||||
max_side = settings.OCR_MAX_SIDE
|
||||
longest = max(orig_w, orig_h)
|
||||
if longest <= max_side:
|
||||
return image_path, 1.0, 1.0, orig_w, orig_h, None
|
||||
|
||||
with Image.open(image_path) as img:
|
||||
img = img.convert("RGB")
|
||||
ratio = max_side / longest
|
||||
new_w = max(1, int(orig_w * ratio))
|
||||
new_h = max(1, int(orig_h * ratio))
|
||||
resized = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
|
||||
tmp = Path(tempfile.gettempdir()) / f"ocr_{Path(image_path).stem}_{os.getpid()}.jpg"
|
||||
resized.save(tmp, format="JPEG", quality=85, optimize=True)
|
||||
|
||||
scale_x = orig_w / new_w
|
||||
scale_y = orig_h / new_h
|
||||
return str(tmp), scale_x, scale_y, orig_w, orig_h, tmp
|
||||
|
||||
|
||||
def run_ocr_with_regions(image_path: str) -> dict:
|
||||
"""Return OCR text plus line-level bounding boxes for annotation."""
|
||||
engine = get_ocr_engine()
|
||||
result = engine.ocr(image_path, cls=True)
|
||||
ocr_path, scale_x, scale_y, orig_w, orig_h, tmp_path = _prepare_ocr_image(image_path)
|
||||
try:
|
||||
result = engine.ocr(ocr_path, cls=False)
|
||||
finally:
|
||||
if tmp_path is not None:
|
||||
tmp_path.unlink(missing_ok=True)
|
||||
|
||||
lines: list[dict] = []
|
||||
if result and result[0]:
|
||||
for item in result[0]:
|
||||
@@ -41,27 +130,23 @@ def run_ocr_with_regions(image_path: str) -> dict:
|
||||
conf = float(rec[1]) if rec and len(rec) > 1 else 0.0
|
||||
if not text:
|
||||
continue
|
||||
if scale_x != 1.0 or scale_y != 1.0:
|
||||
box = _scale_box(box, scale_x, scale_y)
|
||||
bbox = _bbox_from_box(box)
|
||||
lines.append(
|
||||
{
|
||||
"text": text,
|
||||
"confidence": conf,
|
||||
"box": box,
|
||||
"bbox": _bbox_from_box(box),
|
||||
"bbox": bbox,
|
||||
}
|
||||
)
|
||||
|
||||
width, height = 0, 0
|
||||
try:
|
||||
with Image.open(image_path) as img:
|
||||
width, height = img.size
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
return {
|
||||
"text": "\n".join(line["text"] for line in lines),
|
||||
"lines": lines,
|
||||
"width": width,
|
||||
"height": height,
|
||||
"width": orig_w,
|
||||
"height": orig_h,
|
||||
}
|
||||
|
||||
|
||||
@@ -77,7 +162,8 @@ def save_upload_file(user_id: str, question_id: str, filename: str, content: byt
|
||||
user_dir.mkdir(parents=True, exist_ok=True)
|
||||
rel_path = f"{user_id}/{question_id}{ext}"
|
||||
full_path = Path(settings.UPLOAD_DIR) / rel_path
|
||||
full_path.write_bytes(content)
|
||||
normalized = _normalize_image_bytes(content, settings.UPLOAD_MAX_SIDE)
|
||||
full_path.write_bytes(normalized)
|
||||
return rel_path
|
||||
|
||||
|
||||
|
||||
+2
-2
File diff suppressed because one or more lines are too long
Vendored
+1
-1
@@ -9,7 +9,7 @@
|
||||
<meta name="author" content="马建军" />
|
||||
<meta name="copyright" content="Copyright (c) 马建军. All rights reserved." />
|
||||
<title>中学成绩档案</title>
|
||||
<script type="module" crossorigin src="/assets/index-_1CtLpiP.js"></script>
|
||||
<script type="module" crossorigin src="/assets/index-DzzkB1zh.js"></script>
|
||||
<link rel="stylesheet" crossorigin href="/assets/index-GY2etMYN.css">
|
||||
</head>
|
||||
<body>
|
||||
|
||||
@@ -37,7 +37,7 @@ export default function WrongQuestionDetail({
|
||||
try {
|
||||
const { data } = await wrongQuestionApi.get(questionId)
|
||||
setWq(data)
|
||||
setQuestionText(data.question_text || '')
|
||||
setQuestionText(data.question_text || data.ocr_raw_text || '')
|
||||
setApproachText(data.solution_approach || '')
|
||||
setSolutionText(data.solution_text || '')
|
||||
setImageMode(data.has_annotated_image ? 'annotated' : 'original')
|
||||
@@ -56,7 +56,7 @@ export default function WrongQuestionDetail({
|
||||
try {
|
||||
const { data } = await wrongQuestionApi.get(questionId)
|
||||
setWq(data)
|
||||
setQuestionText(data.question_text || '')
|
||||
setQuestionText(data.question_text || data.ocr_raw_text || '')
|
||||
setApproachText(data.solution_approach || '')
|
||||
setSolutionText(data.solution_text || '')
|
||||
if (data.has_annotated_image) setImageMode('annotated')
|
||||
@@ -88,7 +88,7 @@ export default function WrongQuestionDetail({
|
||||
try {
|
||||
const { data } = await wrongQuestionApi.regenerate(questionId)
|
||||
setWq(data)
|
||||
setQuestionText(data.question_text || '')
|
||||
setQuestionText(data.question_text || data.ocr_raw_text || '')
|
||||
setApproachText(data.solution_approach || '')
|
||||
setSolutionText(data.solution_text || '')
|
||||
message.success('解题思路已重新生成')
|
||||
|
||||
@@ -9,10 +9,10 @@ export function isWrongQuestionProcessing(wq: WrongQuestion): boolean {
|
||||
|
||||
export function processingHint(wq: WrongQuestion): string {
|
||||
if (wq.status === 'pending') {
|
||||
return '正在 OCR 识别(首次约 1–5 分钟,请稍候)…'
|
||||
return '正在识别文字(约 10–30 秒)…'
|
||||
}
|
||||
if (wq.status === 'ocr_done') {
|
||||
return '正在标注错题并生成解题思路…'
|
||||
return '正在标注错题并生成解题思路(约 30–90 秒)…'
|
||||
}
|
||||
return '正在识别、标注并生成解题思路…'
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user