上传前人工裁剪错题区域,OCR 原文排除手写作答。

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
dekun
2026-06-28 16:01:46 +08:00
parent 23be608521
commit acfe002fbf
18 changed files with 975 additions and 448 deletions
+2
View File
@@ -28,6 +28,8 @@ class Settings(BaseSettings):
OCR_SERVICE_URL: str = "http://127.0.0.1:23567"
OCR_API_KEY: str = ""
OCR_USE_GPU: bool = False
OCR_FILTER_HANDWRITING: bool = True
OCR_ANSWER_ZONE_RATIO: float = 0.45
class Config:
env_file = ".env"
+1
View File
@@ -133,6 +133,7 @@ class WrongQuestion(Base):
solution_text: Mapped[str | None] = mapped_column(Text, nullable=True)
mark_regions_json: Mapped[str | None] = mapped_column(Text, nullable=True)
annotated_image_path: Mapped[str | None] = mapped_column(String(512), nullable=True)
cropped_image_path: Mapped[str | None] = mapped_column(String(512), nullable=True)
error_message: Mapped[str | None] = mapped_column(Text, nullable=True)
status: Mapped[WrongQuestionStatus] = mapped_column(
Enum(WrongQuestionStatus), default=WrongQuestionStatus.pending
+69 -7
View File
@@ -16,6 +16,7 @@ from app.schemas import WrongQuestionCategoryEnum, WrongQuestionOut, WrongQuesti
from app.services import annotation as annotation_service
from app.services import llm as llm_service
from app.services import ocr as ocr_service
from app.services import ocr_filter as ocr_filter_service
from app.services.student_access import get_student_for_user
router = APIRouter(tags=["wrong_questions"])
@@ -81,13 +82,21 @@ def _wq_to_out(wq: WrongQuestion) -> WrongQuestionOut:
solution_text=wq.solution_text,
mark_regions=_parse_mark_regions(wq.mark_regions_json),
has_annotated_image=bool(wq.annotated_image_path),
has_cropped_image=bool(wq.cropped_image_path),
error_message=wq.error_message,
status=wq.status,
created_at=wq.created_at,
)
async def _run_ai_pipeline(wq: WrongQuestion, db: Session, ocr_lines: list[dict], ocr_text: str):
async def _run_ai_pipeline(
wq: WrongQuestion,
db: Session,
ocr_lines: list[dict],
printed_text: str,
*,
hw_indices: list[int] | None = None,
):
import asyncio
subject_name = wq.subject.name if wq.subject else "综合"
@@ -97,14 +106,22 @@ async def _run_ai_pipeline(wq: WrongQuestion, db: Session, ocr_lines: list[dict]
image_full = str(Path(settings.UPLOAD_DIR) / wq.image_path)
timeout = settings.AI_TIMEOUT_SECONDS
candidate_indices = hw_indices if hw_indices else list(range(len(ocr_lines)))
candidate_lines = [ocr_lines[i] for i in candidate_indices if 0 <= i < len(ocr_lines)]
if not candidate_lines:
candidate_lines = ocr_lines
candidate_indices = list(range(len(ocr_lines)))
try:
detect_resp = await asyncio.wait_for(
llm_service.detect_wrong_line_ids(ai_cfg, subject_name, ocr_lines, school_level),
llm_service.detect_wrong_line_ids(ai_cfg, subject_name, candidate_lines, school_level),
timeout=min(90, timeout),
)
wrong_ids = annotation_service.parse_wrong_line_ids(detect_resp, ocr_lines)
local_wrong = annotation_service.parse_wrong_line_ids(detect_resp, candidate_lines)
wrong_ids = [candidate_indices[i] for i in local_wrong if i < len(candidate_indices)]
except Exception:
wrong_ids = annotation_service.heuristic_wrong_line_ids(ocr_lines)
local_wrong = annotation_service.heuristic_wrong_line_ids(candidate_lines)
wrong_ids = [candidate_indices[i] for i in local_wrong if i < len(candidate_indices)]
regions = annotation_service.regions_from_lines(ocr_lines, wrong_ids)
wq.mark_regions_json = json.dumps(regions, ensure_ascii=False)
@@ -115,7 +132,7 @@ async def _run_ai_pipeline(wq: WrongQuestion, db: Session, ocr_lines: list[dict]
db.commit()
question_text = await asyncio.wait_for(
llm_service.format_question(ai_cfg, subject_name, ocr_text, school_level),
llm_service.format_question(ai_cfg, subject_name, printed_text, school_level),
timeout=timeout,
)
solution_full = await asyncio.wait_for(
@@ -156,7 +173,17 @@ def _process_wrong_question(question_id: uuid.UUID):
ocr_result = future.result(timeout=settings.OCR_TIMEOUT_SECONDS)
ocr_text = ocr_result["text"]
ocr_lines = ocr_result["lines"]
wq.ocr_raw_text = ocr_text or None
img_height = int(ocr_result.get("height") or 0)
printed_ids, hw_ids = ocr_filter_service.split_printed_handwriting(
ocr_lines,
img_height,
answer_zone_ratio=settings.OCR_ANSWER_ZONE_RATIO,
enabled=settings.OCR_FILTER_HANDWRITING,
)
printed_text = ocr_filter_service.text_from_indices(ocr_lines, printed_ids)
if not printed_text:
printed_text = ocr_text
wq.ocr_raw_text = printed_text or None
if not ocr_text:
wq.status = WrongQuestionStatus.failed
wq.error_message = "OCR 未识别到文字,请拍摄更清晰、光线充足的题目照片"
@@ -188,7 +215,15 @@ def _process_wrong_question(question_id: uuid.UUID):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(_run_ai_pipeline(wq, db, ocr_lines, ocr_text))
loop.run_until_complete(
_run_ai_pipeline(
wq,
db,
ocr_lines,
printed_text,
hw_indices=hw_ids,
)
)
db.commit()
except Exception as exc:
wq.status = WrongQuestionStatus.failed
@@ -359,12 +394,15 @@ def delete_wrong_question(
image_full = Path(settings.UPLOAD_DIR) / wq.image_path
ann_full = Path(settings.UPLOAD_DIR) / wq.annotated_image_path if wq.annotated_image_path else None
crop_full = Path(settings.UPLOAD_DIR) / wq.cropped_image_path if wq.cropped_image_path else None
db.delete(wq)
db.commit()
if image_full.exists():
image_full.unlink()
if ann_full and ann_full.exists():
ann_full.unlink()
if crop_full and crop_full.exists():
crop_full.unlink()
@router.post("/wrong-questions/{question_id}/retry-ocr", response_model=WrongQuestionOut)
@@ -385,6 +423,7 @@ def retry_ocr(
wq.status = WrongQuestionStatus.pending
wq.error_message = None
wq.cropped_image_path = None
db.commit()
background_tasks.add_task(_process_wrong_question, wq.id)
return _wq_to_out(wq)
@@ -483,3 +522,26 @@ def get_wrong_question_annotated_image(
if not image_full.exists():
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="标注图不存在")
return FileResponse(image_full)
@router.get("/wrong-questions/{question_id}/cropped-image")
def get_wrong_question_cropped_image(
question_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
wq = (
db.query(WrongQuestion)
.options(joinedload(WrongQuestion.student))
.filter(WrongQuestion.id == question_id)
.first()
)
if wq is None or wq.student.user_id != current_user.id:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="错题不存在")
if not wq.cropped_image_path:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="裁剪图尚未生成")
image_full = Path(settings.UPLOAD_DIR) / wq.cropped_image_path
if not image_full.exists():
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="裁剪图不存在")
return FileResponse(image_full)
+1
View File
@@ -239,6 +239,7 @@ class WrongQuestionOut(BaseModel):
solution_text: str | None
mark_regions: list[dict] | None = None
has_annotated_image: bool = False
has_cropped_image: bool = False
error_message: str | None = None
status: WrongQuestionStatusEnum
created_at: datetime
+48
View File
@@ -107,3 +107,51 @@ def split_solution_sections(text: str) -> tuple[str | None, str]:
approach = parts[0].replace("## 解题思路", "").strip()
rest = "## " + parts[1]
return approach or None, rest.strip()
def union_bbox(bboxes: list[list[float]], img_w: int, img_h: int, padding_ratio: float = 0.06) -> list[int]:
if not bboxes:
return [0, 0, img_w, img_h]
x1 = min(b[0] for b in bboxes)
y1 = min(b[1] for b in bboxes)
x2 = max(b[2] for b in bboxes)
y2 = max(b[3] for b in bboxes)
pad_x = max(8, (x2 - x1) * padding_ratio)
pad_y = max(8, (y2 - y1) * padding_ratio)
return [
int(max(0, x1 - pad_x)),
int(max(0, y1 - pad_y)),
int(min(img_w, x2 + pad_x)),
int(min(img_h, y2 + pad_y)),
]
def cropped_rel_path(original_rel: str) -> str:
p = Path(original_rel)
return str(p.parent / f"{p.stem}_crop.jpg")
def crop_wrong_region(
src_path: str,
lines: list[dict],
wrong_ids: list[int],
dest_rel_path: str,
img_width: int,
img_height: int,
) -> str | None:
if not wrong_ids:
return None
bboxes = [lines[i].get("bbox") or [0, 0, 0, 0] for i in wrong_ids if i < len(lines)]
if not bboxes:
return None
box = union_bbox(bboxes, img_width, img_height, padding_ratio=0.12)
x1, y1, x2, y2 = box
if x2 <= x1 or y2 <= y1:
return None
img = Image.open(src_path).convert("RGB")
cropped = img.crop((x1, y1, x2, y2))
full_path = Path(settings.UPLOAD_DIR) / dest_rel_path
full_path.parent.mkdir(parents=True, exist_ok=True)
cropped.save(full_path, format="JPEG", quality=92)
return dest_rel_path
+2
View File
@@ -77,6 +77,8 @@ def run_migrations() -> None:
wq_alters.append("ADD COLUMN mark_regions_json TEXT")
if "annotated_image_path" not in wq_columns:
wq_alters.append("ADD COLUMN annotated_image_path VARCHAR(512)")
if "cropped_image_path" not in wq_columns:
wq_alters.append("ADD COLUMN cropped_image_path VARCHAR(512)")
if "error_message" not in wq_columns:
wq_alters.append("ADD COLUMN error_message TEXT")
if wq_alters:
+103
View File
@@ -0,0 +1,103 @@
"""OCR 行分类:区分印刷题干与手写作答。"""
import re
# 印刷体/题干常见特征
_PRINTED_RE = re.compile(
r"(第\s*[0-9一二三四五六七八九十百]+题|"
r"[(]\s*[0-9一二三四五六七八九十]+\s*[)]|"
r"^\s*[0-9]{1,2}\s*[\..、\)]|"
r"^[A-Da-d]\s*[\..、]|"
r"选择题|填空题|解答题|证明题|计算题|应用题|"
r"下列|以下|正确|错误|不正确|单选|多选|"
r"已知|求证|设|若|求|如图|如图所示)",
re.MULTILINE,
)
# 手写作答常见特征(算式、短碎片)
_HANDWRITE_RE = re.compile(
r"^[0-9\s+\-×÷*/=≈<>()\[\].,,、%°]+$|"
r"^[xXyYzZ]\s*[=]|"
r"^\s*\d+\s*[\.]\s*\d*\s*$"
)
def _line_center_y(line: dict) -> float:
bbox = line.get("bbox") or [0, 0, 0, 0]
return (float(bbox[1]) + float(bbox[3])) / 2.0
def _looks_printed(text: str) -> bool:
t = text.strip()
if len(t) >= 12 and _PRINTED_RE.search(t):
return True
if _PRINTED_RE.match(t):
return True
return False
def _looks_handwritten(text: str, confidence: float) -> bool:
t = text.strip()
if not t:
return True
if _looks_printed(t):
return False
if _HANDWRITE_RE.match(t):
return True
if len(t) <= 6 and confidence < 0.92:
return True
digit_ratio = sum(c.isdigit() or c in "+-×÷*/=≈.%" for c in t) / max(len(t), 1)
if digit_ratio > 0.55 and len(t) < 20:
return True
return False
def split_printed_handwriting(
lines: list[dict],
img_height: int,
*,
answer_zone_ratio: float = 0.45,
enabled: bool = True,
) -> tuple[list[int], list[int]]:
"""
返回 (印刷题干行编号, 手写作答行编号),编号为 lines 列表下标。
answer_zone_ratio: 图片高度比例,低于此 y 中心视为题干区,高于视为作答区。
"""
if not lines or not enabled or img_height <= 0:
return list(range(len(lines))), []
split_y = img_height * answer_zone_ratio
printed_ids: list[int] = []
handwriting_ids: list[int] = []
for i, line in enumerate(lines):
text = line.get("text", "")
conf = float(line.get("confidence") or 0.0)
cy = _line_center_y(line)
if _looks_printed(text):
printed_ids.append(i)
continue
in_answer_zone = cy >= split_y
if in_answer_zone and _looks_handwritten(text, conf):
handwriting_ids.append(i)
elif not in_answer_zone:
printed_ids.append(i)
elif in_answer_zone:
handwriting_ids.append(i)
if not printed_ids and lines:
printed_ids = list(range(min(3, len(lines))))
if not handwriting_ids and len(lines) >= 2:
handwriting_ids = list(range(max(0, len(lines) - 3), len(lines)))
return printed_ids, handwriting_ids
def lines_by_indices(lines: list[dict], indices: list[int]) -> list[dict]:
return [lines[i] for i in indices if 0 <= i < len(lines)]
def text_from_indices(lines: list[dict], indices: list[int]) -> str:
return "\n".join(lines[i].get("text", "") for i in indices if 0 <= i < len(lines)).strip()