acfe002fbf
Co-authored-by: Cursor <cursoragent@cursor.com>
158 lines
4.7 KiB
Python
158 lines
4.7 KiB
Python
import json
|
||
import re
|
||
from pathlib import Path
|
||
|
||
from PIL import Image, ImageDraw, ImageFont
|
||
|
||
from app.core.config import settings
|
||
|
||
|
||
def _parse_llm_json(text: str) -> dict | None:
|
||
text = text.strip()
|
||
match = re.search(r"\{[\s\S]*\}", text)
|
||
if not match:
|
||
return None
|
||
try:
|
||
return json.loads(match.group())
|
||
except json.JSONDecodeError:
|
||
return None
|
||
|
||
|
||
def heuristic_wrong_line_ids(lines: list[dict]) -> list[int]:
|
||
wrong: list[int] = []
|
||
for i, line in enumerate(lines):
|
||
t = line.get("text", "")
|
||
if any(c in t for c in ("×", "✗", "❌", "错")):
|
||
wrong.append(i)
|
||
continue
|
||
if re.search(r"[×xX]\s*$", t.strip()):
|
||
wrong.append(i)
|
||
if wrong:
|
||
return wrong
|
||
# 单题照片:标注最后几行作答区域
|
||
if len(lines) == 1:
|
||
return [0]
|
||
if len(lines) <= 4:
|
||
return list(range(max(0, len(lines) - 2), len(lines)))
|
||
return list(range(len(lines) - 2, len(lines)))
|
||
|
||
|
||
def parse_wrong_line_ids(llm_response: str, lines: list[dict]) -> list[int]:
|
||
data = _parse_llm_json(llm_response)
|
||
if data and isinstance(data.get("wrong_line_ids"), list):
|
||
ids = [int(x) for x in data["wrong_line_ids"] if isinstance(x, (int, float, str))]
|
||
ids = [i for i in ids if 0 <= i < len(lines)]
|
||
if ids:
|
||
return ids
|
||
return heuristic_wrong_line_ids(lines)
|
||
|
||
|
||
def regions_from_lines(lines: list[dict], wrong_ids: list[int]) -> list[dict]:
|
||
regions = []
|
||
for i in wrong_ids:
|
||
if i >= len(lines):
|
||
continue
|
||
line = lines[i]
|
||
bbox = line.get("bbox") or [0, 0, 0, 0]
|
||
regions.append(
|
||
{
|
||
"line_id": i,
|
||
"text": line.get("text", ""),
|
||
"bbox": bbox,
|
||
"type": "wrong",
|
||
"label": "错",
|
||
}
|
||
)
|
||
return regions
|
||
|
||
|
||
def draw_annotated_image(
|
||
src_path: str,
|
||
lines: list[dict],
|
||
wrong_ids: list[int],
|
||
dest_rel_path: str,
|
||
) -> str:
|
||
img = Image.open(src_path).convert("RGBA")
|
||
overlay = Image.new("RGBA", img.size, (255, 255, 255, 0))
|
||
draw = ImageDraw.Draw(overlay)
|
||
|
||
try:
|
||
font = ImageFont.truetype("DejaVuSans-Bold.ttf", max(14, img.size[0] // 40))
|
||
except OSError:
|
||
font = ImageFont.load_default()
|
||
|
||
for i in wrong_ids:
|
||
if i >= len(lines):
|
||
continue
|
||
bbox = lines[i].get("bbox") or [0, 0, 0, 0]
|
||
x1, y1, x2, y2 = bbox
|
||
pad = 6
|
||
box = [x1 - pad, y1 - pad, x2 + pad, y2 + pad]
|
||
draw.rounded_rectangle(box, radius=4, fill=(255, 59, 48, 55), outline=(255, 59, 48, 220), width=3)
|
||
draw.text((x1, max(0, y1 - 18)), "×", fill=(255, 59, 48, 255), font=font)
|
||
|
||
combined = Image.alpha_composite(img, overlay).convert("RGB")
|
||
full_path = Path(settings.UPLOAD_DIR) / dest_rel_path
|
||
full_path.parent.mkdir(parents=True, exist_ok=True)
|
||
combined.save(full_path, format="JPEG", quality=92)
|
||
return dest_rel_path
|
||
|
||
|
||
def split_solution_sections(text: str) -> tuple[str | None, str]:
|
||
if "## 解题思路" not in text:
|
||
return None, text
|
||
parts = re.split(r"\n##\s*", text, maxsplit=1)
|
||
if len(parts) < 2:
|
||
return None, text
|
||
approach = parts[0].replace("## 解题思路", "").strip()
|
||
rest = "## " + parts[1]
|
||
return approach or None, rest.strip()
|
||
|
||
|
||
def union_bbox(bboxes: list[list[float]], img_w: int, img_h: int, padding_ratio: float = 0.06) -> list[int]:
|
||
if not bboxes:
|
||
return [0, 0, img_w, img_h]
|
||
x1 = min(b[0] for b in bboxes)
|
||
y1 = min(b[1] for b in bboxes)
|
||
x2 = max(b[2] for b in bboxes)
|
||
y2 = max(b[3] for b in bboxes)
|
||
pad_x = max(8, (x2 - x1) * padding_ratio)
|
||
pad_y = max(8, (y2 - y1) * padding_ratio)
|
||
return [
|
||
int(max(0, x1 - pad_x)),
|
||
int(max(0, y1 - pad_y)),
|
||
int(min(img_w, x2 + pad_x)),
|
||
int(min(img_h, y2 + pad_y)),
|
||
]
|
||
|
||
|
||
def cropped_rel_path(original_rel: str) -> str:
|
||
p = Path(original_rel)
|
||
return str(p.parent / f"{p.stem}_crop.jpg")
|
||
|
||
|
||
def crop_wrong_region(
|
||
src_path: str,
|
||
lines: list[dict],
|
||
wrong_ids: list[int],
|
||
dest_rel_path: str,
|
||
img_width: int,
|
||
img_height: int,
|
||
) -> str | None:
|
||
if not wrong_ids:
|
||
return None
|
||
bboxes = [lines[i].get("bbox") or [0, 0, 0, 0] for i in wrong_ids if i < len(lines)]
|
||
if not bboxes:
|
||
return None
|
||
box = union_bbox(bboxes, img_width, img_height, padding_ratio=0.12)
|
||
x1, y1, x2, y2 = box
|
||
if x2 <= x1 or y2 <= y1:
|
||
return None
|
||
|
||
img = Image.open(src_path).convert("RGB")
|
||
cropped = img.crop((x1, y1, x2, y2))
|
||
full_path = Path(settings.UPLOAD_DIR) / dest_rel_path
|
||
full_path.parent.mkdir(parents=True, exist_ok=True)
|
||
cropped.save(full_path, format="JPEG", quality=92)
|
||
return dest_rel_path
|