6169fee7b9
Co-authored-by: Cursor <cursoragent@cursor.com>
436 lines
15 KiB
Python
436 lines
15 KiB
Python
"""大模型调用:OpenAI 兼容接口(默认)或本机 Ollama 二选一。
|
||
|
||
配置从 os.environ 惰性读取:各实例 app.py 在 import 本模块后才 load_env_file(.env),
|
||
若在 import 时缓存变量会导致 OPENAI_API_KEY 始终为空。
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import base64
|
||
import os
|
||
from typing import List, Optional, Sequence, Tuple
|
||
|
||
import requests
|
||
|
||
|
||
def _env_str(name: str, default: str = "") -> str:
|
||
v = os.getenv(name)
|
||
if v is None:
|
||
return default
|
||
return str(v).strip()
|
||
|
||
|
||
def _ai_timeout_seconds(*, image_count: int = 0, chat: bool = False) -> int:
|
||
if chat:
|
||
try:
|
||
return max(30, int(_env_str("CHAT_AI_TIMEOUT_SECONDS", "300") or "300"))
|
||
except ValueError:
|
||
return 300
|
||
if image_count > 0:
|
||
try:
|
||
return max(30, int(_env_str("AI_REVIEW_TIMEOUT_SECONDS", "300") or "300"))
|
||
except ValueError:
|
||
return 300
|
||
try:
|
||
return max(10, int(_env_str("AI_TIMEOUT_SECONDS", "120") or "120"))
|
||
except ValueError:
|
||
return 120
|
||
|
||
|
||
def _ai_provider() -> str:
|
||
return (_env_str("AI_PROVIDER", "openai") or "openai").lower()
|
||
|
||
|
||
def _openai_api_base() -> str:
|
||
base = _env_str("OPENAI_API_BASE", "https://op.bz121.com/v1") or "https://op.bz121.com/v1"
|
||
return base.rstrip("/")
|
||
|
||
|
||
def _openai_api_key() -> str:
|
||
return _env_str("OPENAI_API_KEY") or _env_str("AI_API_KEY")
|
||
|
||
|
||
def _openai_model() -> str:
|
||
return _env_str("OPENAI_MODEL", "gemma4:e4b") or "gemma4:e4b"
|
||
|
||
|
||
def _ollama_api() -> str:
|
||
return _env_str("OLLAMA_API", "http://127.0.0.1:11434/api/generate") or "http://127.0.0.1:11434/api/generate"
|
||
|
||
|
||
def _ollama_model() -> str:
|
||
return _env_str("AI_MODEL", "huihui_ai/deepseek-r1-abliterated:latest") or "huihui_ai/deepseek-r1-abliterated:latest"
|
||
|
||
|
||
def _use_openai() -> bool:
|
||
return _ai_provider() in ("openai", "openai_compatible", "gateway")
|
||
|
||
|
||
def _image_mime_for_path(path: str) -> str:
|
||
ext = os.path.splitext(str(path or ""))[1].lower()
|
||
if ext == ".png":
|
||
return "image/png"
|
||
if ext in (".jpg", ".jpeg"):
|
||
return "image/jpeg"
|
||
if ext == ".webp":
|
||
return "image/webp"
|
||
if ext == ".gif":
|
||
return "image/gif"
|
||
return "image/jpeg"
|
||
|
||
|
||
def _read_image_base64(image_path: str) -> Optional[tuple]:
|
||
try:
|
||
with open(image_path, "rb") as f:
|
||
b64 = base64.b64encode(f.read()).decode("utf-8")
|
||
return b64, _image_mime_for_path(image_path)
|
||
except Exception:
|
||
return None
|
||
|
||
|
||
def _collect_images(
|
||
image_paths: Optional[Sequence[str]] = None,
|
||
images_b64: Optional[Sequence[str]] = None,
|
||
) -> List[tuple]:
|
||
out: List[tuple] = []
|
||
for p in image_paths or []:
|
||
item = _read_image_base64(p)
|
||
if item:
|
||
out.append(item)
|
||
for b in images_b64 or []:
|
||
if b:
|
||
out.append((str(b), "image/jpeg"))
|
||
return out
|
||
|
||
|
||
def _openai_chat_url() -> str:
|
||
base = _openai_api_base()
|
||
if base.endswith("/chat/completions"):
|
||
return base
|
||
return f"{base}/chat/completions"
|
||
|
||
|
||
def _openai_message_text(msg: dict) -> str:
|
||
content = msg.get("content")
|
||
if isinstance(content, list):
|
||
parts: list[str] = []
|
||
for part in content:
|
||
if isinstance(part, dict) and part.get("type") == "text":
|
||
parts.append(str(part.get("text") or ""))
|
||
content = "".join(parts)
|
||
text = str(content or "").strip()
|
||
if not text:
|
||
text = str(msg.get("reasoning_content") or "").strip()
|
||
return text
|
||
|
||
|
||
def _apply_max_tokens(body: dict, max_tokens: int | None) -> None:
|
||
if max_tokens is not None and max_tokens > 0:
|
||
mt = int(max_tokens)
|
||
body["max_tokens"] = mt
|
||
body["max_completion_tokens"] = mt
|
||
|
||
|
||
def _openai_chat_completion(
|
||
messages: list[dict],
|
||
*,
|
||
temperature: float,
|
||
max_tokens: int | None = None,
|
||
image_count: int = 0,
|
||
chat: bool = False,
|
||
) -> Tuple[str, str]:
|
||
api_key = _openai_api_key()
|
||
if not api_key:
|
||
return "AI 调用失败:未配置 OPENAI_API_KEY(请在当前实例目录 .env 中设置,修改后需重启服务)", "error"
|
||
headers = {
|
||
"Authorization": f"Bearer {api_key}",
|
||
"Content-Type": "application/json",
|
||
}
|
||
body: dict = {
|
||
"model": _openai_model(),
|
||
"messages": messages,
|
||
"temperature": temperature,
|
||
"stream": False,
|
||
}
|
||
_apply_max_tokens(body, max_tokens)
|
||
r = requests.post(
|
||
_openai_chat_url(),
|
||
headers=headers,
|
||
json=body,
|
||
timeout=_ai_timeout_seconds(image_count=image_count, chat=chat),
|
||
)
|
||
r.raise_for_status()
|
||
data = r.json()
|
||
choices = data.get("choices") or []
|
||
if not choices:
|
||
return "AI 生成失败:响应无 choices", "error"
|
||
choice = choices[0] or {}
|
||
msg = choice.get("message") or {}
|
||
text = _openai_message_text(msg)
|
||
if not text:
|
||
return "AI 生成失败:空内容", choice.get("finish_reason") or "error"
|
||
return text, str(choice.get("finish_reason") or "")
|
||
|
||
|
||
def _generate_openai(
|
||
prompt: str,
|
||
images: List[tuple],
|
||
temperature: float,
|
||
*,
|
||
max_tokens: int | None = None,
|
||
) -> str:
|
||
if images:
|
||
content: List[dict] = [{"type": "text", "text": prompt}]
|
||
for b64, mime in images:
|
||
content.append(
|
||
{
|
||
"type": "image_url",
|
||
"image_url": {"url": f"data:{mime};base64,{b64}"},
|
||
}
|
||
)
|
||
messages = [{"role": "user", "content": content}]
|
||
else:
|
||
messages = [{"role": "user", "content": prompt}]
|
||
text, _reason = _openai_chat_completion(
|
||
messages,
|
||
temperature=temperature,
|
||
max_tokens=max_tokens,
|
||
image_count=len(images),
|
||
)
|
||
return text
|
||
|
||
|
||
def _generate_ollama(
|
||
prompt: str,
|
||
images: List[tuple],
|
||
temperature: float,
|
||
*,
|
||
max_tokens: int | None = None,
|
||
chat: bool = False,
|
||
) -> Tuple[str, str]:
|
||
options: dict = {"temperature": temperature}
|
||
if max_tokens is not None and max_tokens > 0:
|
||
options["num_predict"] = int(max_tokens)
|
||
payload = {
|
||
"model": _ollama_model(),
|
||
"prompt": prompt,
|
||
"stream": False,
|
||
"options": options,
|
||
}
|
||
if images:
|
||
payload["images"] = [b64 for b64, _mime in images]
|
||
r = requests.post(
|
||
_ollama_api(),
|
||
json=payload,
|
||
timeout=_ai_timeout_seconds(image_count=len(images), chat=chat),
|
||
)
|
||
r.raise_for_status()
|
||
data = r.json()
|
||
text = (data.get("response") or "").strip() or "AI 生成失败"
|
||
return text, str(data.get("done_reason") or "")
|
||
|
||
|
||
def ai_generate(
|
||
prompt: str,
|
||
*,
|
||
image_paths: Optional[Sequence[str]] = None,
|
||
images_b64: Optional[Sequence[str]] = None,
|
||
temperature: float = 0.2,
|
||
max_tokens: int | None = None,
|
||
) -> str:
|
||
"""统一文本生成;失败时返回以「AI 调用失败」开头的说明。"""
|
||
images = _collect_images(image_paths, images_b64)
|
||
try:
|
||
if _use_openai():
|
||
return _generate_openai(prompt, images, temperature, max_tokens=max_tokens)
|
||
text, _reason = _generate_ollama(prompt, images, temperature, max_tokens=max_tokens)
|
||
return text
|
||
except requests.HTTPError as e:
|
||
detail = ""
|
||
try:
|
||
detail = (e.response.text or "")[:500]
|
||
except Exception:
|
||
pass
|
||
prov = "OpenAI" if _use_openai() else "Ollama"
|
||
return f"AI 调用失败({prov} HTTP {e.response.status_code if e.response else '?'}):{detail or str(e)}"
|
||
except Exception as e:
|
||
prov = "OpenAI" if _use_openai() else "Ollama"
|
||
return f"AI 调用失败({prov}):{str(e)}"
|
||
|
||
|
||
_CHAT_CONTINUE_USER = (
|
||
"你上一条回复在中途截断了。请从断点处继续写完,不要重复已写内容,"
|
||
"保持同一语气,写完给出完整结尾。"
|
||
)
|
||
_CHAT_END_CHARS = "。!?.!?\"」』))>】"
|
||
|
||
|
||
def _looks_truncated(text: str) -> bool:
|
||
t = (text or "").rstrip()
|
||
if len(t) < 48:
|
||
return False
|
||
if t[-1] in _CHAT_END_CHARS:
|
||
return False
|
||
if t.endswith("…") or t.endswith("..."):
|
||
return True
|
||
return t[-1] not in ",、,;;::\n"
|
||
|
||
|
||
def _should_continue(reason: str, chunk: str) -> bool:
|
||
if reason == "length":
|
||
return True
|
||
return _looks_truncated(chunk)
|
||
|
||
|
||
def ai_generate_chat(
|
||
*,
|
||
system: str,
|
||
user: str,
|
||
temperature: float = 0.5,
|
||
images_b64: Optional[Sequence[str]] = None,
|
||
max_tokens: int = 8192,
|
||
max_continuations: int = 3,
|
||
) -> str:
|
||
"""聊天专用:system/user 分消息;输出触顶时自动续写。"""
|
||
images = _collect_images(None, images_b64)
|
||
try:
|
||
if _use_openai():
|
||
messages: list[dict] = [
|
||
{"role": "system", "content": system.strip()},
|
||
]
|
||
if images:
|
||
content: List[dict] = [{"type": "text", "text": user.strip()}]
|
||
for b64, mime in images:
|
||
content.append(
|
||
{
|
||
"type": "image_url",
|
||
"image_url": {"url": f"data:{mime};base64,{b64}"},
|
||
}
|
||
)
|
||
messages.append({"role": "user", "content": content})
|
||
else:
|
||
messages.append({"role": "user", "content": user.strip()})
|
||
|
||
parts: list[str] = []
|
||
for _ in range(max(1, int(max_continuations) + 1)):
|
||
chunk, reason = _openai_chat_completion(
|
||
messages,
|
||
temperature=temperature,
|
||
max_tokens=max_tokens,
|
||
image_count=len(images),
|
||
chat=True,
|
||
)
|
||
if chunk.startswith("AI 调用失败") or chunk.startswith("AI 生成失败"):
|
||
return chunk if not parts else "".join(parts)
|
||
parts.append(chunk)
|
||
if not _should_continue(reason, chunk):
|
||
break
|
||
messages.append({"role": "assistant", "content": chunk})
|
||
messages.append({"role": "user", "content": _CHAT_CONTINUE_USER})
|
||
return "".join(parts).strip() or "AI 生成失败:空内容"
|
||
|
||
prompt = f"{system.strip()}\n\n---\n\n{user.strip()}"
|
||
parts = []
|
||
current_prompt = prompt
|
||
for _ in range(max(1, int(max_continuations) + 1)):
|
||
chunk, reason = _generate_ollama(
|
||
current_prompt,
|
||
images if not parts else [],
|
||
temperature,
|
||
max_tokens=max_tokens,
|
||
chat=True,
|
||
)
|
||
if chunk.startswith("AI 生成失败") and not parts:
|
||
return chunk
|
||
parts.append(chunk)
|
||
if not _should_continue(reason, chunk):
|
||
break
|
||
tail = chunk[-400:] if len(chunk) > 400 else chunk
|
||
current_prompt = (
|
||
f"{prompt}\n\n{''.join(parts)}\n\n"
|
||
f"{_CHAT_CONTINUE_USER}\n\n"
|
||
f"(已写结尾片段供衔接:…{tail})"
|
||
)
|
||
return "".join(parts).strip() or "AI 生成失败"
|
||
except requests.HTTPError as e:
|
||
detail = ""
|
||
try:
|
||
detail = (e.response.text or "")[:500]
|
||
except Exception:
|
||
pass
|
||
prov = "OpenAI" if _use_openai() else "Ollama"
|
||
return f"AI 调用失败({prov} HTTP {e.response.status_code if e.response else '?'}):{detail or str(e)}"
|
||
except Exception as e:
|
||
prov = "OpenAI" if _use_openai() else "Ollama"
|
||
return f"AI 调用失败({prov}):{str(e)}"
|
||
|
||
|
||
def ai_review(trades_text: str, period_title: str, image_paths=None) -> str:
|
||
n_img = len(image_paths or [])
|
||
period_label = "周" if "周" in str(period_title) else "日"
|
||
attach_note = (
|
||
f"ℹ️ 【系统说明:已向模型附带 {n_img} 张复盘附图(自动K线或上传截图),请结合附图分析第5节。】\n\n"
|
||
if n_img
|
||
else "ℹ️ 【系统说明:本次未附带复盘附图,第5节请写明「无附图,无法看图」;保存复盘记录时可勾选「自动生成K线图」。】\n\n"
|
||
)
|
||
prompt = f"""
|
||
你是一位专业交易教练。下面是用户的{period_title}交易记录,请做简洁、可执行的复盘(中文)。
|
||
|
||
【硬性规则 — 必须遵守】
|
||
- 你只能根据「交易记录」里**明确出现的字段**陈述事实;禁止编造:是否触发止损、是否扛单、亏损是否扩大、图上具体结构/进出场点位等记录里**没有**的信息。
|
||
- 「平仓/离场」只是交易员自述摘要,不是客观成交明细;若记录未写明代币是否打到止损价、是否软件平仓等,不要断言执行路径,可用「在记录有限前提下,一种可能是……」或简短写「执行路径记录不足,无法判断」。
|
||
- 「提前离场」类结论必须优先依据记录中的「提前离场记录」字段;若该段全为「无」或未出现有效内容,不得写道「明显扛单」「拒不止损」「未执行硬止损」等。
|
||
- 实际RR为负只说明结果相对于预期RR不利,不等同于「风控失灵」或「止损纪律崩溃」,除非记录里另有依据。
|
||
- 禁止用语:人身攻击、夸张定性(如「致命伤」「灾难」);语气克制、对事不对人。
|
||
- 若有截图且你能辨认,再结合图讨论;看不清或无明确定位则明确说「无法从图确认」,不得虚构 K 线故事。
|
||
|
||
【输出格式 — Markdown,必须严格遵守】
|
||
- 第一行:**交易复盘报告({period_label}度)**
|
||
- 五个大节标题必须**完全一致**(含 emoji,不要用其它编号或改名):
|
||
**1. 📊 总体盈亏结构**
|
||
**2. 🧠 心态与执行**
|
||
**3. 🏷️ 行为标签**
|
||
**4. ✅ 改进建议**
|
||
**5. 📈 图表分析**
|
||
- 每节正文用 `- **子项名**:内容` 列表;第4节改进建议用有序列表 `1. 2. 3.`
|
||
- 第1节至少包含:**笔数/盈亏**、**风险回报比**、**总结**
|
||
- 第2节至少包含:**得分**(1–10)、**依据**(对应记录字段)
|
||
- 第5节至少包含:**趋势确认**、**执行路径**(记录不足则写明)
|
||
- 语气简洁,少形容词;不要输出代码块、不要表格
|
||
|
||
交易记录:
|
||
{trades_text}
|
||
""".strip()
|
||
return attach_note + ai_generate(prompt, image_paths=image_paths, temperature=0.2)
|
||
|
||
|
||
def ai_short_advice(prompt_text: str) -> str:
|
||
prompt = f"""
|
||
你是交易风控助理。请用中文给出**最多 3 条**提醒,要求:
|
||
- 每条不超过 25 个字
|
||
- 语气克制、具体、可执行
|
||
- 不要输出 Markdown,不要编号前缀以外的废话
|
||
|
||
场景:
|
||
{prompt_text}
|
||
""".strip()
|
||
return ai_generate(prompt, temperature=0.2)
|
||
|
||
|
||
def ai_provider_label() -> str:
|
||
if _use_openai():
|
||
return f"OpenAI 兼容 · {_openai_model()} @ {_openai_api_base()}"
|
||
return f"Ollama · {_ollama_model()}"
|
||
|
||
|
||
def ai_config_status() -> dict:
|
||
"""调试用:当前进程内读到的 AI 配置(不含密钥明文)。"""
|
||
key = _openai_api_key()
|
||
return {
|
||
"provider": _ai_provider(),
|
||
"openai_base": _openai_api_base(),
|
||
"openai_model": _openai_model(),
|
||
"openai_key_configured": bool(key),
|
||
"ollama_api": _ollama_api(),
|
||
"ollama_model": _ollama_model(),
|
||
}
|