Files
secondary-school-grade-archive/backend/app/services/ocr.py
T

173 lines
5.7 KiB
Python

import logging
import os
import tempfile
import threading
from io import BytesIO
from pathlib import Path
from PIL import Image
from app.core.config import settings
# 无图形界面服务器:避免 OpenCV/Paddle 依赖 X11
os.environ.setdefault("OPENCV_IO_ENABLE_OPENEXR", "0")
logger = logging.getLogger(__name__)
_ocr_engine = None
_ocr_warmup_started = False
def get_ocr_engine():
global _ocr_engine
if _ocr_engine is None:
from paddleocr import PaddleOCR
_ocr_engine = PaddleOCR(
use_angle_cls=False,
lang="ch",
show_log=False,
use_gpu=False,
enable_mkldnn=True,
det_limit_side_len=min(settings.OCR_MAX_SIDE, 1280),
rec_batch_num=8,
)
return _ocr_engine
def warmup_ocr_engine() -> None:
"""后台预加载 OCR 模型,避免首张图片等待数分钟。"""
global _ocr_warmup_started
if _ocr_warmup_started or not settings.OCR_WARMUP:
return
_ocr_warmup_started = True
def _run() -> None:
try:
engine = get_ocr_engine()
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
Image.new("RGB", (120, 40), color=(255, 255, 255)).save(tmp.name, format="JPEG")
tmp_path = tmp.name
try:
engine.ocr(tmp_path, cls=False)
logger.info("OCR engine warmed up")
finally:
Path(tmp_path).unlink(missing_ok=True)
except Exception as exc:
logger.warning("OCR warmup failed: %s", exc)
threading.Thread(target=_run, daemon=True, name="ocr-warmup").start()
def _bbox_from_box(box: list) -> list[float]:
xs = [float(p[0]) for p in box]
ys = [float(p[1]) for p in box]
return [min(xs), min(ys), max(xs), max(ys)]
def _scale_bbox(bbox: list[float], scale_x: float, scale_y: float) -> list[float]:
return [bbox[0] * scale_x, bbox[1] * scale_y, bbox[2] * scale_x, bbox[3] * scale_y]
def _scale_box(box: list, scale_x: float, scale_y: float) -> list:
return [[float(p[0]) * scale_x, float(p[1]) * scale_y] for p in box]
def _normalize_image_bytes(content: bytes, max_side: int) -> bytes:
with Image.open(BytesIO(content)) as img:
img = img.convert("RGB")
width, height = img.size
longest = max(width, height)
if longest > max_side:
ratio = max_side / longest
img = img.resize((int(width * ratio), int(height * ratio)), Image.Resampling.LANCZOS)
buf = BytesIO()
img.save(buf, format="JPEG", quality=88, optimize=True)
return buf.getvalue()
def _prepare_ocr_image(image_path: str) -> tuple[str, float, float, int, int, Path | None]:
"""若图片过大则生成临时缩小图供 OCR,返回缩放比例与原始尺寸。"""
with Image.open(image_path) as img:
orig_w, orig_h = img.size
max_side = settings.OCR_MAX_SIDE
longest = max(orig_w, orig_h)
if longest <= max_side:
return image_path, 1.0, 1.0, orig_w, orig_h, None
with Image.open(image_path) as img:
img = img.convert("RGB")
ratio = max_side / longest
new_w = max(1, int(orig_w * ratio))
new_h = max(1, int(orig_h * ratio))
resized = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
tmp = Path(tempfile.gettempdir()) / f"ocr_{Path(image_path).stem}_{os.getpid()}.jpg"
resized.save(tmp, format="JPEG", quality=85, optimize=True)
scale_x = orig_w / new_w
scale_y = orig_h / new_h
return str(tmp), scale_x, scale_y, orig_w, orig_h, tmp
def run_ocr_with_regions(image_path: str) -> dict:
"""Return OCR text plus line-level bounding boxes for annotation."""
engine = get_ocr_engine()
ocr_path, scale_x, scale_y, orig_w, orig_h, tmp_path = _prepare_ocr_image(image_path)
try:
result = engine.ocr(ocr_path, cls=False)
finally:
if tmp_path is not None:
tmp_path.unlink(missing_ok=True)
lines: list[dict] = []
if result and result[0]:
for item in result[0]:
if not item or len(item) < 2:
continue
box, rec = item[0], item[1]
text = rec[0] if rec else ""
conf = float(rec[1]) if rec and len(rec) > 1 else 0.0
if not text:
continue
if scale_x != 1.0 or scale_y != 1.0:
box = _scale_box(box, scale_x, scale_y)
bbox = _bbox_from_box(box)
lines.append(
{
"text": text,
"confidence": conf,
"box": box,
"bbox": bbox,
}
)
return {
"text": "\n".join(line["text"] for line in lines),
"lines": lines,
"width": orig_w,
"height": orig_h,
}
def run_ocr(image_path: str) -> str:
return run_ocr_with_regions(image_path)["text"]
def save_upload_file(user_id: str, question_id: str, filename: str, content: bytes) -> str:
ext = Path(filename).suffix.lower() or ".jpg"
if ext not in {".jpg", ".jpeg", ".png", ".webp"}:
ext = ".jpg"
user_dir = Path(settings.UPLOAD_DIR) / user_id
user_dir.mkdir(parents=True, exist_ok=True)
rel_path = f"{user_id}/{question_id}{ext}"
full_path = Path(settings.UPLOAD_DIR) / rel_path
normalized = _normalize_image_bytes(content, settings.UPLOAD_MAX_SIDE)
full_path.write_bytes(normalized)
return rel_path
def annotated_rel_path(original_rel: str) -> str:
p = Path(original_rel)
return str(p.parent / f"{p.stem}_marked.jpg")