485 lines
18 KiB
Python
485 lines
18 KiB
Python
import json
|
||
import uuid
|
||
from concurrent.futures import ThreadPoolExecutor, TimeoutError as FuturesTimeout
|
||
from datetime import datetime, timedelta, timezone
|
||
from pathlib import Path
|
||
|
||
from fastapi import APIRouter, BackgroundTasks, Depends, File, Form, HTTPException, Query, UploadFile, status
|
||
from fastapi.responses import FileResponse
|
||
from sqlalchemy.orm import Session, joinedload
|
||
|
||
from app.core.config import settings
|
||
from app.core.database import SessionLocal, get_db
|
||
from app.core.deps import get_current_user
|
||
from app.models.user import Subject, SystemSettings, User, WrongQuestion, WrongQuestionCategory, WrongQuestionStatus
|
||
from app.schemas import WrongQuestionCategoryEnum, WrongQuestionOut, WrongQuestionUpdate
|
||
from app.services import annotation as annotation_service
|
||
from app.services import llm as llm_service
|
||
from app.services import ocr as ocr_service
|
||
from app.services.student_access import get_student_for_user
|
||
|
||
router = APIRouter(tags=["wrong_questions"])
|
||
|
||
|
||
def _short_error(exc: BaseException, prefix: str = "") -> str:
|
||
msg = str(exc).strip() or type(exc).__name__
|
||
if len(msg) > 500:
|
||
msg = msg[:500] + "…"
|
||
return f"{prefix}{msg}" if prefix else msg
|
||
|
||
|
||
def _is_still_processing(wq: WrongQuestion) -> bool:
|
||
if wq.status == WrongQuestionStatus.pending:
|
||
return True
|
||
if wq.status == WrongQuestionStatus.ocr_done and not wq.question_text and not wq.error_message:
|
||
return True
|
||
return False
|
||
|
||
|
||
def _expire_stale_processing(wq: WrongQuestion, db: Session) -> None:
|
||
if not _is_still_processing(wq):
|
||
return
|
||
created = wq.created_at
|
||
if created.tzinfo is None:
|
||
created = created.replace(tzinfo=timezone.utc)
|
||
age = datetime.now(timezone.utc) - created
|
||
if age <= timedelta(minutes=settings.PROCESS_STALE_MINUTES):
|
||
return
|
||
wq.status = WrongQuestionStatus.failed
|
||
wq.error_message = f"处理超时(超过 {settings.PROCESS_STALE_MINUTES} 分钟),请点击「重新识别标注」重试"
|
||
db.commit()
|
||
|
||
|
||
def _ocr_service_url(db: Session) -> str | None:
|
||
row = db.get(SystemSettings, 1)
|
||
if row and row.ocr_service_url:
|
||
return row.ocr_service_url.strip() or None
|
||
return ocr_service.resolve_ocr_service_url()
|
||
|
||
|
||
def _parse_mark_regions(raw: str | None) -> list[dict] | None:
|
||
if not raw:
|
||
return None
|
||
try:
|
||
data = json.loads(raw)
|
||
return data if isinstance(data, list) else None
|
||
except json.JSONDecodeError:
|
||
return None
|
||
|
||
|
||
def _wq_to_out(wq: WrongQuestion) -> WrongQuestionOut:
|
||
return WrongQuestionOut(
|
||
id=wq.id,
|
||
student_id=wq.student_id,
|
||
subject_id=wq.subject_id,
|
||
subject_name=wq.subject.name if wq.subject else None,
|
||
category=wq.category,
|
||
image_path=wq.image_path,
|
||
ocr_raw_text=wq.ocr_raw_text,
|
||
question_text=wq.question_text,
|
||
solution_approach=wq.solution_approach,
|
||
solution_text=wq.solution_text,
|
||
mark_regions=_parse_mark_regions(wq.mark_regions_json),
|
||
has_annotated_image=bool(wq.annotated_image_path),
|
||
error_message=wq.error_message,
|
||
status=wq.status,
|
||
created_at=wq.created_at,
|
||
)
|
||
|
||
|
||
async def _run_ai_pipeline(wq: WrongQuestion, db: Session, ocr_lines: list[dict], ocr_text: str):
|
||
import asyncio
|
||
|
||
subject_name = wq.subject.name if wq.subject else "综合"
|
||
school_level = wq.student.school_level if wq.student else None
|
||
olympiad = wq.category == WrongQuestionCategory.olympiad
|
||
ai_cfg = llm_service.load_ai_config(db)
|
||
image_full = str(Path(settings.UPLOAD_DIR) / wq.image_path)
|
||
timeout = settings.AI_TIMEOUT_SECONDS
|
||
|
||
try:
|
||
detect_resp = await asyncio.wait_for(
|
||
llm_service.detect_wrong_line_ids(ai_cfg, subject_name, ocr_lines, school_level),
|
||
timeout=min(90, timeout),
|
||
)
|
||
wrong_ids = annotation_service.parse_wrong_line_ids(detect_resp, ocr_lines)
|
||
except Exception:
|
||
wrong_ids = annotation_service.heuristic_wrong_line_ids(ocr_lines)
|
||
|
||
regions = annotation_service.regions_from_lines(ocr_lines, wrong_ids)
|
||
wq.mark_regions_json = json.dumps(regions, ensure_ascii=False)
|
||
ann_rel = ocr_service.annotated_rel_path(wq.image_path)
|
||
wq.annotated_image_path = annotation_service.draw_annotated_image(
|
||
image_full, ocr_lines, wrong_ids, ann_rel
|
||
)
|
||
db.commit()
|
||
|
||
question_text = await asyncio.wait_for(
|
||
llm_service.format_question(ai_cfg, subject_name, ocr_text, school_level),
|
||
timeout=timeout,
|
||
)
|
||
solution_full = await asyncio.wait_for(
|
||
llm_service.generate_solution(
|
||
ai_cfg, subject_name, question_text, school_level, olympiad=olympiad
|
||
),
|
||
timeout=timeout,
|
||
)
|
||
approach, solution_body = annotation_service.split_solution_sections(solution_full)
|
||
wq.question_text = question_text
|
||
wq.solution_approach = approach
|
||
wq.solution_text = solution_body if approach else solution_full
|
||
wq.status = WrongQuestionStatus.solved
|
||
wq.error_message = None
|
||
|
||
|
||
def _process_wrong_question(question_id: uuid.UUID):
|
||
db = SessionLocal()
|
||
wq = None
|
||
try:
|
||
wq = (
|
||
db.query(WrongQuestion)
|
||
.options(joinedload(WrongQuestion.subject), joinedload(WrongQuestion.student))
|
||
.filter(WrongQuestion.id == question_id)
|
||
.first()
|
||
)
|
||
if wq is None:
|
||
return
|
||
|
||
wq.error_message = None
|
||
image_full = Path(settings.UPLOAD_DIR) / wq.image_path
|
||
ocr_url = _ocr_service_url(db)
|
||
try:
|
||
with ThreadPoolExecutor(max_workers=1) as pool:
|
||
future = pool.submit(
|
||
ocr_service.run_ocr_with_regions, str(image_full), ocr_url
|
||
)
|
||
ocr_result = future.result(timeout=settings.OCR_TIMEOUT_SECONDS)
|
||
ocr_text = ocr_result["text"]
|
||
ocr_lines = ocr_result["lines"]
|
||
wq.ocr_raw_text = ocr_text or None
|
||
if not ocr_text:
|
||
wq.status = WrongQuestionStatus.failed
|
||
wq.error_message = "OCR 未识别到文字,请拍摄更清晰、光线充足的题目照片"
|
||
db.commit()
|
||
return
|
||
wq.status = WrongQuestionStatus.ocr_done
|
||
db.commit()
|
||
except FuturesTimeout:
|
||
wq.status = WrongQuestionStatus.failed
|
||
wq.error_message = (
|
||
f"OCR 识别超时(>{settings.OCR_TIMEOUT_SECONDS}秒),请点「重新识别标注」重试"
|
||
)
|
||
db.commit()
|
||
return
|
||
except Exception as exc:
|
||
wq.status = WrongQuestionStatus.failed
|
||
msg = _short_error(exc, "OCR 识别失败:")
|
||
if "libGL" in str(exc):
|
||
msg += " 请在服务器执行: sudo bash deploy/install-ocr-deps.sh && systemctl restart grade-archive"
|
||
elif ocr_url:
|
||
msg += f" 请检查 OCR 服务是否可达: {ocr_url} (可浏览器访问 {ocr_url.rstrip('/')}/health)"
|
||
wq.error_message = msg
|
||
db.commit()
|
||
return
|
||
|
||
import asyncio
|
||
|
||
loop = asyncio.new_event_loop()
|
||
asyncio.set_event_loop(loop)
|
||
try:
|
||
loop.run_until_complete(_run_ai_pipeline(wq, db, ocr_lines, ocr_text))
|
||
db.commit()
|
||
except Exception as exc:
|
||
wq.status = WrongQuestionStatus.failed
|
||
detail = _short_error(exc, "AI 处理失败:")
|
||
if "Timeout" in type(exc).__name__ or "timeout" in str(exc).lower():
|
||
detail = "AI 处理超时,请检查 Ollama/OpenAI 是否可用后重试"
|
||
wq.error_message = detail
|
||
db.commit()
|
||
finally:
|
||
loop.close()
|
||
except Exception as exc:
|
||
if wq is not None:
|
||
wq.status = WrongQuestionStatus.failed
|
||
wq.error_message = _short_error(exc, "处理失败:")
|
||
db.commit()
|
||
finally:
|
||
db.close()
|
||
|
||
|
||
@router.get("/students/{student_id}/wrong-questions", response_model=list[WrongQuestionOut])
|
||
def list_wrong_questions(
|
||
student_id: uuid.UUID,
|
||
subject_id: int | None = Query(None),
|
||
category: WrongQuestionCategoryEnum | None = Query(None),
|
||
q: str | None = Query(None),
|
||
db: Session = Depends(get_db),
|
||
current_user: User = Depends(get_current_user),
|
||
):
|
||
get_student_for_user(db, student_id, current_user.id)
|
||
query = (
|
||
db.query(WrongQuestion)
|
||
.options(joinedload(WrongQuestion.subject))
|
||
.filter(WrongQuestion.student_id == student_id)
|
||
)
|
||
if subject_id is not None:
|
||
query = query.filter(WrongQuestion.subject_id == subject_id)
|
||
if category is not None:
|
||
query = query.filter(WrongQuestion.category == category.value)
|
||
if q:
|
||
pattern = f"%{q}%"
|
||
query = query.filter(
|
||
(WrongQuestion.question_text.ilike(pattern))
|
||
| (WrongQuestion.solution_text.ilike(pattern))
|
||
| (WrongQuestion.ocr_raw_text.ilike(pattern))
|
||
)
|
||
items = query.order_by(WrongQuestion.created_at.desc()).all()
|
||
for w in items:
|
||
_expire_stale_processing(w, db)
|
||
return [_wq_to_out(w) for w in items]
|
||
|
||
|
||
@router.post(
|
||
"/students/{student_id}/wrong-questions",
|
||
response_model=WrongQuestionOut,
|
||
status_code=status.HTTP_201_CREATED,
|
||
)
|
||
async def upload_wrong_question(
|
||
student_id: uuid.UUID,
|
||
background_tasks: BackgroundTasks,
|
||
subject_id: int = Form(...),
|
||
file: UploadFile = File(...),
|
||
category: WrongQuestionCategoryEnum = Form(WrongQuestionCategoryEnum.regular),
|
||
db: Session = Depends(get_db),
|
||
current_user: User = Depends(get_current_user),
|
||
):
|
||
get_student_for_user(db, student_id, current_user.id)
|
||
subject = db.get(Subject, subject_id)
|
||
if subject is None:
|
||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="科目不存在")
|
||
if category == WrongQuestionCategoryEnum.olympiad and subject.name != "数学":
|
||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="奥数区仅支持数学科目")
|
||
|
||
content = await file.read()
|
||
if len(content) > settings.MAX_UPLOAD_SIZE:
|
||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="文件超过10MB限制")
|
||
|
||
wq = WrongQuestion(
|
||
student_id=student_id,
|
||
subject_id=subject_id,
|
||
image_path="",
|
||
category=WrongQuestionCategory(category.value),
|
||
status=WrongQuestionStatus.pending,
|
||
)
|
||
db.add(wq)
|
||
db.flush()
|
||
|
||
rel_path = ocr_service.save_upload_file(
|
||
str(current_user.id), str(wq.id), file.filename or "image.jpg", content
|
||
)
|
||
wq.image_path = rel_path
|
||
db.commit()
|
||
db.refresh(wq)
|
||
wq = (
|
||
db.query(WrongQuestion)
|
||
.options(joinedload(WrongQuestion.subject))
|
||
.filter(WrongQuestion.id == wq.id)
|
||
.first()
|
||
)
|
||
|
||
background_tasks.add_task(_process_wrong_question, wq.id)
|
||
return _wq_to_out(wq)
|
||
|
||
|
||
@router.get("/wrong-questions/{question_id}", response_model=WrongQuestionOut)
|
||
def get_wrong_question(
|
||
question_id: uuid.UUID,
|
||
db: Session = Depends(get_db),
|
||
current_user: User = Depends(get_current_user),
|
||
):
|
||
wq = (
|
||
db.query(WrongQuestion)
|
||
.options(joinedload(WrongQuestion.subject), joinedload(WrongQuestion.student))
|
||
.filter(WrongQuestion.id == question_id)
|
||
.first()
|
||
)
|
||
if wq is None or wq.student.user_id != current_user.id:
|
||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="错题不存在")
|
||
_expire_stale_processing(wq, db)
|
||
return _wq_to_out(wq)
|
||
|
||
|
||
@router.patch("/wrong-questions/{question_id}", response_model=WrongQuestionOut)
|
||
def update_wrong_question(
|
||
question_id: uuid.UUID,
|
||
data: WrongQuestionUpdate,
|
||
db: Session = Depends(get_db),
|
||
current_user: User = Depends(get_current_user),
|
||
):
|
||
wq = (
|
||
db.query(WrongQuestion)
|
||
.options(joinedload(WrongQuestion.subject), joinedload(WrongQuestion.student))
|
||
.filter(WrongQuestion.id == question_id)
|
||
.first()
|
||
)
|
||
if wq is None or wq.student.user_id != current_user.id:
|
||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="错题不存在")
|
||
|
||
if data.subject_id is not None:
|
||
if db.get(Subject, data.subject_id) is None:
|
||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="科目不存在")
|
||
wq.subject_id = data.subject_id
|
||
if data.question_text is not None:
|
||
wq.question_text = data.question_text
|
||
if data.solution_text is not None:
|
||
wq.solution_text = data.solution_text
|
||
if data.solution_approach is not None:
|
||
wq.solution_approach = data.solution_approach
|
||
|
||
db.commit()
|
||
db.refresh(wq)
|
||
return _wq_to_out(wq)
|
||
|
||
|
||
@router.delete("/wrong-questions/{question_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||
def delete_wrong_question(
|
||
question_id: uuid.UUID,
|
||
db: Session = Depends(get_db),
|
||
current_user: User = Depends(get_current_user),
|
||
):
|
||
wq = (
|
||
db.query(WrongQuestion)
|
||
.options(joinedload(WrongQuestion.student))
|
||
.filter(WrongQuestion.id == question_id)
|
||
.first()
|
||
)
|
||
if wq is None or wq.student.user_id != current_user.id:
|
||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="错题不存在")
|
||
|
||
image_full = Path(settings.UPLOAD_DIR) / wq.image_path
|
||
ann_full = Path(settings.UPLOAD_DIR) / wq.annotated_image_path if wq.annotated_image_path else None
|
||
db.delete(wq)
|
||
db.commit()
|
||
if image_full.exists():
|
||
image_full.unlink()
|
||
if ann_full and ann_full.exists():
|
||
ann_full.unlink()
|
||
|
||
|
||
@router.post("/wrong-questions/{question_id}/retry-ocr", response_model=WrongQuestionOut)
|
||
def retry_ocr(
|
||
question_id: uuid.UUID,
|
||
background_tasks: BackgroundTasks,
|
||
db: Session = Depends(get_db),
|
||
current_user: User = Depends(get_current_user),
|
||
):
|
||
wq = (
|
||
db.query(WrongQuestion)
|
||
.options(joinedload(WrongQuestion.subject), joinedload(WrongQuestion.student))
|
||
.filter(WrongQuestion.id == question_id)
|
||
.first()
|
||
)
|
||
if wq is None or wq.student.user_id != current_user.id:
|
||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="错题不存在")
|
||
|
||
wq.status = WrongQuestionStatus.pending
|
||
wq.error_message = None
|
||
db.commit()
|
||
background_tasks.add_task(_process_wrong_question, wq.id)
|
||
return _wq_to_out(wq)
|
||
|
||
|
||
@router.post("/wrong-questions/{question_id}/regenerate-solution", response_model=WrongQuestionOut)
|
||
async def regenerate_solution(
|
||
question_id: uuid.UUID,
|
||
db: Session = Depends(get_db),
|
||
current_user: User = Depends(get_current_user),
|
||
):
|
||
wq = (
|
||
db.query(WrongQuestion)
|
||
.options(joinedload(WrongQuestion.subject), joinedload(WrongQuestion.student))
|
||
.filter(WrongQuestion.id == question_id)
|
||
.first()
|
||
)
|
||
if wq is None or wq.student.user_id != current_user.id:
|
||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="错题不存在")
|
||
|
||
if not wq.question_text and not wq.ocr_raw_text:
|
||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="缺少题目内容")
|
||
|
||
subject_name = wq.subject.name if wq.subject else "综合"
|
||
school_level = wq.student.school_level if wq.student else None
|
||
olympiad = wq.category == WrongQuestionCategory.olympiad
|
||
question_text = wq.question_text or wq.ocr_raw_text or ""
|
||
ai_cfg = llm_service.load_ai_config(db)
|
||
|
||
try:
|
||
if not wq.question_text and wq.ocr_raw_text:
|
||
wq.question_text = await llm_service.format_question(
|
||
ai_cfg, subject_name, wq.ocr_raw_text, school_level
|
||
)
|
||
question_text = wq.question_text
|
||
solution_full = await llm_service.generate_solution(
|
||
ai_cfg,
|
||
subject_name,
|
||
question_text,
|
||
school_level,
|
||
olympiad=olympiad,
|
||
)
|
||
approach, solution_body = annotation_service.split_solution_sections(solution_full)
|
||
wq.solution_approach = approach
|
||
wq.solution_text = solution_body if approach else solution_full
|
||
wq.status = WrongQuestionStatus.solved
|
||
except Exception as exc:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_502_BAD_GATEWAY, detail=f"AI 调用失败: {exc}"
|
||
) from exc
|
||
|
||
db.commit()
|
||
db.refresh(wq)
|
||
return _wq_to_out(wq)
|
||
|
||
|
||
@router.get("/wrong-questions/{question_id}/image")
|
||
def get_wrong_question_image(
|
||
question_id: uuid.UUID,
|
||
db: Session = Depends(get_db),
|
||
current_user: User = Depends(get_current_user),
|
||
):
|
||
wq = (
|
||
db.query(WrongQuestion)
|
||
.options(joinedload(WrongQuestion.student))
|
||
.filter(WrongQuestion.id == question_id)
|
||
.first()
|
||
)
|
||
if wq is None or wq.student.user_id != current_user.id:
|
||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="错题不存在")
|
||
|
||
image_full = Path(settings.UPLOAD_DIR) / wq.image_path
|
||
if not image_full.exists():
|
||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="图片不存在")
|
||
return FileResponse(image_full)
|
||
|
||
|
||
@router.get("/wrong-questions/{question_id}/annotated-image")
|
||
def get_wrong_question_annotated_image(
|
||
question_id: uuid.UUID,
|
||
db: Session = Depends(get_db),
|
||
current_user: User = Depends(get_current_user),
|
||
):
|
||
wq = (
|
||
db.query(WrongQuestion)
|
||
.options(joinedload(WrongQuestion.student))
|
||
.filter(WrongQuestion.id == question_id)
|
||
.first()
|
||
)
|
||
if wq is None or wq.student.user_id != current_user.id:
|
||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="错题不存在")
|
||
if not wq.annotated_image_path:
|
||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="标注图尚未生成")
|
||
|
||
image_full = Path(settings.UPLOAD_DIR) / wq.annotated_image_path
|
||
if not image_full.exists():
|
||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="标注图不存在")
|
||
return FileResponse(image_full)
|