作业帮式错题标注:OCR 定位错误红框 + 解题思路。

- PaddleOCR 行级坐标 + AI 识别错答区域,生成标注图

- 解法拆分为「解题思路」与「详细解答」

- 详情页标注图/原图切换,列表显示标注缩略图
This commit is contained in:
dekun
2026-06-28 13:50:20 +08:00
parent c30e21b51e
commit a2a6d59f7c
16 changed files with 852 additions and 507 deletions
+109
View File
@@ -0,0 +1,109 @@
import json
import re
from pathlib import Path
from PIL import Image, ImageDraw, ImageFont
from app.core.config import settings
def _parse_llm_json(text: str) -> dict | None:
text = text.strip()
match = re.search(r"\{[\s\S]*\}", text)
if not match:
return None
try:
return json.loads(match.group())
except json.JSONDecodeError:
return None
def heuristic_wrong_line_ids(lines: list[dict]) -> list[int]:
wrong: list[int] = []
for i, line in enumerate(lines):
t = line.get("text", "")
if any(c in t for c in ("×", "", "", "")):
wrong.append(i)
continue
if re.search(r"[×xX]\s*$", t.strip()):
wrong.append(i)
if wrong:
return wrong
# 单题照片:标注最后几行作答区域
if len(lines) == 1:
return [0]
if len(lines) <= 4:
return list(range(max(0, len(lines) - 2), len(lines)))
return list(range(len(lines) - 2, len(lines)))
def parse_wrong_line_ids(llm_response: str, lines: list[dict]) -> list[int]:
data = _parse_llm_json(llm_response)
if data and isinstance(data.get("wrong_line_ids"), list):
ids = [int(x) for x in data["wrong_line_ids"] if isinstance(x, (int, float, str))]
ids = [i for i in ids if 0 <= i < len(lines)]
if ids:
return ids
return heuristic_wrong_line_ids(lines)
def regions_from_lines(lines: list[dict], wrong_ids: list[int]) -> list[dict]:
regions = []
for i in wrong_ids:
if i >= len(lines):
continue
line = lines[i]
bbox = line.get("bbox") or [0, 0, 0, 0]
regions.append(
{
"line_id": i,
"text": line.get("text", ""),
"bbox": bbox,
"type": "wrong",
"label": "",
}
)
return regions
def draw_annotated_image(
src_path: str,
lines: list[dict],
wrong_ids: list[int],
dest_rel_path: str,
) -> str:
img = Image.open(src_path).convert("RGBA")
overlay = Image.new("RGBA", img.size, (255, 255, 255, 0))
draw = ImageDraw.Draw(overlay)
try:
font = ImageFont.truetype("DejaVuSans-Bold.ttf", max(14, img.size[0] // 40))
except OSError:
font = ImageFont.load_default()
for i in wrong_ids:
if i >= len(lines):
continue
bbox = lines[i].get("bbox") or [0, 0, 0, 0]
x1, y1, x2, y2 = bbox
pad = 6
box = [x1 - pad, y1 - pad, x2 + pad, y2 + pad]
draw.rounded_rectangle(box, radius=4, fill=(255, 59, 48, 55), outline=(255, 59, 48, 220), width=3)
draw.text((x1, max(0, y1 - 18)), "×", fill=(255, 59, 48, 255), font=font)
combined = Image.alpha_composite(img, overlay).convert("RGB")
full_path = Path(settings.UPLOAD_DIR) / dest_rel_path
full_path.parent.mkdir(parents=True, exist_ok=True)
combined.save(full_path, format="JPEG", quality=92)
return dest_rel_path
def split_solution_sections(text: str) -> tuple[str | None, str]:
if "## 解题思路" not in text:
return None, text
parts = re.split(r"\n##\s*", text, maxsplit=1)
if len(parts) < 2:
return None, text
approach = parts[0].replace("## 解题思路", "").strip()
rest = "## " + parts[1]
return approach or None, rest.strip()
+48 -14
View File
@@ -1,5 +1,3 @@
import enum
import httpx
from sqlalchemy.orm import Session
@@ -36,7 +34,7 @@ OCR 原文:
{ocr_text}
"""
SOLUTION_PROMPT = """你是一位耐心的{stage}{subject}老师。请为以下题目给出详细解法
SOLUTION_PROMPT = """你是一位耐心的{stage}{subject}老师。请像「作业帮」一样,先讲清楚解题思路,再给出完整解答
【学段要求 — 严禁超纲】
{curriculum}
@@ -44,14 +42,31 @@ SOLUTION_PROMPT = """你是一位耐心的{stage}{subject}老师。请为以下
题目:
{question_text}
请按以下结构输出:
1. 考点分析({stage}范围内)
2. 解题步骤(逐步推导,每步说明依据)
3. 易错点提醒
4. 若必须使用超纲方法才能解,请改用{stage}可理解的方法重新解答,不得输出超纲解法。
严格按以下 Markdown 结构输出:
## 解题思路
(2-5 句话:这题考什么、从哪里入手、关键一步是什么,让学生先懂「怎么想」)
## 详细解答
(分步骤完整推导,每步说明依据)
## 易错点
(指出常见错误及正确做法)
严禁使用超纲方法;若原题超纲,请给出{stage}课内可理解的解法。
"""
OLYMPIAD_SOLUTION_PROMPT = """你是一位{stage}奥数教练。请为以下奥数题给出详细解题思路与完整解答
ERROR_DETECT_PROMPT = """你是{stage}{subject}老师。以下是试卷/作业 OCR 识别结果,每行前有编号
请找出「学生答错的部分」:错误答案、被打叉的作答、明显不正确的计算结果等。
{numbered_lines}
只输出 JSON,不要其他文字:
{{"wrong_line_ids": [行编号整数列表]}}
若整张图就是一道错题,请标注含有错误答案或作答的行;找不到则标注最后作答行。
"""
OLYMPIAD_SOLUTION_PROMPT = """你是一位{stage}奥数教练。请像优秀辅导老师一样,先讲解题思路,再完整解答。
【奥数学段要求 — 严禁超纲】
{curriculum}
@@ -59,11 +74,18 @@ OLYMPIAD_SOLUTION_PROMPT = """你是一位{stage}奥数教练。请为以下奥
题目:
{question_text}
请按以下结构输出:
1. 题型与思路切入点({stage}奥数常见技巧)
2. 详细解答步骤
3. 关键技巧总结(仅限{stage}奥数范围
4. 严禁使用超出上述范围的方法;若题目过难,给出{stage}可接受的培优思路。
严格按以下 Markdown 结构输出:
## 解题思路
(点明题型、突破口、{stage}奥数常用技巧
## 详细解答
(完整步骤)
## 关键技巧
(总结,仅限{stage}奥数范围)
严禁超纲;过难题给出{stage}可接受的培优思路。
"""
@@ -167,3 +189,15 @@ async def generate_solution(
question_text=question_text,
)
return await generate_text(prompt, cfg)
async def detect_wrong_line_ids(
cfg: AIConfig,
subject: str,
ocr_lines: list[dict],
school_level=None,
) -> str:
stage = school_level_label(school_level)
numbered = "\n".join(f"[{i}] {line.get('text', '')}" for i, line in enumerate(ocr_lines))
prompt = ERROR_DETECT_PROMPT.format(stage=stage, subject=subject, numbered_lines=numbered)
return await generate_text(prompt, cfg)
+14
View File
@@ -65,3 +65,17 @@ def run_migrations() -> None:
with engine.begin() as conn:
for clause in alters:
conn.execute(text(f"ALTER TABLE system_settings {clause}"))
if "wrong_questions" in tables:
wq_columns = {col["name"] for col in inspector.get_columns("wrong_questions")}
wq_alters: list[str] = []
if "solution_approach" not in wq_columns:
wq_alters.append("ADD COLUMN solution_approach TEXT")
if "mark_regions_json" not in wq_columns:
wq_alters.append("ADD COLUMN mark_regions_json TEXT")
if "annotated_image_path" not in wq_columns:
wq_alters.append("ADD COLUMN annotated_image_path VARCHAR(512)")
if wq_alters:
with engine.begin() as conn:
for clause in wq_alters:
conn.execute(text(f"ALTER TABLE wrong_questions {clause}"))
+51 -10
View File
@@ -1,5 +1,7 @@
from pathlib import Path
from PIL import Image
from app.core.config import settings
_ocr_engine = None
@@ -14,18 +16,52 @@ def get_ocr_engine():
return _ocr_engine
def run_ocr(image_path: str) -> str:
def _bbox_from_box(box: list) -> list[float]:
xs = [float(p[0]) for p in box]
ys = [float(p[1]) for p in box]
return [min(xs), min(ys), max(xs), max(ys)]
def run_ocr_with_regions(image_path: str) -> dict:
"""Return OCR text plus line-level bounding boxes for annotation."""
engine = get_ocr_engine()
result = engine.ocr(image_path, cls=True)
if not result or not result[0]:
return ""
lines = []
for line in result[0]:
if line and len(line) >= 2:
text = line[1][0]
if text:
lines.append(text)
return "\n".join(lines)
lines: list[dict] = []
if result and result[0]:
for item in result[0]:
if not item or len(item) < 2:
continue
box, rec = item[0], item[1]
text = rec[0] if rec else ""
conf = float(rec[1]) if rec and len(rec) > 1 else 0.0
if not text:
continue
lines.append(
{
"text": text,
"confidence": conf,
"box": box,
"bbox": _bbox_from_box(box),
}
)
width, height = 0, 0
try:
with Image.open(image_path) as img:
width, height = img.size
except OSError:
pass
return {
"text": "\n".join(line["text"] for line in lines),
"lines": lines,
"width": width,
"height": height,
}
def run_ocr(image_path: str) -> str:
return run_ocr_with_regions(image_path)["text"]
def save_upload_file(user_id: str, question_id: str, filename: str, content: bytes) -> str:
@@ -38,3 +74,8 @@ def save_upload_file(user_id: str, question_id: str, filename: str, content: byt
full_path = Path(settings.UPLOAD_DIR) / rel_path
full_path.write_bytes(content)
return rel_path
def annotated_rel_path(original_rel: str) -> str:
p = Path(original_rel)
return str(p.parent / f"{p.stem}_marked.jpg")