import logging import os import tempfile import threading from io import BytesIO from pathlib import Path import httpx from PIL import Image from app.core.config import settings # 无图形界面服务器:避免 OpenCV/Paddle 依赖 X11 os.environ.setdefault("OPENCV_IO_ENABLE_OPENEXR", "0") logger = logging.getLogger(__name__) _ocr_engine = None _ocr_warmup_started = False def get_ocr_engine(): global _ocr_engine if _ocr_engine is None: from paddleocr import PaddleOCR use_gpu = settings.OCR_USE_GPU _ocr_engine = PaddleOCR( use_angle_cls=False, lang="ch", show_log=False, use_gpu=use_gpu, enable_mkldnn=not use_gpu, det_limit_side_len=min(settings.OCR_MAX_SIDE, 1280), rec_batch_num=8, ) return _ocr_engine def resolve_ocr_service_url(service_url: str | None = None) -> str | None: url = (service_url or settings.OCR_SERVICE_URL or "").strip() return url or None def uses_remote_ocr(service_url: str | None = None) -> bool: return resolve_ocr_service_url(service_url) is not None def warmup_ocr_engine() -> None: """后台预加载 OCR 模型,避免首张图片等待数分钟。""" global _ocr_warmup_started if _ocr_warmup_started or not settings.OCR_WARMUP or uses_remote_ocr(): return _ocr_warmup_started = True def _run() -> None: try: engine = get_ocr_engine() with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp: Image.new("RGB", (120, 40), color=(255, 255, 255)).save(tmp.name, format="JPEG") tmp_path = tmp.name try: engine.ocr(tmp_path, cls=False) logger.info("OCR engine warmed up") finally: Path(tmp_path).unlink(missing_ok=True) except Exception as exc: logger.warning("OCR warmup failed: %s", exc) threading.Thread(target=_run, daemon=True, name="ocr-warmup").start() def _bbox_from_box(box: list) -> list[float]: xs = [float(p[0]) for p in box] ys = [float(p[1]) for p in box] return [min(xs), min(ys), max(xs), max(ys)] def _scale_bbox(bbox: list[float], scale_x: float, scale_y: float) -> list[float]: return [bbox[0] * scale_x, bbox[1] * scale_y, bbox[2] * scale_x, bbox[3] * scale_y] def _scale_box(box: list, scale_x: float, scale_y: float) -> list: return [[float(p[0]) * scale_x, float(p[1]) * scale_y] for p in box] def _normalize_image_bytes(content: bytes, max_side: int) -> bytes: with Image.open(BytesIO(content)) as img: img = img.convert("RGB") width, height = img.size longest = max(width, height) if longest > max_side: ratio = max_side / longest img = img.resize((int(width * ratio), int(height * ratio)), Image.Resampling.LANCZOS) buf = BytesIO() img.save(buf, format="JPEG", quality=88, optimize=True) return buf.getvalue() def _prepare_ocr_image(image_path: str) -> tuple[str, float, float, int, int, Path | None]: """若图片过大则生成临时缩小图供 OCR,返回缩放比例与原始尺寸。""" with Image.open(image_path) as img: orig_w, orig_h = img.size max_side = settings.OCR_MAX_SIDE longest = max(orig_w, orig_h) if longest <= max_side: return image_path, 1.0, 1.0, orig_w, orig_h, None with Image.open(image_path) as img: img = img.convert("RGB") ratio = max_side / longest new_w = max(1, int(orig_w * ratio)) new_h = max(1, int(orig_h * ratio)) resized = img.resize((new_w, new_h), Image.Resampling.LANCZOS) tmp = Path(tempfile.gettempdir()) / f"ocr_{Path(image_path).stem}_{os.getpid()}.jpg" resized.save(tmp, format="JPEG", quality=85, optimize=True) scale_x = orig_w / new_w scale_y = orig_h / new_h return str(tmp), scale_x, scale_y, orig_w, orig_h, tmp def _run_remote_ocr(service_url: str, image_path: str) -> dict: url = f"{service_url.rstrip('/')}/api/ocr/regions" headers: dict[str, str] = {} if settings.OCR_API_KEY: headers["X-OCR-Key"] = settings.OCR_API_KEY with open(image_path, "rb") as handle: files = {"file": (Path(image_path).name, handle, "image/jpeg")} with httpx.Client(timeout=settings.OCR_TIMEOUT_SECONDS) as client: resp = client.post(url, files=files, headers=headers) resp.raise_for_status() return resp.json() def _run_local_ocr(image_path: str) -> dict: engine = get_ocr_engine() ocr_path, scale_x, scale_y, orig_w, orig_h, tmp_path = _prepare_ocr_image(image_path) try: result = engine.ocr(ocr_path, cls=False) finally: if tmp_path is not None: tmp_path.unlink(missing_ok=True) lines: list[dict] = [] if result and result[0]: for item in result[0]: if not item or len(item) < 2: continue box, rec = item[0], item[1] text = rec[0] if rec else "" conf = float(rec[1]) if rec and len(rec) > 1 else 0.0 if not text: continue if scale_x != 1.0 or scale_y != 1.0: box = _scale_box(box, scale_x, scale_y) bbox = _bbox_from_box(box) lines.append( { "text": text, "confidence": conf, "box": box, "bbox": bbox, } ) return { "text": "\n".join(line["text"] for line in lines), "lines": lines, "width": orig_w, "height": orig_h, } def run_ocr_with_regions(image_path: str, service_url: str | None = None) -> dict: """Return OCR text plus line-level bounding boxes for annotation.""" remote = resolve_ocr_service_url(service_url) if remote: return _run_remote_ocr(remote, image_path) return _run_local_ocr(image_path) def run_ocr(image_path: str) -> str: return run_ocr_with_regions(image_path)["text"] def save_upload_file(user_id: str, question_id: str, filename: str, content: bytes) -> str: ext = Path(filename).suffix.lower() or ".jpg" if ext not in {".jpg", ".jpeg", ".png", ".webp"}: ext = ".jpg" user_dir = Path(settings.UPLOAD_DIR) / user_id user_dir.mkdir(parents=True, exist_ok=True) rel_path = f"{user_id}/{question_id}{ext}" full_path = Path(settings.UPLOAD_DIR) / rel_path normalized = _normalize_image_bytes(content, settings.UPLOAD_MAX_SIDE) full_path.write_bytes(normalized) return rel_path def annotated_rel_path(original_rel: str) -> str: p = Path(original_rel) return str(p.parent / f"{p.stem}_marked.jpg")