"""数据库连接与初始化""" import os from pathlib import Path from sqlalchemy import create_engine, text from sqlalchemy.orm import sessionmaker, declarative_base # 数据库文件路径 BASE_DIR = Path(__file__).resolve().parent.parent DB_PATH = BASE_DIR / "data" / "pretrade.db" DB_PATH.parent.mkdir(parents=True, exist_ok=True) DATABASE_URL = f"sqlite:///{DB_PATH}" engine = create_engine( DATABASE_URL, connect_args={"check_same_thread": False}, ) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) Base = declarative_base() def get_db(): """FastAPI 依赖:获取数据库会话""" db = SessionLocal() try: yield db finally: db.close() def init_database(): """执行建表 SQL 并写入默认数据""" from app.models import MarketRegime, Account, Strategy, RegimeMatch # noqa: F401 # 建表 Base.metadata.create_all(bind=engine) # 若已有数据则跳过初始化 db = SessionLocal() try: if db.query(MarketRegime).count() > 0: return _seed_default_data(db) finally: db.close() def _seed_default_data(db): """写入默认大盘阶段、账户、策略及匹配规则""" from app.models import MarketRegime, Account, Strategy, RegimeMatch # ── 大盘阶段(8 个固定) ── regimes = [ MarketRegime(name="上涨初期", trade_type="顺势", allow_direction="做多", remark="顺势做多,禁止做空"), MarketRegime(name="上涨中期", trade_type="顺势", allow_direction="做多", remark="顺势做多,禁止做空"), MarketRegime(name="上涨末期", trade_type="反转", allow_direction="做空", remark="反转做空"), MarketRegime(name="宽幅震荡", trade_type="观望", allow_direction="禁止", remark="观望,禁止交易"), MarketRegime(name="宽幅震荡末期", trade_type="反转", allow_direction="多空均可", remark="反转交易(收敛突破)"), MarketRegime(name="下跌初期", trade_type="顺势", allow_direction="做空", remark="顺势做空,禁止做多"), MarketRegime(name="下跌中期", trade_type="顺势", allow_direction="做空", remark="顺势做空,禁止做多"), MarketRegime(name="下跌末期", trade_type="反转", allow_direction="做多", remark="反转做多"), ] db.add_all(regimes) db.flush() regime_map = {r.name: r.id for r in regimes} # ── 账户(默认本金 100U) ── accounts = [ Account(account_name="账户1-斐波回调", total_capital=100, trade_cycle="4H/1H", risk_ratio="5%~10%", remark="斐波回调专用"), Account(account_name="账户2-箱体突破", total_capital=100, trade_cycle="日内", risk_ratio="0.5%~1%", remark="箱体顺势突破专用"), Account(account_name="账户3-收敛突破", total_capital=100, trade_cycle="日内", risk_ratio="0.5%~1%", remark="收敛结构突破专用"), Account(account_name="账户4-手工主观", total_capital=100, trade_cycle="灵活", risk_ratio="2%", remark="手工主观策略专用"), ] db.add_all(accounts) db.flush() acc_map = {a.account_name: a.id for a in accounts} # ── 策略(4 个内置) ── strategies = [ Strategy( strategy_name="斐波回调", fit_cycle="4H/1H", fit_trend_strength="弱", trade_type="顺势", strategy_rule=( "入场:仅 0.618 / 0.786\n" "适用:通道式上涨/下跌\n" "适用:弱趋势\n" "周期:4H/1H\n" "注意:点位、触碰次数、突破条件均由人工手动筛选输入" ), ), Strategy( strategy_name="箱体顺势突破", fit_cycle="日内(5分钟K线)", fit_trend_strength="强", trade_type="顺势", strategy_rule=( "箱体时长要求:≥4小时(48根5分钟K线),优先8小时以上\n" "成立条件:人工手动判断触碰上下沿次数、顺势/逆势箱体、突破确认条件\n" "系统不校验,仅展示该策略可用\n" "适用:强趋势顺势阶段" ), ), Strategy( strategy_name="收敛结构突破", fit_cycle="日内", fit_trend_strength="强", trade_type="反转", strategy_rule=( "适用:宽幅震荡末期收敛三角\n" "人工确认结构,系统仅匹配可用\n" "注意:结构细节、突破条件均由人工手动确认" ), ), Strategy( strategy_name="手工主观", fit_cycle="灵活", fit_trend_strength="全部", trade_type="全部", strategy_rule="全场景人工自主判断,系统仅做前置匹配", ), ] db.add_all(strategies) db.flush() strat_map = {s.strategy_name: s.id for s in strategies} # ── 匹配绑定规则 ── cycles = ["日线", "4H", "1H"] matches = [] # 顺势阶段(上涨初/中期、下跌初/中期):强→箱体,弱→斐波 shunshi_regimes = ["上涨初期", "上涨中期", "下跌初期", "下跌中期"] for rname in shunshi_regimes: rid = regime_map[rname] for cycle in cycles: matches.append(RegimeMatch( market_regime_id=rid, market_cycle=cycle, trend_strength="强", account_id=acc_map["账户2-箱体突破"], strategy_id=strat_map["箱体顺势突破"], )) matches.append(RegimeMatch( market_regime_id=rid, market_cycle=cycle, trend_strength="弱", account_id=acc_map["账户1-斐波回调"], strategy_id=strat_map["斐波回调"], )) # 反转阶段(上涨末期、下跌末期):手工主观 fanzhuan_regimes = ["上涨末期", "下跌末期"] for rname in fanzhuan_regimes: rid = regime_map[rname] for cycle in cycles: for strength in ["强", "弱"]: matches.append(RegimeMatch( market_regime_id=rid, market_cycle=cycle, trend_strength=strength, account_id=acc_map["账户4-手工主观"], strategy_id=strat_map["手工主观"], )) # 宽幅震荡末期:收敛突破 rid = regime_map["宽幅震荡末期"] for cycle in cycles: for strength in ["强", "弱"]: matches.append(RegimeMatch( market_regime_id=rid, market_cycle=cycle, trend_strength=strength, account_id=acc_map["账户3-收敛突破"], strategy_id=strat_map["收敛结构突破"], )) # 宽幅震荡:不写入匹配(系统自动禁用) db.add_all(matches) db.commit()