增加大模型
This commit is contained in:
@@ -0,0 +1,174 @@
|
||||
"""大模型解读(OpenAI 兼容接口 + 图表图片)。"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
import httpx
|
||||
|
||||
from .chart_image import render_daily_chart_png_async
|
||||
from .config import settings
|
||||
from .db import get_llm_interpretation, save_llm_interpretation
|
||||
from .stats import compute_three_day_stats
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_interpret_lock = asyncio.Lock()
|
||||
_interpret_state: dict = {
|
||||
"running": False,
|
||||
"current_symbol": "",
|
||||
"done": 0,
|
||||
"total": 0,
|
||||
"batch_id": "",
|
||||
"last_error": "",
|
||||
}
|
||||
|
||||
|
||||
def get_interpret_state() -> dict:
|
||||
return dict(_interpret_state)
|
||||
|
||||
|
||||
def _api_url() -> str:
|
||||
base = settings.llm_base_url.rstrip("/")
|
||||
if base.endswith("/v1"):
|
||||
return f"{base}/chat/completions"
|
||||
return f"{base}/v1/chat/completions"
|
||||
|
||||
|
||||
def _build_prompt(symbol: str, stats_row: dict | None) -> str:
|
||||
lines = [
|
||||
f"你是加密货币合约分析师。请根据附图({symbol} 近300日K+成交量)及数据给出中文简析。",
|
||||
"关注:趋势、关键支撑阻力、成交量变化、资金费率含义、未来1-3日可能节奏。",
|
||||
"控制在 200-350 字,条理清晰,不要废话。",
|
||||
]
|
||||
if stats_row:
|
||||
t, y, b = stats_row.get("today", {}), stats_row.get("yesterday", {}), stats_row.get("daybefore", {})
|
||||
lines.append(
|
||||
f"\n三日均为成交额Top30交集:"
|
||||
f"\n今日 排名{t.get('rank')} 涨跌{t.get('price_change_pct_fmt')} 额{t.get('quote_volume_fmt')}"
|
||||
f"\n昨日 排名{y.get('rank')} 涨跌{y.get('price_change_pct_fmt')} 额{y.get('quote_volume_fmt')}"
|
||||
f"\n前日 排名{b.get('rank')} 涨跌{b.get('price_change_pct_fmt')} 额{b.get('quote_volume_fmt')}"
|
||||
f"\n资金费率(当前):{t.get('funding_rate_fmt', '—')}"
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
async def interpret_symbol(
|
||||
symbol: str,
|
||||
stats_row: dict | None = None,
|
||||
batch_id: str | None = None,
|
||||
) -> str:
|
||||
if not settings.llm_api_key.strip():
|
||||
raise RuntimeError("LLM_API_KEY 未配置")
|
||||
|
||||
png = await render_daily_chart_png_async(symbol, settings.chart_kline_limit)
|
||||
b64 = base64.standard_b64encode(png).decode("ascii")
|
||||
prompt = _build_prompt(symbol, stats_row)
|
||||
|
||||
payload = {
|
||||
"model": settings.llm_model,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": prompt},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{b64}"},
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
"max_tokens": 800,
|
||||
"temperature": 0.4,
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {settings.llm_api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
resp = await client.post(_api_url(), json=payload, headers=headers)
|
||||
if resp.status_code >= 400:
|
||||
# 部分模型不支持 vision,降级纯文本
|
||||
logger.warning("LLM vision failed %s, fallback text", resp.status_code)
|
||||
payload["messages"] = [{"role": "user", "content": prompt + "\n(附图日K+成交量未能传入,请基于数据简析)"}]
|
||||
resp = await client.post(_api_url(), json=payload, headers=headers)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
content = data["choices"][0]["message"]["content"]
|
||||
bid = batch_id or datetime.now().strftime("%Y-%m-%d-%H%M")
|
||||
save_llm_interpretation(symbol, bid, content)
|
||||
return content
|
||||
|
||||
|
||||
async def run_interpretation_batch(
|
||||
symbols: list[str] | None = None,
|
||||
*,
|
||||
batch_id: str | None = None,
|
||||
) -> dict:
|
||||
global _interpret_state
|
||||
|
||||
if _interpret_lock.locked():
|
||||
return {"ok": False, "message": "解读任务进行中"}
|
||||
|
||||
stats = compute_three_day_stats()
|
||||
if not stats.get("ok"):
|
||||
return {"ok": False, "message": stats.get("message", "统计数据未就绪")}
|
||||
|
||||
sym_list = symbols or stats.get("symbols") or [x["symbol"] for x in stats.get("items", [])]
|
||||
if not sym_list:
|
||||
return {"ok": False, "message": "三日交集为空"}
|
||||
|
||||
stats_map = {x["symbol"]: x for x in stats.get("items", [])}
|
||||
bid = batch_id or datetime.now().strftime("%Y-%m-%d-%H%M")
|
||||
interval = settings.llm_symbol_interval_sec
|
||||
|
||||
async with _interpret_lock:
|
||||
_interpret_state.update(
|
||||
{
|
||||
"running": True,
|
||||
"current_symbol": "",
|
||||
"done": 0,
|
||||
"total": len(sym_list),
|
||||
"batch_id": bid,
|
||||
"last_error": "",
|
||||
}
|
||||
)
|
||||
for i, sym in enumerate(sym_list):
|
||||
_interpret_state["current_symbol"] = sym
|
||||
try:
|
||||
await interpret_symbol(sym, stats_map.get(sym), bid)
|
||||
logger.info("LLM interpreted %s (%d/%d)", sym, i + 1, len(sym_list))
|
||||
except Exception as e:
|
||||
_interpret_state["last_error"] = str(e)
|
||||
logger.error("LLM %s failed: %s", sym, e)
|
||||
save_llm_interpretation(sym, bid, f"[解读失败] {e}")
|
||||
_interpret_state["done"] = i + 1
|
||||
if i < len(sym_list) - 1:
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
_interpret_state["running"] = False
|
||||
_interpret_state["current_symbol"] = ""
|
||||
|
||||
return {
|
||||
"ok": True,
|
||||
"batch_id": bid,
|
||||
"count": len(sym_list),
|
||||
"interval_sec": interval,
|
||||
}
|
||||
|
||||
|
||||
def schedule_interpret_background(symbols: list[str] | None = None) -> None:
|
||||
"""后台启动解读,不阻塞请求。"""
|
||||
|
||||
async def _run():
|
||||
try:
|
||||
await run_interpretation_batch(symbols)
|
||||
except Exception as e:
|
||||
logger.error("Background LLM batch failed: %s", e)
|
||||
|
||||
asyncio.create_task(_run())
|
||||
Reference in New Issue
Block a user