Files
secondary-school-grade-archive/backend/app/services/llm.py
T

241 lines
8.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import httpx
from sqlalchemy.orm import Session
from app.core.config import settings as app_settings
from app.models.user import SchoolLevel, SystemSettings
from app.services.school_level import school_level_label
from app.services.url_sanitize import sanitize_http_url, sanitize_model_name
CURRICULUM_JUNIOR = """初中课程标准:代数、几何(全等/相似/勾股)、一次函数与简单二次函数、基础概率统计。
严禁使用:高中导数、向量、解析几何、排列组合进阶、复数、微积分、大学线性代数等。"""
CURRICULUM_SENIOR = """高中课程标准:课内函数、三角、向量、解析几何、概率统计、导数(课内范围)等。
严禁使用:大学数学分析、抽象代数、高等几何、超出课内的竞赛高阶技巧。"""
CURRICULUM_JUNIOR_OLYMPIAD = """初中奥数培优范围:整数/整除、因数分解、简单数论、代数恒等变形、几何辅助线与全等相似、简单组合计数。
严禁使用:高中及以上方法(导数、向量、解析几何、微积分、复数运算等)。"""
CURRICULUM_SENIOR_OLYMPIAD = """高中奥数/竞赛入门范围:课内知识+常规竞赛技巧(不等式、构造、归纳、简单数论等)。
严禁使用:大学数学、超出高中奥数培优体系的 IMO 高阶理论。"""
def _curriculum_block(level: SchoolLevel | str | None, olympiad: bool) -> str:
label = school_level_label(level)
is_senior = level == SchoolLevel.senior_high or level == "senior_high"
if olympiad:
return CURRICULUM_SENIOR_OLYMPIAD if is_senior else CURRICULUM_JUNIOR_OLYMPIAD
return CURRICULUM_SENIOR if is_senior else CURRICULUM_JUNIOR
QUESTION_PROMPT = """你是一位{stage}老师。以下是从试卷 OCR 识别出的文字,可能含有噪声。
科目:{subject}
请整理出清晰的题目内容(保留题号、选项、公式),只输出题目正文,不要解释。
OCR 原文:
{ocr_text}
"""
SOLUTION_PROMPT = """你是一位耐心的{stage}{subject}老师。请像「作业帮」一样,先讲清楚解题思路,再给出完整解答。
【学段要求 — 严禁超纲】
{curriculum}
题目:
{question_text}
请严格按以下 Markdown 结构输出:
## 解题思路
(2-5 句话:这题考什么、从哪里入手、关键一步是什么,让学生先懂「怎么想」)
## 详细解答
(分步骤完整推导,每步说明依据)
## 易错点
(指出常见错误及正确做法)
严禁使用超纲方法;若原题超纲,请给出{stage}课内可理解的解法。
"""
OLYMPIAD_SOLUTION_PROMPT = """你是一位{stage}奥数教练。请像优秀辅导老师一样,先讲解题思路,再完整解答。
【奥数学段要求 — 严禁超纲】
{curriculum}
题目:
{question_text}
请严格按以下 Markdown 结构输出:
## 解题思路
(点明题型、突破口、{stage}奥数常用技巧)
## 详细解答
(完整步骤)
## 关键技巧
(总结,仅限{stage}奥数范围)
严禁超纲;过难题给出{stage}可接受的培优思路。
"""
ERROR_DETECT_PROMPT = """你是{stage}{subject}老师。以下是试卷/作业 OCR 识别结果,每行前有编号。
请找出「学生答错的部分」:错误答案、被打叉的作答、明显不正确的计算结果等。
{numbered_lines}
只输出 JSON,不要其他文字:
{{"wrong_line_ids": [行编号整数列表]}}
若整张图就是一道错题,请标注含有错误答案或作答的行;找不到则标注最后作答行。
"""
REVIEW_INSIGHT_PROMPT = """你是一位{stage}{subject}学习顾问。根据学生历次考试的复盘状态,给出解读与可执行建议。
【学段】{stage},科目:{subject}
历次复盘记录(按时间从新到旧):
{review_records}
状态说明:粗心=审题/计算失误;不会=知识点未掌握;紧张=心态影响发挥;正常发挥=状态良好。
请用 Markdown 输出,结构如下:
## 情况解读
(2-4 句话:从成绩与状态看出什么规律或趋势)
## 改进建议
(3-5 条具体可执行建议,针对出现最多的问题状态,适合{stage}学生)
## 近期重点
1-2 条本周可落实的小目标)
语气亲切、务实,不要空泛鸡汤;不要超纲推荐学习内容。
"""
class AIConfig:
def __init__(
self,
provider: str,
ollama_base_url: str,
ollama_model: str,
openai_base_url: str,
openai_model: str,
openai_api_key: str | None,
):
self.provider = provider
self.ollama_base_url = ollama_base_url
self.ollama_model = ollama_model
self.openai_base_url = openai_base_url
self.openai_model = openai_model
self.openai_api_key = openai_api_key
def load_ai_config(db: Session) -> AIConfig:
row = db.get(SystemSettings, 1)
if row is None:
return AIConfig(
provider="ollama",
ollama_base_url=sanitize_http_url(app_settings.OLLAMA_BASE_URL),
ollama_model=sanitize_model_name(app_settings.OLLAMA_MODEL),
openai_base_url=sanitize_http_url(app_settings.OPENAI_BASE_URL),
openai_model=sanitize_model_name(app_settings.OPENAI_MODEL),
openai_api_key=None,
)
return AIConfig(
provider=row.ai_provider or "ollama",
ollama_base_url=sanitize_http_url(row.ollama_base_url or app_settings.OLLAMA_BASE_URL),
ollama_model=sanitize_model_name(row.ollama_model or app_settings.OLLAMA_MODEL),
openai_base_url=sanitize_http_url(row.openai_base_url or app_settings.OPENAI_BASE_URL),
openai_model=sanitize_model_name(row.openai_model or app_settings.OPENAI_MODEL),
openai_api_key=row.openai_api_key,
)
async def _ollama_generate(prompt: str, cfg: AIConfig) -> str:
url = f"{cfg.ollama_base_url.rstrip('/')}/api/generate"
payload = {"model": cfg.ollama_model, "prompt": prompt, "stream": False}
async with httpx.AsyncClient(timeout=180.0) as client:
response = await client.post(url, json=payload)
response.raise_for_status()
return (response.json().get("response") or "").strip()
async def _openai_generate(prompt: str, cfg: AIConfig) -> str:
if not cfg.openai_api_key:
raise ValueError("未配置 OpenAI API Key")
url = f"{cfg.openai_base_url.rstrip('/')}/chat/completions"
headers = {"Authorization": f"Bearer {cfg.openai_api_key}"}
payload = {
"model": cfg.openai_model,
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.3,
}
async with httpx.AsyncClient(timeout=180.0) as client:
response = await client.post(url, json=payload, headers=headers)
response.raise_for_status()
data = response.json()
return (data["choices"][0]["message"]["content"] or "").strip()
async def generate_text(prompt: str, cfg: AIConfig) -> str:
if cfg.provider == "openai":
return await _openai_generate(prompt, cfg)
return await _ollama_generate(prompt, cfg)
async def format_question(
cfg: AIConfig,
subject: str,
ocr_text: str,
school_level=None,
) -> str:
stage = school_level_label(school_level)
prompt = QUESTION_PROMPT.format(stage=stage, subject=subject, ocr_text=ocr_text)
return await generate_text(prompt, cfg)
async def generate_solution(
cfg: AIConfig,
subject: str,
question_text: str,
school_level=None,
*,
olympiad: bool = False,
) -> str:
stage = school_level_label(school_level)
curriculum = _curriculum_block(school_level, olympiad)
template = OLYMPIAD_SOLUTION_PROMPT if olympiad else SOLUTION_PROMPT
prompt = template.format(
stage=stage,
subject=subject,
curriculum=curriculum,
question_text=question_text,
)
return await generate_text(prompt, cfg)
async def detect_wrong_line_ids(
cfg: AIConfig,
subject: str,
ocr_lines: list[dict],
school_level=None,
) -> str:
stage = school_level_label(school_level)
numbered = "\n".join(f"[{i}] {line.get('text', '')}" for i, line in enumerate(ocr_lines))
prompt = ERROR_DETECT_PROMPT.format(stage=stage, subject=subject, numbered_lines=numbered)
return await generate_text(prompt, cfg)
async def generate_review_insight(
cfg: AIConfig,
subject: str,
review_records: str,
school_level=None,
) -> str:
stage = school_level_label(school_level)
prompt = REVIEW_INSIGHT_PROMPT.format(
stage=stage,
subject=subject,
review_records=review_records,
)
return await generate_text(prompt, cfg)