187 lines
6.0 KiB
Python
187 lines
6.0 KiB
Python
"""局域网 OCR 服务:在带 NVIDIA 显卡的机器上运行,供成绩档案系统调用。"""
|
|
|
|
import logging
|
|
import os
|
|
import tempfile
|
|
from io import BytesIO
|
|
from pathlib import Path
|
|
|
|
from fastapi import FastAPI, File, Header, HTTPException, UploadFile
|
|
from PIL import Image
|
|
|
|
os.environ.setdefault("OPENCV_IO_ENABLE_OPENEXR", "0")
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
|
logger = logging.getLogger("ocr-worker")
|
|
|
|
OCR_MAX_SIDE = int(os.getenv("OCR_MAX_SIDE", "1280"))
|
|
OCR_API_KEY = os.getenv("OCR_API_KEY", "").strip()
|
|
OCR_USE_GPU = os.getenv("OCR_USE_GPU", "true").lower() in {"1", "true", "yes"}
|
|
|
|
app = FastAPI(title="Grade Archive OCR Worker", version="1.0.0")
|
|
_engine = None
|
|
_engine_mode = "none"
|
|
|
|
|
|
def _check_key(key: str | None) -> None:
|
|
if OCR_API_KEY and key != OCR_API_KEY:
|
|
raise HTTPException(status_code=401, detail="Invalid OCR API key")
|
|
|
|
|
|
def _create_engine(use_gpu: bool):
|
|
from paddleocr import PaddleOCR
|
|
|
|
return PaddleOCR(
|
|
use_angle_cls=False,
|
|
lang="ch",
|
|
show_log=False,
|
|
use_gpu=use_gpu,
|
|
enable_mkldnn=not use_gpu,
|
|
det_limit_side_len=min(OCR_MAX_SIDE, 1280),
|
|
rec_batch_num=8,
|
|
)
|
|
|
|
|
|
def get_engine(force_cpu: bool = False):
|
|
global _engine, _engine_mode
|
|
if _engine is not None and not force_cpu:
|
|
return _engine
|
|
|
|
modes: list[bool] = [False] if force_cpu or not OCR_USE_GPU else [True, False]
|
|
last_err: Exception | None = None
|
|
for use_gpu in modes:
|
|
try:
|
|
logger.info("Loading PaddleOCR use_gpu=%s", use_gpu)
|
|
_engine = _create_engine(use_gpu)
|
|
_engine_mode = "gpu" if use_gpu else "cpu"
|
|
logger.info("PaddleOCR ready mode=%s", _engine_mode)
|
|
return _engine
|
|
except Exception as exc:
|
|
last_err = exc
|
|
logger.warning("PaddleOCR init failed use_gpu=%s: %s", use_gpu, exc)
|
|
_engine = None
|
|
_engine_mode = "none"
|
|
|
|
hint = ""
|
|
err_text = str(last_err or "")
|
|
if "libGL" in err_text:
|
|
hint = " 请执行: sudo bash deploy/install-ocr-deps.sh 后重启 OCR"
|
|
elif any(x in err_text.lower() for x in ("cuda", "cudnn", "gpu", "out of memory")):
|
|
hint = " 显存不足或 CUDA 异常,可设置 OCR_USE_GPU=false 用 CPU"
|
|
raise RuntimeError(f"PaddleOCR 初始化失败: {last_err}{hint}") from last_err
|
|
|
|
|
|
def _reset_engine():
|
|
global _engine, _engine_mode
|
|
_engine = None
|
|
_engine_mode = "none"
|
|
|
|
|
|
def _bbox_from_box(box: list) -> list[float]:
|
|
xs = [float(p[0]) for p in box]
|
|
ys = [float(p[1]) for p in box]
|
|
return [min(xs), min(ys), max(xs), max(ys)]
|
|
|
|
|
|
def _scale_box(box: list, scale_x: float, scale_y: float) -> list:
|
|
return [[float(p[0]) * scale_x, float(p[1]) * scale_y] for p in box]
|
|
|
|
|
|
def _prepare_image_bytes(content: bytes) -> tuple[bytes, float, float, int, int]:
|
|
with Image.open(BytesIO(content)) as img:
|
|
img = img.convert("RGB")
|
|
orig_w, orig_h = img.size
|
|
longest = max(orig_w, orig_h)
|
|
if longest <= OCR_MAX_SIDE:
|
|
buf = BytesIO()
|
|
img.save(buf, format="JPEG", quality=88)
|
|
return buf.getvalue(), 1.0, 1.0, orig_w, orig_h
|
|
|
|
ratio = OCR_MAX_SIDE / longest
|
|
new_w = max(1, int(orig_w * ratio))
|
|
new_h = max(1, int(orig_h * ratio))
|
|
resized = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
|
|
buf = BytesIO()
|
|
resized.save(buf, format="JPEG", quality=85)
|
|
scale_x = orig_w / new_w
|
|
scale_y = orig_h / new_h
|
|
return buf.getvalue(), scale_x, scale_y, orig_w, orig_h
|
|
|
|
|
|
def _run_ocr_impl(content: bytes) -> dict:
|
|
engine = get_engine()
|
|
image_bytes, scale_x, scale_y, orig_w, orig_h = _prepare_image_bytes(content)
|
|
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
|
|
tmp.write(image_bytes)
|
|
tmp_path = tmp.name
|
|
try:
|
|
result = engine.ocr(tmp_path, cls=False)
|
|
finally:
|
|
Path(tmp_path).unlink(missing_ok=True)
|
|
|
|
lines: list[dict] = []
|
|
if result and result[0]:
|
|
for item in result[0]:
|
|
if not item or len(item) < 2:
|
|
continue
|
|
box, rec = item[0], item[1]
|
|
text = rec[0] if rec else ""
|
|
conf = float(rec[1]) if rec and len(rec) > 1 else 0.0
|
|
if not text:
|
|
continue
|
|
if scale_x != 1.0 or scale_y != 1.0:
|
|
box = _scale_box(box, scale_x, scale_y)
|
|
lines.append(
|
|
{
|
|
"text": text,
|
|
"confidence": conf,
|
|
"box": box,
|
|
"bbox": _bbox_from_box(box),
|
|
}
|
|
)
|
|
|
|
return {
|
|
"text": "\n".join(line["text"] for line in lines),
|
|
"lines": lines,
|
|
"width": orig_w,
|
|
"height": orig_h,
|
|
"engine_mode": _engine_mode,
|
|
}
|
|
|
|
|
|
def run_ocr_on_bytes(content: bytes) -> dict:
|
|
try:
|
|
return _run_ocr_impl(content)
|
|
except Exception as exc:
|
|
err = str(exc).lower()
|
|
gpu_fail = _engine_mode == "gpu" and any(
|
|
x in err for x in ("cuda", "cudnn", "gpu", "out of memory", "resource exhausted", "precondition")
|
|
)
|
|
if gpu_fail and OCR_USE_GPU:
|
|
logger.warning("GPU OCR runtime failed, retry CPU: %s", exc)
|
|
_reset_engine()
|
|
get_engine(force_cpu=True)
|
|
return _run_ocr_impl(content)
|
|
raise
|
|
|
|
|
|
@app.get("/health")
|
|
def health():
|
|
return {"status": "ok", "gpu_requested": OCR_USE_GPU, "engine_mode": _engine_mode}
|
|
|
|
|
|
@app.post("/api/ocr/regions")
|
|
async def ocr_regions(
|
|
file: UploadFile = File(...),
|
|
x_ocr_key: str | None = Header(default=None, alias="X-OCR-Key"),
|
|
):
|
|
_check_key(x_ocr_key)
|
|
content = await file.read()
|
|
if not content:
|
|
raise HTTPException(status_code=400, detail="Empty image")
|
|
try:
|
|
return run_ocr_on_bytes(content)
|
|
except Exception as exc:
|
|
logger.exception("OCR failed")
|
|
raise HTTPException(status_code=500, detail=str(exc)) from exc
|