174 lines
4.5 KiB
Python
174 lines
4.5 KiB
Python
"""数据库 CRUD 操作"""
|
|
|
|
from typing import List, Optional
|
|
from sqlalchemy.orm import Session, joinedload
|
|
|
|
from app.models import MarketRegime, Account, Strategy, RegimeMatch
|
|
from app.schemas import (
|
|
MarketRegimeCreate, MarketRegimeUpdate,
|
|
AccountCreate, AccountUpdate,
|
|
StrategyCreate, StrategyUpdate,
|
|
RegimeMatchCreate, RegimeMatchUpdate,
|
|
)
|
|
|
|
|
|
# ── 大盘阶段 ──
|
|
|
|
def get_regimes(db: Session) -> List[MarketRegime]:
|
|
return db.query(MarketRegime).order_by(MarketRegime.id).all()
|
|
|
|
|
|
def get_regime(db: Session, regime_id: int) -> Optional[MarketRegime]:
|
|
return db.query(MarketRegime).filter(MarketRegime.id == regime_id).first()
|
|
|
|
|
|
def create_regime(db: Session, data: MarketRegimeCreate) -> MarketRegime:
|
|
obj = MarketRegime(**data.model_dump())
|
|
db.add(obj)
|
|
db.commit()
|
|
db.refresh(obj)
|
|
return obj
|
|
|
|
|
|
def update_regime(db: Session, regime_id: int, data: MarketRegimeUpdate) -> Optional[MarketRegime]:
|
|
obj = get_regime(db, regime_id)
|
|
if not obj:
|
|
return None
|
|
for k, v in data.model_dump(exclude_unset=True).items():
|
|
setattr(obj, k, v)
|
|
db.commit()
|
|
db.refresh(obj)
|
|
return obj
|
|
|
|
|
|
def delete_regime(db: Session, regime_id: int) -> bool:
|
|
obj = get_regime(db, regime_id)
|
|
if not obj:
|
|
return False
|
|
db.delete(obj)
|
|
db.commit()
|
|
return True
|
|
|
|
|
|
# ── 账户 ──
|
|
|
|
def get_accounts(db: Session) -> List[Account]:
|
|
return db.query(Account).order_by(Account.id).all()
|
|
|
|
|
|
def get_account(db: Session, account_id: int) -> Optional[Account]:
|
|
return db.query(Account).filter(Account.id == account_id).first()
|
|
|
|
|
|
def create_account(db: Session, data: AccountCreate) -> Account:
|
|
obj = Account(**data.model_dump())
|
|
db.add(obj)
|
|
db.commit()
|
|
db.refresh(obj)
|
|
return obj
|
|
|
|
|
|
def update_account(db: Session, account_id: int, data: AccountUpdate) -> Optional[Account]:
|
|
obj = get_account(db, account_id)
|
|
if not obj:
|
|
return None
|
|
for k, v in data.model_dump(exclude_unset=True).items():
|
|
setattr(obj, k, v)
|
|
db.commit()
|
|
db.refresh(obj)
|
|
return obj
|
|
|
|
|
|
def delete_account(db: Session, account_id: int) -> bool:
|
|
obj = get_account(db, account_id)
|
|
if not obj:
|
|
return False
|
|
db.delete(obj)
|
|
db.commit()
|
|
return True
|
|
|
|
|
|
# ── 策略 ──
|
|
|
|
def get_strategies(db: Session) -> List[Strategy]:
|
|
return db.query(Strategy).order_by(Strategy.id).all()
|
|
|
|
|
|
def get_strategy(db: Session, strategy_id: int) -> Optional[Strategy]:
|
|
return db.query(Strategy).filter(Strategy.id == strategy_id).first()
|
|
|
|
|
|
def create_strategy(db: Session, data: StrategyCreate) -> Strategy:
|
|
obj = Strategy(**data.model_dump())
|
|
db.add(obj)
|
|
db.commit()
|
|
db.refresh(obj)
|
|
return obj
|
|
|
|
|
|
def update_strategy(db: Session, strategy_id: int, data: StrategyUpdate) -> Optional[Strategy]:
|
|
obj = get_strategy(db, strategy_id)
|
|
if not obj:
|
|
return None
|
|
for k, v in data.model_dump(exclude_unset=True).items():
|
|
setattr(obj, k, v)
|
|
db.commit()
|
|
db.refresh(obj)
|
|
return obj
|
|
|
|
|
|
def delete_strategy(db: Session, strategy_id: int) -> bool:
|
|
obj = get_strategy(db, strategy_id)
|
|
if not obj:
|
|
return False
|
|
db.delete(obj)
|
|
db.commit()
|
|
return True
|
|
|
|
|
|
# ── 匹配绑定 ──
|
|
|
|
def get_matches(db: Session) -> List[RegimeMatch]:
|
|
return (
|
|
db.query(RegimeMatch)
|
|
.options(
|
|
joinedload(RegimeMatch.regime),
|
|
joinedload(RegimeMatch.account),
|
|
joinedload(RegimeMatch.strategy),
|
|
)
|
|
.order_by(RegimeMatch.id)
|
|
.all()
|
|
)
|
|
|
|
|
|
def get_match(db: Session, match_id: int) -> Optional[RegimeMatch]:
|
|
return db.query(RegimeMatch).filter(RegimeMatch.id == match_id).first()
|
|
|
|
|
|
def create_match(db: Session, data: RegimeMatchCreate) -> RegimeMatch:
|
|
obj = RegimeMatch(**data.model_dump())
|
|
db.add(obj)
|
|
db.commit()
|
|
db.refresh(obj)
|
|
return obj
|
|
|
|
|
|
def update_match(db: Session, match_id: int, data: RegimeMatchUpdate) -> Optional[RegimeMatch]:
|
|
obj = get_match(db, match_id)
|
|
if not obj:
|
|
return None
|
|
for k, v in data.model_dump(exclude_unset=True).items():
|
|
setattr(obj, k, v)
|
|
db.commit()
|
|
db.refresh(obj)
|
|
return obj
|
|
|
|
|
|
def delete_match(db: Session, match_id: int) -> bool:
|
|
obj = get_match(db, match_id)
|
|
if not obj:
|
|
return False
|
|
db.delete(obj)
|
|
db.commit()
|
|
return True
|