上传前人工裁剪错题区域,OCR 原文排除手写作答。
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -16,6 +16,7 @@ from app.schemas import WrongQuestionCategoryEnum, WrongQuestionOut, WrongQuesti
|
||||
from app.services import annotation as annotation_service
|
||||
from app.services import llm as llm_service
|
||||
from app.services import ocr as ocr_service
|
||||
from app.services import ocr_filter as ocr_filter_service
|
||||
from app.services.student_access import get_student_for_user
|
||||
|
||||
router = APIRouter(tags=["wrong_questions"])
|
||||
@@ -81,13 +82,21 @@ def _wq_to_out(wq: WrongQuestion) -> WrongQuestionOut:
|
||||
solution_text=wq.solution_text,
|
||||
mark_regions=_parse_mark_regions(wq.mark_regions_json),
|
||||
has_annotated_image=bool(wq.annotated_image_path),
|
||||
has_cropped_image=bool(wq.cropped_image_path),
|
||||
error_message=wq.error_message,
|
||||
status=wq.status,
|
||||
created_at=wq.created_at,
|
||||
)
|
||||
|
||||
|
||||
async def _run_ai_pipeline(wq: WrongQuestion, db: Session, ocr_lines: list[dict], ocr_text: str):
|
||||
async def _run_ai_pipeline(
|
||||
wq: WrongQuestion,
|
||||
db: Session,
|
||||
ocr_lines: list[dict],
|
||||
printed_text: str,
|
||||
*,
|
||||
hw_indices: list[int] | None = None,
|
||||
):
|
||||
import asyncio
|
||||
|
||||
subject_name = wq.subject.name if wq.subject else "综合"
|
||||
@@ -97,14 +106,22 @@ async def _run_ai_pipeline(wq: WrongQuestion, db: Session, ocr_lines: list[dict]
|
||||
image_full = str(Path(settings.UPLOAD_DIR) / wq.image_path)
|
||||
timeout = settings.AI_TIMEOUT_SECONDS
|
||||
|
||||
candidate_indices = hw_indices if hw_indices else list(range(len(ocr_lines)))
|
||||
candidate_lines = [ocr_lines[i] for i in candidate_indices if 0 <= i < len(ocr_lines)]
|
||||
if not candidate_lines:
|
||||
candidate_lines = ocr_lines
|
||||
candidate_indices = list(range(len(ocr_lines)))
|
||||
|
||||
try:
|
||||
detect_resp = await asyncio.wait_for(
|
||||
llm_service.detect_wrong_line_ids(ai_cfg, subject_name, ocr_lines, school_level),
|
||||
llm_service.detect_wrong_line_ids(ai_cfg, subject_name, candidate_lines, school_level),
|
||||
timeout=min(90, timeout),
|
||||
)
|
||||
wrong_ids = annotation_service.parse_wrong_line_ids(detect_resp, ocr_lines)
|
||||
local_wrong = annotation_service.parse_wrong_line_ids(detect_resp, candidate_lines)
|
||||
wrong_ids = [candidate_indices[i] for i in local_wrong if i < len(candidate_indices)]
|
||||
except Exception:
|
||||
wrong_ids = annotation_service.heuristic_wrong_line_ids(ocr_lines)
|
||||
local_wrong = annotation_service.heuristic_wrong_line_ids(candidate_lines)
|
||||
wrong_ids = [candidate_indices[i] for i in local_wrong if i < len(candidate_indices)]
|
||||
|
||||
regions = annotation_service.regions_from_lines(ocr_lines, wrong_ids)
|
||||
wq.mark_regions_json = json.dumps(regions, ensure_ascii=False)
|
||||
@@ -115,7 +132,7 @@ async def _run_ai_pipeline(wq: WrongQuestion, db: Session, ocr_lines: list[dict]
|
||||
db.commit()
|
||||
|
||||
question_text = await asyncio.wait_for(
|
||||
llm_service.format_question(ai_cfg, subject_name, ocr_text, school_level),
|
||||
llm_service.format_question(ai_cfg, subject_name, printed_text, school_level),
|
||||
timeout=timeout,
|
||||
)
|
||||
solution_full = await asyncio.wait_for(
|
||||
@@ -156,7 +173,17 @@ def _process_wrong_question(question_id: uuid.UUID):
|
||||
ocr_result = future.result(timeout=settings.OCR_TIMEOUT_SECONDS)
|
||||
ocr_text = ocr_result["text"]
|
||||
ocr_lines = ocr_result["lines"]
|
||||
wq.ocr_raw_text = ocr_text or None
|
||||
img_height = int(ocr_result.get("height") or 0)
|
||||
printed_ids, hw_ids = ocr_filter_service.split_printed_handwriting(
|
||||
ocr_lines,
|
||||
img_height,
|
||||
answer_zone_ratio=settings.OCR_ANSWER_ZONE_RATIO,
|
||||
enabled=settings.OCR_FILTER_HANDWRITING,
|
||||
)
|
||||
printed_text = ocr_filter_service.text_from_indices(ocr_lines, printed_ids)
|
||||
if not printed_text:
|
||||
printed_text = ocr_text
|
||||
wq.ocr_raw_text = printed_text or None
|
||||
if not ocr_text:
|
||||
wq.status = WrongQuestionStatus.failed
|
||||
wq.error_message = "OCR 未识别到文字,请拍摄更清晰、光线充足的题目照片"
|
||||
@@ -188,7 +215,15 @@ def _process_wrong_question(question_id: uuid.UUID):
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(_run_ai_pipeline(wq, db, ocr_lines, ocr_text))
|
||||
loop.run_until_complete(
|
||||
_run_ai_pipeline(
|
||||
wq,
|
||||
db,
|
||||
ocr_lines,
|
||||
printed_text,
|
||||
hw_indices=hw_ids,
|
||||
)
|
||||
)
|
||||
db.commit()
|
||||
except Exception as exc:
|
||||
wq.status = WrongQuestionStatus.failed
|
||||
@@ -359,12 +394,15 @@ def delete_wrong_question(
|
||||
|
||||
image_full = Path(settings.UPLOAD_DIR) / wq.image_path
|
||||
ann_full = Path(settings.UPLOAD_DIR) / wq.annotated_image_path if wq.annotated_image_path else None
|
||||
crop_full = Path(settings.UPLOAD_DIR) / wq.cropped_image_path if wq.cropped_image_path else None
|
||||
db.delete(wq)
|
||||
db.commit()
|
||||
if image_full.exists():
|
||||
image_full.unlink()
|
||||
if ann_full and ann_full.exists():
|
||||
ann_full.unlink()
|
||||
if crop_full and crop_full.exists():
|
||||
crop_full.unlink()
|
||||
|
||||
|
||||
@router.post("/wrong-questions/{question_id}/retry-ocr", response_model=WrongQuestionOut)
|
||||
@@ -385,6 +423,7 @@ def retry_ocr(
|
||||
|
||||
wq.status = WrongQuestionStatus.pending
|
||||
wq.error_message = None
|
||||
wq.cropped_image_path = None
|
||||
db.commit()
|
||||
background_tasks.add_task(_process_wrong_question, wq.id)
|
||||
return _wq_to_out(wq)
|
||||
@@ -483,3 +522,26 @@ def get_wrong_question_annotated_image(
|
||||
if not image_full.exists():
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="标注图不存在")
|
||||
return FileResponse(image_full)
|
||||
|
||||
|
||||
@router.get("/wrong-questions/{question_id}/cropped-image")
|
||||
def get_wrong_question_cropped_image(
|
||||
question_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
wq = (
|
||||
db.query(WrongQuestion)
|
||||
.options(joinedload(WrongQuestion.student))
|
||||
.filter(WrongQuestion.id == question_id)
|
||||
.first()
|
||||
)
|
||||
if wq is None or wq.student.user_id != current_user.id:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="错题不存在")
|
||||
if not wq.cropped_image_path:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="裁剪图尚未生成")
|
||||
|
||||
image_full = Path(settings.UPLOAD_DIR) / wq.cropped_image_path
|
||||
if not image_full.exists():
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="裁剪图不存在")
|
||||
return FileResponse(image_full)
|
||||
|
||||
Reference in New Issue
Block a user