204 lines
6.9 KiB
Python
204 lines
6.9 KiB
Python
import logging
|
|
import os
|
|
import tempfile
|
|
import threading
|
|
from io import BytesIO
|
|
from pathlib import Path
|
|
|
|
import httpx
|
|
from PIL import Image
|
|
|
|
from app.core.config import settings
|
|
|
|
# 无图形界面服务器:避免 OpenCV/Paddle 依赖 X11
|
|
os.environ.setdefault("OPENCV_IO_ENABLE_OPENEXR", "0")
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_ocr_engine = None
|
|
_ocr_warmup_started = False
|
|
|
|
|
|
def get_ocr_engine():
|
|
global _ocr_engine
|
|
if _ocr_engine is None:
|
|
from paddleocr import PaddleOCR
|
|
|
|
use_gpu = settings.OCR_USE_GPU
|
|
_ocr_engine = PaddleOCR(
|
|
use_angle_cls=False,
|
|
lang="ch",
|
|
show_log=False,
|
|
use_gpu=use_gpu,
|
|
enable_mkldnn=not use_gpu,
|
|
det_limit_side_len=min(settings.OCR_MAX_SIDE, 1280),
|
|
rec_batch_num=8,
|
|
)
|
|
return _ocr_engine
|
|
|
|
|
|
def resolve_ocr_service_url(service_url: str | None = None) -> str | None:
|
|
url = (service_url or settings.OCR_SERVICE_URL or "").strip()
|
|
return url or None
|
|
|
|
|
|
def uses_remote_ocr(service_url: str | None = None) -> bool:
|
|
return resolve_ocr_service_url(service_url) is not None
|
|
|
|
|
|
def warmup_ocr_engine() -> None:
|
|
"""后台预加载 OCR 模型,避免首张图片等待数分钟。"""
|
|
global _ocr_warmup_started
|
|
if _ocr_warmup_started or not settings.OCR_WARMUP or uses_remote_ocr():
|
|
return
|
|
_ocr_warmup_started = True
|
|
|
|
def _run() -> None:
|
|
try:
|
|
engine = get_ocr_engine()
|
|
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
|
|
Image.new("RGB", (120, 40), color=(255, 255, 255)).save(tmp.name, format="JPEG")
|
|
tmp_path = tmp.name
|
|
try:
|
|
engine.ocr(tmp_path, cls=False)
|
|
logger.info("OCR engine warmed up")
|
|
finally:
|
|
Path(tmp_path).unlink(missing_ok=True)
|
|
except Exception as exc:
|
|
logger.warning("OCR warmup failed: %s", exc)
|
|
|
|
threading.Thread(target=_run, daemon=True, name="ocr-warmup").start()
|
|
|
|
|
|
def _bbox_from_box(box: list) -> list[float]:
|
|
xs = [float(p[0]) for p in box]
|
|
ys = [float(p[1]) for p in box]
|
|
return [min(xs), min(ys), max(xs), max(ys)]
|
|
|
|
|
|
def _scale_bbox(bbox: list[float], scale_x: float, scale_y: float) -> list[float]:
|
|
return [bbox[0] * scale_x, bbox[1] * scale_y, bbox[2] * scale_x, bbox[3] * scale_y]
|
|
|
|
|
|
def _scale_box(box: list, scale_x: float, scale_y: float) -> list:
|
|
return [[float(p[0]) * scale_x, float(p[1]) * scale_y] for p in box]
|
|
|
|
|
|
def _normalize_image_bytes(content: bytes, max_side: int) -> bytes:
|
|
with Image.open(BytesIO(content)) as img:
|
|
img = img.convert("RGB")
|
|
width, height = img.size
|
|
longest = max(width, height)
|
|
if longest > max_side:
|
|
ratio = max_side / longest
|
|
img = img.resize((int(width * ratio), int(height * ratio)), Image.Resampling.LANCZOS)
|
|
buf = BytesIO()
|
|
img.save(buf, format="JPEG", quality=88, optimize=True)
|
|
return buf.getvalue()
|
|
|
|
|
|
def _prepare_ocr_image(image_path: str) -> tuple[str, float, float, int, int, Path | None]:
|
|
"""若图片过大则生成临时缩小图供 OCR,返回缩放比例与原始尺寸。"""
|
|
with Image.open(image_path) as img:
|
|
orig_w, orig_h = img.size
|
|
|
|
max_side = settings.OCR_MAX_SIDE
|
|
longest = max(orig_w, orig_h)
|
|
if longest <= max_side:
|
|
return image_path, 1.0, 1.0, orig_w, orig_h, None
|
|
|
|
with Image.open(image_path) as img:
|
|
img = img.convert("RGB")
|
|
ratio = max_side / longest
|
|
new_w = max(1, int(orig_w * ratio))
|
|
new_h = max(1, int(orig_h * ratio))
|
|
resized = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
|
|
tmp = Path(tempfile.gettempdir()) / f"ocr_{Path(image_path).stem}_{os.getpid()}.jpg"
|
|
resized.save(tmp, format="JPEG", quality=85, optimize=True)
|
|
|
|
scale_x = orig_w / new_w
|
|
scale_y = orig_h / new_h
|
|
return str(tmp), scale_x, scale_y, orig_w, orig_h, tmp
|
|
|
|
|
|
def _run_remote_ocr(service_url: str, image_path: str) -> dict:
|
|
url = f"{service_url.rstrip('/')}/api/ocr/regions"
|
|
headers: dict[str, str] = {}
|
|
if settings.OCR_API_KEY:
|
|
headers["X-OCR-Key"] = settings.OCR_API_KEY
|
|
with open(image_path, "rb") as handle:
|
|
files = {"file": (Path(image_path).name, handle, "image/jpeg")}
|
|
with httpx.Client(timeout=settings.OCR_TIMEOUT_SECONDS) as client:
|
|
resp = client.post(url, files=files, headers=headers)
|
|
resp.raise_for_status()
|
|
return resp.json()
|
|
|
|
|
|
def _run_local_ocr(image_path: str) -> dict:
|
|
engine = get_ocr_engine()
|
|
ocr_path, scale_x, scale_y, orig_w, orig_h, tmp_path = _prepare_ocr_image(image_path)
|
|
try:
|
|
result = engine.ocr(ocr_path, cls=False)
|
|
finally:
|
|
if tmp_path is not None:
|
|
tmp_path.unlink(missing_ok=True)
|
|
|
|
lines: list[dict] = []
|
|
if result and result[0]:
|
|
for item in result[0]:
|
|
if not item or len(item) < 2:
|
|
continue
|
|
box, rec = item[0], item[1]
|
|
text = rec[0] if rec else ""
|
|
conf = float(rec[1]) if rec and len(rec) > 1 else 0.0
|
|
if not text:
|
|
continue
|
|
if scale_x != 1.0 or scale_y != 1.0:
|
|
box = _scale_box(box, scale_x, scale_y)
|
|
bbox = _bbox_from_box(box)
|
|
lines.append(
|
|
{
|
|
"text": text,
|
|
"confidence": conf,
|
|
"box": box,
|
|
"bbox": bbox,
|
|
}
|
|
)
|
|
|
|
return {
|
|
"text": "\n".join(line["text"] for line in lines),
|
|
"lines": lines,
|
|
"width": orig_w,
|
|
"height": orig_h,
|
|
}
|
|
|
|
|
|
def run_ocr_with_regions(image_path: str, service_url: str | None = None) -> dict:
|
|
"""Return OCR text plus line-level bounding boxes for annotation."""
|
|
remote = resolve_ocr_service_url(service_url)
|
|
if remote:
|
|
return _run_remote_ocr(remote, image_path)
|
|
return _run_local_ocr(image_path)
|
|
|
|
|
|
def run_ocr(image_path: str) -> str:
|
|
return run_ocr_with_regions(image_path)["text"]
|
|
|
|
|
|
def save_upload_file(user_id: str, question_id: str, filename: str, content: bytes) -> str:
|
|
ext = Path(filename).suffix.lower() or ".jpg"
|
|
if ext not in {".jpg", ".jpeg", ".png", ".webp"}:
|
|
ext = ".jpg"
|
|
user_dir = Path(settings.UPLOAD_DIR) / user_id
|
|
user_dir.mkdir(parents=True, exist_ok=True)
|
|
rel_path = f"{user_id}/{question_id}{ext}"
|
|
full_path = Path(settings.UPLOAD_DIR) / rel_path
|
|
normalized = _normalize_image_bytes(content, settings.UPLOAD_MAX_SIDE)
|
|
full_path.write_bytes(normalized)
|
|
return rel_path
|
|
|
|
|
|
def annotated_rel_path(original_rel: str) -> str:
|
|
p = Path(original_rel)
|
|
return str(p.parent / f"{p.stem}_marked.jpg")
|