Files
2026-05-28 21:43:23 +08:00

176 lines
7.0 KiB
Python

"""数据库连接与初始化"""
import os
from pathlib import Path
from sqlalchemy import create_engine, text
from sqlalchemy.orm import sessionmaker, declarative_base
# 数据库文件路径
BASE_DIR = Path(__file__).resolve().parent.parent
DB_PATH = BASE_DIR / "data" / "pretrade.db"
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
DATABASE_URL = f"sqlite:///{DB_PATH}"
engine = create_engine(
DATABASE_URL,
connect_args={"check_same_thread": False},
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
def get_db():
"""FastAPI 依赖:获取数据库会话"""
db = SessionLocal()
try:
yield db
finally:
db.close()
def init_database():
"""执行建表 SQL 并写入默认数据"""
from app.models import MarketRegime, Account, Strategy, RegimeMatch # noqa: F401
# 建表
Base.metadata.create_all(bind=engine)
# 若已有数据则跳过初始化
db = SessionLocal()
try:
if db.query(MarketRegime).count() > 0:
return
_seed_default_data(db)
finally:
db.close()
def _seed_default_data(db):
"""写入默认大盘阶段、账户、策略及匹配规则"""
from app.models import MarketRegime, Account, Strategy, RegimeMatch
# ── 大盘阶段(8 个固定) ──
regimes = [
MarketRegime(name="上涨初期", trade_type="顺势", allow_direction="做多", remark="顺势做多,禁止做空"),
MarketRegime(name="上涨中期", trade_type="顺势", allow_direction="做多", remark="顺势做多,禁止做空"),
MarketRegime(name="上涨末期", trade_type="反转", allow_direction="做空", remark="反转做空"),
MarketRegime(name="宽幅震荡", trade_type="观望", allow_direction="禁止", remark="观望,禁止交易"),
MarketRegime(name="宽幅震荡末期", trade_type="反转", allow_direction="多空均可", remark="反转交易(收敛突破)"),
MarketRegime(name="下跌初期", trade_type="顺势", allow_direction="做空", remark="顺势做空,禁止做多"),
MarketRegime(name="下跌中期", trade_type="顺势", allow_direction="做空", remark="顺势做空,禁止做多"),
MarketRegime(name="下跌末期", trade_type="反转", allow_direction="做多", remark="反转做多"),
]
db.add_all(regimes)
db.flush()
regime_map = {r.name: r.id for r in regimes}
# ── 账户(默认本金 100U) ──
accounts = [
Account(account_name="账户1-斐波回调", total_capital=100, trade_cycle="4H/1H", risk_ratio="5%~10%", remark="斐波回调专用"),
Account(account_name="账户2-箱体突破", total_capital=100, trade_cycle="日内", risk_ratio="0.5%~1%", remark="箱体顺势突破专用"),
Account(account_name="账户3-收敛突破", total_capital=100, trade_cycle="日内", risk_ratio="0.5%~1%", remark="收敛结构突破专用"),
Account(account_name="账户4-手工主观", total_capital=100, trade_cycle="灵活", risk_ratio="2%", remark="手工主观策略专用"),
]
db.add_all(accounts)
db.flush()
acc_map = {a.account_name: a.id for a in accounts}
# ── 策略(4 个内置) ──
strategies = [
Strategy(
strategy_name="斐波回调",
fit_cycle="4H/1H",
fit_trend_strength="",
trade_type="顺势",
strategy_rule=(
"入场:仅 0.618 / 0.786\n"
"适用:通道式上涨/下跌\n"
"适用:弱趋势\n"
"周期:4H/1H\n"
"注意:点位、触碰次数、突破条件均由人工手动筛选输入"
),
),
Strategy(
strategy_name="箱体顺势突破",
fit_cycle="日内(5分钟K线)",
fit_trend_strength="",
trade_type="顺势",
strategy_rule=(
"箱体时长要求:≥4小时(48根5分钟K线),优先8小时以上\n"
"成立条件:人工手动判断触碰上下沿次数、顺势/逆势箱体、突破确认条件\n"
"系统不校验,仅展示该策略可用\n"
"适用:强趋势顺势阶段"
),
),
Strategy(
strategy_name="收敛结构突破",
fit_cycle="日内",
fit_trend_strength="",
trade_type="反转",
strategy_rule=(
"适用:宽幅震荡末期收敛三角\n"
"人工确认结构,系统仅匹配可用\n"
"注意:结构细节、突破条件均由人工手动确认"
),
),
Strategy(
strategy_name="手工主观",
fit_cycle="灵活",
fit_trend_strength="全部",
trade_type="全部",
strategy_rule="全场景人工自主判断,系统仅做前置匹配",
),
]
db.add_all(strategies)
db.flush()
strat_map = {s.strategy_name: s.id for s in strategies}
# ── 匹配绑定规则 ──
cycles = ["日线", "4H", "1H"]
matches = []
# 顺势阶段(上涨初/中期、下跌初/中期):强→箱体,弱→斐波
shunshi_regimes = ["上涨初期", "上涨中期", "下跌初期", "下跌中期"]
for rname in shunshi_regimes:
rid = regime_map[rname]
for cycle in cycles:
matches.append(RegimeMatch(
market_regime_id=rid, market_cycle=cycle, trend_strength="",
account_id=acc_map["账户2-箱体突破"], strategy_id=strat_map["箱体顺势突破"],
))
matches.append(RegimeMatch(
market_regime_id=rid, market_cycle=cycle, trend_strength="",
account_id=acc_map["账户1-斐波回调"], strategy_id=strat_map["斐波回调"],
))
# 反转阶段(上涨末期、下跌末期):手工主观
fanzhuan_regimes = ["上涨末期", "下跌末期"]
for rname in fanzhuan_regimes:
rid = regime_map[rname]
for cycle in cycles:
for strength in ["", ""]:
matches.append(RegimeMatch(
market_regime_id=rid, market_cycle=cycle, trend_strength=strength,
account_id=acc_map["账户4-手工主观"], strategy_id=strat_map["手工主观"],
))
# 宽幅震荡末期:收敛突破
rid = regime_map["宽幅震荡末期"]
for cycle in cycles:
for strength in ["", ""]:
matches.append(RegimeMatch(
market_regime_id=rid, market_cycle=cycle, trend_strength=strength,
account_id=acc_map["账户3-收敛突破"], strategy_id=strat_map["收敛结构突破"],
))
# 宽幅震荡:不写入匹配(系统自动禁用)
db.add_all(matches)
db.commit()