加速错题 OCR:上传压缩、识别缩放、启动预热与 MKL-DNN。
This commit is contained in:
@@ -19,9 +19,12 @@ class Settings(BaseSettings):
|
|||||||
FRONTEND_DIST: str = "../frontend/dist"
|
FRONTEND_DIST: str = "../frontend/dist"
|
||||||
ADMIN_DEFAULT_USERNAME: str = "admin"
|
ADMIN_DEFAULT_USERNAME: str = "admin"
|
||||||
ADMIN_DEFAULT_PASSWORD: str = "admin123"
|
ADMIN_DEFAULT_PASSWORD: str = "admin123"
|
||||||
OCR_TIMEOUT_SECONDS: int = 300
|
OCR_TIMEOUT_SECONDS: int = 180
|
||||||
AI_TIMEOUT_SECONDS: int = 600
|
AI_TIMEOUT_SECONDS: int = 600
|
||||||
PROCESS_STALE_MINUTES: int = 20
|
PROCESS_STALE_MINUTES: int = 20
|
||||||
|
OCR_MAX_SIDE: int = 1280
|
||||||
|
UPLOAD_MAX_SIDE: int = 2048
|
||||||
|
OCR_WARMUP: bool = True
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
env_file = ".env"
|
env_file = ".env"
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from fastapi.staticfiles import StaticFiles
|
|||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.core.database import Base, SessionLocal, engine
|
from app.core.database import Base, SessionLocal, engine
|
||||||
from app.routers import admin, auth, exams, export, settings as settings_router, students, subjects, wrong_questions
|
from app.routers import admin, auth, exams, export, settings as settings_router, students, subjects, wrong_questions
|
||||||
|
from app.services import ocr as ocr_service
|
||||||
from app.services.migrate import run_migrations
|
from app.services.migrate import run_migrations
|
||||||
from app.services.seed import seed_admin_and_settings, seed_subjects
|
from app.services.seed import seed_admin_and_settings, seed_subjects
|
||||||
|
|
||||||
@@ -34,6 +35,7 @@ async def lifespan(app: FastAPI):
|
|||||||
seed_admin_and_settings(db)
|
seed_admin_and_settings(db)
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
|
ocr_service.warmup_ocr_engine()
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -157,8 +157,7 @@ def _process_wrong_question(question_id: uuid.UUID):
|
|||||||
except FuturesTimeout:
|
except FuturesTimeout:
|
||||||
wq.status = WrongQuestionStatus.failed
|
wq.status = WrongQuestionStatus.failed
|
||||||
wq.error_message = (
|
wq.error_message = (
|
||||||
f"OCR 识别超时(>{settings.OCR_TIMEOUT_SECONDS}秒)。"
|
f"OCR 识别超时(>{settings.OCR_TIMEOUT_SECONDS}秒),请点「重新识别标注」重试"
|
||||||
" 首次加载模型较慢,请稍后点「重新识别标注」重试"
|
|
||||||
)
|
)
|
||||||
db.commit()
|
db.commit()
|
||||||
return
|
return
|
||||||
|
|||||||
+101
-15
@@ -1,6 +1,9 @@
|
|||||||
from pathlib import Path
|
import logging
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import tempfile
|
||||||
|
import threading
|
||||||
|
from io import BytesIO
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
@@ -9,7 +12,10 @@ from app.core.config import settings
|
|||||||
# 无图形界面服务器:避免 OpenCV/Paddle 依赖 X11
|
# 无图形界面服务器:避免 OpenCV/Paddle 依赖 X11
|
||||||
os.environ.setdefault("OPENCV_IO_ENABLE_OPENEXR", "0")
|
os.environ.setdefault("OPENCV_IO_ENABLE_OPENEXR", "0")
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
_ocr_engine = None
|
_ocr_engine = None
|
||||||
|
_ocr_warmup_started = False
|
||||||
|
|
||||||
|
|
||||||
def get_ocr_engine():
|
def get_ocr_engine():
|
||||||
@@ -17,20 +23,103 @@ def get_ocr_engine():
|
|||||||
if _ocr_engine is None:
|
if _ocr_engine is None:
|
||||||
from paddleocr import PaddleOCR
|
from paddleocr import PaddleOCR
|
||||||
|
|
||||||
_ocr_engine = PaddleOCR(use_angle_cls=True, lang="ch", show_log=False)
|
_ocr_engine = PaddleOCR(
|
||||||
|
use_angle_cls=False,
|
||||||
|
lang="ch",
|
||||||
|
show_log=False,
|
||||||
|
use_gpu=False,
|
||||||
|
enable_mkldnn=True,
|
||||||
|
det_limit_side_len=min(settings.OCR_MAX_SIDE, 1280),
|
||||||
|
rec_batch_num=8,
|
||||||
|
)
|
||||||
return _ocr_engine
|
return _ocr_engine
|
||||||
|
|
||||||
|
|
||||||
|
def warmup_ocr_engine() -> None:
|
||||||
|
"""后台预加载 OCR 模型,避免首张图片等待数分钟。"""
|
||||||
|
global _ocr_warmup_started
|
||||||
|
if _ocr_warmup_started or not settings.OCR_WARMUP:
|
||||||
|
return
|
||||||
|
_ocr_warmup_started = True
|
||||||
|
|
||||||
|
def _run() -> None:
|
||||||
|
try:
|
||||||
|
engine = get_ocr_engine()
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
|
||||||
|
Image.new("RGB", (120, 40), color=(255, 255, 255)).save(tmp.name, format="JPEG")
|
||||||
|
tmp_path = tmp.name
|
||||||
|
try:
|
||||||
|
engine.ocr(tmp_path, cls=False)
|
||||||
|
logger.info("OCR engine warmed up")
|
||||||
|
finally:
|
||||||
|
Path(tmp_path).unlink(missing_ok=True)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("OCR warmup failed: %s", exc)
|
||||||
|
|
||||||
|
threading.Thread(target=_run, daemon=True, name="ocr-warmup").start()
|
||||||
|
|
||||||
|
|
||||||
def _bbox_from_box(box: list) -> list[float]:
|
def _bbox_from_box(box: list) -> list[float]:
|
||||||
xs = [float(p[0]) for p in box]
|
xs = [float(p[0]) for p in box]
|
||||||
ys = [float(p[1]) for p in box]
|
ys = [float(p[1]) for p in box]
|
||||||
return [min(xs), min(ys), max(xs), max(ys)]
|
return [min(xs), min(ys), max(xs), max(ys)]
|
||||||
|
|
||||||
|
|
||||||
|
def _scale_bbox(bbox: list[float], scale_x: float, scale_y: float) -> list[float]:
|
||||||
|
return [bbox[0] * scale_x, bbox[1] * scale_y, bbox[2] * scale_x, bbox[3] * scale_y]
|
||||||
|
|
||||||
|
|
||||||
|
def _scale_box(box: list, scale_x: float, scale_y: float) -> list:
|
||||||
|
return [[float(p[0]) * scale_x, float(p[1]) * scale_y] for p in box]
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_image_bytes(content: bytes, max_side: int) -> bytes:
|
||||||
|
with Image.open(BytesIO(content)) as img:
|
||||||
|
img = img.convert("RGB")
|
||||||
|
width, height = img.size
|
||||||
|
longest = max(width, height)
|
||||||
|
if longest > max_side:
|
||||||
|
ratio = max_side / longest
|
||||||
|
img = img.resize((int(width * ratio), int(height * ratio)), Image.Resampling.LANCZOS)
|
||||||
|
buf = BytesIO()
|
||||||
|
img.save(buf, format="JPEG", quality=88, optimize=True)
|
||||||
|
return buf.getvalue()
|
||||||
|
|
||||||
|
|
||||||
|
def _prepare_ocr_image(image_path: str) -> tuple[str, float, float, int, int, Path | None]:
|
||||||
|
"""若图片过大则生成临时缩小图供 OCR,返回缩放比例与原始尺寸。"""
|
||||||
|
with Image.open(image_path) as img:
|
||||||
|
orig_w, orig_h = img.size
|
||||||
|
|
||||||
|
max_side = settings.OCR_MAX_SIDE
|
||||||
|
longest = max(orig_w, orig_h)
|
||||||
|
if longest <= max_side:
|
||||||
|
return image_path, 1.0, 1.0, orig_w, orig_h, None
|
||||||
|
|
||||||
|
with Image.open(image_path) as img:
|
||||||
|
img = img.convert("RGB")
|
||||||
|
ratio = max_side / longest
|
||||||
|
new_w = max(1, int(orig_w * ratio))
|
||||||
|
new_h = max(1, int(orig_h * ratio))
|
||||||
|
resized = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
|
||||||
|
tmp = Path(tempfile.gettempdir()) / f"ocr_{Path(image_path).stem}_{os.getpid()}.jpg"
|
||||||
|
resized.save(tmp, format="JPEG", quality=85, optimize=True)
|
||||||
|
|
||||||
|
scale_x = orig_w / new_w
|
||||||
|
scale_y = orig_h / new_h
|
||||||
|
return str(tmp), scale_x, scale_y, orig_w, orig_h, tmp
|
||||||
|
|
||||||
|
|
||||||
def run_ocr_with_regions(image_path: str) -> dict:
|
def run_ocr_with_regions(image_path: str) -> dict:
|
||||||
"""Return OCR text plus line-level bounding boxes for annotation."""
|
"""Return OCR text plus line-level bounding boxes for annotation."""
|
||||||
engine = get_ocr_engine()
|
engine = get_ocr_engine()
|
||||||
result = engine.ocr(image_path, cls=True)
|
ocr_path, scale_x, scale_y, orig_w, orig_h, tmp_path = _prepare_ocr_image(image_path)
|
||||||
|
try:
|
||||||
|
result = engine.ocr(ocr_path, cls=False)
|
||||||
|
finally:
|
||||||
|
if tmp_path is not None:
|
||||||
|
tmp_path.unlink(missing_ok=True)
|
||||||
|
|
||||||
lines: list[dict] = []
|
lines: list[dict] = []
|
||||||
if result and result[0]:
|
if result and result[0]:
|
||||||
for item in result[0]:
|
for item in result[0]:
|
||||||
@@ -41,27 +130,23 @@ def run_ocr_with_regions(image_path: str) -> dict:
|
|||||||
conf = float(rec[1]) if rec and len(rec) > 1 else 0.0
|
conf = float(rec[1]) if rec and len(rec) > 1 else 0.0
|
||||||
if not text:
|
if not text:
|
||||||
continue
|
continue
|
||||||
|
if scale_x != 1.0 or scale_y != 1.0:
|
||||||
|
box = _scale_box(box, scale_x, scale_y)
|
||||||
|
bbox = _bbox_from_box(box)
|
||||||
lines.append(
|
lines.append(
|
||||||
{
|
{
|
||||||
"text": text,
|
"text": text,
|
||||||
"confidence": conf,
|
"confidence": conf,
|
||||||
"box": box,
|
"box": box,
|
||||||
"bbox": _bbox_from_box(box),
|
"bbox": bbox,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
width, height = 0, 0
|
|
||||||
try:
|
|
||||||
with Image.open(image_path) as img:
|
|
||||||
width, height = img.size
|
|
||||||
except OSError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"text": "\n".join(line["text"] for line in lines),
|
"text": "\n".join(line["text"] for line in lines),
|
||||||
"lines": lines,
|
"lines": lines,
|
||||||
"width": width,
|
"width": orig_w,
|
||||||
"height": height,
|
"height": orig_h,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -77,7 +162,8 @@ def save_upload_file(user_id: str, question_id: str, filename: str, content: byt
|
|||||||
user_dir.mkdir(parents=True, exist_ok=True)
|
user_dir.mkdir(parents=True, exist_ok=True)
|
||||||
rel_path = f"{user_id}/{question_id}{ext}"
|
rel_path = f"{user_id}/{question_id}{ext}"
|
||||||
full_path = Path(settings.UPLOAD_DIR) / rel_path
|
full_path = Path(settings.UPLOAD_DIR) / rel_path
|
||||||
full_path.write_bytes(content)
|
normalized = _normalize_image_bytes(content, settings.UPLOAD_MAX_SIDE)
|
||||||
|
full_path.write_bytes(normalized)
|
||||||
return rel_path
|
return rel_path
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
+2
-2
File diff suppressed because one or more lines are too long
Vendored
+1
-1
@@ -9,7 +9,7 @@
|
|||||||
<meta name="author" content="马建军" />
|
<meta name="author" content="马建军" />
|
||||||
<meta name="copyright" content="Copyright (c) 马建军. All rights reserved." />
|
<meta name="copyright" content="Copyright (c) 马建军. All rights reserved." />
|
||||||
<title>中学成绩档案</title>
|
<title>中学成绩档案</title>
|
||||||
<script type="module" crossorigin src="/assets/index-_1CtLpiP.js"></script>
|
<script type="module" crossorigin src="/assets/index-DzzkB1zh.js"></script>
|
||||||
<link rel="stylesheet" crossorigin href="/assets/index-GY2etMYN.css">
|
<link rel="stylesheet" crossorigin href="/assets/index-GY2etMYN.css">
|
||||||
</head>
|
</head>
|
||||||
<body>
|
<body>
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ export default function WrongQuestionDetail({
|
|||||||
try {
|
try {
|
||||||
const { data } = await wrongQuestionApi.get(questionId)
|
const { data } = await wrongQuestionApi.get(questionId)
|
||||||
setWq(data)
|
setWq(data)
|
||||||
setQuestionText(data.question_text || '')
|
setQuestionText(data.question_text || data.ocr_raw_text || '')
|
||||||
setApproachText(data.solution_approach || '')
|
setApproachText(data.solution_approach || '')
|
||||||
setSolutionText(data.solution_text || '')
|
setSolutionText(data.solution_text || '')
|
||||||
setImageMode(data.has_annotated_image ? 'annotated' : 'original')
|
setImageMode(data.has_annotated_image ? 'annotated' : 'original')
|
||||||
@@ -56,7 +56,7 @@ export default function WrongQuestionDetail({
|
|||||||
try {
|
try {
|
||||||
const { data } = await wrongQuestionApi.get(questionId)
|
const { data } = await wrongQuestionApi.get(questionId)
|
||||||
setWq(data)
|
setWq(data)
|
||||||
setQuestionText(data.question_text || '')
|
setQuestionText(data.question_text || data.ocr_raw_text || '')
|
||||||
setApproachText(data.solution_approach || '')
|
setApproachText(data.solution_approach || '')
|
||||||
setSolutionText(data.solution_text || '')
|
setSolutionText(data.solution_text || '')
|
||||||
if (data.has_annotated_image) setImageMode('annotated')
|
if (data.has_annotated_image) setImageMode('annotated')
|
||||||
@@ -88,7 +88,7 @@ export default function WrongQuestionDetail({
|
|||||||
try {
|
try {
|
||||||
const { data } = await wrongQuestionApi.regenerate(questionId)
|
const { data } = await wrongQuestionApi.regenerate(questionId)
|
||||||
setWq(data)
|
setWq(data)
|
||||||
setQuestionText(data.question_text || '')
|
setQuestionText(data.question_text || data.ocr_raw_text || '')
|
||||||
setApproachText(data.solution_approach || '')
|
setApproachText(data.solution_approach || '')
|
||||||
setSolutionText(data.solution_text || '')
|
setSolutionText(data.solution_text || '')
|
||||||
message.success('解题思路已重新生成')
|
message.success('解题思路已重新生成')
|
||||||
|
|||||||
@@ -9,10 +9,10 @@ export function isWrongQuestionProcessing(wq: WrongQuestion): boolean {
|
|||||||
|
|
||||||
export function processingHint(wq: WrongQuestion): string {
|
export function processingHint(wq: WrongQuestion): string {
|
||||||
if (wq.status === 'pending') {
|
if (wq.status === 'pending') {
|
||||||
return '正在 OCR 识别(首次约 1–5 分钟,请稍候)…'
|
return '正在识别文字(约 10–30 秒)…'
|
||||||
}
|
}
|
||||||
if (wq.status === 'ocr_done') {
|
if (wq.status === 'ocr_done') {
|
||||||
return '正在标注错题并生成解题思路…'
|
return '正在标注错题并生成解题思路(约 30–90 秒)…'
|
||||||
}
|
}
|
||||||
return '正在识别、标注并生成解题思路…'
|
return '正在识别、标注并生成解题思路…'
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user