Files
secondary-school-grade-archive/deploy/ocr-worker/app.py
T

128 lines
3.9 KiB
Python

"""局域网 OCR 服务:RapidOCR(ONNX),不依赖 Paddle,避免 SIGILL/cuDNN 问题。"""
import logging
import os
import tempfile
from io import BytesIO
from pathlib import Path
from fastapi import FastAPI, File, Header, HTTPException, UploadFile
from PIL import Image
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
logger = logging.getLogger("ocr-worker")
OCR_MAX_SIDE = int(os.getenv("OCR_MAX_SIDE", "1280"))
OCR_API_KEY = os.getenv("OCR_API_KEY", "").strip()
app = FastAPI(title="Grade Archive OCR Worker", version="2.0.0")
_engine = None
def _check_key(key: str | None) -> None:
if OCR_API_KEY and key != OCR_API_KEY:
raise HTTPException(status_code=401, detail="Invalid OCR API key")
def get_engine():
global _engine
if _engine is None:
from rapidocr_onnxruntime import RapidOCR
logger.info("Loading RapidOCR (ONNX CPU)…")
_engine = RapidOCR()
logger.info("RapidOCR ready")
return _engine
def _bbox_from_box(box: list) -> list[float]:
xs = [float(p[0]) for p in box]
ys = [float(p[1]) for p in box]
return [min(xs), min(ys), max(xs), max(ys)]
def _scale_box(box: list, scale_x: float, scale_y: float) -> list:
return [[float(p[0]) * scale_x, float(p[1]) * scale_y] for p in box]
def _prepare_image_bytes(content: bytes) -> tuple[bytes, float, float, int, int]:
with Image.open(BytesIO(content)) as img:
img = img.convert("RGB")
orig_w, orig_h = img.size
longest = max(orig_w, orig_h)
if longest <= OCR_MAX_SIDE:
buf = BytesIO()
img.save(buf, format="JPEG", quality=88)
return buf.getvalue(), 1.0, 1.0, orig_w, orig_h
ratio = OCR_MAX_SIDE / longest
new_w = max(1, int(orig_w * ratio))
new_h = max(1, int(orig_h * ratio))
resized = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
buf = BytesIO()
resized.save(buf, format="JPEG", quality=85)
scale_x = orig_w / new_w
scale_y = orig_h / new_h
return buf.getvalue(), scale_x, scale_y, orig_w, orig_h
def run_ocr_on_bytes(content: bytes) -> dict:
engine = get_engine()
image_bytes, scale_x, scale_y, orig_w, orig_h = _prepare_image_bytes(content)
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
tmp.write(image_bytes)
tmp_path = tmp.name
try:
result, _elapsed = engine(tmp_path)
finally:
Path(tmp_path).unlink(missing_ok=True)
lines: list[dict] = []
if result:
for item in result:
if not item or len(item) < 2:
continue
box, text = item[0], item[1]
conf = float(item[2]) if len(item) > 2 else 0.0
if not text:
continue
if scale_x != 1.0 or scale_y != 1.0:
box = _scale_box(box, scale_x, scale_y)
lines.append(
{
"text": str(text),
"confidence": conf,
"box": box,
"bbox": _bbox_from_box(box),
}
)
return {
"text": "\n".join(line["text"] for line in lines),
"lines": lines,
"width": orig_w,
"height": orig_h,
"engine_mode": "rapidocr-onnx",
}
@app.get("/health")
def health():
return {"status": "ok", "engine": "rapidocr-onnxruntime"}
@app.post("/api/ocr/regions")
async def ocr_regions(
file: UploadFile = File(...),
x_ocr_key: str | None = Header(default=None, alias="X-OCR-Key"),
):
_check_key(x_ocr_key)
content = await file.read()
if not content:
raise HTTPException(status_code=400, detail="Empty image")
try:
return run_ocr_on_bytes(content)
except Exception as exc:
logger.exception("OCR failed")
raise HTTPException(status_code=500, detail=str(exc)) from exc