Files
secondary-school-grade-archive/backend/app/services/ocr.py
T

214 lines
7.3 KiB
Python

import logging
import os
import tempfile
import threading
from io import BytesIO
from pathlib import Path
import httpx
from PIL import Image
from app.core.config import settings
# 无图形界面服务器:避免 OpenCV/Paddle 依赖 X11
os.environ.setdefault("OPENCV_IO_ENABLE_OPENEXR", "0")
logger = logging.getLogger(__name__)
_ocr_engine = None
_ocr_warmup_started = False
def get_ocr_engine():
global _ocr_engine
if _ocr_engine is None:
from paddleocr import PaddleOCR
use_gpu = settings.OCR_USE_GPU
_ocr_engine = PaddleOCR(
use_angle_cls=False,
lang="ch",
show_log=False,
use_gpu=use_gpu,
enable_mkldnn=not use_gpu,
det_limit_side_len=min(settings.OCR_MAX_SIDE, 1280),
rec_batch_num=8,
)
return _ocr_engine
def resolve_ocr_service_url(service_url: str | None = None) -> str | None:
url = (service_url or settings.OCR_SERVICE_URL or "").strip()
return url or None
def uses_remote_ocr(service_url: str | None = None) -> bool:
return resolve_ocr_service_url(service_url) is not None
def warmup_ocr_engine() -> None:
"""后台预加载 OCR 模型,避免首张图片等待数分钟。"""
global _ocr_warmup_started
if _ocr_warmup_started or not settings.OCR_WARMUP or uses_remote_ocr():
return
_ocr_warmup_started = True
def _run() -> None:
try:
engine = get_ocr_engine()
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
Image.new("RGB", (120, 40), color=(255, 255, 255)).save(tmp.name, format="JPEG")
tmp_path = tmp.name
try:
engine.ocr(tmp_path, cls=False)
logger.info("OCR engine warmed up")
finally:
Path(tmp_path).unlink(missing_ok=True)
except Exception as exc:
logger.warning("OCR warmup failed: %s", exc)
threading.Thread(target=_run, daemon=True, name="ocr-warmup").start()
def _bbox_from_box(box: list) -> list[float]:
xs = [float(p[0]) for p in box]
ys = [float(p[1]) for p in box]
return [min(xs), min(ys), max(xs), max(ys)]
def _scale_bbox(bbox: list[float], scale_x: float, scale_y: float) -> list[float]:
return [bbox[0] * scale_x, bbox[1] * scale_y, bbox[2] * scale_x, bbox[3] * scale_y]
def _scale_box(box: list, scale_x: float, scale_y: float) -> list:
return [[float(p[0]) * scale_x, float(p[1]) * scale_y] for p in box]
def _normalize_image_bytes(content: bytes, max_side: int) -> bytes:
with Image.open(BytesIO(content)) as img:
img = img.convert("RGB")
width, height = img.size
longest = max(width, height)
if longest > max_side:
ratio = max_side / longest
img = img.resize((int(width * ratio), int(height * ratio)), Image.Resampling.LANCZOS)
buf = BytesIO()
img.save(buf, format="JPEG", quality=88, optimize=True)
return buf.getvalue()
def _prepare_ocr_image(image_path: str) -> tuple[str, float, float, int, int, Path | None]:
"""若图片过大则生成临时缩小图供 OCR,返回缩放比例与原始尺寸。"""
with Image.open(image_path) as img:
orig_w, orig_h = img.size
max_side = settings.OCR_MAX_SIDE
longest = max(orig_w, orig_h)
if longest <= max_side:
return image_path, 1.0, 1.0, orig_w, orig_h, None
with Image.open(image_path) as img:
img = img.convert("RGB")
ratio = max_side / longest
new_w = max(1, int(orig_w * ratio))
new_h = max(1, int(orig_h * ratio))
resized = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
tmp = Path(tempfile.gettempdir()) / f"ocr_{Path(image_path).stem}_{os.getpid()}.jpg"
resized.save(tmp, format="JPEG", quality=85, optimize=True)
scale_x = orig_w / new_w
scale_y = orig_h / new_h
return str(tmp), scale_x, scale_y, orig_w, orig_h, tmp
def _run_remote_ocr(service_url: str, image_path: str) -> dict:
url = f"{service_url.rstrip('/')}/api/ocr/regions"
headers: dict[str, str] = {}
if settings.OCR_API_KEY:
headers["X-OCR-Key"] = settings.OCR_API_KEY
with open(image_path, "rb") as handle:
files = {"file": (Path(image_path).name, handle, "image/jpeg")}
with httpx.Client(timeout=settings.OCR_TIMEOUT_SECONDS) as client:
resp = client.post(url, files=files, headers=headers)
if resp.status_code >= 400:
detail = resp.text
try:
body = resp.json()
if isinstance(body.get("detail"), str):
detail = body["detail"]
elif isinstance(body.get("detail"), list):
detail = str(body["detail"])
except Exception:
pass
raise RuntimeError(f"OCR 服务 {resp.status_code}: {detail}")
return resp.json()
def _run_local_ocr(image_path: str) -> dict:
engine = get_ocr_engine()
ocr_path, scale_x, scale_y, orig_w, orig_h, tmp_path = _prepare_ocr_image(image_path)
try:
result = engine.ocr(ocr_path, cls=False)
finally:
if tmp_path is not None:
tmp_path.unlink(missing_ok=True)
lines: list[dict] = []
if result and result[0]:
for item in result[0]:
if not item or len(item) < 2:
continue
box, rec = item[0], item[1]
text = rec[0] if rec else ""
conf = float(rec[1]) if rec and len(rec) > 1 else 0.0
if not text:
continue
if scale_x != 1.0 or scale_y != 1.0:
box = _scale_box(box, scale_x, scale_y)
bbox = _bbox_from_box(box)
lines.append(
{
"text": text,
"confidence": conf,
"box": box,
"bbox": bbox,
}
)
return {
"text": "\n".join(line["text"] for line in lines),
"lines": lines,
"width": orig_w,
"height": orig_h,
}
def run_ocr_with_regions(image_path: str, service_url: str | None = None) -> dict:
"""Return OCR text plus line-level bounding boxes for annotation."""
remote = resolve_ocr_service_url(service_url)
if remote:
return _run_remote_ocr(remote, image_path)
return _run_local_ocr(image_path)
def run_ocr(image_path: str) -> str:
return run_ocr_with_regions(image_path)["text"]
def save_upload_file(user_id: str, question_id: str, filename: str, content: bytes) -> str:
ext = Path(filename).suffix.lower() or ".jpg"
if ext not in {".jpg", ".jpeg", ".png", ".webp"}:
ext = ".jpg"
user_dir = Path(settings.UPLOAD_DIR) / user_id
user_dir.mkdir(parents=True, exist_ok=True)
rel_path = f"{user_id}/{question_id}{ext}"
full_path = Path(settings.UPLOAD_DIR) / rel_path
normalized = _normalize_image_bytes(content, settings.UPLOAD_MAX_SIDE)
full_path.write_bytes(normalized)
return rel_path
def annotated_rel_path(original_rel: str) -> str:
p = Path(original_rel)
return str(p.parent / f"{p.stem}_marked.jpg")