import json import uuid from concurrent.futures import ThreadPoolExecutor, TimeoutError as FuturesTimeout from datetime import datetime, timedelta, timezone from pathlib import Path from fastapi import APIRouter, BackgroundTasks, Depends, File, Form, HTTPException, Query, UploadFile, status from fastapi.responses import FileResponse from sqlalchemy.orm import Session, joinedload from app.core.config import settings from app.core.database import SessionLocal, get_db from app.core.deps import get_current_user from app.models.user import Subject, SystemSettings, User, WrongQuestion, WrongQuestionCategory, WrongQuestionStatus from app.schemas import WrongQuestionCategoryEnum, WrongQuestionOut, WrongQuestionUpdate from app.services import annotation as annotation_service from app.services import llm as llm_service from app.services import ocr as ocr_service from app.services import ocr_filter as ocr_filter_service from app.services.student_access import get_student_for_user router = APIRouter(tags=["wrong_questions"]) def _short_error(exc: BaseException, prefix: str = "") -> str: msg = str(exc).strip() or type(exc).__name__ if len(msg) > 500: msg = msg[:500] + "…" return f"{prefix}{msg}" if prefix else msg def _is_still_processing(wq: WrongQuestion) -> bool: if wq.status == WrongQuestionStatus.pending: return True if wq.status == WrongQuestionStatus.ocr_done and not wq.question_text and not wq.error_message: return True return False def _expire_stale_processing(wq: WrongQuestion, db: Session) -> None: if not _is_still_processing(wq): return created = wq.created_at if created.tzinfo is None: created = created.replace(tzinfo=timezone.utc) age = datetime.now(timezone.utc) - created if age <= timedelta(minutes=settings.PROCESS_STALE_MINUTES): return wq.status = WrongQuestionStatus.failed wq.error_message = f"处理超时(超过 {settings.PROCESS_STALE_MINUTES} 分钟),请点击「重新识别标注」重试" db.commit() def _ocr_service_url(db: Session) -> str | None: row = db.get(SystemSettings, 1) if row and row.ocr_service_url: return row.ocr_service_url.strip() or None return ocr_service.resolve_ocr_service_url() def _parse_mark_regions(raw: str | None) -> list[dict] | None: if not raw: return None try: data = json.loads(raw) return data if isinstance(data, list) else None except json.JSONDecodeError: return None def _wq_to_out(wq: WrongQuestion) -> WrongQuestionOut: return WrongQuestionOut( id=wq.id, student_id=wq.student_id, subject_id=wq.subject_id, subject_name=wq.subject.name if wq.subject else None, category=wq.category, image_path=wq.image_path, ocr_raw_text=wq.ocr_raw_text, question_text=wq.question_text, solution_approach=wq.solution_approach, solution_text=wq.solution_text, mark_regions=_parse_mark_regions(wq.mark_regions_json), has_annotated_image=bool(wq.annotated_image_path), has_cropped_image=bool(wq.cropped_image_path), error_message=wq.error_message, status=wq.status, created_at=wq.created_at, ) async def _run_ai_pipeline( wq: WrongQuestion, db: Session, ocr_lines: list[dict], printed_text: str, *, hw_indices: list[int] | None = None, ): import asyncio subject_name = wq.subject.name if wq.subject else "综合" school_level = wq.student.school_level if wq.student else None olympiad = wq.category == WrongQuestionCategory.olympiad ai_cfg = llm_service.load_ai_config(db) image_full = str(Path(settings.UPLOAD_DIR) / wq.image_path) timeout = settings.AI_TIMEOUT_SECONDS candidate_indices = hw_indices if hw_indices else list(range(len(ocr_lines))) candidate_lines = [ocr_lines[i] for i in candidate_indices if 0 <= i < len(ocr_lines)] if not candidate_lines: candidate_lines = ocr_lines candidate_indices = list(range(len(ocr_lines))) try: detect_resp = await asyncio.wait_for( llm_service.detect_wrong_line_ids(ai_cfg, subject_name, candidate_lines, school_level), timeout=min(90, timeout), ) local_wrong = annotation_service.parse_wrong_line_ids(detect_resp, candidate_lines) wrong_ids = [candidate_indices[i] for i in local_wrong if i < len(candidate_indices)] except Exception: local_wrong = annotation_service.heuristic_wrong_line_ids(candidate_lines) wrong_ids = [candidate_indices[i] for i in local_wrong if i < len(candidate_indices)] regions = annotation_service.regions_from_lines(ocr_lines, wrong_ids) wq.mark_regions_json = json.dumps(regions, ensure_ascii=False) ann_rel = ocr_service.annotated_rel_path(wq.image_path) wq.annotated_image_path = annotation_service.draw_annotated_image( image_full, ocr_lines, wrong_ids, ann_rel ) db.commit() question_text = await asyncio.wait_for( llm_service.format_question(ai_cfg, subject_name, printed_text, school_level), timeout=timeout, ) solution_full = await asyncio.wait_for( llm_service.generate_solution( ai_cfg, subject_name, question_text, school_level, olympiad=olympiad ), timeout=timeout, ) approach, solution_body = annotation_service.split_solution_sections(solution_full) wq.question_text = question_text wq.solution_approach = approach wq.solution_text = solution_body if approach else solution_full wq.status = WrongQuestionStatus.solved wq.error_message = None def _process_wrong_question(question_id: uuid.UUID): db = SessionLocal() wq = None try: wq = ( db.query(WrongQuestion) .options(joinedload(WrongQuestion.subject), joinedload(WrongQuestion.student)) .filter(WrongQuestion.id == question_id) .first() ) if wq is None: return wq.error_message = None image_full = Path(settings.UPLOAD_DIR) / wq.image_path ocr_url = _ocr_service_url(db) try: with ThreadPoolExecutor(max_workers=1) as pool: future = pool.submit( ocr_service.run_ocr_with_regions, str(image_full), ocr_url ) ocr_result = future.result(timeout=settings.OCR_TIMEOUT_SECONDS) ocr_text = ocr_result["text"] ocr_lines = ocr_result["lines"] img_height = int(ocr_result.get("height") or 0) printed_ids, hw_ids = ocr_filter_service.split_printed_handwriting( ocr_lines, img_height, answer_zone_ratio=settings.OCR_ANSWER_ZONE_RATIO, enabled=settings.OCR_FILTER_HANDWRITING, ) printed_text = ocr_filter_service.text_from_indices(ocr_lines, printed_ids) if not printed_text: printed_text = ocr_text wq.ocr_raw_text = printed_text or None if not ocr_text: wq.status = WrongQuestionStatus.failed wq.error_message = "OCR 未识别到文字,请拍摄更清晰、光线充足的题目照片" db.commit() return wq.status = WrongQuestionStatus.ocr_done db.commit() except FuturesTimeout: wq.status = WrongQuestionStatus.failed wq.error_message = ( f"OCR 识别超时(>{settings.OCR_TIMEOUT_SECONDS}秒),请点「重新识别标注」重试" ) db.commit() return except Exception as exc: wq.status = WrongQuestionStatus.failed msg = _short_error(exc, "OCR 识别失败:") if "libGL" in str(exc): msg += " 请在服务器执行: sudo bash deploy/install-ocr-deps.sh && systemctl restart grade-archive" elif ocr_url: if "OCR 服务" not in msg: msg += " 诊断: bash deploy/ocr-screen.sh status && bash deploy/ocr-worker/test-ocr.sh" wq.error_message = msg db.commit() return import asyncio loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: loop.run_until_complete( _run_ai_pipeline( wq, db, ocr_lines, printed_text, hw_indices=hw_ids, ) ) db.commit() except Exception as exc: wq.status = WrongQuestionStatus.failed detail = _short_error(exc, "AI 处理失败:") if "Timeout" in type(exc).__name__ or "timeout" in str(exc).lower(): detail = "AI 处理超时,请检查 Ollama/OpenAI 是否可用后重试" wq.error_message = detail db.commit() finally: loop.close() except Exception as exc: if wq is not None: wq.status = WrongQuestionStatus.failed wq.error_message = _short_error(exc, "处理失败:") db.commit() finally: db.close() @router.get("/students/{student_id}/wrong-questions", response_model=list[WrongQuestionOut]) def list_wrong_questions( student_id: uuid.UUID, subject_id: int | None = Query(None), category: WrongQuestionCategoryEnum | None = Query(None), q: str | None = Query(None), db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ): get_student_for_user(db, student_id, current_user.id) query = ( db.query(WrongQuestion) .options(joinedload(WrongQuestion.subject)) .filter(WrongQuestion.student_id == student_id) ) if subject_id is not None: query = query.filter(WrongQuestion.subject_id == subject_id) if category is not None: query = query.filter(WrongQuestion.category == category.value) if q: pattern = f"%{q}%" query = query.filter( (WrongQuestion.question_text.ilike(pattern)) | (WrongQuestion.solution_text.ilike(pattern)) | (WrongQuestion.ocr_raw_text.ilike(pattern)) ) items = query.order_by(WrongQuestion.created_at.desc()).all() for w in items: _expire_stale_processing(w, db) return [_wq_to_out(w) for w in items] @router.post( "/students/{student_id}/wrong-questions", response_model=WrongQuestionOut, status_code=status.HTTP_201_CREATED, ) async def upload_wrong_question( student_id: uuid.UUID, background_tasks: BackgroundTasks, subject_id: int = Form(...), file: UploadFile = File(...), category: WrongQuestionCategoryEnum = Form(WrongQuestionCategoryEnum.regular), db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ): get_student_for_user(db, student_id, current_user.id) subject = db.get(Subject, subject_id) if subject is None: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="科目不存在") if category == WrongQuestionCategoryEnum.olympiad and subject.name != "数学": raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="奥数区仅支持数学科目") content = await file.read() if len(content) > settings.MAX_UPLOAD_SIZE: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="文件超过10MB限制") wq = WrongQuestion( student_id=student_id, subject_id=subject_id, image_path="", category=WrongQuestionCategory(category.value), status=WrongQuestionStatus.pending, ) db.add(wq) db.flush() rel_path = ocr_service.save_upload_file( str(current_user.id), str(wq.id), file.filename or "image.jpg", content ) wq.image_path = rel_path db.commit() db.refresh(wq) wq = ( db.query(WrongQuestion) .options(joinedload(WrongQuestion.subject)) .filter(WrongQuestion.id == wq.id) .first() ) background_tasks.add_task(_process_wrong_question, wq.id) return _wq_to_out(wq) @router.get("/wrong-questions/{question_id}", response_model=WrongQuestionOut) def get_wrong_question( question_id: uuid.UUID, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ): wq = ( db.query(WrongQuestion) .options(joinedload(WrongQuestion.subject), joinedload(WrongQuestion.student)) .filter(WrongQuestion.id == question_id) .first() ) if wq is None or wq.student.user_id != current_user.id: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="错题不存在") _expire_stale_processing(wq, db) return _wq_to_out(wq) @router.patch("/wrong-questions/{question_id}", response_model=WrongQuestionOut) def update_wrong_question( question_id: uuid.UUID, data: WrongQuestionUpdate, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ): wq = ( db.query(WrongQuestion) .options(joinedload(WrongQuestion.subject), joinedload(WrongQuestion.student)) .filter(WrongQuestion.id == question_id) .first() ) if wq is None or wq.student.user_id != current_user.id: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="错题不存在") if data.subject_id is not None: if db.get(Subject, data.subject_id) is None: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="科目不存在") wq.subject_id = data.subject_id if data.question_text is not None: wq.question_text = data.question_text if data.solution_text is not None: wq.solution_text = data.solution_text if data.solution_approach is not None: wq.solution_approach = data.solution_approach db.commit() db.refresh(wq) return _wq_to_out(wq) @router.delete("/wrong-questions/{question_id}", status_code=status.HTTP_204_NO_CONTENT) def delete_wrong_question( question_id: uuid.UUID, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ): wq = ( db.query(WrongQuestion) .options(joinedload(WrongQuestion.student)) .filter(WrongQuestion.id == question_id) .first() ) if wq is None or wq.student.user_id != current_user.id: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="错题不存在") image_full = Path(settings.UPLOAD_DIR) / wq.image_path ann_full = Path(settings.UPLOAD_DIR) / wq.annotated_image_path if wq.annotated_image_path else None crop_full = Path(settings.UPLOAD_DIR) / wq.cropped_image_path if wq.cropped_image_path else None db.delete(wq) db.commit() if image_full.exists(): image_full.unlink() if ann_full and ann_full.exists(): ann_full.unlink() if crop_full and crop_full.exists(): crop_full.unlink() @router.post("/wrong-questions/{question_id}/retry-ocr", response_model=WrongQuestionOut) def retry_ocr( question_id: uuid.UUID, background_tasks: BackgroundTasks, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ): wq = ( db.query(WrongQuestion) .options(joinedload(WrongQuestion.subject), joinedload(WrongQuestion.student)) .filter(WrongQuestion.id == question_id) .first() ) if wq is None or wq.student.user_id != current_user.id: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="错题不存在") wq.status = WrongQuestionStatus.pending wq.error_message = None wq.cropped_image_path = None db.commit() background_tasks.add_task(_process_wrong_question, wq.id) return _wq_to_out(wq) @router.post("/wrong-questions/{question_id}/regenerate-solution", response_model=WrongQuestionOut) async def regenerate_solution( question_id: uuid.UUID, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ): wq = ( db.query(WrongQuestion) .options(joinedload(WrongQuestion.subject), joinedload(WrongQuestion.student)) .filter(WrongQuestion.id == question_id) .first() ) if wq is None or wq.student.user_id != current_user.id: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="错题不存在") if not wq.question_text and not wq.ocr_raw_text: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="缺少题目内容") subject_name = wq.subject.name if wq.subject else "综合" school_level = wq.student.school_level if wq.student else None olympiad = wq.category == WrongQuestionCategory.olympiad question_text = wq.question_text or wq.ocr_raw_text or "" ai_cfg = llm_service.load_ai_config(db) try: if not wq.question_text and wq.ocr_raw_text: wq.question_text = await llm_service.format_question( ai_cfg, subject_name, wq.ocr_raw_text, school_level ) question_text = wq.question_text solution_full = await llm_service.generate_solution( ai_cfg, subject_name, question_text, school_level, olympiad=olympiad, ) approach, solution_body = annotation_service.split_solution_sections(solution_full) wq.solution_approach = approach wq.solution_text = solution_body if approach else solution_full wq.status = WrongQuestionStatus.solved except Exception as exc: raise HTTPException( status_code=status.HTTP_502_BAD_GATEWAY, detail=f"AI 调用失败: {exc}" ) from exc db.commit() db.refresh(wq) return _wq_to_out(wq) @router.get("/wrong-questions/{question_id}/image") def get_wrong_question_image( question_id: uuid.UUID, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ): wq = ( db.query(WrongQuestion) .options(joinedload(WrongQuestion.student)) .filter(WrongQuestion.id == question_id) .first() ) if wq is None or wq.student.user_id != current_user.id: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="错题不存在") image_full = Path(settings.UPLOAD_DIR) / wq.image_path if not image_full.exists(): raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="图片不存在") return FileResponse(image_full) @router.get("/wrong-questions/{question_id}/annotated-image") def get_wrong_question_annotated_image( question_id: uuid.UUID, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ): wq = ( db.query(WrongQuestion) .options(joinedload(WrongQuestion.student)) .filter(WrongQuestion.id == question_id) .first() ) if wq is None or wq.student.user_id != current_user.id: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="错题不存在") if not wq.annotated_image_path: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="标注图尚未生成") image_full = Path(settings.UPLOAD_DIR) / wq.annotated_image_path if not image_full.exists(): raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="标注图不存在") return FileResponse(image_full) @router.get("/wrong-questions/{question_id}/cropped-image") def get_wrong_question_cropped_image( question_id: uuid.UUID, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ): wq = ( db.query(WrongQuestion) .options(joinedload(WrongQuestion.student)) .filter(WrongQuestion.id == question_id) .first() ) if wq is None or wq.student.user_id != current_user.id: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="错题不存在") if not wq.cropped_image_path: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="裁剪图尚未生成") image_full = Path(settings.UPLOAD_DIR) / wq.cropped_image_path if not image_full.exists(): raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="裁剪图不存在") return FileResponse(image_full)