""" 三所共用:账户级方向 / 币种白名单(.env 开关,默认关闭=不限制)。 """ from __future__ import annotations import os from dataclasses import dataclass from typing import Callable, FrozenSet, Optional, Sequence, Tuple DIR_BOTH = "both" DIR_LONG_ONLY = "long_only" DIR_SHORT_ONLY = "short_only" VALID_DIRECTION_MODES = frozenset({DIR_BOTH, DIR_LONG_ONLY, DIR_SHORT_ONLY}) _DIR_ALIASES = { "both": DIR_BOTH, "双向": DIR_BOTH, "long": DIR_LONG_ONLY, "long_only": DIR_LONG_ONLY, "多": DIR_LONG_ONLY, "仅多": DIR_LONG_ONLY, "做多": DIR_LONG_ONLY, "short": DIR_SHORT_ONLY, "short_only": DIR_SHORT_ONLY, "空": DIR_SHORT_ONLY, "仅空": DIR_SHORT_ONLY, "做空": DIR_SHORT_ONLY, } def _env_bool(raw: Optional[str], default: bool = False) -> bool: if raw is None: return default return (raw or "").strip().lower() in ("1", "true", "yes", "on") def normalize_direction_mode(raw: Optional[str]) -> str: v = (raw or DIR_BOTH).strip().lower() return _DIR_ALIASES.get(v, v if v in VALID_DIRECTION_MODES else DIR_BOTH) def symbol_base_coin(symbol: str) -> str: """BTC/USDT:USDT、BTC/USDT、BTC、btc -> BTC""" s = (symbol or "").strip().upper() if not s: return "" if ":" in s: s = s.split(":", 1)[0] if "/" in s: return s.split("/", 1)[0].strip() if s.endswith("USDT") and len(s) > 4: return s[:-4] return s def parse_symbol_whitelist(raw: Optional[str]) -> Tuple[str, ...]: if not raw or not str(raw).strip(): return () parts = [] for piece in str(raw).replace(";", ",").split(","): base = symbol_base_coin(piece.strip()) if base and base not in parts: parts.append(base) return tuple(parts) @dataclass(frozen=True) class TradePolicy: direction_restrict_enabled: bool direction_mode: str symbol_restrict_enabled: bool symbol_whitelist: Tuple[str, ...] @property def allows_long(self) -> bool: if not self.direction_restrict_enabled: return True return self.direction_mode in (DIR_BOTH, DIR_LONG_ONLY) @property def allows_short(self) -> bool: if not self.direction_restrict_enabled: return True return self.direction_mode in (DIR_BOTH, DIR_SHORT_ONLY) def load_trade_policy(env: Optional[dict] = None) -> TradePolicy: e = env if env is not None else os.environ direction_restrict = _env_bool(e.get("TRADE_DIRECTION_RESTRICT_ENABLED"), False) symbol_restrict = _env_bool(e.get("TRADE_SYMBOL_RESTRICT_ENABLED"), False) direction_mode = normalize_direction_mode(e.get("TRADE_DIRECTION")) whitelist = parse_symbol_whitelist(e.get("TRADE_SYMBOL_WHITELIST")) if symbol_restrict and not whitelist: symbol_restrict = False return TradePolicy( direction_restrict_enabled=direction_restrict, direction_mode=direction_mode, symbol_restrict_enabled=symbol_restrict, symbol_whitelist=whitelist, ) def direction_mode_label_zh(mode: str) -> str: m = normalize_direction_mode(mode) if m == DIR_LONG_ONLY: return "仅多" if m == DIR_SHORT_ONLY: return "仅空" return "双向" def trade_policy_badge_parts(policy: TradePolicy) -> Tuple[str, ...]: parts: list[str] = [] if policy.direction_restrict_enabled: if policy.direction_mode == DIR_LONG_ONLY: parts.append("仅多") elif policy.direction_mode == DIR_SHORT_ONLY: parts.append("仅空") if policy.symbol_restrict_enabled and policy.symbol_whitelist: parts.append("/".join(policy.symbol_whitelist)) return tuple(parts) def trade_policy_to_dict(policy: TradePolicy) -> dict: badges = trade_policy_badge_parts(policy) return { "direction_restrict_enabled": policy.direction_restrict_enabled, "direction_mode": policy.direction_mode, "direction_label_zh": ( direction_mode_label_zh(policy.direction_mode) if policy.direction_restrict_enabled else "双向" ), "allows_long": policy.allows_long, "allows_short": policy.allows_short, "symbol_restrict_enabled": policy.symbol_restrict_enabled, "symbol_whitelist": list(policy.symbol_whitelist), "badge_parts": list(badges), "badge_text": " · ".join(badges), } def normalize_open_direction(policy: TradePolicy, direction: str) -> str: d = (direction or "long").strip().lower() if d not in ("long", "short"): d = "long" if policy.direction_restrict_enabled: if policy.direction_mode == DIR_LONG_ONLY: return "long" if policy.direction_mode == DIR_SHORT_ONLY: return "short" return d def assert_direction_allowed(policy: TradePolicy, direction: str) -> Tuple[bool, str]: d = (direction or "").strip().lower() if d not in ("long", "short"): if d in ("watch", ""): return True, "" return False, "方向无效,请选择做多或做空" if d == "long" and not policy.allows_long: return False, "当前账户配置为仅做空,不允许做多" if d == "short" and not policy.allows_short: return False, "当前账户配置为仅做多,不允许做空" return True, "" def assert_symbol_allowed( policy: TradePolicy, symbol: str, *, normalize_symbol_fn: Optional[Callable[[str], str]] = None, ) -> Tuple[bool, str]: if not policy.symbol_restrict_enabled: return True, "" sym = (symbol or "").strip() if not sym: return False, "请选择币种" if normalize_symbol_fn is not None: sym_norm = (normalize_symbol_fn(sym) or "").strip() else: sym_norm = sym base = symbol_base_coin(sym_norm or sym) allowed: FrozenSet[str] = frozenset(policy.symbol_whitelist) if base not in allowed: allowed_txt = "、".join(policy.symbol_whitelist) return False, f"当前账户仅允许 {allowed_txt},不允许 {base or sym}" return True, "" def assert_trade_policy_open( policy: TradePolicy, symbol: str, direction: str, normalize_symbol_fn: Optional[Callable[[str], str]] = None, ) -> Tuple[bool, str]: ok_sym, msg_sym = assert_symbol_allowed( policy, symbol, normalize_symbol_fn=normalize_symbol_fn ) if not ok_sym: return False, msg_sym ok_dir, msg_dir = assert_direction_allowed(policy, direction) if not ok_dir: return False, msg_dir return True, ""