支持局域网 GPU OCR 服务,配置方式类似 Ollama。
This commit is contained in:
@@ -25,6 +25,9 @@ class Settings(BaseSettings):
|
||||
OCR_MAX_SIDE: int = 1280
|
||||
UPLOAD_MAX_SIDE: int = 2048
|
||||
OCR_WARMUP: bool = True
|
||||
OCR_SERVICE_URL: str = ""
|
||||
OCR_API_KEY: str = ""
|
||||
OCR_USE_GPU: bool = False
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
|
||||
@@ -159,6 +159,7 @@ class SystemSettings(Base):
|
||||
openai_base_url: Mapped[str | None] = mapped_column(String(256), nullable=True)
|
||||
openai_model: Mapped[str | None] = mapped_column(String(128), nullable=True)
|
||||
openai_api_key: Mapped[str | None] = mapped_column(String(512), nullable=True)
|
||||
ocr_service_url: Mapped[str | None] = mapped_column(String(256), nullable=True)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
@@ -31,6 +31,7 @@ def settings_to_out(row: SystemSettings) -> SystemSettingsOut:
|
||||
openai_base_url=row.openai_base_url,
|
||||
openai_model=row.openai_model,
|
||||
openai_api_key_set=bool(row.openai_api_key),
|
||||
ocr_service_url=row.ocr_service_url,
|
||||
updated_at=row.updated_at,
|
||||
)
|
||||
|
||||
@@ -74,6 +75,8 @@ def update_settings(
|
||||
row.openai_model = data.openai_model or None
|
||||
if data.openai_api_key is not None and data.openai_api_key.strip():
|
||||
row.openai_api_key = data.openai_api_key.strip()
|
||||
if data.ocr_service_url is not None:
|
||||
row.ocr_service_url = data.ocr_service_url.strip() or None
|
||||
row.updated_at = datetime.now(timezone.utc)
|
||||
db.commit()
|
||||
db.refresh(row)
|
||||
|
||||
@@ -11,7 +11,7 @@ from sqlalchemy.orm import Session, joinedload
|
||||
from app.core.config import settings
|
||||
from app.core.database import SessionLocal, get_db
|
||||
from app.core.deps import get_current_user
|
||||
from app.models.user import Subject, User, WrongQuestion, WrongQuestionCategory, WrongQuestionStatus
|
||||
from app.models.user import Subject, SystemSettings, User, WrongQuestion, WrongQuestionCategory, WrongQuestionStatus
|
||||
from app.schemas import WrongQuestionCategoryEnum, WrongQuestionOut, WrongQuestionUpdate
|
||||
from app.services import annotation as annotation_service
|
||||
from app.services import llm as llm_service
|
||||
@@ -50,6 +50,13 @@ def _expire_stale_processing(wq: WrongQuestion, db: Session) -> None:
|
||||
db.commit()
|
||||
|
||||
|
||||
def _ocr_service_url(db: Session) -> str | None:
|
||||
row = db.get(SystemSettings, 1)
|
||||
if row and row.ocr_service_url:
|
||||
return row.ocr_service_url.strip() or None
|
||||
return ocr_service.resolve_ocr_service_url()
|
||||
|
||||
|
||||
def _parse_mark_regions(raw: str | None) -> list[dict] | None:
|
||||
if not raw:
|
||||
return None
|
||||
@@ -140,9 +147,12 @@ def _process_wrong_question(question_id: uuid.UUID):
|
||||
|
||||
wq.error_message = None
|
||||
image_full = Path(settings.UPLOAD_DIR) / wq.image_path
|
||||
ocr_url = _ocr_service_url(db)
|
||||
try:
|
||||
with ThreadPoolExecutor(max_workers=1) as pool:
|
||||
future = pool.submit(ocr_service.run_ocr_with_regions, str(image_full))
|
||||
future = pool.submit(
|
||||
ocr_service.run_ocr_with_regions, str(image_full), ocr_url
|
||||
)
|
||||
ocr_result = future.result(timeout=settings.OCR_TIMEOUT_SECONDS)
|
||||
ocr_text = ocr_result["text"]
|
||||
ocr_lines = ocr_result["lines"]
|
||||
@@ -166,6 +176,8 @@ def _process_wrong_question(question_id: uuid.UUID):
|
||||
msg = _short_error(exc, "OCR 识别失败:")
|
||||
if "libGL" in str(exc):
|
||||
msg += " 请在服务器执行: sudo bash deploy/install-ocr-deps.sh && systemctl restart grade-archive"
|
||||
elif ocr_url:
|
||||
msg += f" 请检查 OCR 服务是否可达: {ocr_url} (可浏览器访问 {ocr_url.rstrip('/')}/health)"
|
||||
wq.error_message = msg
|
||||
db.commit()
|
||||
return
|
||||
|
||||
@@ -74,6 +74,7 @@ class SystemSettingsOut(BaseModel):
|
||||
openai_base_url: str | None = None
|
||||
openai_model: str | None = None
|
||||
openai_api_key_set: bool = False
|
||||
ocr_service_url: str | None = None
|
||||
updated_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
@@ -87,6 +88,7 @@ class SystemSettingsUpdate(BaseModel):
|
||||
openai_base_url: str | None = None
|
||||
openai_model: str | None = None
|
||||
openai_api_key: str | None = None
|
||||
ocr_service_url: str | None = None
|
||||
|
||||
|
||||
class AdminProfileUpdate(BaseModel):
|
||||
|
||||
@@ -61,6 +61,8 @@ def run_migrations() -> None:
|
||||
alters.append("ADD COLUMN openai_model VARCHAR(128)")
|
||||
if "openai_api_key" not in ss_columns:
|
||||
alters.append("ADD COLUMN openai_api_key VARCHAR(512)")
|
||||
if "ocr_service_url" not in ss_columns:
|
||||
alters.append("ADD COLUMN ocr_service_url VARCHAR(256)")
|
||||
if alters:
|
||||
with engine.begin() as conn:
|
||||
for clause in alters:
|
||||
|
||||
@@ -5,6 +5,7 @@ import threading
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
from PIL import Image
|
||||
|
||||
from app.core.config import settings
|
||||
@@ -23,22 +24,32 @@ def get_ocr_engine():
|
||||
if _ocr_engine is None:
|
||||
from paddleocr import PaddleOCR
|
||||
|
||||
use_gpu = settings.OCR_USE_GPU
|
||||
_ocr_engine = PaddleOCR(
|
||||
use_angle_cls=False,
|
||||
lang="ch",
|
||||
show_log=False,
|
||||
use_gpu=False,
|
||||
enable_mkldnn=True,
|
||||
use_gpu=use_gpu,
|
||||
enable_mkldnn=not use_gpu,
|
||||
det_limit_side_len=min(settings.OCR_MAX_SIDE, 1280),
|
||||
rec_batch_num=8,
|
||||
)
|
||||
return _ocr_engine
|
||||
|
||||
|
||||
def resolve_ocr_service_url(service_url: str | None = None) -> str | None:
|
||||
url = (service_url or settings.OCR_SERVICE_URL or "").strip()
|
||||
return url or None
|
||||
|
||||
|
||||
def uses_remote_ocr(service_url: str | None = None) -> bool:
|
||||
return resolve_ocr_service_url(service_url) is not None
|
||||
|
||||
|
||||
def warmup_ocr_engine() -> None:
|
||||
"""后台预加载 OCR 模型,避免首张图片等待数分钟。"""
|
||||
global _ocr_warmup_started
|
||||
if _ocr_warmup_started or not settings.OCR_WARMUP:
|
||||
if _ocr_warmup_started or not settings.OCR_WARMUP or uses_remote_ocr():
|
||||
return
|
||||
_ocr_warmup_started = True
|
||||
|
||||
@@ -110,8 +121,20 @@ def _prepare_ocr_image(image_path: str) -> tuple[str, float, float, int, int, Pa
|
||||
return str(tmp), scale_x, scale_y, orig_w, orig_h, tmp
|
||||
|
||||
|
||||
def run_ocr_with_regions(image_path: str) -> dict:
|
||||
"""Return OCR text plus line-level bounding boxes for annotation."""
|
||||
def _run_remote_ocr(service_url: str, image_path: str) -> dict:
|
||||
url = f"{service_url.rstrip('/')}/api/ocr/regions"
|
||||
headers: dict[str, str] = {}
|
||||
if settings.OCR_API_KEY:
|
||||
headers["X-OCR-Key"] = settings.OCR_API_KEY
|
||||
with open(image_path, "rb") as handle:
|
||||
files = {"file": (Path(image_path).name, handle, "image/jpeg")}
|
||||
with httpx.Client(timeout=settings.OCR_TIMEOUT_SECONDS) as client:
|
||||
resp = client.post(url, files=files, headers=headers)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
def _run_local_ocr(image_path: str) -> dict:
|
||||
engine = get_ocr_engine()
|
||||
ocr_path, scale_x, scale_y, orig_w, orig_h, tmp_path = _prepare_ocr_image(image_path)
|
||||
try:
|
||||
@@ -150,6 +173,14 @@ def run_ocr_with_regions(image_path: str) -> dict:
|
||||
}
|
||||
|
||||
|
||||
def run_ocr_with_regions(image_path: str, service_url: str | None = None) -> dict:
|
||||
"""Return OCR text plus line-level bounding boxes for annotation."""
|
||||
remote = resolve_ocr_service_url(service_url)
|
||||
if remote:
|
||||
return _run_remote_ocr(remote, image_path)
|
||||
return _run_local_ocr(image_path)
|
||||
|
||||
|
||||
def run_ocr(image_path: str) -> str:
|
||||
return run_ocr_with_regions(image_path)["text"]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user