175 lines
5.8 KiB
Python
175 lines
5.8 KiB
Python
"""大模型解读(OpenAI 兼容接口 + 图表图片)。"""
|
||
|
||
import asyncio
|
||
import base64
|
||
import logging
|
||
from datetime import datetime
|
||
|
||
import httpx
|
||
|
||
from .chart_image import render_daily_chart_png_async
|
||
from .config import settings
|
||
from .db import get_llm_interpretation, save_llm_interpretation
|
||
from .stats import compute_three_day_stats
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
_interpret_lock = asyncio.Lock()
|
||
_interpret_state: dict = {
|
||
"running": False,
|
||
"current_symbol": "",
|
||
"done": 0,
|
||
"total": 0,
|
||
"batch_id": "",
|
||
"last_error": "",
|
||
}
|
||
|
||
|
||
def get_interpret_state() -> dict:
|
||
return dict(_interpret_state)
|
||
|
||
|
||
def _api_url() -> str:
|
||
base = settings.llm_base_url.rstrip("/")
|
||
if base.endswith("/v1"):
|
||
return f"{base}/chat/completions"
|
||
return f"{base}/v1/chat/completions"
|
||
|
||
|
||
def _build_prompt(symbol: str, stats_row: dict | None) -> str:
|
||
lines = [
|
||
f"你是加密货币合约分析师。请根据附图({symbol} 近300日K+成交量)及数据给出中文简析。",
|
||
"关注:趋势、关键支撑阻力、成交量变化、资金费率含义、未来1-3日可能节奏。",
|
||
"控制在 200-350 字,条理清晰,不要废话。",
|
||
]
|
||
if stats_row:
|
||
t, y, b = stats_row.get("today", {}), stats_row.get("yesterday", {}), stats_row.get("daybefore", {})
|
||
lines.append(
|
||
f"\n三日均为成交额Top30交集:"
|
||
f"\n今日 排名{t.get('rank')} 涨跌{t.get('price_change_pct_fmt')} 额{t.get('quote_volume_fmt')}"
|
||
f"\n昨日 排名{y.get('rank')} 涨跌{y.get('price_change_pct_fmt')} 额{y.get('quote_volume_fmt')}"
|
||
f"\n前日 排名{b.get('rank')} 涨跌{b.get('price_change_pct_fmt')} 额{b.get('quote_volume_fmt')}"
|
||
f"\n资金费率(当前):{t.get('funding_rate_fmt', '—')}"
|
||
)
|
||
return "\n".join(lines)
|
||
|
||
|
||
async def interpret_symbol(
|
||
symbol: str,
|
||
stats_row: dict | None = None,
|
||
batch_id: str | None = None,
|
||
) -> str:
|
||
if not settings.llm_api_key.strip():
|
||
raise RuntimeError("LLM_API_KEY 未配置")
|
||
|
||
png = await render_daily_chart_png_async(symbol, settings.chart_kline_limit)
|
||
b64 = base64.standard_b64encode(png).decode("ascii")
|
||
prompt = _build_prompt(symbol, stats_row)
|
||
|
||
payload = {
|
||
"model": settings.llm_model,
|
||
"messages": [
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{"type": "text", "text": prompt},
|
||
{
|
||
"type": "image_url",
|
||
"image_url": {"url": f"data:image/png;base64,{b64}"},
|
||
},
|
||
],
|
||
}
|
||
],
|
||
"max_tokens": 800,
|
||
"temperature": 0.4,
|
||
}
|
||
|
||
headers = {
|
||
"Authorization": f"Bearer {settings.llm_api_key}",
|
||
"Content-Type": "application/json",
|
||
}
|
||
|
||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||
resp = await client.post(_api_url(), json=payload, headers=headers)
|
||
if resp.status_code >= 400:
|
||
# 部分模型不支持 vision,降级纯文本
|
||
logger.warning("LLM vision failed %s, fallback text", resp.status_code)
|
||
payload["messages"] = [{"role": "user", "content": prompt + "\n(附图日K+成交量未能传入,请基于数据简析)"}]
|
||
resp = await client.post(_api_url(), json=payload, headers=headers)
|
||
resp.raise_for_status()
|
||
data = resp.json()
|
||
|
||
content = data["choices"][0]["message"]["content"]
|
||
bid = batch_id or datetime.now().strftime("%Y-%m-%d-%H%M")
|
||
save_llm_interpretation(symbol, bid, content)
|
||
return content
|
||
|
||
|
||
async def run_interpretation_batch(
|
||
symbols: list[str] | None = None,
|
||
*,
|
||
batch_id: str | None = None,
|
||
) -> dict:
|
||
global _interpret_state
|
||
|
||
if _interpret_lock.locked():
|
||
return {"ok": False, "message": "解读任务进行中"}
|
||
|
||
stats = compute_three_day_stats()
|
||
if not stats.get("ok"):
|
||
return {"ok": False, "message": stats.get("message", "统计数据未就绪")}
|
||
|
||
sym_list = symbols or stats.get("symbols") or [x["symbol"] for x in stats.get("items", [])]
|
||
if not sym_list:
|
||
return {"ok": False, "message": "三日交集为空"}
|
||
|
||
stats_map = {x["symbol"]: x for x in stats.get("items", [])}
|
||
bid = batch_id or datetime.now().strftime("%Y-%m-%d-%H%M")
|
||
interval = settings.llm_symbol_interval_sec
|
||
|
||
async with _interpret_lock:
|
||
_interpret_state.update(
|
||
{
|
||
"running": True,
|
||
"current_symbol": "",
|
||
"done": 0,
|
||
"total": len(sym_list),
|
||
"batch_id": bid,
|
||
"last_error": "",
|
||
}
|
||
)
|
||
for i, sym in enumerate(sym_list):
|
||
_interpret_state["current_symbol"] = sym
|
||
try:
|
||
await interpret_symbol(sym, stats_map.get(sym), bid)
|
||
logger.info("LLM interpreted %s (%d/%d)", sym, i + 1, len(sym_list))
|
||
except Exception as e:
|
||
_interpret_state["last_error"] = str(e)
|
||
logger.error("LLM %s failed: %s", sym, e)
|
||
save_llm_interpretation(sym, bid, f"[解读失败] {e}")
|
||
_interpret_state["done"] = i + 1
|
||
if i < len(sym_list) - 1:
|
||
await asyncio.sleep(interval)
|
||
|
||
_interpret_state["running"] = False
|
||
_interpret_state["current_symbol"] = ""
|
||
|
||
return {
|
||
"ok": True,
|
||
"batch_id": bid,
|
||
"count": len(sym_list),
|
||
"interval_sec": interval,
|
||
}
|
||
|
||
|
||
def schedule_interpret_background(symbols: list[str] | None = None) -> None:
|
||
"""后台启动解读,不阻塞请求。"""
|
||
|
||
async def _run():
|
||
try:
|
||
await run_interpretation_batch(symbols)
|
||
except Exception as e:
|
||
logger.error("Background LLM batch failed: %s", e)
|
||
|
||
asyncio.create_task(_run())
|