From e5a586f90351c087cb368bd254df1157e6294bf8 Mon Sep 17 00:00:00 2001 From: dekun Date: Wed, 1 Jul 2026 14:42:16 +0800 Subject: [PATCH] Restructure into modules/ with single-process CTP and config/ layout. Move business code under modules/, env template to config/, PM2 single qihuo process, and _legacy shims for old imports. Co-authored-by: Cursor --- .env.example | 63 +- .gitignore | 1 + _legacy/admin_settings.py | 2 + _legacy/ai_client.py | 2 + _legacy/ai_messages.py | 2 + _legacy/ai_worker.py | 2 + _legacy/contract_profile.py | 2 + _legacy/contract_specs.py | 2 + _legacy/ctp_entry_price.py | 2 + _legacy/ctp_fee_sync.py | 2 + _legacy/ctp_fee_worker.py | 2 + _legacy/ctp_ipc_client.py | 2 + _legacy/ctp_kline.py | 2 + _legacy/ctp_premarket_connect.py | 2 + _legacy/ctp_reconnect.py | 2 + _legacy/ctp_settings.py | 2 + _legacy/ctp_symbol.py | 2 + _legacy/ctp_trade_sync.py | 2 + _legacy/ctp_trading_state.py | 2 + _legacy/ctp_worker.py | 2 + _legacy/dashboard_lib.py | 2 + _legacy/db_backup.py | 2 + _legacy/db_conn.py | 2 + _legacy/doc_render.py | 2 + _legacy/env_file.py | 2 + _legacy/fee_specs.py | 2 + _legacy/fee_sync.py | 2 + _legacy/key_monitor_lib.py | 2 + _legacy/kline_chart.py | 2 + _legacy/kline_store.py | 2 + _legacy/kline_stream.py | 2 + _legacy/locale_fix.py | 2 + _legacy/market.py | 2 + _legacy/market_sessions.py | 2 + _legacy/nav_settings.py | 2 + _legacy/order_pending.py | 2 + _legacy/pending_order_worker.py | 2 + _legacy/position_sizing.py | 2 + _legacy/position_stream.py | 2 + _legacy/product_recommend.py | 2 + _legacy/recommend_store.py | 2 + _legacy/recommend_stream.py | 2 + _legacy/recommend_trend.py | 2 + {risk => _legacy/risk}/__init__.py | 0 _legacy/risk/account_risk_lib.py | 2 + _legacy/sl_tp_guard.py | 2 + _legacy/stats_engine.py | 2 + {strategy => _legacy/strategy}/__init__.py | 0 _legacy/strategy/fib_lib.py | 2 + _legacy/strategy/strategy_db.py | 2 + _legacy/strategy/strategy_roll_lib.py | 2 + _legacy/strategy/strategy_roll_monitor_lib.py | 2 + _legacy/strategy/strategy_snapshot_lib.py | 2 + _legacy/strategy/strategy_trend_lib.py | 2 + _legacy/symbols.py | 2 + _legacy/trade_log_lib.py | 2 + _legacy/trade_notify.py | 2 + _legacy/trading_context.py | 2 + _legacy/vnpy_bridge.py | 2 + _legacy/wechat_notify.py | 2 + app.py | 3127 +++------- config/.env.example | 61 + deploy.sh | 28 +- docs/ARCHITECTURE.md | 50 + docs/DEPLOY.md | 31 +- ecosystem.config.cjs | 28 +- install_trading.py | 4687 +------------- modules/__init__.py | 3 + modules/backup/__init__.py | 5 + db_backup.py => modules/backup/db_backup.py | 805 +-- modules/backup/routes.py | 78 + modules/core/__init__.py | 8 + modules/core/bootstrap.py | 55 + .../core/contract_profile.py | 560 +- .../core/contract_specs.py | 332 +- db_conn.py => modules/core/db_conn.py | 4 +- modules/core/deps.py | 46 + doc_render.py => modules/core/doc_render.py | 0 env_file.py => modules/core/env_file.py | 18 +- locale_fix.py => modules/core/locale_fix.py | 0 modules/core/paths.py | 37 + symbols.py => modules/core/symbols.py | 1366 ++--- .../core/trading_context.py | 368 +- modules/ctp/__init__.py | 10 + .../ctp/ctp_entry_price.py | 126 +- .../ctp/ctp_fee_sync.py | 288 +- .../ctp/ctp_fee_worker.py | 262 +- .../ctp/ctp_ipc_client.py | 0 ctp_kline.py => modules/ctp/ctp_kline.py | 178 +- .../ctp/ctp_premarket_connect.py | 232 +- .../ctp/ctp_reconnect.py | 118 +- .../ctp/ctp_settings.py | 308 +- ctp_symbol.py => modules/ctp/ctp_symbol.py | 132 +- .../ctp/ctp_trade_sync.py | 674 +- .../ctp/ctp_trading_state.py | 540 +- ctp_worker.py => modules/ctp/ctp_worker.py | 988 +-- vnpy_bridge.py => modules/ctp/vnpy_bridge.py | 5408 +++++++++-------- modules/fees/__init__.py | 5 + fee_specs.py => modules/fees/fee_specs.py | 770 +-- fee_sync.py => modules/fees/fee_sync.py | 182 +- modules/fees/routes.py | 95 + modules/keys/__init__.py | 5 + .../keys/key_monitor_lib.py | 812 +-- modules/keys/routes.py | 185 + modules/market/__init__.py | 10 + .../market/kline_chart.py | 1116 ++-- .../market/kline_store.py | 0 .../market/kline_stream.py | 278 +- market.py => modules/market/market.py | 0 .../market/market_sessions.py | 0 modules/market/routes.py | 230 + modules/notify/__init__.py | 5 + ai_client.py => modules/notify/ai_client.py | 0 .../notify/ai_messages.py | 0 ai_worker.py => modules/notify/ai_worker.py | 346 +- modules/notify/routes.py | 65 + .../notify/wechat_notify.py | 0 modules/plans/__init__.py | 5 + modules/plans/routes.py | 167 + modules/records/__init__.py | 5 + modules/records/routes.py | 554 ++ modules/risk/__init__.py | 12 + {risk => modules/risk}/account_risk_lib.py | 900 +-- modules/settings/__init__.py | 5 + .../settings/admin_settings.py | 172 +- .../settings/nav_settings.py | 0 modules/settings/routes.py | 314 + modules/stats/__init__.py | 5 + .../stats/dashboard_lib.py | 576 +- modules/stats/routes.py | 174 + .../stats/stats_engine.py | 1136 ++-- modules/strategy/__init__.py | 10 + {strategy => modules/strategy}/fib_lib.py | 0 {strategy => modules/strategy}/strategy_db.py | 338 +- .../strategy}/strategy_roll_lib.py | 740 +-- .../strategy}/strategy_roll_monitor_lib.py | 316 +- .../strategy}/strategy_snapshot_lib.py | 0 .../strategy}/strategy_trend_lib.py | 466 +- modules/trading/__init__.py | 19 + modules/trading/install.py | 4685 ++++++++++++++ .../trading/order_pending.py | 568 +- .../trading/pending_order_worker.py | 164 +- .../trading/position_sizing.py | 540 +- .../trading/position_stream.py | 226 +- .../trading/product_recommend.py | 670 +- .../trading/recommend_store.py | 798 +-- .../trading/recommend_stream.py | 326 +- .../trading/recommend_trend.py | 678 +-- .../trading/sl_tp_guard.py | 2116 +++---- .../trading/trade_log_lib.py | 436 +- .../trading/trade_notify.py | 450 +- modules/web/__init__.py | 5 + modules/web/routes.py | 108 + .../web/static}/css/ai_messages.css | 0 {static => modules/web/static}/css/base.css | 0 .../web/static}/css/dashboard.css | 0 {static => modules/web/static}/css/doc.css | 0 {static => modules/web/static}/css/keys.css | 0 {static => modules/web/static}/css/mobile.css | 0 .../web/static}/css/records.css | 0 .../web/static}/css/responsive.css | 0 {static => modules/web/static}/css/tech.css | 0 {static => modules/web/static}/css/trade.css | 0 .../web/static}/icons/icon-192.png | Bin .../web/static}/icons/icon-512.png | Bin {static => modules/web/static}/icons/icon.svg | 0 {static => modules/web/static}/js/calendar.js | 0 {static => modules/web/static}/js/contract.js | 0 .../web/static}/js/dashboard.js | 0 .../web/static}/js/equity_curve.js | 0 {static => modules/web/static}/js/keys.js | 0 {static => modules/web/static}/js/lunar.js | 0 {static => modules/web/static}/js/market.js | 0 {static => modules/web/static}/js/nav.js | 0 .../web/static}/js/orientation.js | 0 {static => modules/web/static}/js/page.js | 0 {static => modules/web/static}/js/plans.js | 0 .../web/static}/js/positions.js | 0 {static => modules/web/static}/js/pwa.js | 0 {static => modules/web/static}/js/records.js | 0 {static => modules/web/static}/js/review.js | 0 {static => modules/web/static}/js/settings.js | 0 {static => modules/web/static}/js/stats.js | 0 {static => modules/web/static}/js/strategy.js | 0 {static => modules/web/static}/js/symbol.js | 0 {static => modules/web/static}/js/theme.js | 0 {static => modules/web/static}/js/trade.js | 0 {static => modules/web/static}/js/trades.js | 0 {static => modules/web/static}/manifest.json | 0 {static => modules/web/static}/sw.js | 0 .../web/templates}/ai_messages.html | 0 .../web/templates}/base.html | 0 .../web/templates}/calendar.html | 0 .../web/templates}/contract.html | 0 .../web/templates}/dashboard.html | 0 .../web/templates}/fees.html | 0 .../web/templates}/keys.html | 0 .../web/templates}/login.html | 0 .../web/templates}/market.html | 0 .../web/templates}/plans.html | 0 .../web/templates}/positions.html | 0 .../web/templates}/recommend.html | 0 .../web/templates}/records.html | 0 .../web/templates}/risk_guide.html | 0 .../web/templates}/settings.html | 0 .../web/templates}/stats.html | 0 .../web/templates}/strategy.html | 0 .../web/templates}/strategy_records.html | 0 .../web/templates}/trade.html | 0 209 files changed, 21962 insertions(+), 20963 deletions(-) create mode 100644 _legacy/admin_settings.py create mode 100644 _legacy/ai_client.py create mode 100644 _legacy/ai_messages.py create mode 100644 _legacy/ai_worker.py create mode 100644 _legacy/contract_profile.py create mode 100644 _legacy/contract_specs.py create mode 100644 _legacy/ctp_entry_price.py create mode 100644 _legacy/ctp_fee_sync.py create mode 100644 _legacy/ctp_fee_worker.py create mode 100644 _legacy/ctp_ipc_client.py create mode 100644 _legacy/ctp_kline.py create mode 100644 _legacy/ctp_premarket_connect.py create mode 100644 _legacy/ctp_reconnect.py create mode 100644 _legacy/ctp_settings.py create mode 100644 _legacy/ctp_symbol.py create mode 100644 _legacy/ctp_trade_sync.py create mode 100644 _legacy/ctp_trading_state.py create mode 100644 _legacy/ctp_worker.py create mode 100644 _legacy/dashboard_lib.py create mode 100644 _legacy/db_backup.py create mode 100644 _legacy/db_conn.py create mode 100644 _legacy/doc_render.py create mode 100644 _legacy/env_file.py create mode 100644 _legacy/fee_specs.py create mode 100644 _legacy/fee_sync.py create mode 100644 _legacy/key_monitor_lib.py create mode 100644 _legacy/kline_chart.py create mode 100644 _legacy/kline_store.py create mode 100644 _legacy/kline_stream.py create mode 100644 _legacy/locale_fix.py create mode 100644 _legacy/market.py create mode 100644 _legacy/market_sessions.py create mode 100644 _legacy/nav_settings.py create mode 100644 _legacy/order_pending.py create mode 100644 _legacy/pending_order_worker.py create mode 100644 _legacy/position_sizing.py create mode 100644 _legacy/position_stream.py create mode 100644 _legacy/product_recommend.py create mode 100644 _legacy/recommend_store.py create mode 100644 _legacy/recommend_stream.py create mode 100644 _legacy/recommend_trend.py rename {risk => _legacy/risk}/__init__.py (100%) create mode 100644 _legacy/risk/account_risk_lib.py create mode 100644 _legacy/sl_tp_guard.py create mode 100644 _legacy/stats_engine.py rename {strategy => _legacy/strategy}/__init__.py (100%) create mode 100644 _legacy/strategy/fib_lib.py create mode 100644 _legacy/strategy/strategy_db.py create mode 100644 _legacy/strategy/strategy_roll_lib.py create mode 100644 _legacy/strategy/strategy_roll_monitor_lib.py create mode 100644 _legacy/strategy/strategy_snapshot_lib.py create mode 100644 _legacy/strategy/strategy_trend_lib.py create mode 100644 _legacy/symbols.py create mode 100644 _legacy/trade_log_lib.py create mode 100644 _legacy/trade_notify.py create mode 100644 _legacy/trading_context.py create mode 100644 _legacy/vnpy_bridge.py create mode 100644 _legacy/wechat_notify.py create mode 100644 config/.env.example create mode 100644 docs/ARCHITECTURE.md create mode 100644 modules/__init__.py create mode 100644 modules/backup/__init__.py rename db_backup.py => modules/backup/db_backup.py (96%) create mode 100644 modules/backup/routes.py create mode 100644 modules/core/__init__.py create mode 100644 modules/core/bootstrap.py rename contract_profile.py => modules/core/contract_profile.py (95%) rename contract_specs.py => modules/core/contract_specs.py (96%) rename db_conn.py => modules/core/db_conn.py (99%) create mode 100644 modules/core/deps.py rename doc_render.py => modules/core/doc_render.py (100%) rename env_file.py => modules/core/env_file.py (82%) rename locale_fix.py => modules/core/locale_fix.py (100%) create mode 100644 modules/core/paths.py rename symbols.py => modules/core/symbols.py (95%) rename trading_context.py => modules/core/trading_context.py (90%) create mode 100644 modules/ctp/__init__.py rename ctp_entry_price.py => modules/ctp/ctp_entry_price.py (88%) rename ctp_fee_sync.py => modules/ctp/ctp_fee_sync.py (92%) rename ctp_fee_worker.py => modules/ctp/ctp_fee_worker.py (94%) rename ctp_ipc_client.py => modules/ctp/ctp_ipc_client.py (100%) rename ctp_kline.py => modules/ctp/ctp_kline.py (92%) rename ctp_premarket_connect.py => modules/ctp/ctp_premarket_connect.py (95%) rename ctp_reconnect.py => modules/ctp/ctp_reconnect.py (86%) rename ctp_settings.py => modules/ctp/ctp_settings.py (95%) rename ctp_symbol.py => modules/ctp/ctp_symbol.py (94%) rename ctp_trade_sync.py => modules/ctp/ctp_trade_sync.py (92%) rename ctp_trading_state.py => modules/ctp/ctp_trading_state.py (96%) rename ctp_worker.py => modules/ctp/ctp_worker.py (90%) rename vnpy_bridge.py => modules/ctp/vnpy_bridge.py (95%) create mode 100644 modules/fees/__init__.py rename fee_specs.py => modules/fees/fee_specs.py (95%) rename fee_sync.py => modules/fees/fee_sync.py (93%) create mode 100644 modules/fees/routes.py create mode 100644 modules/keys/__init__.py rename key_monitor_lib.py => modules/keys/key_monitor_lib.py (95%) create mode 100644 modules/keys/routes.py create mode 100644 modules/market/__init__.py rename kline_chart.py => modules/market/kline_chart.py (94%) rename kline_store.py => modules/market/kline_store.py (100%) rename kline_stream.py => modules/market/kline_stream.py (91%) rename market.py => modules/market/market.py (100%) rename market_sessions.py => modules/market/market_sessions.py (100%) create mode 100644 modules/market/routes.py create mode 100644 modules/notify/__init__.py rename ai_client.py => modules/notify/ai_client.py (100%) rename ai_messages.py => modules/notify/ai_messages.py (100%) rename ai_worker.py => modules/notify/ai_worker.py (91%) create mode 100644 modules/notify/routes.py rename wechat_notify.py => modules/notify/wechat_notify.py (100%) create mode 100644 modules/plans/__init__.py create mode 100644 modules/plans/routes.py create mode 100644 modules/records/__init__.py create mode 100644 modules/records/routes.py create mode 100644 modules/risk/__init__.py rename {risk => modules/risk}/account_risk_lib.py (95%) create mode 100644 modules/settings/__init__.py rename admin_settings.py => modules/settings/admin_settings.py (95%) rename nav_settings.py => modules/settings/nav_settings.py (100%) create mode 100644 modules/settings/routes.py create mode 100644 modules/stats/__init__.py rename dashboard_lib.py => modules/stats/dashboard_lib.py (93%) create mode 100644 modules/stats/routes.py rename stats_engine.py => modules/stats/stats_engine.py (96%) create mode 100644 modules/strategy/__init__.py rename {strategy => modules/strategy}/fib_lib.py (100%) rename {strategy => modules/strategy}/strategy_db.py (95%) rename {strategy => modules/strategy}/strategy_roll_lib.py (96%) rename {strategy => modules/strategy}/strategy_roll_monitor_lib.py (96%) rename {strategy => modules/strategy}/strategy_snapshot_lib.py (100%) rename {strategy => modules/strategy}/strategy_trend_lib.py (95%) create mode 100644 modules/trading/__init__.py create mode 100644 modules/trading/install.py rename order_pending.py => modules/trading/order_pending.py (94%) rename pending_order_worker.py => modules/trading/pending_order_worker.py (94%) rename position_sizing.py => modules/trading/position_sizing.py (96%) rename position_stream.py => modules/trading/position_stream.py (94%) rename product_recommend.py => modules/trading/product_recommend.py (94%) rename recommend_store.py => modules/trading/recommend_store.py (92%) rename recommend_stream.py => modules/trading/recommend_stream.py (94%) rename recommend_trend.py => modules/trading/recommend_trend.py (96%) rename sl_tp_guard.py => modules/trading/sl_tp_guard.py (94%) rename trade_log_lib.py => modules/trading/trade_log_lib.py (93%) rename trade_notify.py => modules/trading/trade_notify.py (93%) create mode 100644 modules/web/__init__.py create mode 100644 modules/web/routes.py rename {static => modules/web/static}/css/ai_messages.css (100%) rename {static => modules/web/static}/css/base.css (100%) rename {static => modules/web/static}/css/dashboard.css (100%) rename {static => modules/web/static}/css/doc.css (100%) rename {static => modules/web/static}/css/keys.css (100%) rename {static => modules/web/static}/css/mobile.css (100%) rename {static => modules/web/static}/css/records.css (100%) rename {static => modules/web/static}/css/responsive.css (100%) rename {static => modules/web/static}/css/tech.css (100%) rename {static => modules/web/static}/css/trade.css (100%) rename {static => modules/web/static}/icons/icon-192.png (100%) rename {static => modules/web/static}/icons/icon-512.png (100%) rename {static => modules/web/static}/icons/icon.svg (100%) rename {static => modules/web/static}/js/calendar.js (100%) rename {static => modules/web/static}/js/contract.js (100%) rename {static => modules/web/static}/js/dashboard.js (100%) rename {static => modules/web/static}/js/equity_curve.js (100%) rename {static => modules/web/static}/js/keys.js (100%) rename {static => modules/web/static}/js/lunar.js (100%) rename {static => modules/web/static}/js/market.js (100%) rename {static => modules/web/static}/js/nav.js (100%) rename {static => modules/web/static}/js/orientation.js (100%) rename {static => modules/web/static}/js/page.js (100%) rename {static => modules/web/static}/js/plans.js (100%) rename {static => modules/web/static}/js/positions.js (100%) rename {static => modules/web/static}/js/pwa.js (100%) rename {static => modules/web/static}/js/records.js (100%) rename {static => modules/web/static}/js/review.js (100%) rename {static => modules/web/static}/js/settings.js (100%) rename {static => modules/web/static}/js/stats.js (100%) rename {static => modules/web/static}/js/strategy.js (100%) rename {static => modules/web/static}/js/symbol.js (100%) rename {static => modules/web/static}/js/theme.js (100%) rename {static => modules/web/static}/js/trade.js (100%) rename {static => modules/web/static}/js/trades.js (100%) rename {static => modules/web/static}/manifest.json (100%) rename {static => modules/web/static}/sw.js (100%) rename {templates => modules/web/templates}/ai_messages.html (100%) rename {templates => modules/web/templates}/base.html (100%) rename {templates => modules/web/templates}/calendar.html (100%) rename {templates => modules/web/templates}/contract.html (100%) rename {templates => modules/web/templates}/dashboard.html (100%) rename {templates => modules/web/templates}/fees.html (100%) rename {templates => modules/web/templates}/keys.html (100%) rename {templates => modules/web/templates}/login.html (100%) rename {templates => modules/web/templates}/market.html (100%) rename {templates => modules/web/templates}/plans.html (100%) rename {templates => modules/web/templates}/positions.html (100%) rename {templates => modules/web/templates}/recommend.html (100%) rename {templates => modules/web/templates}/records.html (100%) rename {templates => modules/web/templates}/risk_guide.html (100%) rename {templates => modules/web/templates}/settings.html (100%) rename {templates => modules/web/templates}/stats.html (100%) rename {templates => modules/web/templates}/strategy.html (100%) rename {templates => modules/web/templates}/strategy_records.html (100%) rename {templates => modules/web/templates}/trade.html (100%) diff --git a/.env.example b/.env.example index cc498b4..86329b5 100644 --- a/.env.example +++ b/.env.example @@ -1,61 +1,2 @@ -# 服务配置 -HOST=0.0.0.0 -PORT=6600 -DEBUG=false - -SECRET_KEY=change-this-to-a-random-secret-key - -ADMIN_USERNAME=admin -ADMIN_PASSWORD=change-me-on-first-login -ADMIN_SYNC_FROM_ENV=false - -WECHAT_WEBHOOK= - -QUOTE_SOURCE=sina -THS_REFRESH_TOKEN= - -# 交易模式:simulation=SimNow,live=期货公司(系统设置页可改) -TRADING_MODE=simulation -POSITION_SIZING_MODE=risk -RISK_PERCENT=1 - -# CTP 断线后后台自动重连(true/false) -CTP_AUTO_RECONNECT=true - -# —— SimNow 模拟盘(也可在「系统设置 → CTP 连接」配置,优先于本文件)—— -SIMNOW_USER= -SIMNOW_PASSWORD= -SIMNOW_BROKER_ID=9999 -# 7×24 / 日盘前置(deploy.sh 会自动 nc 探测并写入可用线路) -SIMNOW_TD_ADDRESS=tcp://180.168.146.187:10201 -SIMNOW_MD_ADDRESS=tcp://180.168.146.187:10211 -SIMNOW_APP_ID=simnow_client_test -SIMNOW_AUTH_CODE=0000000000000000 -# SimNow 看穿式前置固定用「实盘」;仅穿透式测评才用「测试」 -SIMNOW_ENV=实盘 - -# —— 期货公司实盘(后期接入)—— -CTP_LIVE_USER= -CTP_LIVE_PASSWORD= -CTP_LIVE_BROKER_ID= -CTP_LIVE_TD_ADDRESS= -CTP_LIVE_MD_ADDRESS= -CTP_LIVE_APP_ID= -CTP_LIVE_AUTH_CODE= -CTP_LIVE_PRODUCT_INFO= - -# 账户冷静期 -RISK_CONTROL_ENABLED=true -RISK_COOLING_HOURS_MANUAL=4 -RISK_COOLING_HOURS_MANUAL_JOURNAL=1 -RISK_MANUAL_CLOSE_DAILY_LIMIT=2 -MAX_ACTIVE_POSITIONS=1 -RISK_DAILY_POSITION_LIMIT=5 -RISK_DAILY_TRADING_RISK_PCT=2 -TRADING_DAY_RESET_HOUR=8 - -# —— 数据库(生产推荐 PostgreSQL,见 docs/POSTGRES.md)—— -# 未配置 DATABASE_URL 时使用本地 SQLite futures.db -# DATABASE_URL=postgresql://qihuo:your_password@127.0.0.1:5432/qihuo -# PG_POOL_MIN=2 -# PG_POOL_MAX=20 +# 环境变量模板已迁移至 config/.env.example +# 使用: cp config/.env.example config/.env diff --git a/.gitignore b/.gitignore index 1ec9647..68f2a24 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ .env +config/.env *.db __pycache__/ *.py[cod] diff --git a/_legacy/admin_settings.py b/_legacy/admin_settings.py new file mode 100644 index 0000000..52ec248 --- /dev/null +++ b/_legacy/admin_settings.py @@ -0,0 +1,2 @@ +# Compatibility shim — use modules.settings.admin_settings +from modules.settings.admin_settings import * # noqa: F401,F403 diff --git a/_legacy/ai_client.py b/_legacy/ai_client.py new file mode 100644 index 0000000..c36e071 --- /dev/null +++ b/_legacy/ai_client.py @@ -0,0 +1,2 @@ +# Compatibility shim — use modules.notify.ai_client +from modules.notify.ai_client import * # noqa: F401,F403 diff --git a/_legacy/ai_messages.py b/_legacy/ai_messages.py new file mode 100644 index 0000000..9f61dfd --- /dev/null +++ b/_legacy/ai_messages.py @@ -0,0 +1,2 @@ +# Compatibility shim — use modules.notify.ai_messages +from modules.notify.ai_messages import * # noqa: F401,F403 diff --git a/_legacy/ai_worker.py b/_legacy/ai_worker.py new file mode 100644 index 0000000..e69a4d0 --- /dev/null +++ b/_legacy/ai_worker.py @@ -0,0 +1,2 @@ +# Compatibility shim — use modules.notify.ai_worker +from modules.notify.ai_worker import * # noqa: F401,F403 diff --git a/_legacy/contract_profile.py b/_legacy/contract_profile.py new file mode 100644 index 0000000..aab689f --- /dev/null +++ b/_legacy/contract_profile.py @@ -0,0 +1,2 @@ +# Compatibility shim — use modules.core.contract_profile +from modules.core.contract_profile import * # noqa: F401,F403 diff --git a/_legacy/contract_specs.py b/_legacy/contract_specs.py new file mode 100644 index 0000000..99b6f9f --- /dev/null +++ b/_legacy/contract_specs.py @@ -0,0 +1,2 @@ +# Compatibility shim — use modules.core.contract_specs +from modules.core.contract_specs import * # noqa: F401,F403 diff --git a/_legacy/ctp_entry_price.py b/_legacy/ctp_entry_price.py new file mode 100644 index 0000000..9959397 --- /dev/null +++ b/_legacy/ctp_entry_price.py @@ -0,0 +1,2 @@ +# Compatibility shim — use modules.ctp.ctp_entry_price +from modules.ctp.ctp_entry_price import * # noqa: F401,F403 diff --git a/_legacy/ctp_fee_sync.py b/_legacy/ctp_fee_sync.py new file mode 100644 index 0000000..2d0f774 --- /dev/null +++ b/_legacy/ctp_fee_sync.py @@ -0,0 +1,2 @@ +# Compatibility shim — use modules.ctp.ctp_fee_sync +from modules.ctp.ctp_fee_sync import * # noqa: F401,F403 diff --git a/_legacy/ctp_fee_worker.py b/_legacy/ctp_fee_worker.py new file mode 100644 index 0000000..c057e71 --- /dev/null +++ b/_legacy/ctp_fee_worker.py @@ -0,0 +1,2 @@ +# Compatibility shim — use modules.ctp.ctp_fee_worker +from modules.ctp.ctp_fee_worker import * # noqa: F401,F403 diff --git a/_legacy/ctp_ipc_client.py b/_legacy/ctp_ipc_client.py new file mode 100644 index 0000000..238a6fa --- /dev/null +++ b/_legacy/ctp_ipc_client.py @@ -0,0 +1,2 @@ +# Compatibility shim — use modules.ctp.ctp_ipc_client +from modules.ctp.ctp_ipc_client import * # noqa: F401,F403 diff --git a/_legacy/ctp_kline.py b/_legacy/ctp_kline.py new file mode 100644 index 0000000..5d71e14 --- /dev/null +++ b/_legacy/ctp_kline.py @@ -0,0 +1,2 @@ +# Compatibility shim — use modules.ctp.ctp_kline +from modules.ctp.ctp_kline import * # noqa: F401,F403 diff --git a/_legacy/ctp_premarket_connect.py b/_legacy/ctp_premarket_connect.py new file mode 100644 index 0000000..9ed7f0d --- /dev/null +++ b/_legacy/ctp_premarket_connect.py @@ -0,0 +1,2 @@ +# Compatibility shim — use modules.ctp.ctp_premarket_connect +from modules.ctp.ctp_premarket_connect import * # noqa: F401,F403 diff --git a/_legacy/ctp_reconnect.py b/_legacy/ctp_reconnect.py new file mode 100644 index 0000000..a8ce7f3 --- /dev/null +++ b/_legacy/ctp_reconnect.py @@ -0,0 +1,2 @@ +# Compatibility shim — use modules.ctp.ctp_reconnect +from modules.ctp.ctp_reconnect import * # noqa: F401,F403 diff --git a/_legacy/ctp_settings.py b/_legacy/ctp_settings.py new file mode 100644 index 0000000..6fbc300 --- /dev/null +++ b/_legacy/ctp_settings.py @@ -0,0 +1,2 @@ +# Compatibility shim — use modules.ctp.ctp_settings +from modules.ctp.ctp_settings import * # noqa: F401,F403 diff --git a/_legacy/ctp_symbol.py b/_legacy/ctp_symbol.py new file mode 100644 index 0000000..47b998a --- /dev/null +++ b/_legacy/ctp_symbol.py @@ -0,0 +1,2 @@ +# Compatibility shim — use modules.ctp.ctp_symbol +from modules.ctp.ctp_symbol import * # noqa: F401,F403 diff --git a/_legacy/ctp_trade_sync.py b/_legacy/ctp_trade_sync.py new file mode 100644 index 0000000..f5d64e0 --- /dev/null +++ b/_legacy/ctp_trade_sync.py @@ -0,0 +1,2 @@ +# Compatibility shim — use modules.ctp.ctp_trade_sync +from modules.ctp.ctp_trade_sync import * # noqa: F401,F403 diff --git a/_legacy/ctp_trading_state.py b/_legacy/ctp_trading_state.py new file mode 100644 index 0000000..cf7621f --- /dev/null +++ b/_legacy/ctp_trading_state.py @@ -0,0 +1,2 @@ +# Compatibility shim — use modules.ctp.ctp_trading_state +from modules.ctp.ctp_trading_state import * # noqa: F401,F403 diff --git a/_legacy/ctp_worker.py b/_legacy/ctp_worker.py new file mode 100644 index 0000000..426d71c --- /dev/null +++ b/_legacy/ctp_worker.py @@ -0,0 +1,2 @@ +# Compatibility shim — use modules.ctp.ctp_worker +from modules.ctp.ctp_worker import * # noqa: F401,F403 diff --git a/_legacy/dashboard_lib.py b/_legacy/dashboard_lib.py new file mode 100644 index 0000000..e5e48df --- /dev/null +++ b/_legacy/dashboard_lib.py @@ -0,0 +1,2 @@ +# Compatibility shim — use modules.stats.dashboard_lib +from modules.stats.dashboard_lib import * # noqa: F401,F403 diff --git a/_legacy/db_backup.py b/_legacy/db_backup.py new file mode 100644 index 0000000..c8a7aba --- /dev/null +++ b/_legacy/db_backup.py @@ -0,0 +1,2 @@ +# Compatibility shim — use modules.backup.db_backup +from modules.backup.db_backup import * # noqa: F401,F403 diff --git a/_legacy/db_conn.py b/_legacy/db_conn.py new file mode 100644 index 0000000..4cc5607 --- /dev/null +++ b/_legacy/db_conn.py @@ -0,0 +1,2 @@ +# Compatibility shim — use modules.core.db_conn +from modules.core.db_conn import * # noqa: F401,F403 diff --git a/_legacy/doc_render.py b/_legacy/doc_render.py new file mode 100644 index 0000000..b174e2f --- /dev/null +++ b/_legacy/doc_render.py @@ -0,0 +1,2 @@ +# Compatibility shim — use modules.core.doc_render +from modules.core.doc_render import * # noqa: F401,F403 diff --git a/_legacy/env_file.py b/_legacy/env_file.py new file mode 100644 index 0000000..cb9015e --- /dev/null +++ b/_legacy/env_file.py @@ -0,0 +1,2 @@ +# Compatibility shim — use modules.core.env_file +from modules.core.env_file import * # noqa: F401,F403 diff --git a/_legacy/fee_specs.py b/_legacy/fee_specs.py new file mode 100644 index 0000000..f379087 --- /dev/null +++ b/_legacy/fee_specs.py @@ -0,0 +1,2 @@ +# Compatibility shim — use modules.fees.fee_specs +from modules.fees.fee_specs import * # noqa: F401,F403 diff --git a/_legacy/fee_sync.py b/_legacy/fee_sync.py new file mode 100644 index 0000000..3a385f5 --- /dev/null +++ b/_legacy/fee_sync.py @@ -0,0 +1,2 @@ +# Compatibility shim — use modules.fees.fee_sync +from modules.fees.fee_sync import * # noqa: F401,F403 diff --git a/_legacy/key_monitor_lib.py b/_legacy/key_monitor_lib.py new file mode 100644 index 0000000..5629511 --- /dev/null +++ b/_legacy/key_monitor_lib.py @@ -0,0 +1,2 @@ +# Compatibility shim — use modules.keys.key_monitor_lib +from modules.keys.key_monitor_lib import * # noqa: F401,F403 diff --git a/_legacy/kline_chart.py b/_legacy/kline_chart.py new file mode 100644 index 0000000..4c04d78 --- /dev/null +++ b/_legacy/kline_chart.py @@ -0,0 +1,2 @@ +# Compatibility shim — use modules.market.kline_chart +from modules.market.kline_chart import * # noqa: F401,F403 diff --git a/_legacy/kline_store.py b/_legacy/kline_store.py new file mode 100644 index 0000000..42aaf52 --- /dev/null +++ b/_legacy/kline_store.py @@ -0,0 +1,2 @@ +# Compatibility shim — use modules.market.kline_store +from modules.market.kline_store import * # noqa: F401,F403 diff --git a/_legacy/kline_stream.py b/_legacy/kline_stream.py new file mode 100644 index 0000000..60e46e0 --- /dev/null +++ b/_legacy/kline_stream.py @@ -0,0 +1,2 @@ +# Compatibility shim — use modules.market.kline_stream +from modules.market.kline_stream import * # noqa: F401,F403 diff --git a/_legacy/locale_fix.py b/_legacy/locale_fix.py new file mode 100644 index 0000000..ff7f775 --- /dev/null +++ b/_legacy/locale_fix.py @@ -0,0 +1,2 @@ +# Compatibility shim — use modules.core.locale_fix +from modules.core.locale_fix import * # noqa: F401,F403 diff --git a/_legacy/market.py b/_legacy/market.py new file mode 100644 index 0000000..f49f557 --- /dev/null +++ b/_legacy/market.py @@ -0,0 +1,2 @@ +# Compatibility shim — use modules.market.market +from modules.market.market import * # noqa: F401,F403 diff --git a/_legacy/market_sessions.py b/_legacy/market_sessions.py new file mode 100644 index 0000000..3bb2848 --- /dev/null +++ b/_legacy/market_sessions.py @@ -0,0 +1,2 @@ +# Compatibility shim — use modules.market.market_sessions +from modules.market.market_sessions import * # noqa: F401,F403 diff --git a/_legacy/nav_settings.py b/_legacy/nav_settings.py new file mode 100644 index 0000000..a9998bc --- /dev/null +++ b/_legacy/nav_settings.py @@ -0,0 +1,2 @@ +# Compatibility shim — use modules.settings.nav_settings +from modules.settings.nav_settings import * # noqa: F401,F403 diff --git a/_legacy/order_pending.py b/_legacy/order_pending.py new file mode 100644 index 0000000..6f25f52 --- /dev/null +++ b/_legacy/order_pending.py @@ -0,0 +1,2 @@ +# Compatibility shim — use modules.trading.order_pending +from modules.trading.order_pending import * # noqa: F401,F403 diff --git a/_legacy/pending_order_worker.py b/_legacy/pending_order_worker.py new file mode 100644 index 0000000..ed6c34d --- /dev/null +++ b/_legacy/pending_order_worker.py @@ -0,0 +1,2 @@ +# Compatibility shim — use modules.trading.pending_order_worker +from modules.trading.pending_order_worker import * # noqa: F401,F403 diff --git a/_legacy/position_sizing.py b/_legacy/position_sizing.py new file mode 100644 index 0000000..dfc3b58 --- /dev/null +++ b/_legacy/position_sizing.py @@ -0,0 +1,2 @@ +# Compatibility shim — use modules.trading.position_sizing +from modules.trading.position_sizing import * # noqa: F401,F403 diff --git a/_legacy/position_stream.py b/_legacy/position_stream.py new file mode 100644 index 0000000..405be7d --- /dev/null +++ b/_legacy/position_stream.py @@ -0,0 +1,2 @@ +# Compatibility shim — use modules.trading.position_stream +from modules.trading.position_stream import * # noqa: F401,F403 diff --git a/_legacy/product_recommend.py b/_legacy/product_recommend.py new file mode 100644 index 0000000..f5499af --- /dev/null +++ b/_legacy/product_recommend.py @@ -0,0 +1,2 @@ +# Compatibility shim — use modules.trading.product_recommend +from modules.trading.product_recommend import * # noqa: F401,F403 diff --git a/_legacy/recommend_store.py b/_legacy/recommend_store.py new file mode 100644 index 0000000..c18df73 --- /dev/null +++ b/_legacy/recommend_store.py @@ -0,0 +1,2 @@ +# Compatibility shim — use modules.trading.recommend_store +from modules.trading.recommend_store import * # noqa: F401,F403 diff --git a/_legacy/recommend_stream.py b/_legacy/recommend_stream.py new file mode 100644 index 0000000..e3eec3a --- /dev/null +++ b/_legacy/recommend_stream.py @@ -0,0 +1,2 @@ +# Compatibility shim — use modules.trading.recommend_stream +from modules.trading.recommend_stream import * # noqa: F401,F403 diff --git a/_legacy/recommend_trend.py b/_legacy/recommend_trend.py new file mode 100644 index 0000000..662d960 --- /dev/null +++ b/_legacy/recommend_trend.py @@ -0,0 +1,2 @@ +# Compatibility shim — use modules.trading.recommend_trend +from modules.trading.recommend_trend import * # noqa: F401,F403 diff --git a/risk/__init__.py b/_legacy/risk/__init__.py similarity index 100% rename from risk/__init__.py rename to _legacy/risk/__init__.py diff --git a/_legacy/risk/account_risk_lib.py b/_legacy/risk/account_risk_lib.py new file mode 100644 index 0000000..6533861 --- /dev/null +++ b/_legacy/risk/account_risk_lib.py @@ -0,0 +1,2 @@ +# Compatibility shim — use modules.risk.account_risk_lib +from modules.risk.account_risk_lib import * # noqa: F401,F403 diff --git a/_legacy/sl_tp_guard.py b/_legacy/sl_tp_guard.py new file mode 100644 index 0000000..ef9ad0a --- /dev/null +++ b/_legacy/sl_tp_guard.py @@ -0,0 +1,2 @@ +# Compatibility shim — use modules.trading.sl_tp_guard +from modules.trading.sl_tp_guard import * # noqa: F401,F403 diff --git a/_legacy/stats_engine.py b/_legacy/stats_engine.py new file mode 100644 index 0000000..25fd319 --- /dev/null +++ b/_legacy/stats_engine.py @@ -0,0 +1,2 @@ +# Compatibility shim — use modules.stats.stats_engine +from modules.stats.stats_engine import * # noqa: F401,F403 diff --git a/strategy/__init__.py b/_legacy/strategy/__init__.py similarity index 100% rename from strategy/__init__.py rename to _legacy/strategy/__init__.py diff --git a/_legacy/strategy/fib_lib.py b/_legacy/strategy/fib_lib.py new file mode 100644 index 0000000..1bf3f70 --- /dev/null +++ b/_legacy/strategy/fib_lib.py @@ -0,0 +1,2 @@ +# Compatibility shim — use modules.strategy.fib_lib +from modules.strategy.fib_lib import * # noqa: F401,F403 diff --git a/_legacy/strategy/strategy_db.py b/_legacy/strategy/strategy_db.py new file mode 100644 index 0000000..ca645cc --- /dev/null +++ b/_legacy/strategy/strategy_db.py @@ -0,0 +1,2 @@ +# Compatibility shim — use modules.strategy.strategy_db +from modules.strategy.strategy_db import * # noqa: F401,F403 diff --git a/_legacy/strategy/strategy_roll_lib.py b/_legacy/strategy/strategy_roll_lib.py new file mode 100644 index 0000000..121dfc5 --- /dev/null +++ b/_legacy/strategy/strategy_roll_lib.py @@ -0,0 +1,2 @@ +# Compatibility shim — use modules.strategy.strategy_roll_lib +from modules.strategy.strategy_roll_lib import * # noqa: F401,F403 diff --git a/_legacy/strategy/strategy_roll_monitor_lib.py b/_legacy/strategy/strategy_roll_monitor_lib.py new file mode 100644 index 0000000..9a19d4d --- /dev/null +++ b/_legacy/strategy/strategy_roll_monitor_lib.py @@ -0,0 +1,2 @@ +# Compatibility shim — use modules.strategy.strategy_roll_monitor_lib +from modules.strategy.strategy_roll_monitor_lib import * # noqa: F401,F403 diff --git a/_legacy/strategy/strategy_snapshot_lib.py b/_legacy/strategy/strategy_snapshot_lib.py new file mode 100644 index 0000000..f45e60d --- /dev/null +++ b/_legacy/strategy/strategy_snapshot_lib.py @@ -0,0 +1,2 @@ +# Compatibility shim — use modules.strategy.strategy_snapshot_lib +from modules.strategy.strategy_snapshot_lib import * # noqa: F401,F403 diff --git a/_legacy/strategy/strategy_trend_lib.py b/_legacy/strategy/strategy_trend_lib.py new file mode 100644 index 0000000..55647d8 --- /dev/null +++ b/_legacy/strategy/strategy_trend_lib.py @@ -0,0 +1,2 @@ +# Compatibility shim — use modules.strategy.strategy_trend_lib +from modules.strategy.strategy_trend_lib import * # noqa: F401,F403 diff --git a/_legacy/symbols.py b/_legacy/symbols.py new file mode 100644 index 0000000..19b3664 --- /dev/null +++ b/_legacy/symbols.py @@ -0,0 +1,2 @@ +# Compatibility shim — use modules.core.symbols +from modules.core.symbols import * # noqa: F401,F403 diff --git a/_legacy/trade_log_lib.py b/_legacy/trade_log_lib.py new file mode 100644 index 0000000..44dc075 --- /dev/null +++ b/_legacy/trade_log_lib.py @@ -0,0 +1,2 @@ +# Compatibility shim — use modules.trading.trade_log_lib +from modules.trading.trade_log_lib import * # noqa: F401,F403 diff --git a/_legacy/trade_notify.py b/_legacy/trade_notify.py new file mode 100644 index 0000000..a9339d7 --- /dev/null +++ b/_legacy/trade_notify.py @@ -0,0 +1,2 @@ +# Compatibility shim — use modules.trading.trade_notify +from modules.trading.trade_notify import * # noqa: F401,F403 diff --git a/_legacy/trading_context.py b/_legacy/trading_context.py new file mode 100644 index 0000000..9c9230b --- /dev/null +++ b/_legacy/trading_context.py @@ -0,0 +1,2 @@ +# Compatibility shim — use modules.core.trading_context +from modules.core.trading_context import * # noqa: F401,F403 diff --git a/_legacy/vnpy_bridge.py b/_legacy/vnpy_bridge.py new file mode 100644 index 0000000..673a145 --- /dev/null +++ b/_legacy/vnpy_bridge.py @@ -0,0 +1,2 @@ +# Compatibility shim — use modules.ctp.vnpy_bridge +from modules.ctp.vnpy_bridge import * # noqa: F401,F403 diff --git a/_legacy/wechat_notify.py b/_legacy/wechat_notify.py new file mode 100644 index 0000000..5b831b4 --- /dev/null +++ b/_legacy/wechat_notify.py @@ -0,0 +1,2 @@ +# Compatibility shim — use modules.notify.wechat_notify +from modules.notify.wechat_notify import * # noqa: F401,F403 diff --git a/app.py b/app.py index 6177ba2..2e22f7a 100644 --- a/app.py +++ b/app.py @@ -1,2259 +1,868 @@ -# Copyright (c) 2025-2026 马建军. All rights reserved. -# 专有软件 — 未经授权禁止复制、传播、转售。 -# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 -# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md - -import os - -from locale_fix import ensure_process_locale - -ensure_process_locale() - -import time -import threading -import requests -from datetime import date, datetime, timedelta -from typing import Optional -from functools import wraps -from zoneinfo import ZoneInfo - -from werkzeug.utils import secure_filename - -from dotenv import load_dotenv -from flask import ( - Flask, render_template, request, redirect, url_for, - flash, session, jsonify, Response, stream_with_context, -) -from werkzeug.security import check_password_hash, generate_password_hash - -from functools import wraps - -from symbols import ( - search_symbols, - ths_to_codes, - list_main_contracts_grouped, - list_recommended_symbols_grouped, - refresh_main_index, -) -from contract_specs import calc_position_metrics -from fee_specs import ( - calc_fee_breakdown, - calc_round_trip_fee, - list_fee_rates_for_ui, - count_fee_rates_by_source, - purge_non_ctp_fee_rates, -) -from nav_settings import NAV_TOGGLES, get_nav_items, nav_enabled, save_nav_items -from stats_engine import ( - STATS_VIEWS, - build_all_stats, - get_calendar_day, - get_calendar_month, - load_stats_cache, - refresh_stats_cache, -) -from kline_store import ensure_kline_tables -from kline_stream import kline_hub, sse_format -from kline_chart import generate_review_kline_chart, fetch_market_klines, MARKET_PERIODS -from market import get_price as market_get_price, set_ths_refresh_token, get_quote_source_label -from db_conn import OperationalError, connect_db, database_label, is_benign_migration_error, is_db_contention_error, is_schema_migration_error, rollback_if_postgres -from admin_settings import save_admin_credentials -from db_backup import ( - backup_dir, - backup_in_progress, - default_restore_dir, - get_backup_last_at, - list_backups, - resolve_backup_file, - schedule_backup, - start_backup_worker, -) -from strategy.strategy_db import init_strategy_tables -from install_trading import install_trading -from vnpy_bridge import try_init_vnpy - -load_dotenv(os.path.join(os.path.dirname(os.path.abspath(__file__)), ".env")) - -app = Flask(__name__) -app.secret_key = os.getenv("SECRET_KEY", "futures_monitor_default_secret") - -HOST = os.getenv("HOST", "0.0.0.0") -PORT = int(os.getenv("PORT", "6600")) -DEBUG = os.getenv("DEBUG", "false").lower() in ("1", "true", "yes") - -DB_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "futures.db") -UPLOAD_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "uploads") -TZ = ZoneInfo("Asia/Shanghai") - -OPEN_TYPES = ["突破开仓", "回调开仓", "追涨杀跌", "计划内开仓", "震荡摸顶底", "其他"] -EXIT_TRIGGERS = ["止盈", "止损", "手工平仓", "移动止损", "时间离场", "其他"] -BEHAVIOR_TAGS = ["怕踏空", "报复开仓", "盈利飘了", "拿不住单", "扛单", "重仓违规"] -KLINE_PERIODS = ["1m", "3m", "5m", "15m", "30m", "1h", "4h", "1d"] -KLINE_CUTOFFS = ["平仓时间", "开仓时间", "当前时间"] - - -def today_str() -> str: - return datetime.now(TZ).date().isoformat() - - -def calc_holding_duration(open_time: str, close_time: str) -> str: - try: - o = datetime.fromisoformat(open_time.strip().replace(" ", "T")[:19]) - c = datetime.fromisoformat(close_time.strip().replace(" ", "T")[:19]) - delta = c - o - if delta.total_seconds() < 0: - return "" - secs = int(delta.total_seconds()) - h, rem = divmod(secs, 3600) - m, _ = divmod(rem, 60) - if h: - return f"{h}小时{m}分钟" - return f"{m}分钟" - except Exception: - return "" - - -def holding_to_minutes(open_time: str, close_time: str) -> int: - try: - o = datetime.fromisoformat(open_time.strip().replace(" ", "T")) - c = datetime.fromisoformat(close_time.strip().replace(" ", "T")) - secs = int((c - o).total_seconds()) - return max(0, secs // 60) - except Exception: - return 0 - - -def classify_close_result(direction: str, close: float, sl: float, tp: float) -> str: - """根据平仓价与止损/止盈距离判断结果。""" - if close is None: - return "手动平仓" - tol = max(abs(close) * 0.002, 1.0) - if abs(close - tp) <= tol: - return "止盈" - if abs(close - sl) <= tol: - return "止损" - return "手动平仓" - - -def calc_rr_ratio(direction: str, entry: float, stop: float, target: float) -> Optional[float]: - """盈亏比 = 盈利空间 / 风险空间。""" - if entry is None or stop is None or target is None: - return None - if direction == "long": - risk = entry - stop - if risk <= 0: - return None - return round((target - entry) / risk, 2) - if direction == "short": - risk = stop - entry - if risk <= 0: - return None - return round((entry - target) / risk, 2) - return None - - -def calc_theoretical_pnl(direction: str, entry: float, target: float, lots: float) -> Optional[float]: - if entry is None or target is None or lots is None: - return None - if direction == "long": - return round((target - entry) * lots, 2) - if direction == "short": - return round((entry - target) * lots, 2) - return None - - -def parse_review_date_filter(preset: str, start: str, end: str) -> tuple[str, str]: - today = datetime.now(TZ).date() - if preset == "today": - s = today.isoformat() - return s, s - if preset == "week": - monday = today - timedelta(days=today.weekday()) - return monday.isoformat(), today.isoformat() - if preset == "month": - return today.replace(day=1).isoformat(), today.isoformat() - return start.strip(), end.strip() - - -def expire_old_plans(): - """当日结束后计划自动失效,保留历史。""" - today = today_str() - conn = get_db() - conn.execute( - "UPDATE order_plans SET status='expired' WHERE plan_date < ? AND status IN ('planned', 'active')", - (today,), - ) - conn.execute( - "UPDATE order_plans SET plan_date=date(created_at) WHERE plan_date IS NULL OR plan_date=''" - ) - conn.commit() - conn.close() - - -def get_db(): - return connect_db() - - -def get_setting(key: str, default: str = "") -> str: - conn = get_db() - row = conn.execute("SELECT value FROM settings WHERE key=?", (key,)).fetchone() - conn.close() - return row["value"] if row else default - - -def set_setting(key: str, value: str): - conn = get_db() - conn.execute( - "INSERT INTO settings (key, value) VALUES (?, ?) ON CONFLICT(key) DO UPDATE SET value=?", - (key, value, value), - ) - conn.commit() - conn.close() - - -def require_nav(key: str): - """导航项关闭时拒绝访问对应页面。""" - def decorator(f): - @wraps(f) - def wrapped(*args, **kwargs): - if not nav_enabled(get_setting, key): - flash("该页面已在系统设置中关闭") - return redirect(url_for("positions")) - return f(*args, **kwargs) - return wrapped - return decorator - - -def _static_asset_v() -> str: - base = os.path.dirname(os.path.abspath(__file__)) - rels = ( - "static/js/trade.js", - "static/js/dashboard.js", - "static/js/orientation.js", - "static/css/records.css", - "static/js/records.js", - "static/js/settings.js", - "static/css/mobile.css", - "static/css/responsive.css", - "static/css/trade.css", - "static/css/dashboard.css", - "static/css/doc.css", - "static/css/base.css", - ) - mtimes = [] - for rel in rels: - path = os.path.join(base, rel.replace("/", os.sep)) - if os.path.isfile(path): - mtimes.append(os.path.getmtime(path)) - return str(int(max(mtimes))) if mtimes else "0" - - -def _ua_is_phone(ua: str) -> bool: - ua_l = (ua or "").lower() - if "ipad" in ua_l: - return False - if "android" in ua_l and "mobile" not in ua_l: - return False - if any(x in ua_l for x in ("iphone", "ipod", "windows phone", "iemobile")): - return True - if "android" in ua_l and "mobile" in ua_l: - return True - if "mobile" in ua_l or "harmonyos" in ua_l or "openharmony" in ua_l: - return True - return False - - -@app.context_processor -def inject_globals(): - return {"nav_items": get_nav_items(get_setting), "asset_v": _static_asset_v()} - - -def _trading_mode() -> str: - return (get_setting("trading_mode", "simulation") or "simulation").strip() - - -def touch_stats_cache(): - try: - conn = get_db() - capital = float(get_setting("live_capital", "0") or 0) - refresh_stats_cache(conn, capital) - conn.close() - except Exception as exc: - app.logger.warning("stats cache refresh failed: %s", exc) - - -def get_stats_data() -> dict: - conn = get_db() - try: - capital = float(get_setting("live_capital", "0") or 0) - data = load_stats_cache(conn) - if data: - return data - try: - return refresh_stats_cache(conn, capital) - except OperationalError as exc: - if not is_db_contention_error(exc): - raise - app.logger.warning("stats cache refresh contention, compute without save: %s", exc) - return build_all_stats(conn, capital) - finally: - conn.close() - - -def init_db(): - import strategy.strategy_db as strategy_db - import risk.account_risk_lib as account_risk_lib - - strategy_db._TABLES_READY = False - account_risk_lib._SCHEMA_READY = False - - conn = get_db() - c = conn.cursor() - c.execute("CREATE TABLE IF NOT EXISTS settings (key TEXT PRIMARY KEY, value TEXT)") - c.execute('''CREATE TABLE IF NOT EXISTS order_plans - (id INTEGER PRIMARY KEY AUTOINCREMENT, - symbol TEXT, symbol_name TEXT, direction TEXT, - zone_upper REAL, zone_lower REAL, - stop_loss REAL, take_profit REAL, - status TEXT DEFAULT 'planned', - triggered_at TIMESTAMP, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)''') - c.execute('''CREATE TABLE IF NOT EXISTS key_monitors - (id INTEGER PRIMARY KEY AUTOINCREMENT, - symbol TEXT, symbol_name TEXT, monitor_type TEXT, direction TEXT, - upper REAL, lower REAL, - upper_triggered INTEGER DEFAULT 0, - lower_triggered INTEGER DEFAULT 0, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)''') - c.execute('''CREATE TABLE IF NOT EXISTS trade_records - (id INTEGER PRIMARY KEY AUTOINCREMENT, - symbol TEXT, symbol_name TEXT, monitor_type TEXT, direction TEXT, - trigger_price REAL, stop_loss REAL, take_profit REAL, - result TEXT, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)''') - conn.commit() - migrations = [ - "ALTER TABLE key_monitors ADD COLUMN symbol_name TEXT", - "ALTER TABLE key_monitors ADD COLUMN upper_triggered INTEGER DEFAULT 0", - "ALTER TABLE key_monitors ADD COLUMN lower_triggered INTEGER DEFAULT 0", - "ALTER TABLE trade_records ADD COLUMN symbol_name TEXT", - "ALTER TABLE order_plans ADD COLUMN sina_code TEXT", - "ALTER TABLE order_plans ADD COLUMN market_code TEXT", - "ALTER TABLE key_monitors ADD COLUMN market_code TEXT", - "ALTER TABLE key_monitors ADD COLUMN sina_code TEXT", - "ALTER TABLE trade_records ADD COLUMN market_code TEXT", - "ALTER TABLE order_plans ADD COLUMN plan_date TEXT", - "ALTER TABLE order_plans ADD COLUMN decision_reason TEXT", - "ALTER TABLE key_monitors ADD COLUMN status TEXT DEFAULT 'active'", - "ALTER TABLE key_monitors ADD COLUMN archived_at TEXT", - "ALTER TABLE key_monitors ADD COLUMN trade_mode TEXT DEFAULT '顺势'", - "ALTER TABLE key_monitors ADD COLUMN risk_reward REAL DEFAULT 2", - "ALTER TABLE key_monitors ADD COLUMN trailing_be INTEGER DEFAULT 0", - "ALTER TABLE key_monitors ADD COLUMN last_trigger_bar TEXT", - "ALTER TABLE key_monitors ADD COLUMN alert_push_count INTEGER DEFAULT 0", - "ALTER TABLE key_monitors ADD COLUMN alert_last_push_at TEXT", - "ALTER TABLE key_monitors ADD COLUMN alert_break_side TEXT", - "ALTER TABLE key_monitors ADD COLUMN breakout_bar_time TEXT", - "ALTER TABLE key_monitors ADD COLUMN alert_close_price REAL", - "ALTER TABLE key_monitors ADD COLUMN bar_period TEXT DEFAULT '5m'", - "ALTER TABLE review_records ADD COLUMN direction TEXT", - "ALTER TABLE review_records ADD COLUMN entry_price REAL", - "ALTER TABLE review_records ADD COLUMN stop_loss REAL", - "ALTER TABLE review_records ADD COLUMN take_profit REAL", - "ALTER TABLE review_records ADD COLUMN close_price REAL", - "ALTER TABLE review_records ADD COLUMN lots REAL", - "ALTER TABLE review_records ADD COLUMN holding_duration TEXT", - "ALTER TABLE review_records ADD COLUMN initial_pnl REAL", - "ALTER TABLE review_records ADD COLUMN actual_pnl REAL", - "ALTER TABLE review_records ADD COLUMN is_emotion INTEGER DEFAULT 0", - "ALTER TABLE review_records ADD COLUMN symbol_name TEXT", - "ALTER TABLE review_records ADD COLUMN market_code TEXT", - "ALTER TABLE review_records ADD COLUMN sina_code TEXT", - "ALTER TABLE trade_logs ADD COLUMN fee REAL", - "ALTER TABLE trade_logs ADD COLUMN pnl_net REAL", - "ALTER TABLE trade_logs ADD COLUMN margin_pct REAL", - "ALTER TABLE trade_logs ADD COLUMN equity_after REAL", - "ALTER TABLE review_records ADD COLUMN fee REAL", - "ALTER TABLE review_records ADD COLUMN pnl_net REAL", - ] - for sql in migrations: - try: - c.execute(sql) - conn.commit() - except Exception as exc: - if not is_schema_migration_error(exc): - raise - rollback_if_postgres(conn) - c.execute('''CREATE TABLE IF NOT EXISTS review_records - (id INTEGER PRIMARY KEY AUTOINCREMENT, - open_time TEXT, close_time TEXT, - symbol TEXT, timeframe TEXT, - pnl REAL, - open_type TEXT, expected_rr REAL, actual_rr REAL, - exit_trigger TEXT, exit_supplement TEXT, - watch_after_breakeven TEXT, new_position_while_occupied TEXT, - screenshot TEXT, - auto_kline INTEGER DEFAULT 0, - kline_period1 TEXT, kline_period2 TEXT, - kline_count INTEGER, kline_cutoff TEXT, - behavior_tags TEXT, notes TEXT, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)''') - c.execute('''CREATE TABLE IF NOT EXISTS position_monitors - (id INTEGER PRIMARY KEY AUTOINCREMENT, - symbol TEXT, symbol_name TEXT, market_code TEXT, sina_code TEXT, - direction TEXT, lots REAL, entry_price REAL, - stop_loss REAL, take_profit REAL, open_time TEXT, - status TEXT DEFAULT 'active', - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)''') - c.execute('''CREATE TABLE IF NOT EXISTS trade_logs - (id INTEGER PRIMARY KEY AUTOINCREMENT, - symbol TEXT, symbol_name TEXT, market_code TEXT, sina_code TEXT, - monitor_type TEXT, direction TEXT, - entry_price REAL, stop_loss REAL, take_profit REAL, close_price REAL, - lots REAL, margin REAL, holding_minutes INTEGER, - open_time TEXT, close_time TEXT, - pnl REAL, result TEXT, - verified INTEGER DEFAULT 0, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)''') - c.execute('''CREATE TABLE IF NOT EXISTS fee_rates - (product TEXT PRIMARY KEY, - exchange TEXT, - mult INTEGER, - open_fixed REAL DEFAULT 0, - open_ratio REAL DEFAULT 0, - close_yesterday_fixed REAL DEFAULT 0, - close_yesterday_ratio REAL DEFAULT 0, - close_today_fixed REAL DEFAULT 0, - close_today_ratio REAL DEFAULT 0, - updated_at TEXT)''') - c.execute('''CREATE TABLE IF NOT EXISTS stats_cache - (key TEXT PRIMARY KEY, - data_json TEXT NOT NULL, - updated_at TEXT NOT NULL)''') - conn.commit() - for sql in ( - "ALTER TABLE fee_rates ADD COLUMN source TEXT DEFAULT 'local'", - ): - try: - c.execute(sql) - conn.commit() - except Exception as exc: - if not is_schema_migration_error(exc): - raise - rollback_if_postgres(conn) - ensure_kline_tables(conn) - init_strategy_tables(conn) - from risk.account_risk_lib import ensure_account_risk_schema - from recommend_store import ensure_recommend_tables - - ensure_account_risk_schema(conn) - ensure_recommend_tables(conn) - from ai_messages import ensure_ai_messages_table - - ensure_ai_messages_table(conn) - conn.commit() - conn.close() - - sync_admin_from_env() - - if not get_setting("wechat_webhook") and os.getenv("WECHAT_WEBHOOK"): - set_setting("wechat_webhook", os.getenv("WECHAT_WEBHOOK")) - - if not get_setting("ths_refresh_token") and os.getenv("THS_REFRESH_TOKEN"): - set_setting("ths_refresh_token", os.getenv("THS_REFRESH_TOKEN")) - - from ctp_settings import seed_ctp_settings_from_env - seed_ctp_settings_from_env(set_setting) - - os.makedirs(UPLOAD_DIR, exist_ok=True) - expire_old_plans() - - if not get_setting("fee_multiplier"): - set_setting("fee_multiplier", "2") - if not get_setting("trading_mode"): - set_setting("trading_mode", "simulation") - if not get_setting("position_sizing_mode"): - set_setting("position_sizing_mode", "fixed") - if not get_setting("fixed_lots"): - set_setting("fixed_lots", "1") - if not get_setting("fixed_amount"): - set_setting("fixed_amount", "5000") - if not get_setting("risk_percent"): - set_setting("risk_percent", "1") - if not get_setting("max_margin_pct"): - set_setting("max_margin_pct", "30") - if not get_setting("roll_max_margin_pct"): - set_setting("roll_max_margin_pct", "50") - if not get_setting("trailing_be_tick_buffer"): - set_setting("trailing_be_tick_buffer", "2") - if not get_setting("pending_order_timeout_min"): - set_setting("pending_order_timeout_min", "5") - if not get_setting("ai_enabled"): - set_setting("ai_enabled", "0") - if not get_setting("ai_provider"): - set_setting("ai_provider", "ollama") - if not get_setting("ai_ollama_base_url"): - set_setting("ai_ollama_base_url", "http://127.0.0.1:11434") - if not get_setting("ai_ollama_model"): - set_setting("ai_ollama_model", "qwen2.5:7b") - if not get_setting("ai_openai_base_url"): - set_setting("ai_openai_base_url", "https://api.openai.com/v1") - if not get_setting("ai_openai_model"): - set_setting("ai_openai_model", "gpt-4o-mini") - if not get_setting("ai_daily_report_enabled"): - set_setting("ai_daily_report_enabled", "1") - if not get_setting("ai_daily_report_hour"): - set_setting("ai_daily_report_hour", "15") - if not get_setting("ai_daily_report_minute"): - set_setting("ai_daily_report_minute", "5") - if not get_setting("backup_auto_enabled"): - set_setting("backup_auto_enabled", "1") - if not get_setting("backup_auto_hour"): - set_setting("backup_auto_hour", "3") - if not get_setting("backup_keep_count"): - set_setting("backup_keep_count", "30") - if not get_setting("fee_source_mode"): - set_setting("fee_source_mode", "ctp") - set_setting("fee_source_mode", "ctp") - try: - purge_non_ctp_fee_rates() - except Exception: - pass - - -def sync_admin_from_env(): - """ - 从 .env 同步管理员账号。 - - 首次建库:自动写入 ADMIN_USERNAME / ADMIN_PASSWORD - - 已建库后改 .env:需设 ADMIN_SYNC_FROM_ENV=true 并重启服务 - """ - sync = os.getenv("ADMIN_SYNC_FROM_ENV", "false").lower() in ("1", "true", "yes") - env_username = os.getenv("ADMIN_USERNAME", "").strip() - env_password = os.getenv("ADMIN_PASSWORD", "").strip() - placeholder_passwords = {"", "change-me-on-first-login", "admin123"} - - if not get_setting("admin_username"): - username = env_username or "admin" - password = env_password if env_password not in placeholder_passwords else "admin123" - set_setting("admin_username", username) - set_setting("admin_password_hash", generate_password_hash(password)) - return - - if not sync: - return - - if env_username: - set_setting("admin_username", env_username) - if env_password and env_password not in placeholder_passwords: - set_setting("admin_password_hash", generate_password_hash(env_password)) - - -if os.getenv("QIHUO_SKIP_INIT_DB") != "1": - init_db() - app.logger.info("数据库: %s", database_label()) - - -def sync_ths_token(): - set_ths_refresh_token(get_setting("ths_refresh_token")) - - -if os.getenv("QIHUO_INIT_ONLY") != "1": - sync_ths_token() - - -def build_market_quote_payload( - symbol: str, - market_code: str = "", - sina_code: str = "", - *, - prefer_sina: bool = False, -) -> dict: - if not market_code or not sina_code: - codes = ths_to_codes(symbol) - if codes: - market_code = codes.get("market_code", "") or market_code - sina_code = codes.get("sina_code", "") or sina_code - quote_source = "sina" - price = None - prev_close = None - if not prefer_sina: - try: - from vnpy_bridge import ctp_status, ctp_get_tick_detail - from trading_context import get_trading_mode - - mode = get_trading_mode(get_setting) - if ctp_status(mode).get("connected"): - detail = ctp_get_tick_detail(mode, symbol) - if detail.get("price"): - price = detail["price"] - quote_source = "ctp" - if detail.get("pre_close") is not None: - prev_close = detail["pre_close"] - except Exception: - pass - if price is None: - price = fetch_price(symbol, market_code, sina_code) - name = symbol - codes = ths_to_codes(symbol) - if codes: - name = codes.get("name", symbol) - if prev_close is None and sina_code: - from market import fetch_raw_for_volume - raw = fetch_raw_for_volume(sina_code) - if raw and raw.get("prev_close") is not None: - prev_close = raw["prev_close"] - return { - "symbol": symbol, - "name": name, - "price": price, - "prev_close": prev_close, - "quote_source": quote_source, - } - - -# —————————————— 推送 —————————————— - -def send_wechat_msg(content: str): - webhook = get_setting("wechat_webhook") - if not webhook: - return - full = f"【国内期货】\n{content}" - data = {"msgtype": "text", "text": {"content": full}} - try: - requests.post(webhook, json=data, timeout=10) - except Exception: - pass - -# —————————————— 行情 —————————————— - -def resolve_market_codes(ths_code: str, market_code: str = "", sina_code: str = "") -> tuple[str, str]: - """返回 (market_code, sina_code) 用于行情拉取。""" - if market_code: - return market_code, sina_code - if sina_code and "." in sina_code: - return sina_code, "" - codes = ths_to_codes(ths_code) - if codes: - return codes["market_code"], codes["sina_code"] - if ths_code.startswith("nf_") or ths_code.startswith("CFF_RE_"): - return ths_code, ths_code - return "", sina_code or "" - - -def fetch_price(ths_code: str, market_code: str = "", sina_code: str = "") -> Optional[float]: - sym = (ths_code or "").strip() - if sym: - try: - from vnpy_bridge import ctp_status, ctp_get_tick_price - from trading_context import get_trading_mode - - mode = get_trading_mode(get_setting) - if ctp_status(mode).get("connected"): - p = ctp_get_tick_price(mode, sym) - if p and p > 0: - return p - except Exception: - pass - mc, sc = resolve_market_codes(sym, market_code, sina_code) - if not mc and not sc: - return None - return market_get_price(mc, sc) - -# —————————————— 监控逻辑 —————————————— - -def check_order_plans(): - expire_old_plans() - today = today_str() - conn = get_db() - rows = conn.execute( - "SELECT * FROM order_plans WHERE plan_date=? AND status IN ('planned', 'active')", - (today,), - ).fetchall() - - for r in rows: - sym = r["symbol"] - sina = r["sina_code"] if "sina_code" in r.keys() else "" - market = r["market_code"] if "market_code" in r.keys() else "" - p = fetch_price(sym, market, sina) - if not p: - continue - - direction = r["direction"] - zone_upper = r["zone_upper"] - zone_lower = r["zone_lower"] - stop_loss = r["stop_loss"] - take_profit = r["take_profit"] - status = r["status"] - pid = r["id"] - name = r["symbol_name"] or sym - reason = r["decision_reason"] if "decision_reason" in r.keys() and r["decision_reason"] else "—" - - # 计划状态:价格进入决策区间则激活并通知 - if status == "planned": - in_zone = zone_lower <= p <= zone_upper - if in_zone: - msg = ( - f"【开单计划触发】{name} ({sym})\n" - f"方向:{'做多' if direction == 'long' else '做空'}\n" - f"决策区间:{zone_lower} ~ {zone_upper}\n" - f"决策理由:{reason}\n" - f"当前价:{p}\n" - f"止损:{stop_loss} 止盈:{take_profit}" - ) - send_wechat_msg(msg) - conn.execute( - "UPDATE order_plans SET status='active', triggered_at=? WHERE id=?", - (datetime.now().isoformat(), pid), - ) - status = "active" - - # 激活状态:监控止盈止损 - if status == "active": - res = None - if direction == "long": - if p >= take_profit: - res = "止盈" - elif p <= stop_loss: - res = "止损" - elif direction == "short": - if p <= take_profit: - res = "止盈" - elif p >= stop_loss: - res = "止损" - - if res: - msg = ( - f"[{'做多' if direction == 'long' else '做空'}] {name} 已{res}\n" - f"决策区间:{zone_lower} ~ {zone_upper}\n" - f"止损:{stop_loss} 止盈:{take_profit}\n" - f"当前价:{p}" - ) - send_wechat_msg(msg) - conn.execute( - """INSERT INTO trade_records - (symbol, symbol_name, monitor_type, direction, - trigger_price, stop_loss, take_profit, result) - VALUES (?,?,?,?,?,?,?,?)""", - (sym, name, "开单计划", direction, p, stop_loss, take_profit, res), - ) - conn.execute( - "UPDATE order_plans SET status='closed' WHERE id=?", (pid,) - ) - - conn.commit() - conn.close() - - -def check_key_monitors(): - from db_conn import DB_PATH - from key_monitor_lib import run_key_monitor_check - from trading_context import get_trading_mode - - conn = get_db() - try: - execute_fn = getattr(app, "_execute_key_breakout", None) - run_key_monitor_check( - conn, - db_path=DB_PATH, - get_trading_mode_fn=lambda: get_trading_mode(get_setting), - send_wechat=send_wechat_msg, - execute_breakout_fn=execute_fn, - ) - conn.commit() - finally: - conn.close() - - -def background_task(): - while True: - try: - expire_old_plans() - check_key_monitors() - fn_roll = getattr(app, "_check_roll_monitors", None) - if fn_roll: - fn_roll() - check_order_plans() - fn = getattr(app, "_check_trend_plans", None) - if fn: - fn(app) - except Exception: - pass - time.sleep(3) - - -def start_background_threads(): - from trading_context import get_trading_mode - - threading.Thread(target=background_task, daemon=True).start() - threading.Thread( - target=lambda: kline_hub.worker_loop( - DB_PATH, - lambda sym, mc, sc: build_market_quote_payload( - sym, mc, sc, prefer_sina=True, - ), - get_mode_fn=lambda: get_trading_mode(get_setting), - ), - daemon=True, - ).start() - threading.Thread(target=refresh_main_index, daemon=True).start() - start_backup_worker(get_setting_fn=get_setting, set_setting_fn=set_setting) - - -# —————————————— 登录 —————————————— - -def login_required(f): - @wraps(f) - def wrap(*args, **kwargs): - if not session.get("logged_in"): - return redirect(url_for("login")) - return f(*args, **kwargs) - return wrap - - -@app.route("/") -def index(): - if session.get("logged_in"): - return redirect(url_for("positions")) - return redirect(url_for("login")) - - -@app.route("/manifest.webmanifest") -def web_manifest(): - import json - - manifest_path = os.path.join(app.static_folder, "manifest.json") - with open(manifest_path, encoding="utf-8") as fh: - data = json.load(fh) - if _ua_is_phone(request.headers.get("User-Agent", "")): - data["orientation"] = "portrait-primary" - else: - data["orientation"] = "any" - response = app.make_response(json.dumps(data, ensure_ascii=False)) - response.mimetype = "application/manifest+json" - response.headers["Cache-Control"] = "no-cache" - return response - - -@app.route("/sw.js") -def service_worker(): - response = app.send_static_file("sw.js") - response.headers["Cache-Control"] = "no-cache" - response.headers["Service-Worker-Allowed"] = "/" - return response - - -@app.route("/login", methods=["GET", "POST"]) -def login(): - if request.method == "POST": - u = request.form.get("username", "").strip() - p = request.form.get("password", "") - admin_u = get_setting("admin_username") - admin_hash = get_setting("admin_password_hash") - if u == admin_u and check_password_hash(admin_hash, p): - session["logged_in"] = True - session["username"] = u - return redirect(url_for("positions")) - flash("账号或密码错误") - return render_template("login.html") - - -@app.route("/logout") -def logout(): - session.clear() - return redirect(url_for("login")) - -# —————————————— API —————————————— - -@app.route("/api/symbols/search") -@login_required -def api_symbol_search(): - q = request.args.get("q", "") - conn = get_db() - try: - from trading_context import get_account_capital, is_ctp_connected - capital = get_account_capital(conn, get_setting) - ctp_connected = is_ctp_connected(get_setting) - finally: - conn.close() - return jsonify(search_symbols(q, capital=capital, ctp_connected=ctp_connected)) - - -@app.route("/api/symbols/mains") -@login_required -def api_symbols_mains(): - return jsonify(list_main_contracts_grouped()) - - -@app.route("/api/symbols/recommended") -@login_required -def api_symbols_recommended(): - """品种下拉:仅展示当前资金下可开仓品种(与下方可开仓品种表一致)。""" - from recommend_store import recommend_payload - from trading_context import ( - get_fixed_lots, - get_max_margin_pct, - get_recommend_capital, - get_sizing_mode, - get_trading_mode, - ) - - conn = get_db() - try: - capital = get_recommend_capital(conn, get_setting) - payload = recommend_payload( - conn, - live_capital=capital, - max_margin_pct=get_max_margin_pct(get_setting), - trading_mode=get_trading_mode(get_setting), - sizing_mode=get_sizing_mode(get_setting), - fixed_lots=get_fixed_lots(get_setting), - ) - return jsonify(list_recommended_symbols_grouped(payload.get("rows") or [])) - finally: - conn.close() - - -@app.route("/api/key_prices") -@login_required -def api_key_prices(): - """关键位监控列表:批量现价与距上/下沿距离。""" - conn = get_db() - rows = conn.execute( - "SELECT id, symbol, market_code, sina_code, upper, lower " - "FROM key_monitors WHERE status='active' OR status IS NULL" - ).fetchall() - conn.close() - out = [] - for r in rows: - sym = r["symbol"] - market = r["market_code"] or "" - sina = r["sina_code"] or "" - upper = float(r["upper"]) - lower = float(r["lower"]) - price = fetch_price(sym, market, sina) - dist_upper = None - dist_lower = None - if price is not None: - dist_upper = round(upper - price, 2) - dist_lower = round(price - lower, 2) - out.append({ - "id": r["id"], - "price": price, - "dist_upper": dist_upper, - "dist_lower": dist_lower, - }) - return jsonify(out) - - -@app.route("/api/plan_prices") -@login_required -def api_plan_prices(): - """今日计划:批量现价与距决策区间上/下沿距离。""" - today = today_str() - conn = get_db() - rows = conn.execute( - "SELECT id, symbol, market_code, sina_code, zone_upper, zone_lower " - "FROM order_plans WHERE plan_date=? AND status IN ('planned', 'active')", - (today,), - ).fetchall() - conn.close() - out = [] - for r in rows: - sym = r["symbol"] - market = r["market_code"] or "" - sina = r["sina_code"] or "" - upper = float(r["zone_upper"]) - lower = float(r["zone_lower"]) - price = fetch_price(sym, market, sina) - dist_upper = None - dist_lower = None - in_zone = False - if price is not None: - dist_upper = round(upper - price, 2) - dist_lower = round(price - lower, 2) - in_zone = lower <= price <= upper - out.append({ - "id": r["id"], - "price": price, - "dist_upper": dist_upper, - "dist_lower": dist_lower, - "in_zone": in_zone, - }) - return jsonify(out) - - -@app.route("/api/position_live") -@login_required -def api_position_live(): - capital = float(get_setting("live_capital", "0") or 0) - now_iso = datetime.now(TZ).strftime("%Y-%m-%dT%H:%M") - conn = get_db() - rows = conn.execute( - "SELECT * FROM position_monitors WHERE status='active' ORDER BY id DESC" - ).fetchall() - conn.close() - out = [] - for r in rows: - sym = r["symbol"] - market = r["market_code"] or "" - sina = r["sina_code"] or "" - direction = r["direction"] - entry = float(r["entry_price"]) - sl = float(r["stop_loss"]) - tp = float(r["take_profit"]) - lots = float(r["lots"] or 1) - mark = fetch_price(sym, market, sina) - metrics = calc_position_metrics( - direction, entry, sl, tp, lots, mark, capital, sym, - ) - holding = calc_holding_duration(r["open_time"] or "", now_iso) - close_est = mark if mark is not None else entry - fee_info = calc_fee_breakdown( - sym, entry, close_est, lots, r["open_time"] or "", now_iso, - trading_mode=_trading_mode(), - ) - est_net = None - if metrics.get("float_pnl") is not None: - est_net = round(metrics["float_pnl"] - fee_info["total_fee"], 2) - out.append({ - "id": r["id"], - "symbol": r["symbol_name"] or sym, - "symbol_code": sym, - "direction": "做多" if direction == "long" else "做空", - "lots": lots, - "entry_price": entry, - "stop_loss": sl, - "take_profit": tp, - "open_time": r["open_time"], - "mark_price": mark, - "holding_duration": holding, - "est_fee": fee_info["total_fee"], - "est_fee_open": fee_info["open_fee"], - "est_fee_close": fee_info["close_fee"], - "est_fee_close_type": fee_info["close_type"], - "est_pnl_net": est_net, - **metrics, - }) - return jsonify(out) - - -@app.route("/plans") -@login_required -@require_nav("plans") -def plans(): - today = today_str() - start = request.args.get("start", "") - end = request.args.get("end", "") - - conn = get_db() - plan_list = conn.execute( - "SELECT * FROM order_plans WHERE plan_date=? AND status IN ('planned', 'active') ORDER BY id DESC", - (today,), - ).fetchall() - - sql = "SELECT * FROM order_plans WHERE plan_date < ? OR status IN ('closed', 'expired')" - params: list = [today] - if start: - sql += " AND plan_date >= ?" - params.append(start) - if end: - sql += " AND plan_date <= ?" - params.append(end) - sql += " ORDER BY plan_date DESC, id DESC LIMIT 200" - history = conn.execute(sql, params).fetchall() - conn.close() - return render_template( - "plans.html", - plans=plan_list, - history=history, - today=today, - start=start, - end=end, - ) - - -@app.route("/add_plan", methods=["POST"]) -@login_required -def add_plan(): - d = request.form - direction = d.get("direction") - symbol = d.get("symbol", "").strip() - symbol_name = d.get("symbol_name", "").strip() - market_code = d.get("market_code", "").strip() - sina_code = d.get("sina_code", "").strip() - if not direction: - flash("请选择多空方向") - return redirect(url_for("plans")) - if not symbol or not market_code: - flash("请从下拉列表选择品种(同花顺合约代码)") - return redirect(url_for("plans")) - conn = get_db() - conn.execute( - """INSERT INTO order_plans - (symbol, symbol_name, market_code, sina_code, direction, - zone_upper, zone_lower, stop_loss, take_profit, plan_date, decision_reason) - VALUES (?,?,?,?,?,?,?,?,?,?,?)""", - ( - symbol, symbol_name, market_code, sina_code, direction, - float(d["zone_upper"]), float(d["zone_lower"]), - float(d["stop_loss"]), float(d["take_profit"]), - today_str(), - d.get("decision_reason", "").strip(), - ), - ) - conn.commit() - conn.close() - flash("开单计划已添加") - return redirect(url_for("plans")) - - -@app.route("/del_plan/") -@login_required -def del_plan(pid): - conn = get_db() - conn.execute("DELETE FROM order_plans WHERE id=?", (pid,)) - conn.commit() - conn.close() - flash("已删除") - return redirect(url_for("plans")) - - -@app.route("/ai") -@login_required -@require_nav("ai") -def ai_messages_page(): - from ai_messages import list_ai_messages - - conn = get_db() - try: - messages = list_ai_messages(conn, limit=100) - finally: - conn.close() - return render_template("ai_messages.html", messages=messages) - - -@app.route("/keys") -@login_required -def keys(): - from key_monitor_lib import key_monitor_periods - - conn = get_db() - key_list = conn.execute( - "SELECT * FROM key_monitors WHERE status='active' OR status IS NULL ORDER BY id DESC" - ).fetchall() - history = conn.execute( - "SELECT * FROM key_monitors WHERE status='archived' ORDER BY archived_at DESC LIMIT 100" - ).fetchall() - conn.close() - return render_template( - "keys.html", - keys=key_list, - history=history, - key_periods=key_monitor_periods(), - ) - - - -@app.route("/add_key", methods=["POST"]) -@login_required -def add_key(): - d = request.form - symbol = d.get("symbol", "").strip() - symbol_name = d.get("symbol_name", "").strip() - market_code = d.get("market_code", "").strip() - sina_code = d.get("sina_code", "").strip() - monitor_type = (d.get("type") or "").strip() - if not symbol or not market_code: - flash("请从下拉列表选择品种(同花顺合约代码)") - return redirect(url_for("keys")) - try: - upper = float(d.get("upper") or 0) - lower = float(d.get("lower") or 0) - except (TypeError, ValueError): - flash("上沿/下沿价格无效") - return redirect(url_for("keys")) - if upper <= lower: - flash("上沿必须大于下沿") - return redirect(url_for("keys")) - - trade_mode = (d.get("trade_mode") or "顺势").strip() - if trade_mode not in ("顺势", "反转"): - trade_mode = "顺势" - try: - risk_reward = float(d.get("risk_reward") or 2) - except (TypeError, ValueError): - risk_reward = 2.0 - risk_reward = max(0.5, min(10.0, risk_reward)) - trailing_be = 1 if d.get("trailing_be") else 0 - if trailing_be and risk_reward < 3: - risk_reward = 3.0 - - from key_monitor_lib import normalize_bar_period - - bar_period = normalize_bar_period(d.get("bar_period") or "5m") - direction = (d.get("direction") or "").strip().lower() - if monitor_type == "箱体突破": - if direction not in ("long", "short"): - flash("箱体突破须选择上方向(做多/做空)") - return redirect(url_for("keys")) - else: - direction = "" - - conn = get_db() - conn.execute( - """INSERT INTO key_monitors - (symbol, symbol_name, market_code, sina_code, monitor_type, direction, - upper, lower, trade_mode, risk_reward, trailing_be, bar_period) - VALUES (?,?,?,?,?,?,?,?,?,?,?,?)""", - ( - symbol, symbol_name, market_code, sina_code, monitor_type, direction, - upper, lower, trade_mode, risk_reward, trailing_be, bar_period, - ), - ) - conn.commit() - conn.close() - flash("关键位监控已添加") - return redirect(url_for("keys")) - - -@app.route("/add_position", methods=["POST"]) -@login_required -def add_position(): - flash("持仓由策略交易或 CTP 自动同步,无需手工录入") - return redirect(url_for("positions")) - - -@app.route("/del_position/") -@login_required -def del_position(pid): - return close_position(pid) - - -@app.route("/close_position/", methods=["POST"]) -@login_required -def close_position(pid): - conn = get_db() - row = conn.execute("SELECT * FROM position_monitors WHERE id=?", (pid,)).fetchone() - if not row: - conn.close() - flash("持仓不存在") - return redirect(url_for("positions")) - sym = row["symbol"] - market = row["market_code"] or "" - sina = row["sina_code"] or "" - direction = row["direction"] - entry = float(row["entry_price"]) - sl = float(row["stop_loss"]) - tp = float(row["take_profit"]) - lots = float(row["lots"] or 1) - open_time = row["open_time"] or "" - close_time = datetime.now(TZ).strftime("%Y-%m-%dT%H:%M") - close_price = fetch_price(sym, market, sina) - if close_price is None: - conn.close() - flash("无法获取现价,平仓失败") - return redirect(url_for("positions")) - capital = float(get_setting("live_capital", "0") or 0) - metrics = calc_position_metrics(direction, entry, sl, tp, lots, close_price, capital, sym) - pnl = metrics.get("float_pnl") or 0.0 - fee = calc_round_trip_fee(sym, entry, close_price, lots, open_time, close_time, trading_mode=_trading_mode()) - pnl_net = round(pnl - fee, 2) - result = classify_close_result(direction, close_price, sl, tp) - minutes = holding_to_minutes(open_time, close_time) - margin_pct = metrics.get("position_pct") - from trade_log_lib import calc_equity_after, refresh_trade_log_equity_chain - equity_after = calc_equity_after(capital, pnl_net) - conn.execute( - """INSERT INTO trade_logs - (symbol, symbol_name, market_code, sina_code, monitor_type, direction, - entry_price, stop_loss, take_profit, close_price, lots, margin, - margin_pct, holding_minutes, open_time, close_time, pnl, fee, pnl_net, - equity_after, result) - VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)""", - ( - sym, row["symbol_name"], market, sina, "持仓监控", direction, - entry, sl, tp, close_price, lots, metrics["margin"], - margin_pct, - minutes, open_time, close_time, pnl, fee, pnl_net, equity_after, result, - ), - ) - conn.execute("DELETE FROM position_monitors WHERE id=?", (pid,)) - try: - refresh_trade_log_equity_chain(conn, capital if capital > 0 else None) - except Exception as exc: - app.logger.debug("equity chain refresh after close: %s", exc) - conn.commit() - conn.close() - touch_stats_cache() - flash(f"已平仓,盈亏 {pnl:.2f} 元(扣费后 {pnl_net:.2f} 元),已记入交易记录") - return redirect(url_for("positions")) - - -@app.route("/trades") -@login_required -def trades(): - return redirect(url_for("records")) - - -@app.route("/update_trade/", methods=["POST"]) -@login_required -def update_trade(tid): - d = request.form - conn = get_db() - row = conn.execute("SELECT * FROM trade_logs WHERE id=?", (tid,)).fetchone() - if not row: - conn.close() - flash("记录不存在") - return redirect(url_for("records")) - row = dict(row) - entry = float(d.get("entry_price") or 0) - close_px = float(d.get("close_price") or 0) - lots = float(d.get("lots") or 0) - sl_raw = d.get("stop_loss") - tp_raw = d.get("take_profit") - stop_loss = float(sl_raw) if sl_raw not in (None, "") else None - take_profit = float(tp_raw) if tp_raw not in (None, "") else None - open_time = (d.get("open_time") or row.get("open_time") or "").strip() - close_time = (d.get("close_time") or row.get("close_time") or "").strip() - direction = (d.get("direction") or row.get("direction") or "long").strip() - - from trade_log_lib import recalc_trade_log_pnl, refresh_trade_log_equity_chain, _read_initial_capital - from trading_context import get_trading_mode - - pnl = float(row.get("pnl") or 0) - fee = float(row.get("fee") or 0) - pnl_net = float(row.get("pnl_net") or 0) - old_entry = float(row.get("entry_price") or 0) - old_close = float(row.get("close_price") or 0) - old_lots = float(row.get("lots") or 0) - prices_changed = ( - abs(entry - old_entry) > 0.0001 - or abs(close_px - old_close) > 0.0001 - or abs(lots - old_lots) > 0.0001 - ) - if prices_changed and close_px > 0 and entry > 0 and lots > 0: - calc = recalc_trade_log_pnl( - symbol=row.get("symbol") or "", - direction=direction, - entry_price=entry, - close_price=close_px, - lots=lots, - stop_loss=stop_loss, - take_profit=take_profit, - open_time=open_time, - close_time=close_time, - trading_mode=get_trading_mode(get_setting), - ) - pnl = calc["pnl"] - fee = calc["fee"] - pnl_net = calc["pnl_net"] - - form_pnl_raw = d.get("pnl") - if form_pnl_raw not in (None, ""): - pnl = float(form_pnl_raw) - pnl_net = round(pnl - fee, 2) - - try: - from app import holding_to_minutes - minutes = int(holding_to_minutes(open_time, close_time) or 0) - except Exception: - minutes = int(d.get("holding_minutes") or row.get("holding_minutes") or 0) - - conn.execute( - """UPDATE trade_logs SET - symbol_name=?, monitor_type=?, direction=?, - entry_price=?, stop_loss=?, take_profit=?, close_price=?, - lots=?, margin=?, holding_minutes=?, open_time=?, close_time=?, - pnl=?, fee=?, pnl_net=?, result=?, verified=1 - WHERE id=?""", - ( - d.get("symbol_name", "").strip(), - d.get("monitor_type", "").strip(), - direction, - entry, - stop_loss, - take_profit, - close_px, - lots, - float(d.get("margin") or 0), - minutes, - open_time, - close_time, - pnl, - fee, - pnl_net, - d.get("result", "").strip(), - tid, - ), - ) - try: - refresh_trade_log_equity_chain(conn, _read_initial_capital(conn)) - except Exception as exc: - app.logger.debug("equity chain refresh after trade edit: %s", exc) - conn.commit() - conn.close() - touch_stats_cache() - flash("交易记录已核对保存") - return redirect(url_for("records")) - - -@app.route("/del_trade/") -@login_required -def del_trade(tid): - conn = get_db() - conn.execute("DELETE FROM trade_logs WHERE id=?", (tid,)) - conn.commit() - conn.close() - touch_stats_cache() - flash("已删除") - return redirect(url_for("records")) - - -@app.route("/fill_review/") -@login_required -def fill_review_from_trade(tid): - conn = get_db() - row = conn.execute("SELECT * FROM trade_logs WHERE id=?", (tid,)).fetchone() - conn.close() - if not row: - flash("记录不存在") - return redirect(url_for("records")) - q = { - "symbol": row["symbol"], - "symbol_name": row["symbol_name"] or row["symbol"], - "market_code": row["market_code"] or "", - "sina_code": row["sina_code"] or "", - "direction": row["direction"], - "entry_price": row["entry_price"], - "stop_loss": row["stop_loss"], - "take_profit": row["take_profit"], - "close_price": row["close_price"], - "lots": row["lots"], - "open_time": row["open_time"], - "close_time": row["close_time"], - "pnl": row["pnl"], - } - params = {k: v for k, v in q.items() if v is not None} - return redirect(url_for("records", **params) + "#review-panel") - - -@app.route("/del_key/") -@login_required -def del_key(pid): - conn = get_db() - conn.execute( - "UPDATE key_monitors SET status='archived', archived_at=? WHERE id=?", - (datetime.now(TZ).isoformat(), pid), - ) - conn.commit() - conn.close() - flash("已移入监控历史") - return redirect(url_for("keys")) - - -@app.route("/records") -@login_required -def records(): - preset = request.args.get("preset", "") - start = request.args.get("start", "") - end = request.args.get("end", "") - if preset: - start, end = parse_review_date_filter(preset, start, end) - - conn = get_db() - ctp_sync_info = None - sql = "SELECT * FROM review_records WHERE 1=1" - params: list = [] - if start: - sql += " AND date(close_time) >= ?" - params.append(start) - if end: - sql += " AND date(close_time) <= ?" - params.append(end) - sql += " ORDER BY id DESC LIMIT 200" - review_list = conn.execute(sql, params).fetchall() - - auto_list = conn.execute( - "SELECT * FROM trade_records ORDER BY id DESC LIMIT 30" - ).fetchall() - trade_list = conn.execute( - "SELECT * FROM trade_logs ORDER BY id DESC LIMIT 500" - ).fetchall() - from trade_log_lib import enrich_trades_for_records, _read_initial_capital - try: - initial_capital = _read_initial_capital(conn) - except Exception: - initial_capital = 100_000.0 - trades, equity_curve = enrich_trades_for_records( - [dict(r) for r in trade_list], - initial_capital=initial_capital, - ) - conn.close() - - trade_prefill_keys = ( - "symbol", "symbol_name", "market_code", "sina_code", "direction", - "entry_price", "stop_loss", "take_profit", "close_price", - "lots", "open_time", "close_time", "pnl", - ) - prefill = {k: request.args.get(k) for k in trade_prefill_keys if request.args.get(k)} - - return render_template( - "records.html", - reviews=review_list, - trades=trades, - equity_curve=equity_curve, - auto_records=auto_list, - ctp_sync_info=ctp_sync_info, - preset=preset, - start=start, - end=end, - prefill=prefill, - open_types=OPEN_TYPES, - exit_triggers=EXIT_TRIGGERS, - behavior_tags=BEHAVIOR_TAGS, - kline_periods=KLINE_PERIODS, - kline_cutoffs=KLINE_CUTOFFS, - ) - - -@app.route("/add_review", methods=["POST"]) -@login_required -def add_review(): - d = request.form - open_type = d.get("open_type", "").strip() - exit_trigger = d.get("exit_trigger", "").strip() - if not open_type: - flash("请选择开仓类型") - return redirect(url_for("records")) - if not exit_trigger: - flash("请选择离场触发") - return redirect(url_for("records")) - - symbol = d.get("symbol", "").strip() - symbol_name = d.get("symbol_name", "").strip() - market_code = d.get("market_code", "").strip() - sina_code = d.get("sina_code", "").strip() - if not symbol or not market_code: - flash("请从下拉列表选择品种(同花顺合约代码)") - return redirect(url_for("records")) - - screenshot = "" - f = request.files.get("screenshot") - if f and f.filename: - fname = secure_filename(f.filename) - ts = datetime.now(TZ).strftime("%Y%m%d%H%M%S") - screenshot = f"{ts}_{fname}" - f.save(os.path.join(UPLOAD_DIR, screenshot)) - - tags = [t for t in BEHAVIOR_TAGS if d.get(f"tag_{t}")] - is_emotion = 1 if tags else 0 - - def num(key: str) -> Optional[float]: - v = d.get(key, "").strip() - if not v: - return None - return float(v) - - open_time = d.get("open_time", "").strip() - close_time = d.get("close_time", "").strip() - direction = d.get("direction", "").strip() - entry_price = num("entry_price") - stop_loss = num("stop_loss") - take_profit = num("take_profit") - close_price = num("close_price") - lots = num("lots") or 1.0 - - holding = calc_holding_duration(open_time, close_time) - initial_pnl = calc_rr_ratio(direction, entry_price, stop_loss, take_profit) - actual_pnl = calc_rr_ratio(direction, entry_price, stop_loss, close_price) - - gross_pnl = num("pnl") - if gross_pnl is None and entry_price and close_price: - spec_mult = calc_position_metrics( - direction, entry_price, stop_loss, take_profit, - lots, close_price, 0, symbol, - ) - gross_pnl = spec_mult.get("float_pnl") - fee = calc_round_trip_fee( - symbol, entry_price or 0, close_price or 0, lots, open_time, close_time, - trading_mode=_trading_mode(), - ) - pnl_net = round((gross_pnl or 0) - fee, 2) if gross_pnl is not None else None - - auto_kline = bool(d.get("auto_kline")) - if auto_kline and not screenshot: - try: - generated = generate_review_kline_chart( - symbol=symbol, - periods=[d.get("kline_period1", "15m"), d.get("kline_period2", "1h")], - count=int(d.get("kline_count") or 300), - cutoff_label=d.get("kline_cutoff", "平仓时间"), - open_time=open_time, - close_time=close_time, - entry_price=entry_price, - stop_loss=stop_loss, - take_profit=take_profit, - close_price=close_price, - upload_dir=UPLOAD_DIR, - ) - if generated: - screenshot = generated - except Exception as exc: - app.logger.warning("auto kline failed: %s", exc) - - conn = get_db() - conn.execute( - """INSERT INTO review_records - (open_time, close_time, symbol, symbol_name, market_code, sina_code, - timeframe, direction, - entry_price, stop_loss, take_profit, close_price, lots, - holding_duration, initial_pnl, actual_pnl, pnl, fee, pnl_net, - open_type, expected_rr, actual_rr, exit_trigger, exit_supplement, - watch_after_breakeven, new_position_while_occupied, screenshot, - auto_kline, kline_period1, kline_period2, kline_count, kline_cutoff, - behavior_tags, is_emotion, notes) - VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)""", - ( - open_time, close_time, - symbol, symbol_name, market_code, sina_code, - d.get("timeframe", "").strip(), - direction, - entry_price, stop_loss, take_profit, close_price, lots, - holding, initial_pnl, actual_pnl, gross_pnl, fee, pnl_net, - open_type, - None, - None, - exit_trigger, - d.get("exit_supplement", "").strip(), - d.get("watch_after_breakeven", "否"), - d.get("new_position_while_occupied", "否"), - screenshot, - 1 if auto_kline else 0, - d.get("kline_period1", "15m"), - d.get("kline_period2", "1h"), - int(d.get("kline_count") or 300), - d.get("kline_cutoff", "平仓时间"), - ",".join(tags), - is_emotion, - d.get("notes", "").strip(), - ), - ) - hook = getattr(app, "_risk_review_hook", None) - if hook: - hook( - conn, - ",".join(tags), - exit_trigger, - d.get("exit_supplement", "").strip(), - ) - conn.commit() - conn.close() - touch_stats_cache() - flash("复盘记录已保存") - return redirect(url_for("records")) - - -@app.route("/del_review/") -@login_required -def del_review(rid): - conn = get_db() - row = conn.execute("SELECT screenshot FROM review_records WHERE id=?", (rid,)).fetchone() - if row and row["screenshot"]: - path = os.path.join(UPLOAD_DIR, row["screenshot"]) - if os.path.isfile(path): - os.remove(path) - conn.execute("DELETE FROM review_records WHERE id=?", (rid,)) - conn.commit() - conn.close() - touch_stats_cache() - flash("已删除") - return redirect(url_for("records")) - - -@app.route("/uploads/") -@login_required -def uploaded_file(filename): - from flask import send_from_directory - return send_from_directory(UPLOAD_DIR, filename) - - -@app.route("/del_record/") -@login_required -def del_record(rid): - conn = get_db() - conn.execute("DELETE FROM trade_records WHERE id=?", (rid,)) - conn.commit() - conn.close() - flash("已删除") - return redirect(url_for("records")) - - -@app.route("/stats") -@login_required -def stats(): - return render_template("stats.html") - - -@app.route("/calendar") -@login_required -def trade_calendar(): - return render_template("calendar.html") - - -@app.route("/api/stats") -@login_required -def api_stats(): - return jsonify(get_stats_data()) - - -@app.route("/api/stats/views") -@login_required -def api_stats_views(): - return jsonify({"views": STATS_VIEWS}) - - -@app.route("/api/stats/refresh", methods=["POST"]) -@login_required -def api_stats_refresh(): - conn = get_db() - capital = float(get_setting("live_capital", "0") or 0) - data = refresh_stats_cache(conn, capital) - conn.close() - return jsonify(data) - - -@app.route("/api/stats/calendar") -@login_required -def api_stats_calendar(): - now = datetime.now(TZ) - year = request.args.get("year", type=int) or now.year - month = request.args.get("month", type=int) or now.month - if month < 1 or month > 12: - return jsonify({"error": "invalid month"}), 400 - conn = get_db() - try: - data = get_calendar_month(conn, year, month) - finally: - conn.close() - return jsonify(data) - - -@app.route("/api/stats/calendar/day") -@login_required -def api_stats_calendar_day(): - day = (request.args.get("date") or "").strip() - if not day: - return jsonify({"error": "date required"}), 400 - try: - date.fromisoformat(day) - except ValueError: - return jsonify({"error": "invalid date"}), 400 - conn = get_db() - try: - data = get_calendar_day(conn, day) - finally: - conn.close() - return jsonify(data) - - -_dashboard_sync_tick = {"n": 0} - - -@app.route("/dashboard") -@login_required -@require_nav("dashboard") -def dashboard(): - return render_template("dashboard.html") - - -@app.route("/risk-guide") -@login_required -@require_nav("risk_guide") -def risk_guide(): - from doc_render import read_doc, render_markdown - - try: - _title, raw = read_doc("risk-guide") - except FileNotFoundError: - flash("文档不存在") - return redirect(url_for("positions")) - return render_template("risk_guide.html", doc_html=render_markdown(raw)) - - -@app.route("/api/dashboard/live") -@login_required -def api_dashboard_live(): - if not nav_enabled(get_setting, "dashboard"): - return jsonify({"ok": False, "error": "数据看板已在系统设置中关闭"}), 403 - from dashboard_lib import build_dashboard_payload - - _dashboard_sync_tick["n"] += 1 - sync_trades = _dashboard_sync_tick["n"] % 15 == 0 - try: - payload = build_dashboard_payload( - get_db=get_db, - get_setting=get_setting, - fetch_price=fetch_price, - sync_ctp_trades=sync_trades, - ) - return jsonify(payload) - except Exception as exc: - app.logger.exception("dashboard live: %s", exc) - return jsonify({"ok": False, "error": "看板数据暂时不可用"}), 503 - - -@app.route("/market") -@login_required -@require_nav("market") -def market_page(): - symbol = request.args.get("symbol", "").strip() - period = request.args.get("period", "15m").strip() - valid = {p["key"] for p in MARKET_PERIODS} - if period not in valid: - period = "15m" - ctp_st = {} - try: - from vnpy_bridge import ctp_status - from trading_context import get_trading_mode - - ctp_st = ctp_status(get_trading_mode(get_setting)) - except Exception: - pass - return render_template( - "market.html", - symbol=symbol, - period=period, - market_periods=MARKET_PERIODS, - quote_label=get_quote_source_label(ctp_connected=bool(ctp_st.get("connected"))), - ctp_connected=bool(ctp_st.get("connected")), - ) - - -@app.route("/api/kline") -@login_required -def api_kline(): - symbol = request.args.get("symbol", "").strip() - period = request.args.get("period", "15m").strip() - if not symbol: - return jsonify({"error": "请提供合约代码"}), 400 - try: - from trading_context import get_trading_mode - - data = fetch_market_klines( - symbol, period, DB_PATH, prefer_ctp=False, - ) - except Exception as exc: - app.logger.warning("kline api failed: %s", exc) - return jsonify({"error": str(exc)}), 500 - if not data.get("chart_symbol"): - return jsonify({"error": "无法识别合约代码"}), 400 - if not data.get("bars"): - return jsonify({"error": "未获取到K线数据,请稍后重试或更换合约"}), 404 - return jsonify(data) - - -@app.route("/api/kline/stream") -@login_required -def api_kline_stream(): - from queue import Empty - - symbol = request.args.get("symbol", "").strip() - period = request.args.get("period", "15m").strip() - market_code = request.args.get("market_code", "").strip() - sina_code = request.args.get("sina_code", "").strip() - if not symbol: - return jsonify({"error": "请提供合约代码"}), 400 - - def generate(): - sub = kline_hub.subscribe(symbol, period, market_code, sina_code) - try: - kline_data = fetch_market_klines( - symbol, period, DB_PATH, prefer_ctp=False, - ) - if kline_data.get("bars"): - yield sse_format("kline", kline_data) - yield sse_format( - "quote", - build_market_quote_payload( - symbol, market_code, sina_code, prefer_sina=True, - ), - ) - while True: - try: - msg = sub.queue.get(timeout=20) - yield sse_format(msg["event"], msg["data"]) - except Empty: - yield ": heartbeat\n\n" - finally: - kline_hub.unsubscribe(sub) - - return Response( - stream_with_context(generate()), - mimetype="text/event-stream", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "X-Accel-Buffering": "no", - }, - ) - - -@app.route("/api/market_quote") -@login_required -def api_market_quote(): - symbol = request.args.get("symbol", "").strip() - market_code = request.args.get("market_code", "").strip() - sina_code = request.args.get("sina_code", "").strip() - if not symbol and not market_code: - return jsonify({"error": "请提供合约"}), 400 - return jsonify(build_market_quote_payload( - symbol, market_code, sina_code, prefer_sina=True, - )) - - -@app.route("/contract") -@login_required -def contract_profile_page(): - return redirect(url_for("positions")) - - -@app.route("/api/contract_profile") -@login_required -def api_contract_profile(): - return jsonify({"error": "品种简介功能已移除"}), 404 - - -@app.route("/fees", methods=["GET", "POST"]) -@login_required -@require_nav("fees") -def fees(): - from trading_context import get_trading_mode - from ctp_fee_worker import ( - schedule_ctp_fee_sync, - get_fee_last_sync, - fees_synced_today, - fee_sync_in_progress, - ) - from vnpy_bridge import ctp_status - - mode = get_trading_mode(get_setting) - if request.method == "POST": - action = request.form.get("action") - if action == "sync_ctp": - force = request.form.get("force") == "1" - _, msg = schedule_ctp_fee_sync( - mode, - get_setting=get_setting, - set_setting=set_setting, - force=force, - ) - flash(msg) - return redirect(url_for("fees")) - - rates = list_fee_rates_for_ui() - fee_counts = count_fee_rates_by_source() - ctp_st = ctp_status(mode) - return render_template( - "fees.html", - rates=rates, - fee_counts=fee_counts, - fee_last_sync=get_fee_last_sync(get_setting), - fee_synced_today=fees_synced_today(get_setting), - fee_sync_running=fee_sync_in_progress(), - ctp_connected=bool(ctp_st.get("connected")), - ) - - -@app.route("/api/backup/list") -@login_required -def api_backup_list(): - return jsonify( - { - "dir": str(backup_dir()), - "last_at": get_backup_last_at(get_setting), - "running": backup_in_progress(), - "items": list_backups(), - } - ) - - -@app.route("/api/backup/download/") -@login_required -def api_backup_download(filename): - from flask import send_file - - try: - path = resolve_backup_file(filename) - except (ValueError, FileNotFoundError) as exc: - return jsonify({"error": str(exc)}), 404 - return send_file(path, as_attachment=True, download_name=path.name) - - -@app.route("/settings", methods=["GET", "POST"]) -@login_required -def settings(): - if request.method == "POST": - action = request.form.get("action") - if action == "backup_now": - ok, msg = schedule_backup( - get_setting=get_setting, - set_setting=set_setting, - include_uploads=True, - ) - flash(msg if ok else msg) - elif action == "backup_config": - auto = request.form.get("backup_auto_enabled") == "1" - set_setting("backup_auto_enabled", "1" if auto else "0") - try: - hour = int(request.form.get("backup_auto_hour", "3") or 3) - set_setting("backup_auto_hour", str(max(0, min(23, hour)))) - except ValueError: - flash("自动备份小时无效") - return redirect(url_for("settings")) - try: - keep = int(request.form.get("backup_keep_count", "30") or 30) - set_setting("backup_keep_count", str(max(5, min(200, keep)))) - except ValueError: - flash("保留份数无效") - return redirect(url_for("settings")) - flash("备份策略已保存") - elif action == "wechat": - webhook = request.form.get("wechat_webhook", "").strip() - set_setting("wechat_webhook", webhook) - flash("企业微信配置已保存") - elif action == "ai": - set_setting("ai_enabled", "1" if request.form.get("ai_enabled") else "0") - provider = (request.form.get("ai_provider") or "ollama").strip().lower() - if provider not in ("ollama", "openai"): - provider = "ollama" - set_setting("ai_provider", provider) - set_setting("ai_ollama_base_url", (request.form.get("ai_ollama_base_url") or "").strip()) - set_setting("ai_ollama_model", (request.form.get("ai_ollama_model") or "").strip()) - set_setting("ai_openai_base_url", (request.form.get("ai_openai_base_url") or "").strip()) - key = (request.form.get("ai_openai_api_key") or "").strip() - if key: - set_setting("ai_openai_api_key", key) - set_setting("ai_openai_model", (request.form.get("ai_openai_model") or "").strip()) - set_setting("ai_daily_report_enabled", "1" if request.form.get("ai_daily_report_enabled") else "0") - try: - set_setting("ai_daily_report_hour", str(max(0, min(23, int(request.form.get("ai_daily_report_hour", "15") or 15))))) - except ValueError: - pass - try: - set_setting("ai_daily_report_minute", str(max(0, min(59, int(request.form.get("ai_daily_report_minute", "5") or 5))))) - except ValueError: - pass - flash("AI 配置已保存") - elif action == "trading": - mode = request.form.get("trading_mode", "simulation").strip() - if mode not in ("simulation", "live"): - mode = "simulation" - sizing = request.form.get("position_sizing_mode", "fixed").strip() - if sizing == "risk": - sizing = "amount" - if sizing not in ("fixed", "amount"): - sizing = "fixed" - set_setting("trading_mode", mode) - set_setting("position_sizing_mode", sizing) - try: - fl = int(float(request.form.get("fixed_lots", "1") or 1)) - set_setting("fixed_lots", str(max(1, fl))) - except ValueError: - flash("固定手数无效") - return redirect(url_for("settings")) - try: - fa = float(request.form.get("fixed_amount", "5000") or 5000) - set_setting("fixed_amount", str(max(1.0, fa))) - except ValueError: - flash("固定金额无效") - return redirect(url_for("settings")) - try: - rp = float(request.form.get("risk_percent", "1") or 1) - set_setting("risk_percent", str(max(0.1, min(100.0, rp)))) - except ValueError: - pass - try: - mp = float(request.form.get("max_margin_pct", "30") or 30) - set_setting("max_margin_pct", str(max(1.0, min(100.0, mp)))) - except ValueError: - flash("保证金比例无效") - return redirect(url_for("settings")) - try: - rmp = float(request.form.get("roll_max_margin_pct", "50") or 50) - set_setting("roll_max_margin_pct", str(max(1.0, min(100.0, rmp)))) - except ValueError: - flash("滚仓保证金比例无效") - return redirect(url_for("settings")) - try: - tb = int(float(request.form.get("trailing_be_tick_buffer", "2") or 2)) - set_setting("trailing_be_tick_buffer", str(max(1, min(20, tb)))) - except ValueError: - flash("移动保本缓冲无效") - return redirect(url_for("settings")) - try: - pt = int(float(request.form.get("pending_order_timeout_min", "5") or 5)) - set_setting("pending_order_timeout_min", str(max(1, min(60, pt)))) - except ValueError: - flash("挂单超时无效") - return redirect(url_for("settings")) - flash("交易模式已保存") - elif action == "ctp": - from ctp_settings import save_ctp_auto_connect, is_ctp_auto_connect_enabled - from ctp_settings import save_ctp_settings_from_form - from vnpy_bridge import ctp_disconnect - - was_enabled = is_ctp_auto_connect_enabled(get_setting) - auto_enabled = save_ctp_auto_connect(request.form, set_setting) - save_result = save_ctp_settings_from_form(request.form, set_setting) - if not auto_enabled: - ctp_disconnect(set_disabled_hint=True) - elif not was_enabled and auto_enabled: - try: - from vnpy_bridge import get_bridge - from trading_context import get_trading_mode - - mode = get_trading_mode(get_setting) - get_bridge().reconnect_after_settings_saved(mode) - except Exception as exc: - app.logger.debug("CTP connect after enable auto: %s", exc) - pwd_updated = save_result.get("passwords_updated") or [] - pwd_empty = save_result.get("passwords_submitted_empty") or [] - simnow_pwd_len = len((request.form.get("simnow_password") or "").strip()) - live_pwd_len = len((request.form.get("ctp_live_password") or "").strip()) - print( - f"CTP settings save: simnow_password_len={simnow_pwd_len} " - f"live_password_len={live_pwd_len} updated={pwd_updated}", - flush=True, - ) - app.logger.info( - "CTP settings save: simnow_password_len=%s live_password_len=%s updated=%s", - simnow_pwd_len, - live_pwd_len, - pwd_updated, - ) - if "simnow_password" in pwd_updated: - pwd_note = f"SimNow 交易密码已更新({simnow_pwd_len} 位)" - elif "simnow_password" in pwd_empty: - pwd_note = "SimNow 交易密码未改:提交为空,请在「交易密码」框手打后再保存" - elif "ctp_live_password" in pwd_updated: - pwd_note = "实盘交易密码已更新" - elif "ctp_live_password" in pwd_empty: - pwd_note = "实盘交易密码未改(提交为空)" - else: - pwd_note = "" - if not auto_enabled: - flash("CTP 配置已保存;自动连接已关闭,所有 CTP 连接已断开") - return redirect(url_for("settings")) - if not was_enabled: - flash("CTP 配置已保存;自动连接已开启,正在连接…") - return redirect(url_for("settings")) - flash_msg = "CTP 配置已保存,正在使用新地址重连…" - if pwd_note: - flash_msg = f"CTP 配置已保存;{pwd_note},正在重连…" - try: - from vnpy_bridge import get_bridge - from trading_context import get_trading_mode - - b = get_bridge() - if pwd_updated: - b._clear_login_cooldown() - mode = get_trading_mode(get_setting) - info = b.reconnect_after_settings_saved(mode) - if info.get("cooldown"): - flash_msg = f"CTP 配置已保存;{pwd_note or '请稍后再连'}" - elif not info.get("started") and info.get("connected"): - flash_msg = f"CTP 配置已保存;{pwd_note or '当前连接正常'}" - except Exception as exc: - app.logger.warning("CTP reconnect after settings save: %s", exc) - flash_msg = f"CTP 配置已保存;{pwd_note or '请稍后在持仓监控页重连'}" - flash(flash_msg) - elif action == "nav": - items = {k: request.form.get(f"nav_{k}") == "on" for k in NAV_TOGGLES} - save_nav_items(set_setting, items) - flash("导航显示已保存") - elif action == "password": - ok, msg, _ = save_admin_credentials( - username=request.form.get("admin_username", ""), - old_password=request.form.get("old_password", ""), - new_password=request.form.get("new_password", ""), - new_password2=request.form.get("new_password2", ""), - get_setting=get_setting, - set_setting=set_setting, - ) - if ok and session.get("logged_in"): - session["username"] = (request.form.get("admin_username") or "").strip() - flash(msg) - return redirect(url_for("settings")) - - webhook = get_setting("wechat_webhook") - username = get_setting("admin_username") - ctp_st = {} - try: - from vnpy_bridge import ctp_status - from trading_context import get_trading_mode - - ctp_st = ctp_status(get_trading_mode(get_setting)) - except Exception: - pass - from ctp_settings import get_ctp_settings_for_ui, is_ctp_auto_connect_enabled - from product_recommend import small_account_margin_recommendations - - return render_template( - "settings.html", - webhook=webhook, - username=username, - quote_label=get_quote_source_label(ctp_connected=bool(ctp_st.get("connected"))), - ctp_status=ctp_st, - ctp_cfg=get_ctp_settings_for_ui(), - ctp_auto_connect=is_ctp_auto_connect_enabled(get_setting), - trading_mode=get_setting("trading_mode", "simulation"), - position_sizing_mode=get_setting("position_sizing_mode", "fixed"), - fixed_lots=get_setting("fixed_lots", "1"), - fixed_amount=get_setting("fixed_amount", "5000"), - risk_percent=get_setting("risk_percent", "1"), - max_margin_pct=get_setting("max_margin_pct", "30"), - roll_max_margin_pct=get_setting("roll_max_margin_pct", "50"), - small_account_margin_rec=small_account_margin_recommendations(), - trailing_be_tick_buffer=get_setting("trailing_be_tick_buffer", "2"), - pending_order_timeout_min=get_setting("pending_order_timeout_min", "5"), - nav_items=get_nav_items(get_setting), - nav_toggles=NAV_TOGGLES, - backup_dir=str(backup_dir()), - backup_last_at=get_backup_last_at(get_setting), - backup_running=backup_in_progress(), - backup_items=list_backups(), - backup_auto_enabled=get_setting("backup_auto_enabled", "1") == "1", - backup_auto_hour=get_setting("backup_auto_hour", "3"), - backup_keep_count=get_setting("backup_keep_count", "30"), - backup_restore_dir=default_restore_dir(), - ai_enabled=get_setting("ai_enabled", "0") == "1", - ai_provider=get_setting("ai_provider", "ollama"), - ai_ollama_base_url=get_setting("ai_ollama_base_url", "http://127.0.0.1:11434"), - ai_ollama_model=get_setting("ai_ollama_model", "qwen2.5:7b"), - ai_openai_base_url=get_setting("ai_openai_base_url", "https://api.openai.com/v1"), - ai_openai_api_key=get_setting("ai_openai_api_key", ""), - ai_openai_model=get_setting("ai_openai_model", "gpt-4o-mini"), - ai_daily_report_enabled=get_setting("ai_daily_report_enabled", "1") == "1", - ai_daily_report_hour=get_setting("ai_daily_report_hour", "15"), - ai_daily_report_minute=get_setting("ai_daily_report_minute", "5"), - ) - - -if os.getenv("QIHUO_INIT_ONLY") != "1": - install_trading( - app, - login_required=login_required, - require_nav=require_nav, - get_db=get_db, - get_setting=get_setting, - set_setting=set_setting, - fetch_price=fetch_price, - send_wechat_msg=send_wechat_msg, - ) - try_init_vnpy({}) - start_background_threads() - -# —————————————— 启动 —————————————— - -if __name__ == "__main__": - app.run(host=HOST, port=PORT, debug=DEBUG, threaded=True) +# Copyright (c) 2025-2026 马建军. All rights reserved. +# 专有软件 — 未经授权禁止复制、传播、转售。 +# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 +# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md + +import os +import sys + +_ROOT = os.path.dirname(os.path.abspath(__file__)) +_legacy = os.path.join(_ROOT, "_legacy") +if _legacy not in sys.path: + sys.path.insert(0, _legacy) + +from modules.core.paths import ROOT, UPLOADS_DIR, DB_PATH, ensure_runtime_dirs, resolve_env_file +from locale_fix import ensure_process_locale + +ensure_process_locale() +ensure_runtime_dirs() + +import time +import threading +import requests +from datetime import date, datetime, timedelta +from typing import Optional +from functools import wraps +from zoneinfo import ZoneInfo + +from werkzeug.utils import secure_filename + +from dotenv import load_dotenv +from flask import ( + Flask, render_template, request, redirect, url_for, + flash, session, jsonify, Response, stream_with_context, +) +from werkzeug.security import check_password_hash, generate_password_hash + +from functools import wraps + +from symbols import ( + search_symbols, + ths_to_codes, + list_main_contracts_grouped, + list_recommended_symbols_grouped, + refresh_main_index, +) +from contract_specs import calc_position_metrics +from fee_specs import ( + calc_fee_breakdown, + calc_round_trip_fee, + list_fee_rates_for_ui, + count_fee_rates_by_source, + purge_non_ctp_fee_rates, +) +from nav_settings import NAV_TOGGLES, get_nav_items, nav_enabled, save_nav_items +from stats_engine import ( + STATS_VIEWS, + build_all_stats, + get_calendar_day, + get_calendar_month, + load_stats_cache, + refresh_stats_cache, +) +from kline_store import ensure_kline_tables +from kline_stream import kline_hub, sse_format +from kline_chart import generate_review_kline_chart, fetch_market_klines, MARKET_PERIODS +from market import get_price as market_get_price, set_ths_refresh_token, get_quote_source_label +from db_conn import OperationalError, connect_db, database_label, is_benign_migration_error, is_db_contention_error, is_schema_migration_error, rollback_if_postgres +from admin_settings import save_admin_credentials +from db_backup import ( + backup_dir, + backup_in_progress, + default_restore_dir, + get_backup_last_at, + list_backups, + resolve_backup_file, + schedule_backup, + start_backup_worker, +) +from strategy.strategy_db import init_strategy_tables + +load_dotenv(resolve_env_file()) +load_dotenv(os.path.join(ROOT, ".env")) # 兼容旧路径 + +app = Flask( + __name__, + template_folder=os.path.join(ROOT, "modules", "web", "templates"), + static_folder=os.path.join(ROOT, "modules", "web", "static"), +) +app.secret_key = os.getenv("SECRET_KEY", "futures_monitor_default_secret") + +HOST = os.getenv("HOST", "0.0.0.0") +PORT = int(os.getenv("PORT", "6600")) +DEBUG = os.getenv("DEBUG", "false").lower() in ("1", "true", "yes") + +UPLOAD_DIR = str(UPLOADS_DIR) +TZ = ZoneInfo("Asia/Shanghai") + +OPEN_TYPES = ["突破开仓", "回调开仓", "追涨杀跌", "计划内开仓", "震荡摸顶底", "其他"] +EXIT_TRIGGERS = ["止盈", "止损", "手工平仓", "移动止损", "时间离场", "其他"] +BEHAVIOR_TAGS = ["怕踏空", "报复开仓", "盈利飘了", "拿不住单", "扛单", "重仓违规"] +KLINE_PERIODS = ["1m", "3m", "5m", "15m", "30m", "1h", "4h", "1d"] +KLINE_CUTOFFS = ["平仓时间", "开仓时间", "当前时间"] + + +def today_str() -> str: + return datetime.now(TZ).date().isoformat() + + +def calc_holding_duration(open_time: str, close_time: str) -> str: + try: + o = datetime.fromisoformat(open_time.strip().replace(" ", "T")[:19]) + c = datetime.fromisoformat(close_time.strip().replace(" ", "T")[:19]) + delta = c - o + if delta.total_seconds() < 0: + return "" + secs = int(delta.total_seconds()) + h, rem = divmod(secs, 3600) + m, _ = divmod(rem, 60) + if h: + return f"{h}小时{m}分钟" + return f"{m}分钟" + except Exception: + return "" + + +def holding_to_minutes(open_time: str, close_time: str) -> int: + try: + o = datetime.fromisoformat(open_time.strip().replace(" ", "T")) + c = datetime.fromisoformat(close_time.strip().replace(" ", "T")) + secs = int((c - o).total_seconds()) + return max(0, secs // 60) + except Exception: + return 0 + + +def classify_close_result(direction: str, close: float, sl: float, tp: float) -> str: + """根据平仓价与止损/止盈距离判断结果。""" + if close is None: + return "手动平仓" + tol = max(abs(close) * 0.002, 1.0) + if abs(close - tp) <= tol: + return "止盈" + if abs(close - sl) <= tol: + return "止损" + return "手动平仓" + + +def calc_rr_ratio(direction: str, entry: float, stop: float, target: float) -> Optional[float]: + """盈亏比 = 盈利空间 / 风险空间。""" + if entry is None or stop is None or target is None: + return None + if direction == "long": + risk = entry - stop + if risk <= 0: + return None + return round((target - entry) / risk, 2) + if direction == "short": + risk = stop - entry + if risk <= 0: + return None + return round((entry - target) / risk, 2) + return None + + +def calc_theoretical_pnl(direction: str, entry: float, target: float, lots: float) -> Optional[float]: + if entry is None or target is None or lots is None: + return None + if direction == "long": + return round((target - entry) * lots, 2) + if direction == "short": + return round((entry - target) * lots, 2) + return None + + +def parse_review_date_filter(preset: str, start: str, end: str) -> tuple[str, str]: + today = datetime.now(TZ).date() + if preset == "today": + s = today.isoformat() + return s, s + if preset == "week": + monday = today - timedelta(days=today.weekday()) + return monday.isoformat(), today.isoformat() + if preset == "month": + return today.replace(day=1).isoformat(), today.isoformat() + return start.strip(), end.strip() + + +def expire_old_plans(): + """当日结束后计划自动失效,保留历史。""" + today = today_str() + conn = get_db() + conn.execute( + "UPDATE order_plans SET status='expired' WHERE plan_date < ? AND status IN ('planned', 'active')", + (today,), + ) + conn.execute( + "UPDATE order_plans SET plan_date=date(created_at) WHERE plan_date IS NULL OR plan_date=''" + ) + conn.commit() + conn.close() + + +def get_db(): + return connect_db() + + +def get_setting(key: str, default: str = "") -> str: + conn = get_db() + row = conn.execute("SELECT value FROM settings WHERE key=?", (key,)).fetchone() + conn.close() + return row["value"] if row else default + + +def set_setting(key: str, value: str): + conn = get_db() + conn.execute( + "INSERT INTO settings (key, value) VALUES (?, ?) ON CONFLICT(key) DO UPDATE SET value=?", + (key, value, value), + ) + conn.commit() + conn.close() + + +def require_nav(key: str): + """导航项关闭时拒绝访问对应页面。""" + def decorator(f): + @wraps(f) + def wrapped(*args, **kwargs): + if not nav_enabled(get_setting, key): + flash("该页面已在系统设置中关闭") + return redirect(url_for("positions")) + return f(*args, **kwargs) + return wrapped + return decorator + + +def _static_asset_v() -> str: + base = os.path.dirname(os.path.abspath(__file__)) + rels = ( + "static/js/trade.js", + "static/js/dashboard.js", + "static/js/orientation.js", + "static/css/records.css", + "static/js/records.js", + "static/js/settings.js", + "static/css/mobile.css", + "static/css/responsive.css", + "static/css/trade.css", + "static/css/dashboard.css", + "static/css/doc.css", + "static/css/base.css", + ) + mtimes = [] + for rel in rels: + path = os.path.join(base, rel.replace("/", os.sep)) + if os.path.isfile(path): + mtimes.append(os.path.getmtime(path)) + return str(int(max(mtimes))) if mtimes else "0" + + +def _ua_is_phone(ua: str) -> bool: + ua_l = (ua or "").lower() + if "ipad" in ua_l: + return False + if "android" in ua_l and "mobile" not in ua_l: + return False + if any(x in ua_l for x in ("iphone", "ipod", "windows phone", "iemobile")): + return True + if "android" in ua_l and "mobile" in ua_l: + return True + if "mobile" in ua_l or "harmonyos" in ua_l or "openharmony" in ua_l: + return True + return False + + +@app.context_processor +def inject_globals(): + return {"nav_items": get_nav_items(get_setting), "asset_v": _static_asset_v()} + + +def _trading_mode() -> str: + return (get_setting("trading_mode", "simulation") or "simulation").strip() + + +def touch_stats_cache(): + try: + conn = get_db() + capital = float(get_setting("live_capital", "0") or 0) + refresh_stats_cache(conn, capital) + conn.close() + except Exception as exc: + app.logger.warning("stats cache refresh failed: %s", exc) + + +def get_stats_data() -> dict: + conn = get_db() + try: + capital = float(get_setting("live_capital", "0") or 0) + data = load_stats_cache(conn) + if data: + return data + try: + return refresh_stats_cache(conn, capital) + except OperationalError as exc: + if not is_db_contention_error(exc): + raise + app.logger.warning("stats cache refresh contention, compute without save: %s", exc) + return build_all_stats(conn, capital) + finally: + conn.close() + + +def init_db(): + import strategy.strategy_db as strategy_db + import risk.account_risk_lib as account_risk_lib + + strategy_db._TABLES_READY = False + account_risk_lib._SCHEMA_READY = False + + conn = get_db() + c = conn.cursor() + c.execute("CREATE TABLE IF NOT EXISTS settings (key TEXT PRIMARY KEY, value TEXT)") + c.execute('''CREATE TABLE IF NOT EXISTS order_plans + (id INTEGER PRIMARY KEY AUTOINCREMENT, + symbol TEXT, symbol_name TEXT, direction TEXT, + zone_upper REAL, zone_lower REAL, + stop_loss REAL, take_profit REAL, + status TEXT DEFAULT 'planned', + triggered_at TIMESTAMP, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)''') + c.execute('''CREATE TABLE IF NOT EXISTS key_monitors + (id INTEGER PRIMARY KEY AUTOINCREMENT, + symbol TEXT, symbol_name TEXT, monitor_type TEXT, direction TEXT, + upper REAL, lower REAL, + upper_triggered INTEGER DEFAULT 0, + lower_triggered INTEGER DEFAULT 0, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)''') + c.execute('''CREATE TABLE IF NOT EXISTS trade_records + (id INTEGER PRIMARY KEY AUTOINCREMENT, + symbol TEXT, symbol_name TEXT, monitor_type TEXT, direction TEXT, + trigger_price REAL, stop_loss REAL, take_profit REAL, + result TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)''') + conn.commit() + migrations = [ + "ALTER TABLE key_monitors ADD COLUMN symbol_name TEXT", + "ALTER TABLE key_monitors ADD COLUMN upper_triggered INTEGER DEFAULT 0", + "ALTER TABLE key_monitors ADD COLUMN lower_triggered INTEGER DEFAULT 0", + "ALTER TABLE trade_records ADD COLUMN symbol_name TEXT", + "ALTER TABLE order_plans ADD COLUMN sina_code TEXT", + "ALTER TABLE order_plans ADD COLUMN market_code TEXT", + "ALTER TABLE key_monitors ADD COLUMN market_code TEXT", + "ALTER TABLE key_monitors ADD COLUMN sina_code TEXT", + "ALTER TABLE trade_records ADD COLUMN market_code TEXT", + "ALTER TABLE order_plans ADD COLUMN plan_date TEXT", + "ALTER TABLE order_plans ADD COLUMN decision_reason TEXT", + "ALTER TABLE key_monitors ADD COLUMN status TEXT DEFAULT 'active'", + "ALTER TABLE key_monitors ADD COLUMN archived_at TEXT", + "ALTER TABLE key_monitors ADD COLUMN trade_mode TEXT DEFAULT '顺势'", + "ALTER TABLE key_monitors ADD COLUMN risk_reward REAL DEFAULT 2", + "ALTER TABLE key_monitors ADD COLUMN trailing_be INTEGER DEFAULT 0", + "ALTER TABLE key_monitors ADD COLUMN last_trigger_bar TEXT", + "ALTER TABLE key_monitors ADD COLUMN alert_push_count INTEGER DEFAULT 0", + "ALTER TABLE key_monitors ADD COLUMN alert_last_push_at TEXT", + "ALTER TABLE key_monitors ADD COLUMN alert_break_side TEXT", + "ALTER TABLE key_monitors ADD COLUMN breakout_bar_time TEXT", + "ALTER TABLE key_monitors ADD COLUMN alert_close_price REAL", + "ALTER TABLE key_monitors ADD COLUMN bar_period TEXT DEFAULT '5m'", + "ALTER TABLE review_records ADD COLUMN direction TEXT", + "ALTER TABLE review_records ADD COLUMN entry_price REAL", + "ALTER TABLE review_records ADD COLUMN stop_loss REAL", + "ALTER TABLE review_records ADD COLUMN take_profit REAL", + "ALTER TABLE review_records ADD COLUMN close_price REAL", + "ALTER TABLE review_records ADD COLUMN lots REAL", + "ALTER TABLE review_records ADD COLUMN holding_duration TEXT", + "ALTER TABLE review_records ADD COLUMN initial_pnl REAL", + "ALTER TABLE review_records ADD COLUMN actual_pnl REAL", + "ALTER TABLE review_records ADD COLUMN is_emotion INTEGER DEFAULT 0", + "ALTER TABLE review_records ADD COLUMN symbol_name TEXT", + "ALTER TABLE review_records ADD COLUMN market_code TEXT", + "ALTER TABLE review_records ADD COLUMN sina_code TEXT", + "ALTER TABLE trade_logs ADD COLUMN fee REAL", + "ALTER TABLE trade_logs ADD COLUMN pnl_net REAL", + "ALTER TABLE trade_logs ADD COLUMN margin_pct REAL", + "ALTER TABLE trade_logs ADD COLUMN equity_after REAL", + "ALTER TABLE review_records ADD COLUMN fee REAL", + "ALTER TABLE review_records ADD COLUMN pnl_net REAL", + ] + for sql in migrations: + try: + c.execute(sql) + conn.commit() + except Exception as exc: + if not is_schema_migration_error(exc): + raise + rollback_if_postgres(conn) + c.execute('''CREATE TABLE IF NOT EXISTS review_records + (id INTEGER PRIMARY KEY AUTOINCREMENT, + open_time TEXT, close_time TEXT, + symbol TEXT, timeframe TEXT, + pnl REAL, + open_type TEXT, expected_rr REAL, actual_rr REAL, + exit_trigger TEXT, exit_supplement TEXT, + watch_after_breakeven TEXT, new_position_while_occupied TEXT, + screenshot TEXT, + auto_kline INTEGER DEFAULT 0, + kline_period1 TEXT, kline_period2 TEXT, + kline_count INTEGER, kline_cutoff TEXT, + behavior_tags TEXT, notes TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)''') + c.execute('''CREATE TABLE IF NOT EXISTS position_monitors + (id INTEGER PRIMARY KEY AUTOINCREMENT, + symbol TEXT, symbol_name TEXT, market_code TEXT, sina_code TEXT, + direction TEXT, lots REAL, entry_price REAL, + stop_loss REAL, take_profit REAL, open_time TEXT, + status TEXT DEFAULT 'active', + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)''') + c.execute('''CREATE TABLE IF NOT EXISTS trade_logs + (id INTEGER PRIMARY KEY AUTOINCREMENT, + symbol TEXT, symbol_name TEXT, market_code TEXT, sina_code TEXT, + monitor_type TEXT, direction TEXT, + entry_price REAL, stop_loss REAL, take_profit REAL, close_price REAL, + lots REAL, margin REAL, holding_minutes INTEGER, + open_time TEXT, close_time TEXT, + pnl REAL, result TEXT, + verified INTEGER DEFAULT 0, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)''') + c.execute('''CREATE TABLE IF NOT EXISTS fee_rates + (product TEXT PRIMARY KEY, + exchange TEXT, + mult INTEGER, + open_fixed REAL DEFAULT 0, + open_ratio REAL DEFAULT 0, + close_yesterday_fixed REAL DEFAULT 0, + close_yesterday_ratio REAL DEFAULT 0, + close_today_fixed REAL DEFAULT 0, + close_today_ratio REAL DEFAULT 0, + updated_at TEXT)''') + c.execute('''CREATE TABLE IF NOT EXISTS stats_cache + (key TEXT PRIMARY KEY, + data_json TEXT NOT NULL, + updated_at TEXT NOT NULL)''') + conn.commit() + for sql in ( + "ALTER TABLE fee_rates ADD COLUMN source TEXT DEFAULT 'local'", + ): + try: + c.execute(sql) + conn.commit() + except Exception as exc: + if not is_schema_migration_error(exc): + raise + rollback_if_postgres(conn) + ensure_kline_tables(conn) + init_strategy_tables(conn) + from risk.account_risk_lib import ensure_account_risk_schema + from recommend_store import ensure_recommend_tables + + ensure_account_risk_schema(conn) + ensure_recommend_tables(conn) + from ai_messages import ensure_ai_messages_table + + ensure_ai_messages_table(conn) + conn.commit() + conn.close() + + sync_admin_from_env() + + if not get_setting("wechat_webhook") and os.getenv("WECHAT_WEBHOOK"): + set_setting("wechat_webhook", os.getenv("WECHAT_WEBHOOK")) + + if not get_setting("ths_refresh_token") and os.getenv("THS_REFRESH_TOKEN"): + set_setting("ths_refresh_token", os.getenv("THS_REFRESH_TOKEN")) + + from ctp_settings import seed_ctp_settings_from_env + seed_ctp_settings_from_env(set_setting) + + os.makedirs(UPLOAD_DIR, exist_ok=True) + expire_old_plans() + + if not get_setting("fee_multiplier"): + set_setting("fee_multiplier", "2") + if not get_setting("trading_mode"): + set_setting("trading_mode", "simulation") + if not get_setting("position_sizing_mode"): + set_setting("position_sizing_mode", "fixed") + if not get_setting("fixed_lots"): + set_setting("fixed_lots", "1") + if not get_setting("fixed_amount"): + set_setting("fixed_amount", "5000") + if not get_setting("risk_percent"): + set_setting("risk_percent", "1") + if not get_setting("max_margin_pct"): + set_setting("max_margin_pct", "30") + if not get_setting("roll_max_margin_pct"): + set_setting("roll_max_margin_pct", "50") + if not get_setting("trailing_be_tick_buffer"): + set_setting("trailing_be_tick_buffer", "2") + if not get_setting("pending_order_timeout_min"): + set_setting("pending_order_timeout_min", "5") + if not get_setting("ai_enabled"): + set_setting("ai_enabled", "0") + if not get_setting("ai_provider"): + set_setting("ai_provider", "ollama") + if not get_setting("ai_ollama_base_url"): + set_setting("ai_ollama_base_url", "http://127.0.0.1:11434") + if not get_setting("ai_ollama_model"): + set_setting("ai_ollama_model", "qwen2.5:7b") + if not get_setting("ai_openai_base_url"): + set_setting("ai_openai_base_url", "https://api.openai.com/v1") + if not get_setting("ai_openai_model"): + set_setting("ai_openai_model", "gpt-4o-mini") + if not get_setting("ai_daily_report_enabled"): + set_setting("ai_daily_report_enabled", "1") + if not get_setting("ai_daily_report_hour"): + set_setting("ai_daily_report_hour", "15") + if not get_setting("ai_daily_report_minute"): + set_setting("ai_daily_report_minute", "5") + if not get_setting("backup_auto_enabled"): + set_setting("backup_auto_enabled", "1") + if not get_setting("backup_auto_hour"): + set_setting("backup_auto_hour", "3") + if not get_setting("backup_keep_count"): + set_setting("backup_keep_count", "30") + if not get_setting("fee_source_mode"): + set_setting("fee_source_mode", "ctp") + set_setting("fee_source_mode", "ctp") + try: + purge_non_ctp_fee_rates() + except Exception: + pass + + +def sync_admin_from_env(): + """ + 从 .env 同步管理员账号。 + - 首次建库:自动写入 ADMIN_USERNAME / ADMIN_PASSWORD + - 已建库后改 .env:需设 ADMIN_SYNC_FROM_ENV=true 并重启服务 + """ + sync = os.getenv("ADMIN_SYNC_FROM_ENV", "false").lower() in ("1", "true", "yes") + env_username = os.getenv("ADMIN_USERNAME", "").strip() + env_password = os.getenv("ADMIN_PASSWORD", "").strip() + placeholder_passwords = {"", "change-me-on-first-login", "admin123"} + + if not get_setting("admin_username"): + username = env_username or "admin" + password = env_password if env_password not in placeholder_passwords else "admin123" + set_setting("admin_username", username) + set_setting("admin_password_hash", generate_password_hash(password)) + return + + if not sync: + return + + if env_username: + set_setting("admin_username", env_username) + if env_password and env_password not in placeholder_passwords: + set_setting("admin_password_hash", generate_password_hash(env_password)) + + +if os.getenv("QIHUO_SKIP_INIT_DB") != "1": + init_db() + app.logger.info("数据库: %s", database_label()) + + +def sync_ths_token(): + set_ths_refresh_token(get_setting("ths_refresh_token")) + + +if os.getenv("QIHUO_INIT_ONLY") != "1": + sync_ths_token() + + +def build_market_quote_payload( + symbol: str, + market_code: str = "", + sina_code: str = "", + *, + prefer_sina: bool = False, +) -> dict: + if not market_code or not sina_code: + codes = ths_to_codes(symbol) + if codes: + market_code = codes.get("market_code", "") or market_code + sina_code = codes.get("sina_code", "") or sina_code + quote_source = "sina" + price = None + prev_close = None + if not prefer_sina: + try: + from vnpy_bridge import ctp_status, ctp_get_tick_detail + from trading_context import get_trading_mode + + mode = get_trading_mode(get_setting) + if ctp_status(mode).get("connected"): + detail = ctp_get_tick_detail(mode, symbol) + if detail.get("price"): + price = detail["price"] + quote_source = "ctp" + if detail.get("pre_close") is not None: + prev_close = detail["pre_close"] + except Exception: + pass + if price is None: + price = fetch_price(symbol, market_code, sina_code) + name = symbol + codes = ths_to_codes(symbol) + if codes: + name = codes.get("name", symbol) + if prev_close is None and sina_code: + from market import fetch_raw_for_volume + raw = fetch_raw_for_volume(sina_code) + if raw and raw.get("prev_close") is not None: + prev_close = raw["prev_close"] + return { + "symbol": symbol, + "name": name, + "price": price, + "prev_close": prev_close, + "quote_source": quote_source, + } + + +# —————————————— 推送 —————————————— + +def send_wechat_msg(content: str): + webhook = get_setting("wechat_webhook") + if not webhook: + return + full = f"【国内期货】\n{content}" + data = {"msgtype": "text", "text": {"content": full}} + try: + requests.post(webhook, json=data, timeout=10) + except Exception: + pass + +# —————————————— 行情 —————————————— + +def resolve_market_codes(ths_code: str, market_code: str = "", sina_code: str = "") -> tuple[str, str]: + """返回 (market_code, sina_code) 用于行情拉取。""" + if market_code: + return market_code, sina_code + if sina_code and "." in sina_code: + return sina_code, "" + codes = ths_to_codes(ths_code) + if codes: + return codes["market_code"], codes["sina_code"] + if ths_code.startswith("nf_") or ths_code.startswith("CFF_RE_"): + return ths_code, ths_code + return "", sina_code or "" + + +def fetch_price(ths_code: str, market_code: str = "", sina_code: str = "") -> Optional[float]: + sym = (ths_code or "").strip() + if sym: + try: + from vnpy_bridge import ctp_status, ctp_get_tick_price + from trading_context import get_trading_mode + + mode = get_trading_mode(get_setting) + if ctp_status(mode).get("connected"): + p = ctp_get_tick_price(mode, sym) + if p and p > 0: + return p + except Exception: + pass + mc, sc = resolve_market_codes(sym, market_code, sina_code) + if not mc and not sc: + return None + return market_get_price(mc, sc) + +# —————————————— 监控逻辑 —————————————— + +def check_order_plans(): + expire_old_plans() + today = today_str() + conn = get_db() + rows = conn.execute( + "SELECT * FROM order_plans WHERE plan_date=? AND status IN ('planned', 'active')", + (today,), + ).fetchall() + + for r in rows: + sym = r["symbol"] + sina = r["sina_code"] if "sina_code" in r.keys() else "" + market = r["market_code"] if "market_code" in r.keys() else "" + p = fetch_price(sym, market, sina) + if not p: + continue + + direction = r["direction"] + zone_upper = r["zone_upper"] + zone_lower = r["zone_lower"] + stop_loss = r["stop_loss"] + take_profit = r["take_profit"] + status = r["status"] + pid = r["id"] + name = r["symbol_name"] or sym + reason = r["decision_reason"] if "decision_reason" in r.keys() and r["decision_reason"] else "—" + + # 计划状态:价格进入决策区间则激活并通知 + if status == "planned": + in_zone = zone_lower <= p <= zone_upper + if in_zone: + msg = ( + f"【开单计划触发】{name} ({sym})\n" + f"方向:{'做多' if direction == 'long' else '做空'}\n" + f"决策区间:{zone_lower} ~ {zone_upper}\n" + f"决策理由:{reason}\n" + f"当前价:{p}\n" + f"止损:{stop_loss} 止盈:{take_profit}" + ) + send_wechat_msg(msg) + conn.execute( + "UPDATE order_plans SET status='active', triggered_at=? WHERE id=?", + (datetime.now().isoformat(), pid), + ) + status = "active" + + # 激活状态:监控止盈止损 + if status == "active": + res = None + if direction == "long": + if p >= take_profit: + res = "止盈" + elif p <= stop_loss: + res = "止损" + elif direction == "short": + if p <= take_profit: + res = "止盈" + elif p >= stop_loss: + res = "止损" + + if res: + msg = ( + f"[{'做多' if direction == 'long' else '做空'}] {name} 已{res}\n" + f"决策区间:{zone_lower} ~ {zone_upper}\n" + f"止损:{stop_loss} 止盈:{take_profit}\n" + f"当前价:{p}" + ) + send_wechat_msg(msg) + conn.execute( + """INSERT INTO trade_records + (symbol, symbol_name, monitor_type, direction, + trigger_price, stop_loss, take_profit, result) + VALUES (?,?,?,?,?,?,?,?)""", + (sym, name, "开单计划", direction, p, stop_loss, take_profit, res), + ) + conn.execute( + "UPDATE order_plans SET status='closed' WHERE id=?", (pid,) + ) + + conn.commit() + conn.close() + + +def check_key_monitors(): + from db_conn import DB_PATH + from key_monitor_lib import run_key_monitor_check + from trading_context import get_trading_mode + + conn = get_db() + try: + execute_fn = getattr(app, "_execute_key_breakout", None) + run_key_monitor_check( + conn, + db_path=DB_PATH, + get_trading_mode_fn=lambda: get_trading_mode(get_setting), + send_wechat=send_wechat_msg, + execute_breakout_fn=execute_fn, + ) + conn.commit() + finally: + conn.close() + + +def background_task(): + while True: + try: + expire_old_plans() + check_key_monitors() + fn_roll = getattr(app, "_check_roll_monitors", None) + if fn_roll: + fn_roll() + check_order_plans() + fn = getattr(app, "_check_trend_plans", None) + if fn: + fn(app) + except Exception: + pass + time.sleep(3) + + +def start_background_threads(): + from trading_context import get_trading_mode + + threading.Thread(target=background_task, daemon=True).start() + threading.Thread( + target=lambda: kline_hub.worker_loop( + DB_PATH, + lambda sym, mc, sc: build_market_quote_payload( + sym, mc, sc, prefer_sina=True, + ), + get_mode_fn=lambda: get_trading_mode(get_setting), + ), + daemon=True, + ).start() + threading.Thread(target=refresh_main_index, daemon=True).start() + start_backup_worker(get_setting_fn=get_setting, set_setting_fn=set_setting) + + +# —————————————— 登录 —————————————— + +def login_required(f): + @wraps(f) + def wrap(*args, **kwargs): + if not session.get("logged_in"): + return redirect(url_for("login")) + return f(*args, **kwargs) + return wrap + + +from modules.core import AppDeps, register_all_modules, start_module_workers + +if os.getenv("QIHUO_INIT_ONLY") != "1": + _deps = AppDeps( + app=app, + get_db=get_db, + get_setting=get_setting, + set_setting=set_setting, + login_required=login_required, + require_nav=require_nav, + fetch_price=fetch_price, + send_wechat_msg=send_wechat_msg, + touch_stats_cache=touch_stats_cache, + get_stats_data=get_stats_data, + build_market_quote_payload=build_market_quote_payload, + today_str=today_str, + expire_old_plans=expire_old_plans, + check_order_plans=check_order_plans, + check_key_monitors=check_key_monitors, + background_task=background_task, + start_background_threads=start_background_threads, + tz=TZ, + db_path=DB_PATH, + upload_dir=UPLOAD_DIR, + open_types=OPEN_TYPES, + exit_triggers=EXIT_TRIGGERS, + behavior_tags=BEHAVIOR_TAGS, + kline_periods=KLINE_PERIODS, + kline_cutoffs=KLINE_CUTOFFS, + calc_holding_duration=calc_holding_duration, + holding_to_minutes=holding_to_minutes, + classify_close_result=classify_close_result, + calc_rr_ratio=calc_rr_ratio, + calc_theoretical_pnl=calc_theoretical_pnl, + parse_review_date_filter=parse_review_date_filter, + trading_mode=_trading_mode, + static_asset_v=_static_asset_v, + ua_is_phone=_ua_is_phone, + ) + register_all_modules(_deps) + start_module_workers(_deps) + +# —————————————— 启动 —————————————— + +if __name__ == "__main__": + app.run(host=HOST, port=PORT, debug=DEBUG, threaded=True) diff --git a/config/.env.example b/config/.env.example new file mode 100644 index 0000000..28c7309 --- /dev/null +++ b/config/.env.example @@ -0,0 +1,61 @@ +# 服务配置 +HOST=0.0.0.0 +PORT=6600 +DEBUG=false + +SECRET_KEY=change-this-to-a-random-secret-key + +ADMIN_USERNAME=admin +ADMIN_PASSWORD=change-me-on-first-login +ADMIN_SYNC_FROM_ENV=false + +WECHAT_WEBHOOK= + +QUOTE_SOURCE=sina +THS_REFRESH_TOKEN= + +# 交易模式:simulation=SimNow,live=期货公司(系统设置页可改) +TRADING_MODE=simulation +POSITION_SIZING_MODE=risk +RISK_PERCENT=1 + +# CTP 断线后后台自动重连(true/false) +CTP_AUTO_RECONNECT=true + +# —— SimNow 模拟盘(也可在「系统设置 → CTP 连接」配置,优先于本文件)—— +SIMNOW_USER= +SIMNOW_PASSWORD= +SIMNOW_BROKER_ID=9999 +# 7×24 / 日盘前置(deploy.sh 会自动 nc 探测并写入可用线路) +SIMNOW_TD_ADDRESS=tcp://180.168.146.187:10201 +SIMNOW_MD_ADDRESS=tcp://180.168.146.187:10211 +SIMNOW_APP_ID=simnow_client_test +SIMNOW_AUTH_CODE=0000000000000000 +# SimNow 看穿式前置固定用「实盘」;仅穿透式测评才用「测试」 +SIMNOW_ENV=实盘 + +# —— 期货公司实盘(后期接入)—— +CTP_LIVE_USER= +CTP_LIVE_PASSWORD= +CTP_LIVE_BROKER_ID= +CTP_LIVE_TD_ADDRESS= +CTP_LIVE_MD_ADDRESS= +CTP_LIVE_APP_ID= +CTP_LIVE_AUTH_CODE= +CTP_LIVE_PRODUCT_INFO= + +# 账户冷静期 +RISK_CONTROL_ENABLED=true +RISK_COOLING_HOURS_MANUAL=4 +RISK_COOLING_HOURS_MANUAL_JOURNAL=1 +RISK_MANUAL_CLOSE_DAILY_LIMIT=2 +MAX_ACTIVE_POSITIONS=1 +RISK_DAILY_POSITION_LIMIT=5 +RISK_DAILY_TRADING_RISK_PCT=2 +TRADING_DAY_RESET_HOUR=8 + +# —— 数据库(生产推荐 PostgreSQL,见 docs/POSTGRES.md)—— +# 未配置 DATABASE_URL 时使用本地 SQLite futures.db +# DATABASE_URL=postgresql://qihuo:your_password@127.0.0.1:5432/qihuo +# PG_POOL_MIN=2 +# PG_POOL_MAX=20 diff --git a/deploy.sh b/deploy.sh index db6153b..823a1a8 100644 --- a/deploy.sh +++ b/deploy.sh @@ -161,25 +161,27 @@ pip install --upgrade pip -q pip install -r "$APP_DIR/requirements.txt" python -c "from vnpy_ctp import CtpGateway; print('vnpy_ctp OK')" -echo "==> 配置 .env..." -if [ ! -f "$APP_DIR/.env" ]; then - cp "$APP_DIR/.env.example" "$APP_DIR/.env" +echo "==> 配置 config/.env..." +ENV_FILE="$APP_DIR/config/.env" +mkdir -p "$APP_DIR/config" +if [ ! -f "$ENV_FILE" ]; then + cp "$APP_DIR/config/.env.example" "$ENV_FILE" RAND_KEY=$(python3 -c "import secrets; print(secrets.token_hex(32))") - sed -i "s/change-this-to-a-random-secret-key/${RAND_KEY}/" "$APP_DIR/.env" - echo " 已生成 .env,请编辑 SIMNOW_USER / ADMIN_PASSWORD" + sed -i "s/change-this-to-a-random-secret-key/${RAND_KEY}/" "$ENV_FILE" + echo " 已生成 config/.env,请编辑 SIMNOW_USER / ADMIN_PASSWORD" fi -ensure_env_key "$APP_DIR/.env" "SIMNOW_ENV" "实盘" -ensure_env_key "$APP_DIR/.env" "CTP_AUTO_RECONNECT" "true" -ensure_env_key "$APP_DIR/.env" "SIMNOW_BROKER_ID" "9999" -ensure_env_key "$APP_DIR/.env" "SIMNOW_APP_ID" "simnow_client_test" -ensure_env_key "$APP_DIR/.env" "SIMNOW_AUTH_CODE" "0000000000000000" -update_simnow_front_in_env "$APP_DIR/.env" || true +ensure_env_key "$ENV_FILE" "SIMNOW_ENV" "实盘" +ensure_env_key "$ENV_FILE" "CTP_AUTO_RECONNECT" "true" +ensure_env_key "$ENV_FILE" "SIMNOW_BROKER_ID" "9999" +ensure_env_key "$ENV_FILE" "SIMNOW_APP_ID" "simnow_client_test" +ensure_env_key "$ENV_FILE" "SIMNOW_AUTH_CODE" "0000000000000000" +update_simnow_front_in_env "$ENV_FILE" || true mkdir -p "$APP_DIR/logs" echo "==> 验证 CTP 环境..." -if grep -q "^SIMNOW_USER=.\+" "$APP_DIR/.env" 2>/dev/null && \ - grep -q "^SIMNOW_PASSWORD=.\+" "$APP_DIR/.env" 2>/dev/null; then +if grep -q "^SIMNOW_USER=.\+" "$ENV_FILE" 2>/dev/null && \ + grep -q "^SIMNOW_PASSWORD=.\+" "$ENV_FILE" 2>/dev/null; then set +e python "$APP_DIR/scripts/test_simnow.py" CTP_TEST=$? diff --git a/docs/ARCHITECTURE.md b/docs/ARCHITECTURE.md new file mode 100644 index 0000000..f716800 --- /dev/null +++ b/docs/ARCHITECTURE.md @@ -0,0 +1,50 @@ +# 主目录结构 + +``` +qihuo/ # 主文件夹(仓库根) +├── app.py # 主程序入口(Flask 启动) +├── requirements.txt +├── deploy.sh # 一键部署脚本 +├── ecosystem.config.cjs # PM2 启动配置 +├── config/ +│ ├── .env.example # 环境变量模板 +│ └── .env # 运行时配置(git 忽略) +├── modules/ # 业务模块(每个模块 register(deps)) +│ ├── core/ # DB、路径、公共工具 +│ ├── web/ # 页面路由 + static/ + templates/ +│ ├── trading/ # 下单监控、持仓、推荐 +│ ├── ctp/ # vn.py / CTP 连接与报单 +│ ├── risk/ # 账户风控 +│ ├── strategy/ # 趋势、滚仓策略 +│ ├── keys/ # 关键位 +│ ├── plans/ # 开单计划 +│ ├── market/ # 行情、K 线 +│ ├── records/ # 交易记录、复盘 +│ ├── stats/ # 统计、看板 +│ ├── settings/ # 系统设置 +│ ├── notify/ # 微信、AI 消息 +│ ├── fees/ # 手续费 +│ └── backup/ # 备份 +├── _legacy/ # 旧 import 兼容 shim(PM2 PYTHONPATH) +├── data/ # 静态数据(如 fee_rates.json) +├── docs/ # 文档 +├── scripts/ # 运维/诊断脚本(非运行时) +├── futures.db # SQLite(未配 PG 时) +├── uploads/ +└── logs/ +``` + +根目录 `_legacy/` 为旧 `import db_conn` 等路径的兼容层;新代码请 `from modules.xxx import ...`。 + +## 进程模型 + +- **单进程**:PM2 仅 `qihuo`(`app.py` + CTP 同进程) +- 详见 [DEPLOY.md](./DEPLOY.md) + +## 模块契约 + +每个 `modules//` 提供 `register(deps: AppDeps)`;主程序 `app.py` 只做串联,不写业务。 + +## 发布 + +见 [DEPLOY.md](./DEPLOY.md):**本地修改 → git push → 服务器 git pull**,禁止 SCP。 diff --git a/docs/DEPLOY.md b/docs/DEPLOY.md index 77ba2fa..10321b9 100644 --- a/docs/DEPLOY.md +++ b/docs/DEPLOY.md @@ -43,7 +43,7 @@ pm2 save 以下文件 **不** 随 `git pull` 更新,卸载/重装时须 **单独备份与恢复**: -- `/opt/qihuo/.env` +- `/opt/qihuo/config/.env`(兼容旧版 `/opt/qihuo/.env`) - `/opt/qihuo/futures.db`(SQLite)或 PostgreSQL 数据 - `/opt/qihuo/uploads/` - `/opt/qihuo/backups/`(若有) @@ -58,18 +58,17 @@ pm2 save | 运行用户 | `root`(与 `deploy.sh` / PM2 配置一致) | | Web 端口 | `6600`(对外) | | CTP Worker 端口 | `6601`(仅 `127.0.0.1`,Web 进程 IPC 调用,勿对外开放) | -| 进程管理 | PM2:`qihuo`(Flask Web)+ `qihuo-ctp`(CTP / vn.py 独立进程) | +| 进程管理 | PM2:**仅** `qihuo`(Flask + CTP 单进程) | | 数据库 | **生产推荐 PostgreSQL**(见 [POSTGRES.md](./POSTGRES.md));未配置 `DATABASE_URL` 时使用 SQLite `futures.db` | | 仓库 | https://git.bz121.com/dekun/qihuo.git | -### 进程架构(2026-03 起) +### 进程架构(2026-07 起:单进程) -| PM2 应用 | 角色 | 说明 | -|----------|------|------| -| `qihuo` | Web(`QIHUO_CTP_ROLE=client`) | Flask、页面、API、数据库;通过 HTTP 调用本机 Worker | -| `qihuo-ctp` | Worker(`QIHUO_CTP_ROLE=worker`) | **唯一** 加载 vn.py / vnpy_ctp;CTP 连接、报单、持仓回调、止盈止损 tick、滚仓监控 | +| PM2 应用 | 说明 | +|----------|------| +| `qihuo` | Flask Web + **vn.py / CTP 同进程**(`vnpy_bridge.CtpBridge`) | -Web 进程崩溃或重启 **不会** 直接带走 CTP 原生连接;Worker 重启后 Web 会自动通过 IPC 恢复读写。两个进程的 Token 须一致(见 `ecosystem.config.cjs` 中 `QIHUO_CTP_WORKER_TOKEN`)。 +详见 [ARCHITECTURE.md](./ARCHITECTURE.md)。旧版 `qihuo-ctp` 独立 Worker **已废弃**,`ecosystem.config.cjs` 不再启动该进程。 --- @@ -110,7 +109,7 @@ bash deploy.sh 6. 首次生成 `.env`,并补全 `SIMNOW_ENV=实盘`、`CTP_AUTO_RECONNECT=true` 等缺项 7. **自动探测 SimNow 前置**(`nc` 测端口),写入可用的 `SIMNOW_TD/MD_ADDRESS`(优先 `182.254.243.31`,其次 `180.168.146.187`) 8. 若已配置 SimNow 账号,运行 `scripts/test_simnow.py` 验证连接 -9. `pm2 restart ecosystem.config.cjs --update-env` 或首次 `pm2 start ecosystem.config.cjs`,并 `pm2 save`(同时启动 **`qihuo`** 与 **`qihuo-ctp`**) +9. `pm2 restart ecosystem.config.cjs --update-env` 或首次 `pm2 start ecosystem.config.cjs`,并 `pm2 save`(仅 **`qihuo`** 一个进程) 部署完成后访问:`http://<服务器IP>:6600` @@ -141,7 +140,7 @@ MIGRATE_SQLITE=1 sudo bash scripts/deploy_postgres.sh ```bash # 在服务器上 -cp /opt/qihuo/.env /root/qihuo.env.bak +cp /opt/qihuo/config/.env /root/qihuo.env.bak 2>/dev/null || cp /opt/qihuo/.env /root/qihuo.env.bak 2>/dev/null || true # SQLite cp /opt/qihuo/futures.db /root/futures.db.bak 2>/dev/null || true # PostgreSQL 见 POSTGRES.md 备份命令 @@ -151,8 +150,8 @@ tar czf /root/qihuo_uploads.bak.tar.gz -C /opt/qihuo uploads 2>/dev/null || true ### 2. 卸载 PM2 与代码目录 ```bash -pm2 stop qihuo qihuo-ctp 2>/dev/null || true -pm2 delete qihuo qihuo-ctp 2>/dev/null || true +pm2 stop qihuo 2>/dev/null || true +pm2 delete qihuo 2>/dev/null || true pm2 save rm -rf /opt/qihuo ``` @@ -174,7 +173,7 @@ bash deploy.sh ```bash cd /opt/qihuo && git log -1 --oneline # 须与远端 main 最新提交一致 -pm2 status # qihuo、qihuo-ctp 均为 online +pm2 status # qihuo 为 online ``` 浏览器访问 `http://<服务器IP>:6600` 登录验证。 @@ -225,8 +224,8 @@ python -c "from vnpy_ctp import CtpGateway; print('vnpy_ctp OK')" 若提示找不到模块,查看本文「CTP / vnpy 故障排查」一节。 ```bash -cp .env.example .env -nano .env +cp config/.env.example config/.env +nano config/.env ``` | 变量 | 说明 | @@ -318,7 +317,7 @@ pm2 restart ecosystem.config.cjs --update-env pm2 save ``` -> 须 **同时重启** `qihuo` 与 `qihuo-ctp`。仅 `pm2 restart qihuo` 会导致 Web 与 Worker 代码/协议不一致。 +> 更新后执行 `pm2 restart ecosystem.config.cjs --update-env` 即可(仅 `qihuo`)。 若服务器曾用 SCP 覆盖文件导致 `git pull` 冲突,用 `git reset --hard origin/main` 与远端对齐。 diff --git a/ecosystem.config.cjs b/ecosystem.config.cjs index 24e8545..5049f45 100644 --- a/ecosystem.config.cjs +++ b/ecosystem.config.cjs @@ -21,12 +21,10 @@ module.exports = { max_memory_restart: "8192M", env: { NODE_ENV: "production", + PYTHONPATH: path.join(ROOT, "_legacy"), LANG: "zh_CN.UTF-8", LC_ALL: "zh_CN.UTF-8", LC_CTYPE: "zh_CN.UTF-8", - QIHUO_CTP_ROLE: "client", - QIHUO_CTP_WORKER_URL: "http://127.0.0.1:6601", - QIHUO_CTP_WORKER_TOKEN: "qihuo-local-ctp", QIHUO_STARTUP_WORKERS: "8", QIHUO_MEMORY_MB: "8192", }, @@ -34,29 +32,5 @@ module.exports = { out_file: path.join(ROOT, "logs", "pm2-out.log"), time: true, }, - { - name: "qihuo-ctp", - script: "ctp_worker.py", - cwd: ROOT, - interpreter, - instances: 1, - autorestart: true, - watch: false, - max_memory_restart: "8192M", - env: { - NODE_ENV: "production", - LANG: "zh_CN.UTF-8", - LC_ALL: "zh_CN.UTF-8", - LC_CTYPE: "zh_CN.UTF-8", - QIHUO_CTP_ROLE: "worker", - QIHUO_CTP_WORKER_HOST: "127.0.0.1", - QIHUO_CTP_WORKER_PORT: "6601", - QIHUO_CTP_WORKER_TOKEN: "qihuo-local-ctp", - QIHUO_MEMORY_MB: "8192", - }, - error_file: path.join(ROOT, "logs", "pm2-ctp-error.log"), - out_file: path.join(ROOT, "logs", "pm2-ctp-out.log"), - time: true, - }, ], }; diff --git a/install_trading.py b/install_trading.py index 7750efd..cdbb6cb 100644 --- a/install_trading.py +++ b/install_trading.py @@ -1,4687 +1,6 @@ # Copyright (c) 2025-2026 马建军. All rights reserved. -# 专有软件 — 未经授权禁止复制、传播、转售。 -# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 -# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md +"""Backward-compatible shim — implementation in modules.trading.install.""" -"""期货下单、可开仓品种、策略交易路由注册。""" -from __future__ import annotations +from modules.trading.install import install_trading -import json -import logging -import os -import threading -import time -from datetime import datetime -from typing import Any, Callable, Optional - -from flask import flash, jsonify, redirect, render_template, request, url_for, Response, stream_with_context - -from contract_specs import calc_position_metrics, get_contract_spec -from fee_specs import calc_fee_breakdown -from kline_stream import sse_format -from market_sessions import is_night_trading_session, is_trading_session, trading_session_clock -from position_sizing import ( - MODE_AMOUNT, - MODE_FIXED, - DEFAULT_MAX_ORDER_LOTS, - calc_lots_by_amount, - calc_lots_by_risk, - calc_margin_usage_pct, - cap_lots_for_margin_budget, - calc_order_tick_metrics, - normalize_sizing_mode, -) -from product_recommend import ( - assert_product_allowed_for_capital, - should_apply_small_account_scope, - small_account_margin_recommendations, - small_account_scope_hint, - SMALL_ACCOUNT_SCOPE_LABEL, -) -from recommend_store import ( - recommend_payload, - refresh_recommend_cache, -) -from recommend_stream import recommend_hub, schedule_recommend_refresh, start_recommend_worker -from position_stream import position_hub, start_position_worker -from ctp_settings import is_ctp_auto_connect_enabled -from ctp_reconnect import start_ctp_reconnect_worker -from ctp_premarket_connect import start_ctp_premarket_connect_worker -from ctp_fee_worker import start_ctp_fee_worker -from pending_order_worker import start_pending_order_worker -from order_pending import ( - cancel_pending_monitor, - pending_auto_cancel_remaining, - pending_monitor_has_live_order, - reconcile_pending_orders, -) -from db_conn import commit_retry, execute_retry -from sl_tp_guard import ( - cancel_monitor_exit_orders, - ensure_monitor_order_columns, - monitor_order_status, - monitor_source_label, - place_monitor_exit_orders, - reconcile_monitors_without_position, - start_sl_tp_guard_worker, - write_manual_close_trade_log, -) -from risk.account_risk_lib import ( - assert_can_open, - count_active_trade_monitors, - get_risk_status, - on_mood_journal_freeze, - on_user_initiated_close, - parse_mood_issues, - trading_day_label, -) -from strategy.strategy_db import init_strategy_tables -from strategy.strategy_roll_lib import ( - ADD_MODE_BREAKOUT, - ADD_MODE_MARKET, - FIB_MODES, - LEG_STATUS_CANCELLED, - LEG_STATUS_FILLED, - LEG_STATUS_PENDING, - PENDING_MODES, - add_mode_label, - avg_entry_after_add, - preview_roll, - roll_eligibility_error, -) -from strategy.strategy_roll_monitor_lib import ( - cancel_roll_leg, - check_roll_monitors, - roll_sync_after_external_close, -) -from strategy.strategy_snapshot_lib import list_snapshots, save_snapshot -from strategy.strategy_trend_lib import ( - compute_trend_plan_futures, - enrich_trend_plan_preview, - normalize_trend_period, - trend_dca_level_reached, - trend_period_label, - trend_strategy_periods, -) -from strategy.strategy_snapshot_lib import STRATEGY_ROLL, STRATEGY_TREND -from symbols import ths_to_codes, resolve_main_contract, PRODUCTS, PRODUCT_CATEGORIES, position_symbol_meta -from trading_context import ( - TRADING_MODE_LIVE, - TRADING_MODE_SIM, - get_account_capital, - get_fixed_amount, - get_fixed_lots, - get_max_margin_pct, - get_pending_order_timeout_min, - get_pending_order_timeout_sec, - get_recommend_capital, - get_roll_max_margin_pct, - get_risk_percent, - get_sizing_mode, - get_trailing_be_tick_buffer, - get_trading_mode, - is_ctp_connected, - trading_mode_label, -) -from ctp_entry_price import round_to_tick -from ctp_symbol import ths_to_vnpy_symbol -from ctp_trading_state import position_key, trading_state -from vnpy_bridge import ( - _ctp_td_lock, - ctp_cancel_order, - ctp_connect, - ctp_account_margin_used, - ctp_estimate_margin_one_lot, - ctp_get_account, - ctp_get_tick_price, - ctp_list_active_orders, - ctp_list_positions, - ctp_list_trades, - ctp_status, - execute_order, - get_bridge, - set_position_refresh_callback, - set_tick_sl_tp_callback, - set_tick_quote_callback, - set_ctp_connected_callback, -) - - -logger = logging.getLogger(__name__) - - -def install_trading(app, *, login_required, require_nav, get_db, get_setting, set_setting, fetch_price, send_wechat_msg): - """注册交易相关路由。""" - _nav = require_nav - _live_refresh_lock = threading.Lock() - _ctp_status_cache: dict = {"mode": "", "status": {}, "ts": 0.0} - _ctp_status_cache_lock = threading.Lock() - _ctp_status_refresh_flag = {"busy": False} - - def _remember_ctp_status(mode: str, st: dict) -> None: - if not isinstance(st, dict) or not st: - return - with _ctp_status_cache_lock: - _ctp_status_cache["mode"] = mode - _ctp_status_cache["status"] = dict(st) - _ctp_status_cache["ts"] = time.time() - - def _schedule_ctp_status_refresh(mode: str) -> None: - with _ctp_status_cache_lock: - if _ctp_status_refresh_flag["busy"]: - return - _ctp_status_refresh_flag["busy"] = True - - def _run() -> None: - try: - st = dict(ctp_status(mode) or {}) - _remember_ctp_status(mode, st) - snap = position_hub.get_snapshot() - if snap: - merged = dict(snap) - merged["ctp_status"] = st - position_hub.set_snapshot(merged) - except Exception as exc: - logger.debug("ctp status refresh: %s", exc) - finally: - with _ctp_status_cache_lock: - _ctp_status_refresh_flag["busy"] = False - - threading.Thread( - target=_run, - daemon=True, - name="ctp-status-refresh", - ).start() - - def _cached_ctp_status(mode: str) -> dict: - """页面/SSE 优先读快照与内存缓存,避免同步 worker IPC 阻塞 HTTP 线程。""" - try: - snap = position_hub.get_snapshot() or {} - st = snap.get("ctp_status") - if isinstance(st, dict) and st: - _remember_ctp_status(mode, st) - return dict(st) - except Exception: - pass - with _ctp_status_cache_lock: - if _ctp_status_cache["mode"] == mode and _ctp_status_cache["status"]: - return dict(_ctp_status_cache["status"]) - _schedule_ctp_status_refresh(mode) - return { - "connected": False, - "connecting": True, - "last_error": "", - "mode_label": trading_mode_label(get_setting), - } - - def _sizing_mode_label(mode: str) -> str: - m = normalize_sizing_mode(mode) - if m == MODE_AMOUNT: - return "固定金额" - return "固定手数" - - def _symbol_display_fields(sym: str) -> dict: - meta = position_symbol_meta(sym) - name = meta.get("name") or sym - return { - "symbol": name, - "symbol_name": name, - "symbol_exchange": meta.get("exchange") or "", - "symbol_is_main": bool(meta.get("is_main")), - } - - def _breakeven_locked( - *, - entry: Optional[float], - stop_loss: Optional[float], - direction: str, - tick_size: Optional[float] = None, - trailing_r_locked: int = 0, - ) -> bool: - if int(trailing_r_locked or 0) >= 1: - return True - if entry is None or stop_loss is None: - return False - try: - entry_f = float(entry) - sl_f = float(stop_loss) - except (TypeError, ValueError): - return False - if entry_f <= 0: - return False - tick = float(tick_size or 0) or max(abs(entry_f) * 1e-6, 0.01) - be_mult = max(1, get_trailing_be_tick_buffer(get_setting)) - d = (direction or "long").strip().lower() - expected_be = entry_f + be_mult * tick if d == "long" else entry_f - be_mult * tick - tol = be_mult * tick + tick * 0.05 - if abs(sl_f - expected_be) <= tol: - return True - buf = tick * max(2, be_mult) - near = abs(sl_f - entry_f) <= buf + tick - if d == "long": - return near and sl_f >= entry_f - tick * 0.05 - return near and sl_f <= entry_f + tick * 0.05 - - def _schedule_recommend_refresh() -> None: - from db_conn import DB_PATH - - schedule_recommend_refresh( - db_path=DB_PATH, - get_capital_fn=_recommend_capital, - quote_fn=_main_quote, - init_tables_fn=lambda c: init_strategy_tables(c), - get_mode_fn=lambda: get_trading_mode(get_setting), - get_max_margin_pct_fn=lambda: get_max_margin_pct(get_setting), - get_sizing_mode_fn=lambda: get_sizing_mode(get_setting), - get_fixed_lots_fn=lambda: get_fixed_lots(get_setting), - ) - - def _recommend_payload(conn, *, use_ctp_margin: bool = True) -> dict: - mode = get_trading_mode(get_setting) - return recommend_payload( - conn, - live_capital=_recommend_capital(conn), - max_margin_pct=get_max_margin_pct(get_setting), - trading_mode=mode, - sizing_mode=get_sizing_mode(get_setting), - fixed_lots=get_fixed_lots(get_setting), - use_ctp_margin=use_ctp_margin, - ) - - def _recommend_capital(conn) -> float: - return get_recommend_capital(conn, get_setting) - - def _settings_dict() -> dict: - return { - "trading_mode": get_trading_mode(get_setting), - "position_sizing_mode": get_sizing_mode(get_setting), - "risk_percent": str(get_risk_percent(get_setting)), - "max_margin_pct": str(get_max_margin_pct(get_setting)), - } - - def _capital(conn) -> float: - return get_account_capital(conn, get_setting) - - def _main_quote(product_ths: str) -> Optional[dict]: - for p in PRODUCTS: - if p["ths"] == product_ths: - main = resolve_main_contract(p) - if not main: - return None - sym = main.get("ths_code") or "" - codes = ths_to_codes(sym) - price = None - if codes: - price = fetch_price( - sym, - codes.get("market_code", ""), - codes.get("sina_code", ""), - ) - return { - "ths_code": sym, - "price": price, - "display": main.get("display") or sym, - "name": main.get("name") or p.get("name"), - } - return None - - def _ctp_account(mode: str) -> dict: - try: - return ctp_get_account(mode) - except Exception: - return {} - - def _ctp_positions( - mode: str, - *, - refresh_if_empty: bool = True, - refresh_margin: bool = False, - ) -> list: - try: - return ctp_list_positions( - mode, - refresh_if_empty=refresh_if_empty, - refresh_margin=refresh_margin, - ) - except Exception: - return [] - - def _ctp_pos_to_ths_code(p: dict) -> str: - sym = (p.get("symbol") or "").strip() - ex = (p.get("exchange") or "").strip() - if not sym: - return "" - codes = ths_to_codes(sym) - if codes: - return codes.get("ths_code") or sym - if ex: - from vnpy_bridge import CtpBridge - ths = CtpBridge._vnpy_sym_to_ths(sym, ex) - if ths: - return ths - return sym - - def _resolve_position_margin( - *, - sym: str, - direction: str, - lots: int, - entry: float, - mode: str, - ctp: Optional[dict] = None, - mon_margin: Optional[float] = None, - est_margin: Optional[float] = None, - ) -> tuple[Optional[float], str]: - """占用保证金:柜台持仓 > CTP 合约率估算 > 本地规格估算 > 库内缓存。""" - ctp_margin = float(ctp.get("margin") or 0) if ctp else 0.0 - if ctp_margin > 0: - return round(ctp_margin, 2), "ctp" - connected = bool(ctp_status(mode).get("connected")) - ths_sym = sym - if ctp: - ths_sym = _ctp_pos_to_ths_code(ctp) or sym - else: - codes = ths_to_codes(sym) - if codes and codes.get("ths_code"): - ths_sym = codes["ths_code"] - if connected and ths_sym and entry > 0 and lots > 0: - per_lot = ctp_estimate_margin_one_lot( - mode, ths_sym, entry, direction=direction, - ) - if per_lot and per_lot > 0: - return round(per_lot * lots, 2), "ctp" - if est_margin and float(est_margin) > 0: - return round(float(est_margin), 2), "estimate" - if not connected and mon_margin and float(mon_margin) > 0: - return round(float(mon_margin), 2), "db" - return None, "estimate" - - def _apply_account_margin_to_rows( - rows: list[dict], - mode: str, - capital: float, - ) -> list[dict]: - """仅在持仓缺少柜台保证金时补全;已有 CTP 持仓保证金的行不覆盖。""" - if not ctp_status(mode).get("connected"): - return rows - active = [ - r for r in rows - if r.get("order_state") != "pending" and int(r.get("lots") or 0) > 0 - ] - if not active: - return rows - - def _has_ctp_margin(row: dict) -> bool: - return ( - float(row.get("margin") or 0) > 0 - and row.get("margin_source") == "ctp" - ) - - without_margin = [r for r in active if not _has_ctp_margin(r)] - for row in active: - if _has_ctp_margin(row) and capital > 0: - m = float(row.get("margin") or 0) - row["position_pct"] = round(m / capital * 100, 2) - if not without_margin: - return rows - - total_used = ctp_account_margin_used(mode) - if not total_used: - return rows - known_sum = sum( - float(r.get("margin") or 0) for r in active if _has_ctp_margin(r) - ) - pool = max(0.0, float(total_used) - known_sum) if known_sum > 0 else float(total_used) - if pool <= 0: - return rows - - weights: list[float] = [] - for row in without_margin: - sym = (row.get("symbol_code") or "").strip() - lots = int(row.get("lots") or 0) - entry = float(row.get("entry_price") or 0) - if sym and lots > 0 and entry > 0: - spec = get_contract_spec(sym) - weights.append(entry * spec["mult"] * lots) - else: - weights.append(0.0) - total_weight = sum(weights) - assigned = 0.0 - for i, row in enumerate(without_margin): - if total_weight <= 0: - margin = round(pool / len(without_margin), 2) - elif i == len(without_margin) - 1: - margin = round(pool - assigned, 2) - else: - margin = round(pool * weights[i] / total_weight, 2) - assigned += margin - row["margin"] = margin - row["margin_source"] = "ctp" - if capital > 0: - row["position_pct"] = round(margin / capital * 100, 2) - return rows - - def _persist_ctp_snapshot_to_monitors( - conn, - rows: list[dict], - mode: str, - ) -> None: - """将柜台校正后的均价、手数、现价、浮盈、保证金等写入 trade_order_monitors。""" - if not ctp_status(mode).get("connected"): - return - ensure_monitor_order_columns(conn) - for row in rows: - mid = row.get("monitor_id") - if not mid or row.get("order_state") == "pending": - continue - entry_price = row.get("entry_price") - lots = row.get("lots") - mark_price = row.get("mark_price") - if mark_price is None: - mark_price = row.get("current_price") - float_pnl = row.get("float_pnl") - margin = row.get("margin") - position_pct = row.get("position_pct") - open_fee = row.get("est_fee") - if ( - entry_price is None and lots is None and mark_price is None - and float_pnl is None and margin is None - and position_pct is None and open_fee is None - ): - continue - try: - execute_retry( - conn, - """UPDATE trade_order_monitors SET - entry_price=COALESCE(?, entry_price), - lots=COALESCE(?, lots), - mark_price=COALESCE(?, mark_price), - float_pnl=COALESCE(?, float_pnl), - margin=COALESCE(?, margin), - position_pct=COALESCE(?, position_pct), - open_fee=COALESCE(?, open_fee) - WHERE id=? AND status='active'""", - ( - entry_price, lots, mark_price, float_pnl, - margin, position_pct, open_fee, int(mid), - ), - ) - except Exception as exc: - logger.debug("persist monitor ctp snapshot %s: %s", mid, exc) - - def _positions_from_live_snapshot() -> list[dict]: - snap = position_hub.get_snapshot() or {} - out: list[dict] = [] - for row in snap.get("rows") or []: - lots = int(row.get("lots") or 0) - if lots <= 0 or row.get("order_state") == "pending": - continue - sym = ( - row.get("symbol_code") - or row.get("ths_code") - or row.get("symbol") - or "" - ) - if not sym: - continue - out.append({ - "symbol": sym, - "direction": row.get("direction") or "long", - "lots": lots, - "avg_price": row.get("entry_price") or row.get("avg_price") or 0, - "open_time": row.get("open_time") or "", - "margin": row.get("margin"), - "pnl": row.get("float_pnl"), - "mark_price": row.get("mark_price") or row.get("current_price"), - "exchange": row.get("exchange") or "", - }) - return out - - def _positions_for_monitor_restore(mode: str, *, allow_ctp: bool = True) -> list[dict]: - if allow_ctp: - positions = list(_ctp_positions(mode, refresh_if_empty=True) or []) - if positions: - return positions - positions = list(trading_state.get_positions() or []) - if positions: - return positions - positions = _positions_from_live_snapshot() - if not allow_ctp: - return positions - margin_used = float(ctp_account_margin_used(mode) or 0) - if margin_used <= 100 or not positions: - return [] - return positions - - def _cached_position_mark(sym: str, direction: str = "") -> Optional[float]: - sym_l = (sym or "").strip().lower() - direction_l = (direction or "").strip().lower() - for p in list(trading_state.get_positions() or []) + _positions_from_live_snapshot(): - if direction_l and (p.get("direction") or "long").strip().lower() != direction_l: - continue - ps = (p.get("symbol") or "").strip() - if not ps: - continue - if not _match_ctp_symbol(ps, sym_l): - continue - for key in ("mark_price", "current_price", "last_price"): - val = p.get(key) - try: - px = float(val or 0) - except (TypeError, ValueError): - px = 0.0 - if px > 0: - return px - snap = position_hub.get_snapshot() or {} - for row in snap.get("rows") or []: - rs = row.get("symbol_code") or row.get("symbol") or "" - if not rs or not _match_ctp_symbol(rs, sym_l): - continue - if direction_l and (row.get("direction") or "long").strip().lower() != direction_l: - continue - for key in ("mark_price", "current_price", "last_price", "entry_price"): - try: - px = float(row.get(key) or 0) - except (TypeError, ValueError): - px = 0.0 - if px > 0: - return px - return None - - def _ensure_monitors_from_ctp(conn, mode: str, *, allow_ctp: bool = True) -> None: - """CTP 有持仓但本地无监控时,自动补写一条 active 记录供展示。""" - if not ctp_status(mode).get("connected"): - return - ctp_positions = _positions_for_monitor_restore(mode, allow_ctp=allow_ctp) - for p in ctp_positions: - lots = int(p.get("lots") or 0) - if lots <= 0: - continue - direction = p.get("direction") or "long" - ths = _ctp_pos_to_ths_code(p) - if not ths: - continue - existing = _find_or_revive_monitor(conn, ths, direction) - if existing: - _sync_monitor_from_ctp( - conn, int(existing["id"]), ths, direction, mode, ctp=p, - capital=_capital(conn), - ) - continue - sl, tp, trailing_be, initial_sl = _restore_sl_tp_from_closed(conn, ths, direction) - ctp_open = (p.get("open_time") or "").strip() - mid = _upsert_open_monitor( - conn, - sym=ths, - direction=direction, - lots=lots, - price=float(p.get("avg_price") or 0), - sl=sl, - tp=tp, - trailing_be=trailing_be, - ctp_open_time=ctp_open or None, - monitor_type="ctp_sync", - ) - if initial_sl is not None and sl is not None: - conn.execute( - "UPDATE trade_order_monitors SET initial_stop_loss=? WHERE id=?", - (initial_sl, mid), - ) - if ctp_positions: - return - _ensure_monitors_from_sticky_state(conn, mode) - - def _ensure_monitors_from_sticky_state(conn, mode: str) -> None: - """vnpy 持仓空窗但账户仍有保证金时,恢复本地 active 监控。""" - if not ctp_status(mode).get("connected"): - return - margin_raw = ctp_account_margin_used(mode) - if margin_raw is None or float(margin_raw or 0) <= 0: - return - if count_active_trade_monitors(conn) > 0: - return - capital = _capital(conn) - for p in trading_state.get_positions() or []: - lots = int(p.get("lots") or 0) - if lots <= 0: - continue - direction = p.get("direction") or "long" - ths = _ctp_pos_to_ths_code(p) or (p.get("symbol") or "") - if not ths: - continue - existing = _find_or_revive_monitor(conn, ths, direction) - if existing: - _sync_monitor_from_ctp( - conn, int(existing["id"]), ths, direction, mode, ctp=p, - capital=capital, - ) - continue - sl, tp, trailing_be, initial_sl = _restore_sl_tp_from_closed(conn, ths, direction) - mid = _upsert_open_monitor( - conn, - sym=ths, - direction=direction, - lots=lots, - price=float(p.get("avg_price") or 0), - sl=sl, - tp=tp, - trailing_be=trailing_be, - ctp_open_time=(p.get("open_time") or "").strip() or None, - monitor_type="ctp_sync", - ) - if initial_sl is not None and sl is not None: - conn.execute( - "UPDATE trade_order_monitors SET initial_stop_loss=? WHERE id=?", - (initial_sl, mid), - ) - if count_active_trade_monitors(conn) > 0: - return - today = datetime.now().strftime("%Y-%m-%d") - for r in conn.execute( - "SELECT * FROM trade_order_monitors WHERE status='closed' " - "AND open_time LIKE ? ORDER BY id DESC LIMIT 5", - (f"{today}%",), - ).fetchall(): - mon = dict(r) - if int(mon.get("lots") or 0) <= 0: - continue - revived = _revive_closed_monitor( - conn, mon.get("symbol") or "", mon.get("direction") or "long", - ) - if revived: - logger.info( - "保证金占用下恢复监控 id=%s sym=%s", - revived.get("id"), revived.get("symbol"), - ) - break - - def _restore_recent_pending_monitors(conn, mode: str) -> None: - """重启或 vnpy 委托缓存丢失时,恢复当日最近一笔可能仍有效的开仓挂单。""" - if not ctp_status(mode).get("connected"): - return - if conn.execute("SELECT 1 FROM trade_order_monitors WHERE status='pending' LIMIT 1").fetchone(): - return - today = datetime.now().strftime("%Y-%m-%d") - row = conn.execute( - """SELECT * FROM trade_order_monitors - WHERE status='closed' AND monitor_type='manual' - AND vt_order_id IS NOT NULL AND vt_order_id != '' - AND open_time LIKE ? - ORDER BY id DESC LIMIT 1""", - (f"{today}%",), - ).fetchone() - if not row: - return - mon = dict(row) - sym = mon.get("symbol") or "" - direction = (mon.get("direction") or "long").strip().lower() - if _find_active_monitor(conn, sym, direction): - return - for p in _ctp_positions(mode, refresh_if_empty=False): - if int(p.get("lots") or 0) <= 0: - continue - if (p.get("direction") or "long") != direction: - continue - if _match_ctp_symbol(p.get("symbol") or "", sym): - return - conn.execute( - "UPDATE trade_order_monitors SET status='pending' WHERE id=?", - (mon["id"],), - ) - logger.info("恢复挂单监控 id=%s sym=%s", mon.get("id"), sym) - - def _match_ctp_symbol(ctp_sym: str, ths: str) -> bool: - a = (ctp_sym or "").lower() - b = (ths or "").lower() - if a == b: - return True - if a and b and a.split(".")[0] == b.split(".")[0]: - return True - try: - vnpy_sym, _ = ths_to_vnpy_symbol(ths) - if a == vnpy_sym.lower(): - return True - except Exception: - pass - try: - vnpy_sym, _ = ths_to_vnpy_symbol(ctp_sym) - if vnpy_sym.lower() == b.split(".")[0]: - return True - except Exception: - pass - return False - - def _live_entry_price( - sym: str, - direction: str, - mode: str, - fallback: float = 0.0, - *, - allow_ctp: bool = False, - ) -> float: - """滚仓/展示用均价:仅柜台持仓价。""" - if not ctp_status(mode).get("connected"): - return fallback - positions = list(trading_state.get_positions() or []) - if not positions: - positions = _positions_from_live_snapshot() - if not positions and allow_ctp: - positions = _ctp_positions(mode, refresh_if_empty=False) - for p in positions: - if (p.get("direction") or "long") != (direction or "long"): - continue - if not _match_ctp_symbol(p.get("symbol") or "", sym): - continue - avg = float(p.get("avg_price") or 0) - if avg > 0: - return avg - return fallback - - def _resolve_ctp_entry_price( - mode: str, - sym: str, - direction: str, - ctp: Optional[dict], - ) -> tuple[float, str]: - del mode, direction - if not ctp: - return 0.0, "none" - avg = float(ctp.get("avg_price") or 0) - if avg > 0: - return round_to_tick(avg, sym), "ctp" - return 0.0, "none" - - def _open_commission_from_ctp_trades( - mode: str, sym: str, direction: str, - ) -> Optional[float]: - """汇总该持仓开仓成交的柜台手续费(成交回报中的 commission)。""" - if not ctp_status(mode).get("connected"): - return None - try: - trades = ctp_list_trades(mode) - except Exception: - return None - total = 0.0 - has_commission = False - for t in trades: - if (t.get("offset") or "").strip().lower() != "open": - continue - pos_dir = ( - t.get("position_direction") or t.get("direction") or "long" - ).strip().lower() - if pos_dir != (direction or "long").strip().lower(): - continue - if not _match_ctp_symbol(t.get("symbol") or "", sym): - continue - comm = float(t.get("commission") or 0) - total += comm - if comm > 0: - has_commission = True - return round(total, 2) if has_commission else None - - def _time_str(val) -> str: - if val is None: - return "" - if isinstance(val, str): - return val.strip() - return str(val).strip() - - def _holding_duration(open_time: str, now_iso: str) -> str: - try: - from app import calc_holding_duration - open_s = _time_str(open_time).replace("T", " ")[:19] - now_s = (now_iso or "").strip().replace("T", " ")[:19] - if not open_s or not now_s: - return "" - return calc_holding_duration(open_s, now_s) - except Exception: - return "" - - def _restore_sl_tp_from_closed(conn, sym: str, direction: str) -> tuple: - """重启后从最近关闭的同品种监控恢复止盈止损。""" - direction = (direction or "long").strip().lower() - for r in conn.execute( - "SELECT symbol, direction, stop_loss, take_profit, trailing_be, initial_stop_loss " - "FROM trade_order_monitors WHERE status='closed' ORDER BY id DESC LIMIT 80" - ).fetchall(): - row = dict(r) - if (row.get("direction") or "long") != direction: - continue - if not _match_ctp_symbol(sym, row.get("symbol") or ""): - continue - if row.get("stop_loss") is None and row.get("take_profit") is None: - continue - return ( - row.get("stop_loss"), - row.get("take_profit"), - int(row.get("trailing_be") or 0), - row.get("initial_stop_loss"), - ) - return None, None, 0, None - - def _restore_monitor_sl_tp_if_missing( - conn, - mon: Optional[dict], - sym: str, - direction: str, - ) -> Optional[dict]: - """活跃监控缺少止盈止损时,从最近关闭的同品种记录恢复并写回数据库。""" - if not mon: - return None - sl = mon.get("stop_loss") - tp = mon.get("take_profit") - trailing = int(mon.get("trailing_be") or 0) - if sl is not None or tp is not None or trailing: - return mon - rsl, rtp, rtrail, rinitial = _restore_sl_tp_from_closed(conn, sym, direction) - if rsl is None and rtp is None and not rtrail: - return mon - execute_retry( - conn, - """UPDATE trade_order_monitors SET - stop_loss=?, take_profit=?, trailing_be=?, initial_stop_loss=? - WHERE id=? AND status='active'""", - (rsl, rtp, rtrail, rinitial, int(mon["id"])), - ) - row = conn.execute( - "SELECT * FROM trade_order_monitors WHERE id=?", (int(mon["id"]),), - ).fetchone() - if row: - logger.info( - "恢复止盈止损 monitor=%s sym=%s sl=%s tp=%s", - mon.get("id"), sym, rsl, rtp, - ) - return dict(row) - return mon - - def _ctp_position_keys(mode: str) -> set[tuple[str, str]]: - keys: set[tuple[str, str]] = set() - for p in _ctp_positions(mode): - lots = int(p.get("lots") or 0) - if lots <= 0: - continue - sym = (p.get("symbol") or "").lower() - direction = p.get("direction") or "long" - keys.add((sym, direction)) - return keys - - def _monitor_matches_ctp_position(mon: dict, position_keys: set[tuple[str, str]]) -> bool: - ms = mon.get("symbol") or "" - md = mon.get("direction") or "long" - for ps, pd in position_keys: - if pd != md: - continue - if _match_ctp_symbol(ps, ms): - return True - return False - - def _sync_trade_monitors_with_ctp(conn, mode: str) -> int: - """关闭无对应 CTP 持仓的监控,并撤销残留止盈止损挂单。""" - return reconcile_monitors_without_position(conn, mode) - - def _effective_active_position_count( - conn, - mode: str, - *, - ctp_connected: Optional[bool] = None, - ) -> int: - """风控持仓数以柜台/快照实际持仓优先,本地监控作兜底。""" - monitor_count = count_active_trade_monitors(conn) - if ctp_connected is None: - ctp_connected = bool(_cached_ctp_status(mode).get("connected")) - if not ctp_connected: - return monitor_count - keys: set[tuple[str, str]] = set() - for p in _positions_for_monitor_restore(mode, allow_ctp=False): - lots = int(p.get("lots") or 0) - if lots <= 0: - continue - sym = ( - p.get("symbol") - or p.get("symbol_code") - or p.get("ths_code") - or "" - ).strip().lower() - direction = (p.get("direction") or "long").strip().lower() - if sym: - keys.add((sym, direction)) - return max(monitor_count, len(keys)) - - def _build_pending_orders(conn, mode: str) -> list[dict]: - pending: list[dict] = [] - for r in conn.execute( - "SELECT * FROM trade_order_monitors WHERE status='active' ORDER BY id DESC" - ).fetchall(): - mon = dict(r) - sym = mon.get("symbol") or "" - direction = mon.get("direction") or "long" - lots = int(mon.get("lots") or 0) - base = { - "symbol_code": sym, - "direction": direction, - "direction_label": "做多" if direction == "long" else "做空", - "lots": lots, - "source": "monitor", - "monitor_id": mon.get("id"), - **_symbol_display_fields(sym), - } - sl = mon.get("stop_loss") - tp = mon.get("take_profit") - if sl is not None: - pending.append({ - **base, - "order_kind": "stop_loss", - "label": "止损监控", - "price": float(sl), - }) - if tp is not None: - pending.append({ - **base, - "order_kind": "take_profit", - "label": "止盈监控", - "price": float(tp), - }) - for r in conn.execute( - "SELECT * FROM trade_order_monitors WHERE status='pending' ORDER BY id DESC" - ).fetchall(): - mon = dict(r) - sym = mon.get("symbol") or "" - pending.append({ - "symbol_code": sym, - "direction": mon.get("direction") or "long", - "direction_label": "做多" if (mon.get("direction") or "long") == "long" else "做空", - "lots": int(mon.get("lots") or 0), - "price": float(mon.get("order_price") or mon.get("entry_price") or 0), - "order_kind": "open_pending", - "label": "开仓挂单中", - "source": "monitor", - "monitor_id": mon.get("id"), - "can_cancel_order": is_trading_session(), - "cancel_allowed": is_trading_session(), - **_symbol_display_fields(sym), - }) - ctp_st = ctp_status(mode) - if ctp_st.get("connected"): - for o in _ctp_active_orders(mode): - sym = _ctp_pos_to_ths_code(o) or (o.get("symbol") or "") - offset_s = (o.get("offset") or "").upper() - kind = "limit" - label = "委托挂单" - if "CLOSE" in offset_s: - label = "平仓委托" - pending.append({ - "symbol_code": sym, - "symbol": _symbol_display_fields(sym).get("symbol_name") or sym, - "direction": o.get("direction") or "long", - "direction_label": "做多" if o.get("direction") == "long" else "做空", - "lots": int(o.get("lots") or 0), - "price": float(o.get("price") or 0), - "order_kind": kind, - "label": label, - "source": "ctp", - "order_id": o.get("order_id"), - "vt_order_id": o.get("vt_order_id") or o.get("order_id"), - "can_cancel_order": is_trading_session(), - "cancel_allowed": is_trading_session(), - **_symbol_display_fields(sym), - }) - return pending - - def _ctp_active_orders(mode: str) -> list: - try: - return ctp_list_active_orders(mode) - except Exception: - return [] - - def _canonical_position_key(symbol: str, direction: str, exchange: str = "") -> str: - sym = (symbol or "").strip() - d = (direction or "long").strip().lower() - ex = (exchange or "").strip().upper() - try: - vnpy_sym, ex2 = ths_to_vnpy_symbol(sym) - sym = vnpy_sym - if not ex: - ex = ex2 - except Exception: - sym = sym.lower() - return position_key(ex, sym, d) - - def _position_key_from_ctp(p: dict) -> str: - return position_key( - p.get("exchange") or "", - p.get("symbol") or "", - p.get("direction") or "long", - ) - - def _monitor_position_key(mon: dict, exchange: str = "") -> str: - sym = (mon.get("symbol") or "").strip() - d = (mon.get("direction") or "long").strip().lower() - ex = (exchange or "").strip().upper() - try: - vnpy_sym, ex2 = ths_to_vnpy_symbol(sym) - sym = vnpy_sym - if not ex: - ex = ex2 - except Exception: - sym = sym.lower() - return position_key(ex, sym, d) - - def _monitors_by_position_key(conn) -> dict[str, dict]: - ensure_monitor_order_columns(conn) - out: dict[str, dict] = {} - for r in conn.execute( - "SELECT * FROM trade_order_monitors WHERE status='active' ORDER BY id DESC" - ).fetchall(): - mon = dict(r) - pk = _monitor_position_key(mon) - if pk not in out: - out[pk] = mon - return out - - def _find_active_monitor(conn, symbol: str, direction: str) -> Optional[dict]: - direction = (direction or "long").strip().lower() - for r in conn.execute( - "SELECT * FROM trade_order_monitors WHERE status='active' ORDER BY id DESC" - ).fetchall(): - row = dict(r) - if (row.get("direction") or "long") != direction: - continue - if _match_ctp_symbol(symbol, row.get("symbol") or ""): - return row - return None - - def _find_pending_monitor(conn, symbol: str, direction: str) -> Optional[dict]: - """开仓委托 pending 仍带止损/移动保本元数据,需与 CTP 持仓关联展示。""" - direction = (direction or "long").strip().lower() - for r in conn.execute( - "SELECT * FROM trade_order_monitors WHERE status='pending' ORDER BY id DESC" - ).fetchall(): - row = dict(r) - if (row.get("direction") or "long") != direction: - continue - if _match_ctp_symbol(symbol, row.get("symbol") or ""): - return row - return None - - def _has_pending_monitors(conn) -> bool: - return bool( - conn.execute( - "SELECT 1 FROM trade_order_monitors WHERE status='pending' LIMIT 1" - ).fetchone() - ) - - def _overlay_sl_tp_readonly( - conn, - mon: Optional[dict], - sym: str, - direction: str, - ) -> Optional[dict]: - """只读:从已关闭监控补全止盈止损,不写库。""" - if not mon: - rsl, rtp, rtrail, rinitial = _restore_sl_tp_from_closed(conn, sym, direction) - if rsl is None and rtp is None and not rtrail: - return {"symbol": sym, "direction": direction} - return { - "symbol": sym, - "direction": direction, - "stop_loss": rsl, - "take_profit": rtp, - "trailing_be": rtrail, - "initial_stop_loss": rinitial, - } - sl = mon.get("stop_loss") - tp = mon.get("take_profit") - trailing = int(mon.get("trailing_be") or 0) - if sl is not None or tp is not None or trailing: - return mon - rsl, rtp, rtrail, rinitial = _restore_sl_tp_from_closed(conn, sym, direction) - if rsl is None and rtp is None and not rtrail: - return mon - merged = dict(mon) - merged["stop_loss"] = rsl - merged["take_profit"] = rtp - merged["trailing_be"] = rtrail - merged["initial_stop_loss"] = rinitial - return merged - - def _revive_closed_monitor(conn, symbol: str, direction: str) -> Optional[dict]: - """柜台仍有持仓但本地监控被误关时,恢复最近一条同品种记录。""" - direction = (direction or "long").strip().lower() - for r in conn.execute( - "SELECT * FROM trade_order_monitors WHERE status='closed' ORDER BY id DESC LIMIT 40" - ).fetchall(): - row = dict(r) - if (row.get("direction") or "long") != direction: - continue - if not _match_ctp_symbol(symbol, row.get("symbol") or ""): - continue - if int(row.get("lots") or 0) <= 0: - continue - execute_retry( - conn, - "UPDATE trade_order_monitors SET status='active' WHERE id=?", - (row["id"],), - ) - row["status"] = "active" - logger.info( - "恢复误关闭监控 id=%s sym=%s dir=%s", - row.get("id"), row.get("symbol"), direction, - ) - return row - return None - - def _find_or_revive_monitor(conn, symbol: str, direction: str) -> Optional[dict]: - active = _find_active_monitor(conn, symbol, direction) - if active: - return active - return _revive_closed_monitor(conn, symbol, direction) - - def _close_all_monitors_for_sym_dir(conn, symbol: str, direction: str) -> None: - direction = (direction or "long").strip().lower() - for r in conn.execute( - "SELECT id, symbol, direction FROM trade_order_monitors " - "WHERE status IN ('active', 'pending')" - ).fetchall(): - if (r["direction"] or "long") != direction: - continue - if _match_ctp_symbol(symbol, r["symbol"] or ""): - conn.execute( - "UPDATE trade_order_monitors SET status='closed' WHERE id=?", - (r["id"],), - ) - - def _close_duplicate_monitors(conn, symbol: str, direction: str, keep_id: int) -> None: - direction = (direction or "long").strip().lower() - for r in conn.execute( - "SELECT id, symbol, direction FROM trade_order_monitors WHERE status='active'" - ).fetchall(): - if int(r["id"]) == int(keep_id): - continue - if (r["direction"] or "long") != direction: - continue - if _match_ctp_symbol(symbol, r["symbol"] or ""): - conn.execute( - "UPDATE trade_order_monitors SET status='closed' WHERE id=?", - (r["id"],), - ) - - def _upsert_open_monitor( - conn, - *, - sym: str, - direction: str, - lots: int, - price: float, - sl, - tp, - trailing_be: int, - ctp_open_time: Optional[str] = None, - open_time: Optional[str] = None, - monitor_type: str = "manual", - status: str = "active", - vt_order_id: Optional[str] = None, - order_price: Optional[float] = None, - ) -> int: - ensure_monitor_order_columns(conn) - codes = ths_to_codes(sym) or {} - sl_f = float(sl) if sl not in (None, "") else None - tp_f = float(tp) if tp not in (None, "") else None - now_s = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - status_val = status if status in ("pending", "active") else "active" - order_px = float(order_price if order_price is not None else price) - existing = _find_active_monitor(conn, sym, direction) - if not existing: - for r in conn.execute( - "SELECT * FROM trade_order_monitors WHERE status='pending' ORDER BY id DESC" - ).fetchall(): - row = dict(r) - if (row.get("direction") or "long") != (direction or "long").strip().lower(): - continue - if _match_ctp_symbol(sym, row.get("symbol") or ""): - existing = row - break - if existing: - mid = int(existing["id"]) - existing_status = (existing.get("status") or "active").strip().lower() - if existing_status == "active" and status_val == "pending": - status_val = "active" - initial_sl = existing.get("initial_stop_loss") - if sl_f is None: - sl_f = float(existing["stop_loss"]) if existing.get("stop_loss") is not None else None - if tp_f is None: - tp_f = float(existing["take_profit"]) if existing.get("take_profit") is not None else None - if sl_f is not None and initial_sl is None: - initial_sl = sl_f - if not trailing_be: - trailing_be = int(existing.get("trailing_be") or 0) - open_time_val = (existing.get("open_time") or "").strip() or now_s - if open_time: - open_time_val = open_time - elif monitor_type == "ctp_sync" and ctp_open_time: - open_time_val = ctp_open_time - vt_val = vt_order_id or existing.get("vt_order_id") - conn.execute( - """UPDATE trade_order_monitors SET - symbol=?, symbol_name=?, market_code=?, lots=?, entry_price=?, - stop_loss=?, take_profit=?, initial_stop_loss=?, trailing_be=?, open_time=?, - monitor_type=?, status=?, vt_order_id=?, order_price=?, risk_percent=COALESCE(risk_percent, ?) - WHERE id=?""", - ( - sym, - codes.get("name", sym), - codes.get("market_code", ""), - lots, - price, - sl_f, - tp_f, - initial_sl, - trailing_be, - open_time_val, - monitor_type if monitor_type != "manual" else (existing.get("monitor_type") or "manual"), - status_val, - vt_val, - order_px, - get_risk_percent(get_setting), - mid, - ), - ) - else: - if open_time: - open_time_val = open_time - elif monitor_type == "ctp_sync" and ctp_open_time: - open_time_val = ctp_open_time - else: - open_time_val = now_s - conn.execute( - """INSERT INTO trade_order_monitors ( - symbol, symbol_name, market_code, direction, lots, entry_price, - stop_loss, take_profit, initial_stop_loss, trailing_be, - open_time, monitor_type, status, vt_order_id, order_price, risk_percent - ) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)""", - ( - sym, - codes.get("name", sym), - codes.get("market_code", ""), - direction, - lots, - price, - sl_f, - tp_f, - sl_f, - trailing_be, - open_time_val, - monitor_type, - status_val, - vt_order_id, - order_px, - get_risk_percent(get_setting), - ), - ) - mid = int(conn.execute("SELECT last_insert_rowid()").fetchone()[0]) - if status_val == "active": - _close_duplicate_monitors(conn, sym, direction, mid) - return mid - - def _sync_monitor_from_ctp( - conn, - mid: int, - sym: str, - direction: str, - mode: str, - *, - ctp: Optional[dict] = None, - capital: float = 0.0, - ) -> None: - """CTP 同步:均价、现价、保证金、仓位占比写入数据库;不覆盖期货下单的开仓时间。""" - positions = [ctp] if ctp else _ctp_positions(mode, refresh_if_empty=False, refresh_margin=True) - for p in positions: - if not p or int(p.get("lots") or 0) <= 0: - continue - if (p.get("direction") or "long") != direction: - continue - if not _match_ctp_symbol(p.get("symbol") or "", sym): - continue - row = conn.execute( - "SELECT open_time, monitor_type FROM trade_order_monitors WHERE id=?", (mid,), - ).fetchone() - db_open = (row["open_time"] or "").strip() if row else "" - monitor_type = (row["monitor_type"] or "manual").strip().lower() if row else "manual" - ctp_open = (p.get("open_time") or "").strip() or None - open_time_val = db_open - if monitor_type == "ctp_sync" and ctp_open: - open_time_val = ctp_open - lots = int(p.get("lots") or 0) - entry = float(p.get("avg_price") or 0) - ctp_margin = float(p.get("margin") or 0) - mark = None - if ctp_status(mode).get("connected"): - mark = ctp_get_tick_price(mode, sym) - if mark is None or mark <= 0: - mark = entry if entry else None - resolved_entry, _src = _resolve_ctp_entry_price( - mode, sym, direction, p, - ) - if resolved_entry > 0: - entry = resolved_entry - float_pnl = None - if mark and entry and lots > 0: - float_pnl = calc_position_metrics( - direction, entry, entry, entry, lots, mark, capital, sym, - ).get("float_pnl") - est = calc_position_metrics( - direction, entry, entry, entry, lots, mark or entry, capital, sym, - ).get("margin") - margin, _src = _resolve_position_margin( - sym=sym, - direction=direction, - lots=lots, - entry=entry, - mode=mode, - ctp=p, - est_margin=est, - ) - position_pct = None - if margin and capital > 0: - position_pct = round(float(margin) / float(capital) * 100, 2) - open_commission = _open_commission_from_ctp_trades(mode, sym, direction) - if open_commission is None: - fee_info = calc_fee_breakdown( - sym, entry, entry, lots, open_time_val or "", "", - trading_mode=mode, - ) - open_commission = fee_info.get("open_fee") - execute_retry( - conn, - """UPDATE trade_order_monitors SET lots=?, entry_price=?, - open_time=?, margin=?, position_pct=?, mark_price=?, float_pnl=?, - open_fee=? - WHERE id=?""", - ( - lots, - entry, - open_time_val, - margin, - position_pct, - float(mark) if mark else None, - float_pnl, - open_commission, - mid, - ), - ) - return - - def _sync_monitor_lots_from_ctp( - conn, mid: int, sym: str, direction: str, mode: str, *, ctp: Optional[dict] = None, - ) -> None: - _sync_monitor_from_ctp( - conn, mid, sym, direction, mode, ctp=ctp, capital=_capital(conn), - ) - - def _compose_position_row( - conn, - *, - mon: Optional[dict], - ctp: Optional[dict], - mode: str, - capital: float, - now_iso: str, - fast: bool = False, - ) -> Optional[dict]: - if not mon and not ctp: - return None - - if mon: - sym = (mon.get("symbol") or "").strip() - direction = mon.get("direction") or "long" - lots = int(mon.get("lots") or 0) - entry = float(mon.get("entry_price") or 0) - source_label = monitor_source_label(mon.get("monitor_type")) - open_time = _time_str(mon.get("open_time")) - open_time_source = "order" - margin = mon.get("margin") - position_pct = mon.get("position_pct") - mark = mon.get("mark_price") - float_pnl = mon.get("float_pnl") - if float_pnl is not None: - float_pnl = round(float(float_pnl), 2) - else: - sym = (ctp.get("symbol") or "").strip() - direction = ctp.get("direction") or "long" - lots = int(ctp.get("lots") or 0) - entry = float(ctp.get("avg_price") or 0) - source_label = "CTP 柜台" - open_time = _time_str(ctp.get("open_time")) - open_time_source = "ctp" - margin = None - position_pct = None - mark = None - float_pnl = None - - if lots <= 0: - return None - - if ctp: - ctp_lots = int(ctp.get("lots") or 0) - if ctp_lots > 0: - lots = ctp_lots - ths_sym = _ctp_pos_to_ths_code(ctp) or sym - resolved_entry, _entry_src = _resolve_ctp_entry_price( - mode, ths_sym, direction, ctp, - ) - if resolved_entry > 0: - entry = resolved_entry - elif float(ctp.get("avg_price") or 0) > 0: - entry = float(ctp.get("avg_price") or 0) - ctp_margin = float(ctp.get("margin") or 0) - if (margin is None or float(margin or 0) <= 0) and ctp_margin > 0: - margin = ctp_margin - if ctp_status(mode).get("connected"): - source_label = "CTP 柜台" - - codes = ths_to_codes(sym) - tick = calc_order_tick_metrics(sym, lots, entry, trading_mode=mode) - sl = float(mon["stop_loss"]) if mon and mon.get("stop_loss") is not None else None - tp = float(mon["take_profit"]) if mon and mon.get("take_profit") is not None else None - holding = _holding_duration(open_time, now_iso) if open_time else "" - - if ctp_status(mode).get("connected"): - live_mark = ctp_get_tick_price(mode, sym) - if live_mark and live_mark > 0: - mark = live_mark - elif (mark is None or float(mark or 0) <= 0) and not fast and codes: - mark = fetch_price( - sym, - codes.get("market_code", ""), - codes.get("sina_code", ""), - ) - if mark is None or mark <= 0: - mark = entry if entry else None - close_est = float(mark) if mark and mark > 0 else entry - if mark and entry and lots > 0: - pos_tmp = calc_position_metrics( - direction, entry, sl or entry, tp or entry, lots, mark, capital, sym, - ) - float_pnl = pos_tmp.get("float_pnl") - if ctp and ctp_status(mode).get("connected"): - ctp_pnl = float(ctp.get("pnl") or 0) - if ctp_pnl != 0: - float_pnl = round(ctp_pnl, 2) - - fee_info = calc_fee_breakdown( - sym, entry, close_est, lots, open_time or now_iso, now_iso, trading_mode=mode, - ) - open_commission = _open_commission_from_ctp_trades(mode, sym, direction) - if open_commission is None and mon and mon.get("open_fee") is not None: - cached_fee = float(mon.get("open_fee") or 0) - if cached_fee > 0: - open_commission = cached_fee - if open_commission is not None: - display_fee = open_commission - fee_source = "ctp" - else: - display_fee = fee_info["open_fee"] - fee_source = fee_info.get("fee_source") or "local" - est_net = None - if float_pnl is not None: - est_net = round(float(float_pnl) - fee_info["close_fee"], 2) - pos_metrics = calc_position_metrics( - direction, entry, sl if sl is not None else entry, - tp if tp is not None else entry, lots, mark, capital, sym, - ) - mon_margin = margin - margin, margin_source = _resolve_position_margin( - sym=sym, - direction=direction, - lots=lots, - entry=entry, - mode=mode, - ctp=ctp, - mon_margin=mon_margin if isinstance(mon_margin, (int, float)) else None, - est_margin=pos_metrics.get("margin"), - ) - if margin and capital > 0: - position_pct = round(float(margin) / float(capital) * 100, 2) - elif position_pct is None or float(position_pct or 0) <= 0: - position_pct = pos_metrics.get("position_pct") - elif position_pct is not None: - position_pct = float(position_pct) - order_st = monitor_order_status( - mon or {}, mode=mode, ths_code=sym, direction=direction, - ) - pending_for_row: list[dict] = [] - if sl is not None: - pending_for_row.append({ - "order_kind": "stop_loss", - "label": "止损监控", - "price": sl, - "lots": lots, - "source": "monitor", - "monitor_id": mon["id"] if mon else None, - }) - if tp is not None: - pending_for_row.append({ - "order_kind": "take_profit", - "label": "止盈监控", - "price": tp, - "lots": lots, - "source": "monitor", - "monitor_id": mon["id"] if mon else None, - }) - row_key = _canonical_position_key( - sym, direction, (ctp or {}).get("exchange") or "", - ) - return { - "key": row_key, - "position_key": row_key, - "source": "ctp", - "source_label": source_label, - "sync_pending": False, - "monitor_id": mon["id"] if mon else None, - "symbol_code": sym, - **_symbol_display_fields(sym), - "direction": direction, - "direction_label": "做多" if direction == "long" else "做空", - "lots": lots, - "entry_price": entry, - "stop_loss": sl, - "take_profit": tp, - "open_time": open_time or None, - "open_time_source": open_time_source or None, - "holding_duration": holding or None, - "mark_price": mark, - "current_price": mark, - "margin": margin, - "margin_source": margin_source, - "position_pct": position_pct, - "risk_amount": pos_metrics.get("risk_amount") if sl is not None else None, - "reward_amount": pos_metrics.get("reward_amount") if tp is not None else None, - "risk_pct": pos_metrics.get("risk_pct") if sl is not None else None, - "rr_ratio": pos_metrics.get("rr_ratio") if sl is not None and tp is not None else None, - "float_pnl": float_pnl, - "est_fee": display_fee, - "est_fee_open": display_fee, - "est_fee_close": fee_info["close_fee"], - "est_fee_close_type": fee_info["close_type"], - "fee_source": fee_source, - "est_pnl_net": est_net, - "sl_order_active": order_st.get("sl_monitoring"), - "tp_order_active": order_st.get("tp_monitoring"), - "sl_monitoring": order_st.get("sl_monitoring"), - "tp_monitoring": order_st.get("tp_monitoring"), - "can_place_orders": False, - "tick_value_total": tick.get("tick_value_total"), - "price_precision": tick.get("price_precision"), - "tick_size": tick.get("tick_size"), - "can_close": True, - "close_allowed": is_trading_session(), - "pending_orders": pending_for_row, - "trailing_be": bool(mon.get("trailing_be")) if mon else False, - "trailing_r_locked": int(mon.get("trailing_r_locked") or 0) if mon else 0, - "breakeven_locked": _breakeven_locked( - entry=entry, - stop_loss=sl, - direction=direction, - tick_size=tick.get("tick_size"), - trailing_r_locked=int(mon.get("trailing_r_locked") or 0) if mon else 0, - ), - } - - def _compose_pending_row( - mon: dict, - *, - mode: str, - capital: float, - now_iso: str, - ) -> Optional[dict]: - sym = (mon.get("symbol") or "").strip() - direction = (mon.get("direction") or "long").strip().lower() - lots = int(mon.get("lots") or 0) - if not sym or lots <= 0: - return None - order_price = float(mon.get("order_price") or mon.get("entry_price") or 0) - codes = ths_to_codes(sym) - sl = float(mon["stop_loss"]) if mon.get("stop_loss") is not None else None - tp = float(mon["take_profit"]) if mon.get("take_profit") is not None else None - pos_metrics = calc_position_metrics( - direction, order_price, sl or order_price, tp or order_price, lots, order_price, capital, sym, - ) - open_time = _time_str(mon.get("open_time")) - timeout_sec = get_pending_order_timeout_sec(get_setting) - remain = pending_auto_cancel_remaining(mon, timeout_sec=timeout_sec) - return { - "key": f"{_canonical_position_key(sym, direction)}:pending:{mon.get('id')}", - "order_state": "pending", - "source": "pending", - "source_label": "委托挂单中", - "sync_pending": True, - "monitor_id": mon.get("id"), - "symbol_code": sym, - **_symbol_display_fields(sym), - "direction": direction, - "direction_label": "做多" if direction == "long" else "做空", - "lots": lots, - "entry_price": order_price, - "order_price": order_price, - "stop_loss": sl, - "take_profit": tp, - "open_time": open_time or None, - "holding_duration": _holding_duration(open_time, now_iso) if open_time else None, - "mark_price": order_price, - "current_price": order_price, - "margin": pos_metrics.get("margin"), - "margin_source": "estimate", - "position_pct": pos_metrics.get("position_pct"), - "risk_amount": pos_metrics.get("risk_amount") if sl is not None else None, - "reward_amount": pos_metrics.get("reward_amount") if tp is not None else None, - "rr_ratio": pos_metrics.get("rr_ratio") if sl is not None and tp is not None else None, - "float_pnl": None, - "est_fee": None, - "can_close": False, - "close_allowed": False, - "can_cancel_order": is_trading_session(), - "cancel_allowed": is_trading_session(), - "auto_cancel_sec": remain, - "pending_timeout_sec": timeout_sec, - "pending_timeout_min": max(1, timeout_sec // 60), - "vt_order_id": mon.get("vt_order_id"), - "sl_order_active": False, - "tp_order_active": False, - "sl_monitoring": bool(sl is not None), - "tp_monitoring": bool(tp is not None), - "can_place_orders": False, - "pending_orders": [], - "trailing_be": bool(mon.get("trailing_be")), - "trailing_r_locked": int(mon.get("trailing_r_locked") or 0), - } - - def _compose_ctp_open_order_row( - o: dict, - *, - mode: str, - capital: float, - now_iso: str, - ) -> Optional[dict]: - offset_u = (o.get("offset") or "").upper() - if offset_u and "OPEN" not in offset_u: - return None - sym = _ctp_pos_to_ths_code(o) or (o.get("symbol") or "").strip() - direction = (o.get("direction") or "long").strip().lower() - lots = int(o.get("lots") or 0) - if not sym or lots <= 0: - return None - order_price = float(o.get("price") or 0) - pos_metrics = calc_position_metrics( - direction, order_price, order_price, order_price, lots, order_price, capital, sym, - ) - timeout_sec = get_pending_order_timeout_sec(get_setting) - return { - "key": f"{_canonical_position_key(sym, direction)}:pending:ctp:{o.get('order_id') or ''}", - "order_state": "pending", - "source": "ctp", - "source_label": "委托挂单", - "sync_pending": True, - "monitor_id": None, - "order_id": o.get("order_id"), - "vt_order_id": o.get("vt_order_id") or o.get("order_id"), - "symbol_code": sym, - **_symbol_display_fields(sym), - "direction": direction, - "direction_label": "做多" if direction == "long" else "做空", - "lots": lots, - "entry_price": order_price, - "order_price": order_price, - "stop_loss": None, - "take_profit": None, - "open_time": now_iso, - "holding_duration": None, - "mark_price": order_price, - "current_price": order_price, - "margin": pos_metrics.get("margin"), - "margin_source": "estimate", - "position_pct": pos_metrics.get("position_pct"), - "float_pnl": None, - "est_fee": None, - "can_close": False, - "close_allowed": False, - "can_cancel_order": is_trading_session(), - "cancel_allowed": is_trading_session(), - "pending_timeout_sec": timeout_sec, - "pending_timeout_min": max(1, timeout_sec // 60), - "sl_order_active": False, - "tp_order_active": False, - "sl_monitoring": False, - "tp_monitoring": False, - "can_place_orders": False, - "pending_orders": [], - "trailing_be": False, - "trailing_r_locked": 0, - } - - def _reconcile_pending(conn, mode: str, *, capital: float = 0.0) -> dict[str, int]: - return reconcile_pending_orders( - conn, - mode, - match_symbol_fn=_match_ctp_symbol, - sync_monitor_fn=_sync_monitor_from_ctp, - capital=capital, - list_positions_fn=_ctp_positions, - timeout_sec=get_pending_order_timeout_sec(get_setting), - ) - - def _build_active_orders( - conn, - *, - mode: str, - capital: float, - now_iso: str, - ) -> list[dict]: - """当前委托:CTP 已连接时读柜台;未连接时不展示本地 pending。""" - orders: list[dict] = [] - seen_keys: set[str] = set() - connected = ctp_status(mode).get("connected") - - if connected: - ctp_orders = trading_state.get_active_orders() - if not ctp_orders: - ctp_orders = _ctp_active_orders(mode) - for o in ctp_orders: - try: - row = _compose_ctp_open_order_row( - o, mode=mode, capital=capital, now_iso=now_iso, - ) - if not row: - row = _compose_ctp_order_row_any( - o, mode=mode, capital=capital, now_iso=now_iso, - ) - if row: - orders.append(row) - seen_keys.add(row.get("key") or "") - except Exception as exc: - logger.warning("compose ctp order row failed: %s", exc) - - ctp_active_map: dict[str, dict] = {} - for o in ctp_orders or []: - for key in (o.get("order_id"), o.get("vt_order_id")): - if key: - ctp_active_map[str(key)] = o - - for r in conn.execute( - "SELECT * FROM trade_order_monitors WHERE status='pending' ORDER BY id DESC" - ).fetchall(): - mon = dict(r) - try: - if not pending_monitor_has_live_order( - mon, - active_orders=ctp_active_map, - active_order_list=ctp_orders or [], - ): - continue - prow = _compose_pending_row( - mon, mode=mode, capital=capital, now_iso=now_iso, - ) - if prow and prow.get("key") not in seen_keys: - pk = f"{prow.get('symbol_code') or ''}:{prow.get('direction') or ''}" - dup = any( - (x.get("symbol_code") or "") + ":" + (x.get("direction") or "") == pk - and x.get("order_state") == "pending" - for x in orders - ) - if not dup: - orders.append(prow) - except Exception as exc: - logger.warning("compose pending order row failed: %s", exc) - return orders - - def _compose_ctp_order_row_any( - o: dict, - *, - mode: str, - capital: float, - now_iso: str, - ) -> Optional[dict]: - """CTP 任意未成交委托(含平仓)。""" - sym = _ctp_pos_to_ths_code(o) or (o.get("symbol") or "").strip() - direction = (o.get("direction") or "long").strip().lower() - lots = int(o.get("lots") or 0) - if not sym or lots <= 0: - return None - offset_u = (o.get("offset") or "").upper() - is_open = not offset_u or "OPEN" in offset_u - order_price = float(o.get("price") or 0) - pos_metrics = calc_position_metrics( - direction, order_price, order_price, order_price, lots, order_price, capital, sym, - ) - label = "开仓委托" if is_open else "平仓委托" - timeout_sec = get_pending_order_timeout_sec(get_setting) - ex = o.get("exchange") or "" - pk = _canonical_position_key(sym, direction, ex) - return { - "key": f"{pk}:order:{o.get('order_id') or ''}", - "order_state": "pending", - "source": "ctp", - "source_label": label, - "sync_pending": False, - "monitor_id": None, - "order_id": o.get("order_id"), - "vt_order_id": o.get("vt_order_id") or o.get("order_id"), - "symbol_code": sym, - **_symbol_display_fields(sym), - "direction": direction, - "direction_label": "做多" if direction == "long" else "做空", - "lots": lots, - "entry_price": order_price, - "order_price": order_price, - "stop_loss": None, - "take_profit": None, - "open_time": now_iso, - "mark_price": order_price, - "current_price": order_price, - "margin": pos_metrics.get("margin"), - "margin_source": "estimate", - "position_pct": pos_metrics.get("position_pct"), - "float_pnl": None, - "can_close": False, - "close_allowed": False, - "can_cancel_order": is_trading_session(), - "cancel_allowed": is_trading_session(), - "pending_timeout_sec": timeout_sec if is_open else None, - "pending_timeout_min": max(1, timeout_sec // 60) if is_open else None, - "sl_order_active": False, - "tp_order_active": False, - "sl_monitoring": False, - "tp_monitoring": False, - "can_place_orders": False, - "pending_orders": [], - "trailing_be": False, - "trailing_r_locked": 0, - } - - def _build_trading_live_rows(conn, *, fast: bool = False) -> list[dict]: - """当前持仓:以 CTP 为准,SQLite 仅叠加 SL/TP 元数据。""" - from zoneinfo import ZoneInfo - tz = ZoneInfo("Asia/Shanghai") - now_iso = datetime.now(tz).strftime("%Y-%m-%dT%H:%M") - mode = get_trading_mode(get_setting) - capital = _capital(conn) - - ctp_list: list[dict] = [] - if ctp_status(mode).get("connected"): - merged: dict[str, dict] = {} - for p in list(_ctp_positions(mode) or []) + list(trading_state.get_positions() or []): - lots = int(p.get("lots") or 0) - if lots <= 0: - continue - pk = p.get("position_key") or _position_key_from_ctp(p) - merged[pk] = p - ctp_list = list(merged.values()) - - ensure_monitor_order_columns(conn) - monitor_by_pk = _monitors_by_position_key(conn) - - rows: list[dict] = [] - for p in ctp_list: - lots = int(p.get("lots") or 0) - if lots <= 0: - continue - pk = p.get("position_key") or _position_key_from_ctp(p) - mon = monitor_by_pk.get(pk) - if not mon: - for mk, mv in monitor_by_pk.items(): - if (mv.get("direction") or "long") != (p.get("direction") or "long"): - continue - if _match_ctp_symbol(p.get("symbol") or "", mv.get("symbol") or ""): - mon = mv - break - ths = _ctp_pos_to_ths_code(p) or (p.get("symbol") or "") - direction = p.get("direction") or "long" - if not mon: - mon = _find_pending_monitor(conn, ths, direction) - if not mon: - if fast: - mon = _find_active_monitor(conn, ths, direction) - else: - mon = _find_or_revive_monitor(conn, ths, direction) - if mon: - if fast: - mon = _overlay_sl_tp_readonly(conn, mon, ths, direction) or mon - else: - mon = _restore_monitor_sl_tp_if_missing(conn, mon, ths, direction) or mon - _sync_monitor_from_ctp( - conn, int(mon["id"]), mon.get("symbol") or ths, - mon.get("direction") or direction, - mode, ctp=p, capital=capital, - ) - mon = _find_active_monitor( - conn, mon.get("symbol") or ths, mon.get("direction") or direction, - ) or mon - mon = _restore_monitor_sl_tp_if_missing(conn, mon, ths, direction) or mon - elif fast: - mon = _overlay_sl_tp_readonly(conn, None, ths, direction) - try: - row = _compose_position_row( - conn, mon=mon, ctp=p, mode=mode, capital=capital, - now_iso=now_iso, fast=fast, - ) - if row: - rows.append(row) - except Exception as exc: - logger.warning("compose ctp position row failed: %s", exc) - - seen: set[str] = set() - deduped: list[dict] = [] - for row in rows: - rk = row.get("key") or row.get("position_key") or "" - if rk in seen: - continue - seen.add(rk) - deduped.append(row) - - if not deduped and ctp_status(mode).get("connected"): - margin_raw = ctp_account_margin_used(mode) - margin_used = float(margin_raw or 0) if margin_raw is not None else 0.0 - has_margin_hint = margin_raw is not None and margin_used > 0 - has_active_mon = any( - int(m.get("lots") or 0) > 0 for m in monitor_by_pk.values() - ) - since_connect = 9999.0 - try: - since_connect = time.time() - float( - getattr(get_bridge(), "_last_connect_ok_ts", 0) or 0, - ) - except Exception: - pass - if has_margin_hint or has_active_mon or since_connect < 300: - if not monitor_by_pk and has_margin_hint: - _ensure_monitors_from_sticky_state(conn, mode) - monitor_by_pk = _monitors_by_position_key(conn) - for mon in monitor_by_pk.values(): - lots = int(mon.get("lots") or 0) - if lots <= 0: - continue - sym = (mon.get("symbol") or "").strip() - direction = (mon.get("direction") or "long").strip().lower() - if fast: - mon = _overlay_sl_tp_readonly(conn, mon, sym, direction) or mon - else: - mon = ( - _restore_monitor_sl_tp_if_missing(conn, mon, sym, direction) - or mon - ) - try: - row = _compose_position_row( - conn, - mon=mon, - ctp=None, - mode=mode, - capital=capital, - now_iso=now_iso, - fast=fast, - ) - if not row: - continue - rk = row.get("key") or row.get("position_key") or "" - if rk and rk in seen: - continue - if rk: - seen.add(rk) - deduped.append(row) - except Exception as exc: - logger.warning("compose monitor fallback row failed: %s", exc) - - if not deduped and ctp_status(mode).get("connected"): - for r in conn.execute( - "SELECT * FROM trade_order_monitors WHERE status='active' ORDER BY id DESC" - ).fetchall(): - mon = dict(r) - lots = int(mon.get("lots") or 0) - if lots <= 0: - continue - sym = (mon.get("symbol") or "").strip() - direction = (mon.get("direction") or "long").strip().lower() - rk = _monitor_position_key(mon) - if rk in seen: - continue - if fast: - mon = _overlay_sl_tp_readonly(conn, mon, sym, direction) or mon - try: - row = _compose_position_row( - conn, - mon=mon, - ctp=None, - mode=mode, - capital=capital, - now_iso=now_iso, - fast=fast, - ) - if not row: - continue - row_key = row.get("key") or row.get("position_key") or rk - if row_key in seen: - continue - seen.add(row_key) - deduped.append(row) - except Exception as exc: - logger.warning("compose active monitor row failed: %s", exc) - - return deduped - - def _build_trading_live_payload(conn, *, fast: bool = False) -> dict: - from zoneinfo import ZoneInfo - tz = ZoneInfo("Asia/Shanghai") - now_iso = datetime.now(tz).strftime("%Y-%m-%dT%H:%M") - mode = get_trading_mode(get_setting) - ctp_st = ctp_status(mode) - _remember_ctp_status(mode, ctp_st) - capital = _capital(conn) - if ctp_st.get("connected") and not fast: - _reconcile_pending(conn, mode, capital=capital) - if ctp_st.get("connected"): - if not fast: - _ensure_monitors_from_ctp(conn, mode) - _sync_trade_monitors_with_ctp(conn, mode) - elif count_active_trade_monitors(conn) == 0: - margin_raw = ctp_account_margin_used(mode) - if margin_raw is not None and float(margin_raw) > 0: - _ensure_monitors_from_sticky_state(conn, mode) - if not fast: - _close_stale_roll_groups(conn) - rows = _build_trading_live_rows(conn, fast=fast) - active_orders = _build_active_orders( - conn, mode=mode, capital=capital, now_iso=now_iso, - ) - rows = _apply_account_margin_to_rows(rows, mode, capital) - if not fast: - _persist_ctp_snapshot_to_monitors(conn, rows, mode) - pending_orders = _build_pending_orders(conn, mode) - risk = get_risk_status( - conn, - active_count=_effective_active_position_count(conn, mode), - equity=capital, - ) - margin_used = ( - ctp_account_margin_used(mode) if ctp_st.get("connected") else None - ) - display_sync_state = "ready" if rows else trading_state.sync_state - display_sync_label = "已同步" if rows else trading_state.sync_label() - return { - "ok": True, - "rows": rows, - "active_orders": active_orders, - "pending_orders": pending_orders, - "capital": capital, - "margin_used": margin_used, - "ctp_status": ctp_st, - "trading_mode_label": trading_mode_label(get_setting), - "risk_status": risk, - "trading_session": is_trading_session(), - "night_session": is_night_trading_session(), - "session_clock": trading_session_clock(), - "pending_order_timeout_min": get_pending_order_timeout_min(get_setting), - "sync_state": display_sync_state, - "sync_label": display_sync_label, - } - - def _minimal_live_payload(conn) -> dict: - """零 IPC 兜底:仅读库 + 缓存 CTP 状态,持仓由后台 worker 补全。""" - mode = get_trading_mode(get_setting) - ctp_st = _cached_ctp_status(mode) - capital = _capital(conn) - risk = get_risk_status( - conn, - active_count=count_active_trade_monitors(conn), - equity=capital, - ) - syncing = bool(ctp_st.get("connected") or ctp_st.get("connecting")) - return { - "ok": True, - "rows": [], - "active_orders": [], - "pending_orders": [], - "capital": capital, - "ctp_status": ctp_st, - "trading_mode_label": trading_mode_label(get_setting), - "risk_status": risk, - "trading_session": is_trading_session(), - "night_session": is_night_trading_session(), - "session_clock": trading_session_clock(), - "pending_order_timeout_min": get_pending_order_timeout_min(get_setting), - "sync_state": "syncing" if syncing else trading_state.sync_state, - "sync_label": "加载中…" if syncing else trading_state.sync_label(), - } - - def _normalize_live_payload(payload: dict) -> dict: - if payload.get("rows"): - payload = dict(payload) - payload["sync_state"] = "ready" - payload["sync_label"] = "已同步" - return payload - - def _refresh_trading_live_snapshot(*, fast: bool = False) -> dict: - def _build() -> dict: - mode = get_trading_mode(get_setting) - if ctp_status(mode).get("connected") and not fast: - try: - with _ctp_td_lock: - get_bridge().calibrate_trading_state() - except Exception as exc: - logger.debug("refresh calibrate: %s", exc) - for p in trading_state.get_positions() or _ctp_positions(mode, refresh_if_empty=False): - ths = _ctp_pos_to_ths_code(p) - if ths: - try: - get_bridge().subscribe_symbol(ths) - except Exception: - pass - conn = get_db() - try: - init_strategy_tables(conn) - if not fast: - ensure_monitor_order_columns(conn, migrate=True) - payload = _build_trading_live_payload(conn, fast=fast) - commit_retry(conn) - prev = position_hub.get_snapshot() - active_n = int((payload.get("risk_status") or {}).get("active_count") or 0) - if ( - prev - and ctp_status(mode).get("connected") - and not (payload.get("rows") or []) - and (prev.get("rows") or []) - ): - margin_raw = payload.get("margin_used") - if margin_raw is None: - margin_raw = ctp_account_margin_used(mode) - margin_used_val = float(margin_raw or 0) if margin_raw is not None else 0.0 - if ( - (margin_raw is not None and margin_used_val > 0) - or trading_state.sync_state == "syncing" - or active_n > 0 - ): - payload = dict(payload) - payload["rows"] = prev["rows"] - if trading_state.sync_state == "syncing": - payload["sync_state"] = "syncing" - payload["sync_label"] = "同步中…" - elif ( - ctp_status(mode).get("connected") - and not (payload.get("rows") or []) - and active_n > 0 - ): - payload = dict(payload) - payload["rows"] = _build_trading_live_rows(conn, fast=fast) - elif ctp_status(mode).get("connected") and not (payload.get("rows") or []): - since_connect = time.time() - float( - getattr(get_bridge(), "_last_connect_ok_ts", 0) or 0, - ) - if since_connect < 180: - payload = dict(payload) - payload["sync_state"] = "syncing" - payload["sync_label"] = "持仓同步中…" - return _normalize_live_payload(payload) - finally: - conn.close() - - if fast: - snap = position_hub.get_snapshot() - if snap: - return snap - if _live_refresh_lock.acquire(blocking=False): - try: - return _build() - finally: - _live_refresh_lock.release() - conn = get_db() - try: - init_strategy_tables(conn) - return _minimal_live_payload(conn) - finally: - conn.close() - with _live_refresh_lock: - return _build() - - def _push_position_snapshot_async(*, fast: bool = True) -> None: - def _run() -> None: - try: - payload = _refresh_trading_live_snapshot(fast=fast) - position_hub.broadcast("positions", payload) - conn = get_db() - try: - rec = _recommend_payload(conn) - recommend_hub.broadcast("recommend", {"ok": True, **rec}) - finally: - conn.close() - except Exception as exc: - logger.debug("push position snapshot: %s", exc) - - threading.Thread(target=_run, daemon=True).start() - - def _build_position_quotes_payload(mode: str) -> dict: - """轻量现价/浮盈(仅读 tick 缓存,不走 SQLite)。""" - if not ctp_status(mode).get("connected"): - return {"ok": True, "quotes": []} - from contract_specs import get_contract_spec - - positions = trading_state.get_positions() - if not positions: - positions = _ctp_positions(mode, refresh_if_empty=False) - quotes: list[dict] = [] - for p in positions: - lots = int(p.get("lots") or 0) - if lots <= 0: - continue - ths = _ctp_pos_to_ths_code(p) or (p.get("symbol") or "") - if not ths: - continue - direction = (p.get("direction") or "long").strip().lower() - mark = ctp_get_tick_price(mode, ths) - if not mark or mark <= 0: - continue - entry, _ = _resolve_ctp_entry_price( - mode, ths, direction, p, - ) - if entry <= 0: - continue - mult = float(get_contract_spec(ths).get("mult") or 10) - ctp_pnl = float(p.get("pnl") or 0) - if ctp_pnl != 0: - float_pnl = round(ctp_pnl, 2) - elif direction == "long": - float_pnl = round((mark - entry) * mult * lots, 2) - else: - float_pnl = round((entry - mark) * mult * lots, 2) - row_key = _canonical_position_key( - ths, direction, (p.get("exchange") or ""), - ) - quotes.append({ - "key": row_key, - "position_key": row_key, - "mark_price": mark, - "current_price": mark, - "float_pnl": float_pnl, - }) - return {"ok": True, "quotes": quotes} - - def _push_position_quotes_async() -> None: - def _run() -> None: - try: - if not is_trading_session(): - return - mode = get_trading_mode(get_setting) - if trading_state.try_lock_entry_prices(): - _push_position_snapshot_async(fast=False) - return - payload = _build_position_quotes_payload(mode) - if payload.get("quotes"): - position_hub.push_event("position_quotes", payload) - except Exception as exc: - logger.debug("push position quotes: %s", exc) - - threading.Thread(target=_run, daemon=True, name="position-quotes").start() - - def _on_tick_sl_tp(exchange: str, symbol: str, price: float) -> None: - from sl_tp_guard import check_sl_tp_on_tick - from db_conn import DB_PATH, connect_db - - mode = get_trading_mode(get_setting) - if not ctp_status(mode).get("connected"): - return - conn = connect_db(DB_PATH) - try: - _init_tables(conn) - capital = _capital(conn) - n = check_sl_tp_on_tick( - conn, mode, exchange, symbol, price, - capital=capital, notify_fn=send_wechat_msg, - be_tick_mult=get_trailing_be_tick_buffer(get_setting), - ) - if n: - conn.commit() - _push_position_snapshot_async(fast=True) - except Exception as exc: - logger.debug("tick sl/tp: %s", exc) - finally: - conn.close() - - def _prime_position_snapshot() -> None: - """进程启动同步预热:优先写入持仓/权益快照,页面打开即可读。""" - try: - payload = _refresh_trading_live_snapshot(fast=True) - position_hub.set_snapshot(payload) - n = len(payload.get("rows") or []) - logger.info( - "持仓快照已预热 capital=%s rows=%d", - payload.get("capital"), - n, - ) - except Exception as exc: - logger.warning("prime position snapshot: %s", exc) - - def _bootstrap_trading_runtime() -> None: - """进程启动:并发预热持仓快照 + CTP 连接,不阻塞 HTTP 监听。""" - set_position_refresh_callback( - lambda: _push_position_snapshot_async(fast=True) - ) - set_tick_quote_callback(_push_position_quotes_async) - set_tick_sl_tp_callback(_on_tick_sl_tp) - set_ctp_connected_callback(_on_ctp_connected) - - def _warm() -> None: - try: - payload = _refresh_trading_live_snapshot(fast=True) - position_hub.set_snapshot(payload) - position_hub.broadcast("positions", payload) - mode = get_trading_mode(get_setting) - if ctp_status(mode).get("connected"): - try: - with _ctp_td_lock: - get_bridge().calibrate_trading_state() - get_bridge().request_position_snapshot(force=True) - except Exception as exc: - logger.debug("bootstrap calibrate: %s", exc) - payload = _refresh_trading_live_snapshot(fast=True) - position_hub.set_snapshot(payload) - position_hub.broadcast("positions", payload) - - def _slow_sync() -> None: - time.sleep(20) - try: - pl = _refresh_trading_live_snapshot(fast=False) - position_hub.set_snapshot(pl) - position_hub.broadcast("positions", pl) - except Exception as exc: - logger.warning("bootstrap slow sync: %s", exc) - - threading.Thread(target=_slow_sync, daemon=True, name="boot-slow-sync").start() - except Exception as exc: - logger.warning("bootstrap position snapshot: %s", exc) - - def _start_ctp() -> None: - try: - from ctp_premarket_connect import should_auto_connect_now - from vnpy_bridge import ctp_start_connect - - if should_auto_connect_now(): - mode = get_trading_mode(get_setting) - ctp_start_connect(mode, force=False, scheduled=True) - except Exception as exc: - logger.debug("bootstrap ctp connect: %s", exc) - - from concurrent.futures import ThreadPoolExecutor - - workers = max(2, int(os.getenv("QIHUO_STARTUP_WORKERS", "8") or 8)) - with ThreadPoolExecutor(max_workers=min(workers, 4), thread_name_prefix="boot") as pool: - pool.submit(_warm) - pool.submit(_start_ctp) - - def _on_ctp_connected(mode: str) -> None: - if mode != get_trading_mode(get_setting): - return - _schedule_recommend_refresh() - _push_position_snapshot_async(fast=True) - - def _after_connect() -> None: - try: - try: - with _ctp_td_lock: - get_bridge().request_position_snapshot(force=True) - get_bridge().calibrate_trading_state() - except Exception as exc: - logger.debug("ctp connected calibrate: %s", exc) - _push_position_snapshot_async(fast=True) - conn = get_db() - try: - init_strategy_tables(conn) - _ensure_monitors_from_ctp(conn, mode) - commit_retry(conn) - finally: - conn.close() - _push_position_snapshot_async(fast=False) - except Exception as exc: - logger.debug("ctp connected monitor restore: %s", exc) - - threading.Thread(target=_after_connect, daemon=True, name="ctp-monitor-restore").start() - - @app.route("/trade") - @login_required - def trade_page(): - return redirect(url_for("positions")) - - @app.route("/positions") - @login_required - def positions(): - conn = get_db() - try: - init_strategy_tables(conn) - mode = get_trading_mode(get_setting) - ctp_st = _cached_ctp_status(mode) - connected = bool(ctp_st.get("connected")) - capital = _capital(conn) - recommend_capital = _recommend_capital(conn) - risk = get_risk_status( - conn, - active_count=_effective_active_position_count( - conn, mode, ctp_connected=connected, - ), - equity=capital, - ) - ctp_acc = {} - bootstrap_live = position_hub.get_snapshot() - if connected and bootstrap_live and bootstrap_live.get("capital") is not None: - cap = float(bootstrap_live.get("capital") or 0) - ctp_acc = {"balance": cap, "available": cap} - active_trend = conn.execute( - "SELECT * FROM trend_pullback_plans WHERE status='active' ORDER BY id DESC LIMIT 1" - ).fetchone() - monitor_count = conn.execute( - "SELECT COUNT(*) AS n FROM trade_order_monitors WHERE status='active'" - ).fetchone()["n"] - roll_count = conn.execute( - "SELECT COUNT(*) AS n FROM roll_groups WHERE status='active'" - ).fetchone()["n"] - conn.commit() - sizing = get_sizing_mode(get_setting) - max_pct = get_max_margin_pct(get_setting) - rec_cache = _recommend_payload(conn, use_ctp_margin=False) - if rec_cache.get("needs_refresh"): - _schedule_recommend_refresh() - ctp_connected = connected - margin_rec = small_account_margin_recommendations() - if not bootstrap_live: - bootstrap_live = { - "ok": True, - "rows": [], - "active_orders": [], - "pending_orders": [], - "capital": capital, - "ctp_status": dict(ctp_st), - "risk_status": risk, - "trading_session": is_trading_session(), - "night_session": is_night_trading_session(), - "session_clock": trading_session_clock(), - "sync_state": trading_state.sync_state, - "sync_label": trading_state.sync_label(), - } - else: - bootstrap_live = dict(bootstrap_live) - bootstrap_live.setdefault("capital", capital) - bootstrap_live.setdefault("risk_status", risk) - bootstrap_live["ctp_status"] = dict(ctp_st) - return render_template( - "trade.html", - trading_mode=mode, - trading_mode_label=trading_mode_label(get_setting), - capital=capital, - recommend_capital=recommend_capital, - risk_status=risk, - ctp_status=ctp_st, - ctp_account=ctp_acc, - active_trend=dict(active_trend) if active_trend else None, - monitor_count=monitor_count, - roll_count=roll_count, - sizing_mode=sizing, - sizing_mode_label=_sizing_mode_label(sizing), - fixed_lots=get_fixed_lots(get_setting), - fixed_amount=get_fixed_amount(get_setting), - risk_percent=get_risk_percent(get_setting), - max_margin_pct=get_max_margin_pct(get_setting), - pending_order_timeout_min=get_pending_order_timeout_min(get_setting), - ctp_auto_connect=is_ctp_auto_connect_enabled(get_setting), - recommend_rows=rec_cache.get("rows") or [], - recommend_updated_at=rec_cache.get("updated_at"), - night_session=is_night_trading_session(), - small_account_scope=should_apply_small_account_scope( - capital, ctp_connected=ctp_connected, - ), - small_account_scope_hint=small_account_scope_hint(ctp_connected=ctp_connected), - small_account_margin_rec=margin_rec if should_apply_small_account_scope( - capital, ctp_connected=ctp_connected, - ) else None, - session_clock=trading_session_clock(), - roll_max_margin_pct=get_roll_max_margin_pct(get_setting), - product_categories=PRODUCT_CATEGORIES, - bootstrap_live=bootstrap_live, - ) - finally: - conn.close() - - @app.route("/recommend") - @login_required - def recommend_page(): - return redirect(url_for("positions") + "#recommend") - - @app.route("/api/trading/live") - @login_required - def api_trading_live(): - snap = position_hub.get_snapshot() - if snap: - return jsonify(_normalize_live_payload(snap)) - payload = _refresh_trading_live_snapshot(fast=True) - payload = _normalize_live_payload(payload) - position_hub.set_snapshot(payload) - return jsonify(payload) - - @app.route("/api/trading/stream") - @login_required - def api_trading_stream(): - from queue import Empty - - @stream_with_context - def generate(): - yield ": stream\n\n" - q = position_hub.subscribe() - try: - snap = position_hub.get_snapshot() - if not snap: - conn = get_db() - try: - init_strategy_tables(conn) - payload = _minimal_live_payload(conn) - finally: - conn.close() - position_hub.set_snapshot(payload) - yield sse_format("positions", payload) - _push_position_snapshot_async(fast=True) - else: - yield sse_format("positions", snap) - while True: - try: - msg = q.get(timeout=25) - yield sse_format(msg["event"], msg["data"]) - except Empty: - yield ": heartbeat\n\n" - finally: - position_hub.unsubscribe(q) - - return Response( - generate(), - mimetype="text/event-stream", - headers={ - "Cache-Control": "no-cache", - "X-Accel-Buffering": "no", - }, - ) - - @app.route("/api/trading/monitor/upsert", methods=["POST"]) - @login_required - def api_trading_monitor_upsert(): - """为已有持仓补充/更新本地止盈止损监控。""" - d = request.get_json(silent=True) or {} - sym = (d.get("symbol_code") or d.get("symbol") or "").strip() - direction = (d.get("direction") or "long").strip().lower() - try: - lots = max(1, int(d.get("lots") or 1)) - entry = float(d.get("entry_price") or d.get("entry") or 0) - sl = float(d["stop_loss"]) if d.get("stop_loss") not in (None, "") else None - tp = float(d["take_profit"]) if d.get("take_profit") not in (None, "") else None - except (TypeError, ValueError, KeyError): - return jsonify({"ok": False, "error": "参数无效"}), 400 - if not sym: - return jsonify({"ok": False, "error": "缺少品种代码"}), 400 - if sl is None and tp is None: - return jsonify({"ok": False, "error": "请至少填写止损或止盈"}), 400 - trailing_on = bool(d.get("trailing_be")) - if trailing_on and sl is None: - return jsonify({"ok": False, "error": "移动保本须填写止损价"}), 400 - if trailing_on: - tp = None - mode = get_trading_mode(get_setting) - conn = get_db() - try: - init_strategy_tables(conn) - mon = _find_active_monitor(conn, sym, direction) - has_pos = bool(mon) - ths_sym = sym - if ctp_status(mode).get("connected"): - for p in _ctp_positions(mode, refresh_if_empty=False): - if int(p.get("lots") or 0) <= 0: - continue - if (p.get("direction") or "long") != direction: - continue - if _match_ctp_symbol(p.get("symbol") or "", sym): - has_pos = True - lots = int(p.get("lots") or lots) - entry = float(p.get("avg_price") or entry or 0) - ths_sym = _ctp_pos_to_ths_code(p) or sym - break - if not has_pos: - return jsonify({"ok": False, "error": "未找到对应持仓"}), 400 - trailing_be = 1 if trailing_on else ( - int(mon.get("trailing_be") or 0) if mon else 0 - ) - mid = _upsert_open_monitor( - conn, - sym=ths_sym, - direction=direction, - lots=lots, - price=entry, - sl=sl, - tp=tp, - trailing_be=trailing_be, - ) - if trailing_on and sl is not None: - conn.execute( - """UPDATE trade_order_monitors SET - take_profit=NULL, initial_stop_loss=?, trailing_r_locked=0 - WHERE id=?""", - (sl, mid), - ) - conn.commit() - _push_position_snapshot_async(fast=False) - return jsonify({ - "ok": True, - "monitor_id": mid, - "message": "止盈止损已保存,程序本地监控", - }) - finally: - conn.close() - - @app.route("/api/trading/monitor/place-orders", methods=["POST"]) - @login_required - def api_trading_monitor_place_orders(): - """本地监控模式:清理旧版柜台挂单,不再向交易所挂止盈止损。""" - d = request.get_json(silent=True) or {} - try: - monitor_id = int(d.get("monitor_id") or 0) - except (TypeError, ValueError): - monitor_id = 0 - conn = get_db() - try: - init_strategy_tables(conn) - ensure_monitor_order_columns(conn) - mode = get_trading_mode(get_setting) - if not ctp_status(mode).get("connected"): - return jsonify({"ok": False, "error": "请先连接 CTP"}), 400 - mon = None - if monitor_id > 0: - row = conn.execute( - "SELECT * FROM trade_order_monitors WHERE id=? AND status='active'", - (monitor_id,), - ).fetchone() - mon = dict(row) if row else None - if not mon: - sym = (d.get("symbol_code") or "").strip() - direction = (d.get("direction") or "long").strip().lower() - for r in conn.execute( - "SELECT * FROM trade_order_monitors WHERE status='active'" - ).fetchall(): - row = dict(r) - if row.get("direction") != direction: - continue - if _match_ctp_symbol(sym, row.get("symbol") or ""): - mon = row - break - if not mon: - return jsonify({"ok": False, "error": "未找到有效监控快照"}), 404 - result = place_monitor_exit_orders( - conn, mon, mode=mode, force=bool(d.get("force")), - ) - if not result.get("ok"): - return jsonify(result), 400 - return jsonify(result) - finally: - conn.close() - - @app.route("/api/trading/monitor/dismiss", methods=["POST"]) - @login_required - def api_trading_monitor_dismiss(): - d = request.get_json(silent=True) or {} - try: - monitor_id = int(d.get("monitor_id") or 0) - except (TypeError, ValueError): - monitor_id = 0 - if monitor_id <= 0: - return jsonify({"ok": False, "error": "无效的监控记录"}), 400 - conn = get_db() - try: - init_strategy_tables(conn) - mode = get_trading_mode(get_setting) - row = conn.execute( - "SELECT * FROM trade_order_monitors WHERE id=? AND status IN ('active', 'pending')", - (monitor_id,), - ).fetchone() - if not row: - return jsonify({"ok": False, "error": "记录不存在或已关闭"}), 404 - mon = dict(row) - if (mon.get("status") or "").strip().lower() == "pending": - if not is_trading_session(): - return jsonify({"ok": False, "error": "不在交易时间段,无法撤单"}), 403 - ok, msg = cancel_pending_monitor(conn, mon, mode) - _push_position_snapshot_async(fast=False) - return jsonify({"ok": ok, "message": msg}) - conn.execute( - "UPDATE trade_order_monitors SET status='closed' WHERE id=?", - (monitor_id,), - ) - conn.commit() - _push_position_snapshot_async(fast=False) - return jsonify({"ok": True, "message": "已取消本地止盈止损监控"}) - finally: - conn.close() - - @app.route("/api/trading/monitor/cancel-open", methods=["POST"]) - @login_required - def api_trading_monitor_cancel_open(): - """撤销 pending 开仓委托(柜台撤单 + 关闭本地记录)。""" - d = request.get_json(silent=True) or {} - try: - monitor_id = int(d.get("monitor_id") or 0) - except (TypeError, ValueError): - monitor_id = 0 - if monitor_id <= 0: - return jsonify({"ok": False, "error": "无效的委托记录"}), 400 - conn = get_db() - try: - init_strategy_tables(conn) - mode = get_trading_mode(get_setting) - if not ctp_status(mode).get("connected"): - return jsonify({"ok": False, "error": "请先连接 CTP"}), 400 - if not is_trading_session(): - return jsonify({"ok": False, "error": "不在交易时间段,无法撤单"}), 403 - row = conn.execute( - "SELECT * FROM trade_order_monitors WHERE id=? AND status='pending'", - (monitor_id,), - ).fetchone() - if not row: - return jsonify({"ok": False, "error": "未找到挂单中的开仓委托"}), 404 - ok, msg = cancel_pending_monitor(conn, dict(row), mode) - _push_position_snapshot_async(fast=False) - return jsonify({"ok": ok, "message": msg}) - finally: - conn.close() - - @app.route("/api/trading/order/cancel", methods=["POST"]) - @login_required - def api_trading_order_cancel(): - """撤销柜台未成交委托(按 vt_order_id)。""" - d = request.get_json(silent=True) or {} - order_id = (d.get("order_id") or "").strip() - if not order_id: - return jsonify({"ok": False, "error": "无效的委托号"}), 400 - mode = get_trading_mode(get_setting) - if not ctp_status(mode).get("connected"): - return jsonify({"ok": False, "error": "请先连接 CTP"}), 400 - if not is_trading_session(): - return jsonify({"ok": False, "error": "不在交易时间段,无法撤单"}), 403 - ok = ctp_cancel_order(mode, order_id) - _push_position_snapshot_async(fast=False) - if not ok: - return jsonify({"ok": False, "error": "撤单失败,委托可能已成交或已撤销"}), 400 - return jsonify({"ok": True, "message": "撤单已提交"}) - - @app.route("/api/trading/close", methods=["POST"]) - @login_required - def api_trading_close(): - d = request.get_json(silent=True) or {} - source = (d.get("source") or "").strip() - conn = get_db() - init_strategy_tables(conn) - mode = get_trading_mode(get_setting) - if not ctp_status(mode).get("connected") and source in ("ctp", "program"): - conn.close() - return jsonify({"ok": False, "error": "请先连接 CTP"}), 400 - sym = (d.get("symbol_code") or d.get("symbol") or "").strip() - direction = (d.get("direction") or "long").strip().lower() - try: - lots = max(1, int(d.get("lots") or 1)) - price = float(d.get("price") or 0) - except (TypeError, ValueError): - conn.close() - return jsonify({"ok": False, "error": "参数无效"}), 400 - if not sym or price <= 0: - conn.close() - return jsonify({"ok": False, "error": "品种或价格无效"}), 400 - offset = "close_long" if direction == "long" else "close_short" - capital = _capital(conn) - mon = None - mid = int(d.get("monitor_id") or 0) - if mid: - row = conn.execute( - "SELECT * FROM trade_order_monitors WHERE id=? AND status='active'", - (mid,), - ).fetchone() - if row: - mon = dict(row) - if not mon: - for r in conn.execute( - "SELECT * FROM trade_order_monitors WHERE status='active'" - ).fetchall(): - row = dict(r) - if row.get("direction") != direction: - continue - if _match_ctp_symbol(sym, row.get("symbol") or ""): - mon = row - mid = int(row["id"]) - break - entry = float(mon.get("entry_price") or 0) if mon else 0.0 - if entry <= 0: - for p in _ctp_positions(mode): - if int(p.get("lots") or 0) <= 0: - continue - if (p.get("direction") or "long") != direction: - continue - if _match_ctp_symbol(p.get("symbol") or "", sym): - entry = float(p.get("avg_price") or price) - break - try: - execute_order( - conn, mode=mode, offset=offset, symbol=sym, direction=direction, - lots=lots, price=price, settings=_settings_dict(), - order_type="market", - ) - # 始终写本地记录:CTP 同步依赖内存开平配对,重启后或成交回报延迟时会漏记 - write_manual_close_trade_log( - conn, - mon, - symbol=sym, - direction=direction, - lots=lots, - close_price=price, - entry_price=entry or price, - trading_mode=mode, - capital=capital, - stop_loss=float(mon["stop_loss"]) if mon and mon.get("stop_loss") is not None else None, - take_profit=float(mon["take_profit"]) if mon and mon.get("take_profit") is not None else None, - open_time=(mon.get("open_time") or "") if mon else "", - symbol_name=(mon.get("symbol_name") or "") if mon else "", - market_code=(mon.get("market_code") or "") if mon else "", - ) - _close_all_monitors_for_sym_dir(conn, sym, direction) - conn.commit() - try: - from ctp_trade_sync import sync_trade_logs_from_ctp - sync_trade_logs_from_ctp(conn, mode, capital=capital, trading_mode=mode) - conn.commit() - except Exception as exc: - logger.debug("sync trades after close: %s", exc) - conn.close() - _push_position_snapshot_async() - return jsonify({"ok": True, "message": "已平仓,交易记录已写入"}) - except ValueError as exc: - conn.close() - return jsonify({"ok": False, "error": str(exc)}), 400 - - - def _roll_ui_modes(): - return frozenset({ADD_MODE_MARKET, ADD_MODE_BREAKOUT}) - - def _roll_filled_lots_map(conn, group_ids: list[int]) -> dict[int, int]: - if not group_ids: - return {} - placeholders = ",".join("?" * len(group_ids)) - rows = conn.execute( - f"""SELECT roll_group_id, COALESCE(SUM(lots), 0) AS n - FROM roll_legs - WHERE roll_group_id IN ({placeholders}) AND status=? - GROUP BY roll_group_id""", - (*group_ids, LEG_STATUS_FILLED), - ).fetchall() - return {int(r["roll_group_id"]): int(r["n"] or 0) for r in rows} - - def _build_roll_context(conn) -> dict: - has_trend = bool(conn.execute( - "SELECT 1 FROM trend_pullback_plans WHERE status='active' LIMIT 1", - ).fetchone()) - groups_by_monitor: dict[int, dict] = {} - pending_monitors: set[int] = set() - for row in conn.execute( - "SELECT * FROM roll_groups WHERE status='active'", - ).fetchall(): - g = dict(row) - mid = int(g.get("order_monitor_id") or 0) - if mid: - groups_by_monitor[mid] = g - for row in conn.execute( - """SELECT g.order_monitor_id - FROM roll_legs l - JOIN roll_groups g ON g.id = l.roll_group_id - WHERE l.status=? AND g.status='active'""", - (LEG_STATUS_PENDING,), - ).fetchall(): - mid = int(row["order_monitor_id"] or 0) - if mid: - pending_monitors.add(mid) - return { - "has_trend": has_trend, - "groups_by_monitor": groups_by_monitor, - "pending_monitors": pending_monitors, - } - - def _roll_eligibility_with_ctx(conn, mon: dict, ctx: dict) -> Optional[str]: - mid = int(mon["id"]) - grp = ctx["groups_by_monitor"].get(mid) - legs_done = int(grp.get("leg_count") or 0) if grp else 0 - return roll_eligibility_error( - sizing_mode=get_sizing_mode(get_setting), - monitor=mon, - has_active_trend=ctx["has_trend"], - legs_done=legs_done, - has_pending_leg=mid in ctx["pending_monitors"], - ) - - def _enrich_roll_group_row_fast(row: dict, filled_map: dict[int, int]) -> dict: - out = dict(row) - lots = float(out.get("mon_lots") or 0) - entry = float(out.get("mon_entry") or 0) - tp = float(out.get("mon_tp") or out.get("initial_take_profit") or 0) - direction = (out.get("direction") or "long").strip().lower() - sym = (out.get("symbol") or "").strip() - mult = int(get_contract_spec(sym).get("mult") or 1) if sym else 1 - gid = int(out.get("id") or 0) - filled_add_lots = int(filled_map.get(gid) or 0) - out["add_lots_filled"] = filled_add_lots - out["first_lots"] = max(0, int(lots) - filled_add_lots) - out["total_lots"] = int(lots) - out["avg_entry"] = round(entry, 4) if entry > 0 else None - if lots > 0 and entry > 0 and tp > 0: - if direction == "long": - out["reward_at_tp"] = round((tp - entry) * lots * mult, 2) - else: - out["reward_at_tp"] = round((entry - tp) * lots * mult, 2) - else: - out["reward_at_tp"] = None - return out - - def _enrich_roll_group_row(conn, row: dict) -> dict: - gid = int(row.get("id") or 0) - filled_map = _roll_filled_lots_map(conn, [gid]) if gid > 0 else {} - return _enrich_roll_group_row_fast(row, filled_map) - - def _archive_roll_group( - conn, - grp: dict, - *, - result_label: str = "持仓已结束", - ) -> None: - from zoneinfo import ZoneInfo - - now_s = datetime.now(ZoneInfo("Asia/Shanghai")).strftime("%Y-%m-%d %H:%M:%S") - gid = int(grp.get("id") or 0) - if gid <= 0: - return - if conn.execute( - "SELECT 1 FROM strategy_trade_snapshots WHERE strategy_type=? AND source_id=? LIMIT 1", - (STRATEGY_ROLL, gid), - ).fetchone(): - conn.execute( - "UPDATE roll_groups SET status='closed', updated_at=? WHERE id=?", - (now_s, gid), - ) - return - legs = [ - dict(r) for r in conn.execute( - "SELECT * FROM roll_legs WHERE roll_group_id=? ORDER BY id", - (gid,), - ).fetchall() - ] - mon = None - mid = int(grp.get("order_monitor_id") or 0) - if mid: - row = conn.execute( - "SELECT * FROM trade_order_monitors WHERE id=?", - (mid,), - ).fetchone() - mon = dict(row) if row else None - payload = { - "group": dict(grp), - "legs": legs, - "monitor": mon, - } - save_snapshot( - conn, - strategy_type=STRATEGY_ROLL, - source_id=gid, - symbol=grp.get("symbol") or (mon or {}).get("symbol") or "", - direction=grp.get("direction") or (mon or {}).get("direction") or "", - result_label=result_label, - payload=payload, - opened_at=grp.get("created_at") or "", - ) - conn.execute( - "UPDATE roll_legs SET status=? WHERE roll_group_id=? AND status=?", - (LEG_STATUS_CANCELLED, gid, LEG_STATUS_PENDING), - ) - conn.execute( - "UPDATE roll_groups SET status='closed', updated_at=? WHERE id=?", - (now_s, gid), - ) - - def _close_stale_roll_groups(conn) -> int: - rows = conn.execute( - """SELECT g.*, m.status AS monitor_status - FROM roll_groups g - LEFT JOIN trade_order_monitors m ON m.id = g.order_monitor_id - WHERE g.status='active' - AND (m.id IS NULL OR m.status != 'active')""" - ).fetchall() - for r in rows: - _archive_roll_group(conn, dict(r), result_label="持仓已结束") - return len(rows) - - def _enrich_roll_leg_row(row: dict, mode: str) -> dict: - out = dict(row) - sym = (out.get("symbol") or "").strip() - mark = _cached_position_mark(sym, out.get("direction") or "") if sym else None - out["current_price"] = round(float(mark), 4) if mark and mark > 0 else None - return out - - def _enrich_roll_record_row(conn, row: dict) -> dict: - out = dict(row) - snap = out.get("snapshot") or {} - group = snap.get("group") or {} - legs = snap.get("legs") or [] - monitor = snap.get("monitor") or {} - filled_legs = [ - l for l in legs - if (l.get("status") or "").strip().lower() == LEG_STATUS_FILLED - ] - add_lots = sum(int(l.get("lots") or 0) for l in filled_legs) - total_lots = int((monitor or {}).get("lots") or 0) - first_lots = max(0, total_lots - add_lots) - latest_sl = ( - group.get("current_stop_loss") - or (monitor or {}).get("stop_loss") - or None - ) - close_log = None - try: - close_log = conn.execute( - """SELECT close_price, pnl, pnl_net, close_time, lots - FROM trade_logs - WHERE lower(symbol)=lower(?) AND direction=? - ORDER BY close_time DESC, id DESC LIMIT 1""", - (out.get("symbol") or "", out.get("direction") or ""), - ).fetchone() - except Exception: - close_log = None - close_d = dict(close_log) if close_log else {} - out["detail"] = { - "first_lots": first_lots if first_lots > 0 else None, - "add_count": len(filled_legs), - "add_lots": add_lots, - "total_lots": total_lots if total_lots > 0 else None, - "latest_stop_loss": latest_sl, - "close_price": close_d.get("close_price"), - "close_time": close_d.get("close_time") or out.get("closed_at"), - "pnl": close_d.get("pnl_net") if close_d.get("pnl_net") is not None else close_d.get("pnl"), - "legs": filled_legs, - "monitor": monitor, - "group": group, - } - return out - - def _roll_leg_trigger_price(leg: dict): - for key in ("breakthrough_price", "limit_price", "fill_price"): - val = leg.get(key) - if val not in (None, "", 0): - return val - return None - - @app.route("/strategy") - @login_required - @_nav("strategy") - def strategy_page(): - conn = get_db() - try: - init_strategy_tables(conn) - ensure_monitor_order_columns(conn) - capital = _capital(conn) - active_trend = conn.execute( - "SELECT * FROM trend_pullback_plans WHERE status='active' ORDER BY id DESC LIMIT 1" - ).fetchone() - monitors_raw = conn.execute( - "SELECT * FROM trade_order_monitors WHERE status='active' ORDER BY id DESC" - ).fetchall() - mode = get_trading_mode(get_setting) - roll_ctx = _build_roll_context(conn) - roll_groups = conn.execute( - """SELECT g.*, m.symbol_name, m.lots AS mon_lots, m.entry_price AS mon_entry, - m.take_profit AS mon_tp - FROM roll_groups g - LEFT JOIN trade_order_monitors m ON m.id = g.order_monitor_id - WHERE g.status='active' ORDER BY g.id DESC""" - ).fetchall() - roll_legs = conn.execute( - """SELECT l.*, g.symbol, g.direction, g.order_monitor_id - FROM roll_legs l - JOIN roll_groups g ON g.id = l.roll_group_id - WHERE l.status=? AND g.status='active' - ORDER BY l.id DESC LIMIT 30""", - (LEG_STATUS_PENDING,), - ).fetchall() - sizing = get_sizing_mode(get_setting) - roll_allowed = sizing == MODE_AMOUNT - monitors = [] - for m in monitors_raw: - row = dict(m) - err = _roll_eligibility_with_ctx(conn, row, roll_ctx) - row["roll_eligible"] = roll_allowed and err is None - if not roll_allowed: - row["roll_block_reason"] = "仅固定金额(以损定仓)模式可滚仓" - else: - row["roll_block_reason"] = err or "" - monitors.append(row) - active_trend_row = dict(active_trend) if active_trend else None - if active_trend_row: - active_trend_row["period_label"] = trend_period_label( - active_trend_row.get("period") or "15m", - ) - group_ids = [int(g["id"]) for g in roll_groups if g["id"]] - filled_map = _roll_filled_lots_map(conn, group_ids) - enriched_groups = [ - _enrich_roll_group_row_fast(dict(g), filled_map) for g in roll_groups - ] - enriched_legs = [_enrich_roll_leg_row(dict(l), mode) for l in roll_legs] - return render_template( - "strategy.html", - capital=capital, - fixed_amount=get_fixed_amount(get_setting), - sizing_mode=sizing, - sizing_mode_label=_sizing_mode_label(sizing), - roll_allowed=roll_allowed, - active_trend=active_trend_row, - monitors=monitors, - roll_groups=enriched_groups, - roll_legs=enriched_legs, - trading_session=is_trading_session(), - session_clock=trading_session_clock(), - trend_periods=trend_strategy_periods(), - add_mode_labels={ - "market": "市价加仓", - "breakout": "突破加仓", - }, - roll_leg_status_labels={ - "pending": "监控中", - "filled": "已成交", - "cancelled": "已取消", - }, - ) - finally: - conn.close() - - @app.route("/strategy/records") - @login_required - def strategy_records_page(): - conn = get_db() - init_strategy_tables(conn) - trend, roll = list_snapshots(conn) - roll = [_enrich_roll_record_row(conn, r) for r in roll] - conn.close() - return render_template("strategy_records.html", trend_rows=trend, roll_rows=roll) - - @app.route("/api/trade/quote") - @login_required - def api_trade_quote(): - sym = (request.args.get("symbol") or "").strip() - lots = request.args.get("lots") or "1" - if not sym: - return jsonify({"ok": False, "error": "缺少品种"}), 400 - codes = ths_to_codes(sym) - price = fetch_price(sym, codes.get("market_code", "") if codes else "", codes.get("sina_code", "") if codes else "") - try: - lots_f = max(1, int(float(lots))) - except (TypeError, ValueError): - lots_f = 1 - mode = get_trading_mode(get_setting) - metrics = calc_order_tick_metrics(sym, lots_f, price, trading_mode=mode) - spec = get_contract_spec(sym) - name = codes.get("name", sym) if codes else sym - pos_long = pos_short = 0 - ctp_st = ctp_status(mode) - if ctp_st.get("connected"): - for p in _ctp_positions(mode): - if not _match_ctp_symbol(p.get("symbol", ""), sym): - continue - if p["direction"] == "long": - pos_long = int(p["lots"]) - else: - pos_short = int(p["lots"]) - max_open = int(_capital(get_db()) / (metrics["margin_per_lot"] or 1)) if metrics.get("margin_per_lot") else 0 - return jsonify({ - "ok": True, - "symbol": sym, - "name": name, - "price": price, - "lots": lots_f, - "metrics": metrics, - "exchange": codes.get("exchange", "") if codes else "", - "pos_long": pos_long, - "pos_short": pos_short, - "max_open_long": max_open, - "max_open_short": max_open, - "footer_text": ( - f"*{name} 每手{spec['mult']}吨/点 最小变动{metrics['tick_size']} " - f"每跳{metrics['tick_value_per_lot']}元/手×{lots_f}={metrics['tick_value_total']}元 " - f"精度{metrics['price_precision']}位小数" - ), - }) - - @app.route("/api/trade/preview", methods=["POST"]) - @login_required - def api_trade_preview(): - d = request.get_json(silent=True) or {} - sym = (d.get("symbol") or "").strip() - direction = (d.get("direction") or "long").strip().lower() - try: - entry = float(d.get("entry") or d.get("price") or 0) - sl = float(d.get("stop_loss") or 0) - tp = float(d.get("take_profit") or 0) - except (TypeError, ValueError): - return jsonify({"ok": False, "error": "价格参数无效"}), 400 - conn = get_db() - capital = _capital(conn) - conn.close() - sizing = get_sizing_mode(get_setting) - margin_pct = get_max_margin_pct(get_setting) - sizing_info = {} - if sizing == MODE_AMOUNT: - lots, err, sizing_info = calc_lots_by_amount( - entry, sl, direction, get_fixed_amount(get_setting), sym, - capital=capital, max_margin_pct=margin_pct, - trading_mode=get_trading_mode(get_setting), - ) - if err: - return jsonify({"ok": False, "error": err}), 400 - elif sizing == MODE_FIXED: - lots = get_fixed_lots(get_setting) - else: - try: - lots = max(1, int(d.get("lots") or 1)) - except (TypeError, ValueError): - lots = 1 - metrics = calc_position_metrics(direction, entry, sl, tp, lots, entry, capital, sym) - tick = calc_order_tick_metrics( - sym, lots, entry, direction=direction, trading_mode=get_trading_mode(get_setting), - ) - return jsonify({ - "ok": True, "lots": lots, "sizing_mode": sizing, - "metrics": metrics, "tick": tick, "capital": capital, - "sizing_info": sizing_info, - }) - - @app.route("/api/trade/order", methods=["POST"]) - @login_required - def api_trade_order(): - d = request.get_json(silent=True) or {} - sym = (d.get("symbol") or "").strip() - offset = (d.get("offset") or "open").strip().lower() - direction = (d.get("direction") or "long").strip().lower() - try: - lots = max(1, int(d.get("lots") or 1)) - price = float(d.get("price") or 0) - except (TypeError, ValueError): - return jsonify({"ok": False, "error": "手数或价格无效"}), 400 - order_type = (d.get("order_type") or d.get("price_type") or "limit").strip().lower() - if order_type == "market" and price <= 0: - codes = ths_to_codes(sym) - price = fetch_price( - sym, - codes.get("market_code", "") if codes else "", - codes.get("sina_code", "") if codes else "", - ) or 0 - if not sym or price <= 0: - return jsonify({"ok": False, "error": "品种或价格无效"}), 400 - conn = get_db() - init_strategy_tables(conn) - mode = get_trading_mode(get_setting) - if offset.startswith("open"): - _sync_trade_monitors_with_ctp(conn, mode) - if not is_trading_session(): - conn.close() - return jsonify({"ok": False, "error": "不在交易时间段"}), 403 - if d.get("trailing_be") and not d.get("stop_loss"): - conn.close() - return jsonify({"ok": False, "error": "开启移动保本须填写止损价"}), 400 - err = assert_can_open( - conn, - active_count=_effective_active_position_count(conn, mode), - equity=_capital(conn), - ) - if err: - conn.close() - return jsonify({"ok": False, "error": err}), 403 - scope_err = assert_product_allowed_for_capital( - sym, _capital(conn), ctp_connected=is_ctp_connected(get_setting), - ) - if scope_err: - conn.close() - return jsonify({"ok": False, "error": scope_err}), 403 - ctp_st = ctp_status(mode) - if not ctp_st.get("connected"): - conn.close() - if get_bridge().connect_in_progress(): - return jsonify({"ok": False, "error": "CTP 连接中,请稍候再下单"}), 400 - return jsonify({"ok": False, "error": "请先连接 CTP"}), 400 - sizing = get_sizing_mode(get_setting) - if offset.startswith("open") and sizing == MODE_AMOUNT: - sl = float(d.get("stop_loss") or 0) - if sl <= 0: - conn.close() - return jsonify({"ok": False, "error": "固定金额模式须填写止损价"}), 400 - lots_calc, err, _sizing_info = calc_lots_by_amount( - price, sl, direction, get_fixed_amount(get_setting), sym, - capital=_capital(conn), max_margin_pct=get_max_margin_pct(get_setting), - trading_mode=mode, - ) - if err: - conn.close() - return jsonify({"ok": False, "error": err}), 400 - lots = lots_calc or lots - elif offset.startswith("open") and sizing == MODE_FIXED: - lots = get_fixed_lots(get_setting) - margin_pct = get_max_margin_pct(get_setting) - usage = calc_margin_usage_pct( - _ctp_positions(mode), - _capital(conn), - extra_symbol=sym if offset.startswith("open") else "", - extra_lots=lots if offset.startswith("open") else 0, - extra_price=price if offset.startswith("open") else 0, - extra_direction=direction if offset.startswith("open") else "long", - trading_mode=mode, - ) - if offset.startswith("open") and usage > margin_pct: - conn.close() - return jsonify({ - "ok": False, - "error": f"保证金占用 {usage:.1f}% 超过上限 {margin_pct:g}%(可在系统设置修改)", - }), 403 - if lots > DEFAULT_MAX_ORDER_LOTS: - conn.close() - return jsonify({ - "ok": False, - "error": f"单笔手数 {lots} 超过上限 {DEFAULT_MAX_ORDER_LOTS},请加大止损距离或改固定手数", - }), 400 - try: - result = execute_order( - conn, - mode=mode, - offset=offset, - symbol=sym, - direction=direction, - lots=lots, - price=price, - settings=_settings_dict(), - order_type=order_type, - ) - if offset.startswith("open") and d.get("trailing_be") and not d.get("stop_loss"): - conn.close() - return jsonify({"ok": False, "error": "开启移动保本须填写止损价"}), 400 - if offset.startswith("open"): - from zoneinfo import ZoneInfo - sl = d.get("stop_loss") - trailing_be = 1 if d.get("trailing_be") else 0 - tp = None if trailing_be else d.get("take_profit") - open_ts = datetime.now(ZoneInfo("Asia/Shanghai")).strftime("%Y-%m-%d %H:%M:%S") - vt_order_id = str(result.get("order_id") or "") - mid = _upsert_open_monitor( - conn, - sym=sym, - direction=direction, - lots=lots, - price=price, - sl=sl, - tp=tp, - trailing_be=trailing_be, - open_time=open_ts, - monitor_type="manual", - status="pending", - vt_order_id=vt_order_id or None, - order_price=price, - ) - conn.commit() - try: - with _ctp_td_lock: - get_bridge().refresh_positions() - except Exception: - pass - _reconcile_pending(conn, mode, capital=_capital(conn)) - st_row = conn.execute( - "SELECT status FROM trade_order_monitors WHERE id=?", (mid,), - ).fetchone() - filled = st_row and (st_row["status"] or "").strip().lower() == "active" - rejected = st_row and (st_row["status"] or "").strip().lower() == "closed" - if rejected: - conn.commit() - conn.close() - _push_position_snapshot_async(fast=False) - return jsonify({ - "ok": False, - "error": "委托已被柜台拒绝或撤销(请确认合约状态与交易时段)", - "lots": lots, - "filled": False, - }), 400 - if not filled: - try: - get_bridge().refresh_positions() - except Exception: - pass - _reconcile_pending(conn, mode, capital=_capital(conn)) - st_row = conn.execute( - "SELECT status FROM trade_order_monitors WHERE id=?", (mid,), - ).fetchone() - filled = st_row and (st_row["status"] or "").strip().lower() == "active" - rejected = st_row and (st_row["status"] or "").strip().lower() == "closed" - if rejected: - conn.commit() - conn.close() - _push_position_snapshot_async(fast=False) - return jsonify({ - "ok": False, - "error": "委托已被柜台拒绝或撤销(请确认合约状态与交易时段)", - "lots": lots, - "filled": False, - }), 400 - if filled: - _sync_monitor_from_ctp( - conn, mid, sym, direction, mode, capital=_capital(conn), - ) - mon_row = conn.execute( - "SELECT * FROM trade_order_monitors WHERE id=?", (mid,), - ).fetchone() - if mon_row and (sl or tp): - try: - ensure_monitor_order_columns(conn) - cancel_monitor_exit_orders(conn, dict(mon_row), mode=mode) - except Exception as exc: - logger.warning("清理旧版止盈止损挂单失败: %s", exc) - conn.commit() - _push_position_snapshot_async(fast=False) - msg = ( - f"开仓成功 · {lots} 手" - if filled - else ( - f"委托已提交 · {lots} 手挂单中" - f"({get_pending_order_timeout_sec(get_setting) // 60} 分钟未成交自动撤单)" - ) - ) - conn.commit() - if offset.startswith("open"): - from db_conn import DB_PATH - from ai_worker import schedule_ai_event_analysis - from trade_notify import notify_manual_open_filled - - if filled: - open_sl = float(d.get("stop_loss") or 0) if d.get("stop_loss") else None - open_tp = None if d.get("trailing_be") else d.get("take_profit") - if open_tp is not None: - try: - open_tp = float(open_tp) - except (TypeError, ValueError): - open_tp = None - codes = ths_to_codes(sym) or {} - if open_sl and open_sl > 0: - notify_manual_open_filled( - send_wechat=send_wechat_msg, - get_setting=get_setting, - mode_label=trading_mode_label(get_setting), - sym=sym, - symbol_name=codes.get("name") or sym, - direction=direction, - entry=price, - sl=open_sl, - tp=open_tp, - lots=lots, - capital=_capital(conn), - order_id=str(result.get("order_id") or ""), - trailing_be=bool(d.get("trailing_be")), - be_tick_buffer=get_trailing_be_tick_buffer(get_setting), - schedule_ai_fn=schedule_ai_event_analysis, - db_path=DB_PATH, - ) - else: - send_wechat_msg( - f"{trading_mode_label(get_setting)} 开仓 {sym} {direction} {lots}手 @{price}" - ) - elif not filled: - send_wechat_msg( - f"委托已提交 · {sym} {direction} {lots}手挂单中" - f"({get_pending_order_timeout_sec(get_setting) // 60} 分钟未成交自动撤单)" - ) - elif not offset.startswith("open"): - send_wechat_msg( - f"{trading_mode_label(get_setting)} {offset} {sym} {direction} {lots}手 @{price}" - ) - conn.close() - _push_position_snapshot_async(fast=False) - return jsonify({ - "ok": True, - "result": result, - "lots": lots, - "message": msg if offset.startswith("open") else "委托已提交柜台", - "filled": filled if offset.startswith("open") else None, - }) - except (ValueError, RuntimeError) as exc: - conn.close() - return jsonify({"ok": False, "error": str(exc)}), 400 - except Exception as exc: - conn.close() - return jsonify({"ok": False, "error": str(exc)}), 500 - - @app.route("/api/ctp/connect", methods=["POST"]) - @login_required - def api_ctp_connect(): - from vnpy_bridge import ctp_start_connect - from ctp_settings import CTP_DISABLED_HINT - - if not is_ctp_auto_connect_enabled(get_setting): - mode = get_trading_mode(get_setting) - st = ctp_status(mode) - return jsonify({ - "ok": False, - "disabled": True, - "error": CTP_DISABLED_HINT, - "status": st, - }), 400 - mode = get_trading_mode(get_setting) - body = request.get_json(silent=True) or {} - force = bool(body.get("force")) - auto = bool(body.get("auto")) - # 自动连接仅由 qihuo-ctp 后台 worker 发起;Web 只读状态,避免换页重复 connect。 - if auto and not force: - st = ctp_status(mode) - acc = _ctp_account(mode) if st.get("connected") else {} - return jsonify({ - "ok": True, - "connecting": bool(st.get("connecting")), - "backend_managed": True, - "status": st, - "account": acc, - }) - info = ctp_start_connect(mode, force=force) - st = info.get("status") or ctp_status(mode) - acc = _ctp_account(mode) if st.get("connected") else {} - if st.get("connected"): - return jsonify({"ok": True, "status": st, "account": acc}) - if info.get("connecting") or info.get("started"): - return jsonify({ - "ok": True, - "connecting": True, - "status": st, - "account": acc, - }) - if info.get("cooldown"): - return jsonify({ - "ok": False, - "cooldown": True, - "error": st.get("last_error") or "CTP 登录冷却中", - "status": st, - "account": acc, - }), 400 - return jsonify({ - "ok": False, - "error": st.get("last_error") or "CTP 连接未启动", - "status": st, - "account": acc, - }), 400 - - @app.route("/api/ctp/status") - @login_required - def api_ctp_status(): - mode = get_trading_mode(get_setting) - st = ctp_status(mode) - acc = {} - if st.get("connected"): - try: - acc = _ctp_account(mode) - except Exception: - acc = {} - return jsonify({"ok": True, "status": st, "account": acc}) - - @app.route("/api/account_snapshot") - @login_required - def api_account_snapshot(): - conn = get_db() - try: - init_strategy_tables(conn) - mode = get_trading_mode(get_setting) - ctp_st = ctp_status(mode) - capital = _capital(conn) - risk = get_risk_status( - conn, - active_count=_effective_active_position_count(conn, mode), - equity=capital, - ) - conn.commit() - ctp_acc = _ctp_account(mode) if ctp_st.get("connected") else {} - positions = _ctp_positions(mode) if ctp_st.get("connected") else [] - if ctp_st.get("connected") and not positions: - positions = _positions_for_monitor_restore(mode) - return jsonify({ - "capital": capital, - "trading_mode": mode, - "trading_mode_label": trading_mode_label(get_setting), - "sizing_mode": get_sizing_mode(get_setting), - "risk_status": risk, - "ctp_status": ctp_st, - "ctp_account": ctp_acc, - "positions": positions, - }) - finally: - conn.close() - - @app.route("/api/recommend/list") - @login_required - def api_recommend_list(): - """只读数据库缓存,不在请求时拉行情。""" - conn = get_db() - try: - payload = _recommend_payload(conn) - return jsonify({"ok": True, **payload}) - finally: - conn.close() - - @app.route("/api/recommend/stream") - @login_required - def api_recommend_stream(): - from queue import Empty - - def generate(): - q = recommend_hub.subscribe() - try: - conn = get_db() - try: - payload = _recommend_payload(conn) - finally: - conn.close() - yield sse_format("recommend", {"ok": True, **payload}) - while True: - try: - msg = q.get(timeout=25) - yield sse_format(msg["event"], msg["data"]) - except Empty: - yield ": heartbeat\n\n" - finally: - recommend_hub.unsubscribe(q) - - return Response( - stream_with_context(generate()), - mimetype="text/event-stream", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "X-Accel-Buffering": "no", - }, - ) - - @app.route("/api/recommend/refresh", methods=["POST"]) - @login_required - def api_recommend_refresh(): - """手动触发一次后台刷新(仍写入数据库)。""" - conn = get_db() - try: - init_strategy_tables(conn) - capital = _recommend_capital(conn) - mode = get_trading_mode(get_setting) - rows = refresh_recommend_cache( - conn, capital, _main_quote, trading_mode=mode, - max_margin_pct=get_max_margin_pct(get_setting), - ) - max_pct = get_max_margin_pct(get_setting) - payload = _recommend_payload(conn) - recommend_hub.broadcast("recommend", {"ok": True, **payload}) - return jsonify({"ok": True, "count": len(rows), **payload}) - finally: - conn.close() - - @app.route("/api/strategy/trend/preview", methods=["POST"]) - @login_required - def api_trend_preview(): - d = request.get_json(silent=True) or {} - sym = (d.get("symbol") or "").strip() - conn = get_db() - if conn.execute("SELECT id FROM trend_pullback_plans WHERE status='active'").fetchone(): - conn.close() - return jsonify({"ok": False, "error": "已有运行中趋势计划"}), 400 - capital = _capital(conn) - codes = ths_to_codes(sym) - price = fetch_price(sym, codes.get("market_code", "") if codes else "", codes.get("sina_code", "") if codes else "") - conn.close() - if not price: - return jsonify({"ok": False, "error": "无法获取现价"}), 400 - plan, err = compute_trend_plan_futures( - direction=d.get("direction") or "long", - stop_loss=float(d.get("stop_loss") or 0), - add_upper=float(d.get("add_upper") or 0), - take_profit=float(d.get("take_profit") or 0), - risk_percent=float(d.get("risk_percent") or get_risk_percent(get_setting)), - capital=capital, - live_price=price, - ths_code=sym, - dca_legs=int(d.get("dca_legs") or 5), - ) - if err: - return jsonify({"ok": False, "error": err}), 400 - period = normalize_trend_period(d.get("period")) - sym_name = (d.get("symbol_name") or "").strip() - if not sym_name and codes: - sym_name = codes.get("name") or sym - plan = enrich_trend_plan_preview( - plan, symbol=sym, symbol_name=sym_name, period=period, - ) - return jsonify({"ok": True, "plan": plan}) - - @app.route("/api/strategy/trend/execute", methods=["POST"]) - @login_required - def api_trend_execute(): - d = request.get_json(silent=True) or {} - sym = (d.get("symbol") or "").strip() - conn = get_db() - init_strategy_tables(conn) - capital = _capital(conn) - err = assert_can_open(conn, equity=capital) - if err: - conn.close() - return jsonify({"ok": False, "error": err}), 403 - scope_err = assert_product_allowed_for_capital( - sym, capital, ctp_connected=is_ctp_connected(get_setting), - ) - if scope_err: - conn.close() - return jsonify({"ok": False, "error": scope_err}), 403 - codes = ths_to_codes(sym) - price = fetch_price(sym, codes.get("market_code", "") if codes else "", codes.get("sina_code", "") if codes else "") - plan, perr = compute_trend_plan_futures( - direction=d.get("direction") or "long", - stop_loss=float(d.get("stop_loss") or 0), - add_upper=float(d.get("add_upper") or 0), - take_profit=float(d.get("take_profit") or 0), - risk_percent=float(d.get("risk_percent") or get_risk_percent(get_setting)), - capital=capital, - live_price=price or float(d.get("live_price") or 0), - ths_code=sym, - ) - if perr: - conn.close() - return jsonify({"ok": False, "error": perr}), 400 - period = normalize_trend_period(d.get("period")) - sym_name = (d.get("symbol_name") or "").strip() - if not sym_name and codes: - sym_name = codes.get("name") or sym - plan = enrich_trend_plan_preview( - plan, symbol=sym, symbol_name=sym_name, period=period, - ) - mode = get_trading_mode(get_setting) - try: - execute_order( - conn, mode=mode, offset="open", symbol=sym, - direction=plan["direction"], lots=plan["first_lots"], price=price, settings=_settings_dict(), - ) - except ValueError as exc: - conn.close() - return jsonify({"ok": False, "error": str(exc)}), 400 - now = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - cur = conn.execute( - """INSERT INTO trend_pullback_plans ( - status, symbol, symbol_name, direction, stop_loss, add_upper, take_profit, - risk_percent, capital_snapshot, plan_margin, target_lots, first_lots, remainder_lots, - dca_legs, leg_amounts_json, grid_prices_json, first_order_done, avg_entry_price, - lots_open, opened_at, period - ) VALUES ('active',?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,1,?,?,?,?) RETURNING id""", - ( - sym, sym_name or (codes.get("name", sym) if codes else sym), plan["direction"], - plan["stop_loss"], plan["add_upper"], plan["take_profit"], - plan["risk_percent"], plan["capital_snapshot"], plan["plan_margin"], - plan["target_lots"], plan["first_lots"], plan["remainder_lots"], - plan["dca_legs"], plan["leg_amounts_json"], plan["grid_prices_json"], - price, plan["first_lots"], now, plan["period"], - ), - ) - row = cur.fetchone() - plan_id = int(row["id"] if isinstance(row, dict) else row[0]) - conn.commit() - conn.close() - send_wechat_msg(f"趋势回调首仓 {sym} {plan['first_lots']}手") - return jsonify({"ok": True, "plan_id": plan_id, "plan": plan}) - - def _roll_group_for_monitor(conn, monitor_id: int): - return conn.execute( - "SELECT * FROM roll_groups WHERE order_monitor_id=? AND status='active'", - (int(monitor_id),), - ).fetchone() - - def _roll_filled_legs(conn, monitor_id: int) -> int: - grp = _roll_group_for_monitor(conn, monitor_id) - if grp: - return int(grp["leg_count"] or 0) - return 0 - - def _roll_has_pending(conn, monitor_id: int) -> bool: - grp = _roll_group_for_monitor(conn, monitor_id) - if not grp: - return False - return bool(conn.execute( - "SELECT 1 FROM roll_legs WHERE roll_group_id=? AND status=? LIMIT 1", - (int(grp["id"]), LEG_STATUS_PENDING), - ).fetchone()) - - def _roll_eligibility(conn, mon: dict, ctx: Optional[dict] = None) -> Optional[str]: - if ctx is None: - ctx = _build_roll_context(conn) - return _roll_eligibility_with_ctx(conn, mon, ctx) - - def _roll_monitor_for_request(conn, mon_id: int) -> Optional[dict]: - row = conn.execute( - "SELECT * FROM trade_order_monitors WHERE id=?", - (int(mon_id),), - ).fetchone() - if not row: - return None - mon = dict(row) - if (mon.get("status") or "").strip().lower() == "active": - return mon - mode = get_trading_mode(get_setting) - if not _cached_ctp_status(mode).get("connected"): - return None - sym = (mon.get("symbol") or "").strip() - direction = (mon.get("direction") or "long").strip().lower() - for p in _positions_for_monitor_restore(mode, allow_ctp=False): - if int(p.get("lots") or 0) <= 0: - continue - if (p.get("direction") or "long").strip().lower() != direction: - continue - if not _match_ctp_symbol(p.get("symbol") or "", sym): - continue - execute_retry( - conn, - "UPDATE trade_order_monitors SET status='active' WHERE id=?", - (int(mon_id),), - ) - mon["status"] = "active" - _sync_monitor_from_ctp( - conn, - int(mon_id), - sym, - direction, - mode, - ctp=p, - capital=_capital(conn), - ) - fresh = conn.execute( - "SELECT * FROM trade_order_monitors WHERE id=?", - (int(mon_id),), - ).fetchone() - return dict(fresh) if fresh else mon - return None - - def _roll_mark_price( - sym: str, - mon: dict, - mode: str, - *, - allow_ctp: bool = False, - ) -> float: - mark = _cached_position_mark(sym, (mon or {}).get("direction") or "") - if mark and mark > 0: - return float(mark) - mark = ( - ctp_get_tick_price(mode, sym) - if allow_ctp and ctp_status(mode).get("connected") - else None - ) - if mark and mark > 0: - return float(mark) - px = fetch_price(sym) - if px and px > 0: - return float(px) - return float(mon.get("entry_price") or 0) - - def _build_roll_preview(conn, d: dict, mon: dict, *, mode: str): - sym = mon["symbol"] - spec = get_contract_spec(sym) - capital = _capital(conn) - add_mode = (d.get("add_mode") or ADD_MODE_MARKET).strip().lower() - off_session_breakout = add_mode == ADD_MODE_BREAKOUT and not is_trading_session() - mark = _roll_mark_price(sym, mon, mode, allow_ctp=not off_session_breakout) - if (not mark or mark <= 0) and off_session_breakout: - bt = float(d.get("breakthrough_price") or 0) - mark = bt if bt > 0 else float(mon.get("entry_price") or 0) - entry_existing = _live_entry_price( - sym, mon["direction"], mode, float(mon.get("entry_price") or 0), - allow_ctp=False, - ) - if add_mode in FIB_MODES: - return None, "斐波加仓已停用,请选市价或突破" - if add_mode not in _roll_ui_modes(): - return None, "仅支持市价加仓或突破加仓" - risk_budget = get_fixed_amount(get_setting) - legs_done = _roll_filled_legs(conn, int(mon["id"])) - preview, err = preview_roll( - direction=mon["direction"], - symbol=sym, - qty_existing=float(mon["lots"]), - entry_existing=entry_existing, - initial_take_profit=float(mon["take_profit"] or 0), - add_mode=add_mode, - new_stop_loss=float(d.get("new_stop_loss") or 0), - risk_budget=risk_budget, - mult=int(spec["mult"]), - mark_price=mark, - add_price=float(d.get("add_price") or 0) or mark, - limit_price=d.get("limit_price"), - breakthrough_price=d.get("breakthrough_price"), - fib_upper=d.get("fib_upper"), - fib_lower=d.get("fib_lower"), - legs_done=legs_done, - off_session_pending=off_session_breakout, - ) - if err: - return None, err - preview, merr = _apply_roll_margin_cap( - preview, conn=conn, mode=mode, mon=dict(mon), capital=capital, - ) - if merr: - return None, merr - return preview, None - - def _commit_roll_fill( - conn, - *, - mon: dict, - preview: dict, - add_mode: str, - mode: str, - pending_leg_id: Optional[int] = None, - ) -> tuple[bool, str]: - sym = mon["symbol"] - mon_id = int(mon["id"]) - price = float(preview["add_price"]) - try: - execute_order( - conn, mode=mode, offset="open", symbol=sym, - direction=mon["direction"], lots=int(preview["add_lots"]), price=price, - settings=_settings_dict(), - ) - except ValueError as exc: - return False, str(exc) - new_lots = int(mon["lots"]) + int(preview["add_lots"]) - new_avg = preview["avg_entry_after"] - new_sl = preview["new_stop_loss"] - conn.execute( - "UPDATE trade_order_monitors SET lots=?, entry_price=?, stop_loss=? WHERE id=?", - (new_lots, new_avg, new_sl, mon_id), - ) - grp = _roll_group_for_monitor(conn, mon_id) - now = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - risk_budget = float(preview.get("risk_budget") or get_fixed_amount(get_setting)) - if grp: - gid = int(grp["id"]) - leg_n = int(grp["leg_count"] or 0) + 1 - conn.execute( - "UPDATE roll_groups SET leg_count=?, current_stop_loss=?, updated_at=? WHERE id=?", - (leg_n, new_sl, now, gid), - ) - else: - cur = conn.execute( - """INSERT INTO roll_groups ( - order_monitor_id, symbol, direction, initial_take_profit, initial_stop_loss, - current_stop_loss, risk_percent, leg_count, status, created_at, updated_at - ) VALUES (?,?,?,?,?,?,?,1,'active',?,?) RETURNING id""", - ( - mon_id, sym, mon["direction"], mon["take_profit"], mon["stop_loss"], - new_sl, risk_budget, now, now, - ), - ) - row = cur.fetchone() - gid = int(row["id"] if isinstance(row, dict) else row[0]) - leg_n = 1 - if pending_leg_id: - conn.execute( - """UPDATE roll_legs SET status=?, fill_price=?, lots=?, new_stop_loss=?, created_at=? - WHERE id=?""", - ( - LEG_STATUS_FILLED, price, int(preview["add_lots"]), new_sl, now, - int(pending_leg_id), - ), - ) - else: - conn.execute( - """INSERT INTO roll_legs ( - roll_group_id, leg_index, add_mode, fill_price, lots, new_stop_loss, - status, created_at, limit_price, breakthrough_price, last_mark_price, capital_snapshot - ) VALUES (?,?,?,?,?,?,?,?,?,?,?,?)""", - ( - gid, leg_n, add_mode, price, int(preview["add_lots"]), new_sl, - LEG_STATUS_FILLED, now, - preview.get("limit_price"), preview.get("breakthrough_price"), - preview.get("mark_price"), _capital(conn), - ), - ) - conn.commit() - send_wechat_msg( - f"滚仓成交 {sym} {add_mode_label(add_mode)} +{preview['add_lots']}手 " - f"新止损 {new_sl} 合计 {new_lots}手" - ) - _schedule_roll_entry_sync(mon_id, sym, mon["direction"], mode) - return True, "成交" - - def _schedule_roll_entry_sync( - mon_id: int, sym: str, direction: str, mode: str, - ) -> None: - """滚仓成交后从柜台同步加权均价到手数监控。""" - def _run() -> None: - import time as _time - - _time.sleep(1.5) - try: - conn = get_db() - try: - init_strategy_tables(conn) - capital = _capital(conn) - synced = False - for p in trading_state.get_positions() or _ctp_positions(mode): - if (p.get("direction") or "long") != (direction or "long"): - continue - if not _match_ctp_symbol(p.get("symbol") or "", sym): - continue - _sync_monitor_from_ctp( - conn, mon_id, sym, direction, mode, ctp=p, capital=capital, - ) - synced = True - break - if synced: - commit_retry(conn) - finally: - conn.close() - if synced: - _push_position_snapshot_async(fast=False) - except Exception as exc: - logger.debug("roll entry sync: %s", exc) - - threading.Thread(target=_run, daemon=True, name="roll-entry-sync").start() - - def _submit_roll_pending( - conn, - *, - mon: dict, - preview: dict, - add_mode: str, - ) -> tuple[bool, str]: - mon_id = int(mon["id"]) - grp = _roll_group_for_monitor(conn, mon_id) - now = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - capital = _capital(conn) - risk_budget = float(preview.get("risk_budget") or get_fixed_amount(get_setting)) - if grp: - gid = int(grp["id"]) - else: - cur = conn.execute( - """INSERT INTO roll_groups ( - order_monitor_id, symbol, direction, initial_take_profit, initial_stop_loss, - current_stop_loss, risk_percent, leg_count, status, created_at, updated_at - ) VALUES (?,?,?,?,?,?,?,0,'active',?,?) RETURNING id""", - ( - mon_id, mon["symbol"], mon["direction"], mon["take_profit"], mon["stop_loss"], - preview["new_stop_loss"], risk_budget, now, now, - ), - ) - row = cur.fetchone() - gid = int(row["id"] if isinstance(row, dict) else row[0]) - leg_n = int(conn.execute( - "SELECT COUNT(*) AS n FROM roll_legs WHERE roll_group_id=? AND status=?", - (gid, LEG_STATUS_FILLED), - ).fetchone()["n"]) + 1 - pending_n = conn.execute( - "SELECT COUNT(*) AS n FROM roll_legs WHERE roll_group_id=? AND status=?", - (gid, LEG_STATUS_PENDING), - ).fetchone()["n"] - if int(pending_n or 0) > 0: - return False, "已有监控中的加仓腿" - conn.execute( - """INSERT INTO roll_legs ( - roll_group_id, leg_index, add_mode, lots, new_stop_loss, status, created_at, - limit_price, breakthrough_price, last_mark_price, capital_snapshot - ) VALUES (?,?,?,?,?,?,?,?,?,?,?)""", - ( - gid, leg_n, add_mode, int(preview["add_lots"]), preview["new_stop_loss"], - LEG_STATUS_PENDING, now, - preview.get("limit_price"), preview.get("breakthrough_price"), - preview.get("mark_price"), capital, - ), - ) - conn.commit() - return True, "已提交监控,触价后自动市价加仓" - - def _fill_roll_leg_cb(mon: dict, grp: dict, leg: dict, preview: dict) -> tuple[bool, str]: - conn = get_db() - init_strategy_tables(conn) - mode = get_trading_mode(get_setting) - ok, msg = _commit_roll_fill( - conn, mon=mon, preview=preview, add_mode=leg.get("add_mode") or ADD_MODE_MARKET, - mode=mode, pending_leg_id=int(leg["id"]), - ) - conn.close() - return ok, msg - - def _check_roll_monitors(): - conn = get_db() - init_strategy_tables(conn) - mode = get_trading_mode(get_setting) - try: - check_roll_monitors( - conn, - get_mark_price_fn=lambda sym: _roll_mark_price(sym, {}, mode, allow_ctp=True), - fill_roll_leg_fn=_fill_roll_leg_cb, - is_trading_session_fn=is_trading_session, - get_risk_budget_fn=lambda: get_fixed_amount(get_setting), - get_entry_price_fn=lambda sym, d, fb: _live_entry_price( - sym, d, mode, fb, allow_ctp=True, - ), - ) - conn.commit() - finally: - conn.close() - - app._check_roll_monitors = _check_roll_monitors - - def _apply_roll_margin_cap( - preview: dict, - *, - conn, - mode: str, - mon: dict, - capital: float, - ) -> tuple[dict, Optional[str]]: - """滚仓:风险算手数后再按滚仓保证金上限收紧。""" - if not preview: - return preview, "预览无效" - sym = mon["symbol"] - direction = (mon.get("direction") or "long").strip().lower() - price = float(preview.get("add_price") or 0) - qty_existing = float(mon.get("lots") or 0) - entry_existing = _live_entry_price( - sym, direction, mode, float(mon.get("entry_price") or 0), - allow_ctp=False, - ) - mult = int(get_contract_spec(sym).get("mult") or 1) - roll_pct = get_roll_max_margin_pct(get_setting) - add_lots = int(preview.get("add_lots") or 0) - positions = _positions_for_monitor_restore(mode, allow_ctp=False) - capped, usage = cap_lots_for_margin_budget( - positions, capital, sym, direction, price, add_lots, roll_pct, trading_mode=mode, - ) - if capped < 1: - return preview, f"滚仓后保证金占用将超过上限 {roll_pct:g}%" - out = dict(preview) - if capped < add_lots: - out["add_lots"] = capped - out["qty_after"] = int(qty_existing + capped) - out["avg_entry_after"] = round( - avg_entry_after_add(qty_existing, entry_existing, capped, price), 4, - ) - sl = float(out.get("new_stop_loss") or 0) - tp = float(out.get("initial_take_profit") or 0) - new_avg = float(out["avg_entry_after"]) - new_qty = float(out["qty_after"]) - if direction == "long": - out["loss_at_sl"] = round((new_avg - sl) * new_qty * mult, 2) - out["reward_at_tp"] = round((tp - new_avg) * new_qty * mult, 2) - else: - out["loss_at_sl"] = round((sl - new_avg) * new_qty * mult, 2) - out["reward_at_tp"] = round((new_avg - tp) * new_qty * mult, 2) - out["margin_capped"] = True - out["margin_cap_note"] = ( - f"按滚仓保证金上限 {roll_pct:g}% 收紧:" - f"风险算 {add_lots} 手 → 实际 {capped} 手" - ) - out["margin_usage_pct"] = round(usage, 2) - out["roll_max_margin_pct"] = roll_pct - return out, None - - @app.route("/api/strategy/roll/preview", methods=["POST"]) - @login_required - def api_roll_preview(): - d = request.get_json(silent=True) or {} - conn = get_db() - init_strategy_tables(conn) - ensure_monitor_order_columns(conn) - mon_id = int(d.get("monitor_id") or 0) - roll_ctx = _build_roll_context(conn) - mon = _roll_monitor_for_request(conn, mon_id) - if not mon: - conn.close() - return jsonify({"ok": False, "error": "无有效持仓监控"}), 400 - conn.commit() - mon_d = dict(mon) - err = _roll_eligibility(conn, mon_d, roll_ctx) - if err: - conn.close() - return jsonify({"ok": False, "error": err}), 400 - mode = get_trading_mode(get_setting) - preview, perr = _build_roll_preview(conn, d, mon_d, mode=mode) - conn.close() - if perr: - return jsonify({"ok": False, "error": perr}), 400 - return jsonify({"ok": True, "preview": preview}) - - @app.route("/api/strategy/roll/execute", methods=["POST"]) - @login_required - def api_roll_execute(): - d = request.get_json(silent=True) or {} - conn = get_db() - init_strategy_tables(conn) - ensure_monitor_order_columns(conn) - mon_id = int(d.get("monitor_id") or 0) - roll_ctx = _build_roll_context(conn) - mon = _roll_monitor_for_request(conn, mon_id) - if not mon: - conn.close() - return jsonify({"ok": False, "error": "无有效持仓监控"}), 400 - conn.commit() - mon_d = dict(mon) - err = _roll_eligibility(conn, mon_d, roll_ctx) - if err: - conn.close() - return jsonify({"ok": False, "error": err}), 400 - mode = get_trading_mode(get_setting) - preview, perr = _build_roll_preview(conn, d, mon_d, mode=mode) - if perr: - conn.close() - return jsonify({"ok": False, "error": perr}), 400 - add_mode = (d.get("add_mode") or ADD_MODE_MARKET).strip().lower() - if add_mode in PENDING_MODES: - ok, msg = _submit_roll_pending(conn, mon=mon_d, preview=preview, add_mode=add_mode) - conn.close() - if not ok: - return jsonify({"ok": False, "error": msg}), 400 - note = "已提交监控,开盘触价后自动市价加仓" if not is_trading_session() else msg - return jsonify({"ok": True, "message": note, "pending": True}) - if not is_trading_session(): - conn.close() - return jsonify({"ok": False, "error": "不在交易时间段"}), 403 - if not _cached_ctp_status(mode).get("connected"): - conn.close() - return jsonify({"ok": False, "error": "请先连接 CTP"}), 400 - ok, msg = _commit_roll_fill( - conn, mon=mon_d, preview=preview, add_mode=add_mode, mode=mode, - ) - conn.close() - if not ok: - return jsonify({"ok": False, "error": msg}), 400 - return jsonify({"ok": True, "message": msg, "preview": preview}) - - @app.route("/api/strategy/roll/cancel/", methods=["POST"]) - @login_required - def api_roll_cancel(leg_id: int): - conn = get_db() - init_strategy_tables(conn) - ok, msg = cancel_roll_leg(conn, leg_id) - if ok: - conn.commit() - conn.close() - if not ok: - return jsonify({"ok": False, "error": msg}), 400 - return jsonify({"ok": True, "message": msg}) - - @app.route("/api/strategy/trend/stop", methods=["POST"]) - @login_required - def api_trend_stop(): - d = request.get_json(silent=True) or {} - plan_id = int(d.get("plan_id") or 0) - conn = get_db() - plan = conn.execute("SELECT * FROM trend_pullback_plans WHERE id=? AND status='active'", (plan_id,)).fetchone() - if not plan: - conn.close() - return jsonify({"ok": False, "error": "计划不存在"}), 404 - mode = get_trading_mode(get_setting) - price = fetch_price(plan["symbol"]) or float(plan["avg_entry_price"] or 0) - try: - if int(plan["lots_open"] or 0) > 0: - execute_order( - conn, mode=mode, offset="close", symbol=plan["symbol"], - direction=plan["direction"], lots=int(plan["lots_open"]), price=price, settings=_settings_dict(), - ) - except ValueError: - pass - now = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - conn.execute( - "UPDATE trend_pullback_plans SET status='stopped_manual', message=?, opened_at=opened_at WHERE id=?", - ("手动结束", plan_id), - ) - save_snapshot( - conn, strategy_type=STRATEGY_TREND, source_id=plan_id, - symbol=plan["symbol"], direction=plan["direction"], result_label="手动结束", - payload=dict(plan), opened_at=plan["opened_at"] or "", - ) - on_user_initiated_close(conn, trading_day=trading_day_label()) - conn.commit() - conn.close() - return jsonify({"ok": True}) - - def check_trend_plans(app_ref): - """后台:趋势补仓与止盈。""" - conn = get_db() - init_strategy_tables(conn) - rows = conn.execute("SELECT * FROM trend_pullback_plans WHERE status='active'").fetchall() - mode = get_trading_mode(get_setting) - for plan in rows: - sym = plan["symbol"] - price = fetch_price(sym) - if not price: - continue - direction = plan["direction"] - tp = float(plan["take_profit"] or 0) - if tp > 0: - hit_tp = (direction == "long" and price >= tp) or (direction == "short" and price <= tp) - if hit_tp: - try: - execute_order( - conn, mode=mode, offset="close", symbol=sym, direction=direction, - lots=int(plan["lots_open"] or 0), price=price, settings=_settings_dict(), - ) - except ValueError: - pass - conn.execute( - "UPDATE trend_pullback_plans SET status='stopped_tp', message=? WHERE id=?", - ("程序止盈", plan["id"]), - ) - save_snapshot( - conn, strategy_type=STRATEGY_TREND, source_id=plan["id"], - symbol=sym, direction=direction, result_label="止盈", - payload=dict(plan), opened_at=plan["opened_at"] or "", - ) - send_wechat_msg(f"趋势回调止盈 {sym}") - continue - try: - grid = json.loads(plan["grid_prices_json"] or "[]") - legs = json.loads(plan["leg_amounts_json"] or "[]") - except Exception: - grid, legs = [], [] - done = int(plan["legs_done"] or 0) - if done < len(grid) and done < len(legs): - level = float(grid[done]) - if trend_dca_level_reached(direction, price, level): - add_lots = int(legs[done]) - try: - execute_order( - conn, mode=mode, offset="open", symbol=sym, direction=direction, - lots=add_lots, price=price, settings=_settings_dict(), - ) - new_open = int(plan["lots_open"] or 0) + add_lots - old_avg = float(plan["avg_entry_price"] or price) - new_avg = (old_avg * int(plan["lots_open"] or 0) + price * add_lots) / new_open if new_open else price - conn.execute( - """UPDATE trend_pullback_plans SET legs_done=?, lots_open=?, avg_entry_price=? WHERE id=?""", - (done + 1, new_open, new_avg, plan["id"]), - ) - send_wechat_msg(f"趋势回调补仓 {sym} +{add_lots}手 @档位{done+1}") - except ValueError: - pass - conn.commit() - conn.close() - - app._check_trend_plans = check_trend_plans - - def _execute_key_breakout(conn, row, bar, break_side): - """关键位箱体/收敛:5m 收盘突破后自动市价开仓。""" - from key_monitor_lib import ( - TYPE_BOX, - calc_breakout_sl_tp, - format_auto_breakout_msg, - normalize_monitor_type, - resolve_order_direction, - ) - - sym = (row.get("symbol") or "").strip() - bar_time = str(bar.get("time") or "")[:19] - monitor_type = normalize_monitor_type(row.get("monitor_type") or "") - trade_mode = row.get("trade_mode") or "顺势" - direction = resolve_order_direction(break_side, trade_mode) - trailing_be = int(row.get("trailing_be") or 0) - try: - rr = float(row.get("risk_reward") or (3 if trailing_be else 2)) - except (TypeError, ValueError): - rr = 3.0 if trailing_be else 2.0 - if trailing_be and rr < 3: - rr = 3.0 - - def _notify(ok: bool, detail: str, **kw): - send_wechat_msg(format_auto_breakout_msg( - row, - break_side=break_side, - direction=direction, - entry=kw.get("entry", 0), - sl=kw.get("sl", 0), - tp=kw.get("tp", 0), - lots=kw.get("lots", 0), - bar_time=bar_time, - ok=ok, - detail=detail, - )) - - if monitor_type == TYPE_BOX: - cfg_dir = (row.get("direction") or "").strip().lower() - if cfg_dir in ("long", "short") and direction != cfg_dir: - dir_cn = "做多" if cfg_dir == "long" else "做空" - _notify(False, f"突破方向与上方向({dir_cn})不一致", entry=0, sl=0, tp=0, lots=0) - return False, "突破方向与上方向不一致" - - try: - init_strategy_tables(conn) - mode = get_trading_mode(get_setting) - if not ctp_status(mode).get("connected"): - _notify(False, "CTP 未连接") - return False, "CTP 未连接" - if not is_trading_session(): - _notify(False, "非交易时段") - return False, "非交易时段" - - try: - entry = float(bar.get("close") or 0) - except (TypeError, ValueError): - _notify(False, "K 线收盘价无效") - return False, "K 线收盘价无效" - if entry <= 0: - _notify(False, "K 线收盘价无效") - return False, "K 线收盘价无效" - - sl, tp = calc_breakout_sl_tp( - sym=sym, direction=direction, entry=entry, bar=bar, risk_reward=rr, - ) - err = assert_can_open( - conn, - active_count=_effective_active_position_count(conn, mode), - equity=_capital(conn), - ) - if err: - _notify(False, err, entry=entry, sl=sl, tp=tp, lots=0) - return False, err - - capital = _capital(conn) - lots, lot_err = calc_lots_by_risk( - entry, sl, direction, capital, get_risk_percent(get_setting), sym, - max_margin_pct=get_max_margin_pct(get_setting), trading_mode=mode, - ) - if lot_err or not lots: - msg = lot_err or "手数计算失败" - _notify(False, msg, entry=entry, sl=sl, tp=tp, lots=0) - return False, msg - - result = execute_order( - conn, - mode=mode, - offset="open", - symbol=sym, - direction=direction, - lots=lots, - price=entry, - settings=_settings_dict(), - order_type="market", - ) - open_ts = bar_time.replace("T", " ") if bar_time else datetime.now().strftime("%Y-%m-%d %H:%M:%S") - vt_order_id = str(result.get("order_id") or "") - mid = _upsert_open_monitor( - conn, - sym=sym, - direction=direction, - lots=lots, - price=entry, - sl=sl, - tp=tp, - trailing_be=trailing_be, - open_time=open_ts, - monitor_type=monitor_type, - status="pending", - vt_order_id=vt_order_id or None, - order_price=entry, - ) - _reconcile_pending(conn, mode, capital=capital) - st_row = conn.execute( - "SELECT status FROM trade_order_monitors WHERE id=?", (mid,), - ).fetchone() - filled = st_row and (st_row["status"] or "").strip().lower() == "active" - rejected = st_row and (st_row["status"] or "").strip().lower() == "closed" - if rejected: - conn.commit() - _notify(False, "委托被柜台拒绝或撤销", entry=entry, sl=sl, tp=tp, lots=lots) - return False, "委托被拒绝" - if filled: - _sync_monitor_from_ctp( - conn, mid, sym, direction, mode, capital=capital, - ) - conn.commit() - if filled: - from db_conn import DB_PATH - from ai_worker import schedule_ai_event_analysis - from trade_notify import notify_key_breakout_open - - notify_key_breakout_open( - send_wechat=send_wechat_msg, - get_setting=get_setting, - mode_label=trading_mode_label(get_setting), - row=row, - break_side=break_side, - bar_time=bar_time, - direction=direction, - entry=entry, - sl=sl, - tp=tp, - lots=lots, - capital=capital, - order_id=vt_order_id, - schedule_ai_fn=schedule_ai_event_analysis, - db_path=DB_PATH, - ) - else: - _notify(True, "委托已提交,待成交", entry=entry, sl=sl, tp=tp, lots=lots) - _push_position_snapshot_async(fast=False) - return True, "已下单" if filled else "委托已提交" - except Exception as exc: - logger.warning("key breakout auto order: %s", exc) - _notify(False, str(exc)) - return False, str(exc) - - app._execute_key_breakout = _execute_key_breakout - - @app.route("/settings/trading", methods=["POST"]) - @login_required - def settings_trading_post(): - return redirect(url_for("settings")) - - def hook_review_mood(conn, behavior_tags: str, exit_trigger: str, exit_supplement: str): - if parse_mood_issues(behavior_tags): - on_mood_journal_freeze(conn, trading_day=trading_day_label()) - - app._risk_review_hook = hook_review_mood - - from db_conn import DB_PATH - - def _init_tables(conn): - init_strategy_tables(conn) - - threading.Thread( - target=_prime_position_snapshot, - daemon=True, - name="position-prime", - ).start() - - _pos_refresh_tick = {"n": 0} - _last_full_calibrate = {"ts": 0.0} - - def _position_worker_refresh() -> dict: - import time as _time - from ctp_trading_state import CALIBRATE_INTERVAL_SEC - - _pos_refresh_tick["n"] += 1 - mode = get_trading_mode(get_setting) - connected = bool(ctp_status(mode).get("connected")) - now = _time.time() - since_connect = now - float( - getattr(get_bridge(), "_last_connect_ok_ts", 0) or 0, - ) - if connected and since_connect < 45: - return _refresh_trading_live_snapshot(fast=True) - need_full = ( - connected - and ( - trading_state.needs_calibrate() - or (now - _last_full_calibrate["ts"]) >= CALIBRATE_INTERVAL_SEC - ) - ) - if need_full: - _last_full_calibrate["ts"] = now - return _refresh_trading_live_snapshot(fast=False) - return _refresh_trading_live_snapshot(fast=True) - - start_position_worker( - refresh_fn=_position_worker_refresh, - interval=1, - idle_interval=3, - ) - if os.getenv("QIHUO_CTP_ROLE", "client").strip().lower() == "worker": - _bootstrap_trading_runtime() - start_ctp_reconnect_worker( - get_mode_fn=lambda: get_trading_mode(get_setting), - get_setting_fn=get_setting, - ) - start_ctp_premarket_connect_worker( - get_mode_fn=lambda: get_trading_mode(get_setting), - get_setting_fn=get_setting, - ) - start_sl_tp_guard_worker( - db_path=DB_PATH, - get_mode_fn=lambda: get_trading_mode(get_setting), - init_tables_fn=_init_tables, - get_capital_fn=_capital, - get_be_tick_buffer_fn=lambda: get_trailing_be_tick_buffer(get_setting), - notify_fn=send_wechat_msg, - interval=1, - ) - start_pending_order_worker( - db_path=DB_PATH, - get_mode_fn=lambda: get_trading_mode(get_setting), - init_tables_fn=_init_tables, - get_capital_fn=_capital, - reconcile_fn=_reconcile_pending, - on_changed_fn=lambda: _push_position_snapshot_async(fast=False), - ) - - def _start_deferred_workers() -> None: - time.sleep(2) - start_recommend_worker( - db_path=DB_PATH, - get_capital_fn=_recommend_capital, - quote_fn=_main_quote, - init_tables_fn=_init_tables, - get_mode_fn=lambda: get_trading_mode(get_setting), - get_max_margin_pct_fn=lambda: get_max_margin_pct(get_setting), - get_sizing_mode_fn=lambda: get_sizing_mode(get_setting), - get_fixed_lots_fn=lambda: get_fixed_lots(get_setting), - ) - if os.getenv("QIHUO_CTP_ROLE", "client").strip().lower() == "worker": - start_ctp_fee_worker( - get_mode_fn=lambda: get_trading_mode(get_setting), - get_setting_fn=get_setting, - set_setting_fn=set_setting, - ) - from ai_worker import start_ai_worker - - start_ai_worker( - db_path=DB_PATH, - get_setting_fn=get_setting, - set_setting_fn=set_setting, - send_wechat_fn=send_wechat_msg, - ) - - threading.Thread( - target=_start_deferred_workers, - daemon=True, - name="deferred-workers", - ).start() +__all__ = ["install_trading"] diff --git a/modules/__init__.py b/modules/__init__.py new file mode 100644 index 0000000..ac285f8 --- /dev/null +++ b/modules/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) 2025-2026 马建军. All rights reserved. + +"""Qihuo feature modules package.""" diff --git a/modules/backup/__init__.py b/modules/backup/__init__.py new file mode 100644 index 0000000..f4584c4 --- /dev/null +++ b/modules/backup/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) 2025-2026 马建军. All rights reserved. + +from modules.backup.routes import register + +__all__ = ["register"] diff --git a/db_backup.py b/modules/backup/db_backup.py similarity index 96% rename from db_backup.py rename to modules/backup/db_backup.py index 960995f..9fa4397 100644 --- a/db_backup.py +++ b/modules/backup/db_backup.py @@ -1,402 +1,403 @@ -# Copyright (c) 2025-2026 马建军. All rights reserved. -# 专有软件 — 未经授权禁止复制、传播、转售。 -# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 -# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md - -"""数据库备份:SQLite futures.db 或 PostgreSQL pg_dump,含 uploads 与一键恢复脚本。""" -from __future__ import annotations - -import json -import logging -import os -import re -import shutil -import sqlite3 -import subprocess -import tarfile -import tempfile -import threading -import time -from datetime import datetime -from pathlib import Path -from typing import Callable, Optional -from zoneinfo import ZoneInfo - -from db_conn import DB_PATH, db_backend - -logger = logging.getLogger(__name__) - -TZ = ZoneInfo("Asia/Shanghai") -BACKUP_FILENAME_RE = re.compile(r"^qihuo_backup_\d{8}_\d{6}\.tar\.gz$") -BACKUP_LAST_KEY = "backup_last_at" -BACKUP_KEEP_KEY = "backup_keep_count" -BACKUP_AUTO_KEY = "backup_auto_enabled" -BACKUP_HOUR_KEY = "backup_auto_hour" -DEFAULT_KEEP_COUNT = 30 -DEFAULT_AUTO_HOUR = 3 -CHECK_INTERVAL_SEC = 3600 -_backup_lock = threading.Lock() - -RESTORE_MD = """# qihuo 备份恢复说明 - -本压缩包由 qihuo 系统自动生成,可在另一台 Linux 服务器上恢复交易数据。 - -## 包内文件 - -| 文件/目录 | 说明 | -|-----------|------| -| `futures.db` | SQLite 主库(仅 SQLite 模式备份) | -| `postgres_dump.sql` | PostgreSQL 逻辑备份(仅 PostgreSQL 模式) | -| `uploads/` | 复盘截图与 K 线图(若备份时存在) | -| `manifest.json` | 备份元数据(含 `backend` 字段) | -| `restore.sh` | 一键恢复脚本 | - -## 快速恢复(推荐) - -1. 将本压缩包上传到目标服务器(例如 `/root/`) -2. 解压并执行恢复脚本: - -```bash -cd /root -tar -xzf qihuo_backup_YYYYMMDD_HHMMSS.tar.gz -cd qihuo_backup_YYYYMMDD_HHMMSS -chmod +x restore.sh -./restore.sh -``` - -默认恢复到 **`/root/qihuo`**(SQLite)或导入到 `.env` 中的 PostgreSQL(见 manifest)。 - -指定应用目录: - -```bash -RESTORE_DIR=/opt/qihuo ./restore.sh -``` - -3. 在新服务器部署 qihuo 代码与 Python 环境(见 `docs/POSTGRES.md` / `docs/DEPLOY.md`) -4. 配置 `.env`(`DATABASE_URL` 或 SQLite、`SECRET_KEY`、CTP 账号等) -5. 重启服务:`pm2 restart qihuo` - -## PostgreSQL 恢复 - -若 `manifest.json` 中 `"backend": "postgres"`: - -1. 确保目标机已安装 PostgreSQL,且 `.env` 中 `DATABASE_URL` 指向空库或待覆盖库 -2. 执行 `./restore.sh`(会调用 `psql` 导入 `postgres_dump.sql`) - -手工导入: - -```bash -export DATABASE_URL=postgresql://qihuo:密码@127.0.0.1:5432/qihuo -psql "$DATABASE_URL" -f postgres_dump.sql -``` - -## SQLite 手工恢复 - -```bash -mkdir -p /opt/qihuo/uploads -cp futures.db /opt/qihuo/futures.db -cp -a uploads/. /opt/qihuo/uploads/ -``` - -## 注意 - -- 恢复前请停止 qihuo 进程 -- `.env` 含敏感信息,请单独安全传输 -- 详见 `docs/POSTGRES.md` 与 `docs/BACKUP.md` -""" - - -def _app_root() -> Path: - return Path(os.path.dirname(os.path.abspath(__file__))) - - -def default_backup_dir() -> str: - env = (os.getenv("QIHUO_BACKUP_DIR") or "").strip() - if env: - return env - if os.name == "nt": - return str(_app_root() / "qihuo_backup") - return "/root/qihuo_backup" - - -def default_restore_dir() -> str: - env = (os.getenv("QIHUO_RESTORE_DIR") or "").strip() - if env: - return env - if os.name == "nt": - return str(_app_root()) - return "/root/qihuo" - - -def backup_dir() -> Path: - path = Path(default_backup_dir()) - path.mkdir(parents=True, exist_ok=True) - return path - - -def backup_in_progress() -> bool: - return _backup_lock.locked() - - -def get_backup_last_at(get_setting: Callable[[str, str], str]) -> str: - return (get_setting(BACKUP_LAST_KEY, "") or "").strip() - - -def _backup_sqlite(src_path: str, dst_path: str) -> None: - src = sqlite3.connect(src_path, timeout=30) - try: - try: - src.execute("PRAGMA wal_checkpoint(TRUNCATE)") - except sqlite3.OperationalError: - pass - dst = sqlite3.connect(dst_path) - try: - src.backup(dst) - dst.commit() - finally: - dst.close() - finally: - src.close() - - -def _backup_postgres(dst_path: str) -> None: - url = (os.getenv("DATABASE_URL") or "").strip() - if not url: - raise RuntimeError("PostgreSQL 备份需要 DATABASE_URL") - env = os.environ.copy() - proc = subprocess.run( - ["pg_dump", "--no-owner", "--no-acl", "-f", dst_path, url], - capture_output=True, - text=True, - env=env, - check=False, - ) - if proc.returncode != 0: - raise RuntimeError(f"pg_dump 失败: {proc.stderr.strip() or proc.stdout.strip()}") - - -def _write_restore_script(dest: Path, *, backend: str) -> None: - pg_block = "" - if backend == "postgres": - pg_block = """ -if [ -f "$SCRIPT_DIR/postgres_dump.sql" ]; then - if [ -z "${DATABASE_URL:-}" ]; then - if [ -f "$RESTORE_DIR/.env" ]; then - set -a - # shellcheck disable=SC1090 - source "$RESTORE_DIR/.env" - set +a - fi - fi - if [ -z "${DATABASE_URL:-}" ]; then - echo "错误: PostgreSQL 恢复需要 DATABASE_URL(环境变量或 $RESTORE_DIR/.env)" - exit 1 - fi - if ! command -v psql >/dev/null; then - echo "错误: 未找到 psql,请先安装 PostgreSQL 客户端" - exit 1 - fi - echo "导入 PostgreSQL: postgres_dump.sql" - psql "$DATABASE_URL" -f "$SCRIPT_DIR/postgres_dump.sql" - echo "PostgreSQL 导入完成" -fi -""" - script = f"""#!/bin/bash -set -euo pipefail -RESTORE_DIR="${{RESTORE_DIR:-{default_restore_dir()}}}" -SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" -mkdir -p "$RESTORE_DIR/uploads" -{pg_block} -if [ -f "$SCRIPT_DIR/futures.db" ]; then - cp -f "$SCRIPT_DIR/futures.db" "$RESTORE_DIR/futures.db" - echo "已复制 futures.db -> $RESTORE_DIR/futures.db" -fi -if [ -d "$SCRIPT_DIR/uploads" ]; then - cp -a "$SCRIPT_DIR/uploads/." "$RESTORE_DIR/uploads/" - echo "已复制 uploads -> $RESTORE_DIR/uploads/" -fi -echo "" -echo "恢复完成。目标目录: $RESTORE_DIR" -echo "下一步: 确认 .env、pm2 restart qihuo" -echo "详见 RESTORE.md 与 docs/POSTGRES.md" -""" - dest.write_text(script, encoding="utf-8") - - -def create_backup(*, include_uploads: bool = True) -> tuple[str, str]: - """创建 tar.gz 备份,返回 (文件名, 说明)。""" - backend = db_backend() - if backend == "sqlite" and not os.path.isfile(DB_PATH): - raise FileNotFoundError(f"数据库不存在: {DB_PATH}") - if backend == "postgres" and not (os.getenv("DATABASE_URL") or "").strip(): - raise RuntimeError("PostgreSQL 模式需要 DATABASE_URL") - - with _backup_lock: - stamp = datetime.now(TZ).strftime("%Y%m%d_%H%M%S") - folder_name = f"qihuo_backup_{stamp}" - filename = f"{folder_name}.tar.gz" - out_path = backup_dir() / filename - app_root = _app_root() - upload_src = app_root / "uploads" - - with tempfile.TemporaryDirectory(prefix="qihuo_bak_") as tmp: - work = Path(tmp) / folder_name - work.mkdir() - if backend == "postgres": - _backup_postgres(str(work / "postgres_dump.sql")) - else: - _backup_sqlite(DB_PATH, str(work / "futures.db")) - - if include_uploads and upload_src.is_dir(): - shutil.copytree(upload_src, work / "uploads", dirs_exist_ok=True) - - manifest = { - "app": "qihuo", - "backend": backend, - "created_at": datetime.now(TZ).isoformat(timespec="seconds"), - "db_path": DB_PATH if backend == "sqlite" else (os.getenv("DATABASE_URL") or ""), - "includes_uploads": include_uploads and upload_src.is_dir(), - "default_restore_dir": default_restore_dir(), - "files": sorted(p.name for p in work.iterdir()), - } - (work / "manifest.json").write_text( - json.dumps(manifest, ensure_ascii=False, indent=2), - encoding="utf-8", - ) - (work / "RESTORE.md").write_text(RESTORE_MD, encoding="utf-8") - _write_restore_script(work / "restore.sh", backend=backend) - - with tarfile.open(out_path, "w:gz") as tar: - tar.add(work, arcname=folder_name) - - size_mb = out_path.stat().st_size / (1024 * 1024) - label = "PostgreSQL" if backend == "postgres" else "SQLite" - return filename, f"备份已生成 {filename}({label},{size_mb:.2f} MB)" - - -def list_backups() -> list[dict]: - items: list[dict] = [] - for path in sorted(backup_dir().glob("qihuo_backup_*.tar.gz"), reverse=True): - if not BACKUP_FILENAME_RE.match(path.name): - continue - stat = path.stat() - items.append( - { - "name": path.name, - "size": stat.st_size, - "size_mb": round(stat.st_size / (1024 * 1024), 2), - "mtime": datetime.fromtimestamp(stat.st_mtime, TZ).isoformat(timespec="seconds"), - } - ) - return items - - -def resolve_backup_file(filename: str) -> Path: - name = (filename or "").strip() - if not BACKUP_FILENAME_RE.match(name): - raise ValueError("无效的备份文件名") - path = (backup_dir() / name).resolve() - root = backup_dir().resolve() - if not str(path).startswith(str(root) + os.sep) and path != root: - raise ValueError("无效的备份路径") - if not path.is_file(): - raise FileNotFoundError("备份文件不存在") - return path - - -def prune_old_backups(keep: int) -> int: - keep_n = max(1, int(keep or DEFAULT_KEEP_COUNT)) - files = list_backups() - removed = 0 - for item in files[keep_n:]: - try: - resolve_backup_file(item["name"]).unlink() - removed += 1 - except Exception as exc: - logger.warning("prune backup %s: %s", item["name"], exc) - return removed - - -def run_backup_job( - *, - get_setting: Callable[[str, str], str], - set_setting: Callable[[str, str], None], - include_uploads: bool = True, -) -> tuple[str, str]: - keep = DEFAULT_KEEP_COUNT - try: - keep = max(5, min(200, int(get_setting(BACKUP_KEEP_KEY, str(DEFAULT_KEEP_COUNT)) or DEFAULT_KEEP_COUNT))) - except ValueError: - pass - filename, msg = create_backup(include_uploads=include_uploads) - set_setting(BACKUP_LAST_KEY, datetime.now(TZ).isoformat(timespec="seconds")) - removed = prune_old_backups(keep) - if removed: - msg = f"{msg},已清理 {removed} 个旧备份" - return filename, msg - - -def schedule_backup( - *, - get_setting: Callable[[str, str], str], - set_setting: Callable[[str, str], None], - include_uploads: bool = True, -) -> tuple[bool, str]: - if _backup_lock.locked(): - return False, "备份进行中,请稍后再试" - - def _run() -> None: - try: - run_backup_job( - get_setting=get_setting, - set_setting=set_setting, - include_uploads=include_uploads, - ) - except Exception as exc: - logger.exception("backup failed: %s", exc) - - threading.Thread(target=_run, daemon=True, name="qihuo-backup-run").start() - return True, "已在后台开始备份,请稍后刷新本页查看" - - -def _should_auto_backup(get_setting: Callable[[str, str], str]) -> bool: - if (get_setting(BACKUP_AUTO_KEY, "1") or "1").strip() not in ("1", "true", "yes"): - return False - try: - hour = int(get_setting(BACKUP_HOUR_KEY, str(DEFAULT_AUTO_HOUR)) or DEFAULT_AUTO_HOUR) - except ValueError: - hour = DEFAULT_AUTO_HOUR - hour = max(0, min(23, hour)) - now = datetime.now(TZ) - if now.hour != hour: - return False - last = get_backup_last_at(get_setting) - if last and last[:10] == now.date().isoformat(): - return False - return True - - -def start_backup_worker( - *, - get_setting_fn: Callable[[str, str], str], - set_setting_fn: Callable[[str, str], None], - interval: int = CHECK_INTERVAL_SEC, -) -> None: - """后台线程:按设定小时每日自动备份。""" - - def _loop() -> None: - time.sleep(30) - while True: - try: - if _should_auto_backup(get_setting_fn): - filename, msg = run_backup_job( - get_setting=get_setting_fn, - set_setting=set_setting_fn, - include_uploads=True, - ) - logger.info("auto backup: %s — %s", filename, msg) - except Exception as exc: - logger.warning("backup worker: %s", exc) - time.sleep(max(300, interval)) - - threading.Thread(target=_loop, daemon=True, name="qihuo-backup-worker").start() +# Copyright (c) 2025-2026 马建军. All rights reserved. +# 专有软件 — 未经授权禁止复制、传播、转售。 +# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 +# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md + +"""数据库备份:SQLite futures.db 或 PostgreSQL pg_dump,含 uploads 与一键恢复脚本。""" +from __future__ import annotations + +import json +import logging +import os +import re +import shutil +import sqlite3 +import subprocess +import tarfile +import tempfile +import threading +import time +from datetime import datetime +from pathlib import Path +from typing import Callable, Optional +from zoneinfo import ZoneInfo + +from modules.core.db_conn import DB_PATH, db_backend + +logger = logging.getLogger(__name__) + +TZ = ZoneInfo("Asia/Shanghai") +BACKUP_FILENAME_RE = re.compile(r"^qihuo_backup_\d{8}_\d{6}\.tar\.gz$") +BACKUP_LAST_KEY = "backup_last_at" +BACKUP_KEEP_KEY = "backup_keep_count" +BACKUP_AUTO_KEY = "backup_auto_enabled" +BACKUP_HOUR_KEY = "backup_auto_hour" +DEFAULT_KEEP_COUNT = 30 +DEFAULT_AUTO_HOUR = 3 +CHECK_INTERVAL_SEC = 3600 +_backup_lock = threading.Lock() + +RESTORE_MD = """# qihuo 备份恢复说明 + +本压缩包由 qihuo 系统自动生成,可在另一台 Linux 服务器上恢复交易数据。 + +## 包内文件 + +| 文件/目录 | 说明 | +|-----------|------| +| `futures.db` | SQLite 主库(仅 SQLite 模式备份) | +| `postgres_dump.sql` | PostgreSQL 逻辑备份(仅 PostgreSQL 模式) | +| `uploads/` | 复盘截图与 K 线图(若备份时存在) | +| `manifest.json` | 备份元数据(含 `backend` 字段) | +| `restore.sh` | 一键恢复脚本 | + +## 快速恢复(推荐) + +1. 将本压缩包上传到目标服务器(例如 `/root/`) +2. 解压并执行恢复脚本: + +```bash +cd /root +tar -xzf qihuo_backup_YYYYMMDD_HHMMSS.tar.gz +cd qihuo_backup_YYYYMMDD_HHMMSS +chmod +x restore.sh +./restore.sh +``` + +默认恢复到 **`/root/qihuo`**(SQLite)或导入到 `.env` 中的 PostgreSQL(见 manifest)。 + +指定应用目录: + +```bash +RESTORE_DIR=/opt/qihuo ./restore.sh +``` + +3. 在新服务器部署 qihuo 代码与 Python 环境(见 `docs/POSTGRES.md` / `docs/DEPLOY.md`) +4. 配置 `.env`(`DATABASE_URL` 或 SQLite、`SECRET_KEY`、CTP 账号等) +5. 重启服务:`pm2 restart qihuo` + +## PostgreSQL 恢复 + +若 `manifest.json` 中 `"backend": "postgres"`: + +1. 确保目标机已安装 PostgreSQL,且 `.env` 中 `DATABASE_URL` 指向空库或待覆盖库 +2. 执行 `./restore.sh`(会调用 `psql` 导入 `postgres_dump.sql`) + +手工导入: + +```bash +export DATABASE_URL=postgresql://qihuo:密码@127.0.0.1:5432/qihuo +psql "$DATABASE_URL" -f postgres_dump.sql +``` + +## SQLite 手工恢复 + +```bash +mkdir -p /opt/qihuo/uploads +cp futures.db /opt/qihuo/futures.db +cp -a uploads/. /opt/qihuo/uploads/ +``` + +## 注意 + +- 恢复前请停止 qihuo 进程 +- `.env` 含敏感信息,请单独安全传输 +- 详见 `docs/POSTGRES.md` 与 `docs/BACKUP.md` +""" + + +def _app_root() -> Path: + from modules.core.paths import ROOT + return ROOT + + +def default_backup_dir() -> str: + env = (os.getenv("QIHUO_BACKUP_DIR") or "").strip() + if env: + return env + if os.name == "nt": + return str(_app_root() / "qihuo_backup") + return "/root/qihuo_backup" + + +def default_restore_dir() -> str: + env = (os.getenv("QIHUO_RESTORE_DIR") or "").strip() + if env: + return env + if os.name == "nt": + return str(_app_root()) + return "/root/qihuo" + + +def backup_dir() -> Path: + path = Path(default_backup_dir()) + path.mkdir(parents=True, exist_ok=True) + return path + + +def backup_in_progress() -> bool: + return _backup_lock.locked() + + +def get_backup_last_at(get_setting: Callable[[str, str], str]) -> str: + return (get_setting(BACKUP_LAST_KEY, "") or "").strip() + + +def _backup_sqlite(src_path: str, dst_path: str) -> None: + src = sqlite3.connect(src_path, timeout=30) + try: + try: + src.execute("PRAGMA wal_checkpoint(TRUNCATE)") + except sqlite3.OperationalError: + pass + dst = sqlite3.connect(dst_path) + try: + src.backup(dst) + dst.commit() + finally: + dst.close() + finally: + src.close() + + +def _backup_postgres(dst_path: str) -> None: + url = (os.getenv("DATABASE_URL") or "").strip() + if not url: + raise RuntimeError("PostgreSQL 备份需要 DATABASE_URL") + env = os.environ.copy() + proc = subprocess.run( + ["pg_dump", "--no-owner", "--no-acl", "-f", dst_path, url], + capture_output=True, + text=True, + env=env, + check=False, + ) + if proc.returncode != 0: + raise RuntimeError(f"pg_dump 失败: {proc.stderr.strip() or proc.stdout.strip()}") + + +def _write_restore_script(dest: Path, *, backend: str) -> None: + pg_block = "" + if backend == "postgres": + pg_block = """ +if [ -f "$SCRIPT_DIR/postgres_dump.sql" ]; then + if [ -z "${DATABASE_URL:-}" ]; then + if [ -f "$RESTORE_DIR/.env" ]; then + set -a + # shellcheck disable=SC1090 + source "$RESTORE_DIR/.env" + set +a + fi + fi + if [ -z "${DATABASE_URL:-}" ]; then + echo "错误: PostgreSQL 恢复需要 DATABASE_URL(环境变量或 $RESTORE_DIR/.env)" + exit 1 + fi + if ! command -v psql >/dev/null; then + echo "错误: 未找到 psql,请先安装 PostgreSQL 客户端" + exit 1 + fi + echo "导入 PostgreSQL: postgres_dump.sql" + psql "$DATABASE_URL" -f "$SCRIPT_DIR/postgres_dump.sql" + echo "PostgreSQL 导入完成" +fi +""" + script = f"""#!/bin/bash +set -euo pipefail +RESTORE_DIR="${{RESTORE_DIR:-{default_restore_dir()}}}" +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +mkdir -p "$RESTORE_DIR/uploads" +{pg_block} +if [ -f "$SCRIPT_DIR/futures.db" ]; then + cp -f "$SCRIPT_DIR/futures.db" "$RESTORE_DIR/futures.db" + echo "已复制 futures.db -> $RESTORE_DIR/futures.db" +fi +if [ -d "$SCRIPT_DIR/uploads" ]; then + cp -a "$SCRIPT_DIR/uploads/." "$RESTORE_DIR/uploads/" + echo "已复制 uploads -> $RESTORE_DIR/uploads/" +fi +echo "" +echo "恢复完成。目标目录: $RESTORE_DIR" +echo "下一步: 确认 .env、pm2 restart qihuo" +echo "详见 RESTORE.md 与 docs/POSTGRES.md" +""" + dest.write_text(script, encoding="utf-8") + + +def create_backup(*, include_uploads: bool = True) -> tuple[str, str]: + """创建 tar.gz 备份,返回 (文件名, 说明)。""" + backend = db_backend() + if backend == "sqlite" and not os.path.isfile(DB_PATH): + raise FileNotFoundError(f"数据库不存在: {DB_PATH}") + if backend == "postgres" and not (os.getenv("DATABASE_URL") or "").strip(): + raise RuntimeError("PostgreSQL 模式需要 DATABASE_URL") + + with _backup_lock: + stamp = datetime.now(TZ).strftime("%Y%m%d_%H%M%S") + folder_name = f"qihuo_backup_{stamp}" + filename = f"{folder_name}.tar.gz" + out_path = backup_dir() / filename + app_root = _app_root() + upload_src = app_root / "uploads" + + with tempfile.TemporaryDirectory(prefix="qihuo_bak_") as tmp: + work = Path(tmp) / folder_name + work.mkdir() + if backend == "postgres": + _backup_postgres(str(work / "postgres_dump.sql")) + else: + _backup_sqlite(DB_PATH, str(work / "futures.db")) + + if include_uploads and upload_src.is_dir(): + shutil.copytree(upload_src, work / "uploads", dirs_exist_ok=True) + + manifest = { + "app": "qihuo", + "backend": backend, + "created_at": datetime.now(TZ).isoformat(timespec="seconds"), + "db_path": DB_PATH if backend == "sqlite" else (os.getenv("DATABASE_URL") or ""), + "includes_uploads": include_uploads and upload_src.is_dir(), + "default_restore_dir": default_restore_dir(), + "files": sorted(p.name for p in work.iterdir()), + } + (work / "manifest.json").write_text( + json.dumps(manifest, ensure_ascii=False, indent=2), + encoding="utf-8", + ) + (work / "RESTORE.md").write_text(RESTORE_MD, encoding="utf-8") + _write_restore_script(work / "restore.sh", backend=backend) + + with tarfile.open(out_path, "w:gz") as tar: + tar.add(work, arcname=folder_name) + + size_mb = out_path.stat().st_size / (1024 * 1024) + label = "PostgreSQL" if backend == "postgres" else "SQLite" + return filename, f"备份已生成 {filename}({label},{size_mb:.2f} MB)" + + +def list_backups() -> list[dict]: + items: list[dict] = [] + for path in sorted(backup_dir().glob("qihuo_backup_*.tar.gz"), reverse=True): + if not BACKUP_FILENAME_RE.match(path.name): + continue + stat = path.stat() + items.append( + { + "name": path.name, + "size": stat.st_size, + "size_mb": round(stat.st_size / (1024 * 1024), 2), + "mtime": datetime.fromtimestamp(stat.st_mtime, TZ).isoformat(timespec="seconds"), + } + ) + return items + + +def resolve_backup_file(filename: str) -> Path: + name = (filename or "").strip() + if not BACKUP_FILENAME_RE.match(name): + raise ValueError("无效的备份文件名") + path = (backup_dir() / name).resolve() + root = backup_dir().resolve() + if not str(path).startswith(str(root) + os.sep) and path != root: + raise ValueError("无效的备份路径") + if not path.is_file(): + raise FileNotFoundError("备份文件不存在") + return path + + +def prune_old_backups(keep: int) -> int: + keep_n = max(1, int(keep or DEFAULT_KEEP_COUNT)) + files = list_backups() + removed = 0 + for item in files[keep_n:]: + try: + resolve_backup_file(item["name"]).unlink() + removed += 1 + except Exception as exc: + logger.warning("prune backup %s: %s", item["name"], exc) + return removed + + +def run_backup_job( + *, + get_setting: Callable[[str, str], str], + set_setting: Callable[[str, str], None], + include_uploads: bool = True, +) -> tuple[str, str]: + keep = DEFAULT_KEEP_COUNT + try: + keep = max(5, min(200, int(get_setting(BACKUP_KEEP_KEY, str(DEFAULT_KEEP_COUNT)) or DEFAULT_KEEP_COUNT))) + except ValueError: + pass + filename, msg = create_backup(include_uploads=include_uploads) + set_setting(BACKUP_LAST_KEY, datetime.now(TZ).isoformat(timespec="seconds")) + removed = prune_old_backups(keep) + if removed: + msg = f"{msg},已清理 {removed} 个旧备份" + return filename, msg + + +def schedule_backup( + *, + get_setting: Callable[[str, str], str], + set_setting: Callable[[str, str], None], + include_uploads: bool = True, +) -> tuple[bool, str]: + if _backup_lock.locked(): + return False, "备份进行中,请稍后再试" + + def _run() -> None: + try: + run_backup_job( + get_setting=get_setting, + set_setting=set_setting, + include_uploads=include_uploads, + ) + except Exception as exc: + logger.exception("backup failed: %s", exc) + + threading.Thread(target=_run, daemon=True, name="qihuo-backup-run").start() + return True, "已在后台开始备份,请稍后刷新本页查看" + + +def _should_auto_backup(get_setting: Callable[[str, str], str]) -> bool: + if (get_setting(BACKUP_AUTO_KEY, "1") or "1").strip() not in ("1", "true", "yes"): + return False + try: + hour = int(get_setting(BACKUP_HOUR_KEY, str(DEFAULT_AUTO_HOUR)) or DEFAULT_AUTO_HOUR) + except ValueError: + hour = DEFAULT_AUTO_HOUR + hour = max(0, min(23, hour)) + now = datetime.now(TZ) + if now.hour != hour: + return False + last = get_backup_last_at(get_setting) + if last and last[:10] == now.date().isoformat(): + return False + return True + + +def start_backup_worker( + *, + get_setting_fn: Callable[[str, str], str], + set_setting_fn: Callable[[str, str], None], + interval: int = CHECK_INTERVAL_SEC, +) -> None: + """后台线程:按设定小时每日自动备份。""" + + def _loop() -> None: + time.sleep(30) + while True: + try: + if _should_auto_backup(get_setting_fn): + filename, msg = run_backup_job( + get_setting=get_setting_fn, + set_setting=set_setting_fn, + include_uploads=True, + ) + logger.info("auto backup: %s — %s", filename, msg) + except Exception as exc: + logger.warning("backup worker: %s", exc) + time.sleep(max(300, interval)) + + threading.Thread(target=_loop, daemon=True, name="qihuo-backup-worker").start() diff --git a/modules/backup/routes.py b/modules/backup/routes.py new file mode 100644 index 0000000..b7e3628 --- /dev/null +++ b/modules/backup/routes.py @@ -0,0 +1,78 @@ +# Copyright (c) 2025-2026 马建军. All rights reserved. +"""HTTP routes for backup module.""" + +from __future__ import annotations + +from datetime import date, datetime + +from flask import ( + Response, + flash, + jsonify, + redirect, + render_template, + request, + send_file, + session, + stream_with_context, + url_for, +) + + +def register(deps) -> None: + app = deps.app + login_required = deps.login_required + require_nav = deps.require_nav + get_db = deps.get_db + get_setting = deps.get_setting + set_setting = deps.set_setting + fetch_price = deps.fetch_price + send_wechat_msg = deps.send_wechat_msg + touch_stats_cache = deps.touch_stats_cache + get_stats_data = deps.get_stats_data + build_market_quote_payload = deps.build_market_quote_payload + today_str = deps.today_str + expire_old_plans = deps.expire_old_plans + TZ = deps.tz + DB_PATH = deps.db_path + UPLOAD_DIR = deps.upload_dir + OPEN_TYPES = deps.open_types + EXIT_TRIGGERS = deps.exit_triggers + BEHAVIOR_TAGS = deps.behavior_tags + KLINE_PERIODS = deps.kline_periods + KLINE_CUTOFFS = deps.kline_cutoffs + calc_holding_duration = deps.calc_holding_duration + holding_to_minutes = deps.holding_to_minutes + classify_close_result = deps.classify_close_result + calc_rr_ratio = deps.calc_rr_ratio + calc_theoretical_pnl = deps.calc_theoretical_pnl + parse_review_date_filter = deps.parse_review_date_filter + _trading_mode = deps.trading_mode + _ua_is_phone = deps.ua_is_phone + _static_asset_v = deps.static_asset_v + + from modules.backup.db_backup import list_backups, resolve_backup_file + + @app.route("/api/backup/list") + @login_required + def api_backup_list(): + return jsonify( + { + "dir": str(backup_dir()), + "last_at": get_backup_last_at(get_setting), + "running": backup_in_progress(), + "items": list_backups(), + } + ) + + + @app.route("/api/backup/download/") + @login_required + def api_backup_download(filename): + from flask import send_file + + try: + path = resolve_backup_file(filename) + except (ValueError, FileNotFoundError) as exc: + return jsonify({"error": str(exc)}), 404 + return send_file(path, as_attachment=True, download_name=path.name) diff --git a/modules/core/__init__.py b/modules/core/__init__.py new file mode 100644 index 0000000..6274207 --- /dev/null +++ b/modules/core/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2025-2026 马建军. All rights reserved. + +"""Core bootstrap and shared types.""" + +from modules.core.bootstrap import register_all_modules, start_module_workers +from modules.core.deps import AppDeps + +__all__ = ["AppDeps", "register_all_modules", "start_module_workers"] diff --git a/modules/core/bootstrap.py b/modules/core/bootstrap.py new file mode 100644 index 0000000..511582a --- /dev/null +++ b/modules/core/bootstrap.py @@ -0,0 +1,55 @@ +# Copyright (c) 2025-2026 马建军. All rights reserved. + +"""Application module registry and startup wiring.""" + +from __future__ import annotations + +import importlib +import logging +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from modules.core.deps import AppDeps + +logger = logging.getLogger(__name__) + +# Registration order: core services first, trading last among features. +_MODULE_NAMES = ( + "modules.web", + "modules.market", + "modules.keys", + "modules.plans", + "modules.notify", + "modules.records", + "modules.stats", + "modules.fees", + "modules.backup", + "modules.settings", + "modules.risk", + "modules.strategy", + "modules.ctp", + "modules.trading", +) + + +def register_all_modules(deps: "AppDeps") -> None: + for name in _MODULE_NAMES: + mod = importlib.import_module(name) + register = getattr(mod, "register", None) + if not callable(register): + logger.warning("module %s has no register()", name) + continue + register(deps) + logger.debug("registered %s", name) + + +def start_module_workers(deps: "AppDeps") -> None: + """Background threads owned by feature modules.""" + from modules.ctp.vnpy_bridge import try_init_vnpy + + try_init_vnpy({}) + for name in ("modules.market",): + mod = importlib.import_module(name) + start = getattr(mod, "start_workers", None) + if callable(start): + start(deps) diff --git a/contract_profile.py b/modules/core/contract_profile.py similarity index 95% rename from contract_profile.py rename to modules/core/contract_profile.py index ccdfd3a..3fd98c6 100644 --- a/contract_profile.py +++ b/modules/core/contract_profile.py @@ -1,280 +1,280 @@ -# Copyright (c) 2025-2026 马建军. All rights reserved. -# 专有软件 — 未经授权禁止复制、传播、转售。 -# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 -# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md - -"""期货合约简介:东方财富 / 新浪 / AKShare。""" -import logging -import re -from typing import Any, Optional - -import requests - -from contract_specs import get_contract_spec -from symbols import ths_to_codes, search_symbols - -logger = logging.getLogger(__name__) - -EM_LABEL_MAP = { - "vname": "交易品种", - "vcode": "交易代码", - "jydw": "交易单位", - "bjdw": "报价单位", - "market": "交易所", - "zxbddw": "最小变动价位", - "zdtbfd": "涨跌停幅度", - "hyjgyf": "合约月份", - "jysj": "交易时间", - "zhjyr": "最后交易日", - "zhjgr": "交割日期", - "jgpj": "交割品级", - "zcjybzj": "最低交易保证金", - "jgfs": "交割方式", - "jgdd": "交割地点", - "ssrq": "上市日期", -} - -DISPLAY_ORDER = [ - "交易品种", - "交易代码", - "交易单位", - "报价单位", - "最小变动价位", - "最低交易保证金", - "涨跌停幅度", - "合约月份", - "交易时间", - "最后交易日", - "交割日期", - "交割方式", - "交割地点", - "交割品级", - "上市日期", - "交易所", -] - -SKIP_ITEMS = {"", "-", "None", "nan", "null"} - - -def _normalize_ths_code(raw: str) -> Optional[str]: - code = (raw or "").strip() - if not code: - return None - # 已是完整合约 - if re.match(r"^[A-Za-z]+\d{3,4}$", code): - return code - # 仅品种字母时尝试匹配主力 - results = search_symbols(code) - if results: - return results[0].get("ths_code") or code - codes = ths_to_codes(code) - if codes: - return codes["ths_code"] - return code - - -def _to_sina_quote_symbol(ths_code: str) -> str: - m = re.match(r"^([A-Za-z]+)(\d+)$", ths_code.strip()) - if not m: - return ths_code.upper() - return m.group(1).upper() + m.group(2) - - -def _to_em_page_symbol(ths_code: str) -> str: - return ths_code.strip().lower() + "F" - - -def _clean_value(val: Any) -> str: - if val is None: - return "" - s = str(val).strip() - if s in SKIP_ITEMS: - return "" - return s - - -def _rows_from_dict(data: dict[str, str]) -> list[dict]: - rows: list[dict] = [] - seen: set[str] = set() - for label in DISPLAY_ORDER: - val = _clean_value(data.get(label)) - if not val: - continue - hint = _clean_value(data.get(f"{label}_hint")) - rows.append({"label": label, "value": val, "hint": hint}) - seen.add(label) - for label, val in data.items(): - if label.endswith("_hint") or label in seen: - continue - val = _clean_value(val) - if val: - rows.append({"label": label, "value": val, "hint": ""}) - return rows - - -def _add_computed_hints(ths_code: str, data: dict[str, str]) -> None: - spec = get_contract_spec(ths_code) - mult = spec.get("mult") or 0 - tick_raw = data.get("最小变动价位", "") - m = re.search(r"([\d.]+)", tick_raw) - if m and mult: - tick = float(m.group(1)) - data["最小变动价位_hint"] = f"一手合约最小波动{round(tick * mult, 2)}元" - - -def _fetch_em_direct(em_symbol: str) -> dict[str, str]: - page_url = f"https://quote.eastmoney.com/qihuo/{em_symbol}.html" - r = requests.get(page_url, timeout=12) - r.encoding = r.apparent_encoding or "utf-8" - inner = None - for pat in [ - r"futures_([A-Za-z0-9_]+)", - r"#(futures_[A-Za-z0-9_]+)", - r"/(futures_[A-Za-z0-9_]+)", - ]: - m = re.search(pat, r.text) - if m: - inner = m.group(1).replace("futures_", "") - break - if not inner: - raise ValueError("无法解析东方财富合约标识") - - info_url = f"https://futsse-static.eastmoney.com/redis?msgid={inner}_info" - r2 = requests.get(info_url, timeout=12) - payload = r2.json() - if not isinstance(payload, dict): - raise ValueError("东方财富返回数据无效") - - out: dict[str, str] = {} - for key, label in EM_LABEL_MAP.items(): - val = _clean_value(payload.get(key)) - if val: - out[label] = val - if not out: - raise ValueError("东方财富合约字段为空") - return out - - -def _fetch_em_akshare(em_symbol: str) -> dict[str, str]: - import akshare as ak - - df = ak.futures_contract_detail_em(symbol=em_symbol) - out: dict[str, str] = {} - for _, row in df.iterrows(): - label = _clean_value(row.get("item")) - val = _clean_value(row.get("value")) - if label and val: - if label == "跌涨停板幅度": - label = "涨跌停幅度" - if label == "最后交割日": - label = "交割日期" - if label == "上市交易所": - label = "交易所" - if label == "合约交割月份": - label = "合约月份" - if label == "最初交易保证金": - label = "最低交易保证金" - if label == "最小变动价格": - label = "最小变动价位" - out[label] = val - return out - - -def _fetch_sina_direct(sina_symbol: str) -> dict[str, str]: - from io import StringIO - - import pandas as pd - - url = f"https://finance.sina.com.cn/futures/quotes/{sina_symbol}.shtml" - r = requests.get(url, timeout=12, headers={"Referer": "https://finance.sina.com.cn/"}) - r.encoding = "gb2312" - tables = pd.read_html(StringIO(r.text)) - if len(tables) < 7: - raise ValueError("新浪页面结构变化") - temp_df = tables[6] - parts = [] - for ncol in [slice(0, 2), slice(2, 4), slice(4, None)]: - part = temp_df.iloc[:, ncol] - part.columns = ["item", "value"] - parts.append(part) - merged = pd.concat(parts, axis=0, ignore_index=True) - out: dict[str, str] = {} - for _, row in merged.iterrows(): - label = _clean_value(row["item"]) - val = _clean_value(row["value"]) - if not label or not val or len(label) > 80 or "发帖" in val: - continue - out[label] = val - return out - - -def _fetch_sina_akshare(sina_symbol: str) -> dict[str, str]: - import akshare as ak - - df = ak.futures_contract_detail(symbol=sina_symbol) - out: dict[str, str] = {} - for _, row in df.iterrows(): - label = _clean_value(row.get("item")) - val = _clean_value(row.get("value")) - if label and val and "发帖" not in val: - out[label] = val - return out - - -def _merge_profile(primary: dict[str, str], secondary: dict[str, str]) -> dict[str, str]: - merged = dict(secondary) - merged.update(primary) - return merged - - -def get_contract_profile(raw_symbol: str) -> Optional[dict]: - ths_code = _normalize_ths_code(raw_symbol) - if not ths_code: - return None - - em_symbol = _to_em_page_symbol(ths_code) - sina_symbol = _to_sina_quote_symbol(ths_code) - data: dict[str, str] = {} - source_parts: list[str] = [] - - # 东方财富(字段与看盘软件简介接近) - try: - try: - data = _fetch_em_akshare(em_symbol) - source_parts.append("东方财富") - except ImportError: - data = _fetch_em_direct(em_symbol) - source_parts.append("东方财富") - except Exception as exc: - logger.warning("eastmoney profile failed %s: %s", em_symbol, exc) - - # 新浪补充交割地点、上市日期等 - sina_data: dict[str, str] = {} - try: - try: - sina_data = _fetch_sina_akshare(sina_symbol) - except ImportError: - sina_data = _fetch_sina_direct(sina_symbol) - if sina_data: - source_parts.append("新浪") - except Exception as exc: - logger.warning("sina profile failed %s: %s", sina_symbol, exc) - - if sina_data: - data = _merge_profile(data, sina_data) - - if not data: - return None - - _add_computed_hints(ths_code, data) - rows = _rows_from_dict(data) - if not rows: - return None - - return { - "ths_code": ths_code, - "symbol_name": data.get("交易品种", ""), - "exchange": data.get("交易所", ""), - "rows": rows, - "source": " + ".join(source_parts) if source_parts else "未知", - } +# Copyright (c) 2025-2026 马建军. All rights reserved. +# 专有软件 — 未经授权禁止复制、传播、转售。 +# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 +# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md + +"""期货合约简介:东方财富 / 新浪 / AKShare。""" +import logging +import re +from typing import Any, Optional + +import requests + +from modules.core.contract_specs import get_contract_spec +from modules.core.symbols import ths_to_codes, search_symbols + +logger = logging.getLogger(__name__) + +EM_LABEL_MAP = { + "vname": "交易品种", + "vcode": "交易代码", + "jydw": "交易单位", + "bjdw": "报价单位", + "market": "交易所", + "zxbddw": "最小变动价位", + "zdtbfd": "涨跌停幅度", + "hyjgyf": "合约月份", + "jysj": "交易时间", + "zhjyr": "最后交易日", + "zhjgr": "交割日期", + "jgpj": "交割品级", + "zcjybzj": "最低交易保证金", + "jgfs": "交割方式", + "jgdd": "交割地点", + "ssrq": "上市日期", +} + +DISPLAY_ORDER = [ + "交易品种", + "交易代码", + "交易单位", + "报价单位", + "最小变动价位", + "最低交易保证金", + "涨跌停幅度", + "合约月份", + "交易时间", + "最后交易日", + "交割日期", + "交割方式", + "交割地点", + "交割品级", + "上市日期", + "交易所", +] + +SKIP_ITEMS = {"", "-", "None", "nan", "null"} + + +def _normalize_ths_code(raw: str) -> Optional[str]: + code = (raw or "").strip() + if not code: + return None + # 已是完整合约 + if re.match(r"^[A-Za-z]+\d{3,4}$", code): + return code + # 仅品种字母时尝试匹配主力 + results = search_symbols(code) + if results: + return results[0].get("ths_code") or code + codes = ths_to_codes(code) + if codes: + return codes["ths_code"] + return code + + +def _to_sina_quote_symbol(ths_code: str) -> str: + m = re.match(r"^([A-Za-z]+)(\d+)$", ths_code.strip()) + if not m: + return ths_code.upper() + return m.group(1).upper() + m.group(2) + + +def _to_em_page_symbol(ths_code: str) -> str: + return ths_code.strip().lower() + "F" + + +def _clean_value(val: Any) -> str: + if val is None: + return "" + s = str(val).strip() + if s in SKIP_ITEMS: + return "" + return s + + +def _rows_from_dict(data: dict[str, str]) -> list[dict]: + rows: list[dict] = [] + seen: set[str] = set() + for label in DISPLAY_ORDER: + val = _clean_value(data.get(label)) + if not val: + continue + hint = _clean_value(data.get(f"{label}_hint")) + rows.append({"label": label, "value": val, "hint": hint}) + seen.add(label) + for label, val in data.items(): + if label.endswith("_hint") or label in seen: + continue + val = _clean_value(val) + if val: + rows.append({"label": label, "value": val, "hint": ""}) + return rows + + +def _add_computed_hints(ths_code: str, data: dict[str, str]) -> None: + spec = get_contract_spec(ths_code) + mult = spec.get("mult") or 0 + tick_raw = data.get("最小变动价位", "") + m = re.search(r"([\d.]+)", tick_raw) + if m and mult: + tick = float(m.group(1)) + data["最小变动价位_hint"] = f"一手合约最小波动{round(tick * mult, 2)}元" + + +def _fetch_em_direct(em_symbol: str) -> dict[str, str]: + page_url = f"https://quote.eastmoney.com/qihuo/{em_symbol}.html" + r = requests.get(page_url, timeout=12) + r.encoding = r.apparent_encoding or "utf-8" + inner = None + for pat in [ + r"futures_([A-Za-z0-9_]+)", + r"#(futures_[A-Za-z0-9_]+)", + r"/(futures_[A-Za-z0-9_]+)", + ]: + m = re.search(pat, r.text) + if m: + inner = m.group(1).replace("futures_", "") + break + if not inner: + raise ValueError("无法解析东方财富合约标识") + + info_url = f"https://futsse-static.eastmoney.com/redis?msgid={inner}_info" + r2 = requests.get(info_url, timeout=12) + payload = r2.json() + if not isinstance(payload, dict): + raise ValueError("东方财富返回数据无效") + + out: dict[str, str] = {} + for key, label in EM_LABEL_MAP.items(): + val = _clean_value(payload.get(key)) + if val: + out[label] = val + if not out: + raise ValueError("东方财富合约字段为空") + return out + + +def _fetch_em_akshare(em_symbol: str) -> dict[str, str]: + import akshare as ak + + df = ak.futures_contract_detail_em(symbol=em_symbol) + out: dict[str, str] = {} + for _, row in df.iterrows(): + label = _clean_value(row.get("item")) + val = _clean_value(row.get("value")) + if label and val: + if label == "跌涨停板幅度": + label = "涨跌停幅度" + if label == "最后交割日": + label = "交割日期" + if label == "上市交易所": + label = "交易所" + if label == "合约交割月份": + label = "合约月份" + if label == "最初交易保证金": + label = "最低交易保证金" + if label == "最小变动价格": + label = "最小变动价位" + out[label] = val + return out + + +def _fetch_sina_direct(sina_symbol: str) -> dict[str, str]: + from io import StringIO + + import pandas as pd + + url = f"https://finance.sina.com.cn/futures/quotes/{sina_symbol}.shtml" + r = requests.get(url, timeout=12, headers={"Referer": "https://finance.sina.com.cn/"}) + r.encoding = "gb2312" + tables = pd.read_html(StringIO(r.text)) + if len(tables) < 7: + raise ValueError("新浪页面结构变化") + temp_df = tables[6] + parts = [] + for ncol in [slice(0, 2), slice(2, 4), slice(4, None)]: + part = temp_df.iloc[:, ncol] + part.columns = ["item", "value"] + parts.append(part) + merged = pd.concat(parts, axis=0, ignore_index=True) + out: dict[str, str] = {} + for _, row in merged.iterrows(): + label = _clean_value(row["item"]) + val = _clean_value(row["value"]) + if not label or not val or len(label) > 80 or "发帖" in val: + continue + out[label] = val + return out + + +def _fetch_sina_akshare(sina_symbol: str) -> dict[str, str]: + import akshare as ak + + df = ak.futures_contract_detail(symbol=sina_symbol) + out: dict[str, str] = {} + for _, row in df.iterrows(): + label = _clean_value(row.get("item")) + val = _clean_value(row.get("value")) + if label and val and "发帖" not in val: + out[label] = val + return out + + +def _merge_profile(primary: dict[str, str], secondary: dict[str, str]) -> dict[str, str]: + merged = dict(secondary) + merged.update(primary) + return merged + + +def get_contract_profile(raw_symbol: str) -> Optional[dict]: + ths_code = _normalize_ths_code(raw_symbol) + if not ths_code: + return None + + em_symbol = _to_em_page_symbol(ths_code) + sina_symbol = _to_sina_quote_symbol(ths_code) + data: dict[str, str] = {} + source_parts: list[str] = [] + + # 东方财富(字段与看盘软件简介接近) + try: + try: + data = _fetch_em_akshare(em_symbol) + source_parts.append("东方财富") + except ImportError: + data = _fetch_em_direct(em_symbol) + source_parts.append("东方财富") + except Exception as exc: + logger.warning("eastmoney profile failed %s: %s", em_symbol, exc) + + # 新浪补充交割地点、上市日期等 + sina_data: dict[str, str] = {} + try: + try: + sina_data = _fetch_sina_akshare(sina_symbol) + except ImportError: + sina_data = _fetch_sina_direct(sina_symbol) + if sina_data: + source_parts.append("新浪") + except Exception as exc: + logger.warning("sina profile failed %s: %s", sina_symbol, exc) + + if sina_data: + data = _merge_profile(data, sina_data) + + if not data: + return None + + _add_computed_hints(ths_code, data) + rows = _rows_from_dict(data) + if not rows: + return None + + return { + "ths_code": ths_code, + "symbol_name": data.get("交易品种", ""), + "exchange": data.get("交易所", ""), + "rows": rows, + "source": " + ".join(source_parts) if source_parts else "未知", + } diff --git a/contract_specs.py b/modules/core/contract_specs.py similarity index 96% rename from contract_specs.py rename to modules/core/contract_specs.py index ba085ef..2e77c31 100644 --- a/contract_specs.py +++ b/modules/core/contract_specs.py @@ -1,166 +1,166 @@ -# Copyright (c) 2025-2026 马建军. All rights reserved. -# 专有软件 — 未经授权禁止复制、传播、转售。 -# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 -# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md - -"""国内期货合约乘数与参考保证金比例(用于估算保证金与风险)。""" -import re -from typing import Optional - -DEFAULT_SPEC = {"mult": 10, "margin_rate": 0.10, "tick_size": 1.0} - -# 参考交易所常见规格(乘数 + 保证金比例 + 最小变动价位) -_SPEC_BY_THS: dict[str, dict] = { - "ag": {"mult": 15, "margin_rate": 0.14, "tick_size": 1.0}, - "au": {"mult": 1000, "margin_rate": 0.10, "tick_size": 0.02}, - "cu": {"mult": 5, "margin_rate": 0.10, "tick_size": 10.0}, - "al": {"mult": 5, "margin_rate": 0.10}, - "zn": {"mult": 5, "margin_rate": 0.10}, - "pb": {"mult": 5, "margin_rate": 0.10}, - "ni": {"mult": 1, "margin_rate": 0.12}, - "sn": {"mult": 1, "margin_rate": 0.12}, - "rb": {"mult": 10, "margin_rate": 0.09}, - "hc": {"mult": 10, "margin_rate": 0.09}, - "ss": {"mult": 5, "margin_rate": 0.11}, - "sc": {"mult": 1000, "margin_rate": 0.11}, - "fu": {"mult": 10, "margin_rate": 0.11}, - "bu": {"mult": 10, "margin_rate": 0.11}, - "ru": {"mult": 10, "margin_rate": 0.11}, - "sp": {"mult": 10, "margin_rate": 0.10}, - "i": {"mult": 100, "margin_rate": 0.11}, - "j": {"mult": 100, "margin_rate": 0.12}, - "jm": {"mult": 60, "margin_rate": 0.12}, - "m": {"mult": 10, "margin_rate": 0.08}, - "y": {"mult": 10, "margin_rate": 0.08}, - "p": {"mult": 10, "margin_rate": 0.09}, - "c": {"mult": 10, "margin_rate": 0.08}, - "cs": {"mult": 10, "margin_rate": 0.08}, - "jd": {"mult": 10, "margin_rate": 0.09}, - "lh": {"mult": 16, "margin_rate": 0.12}, - "l": {"mult": 5, "margin_rate": 0.09}, - "pp": {"mult": 5, "margin_rate": 0.09}, - "v": {"mult": 5, "margin_rate": 0.09}, - "eg": {"mult": 10, "margin_rate": 0.09}, - "eb": {"mult": 5, "margin_rate": 0.10}, - "pg": {"mult": 20, "margin_rate": 0.10}, - "RM": {"mult": 10, "margin_rate": 0.08}, - "OI": {"mult": 10, "margin_rate": 0.08}, - "SR": {"mult": 10, "margin_rate": 0.08}, - "CF": {"mult": 5, "margin_rate": 0.08}, - "MA": {"mult": 10, "margin_rate": 0.09}, - "TA": {"mult": 5, "margin_rate": 0.09}, - "FG": {"mult": 20, "margin_rate": 0.10}, - "SA": {"mult": 20, "margin_rate": 0.10}, - "UR": {"mult": 20, "margin_rate": 0.10}, - "SF": {"mult": 5, "margin_rate": 0.10}, - "SM": {"mult": 5, "margin_rate": 0.10}, - "AP": {"mult": 10, "margin_rate": 0.10}, - "CJ": {"mult": 5, "margin_rate": 0.10}, - "PK": {"mult": 5, "margin_rate": 0.10}, - "IF": {"mult": 300, "margin_rate": 0.12, "tick_size": 0.2}, - "IH": {"mult": 300, "margin_rate": 0.12, "tick_size": 0.2}, - "IC": {"mult": 200, "margin_rate": 0.12, "tick_size": 0.2}, - "IM": {"mult": 200, "margin_rate": 0.12, "tick_size": 0.2}, -} - -_TICK_OVERRIDES: dict[str, float] = { - "sc": 0.1, "TA": 2.0, "CF": 5.0, "SF": 2.0, "SM": 2.0, -} - - -def get_contract_spec(ths_code: str) -> dict: - code = (ths_code or "").strip() - m = re.match(r"^([A-Za-z]+)", code) - if not m: - return dict(DEFAULT_SPEC) - letters = m.group(1) - spec = _SPEC_BY_THS.get(letters) or _SPEC_BY_THS.get(letters.upper()) or _SPEC_BY_THS.get(letters.lower()) - if spec: - tick = spec.get("tick_size") - if tick is None: - tick = _TICK_OVERRIDES.get(letters) or _TICK_OVERRIDES.get(letters.upper()) or 1.0 - return {"mult": spec["mult"], "margin_rate": spec["margin_rate"], "tick_size": float(tick)} - return dict(DEFAULT_SPEC) - - -def margin_one_lot( - ths_code: str, - price: float, - *, - direction: str = "long", - trading_mode: str | None = None, -) -> tuple[float, str, dict]: - """1 手保证金。CTP 已连接时优先读柜台合约保证金率,否则用本地参考规格估算。 - - direction 可为 long / short / max(多空费率取较大值,用于可开仓品种表)。 - 返回 (保证金, 来源 estimate|ctp, 合约规格片段)。 - """ - spec = get_contract_spec(ths_code) - est = 0.0 - if price and price > 0: - est = round(float(price) * spec["mult"] * spec["margin_rate"], 2) - if trading_mode: - try: - from vnpy_bridge import ctp_estimate_margin_one_lot, ctp_lookup_contract_spec, ctp_status - - if ctp_status(trading_mode).get("connected"): - ctp_margin = ctp_estimate_margin_one_lot( - trading_mode, ths_code, float(price), direction=direction, - ) - if ctp_margin and ctp_margin > 0: - merged = dict(spec) - ctp_spec = ctp_lookup_contract_spec(trading_mode, ths_code) or {} - if ctp_spec.get("mult"): - merged["mult"] = ctp_spec["mult"] - if ctp_spec.get("tick_size"): - merged["tick_size"] = ctp_spec["tick_size"] - if ctp_spec.get("margin_rate"): - merged["margin_rate"] = ctp_spec["margin_rate"] - return float(ctp_margin), "ctp", merged - except Exception: - pass - return est, "estimate", spec - - -def calc_position_metrics( - direction: str, - entry: float, - stop_loss: float, - take_profit: float, - lots: float, - mark_price: Optional[float], - capital: float, - ths_code: str, -) -> dict: - spec = get_contract_spec(ths_code) - mult = spec["mult"] - margin_rate = spec["margin_rate"] - lots = lots or 1.0 - margin = entry * mult * lots * margin_rate - - if direction == "long": - risk_amt = max(0.0, (entry - stop_loss) * mult * lots) - reward = max(0.0, (take_profit - entry) * mult * lots) - float_pnl = (mark_price - entry) * mult * lots if mark_price is not None else None - else: - risk_amt = max(0.0, (stop_loss - entry) * mult * lots) - reward = max(0.0, (entry - take_profit) * mult * lots) - float_pnl = (entry - mark_price) * mult * lots if mark_price is not None else None - - risk_pct = (risk_amt / capital * 100) if capital > 0 else 0.0 - pos_pct = (margin / capital * 100) if capital > 0 else 0.0 - rr = (reward / risk_amt) if risk_amt > 0 else None - float_pct = (float_pnl / margin * 100) if margin > 0 and float_pnl is not None else None - - return { - "mult": mult, - "margin_rate": margin_rate, - "margin": round(margin, 2), - "risk_amount": round(risk_amt, 2), - "risk_pct": round(risk_pct, 2), - "position_pct": round(pos_pct, 2), - "float_pnl": round(float_pnl, 2) if float_pnl is not None else None, - "float_pct": round(float_pct, 2) if float_pct is not None else None, - "reward_amount": round(reward, 2) if reward else None, - "rr_ratio": round(rr, 2) if rr is not None else None, - } +# Copyright (c) 2025-2026 马建军. All rights reserved. +# 专有软件 — 未经授权禁止复制、传播、转售。 +# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 +# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md + +"""国内期货合约乘数与参考保证金比例(用于估算保证金与风险)。""" +import re +from typing import Optional + +DEFAULT_SPEC = {"mult": 10, "margin_rate": 0.10, "tick_size": 1.0} + +# 参考交易所常见规格(乘数 + 保证金比例 + 最小变动价位) +_SPEC_BY_THS: dict[str, dict] = { + "ag": {"mult": 15, "margin_rate": 0.14, "tick_size": 1.0}, + "au": {"mult": 1000, "margin_rate": 0.10, "tick_size": 0.02}, + "cu": {"mult": 5, "margin_rate": 0.10, "tick_size": 10.0}, + "al": {"mult": 5, "margin_rate": 0.10}, + "zn": {"mult": 5, "margin_rate": 0.10}, + "pb": {"mult": 5, "margin_rate": 0.10}, + "ni": {"mult": 1, "margin_rate": 0.12}, + "sn": {"mult": 1, "margin_rate": 0.12}, + "rb": {"mult": 10, "margin_rate": 0.09}, + "hc": {"mult": 10, "margin_rate": 0.09}, + "ss": {"mult": 5, "margin_rate": 0.11}, + "sc": {"mult": 1000, "margin_rate": 0.11}, + "fu": {"mult": 10, "margin_rate": 0.11}, + "bu": {"mult": 10, "margin_rate": 0.11}, + "ru": {"mult": 10, "margin_rate": 0.11}, + "sp": {"mult": 10, "margin_rate": 0.10}, + "i": {"mult": 100, "margin_rate": 0.11}, + "j": {"mult": 100, "margin_rate": 0.12}, + "jm": {"mult": 60, "margin_rate": 0.12}, + "m": {"mult": 10, "margin_rate": 0.08}, + "y": {"mult": 10, "margin_rate": 0.08}, + "p": {"mult": 10, "margin_rate": 0.09}, + "c": {"mult": 10, "margin_rate": 0.08}, + "cs": {"mult": 10, "margin_rate": 0.08}, + "jd": {"mult": 10, "margin_rate": 0.09}, + "lh": {"mult": 16, "margin_rate": 0.12}, + "l": {"mult": 5, "margin_rate": 0.09}, + "pp": {"mult": 5, "margin_rate": 0.09}, + "v": {"mult": 5, "margin_rate": 0.09}, + "eg": {"mult": 10, "margin_rate": 0.09}, + "eb": {"mult": 5, "margin_rate": 0.10}, + "pg": {"mult": 20, "margin_rate": 0.10}, + "RM": {"mult": 10, "margin_rate": 0.08}, + "OI": {"mult": 10, "margin_rate": 0.08}, + "SR": {"mult": 10, "margin_rate": 0.08}, + "CF": {"mult": 5, "margin_rate": 0.08}, + "MA": {"mult": 10, "margin_rate": 0.09}, + "TA": {"mult": 5, "margin_rate": 0.09}, + "FG": {"mult": 20, "margin_rate": 0.10}, + "SA": {"mult": 20, "margin_rate": 0.10}, + "UR": {"mult": 20, "margin_rate": 0.10}, + "SF": {"mult": 5, "margin_rate": 0.10}, + "SM": {"mult": 5, "margin_rate": 0.10}, + "AP": {"mult": 10, "margin_rate": 0.10}, + "CJ": {"mult": 5, "margin_rate": 0.10}, + "PK": {"mult": 5, "margin_rate": 0.10}, + "IF": {"mult": 300, "margin_rate": 0.12, "tick_size": 0.2}, + "IH": {"mult": 300, "margin_rate": 0.12, "tick_size": 0.2}, + "IC": {"mult": 200, "margin_rate": 0.12, "tick_size": 0.2}, + "IM": {"mult": 200, "margin_rate": 0.12, "tick_size": 0.2}, +} + +_TICK_OVERRIDES: dict[str, float] = { + "sc": 0.1, "TA": 2.0, "CF": 5.0, "SF": 2.0, "SM": 2.0, +} + + +def get_contract_spec(ths_code: str) -> dict: + code = (ths_code or "").strip() + m = re.match(r"^([A-Za-z]+)", code) + if not m: + return dict(DEFAULT_SPEC) + letters = m.group(1) + spec = _SPEC_BY_THS.get(letters) or _SPEC_BY_THS.get(letters.upper()) or _SPEC_BY_THS.get(letters.lower()) + if spec: + tick = spec.get("tick_size") + if tick is None: + tick = _TICK_OVERRIDES.get(letters) or _TICK_OVERRIDES.get(letters.upper()) or 1.0 + return {"mult": spec["mult"], "margin_rate": spec["margin_rate"], "tick_size": float(tick)} + return dict(DEFAULT_SPEC) + + +def margin_one_lot( + ths_code: str, + price: float, + *, + direction: str = "long", + trading_mode: str | None = None, +) -> tuple[float, str, dict]: + """1 手保证金。CTP 已连接时优先读柜台合约保证金率,否则用本地参考规格估算。 + + direction 可为 long / short / max(多空费率取较大值,用于可开仓品种表)。 + 返回 (保证金, 来源 estimate|ctp, 合约规格片段)。 + """ + spec = get_contract_spec(ths_code) + est = 0.0 + if price and price > 0: + est = round(float(price) * spec["mult"] * spec["margin_rate"], 2) + if trading_mode: + try: + from modules.ctp.vnpy_bridge import ctp_estimate_margin_one_lot, ctp_lookup_contract_spec, ctp_status + + if ctp_status(trading_mode).get("connected"): + ctp_margin = ctp_estimate_margin_one_lot( + trading_mode, ths_code, float(price), direction=direction, + ) + if ctp_margin and ctp_margin > 0: + merged = dict(spec) + ctp_spec = ctp_lookup_contract_spec(trading_mode, ths_code) or {} + if ctp_spec.get("mult"): + merged["mult"] = ctp_spec["mult"] + if ctp_spec.get("tick_size"): + merged["tick_size"] = ctp_spec["tick_size"] + if ctp_spec.get("margin_rate"): + merged["margin_rate"] = ctp_spec["margin_rate"] + return float(ctp_margin), "ctp", merged + except Exception: + pass + return est, "estimate", spec + + +def calc_position_metrics( + direction: str, + entry: float, + stop_loss: float, + take_profit: float, + lots: float, + mark_price: Optional[float], + capital: float, + ths_code: str, +) -> dict: + spec = get_contract_spec(ths_code) + mult = spec["mult"] + margin_rate = spec["margin_rate"] + lots = lots or 1.0 + margin = entry * mult * lots * margin_rate + + if direction == "long": + risk_amt = max(0.0, (entry - stop_loss) * mult * lots) + reward = max(0.0, (take_profit - entry) * mult * lots) + float_pnl = (mark_price - entry) * mult * lots if mark_price is not None else None + else: + risk_amt = max(0.0, (stop_loss - entry) * mult * lots) + reward = max(0.0, (entry - take_profit) * mult * lots) + float_pnl = (entry - mark_price) * mult * lots if mark_price is not None else None + + risk_pct = (risk_amt / capital * 100) if capital > 0 else 0.0 + pos_pct = (margin / capital * 100) if capital > 0 else 0.0 + rr = (reward / risk_amt) if risk_amt > 0 else None + float_pct = (float_pnl / margin * 100) if margin > 0 and float_pnl is not None else None + + return { + "mult": mult, + "margin_rate": margin_rate, + "margin": round(margin, 2), + "risk_amount": round(risk_amt, 2), + "risk_pct": round(risk_pct, 2), + "position_pct": round(pos_pct, 2), + "float_pnl": round(float_pnl, 2) if float_pnl is not None else None, + "float_pct": round(float_pct, 2) if float_pct is not None else None, + "reward_amount": round(reward, 2) if reward else None, + "rr_ratio": round(rr, 2) if rr is not None else None, + } diff --git a/db_conn.py b/modules/core/db_conn.py similarity index 99% rename from db_conn.py rename to modules/core/db_conn.py index d6bc38e..5c976bb 100644 --- a/db_conn.py +++ b/modules/core/db_conn.py @@ -13,7 +13,9 @@ import threading import time from typing import Any, Iterable, Optional, Sequence -DB_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "futures.db") +from modules.core.paths import DB_PATH as _ROOT_DB_PATH + +DB_PATH = _ROOT_DB_PATH _backend_lock = threading.Lock() _backend: Optional[str] = None diff --git a/modules/core/deps.py b/modules/core/deps.py new file mode 100644 index 0000000..8cfaeb1 --- /dev/null +++ b/modules/core/deps.py @@ -0,0 +1,46 @@ +# Copyright (c) 2025-2026 马建军. All rights reserved. + +"""Shared dependencies passed into each feature module at register() time.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Callable, Optional + + +@dataclass +class AppDeps: + app: Any + get_db: Callable + get_setting: Callable + set_setting: Callable + login_required: Callable + require_nav: Callable + fetch_price: Callable + send_wechat_msg: Callable + touch_stats_cache: Callable + get_stats_data: Callable + build_market_quote_payload: Callable + today_str: Callable + expire_old_plans: Callable + check_order_plans: Callable + check_key_monitors: Callable + background_task: Callable + start_background_threads: Callable + tz: Any + db_path: str + upload_dir: str + open_types: list + exit_triggers: list + behavior_tags: list + kline_periods: list + kline_cutoffs: list + calc_holding_duration: Callable + holding_to_minutes: Callable + classify_close_result: Callable + calc_rr_ratio: Callable + calc_theoretical_pnl: Callable + parse_review_date_filter: Callable + trading_mode: Callable + static_asset_v: Callable + ua_is_phone: Callable diff --git a/doc_render.py b/modules/core/doc_render.py similarity index 100% rename from doc_render.py rename to modules/core/doc_render.py diff --git a/env_file.py b/modules/core/env_file.py similarity index 82% rename from env_file.py rename to modules/core/env_file.py index 85ae301..90baf43 100644 --- a/env_file.py +++ b/modules/core/env_file.py @@ -9,12 +9,26 @@ from __future__ import annotations import os import re -ENV_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), ".env") +from modules.core.paths import ENV_FILE, LEGACY_ENV_FILE + + +def _default_env_path() -> str: + if ENV_FILE.is_file(): + return str(ENV_FILE) + if LEGACY_ENV_FILE.is_file(): + return str(LEGACY_ENV_FILE) + return str(ENV_FILE) + + +ENV_PATH = _default_env_path() _KEY_RE = re.compile(r"^([A-Za-z_][A-Za-z0-9_]*)\s*=") def env_file_path(path: str | None = None) -> str: - return path or ENV_PATH + if path: + return path + from modules.core.paths import resolve_env_file + return resolve_env_file() def _quote_env_value(value: str) -> str: diff --git a/locale_fix.py b/modules/core/locale_fix.py similarity index 100% rename from locale_fix.py rename to modules/core/locale_fix.py diff --git a/modules/core/paths.py b/modules/core/paths.py new file mode 100644 index 0000000..9550061 --- /dev/null +++ b/modules/core/paths.py @@ -0,0 +1,37 @@ +# Copyright (c) 2025-2026 马建军. All rights reserved. + +"""Repository layout paths — single source for config, data, uploads.""" + +from __future__ import annotations + +import os +from pathlib import Path + +# .../qihuo/modules/core/paths.py -> repo root +ROOT = Path(__file__).resolve().parents[2] + +CONFIG_DIR = ROOT / "config" +ENV_FILE = CONFIG_DIR / ".env" +LEGACY_ENV_FILE = ROOT / ".env" + +DATA_DIR = ROOT / "data" +UPLOADS_DIR = ROOT / "uploads" +LOGS_DIR = ROOT / "logs" + +DB_PATH = str(ROOT / "futures.db") + + +def ensure_runtime_dirs() -> None: + DATA_DIR.mkdir(parents=True, exist_ok=True) + UPLOADS_DIR.mkdir(parents=True, exist_ok=True) + LOGS_DIR.mkdir(parents=True, exist_ok=True) + CONFIG_DIR.mkdir(parents=True, exist_ok=True) + + +def resolve_env_file() -> str: + """Prefer config/.env, fall back to legacy root .env.""" + if ENV_FILE.is_file(): + return str(ENV_FILE) + if LEGACY_ENV_FILE.is_file(): + return str(LEGACY_ENV_FILE) + return str(ENV_FILE) diff --git a/symbols.py b/modules/core/symbols.py similarity index 95% rename from symbols.py rename to modules/core/symbols.py index 2ad0db1..2a87714 100644 --- a/symbols.py +++ b/modules/core/symbols.py @@ -1,683 +1,683 @@ -# Copyright (c) 2025-2026 马建军. All rights reserved. -# 专有软件 — 未经授权禁止复制、传播、转售。 -# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 -# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md - -""" -期货品种与同花顺代码映射。 -展示同花顺合约代码(ag2608);行情默认新浪,机构用户可通过环境变量启用同花顺 iFinD。 -""" -import re -import threading -import time -from collections import defaultdict -from concurrent.futures import ThreadPoolExecutor, as_completed -from datetime import date -from typing import Optional - -from market import fetch_raw_for_volume, get_price as market_get_price, THS_EX_SUFFIX - -PRODUCTS = [ - {"name": "白银", "ths": "ag", "sina": "AG", "exchange": "上期所", "ex": "SHFE"}, - {"name": "黄金", "ths": "au", "sina": "AU", "exchange": "上期所", "ex": "SHFE"}, - {"name": "铜", "ths": "cu", "sina": "CU", "exchange": "上期所", "ex": "SHFE"}, - {"name": "铝", "ths": "al", "sina": "AL", "exchange": "上期所", "ex": "SHFE"}, - {"name": "锌", "ths": "zn", "sina": "ZN", "exchange": "上期所", "ex": "SHFE"}, - {"name": "铅", "ths": "pb", "sina": "PB", "exchange": "上期所", "ex": "SHFE"}, - {"name": "镍", "ths": "ni", "sina": "NI", "exchange": "上期所", "ex": "SHFE"}, - {"name": "锡", "ths": "sn", "sina": "SN", "exchange": "上期所", "ex": "SHFE"}, - {"name": "螺纹钢", "ths": "rb", "sina": "RB", "exchange": "上期所", "ex": "SHFE"}, - {"name": "热卷", "ths": "hc", "sina": "HC", "exchange": "上期所", "ex": "SHFE"}, - {"name": "不锈钢", "ths": "ss", "sina": "SS", "exchange": "上期所", "ex": "SHFE"}, - {"name": "原油", "ths": "sc", "sina": "SC", "exchange": "上期能源", "ex": "INE"}, - {"name": "燃油", "ths": "fu", "sina": "FU", "exchange": "上期所", "ex": "SHFE"}, - {"name": "沥青", "ths": "bu", "sina": "BU", "exchange": "上期所", "ex": "SHFE"}, - {"name": "橡胶", "ths": "ru", "sina": "RU", "exchange": "上期所", "ex": "SHFE"}, - {"name": "纸浆", "ths": "sp", "sina": "SP", "exchange": "上期所", "ex": "SHFE"}, - {"name": "铁矿石", "ths": "i", "sina": "I", "exchange": "大商所", "ex": "DCE"}, - {"name": "焦炭", "ths": "j", "sina": "J", "exchange": "大商所", "ex": "DCE"}, - {"name": "焦煤", "ths": "jm", "sina": "JM", "exchange": "大商所", "ex": "DCE"}, - {"name": "豆粕", "ths": "m", "sina": "M", "exchange": "大商所", "ex": "DCE"}, - {"name": "豆油", "ths": "y", "sina": "Y", "exchange": "大商所", "ex": "DCE"}, - {"name": "棕榈油", "ths": "p", "sina": "P", "exchange": "大商所", "ex": "DCE"}, - {"name": "玉米", "ths": "c", "sina": "C", "exchange": "大商所", "ex": "DCE"}, - {"name": "淀粉", "ths": "cs", "sina": "CS", "exchange": "大商所", "ex": "DCE"}, - {"name": "鸡蛋", "ths": "jd", "sina": "JD", "exchange": "大商所", "ex": "DCE"}, - {"name": "生猪", "ths": "lh", "sina": "LH", "exchange": "大商所", "ex": "DCE"}, - {"name": "聚乙烯", "ths": "l", "sina": "L", "exchange": "大商所", "ex": "DCE"}, - {"name": "聚丙烯", "ths": "pp", "sina": "PP", "exchange": "大商所", "ex": "DCE"}, - {"name": "PVC", "ths": "v", "sina": "V", "exchange": "大商所", "ex": "DCE"}, - {"name": "乙二醇", "ths": "eg", "sina": "EG", "exchange": "大商所", "ex": "DCE"}, - {"name": "苯乙烯", "ths": "eb", "sina": "EB", "exchange": "大商所", "ex": "DCE"}, - {"name": "液化气", "ths": "pg", "sina": "PG", "exchange": "大商所", "ex": "DCE"}, - {"name": "菜粕", "ths": "RM", "sina": "RM", "exchange": "郑商所", "ex": "CZCE"}, - {"name": "菜油", "ths": "OI", "sina": "OI", "exchange": "郑商所", "ex": "CZCE"}, - {"name": "白糖", "ths": "SR", "sina": "SR", "exchange": "郑商所", "ex": "CZCE"}, - {"name": "棉花", "ths": "CF", "sina": "CF", "exchange": "郑商所", "ex": "CZCE"}, - {"name": "甲醇", "ths": "MA", "sina": "MA", "exchange": "郑商所", "ex": "CZCE"}, - {"name": "PTA", "ths": "TA", "sina": "TA", "exchange": "郑商所", "ex": "CZCE"}, - {"name": "玻璃", "ths": "FG", "sina": "FG", "exchange": "郑商所", "ex": "CZCE"}, - {"name": "纯碱", "ths": "SA", "sina": "SA", "exchange": "郑商所", "ex": "CZCE"}, - {"name": "尿素", "ths": "UR", "sina": "UR", "exchange": "郑商所", "ex": "CZCE"}, - {"name": "硅铁", "ths": "SF", "sina": "SF", "exchange": "郑商所", "ex": "CZCE"}, - {"name": "锰硅", "ths": "SM", "sina": "SM", "exchange": "郑商所", "ex": "CZCE"}, - {"name": "苹果", "ths": "AP", "sina": "AP", "exchange": "郑商所", "ex": "CZCE"}, - {"name": "红枣", "ths": "CJ", "sina": "CJ", "exchange": "郑商所", "ex": "CZCE"}, - {"name": "花生", "ths": "PK", "sina": "PK", "exchange": "郑商所", "ex": "CZCE"}, - {"name": "沪深300", "ths": "IF", "sina": "IF", "exchange": "中金所", "ex": "CFFEX"}, - {"name": "上证50", "ths": "IH", "sina": "IH", "exchange": "中金所", "ex": "CFFEX"}, - {"name": "中证500", "ths": "IC", "sina": "IC", "exchange": "中金所", "ex": "CFFEX"}, - {"name": "中证1000", "ths": "IM", "sina": "IM", "exchange": "中金所", "ex": "CFFEX"}, -] - -PRODUCT_CATEGORY_MAP = { - "ag": "贵金属", "au": "贵金属", - "cu": "有色金属", "al": "有色金属", "zn": "有色金属", "pb": "有色金属", "ni": "有色金属", "sn": "有色金属", - "rb": "黑色金属", "hc": "黑色金属", "ss": "黑色金属", "i": "黑色金属", "j": "黑色金属", "jm": "黑色金属", - "SF": "黑色金属", "SM": "黑色金属", - "sc": "能源化工", "fu": "能源化工", "bu": "能源化工", "ru": "能源化工", "sp": "能源化工", - "l": "能源化工", "pp": "能源化工", "v": "能源化工", "eg": "能源化工", "eb": "能源化工", "pg": "能源化工", - "MA": "能源化工", "TA": "能源化工", "SA": "能源化工", "UR": "能源化工", "FG": "能源化工", - "m": "农产品", "y": "农产品", "p": "农产品", "c": "农产品", "cs": "农产品", "jd": "农产品", "lh": "农产品", - "RM": "农产品", "OI": "农产品", "SR": "农产品", "CF": "农产品", "AP": "农产品", "CJ": "农产品", "PK": "农产品", - "IF": "金融期货", "IH": "金融期货", "IC": "金融期货", "IM": "金融期货", -} -PRODUCT_CATEGORIES = ["贵金属", "有色金属", "黑色金属", "能源化工", "农产品", "金融期货"] - -for _p in PRODUCTS: - _p["category"] = PRODUCT_CATEGORY_MAP.get(_p["ths"], "其他") - -# 无夜盘品种(日盘-only):中金所股指、大商所鸡蛋/生猪等 -NO_NIGHT_SESSION_THS = frozenset({"IF", "IH", "IC", "IM", "jd", "lh"}) - - -def product_has_night_session(ths_or_product) -> bool: - """品种是否参与夜盘交易。""" - if isinstance(ths_or_product, dict): - ths = (ths_or_product.get("ths") or "").strip() - else: - ths = (ths_or_product or "").strip() - if not ths: - return True - m = re.match(r"^([A-Za-z]+)", ths) - letters = m.group(1) if m else ths - return letters not in NO_NIGHT_SESSION_THS and letters.upper() not in NO_NIGHT_SESSION_THS - - -def filter_for_trading_session(rows: list[dict]) -> list[dict]: - """夜盘时段隐藏无夜盘品种。""" - from market_sessions import is_night_trading_session - - if not is_night_trading_session(): - return rows - out: list[dict] = [] - for row in rows: - if row.get("has_night_session") is False: - continue - ths = row.get("ths") or row.get("ths_code") or "" - if row.get("has_night_session") is True or product_has_night_session(ths): - out.append(row) - return out - - -def product_category(ths: str) -> str: - return PRODUCT_CATEGORY_MAP.get((ths or "").strip(), "其他") - - -EXCHANGE_ORDER = ["上期所", "上期能源", "大商所", "郑商所", "中金所"] -_MAIN_CACHE: dict[str, tuple[float, dict]] = {} -_CACHE_TTL = 300 -_main_index_lock = threading.Lock() -_main_index: dict[str, dict] = {} -_main_index_ts = 0.0 -_index_refresh_lock = threading.Lock() - - -def build_ths_code(product: dict, year: int, month: int) -> str: - """同花顺软件内显示的合约代码。""" - ex = product["ex"] - letters = product["ths"] - if ex == "CZCE": - return f"{letters}{year % 10}{month:02d}" - return f"{letters}{year % 100:02d}{month:02d}" - - -def build_ths_full_code(product: dict, year: int, month: int) -> str: - """同花顺 iFinD HTTP API 代码,如 ag2608.SHFE""" - ths = build_ths_code(product, year, month) - suffix = THS_EX_SUFFIX.get(product["ex"], product["ex"]) - return f"{ths}.{suffix}" - - -def build_sina_code(product: dict, year: int, month: int) -> str: - letters = product["sina"] - suffix = f"{year % 100:02d}{month:02d}" - if product["ex"] == "CFFEX": - return f"CFF_RE_{letters}{suffix}" - return f"nf_{letters}{suffix}" - - -def build_sina_main_code(product: dict) -> str: - letters = product["sina"] - if product["ex"] == "CFFEX": - return f"CFF_RE_{letters}0" - return f"nf_{letters}0" - - -def _find_product_by_letters(letters: str) -> Optional[dict]: - letters_up = letters.upper() - for p in PRODUCTS: - if p["ths"].upper() == letters_up or p["sina"] == letters_up: - return p - return None - - -def _product_codes(product: dict, ths_code: str, market_code: str, sina_code: str) -> dict: - return { - "ths_code": ths_code, - "market_code": market_code, - "sina_code": sina_code, - "ex": product["ex"], - "name": product["name"], - "exchange": product["exchange"], - } - - -def ths_to_codes(ths_code: str) -> Optional[dict]: - """同花顺合约代码 -> ths_full + sina 回退代码。""" - code = ths_code.strip() - if not code: - return None - - m4 = re.match(r"^([A-Za-z]+)(\d{4})$", code) - if m4: - letters, digits = m4.group(1), m4.group(2) - year = 2000 + int(digits[:2]) - month = int(digits[2:]) - if not 1 <= month <= 12: - return None - product = _find_product_by_letters(letters) - if product: - ths = build_ths_code(product, year, month) - return _product_codes( - product, - ths, - build_ths_full_code(product, year, month), - build_sina_code(product, year, month), - ) - letters_up = letters.upper() - if letters_up in ("IF", "IH", "IC", "IM", "T", "TF", "TS"): - ths = f"{letters_up}{digits}" - return { - "ths_code": ths, - "market_code": f"{ths}.CFFEX", - "sina_code": f"CFF_RE_{letters_up}{digits}", - "ex": "CFFEX", - "name": letters_up, - "exchange": "中金所", - } - - m3 = re.match(r"^([A-Za-z]+)(\d{3})$", code) - if m3: - letters, digits = m3.group(1), m3.group(2) - y_digit = int(digits[0]) - month = int(digits[1:]) - if not 1 <= month <= 12: - return None - year = date.today().year - decade = year // 10 * 10 - candidate = decade + y_digit - if candidate < year - 1: - candidate += 10 - product = _find_product_by_letters(letters) - if product: - ths = build_ths_code(product, candidate, month) - return _product_codes( - product, - ths, - build_ths_full_code(product, candidate, month), - build_sina_code(product, candidate, month), - ) - - return None - - -def ths_to_sina_code(ths_code: str) -> Optional[str]: - codes = ths_to_codes(ths_code) - return codes["sina_code"] if codes else None - - -def parse_contract_year_month(ths_code: str) -> Optional[tuple[int, int]]: - """从同花顺合约代码解析交割年月。""" - code = (ths_code or "").strip() - if not code or "888" in code: - return None - m4 = re.match(r"^([A-Za-z]+)(\d{4})$", code) - if m4: - digits = m4.group(2) - year = 2000 + int(digits[:2]) - month = int(digits[2:]) - if 1 <= month <= 12: - return year, month - m3 = re.match(r"^([A-Za-z]+)(\d{3})$", code) - if m3: - letters, digits = m3.group(1), m3.group(2) - month = int(digits[1:]) - if not 1 <= month <= 12: - return None - y_digit = int(digits[0]) - year = date.today().year - decade = year // 10 * 10 - candidate = decade + y_digit - if candidate < year - 1: - candidate += 10 - product = _find_product_by_letters(letters) - if product: - return candidate, month - return None - - -def is_near_expiry_main(ths_code: str) -> bool: - """主力合约交割月为当月或下月时视为临期。""" - ym = parse_contract_year_month(ths_code) - if not ym: - return False - cy, cm = ym - today = date.today() - months_ahead = (cy - today.year) * 12 + (cm - today.month) - return months_ahead <= 1 - - -def _main_contract_score(raw: dict) -> float: - """主力判定:优先持仓量,其次成交量。""" - oi = float(raw.get("open_interest") or 0) - vol = float(raw.get("volume") or 0) - return oi if oi > 0 else vol - - -def _make_symbol_item( - product: dict, - year: int, - month: int, - volume: float, - open_interest: float = 0, -) -> dict: - ths = build_ths_code(product, year, month) - name = product["name"] - return { - "name": name, - "ths_code": ths, - "market_code": build_ths_full_code(product, year, month), - "sina_code": build_sina_code(product, year, month), - "exchange": product["exchange"], - "contract": f"主力 {ths}", - "display": f"{name} 主力 {ths}", - "input_label": f"{name} {ths}", - "volume": volume, - "open_interest": open_interest, - } - - -def resolve_main_contract(product: dict) -> Optional[dict]: - cache_key = product["sina"] - now = time.time() - cached = _MAIN_CACHE.get(cache_key) - if cached and now - cached[0] < _CACHE_TTL: - return cached[1] - - today = date.today() - y, m = today.year, today.month - best = None - best_score = 0.0 - - for i in range(14): - cy, cm = y, m + i - while cm > 12: - cm -= 12 - cy += 1 - sina = build_sina_code(product, cy, cm) - raw = fetch_raw_for_volume(sina) - if not raw: - continue - score = _main_contract_score(raw) - if score <= 0: - continue - item = _make_symbol_item( - product, cy, cm, raw["volume"], raw.get("open_interest", 0), - ) - if score > best_score: - best_score = score - best = item - - if best is None: - sina_main = build_sina_main_code(product) - raw = fetch_raw_for_volume(sina_main) - if raw: - ths_letters = product["ths"] - ths_main = ( - f"{ths_letters}888" - if product["ex"] != "CFFEX" - else f"{ths_letters.upper()}888" - ) - suffix = THS_EX_SUFFIX.get(product["ex"], product["ex"]) - best = { - "name": product["name"], - "ths_code": ths_main, - "market_code": f"{ths_main}.{suffix}", - "sina_code": sina_main, - "exchange": product["exchange"], - "contract": f"主力连续 {ths_main}", - "display": f"{product['name']} 主力连续 {ths_main}", - "input_label": f"{product['name']} {ths_main}", - "volume": raw.get("volume", 0), - } - - if best: - best = _enrich_item(best, product) - _MAIN_CACHE[cache_key] = (now, best) - return best - - -def _enrich_item(item: dict, product: Optional[dict] = None) -> dict: - out = dict(item) - if not out.get("input_label"): - out["input_label"] = f"{out.get('name', '')} {out.get('ths_code', '')}".strip() - out["near_expiry"] = is_near_expiry_main(out.get("ths_code", "")) - if product is None and out.get("ths_code"): - product = _product_for_contract_code(out["ths_code"]) - if product is not None: - out["has_night_session"] = product_has_night_session(product) - elif "has_night_session" not in out: - out["has_night_session"] = product_has_night_session(out.get("ths_code") or "") - return out - - -def refresh_main_index(): - """后台预热全部品种主力合约,搜索时只读本地缓存。""" - global _main_index, _main_index_ts - with _index_refresh_lock: - new_idx: dict[str, dict] = {} - with ThreadPoolExecutor(max_workers=10) as pool: - futures = {pool.submit(resolve_main_contract, p): p for p in PRODUCTS} - for fut in as_completed(futures): - product = futures[fut] - try: - main = fut.result() - if main: - new_idx[product["sina"]] = _enrich_item(main, product) - except Exception: - pass - with _main_index_lock: - _main_index = new_idx - _main_index_ts = time.time() - - -def _warm_loop(): - while True: - try: - refresh_main_index() - except Exception: - pass - time.sleep(_CACHE_TTL) - - -def _start_warm_thread(): - threading.Thread(target=_warm_loop, daemon=True).start() - - -def _stub_main_contract(product: dict) -> dict: - """缓存未就绪时的快速占位(当月合约),避免首次打开搜索为空。""" - today = date.today() - return _enrich_item(_make_symbol_item(product, today.year, today.month, 0), product) - - -def _product_matches(product: dict, q_lower: str) -> bool: - name_lower = product["name"].lower() - if q_lower in name_lower: - return True - if len(q_lower) >= 2: - ths_lower = product["ths"].lower() - sina_lower = product["sina"].lower() - if q_lower in ths_lower or q_lower in sina_lower: - return True - return False - - -def _match_score(product: dict, q_lower: str) -> int: - name_lower = product["name"].lower() - if name_lower == q_lower: - return 200 - if name_lower.startswith(q_lower): - return 150 - if q_lower in name_lower: - return 100 - ths_lower = product["ths"].lower() - if ths_lower == q_lower: - return 90 - if ths_lower.startswith(q_lower): - return 70 - if product["sina"].lower() == q_lower: - return 80 - return 10 - - -def search_symbols(query: str, *, capital: float | None = None, ctp_connected: bool = True) -> list: - q = query.strip() - if not q: - return [] - - q_lower = q.lower() - from market_sessions import is_night_trading_session - from product_recommend import filter_products_for_capital, should_apply_small_account_scope - - night_only = is_night_trading_session() - product_pool = PRODUCTS - if capital is not None and should_apply_small_account_scope(capital, ctp_connected=ctp_connected): - product_pool = filter_products_for_capital( - PRODUCTS, capital, ctp_connected=ctp_connected, - ) - with _main_index_lock: - index = dict(_main_index) - index_ready = bool(index) - - scored: list[tuple[int, dict]] = [] - for p in product_pool: - if night_only and not product_has_night_session(p): - continue - if not _product_matches(p, q_lower): - continue - main = index.get(p["sina"]) - if not main and not index_ready: - main = _stub_main_contract(p) - if main: - scored.append((_match_score(p, q_lower), main)) - - scored.sort(key=lambda x: -x[0]) - results = [item for _, item in scored[:12]] - results = filter_for_trading_session(results) - - if not results and len(q) >= 3: - codes = ths_to_codes(q) - if codes: - product = _product_for_contract_code(codes["ths_code"]) - if capital is not None and should_apply_small_account_scope( - capital, ctp_connected=ctp_connected, - ): - from product_recommend import product_in_small_account_whitelist - if not product or not product_in_small_account_whitelist(product): - return results - raw = fetch_raw_for_volume(codes["sina_code"]) - name = raw["name"] if raw else q - results.append(_enrich_item({ - "name": name, - "ths_code": codes["ths_code"], - "market_code": codes["market_code"], - "sina_code": codes["sina_code"], - "exchange": "", - "contract": codes["ths_code"], - "display": f"{name} ({codes['ths_code']})", - "volume": raw.get("volume", 0) if raw else 0, - })) - results = filter_for_trading_session(results) - - return results - - -def enrich_recommend_row(row: dict) -> dict: - """补全推荐行字段(含是否夜盘)。""" - out = dict(row) - ths = out.get("ths") or "" - out["has_night_session"] = product_has_night_session(ths) - return out - - -_THS_TO_PRODUCT = {p["ths"]: p for p in PRODUCTS} -for _p in PRODUCTS: - _THS_TO_PRODUCT.setdefault(_p["ths"].lower(), _p) - - -def _product_for_ths(ths: str) -> Optional[dict]: - key = (ths or "").strip() - if not key: - return None - return _THS_TO_PRODUCT.get(key) or _THS_TO_PRODUCT.get(key.lower()) - - -def _product_for_contract_code(ths_code: str) -> Optional[dict]: - sym = (ths_code or "").strip() - if not sym: - return None - m = re.match(r"^([A-Za-z]+)", sym) - if not m: - return None - return _find_product_by_letters(m.group(1)) - - -def position_symbol_meta(ths_code: str) -> dict: - """持仓/委托展示:品种名、交易所、是否主力合约。""" - sym = (ths_code or "").strip() - if not sym: - return {"name": "", "exchange": "", "is_main": False} - product = _product_for_contract_code(sym) - if not product: - return {"name": sym, "exchange": "", "is_main": False} - codes = ths_to_codes(sym) - norm = (codes["ths_code"] if codes else sym).strip().lower() - is_main = False - with _main_index_lock: - main_item = _main_index.get(product["sina"]) - if main_item: - main_ths = (main_item.get("ths_code") or "").strip().lower() - is_main = main_ths == norm or main_ths == sym.lower() - return { - "name": product["name"], - "exchange": product.get("exchange") or "", - "is_main": is_main, - } - - -def _item_from_recommend_row(row: dict, product: dict) -> Optional[dict]: - """由可开仓缓存行快速构造下拉项(不在 HTTP 请求中解析主力)。""" - name = row.get("name") or product["name"] - main_code = (row.get("main_code") or "").strip() - max_lots = row.get("max_lots") - - if main_code: - codes = ths_to_codes(main_code) - if codes: - ths = codes["ths_code"] - item = { - "name": name, - "ths_code": ths, - "market_code": codes.get("market_code") or "", - "sina_code": codes.get("sina_code") or "", - "exchange": product["exchange"], - "contract": f"主力 {ths}", - "display": f"{name} 主力 {ths}", - "input_label": f"{name} {ths}", - } - if max_lots is not None: - item["max_lots"] = max_lots - return _enrich_item(item, product) - - with _main_index_lock: - main = _main_index.get(product["sina"]) - if main: - item = dict(main) - if max_lots is not None: - item["max_lots"] = max_lots - return _enrich_item(item, product) - - item = _stub_main_contract(product) - if max_lots is not None: - item["max_lots"] = max_lots - return item - - -def list_recommended_symbols_grouped(recommend_rows: list[dict]) -> list[dict]: - """按交易所分类返回可开仓品种对应的主力合约(品种选择下拉用)。""" - if not recommend_rows: - return [] - - buckets: dict[str, list] = defaultdict(list) - seen: set[str] = set() - for row in recommend_rows: - if row.get("status") not in ("ok", "margin_ok"): - continue - ths_key = (row.get("ths") or "").strip() - if not ths_key or ths_key in seen: - continue - product = _product_for_ths(ths_key) - if not product: - continue - if not product_has_night_session(product): - from market_sessions import is_night_trading_session - if is_night_trading_session(): - continue - seen.add(ths_key) - item = _item_from_recommend_row(row, product) - if not item: - continue - buckets[product["exchange"]].append(item) - - groups: list[dict] = [] - for cat in EXCHANGE_ORDER: - items = buckets.get(cat) - if items: - groups.append({"category": cat, "items": items}) - return groups - - -def list_main_contracts_grouped() -> list[dict]: - """按交易所分类返回全部品种主力合约(行情页下拉用)。""" - with _main_index_lock: - index = dict(_main_index) - - if len(index) < len(PRODUCTS) // 2: - refresh_main_index() - with _main_index_lock: - index = dict(_main_index) - - buckets: dict[str, list] = defaultdict(list) - for p in PRODUCTS: - main = index.get(p["sina"]) - if not main: - resolved = resolve_main_contract(p) - if resolved: - main = _enrich_item(resolved) - if main: - buckets[p["exchange"]].append(main) - - groups: list[dict] = [] - for cat in EXCHANGE_ORDER: - items = buckets.get(cat) - if items: - groups.append({"category": cat, "items": items}) - return groups - - -_start_warm_thread() - - -def get_price(market_code: str, sina_code: str = "") -> Optional[float]: - return market_get_price(market_code, sina_code) +# Copyright (c) 2025-2026 马建军. All rights reserved. +# 专有软件 — 未经授权禁止复制、传播、转售。 +# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 +# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md + +""" +期货品种与同花顺代码映射。 +展示同花顺合约代码(ag2608);行情默认新浪,机构用户可通过环境变量启用同花顺 iFinD。 +""" +import re +import threading +import time +from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor, as_completed +from datetime import date +from typing import Optional + +from modules.market.market import fetch_raw_for_volume, get_price as market_get_price, THS_EX_SUFFIX + +PRODUCTS = [ + {"name": "白银", "ths": "ag", "sina": "AG", "exchange": "上期所", "ex": "SHFE"}, + {"name": "黄金", "ths": "au", "sina": "AU", "exchange": "上期所", "ex": "SHFE"}, + {"name": "铜", "ths": "cu", "sina": "CU", "exchange": "上期所", "ex": "SHFE"}, + {"name": "铝", "ths": "al", "sina": "AL", "exchange": "上期所", "ex": "SHFE"}, + {"name": "锌", "ths": "zn", "sina": "ZN", "exchange": "上期所", "ex": "SHFE"}, + {"name": "铅", "ths": "pb", "sina": "PB", "exchange": "上期所", "ex": "SHFE"}, + {"name": "镍", "ths": "ni", "sina": "NI", "exchange": "上期所", "ex": "SHFE"}, + {"name": "锡", "ths": "sn", "sina": "SN", "exchange": "上期所", "ex": "SHFE"}, + {"name": "螺纹钢", "ths": "rb", "sina": "RB", "exchange": "上期所", "ex": "SHFE"}, + {"name": "热卷", "ths": "hc", "sina": "HC", "exchange": "上期所", "ex": "SHFE"}, + {"name": "不锈钢", "ths": "ss", "sina": "SS", "exchange": "上期所", "ex": "SHFE"}, + {"name": "原油", "ths": "sc", "sina": "SC", "exchange": "上期能源", "ex": "INE"}, + {"name": "燃油", "ths": "fu", "sina": "FU", "exchange": "上期所", "ex": "SHFE"}, + {"name": "沥青", "ths": "bu", "sina": "BU", "exchange": "上期所", "ex": "SHFE"}, + {"name": "橡胶", "ths": "ru", "sina": "RU", "exchange": "上期所", "ex": "SHFE"}, + {"name": "纸浆", "ths": "sp", "sina": "SP", "exchange": "上期所", "ex": "SHFE"}, + {"name": "铁矿石", "ths": "i", "sina": "I", "exchange": "大商所", "ex": "DCE"}, + {"name": "焦炭", "ths": "j", "sina": "J", "exchange": "大商所", "ex": "DCE"}, + {"name": "焦煤", "ths": "jm", "sina": "JM", "exchange": "大商所", "ex": "DCE"}, + {"name": "豆粕", "ths": "m", "sina": "M", "exchange": "大商所", "ex": "DCE"}, + {"name": "豆油", "ths": "y", "sina": "Y", "exchange": "大商所", "ex": "DCE"}, + {"name": "棕榈油", "ths": "p", "sina": "P", "exchange": "大商所", "ex": "DCE"}, + {"name": "玉米", "ths": "c", "sina": "C", "exchange": "大商所", "ex": "DCE"}, + {"name": "淀粉", "ths": "cs", "sina": "CS", "exchange": "大商所", "ex": "DCE"}, + {"name": "鸡蛋", "ths": "jd", "sina": "JD", "exchange": "大商所", "ex": "DCE"}, + {"name": "生猪", "ths": "lh", "sina": "LH", "exchange": "大商所", "ex": "DCE"}, + {"name": "聚乙烯", "ths": "l", "sina": "L", "exchange": "大商所", "ex": "DCE"}, + {"name": "聚丙烯", "ths": "pp", "sina": "PP", "exchange": "大商所", "ex": "DCE"}, + {"name": "PVC", "ths": "v", "sina": "V", "exchange": "大商所", "ex": "DCE"}, + {"name": "乙二醇", "ths": "eg", "sina": "EG", "exchange": "大商所", "ex": "DCE"}, + {"name": "苯乙烯", "ths": "eb", "sina": "EB", "exchange": "大商所", "ex": "DCE"}, + {"name": "液化气", "ths": "pg", "sina": "PG", "exchange": "大商所", "ex": "DCE"}, + {"name": "菜粕", "ths": "RM", "sina": "RM", "exchange": "郑商所", "ex": "CZCE"}, + {"name": "菜油", "ths": "OI", "sina": "OI", "exchange": "郑商所", "ex": "CZCE"}, + {"name": "白糖", "ths": "SR", "sina": "SR", "exchange": "郑商所", "ex": "CZCE"}, + {"name": "棉花", "ths": "CF", "sina": "CF", "exchange": "郑商所", "ex": "CZCE"}, + {"name": "甲醇", "ths": "MA", "sina": "MA", "exchange": "郑商所", "ex": "CZCE"}, + {"name": "PTA", "ths": "TA", "sina": "TA", "exchange": "郑商所", "ex": "CZCE"}, + {"name": "玻璃", "ths": "FG", "sina": "FG", "exchange": "郑商所", "ex": "CZCE"}, + {"name": "纯碱", "ths": "SA", "sina": "SA", "exchange": "郑商所", "ex": "CZCE"}, + {"name": "尿素", "ths": "UR", "sina": "UR", "exchange": "郑商所", "ex": "CZCE"}, + {"name": "硅铁", "ths": "SF", "sina": "SF", "exchange": "郑商所", "ex": "CZCE"}, + {"name": "锰硅", "ths": "SM", "sina": "SM", "exchange": "郑商所", "ex": "CZCE"}, + {"name": "苹果", "ths": "AP", "sina": "AP", "exchange": "郑商所", "ex": "CZCE"}, + {"name": "红枣", "ths": "CJ", "sina": "CJ", "exchange": "郑商所", "ex": "CZCE"}, + {"name": "花生", "ths": "PK", "sina": "PK", "exchange": "郑商所", "ex": "CZCE"}, + {"name": "沪深300", "ths": "IF", "sina": "IF", "exchange": "中金所", "ex": "CFFEX"}, + {"name": "上证50", "ths": "IH", "sina": "IH", "exchange": "中金所", "ex": "CFFEX"}, + {"name": "中证500", "ths": "IC", "sina": "IC", "exchange": "中金所", "ex": "CFFEX"}, + {"name": "中证1000", "ths": "IM", "sina": "IM", "exchange": "中金所", "ex": "CFFEX"}, +] + +PRODUCT_CATEGORY_MAP = { + "ag": "贵金属", "au": "贵金属", + "cu": "有色金属", "al": "有色金属", "zn": "有色金属", "pb": "有色金属", "ni": "有色金属", "sn": "有色金属", + "rb": "黑色金属", "hc": "黑色金属", "ss": "黑色金属", "i": "黑色金属", "j": "黑色金属", "jm": "黑色金属", + "SF": "黑色金属", "SM": "黑色金属", + "sc": "能源化工", "fu": "能源化工", "bu": "能源化工", "ru": "能源化工", "sp": "能源化工", + "l": "能源化工", "pp": "能源化工", "v": "能源化工", "eg": "能源化工", "eb": "能源化工", "pg": "能源化工", + "MA": "能源化工", "TA": "能源化工", "SA": "能源化工", "UR": "能源化工", "FG": "能源化工", + "m": "农产品", "y": "农产品", "p": "农产品", "c": "农产品", "cs": "农产品", "jd": "农产品", "lh": "农产品", + "RM": "农产品", "OI": "农产品", "SR": "农产品", "CF": "农产品", "AP": "农产品", "CJ": "农产品", "PK": "农产品", + "IF": "金融期货", "IH": "金融期货", "IC": "金融期货", "IM": "金融期货", +} +PRODUCT_CATEGORIES = ["贵金属", "有色金属", "黑色金属", "能源化工", "农产品", "金融期货"] + +for _p in PRODUCTS: + _p["category"] = PRODUCT_CATEGORY_MAP.get(_p["ths"], "其他") + +# 无夜盘品种(日盘-only):中金所股指、大商所鸡蛋/生猪等 +NO_NIGHT_SESSION_THS = frozenset({"IF", "IH", "IC", "IM", "jd", "lh"}) + + +def product_has_night_session(ths_or_product) -> bool: + """品种是否参与夜盘交易。""" + if isinstance(ths_or_product, dict): + ths = (ths_or_product.get("ths") or "").strip() + else: + ths = (ths_or_product or "").strip() + if not ths: + return True + m = re.match(r"^([A-Za-z]+)", ths) + letters = m.group(1) if m else ths + return letters not in NO_NIGHT_SESSION_THS and letters.upper() not in NO_NIGHT_SESSION_THS + + +def filter_for_trading_session(rows: list[dict]) -> list[dict]: + """夜盘时段隐藏无夜盘品种。""" + from modules.market.market_sessions import is_night_trading_session + + if not is_night_trading_session(): + return rows + out: list[dict] = [] + for row in rows: + if row.get("has_night_session") is False: + continue + ths = row.get("ths") or row.get("ths_code") or "" + if row.get("has_night_session") is True or product_has_night_session(ths): + out.append(row) + return out + + +def product_category(ths: str) -> str: + return PRODUCT_CATEGORY_MAP.get((ths or "").strip(), "其他") + + +EXCHANGE_ORDER = ["上期所", "上期能源", "大商所", "郑商所", "中金所"] +_MAIN_CACHE: dict[str, tuple[float, dict]] = {} +_CACHE_TTL = 300 +_main_index_lock = threading.Lock() +_main_index: dict[str, dict] = {} +_main_index_ts = 0.0 +_index_refresh_lock = threading.Lock() + + +def build_ths_code(product: dict, year: int, month: int) -> str: + """同花顺软件内显示的合约代码。""" + ex = product["ex"] + letters = product["ths"] + if ex == "CZCE": + return f"{letters}{year % 10}{month:02d}" + return f"{letters}{year % 100:02d}{month:02d}" + + +def build_ths_full_code(product: dict, year: int, month: int) -> str: + """同花顺 iFinD HTTP API 代码,如 ag2608.SHFE""" + ths = build_ths_code(product, year, month) + suffix = THS_EX_SUFFIX.get(product["ex"], product["ex"]) + return f"{ths}.{suffix}" + + +def build_sina_code(product: dict, year: int, month: int) -> str: + letters = product["sina"] + suffix = f"{year % 100:02d}{month:02d}" + if product["ex"] == "CFFEX": + return f"CFF_RE_{letters}{suffix}" + return f"nf_{letters}{suffix}" + + +def build_sina_main_code(product: dict) -> str: + letters = product["sina"] + if product["ex"] == "CFFEX": + return f"CFF_RE_{letters}0" + return f"nf_{letters}0" + + +def _find_product_by_letters(letters: str) -> Optional[dict]: + letters_up = letters.upper() + for p in PRODUCTS: + if p["ths"].upper() == letters_up or p["sina"] == letters_up: + return p + return None + + +def _product_codes(product: dict, ths_code: str, market_code: str, sina_code: str) -> dict: + return { + "ths_code": ths_code, + "market_code": market_code, + "sina_code": sina_code, + "ex": product["ex"], + "name": product["name"], + "exchange": product["exchange"], + } + + +def ths_to_codes(ths_code: str) -> Optional[dict]: + """同花顺合约代码 -> ths_full + sina 回退代码。""" + code = ths_code.strip() + if not code: + return None + + m4 = re.match(r"^([A-Za-z]+)(\d{4})$", code) + if m4: + letters, digits = m4.group(1), m4.group(2) + year = 2000 + int(digits[:2]) + month = int(digits[2:]) + if not 1 <= month <= 12: + return None + product = _find_product_by_letters(letters) + if product: + ths = build_ths_code(product, year, month) + return _product_codes( + product, + ths, + build_ths_full_code(product, year, month), + build_sina_code(product, year, month), + ) + letters_up = letters.upper() + if letters_up in ("IF", "IH", "IC", "IM", "T", "TF", "TS"): + ths = f"{letters_up}{digits}" + return { + "ths_code": ths, + "market_code": f"{ths}.CFFEX", + "sina_code": f"CFF_RE_{letters_up}{digits}", + "ex": "CFFEX", + "name": letters_up, + "exchange": "中金所", + } + + m3 = re.match(r"^([A-Za-z]+)(\d{3})$", code) + if m3: + letters, digits = m3.group(1), m3.group(2) + y_digit = int(digits[0]) + month = int(digits[1:]) + if not 1 <= month <= 12: + return None + year = date.today().year + decade = year // 10 * 10 + candidate = decade + y_digit + if candidate < year - 1: + candidate += 10 + product = _find_product_by_letters(letters) + if product: + ths = build_ths_code(product, candidate, month) + return _product_codes( + product, + ths, + build_ths_full_code(product, candidate, month), + build_sina_code(product, candidate, month), + ) + + return None + + +def ths_to_sina_code(ths_code: str) -> Optional[str]: + codes = ths_to_codes(ths_code) + return codes["sina_code"] if codes else None + + +def parse_contract_year_month(ths_code: str) -> Optional[tuple[int, int]]: + """从同花顺合约代码解析交割年月。""" + code = (ths_code or "").strip() + if not code or "888" in code: + return None + m4 = re.match(r"^([A-Za-z]+)(\d{4})$", code) + if m4: + digits = m4.group(2) + year = 2000 + int(digits[:2]) + month = int(digits[2:]) + if 1 <= month <= 12: + return year, month + m3 = re.match(r"^([A-Za-z]+)(\d{3})$", code) + if m3: + letters, digits = m3.group(1), m3.group(2) + month = int(digits[1:]) + if not 1 <= month <= 12: + return None + y_digit = int(digits[0]) + year = date.today().year + decade = year // 10 * 10 + candidate = decade + y_digit + if candidate < year - 1: + candidate += 10 + product = _find_product_by_letters(letters) + if product: + return candidate, month + return None + + +def is_near_expiry_main(ths_code: str) -> bool: + """主力合约交割月为当月或下月时视为临期。""" + ym = parse_contract_year_month(ths_code) + if not ym: + return False + cy, cm = ym + today = date.today() + months_ahead = (cy - today.year) * 12 + (cm - today.month) + return months_ahead <= 1 + + +def _main_contract_score(raw: dict) -> float: + """主力判定:优先持仓量,其次成交量。""" + oi = float(raw.get("open_interest") or 0) + vol = float(raw.get("volume") or 0) + return oi if oi > 0 else vol + + +def _make_symbol_item( + product: dict, + year: int, + month: int, + volume: float, + open_interest: float = 0, +) -> dict: + ths = build_ths_code(product, year, month) + name = product["name"] + return { + "name": name, + "ths_code": ths, + "market_code": build_ths_full_code(product, year, month), + "sina_code": build_sina_code(product, year, month), + "exchange": product["exchange"], + "contract": f"主力 {ths}", + "display": f"{name} 主力 {ths}", + "input_label": f"{name} {ths}", + "volume": volume, + "open_interest": open_interest, + } + + +def resolve_main_contract(product: dict) -> Optional[dict]: + cache_key = product["sina"] + now = time.time() + cached = _MAIN_CACHE.get(cache_key) + if cached and now - cached[0] < _CACHE_TTL: + return cached[1] + + today = date.today() + y, m = today.year, today.month + best = None + best_score = 0.0 + + for i in range(14): + cy, cm = y, m + i + while cm > 12: + cm -= 12 + cy += 1 + sina = build_sina_code(product, cy, cm) + raw = fetch_raw_for_volume(sina) + if not raw: + continue + score = _main_contract_score(raw) + if score <= 0: + continue + item = _make_symbol_item( + product, cy, cm, raw["volume"], raw.get("open_interest", 0), + ) + if score > best_score: + best_score = score + best = item + + if best is None: + sina_main = build_sina_main_code(product) + raw = fetch_raw_for_volume(sina_main) + if raw: + ths_letters = product["ths"] + ths_main = ( + f"{ths_letters}888" + if product["ex"] != "CFFEX" + else f"{ths_letters.upper()}888" + ) + suffix = THS_EX_SUFFIX.get(product["ex"], product["ex"]) + best = { + "name": product["name"], + "ths_code": ths_main, + "market_code": f"{ths_main}.{suffix}", + "sina_code": sina_main, + "exchange": product["exchange"], + "contract": f"主力连续 {ths_main}", + "display": f"{product['name']} 主力连续 {ths_main}", + "input_label": f"{product['name']} {ths_main}", + "volume": raw.get("volume", 0), + } + + if best: + best = _enrich_item(best, product) + _MAIN_CACHE[cache_key] = (now, best) + return best + + +def _enrich_item(item: dict, product: Optional[dict] = None) -> dict: + out = dict(item) + if not out.get("input_label"): + out["input_label"] = f"{out.get('name', '')} {out.get('ths_code', '')}".strip() + out["near_expiry"] = is_near_expiry_main(out.get("ths_code", "")) + if product is None and out.get("ths_code"): + product = _product_for_contract_code(out["ths_code"]) + if product is not None: + out["has_night_session"] = product_has_night_session(product) + elif "has_night_session" not in out: + out["has_night_session"] = product_has_night_session(out.get("ths_code") or "") + return out + + +def refresh_main_index(): + """后台预热全部品种主力合约,搜索时只读本地缓存。""" + global _main_index, _main_index_ts + with _index_refresh_lock: + new_idx: dict[str, dict] = {} + with ThreadPoolExecutor(max_workers=10) as pool: + futures = {pool.submit(resolve_main_contract, p): p for p in PRODUCTS} + for fut in as_completed(futures): + product = futures[fut] + try: + main = fut.result() + if main: + new_idx[product["sina"]] = _enrich_item(main, product) + except Exception: + pass + with _main_index_lock: + _main_index = new_idx + _main_index_ts = time.time() + + +def _warm_loop(): + while True: + try: + refresh_main_index() + except Exception: + pass + time.sleep(_CACHE_TTL) + + +def _start_warm_thread(): + threading.Thread(target=_warm_loop, daemon=True).start() + + +def _stub_main_contract(product: dict) -> dict: + """缓存未就绪时的快速占位(当月合约),避免首次打开搜索为空。""" + today = date.today() + return _enrich_item(_make_symbol_item(product, today.year, today.month, 0), product) + + +def _product_matches(product: dict, q_lower: str) -> bool: + name_lower = product["name"].lower() + if q_lower in name_lower: + return True + if len(q_lower) >= 2: + ths_lower = product["ths"].lower() + sina_lower = product["sina"].lower() + if q_lower in ths_lower or q_lower in sina_lower: + return True + return False + + +def _match_score(product: dict, q_lower: str) -> int: + name_lower = product["name"].lower() + if name_lower == q_lower: + return 200 + if name_lower.startswith(q_lower): + return 150 + if q_lower in name_lower: + return 100 + ths_lower = product["ths"].lower() + if ths_lower == q_lower: + return 90 + if ths_lower.startswith(q_lower): + return 70 + if product["sina"].lower() == q_lower: + return 80 + return 10 + + +def search_symbols(query: str, *, capital: float | None = None, ctp_connected: bool = True) -> list: + q = query.strip() + if not q: + return [] + + q_lower = q.lower() + from modules.market.market_sessions import is_night_trading_session + from modules.trading.product_recommend import filter_products_for_capital, should_apply_small_account_scope + + night_only = is_night_trading_session() + product_pool = PRODUCTS + if capital is not None and should_apply_small_account_scope(capital, ctp_connected=ctp_connected): + product_pool = filter_products_for_capital( + PRODUCTS, capital, ctp_connected=ctp_connected, + ) + with _main_index_lock: + index = dict(_main_index) + index_ready = bool(index) + + scored: list[tuple[int, dict]] = [] + for p in product_pool: + if night_only and not product_has_night_session(p): + continue + if not _product_matches(p, q_lower): + continue + main = index.get(p["sina"]) + if not main and not index_ready: + main = _stub_main_contract(p) + if main: + scored.append((_match_score(p, q_lower), main)) + + scored.sort(key=lambda x: -x[0]) + results = [item for _, item in scored[:12]] + results = filter_for_trading_session(results) + + if not results and len(q) >= 3: + codes = ths_to_codes(q) + if codes: + product = _product_for_contract_code(codes["ths_code"]) + if capital is not None and should_apply_small_account_scope( + capital, ctp_connected=ctp_connected, + ): + from modules.trading.product_recommend import product_in_small_account_whitelist + if not product or not product_in_small_account_whitelist(product): + return results + raw = fetch_raw_for_volume(codes["sina_code"]) + name = raw["name"] if raw else q + results.append(_enrich_item({ + "name": name, + "ths_code": codes["ths_code"], + "market_code": codes["market_code"], + "sina_code": codes["sina_code"], + "exchange": "", + "contract": codes["ths_code"], + "display": f"{name} ({codes['ths_code']})", + "volume": raw.get("volume", 0) if raw else 0, + })) + results = filter_for_trading_session(results) + + return results + + +def enrich_recommend_row(row: dict) -> dict: + """补全推荐行字段(含是否夜盘)。""" + out = dict(row) + ths = out.get("ths") or "" + out["has_night_session"] = product_has_night_session(ths) + return out + + +_THS_TO_PRODUCT = {p["ths"]: p for p in PRODUCTS} +for _p in PRODUCTS: + _THS_TO_PRODUCT.setdefault(_p["ths"].lower(), _p) + + +def _product_for_ths(ths: str) -> Optional[dict]: + key = (ths or "").strip() + if not key: + return None + return _THS_TO_PRODUCT.get(key) or _THS_TO_PRODUCT.get(key.lower()) + + +def _product_for_contract_code(ths_code: str) -> Optional[dict]: + sym = (ths_code or "").strip() + if not sym: + return None + m = re.match(r"^([A-Za-z]+)", sym) + if not m: + return None + return _find_product_by_letters(m.group(1)) + + +def position_symbol_meta(ths_code: str) -> dict: + """持仓/委托展示:品种名、交易所、是否主力合约。""" + sym = (ths_code or "").strip() + if not sym: + return {"name": "", "exchange": "", "is_main": False} + product = _product_for_contract_code(sym) + if not product: + return {"name": sym, "exchange": "", "is_main": False} + codes = ths_to_codes(sym) + norm = (codes["ths_code"] if codes else sym).strip().lower() + is_main = False + with _main_index_lock: + main_item = _main_index.get(product["sina"]) + if main_item: + main_ths = (main_item.get("ths_code") or "").strip().lower() + is_main = main_ths == norm or main_ths == sym.lower() + return { + "name": product["name"], + "exchange": product.get("exchange") or "", + "is_main": is_main, + } + + +def _item_from_recommend_row(row: dict, product: dict) -> Optional[dict]: + """由可开仓缓存行快速构造下拉项(不在 HTTP 请求中解析主力)。""" + name = row.get("name") or product["name"] + main_code = (row.get("main_code") or "").strip() + max_lots = row.get("max_lots") + + if main_code: + codes = ths_to_codes(main_code) + if codes: + ths = codes["ths_code"] + item = { + "name": name, + "ths_code": ths, + "market_code": codes.get("market_code") or "", + "sina_code": codes.get("sina_code") or "", + "exchange": product["exchange"], + "contract": f"主力 {ths}", + "display": f"{name} 主力 {ths}", + "input_label": f"{name} {ths}", + } + if max_lots is not None: + item["max_lots"] = max_lots + return _enrich_item(item, product) + + with _main_index_lock: + main = _main_index.get(product["sina"]) + if main: + item = dict(main) + if max_lots is not None: + item["max_lots"] = max_lots + return _enrich_item(item, product) + + item = _stub_main_contract(product) + if max_lots is not None: + item["max_lots"] = max_lots + return item + + +def list_recommended_symbols_grouped(recommend_rows: list[dict]) -> list[dict]: + """按交易所分类返回可开仓品种对应的主力合约(品种选择下拉用)。""" + if not recommend_rows: + return [] + + buckets: dict[str, list] = defaultdict(list) + seen: set[str] = set() + for row in recommend_rows: + if row.get("status") not in ("ok", "margin_ok"): + continue + ths_key = (row.get("ths") or "").strip() + if not ths_key or ths_key in seen: + continue + product = _product_for_ths(ths_key) + if not product: + continue + if not product_has_night_session(product): + from modules.market.market_sessions import is_night_trading_session + if is_night_trading_session(): + continue + seen.add(ths_key) + item = _item_from_recommend_row(row, product) + if not item: + continue + buckets[product["exchange"]].append(item) + + groups: list[dict] = [] + for cat in EXCHANGE_ORDER: + items = buckets.get(cat) + if items: + groups.append({"category": cat, "items": items}) + return groups + + +def list_main_contracts_grouped() -> list[dict]: + """按交易所分类返回全部品种主力合约(行情页下拉用)。""" + with _main_index_lock: + index = dict(_main_index) + + if len(index) < len(PRODUCTS) // 2: + refresh_main_index() + with _main_index_lock: + index = dict(_main_index) + + buckets: dict[str, list] = defaultdict(list) + for p in PRODUCTS: + main = index.get(p["sina"]) + if not main: + resolved = resolve_main_contract(p) + if resolved: + main = _enrich_item(resolved) + if main: + buckets[p["exchange"]].append(main) + + groups: list[dict] = [] + for cat in EXCHANGE_ORDER: + items = buckets.get(cat) + if items: + groups.append({"category": cat, "items": items}) + return groups + + +_start_warm_thread() + + +def get_price(market_code: str, sina_code: str = "") -> Optional[float]: + return market_get_price(market_code, sina_code) diff --git a/trading_context.py b/modules/core/trading_context.py similarity index 90% rename from trading_context.py rename to modules/core/trading_context.py index 56edda7..5e5e39e 100644 --- a/trading_context.py +++ b/modules/core/trading_context.py @@ -1,184 +1,184 @@ -# Copyright (c) 2025-2026 马建军. All rights reserved. -# 专有软件 — 未经授权禁止复制、传播、转售。 -# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 -# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md - -"""交易上下文:设置读取、资金、模式。""" -from __future__ import annotations - -from typing import Callable, Optional - -TRADING_MODE_SIM = "simulation" # SimNow CTP -TRADING_MODE_LIVE = "live" # 期货公司 CTP - - -def get_trading_mode(get_setting: Callable[[str, str], str]) -> str: - m = (get_setting("trading_mode", TRADING_MODE_SIM) or TRADING_MODE_SIM).strip().lower() - return m if m in (TRADING_MODE_SIM, TRADING_MODE_LIVE) else TRADING_MODE_SIM - - -def get_sizing_mode(get_setting: Callable[[str, str], str]) -> str: - from position_sizing import normalize_sizing_mode - return normalize_sizing_mode(get_setting("position_sizing_mode", "fixed")) - - -def get_fixed_lots(get_setting: Callable[[str, str], str]) -> int: - try: - return max(1, int(float(get_setting("fixed_lots", "1") or 1))) - except (TypeError, ValueError): - return 1 - - -def get_fixed_amount(get_setting: Callable[[str, str], str]) -> float: - try: - return max(1.0, float(get_setting("fixed_amount", "5000") or 5000)) - except (TypeError, ValueError): - return 5000.0 - - -def get_risk_percent(get_setting: Callable[[str, str], str]) -> float: - try: - return max(0.1, float(get_setting("risk_percent", "1") or 1)) - except (TypeError, ValueError): - return 1.0 - - -def get_max_margin_pct(get_setting: Callable[[str, str], str]) -> float: - """单笔/总仓位保证金占权益上限(%),默认 30。""" - try: - return max(1.0, min(100.0, float(get_setting("max_margin_pct", "30") or 30))) - except (TypeError, ValueError): - return 30.0 - - -def get_roll_max_margin_pct(get_setting: Callable[[str, str], str]) -> float: - """滚仓后总保证金占权益上限(%),默认 50。""" - try: - return max(1.0, min(100.0, float(get_setting("roll_max_margin_pct", "50") or 50))) - except (TypeError, ValueError): - return 50.0 - - -def get_trailing_be_tick_buffer(get_setting: Callable[[str, str], str]) -> int: - """移动保本:止损移至开仓价 ± N 个最小变动价位(默认 2)。""" - try: - return max(1, min(20, int(float(get_setting("trailing_be_tick_buffer", "2") or 2)))) - except (TypeError, ValueError): - return 2 - - -def get_pending_order_timeout_min(get_setting: Callable[[str, str], str]) -> int: - """开仓限价委托未成交自动撤单时间(分钟),默认 5。""" - try: - return max(1, min(60, int(float(get_setting("pending_order_timeout_min", "5") or 5)))) - except (TypeError, ValueError): - return 5 - - -def get_pending_order_timeout_sec(get_setting: Callable[[str, str], str]) -> int: - return get_pending_order_timeout_min(get_setting) * 60 - - -def _cached_ctp_account(mode: str) -> dict[str, float]: - """CTP 未连接时,用最近一次 worker/持仓快照里的账户权益。""" - import json - - try: - from position_stream import position_hub - - snap = position_hub.get_snapshot() or {} - cap = float(snap.get("capital") or 0) - if cap > 0: - return {"balance": cap} - except Exception: - pass - try: - from db_conn import connect_db - - conn = connect_db() - try: - row = conn.execute( - "SELECT value FROM ctp_worker_snapshots WHERE key='account' LIMIT 1" - ).fetchone() - finally: - conn.close() - if row and row["value"]: - acc = json.loads(row["value"]) - balance = float(acc.get("balance") or 0) - available = acc.get("available") - out: dict[str, float] = {} - if balance > 0: - out["balance"] = balance - if available is not None: - out["available"] = float(available) - return out - except Exception: - pass - del mode - return {} - - -def _ctp_status_from_snapshot(mode: str) -> Optional[dict]: - """读持仓快照中的 CTP 状态,避免页面渲染同步 IPC。""" - try: - from position_stream import position_hub - - snap = position_hub.get_snapshot() or {} - st = snap.get("ctp_status") - if isinstance(st, dict) and st: - return st - except Exception: - pass - del mode - return None - - -def get_account_capital(conn, get_setting: Callable[[str, str], str]) -> float: - """优先读持仓/Worker 快照权益;无快照时才同步问 CTP。""" - del conn - mode = get_trading_mode(get_setting) - cached = _cached_ctp_account(mode) - balance = float(cached.get("balance") or 0) - if balance > 0: - return balance - try: - from vnpy_bridge import ctp_status, get_ctp_balance - - st = ctp_status(mode) - if st.get("connected"): - bal = get_ctp_balance(mode) - if bal and bal > 0: - return float(bal) - except Exception: - pass - try: - return float(get_setting("live_capital", "0") or 0) - except (TypeError, ValueError): - return 0.0 - - -def get_recommend_capital(conn, get_setting: Callable[[str, str], str]) -> float: - """可开仓品种表用权益:已连接 CTP 用柜台权益,未连接固定 10 万。""" - from product_recommend import DISCONNECTED_RECOMMEND_CAPITAL - - if is_ctp_connected(get_setting): - return get_account_capital(conn, get_setting) - return float(DISCONNECTED_RECOMMEND_CAPITAL) - - -def is_ctp_connected(get_setting: Callable[[str, str], str]) -> bool: - """当前交易模式(SimNow / 实盘)是否已连接 CTP。""" - mode = get_trading_mode(get_setting) - st = _ctp_status_from_snapshot(mode) - if st is not None: - return bool(st.get("connected")) - try: - from vnpy_bridge import ctp_status - - return bool(ctp_status(mode).get("connected")) - except Exception: - return False - - -def trading_mode_label(get_setting: Callable[[str, str], str]) -> str: - return "SimNow" if get_trading_mode(get_setting) == TRADING_MODE_SIM else "期货公司实盘" +# Copyright (c) 2025-2026 马建军. All rights reserved. +# 专有软件 — 未经授权禁止复制、传播、转售。 +# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 +# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md + +"""交易上下文:设置读取、资金、模式。""" +from __future__ import annotations + +from typing import Callable, Optional + +TRADING_MODE_SIM = "simulation" # SimNow CTP +TRADING_MODE_LIVE = "live" # 期货公司 CTP + + +def get_trading_mode(get_setting: Callable[[str, str], str]) -> str: + m = (get_setting("trading_mode", TRADING_MODE_SIM) or TRADING_MODE_SIM).strip().lower() + return m if m in (TRADING_MODE_SIM, TRADING_MODE_LIVE) else TRADING_MODE_SIM + + +def get_sizing_mode(get_setting: Callable[[str, str], str]) -> str: + from modules.trading.position_sizing import normalize_sizing_mode + return normalize_sizing_mode(get_setting("position_sizing_mode", "fixed")) + + +def get_fixed_lots(get_setting: Callable[[str, str], str]) -> int: + try: + return max(1, int(float(get_setting("fixed_lots", "1") or 1))) + except (TypeError, ValueError): + return 1 + + +def get_fixed_amount(get_setting: Callable[[str, str], str]) -> float: + try: + return max(1.0, float(get_setting("fixed_amount", "5000") or 5000)) + except (TypeError, ValueError): + return 5000.0 + + +def get_risk_percent(get_setting: Callable[[str, str], str]) -> float: + try: + return max(0.1, float(get_setting("risk_percent", "1") or 1)) + except (TypeError, ValueError): + return 1.0 + + +def get_max_margin_pct(get_setting: Callable[[str, str], str]) -> float: + """单笔/总仓位保证金占权益上限(%),默认 30。""" + try: + return max(1.0, min(100.0, float(get_setting("max_margin_pct", "30") or 30))) + except (TypeError, ValueError): + return 30.0 + + +def get_roll_max_margin_pct(get_setting: Callable[[str, str], str]) -> float: + """滚仓后总保证金占权益上限(%),默认 50。""" + try: + return max(1.0, min(100.0, float(get_setting("roll_max_margin_pct", "50") or 50))) + except (TypeError, ValueError): + return 50.0 + + +def get_trailing_be_tick_buffer(get_setting: Callable[[str, str], str]) -> int: + """移动保本:止损移至开仓价 ± N 个最小变动价位(默认 2)。""" + try: + return max(1, min(20, int(float(get_setting("trailing_be_tick_buffer", "2") or 2)))) + except (TypeError, ValueError): + return 2 + + +def get_pending_order_timeout_min(get_setting: Callable[[str, str], str]) -> int: + """开仓限价委托未成交自动撤单时间(分钟),默认 5。""" + try: + return max(1, min(60, int(float(get_setting("pending_order_timeout_min", "5") or 5)))) + except (TypeError, ValueError): + return 5 + + +def get_pending_order_timeout_sec(get_setting: Callable[[str, str], str]) -> int: + return get_pending_order_timeout_min(get_setting) * 60 + + +def _cached_ctp_account(mode: str) -> dict[str, float]: + """CTP 未连接时,用最近一次 worker/持仓快照里的账户权益。""" + import json + + try: + from modules.trading.position_stream import position_hub + + snap = position_hub.get_snapshot() or {} + cap = float(snap.get("capital") or 0) + if cap > 0: + return {"balance": cap} + except Exception: + pass + try: + from modules.core.db_conn import connect_db + + conn = connect_db() + try: + row = conn.execute( + "SELECT value FROM ctp_worker_snapshots WHERE key='account' LIMIT 1" + ).fetchone() + finally: + conn.close() + if row and row["value"]: + acc = json.loads(row["value"]) + balance = float(acc.get("balance") or 0) + available = acc.get("available") + out: dict[str, float] = {} + if balance > 0: + out["balance"] = balance + if available is not None: + out["available"] = float(available) + return out + except Exception: + pass + del mode + return {} + + +def _ctp_status_from_snapshot(mode: str) -> Optional[dict]: + """读持仓快照中的 CTP 状态,避免页面渲染同步 IPC。""" + try: + from modules.trading.position_stream import position_hub + + snap = position_hub.get_snapshot() or {} + st = snap.get("ctp_status") + if isinstance(st, dict) and st: + return st + except Exception: + pass + del mode + return None + + +def get_account_capital(conn, get_setting: Callable[[str, str], str]) -> float: + """优先读持仓/Worker 快照权益;无快照时才同步问 CTP。""" + del conn + mode = get_trading_mode(get_setting) + cached = _cached_ctp_account(mode) + balance = float(cached.get("balance") or 0) + if balance > 0: + return balance + try: + from modules.ctp.vnpy_bridge import ctp_status, get_ctp_balance + + st = ctp_status(mode) + if st.get("connected"): + bal = get_ctp_balance(mode) + if bal and bal > 0: + return float(bal) + except Exception: + pass + try: + return float(get_setting("live_capital", "0") or 0) + except (TypeError, ValueError): + return 0.0 + + +def get_recommend_capital(conn, get_setting: Callable[[str, str], str]) -> float: + """可开仓品种表用权益:已连接 CTP 用柜台权益,未连接固定 10 万。""" + from modules.trading.product_recommend import DISCONNECTED_RECOMMEND_CAPITAL + + if is_ctp_connected(get_setting): + return get_account_capital(conn, get_setting) + return float(DISCONNECTED_RECOMMEND_CAPITAL) + + +def is_ctp_connected(get_setting: Callable[[str, str], str]) -> bool: + """当前交易模式(SimNow / 实盘)是否已连接 CTP。""" + mode = get_trading_mode(get_setting) + st = _ctp_status_from_snapshot(mode) + if st is not None: + return bool(st.get("connected")) + try: + from modules.ctp.vnpy_bridge import ctp_status + + return bool(ctp_status(mode).get("connected")) + except Exception: + return False + + +def trading_mode_label(get_setting: Callable[[str, str], str]) -> str: + return "SimNow" if get_trading_mode(get_setting) == TRADING_MODE_SIM else "期货公司实盘" diff --git a/modules/ctp/__init__.py b/modules/ctp/__init__.py new file mode 100644 index 0000000..2bb8aaf --- /dev/null +++ b/modules/ctp/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) 2025-2026 马建军. All rights reserved. + +"""CTP / vn.py integration — single-process mode.""" + + +def register(deps) -> None: + del deps + + +__all__ = ["register"] diff --git a/ctp_entry_price.py b/modules/ctp/ctp_entry_price.py similarity index 88% rename from ctp_entry_price.py rename to modules/ctp/ctp_entry_price.py index aaf93e2..45aa5c5 100644 --- a/ctp_entry_price.py +++ b/modules/ctp/ctp_entry_price.py @@ -1,63 +1,63 @@ -# Copyright (c) 2025-2026 马建军. All rights reserved. -# 详见 LICENSE.zh-CN.txt - -"""CTP 持仓均价:仅使用柜台持仓回报(vnpy pos.price = PositionCost 加权)。""" -from __future__ import annotations - -from typing import Any, Optional - -from contract_specs import get_contract_spec -from ctp_symbol import ths_to_vnpy_symbol -from symbols import ths_to_codes - - -def symbols_match(ctp_sym: str, ths: str) -> bool: - a = (ctp_sym or "").lower() - b = (ths or "").lower() - if a == b: - return True - if a and b and a.split(".")[0] == b.split(".")[0]: - return True - try: - vnpy_sym, _ = ths_to_vnpy_symbol(ths) - if a == vnpy_sym.lower(): - return True - except Exception: - pass - try: - vnpy_sym, _ = ths_to_vnpy_symbol(ctp_sym) - if vnpy_sym.lower() == b.split(".")[0]: - return True - except Exception: - pass - return False - - -def _ths_code(sym: str) -> str: - codes = ths_to_codes(sym) or {} - return codes.get("ths_code") or sym - - -def round_to_tick(price: float, sym: str) -> float: - tick = float(get_contract_spec(_ths_code(sym)).get("tick_size") or 1.0) - if tick <= 0: - return round(price, 2) - return round(round(price / tick) * tick, 4) - - -def resolve_ctp_entry( - sym: str, - direction: str, - ctp: Optional[dict[str, Any]], - trades: Optional[list[dict[str, Any]]] = None, - *, - tick: Optional[float] = None, -) -> tuple[float, str]: - """均价:仅柜台持仓价(trades/tick 参数保留兼容,不参与计算)。""" - del direction, trades, tick - if not ctp: - return 0.0, "none" - pos_avg = float(ctp.get("avg_price") or 0) - if pos_avg > 0: - return round_to_tick(pos_avg, sym), "ctp" - return 0.0, "none" +# Copyright (c) 2025-2026 马建军. All rights reserved. +# 详见 LICENSE.zh-CN.txt + +"""CTP 持仓均价:仅使用柜台持仓回报(vnpy pos.price = PositionCost 加权)。""" +from __future__ import annotations + +from typing import Any, Optional + +from modules.core.contract_specs import get_contract_spec +from modules.ctp.ctp_symbol import ths_to_vnpy_symbol +from modules.core.symbols import ths_to_codes + + +def symbols_match(ctp_sym: str, ths: str) -> bool: + a = (ctp_sym or "").lower() + b = (ths or "").lower() + if a == b: + return True + if a and b and a.split(".")[0] == b.split(".")[0]: + return True + try: + vnpy_sym, _ = ths_to_vnpy_symbol(ths) + if a == vnpy_sym.lower(): + return True + except Exception: + pass + try: + vnpy_sym, _ = ths_to_vnpy_symbol(ctp_sym) + if vnpy_sym.lower() == b.split(".")[0]: + return True + except Exception: + pass + return False + + +def _ths_code(sym: str) -> str: + codes = ths_to_codes(sym) or {} + return codes.get("ths_code") or sym + + +def round_to_tick(price: float, sym: str) -> float: + tick = float(get_contract_spec(_ths_code(sym)).get("tick_size") or 1.0) + if tick <= 0: + return round(price, 2) + return round(round(price / tick) * tick, 4) + + +def resolve_ctp_entry( + sym: str, + direction: str, + ctp: Optional[dict[str, Any]], + trades: Optional[list[dict[str, Any]]] = None, + *, + tick: Optional[float] = None, +) -> tuple[float, str]: + """均价:仅柜台持仓价(trades/tick 参数保留兼容,不参与计算)。""" + del direction, trades, tick + if not ctp: + return 0.0, "none" + pos_avg = float(ctp.get("avg_price") or 0) + if pos_avg > 0: + return round_to_tick(pos_avg, sym), "ctp" + return 0.0, "none" diff --git a/ctp_fee_sync.py b/modules/ctp/ctp_fee_sync.py similarity index 92% rename from ctp_fee_sync.py rename to modules/ctp/ctp_fee_sync.py index 86986f6..d7aa24c 100644 --- a/ctp_fee_sync.py +++ b/modules/ctp/ctp_fee_sync.py @@ -1,144 +1,144 @@ -# Copyright (c) 2025-2026 马建军. All rights reserved. -# 专有软件 — 未经授权禁止复制、传播、转售。 -# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 -# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md - -"""从 CTP 柜台同步手续费率(SimNow / 期货公司)。""" -from __future__ import annotations - -import logging -import re -import time -from typing import Optional - -from contract_specs import get_contract_spec -from fee_specs import upsert_fee_rate -from vnpy_bridge import get_bridge - -logger = logging.getLogger(__name__) - - -def _product_from_instrument(instrument_id: str) -> str: - m = re.match(r"^([A-Za-z]+)", instrument_id or "") - return m.group(1).lower() if m else "" - - -def ctp_commission_to_fee_fields(data: dict, ths_code: str) -> dict: - """CTP OnRspQryInstrumentCommissionRate → fee_rates 字段。""" - mult = int(get_contract_spec(ths_code)["mult"]) - exchange = str(data.get("ExchangeID") or "").strip() - return { - "exchange": exchange, - "mult": mult, - "open_fixed": float(data.get("OpenRatioByVolume") or 0), - "open_ratio": float(data.get("OpenRatioByMoney") or 0), - "close_yesterday_fixed": float(data.get("CloseRatioByVolume") or 0), - "close_yesterday_ratio": float(data.get("CloseRatioByMoney") or 0), - "close_today_fixed": float(data.get("CloseTodayRatioByVolume") or 0), - "close_today_ratio": float(data.get("CloseTodayRatioByMoney") or 0), - "source": "ctp", - } - - -def _collect_main_ths_codes() -> list[str]: - """从主力列表收集同花顺合约代码(供 CTP 手续费查询)。""" - from datetime import date - - from symbols import PRODUCTS, build_ths_code, list_main_contracts_grouped - - symbols: list[str] = [] - for group in list_main_contracts_grouped(): - for item in group.get("items") or []: - ths = (item.get("ths_code") or item.get("ths") or item.get("code") or "").strip() - if ths and not ths.endswith("888"): - symbols.append(ths) - - if symbols: - return symbols - - today = date.today() - for p in PRODUCTS: - symbols.append(build_ths_code(p, today.year, today.month)) - return symbols - - -def sync_fees_from_ctp(mode: str, *, max_symbols: int = 80) -> tuple[int, str]: - """CTP 已连接时查询手续费并写入 fee_rates(source=ctp,覆盖同品种旧数据)。""" - bridge = get_bridge() - if not bridge.available(): - return 0, "vnpy 未安装" - if bridge.connected_mode != mode: - return 0, "请先连接 CTP" - if not bridge.ping(): - return 0, "CTP 连接无效,请重连" - - seen: set[str] = set() - ok = 0 - errors = 0 - - batch = bridge.query_all_commissions(mode=mode) - if batch: - for raw in batch: - inst = str(raw.get("InstrumentID") or "").strip() - product = _product_from_instrument(inst) - if not product or product in seen: - continue - seen.add(product) - try: - fields = ctp_commission_to_fee_fields(raw, inst or product) - upsert_fee_rate(product, fields) - ok += 1 - except Exception as exc: - logger.debug("CTP fee batch %s: %s", inst, exc) - errors += 1 - if ok > 0: - msg = f"已从 CTP 批量同步 {ok} 个品种手续费" - if errors: - msg += f"({errors} 个跳过)" - return ok, msg - - symbols = _collect_main_ths_codes()[:max_symbols] - - if not symbols: - return 0, "无主力合约列表" - - for ths in symbols: - product = _product_from_instrument(ths) - if not product or product in seen: - continue - seen.add(product) - try: - raw = bridge.query_instrument_commission(ths, mode=mode) - if not raw: - errors += 1 - continue - fields = ctp_commission_to_fee_fields(raw, ths) - upsert_fee_rate(product, fields) - ok += 1 - time.sleep(0.35) - except Exception as exc: - logger.debug("CTP fee sync %s: %s", ths, exc) - errors += 1 - - if ok == 0: - return 0, f"CTP 未返回手续费率(失败 {errors} 次),请确认柜台支持查询" - msg = f"已从 CTP 同步 {ok} 个品种手续费" - if errors: - msg += f"({errors} 个跳过)" - return ok, msg - - -def sync_fee_for_symbol(mode: str, ths_code: str) -> Optional[dict]: - """单品种按需从 CTP 拉取并缓存。""" - bridge = get_bridge() - if bridge.connected_mode != mode or not bridge.ping(): - return None - raw = bridge.query_instrument_commission(ths_code, mode=mode) - if not raw: - return None - product = _product_from_instrument(ths_code) - if not product: - return None - fields = ctp_commission_to_fee_fields(raw, ths_code) - upsert_fee_rate(product, fields) - return fields +# Copyright (c) 2025-2026 马建军. All rights reserved. +# 专有软件 — 未经授权禁止复制、传播、转售。 +# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 +# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md + +"""从 CTP 柜台同步手续费率(SimNow / 期货公司)。""" +from __future__ import annotations + +import logging +import re +import time +from typing import Optional + +from modules.core.contract_specs import get_contract_spec +from modules.fees.fee_specs import upsert_fee_rate +from modules.ctp.vnpy_bridge import get_bridge + +logger = logging.getLogger(__name__) + + +def _product_from_instrument(instrument_id: str) -> str: + m = re.match(r"^([A-Za-z]+)", instrument_id or "") + return m.group(1).lower() if m else "" + + +def ctp_commission_to_fee_fields(data: dict, ths_code: str) -> dict: + """CTP OnRspQryInstrumentCommissionRate → fee_rates 字段。""" + mult = int(get_contract_spec(ths_code)["mult"]) + exchange = str(data.get("ExchangeID") or "").strip() + return { + "exchange": exchange, + "mult": mult, + "open_fixed": float(data.get("OpenRatioByVolume") or 0), + "open_ratio": float(data.get("OpenRatioByMoney") or 0), + "close_yesterday_fixed": float(data.get("CloseRatioByVolume") or 0), + "close_yesterday_ratio": float(data.get("CloseRatioByMoney") or 0), + "close_today_fixed": float(data.get("CloseTodayRatioByVolume") or 0), + "close_today_ratio": float(data.get("CloseTodayRatioByMoney") or 0), + "source": "ctp", + } + + +def _collect_main_ths_codes() -> list[str]: + """从主力列表收集同花顺合约代码(供 CTP 手续费查询)。""" + from datetime import date + + from modules.core.symbols import PRODUCTS, build_ths_code, list_main_contracts_grouped + + symbols: list[str] = [] + for group in list_main_contracts_grouped(): + for item in group.get("items") or []: + ths = (item.get("ths_code") or item.get("ths") or item.get("code") or "").strip() + if ths and not ths.endswith("888"): + symbols.append(ths) + + if symbols: + return symbols + + today = date.today() + for p in PRODUCTS: + symbols.append(build_ths_code(p, today.year, today.month)) + return symbols + + +def sync_fees_from_ctp(mode: str, *, max_symbols: int = 80) -> tuple[int, str]: + """CTP 已连接时查询手续费并写入 fee_rates(source=ctp,覆盖同品种旧数据)。""" + bridge = get_bridge() + if not bridge.available(): + return 0, "vnpy 未安装" + if bridge.connected_mode != mode: + return 0, "请先连接 CTP" + if not bridge.ping(): + return 0, "CTP 连接无效,请重连" + + seen: set[str] = set() + ok = 0 + errors = 0 + + batch = bridge.query_all_commissions(mode=mode) + if batch: + for raw in batch: + inst = str(raw.get("InstrumentID") or "").strip() + product = _product_from_instrument(inst) + if not product or product in seen: + continue + seen.add(product) + try: + fields = ctp_commission_to_fee_fields(raw, inst or product) + upsert_fee_rate(product, fields) + ok += 1 + except Exception as exc: + logger.debug("CTP fee batch %s: %s", inst, exc) + errors += 1 + if ok > 0: + msg = f"已从 CTP 批量同步 {ok} 个品种手续费" + if errors: + msg += f"({errors} 个跳过)" + return ok, msg + + symbols = _collect_main_ths_codes()[:max_symbols] + + if not symbols: + return 0, "无主力合约列表" + + for ths in symbols: + product = _product_from_instrument(ths) + if not product or product in seen: + continue + seen.add(product) + try: + raw = bridge.query_instrument_commission(ths, mode=mode) + if not raw: + errors += 1 + continue + fields = ctp_commission_to_fee_fields(raw, ths) + upsert_fee_rate(product, fields) + ok += 1 + time.sleep(0.35) + except Exception as exc: + logger.debug("CTP fee sync %s: %s", ths, exc) + errors += 1 + + if ok == 0: + return 0, f"CTP 未返回手续费率(失败 {errors} 次),请确认柜台支持查询" + msg = f"已从 CTP 同步 {ok} 个品种手续费" + if errors: + msg += f"({errors} 个跳过)" + return ok, msg + + +def sync_fee_for_symbol(mode: str, ths_code: str) -> Optional[dict]: + """单品种按需从 CTP 拉取并缓存。""" + bridge = get_bridge() + if bridge.connected_mode != mode or not bridge.ping(): + return None + raw = bridge.query_instrument_commission(ths_code, mode=mode) + if not raw: + return None + product = _product_from_instrument(ths_code) + if not product: + return None + fields = ctp_commission_to_fee_fields(raw, ths_code) + upsert_fee_rate(product, fields) + return fields diff --git a/ctp_fee_worker.py b/modules/ctp/ctp_fee_worker.py similarity index 94% rename from ctp_fee_worker.py rename to modules/ctp/ctp_fee_worker.py index d7ab04b..7bb1ae2 100644 --- a/ctp_fee_worker.py +++ b/modules/ctp/ctp_fee_worker.py @@ -1,131 +1,131 @@ -# Copyright (c) 2025-2026 马建军. All rights reserved. -# 专有软件 — 未经授权禁止复制、传播、转售。 -# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 -# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md - -"""CTP 手续费后台同步:每日一次写入数据库,前端只读展示。""" -from __future__ import annotations - -import logging -import threading -import time -from datetime import date, datetime -from typing import Callable, Optional -from zoneinfo import ZoneInfo - -logger = logging.getLogger(__name__) - -TZ = ZoneInfo("Asia/Shanghai") -FEE_SYNC_KEY = "ctp_fee_last_sync" -CHECK_INTERVAL_SEC = 3600 -_sync_lock = threading.Lock() - - -def fee_sync_in_progress() -> bool: - return _sync_lock.locked() - - -def _today_str() -> str: - return datetime.now(TZ).date().isoformat() - - -def get_fee_last_sync(get_setting: Callable[[str, str], str]) -> str: - return (get_setting(FEE_SYNC_KEY, "") or "").strip() - - -def fees_synced_today(get_setting: Callable[[str, str], str]) -> bool: - last = get_fee_last_sync(get_setting) - return bool(last) and last[:10] == _today_str() - - -def mark_fees_synced(set_setting: Callable[[str, str], None]) -> None: - set_setting(FEE_SYNC_KEY, datetime.now(TZ).isoformat(timespec="seconds")) - - -def try_daily_ctp_fee_sync( - mode: str, - *, - get_setting: Callable[[str, str], str], - set_setting: Callable[[str, str], None], - force: bool = False, -) -> tuple[int, str]: - """CTP 已连接且今日未同步时拉取费率入库;force=True 忽略日期限制。""" - if not force and fees_synced_today(get_setting): - return 0, "今日已从 CTP 同步过,无需重复(可点「立即同步」强制刷新)" - - with _sync_lock: - if not force and fees_synced_today(get_setting): - return 0, "今日已从 CTP 同步过" - - t0 = time.monotonic() - from ctp_fee_sync import sync_fees_from_ctp - - count, msg = sync_fees_from_ctp(mode) - elapsed = time.monotonic() - t0 - if count > 0: - mark_fees_synced(set_setting) - msg = f"{msg}(耗时 {elapsed:.1f} 秒)" - logger.info("CTP 手续费每日同步: %s", msg) - elif force: - msg = f"{msg}(耗时 {elapsed:.1f} 秒)" - logger.warning("CTP 手续费强制同步未写入: %s", msg) - return count, msg - - -def schedule_ctp_fee_sync( - mode: str, - *, - get_setting: Callable[[str, str], str], - set_setting: Callable[[str, str], None], - force: bool = False, -) -> tuple[bool, str]: - """后台线程同步,避免阻塞 Web 请求。""" - if _sync_lock.locked(): - return False, "手续费同步进行中,请稍后再试(约 1~3 分钟)" - - def _run() -> None: - try: - try_daily_ctp_fee_sync( - mode, - get_setting=get_setting, - set_setting=set_setting, - force=force, - ) - except Exception as exc: - logger.exception("CTP 手续费后台同步失败: %s", exc) - - threading.Thread(target=_run, daemon=True, name="ctp-fee-sync-run").start() - if force: - return True, "已在后台开始同步,约 30 秒~2 分钟完成,请稍后刷新本页查看" - return True, "已在后台检查同步,请稍后刷新本页" - - -def start_ctp_fee_worker( - *, - get_mode_fn: Callable[[], str], - get_setting_fn: Callable[[str, str], str], - set_setting_fn: Callable[[str, str], None], - interval: int = CHECK_INTERVAL_SEC, -) -> None: - """后台线程:每小时检查,CTP 已连接且当日未同步则自动同步。""" - - def _loop() -> None: - time.sleep(20) - while True: - try: - from vnpy_bridge import ctp_status - - mode = get_mode_fn() - st = ctp_status(mode) - if st.get("connected") and not fees_synced_today(get_setting_fn): - try_daily_ctp_fee_sync( - mode, - get_setting=get_setting_fn, - set_setting=set_setting_fn, - force=False, - ) - except Exception as exc: - logger.warning("CTP fee worker: %s", exc) - time.sleep(max(300, interval)) - - threading.Thread(target=_loop, daemon=True, name="ctp-fee-worker").start() +# Copyright (c) 2025-2026 马建军. All rights reserved. +# 专有软件 — 未经授权禁止复制、传播、转售。 +# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 +# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md + +"""CTP 手续费后台同步:每日一次写入数据库,前端只读展示。""" +from __future__ import annotations + +import logging +import threading +import time +from datetime import date, datetime +from typing import Callable, Optional +from zoneinfo import ZoneInfo + +logger = logging.getLogger(__name__) + +TZ = ZoneInfo("Asia/Shanghai") +FEE_SYNC_KEY = "ctp_fee_last_sync" +CHECK_INTERVAL_SEC = 3600 +_sync_lock = threading.Lock() + + +def fee_sync_in_progress() -> bool: + return _sync_lock.locked() + + +def _today_str() -> str: + return datetime.now(TZ).date().isoformat() + + +def get_fee_last_sync(get_setting: Callable[[str, str], str]) -> str: + return (get_setting(FEE_SYNC_KEY, "") or "").strip() + + +def fees_synced_today(get_setting: Callable[[str, str], str]) -> bool: + last = get_fee_last_sync(get_setting) + return bool(last) and last[:10] == _today_str() + + +def mark_fees_synced(set_setting: Callable[[str, str], None]) -> None: + set_setting(FEE_SYNC_KEY, datetime.now(TZ).isoformat(timespec="seconds")) + + +def try_daily_ctp_fee_sync( + mode: str, + *, + get_setting: Callable[[str, str], str], + set_setting: Callable[[str, str], None], + force: bool = False, +) -> tuple[int, str]: + """CTP 已连接且今日未同步时拉取费率入库;force=True 忽略日期限制。""" + if not force and fees_synced_today(get_setting): + return 0, "今日已从 CTP 同步过,无需重复(可点「立即同步」强制刷新)" + + with _sync_lock: + if not force and fees_synced_today(get_setting): + return 0, "今日已从 CTP 同步过" + + t0 = time.monotonic() + from modules.ctp.ctp_fee_sync import sync_fees_from_ctp + + count, msg = sync_fees_from_ctp(mode) + elapsed = time.monotonic() - t0 + if count > 0: + mark_fees_synced(set_setting) + msg = f"{msg}(耗时 {elapsed:.1f} 秒)" + logger.info("CTP 手续费每日同步: %s", msg) + elif force: + msg = f"{msg}(耗时 {elapsed:.1f} 秒)" + logger.warning("CTP 手续费强制同步未写入: %s", msg) + return count, msg + + +def schedule_ctp_fee_sync( + mode: str, + *, + get_setting: Callable[[str, str], str], + set_setting: Callable[[str, str], None], + force: bool = False, +) -> tuple[bool, str]: + """后台线程同步,避免阻塞 Web 请求。""" + if _sync_lock.locked(): + return False, "手续费同步进行中,请稍后再试(约 1~3 分钟)" + + def _run() -> None: + try: + try_daily_ctp_fee_sync( + mode, + get_setting=get_setting, + set_setting=set_setting, + force=force, + ) + except Exception as exc: + logger.exception("CTP 手续费后台同步失败: %s", exc) + + threading.Thread(target=_run, daemon=True, name="ctp-fee-sync-run").start() + if force: + return True, "已在后台开始同步,约 30 秒~2 分钟完成,请稍后刷新本页查看" + return True, "已在后台检查同步,请稍后刷新本页" + + +def start_ctp_fee_worker( + *, + get_mode_fn: Callable[[], str], + get_setting_fn: Callable[[str, str], str], + set_setting_fn: Callable[[str, str], None], + interval: int = CHECK_INTERVAL_SEC, +) -> None: + """后台线程:每小时检查,CTP 已连接且当日未同步则自动同步。""" + + def _loop() -> None: + time.sleep(20) + while True: + try: + from modules.ctp.vnpy_bridge import ctp_status + + mode = get_mode_fn() + st = ctp_status(mode) + if st.get("connected") and not fees_synced_today(get_setting_fn): + try_daily_ctp_fee_sync( + mode, + get_setting=get_setting_fn, + set_setting=set_setting_fn, + force=False, + ) + except Exception as exc: + logger.warning("CTP fee worker: %s", exc) + time.sleep(max(300, interval)) + + threading.Thread(target=_loop, daemon=True, name="ctp-fee-worker").start() diff --git a/ctp_ipc_client.py b/modules/ctp/ctp_ipc_client.py similarity index 100% rename from ctp_ipc_client.py rename to modules/ctp/ctp_ipc_client.py diff --git a/ctp_kline.py b/modules/ctp/ctp_kline.py similarity index 92% rename from ctp_kline.py rename to modules/ctp/ctp_kline.py index 71d14e4..30c9b39 100644 --- a/ctp_kline.py +++ b/modules/ctp/ctp_kline.py @@ -1,89 +1,89 @@ -# Copyright (c) 2025-2026 马建军. All rights reserved. -# 专有软件 — 未经授权禁止复制、传播、转售。 -# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 -# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md - -"""CTP tick 聚合 K 线(1 分钟为基础,再合成各周期)。""" -from __future__ import annotations - -import logging -from typing import Optional - -from kline_chart import ( - PERIOD_MINUTES, - _aggregate_bars, - _bar_datetime, - _merge_bars, - _timeshare_session, - _weekly_from_daily, -) - -logger = logging.getLogger(__name__) - -PERIOD_AGG = { - "2m": 2, - "3m": 3, - "5m": 5, - "15m": 15, - "30m": 30, - "1h": 60, - "2h": 120, - "4h": 240, -} - - -def _daily_from_1m(bars_1m: list) -> list: - if not bars_1m: - return [] - buckets: dict[str, list] = {} - for bar in bars_1m: - dt = _bar_datetime(bar) - if not dt: - continue - key = dt.strftime("%Y-%m-%d") - buckets.setdefault(key, []).append(bar) - out = [] - for day in sorted(buckets.keys()): - chunk = buckets[day] - merged = _merge_bars(chunk) - merged["d"] = day + " 15:00:00" - out.append(merged) - return out - - -def compose_period_bars(bars_1m: list, period: str) -> list: - p = (period or "15m").lower() - if p == "timeshare": - return _timeshare_session(bars_1m) - if p in ("1d", "d"): - return _daily_from_1m(bars_1m) - if p == "w": - return _weekly_from_daily(_daily_from_1m(bars_1m)) - if p == "1m": - return list(bars_1m) - n = PERIOD_AGG.get(p) - if n: - return _aggregate_bars(bars_1m, n) - if p in PERIOD_MINUTES: - try: - n = int(PERIOD_MINUTES[p]) - return _aggregate_bars(bars_1m, n) - except (TypeError, ValueError): - pass - return list(bars_1m) - - -def fetch_ctp_klines(symbol: str, period: str, mode: str) -> Optional[list]: - """CTP 已连接时由 tick 聚合 K 线;失败返回 None。""" - try: - from vnpy_bridge import ctp_status, get_bridge - - if not ctp_status(mode).get("connected"): - return None - bars_1m = get_bridge().get_kline_bars_1m(symbol, mode=mode) - if not bars_1m: - return None - return compose_period_bars(bars_1m, period) - except Exception as exc: - logger.debug("fetch_ctp_klines %s %s: %s", symbol, period, exc) - return None +# Copyright (c) 2025-2026 马建军. All rights reserved. +# 专有软件 — 未经授权禁止复制、传播、转售。 +# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 +# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md + +"""CTP tick 聚合 K 线(1 分钟为基础,再合成各周期)。""" +from __future__ import annotations + +import logging +from typing import Optional + +from modules.market.kline_chart import ( + PERIOD_MINUTES, + _aggregate_bars, + _bar_datetime, + _merge_bars, + _timeshare_session, + _weekly_from_daily, +) + +logger = logging.getLogger(__name__) + +PERIOD_AGG = { + "2m": 2, + "3m": 3, + "5m": 5, + "15m": 15, + "30m": 30, + "1h": 60, + "2h": 120, + "4h": 240, +} + + +def _daily_from_1m(bars_1m: list) -> list: + if not bars_1m: + return [] + buckets: dict[str, list] = {} + for bar in bars_1m: + dt = _bar_datetime(bar) + if not dt: + continue + key = dt.strftime("%Y-%m-%d") + buckets.setdefault(key, []).append(bar) + out = [] + for day in sorted(buckets.keys()): + chunk = buckets[day] + merged = _merge_bars(chunk) + merged["d"] = day + " 15:00:00" + out.append(merged) + return out + + +def compose_period_bars(bars_1m: list, period: str) -> list: + p = (period or "15m").lower() + if p == "timeshare": + return _timeshare_session(bars_1m) + if p in ("1d", "d"): + return _daily_from_1m(bars_1m) + if p == "w": + return _weekly_from_daily(_daily_from_1m(bars_1m)) + if p == "1m": + return list(bars_1m) + n = PERIOD_AGG.get(p) + if n: + return _aggregate_bars(bars_1m, n) + if p in PERIOD_MINUTES: + try: + n = int(PERIOD_MINUTES[p]) + return _aggregate_bars(bars_1m, n) + except (TypeError, ValueError): + pass + return list(bars_1m) + + +def fetch_ctp_klines(symbol: str, period: str, mode: str) -> Optional[list]: + """CTP 已连接时由 tick 聚合 K 线;失败返回 None。""" + try: + from modules.ctp.vnpy_bridge import ctp_status, get_bridge + + if not ctp_status(mode).get("connected"): + return None + bars_1m = get_bridge().get_kline_bars_1m(symbol, mode=mode) + if not bars_1m: + return None + return compose_period_bars(bars_1m, period) + except Exception as exc: + logger.debug("fetch_ctp_klines %s %s: %s", symbol, period, exc) + return None diff --git a/ctp_premarket_connect.py b/modules/ctp/ctp_premarket_connect.py similarity index 95% rename from ctp_premarket_connect.py rename to modules/ctp/ctp_premarket_connect.py index 10543e4..94cb381 100644 --- a/ctp_premarket_connect.py +++ b/modules/ctp/ctp_premarket_connect.py @@ -1,116 +1,116 @@ -# Copyright (c) 2025-2026 马建军. All rights reserved. -# 专有软件 — 未经授权禁止复制、传播、转售。 -# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 -# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md - -"""CTP 按计划自动连接:盘前 30 分钟检查;交易时段断线后台重连;不自动强制断开。""" -from __future__ import annotations - -import logging -import os -import threading -import time -from typing import Callable - -from market_sessions import ( - in_premarket_connect_window, - in_postmarket_grace_window, - is_trading_session, - should_keep_ctp_connected, -) -from vnpy_bridge import ctp_start_connect, ctp_status - -logger = logging.getLogger(__name__) - -CHECK_INTERVAL_SEC = 60 -TRADING_CHECK_INTERVAL_SEC = 15 -PREMARKET_CHECK_INTERVAL_SEC = 30 -DEFAULT_MINUTES_BEFORE = 30 -DEFAULT_MINUTES_AFTER = 30 - - -def premarket_minutes_before() -> int: - try: - return max(5, int(os.getenv("CTP_PREMARKET_MINUTES", str(DEFAULT_MINUTES_BEFORE)))) - except (TypeError, ValueError): - return DEFAULT_MINUTES_BEFORE - - -def postmarket_minutes_after() -> int: - try: - return max(5, int(os.getenv("CTP_POSTMARKET_MINUTES", str(DEFAULT_MINUTES_AFTER)))) - except (TypeError, ValueError): - return DEFAULT_MINUTES_AFTER - - -def _scheduled_connect_enabled() -> bool: - return (os.getenv("CTP_PREMARKET_CONNECT", "true") or "true").strip().lower() in ( - "1", - "true", - "yes", - ) - - -def should_auto_connect_now(*, minutes_before: int | None = None) -> bool: - """是否应保持/发起 CTP 连接(供重连、权限判断复用)。""" - mins_b = premarket_minutes_before() if minutes_before is None else minutes_before - mins_a = postmarket_minutes_after() - if not _scheduled_connect_enabled() and not is_trading_session(): - if not in_postmarket_grace_window(minutes_after=mins_a): - return False - return should_keep_ctp_connected( - minutes_before=mins_b, - minutes_after=mins_a, - ) - - -def start_ctp_premarket_connect_worker( - *, - get_mode_fn: Callable[[], str], - get_setting_fn: Callable[[str, str], str] | None = None, - interval: int = CHECK_INTERVAL_SEC, -) -> None: - """盘前 30 分钟:未连接则自动连;已连接则不重复发起。不自动强制断开。""" - - def _loop() -> None: - time.sleep(10) - while True: - sleep_sec = max(30, interval) - try: - mins_b = premarket_minutes_before() - mins_a = postmarket_minutes_after() - keep = should_auto_connect_now() - mode = get_mode_fn() - st = ctp_status(mode) - - if keep: - if ( - not st.get("connected") - and not st.get("connecting") - and int(st.get("login_cooldown_sec") or 0) <= 0 - ): - info = ctp_start_connect(mode, force=False, scheduled=True) - if info.get("started"): - if is_trading_session(): - logger.info("交易时段内自动连接 CTP [%s]", mode) - elif in_postmarket_grace_window(minutes_after=mins_a): - logger.info( - "盘后宽限期内恢复 CTP 连接 [%s](收盘后 %d 分钟内)", - mode, - mins_a, - ) - else: - logger.info( - "盘前自动连接 CTP [%s](开盘前 %d 分钟)", - mode, - mins_b, - ) - if is_trading_session(): - sleep_sec = TRADING_CHECK_INTERVAL_SEC - elif in_premarket_connect_window(minutes_before=mins_b): - sleep_sec = PREMARKET_CHECK_INTERVAL_SEC - except Exception as exc: - logger.warning("CTP scheduled connect worker: %s", exc) - time.sleep(sleep_sec) - - threading.Thread(target=_loop, daemon=True, name="ctp-premarket-connect").start() +# Copyright (c) 2025-2026 马建军. All rights reserved. +# 专有软件 — 未经授权禁止复制、传播、转售。 +# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 +# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md + +"""CTP 按计划自动连接:盘前 30 分钟检查;交易时段断线后台重连;不自动强制断开。""" +from __future__ import annotations + +import logging +import os +import threading +import time +from typing import Callable + +from modules.market.market_sessions import ( + in_premarket_connect_window, + in_postmarket_grace_window, + is_trading_session, + should_keep_ctp_connected, +) +from modules.ctp.vnpy_bridge import ctp_start_connect, ctp_status + +logger = logging.getLogger(__name__) + +CHECK_INTERVAL_SEC = 60 +TRADING_CHECK_INTERVAL_SEC = 15 +PREMARKET_CHECK_INTERVAL_SEC = 30 +DEFAULT_MINUTES_BEFORE = 30 +DEFAULT_MINUTES_AFTER = 30 + + +def premarket_minutes_before() -> int: + try: + return max(5, int(os.getenv("CTP_PREMARKET_MINUTES", str(DEFAULT_MINUTES_BEFORE)))) + except (TypeError, ValueError): + return DEFAULT_MINUTES_BEFORE + + +def postmarket_minutes_after() -> int: + try: + return max(5, int(os.getenv("CTP_POSTMARKET_MINUTES", str(DEFAULT_MINUTES_AFTER)))) + except (TypeError, ValueError): + return DEFAULT_MINUTES_AFTER + + +def _scheduled_connect_enabled() -> bool: + return (os.getenv("CTP_PREMARKET_CONNECT", "true") or "true").strip().lower() in ( + "1", + "true", + "yes", + ) + + +def should_auto_connect_now(*, minutes_before: int | None = None) -> bool: + """是否应保持/发起 CTP 连接(供重连、权限判断复用)。""" + mins_b = premarket_minutes_before() if minutes_before is None else minutes_before + mins_a = postmarket_minutes_after() + if not _scheduled_connect_enabled() and not is_trading_session(): + if not in_postmarket_grace_window(minutes_after=mins_a): + return False + return should_keep_ctp_connected( + minutes_before=mins_b, + minutes_after=mins_a, + ) + + +def start_ctp_premarket_connect_worker( + *, + get_mode_fn: Callable[[], str], + get_setting_fn: Callable[[str, str], str] | None = None, + interval: int = CHECK_INTERVAL_SEC, +) -> None: + """盘前 30 分钟:未连接则自动连;已连接则不重复发起。不自动强制断开。""" + + def _loop() -> None: + time.sleep(10) + while True: + sleep_sec = max(30, interval) + try: + mins_b = premarket_minutes_before() + mins_a = postmarket_minutes_after() + keep = should_auto_connect_now() + mode = get_mode_fn() + st = ctp_status(mode) + + if keep: + if ( + not st.get("connected") + and not st.get("connecting") + and int(st.get("login_cooldown_sec") or 0) <= 0 + ): + info = ctp_start_connect(mode, force=False, scheduled=True) + if info.get("started"): + if is_trading_session(): + logger.info("交易时段内自动连接 CTP [%s]", mode) + elif in_postmarket_grace_window(minutes_after=mins_a): + logger.info( + "盘后宽限期内恢复 CTP 连接 [%s](收盘后 %d 分钟内)", + mode, + mins_a, + ) + else: + logger.info( + "盘前自动连接 CTP [%s](开盘前 %d 分钟)", + mode, + mins_b, + ) + if is_trading_session(): + sleep_sec = TRADING_CHECK_INTERVAL_SEC + elif in_premarket_connect_window(minutes_before=mins_b): + sleep_sec = PREMARKET_CHECK_INTERVAL_SEC + except Exception as exc: + logger.warning("CTP scheduled connect worker: %s", exc) + time.sleep(sleep_sec) + + threading.Thread(target=_loop, daemon=True, name="ctp-premarket-connect").start() diff --git a/ctp_reconnect.py b/modules/ctp/ctp_reconnect.py similarity index 86% rename from ctp_reconnect.py rename to modules/ctp/ctp_reconnect.py index 8a674ee..da9d3e1 100644 --- a/ctp_reconnect.py +++ b/modules/ctp/ctp_reconnect.py @@ -1,59 +1,59 @@ -# Copyright (c) 2025-2026 马建军. All rights reserved. -# 专有软件 — 未经授权禁止复制、传播、转售。 -# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 -# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md - -"""CTP 断线自动重连(后台线程)。""" -from __future__ import annotations - -import logging -import os -import threading -import time -from typing import Callable - -from ctp_premarket_connect import premarket_minutes_before, should_auto_connect_now -from market_sessions import in_premarket_connect_window, is_trading_session -from vnpy_bridge import ctp_try_auto_reconnect - -logger = logging.getLogger(__name__) - -RECONNECT_INTERVAL_SEC = 60 -TRADING_RECONNECT_INTERVAL_SEC = 15 -PREMARKET_RECONNECT_INTERVAL_SEC = 30 - - -def _auto_reconnect_enabled() -> bool: - return (os.getenv("CTP_AUTO_RECONNECT", "true") or "true").strip().lower() in ( - "1", - "true", - "yes", - ) - - -def start_ctp_reconnect_worker( - *, - get_mode_fn: Callable[[], str], - get_setting_fn: Callable[[str, str], str] | None = None, - interval: int = RECONNECT_INTERVAL_SEC, -) -> None: - """交易时段 / 盘前窗口内检测 CTP;断线则后台自动重连。""" - - def _loop() -> None: - while True: - sleep_sec = max(5, interval) - try: - if _auto_reconnect_enabled() and should_auto_connect_now(): - mode = get_mode_fn() - ctp_try_auto_reconnect(mode) - if is_trading_session(): - sleep_sec = TRADING_RECONNECT_INTERVAL_SEC - elif in_premarket_connect_window( - minutes_before=premarket_minutes_before(), - ): - sleep_sec = PREMARKET_RECONNECT_INTERVAL_SEC - except Exception as exc: - logger.warning("CTP reconnect worker: %s", exc) - time.sleep(sleep_sec) - - threading.Thread(target=_loop, daemon=True, name="ctp-reconnect-worker").start() +# Copyright (c) 2025-2026 马建军. All rights reserved. +# 专有软件 — 未经授权禁止复制、传播、转售。 +# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 +# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md + +"""CTP 断线自动重连(后台线程)。""" +from __future__ import annotations + +import logging +import os +import threading +import time +from typing import Callable + +from modules.ctp.ctp_premarket_connect import premarket_minutes_before, should_auto_connect_now +from modules.market.market_sessions import in_premarket_connect_window, is_trading_session +from modules.ctp.vnpy_bridge import ctp_try_auto_reconnect + +logger = logging.getLogger(__name__) + +RECONNECT_INTERVAL_SEC = 60 +TRADING_RECONNECT_INTERVAL_SEC = 15 +PREMARKET_RECONNECT_INTERVAL_SEC = 30 + + +def _auto_reconnect_enabled() -> bool: + return (os.getenv("CTP_AUTO_RECONNECT", "true") or "true").strip().lower() in ( + "1", + "true", + "yes", + ) + + +def start_ctp_reconnect_worker( + *, + get_mode_fn: Callable[[], str], + get_setting_fn: Callable[[str, str], str] | None = None, + interval: int = RECONNECT_INTERVAL_SEC, +) -> None: + """交易时段 / 盘前窗口内检测 CTP;断线则后台自动重连。""" + + def _loop() -> None: + while True: + sleep_sec = max(5, interval) + try: + if _auto_reconnect_enabled() and should_auto_connect_now(): + mode = get_mode_fn() + ctp_try_auto_reconnect(mode) + if is_trading_session(): + sleep_sec = TRADING_RECONNECT_INTERVAL_SEC + elif in_premarket_connect_window( + minutes_before=premarket_minutes_before(), + ): + sleep_sec = PREMARKET_RECONNECT_INTERVAL_SEC + except Exception as exc: + logger.warning("CTP reconnect worker: %s", exc) + time.sleep(sleep_sec) + + threading.Thread(target=_loop, daemon=True, name="ctp-reconnect-worker").start() diff --git a/ctp_settings.py b/modules/ctp/ctp_settings.py similarity index 95% rename from ctp_settings.py rename to modules/ctp/ctp_settings.py index 68094c9..cb3d37a 100644 --- a/ctp_settings.py +++ b/modules/ctp/ctp_settings.py @@ -1,154 +1,154 @@ -# Copyright (c) 2025-2026 马建军. All rights reserved. -# 专有软件 — 未经授权禁止复制、传播、转售。 -# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 -# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md - -"""CTP / SimNow 配置:系统设置优先,.env 作兜底。""" -from __future__ import annotations - -import os -from typing import Any, Callable - -# (db_key, env_key, vnpy字段名, 默认值) -SIMNOW_FIELDS: tuple[tuple[str, str, str, str], ...] = ( - ("simnow_user", "SIMNOW_USER", "用户名", ""), - ("simnow_password", "SIMNOW_PASSWORD", "密码", ""), - ("simnow_broker_id", "SIMNOW_BROKER_ID", "经纪商代码", "9999"), - ("simnow_td_address", "SIMNOW_TD_ADDRESS", "交易服务器", "tcp://180.168.146.187:10201"), - ("simnow_md_address", "SIMNOW_MD_ADDRESS", "行情服务器", "tcp://180.168.146.187:10211"), - ("simnow_app_id", "SIMNOW_APP_ID", "产品名称", "simnow_client_test"), - ("simnow_auth_code", "SIMNOW_AUTH_CODE", "授权编码", "0000000000000000"), - ("simnow_env", "SIMNOW_ENV", "柜台环境", "实盘"), -) - -LIVE_FIELDS: tuple[tuple[str, str, str, str], ...] = ( - ("ctp_live_user", "CTP_LIVE_USER", "用户名", ""), - ("ctp_live_password", "CTP_LIVE_PASSWORD", "密码", ""), - ("ctp_live_broker_id", "CTP_LIVE_BROKER_ID", "经纪商代码", ""), - ("ctp_live_td_address", "CTP_LIVE_TD_ADDRESS", "交易服务器", ""), - ("ctp_live_md_address", "CTP_LIVE_MD_ADDRESS", "行情服务器", ""), - ("ctp_live_app_id", "CTP_LIVE_APP_ID", "产品名称", ""), - ("ctp_live_auth_code", "CTP_LIVE_AUTH_CODE", "授权编码", ""), - ("ctp_live_env", "CTP_LIVE_ENV", "柜台环境", "实盘"), -) - -PASSWORD_DB_KEYS = frozenset({"simnow_password", "ctp_live_password"}) - -CTP_AUTO_CONNECT_KEY = "ctp_auto_connect" -CTP_DISABLED_HINT = "CTP 自动连接已关闭(非交易时段不重连;开盘前 30 分钟及交易时段仍会按计划连接;断开请手动操作)" - - -def is_ctp_auto_connect_enabled(get_setting=None) -> bool: - """系统设置:是否允许手动连接及非交易时段自动重连(盘前/交易时段计划连接不受此限制)。""" - if get_setting is None: - from fee_specs import get_setting as _gs - - get_setting = _gs - val = (get_setting(CTP_AUTO_CONNECT_KEY, "1") or "1").strip().lower() - return val in ("1", "true", "yes", "on") - - -def save_ctp_auto_connect(form: Any, set_setting: Callable[[str, str], None]) -> bool: - enabled = (form.get("ctp_auto_connect") or "").strip().lower() in ( - "1", - "on", - "true", - "yes", - ) - set_setting(CTP_AUTO_CONNECT_KEY, "1" if enabled else "0") - return enabled - - -def _get_db_setting(key: str, default: str = "") -> str: - from fee_specs import get_setting - - return (get_setting(key, default) or default).strip() - - -def resolve_ctp_value(db_key: str, env_key: str, default: str = "") -> str: - v = _get_db_setting(db_key, "") - if v: - return v - return (os.getenv(env_key) or default).strip() - - -def _build_setting_dict(fields: tuple[tuple[str, str, str, str], ...]) -> dict[str, str]: - out: dict[str, str] = {} - for db_key, env_key, vnpy_key, default in fields: - out[vnpy_key] = resolve_ctp_value(db_key, env_key, default) - return out - - -def simnow_setting_dict() -> dict[str, str]: - return _build_setting_dict(SIMNOW_FIELDS) - - -def live_setting_dict() -> dict[str, str]: - return _build_setting_dict(LIVE_FIELDS) - - -def seed_ctp_settings_from_env(set_setting: Callable[[str, str], None]) -> None: - """首次启动:将 .env 中已有 CTP 配置写入 settings 表。""" - for db_key, env_key, _, _ in (*SIMNOW_FIELDS, *LIVE_FIELDS): - if _get_db_setting(db_key, ""): - continue - env_val = (os.getenv(env_key) or "").strip() - if env_val: - set_setting(db_key, env_val) - - -def get_ctp_settings_for_ui() -> dict[str, Any]: - ui: dict[str, Any] = {} - for db_key, env_key, _, default in SIMNOW_FIELDS: - ui[db_key] = resolve_ctp_value(db_key, env_key, default) - if db_key in PASSWORD_DB_KEYS: - ui[f"{db_key}_set"] = bool(ui[db_key]) - ui[db_key] = "" - for db_key, env_key, _, default in LIVE_FIELDS: - ui[db_key] = resolve_ctp_value(db_key, env_key, default) - if db_key in PASSWORD_DB_KEYS: - ui[f"{db_key}_set"] = bool(ui[db_key]) - ui[db_key] = "" - ui["ctp_auto_connect"] = is_ctp_auto_connect_enabled() - return ui - - -def save_ctp_settings_from_form( - form: Any, - set_setting: Callable[[str, str], None], -) -> dict[str, Any]: - """保存 CTP 配置;密码留空表示不修改。返回摘要供页面提示。""" - passwords_updated: list[str] = [] - passwords_submitted_empty: list[str] = [] - - for db_key, _, _, default in SIMNOW_FIELDS: - if db_key in PASSWORD_DB_KEYS: - raw = form.get(db_key) - val = (raw or "").strip() - if val: - set_setting(db_key, val) - passwords_updated.append(db_key) - else: - passwords_submitted_empty.append(db_key) - continue - val = (form.get(db_key) or "").strip() - set_setting(db_key, val or default) - - for db_key, _, _, default in LIVE_FIELDS: - if db_key in PASSWORD_DB_KEYS: - raw = form.get(db_key) - val = (raw or "").strip() - if val: - set_setting(db_key, val) - passwords_updated.append(db_key) - else: - passwords_submitted_empty.append(db_key) - continue - val = (form.get(db_key) or "").strip() - if default or val: - set_setting(db_key, val or default) - - return { - "passwords_updated": passwords_updated, - "passwords_submitted_empty": passwords_submitted_empty, - } +# Copyright (c) 2025-2026 马建军. All rights reserved. +# 专有软件 — 未经授权禁止复制、传播、转售。 +# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 +# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md + +"""CTP / SimNow 配置:系统设置优先,.env 作兜底。""" +from __future__ import annotations + +import os +from typing import Any, Callable + +# (db_key, env_key, vnpy字段名, 默认值) +SIMNOW_FIELDS: tuple[tuple[str, str, str, str], ...] = ( + ("simnow_user", "SIMNOW_USER", "用户名", ""), + ("simnow_password", "SIMNOW_PASSWORD", "密码", ""), + ("simnow_broker_id", "SIMNOW_BROKER_ID", "经纪商代码", "9999"), + ("simnow_td_address", "SIMNOW_TD_ADDRESS", "交易服务器", "tcp://180.168.146.187:10201"), + ("simnow_md_address", "SIMNOW_MD_ADDRESS", "行情服务器", "tcp://180.168.146.187:10211"), + ("simnow_app_id", "SIMNOW_APP_ID", "产品名称", "simnow_client_test"), + ("simnow_auth_code", "SIMNOW_AUTH_CODE", "授权编码", "0000000000000000"), + ("simnow_env", "SIMNOW_ENV", "柜台环境", "实盘"), +) + +LIVE_FIELDS: tuple[tuple[str, str, str, str], ...] = ( + ("ctp_live_user", "CTP_LIVE_USER", "用户名", ""), + ("ctp_live_password", "CTP_LIVE_PASSWORD", "密码", ""), + ("ctp_live_broker_id", "CTP_LIVE_BROKER_ID", "经纪商代码", ""), + ("ctp_live_td_address", "CTP_LIVE_TD_ADDRESS", "交易服务器", ""), + ("ctp_live_md_address", "CTP_LIVE_MD_ADDRESS", "行情服务器", ""), + ("ctp_live_app_id", "CTP_LIVE_APP_ID", "产品名称", ""), + ("ctp_live_auth_code", "CTP_LIVE_AUTH_CODE", "授权编码", ""), + ("ctp_live_env", "CTP_LIVE_ENV", "柜台环境", "实盘"), +) + +PASSWORD_DB_KEYS = frozenset({"simnow_password", "ctp_live_password"}) + +CTP_AUTO_CONNECT_KEY = "ctp_auto_connect" +CTP_DISABLED_HINT = "CTP 自动连接已关闭(非交易时段不重连;开盘前 30 分钟及交易时段仍会按计划连接;断开请手动操作)" + + +def is_ctp_auto_connect_enabled(get_setting=None) -> bool: + """系统设置:是否允许手动连接及非交易时段自动重连(盘前/交易时段计划连接不受此限制)。""" + if get_setting is None: + from modules.fees.fee_specs import get_setting as _gs + + get_setting = _gs + val = (get_setting(CTP_AUTO_CONNECT_KEY, "1") or "1").strip().lower() + return val in ("1", "true", "yes", "on") + + +def save_ctp_auto_connect(form: Any, set_setting: Callable[[str, str], None]) -> bool: + enabled = (form.get("ctp_auto_connect") or "").strip().lower() in ( + "1", + "on", + "true", + "yes", + ) + set_setting(CTP_AUTO_CONNECT_KEY, "1" if enabled else "0") + return enabled + + +def _get_db_setting(key: str, default: str = "") -> str: + from modules.fees.fee_specs import get_setting + + return (get_setting(key, default) or default).strip() + + +def resolve_ctp_value(db_key: str, env_key: str, default: str = "") -> str: + v = _get_db_setting(db_key, "") + if v: + return v + return (os.getenv(env_key) or default).strip() + + +def _build_setting_dict(fields: tuple[tuple[str, str, str, str], ...]) -> dict[str, str]: + out: dict[str, str] = {} + for db_key, env_key, vnpy_key, default in fields: + out[vnpy_key] = resolve_ctp_value(db_key, env_key, default) + return out + + +def simnow_setting_dict() -> dict[str, str]: + return _build_setting_dict(SIMNOW_FIELDS) + + +def live_setting_dict() -> dict[str, str]: + return _build_setting_dict(LIVE_FIELDS) + + +def seed_ctp_settings_from_env(set_setting: Callable[[str, str], None]) -> None: + """首次启动:将 .env 中已有 CTP 配置写入 settings 表。""" + for db_key, env_key, _, _ in (*SIMNOW_FIELDS, *LIVE_FIELDS): + if _get_db_setting(db_key, ""): + continue + env_val = (os.getenv(env_key) or "").strip() + if env_val: + set_setting(db_key, env_val) + + +def get_ctp_settings_for_ui() -> dict[str, Any]: + ui: dict[str, Any] = {} + for db_key, env_key, _, default in SIMNOW_FIELDS: + ui[db_key] = resolve_ctp_value(db_key, env_key, default) + if db_key in PASSWORD_DB_KEYS: + ui[f"{db_key}_set"] = bool(ui[db_key]) + ui[db_key] = "" + for db_key, env_key, _, default in LIVE_FIELDS: + ui[db_key] = resolve_ctp_value(db_key, env_key, default) + if db_key in PASSWORD_DB_KEYS: + ui[f"{db_key}_set"] = bool(ui[db_key]) + ui[db_key] = "" + ui["ctp_auto_connect"] = is_ctp_auto_connect_enabled() + return ui + + +def save_ctp_settings_from_form( + form: Any, + set_setting: Callable[[str, str], None], +) -> dict[str, Any]: + """保存 CTP 配置;密码留空表示不修改。返回摘要供页面提示。""" + passwords_updated: list[str] = [] + passwords_submitted_empty: list[str] = [] + + for db_key, _, _, default in SIMNOW_FIELDS: + if db_key in PASSWORD_DB_KEYS: + raw = form.get(db_key) + val = (raw or "").strip() + if val: + set_setting(db_key, val) + passwords_updated.append(db_key) + else: + passwords_submitted_empty.append(db_key) + continue + val = (form.get(db_key) or "").strip() + set_setting(db_key, val or default) + + for db_key, _, _, default in LIVE_FIELDS: + if db_key in PASSWORD_DB_KEYS: + raw = form.get(db_key) + val = (raw or "").strip() + if val: + set_setting(db_key, val) + passwords_updated.append(db_key) + else: + passwords_submitted_empty.append(db_key) + continue + val = (form.get(db_key) or "").strip() + if default or val: + set_setting(db_key, val or default) + + return { + "passwords_updated": passwords_updated, + "passwords_submitted_empty": passwords_submitted_empty, + } diff --git a/ctp_symbol.py b/modules/ctp/ctp_symbol.py similarity index 94% rename from ctp_symbol.py rename to modules/ctp/ctp_symbol.py index 4fd2a6c..1401e3f 100644 --- a/ctp_symbol.py +++ b/modules/ctp/ctp_symbol.py @@ -1,66 +1,66 @@ -# Copyright (c) 2025-2026 马建军. All rights reserved. -# 专有软件 — 未经授权禁止复制、传播、转售。 -# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 -# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md - -"""同花顺合约代码 → vnpy Symbol + Exchange。""" -from __future__ import annotations - -import re -from typing import Optional, Tuple - -from symbols import ths_to_codes - -try: - from vnpy.trader.constant import Exchange -except ImportError: - Exchange = None # type: ignore - -_EX_MAP = { - "SHFE": "SHFE", - "DCE": "DCE", - "CZCE": "CZCE", - "CFFEX": "CFFEX", - "INE": "INE", -} - - -def ths_to_vnpy_symbol(ths_code: str) -> Tuple[str, str]: - """ - 返回 (symbol, exchange_enum_name)。 - 例:rb2610 → rb2610, SHFE;SR609 → SR609, CZCE - """ - code = (ths_code or "").strip() - codes = ths_to_codes(code) - ex = (codes.get("ex") if codes else None) - if not ex and codes: - mc = (codes.get("market_code") or "") - if "." in mc: - ex = mc.rsplit(".", 1)[-1] - ex = _EX_MAP.get(ex or "SHFE", "SHFE") - m = re.match(r"^([A-Za-z]+)(\d+)$", code) - if not m: - return code, ex - letters, digits = m.group(1), m.group(2) - if ex == "CZCE": - # 郑商所 CTP 常为大写 + 3 位年月(如 SR509);4 位则取后 3 位 - sym = letters.upper() + (digits[-3:] if len(digits) >= 3 else digits) - else: - sym = letters.lower() + digits - return sym, ex - - -def to_vnpy_exchange(ex_name: str): - if Exchange is None: - raise ImportError("vnpy 未安装") - mapping = { - "SHFE": Exchange.SHFE, - "DCE": Exchange.DCE, - "CZCE": Exchange.CZCE, - "CFFEX": Exchange.CFFEX, - "INE": Exchange.INE, - } - ex = mapping.get((ex_name or "").upper()) - if ex is None: - raise ValueError(f"未知交易所: {ex_name}") - return ex +# Copyright (c) 2025-2026 马建军. All rights reserved. +# 专有软件 — 未经授权禁止复制、传播、转售。 +# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 +# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md + +"""同花顺合约代码 → vnpy Symbol + Exchange。""" +from __future__ import annotations + +import re +from typing import Optional, Tuple + +from modules.core.symbols import ths_to_codes + +try: + from vnpy.trader.constant import Exchange +except ImportError: + Exchange = None # type: ignore + +_EX_MAP = { + "SHFE": "SHFE", + "DCE": "DCE", + "CZCE": "CZCE", + "CFFEX": "CFFEX", + "INE": "INE", +} + + +def ths_to_vnpy_symbol(ths_code: str) -> Tuple[str, str]: + """ + 返回 (symbol, exchange_enum_name)。 + 例:rb2610 → rb2610, SHFE;SR609 → SR609, CZCE + """ + code = (ths_code or "").strip() + codes = ths_to_codes(code) + ex = (codes.get("ex") if codes else None) + if not ex and codes: + mc = (codes.get("market_code") or "") + if "." in mc: + ex = mc.rsplit(".", 1)[-1] + ex = _EX_MAP.get(ex or "SHFE", "SHFE") + m = re.match(r"^([A-Za-z]+)(\d+)$", code) + if not m: + return code, ex + letters, digits = m.group(1), m.group(2) + if ex == "CZCE": + # 郑商所 CTP 常为大写 + 3 位年月(如 SR509);4 位则取后 3 位 + sym = letters.upper() + (digits[-3:] if len(digits) >= 3 else digits) + else: + sym = letters.lower() + digits + return sym, ex + + +def to_vnpy_exchange(ex_name: str): + if Exchange is None: + raise ImportError("vnpy 未安装") + mapping = { + "SHFE": Exchange.SHFE, + "DCE": Exchange.DCE, + "CZCE": Exchange.CZCE, + "CFFEX": Exchange.CFFEX, + "INE": Exchange.INE, + } + ex = mapping.get((ex_name or "").upper()) + if ex is None: + raise ValueError(f"未知交易所: {ex_name}") + return ex diff --git a/ctp_trade_sync.py b/modules/ctp/ctp_trade_sync.py similarity index 92% rename from ctp_trade_sync.py rename to modules/ctp/ctp_trade_sync.py index 4256107..049049e 100644 --- a/ctp_trade_sync.py +++ b/modules/ctp/ctp_trade_sync.py @@ -1,337 +1,337 @@ -# Copyright (c) 2025-2026 马建军. All rights reserved. -# 专有软件 — 未经授权禁止复制、传播、转售。 -# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 -# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md - -"""从 CTP 柜台同步成交,写入 trade_logs(以交易所成交为准)。""" -from __future__ import annotations - -import logging -from collections import defaultdict -from datetime import datetime -from typing import Any, Callable, Optional -from zoneinfo import ZoneInfo - -from contract_specs import calc_position_metrics -from ctp_symbol import ths_to_vnpy_symbol -from fee_specs import calc_round_trip_fee -from symbols import ths_to_codes -from trade_log_lib import ( - calc_equity_after, - purge_duplicate_local_trade_logs, - ensure_trade_log_columns, - refresh_trade_log_equity_chain, -) -from vnpy_bridge import ctp_list_trades, ctp_status - -logger = logging.getLogger(__name__) -TZ = ZoneInfo("Asia/Shanghai") - - -def _match_symbol(ctp_sym: str, ths: str) -> bool: - a = (ctp_sym or "").lower() - b = (ths or "").lower() - if a == b: - return True - if a and b and a.split(".")[0] == b.split(".")[0]: - return True - try: - vnpy_sym, _ = ths_to_vnpy_symbol(ths) - if a == vnpy_sym.lower(): - return True - except Exception: - pass - return False - - -def _to_ths_code(symbol: str) -> str: - sym = (symbol or "").strip() - if not sym: - return "" - codes = ths_to_codes(sym) - if codes: - return codes.get("ths_code") or sym - return sym.lower() - - -def _allocate_commission(total_comm: float, matched: int, total_lots: int) -> float: - if total_comm <= 0 or matched <= 0 or total_lots <= 0: - return 0.0 - return round(total_comm * matched / total_lots, 2) - - -def build_round_trips(trades: list[dict[str, Any]]) -> list[dict[str, Any]]: - """按 FIFO 将开/平仓成交配对为完整回合。""" - stacks: dict[tuple[str, str], list[dict[str, Any]]] = defaultdict(list) - trips: list[dict[str, Any]] = [] - - ordered = sorted( - trades, - key=lambda t: ((t.get("datetime") or ""), str(t.get("trade_id") or "")), - ) - for t in ordered: - sym = (t.get("symbol") or "").lower() - pos_dir = (t.get("position_direction") or "long").strip().lower() - offset = (t.get("offset") or "open").strip().lower() - lots = int(t.get("lots") or 0) - if not sym or lots <= 0: - continue - key = (sym, pos_dir) - if offset == "open": - stacks[key].append({ - **t, - "remaining": lots, - "commission_remaining": float(t.get("commission") or 0), - }) - continue - - close_lots_total = lots - close_lots_left = lots - close_price = float(t.get("price") or 0) - close_time = t.get("datetime") or "" - close_trade_id = str(t.get("trade_id") or "") - close_comm_total = float(t.get("commission") or 0) - while close_lots_left > 0 and stacks[key]: - open_t = stacks[key][0] - open_rem = int(open_t.get("remaining") or 0) - matched = min(close_lots_left, open_rem) - if matched <= 0: - stacks[key].pop(0) - continue - open_comm_rem = float(open_t.get("commission_remaining") or 0) - open_comm_share = ( - _allocate_commission(open_comm_rem, matched, open_rem) - if open_rem > 0 else 0.0 - ) - close_comm_share = _allocate_commission( - close_comm_total, matched, close_lots_total, - ) - open_t["remaining"] = open_rem - matched - open_t["commission_remaining"] = round( - max(0.0, open_comm_rem - open_comm_share), 2, - ) - if open_t["remaining"] <= 0: - stacks[key].pop(0) - close_lots_left -= matched - open_trade_id = str(open_t.get("trade_id") or "") - ctp_key = f"{open_trade_id}|{close_trade_id}|{sym}|{pos_dir}|{matched}" - trip_fee = round(open_comm_share + close_comm_share, 2) - trips.append({ - "ctp_trade_key": ctp_key, - "symbol": sym, - "ths_code": _to_ths_code(sym), - "direction": pos_dir, - "lots": matched, - "entry_price": float(open_t.get("price") or 0), - "close_price": close_price, - "open_time": open_t.get("datetime") or "", - "close_time": close_time, - "open_trade_id": open_trade_id, - "close_trade_id": close_trade_id, - "fee": trip_fee, - "fee_from_ctp": trip_fee > 0, - }) - return trips - - -def _find_monitor_meta( - conn, - *, - symbol: str, - direction: str, - open_time: str, - match_symbol_fn: Callable[[str, str], bool] | None = None, -) -> dict[str, Any]: - match = match_symbol_fn or _match_symbol - direction = (direction or "long").strip().lower() - best: Optional[dict[str, Any]] = None - for r in conn.execute( - "SELECT * FROM trade_order_monitors ORDER BY id DESC LIMIT 200" - ).fetchall(): - row = dict(r) - if (row.get("direction") or "long").strip().lower() != direction: - continue - if not match(symbol, row.get("symbol") or ""): - continue - if best is None: - best = row - continue - ot = (row.get("open_time") or "").strip() - if open_time and ot and abs(len(ot) - len(open_time)) <= 2 and ot[:16] == open_time[:16]: - return row - return best or {} - - -def _holding_minutes(open_time: str, close_time: str) -> int: - try: - from app import holding_to_minutes - return int(holding_to_minutes(open_time, close_time) or 0) - except Exception: - return 0 - - -def sync_trade_logs_from_ctp( - conn, - mode: str, - *, - capital: float = 0.0, - trading_mode: str = "simulation", -) -> dict[str, Any]: - """查询 CTP 成交并 upsert 到 trade_logs。返回同步摘要。""" - stats = {"synced": 0, "updated": 0, "skipped": 0, "connected": False} - if not ctp_status(mode).get("connected"): - return stats - stats["connected"] = True - ensure_trade_log_columns(conn) - try: - conn.execute("ALTER TABLE trade_logs ADD COLUMN source TEXT DEFAULT 'local'") - except Exception: - pass - try: - conn.execute("ALTER TABLE trade_logs ADD COLUMN ctp_trade_key TEXT") - except Exception: - pass - - trades = ctp_list_trades(mode, refresh=True) - trips = build_round_trips(trades) - for trip in trips: - key = trip.get("ctp_trade_key") or "" - if not key: - stats["skipped"] += 1 - continue - existing = conn.execute( - "SELECT id FROM trade_logs WHERE ctp_trade_key=?", - (key,), - ).fetchone() - - ths = trip.get("ths_code") or trip.get("symbol") or "" - codes = ths_to_codes(ths) or {} - direction = trip.get("direction") or "long" - entry = float(trip.get("entry_price") or 0) - close_px = float(trip.get("close_price") or 0) - lots = float(trip.get("lots") or 0) - open_time = trip.get("open_time") or "" - close_time = trip.get("close_time") or datetime.now(TZ).strftime("%Y-%m-%dT%H:%M") - - mon = _find_monitor_meta( - conn, - symbol=trip.get("symbol") or ths, - direction=direction, - open_time=open_time, - ) - sl = mon.get("stop_loss") - tp = mon.get("take_profit") - try: - sl_f = float(sl) if sl is not None else entry - tp_f = float(tp) if tp is not None else entry - except (TypeError, ValueError): - sl_f, tp_f = entry, entry - - metrics = calc_position_metrics( - direction, entry, sl_f, tp_f, lots, close_px, capital, ths, - ) - pnl = float(metrics.get("float_pnl") or 0) - trip_fee = float(trip.get("fee") or 0) - if trip_fee > 0: - fee = round(trip_fee, 2) - else: - fee = calc_round_trip_fee( - ths, entry, close_px, lots, open_time, close_time, trading_mode=trading_mode, - ) - pnl_net = round(pnl - fee, 2) - margin_pct = metrics.get("position_pct") - equity_after = calc_equity_after(capital, pnl_net) - minutes = _holding_minutes(open_time, close_time) - result = "CTP同步" - monitor_type = mon.get("monitor_type") or "CTP同步" - - row_vals = ( - ths, - codes.get("name") or mon.get("symbol_name") or ths, - codes.get("market_code") or mon.get("market_code") or "", - codes.get("sina_code") or mon.get("sina_code") or "", - monitor_type, - direction, - entry, - sl if sl is not None else None, - tp if tp is not None else None, - close_px, - lots, - metrics.get("margin"), - margin_pct, - minutes, - open_time, - close_time, - pnl, - fee, - pnl_net, - equity_after, - result, - ) - if existing: - conn.execute( - """UPDATE trade_logs SET - symbol=?, symbol_name=?, market_code=?, sina_code=?, monitor_type=?, - direction=?, entry_price=?, stop_loss=?, take_profit=?, close_price=?, - lots=?, margin=?, margin_pct=?, holding_minutes=?, open_time=?, close_time=?, - pnl=?, fee=?, pnl_net=?, equity_after=?, result=?, source='ctp', verified=1 - WHERE ctp_trade_key=?""", - row_vals + (key,), - ) - stats["updated"] += 1 - else: - conn.execute( - """INSERT INTO trade_logs - (symbol, symbol_name, market_code, sina_code, monitor_type, direction, - entry_price, stop_loss, take_profit, close_price, lots, margin, - margin_pct, holding_minutes, open_time, close_time, pnl, fee, pnl_net, - equity_after, result, source, ctp_trade_key, verified) - VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)""", - row_vals + ("ctp", key, 1), - ) - stats["synced"] += 1 - try: - from trade_notify import notify_trade_log_close - from trading_context import trading_mode_label - from app import get_setting, send_wechat_msg - from ai_worker import schedule_ai_event_analysis - from db_conn import DB_PATH - - notify_trade_log_close( - send_wechat=send_wechat_msg, - get_setting=get_setting, - mode_label=trading_mode_label(get_setting), - capital=capital, - sym=ths, - symbol_name=codes.get("name") or mon.get("symbol_name") or ths, - direction=direction, - entry=entry, - close_price=close_px, - sl=float(sl) if sl is not None else None, - tp=float(tp) if tp is not None else None, - lots=lots, - pnl_net=pnl_net, - equity_after=equity_after, - holding_minutes=minutes, - result=result, - monitor_type=monitor_type, - schedule_ai_fn=schedule_ai_event_analysis, - db_path=DB_PATH, - ) - except Exception as exc: - logger.debug("ctp close notify: %s", exc) - - if stats["synced"] or stats["updated"]: - try: - from stats_engine import refresh_stats_cache - refresh_stats_cache(conn, capital) - except Exception as exc: - logger.debug("stats refresh after ctp trade sync: %s", exc) - purged = purge_duplicate_local_trade_logs(conn) - if purged: - stats["purged"] = purged - try: - refresh_trade_log_equity_chain(conn) - except Exception as exc: - logger.debug("equity chain refresh after ctp sync: %s", exc) - return stats +# Copyright (c) 2025-2026 马建军. All rights reserved. +# 专有软件 — 未经授权禁止复制、传播、转售。 +# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 +# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md + +"""从 CTP 柜台同步成交,写入 trade_logs(以交易所成交为准)。""" +from __future__ import annotations + +import logging +from collections import defaultdict +from datetime import datetime +from typing import Any, Callable, Optional +from zoneinfo import ZoneInfo + +from modules.core.contract_specs import calc_position_metrics +from modules.ctp.ctp_symbol import ths_to_vnpy_symbol +from modules.fees.fee_specs import calc_round_trip_fee +from modules.core.symbols import ths_to_codes +from modules.trading.trade_log_lib import ( + calc_equity_after, + purge_duplicate_local_trade_logs, + ensure_trade_log_columns, + refresh_trade_log_equity_chain, +) +from modules.ctp.vnpy_bridge import ctp_list_trades, ctp_status + +logger = logging.getLogger(__name__) +TZ = ZoneInfo("Asia/Shanghai") + + +def _match_symbol(ctp_sym: str, ths: str) -> bool: + a = (ctp_sym or "").lower() + b = (ths or "").lower() + if a == b: + return True + if a and b and a.split(".")[0] == b.split(".")[0]: + return True + try: + vnpy_sym, _ = ths_to_vnpy_symbol(ths) + if a == vnpy_sym.lower(): + return True + except Exception: + pass + return False + + +def _to_ths_code(symbol: str) -> str: + sym = (symbol or "").strip() + if not sym: + return "" + codes = ths_to_codes(sym) + if codes: + return codes.get("ths_code") or sym + return sym.lower() + + +def _allocate_commission(total_comm: float, matched: int, total_lots: int) -> float: + if total_comm <= 0 or matched <= 0 or total_lots <= 0: + return 0.0 + return round(total_comm * matched / total_lots, 2) + + +def build_round_trips(trades: list[dict[str, Any]]) -> list[dict[str, Any]]: + """按 FIFO 将开/平仓成交配对为完整回合。""" + stacks: dict[tuple[str, str], list[dict[str, Any]]] = defaultdict(list) + trips: list[dict[str, Any]] = [] + + ordered = sorted( + trades, + key=lambda t: ((t.get("datetime") or ""), str(t.get("trade_id") or "")), + ) + for t in ordered: + sym = (t.get("symbol") or "").lower() + pos_dir = (t.get("position_direction") or "long").strip().lower() + offset = (t.get("offset") or "open").strip().lower() + lots = int(t.get("lots") or 0) + if not sym or lots <= 0: + continue + key = (sym, pos_dir) + if offset == "open": + stacks[key].append({ + **t, + "remaining": lots, + "commission_remaining": float(t.get("commission") or 0), + }) + continue + + close_lots_total = lots + close_lots_left = lots + close_price = float(t.get("price") or 0) + close_time = t.get("datetime") or "" + close_trade_id = str(t.get("trade_id") or "") + close_comm_total = float(t.get("commission") or 0) + while close_lots_left > 0 and stacks[key]: + open_t = stacks[key][0] + open_rem = int(open_t.get("remaining") or 0) + matched = min(close_lots_left, open_rem) + if matched <= 0: + stacks[key].pop(0) + continue + open_comm_rem = float(open_t.get("commission_remaining") or 0) + open_comm_share = ( + _allocate_commission(open_comm_rem, matched, open_rem) + if open_rem > 0 else 0.0 + ) + close_comm_share = _allocate_commission( + close_comm_total, matched, close_lots_total, + ) + open_t["remaining"] = open_rem - matched + open_t["commission_remaining"] = round( + max(0.0, open_comm_rem - open_comm_share), 2, + ) + if open_t["remaining"] <= 0: + stacks[key].pop(0) + close_lots_left -= matched + open_trade_id = str(open_t.get("trade_id") or "") + ctp_key = f"{open_trade_id}|{close_trade_id}|{sym}|{pos_dir}|{matched}" + trip_fee = round(open_comm_share + close_comm_share, 2) + trips.append({ + "ctp_trade_key": ctp_key, + "symbol": sym, + "ths_code": _to_ths_code(sym), + "direction": pos_dir, + "lots": matched, + "entry_price": float(open_t.get("price") or 0), + "close_price": close_price, + "open_time": open_t.get("datetime") or "", + "close_time": close_time, + "open_trade_id": open_trade_id, + "close_trade_id": close_trade_id, + "fee": trip_fee, + "fee_from_ctp": trip_fee > 0, + }) + return trips + + +def _find_monitor_meta( + conn, + *, + symbol: str, + direction: str, + open_time: str, + match_symbol_fn: Callable[[str, str], bool] | None = None, +) -> dict[str, Any]: + match = match_symbol_fn or _match_symbol + direction = (direction or "long").strip().lower() + best: Optional[dict[str, Any]] = None + for r in conn.execute( + "SELECT * FROM trade_order_monitors ORDER BY id DESC LIMIT 200" + ).fetchall(): + row = dict(r) + if (row.get("direction") or "long").strip().lower() != direction: + continue + if not match(symbol, row.get("symbol") or ""): + continue + if best is None: + best = row + continue + ot = (row.get("open_time") or "").strip() + if open_time and ot and abs(len(ot) - len(open_time)) <= 2 and ot[:16] == open_time[:16]: + return row + return best or {} + + +def _holding_minutes(open_time: str, close_time: str) -> int: + try: + from app import holding_to_minutes + return int(holding_to_minutes(open_time, close_time) or 0) + except Exception: + return 0 + + +def sync_trade_logs_from_ctp( + conn, + mode: str, + *, + capital: float = 0.0, + trading_mode: str = "simulation", +) -> dict[str, Any]: + """查询 CTP 成交并 upsert 到 trade_logs。返回同步摘要。""" + stats = {"synced": 0, "updated": 0, "skipped": 0, "connected": False} + if not ctp_status(mode).get("connected"): + return stats + stats["connected"] = True + ensure_trade_log_columns(conn) + try: + conn.execute("ALTER TABLE trade_logs ADD COLUMN source TEXT DEFAULT 'local'") + except Exception: + pass + try: + conn.execute("ALTER TABLE trade_logs ADD COLUMN ctp_trade_key TEXT") + except Exception: + pass + + trades = ctp_list_trades(mode, refresh=True) + trips = build_round_trips(trades) + for trip in trips: + key = trip.get("ctp_trade_key") or "" + if not key: + stats["skipped"] += 1 + continue + existing = conn.execute( + "SELECT id FROM trade_logs WHERE ctp_trade_key=?", + (key,), + ).fetchone() + + ths = trip.get("ths_code") or trip.get("symbol") or "" + codes = ths_to_codes(ths) or {} + direction = trip.get("direction") or "long" + entry = float(trip.get("entry_price") or 0) + close_px = float(trip.get("close_price") or 0) + lots = float(trip.get("lots") or 0) + open_time = trip.get("open_time") or "" + close_time = trip.get("close_time") or datetime.now(TZ).strftime("%Y-%m-%dT%H:%M") + + mon = _find_monitor_meta( + conn, + symbol=trip.get("symbol") or ths, + direction=direction, + open_time=open_time, + ) + sl = mon.get("stop_loss") + tp = mon.get("take_profit") + try: + sl_f = float(sl) if sl is not None else entry + tp_f = float(tp) if tp is not None else entry + except (TypeError, ValueError): + sl_f, tp_f = entry, entry + + metrics = calc_position_metrics( + direction, entry, sl_f, tp_f, lots, close_px, capital, ths, + ) + pnl = float(metrics.get("float_pnl") or 0) + trip_fee = float(trip.get("fee") or 0) + if trip_fee > 0: + fee = round(trip_fee, 2) + else: + fee = calc_round_trip_fee( + ths, entry, close_px, lots, open_time, close_time, trading_mode=trading_mode, + ) + pnl_net = round(pnl - fee, 2) + margin_pct = metrics.get("position_pct") + equity_after = calc_equity_after(capital, pnl_net) + minutes = _holding_minutes(open_time, close_time) + result = "CTP同步" + monitor_type = mon.get("monitor_type") or "CTP同步" + + row_vals = ( + ths, + codes.get("name") or mon.get("symbol_name") or ths, + codes.get("market_code") or mon.get("market_code") or "", + codes.get("sina_code") or mon.get("sina_code") or "", + monitor_type, + direction, + entry, + sl if sl is not None else None, + tp if tp is not None else None, + close_px, + lots, + metrics.get("margin"), + margin_pct, + minutes, + open_time, + close_time, + pnl, + fee, + pnl_net, + equity_after, + result, + ) + if existing: + conn.execute( + """UPDATE trade_logs SET + symbol=?, symbol_name=?, market_code=?, sina_code=?, monitor_type=?, + direction=?, entry_price=?, stop_loss=?, take_profit=?, close_price=?, + lots=?, margin=?, margin_pct=?, holding_minutes=?, open_time=?, close_time=?, + pnl=?, fee=?, pnl_net=?, equity_after=?, result=?, source='ctp', verified=1 + WHERE ctp_trade_key=?""", + row_vals + (key,), + ) + stats["updated"] += 1 + else: + conn.execute( + """INSERT INTO trade_logs + (symbol, symbol_name, market_code, sina_code, monitor_type, direction, + entry_price, stop_loss, take_profit, close_price, lots, margin, + margin_pct, holding_minutes, open_time, close_time, pnl, fee, pnl_net, + equity_after, result, source, ctp_trade_key, verified) + VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)""", + row_vals + ("ctp", key, 1), + ) + stats["synced"] += 1 + try: + from modules.trading.trade_notify import notify_trade_log_close + from modules.core.trading_context import trading_mode_label + from app import get_setting, send_wechat_msg + from modules.notify.ai_worker import schedule_ai_event_analysis + from modules.core.db_conn import DB_PATH + + notify_trade_log_close( + send_wechat=send_wechat_msg, + get_setting=get_setting, + mode_label=trading_mode_label(get_setting), + capital=capital, + sym=ths, + symbol_name=codes.get("name") or mon.get("symbol_name") or ths, + direction=direction, + entry=entry, + close_price=close_px, + sl=float(sl) if sl is not None else None, + tp=float(tp) if tp is not None else None, + lots=lots, + pnl_net=pnl_net, + equity_after=equity_after, + holding_minutes=minutes, + result=result, + monitor_type=monitor_type, + schedule_ai_fn=schedule_ai_event_analysis, + db_path=DB_PATH, + ) + except Exception as exc: + logger.debug("ctp close notify: %s", exc) + + if stats["synced"] or stats["updated"]: + try: + from modules.stats.stats_engine import refresh_stats_cache + refresh_stats_cache(conn, capital) + except Exception as exc: + logger.debug("stats refresh after ctp trade sync: %s", exc) + purged = purge_duplicate_local_trade_logs(conn) + if purged: + stats["purged"] = purged + try: + refresh_trade_log_equity_chain(conn) + except Exception as exc: + logger.debug("equity chain refresh after ctp sync: %s", exc) + return stats diff --git a/ctp_trading_state.py b/modules/ctp/ctp_trading_state.py similarity index 96% rename from ctp_trading_state.py rename to modules/ctp/ctp_trading_state.py index 854cf92..98f683e 100644 --- a/ctp_trading_state.py +++ b/modules/ctp/ctp_trading_state.py @@ -1,270 +1,270 @@ -# Copyright (c) 2025-2026 马建军. All rights reserved. -# 专有软件 — 未经授权禁止复制、传播、转售。 -# 详见 LICENSE.zh-CN.txt - -"""CTP 权威内存簿:委托、持仓、同步状态(事件增量 + 定期全量校准)。""" -from __future__ import annotations - -import logging -import threading -import time -from typing import Any, Callable, Optional - -logger = logging.getLogger(__name__) - -CALIBRATE_INTERVAL_SEC = 30.0 - - -def position_key(exchange: str, symbol: str, direction: str) -> str: - """统一持仓键:exchange|symbol|direction""" - ex = (exchange or "").strip().upper() - sym = (symbol or "").strip().lower() - d = (direction or "long").strip().lower() - if ex: - return f"{ex}|{sym}|{d}" - return f"{sym}|{d}" - - -def parse_position_key(key: str) -> tuple[str, str, str]: - parts = (key or "").split("|") - if len(parts) >= 3: - return parts[0], parts[1], parts[2] - if len(parts) == 2: - return "", parts[0], parts[1] - return "", (key or "").lower(), "long" - - -def reconcile_position_avg( - old: Optional[dict[str, Any]], - new: dict[str, Any], - tick: Optional[float], - *, - trades: Optional[list[dict[str, Any]]] = None, - ths_sym: str = "", -) -> dict[str, Any]: - """手数变化时采用柜台回报均价;手数不变时保持已锁定柜台价。""" - del tick, trades - from ctp_entry_price import round_to_tick - - row = dict(new) - lots = int(row.get("lots") or 0) - if lots <= 0: - return row - old_lots = int(old.get("lots") or 0) if old else 0 - lots_changed = not old or old_lots != lots - sym = ths_sym or (row.get("symbol") or "") - - pos_avg = float(row.get("avg_price") or 0) - if pos_avg > 0: - row["avg_price"] = round_to_tick(pos_avg, sym) - row["avg_price_locked"] = True - return row - - if not lots_changed and old and float(old.get("avg_price") or 0) > 0: - row["avg_price"] = float(old["avg_price"]) - row["avg_price_locked"] = True - return row - - -class CtpTradingState: - """进程内 CTP 快照:柜台回报为准,SQLite 仅挂 SL/TP 元数据。""" - - def __init__(self) -> None: - self._lock = threading.RLock() - self._orders: dict[str, dict[str, Any]] = {} - self._positions: dict[str, dict[str, Any]] = {} - self._tick_prices: dict[str, float] = {} - self._sync_state = "idle" - self._last_event_ts: float = 0.0 - self._last_calibrate_ts: float = 0.0 - self._on_change: Optional[Callable[[], None]] = None - - def set_change_callback(self, fn: Optional[Callable[[], None]]) -> None: - self._on_change = fn - - def _notify(self) -> None: - self._last_event_ts = time.time() - fn = self._on_change - if fn: - try: - fn() - except Exception as exc: - logger.debug("trading state change callback: %s", exc) - - @property - def sync_state(self) -> str: - with self._lock: - return self._sync_state - - def sync_label(self) -> str: - st = self.sync_state - if st == "syncing": - return "同步中…" - if st == "ready": - return "已同步" - return "" - - def begin_sync(self) -> None: - with self._lock: - self._sync_state = "syncing" - - def finish_sync(self) -> None: - with self._lock: - self._sync_state = "ready" - self._last_calibrate_ts = time.time() - - def needs_calibrate(self) -> bool: - with self._lock: - if self._sync_state == "idle": - return True - return (time.time() - self._last_calibrate_ts) >= CALIBRATE_INTERVAL_SEC - - def upsert_order(self, row: dict[str, Any], *, notify: bool = True) -> None: - oid = str(row.get("order_id") or row.get("vt_order_id") or "").strip() - if not oid: - return - with self._lock: - self._orders[oid] = dict(row) - if notify: - self._notify() - - def remove_order(self, order_id: str, *, notify: bool = True) -> None: - oid = (order_id or "").strip() - if not oid: - return - removed = False - with self._lock: - if oid in self._orders: - del self._orders[oid] - removed = True - else: - for k in list(self._orders.keys()): - if k == oid or k.endswith(oid) or oid.endswith(k): - del self._orders[k] - removed = True - break - if removed and notify: - self._notify() - - def get_position(self, pk: str) -> Optional[dict[str, Any]]: - with self._lock: - row = self._positions.get(pk) - return dict(row) if row else None - - def try_lock_entry_prices(self) -> bool: - """均价以柜台为准,不按 tick 反推(避免均价随行情跳动)。""" - return False - - def upsert_position( - self, - row: dict[str, Any], - *, - notify: bool = True, - trades: Optional[list[dict[str, Any]]] = None, - ths_sym: str = "", - ) -> None: - lots = int(row.get("lots") or 0) - ex = row.get("exchange") or "" - sym = row.get("symbol") or "" - direction = row.get("direction") or "long" - pk = position_key(ex, sym, direction) - tick = self.get_tick_price(ex, sym) - with self._lock: - if lots <= 0: - self._positions.pop(pk, None) - else: - old = self._positions.get(pk) - row = reconcile_position_avg( - old, dict(row), tick, trades=trades, ths_sym=ths_sym or sym, - ) - row["position_key"] = pk - self._positions[pk] = row - if notify: - self._notify() - - def remove_position(self, pk: str, *, notify: bool = True) -> None: - with self._lock: - self._positions.pop(pk, None) - if notify: - self._notify() - - def set_tick_price(self, exchange: str, symbol: str, price: float) -> None: - if not symbol or price <= 0: - return - key = f"{(exchange or '').upper()}|{symbol.lower()}" - with self._lock: - self._tick_prices[key] = float(price) - - def get_tick_price(self, exchange: str, symbol: str) -> Optional[float]: - key = f"{(exchange or '').upper()}|{symbol.lower()}" - with self._lock: - return self._tick_prices.get(key) - - def get_active_orders(self) -> list[dict[str, Any]]: - with self._lock: - return list(self._orders.values()) - - def get_positions(self) -> list[dict[str, Any]]: - with self._lock: - return list(self._positions.values()) - - def position_keys(self) -> set[str]: - with self._lock: - return set(self._positions.keys()) - - def clear(self) -> None: - with self._lock: - self._orders.clear() - self._positions.clear() - self._tick_prices.clear() - self._sync_state = "idle" - - def calibrate_from_lists( - self, - orders: list[dict[str, Any]], - positions: list[dict[str, Any]], - *, - trades: Optional[list[dict[str, Any]]] = None, - ths_for_vnpy_sym: Optional[Callable[[str, str], str]] = None, - preserve_positions_if_margin: float = 0.0, - ) -> None: - """全量校准:以 vnpy 内存为准重建订单/持仓簿。""" - self.begin_sync() - new_orders: dict[str, dict[str, Any]] = {} - for o in orders or []: - oid = str(o.get("order_id") or o.get("vt_order_id") or "").strip() - if oid: - new_orders[oid] = dict(o) - new_positions: dict[str, dict[str, Any]] = {} - for p in positions or []: - lots = int(p.get("lots") or 0) - if lots <= 0: - continue - ex = p.get("exchange") or "" - sym = p.get("symbol") or "" - direction = p.get("direction") or "long" - pk = position_key(ex, sym, direction) - row = dict(p) - row["position_key"] = pk - old = self._positions.get(pk) - tick = self.get_tick_price(ex, sym) - ths = sym - if ths_for_vnpy_sym: - try: - ths = ths_for_vnpy_sym(sym, ex) or sym - except Exception: - ths = sym - new_positions[pk] = reconcile_position_avg( - old, row, tick, trades=trades, ths_sym=ths, - ) - if not new_positions and self._positions and preserve_positions_if_margin > 0: - with self._lock: - new_positions = {k: dict(v) for k, v in self._positions.items()} - with self._lock: - self._orders = new_orders - self._positions = new_positions - self.finish_sync() - self._notify() - - -trading_state = CtpTradingState() +# Copyright (c) 2025-2026 马建军. All rights reserved. +# 专有软件 — 未经授权禁止复制、传播、转售。 +# 详见 LICENSE.zh-CN.txt + +"""CTP 权威内存簿:委托、持仓、同步状态(事件增量 + 定期全量校准)。""" +from __future__ import annotations + +import logging +import threading +import time +from typing import Any, Callable, Optional + +logger = logging.getLogger(__name__) + +CALIBRATE_INTERVAL_SEC = 30.0 + + +def position_key(exchange: str, symbol: str, direction: str) -> str: + """统一持仓键:exchange|symbol|direction""" + ex = (exchange or "").strip().upper() + sym = (symbol or "").strip().lower() + d = (direction or "long").strip().lower() + if ex: + return f"{ex}|{sym}|{d}" + return f"{sym}|{d}" + + +def parse_position_key(key: str) -> tuple[str, str, str]: + parts = (key or "").split("|") + if len(parts) >= 3: + return parts[0], parts[1], parts[2] + if len(parts) == 2: + return "", parts[0], parts[1] + return "", (key or "").lower(), "long" + + +def reconcile_position_avg( + old: Optional[dict[str, Any]], + new: dict[str, Any], + tick: Optional[float], + *, + trades: Optional[list[dict[str, Any]]] = None, + ths_sym: str = "", +) -> dict[str, Any]: + """手数变化时采用柜台回报均价;手数不变时保持已锁定柜台价。""" + del tick, trades + from modules.ctp.ctp_entry_price import round_to_tick + + row = dict(new) + lots = int(row.get("lots") or 0) + if lots <= 0: + return row + old_lots = int(old.get("lots") or 0) if old else 0 + lots_changed = not old or old_lots != lots + sym = ths_sym or (row.get("symbol") or "") + + pos_avg = float(row.get("avg_price") or 0) + if pos_avg > 0: + row["avg_price"] = round_to_tick(pos_avg, sym) + row["avg_price_locked"] = True + return row + + if not lots_changed and old and float(old.get("avg_price") or 0) > 0: + row["avg_price"] = float(old["avg_price"]) + row["avg_price_locked"] = True + return row + + +class CtpTradingState: + """进程内 CTP 快照:柜台回报为准,SQLite 仅挂 SL/TP 元数据。""" + + def __init__(self) -> None: + self._lock = threading.RLock() + self._orders: dict[str, dict[str, Any]] = {} + self._positions: dict[str, dict[str, Any]] = {} + self._tick_prices: dict[str, float] = {} + self._sync_state = "idle" + self._last_event_ts: float = 0.0 + self._last_calibrate_ts: float = 0.0 + self._on_change: Optional[Callable[[], None]] = None + + def set_change_callback(self, fn: Optional[Callable[[], None]]) -> None: + self._on_change = fn + + def _notify(self) -> None: + self._last_event_ts = time.time() + fn = self._on_change + if fn: + try: + fn() + except Exception as exc: + logger.debug("trading state change callback: %s", exc) + + @property + def sync_state(self) -> str: + with self._lock: + return self._sync_state + + def sync_label(self) -> str: + st = self.sync_state + if st == "syncing": + return "同步中…" + if st == "ready": + return "已同步" + return "" + + def begin_sync(self) -> None: + with self._lock: + self._sync_state = "syncing" + + def finish_sync(self) -> None: + with self._lock: + self._sync_state = "ready" + self._last_calibrate_ts = time.time() + + def needs_calibrate(self) -> bool: + with self._lock: + if self._sync_state == "idle": + return True + return (time.time() - self._last_calibrate_ts) >= CALIBRATE_INTERVAL_SEC + + def upsert_order(self, row: dict[str, Any], *, notify: bool = True) -> None: + oid = str(row.get("order_id") or row.get("vt_order_id") or "").strip() + if not oid: + return + with self._lock: + self._orders[oid] = dict(row) + if notify: + self._notify() + + def remove_order(self, order_id: str, *, notify: bool = True) -> None: + oid = (order_id or "").strip() + if not oid: + return + removed = False + with self._lock: + if oid in self._orders: + del self._orders[oid] + removed = True + else: + for k in list(self._orders.keys()): + if k == oid or k.endswith(oid) or oid.endswith(k): + del self._orders[k] + removed = True + break + if removed and notify: + self._notify() + + def get_position(self, pk: str) -> Optional[dict[str, Any]]: + with self._lock: + row = self._positions.get(pk) + return dict(row) if row else None + + def try_lock_entry_prices(self) -> bool: + """均价以柜台为准,不按 tick 反推(避免均价随行情跳动)。""" + return False + + def upsert_position( + self, + row: dict[str, Any], + *, + notify: bool = True, + trades: Optional[list[dict[str, Any]]] = None, + ths_sym: str = "", + ) -> None: + lots = int(row.get("lots") or 0) + ex = row.get("exchange") or "" + sym = row.get("symbol") or "" + direction = row.get("direction") or "long" + pk = position_key(ex, sym, direction) + tick = self.get_tick_price(ex, sym) + with self._lock: + if lots <= 0: + self._positions.pop(pk, None) + else: + old = self._positions.get(pk) + row = reconcile_position_avg( + old, dict(row), tick, trades=trades, ths_sym=ths_sym or sym, + ) + row["position_key"] = pk + self._positions[pk] = row + if notify: + self._notify() + + def remove_position(self, pk: str, *, notify: bool = True) -> None: + with self._lock: + self._positions.pop(pk, None) + if notify: + self._notify() + + def set_tick_price(self, exchange: str, symbol: str, price: float) -> None: + if not symbol or price <= 0: + return + key = f"{(exchange or '').upper()}|{symbol.lower()}" + with self._lock: + self._tick_prices[key] = float(price) + + def get_tick_price(self, exchange: str, symbol: str) -> Optional[float]: + key = f"{(exchange or '').upper()}|{symbol.lower()}" + with self._lock: + return self._tick_prices.get(key) + + def get_active_orders(self) -> list[dict[str, Any]]: + with self._lock: + return list(self._orders.values()) + + def get_positions(self) -> list[dict[str, Any]]: + with self._lock: + return list(self._positions.values()) + + def position_keys(self) -> set[str]: + with self._lock: + return set(self._positions.keys()) + + def clear(self) -> None: + with self._lock: + self._orders.clear() + self._positions.clear() + self._tick_prices.clear() + self._sync_state = "idle" + + def calibrate_from_lists( + self, + orders: list[dict[str, Any]], + positions: list[dict[str, Any]], + *, + trades: Optional[list[dict[str, Any]]] = None, + ths_for_vnpy_sym: Optional[Callable[[str, str], str]] = None, + preserve_positions_if_margin: float = 0.0, + ) -> None: + """全量校准:以 vnpy 内存为准重建订单/持仓簿。""" + self.begin_sync() + new_orders: dict[str, dict[str, Any]] = {} + for o in orders or []: + oid = str(o.get("order_id") or o.get("vt_order_id") or "").strip() + if oid: + new_orders[oid] = dict(o) + new_positions: dict[str, dict[str, Any]] = {} + for p in positions or []: + lots = int(p.get("lots") or 0) + if lots <= 0: + continue + ex = p.get("exchange") or "" + sym = p.get("symbol") or "" + direction = p.get("direction") or "long" + pk = position_key(ex, sym, direction) + row = dict(p) + row["position_key"] = pk + old = self._positions.get(pk) + tick = self.get_tick_price(ex, sym) + ths = sym + if ths_for_vnpy_sym: + try: + ths = ths_for_vnpy_sym(sym, ex) or sym + except Exception: + ths = sym + new_positions[pk] = reconcile_position_avg( + old, row, tick, trades=trades, ths_sym=ths, + ) + if not new_positions and self._positions and preserve_positions_if_margin > 0: + with self._lock: + new_positions = {k: dict(v) for k, v in self._positions.items()} + with self._lock: + self._orders = new_orders + self._positions = new_positions + self.finish_sync() + self._notify() + + +trading_state = CtpTradingState() diff --git a/ctp_worker.py b/modules/ctp/ctp_worker.py similarity index 90% rename from ctp_worker.py rename to modules/ctp/ctp_worker.py index f093f87..11add83 100644 --- a/ctp_worker.py +++ b/modules/ctp/ctp_worker.py @@ -1,494 +1,494 @@ -# Copyright (c) 2025-2026 马建军. All rights reserved. -# 专有软件 — 未经授权禁止复制、传播、转售。 -# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md - -"""Isolated local CTP worker. - -This process is the only process that should instantiate vn.py / vnpy_ctp. -The Flask web app talks to it through localhost HTTP via ctp_ipc_client.py. -""" -from __future__ import annotations - -import logging -import os -import threading -import time -from typing import Any - -os.environ.setdefault("QIHUO_CTP_ROLE", "worker") - -from flask import Flask, jsonify, request - -from ctp_ipc_client import worker_token -from db_conn import DB_PATH, commit_retry, connect_db -from fee_specs import get_setting, set_setting -from locale_fix import ensure_process_locale -from market_sessions import is_trading_session -from sl_tp_guard import check_sl_tp_on_tick, ensure_monitor_order_columns, start_sl_tp_guard_worker -from strategy.strategy_db import init_strategy_tables -from trading_context import get_account_capital, get_trading_mode, get_trailing_be_tick_buffer -from vnpy_bridge import ( - _ctp_td_lock, - ctp_cancel_order, - ctp_disconnect, - ctp_estimate_margin_one_lot, - ctp_get_account, - ctp_get_tick_detail, - ctp_get_tick_price, - ctp_list_active_orders, - ctp_list_positions, - ctp_list_trades, - ctp_lookup_contract_spec, - ctp_start_connect, - ctp_status, - ctp_try_auto_reconnect, - execute_order, - get_bridge, - set_ctp_connected_callback, - set_position_refresh_callback, - set_tick_quote_callback, - set_tick_sl_tp_callback, - try_init_vnpy, -) - - -logging.basicConfig( - level=os.getenv("LOG_LEVEL", "INFO"), - format="%(asctime)s %(levelname)s [%(name)s] %(message)s", -) -logger = logging.getLogger(__name__) - -app = Flask(__name__) -_started_workers = False -_last_snapshot_ts = 0.0 -_snapshot_lock = threading.Lock() - - -def _json_ok(**payload: Any): - return jsonify({"ok": True, **payload}) - - -def _json_error(exc: Exception, *, status_code: int = 500): - return jsonify({"ok": False, "error": str(exc)}), status_code - - -def _require_token() -> None: - expected = worker_token() - got = request.headers.get("X-Qihuo-CTP-Token", "") - if expected and got != expected: - raise PermissionError("unauthorized") - - -@app.before_request -def _auth(): - _require_token() - - -@app.errorhandler(Exception) -def _handle_error(exc: Exception): - code = 401 if isinstance(exc, PermissionError) else 500 - logger.warning("ctp worker request failed: %s", exc) - return _json_error(exc, status_code=code) - - -def _mode_from_request() -> str: - data = request.get_json(silent=True) or {} - return ( - data.get("mode") - or request.args.get("mode") - or get_trading_mode(get_setting) - or "simulation" - ) - - -def _fast_status(mode: str) -> dict[str, Any]: - """Return worker/native bridge state without slow network probing.""" - from ctp_settings import CTP_DISABLED_HINT, is_ctp_auto_connect_enabled - - try: - st = dict(get_bridge().status(mode) or {}) - except Exception as exc: - st = { - "connected": False, - "connecting": False, - "connected_mode": None, - "last_error": str(exc), - "mode_label": "SimNow" if mode == "simulation" else "期货公司实盘", - } - auto = is_ctp_auto_connect_enabled() - st["auto_connect_enabled"] = auto - st["worker_online"] = True - if not auto: - st["disabled_hint"] = CTP_DISABLED_HINT - if not st.get("connected") and not st.get("connecting"): - st["last_error"] = "" - st["td_reachable"] = None - return st - - -def _send_wechat_msg(content: str) -> None: - webhook = get_setting("wechat_webhook", "") - if not webhook: - return - try: - import requests - - requests.post( - webhook, - json={"msgtype": "text", "text": {"content": f"【国内期货】\n{content}"}}, - timeout=10, - ) - except Exception as exc: - logger.debug("wechat notify failed: %s", exc) - - -def _init_worker_tables(conn) -> None: - init_strategy_tables(conn) - ensure_monitor_order_columns(conn) - - -def _capital(conn) -> float: - try: - return float(get_account_capital(get_setting, conn=conn) or 0) - except Exception: - return 0.0 - - -def _persist_snapshot(mode: str) -> None: - global _last_snapshot_ts - with _snapshot_lock: - now = time.time() - if now - _last_snapshot_ts < 0.25: - return - _last_snapshot_ts = now - try: - import json - - st = _fast_status(mode) - positions = ctp_list_positions(mode, refresh_if_empty=False, refresh_margin=False) - account = ctp_get_account(mode) if st.get("connected") else {} - conn = connect_db(DB_PATH) - try: - conn.execute( - """CREATE TABLE IF NOT EXISTS ctp_worker_snapshots ( - key TEXT PRIMARY KEY, - value TEXT, - updated_at REAL - )""" - ) - for key, value in ( - ("status", st), - ("positions", positions), - ("account", account), - ): - conn.execute( - """INSERT INTO ctp_worker_snapshots(key, value, updated_at) - VALUES(?,?,?) - ON CONFLICT(key) DO UPDATE SET - value=excluded.value, - updated_at=excluded.updated_at""", - (key, json.dumps(value, ensure_ascii=False), now), - ) - commit_retry(conn) - finally: - conn.close() - except Exception as exc: - logger.debug("persist ctp snapshot: %s", exc) - - -def _on_position_refresh() -> None: - try: - _persist_snapshot(get_trading_mode(get_setting)) - except Exception as exc: - logger.debug("position refresh callback: %s", exc) - - -def _on_tick_quote() -> None: - _on_position_refresh() - - -def _on_tick_sl_tp(exchange: str, symbol: str, price: float) -> None: - mode = get_trading_mode(get_setting) - if not ctp_status(mode).get("connected"): - return - conn = connect_db(DB_PATH) - try: - _init_worker_tables(conn) - capital = _capital(conn) - n = check_sl_tp_on_tick( - conn, - mode, - exchange, - symbol, - price, - capital=capital, - notify_fn=_send_wechat_msg, - be_tick_mult=get_trailing_be_tick_buffer(get_setting), - ) - if n: - commit_retry(conn) - _persist_snapshot(mode) - except Exception as exc: - logger.warning("worker tick sl/tp: %s", exc) - finally: - conn.close() - - -def _on_ctp_connected(mode: str) -> None: - try: - with _ctp_td_lock: - get_bridge().request_position_snapshot(force=True) - get_bridge().calibrate_trading_state() - _persist_snapshot(mode) - except Exception as exc: - logger.debug("worker ctp connected callback: %s", exc) - - -def _start_background_workers() -> None: - global _started_workers - if _started_workers: - return - _started_workers = True - - set_position_refresh_callback(_on_position_refresh) - set_tick_quote_callback(_on_tick_quote) - set_tick_sl_tp_callback(_on_tick_sl_tp) - set_ctp_connected_callback(_on_ctp_connected) - - from ctp_fee_worker import start_ctp_fee_worker - from ctp_premarket_connect import start_ctp_premarket_connect_worker - from ctp_reconnect import start_ctp_reconnect_worker - from order_pending import reconcile_pending_orders - from pending_order_worker import start_pending_order_worker - - def _mode() -> str: - return get_trading_mode(get_setting) - - start_ctp_reconnect_worker(get_mode_fn=_mode, get_setting_fn=get_setting) - start_ctp_premarket_connect_worker(get_mode_fn=_mode, get_setting_fn=get_setting) - start_ctp_fee_worker( - get_mode_fn=_mode, - get_setting_fn=get_setting, - set_setting_fn=set_setting, - ) - start_pending_order_worker( - db_path=DB_PATH, - get_mode_fn=_mode, - init_tables_fn=_init_worker_tables, - get_capital_fn=_capital, - reconcile_fn=reconcile_pending_orders, - on_changed_fn=lambda: _persist_snapshot(_mode()), - ) - start_sl_tp_guard_worker( - db_path=DB_PATH, - get_mode_fn=_mode, - init_tables_fn=_init_worker_tables, - get_capital_fn=_capital, - get_be_tick_buffer_fn=lambda: get_trailing_be_tick_buffer(get_setting), - notify_fn=_send_wechat_msg, - ) - - def _snapshot_loop() -> None: - time.sleep(3) - while True: - try: - mode = _mode() - if _fast_status(mode).get("connected"): - _persist_snapshot(mode) - except Exception as exc: - logger.debug("worker snapshot loop: %s", exc) - time.sleep(2 if is_trading_session() else 15) - - threading.Thread(target=_snapshot_loop, daemon=True, name="ctp-worker-snapshot").start() - - -@app.route("/health") -def health(): - mode = request.args.get("mode") or get_trading_mode(get_setting) - st = _fast_status(mode) - return _json_ok( - worker_online=True, - role=os.getenv("QIHUO_CTP_ROLE", "worker"), - mode=mode, - status=st, - ts=time.time(), - ) - - -@app.route("/ctp/status") -def api_status(): - mode = _mode_from_request() - return _json_ok(status=_fast_status(mode)) - - -@app.route("/ctp/connect", methods=["POST"]) -def api_connect(): - data = request.get_json(silent=True) or {} - mode = data.get("mode") or get_trading_mode(get_setting) - info = ctp_start_connect(mode, force=bool(data.get("force"))) - st = info.get("status") or _fast_status(mode) - return _json_ok(status=st, **{k: v for k, v in info.items() if k != "status"}) - - -@app.route("/ctp/start_connect", methods=["POST"]) -def api_start_connect(): - data = request.get_json(silent=True) or {} - mode = data.get("mode") or get_trading_mode(get_setting) - return _json_ok(**ctp_start_connect( - mode, - force=bool(data.get("force")), - scheduled=bool(data.get("scheduled")), - )) - - -@app.route("/ctp/disconnect", methods=["POST"]) -def api_disconnect(): - data = request.get_json(silent=True) or {} - ctp_disconnect(set_disabled_hint=bool(data.get("set_disabled_hint"))) - return _json_ok(disconnected=True) - - -@app.route("/ctp/account") -def api_account(): - mode = _mode_from_request() - if not _fast_status(mode).get("connected"): - return _json_ok(account={}) - return _json_ok(account=ctp_get_account(mode)) - - -@app.route("/ctp/positions", methods=["POST"]) -def api_positions(): - data = request.get_json(silent=True) or {} - mode = data.get("mode") or get_trading_mode(get_setting) - return _json_ok(positions=ctp_list_positions( - mode, - refresh_if_empty=bool(data.get("refresh_if_empty", True)), - refresh_margin=bool(data.get("refresh_margin", False)), - )) - - -@app.route("/ctp/trades", methods=["POST"]) -def api_trades(): - data = request.get_json(silent=True) or {} - mode = data.get("mode") or get_trading_mode(get_setting) - return _json_ok(trades=ctp_list_trades(mode, refresh=bool(data.get("refresh")))) - - -@app.route("/ctp/active_orders") -def api_active_orders(): - mode = _mode_from_request() - return _json_ok(orders=ctp_list_active_orders(mode)) - - -@app.route("/ctp/tick_price", methods=["POST"]) -def api_tick_price(): - data = request.get_json(silent=True) or {} - return _json_ok(price=ctp_get_tick_price( - data.get("mode") or get_trading_mode(get_setting), - data.get("symbol") or "", - )) - - -@app.route("/ctp/tick_detail", methods=["POST"]) -def api_tick_detail(): - data = request.get_json(silent=True) or {} - return _json_ok(detail=ctp_get_tick_detail( - data.get("mode") or get_trading_mode(get_setting), - data.get("symbol") or "", - )) - - -@app.route("/ctp/estimate_margin_one_lot", methods=["POST"]) -def api_estimate_margin(): - data = request.get_json(silent=True) or {} - return _json_ok(margin=ctp_estimate_margin_one_lot( - data.get("mode") or get_trading_mode(get_setting), - data.get("symbol") or "", - float(data.get("price") or 0), - direction=data.get("direction") or "long", - )) - - -@app.route("/ctp/contract_spec", methods=["POST"]) -def api_contract_spec(): - data = request.get_json(silent=True) or {} - return _json_ok(spec=ctp_lookup_contract_spec( - data.get("mode") or get_trading_mode(get_setting), - data.get("symbol") or "", - )) - - -@app.route("/ctp/order", methods=["POST"]) -def api_order(): - data = request.get_json(silent=True) or {} - mode = data.get("mode") or get_trading_mode(get_setting) - result = execute_order( - None, - mode=mode, - offset=data.get("offset") or "open", - symbol=data.get("symbol") or "", - direction=data.get("direction") or "long", - lots=int(data.get("lots") or 1), - price=float(data.get("price") or 0), - settings=data.get("settings") or {}, - order_type=data.get("order_type") or "limit", - ) - _persist_snapshot(mode) - return _json_ok(**result) - - -@app.route("/ctp/cancel", methods=["POST"]) -def api_cancel(): - data = request.get_json(silent=True) or {} - mode = data.get("mode") or get_trading_mode(get_setting) - cancelled = ctp_cancel_order(mode, data.get("vt_orderid") or "") - _persist_snapshot(mode) - return _json_ok(cancelled=cancelled) - - -@app.route("/ctp/bridge/", methods=["POST"]) -def api_bridge_action(action: str): - data = request.get_json(silent=True) or {} - b = get_bridge() - if action == "calibrate_trading_state": - return _json_ok(result=b.calibrate_trading_state()) - if action == "request_position_snapshot": - return _json_ok(result=b.request_position_snapshot(force=bool(data.get("force")))) - if action == "subscribe_symbol": - return _json_ok(result=b.subscribe_symbol(data.get("symbol") or "")) - if action == "refresh_positions": - return _json_ok(result=b.refresh_positions()) - if action == "connect_in_progress": - return _json_ok(result=b.connect_in_progress()) - if action == "reconnect_after_settings_saved": - mode = data.get("mode") or get_trading_mode(get_setting) - return _json_ok(result=b.reconnect_after_settings_saved(mode)) - if action == "query_all_commissions": - return _json_ok(result=b.query_all_commissions( - mode=data.get("mode") or get_trading_mode(get_setting), - )) - if action == "query_instrument_commission": - return _json_ok(result=b.query_instrument_commission( - data.get("symbol") or "", - mode=data.get("mode") or get_trading_mode(get_setting), - )) - if action == "get_kline_bars_1m": - return _json_ok(result=b.get_kline_bars_1m( - data.get("symbol") or "", - mode=data.get("mode") or get_trading_mode(get_setting), - )) - return _json_error(ValueError(f"unsupported bridge action: {action}"), status_code=404) - - -def main() -> None: - ensure_process_locale() - try_init_vnpy({}) - _start_background_workers() - host = os.getenv("QIHUO_CTP_WORKER_HOST", "127.0.0.1") - port = int(os.getenv("QIHUO_CTP_WORKER_PORT", "6601") or 6601) - logger.info("starting qihuo-ctp worker on %s:%s", host, port) - app.run(host=host, port=port, debug=False, threaded=True, use_reloader=False) - - -if __name__ == "__main__": - main() +# Copyright (c) 2025-2026 马建军. All rights reserved. +# 专有软件 — 未经授权禁止复制、传播、转售。 +# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md + +"""Isolated local CTP worker. + +This process is the only process that should instantiate vn.py / vnpy_ctp. +The Flask web app talks to it through localhost HTTP via ctp_ipc_client.py. +""" +from __future__ import annotations + +import logging +import os +import threading +import time +from typing import Any + +os.environ.setdefault("QIHUO_CTP_ROLE", "worker") + +from flask import Flask, jsonify, request + +from modules.ctp.ctp_ipc_client import worker_token +from modules.core.db_conn import DB_PATH, commit_retry, connect_db +from modules.fees.fee_specs import get_setting, set_setting +from modules.core.locale_fix import ensure_process_locale +from modules.market.market_sessions import is_trading_session +from modules.trading.sl_tp_guard import check_sl_tp_on_tick, ensure_monitor_order_columns, start_sl_tp_guard_worker +from strategy.strategy_db import init_strategy_tables +from modules.core.trading_context import get_account_capital, get_trading_mode, get_trailing_be_tick_buffer +from modules.ctp.vnpy_bridge import ( + _ctp_td_lock, + ctp_cancel_order, + ctp_disconnect, + ctp_estimate_margin_one_lot, + ctp_get_account, + ctp_get_tick_detail, + ctp_get_tick_price, + ctp_list_active_orders, + ctp_list_positions, + ctp_list_trades, + ctp_lookup_contract_spec, + ctp_start_connect, + ctp_status, + ctp_try_auto_reconnect, + execute_order, + get_bridge, + set_ctp_connected_callback, + set_position_refresh_callback, + set_tick_quote_callback, + set_tick_sl_tp_callback, + try_init_vnpy, +) + + +logging.basicConfig( + level=os.getenv("LOG_LEVEL", "INFO"), + format="%(asctime)s %(levelname)s [%(name)s] %(message)s", +) +logger = logging.getLogger(__name__) + +app = Flask(__name__) +_started_workers = False +_last_snapshot_ts = 0.0 +_snapshot_lock = threading.Lock() + + +def _json_ok(**payload: Any): + return jsonify({"ok": True, **payload}) + + +def _json_error(exc: Exception, *, status_code: int = 500): + return jsonify({"ok": False, "error": str(exc)}), status_code + + +def _require_token() -> None: + expected = worker_token() + got = request.headers.get("X-Qihuo-CTP-Token", "") + if expected and got != expected: + raise PermissionError("unauthorized") + + +@app.before_request +def _auth(): + _require_token() + + +@app.errorhandler(Exception) +def _handle_error(exc: Exception): + code = 401 if isinstance(exc, PermissionError) else 500 + logger.warning("ctp worker request failed: %s", exc) + return _json_error(exc, status_code=code) + + +def _mode_from_request() -> str: + data = request.get_json(silent=True) or {} + return ( + data.get("mode") + or request.args.get("mode") + or get_trading_mode(get_setting) + or "simulation" + ) + + +def _fast_status(mode: str) -> dict[str, Any]: + """Return worker/native bridge state without slow network probing.""" + from modules.ctp.ctp_settings import CTP_DISABLED_HINT, is_ctp_auto_connect_enabled + + try: + st = dict(get_bridge().status(mode) or {}) + except Exception as exc: + st = { + "connected": False, + "connecting": False, + "connected_mode": None, + "last_error": str(exc), + "mode_label": "SimNow" if mode == "simulation" else "期货公司实盘", + } + auto = is_ctp_auto_connect_enabled() + st["auto_connect_enabled"] = auto + st["worker_online"] = True + if not auto: + st["disabled_hint"] = CTP_DISABLED_HINT + if not st.get("connected") and not st.get("connecting"): + st["last_error"] = "" + st["td_reachable"] = None + return st + + +def _send_wechat_msg(content: str) -> None: + webhook = get_setting("wechat_webhook", "") + if not webhook: + return + try: + import requests + + requests.post( + webhook, + json={"msgtype": "text", "text": {"content": f"【国内期货】\n{content}"}}, + timeout=10, + ) + except Exception as exc: + logger.debug("wechat notify failed: %s", exc) + + +def _init_worker_tables(conn) -> None: + init_strategy_tables(conn) + ensure_monitor_order_columns(conn) + + +def _capital(conn) -> float: + try: + return float(get_account_capital(get_setting, conn=conn) or 0) + except Exception: + return 0.0 + + +def _persist_snapshot(mode: str) -> None: + global _last_snapshot_ts + with _snapshot_lock: + now = time.time() + if now - _last_snapshot_ts < 0.25: + return + _last_snapshot_ts = now + try: + import json + + st = _fast_status(mode) + positions = ctp_list_positions(mode, refresh_if_empty=False, refresh_margin=False) + account = ctp_get_account(mode) if st.get("connected") else {} + conn = connect_db(DB_PATH) + try: + conn.execute( + """CREATE TABLE IF NOT EXISTS ctp_worker_snapshots ( + key TEXT PRIMARY KEY, + value TEXT, + updated_at REAL + )""" + ) + for key, value in ( + ("status", st), + ("positions", positions), + ("account", account), + ): + conn.execute( + """INSERT INTO ctp_worker_snapshots(key, value, updated_at) + VALUES(?,?,?) + ON CONFLICT(key) DO UPDATE SET + value=excluded.value, + updated_at=excluded.updated_at""", + (key, json.dumps(value, ensure_ascii=False), now), + ) + commit_retry(conn) + finally: + conn.close() + except Exception as exc: + logger.debug("persist ctp snapshot: %s", exc) + + +def _on_position_refresh() -> None: + try: + _persist_snapshot(get_trading_mode(get_setting)) + except Exception as exc: + logger.debug("position refresh callback: %s", exc) + + +def _on_tick_quote() -> None: + _on_position_refresh() + + +def _on_tick_sl_tp(exchange: str, symbol: str, price: float) -> None: + mode = get_trading_mode(get_setting) + if not ctp_status(mode).get("connected"): + return + conn = connect_db(DB_PATH) + try: + _init_worker_tables(conn) + capital = _capital(conn) + n = check_sl_tp_on_tick( + conn, + mode, + exchange, + symbol, + price, + capital=capital, + notify_fn=_send_wechat_msg, + be_tick_mult=get_trailing_be_tick_buffer(get_setting), + ) + if n: + commit_retry(conn) + _persist_snapshot(mode) + except Exception as exc: + logger.warning("worker tick sl/tp: %s", exc) + finally: + conn.close() + + +def _on_ctp_connected(mode: str) -> None: + try: + with _ctp_td_lock: + get_bridge().request_position_snapshot(force=True) + get_bridge().calibrate_trading_state() + _persist_snapshot(mode) + except Exception as exc: + logger.debug("worker ctp connected callback: %s", exc) + + +def _start_background_workers() -> None: + global _started_workers + if _started_workers: + return + _started_workers = True + + set_position_refresh_callback(_on_position_refresh) + set_tick_quote_callback(_on_tick_quote) + set_tick_sl_tp_callback(_on_tick_sl_tp) + set_ctp_connected_callback(_on_ctp_connected) + + from modules.ctp.ctp_fee_worker import start_ctp_fee_worker + from modules.ctp.ctp_premarket_connect import start_ctp_premarket_connect_worker + from modules.ctp.ctp_reconnect import start_ctp_reconnect_worker + from modules.trading.order_pending import reconcile_pending_orders + from modules.trading.pending_order_worker import start_pending_order_worker + + def _mode() -> str: + return get_trading_mode(get_setting) + + start_ctp_reconnect_worker(get_mode_fn=_mode, get_setting_fn=get_setting) + start_ctp_premarket_connect_worker(get_mode_fn=_mode, get_setting_fn=get_setting) + start_ctp_fee_worker( + get_mode_fn=_mode, + get_setting_fn=get_setting, + set_setting_fn=set_setting, + ) + start_pending_order_worker( + db_path=DB_PATH, + get_mode_fn=_mode, + init_tables_fn=_init_worker_tables, + get_capital_fn=_capital, + reconcile_fn=reconcile_pending_orders, + on_changed_fn=lambda: _persist_snapshot(_mode()), + ) + start_sl_tp_guard_worker( + db_path=DB_PATH, + get_mode_fn=_mode, + init_tables_fn=_init_worker_tables, + get_capital_fn=_capital, + get_be_tick_buffer_fn=lambda: get_trailing_be_tick_buffer(get_setting), + notify_fn=_send_wechat_msg, + ) + + def _snapshot_loop() -> None: + time.sleep(3) + while True: + try: + mode = _mode() + if _fast_status(mode).get("connected"): + _persist_snapshot(mode) + except Exception as exc: + logger.debug("worker snapshot loop: %s", exc) + time.sleep(2 if is_trading_session() else 15) + + threading.Thread(target=_snapshot_loop, daemon=True, name="ctp-worker-snapshot").start() + + +@app.route("/health") +def health(): + mode = request.args.get("mode") or get_trading_mode(get_setting) + st = _fast_status(mode) + return _json_ok( + worker_online=True, + role=os.getenv("QIHUO_CTP_ROLE", "worker"), + mode=mode, + status=st, + ts=time.time(), + ) + + +@app.route("/ctp/status") +def api_status(): + mode = _mode_from_request() + return _json_ok(status=_fast_status(mode)) + + +@app.route("/ctp/connect", methods=["POST"]) +def api_connect(): + data = request.get_json(silent=True) or {} + mode = data.get("mode") or get_trading_mode(get_setting) + info = ctp_start_connect(mode, force=bool(data.get("force"))) + st = info.get("status") or _fast_status(mode) + return _json_ok(status=st, **{k: v for k, v in info.items() if k != "status"}) + + +@app.route("/ctp/start_connect", methods=["POST"]) +def api_start_connect(): + data = request.get_json(silent=True) or {} + mode = data.get("mode") or get_trading_mode(get_setting) + return _json_ok(**ctp_start_connect( + mode, + force=bool(data.get("force")), + scheduled=bool(data.get("scheduled")), + )) + + +@app.route("/ctp/disconnect", methods=["POST"]) +def api_disconnect(): + data = request.get_json(silent=True) or {} + ctp_disconnect(set_disabled_hint=bool(data.get("set_disabled_hint"))) + return _json_ok(disconnected=True) + + +@app.route("/ctp/account") +def api_account(): + mode = _mode_from_request() + if not _fast_status(mode).get("connected"): + return _json_ok(account={}) + return _json_ok(account=ctp_get_account(mode)) + + +@app.route("/ctp/positions", methods=["POST"]) +def api_positions(): + data = request.get_json(silent=True) or {} + mode = data.get("mode") or get_trading_mode(get_setting) + return _json_ok(positions=ctp_list_positions( + mode, + refresh_if_empty=bool(data.get("refresh_if_empty", True)), + refresh_margin=bool(data.get("refresh_margin", False)), + )) + + +@app.route("/ctp/trades", methods=["POST"]) +def api_trades(): + data = request.get_json(silent=True) or {} + mode = data.get("mode") or get_trading_mode(get_setting) + return _json_ok(trades=ctp_list_trades(mode, refresh=bool(data.get("refresh")))) + + +@app.route("/ctp/active_orders") +def api_active_orders(): + mode = _mode_from_request() + return _json_ok(orders=ctp_list_active_orders(mode)) + + +@app.route("/ctp/tick_price", methods=["POST"]) +def api_tick_price(): + data = request.get_json(silent=True) or {} + return _json_ok(price=ctp_get_tick_price( + data.get("mode") or get_trading_mode(get_setting), + data.get("symbol") or "", + )) + + +@app.route("/ctp/tick_detail", methods=["POST"]) +def api_tick_detail(): + data = request.get_json(silent=True) or {} + return _json_ok(detail=ctp_get_tick_detail( + data.get("mode") or get_trading_mode(get_setting), + data.get("symbol") or "", + )) + + +@app.route("/ctp/estimate_margin_one_lot", methods=["POST"]) +def api_estimate_margin(): + data = request.get_json(silent=True) or {} + return _json_ok(margin=ctp_estimate_margin_one_lot( + data.get("mode") or get_trading_mode(get_setting), + data.get("symbol") or "", + float(data.get("price") or 0), + direction=data.get("direction") or "long", + )) + + +@app.route("/ctp/contract_spec", methods=["POST"]) +def api_contract_spec(): + data = request.get_json(silent=True) or {} + return _json_ok(spec=ctp_lookup_contract_spec( + data.get("mode") or get_trading_mode(get_setting), + data.get("symbol") or "", + )) + + +@app.route("/ctp/order", methods=["POST"]) +def api_order(): + data = request.get_json(silent=True) or {} + mode = data.get("mode") or get_trading_mode(get_setting) + result = execute_order( + None, + mode=mode, + offset=data.get("offset") or "open", + symbol=data.get("symbol") or "", + direction=data.get("direction") or "long", + lots=int(data.get("lots") or 1), + price=float(data.get("price") or 0), + settings=data.get("settings") or {}, + order_type=data.get("order_type") or "limit", + ) + _persist_snapshot(mode) + return _json_ok(**result) + + +@app.route("/ctp/cancel", methods=["POST"]) +def api_cancel(): + data = request.get_json(silent=True) or {} + mode = data.get("mode") or get_trading_mode(get_setting) + cancelled = ctp_cancel_order(mode, data.get("vt_orderid") or "") + _persist_snapshot(mode) + return _json_ok(cancelled=cancelled) + + +@app.route("/ctp/bridge/", methods=["POST"]) +def api_bridge_action(action: str): + data = request.get_json(silent=True) or {} + b = get_bridge() + if action == "calibrate_trading_state": + return _json_ok(result=b.calibrate_trading_state()) + if action == "request_position_snapshot": + return _json_ok(result=b.request_position_snapshot(force=bool(data.get("force")))) + if action == "subscribe_symbol": + return _json_ok(result=b.subscribe_symbol(data.get("symbol") or "")) + if action == "refresh_positions": + return _json_ok(result=b.refresh_positions()) + if action == "connect_in_progress": + return _json_ok(result=b.connect_in_progress()) + if action == "reconnect_after_settings_saved": + mode = data.get("mode") or get_trading_mode(get_setting) + return _json_ok(result=b.reconnect_after_settings_saved(mode)) + if action == "query_all_commissions": + return _json_ok(result=b.query_all_commissions( + mode=data.get("mode") or get_trading_mode(get_setting), + )) + if action == "query_instrument_commission": + return _json_ok(result=b.query_instrument_commission( + data.get("symbol") or "", + mode=data.get("mode") or get_trading_mode(get_setting), + )) + if action == "get_kline_bars_1m": + return _json_ok(result=b.get_kline_bars_1m( + data.get("symbol") or "", + mode=data.get("mode") or get_trading_mode(get_setting), + )) + return _json_error(ValueError(f"unsupported bridge action: {action}"), status_code=404) + + +def main() -> None: + ensure_process_locale() + try_init_vnpy({}) + _start_background_workers() + host = os.getenv("QIHUO_CTP_WORKER_HOST", "127.0.0.1") + port = int(os.getenv("QIHUO_CTP_WORKER_PORT", "6601") or 6601) + logger.info("starting qihuo-ctp worker on %s:%s", host, port) + app.run(host=host, port=port, debug=False, threaded=True, use_reloader=False) + + +if __name__ == "__main__": + main() diff --git a/vnpy_bridge.py b/modules/ctp/vnpy_bridge.py similarity index 95% rename from vnpy_bridge.py rename to modules/ctp/vnpy_bridge.py index ee0b040..b037d68 100644 --- a/vnpy_bridge.py +++ b/modules/ctp/vnpy_bridge.py @@ -1,2702 +1,2706 @@ -# Copyright (c) 2025-2026 马建军. All rights reserved. -# 专有软件 — 未经授权禁止复制、传播、转售。 -# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 -# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md - -"""CTP 执行层:模拟盘 → SimNow;实盘 → 期货公司(vnpy_ctp)。""" -from __future__ import annotations - -import logging -import os -import re -import threading -import time -from collections import deque -from typing import Any, Callable, Optional - -import ctp_ipc_client -from locale_fix import ensure_process_locale - -if ctp_ipc_client.is_worker_role(): - ensure_process_locale() - -from ctp_settings import live_setting_dict, simnow_setting_dict -from ctp_symbol import ths_to_vnpy_symbol, to_vnpy_exchange -from contract_specs import get_contract_spec - -logger = logging.getLogger(__name__) - -GATEWAY_NAME = "CTP" - -CONNECT_WAIT_SEC = 60 -CONNECT_POLL_INTERVAL_SEC = 0.5 -LOGIN_BAN_COOLDOWN_SEC = 45 * 60 -LOGIN_FAIL_COOLDOWN_SEC = 5 * 60 -CTP_COOLDOWN_UNTIL_KEY = "ctp_login_cooldown_until" -CTP_LAST_ERROR_KEY = "ctp_last_error" - - -def _use_ctp_worker_client() -> bool: - return not ctp_ipc_client.is_worker_role() - - -def _persist_login_cooldown(seconds: float) -> None: - from fee_specs import get_setting, set_setting - - new_until = time.time() + max(0.0, seconds) - try: - old = float(get_setting(CTP_COOLDOWN_UNTIL_KEY, "0") or 0) - except (TypeError, ValueError): - old = 0.0 - if new_until > old: - set_setting(CTP_COOLDOWN_UNTIL_KEY, str(new_until)) - - -def _persisted_login_cooldown_remaining() -> int: - from fee_specs import get_setting - - try: - until = float(get_setting(CTP_COOLDOWN_UNTIL_KEY, "0") or 0) - return max(0, int(until - time.time())) - except (TypeError, ValueError): - return 0 - - -def _clear_persisted_login_cooldown() -> None: - from fee_specs import set_setting - - set_setting(CTP_COOLDOWN_UNTIL_KEY, "0") - - -def _persist_last_error(msg: str) -> None: - from fee_specs import set_setting - - set_setting(CTP_LAST_ERROR_KEY, (msg or "").strip()) - - -def _load_persisted_last_error() -> str: - from fee_specs import get_setting - - return (get_setting(CTP_LAST_ERROR_KEY, "") or "").strip() - -_position_refresh_callback: Optional[Callable[[], None]] = None -_tick_sl_tp_callback: Optional[Callable[[str, str, float], None]] = None -_tick_quote_callback: Optional[Callable[[], None]] = None -_ctp_connected_callback: Optional[Callable[[str], None]] = None -_position_refresh_debounce_lock = threading.Lock() -_position_refresh_debounce_ts: float = 0.0 -_tick_quote_timer: Optional[threading.Timer] = None -_tick_quote_timer_lock = threading.Lock() -TICK_QUOTE_DEBOUNCE_SEC = 0.12 - - -def set_position_refresh_callback(fn: Optional[Callable[[], None]]) -> None: - global _position_refresh_callback - _position_refresh_callback = fn - - -def set_tick_sl_tp_callback(fn: Optional[Callable[[str, str, float], None]]) -> None: - """注册 tick 回调:exchange, symbol, last_price → 本地 SL/TP 触发。""" - global _tick_sl_tp_callback - _tick_sl_tp_callback = fn - - -def set_tick_quote_callback(fn: Optional[Callable[[], None]]) -> None: - """注册 tick 回调:推送持仓现价/浮盈(由 bridge 侧防抖)。""" - global _tick_quote_callback - _tick_quote_callback = fn - - -def _fire_tick_quote_callback_debounced() -> None: - """持仓品种 tick 后 trailing 防抖,批量推送现价/浮盈。""" - global _tick_quote_timer - - def _run() -> None: - fn = _tick_quote_callback - if not fn: - return - try: - fn() - except Exception as exc: - logger.debug("tick quote callback: %s", exc) - - with _tick_quote_timer_lock: - if _tick_quote_timer is not None: - _tick_quote_timer.cancel() - _tick_quote_timer = threading.Timer(TICK_QUOTE_DEBOUNCE_SEC, _run) - _tick_quote_timer.daemon = True - _tick_quote_timer.start() - - -def set_ctp_connected_callback(fn: Optional[Callable[[str], None]]) -> None: - """CTP 交易通道登录成功后回调(mode=simulation|live)。""" - global _ctp_connected_callback - _ctp_connected_callback = fn - - -def _fire_ctp_connected_callback(mode: str) -> None: - fn = _ctp_connected_callback - if not fn: - return - try: - threading.Thread( - target=fn, args=(mode,), daemon=True, name="ctp-connected-cb", - ).start() - except Exception as exc: - logger.debug("ctp connected callback: %s", exc) - - -def _fire_position_refresh_callback() -> None: - fn = _position_refresh_callback - if not fn: - return - try: - threading.Thread(target=fn, daemon=True, name="ctp-position-refresh").start() - except Exception as exc: - logger.debug("position refresh callback: %s", exc) - - -def _fire_position_refresh_callback_debounced(*, min_interval: float = 0.35) -> None: - global _position_refresh_debounce_ts - now = time.monotonic() - with _position_refresh_debounce_lock: - if now - _position_refresh_debounce_ts < min_interval: - return - _position_refresh_debounce_ts = now - _fire_position_refresh_callback() - - -def _fire_position_refresh_burst() -> None: - """连接后持仓回报可能分批到达,分多次触发快照刷新。""" - _fire_position_refresh_callback() - for delay in (0.4, 0.9, 1.5, 3.0, 6.0, 12.0, 20.0): - threading.Timer(delay, _fire_position_refresh_callback).start() - - -def _schedule_after_instruments_ready(bridge: "CtpBridge") -> None: - """合约查询结束后查询持仓并校准(SimNow 登录后约 10–20s)。""" - if not getattr(bridge, "_connected_mode", None): - return - now = time.monotonic() - if now - float(getattr(bridge, "_last_instruments_ready_ts", 0) or 0) < 5.0: - return - bridge._last_instruments_ready_ts = now - - def _run() -> None: - try: - if bridge._has_live_positions(): - return - bridge._ensure_instrument_margin_hooks() - with _ctp_td_lock: - bridge.request_position_snapshot(force=True) - time.sleep(0.8) - with _ctp_td_lock: - bridge.calibrate_trading_state() - _fire_position_refresh_callback() - _fire_position_refresh_burst() - n = len(bridge._collect_positions()) - logger.info("CTP 合约加载完成,持仓 %s 条,已刷新快照", n) - except Exception as exc: - logger.debug("instruments ready refresh: %s", exc) - - threading.Timer(0.4, _run).start() - - -def _schedule_position_query_retries(bridge: "CtpBridge") -> None: - def _run() -> None: - if not bridge._connected_mode or bridge._has_live_positions(): - return - try: - bridge._ensure_instrument_margin_hooks() - with _ctp_td_lock: - bridge.request_position_snapshot(force=False) - time.sleep(1.0) - with _ctp_td_lock: - bridge.calibrate_trading_state() - _fire_position_refresh_callback() - except Exception as exc: - logger.debug("position query retry: %s", exc) - - for delay in POSITION_QUERY_RETRY_DELAYS_SEC: - threading.Timer(delay, _run).start() - -_bridge: Optional["CtpBridge"] = None -_bridge_lock = threading.Lock() -_ctp_td_lock = threading.RLock() -POSITION_QUERY_MIN_INTERVAL_SEC = 5.0 -POSITION_QUERY_RETRY_DELAYS_SEC = (1.5, 4.0, 9.0, 18.0, 35.0) -TRADE_QUERY_MIN_INTERVAL_SEC = 10.0 - - -def _simnow_setting() -> dict[str, str]: - """SimNow 仿真前置(系统设置优先,.env 兜底)。""" - return simnow_setting_dict() - - -def _live_setting() -> dict[str, str]: - return live_setting_dict() - - -def _setting_for_mode(mode: str) -> dict[str, str]: - return _simnow_setting() if mode == "simulation" else _live_setting() - - -def _mode_label(mode: str) -> str: - return "SimNow" if mode == "simulation" else "期货公司实盘" - - -def _parse_tcp_address(address: str) -> tuple[str, int]: - raw = (address or "").strip() - if raw.startswith("tcp://"): - raw = raw[6:] - if ":" not in raw: - raise ValueError(f"无效 TCP 地址: {address}") - host, port_s = raw.rsplit(":", 1) - return host, int(port_s) - - -def probe_tcp_address(address: str, timeout: float = 5.0) -> tuple[bool, str]: - """探测 CTP 前置 TCP 是否可达。""" - import socket - - try: - host, port = _parse_tcp_address(address) - with socket.create_connection((host, port), timeout=timeout): - return True, "" - except Exception as exc: - return False, str(exc) - - -def _format_ctp_failure(ctp_logs: list[str], *, td_address: str = "") -> str: - """根据 CTP 网关日志拼出可读错误。""" - if td_address: - ok, err = probe_tcp_address(td_address, timeout=4.0) - if not ok: - return ( - f"SimNow 交易前置不可达:{td_address}({err})。" - "182.254.243.31 已停用,请改 .env 为官方前置 " - "tcp://180.168.146.187:10201 / 10211,并确认服务器能访问该地址。" - ) - text = "\n".join(ctp_logs) - if "连续登录失败" in text or "登录被禁止" in text or "代码:75" in text: - return ( - "CTP 登录被临时禁止:连续失败次数过多(错误码 75)。" - "请等待约 30~60 分钟后再试,先用快期确认投资者代码与密码正确,期间勿反复点「连接」。" - ) - if "4097" in text or "Decrypt handshake" in text or "shake hand" in text.lower(): - return ( - "CTP 握手失败(4097):vnpy_ctp 与 SimNow 前置加密不匹配。" - "请执行 pip install -U vnpy vnpy_ctp 后重启,并确认 .env 中 SIMNOW_ENV=实盘" - ) - if "不合法的登录" in text or "密码" in text or "账号" in text: - tail = ctp_logs[-1] if ctp_logs else "" - return f"CTP 登录被拒:{tail or '请检查投资者代码与密码(快期能否登录)'}" - if "连接断开" in text or "disconnect" in text.lower(): - tail = ctp_logs[-1] if ctp_logs else "" - return f"CTP 连接断开:{tail or '请检查前置地址与网络'}" - if ctp_logs: - return f"CTP 连接失败:{ctp_logs[-1]}" - return "CTP 连接超时:未收到柜台回报。请检查 SimNow 账号、前置地址、网络(nc 测端口),并用快期验证账号" - - -def round_to_tick(price: float, tick: float) -> float: - if tick <= 0: - return float(price) - steps = round(float(price) / tick) - return round(steps * tick, 10) - - -def _is_long_direction(direction_obj: Any) -> bool: - s = str(direction_obj or "") - return "LONG" in s.upper() or "多" in s - - -class CtpBridge: - def __init__(self) -> None: - self._engine = None - self._ee = None - self._connected_mode: Optional[str] = None - self._last_error: str = "" - self._connect_lock = threading.Lock() - self._connect_in_progress = False - self._login_cooldown_until: float = 0.0 - self._restore_persisted_state() - self._commission_waiters: dict[int, threading.Event] = {} - self._commission_lists: dict[int, list] = {} - self._commission_hooked = False - self._margin_rate_waiters: dict[int, threading.Event] = {} - self._margin_rate_lists: dict[int, list] = {} - self._margin_rate_hooked = False - self._instrument_hooked = False - self._hooks_td_api_id: Optional[int] = None - self._ctp_log_hooked = False - self._last_instruments_ready_ts: float = 0.0 - self._last_position_rsp_ts: float = 0.0 - self._instrument_margin_ratios: dict[str, dict[str, float]] = {} - self._margin_per_lot: dict[str, float] = {} - self._subscribed: set[str] = set() - self._last_position_query_ts: float = 0.0 - self._position_margins: dict[str, float] = {} - self._position_open_times: dict[str, str] = {} - self._margin_hooked = False - self._trade_hooked = False - self._trade_query_results: list[dict[str, Any]] = [] - self._trade_query_event = threading.Event() - self._last_trade_query_ts: float = 0.0 - self._last_connect_ok_ts: float = 0.0 - self._connect_started_ts: float = 0.0 - self._tick_hooked = False - self._position_hooked = False - self._order_hooked = False - self._trade_hooked = False - self._bar_generators: dict[str, Any] = {} - self._bars_1m: dict[str, deque] = {} - self._init_engine() - - def _init_engine(self) -> None: - ensure_process_locale() - try: - from vnpy.event import EventEngine - from vnpy.trader.engine import MainEngine - from vnpy_ctp import CtpGateway - - self._ee = EventEngine() - self._engine = MainEngine(self._ee) - self._engine.add_gateway(CtpGateway) - self._ensure_position_event_hook() - self._ensure_order_event_hook() - self._ensure_trade_event_hook() - self._ensure_ctp_log_hooks() - except ImportError: - self._last_error = "未安装 vnpy / vnpy_ctp,请 pip install vnpy vnpy_ctp" - except Exception as exc: - self._last_error = str(exc) - - def _ensure_position_event_hook(self) -> None: - if self._position_hooked or not self._ee: - return - try: - from vnpy.trader.event import EVENT_POSITION - except ImportError: - return - - def _on_position(event) -> None: - try: - from ctp_trading_state import trading_state - - pos = event.data - row = self._position_row_from_vnpy(pos) - if row: - sym = row.get("symbol") or "" - ex = row.get("exchange") or "" - ths = CtpBridge._vnpy_sym_to_ths(sym, ex) or sym - with _ctp_td_lock: - trades = self.list_trades() - trading_state.upsert_position( - row, notify=False, trades=trades, ths_sym=ths, - ) - sym = getattr(pos, "symbol", "") or "" - d = "long" if _is_long_direction(getattr(pos, "direction", None)) else "short" - vol = int(getattr(pos, "volume", 0) or 0) - if vol <= 0: - exchange = getattr(pos, "exchange", None) - ex_name = str(exchange.value if hasattr(exchange, "value") else exchange or "") - from ctp_trading_state import position_key - - trading_state.remove_position( - position_key(ex_name, sym, d), notify=False, - ) - else: - for attr in ("margin", "use_margin", "UseMargin"): - raw = float(getattr(pos, attr, 0) or 0) - if raw > 0: - self._position_margins[self._position_margin_key(sym, d)] = raw - if vol > 0: - self._margin_per_lot[self._position_margin_key(sym, d)] = round( - raw / vol, 2, - ) - break - except Exception as exc: - logger.debug("position margin cache: %s", exc) - _fire_position_refresh_callback() - - self._ee.register(EVENT_POSITION, _on_position) - self._position_hooked = True - - def _ensure_order_event_hook(self) -> None: - if self._order_hooked or not self._ee: - return - try: - from vnpy.trader.event import EVENT_ORDER - except ImportError: - return - - def _on_order(event) -> None: - try: - from ctp_trading_state import trading_state - - order = event.data - row = self._order_row_from_vnpy(order) - if not row: - return - status_s = str(row.get("status") or "") - terminal = any( - x in status_s - for x in ("ALLTRADED", "CANCELLED", "REJECTED", "全部成交", "已撤销", "拒单") - ) - oid = str(row.get("order_id") or row.get("vt_order_id") or "") - if terminal or int(row.get("lots") or 0) <= 0: - trading_state.remove_order(oid, notify=False) - else: - trading_state.upsert_order(row, notify=False) - except Exception as exc: - logger.debug("order event: %s", exc) - _fire_position_refresh_callback() - - self._ee.register(EVENT_ORDER, _on_order) - self._order_hooked = True - - def _ensure_trade_event_hook(self) -> None: - if self._trade_hooked or not self._ee: - return - try: - from vnpy.trader.event import EVENT_TRADE - except ImportError: - return - - def _on_trade(event) -> None: - try: - trade = event.data - row = self._trade_row_from_vnpy(trade) - if row and row.get("offset") == "open": - sym = row.get("symbol") or "" - pd = row.get("position_direction") or "long" - dt = row.get("datetime") or "" - if sym and dt: - self._position_open_times[self._position_margin_key(sym, pd)] = dt - except Exception as exc: - logger.debug("trade event: %s", exc) - _fire_position_refresh_callback() - - self._ee.register(EVENT_TRADE, _on_trade) - self._trade_hooked = True - - def _order_row_from_vnpy(self, order: Any) -> Optional[dict[str, Any]]: - try: - status = getattr(order, "status", None) - status_s = str(status) - vol = int(getattr(order, "volume", 0) or 0) - traded = int(getattr(order, "traded", 0) or 0) - remain = max(0, vol - traded) - direction = getattr(order, "direction", None) - d = "long" - if direction is not None and str(direction).endswith("SHORT"): - d = "short" - offset = getattr(order, "offset", None) - sym = getattr(order, "symbol", "") or "" - exchange = getattr(order, "exchange", None) - ex_name = str(exchange.value if hasattr(exchange, "value") else exchange or "") - vt_oid = str(getattr(order, "vt_orderid", "") or "") - order_id = str(getattr(order, "orderid", "") or "") - return { - "symbol": sym, - "exchange": ex_name, - "direction": d, - "lots": remain, - "price": float(getattr(order, "price", 0) or 0), - "offset": str(offset or ""), - "order_id": vt_oid or order_id, - "vt_order_id": vt_oid, - "status": status_s, - } - except Exception as exc: - logger.debug("order_row_from_vnpy: %s", exc) - return None - - def _position_row_from_vnpy(self, pos: Any) -> Optional[dict[str, Any]]: - try: - vol = int(getattr(pos, "volume", 0) or 0) - d = "long" if _is_long_direction(getattr(pos, "direction", None)) else "short" - sym = getattr(pos, "symbol", "") or "" - exchange = getattr(pos, "exchange", None) - ex_name = str(exchange.value if hasattr(exchange, "value") else exchange or "") - price = float(getattr(pos, "price", 0) or 0) - yd = int(getattr(pos, "yd_volume", 0) or 0) - td = max(0, vol - yd) - margin = self.estimate_position_margin(sym, ex_name, d, vol, price, pos=pos) - open_time = self._lookup_position_open_time(sym, d) or None - pnl = float(getattr(pos, "pnl", 0) or 0) - row = { - "symbol": sym, - "exchange": ex_name, - "direction": d, - "lots": vol, - "avg_price": price, - "pnl": pnl, - "frozen": int(getattr(pos, "frozen", 0) or 0), - "margin": margin, - "open_time": open_time, - "yd_volume": yd, - "td_volume": td, - } - try: - from ctp_entry_price import round_to_tick - - ths = CtpBridge._vnpy_sym_to_ths(sym, ex_name) or sym - if price > 0: - row["avg_price"] = round_to_tick(price, ths) - except Exception as exc: - logger.debug("position avg round: %s", exc) - return row - except Exception as exc: - logger.debug("position_row_from_vnpy: %s", exc) - return None - - def calibrate_trading_state(self) -> None: - """全量校准内存簿(读 vnpy 缓存,不 query 柜台)。""" - try: - from ctp_trading_state import trading_state - - with _ctp_td_lock: - orders = self.list_active_orders() - positions = self._collect_positions() - trades = self.list_trades() - preserve_margin = 0.0 - if self._connected_mode and not positions: - try: - preserve_margin = float( - ctp_account_margin_used(self._connected_mode) or 0, - ) - except Exception: - preserve_margin = 0.0 - trading_state.calibrate_from_lists( - orders, - positions, - trades=trades, - ths_for_vnpy_sym=lambda s, e: CtpBridge._vnpy_sym_to_ths(s, e) or s, - preserve_positions_if_margin=preserve_margin, - ) - except Exception as exc: - logger.debug("calibrate trading state: %s", exc) - - def available(self) -> bool: - return self._engine is not None - - @property - def last_error(self) -> str: - return self._last_error - - @property - def connected_mode(self) -> Optional[str]: - return self._connected_mode - - def connect_in_progress(self) -> bool: - return self._connect_in_progress - - def _restore_persisted_state(self) -> None: - err = _load_persisted_last_error() - if err: - self._last_error = err - db_remain = _persisted_login_cooldown_remaining() - if db_remain > 0: - self._login_cooldown_until = time.monotonic() + db_remain - - def login_cooldown_remaining(self) -> int: - """距允许再次登录的剩余秒数(内存 + 数据库,重启后仍有效)。""" - mem = max(0, int(self._login_cooldown_until - time.monotonic())) - return max(mem, _persisted_login_cooldown_remaining()) - - def _is_login_cooldown_active(self) -> bool: - return self.login_cooldown_remaining() > 0 - - def _set_login_cooldown(self, seconds: float) -> None: - until = time.monotonic() + max(0.0, seconds) - if until > self._login_cooldown_until: - self._login_cooldown_until = until - _persist_login_cooldown(seconds) - - def _clear_login_cooldown(self) -> None: - self._login_cooldown_until = 0.0 - _clear_persisted_login_cooldown() - - def _apply_login_failure_cooldown(self, ctp_logs: list[str]) -> None: - text = "\n".join(ctp_logs) - if "连续登录失败" in text or "登录被禁止" in text or "代码:75" in text: - self._set_login_cooldown(LOGIN_BAN_COOLDOWN_SEC) - elif any("登录失败" in m or "不合法的登录" in m for m in ctp_logs): - self._set_login_cooldown(LOGIN_FAIL_COOLDOWN_SEC) - - def _login_cooldown_message(self) -> str: - remain = self.login_cooldown_remaining() - return ( - f"CTP 登录冷却中,请 {remain // 60} 分 {remain % 60} 秒后再试" - f"(避免连续失败被 SimNow 封禁)" - ) - - def _close_gateway(self) -> None: - """关闭 CTP 网关,避免半连接状态下重连卡在「连接登录」。""" - if not self._engine: - return - try: - gw = self._engine.get_gateway(GATEWAY_NAME) - if gw: - gw.close() - except Exception as exc: - logger.debug("gateway close: %s", exc) - self._connected_mode = None - self._hooks_td_api_id = None - self._instrument_hooked = False - self._margin_rate_hooked = False - self._last_position_query_ts = 0.0 - self._last_instruments_ready_ts = 0.0 - try: - from ctp_trading_state import trading_state - - trading_state.clear() - except Exception: - pass - time.sleep(0.6) - - def _ensure_ctp_log_hooks(self) -> None: - """监听 vnpy 日志:合约查询成功时补触发持仓刷新(重连后 td_api 可能已换)。""" - if self._ctp_log_hooked or not self._ee: - return - try: - from vnpy.trader.event import EVENT_LOG - except ImportError: - return - bridge = self - - def _on_persistent_log(event) -> None: - try: - msg = getattr(event.data, "msg", "") or str(event.data) - if "合约信息查询成功" in str(msg): - _schedule_after_instruments_ready(bridge) - except Exception as exc: - logger.debug("ctp log hook: %s", exc) - - self._ee.register(EVENT_LOG, _on_persistent_log) - self._ctp_log_hooked = True - - def _login_rejected(self, ctp_logs: list[str]) -> bool: - return any( - kw in m - for m in ctp_logs - for kw in ("登录失败", "不合法的登录", "登录被禁止", "连续登录失败") - ) - - def _wait_connected(self, mode: str, ctp_logs: list[str] | None = None) -> bool: - """等待账户回报或交易通道登录成功。""" - if not self._engine: - return False - logs = ctp_logs or [] - loops = max(1, int(CONNECT_WAIT_SEC / CONNECT_POLL_INTERVAL_SEC)) - for _ in range(loops): - if self._login_rejected(logs): - return False - try: - if self._engine.get_all_accounts(): - return True - except Exception: - pass - if self._td_logged_in(): - return True - time.sleep(CONNECT_POLL_INTERVAL_SEC) - return False - - def status(self, mode: str) -> dict[str, Any]: - if self._connected_mode == mode: - self.ping() - st = _setting_for_mode(mode) - missing = [k for k in ("用户名", "密码", "交易服务器") if not st.get(k)] - cooldown = self.login_cooldown_remaining() - connecting = bool(self._connect_in_progress and cooldown <= 0) - last_error = self._last_error or _load_persisted_last_error() - if ( - connecting - and self._connect_started_ts > 0 - and time.time() - self._connect_started_ts > CONNECT_WAIT_SEC + 10 - and not last_error - ): - last_error = ( - f"CTP 连接进行中已超过 {CONNECT_WAIT_SEC}s," - "可能前置不可达或柜台响应慢" - ) - return { - "vnpy_installed": self.available(), - "connected": self._connected_mode == mode, - "connecting": connecting, - "connected_mode": self._connected_mode, - "mode_label": _mode_label(mode), - "missing_config": missing, - "last_error": last_error, - "login_cooldown_sec": cooldown, - "broker_id": st.get("经纪商代码", ""), - "td_address": st.get("交易服务器", ""), - } - - def connect(self, mode: str, *, force: bool = False, scheduled: bool = False) -> None: - from ctp_settings import CTP_DISABLED_HINT - - if not _ctp_connect_permitted(scheduled=scheduled): - self._last_error = CTP_DISABLED_HINT - _persist_last_error(CTP_DISABLED_HINT) - raise RuntimeError(CTP_DISABLED_HINT) - if self._connect_in_progress: - raise RuntimeError("CTP 正在连接中,请稍候") - if self._is_login_cooldown_active() and not force: - msg = self._login_cooldown_message() - self._last_error = msg - raise RuntimeError(msg) - if not self._engine: - raise RuntimeError(self._last_error or "vnpy 引擎未初始化") - if self._connected_mode == mode and not force: - if self.ping(): - return - self._connected_mode = None - setting = _setting_for_mode(mode) - if not setting.get("用户名") or not setting.get("密码"): - raise ValueError( - f"{_mode_label(mode)}:请在 .env 配置 " - f"{'SIMNOW_USER / SIMNOW_PASSWORD' if mode == 'simulation' else 'CTP_LIVE_USER / CTP_LIVE_PASSWORD'}" - ) - if not setting.get("交易服务器"): - raise ValueError(f"{_mode_label(mode)}:未配置交易服务器地址") - - self._connect_in_progress = True - self._connect_started_ts = time.time() - try: - with _ctp_td_lock: - with self._connect_lock: - if force and self._connected_mode: - self._close_gateway() - elif self._connected_mode and self._connected_mode != mode: - try: - self._engine.close() - except Exception: - pass - self._connected_mode = None - time.sleep(1) - elif not (self._connected_mode == mode and self.ping()): - self._close_gateway() - - ctp_logs: list[str] = [] - from vnpy.trader.event import EVENT_LOG - - def _on_log(event) -> None: - msg = getattr(event.data, "msg", "") or str(event.data) - if msg: - ctp_logs.append(str(msg)) - if len(ctp_logs) > 40: - ctp_logs.pop(0) - logger.info("CTP | %s", msg) - - self._ee.register(EVENT_LOG, _on_log) - try: - ensure_process_locale() - logger.info( - "CTP 连接 [%s] user=%s td=%s env=%s", - mode, - setting.get("用户名"), - setting.get("交易服务器"), - setting.get("柜台环境", "实盘"), - ) - td_addr = setting.get("交易服务器", "") - ok, err = probe_tcp_address(td_addr, timeout=5.0) - if not ok: - raise RuntimeError( - f"SimNow 交易前置不可达:{td_addr}({err})。" - "请更新 .env 中 SIMNOW_TD_ADDRESS 为官网最新地址," - "并在服务器执行 nc -zv 验证出网。" - ) - self._ensure_instrument_margin_hooks() - self._engine.connect(setting, GATEWAY_NAME) - if self._wait_connected(mode, ctp_logs): - self._connected_mode = mode - self._last_connect_ok_ts = time.time() - self._last_error = "" - _persist_last_error("") - self._clear_login_cooldown() - logger.info("CTP 已连接 [%s] td_login=%s accounts=%s", - mode, self._td_logged_in(), - len(self._engine.get_all_accounts() or [])) - self._schedule_fee_sync(mode) - try: - self.calibrate_trading_state() - except Exception as exc: - logger.debug("post-connect calibrate: %s", exc) - try: - self.request_position_snapshot(force=True) - except Exception as exc: - logger.debug("post-connect position query: %s", exc) - self._ensure_instrument_margin_hooks() - _fire_position_refresh_burst() - _schedule_position_query_retries(self) - _fire_ctp_connected_callback(mode) - return - finally: - self._ee.unregister(EVENT_LOG, _on_log) - - self._close_gateway() - self._apply_login_failure_cooldown(ctp_logs) - hint = _format_ctp_failure(ctp_logs, td_address=setting.get("交易服务器", "")) - self._last_error = hint - _persist_last_error(hint) - logger.warning("CTP 连接失败 [%s]: %s | logs=%s", mode, hint, ctp_logs[-5:]) - raise RuntimeError(hint) - finally: - self._connect_in_progress = False - self._connect_started_ts = 0.0 - - def start_connect_async( - self, mode: str, *, force: bool = False, scheduled: bool = False, - ) -> dict[str, Any]: - """后台连接,不阻塞 HTTP 请求。""" - from ctp_settings import CTP_DISABLED_HINT - - if not _ctp_connect_permitted(scheduled=scheduled): - self._last_error = CTP_DISABLED_HINT - _persist_last_error(CTP_DISABLED_HINT) - return { - "started": False, - "connecting": False, - "connected": False, - "disabled": True, - "error": CTP_DISABLED_HINT, - } - if self._connected_mode == mode and self.ping() and not force: - return {"started": False, "connecting": False, "connected": True} - if self._connect_in_progress: - return {"started": False, "connecting": True, "connected": False} - if self._is_login_cooldown_active() and not force: - self._last_error = self._login_cooldown_message() - return { - "started": False, - "connecting": False, - "connected": False, - "cooldown": True, - } - - def _run() -> None: - try: - self.connect(mode, force=force, scheduled=scheduled) - except Exception as exc: - logger.warning("CTP 后台连接失败: %s", exc) - - def _watchdog() -> None: - deadline = CONNECT_WAIT_SEC + 25 - time.sleep(deadline) - if not self._connect_in_progress: - return - logger.warning( - "CTP 连接 watchdog 超时 %.0fs,重置连接状态 [%s]", - deadline, - mode, - ) - self._connect_in_progress = False - self._connect_started_ts = 0.0 - hint = ( - f"CTP 连接超时(>{deadline:.0f}s),可能前置不可达或柜台无响应。" - "请检查 SimNow 前置地址与账号,勿频繁重试。" - ) - self._last_error = hint - _persist_last_error(hint) - try: - self._close_gateway() - except Exception as exc: - logger.debug("watchdog gateway close: %s", exc) - - threading.Thread(target=_run, daemon=True, name="ctp-connect-async").start() - threading.Thread(target=_watchdog, daemon=True, name="ctp-connect-watchdog").start() - return {"started": True, "connecting": True, "connected": False} - - def ensure_connected(self, mode: str) -> None: - if self._connected_mode == mode and self.ping(): - return - if self._connect_in_progress: - raise RuntimeError("CTP 连接中,请稍候") - raise RuntimeError("请先连接 CTP") - - def require_connected(self, mode: str) -> None: - """报单前检查:须已连接,不在此发起阻塞式 connect。""" - if self._connect_in_progress: - raise RuntimeError("CTP 连接中,请稍候再下单") - if self._connected_mode != mode or not self.ping(): - raise RuntimeError("请先连接 CTP(持仓监控页点击「连接 CTP」)") - if not self._td_logged_in(): - raise RuntimeError("CTP 交易通道未登录,请重连 CTP 后再下单") - - def _td_logged_in(self) -> bool: - try: - gw = self._engine.get_gateway(GATEWAY_NAME) - td = gw.td_api - return bool(getattr(td, "login_status", False)) - except Exception: - return False - - def _find_position(self, sym: str, ex_name: str, hold_direction: str) -> Any: - if not self._engine: - return None - sym_l = sym.lower() - ex_u = ex_name.upper() - want_long = hold_direction == "long" - try: - for pos in self._engine.get_all_positions(): - ps = (getattr(pos, "symbol", "") or "").lower() - pe = getattr(pos, "exchange", None) - pe_s = str(pe.value if hasattr(pe, "value") else pe or "").upper() - if ps != sym_l or pe_s != ex_u: - continue - vol = int(getattr(pos, "volume", 0) or 0) - if vol <= 0: - continue - is_long = _is_long_direction(getattr(pos, "direction", None)) - if is_long == want_long: - return pos - except Exception as exc: - logger.debug("find position: %s", exc) - return None - - def _resolve_close_offset(self, sym: str, ex_name: str, hold_direction: str, lots: int) -> Any: - from vnpy.trader.constant import Offset - - ex_u = (ex_name or "").upper() - # 上期所/能源中心/郑商所/中金所须区分平今/平昨;大商所等可用通用 CLOSE - if ex_u not in ("CZCE", "CFFEX", "SHFE", "INE"): - return Offset.CLOSE - pos = self._find_position(sym, ex_u, hold_direction) - if not pos: - for p in self._collect_positions(): - ps = (p.get("symbol") or "").lower() - if ps != sym.lower(): - continue - if (p.get("direction") or "long") != hold_direction: - continue - td = int(p.get("td_volume") or 0) - yd = int(p.get("yd_volume") or 0) - if td >= lots: - return Offset.CLOSETODAY - if yd >= lots: - return Offset.CLOSEYESTERDAY - if td + yd >= lots: - return Offset.CLOSETODAY - break - if ex_u in ("SHFE", "INE", "CZCE"): - return Offset.CLOSETODAY - return Offset.CLOSE - vol = int(getattr(pos, "volume", 0) or 0) - yd = int(getattr(pos, "yd_volume", 0) or 0) - today = max(0, vol - yd) - if today >= lots: - return Offset.CLOSETODAY - return Offset.CLOSEYESTERDAY - - def _aggressive_limit_price( - self, - ths_code: str, - sym: str, - ex_name: str, - direction: Any, - tick: float, - fallback: float, - ) -> float: - from vnpy.trader.constant import Direction - - self.subscribe_symbol(ths_code) - lp = fallback - detail = self.get_tick_detail(ths_code, mode=self._connected_mode or "") - if detail.get("price"): - lp = float(detail["price"]) - slip = max(tick, tick * 3) - if direction == Direction.LONG: - lp = lp + slip - else: - lp = max(tick, lp - slip) - return round_to_tick(lp, tick) - - def ping(self) -> bool: - """检测连接是否仍有效;无效则清除 connected 状态。""" - if not self._engine or not self._connected_mode: - return False - if self._td_logged_in(): - return True - try: - if self._engine.get_all_accounts(): - return True - except Exception as exc: - logger.debug("CTP ping failed: %s", exc) - self._connected_mode = None - return False - - def mark_disconnected(self) -> None: - self._connected_mode = None - - def reconnect_after_settings_saved(self, mode: str) -> dict[str, Any]: - """保存前置/账号后关闭旧连接,并用数据库中的新配置重连。""" - from ctp_settings import is_ctp_auto_connect_enabled - - self._close_gateway() - self._last_error = "" - _persist_last_error("") - if not is_ctp_auto_connect_enabled(): - return {"started": False, "connecting": False, "connected": False, "disabled": True} - return self.start_connect_async(mode, force=True) - - def _schedule_fee_sync(self, mode: str) -> None: - """连接成功后触发每日同步检查(非每次全量)。""" - - def _run() -> None: - time.sleep(45) - try: - from ctp_fee_worker import try_daily_ctp_fee_sync - - def _gs(key: str, default: str = "") -> str: - from fee_specs import get_setting - return get_setting(key, default) - - def _ss(key: str, val: str) -> None: - from fee_specs import set_setting - set_setting(key, val) - - try_daily_ctp_fee_sync( - mode, - get_setting=_gs, - set_setting=_ss, - force=False, - ) - except Exception as exc: - logger.debug("CTP 手续费连接后检查: %s", exc) - - threading.Thread(target=_run, daemon=True, name="ctp-fee-sync-check").start() - - def _ensure_commission_callback(self) -> None: - if self._commission_hooked or not self._engine: - return - try: - gw = self._engine.get_gateway(GATEWAY_NAME) - td = gw.td_api - except Exception: - return - bridge = self - - def on_rsp(data: dict, error: dict, reqid: int, last: bool) -> None: - if error and int(error.get("ErrorID") or 0) != 0: - logger.debug( - "CTP commission error reqid=%s: %s", - reqid, - error.get("ErrorMsg") or error, - ) - if data and data.get("InstrumentID"): - bridge._commission_lists.setdefault(reqid, []).append(dict(data)) - ev = bridge._commission_waiters.get(reqid) - if last and ev: - ev.set() - - td.onRspQryInstrumentCommissionRate = on_rsp # type: ignore[method-assign] - self._commission_hooked = True - - def _query_commission( - self, - *, - mode: str, - instrument_id: str = "", - exchange_id: str = "", - timeout: float = 8, - ) -> list[dict]: - if self._connected_mode != mode or not self._engine: - return [] - try: - gw = self._engine.get_gateway(GATEWAY_NAME) - td = gw.td_api - except Exception as exc: - logger.debug("commission query init: %s", exc) - return [] - if not getattr(td, "login_status", False): - return [] - if not hasattr(td, "reqQryInstrumentCommissionRate"): - return [] - self._ensure_commission_callback() - reqid = int(getattr(td, "reqid", 0)) + 1 - td.reqid = reqid - ev = threading.Event() - self._commission_waiters[reqid] = ev - req = { - "BrokerID": td.brokerid, - "InvestorID": td.userid, - "InstrumentID": instrument_id or "", - "ExchangeID": exchange_id or "", - } - ret = td.reqQryInstrumentCommissionRate(req, reqid) - if ret != 0: - self._commission_waiters.pop(reqid, None) - return [] - ev.wait(timeout=timeout) - self._commission_waiters.pop(reqid, None) - return self._commission_lists.pop(reqid, []) - - def query_instrument_commission(self, ths_code: str, *, mode: str) -> dict: - """查询单合约 CTP 手续费率(需已连接)。""" - try: - sym, ex_name = ths_to_vnpy_symbol(ths_code) - except Exception: - return {} - rows = self._query_commission( - mode=mode, - instrument_id=sym, - exchange_id=ex_name, - ) - return rows[-1] if rows else {} - - def query_all_commissions(self, *, mode: str) -> list[dict]: - """批量查询全部合约手续费(InstrumentID 留空)。""" - return self._query_commission(mode=mode, timeout=45) - - @staticmethod - def _parse_margin_ratio_row(data: dict) -> dict[str, float]: - long_r = float( - data.get("LongMarginRatioByMoney") - or data.get("LongMarginRatio") - or 0 - ) - short_r = float( - data.get("ShortMarginRatioByMoney") - or data.get("ShortMarginRatio") - or 0 - ) - return {"long": long_r, "short": short_r} - - def _cache_margin_ratio(self, sym: str, data: dict) -> None: - ratios = self._parse_margin_ratio_row(data) - if ratios["long"] <= 0 and ratios["short"] <= 0: - return - key = (sym or "").strip().lower() - if not key: - return - self._instrument_margin_ratios[key] = ratios - - def _ensure_instrument_margin_hooks(self) -> None: - """登录前挂钩:合约/持仓查询回报;td_api 重建后须重新挂钩。""" - if not self._engine: - return - try: - gw = self._engine.get_gateway(GATEWAY_NAME) - td = gw.td_api - except Exception: - return - bridge = self - td_id = id(td) - if td_id != self._hooks_td_api_id: - self._hooks_td_api_id = td_id - self._instrument_hooked = False - self._margin_rate_hooked = False - - if not self._instrument_hooked: - orig_inst = td.onRspQryInstrument - - def on_instrument(data: dict, error: dict, reqid: int, last: bool) -> None: - try: - if data and data.get("InstrumentID"): - bridge._cache_margin_ratio(str(data["InstrumentID"]), data) - except Exception as exc: - logger.debug("instrument margin cache: %s", exc) - if last: - _schedule_after_instruments_ready(bridge) - return orig_inst(data, error, reqid, last) - - td.onRspQryInstrument = on_instrument # type: ignore[method-assign] - - orig_pos = td.onRspQryInvestorPosition - - def on_rsp_position( - data: dict, error: dict, reqid: int, last: bool, - ) -> None: - ret = orig_pos(data, error, reqid, last) - if last: - now = time.monotonic() - if now - bridge._last_position_rsp_ts < 30.0: - return ret - bridge._last_position_rsp_ts = now - - def _after_position_query() -> None: - try: - time.sleep(1.5) - with _ctp_td_lock: - bridge.calibrate_trading_state() - _fire_position_refresh_callback() - except Exception as exc: - logger.debug("position rsp refresh: %s", exc) - - threading.Timer(0.2, _after_position_query).start() - return ret - - td.onRspQryInvestorPosition = on_rsp_position # type: ignore[method-assign] - self._instrument_hooked = True - - if self._margin_rate_hooked: - return - - def on_margin_rate(data: dict, error: dict, reqid: int, last: bool) -> None: - if error and int(error.get("ErrorID") or 0) != 0: - logger.debug( - "CTP margin rate error reqid=%s: %s", - reqid, - error.get("ErrorMsg") or error, - ) - if data and data.get("InstrumentID"): - bridge._margin_rate_lists.setdefault(reqid, []).append(dict(data)) - ev = bridge._margin_rate_waiters.get(reqid) - if last and ev: - ev.set() - - td.onRspQryInstrumentMarginRate = on_margin_rate # type: ignore[method-assign] - self._margin_rate_hooked = True - - def _query_instrument_margin_rate( - self, - *, - mode: str, - instrument_id: str, - exchange_id: str, - timeout: float = 6, - ) -> Optional[dict[str, float]]: - if self._connected_mode != mode or not self._engine: - return None - sym = (instrument_id or "").strip() - if not sym: - return None - cached = self._instrument_margin_ratios.get(sym.lower()) - if cached and (cached.get("long", 0) > 0 or cached.get("short", 0) > 0): - return cached - try: - gw = self._engine.get_gateway(GATEWAY_NAME) - td = gw.td_api - except Exception as exc: - logger.debug("margin rate query init: %s", exc) - return None - if not getattr(td, "login_status", False): - return None - if not hasattr(td, "reqQryInstrumentMarginRate"): - return None - self._ensure_instrument_margin_hooks() - reqid = int(getattr(td, "reqid", 0)) + 1 - td.reqid = reqid - ev = threading.Event() - self._margin_rate_waiters[reqid] = ev - req = { - "BrokerID": td.brokerid, - "InvestorID": td.userid, - "InstrumentID": sym, - "ExchangeID": exchange_id or "", - "InvestorRange": "1", - "HedgeFlag": "1", - } - with _ctp_td_lock: - ret = td.reqQryInstrumentMarginRate(req, reqid) - if ret != 0: - self._margin_rate_waiters.pop(reqid, None) - return None - ev.wait(timeout=timeout) - self._margin_rate_waiters.pop(reqid, None) - rows = self._margin_rate_lists.pop(reqid, []) - if not rows: - return None - ratios = self._parse_margin_ratio_row(rows[-1]) - if ratios["long"] > 0 or ratios["short"] > 0: - self._cache_margin_ratio(sym, rows[-1]) - return ratios - return None - - def _lookup_margin_ratios( - self, - sym: str, - ex_name: str, - *, - mode: Optional[str] = None, - ) -> Optional[dict[str, float]]: - key = (sym or "").strip().lower() - if not key: - return None - cached = self._instrument_margin_ratios.get(key) - if cached and (cached.get("long", 0) > 0 or cached.get("short", 0) > 0): - return cached - if mode and self._connected_mode == mode: - return self._query_instrument_margin_rate( - mode=mode, - instrument_id=sym, - exchange_id=ex_name, - ) - return None - - def _lookup_margin_per_lot(self, sym: str, direction: str) -> float: - return float( - self._margin_per_lot.get(self._position_margin_key(sym, direction), 0) or 0 - ) - - def _margin_from_ratios( - self, - price: float, - mult: float, - ratios: dict[str, float], - *, - direction: str, - ) -> Optional[float]: - long_r = float(ratios.get("long") or 0) - short_r = float(ratios.get("short") or 0) - d = (direction or "long").strip().lower() - if mult <= 0 or price <= 0: - return None - if d == "max": - candidates = [ - round(float(price) * mult * r, 2) - for r in (long_r, short_r) - if r > 0 - ] - return max(candidates) if candidates else None - if d == "short" and short_r > 0: - ratio = short_r - elif d != "short" and long_r > 0: - ratio = long_r - else: - ratio = max(long_r, short_r) - if ratio <= 0: - return None - return round(float(price) * mult * ratio, 2) - - def _tick_key(self, symbol: str, ex_name: str) -> str: - return f"{symbol.lower()}:{ex_name.upper()}" - - def _price_from_tick(self, tick: Any) -> Optional[float]: - for attr in ("last_price", "bid_price_1", "ask_price_1", "pre_close"): - try: - v = float(getattr(tick, attr, 0) or 0) - except (TypeError, ValueError): - v = 0.0 - if v > 0: - return v - return None - - def _lookup_tick(self, symbol: str, ex_name: str) -> Optional[float]: - if not self._engine: - return None - sym_l = symbol.lower() - ex_u = ex_name.upper() - try: - for tick in self._engine.get_all_ticks(): - ts = (getattr(tick, "symbol", "") or "").lower() - te = getattr(tick, "exchange", None) - te_s = str(te.value if hasattr(te, "value") else te or "").upper() - if ts == sym_l and te_s == ex_u: - p = self._price_from_tick(tick) - if p: - return p - except Exception as exc: - logger.debug("lookup tick: %s", exc) - return None - - def _bar_to_dict(self, bar: Any) -> dict: - dt = getattr(bar, "datetime", None) - d_str = dt.strftime("%Y-%m-%d %H:%M:%S") if dt else "" - return { - "d": d_str, - "o": float(getattr(bar, "open_price", 0) or 0), - "h": float(getattr(bar, "high_price", 0) or 0), - "l": float(getattr(bar, "low_price", 0) or 0), - "c": float(getattr(bar, "close_price", 0) or 0), - "v": float(getattr(bar, "volume", 0) or 0), - } - - def _ensure_bar_generator(self, sym: str, ex_name: str) -> None: - key = self._tick_key(sym, ex_name) - if key in self._bar_generators: - return - self._bars_1m[key] = deque(maxlen=4000) - - def on_bar(bar: Any) -> None: - row = self._bar_to_dict(bar) - if row.get("d"): - self._bars_1m[key].append(row) - - try: - from vnpy.trader.utility import BarGenerator - - self._bar_generators[key] = BarGenerator(on_bar=on_bar) - except ImportError: - logger.debug("BarGenerator unavailable") - - def _find_tick(self, symbol: str, ex_name: str) -> Any: - if not self._engine: - return None - sym_l = symbol.lower() - ex_u = ex_name.upper() - try: - for tick in self._engine.get_all_ticks(): - ts = (getattr(tick, "symbol", "") or "").lower() - te = getattr(tick, "exchange", None) - te_s = str(te.value if hasattr(te, "value") else te or "").upper() - if ts == sym_l and te_s == ex_u: - return tick - except Exception as exc: - logger.debug("find tick: %s", exc) - return None - - def _tick_to_bar(self, symbol: str, ex_name: str) -> Optional[dict]: - tick = self._find_tick(symbol, ex_name) - if not tick: - return None - lp = self._price_from_tick(tick) - if not lp or lp <= 0: - return None - dt = getattr(tick, "datetime", None) - d_str = dt.strftime("%Y-%m-%d %H:%M:%S") if dt else "" - if not d_str: - from datetime import datetime - from zoneinfo import ZoneInfo - - d_str = datetime.now(ZoneInfo("Asia/Shanghai")).strftime("%Y-%m-%d %H:%M:%S") - o = float(getattr(tick, "open_price", 0) or lp) - h = float(getattr(tick, "high_price", 0) or lp) - lo = float(getattr(tick, "low_price", 0) or lp) - return { - "d": d_str, - "o": o, - "h": h, - "l": lo, - "c": lp, - "v": float(getattr(tick, "volume", 0) or 0), - } - - def _on_tick(self, tick: Any) -> None: - sym = (getattr(tick, "symbol", "") or "").lower() - te = getattr(tick, "exchange", None) - ex_s = str(te.value if hasattr(te, "value") else te or "").upper() - price = self._price_from_tick(tick) - if price and price > 0: - try: - from ctp_trading_state import trading_state - - trading_state.set_tick_price(ex_s, sym, price) - except Exception: - pass - fn = _tick_sl_tp_callback - if fn: - try: - fn(ex_s, sym, float(price)) - except Exception as exc: - logger.debug("tick sl/tp callback: %s", exc) - _fire_tick_quote_callback_debounced() - key = self._tick_key(sym, ex_s) - bg = self._bar_generators.get(key) - if not bg: - return - try: - bg.update_tick(tick) - except Exception as exc: - logger.debug("bar gen tick: %s", exc) - - def _ensure_tick_handler(self) -> None: - if self._tick_hooked or not self._ee: - return - try: - from vnpy.trader.event import EVENT_TICK - except ImportError: - return - - def process_tick(event: Any) -> None: - self._on_tick(event.data) - - self._ee.register(EVENT_TICK, process_tick) - self._tick_hooked = True - - def get_kline_bars_1m(self, ths_code: str, *, mode: str) -> list[dict]: - """订阅合约并返回 1 分钟 K 线(含正在形成的 bar)。""" - if self._connected_mode != mode or not self._engine: - return [] - try: - sym, ex_name = ths_to_vnpy_symbol(ths_code) - except Exception: - return [] - key = self._tick_key(sym, ex_name) - self._ensure_bar_generator(sym, ex_name) - self.subscribe_symbol(ths_code) - for _ in range(12): - if self._bars_1m.get(key) and len(self._bars_1m[key]) > 0: - break - if self._lookup_tick(sym, ex_name): - break - time.sleep(0.2) - bars_1m = list(self._bars_1m.get(key, [])) - bg = self._bar_generators.get(key) - if bg and getattr(bg, "bar", None): - forming = self._bar_to_dict(bg.bar) - if forming.get("d"): - if not bars_1m or bars_1m[-1]["d"] != forming["d"]: - bars_1m.append(forming) - else: - bars_1m[-1] = forming - if not bars_1m: - tick_bar = self._tick_to_bar(sym, ex_name) - if tick_bar: - bars_1m = [tick_bar] - return bars_1m - - def get_tick_detail(self, ths_code: str, *, mode: str) -> dict[str, Any]: - if self._connected_mode != mode or not self._engine: - return {} - try: - sym, ex_name = ths_to_vnpy_symbol(ths_code) - except Exception: - return {} - self.subscribe_symbol(ths_code) - for _ in range(8): - tick = self._find_tick(sym, ex_name) - if tick: - price = self._price_from_tick(tick) - try: - pre_close = float(getattr(tick, "pre_close", 0) or 0) - except (TypeError, ValueError): - pre_close = 0.0 - return { - "price": price, - "pre_close": pre_close if pre_close > 0 else None, - } - time.sleep(0.2) - return {} - - def subscribe_symbol(self, ths_code: str) -> None: - if not self._engine or not self._connected_mode: - return - try: - from vnpy.trader.object import SubscribeRequest - - sym, ex_name = ths_to_vnpy_symbol(ths_code) - key = self._tick_key(sym, ex_name) - self._ensure_bar_generator(sym, ex_name) - if key in self._subscribed: - return - exchange = to_vnpy_exchange(ex_name) - self._ensure_tick_handler() - req = SubscribeRequest(symbol=sym, exchange=exchange) - self._engine.subscribe(req, GATEWAY_NAME) - self._subscribed.add(key) - except Exception as exc: - logger.debug("CTP subscribe %s: %s", ths_code, exc) - - def get_tick_price(self, ths_code: str, *, mode: str) -> Optional[float]: - if self._connected_mode != mode or not self._engine: - return None - try: - sym, ex_name = ths_to_vnpy_symbol(ths_code) - except Exception: - return None - price = self._lookup_tick(sym, ex_name) - if price: - return price - self.subscribe_symbol(ths_code) - for _ in range(8): - time.sleep(0.2) - price = self._lookup_tick(sym, ex_name) - if price: - return price - return None - - def get_account(self) -> dict[str, Any]: - if not self._engine: - return {} - accounts = self._engine.get_all_accounts() - if not accounts: - return {} - acc = accounts[0] - return { - "balance": float(getattr(acc, "balance", 0) or 0), - "available": float(getattr(acc, "available", 0) or 0), - "frozen": float(getattr(acc, "frozen", 0) or 0), - "accountid": getattr(acc, "accountid", ""), - } - - def _position_margin_key(self, sym: str, direction: str) -> str: - return f"{(sym or '').lower()}:{(direction or 'long').strip().lower()}" - - def _lookup_position_open_time(self, sym: str, direction: str) -> str: - return (self._position_open_times.get(self._position_margin_key(sym, direction)) or "").strip() - - @staticmethod - def _parse_ctp_open_datetime(date_raw: str, time_raw: str = "") -> str: - """CTP OpenDate + OpenTime → YYYY-MM-DD HH:MM[:SS]。""" - d = (date_raw or "").strip() - if len(d) >= 8 and d[:8].isdigit(): - date_part = f"{d[:4]}-{d[4:6]}-{d[6:8]}" - else: - return "" - t = (time_raw or "").strip().replace(":", "") - if len(t) >= 6 and t[:6].isdigit(): - return f"{date_part} {t[0:2]}:{t[2:4]}:{t[4:6]}" - if len(t) >= 4 and t.isdigit(): - return f"{date_part} {t[0:2]}:{t[2:4]}" - return date_part - - def _parse_ctp_open_date(raw: str) -> str: - return CtpBridge._parse_ctp_open_datetime(raw, "") - - def _install_position_margin_hook(self) -> None: - """已禁用:monkey-patch CTP 持仓回调在并发下会触发 vnctptd 段错误。""" - return - - def _lookup_position_margin(self, sym: str, direction: str) -> float: - return float(self._position_margins.get(self._position_margin_key(sym, direction), 0) or 0) - - @staticmethod - def _vnpy_sym_to_ths(sym: str, ex_name: str) -> str: - import re - - s = (sym or "").strip() - if not s: - return "" - ex = (ex_name or "").upper() - m = re.match(r"^([A-Za-z]+)(\d+)$", s) - if not m: - return s - letters, digits = m.group(1), m.group(2) - if ex == "CZCE": - return letters.upper() + (digits[-3:] if len(digits) >= 4 else digits) - return letters.lower() + digits - - def _get_contract_for_ths(self, ths_code: str) -> Any: - """按同花顺代码查 CTP 合约;精确匹配失败时在同交易所按品种前缀回退。""" - if not self._engine: - return None - try: - sym, ex_name = ths_to_vnpy_symbol(ths_code) - exchange = to_vnpy_exchange(ex_name) - vt_symbol = f"{sym}.{exchange.value}" - contract = self._engine.get_contract(vt_symbol) - if contract: - return contract - m = re.match(r"^([A-Za-z]+)", (ths_code or "").strip()) - if not m: - return None - letters = m.group(1) - ex_val = exchange.value - candidates: list[Any] = [] - get_all = getattr(self._engine, "get_all_contracts", None) - pool = list(get_all()) if callable(get_all) else [] - if not pool: - raw = getattr(self._engine, "contracts", None) - if isinstance(raw, dict): - pool = list(raw.values()) - sym_prefix = sym[: len(letters)] if sym else letters.lower() - sym_prefix_up = letters.upper() - for c in pool: - c_ex = getattr(c, "exchange", None) - c_ex_val = str(c_ex.value if hasattr(c_ex, "value") else c_ex or "") - if c_ex_val != ex_val: - continue - c_sym = str(getattr(c, "symbol", "") or "") - if ( - c_sym.lower().startswith(sym_prefix.lower()) - or c_sym.upper().startswith(sym_prefix_up) - ): - candidates.append(c) - if not candidates: - return None - candidates.sort(key=lambda c: str(getattr(c, "symbol", "") or "")) - return candidates[0] - except Exception as exc: - logger.debug("_get_contract_for_ths %s: %s", ths_code, exc) - return None - - def estimate_margin_one_lot( - self, - ths_code: str, - price: float, - *, - direction: str = "long", - ) -> Optional[float]: - """1 手保证金:持仓实收 > CTP 保证金率查询 > 合约查询缓存。""" - if not self._engine or not price or price <= 0: - return None - try: - sym, ex_name = ths_to_vnpy_symbol(ths_code) - contract = self._get_contract_for_ths(ths_code) - mult = float(getattr(contract, "size", 0) or 0) if contract else 0.0 - if mult <= 0: - mult = float(get_contract_spec(ths_code).get("mult") or 0) - d = (direction or "long").strip().lower() - if d == "max": - per_lots = [ - self._lookup_margin_per_lot(sym, side) - for side in ("long", "short") - ] - per_lots = [x for x in per_lots if x > 0] - if per_lots: - return max(per_lots) - else: - per_lot = self._lookup_margin_per_lot(sym, d) - if per_lot > 0: - return per_lot - mode = self._connected_mode - ratios = self._lookup_margin_ratios(sym, ex_name, mode=mode) - if ratios: - return self._margin_from_ratios( - price, mult, ratios, direction=d, - ) - return None - except Exception as exc: - logger.debug("estimate_margin_one_lot %s: %s", ths_code, exc) - return None - - def estimate_position_margin( - self, - sym: str, - ex_name: str, - direction: str, - lots: int, - price: float, - *, - pos: Any = None, - ) -> Optional[float]: - """持仓占用保证金:优先 vnpy 字段,其次 CTP 合约保证金率估算。""" - if lots <= 0 or price <= 0: - return None - if pos is not None: - raw = float(getattr(pos, "margin", 0) or getattr(pos, "use_margin", 0) or 0) - if raw > 0: - return round(raw, 2) - cached = self._lookup_position_margin(sym, direction) - if cached > 0: - return round(cached, 2) - ths = self._vnpy_sym_to_ths(sym, ex_name) - if not ths: - return None - per_lot = self.estimate_margin_one_lot(ths, price, direction=direction) - if per_lot and per_lot > 0: - return round(per_lot * lots, 2) - return None - - def lookup_contract_spec(self, ths_code: str) -> Optional[dict]: - """从 CTP 合约信息读取乘数与最小变动价位。""" - if not self._engine: - return None - try: - sym, ex_name = ths_to_vnpy_symbol(ths_code) - contract = self._get_contract_for_ths(ths_code) - if not contract: - return None - mult = float(getattr(contract, "size", 0) or 0) - tick = float( - getattr(contract, "pricetick", 0) - or getattr(contract, "price_tick", 0) - or 0 - ) - if mult <= 0: - return None - out: dict[str, Any] = {"mult": mult} - if tick > 0: - out["tick_size"] = tick - long_r = float(getattr(contract, "long_margin_ratio", 0) or 0) - short_r = float(getattr(contract, "short_margin_ratio", 0) or 0) - c_sym = str(getattr(contract, "symbol", "") or sym or "") - if c_sym and self._connected_mode: - queried = self._lookup_margin_ratios( - c_sym, ex_name, mode=self._connected_mode, - ) - if queried: - long_r = float(queried.get("long") or long_r) - short_r = float(queried.get("short") or short_r) - if long_r > 0 or short_r > 0: - out["margin_rate"] = max(long_r, short_r) - return out - except Exception as exc: - logger.debug("lookup_contract_spec %s: %s", ths_code, exc) - return None - - def _collect_positions(self) -> list[dict[str, Any]]: - if not self._engine: - return [] - out: list[dict[str, Any]] = [] - for pos in self._engine.get_all_positions(): - vol = int(getattr(pos, "volume", 0) or 0) - if vol <= 0: - continue - d = "long" if _is_long_direction(getattr(pos, "direction", None)) else "short" - sym = getattr(pos, "symbol", "") or "" - exchange = getattr(pos, "exchange", None) - ex_name = str(exchange.value if hasattr(exchange, "value") else exchange or "") - price = float(getattr(pos, "price", 0) or 0) - margin = self.estimate_position_margin( - sym, ex_name, d, vol, price, pos=pos, - ) - open_time = self._lookup_position_open_time(sym, d) or None - yd = int(getattr(pos, "yd_volume", 0) or 0) - td = max(0, vol - yd) - out.append({ - "symbol": sym, - "exchange": ex_name, - "direction": d, - "lots": vol, - "avg_price": price, - "pnl": float(getattr(pos, "pnl", 0) or 0), - "frozen": int(getattr(pos, "frozen", 0) or 0), - "margin": margin, - "open_time": open_time, - "yd_volume": yd, - "td_volume": td, - }) - return out - - def refresh_positions(self) -> None: - """vnpy 内存缓存持仓;禁止 query_position(vnctptd 并发查询会段错误)。""" - return - - def _has_live_positions(self) -> bool: - if not self._engine: - return False - try: - with _ctp_td_lock: - return len(self._collect_positions()) > 0 - except Exception: - return False - - def request_position_snapshot(self, *, force: bool = False) -> None: - """合约加载后查询持仓,填充 vnpy 内存(已有持仓时跳过主动查询)。""" - if not self._engine or not self._connected_mode: - return - if not force and self._has_live_positions(): - return - now = time.monotonic() - if not force and (now - self._last_position_query_ts) < POSITION_QUERY_MIN_INTERVAL_SEC: - return - try: - self._ensure_instrument_margin_hooks() - gw = self._engine.get_gateway(GATEWAY_NAME) - td = getattr(gw, "td_api", None) if gw else None - if not td or not getattr(td, "login_status", False): - logger.debug("CTP 持仓查询跳过:交易未登录") - return - if hasattr(td, "reqQryInvestorPosition"): - reqid = int(getattr(td, "reqid", 0)) + 1 - td.reqid = reqid - req = { - "BrokerID": getattr(td, "brokerid", ""), - "InvestorID": getattr(td, "userid", ""), - } - with _ctp_td_lock: - ret = td.reqQryInvestorPosition(req, reqid) - if ret == 0: - self._last_position_query_ts = now - logger.info("CTP 已请求持仓查询 reqid=%s", reqid) - else: - logger.debug("CTP 持仓查询发送失败 ret=%s", ret) - elif gw and hasattr(gw, "query_position"): - gw.query_position() - self._last_position_query_ts = now - logger.info("CTP 已请求持仓查询(gateway)") - except Exception as exc: - logger.debug("request_position_snapshot: %s", exc) - - def list_positions(self, *, refresh_if_empty: bool = True, refresh_margin: bool = False) -> list[dict[str, Any]]: - del refresh_if_empty, refresh_margin - with _ctp_td_lock: - return self._collect_positions() - - @staticmethod - def _parse_trade_offset(offset_obj: Any) -> str: - s = str(offset_obj or "").upper() - if "OPEN" in s: - return "open" - return "close" - - @staticmethod - def _parse_trade_direction(direction_obj: Any) -> str: - return "long" if _is_long_direction(direction_obj) else "short" - - @staticmethod - def _position_direction_from_trade(trade_direction: str, offset: str) -> str: - td = (trade_direction or "long").strip().lower() - if (offset or "open").strip().lower() == "open": - return td - return "short" if td == "long" else "long" - - def _format_trade_datetime(self, dt_obj: Any, date_raw: str = "", time_raw: str = "") -> str: - if dt_obj is not None: - try: - if hasattr(dt_obj, "strftime"): - return dt_obj.strftime("%Y-%m-%d %H:%M:%S") - text = str(dt_obj).strip() - if text: - return text[:19].replace("T", " ") - except Exception: - pass - parsed = self._parse_ctp_open_datetime(date_raw, time_raw) - return parsed or "" - - def _trade_row_from_vnpy(self, trade: Any) -> Optional[dict[str, Any]]: - try: - sym = (getattr(trade, "symbol", "") or "").strip() - vol = int(getattr(trade, "volume", 0) or 0) - if not sym or vol <= 0: - return None - direction = self._parse_trade_direction(getattr(trade, "direction", None)) - offset = self._parse_trade_offset(getattr(trade, "offset", None)) - exchange = getattr(trade, "exchange", None) - ex_name = str(exchange.value if hasattr(exchange, "value") else exchange or "") - dt = self._format_trade_datetime(getattr(trade, "datetime", None)) - trade_id = str(getattr(trade, "tradeid", "") or getattr(trade, "vt_tradeid", "") or "") - order_id = str(getattr(trade, "orderid", "") or getattr(trade, "vt_orderid", "") or "") - if not trade_id: - trade_id = f"{order_id}:{sym}:{offset}:{direction}:{vol}:{getattr(trade, 'price', 0)}:{dt}" - return { - "trade_id": trade_id, - "order_id": order_id, - "symbol": sym, - "exchange": ex_name, - "direction": direction, - "offset": offset, - "position_direction": self._position_direction_from_trade(direction, offset), - "lots": vol, - "price": float(getattr(trade, "price", 0) or 0), - "datetime": dt, - "commission": round(float(getattr(trade, "commission", 0) or 0), 2), - } - except Exception as exc: - logger.debug("trade_row_from_vnpy: %s", exc) - return None - - def _trade_row_from_ctp_dict(self, data: dict) -> Optional[dict[str, Any]]: - try: - sym = (data.get("InstrumentID") or data.get("instrument_id") or "").strip() - vol = int(float(data.get("Volume") or data.get("volume") or 0)) - if not sym or vol <= 0: - return None - dir_raw = str(data.get("Direction") or data.get("direction") or "") - direction = "long" if dir_raw in ("0", "2") or "LONG" in dir_raw.upper() or dir_raw == "多" else "short" - off_raw = str(data.get("OffsetFlag") or data.get("offset") or "") - if off_raw in ("0",) or "OPEN" in off_raw.upper(): - offset = "open" - else: - offset = "close" - price = float(data.get("Price") or data.get("price") or 0) - trade_id = str(data.get("TradeID") or data.get("tradeid") or "").strip() - order_sys = str(data.get("OrderSysID") or data.get("orderid") or "").strip() - dt = self._format_trade_datetime( - None, - str(data.get("TradeDate") or data.get("trade_date") or ""), - str(data.get("TradeTime") or data.get("trade_time") or ""), - ) - if not trade_id: - trade_id = f"{order_sys}:{sym}:{offset}:{direction}:{vol}:{price}:{dt}" - return { - "trade_id": trade_id, - "order_id": order_sys, - "symbol": sym, - "exchange": str(data.get("ExchangeID") or data.get("exchange") or ""), - "direction": direction, - "offset": offset, - "position_direction": self._position_direction_from_trade(direction, offset), - "lots": vol, - "price": price, - "datetime": dt, - "commission": round( - float(data.get("Commission") or data.get("commission") or 0), 2, - ), - } - except Exception as exc: - logger.debug("trade_row_from_ctp_dict: %s", exc) - return None - - def _install_trade_query_hook(self) -> None: - """不再 monkey-patch CTP 成交回调(易与并发查询冲突导致 vnctptd 段错误)。""" - return - - @staticmethod - def _engine_collection_items(raw: Any) -> list[Any]: - """vnpy 不同版本可能返回 dict 或 list。""" - if raw is None: - return [] - if isinstance(raw, dict): - return list(raw.values()) - if isinstance(raw, (list, tuple)): - return list(raw) - return [raw] - - def _collect_engine_trades(self) -> list[dict[str, Any]]: - if not self._engine: - return [] - out: list[dict[str, Any]] = [] - seen: set[str] = set() - try: - trades = self._engine.get_all_trades() - except Exception: - trades = None - for trade in self._engine_collection_items(trades): - row = self._trade_row_from_vnpy(trade) - if not row: - continue - key = row["trade_id"] - if key in seen: - continue - seen.add(key) - out.append(row) - return out - - def refresh_trades(self) -> None: - """成交仅读 vnpy 内存回报;不调用 query_trade(避免 CTP 段错误)。""" - return - - def list_trades(self, *, refresh: bool = False) -> list[dict[str, Any]]: - with _ctp_td_lock: - out = self._collect_engine_trades() - out.sort(key=lambda r: (r.get("datetime") or "", r.get("trade_id") or "")) - return out - - def list_active_orders(self) -> list[dict[str, Any]]: - if not self._engine: - return [] - out: list[dict[str, Any]] = [] - try: - orders = self._engine.get_all_active_orders() - except Exception: - return [] - for order in orders or []: - status = getattr(order, "status", None) - status_s = str(status) - if status_s and not any(x in status_s for x in ("NOTTRADED", "PARTTRADED", "SUBMITTING")): - continue - vol = int(getattr(order, "volume", 0) or 0) - traded = int(getattr(order, "traded", 0) or 0) - remain = max(0, vol - traded) - if remain <= 0: - continue - direction = getattr(order, "direction", None) - d = "long" - if direction is not None and str(direction).endswith("SHORT"): - d = "short" - offset = getattr(order, "offset", None) - offset_s = str(offset or "") - sym = getattr(order, "symbol", "") or "" - exchange = getattr(order, "exchange", None) - ex_name = str(exchange.value if hasattr(exchange, "value") else exchange or "") - vt_oid = str(getattr(order, "vt_orderid", "") or "") - order_id = str(getattr(order, "orderid", "") or "") - out.append({ - "symbol": sym, - "exchange": ex_name, - "direction": d, - "lots": remain, - "price": float(getattr(order, "price", 0) or 0), - "offset": offset_s, - "order_id": vt_oid or order_id, - "vt_order_id": vt_oid, - "status": status_s, - }) - return out - - def send_order( - self, - *, - ths_code: str, - offset: str, - direction: str, - lots: int, - price: float, - order_type: str = "limit", - ) -> str: - from vnpy.trader.constant import Direction, Offset, OrderType - from vnpy.trader.object import OrderRequest - - if not self._engine: - raise RuntimeError("CTP 未初始化") - if not self._td_logged_in(): - raise RuntimeError("CTP 交易通道未登录,请重连后再下单") - - sym, ex_name = ths_to_vnpy_symbol(ths_code) - exchange = to_vnpy_exchange(ex_name) - lots = max(1, int(lots)) - tick = float(get_contract_spec(ths_code).get("tick_size") or 1.0) - - offset = (offset or "open").lower() - direction = (direction or "long").lower() - - if offset in ("open", "open_long", "open_short"): - d = Direction.LONG if direction == "long" or offset == "open_long" else Direction.SHORT - off = Offset.OPEN - elif offset in ("close", "close_long", "close_short"): - hold = "long" if direction == "long" or offset == "close_long" else "short" - if hold == "long": - d = Direction.SHORT - else: - d = Direction.LONG - off = self._resolve_close_offset(sym, ex_name, hold, lots) - else: - raise ValueError(f"未知开平: {offset}") - - use_market = (order_type or "limit").lower() == "market" - if use_market: - ot = OrderType.FAK - price = self._aggressive_limit_price(ths_code, sym, ex_name, d, tick, price) - else: - ot = OrderType.LIMIT - price = round_to_tick(float(price), tick) - if price <= 0: - raise ValueError("委托价格无效,请检查行情或手动填写价格") - - req = OrderRequest( - symbol=sym, - exchange=exchange, - direction=d, - type=ot, - volume=lots, - price=price, - offset=off, - ) - logger.info( - "CTP 报单 %s %s %s %s手 @%s offset=%s type=%s", - sym, ex_name, d, lots, price, off, ot, - ) - with _ctp_td_lock: - vt_orderid = self._engine.send_order(req, GATEWAY_NAME) - if not vt_orderid: - raise RuntimeError("CTP 拒单或未返回委托号(请检查合约代码、价格是否为最小变动价位整数倍)") - return str(vt_orderid) - - def cancel_order(self, vt_orderid: str) -> bool: - if not self._engine or not vt_orderid: - return False - try: - with _ctp_td_lock: - order = self._engine.get_order(vt_orderid) - if order is None: - return False - req = order.create_cancel_request() - self._engine.cancel_order(req, GATEWAY_NAME) - logger.info("CTP 撤单 %s", vt_orderid) - return True - except Exception as exc: - logger.warning("CTP 撤单失败 %s: %s", vt_orderid, exc) - return False - - -class CtpBridgeProxy: - """Client-side stand-in for CtpBridge, forwarding calls to qihuo-ctp.""" - - _engine = None - - @property - def connected_mode(self) -> Optional[str]: - st = ctp_ipc_client.health().get("status") or {} - return st.get("connected_mode") - - @property - def last_error(self) -> str: - st = ctp_ipc_client.health().get("status") or {} - return str(st.get("last_error") or "") - - @property - def _last_connect_ok_ts(self) -> float: - st = ctp_ipc_client.health().get("status") or {} - try: - return float(st.get("last_connect_ok_ts") or 0) - except (TypeError, ValueError): - return 0.0 - - def available(self) -> bool: - return bool(ctp_ipc_client.health().get("worker_online")) - - def status(self, mode: str) -> dict[str, Any]: - return ctp_ipc_client.status(mode) - - def ping(self) -> bool: - return bool(ctp_ipc_client.health().get("worker_online")) - - def connect(self, mode: str, *, force: bool = False) -> dict[str, Any]: - return ctp_ipc_client.connect(mode, force=force) - - def start_connect_async( - self, - mode: str, - *, - force: bool = False, - scheduled: bool = False, - ) -> dict[str, Any]: - return ctp_ipc_client.start_connect(mode, force=force, scheduled=scheduled) - - def connect_in_progress(self) -> bool: - data = ctp_ipc_client.bridge_action("connect_in_progress") - return bool(data.get("result")) - - def login_cooldown_remaining(self) -> int: - st = ctp_ipc_client.health().get("status") or {} - try: - return int(st.get("login_cooldown_sec") or 0) - except (TypeError, ValueError): - return 0 - - def ensure_connected(self, mode: str) -> None: - if not self.status(mode).get("connected"): - raise RuntimeError("CTP worker 未连接,请重连后再操作") - - def require_connected(self, mode: str) -> None: - self.ensure_connected(mode) - - def get_account(self) -> dict[str, Any]: - mode = self.connected_mode or "simulation" - return ctp_ipc_client.account(mode) - - def list_positions( - self, - *, - refresh_if_empty: bool = True, - refresh_margin: bool = False, - ) -> list[dict[str, Any]]: - mode = self.connected_mode or "simulation" - return ctp_ipc_client.positions( - mode, - refresh_if_empty=refresh_if_empty, - refresh_margin=refresh_margin, - ) - - def list_active_orders(self) -> list[dict[str, Any]]: - mode = self.connected_mode or "simulation" - return ctp_ipc_client.active_orders(mode) - - def list_trades(self, *, refresh: bool = False) -> list[dict[str, Any]]: - mode = self.connected_mode or "simulation" - return ctp_ipc_client.trades(mode, refresh=refresh) - - def get_tick_price(self, ths_code: str, *, mode: str = "") -> Optional[float]: - return ctp_ipc_client.tick_price(mode or self.connected_mode or "simulation", ths_code) - - def get_tick_detail(self, ths_code: str, *, mode: str = "") -> dict[str, Any]: - return ctp_ipc_client.tick_detail(mode or self.connected_mode or "simulation", ths_code) - - def estimate_margin_one_lot( - self, - ths_code: str, - price: float, - *, - direction: str = "long", - ) -> Optional[float]: - return ctp_ipc_client.estimate_margin_one_lot( - self.connected_mode or "simulation", - ths_code, - price, - direction=direction, - ) - - def lookup_contract_spec(self, ths_code: str) -> Optional[dict]: - return ctp_ipc_client.contract_spec(self.connected_mode or "simulation", ths_code) - - def send_order(self, **payload: Any) -> str: - data = ctp_ipc_client.send_order(payload) - return str(data.get("order_id") or "") - - def cancel_order(self, vt_orderid: str) -> bool: - return ctp_ipc_client.cancel_order(self.connected_mode or "simulation", vt_orderid) - - def calibrate_trading_state(self) -> Any: - return ctp_ipc_client.bridge_action("calibrate_trading_state").get("result") - - def request_position_snapshot(self, *, force: bool = False) -> Any: - return ctp_ipc_client.bridge_action( - "request_position_snapshot", - {"force": bool(force)}, - ).get("result") - - def subscribe_symbol(self, symbol: str) -> Any: - return ctp_ipc_client.bridge_action("subscribe_symbol", {"symbol": symbol}).get("result") - - def refresh_positions(self) -> Any: - return ctp_ipc_client.bridge_action("refresh_positions").get("result") - - def reconnect_after_settings_saved(self, mode: str) -> Any: - return ctp_ipc_client.bridge_action( - "reconnect_after_settings_saved", - {"mode": mode}, - ).get("result") - - def query_all_commissions(self, *, mode: str = "") -> list[dict]: - data = ctp_ipc_client.bridge_action("query_all_commissions", {"mode": mode}) - return list(data.get("result") or []) - - def query_instrument_commission(self, symbol: str, *, mode: str = "") -> dict: - data = ctp_ipc_client.bridge_action( - "query_instrument_commission", - {"symbol": symbol, "mode": mode or self.connected_mode or "simulation"}, - ) - return dict(data.get("result") or {}) - - def get_kline_bars_1m(self, ths_code: str, *, mode: str) -> list[dict]: - data = ctp_ipc_client.bridge_action( - "get_kline_bars_1m", - {"symbol": ths_code, "mode": mode}, - ) - return list(data.get("result") or []) - - def _close_gateway(self) -> None: - ctp_ipc_client.disconnect() - - -def get_bridge(): - global _bridge - if _use_ctp_worker_client(): - return CtpBridgeProxy() - with _bridge_lock: - if _bridge is None: - _bridge = CtpBridge() - return _bridge - - -def try_init_vnpy(_settings: dict | None = None) -> bool: - if _use_ctp_worker_client(): - return bool(ctp_ipc_client.health().get("worker_online")) - return get_bridge().available() - - -def vnpy_available() -> bool: - if _use_ctp_worker_client(): - return bool(ctp_ipc_client.health().get("worker_online")) - return get_bridge().available() - - -def _ctp_connect_permitted(*, scheduled: bool = False) -> bool: - """scheduled=True:盘前/交易时段计划连接,不受「自动连接」开关限制。""" - from ctp_settings import is_ctp_auto_connect_enabled - - if is_ctp_auto_connect_enabled(): - return True - if not scheduled: - return False - from ctp_premarket_connect import should_auto_connect_now - - return should_auto_connect_now() - - -def ctp_disconnect(*, set_disabled_hint: bool = False) -> None: - """主动断开 CTP 并清理内存状态。""" - if _use_ctp_worker_client(): - ctp_ipc_client.disconnect(set_disabled_hint=set_disabled_hint) - return - from ctp_settings import CTP_DISABLED_HINT - - b = get_bridge() - b._close_gateway() - if set_disabled_hint: - b._last_error = CTP_DISABLED_HINT - _persist_last_error(CTP_DISABLED_HINT) - else: - b._last_error = "" - _persist_last_error("") - - -def ctp_connect(mode: str, *, force: bool = False) -> dict[str, Any]: - if _use_ctp_worker_client(): - return ctp_ipc_client.connect(mode, force=force) - b = get_bridge() - b.connect(mode, force=force) - return b.status(mode) - - -def ctp_start_connect(mode: str, *, force: bool = False, scheduled: bool = False) -> dict[str, Any]: - """非阻塞发起连接,供 Web API 使用。""" - if _use_ctp_worker_client(): - return ctp_ipc_client.start_connect(mode, force=force, scheduled=scheduled) - b = get_bridge() - info = b.start_connect_async(mode, force=force, scheduled=scheduled) - st = b.status(mode) - return {**info, "status": st} - - -def ctp_try_auto_reconnect(mode: str) -> bool: - """断线时静默异步重连;已连接且交易通道正常则不再重复 connect。""" - if _use_ctp_worker_client(): - info = ctp_ipc_client.start_connect(mode, force=False, scheduled=True) - return bool( - info.get("connected") - or info.get("connecting") - or info.get("started") - ) - if not _ctp_connect_permitted(scheduled=True): - return False - b = get_bridge() - if not b.available(): - return False - if b.connect_in_progress(): - return False - if b.login_cooldown_remaining() > 0: - return False - st = _setting_for_mode(mode) - if not st.get("用户名") or not st.get("密码") or not st.get("交易服务器"): - return False - if b.connected_mode == mode: - if b._td_logged_in() or b.ping(): - return True - recent = time.time() - float(getattr(b, "_last_connect_ok_ts", 0) or 0) - if recent < 120: - logger.debug("CTP 跳过自动重连:刚连接 %.0fs", recent) - return True - td = st.get("交易服务器", "") - ok, err = probe_tcp_address(td, timeout=4.0) - if not ok: - b._last_error = ( - f"SimNow 交易前置不可达:{td}({err})。" - "请更新 SIMNOW_TD_ADDRESS 并确认服务器出网。" - ) - return False - info = b.start_connect_async(mode, force=False, scheduled=True) - return bool( - info.get("connected") - or info.get("connecting") - or info.get("started") - ) - - -def ctp_status(mode: str) -> dict[str, Any]: - from ctp_settings import CTP_DISABLED_HINT, is_ctp_auto_connect_enabled - - if _use_ctp_worker_client(): - st = ctp_ipc_client.status(mode) - st["auto_connect_enabled"] = is_ctp_auto_connect_enabled() - return st - auto = is_ctp_auto_connect_enabled() - st = get_bridge().status(mode) - st["auto_connect_enabled"] = auto - if not auto: - st["disabled_hint"] = CTP_DISABLED_HINT - if not st.get("connected") and not st.get("connecting"): - st["last_error"] = "" - st["td_reachable"] = None - return st - if not st.get("connected") and not st.get("connecting"): - setting = _setting_for_mode(mode) - td = setting.get("交易服务器", "") - if td: - ok, err = probe_tcp_address(td, timeout=3.0) - st["td_reachable"] = ok - if not ok and not st.get("last_error"): - st["last_error"] = ( - f"SimNow 交易前置不可达:{td}({err})" - ) - return st - - -def ctp_get_account(mode: str) -> dict[str, Any]: - if _use_ctp_worker_client(): - return ctp_ipc_client.account(mode) - b = get_bridge() - b.ensure_connected(mode) - return b.get_account() - - -def ctp_sum_position_margins( - mode: str, - *, - refresh_if_empty: bool = True, - refresh_margin: bool = False, -) -> float: - """各持仓 CTP 回报保证金之和(与柜台「实收保证金」一致)。""" - total = 0.0 - for p in ctp_list_positions( - mode, refresh_if_empty=refresh_if_empty, refresh_margin=refresh_margin, - ): - m = float(p.get("margin") or 0) - if m > 0: - total += m - return round(total, 2) if total > 0 else 0.0 - - -def ctp_account_margin_used(mode: str) -> Optional[float]: - """账户实际占用保证金 ≈ 权益 − 可用(与顶栏柜台资金一致)。""" - if _use_ctp_worker_client(): - try: - acc = ctp_ipc_client.account(mode) - balance = float(acc.get("balance") or 0) - available = float(acc.get("available") or 0) - if balance <= 0: - return None - used = balance - available - return round(used, 2) if used > 0 else None - except Exception as exc: - logger.debug("ctp_account_margin_used ipc: %s", exc) - return None - b = get_bridge() - if b.connected_mode != mode or not b.ping(): - return None - try: - acc = b.get_account() - balance = float(acc.get("balance") or 0) - available = float(acc.get("available") or 0) - if balance <= 0: - return None - used = balance - available - return round(used, 2) if used > 0 else None - except Exception as exc: - logger.debug("ctp_account_margin_used: %s", exc) - return None - - -def ctp_list_positions( - mode: str, - *, - refresh_if_empty: bool = True, - refresh_margin: bool = False, -) -> list[dict[str, Any]]: - if _use_ctp_worker_client(): - return ctp_ipc_client.positions( - mode, - refresh_if_empty=refresh_if_empty, - refresh_margin=refresh_margin, - ) - b = get_bridge() - if b.connected_mode != mode or not b.ping(): - return [] - return b.list_positions(refresh_if_empty=refresh_if_empty, refresh_margin=refresh_margin) - - -def ctp_list_active_orders(mode: str) -> list[dict[str, Any]]: - if _use_ctp_worker_client(): - return ctp_ipc_client.active_orders(mode) - b = get_bridge() - b.ensure_connected(mode) - return b.list_active_orders() - - -def ctp_cancel_order(mode: str, vt_orderid: str) -> bool: - if _use_ctp_worker_client(): - return ctp_ipc_client.cancel_order(mode, vt_orderid) - b = get_bridge() - b.ensure_connected(mode) - return b.cancel_order(vt_orderid) - - -def ctp_list_trades(mode: str, *, refresh: bool = False) -> list[dict[str, Any]]: - if _use_ctp_worker_client(): - return ctp_ipc_client.trades(mode, refresh=refresh) - b = get_bridge() - if b.connected_mode != mode or not b.ping(): - return [] - return b.list_trades(refresh=refresh) - - -def ctp_get_tick_price(mode: str, ths_code: str) -> Optional[float]: - """CTP 柜台最新价(需已连接并订阅)。""" - if _use_ctp_worker_client(): - return ctp_ipc_client.tick_price(mode, ths_code) - b = get_bridge() - if b.connected_mode != mode: - return None - try: - return b.get_tick_price(ths_code, mode=mode) - except Exception as exc: - logger.debug("ctp_get_tick_price: %s", exc) - return None - - -def ctp_get_tick_detail(mode: str, ths_code: str) -> dict[str, Any]: - if _use_ctp_worker_client(): - return ctp_ipc_client.tick_detail(mode, ths_code) - b = get_bridge() - if b.connected_mode != mode: - return {} - try: - return b.get_tick_detail(ths_code, mode=mode) - except Exception as exc: - logger.debug("ctp_get_tick_detail: %s", exc) - return {} - - -def ctp_estimate_margin_one_lot( - mode: str, - ths_code: str, - price: float, - *, - direction: str = "long", -) -> Optional[float]: - if _use_ctp_worker_client(): - return ctp_ipc_client.estimate_margin_one_lot( - mode, - ths_code, - price, - direction=direction, - ) - b = get_bridge() - if b.connected_mode != mode or not b.ping(): - return None - try: - return b.estimate_margin_one_lot(ths_code, price, direction=direction) - except Exception as exc: - logger.debug("ctp_estimate_margin_one_lot: %s", exc) - return None - - -def ctp_lookup_contract_spec(mode: str, ths_code: str) -> Optional[dict]: - if _use_ctp_worker_client(): - return ctp_ipc_client.contract_spec(mode, ths_code) - b = get_bridge() - if b.connected_mode != mode or not b.ping(): - return None - try: - return b.lookup_contract_spec(ths_code) - except Exception as exc: - logger.debug("ctp_lookup_contract_spec: %s", exc) - return None - - -def get_ctp_balance(mode: str) -> Optional[float]: - try: - acc = ctp_get_account(mode) - bal = acc.get("balance") - return float(bal) if bal else None - except Exception as exc: - logger.debug("get_ctp_balance: %s", exc) - return None - - -def execute_order( - conn, - *, - mode: str, - offset: str, - symbol: str, - direction: str, - lots: int, - price: float, - settings: dict | None = None, - order_type: str = "limit", -) -> dict[str, Any]: - """统一下单:simulation=SimNow,live=期货公司 CTP。""" - if _use_ctp_worker_client(): - return ctp_ipc_client.send_order({ - "mode": mode, - "offset": offset, - "symbol": symbol, - "direction": direction, - "lots": lots, - "price": price, - "settings": settings or {}, - "order_type": order_type, - }) - del conn, settings - if mode not in ("simulation", "live"): - raise ValueError("未知交易模式") - if not vnpy_available(): - raise ValueError( - "请先安装 vnpy 与 vnpy_ctp:pip install vnpy vnpy_ctp\n" - f"模拟盘需配置 .env 中 SIMNOW_USER / SIMNOW_PASSWORD 等" - ) - b = get_bridge() - b.require_connected(mode) - order_id = b.send_order( - ths_code=symbol, - offset=offset, - direction=direction, - lots=lots, - price=price, - order_type=order_type, - ) - return { - "order_id": order_id, - "mode": mode, - "mode_label": _mode_label(mode), - "symbol": symbol, - "lots": lots, - "price": price, - } +# Copyright (c) 2025-2026 马建军. All rights reserved. +# 专有软件 — 未经授权禁止复制、传播、转售。 +# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 +# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md + +"""CTP 执行层:模拟盘 → SimNow;实盘 → 期货公司(vnpy_ctp)。""" +from __future__ import annotations + +import logging +import os +import re +import threading +import time +from collections import deque +from typing import Any, Callable, Optional + +import ctp_ipc_client +from modules.core.locale_fix import ensure_process_locale + +if ctp_ipc_client.is_worker_role(): + ensure_process_locale() + +from modules.ctp.ctp_settings import live_setting_dict, simnow_setting_dict +from modules.ctp.ctp_symbol import ths_to_vnpy_symbol, to_vnpy_exchange +from modules.core.contract_specs import get_contract_spec + +logger = logging.getLogger(__name__) + +GATEWAY_NAME = "CTP" + +CONNECT_WAIT_SEC = 60 +CONNECT_POLL_INTERVAL_SEC = 0.5 +LOGIN_BAN_COOLDOWN_SEC = 45 * 60 +LOGIN_FAIL_COOLDOWN_SEC = 5 * 60 +CTP_COOLDOWN_UNTIL_KEY = "ctp_login_cooldown_until" +CTP_LAST_ERROR_KEY = "ctp_last_error" + + +def _use_ctp_worker_client() -> bool: + """默认单进程直连 CTP;仅当显式设置 QIHUO_CTP_WORKER=1 时使用独立 Worker IPC。""" + flag = (os.getenv("QIHUO_CTP_WORKER", "") or "").strip().lower() + if flag not in ("1", "true", "yes"): + return False + return not ctp_ipc_client.is_worker_role() + + +def _persist_login_cooldown(seconds: float) -> None: + from modules.fees.fee_specs import get_setting, set_setting + + new_until = time.time() + max(0.0, seconds) + try: + old = float(get_setting(CTP_COOLDOWN_UNTIL_KEY, "0") or 0) + except (TypeError, ValueError): + old = 0.0 + if new_until > old: + set_setting(CTP_COOLDOWN_UNTIL_KEY, str(new_until)) + + +def _persisted_login_cooldown_remaining() -> int: + from modules.fees.fee_specs import get_setting + + try: + until = float(get_setting(CTP_COOLDOWN_UNTIL_KEY, "0") or 0) + return max(0, int(until - time.time())) + except (TypeError, ValueError): + return 0 + + +def _clear_persisted_login_cooldown() -> None: + from modules.fees.fee_specs import set_setting + + set_setting(CTP_COOLDOWN_UNTIL_KEY, "0") + + +def _persist_last_error(msg: str) -> None: + from modules.fees.fee_specs import set_setting + + set_setting(CTP_LAST_ERROR_KEY, (msg or "").strip()) + + +def _load_persisted_last_error() -> str: + from modules.fees.fee_specs import get_setting + + return (get_setting(CTP_LAST_ERROR_KEY, "") or "").strip() + +_position_refresh_callback: Optional[Callable[[], None]] = None +_tick_sl_tp_callback: Optional[Callable[[str, str, float], None]] = None +_tick_quote_callback: Optional[Callable[[], None]] = None +_ctp_connected_callback: Optional[Callable[[str], None]] = None +_position_refresh_debounce_lock = threading.Lock() +_position_refresh_debounce_ts: float = 0.0 +_tick_quote_timer: Optional[threading.Timer] = None +_tick_quote_timer_lock = threading.Lock() +TICK_QUOTE_DEBOUNCE_SEC = 0.12 + + +def set_position_refresh_callback(fn: Optional[Callable[[], None]]) -> None: + global _position_refresh_callback + _position_refresh_callback = fn + + +def set_tick_sl_tp_callback(fn: Optional[Callable[[str, str, float], None]]) -> None: + """注册 tick 回调:exchange, symbol, last_price → 本地 SL/TP 触发。""" + global _tick_sl_tp_callback + _tick_sl_tp_callback = fn + + +def set_tick_quote_callback(fn: Optional[Callable[[], None]]) -> None: + """注册 tick 回调:推送持仓现价/浮盈(由 bridge 侧防抖)。""" + global _tick_quote_callback + _tick_quote_callback = fn + + +def _fire_tick_quote_callback_debounced() -> None: + """持仓品种 tick 后 trailing 防抖,批量推送现价/浮盈。""" + global _tick_quote_timer + + def _run() -> None: + fn = _tick_quote_callback + if not fn: + return + try: + fn() + except Exception as exc: + logger.debug("tick quote callback: %s", exc) + + with _tick_quote_timer_lock: + if _tick_quote_timer is not None: + _tick_quote_timer.cancel() + _tick_quote_timer = threading.Timer(TICK_QUOTE_DEBOUNCE_SEC, _run) + _tick_quote_timer.daemon = True + _tick_quote_timer.start() + + +def set_ctp_connected_callback(fn: Optional[Callable[[str], None]]) -> None: + """CTP 交易通道登录成功后回调(mode=simulation|live)。""" + global _ctp_connected_callback + _ctp_connected_callback = fn + + +def _fire_ctp_connected_callback(mode: str) -> None: + fn = _ctp_connected_callback + if not fn: + return + try: + threading.Thread( + target=fn, args=(mode,), daemon=True, name="ctp-connected-cb", + ).start() + except Exception as exc: + logger.debug("ctp connected callback: %s", exc) + + +def _fire_position_refresh_callback() -> None: + fn = _position_refresh_callback + if not fn: + return + try: + threading.Thread(target=fn, daemon=True, name="ctp-position-refresh").start() + except Exception as exc: + logger.debug("position refresh callback: %s", exc) + + +def _fire_position_refresh_callback_debounced(*, min_interval: float = 0.35) -> None: + global _position_refresh_debounce_ts + now = time.monotonic() + with _position_refresh_debounce_lock: + if now - _position_refresh_debounce_ts < min_interval: + return + _position_refresh_debounce_ts = now + _fire_position_refresh_callback() + + +def _fire_position_refresh_burst() -> None: + """连接后持仓回报可能分批到达,分多次触发快照刷新。""" + _fire_position_refresh_callback() + for delay in (0.4, 0.9, 1.5, 3.0, 6.0, 12.0, 20.0): + threading.Timer(delay, _fire_position_refresh_callback).start() + + +def _schedule_after_instruments_ready(bridge: "CtpBridge") -> None: + """合约查询结束后查询持仓并校准(SimNow 登录后约 10–20s)。""" + if not getattr(bridge, "_connected_mode", None): + return + now = time.monotonic() + if now - float(getattr(bridge, "_last_instruments_ready_ts", 0) or 0) < 5.0: + return + bridge._last_instruments_ready_ts = now + + def _run() -> None: + try: + if bridge._has_live_positions(): + return + bridge._ensure_instrument_margin_hooks() + with _ctp_td_lock: + bridge.request_position_snapshot(force=True) + time.sleep(0.8) + with _ctp_td_lock: + bridge.calibrate_trading_state() + _fire_position_refresh_callback() + _fire_position_refresh_burst() + n = len(bridge._collect_positions()) + logger.info("CTP 合约加载完成,持仓 %s 条,已刷新快照", n) + except Exception as exc: + logger.debug("instruments ready refresh: %s", exc) + + threading.Timer(0.4, _run).start() + + +def _schedule_position_query_retries(bridge: "CtpBridge") -> None: + def _run() -> None: + if not bridge._connected_mode or bridge._has_live_positions(): + return + try: + bridge._ensure_instrument_margin_hooks() + with _ctp_td_lock: + bridge.request_position_snapshot(force=False) + time.sleep(1.0) + with _ctp_td_lock: + bridge.calibrate_trading_state() + _fire_position_refresh_callback() + except Exception as exc: + logger.debug("position query retry: %s", exc) + + for delay in POSITION_QUERY_RETRY_DELAYS_SEC: + threading.Timer(delay, _run).start() + +_bridge: Optional["CtpBridge"] = None +_bridge_lock = threading.Lock() +_ctp_td_lock = threading.RLock() +POSITION_QUERY_MIN_INTERVAL_SEC = 5.0 +POSITION_QUERY_RETRY_DELAYS_SEC = (1.5, 4.0, 9.0, 18.0, 35.0) +TRADE_QUERY_MIN_INTERVAL_SEC = 10.0 + + +def _simnow_setting() -> dict[str, str]: + """SimNow 仿真前置(系统设置优先,.env 兜底)。""" + return simnow_setting_dict() + + +def _live_setting() -> dict[str, str]: + return live_setting_dict() + + +def _setting_for_mode(mode: str) -> dict[str, str]: + return _simnow_setting() if mode == "simulation" else _live_setting() + + +def _mode_label(mode: str) -> str: + return "SimNow" if mode == "simulation" else "期货公司实盘" + + +def _parse_tcp_address(address: str) -> tuple[str, int]: + raw = (address or "").strip() + if raw.startswith("tcp://"): + raw = raw[6:] + if ":" not in raw: + raise ValueError(f"无效 TCP 地址: {address}") + host, port_s = raw.rsplit(":", 1) + return host, int(port_s) + + +def probe_tcp_address(address: str, timeout: float = 5.0) -> tuple[bool, str]: + """探测 CTP 前置 TCP 是否可达。""" + import socket + + try: + host, port = _parse_tcp_address(address) + with socket.create_connection((host, port), timeout=timeout): + return True, "" + except Exception as exc: + return False, str(exc) + + +def _format_ctp_failure(ctp_logs: list[str], *, td_address: str = "") -> str: + """根据 CTP 网关日志拼出可读错误。""" + if td_address: + ok, err = probe_tcp_address(td_address, timeout=4.0) + if not ok: + return ( + f"SimNow 交易前置不可达:{td_address}({err})。" + "182.254.243.31 已停用,请改 .env 为官方前置 " + "tcp://180.168.146.187:10201 / 10211,并确认服务器能访问该地址。" + ) + text = "\n".join(ctp_logs) + if "连续登录失败" in text or "登录被禁止" in text or "代码:75" in text: + return ( + "CTP 登录被临时禁止:连续失败次数过多(错误码 75)。" + "请等待约 30~60 分钟后再试,先用快期确认投资者代码与密码正确,期间勿反复点「连接」。" + ) + if "4097" in text or "Decrypt handshake" in text or "shake hand" in text.lower(): + return ( + "CTP 握手失败(4097):vnpy_ctp 与 SimNow 前置加密不匹配。" + "请执行 pip install -U vnpy vnpy_ctp 后重启,并确认 .env 中 SIMNOW_ENV=实盘" + ) + if "不合法的登录" in text or "密码" in text or "账号" in text: + tail = ctp_logs[-1] if ctp_logs else "" + return f"CTP 登录被拒:{tail or '请检查投资者代码与密码(快期能否登录)'}" + if "连接断开" in text or "disconnect" in text.lower(): + tail = ctp_logs[-1] if ctp_logs else "" + return f"CTP 连接断开:{tail or '请检查前置地址与网络'}" + if ctp_logs: + return f"CTP 连接失败:{ctp_logs[-1]}" + return "CTP 连接超时:未收到柜台回报。请检查 SimNow 账号、前置地址、网络(nc 测端口),并用快期验证账号" + + +def round_to_tick(price: float, tick: float) -> float: + if tick <= 0: + return float(price) + steps = round(float(price) / tick) + return round(steps * tick, 10) + + +def _is_long_direction(direction_obj: Any) -> bool: + s = str(direction_obj or "") + return "LONG" in s.upper() or "多" in s + + +class CtpBridge: + def __init__(self) -> None: + self._engine = None + self._ee = None + self._connected_mode: Optional[str] = None + self._last_error: str = "" + self._connect_lock = threading.Lock() + self._connect_in_progress = False + self._login_cooldown_until: float = 0.0 + self._restore_persisted_state() + self._commission_waiters: dict[int, threading.Event] = {} + self._commission_lists: dict[int, list] = {} + self._commission_hooked = False + self._margin_rate_waiters: dict[int, threading.Event] = {} + self._margin_rate_lists: dict[int, list] = {} + self._margin_rate_hooked = False + self._instrument_hooked = False + self._hooks_td_api_id: Optional[int] = None + self._ctp_log_hooked = False + self._last_instruments_ready_ts: float = 0.0 + self._last_position_rsp_ts: float = 0.0 + self._instrument_margin_ratios: dict[str, dict[str, float]] = {} + self._margin_per_lot: dict[str, float] = {} + self._subscribed: set[str] = set() + self._last_position_query_ts: float = 0.0 + self._position_margins: dict[str, float] = {} + self._position_open_times: dict[str, str] = {} + self._margin_hooked = False + self._trade_hooked = False + self._trade_query_results: list[dict[str, Any]] = [] + self._trade_query_event = threading.Event() + self._last_trade_query_ts: float = 0.0 + self._last_connect_ok_ts: float = 0.0 + self._connect_started_ts: float = 0.0 + self._tick_hooked = False + self._position_hooked = False + self._order_hooked = False + self._trade_hooked = False + self._bar_generators: dict[str, Any] = {} + self._bars_1m: dict[str, deque] = {} + self._init_engine() + + def _init_engine(self) -> None: + ensure_process_locale() + try: + from vnpy.event import EventEngine + from vnpy.trader.engine import MainEngine + from vnpy_ctp import CtpGateway + + self._ee = EventEngine() + self._engine = MainEngine(self._ee) + self._engine.add_gateway(CtpGateway) + self._ensure_position_event_hook() + self._ensure_order_event_hook() + self._ensure_trade_event_hook() + self._ensure_ctp_log_hooks() + except ImportError: + self._last_error = "未安装 vnpy / vnpy_ctp,请 pip install vnpy vnpy_ctp" + except Exception as exc: + self._last_error = str(exc) + + def _ensure_position_event_hook(self) -> None: + if self._position_hooked or not self._ee: + return + try: + from vnpy.trader.event import EVENT_POSITION + except ImportError: + return + + def _on_position(event) -> None: + try: + from modules.ctp.ctp_trading_state import trading_state + + pos = event.data + row = self._position_row_from_vnpy(pos) + if row: + sym = row.get("symbol") or "" + ex = row.get("exchange") or "" + ths = CtpBridge._vnpy_sym_to_ths(sym, ex) or sym + with _ctp_td_lock: + trades = self.list_trades() + trading_state.upsert_position( + row, notify=False, trades=trades, ths_sym=ths, + ) + sym = getattr(pos, "symbol", "") or "" + d = "long" if _is_long_direction(getattr(pos, "direction", None)) else "short" + vol = int(getattr(pos, "volume", 0) or 0) + if vol <= 0: + exchange = getattr(pos, "exchange", None) + ex_name = str(exchange.value if hasattr(exchange, "value") else exchange or "") + from modules.ctp.ctp_trading_state import position_key + + trading_state.remove_position( + position_key(ex_name, sym, d), notify=False, + ) + else: + for attr in ("margin", "use_margin", "UseMargin"): + raw = float(getattr(pos, attr, 0) or 0) + if raw > 0: + self._position_margins[self._position_margin_key(sym, d)] = raw + if vol > 0: + self._margin_per_lot[self._position_margin_key(sym, d)] = round( + raw / vol, 2, + ) + break + except Exception as exc: + logger.debug("position margin cache: %s", exc) + _fire_position_refresh_callback() + + self._ee.register(EVENT_POSITION, _on_position) + self._position_hooked = True + + def _ensure_order_event_hook(self) -> None: + if self._order_hooked or not self._ee: + return + try: + from vnpy.trader.event import EVENT_ORDER + except ImportError: + return + + def _on_order(event) -> None: + try: + from modules.ctp.ctp_trading_state import trading_state + + order = event.data + row = self._order_row_from_vnpy(order) + if not row: + return + status_s = str(row.get("status") or "") + terminal = any( + x in status_s + for x in ("ALLTRADED", "CANCELLED", "REJECTED", "全部成交", "已撤销", "拒单") + ) + oid = str(row.get("order_id") or row.get("vt_order_id") or "") + if terminal or int(row.get("lots") or 0) <= 0: + trading_state.remove_order(oid, notify=False) + else: + trading_state.upsert_order(row, notify=False) + except Exception as exc: + logger.debug("order event: %s", exc) + _fire_position_refresh_callback() + + self._ee.register(EVENT_ORDER, _on_order) + self._order_hooked = True + + def _ensure_trade_event_hook(self) -> None: + if self._trade_hooked or not self._ee: + return + try: + from vnpy.trader.event import EVENT_TRADE + except ImportError: + return + + def _on_trade(event) -> None: + try: + trade = event.data + row = self._trade_row_from_vnpy(trade) + if row and row.get("offset") == "open": + sym = row.get("symbol") or "" + pd = row.get("position_direction") or "long" + dt = row.get("datetime") or "" + if sym and dt: + self._position_open_times[self._position_margin_key(sym, pd)] = dt + except Exception as exc: + logger.debug("trade event: %s", exc) + _fire_position_refresh_callback() + + self._ee.register(EVENT_TRADE, _on_trade) + self._trade_hooked = True + + def _order_row_from_vnpy(self, order: Any) -> Optional[dict[str, Any]]: + try: + status = getattr(order, "status", None) + status_s = str(status) + vol = int(getattr(order, "volume", 0) or 0) + traded = int(getattr(order, "traded", 0) or 0) + remain = max(0, vol - traded) + direction = getattr(order, "direction", None) + d = "long" + if direction is not None and str(direction).endswith("SHORT"): + d = "short" + offset = getattr(order, "offset", None) + sym = getattr(order, "symbol", "") or "" + exchange = getattr(order, "exchange", None) + ex_name = str(exchange.value if hasattr(exchange, "value") else exchange or "") + vt_oid = str(getattr(order, "vt_orderid", "") or "") + order_id = str(getattr(order, "orderid", "") or "") + return { + "symbol": sym, + "exchange": ex_name, + "direction": d, + "lots": remain, + "price": float(getattr(order, "price", 0) or 0), + "offset": str(offset or ""), + "order_id": vt_oid or order_id, + "vt_order_id": vt_oid, + "status": status_s, + } + except Exception as exc: + logger.debug("order_row_from_vnpy: %s", exc) + return None + + def _position_row_from_vnpy(self, pos: Any) -> Optional[dict[str, Any]]: + try: + vol = int(getattr(pos, "volume", 0) or 0) + d = "long" if _is_long_direction(getattr(pos, "direction", None)) else "short" + sym = getattr(pos, "symbol", "") or "" + exchange = getattr(pos, "exchange", None) + ex_name = str(exchange.value if hasattr(exchange, "value") else exchange or "") + price = float(getattr(pos, "price", 0) or 0) + yd = int(getattr(pos, "yd_volume", 0) or 0) + td = max(0, vol - yd) + margin = self.estimate_position_margin(sym, ex_name, d, vol, price, pos=pos) + open_time = self._lookup_position_open_time(sym, d) or None + pnl = float(getattr(pos, "pnl", 0) or 0) + row = { + "symbol": sym, + "exchange": ex_name, + "direction": d, + "lots": vol, + "avg_price": price, + "pnl": pnl, + "frozen": int(getattr(pos, "frozen", 0) or 0), + "margin": margin, + "open_time": open_time, + "yd_volume": yd, + "td_volume": td, + } + try: + from modules.ctp.ctp_entry_price import round_to_tick + + ths = CtpBridge._vnpy_sym_to_ths(sym, ex_name) or sym + if price > 0: + row["avg_price"] = round_to_tick(price, ths) + except Exception as exc: + logger.debug("position avg round: %s", exc) + return row + except Exception as exc: + logger.debug("position_row_from_vnpy: %s", exc) + return None + + def calibrate_trading_state(self) -> None: + """全量校准内存簿(读 vnpy 缓存,不 query 柜台)。""" + try: + from modules.ctp.ctp_trading_state import trading_state + + with _ctp_td_lock: + orders = self.list_active_orders() + positions = self._collect_positions() + trades = self.list_trades() + preserve_margin = 0.0 + if self._connected_mode and not positions: + try: + preserve_margin = float( + ctp_account_margin_used(self._connected_mode) or 0, + ) + except Exception: + preserve_margin = 0.0 + trading_state.calibrate_from_lists( + orders, + positions, + trades=trades, + ths_for_vnpy_sym=lambda s, e: CtpBridge._vnpy_sym_to_ths(s, e) or s, + preserve_positions_if_margin=preserve_margin, + ) + except Exception as exc: + logger.debug("calibrate trading state: %s", exc) + + def available(self) -> bool: + return self._engine is not None + + @property + def last_error(self) -> str: + return self._last_error + + @property + def connected_mode(self) -> Optional[str]: + return self._connected_mode + + def connect_in_progress(self) -> bool: + return self._connect_in_progress + + def _restore_persisted_state(self) -> None: + err = _load_persisted_last_error() + if err: + self._last_error = err + db_remain = _persisted_login_cooldown_remaining() + if db_remain > 0: + self._login_cooldown_until = time.monotonic() + db_remain + + def login_cooldown_remaining(self) -> int: + """距允许再次登录的剩余秒数(内存 + 数据库,重启后仍有效)。""" + mem = max(0, int(self._login_cooldown_until - time.monotonic())) + return max(mem, _persisted_login_cooldown_remaining()) + + def _is_login_cooldown_active(self) -> bool: + return self.login_cooldown_remaining() > 0 + + def _set_login_cooldown(self, seconds: float) -> None: + until = time.monotonic() + max(0.0, seconds) + if until > self._login_cooldown_until: + self._login_cooldown_until = until + _persist_login_cooldown(seconds) + + def _clear_login_cooldown(self) -> None: + self._login_cooldown_until = 0.0 + _clear_persisted_login_cooldown() + + def _apply_login_failure_cooldown(self, ctp_logs: list[str]) -> None: + text = "\n".join(ctp_logs) + if "连续登录失败" in text or "登录被禁止" in text or "代码:75" in text: + self._set_login_cooldown(LOGIN_BAN_COOLDOWN_SEC) + elif any("登录失败" in m or "不合法的登录" in m for m in ctp_logs): + self._set_login_cooldown(LOGIN_FAIL_COOLDOWN_SEC) + + def _login_cooldown_message(self) -> str: + remain = self.login_cooldown_remaining() + return ( + f"CTP 登录冷却中,请 {remain // 60} 分 {remain % 60} 秒后再试" + f"(避免连续失败被 SimNow 封禁)" + ) + + def _close_gateway(self) -> None: + """关闭 CTP 网关,避免半连接状态下重连卡在「连接登录」。""" + if not self._engine: + return + try: + gw = self._engine.get_gateway(GATEWAY_NAME) + if gw: + gw.close() + except Exception as exc: + logger.debug("gateway close: %s", exc) + self._connected_mode = None + self._hooks_td_api_id = None + self._instrument_hooked = False + self._margin_rate_hooked = False + self._last_position_query_ts = 0.0 + self._last_instruments_ready_ts = 0.0 + try: + from modules.ctp.ctp_trading_state import trading_state + + trading_state.clear() + except Exception: + pass + time.sleep(0.6) + + def _ensure_ctp_log_hooks(self) -> None: + """监听 vnpy 日志:合约查询成功时补触发持仓刷新(重连后 td_api 可能已换)。""" + if self._ctp_log_hooked or not self._ee: + return + try: + from vnpy.trader.event import EVENT_LOG + except ImportError: + return + bridge = self + + def _on_persistent_log(event) -> None: + try: + msg = getattr(event.data, "msg", "") or str(event.data) + if "合约信息查询成功" in str(msg): + _schedule_after_instruments_ready(bridge) + except Exception as exc: + logger.debug("ctp log hook: %s", exc) + + self._ee.register(EVENT_LOG, _on_persistent_log) + self._ctp_log_hooked = True + + def _login_rejected(self, ctp_logs: list[str]) -> bool: + return any( + kw in m + for m in ctp_logs + for kw in ("登录失败", "不合法的登录", "登录被禁止", "连续登录失败") + ) + + def _wait_connected(self, mode: str, ctp_logs: list[str] | None = None) -> bool: + """等待账户回报或交易通道登录成功。""" + if not self._engine: + return False + logs = ctp_logs or [] + loops = max(1, int(CONNECT_WAIT_SEC / CONNECT_POLL_INTERVAL_SEC)) + for _ in range(loops): + if self._login_rejected(logs): + return False + try: + if self._engine.get_all_accounts(): + return True + except Exception: + pass + if self._td_logged_in(): + return True + time.sleep(CONNECT_POLL_INTERVAL_SEC) + return False + + def status(self, mode: str) -> dict[str, Any]: + if self._connected_mode == mode: + self.ping() + st = _setting_for_mode(mode) + missing = [k for k in ("用户名", "密码", "交易服务器") if not st.get(k)] + cooldown = self.login_cooldown_remaining() + connecting = bool(self._connect_in_progress and cooldown <= 0) + last_error = self._last_error or _load_persisted_last_error() + if ( + connecting + and self._connect_started_ts > 0 + and time.time() - self._connect_started_ts > CONNECT_WAIT_SEC + 10 + and not last_error + ): + last_error = ( + f"CTP 连接进行中已超过 {CONNECT_WAIT_SEC}s," + "可能前置不可达或柜台响应慢" + ) + return { + "vnpy_installed": self.available(), + "connected": self._connected_mode == mode, + "connecting": connecting, + "connected_mode": self._connected_mode, + "mode_label": _mode_label(mode), + "missing_config": missing, + "last_error": last_error, + "login_cooldown_sec": cooldown, + "broker_id": st.get("经纪商代码", ""), + "td_address": st.get("交易服务器", ""), + } + + def connect(self, mode: str, *, force: bool = False, scheduled: bool = False) -> None: + from modules.ctp.ctp_settings import CTP_DISABLED_HINT + + if not _ctp_connect_permitted(scheduled=scheduled): + self._last_error = CTP_DISABLED_HINT + _persist_last_error(CTP_DISABLED_HINT) + raise RuntimeError(CTP_DISABLED_HINT) + if self._connect_in_progress: + raise RuntimeError("CTP 正在连接中,请稍候") + if self._is_login_cooldown_active() and not force: + msg = self._login_cooldown_message() + self._last_error = msg + raise RuntimeError(msg) + if not self._engine: + raise RuntimeError(self._last_error or "vnpy 引擎未初始化") + if self._connected_mode == mode and not force: + if self.ping(): + return + self._connected_mode = None + setting = _setting_for_mode(mode) + if not setting.get("用户名") or not setting.get("密码"): + raise ValueError( + f"{_mode_label(mode)}:请在 .env 配置 " + f"{'SIMNOW_USER / SIMNOW_PASSWORD' if mode == 'simulation' else 'CTP_LIVE_USER / CTP_LIVE_PASSWORD'}" + ) + if not setting.get("交易服务器"): + raise ValueError(f"{_mode_label(mode)}:未配置交易服务器地址") + + self._connect_in_progress = True + self._connect_started_ts = time.time() + try: + with _ctp_td_lock: + with self._connect_lock: + if force and self._connected_mode: + self._close_gateway() + elif self._connected_mode and self._connected_mode != mode: + try: + self._engine.close() + except Exception: + pass + self._connected_mode = None + time.sleep(1) + elif not (self._connected_mode == mode and self.ping()): + self._close_gateway() + + ctp_logs: list[str] = [] + from vnpy.trader.event import EVENT_LOG + + def _on_log(event) -> None: + msg = getattr(event.data, "msg", "") or str(event.data) + if msg: + ctp_logs.append(str(msg)) + if len(ctp_logs) > 40: + ctp_logs.pop(0) + logger.info("CTP | %s", msg) + + self._ee.register(EVENT_LOG, _on_log) + try: + ensure_process_locale() + logger.info( + "CTP 连接 [%s] user=%s td=%s env=%s", + mode, + setting.get("用户名"), + setting.get("交易服务器"), + setting.get("柜台环境", "实盘"), + ) + td_addr = setting.get("交易服务器", "") + ok, err = probe_tcp_address(td_addr, timeout=5.0) + if not ok: + raise RuntimeError( + f"SimNow 交易前置不可达:{td_addr}({err})。" + "请更新 .env 中 SIMNOW_TD_ADDRESS 为官网最新地址," + "并在服务器执行 nc -zv 验证出网。" + ) + self._ensure_instrument_margin_hooks() + self._engine.connect(setting, GATEWAY_NAME) + if self._wait_connected(mode, ctp_logs): + self._connected_mode = mode + self._last_connect_ok_ts = time.time() + self._last_error = "" + _persist_last_error("") + self._clear_login_cooldown() + logger.info("CTP 已连接 [%s] td_login=%s accounts=%s", + mode, self._td_logged_in(), + len(self._engine.get_all_accounts() or [])) + self._schedule_fee_sync(mode) + try: + self.calibrate_trading_state() + except Exception as exc: + logger.debug("post-connect calibrate: %s", exc) + try: + self.request_position_snapshot(force=True) + except Exception as exc: + logger.debug("post-connect position query: %s", exc) + self._ensure_instrument_margin_hooks() + _fire_position_refresh_burst() + _schedule_position_query_retries(self) + _fire_ctp_connected_callback(mode) + return + finally: + self._ee.unregister(EVENT_LOG, _on_log) + + self._close_gateway() + self._apply_login_failure_cooldown(ctp_logs) + hint = _format_ctp_failure(ctp_logs, td_address=setting.get("交易服务器", "")) + self._last_error = hint + _persist_last_error(hint) + logger.warning("CTP 连接失败 [%s]: %s | logs=%s", mode, hint, ctp_logs[-5:]) + raise RuntimeError(hint) + finally: + self._connect_in_progress = False + self._connect_started_ts = 0.0 + + def start_connect_async( + self, mode: str, *, force: bool = False, scheduled: bool = False, + ) -> dict[str, Any]: + """后台连接,不阻塞 HTTP 请求。""" + from modules.ctp.ctp_settings import CTP_DISABLED_HINT + + if not _ctp_connect_permitted(scheduled=scheduled): + self._last_error = CTP_DISABLED_HINT + _persist_last_error(CTP_DISABLED_HINT) + return { + "started": False, + "connecting": False, + "connected": False, + "disabled": True, + "error": CTP_DISABLED_HINT, + } + if self._connected_mode == mode and self.ping() and not force: + return {"started": False, "connecting": False, "connected": True} + if self._connect_in_progress: + return {"started": False, "connecting": True, "connected": False} + if self._is_login_cooldown_active() and not force: + self._last_error = self._login_cooldown_message() + return { + "started": False, + "connecting": False, + "connected": False, + "cooldown": True, + } + + def _run() -> None: + try: + self.connect(mode, force=force, scheduled=scheduled) + except Exception as exc: + logger.warning("CTP 后台连接失败: %s", exc) + + def _watchdog() -> None: + deadline = CONNECT_WAIT_SEC + 25 + time.sleep(deadline) + if not self._connect_in_progress: + return + logger.warning( + "CTP 连接 watchdog 超时 %.0fs,重置连接状态 [%s]", + deadline, + mode, + ) + self._connect_in_progress = False + self._connect_started_ts = 0.0 + hint = ( + f"CTP 连接超时(>{deadline:.0f}s),可能前置不可达或柜台无响应。" + "请检查 SimNow 前置地址与账号,勿频繁重试。" + ) + self._last_error = hint + _persist_last_error(hint) + try: + self._close_gateway() + except Exception as exc: + logger.debug("watchdog gateway close: %s", exc) + + threading.Thread(target=_run, daemon=True, name="ctp-connect-async").start() + threading.Thread(target=_watchdog, daemon=True, name="ctp-connect-watchdog").start() + return {"started": True, "connecting": True, "connected": False} + + def ensure_connected(self, mode: str) -> None: + if self._connected_mode == mode and self.ping(): + return + if self._connect_in_progress: + raise RuntimeError("CTP 连接中,请稍候") + raise RuntimeError("请先连接 CTP") + + def require_connected(self, mode: str) -> None: + """报单前检查:须已连接,不在此发起阻塞式 connect。""" + if self._connect_in_progress: + raise RuntimeError("CTP 连接中,请稍候再下单") + if self._connected_mode != mode or not self.ping(): + raise RuntimeError("请先连接 CTP(持仓监控页点击「连接 CTP」)") + if not self._td_logged_in(): + raise RuntimeError("CTP 交易通道未登录,请重连 CTP 后再下单") + + def _td_logged_in(self) -> bool: + try: + gw = self._engine.get_gateway(GATEWAY_NAME) + td = gw.td_api + return bool(getattr(td, "login_status", False)) + except Exception: + return False + + def _find_position(self, sym: str, ex_name: str, hold_direction: str) -> Any: + if not self._engine: + return None + sym_l = sym.lower() + ex_u = ex_name.upper() + want_long = hold_direction == "long" + try: + for pos in self._engine.get_all_positions(): + ps = (getattr(pos, "symbol", "") or "").lower() + pe = getattr(pos, "exchange", None) + pe_s = str(pe.value if hasattr(pe, "value") else pe or "").upper() + if ps != sym_l or pe_s != ex_u: + continue + vol = int(getattr(pos, "volume", 0) or 0) + if vol <= 0: + continue + is_long = _is_long_direction(getattr(pos, "direction", None)) + if is_long == want_long: + return pos + except Exception as exc: + logger.debug("find position: %s", exc) + return None + + def _resolve_close_offset(self, sym: str, ex_name: str, hold_direction: str, lots: int) -> Any: + from vnpy.trader.constant import Offset + + ex_u = (ex_name or "").upper() + # 上期所/能源中心/郑商所/中金所须区分平今/平昨;大商所等可用通用 CLOSE + if ex_u not in ("CZCE", "CFFEX", "SHFE", "INE"): + return Offset.CLOSE + pos = self._find_position(sym, ex_u, hold_direction) + if not pos: + for p in self._collect_positions(): + ps = (p.get("symbol") or "").lower() + if ps != sym.lower(): + continue + if (p.get("direction") or "long") != hold_direction: + continue + td = int(p.get("td_volume") or 0) + yd = int(p.get("yd_volume") or 0) + if td >= lots: + return Offset.CLOSETODAY + if yd >= lots: + return Offset.CLOSEYESTERDAY + if td + yd >= lots: + return Offset.CLOSETODAY + break + if ex_u in ("SHFE", "INE", "CZCE"): + return Offset.CLOSETODAY + return Offset.CLOSE + vol = int(getattr(pos, "volume", 0) or 0) + yd = int(getattr(pos, "yd_volume", 0) or 0) + today = max(0, vol - yd) + if today >= lots: + return Offset.CLOSETODAY + return Offset.CLOSEYESTERDAY + + def _aggressive_limit_price( + self, + ths_code: str, + sym: str, + ex_name: str, + direction: Any, + tick: float, + fallback: float, + ) -> float: + from vnpy.trader.constant import Direction + + self.subscribe_symbol(ths_code) + lp = fallback + detail = self.get_tick_detail(ths_code, mode=self._connected_mode or "") + if detail.get("price"): + lp = float(detail["price"]) + slip = max(tick, tick * 3) + if direction == Direction.LONG: + lp = lp + slip + else: + lp = max(tick, lp - slip) + return round_to_tick(lp, tick) + + def ping(self) -> bool: + """检测连接是否仍有效;无效则清除 connected 状态。""" + if not self._engine or not self._connected_mode: + return False + if self._td_logged_in(): + return True + try: + if self._engine.get_all_accounts(): + return True + except Exception as exc: + logger.debug("CTP ping failed: %s", exc) + self._connected_mode = None + return False + + def mark_disconnected(self) -> None: + self._connected_mode = None + + def reconnect_after_settings_saved(self, mode: str) -> dict[str, Any]: + """保存前置/账号后关闭旧连接,并用数据库中的新配置重连。""" + from modules.ctp.ctp_settings import is_ctp_auto_connect_enabled + + self._close_gateway() + self._last_error = "" + _persist_last_error("") + if not is_ctp_auto_connect_enabled(): + return {"started": False, "connecting": False, "connected": False, "disabled": True} + return self.start_connect_async(mode, force=True) + + def _schedule_fee_sync(self, mode: str) -> None: + """连接成功后触发每日同步检查(非每次全量)。""" + + def _run() -> None: + time.sleep(45) + try: + from modules.ctp.ctp_fee_worker import try_daily_ctp_fee_sync + + def _gs(key: str, default: str = "") -> str: + from modules.fees.fee_specs import get_setting + return get_setting(key, default) + + def _ss(key: str, val: str) -> None: + from modules.fees.fee_specs import set_setting + set_setting(key, val) + + try_daily_ctp_fee_sync( + mode, + get_setting=_gs, + set_setting=_ss, + force=False, + ) + except Exception as exc: + logger.debug("CTP 手续费连接后检查: %s", exc) + + threading.Thread(target=_run, daemon=True, name="ctp-fee-sync-check").start() + + def _ensure_commission_callback(self) -> None: + if self._commission_hooked or not self._engine: + return + try: + gw = self._engine.get_gateway(GATEWAY_NAME) + td = gw.td_api + except Exception: + return + bridge = self + + def on_rsp(data: dict, error: dict, reqid: int, last: bool) -> None: + if error and int(error.get("ErrorID") or 0) != 0: + logger.debug( + "CTP commission error reqid=%s: %s", + reqid, + error.get("ErrorMsg") or error, + ) + if data and data.get("InstrumentID"): + bridge._commission_lists.setdefault(reqid, []).append(dict(data)) + ev = bridge._commission_waiters.get(reqid) + if last and ev: + ev.set() + + td.onRspQryInstrumentCommissionRate = on_rsp # type: ignore[method-assign] + self._commission_hooked = True + + def _query_commission( + self, + *, + mode: str, + instrument_id: str = "", + exchange_id: str = "", + timeout: float = 8, + ) -> list[dict]: + if self._connected_mode != mode or not self._engine: + return [] + try: + gw = self._engine.get_gateway(GATEWAY_NAME) + td = gw.td_api + except Exception as exc: + logger.debug("commission query init: %s", exc) + return [] + if not getattr(td, "login_status", False): + return [] + if not hasattr(td, "reqQryInstrumentCommissionRate"): + return [] + self._ensure_commission_callback() + reqid = int(getattr(td, "reqid", 0)) + 1 + td.reqid = reqid + ev = threading.Event() + self._commission_waiters[reqid] = ev + req = { + "BrokerID": td.brokerid, + "InvestorID": td.userid, + "InstrumentID": instrument_id or "", + "ExchangeID": exchange_id or "", + } + ret = td.reqQryInstrumentCommissionRate(req, reqid) + if ret != 0: + self._commission_waiters.pop(reqid, None) + return [] + ev.wait(timeout=timeout) + self._commission_waiters.pop(reqid, None) + return self._commission_lists.pop(reqid, []) + + def query_instrument_commission(self, ths_code: str, *, mode: str) -> dict: + """查询单合约 CTP 手续费率(需已连接)。""" + try: + sym, ex_name = ths_to_vnpy_symbol(ths_code) + except Exception: + return {} + rows = self._query_commission( + mode=mode, + instrument_id=sym, + exchange_id=ex_name, + ) + return rows[-1] if rows else {} + + def query_all_commissions(self, *, mode: str) -> list[dict]: + """批量查询全部合约手续费(InstrumentID 留空)。""" + return self._query_commission(mode=mode, timeout=45) + + @staticmethod + def _parse_margin_ratio_row(data: dict) -> dict[str, float]: + long_r = float( + data.get("LongMarginRatioByMoney") + or data.get("LongMarginRatio") + or 0 + ) + short_r = float( + data.get("ShortMarginRatioByMoney") + or data.get("ShortMarginRatio") + or 0 + ) + return {"long": long_r, "short": short_r} + + def _cache_margin_ratio(self, sym: str, data: dict) -> None: + ratios = self._parse_margin_ratio_row(data) + if ratios["long"] <= 0 and ratios["short"] <= 0: + return + key = (sym or "").strip().lower() + if not key: + return + self._instrument_margin_ratios[key] = ratios + + def _ensure_instrument_margin_hooks(self) -> None: + """登录前挂钩:合约/持仓查询回报;td_api 重建后须重新挂钩。""" + if not self._engine: + return + try: + gw = self._engine.get_gateway(GATEWAY_NAME) + td = gw.td_api + except Exception: + return + bridge = self + td_id = id(td) + if td_id != self._hooks_td_api_id: + self._hooks_td_api_id = td_id + self._instrument_hooked = False + self._margin_rate_hooked = False + + if not self._instrument_hooked: + orig_inst = td.onRspQryInstrument + + def on_instrument(data: dict, error: dict, reqid: int, last: bool) -> None: + try: + if data and data.get("InstrumentID"): + bridge._cache_margin_ratio(str(data["InstrumentID"]), data) + except Exception as exc: + logger.debug("instrument margin cache: %s", exc) + if last: + _schedule_after_instruments_ready(bridge) + return orig_inst(data, error, reqid, last) + + td.onRspQryInstrument = on_instrument # type: ignore[method-assign] + + orig_pos = td.onRspQryInvestorPosition + + def on_rsp_position( + data: dict, error: dict, reqid: int, last: bool, + ) -> None: + ret = orig_pos(data, error, reqid, last) + if last: + now = time.monotonic() + if now - bridge._last_position_rsp_ts < 30.0: + return ret + bridge._last_position_rsp_ts = now + + def _after_position_query() -> None: + try: + time.sleep(1.5) + with _ctp_td_lock: + bridge.calibrate_trading_state() + _fire_position_refresh_callback() + except Exception as exc: + logger.debug("position rsp refresh: %s", exc) + + threading.Timer(0.2, _after_position_query).start() + return ret + + td.onRspQryInvestorPosition = on_rsp_position # type: ignore[method-assign] + self._instrument_hooked = True + + if self._margin_rate_hooked: + return + + def on_margin_rate(data: dict, error: dict, reqid: int, last: bool) -> None: + if error and int(error.get("ErrorID") or 0) != 0: + logger.debug( + "CTP margin rate error reqid=%s: %s", + reqid, + error.get("ErrorMsg") or error, + ) + if data and data.get("InstrumentID"): + bridge._margin_rate_lists.setdefault(reqid, []).append(dict(data)) + ev = bridge._margin_rate_waiters.get(reqid) + if last and ev: + ev.set() + + td.onRspQryInstrumentMarginRate = on_margin_rate # type: ignore[method-assign] + self._margin_rate_hooked = True + + def _query_instrument_margin_rate( + self, + *, + mode: str, + instrument_id: str, + exchange_id: str, + timeout: float = 6, + ) -> Optional[dict[str, float]]: + if self._connected_mode != mode or not self._engine: + return None + sym = (instrument_id or "").strip() + if not sym: + return None + cached = self._instrument_margin_ratios.get(sym.lower()) + if cached and (cached.get("long", 0) > 0 or cached.get("short", 0) > 0): + return cached + try: + gw = self._engine.get_gateway(GATEWAY_NAME) + td = gw.td_api + except Exception as exc: + logger.debug("margin rate query init: %s", exc) + return None + if not getattr(td, "login_status", False): + return None + if not hasattr(td, "reqQryInstrumentMarginRate"): + return None + self._ensure_instrument_margin_hooks() + reqid = int(getattr(td, "reqid", 0)) + 1 + td.reqid = reqid + ev = threading.Event() + self._margin_rate_waiters[reqid] = ev + req = { + "BrokerID": td.brokerid, + "InvestorID": td.userid, + "InstrumentID": sym, + "ExchangeID": exchange_id or "", + "InvestorRange": "1", + "HedgeFlag": "1", + } + with _ctp_td_lock: + ret = td.reqQryInstrumentMarginRate(req, reqid) + if ret != 0: + self._margin_rate_waiters.pop(reqid, None) + return None + ev.wait(timeout=timeout) + self._margin_rate_waiters.pop(reqid, None) + rows = self._margin_rate_lists.pop(reqid, []) + if not rows: + return None + ratios = self._parse_margin_ratio_row(rows[-1]) + if ratios["long"] > 0 or ratios["short"] > 0: + self._cache_margin_ratio(sym, rows[-1]) + return ratios + return None + + def _lookup_margin_ratios( + self, + sym: str, + ex_name: str, + *, + mode: Optional[str] = None, + ) -> Optional[dict[str, float]]: + key = (sym or "").strip().lower() + if not key: + return None + cached = self._instrument_margin_ratios.get(key) + if cached and (cached.get("long", 0) > 0 or cached.get("short", 0) > 0): + return cached + if mode and self._connected_mode == mode: + return self._query_instrument_margin_rate( + mode=mode, + instrument_id=sym, + exchange_id=ex_name, + ) + return None + + def _lookup_margin_per_lot(self, sym: str, direction: str) -> float: + return float( + self._margin_per_lot.get(self._position_margin_key(sym, direction), 0) or 0 + ) + + def _margin_from_ratios( + self, + price: float, + mult: float, + ratios: dict[str, float], + *, + direction: str, + ) -> Optional[float]: + long_r = float(ratios.get("long") or 0) + short_r = float(ratios.get("short") or 0) + d = (direction or "long").strip().lower() + if mult <= 0 or price <= 0: + return None + if d == "max": + candidates = [ + round(float(price) * mult * r, 2) + for r in (long_r, short_r) + if r > 0 + ] + return max(candidates) if candidates else None + if d == "short" and short_r > 0: + ratio = short_r + elif d != "short" and long_r > 0: + ratio = long_r + else: + ratio = max(long_r, short_r) + if ratio <= 0: + return None + return round(float(price) * mult * ratio, 2) + + def _tick_key(self, symbol: str, ex_name: str) -> str: + return f"{symbol.lower()}:{ex_name.upper()}" + + def _price_from_tick(self, tick: Any) -> Optional[float]: + for attr in ("last_price", "bid_price_1", "ask_price_1", "pre_close"): + try: + v = float(getattr(tick, attr, 0) or 0) + except (TypeError, ValueError): + v = 0.0 + if v > 0: + return v + return None + + def _lookup_tick(self, symbol: str, ex_name: str) -> Optional[float]: + if not self._engine: + return None + sym_l = symbol.lower() + ex_u = ex_name.upper() + try: + for tick in self._engine.get_all_ticks(): + ts = (getattr(tick, "symbol", "") or "").lower() + te = getattr(tick, "exchange", None) + te_s = str(te.value if hasattr(te, "value") else te or "").upper() + if ts == sym_l and te_s == ex_u: + p = self._price_from_tick(tick) + if p: + return p + except Exception as exc: + logger.debug("lookup tick: %s", exc) + return None + + def _bar_to_dict(self, bar: Any) -> dict: + dt = getattr(bar, "datetime", None) + d_str = dt.strftime("%Y-%m-%d %H:%M:%S") if dt else "" + return { + "d": d_str, + "o": float(getattr(bar, "open_price", 0) or 0), + "h": float(getattr(bar, "high_price", 0) or 0), + "l": float(getattr(bar, "low_price", 0) or 0), + "c": float(getattr(bar, "close_price", 0) or 0), + "v": float(getattr(bar, "volume", 0) or 0), + } + + def _ensure_bar_generator(self, sym: str, ex_name: str) -> None: + key = self._tick_key(sym, ex_name) + if key in self._bar_generators: + return + self._bars_1m[key] = deque(maxlen=4000) + + def on_bar(bar: Any) -> None: + row = self._bar_to_dict(bar) + if row.get("d"): + self._bars_1m[key].append(row) + + try: + from vnpy.trader.utility import BarGenerator + + self._bar_generators[key] = BarGenerator(on_bar=on_bar) + except ImportError: + logger.debug("BarGenerator unavailable") + + def _find_tick(self, symbol: str, ex_name: str) -> Any: + if not self._engine: + return None + sym_l = symbol.lower() + ex_u = ex_name.upper() + try: + for tick in self._engine.get_all_ticks(): + ts = (getattr(tick, "symbol", "") or "").lower() + te = getattr(tick, "exchange", None) + te_s = str(te.value if hasattr(te, "value") else te or "").upper() + if ts == sym_l and te_s == ex_u: + return tick + except Exception as exc: + logger.debug("find tick: %s", exc) + return None + + def _tick_to_bar(self, symbol: str, ex_name: str) -> Optional[dict]: + tick = self._find_tick(symbol, ex_name) + if not tick: + return None + lp = self._price_from_tick(tick) + if not lp or lp <= 0: + return None + dt = getattr(tick, "datetime", None) + d_str = dt.strftime("%Y-%m-%d %H:%M:%S") if dt else "" + if not d_str: + from datetime import datetime + from zoneinfo import ZoneInfo + + d_str = datetime.now(ZoneInfo("Asia/Shanghai")).strftime("%Y-%m-%d %H:%M:%S") + o = float(getattr(tick, "open_price", 0) or lp) + h = float(getattr(tick, "high_price", 0) or lp) + lo = float(getattr(tick, "low_price", 0) or lp) + return { + "d": d_str, + "o": o, + "h": h, + "l": lo, + "c": lp, + "v": float(getattr(tick, "volume", 0) or 0), + } + + def _on_tick(self, tick: Any) -> None: + sym = (getattr(tick, "symbol", "") or "").lower() + te = getattr(tick, "exchange", None) + ex_s = str(te.value if hasattr(te, "value") else te or "").upper() + price = self._price_from_tick(tick) + if price and price > 0: + try: + from modules.ctp.ctp_trading_state import trading_state + + trading_state.set_tick_price(ex_s, sym, price) + except Exception: + pass + fn = _tick_sl_tp_callback + if fn: + try: + fn(ex_s, sym, float(price)) + except Exception as exc: + logger.debug("tick sl/tp callback: %s", exc) + _fire_tick_quote_callback_debounced() + key = self._tick_key(sym, ex_s) + bg = self._bar_generators.get(key) + if not bg: + return + try: + bg.update_tick(tick) + except Exception as exc: + logger.debug("bar gen tick: %s", exc) + + def _ensure_tick_handler(self) -> None: + if self._tick_hooked or not self._ee: + return + try: + from vnpy.trader.event import EVENT_TICK + except ImportError: + return + + def process_tick(event: Any) -> None: + self._on_tick(event.data) + + self._ee.register(EVENT_TICK, process_tick) + self._tick_hooked = True + + def get_kline_bars_1m(self, ths_code: str, *, mode: str) -> list[dict]: + """订阅合约并返回 1 分钟 K 线(含正在形成的 bar)。""" + if self._connected_mode != mode or not self._engine: + return [] + try: + sym, ex_name = ths_to_vnpy_symbol(ths_code) + except Exception: + return [] + key = self._tick_key(sym, ex_name) + self._ensure_bar_generator(sym, ex_name) + self.subscribe_symbol(ths_code) + for _ in range(12): + if self._bars_1m.get(key) and len(self._bars_1m[key]) > 0: + break + if self._lookup_tick(sym, ex_name): + break + time.sleep(0.2) + bars_1m = list(self._bars_1m.get(key, [])) + bg = self._bar_generators.get(key) + if bg and getattr(bg, "bar", None): + forming = self._bar_to_dict(bg.bar) + if forming.get("d"): + if not bars_1m or bars_1m[-1]["d"] != forming["d"]: + bars_1m.append(forming) + else: + bars_1m[-1] = forming + if not bars_1m: + tick_bar = self._tick_to_bar(sym, ex_name) + if tick_bar: + bars_1m = [tick_bar] + return bars_1m + + def get_tick_detail(self, ths_code: str, *, mode: str) -> dict[str, Any]: + if self._connected_mode != mode or not self._engine: + return {} + try: + sym, ex_name = ths_to_vnpy_symbol(ths_code) + except Exception: + return {} + self.subscribe_symbol(ths_code) + for _ in range(8): + tick = self._find_tick(sym, ex_name) + if tick: + price = self._price_from_tick(tick) + try: + pre_close = float(getattr(tick, "pre_close", 0) or 0) + except (TypeError, ValueError): + pre_close = 0.0 + return { + "price": price, + "pre_close": pre_close if pre_close > 0 else None, + } + time.sleep(0.2) + return {} + + def subscribe_symbol(self, ths_code: str) -> None: + if not self._engine or not self._connected_mode: + return + try: + from vnpy.trader.object import SubscribeRequest + + sym, ex_name = ths_to_vnpy_symbol(ths_code) + key = self._tick_key(sym, ex_name) + self._ensure_bar_generator(sym, ex_name) + if key in self._subscribed: + return + exchange = to_vnpy_exchange(ex_name) + self._ensure_tick_handler() + req = SubscribeRequest(symbol=sym, exchange=exchange) + self._engine.subscribe(req, GATEWAY_NAME) + self._subscribed.add(key) + except Exception as exc: + logger.debug("CTP subscribe %s: %s", ths_code, exc) + + def get_tick_price(self, ths_code: str, *, mode: str) -> Optional[float]: + if self._connected_mode != mode or not self._engine: + return None + try: + sym, ex_name = ths_to_vnpy_symbol(ths_code) + except Exception: + return None + price = self._lookup_tick(sym, ex_name) + if price: + return price + self.subscribe_symbol(ths_code) + for _ in range(8): + time.sleep(0.2) + price = self._lookup_tick(sym, ex_name) + if price: + return price + return None + + def get_account(self) -> dict[str, Any]: + if not self._engine: + return {} + accounts = self._engine.get_all_accounts() + if not accounts: + return {} + acc = accounts[0] + return { + "balance": float(getattr(acc, "balance", 0) or 0), + "available": float(getattr(acc, "available", 0) or 0), + "frozen": float(getattr(acc, "frozen", 0) or 0), + "accountid": getattr(acc, "accountid", ""), + } + + def _position_margin_key(self, sym: str, direction: str) -> str: + return f"{(sym or '').lower()}:{(direction or 'long').strip().lower()}" + + def _lookup_position_open_time(self, sym: str, direction: str) -> str: + return (self._position_open_times.get(self._position_margin_key(sym, direction)) or "").strip() + + @staticmethod + def _parse_ctp_open_datetime(date_raw: str, time_raw: str = "") -> str: + """CTP OpenDate + OpenTime → YYYY-MM-DD HH:MM[:SS]。""" + d = (date_raw or "").strip() + if len(d) >= 8 and d[:8].isdigit(): + date_part = f"{d[:4]}-{d[4:6]}-{d[6:8]}" + else: + return "" + t = (time_raw or "").strip().replace(":", "") + if len(t) >= 6 and t[:6].isdigit(): + return f"{date_part} {t[0:2]}:{t[2:4]}:{t[4:6]}" + if len(t) >= 4 and t.isdigit(): + return f"{date_part} {t[0:2]}:{t[2:4]}" + return date_part + + def _parse_ctp_open_date(raw: str) -> str: + return CtpBridge._parse_ctp_open_datetime(raw, "") + + def _install_position_margin_hook(self) -> None: + """已禁用:monkey-patch CTP 持仓回调在并发下会触发 vnctptd 段错误。""" + return + + def _lookup_position_margin(self, sym: str, direction: str) -> float: + return float(self._position_margins.get(self._position_margin_key(sym, direction), 0) or 0) + + @staticmethod + def _vnpy_sym_to_ths(sym: str, ex_name: str) -> str: + import re + + s = (sym or "").strip() + if not s: + return "" + ex = (ex_name or "").upper() + m = re.match(r"^([A-Za-z]+)(\d+)$", s) + if not m: + return s + letters, digits = m.group(1), m.group(2) + if ex == "CZCE": + return letters.upper() + (digits[-3:] if len(digits) >= 4 else digits) + return letters.lower() + digits + + def _get_contract_for_ths(self, ths_code: str) -> Any: + """按同花顺代码查 CTP 合约;精确匹配失败时在同交易所按品种前缀回退。""" + if not self._engine: + return None + try: + sym, ex_name = ths_to_vnpy_symbol(ths_code) + exchange = to_vnpy_exchange(ex_name) + vt_symbol = f"{sym}.{exchange.value}" + contract = self._engine.get_contract(vt_symbol) + if contract: + return contract + m = re.match(r"^([A-Za-z]+)", (ths_code or "").strip()) + if not m: + return None + letters = m.group(1) + ex_val = exchange.value + candidates: list[Any] = [] + get_all = getattr(self._engine, "get_all_contracts", None) + pool = list(get_all()) if callable(get_all) else [] + if not pool: + raw = getattr(self._engine, "contracts", None) + if isinstance(raw, dict): + pool = list(raw.values()) + sym_prefix = sym[: len(letters)] if sym else letters.lower() + sym_prefix_up = letters.upper() + for c in pool: + c_ex = getattr(c, "exchange", None) + c_ex_val = str(c_ex.value if hasattr(c_ex, "value") else c_ex or "") + if c_ex_val != ex_val: + continue + c_sym = str(getattr(c, "symbol", "") or "") + if ( + c_sym.lower().startswith(sym_prefix.lower()) + or c_sym.upper().startswith(sym_prefix_up) + ): + candidates.append(c) + if not candidates: + return None + candidates.sort(key=lambda c: str(getattr(c, "symbol", "") or "")) + return candidates[0] + except Exception as exc: + logger.debug("_get_contract_for_ths %s: %s", ths_code, exc) + return None + + def estimate_margin_one_lot( + self, + ths_code: str, + price: float, + *, + direction: str = "long", + ) -> Optional[float]: + """1 手保证金:持仓实收 > CTP 保证金率查询 > 合约查询缓存。""" + if not self._engine or not price or price <= 0: + return None + try: + sym, ex_name = ths_to_vnpy_symbol(ths_code) + contract = self._get_contract_for_ths(ths_code) + mult = float(getattr(contract, "size", 0) or 0) if contract else 0.0 + if mult <= 0: + mult = float(get_contract_spec(ths_code).get("mult") or 0) + d = (direction or "long").strip().lower() + if d == "max": + per_lots = [ + self._lookup_margin_per_lot(sym, side) + for side in ("long", "short") + ] + per_lots = [x for x in per_lots if x > 0] + if per_lots: + return max(per_lots) + else: + per_lot = self._lookup_margin_per_lot(sym, d) + if per_lot > 0: + return per_lot + mode = self._connected_mode + ratios = self._lookup_margin_ratios(sym, ex_name, mode=mode) + if ratios: + return self._margin_from_ratios( + price, mult, ratios, direction=d, + ) + return None + except Exception as exc: + logger.debug("estimate_margin_one_lot %s: %s", ths_code, exc) + return None + + def estimate_position_margin( + self, + sym: str, + ex_name: str, + direction: str, + lots: int, + price: float, + *, + pos: Any = None, + ) -> Optional[float]: + """持仓占用保证金:优先 vnpy 字段,其次 CTP 合约保证金率估算。""" + if lots <= 0 or price <= 0: + return None + if pos is not None: + raw = float(getattr(pos, "margin", 0) or getattr(pos, "use_margin", 0) or 0) + if raw > 0: + return round(raw, 2) + cached = self._lookup_position_margin(sym, direction) + if cached > 0: + return round(cached, 2) + ths = self._vnpy_sym_to_ths(sym, ex_name) + if not ths: + return None + per_lot = self.estimate_margin_one_lot(ths, price, direction=direction) + if per_lot and per_lot > 0: + return round(per_lot * lots, 2) + return None + + def lookup_contract_spec(self, ths_code: str) -> Optional[dict]: + """从 CTP 合约信息读取乘数与最小变动价位。""" + if not self._engine: + return None + try: + sym, ex_name = ths_to_vnpy_symbol(ths_code) + contract = self._get_contract_for_ths(ths_code) + if not contract: + return None + mult = float(getattr(contract, "size", 0) or 0) + tick = float( + getattr(contract, "pricetick", 0) + or getattr(contract, "price_tick", 0) + or 0 + ) + if mult <= 0: + return None + out: dict[str, Any] = {"mult": mult} + if tick > 0: + out["tick_size"] = tick + long_r = float(getattr(contract, "long_margin_ratio", 0) or 0) + short_r = float(getattr(contract, "short_margin_ratio", 0) or 0) + c_sym = str(getattr(contract, "symbol", "") or sym or "") + if c_sym and self._connected_mode: + queried = self._lookup_margin_ratios( + c_sym, ex_name, mode=self._connected_mode, + ) + if queried: + long_r = float(queried.get("long") or long_r) + short_r = float(queried.get("short") or short_r) + if long_r > 0 or short_r > 0: + out["margin_rate"] = max(long_r, short_r) + return out + except Exception as exc: + logger.debug("lookup_contract_spec %s: %s", ths_code, exc) + return None + + def _collect_positions(self) -> list[dict[str, Any]]: + if not self._engine: + return [] + out: list[dict[str, Any]] = [] + for pos in self._engine.get_all_positions(): + vol = int(getattr(pos, "volume", 0) or 0) + if vol <= 0: + continue + d = "long" if _is_long_direction(getattr(pos, "direction", None)) else "short" + sym = getattr(pos, "symbol", "") or "" + exchange = getattr(pos, "exchange", None) + ex_name = str(exchange.value if hasattr(exchange, "value") else exchange or "") + price = float(getattr(pos, "price", 0) or 0) + margin = self.estimate_position_margin( + sym, ex_name, d, vol, price, pos=pos, + ) + open_time = self._lookup_position_open_time(sym, d) or None + yd = int(getattr(pos, "yd_volume", 0) or 0) + td = max(0, vol - yd) + out.append({ + "symbol": sym, + "exchange": ex_name, + "direction": d, + "lots": vol, + "avg_price": price, + "pnl": float(getattr(pos, "pnl", 0) or 0), + "frozen": int(getattr(pos, "frozen", 0) or 0), + "margin": margin, + "open_time": open_time, + "yd_volume": yd, + "td_volume": td, + }) + return out + + def refresh_positions(self) -> None: + """vnpy 内存缓存持仓;禁止 query_position(vnctptd 并发查询会段错误)。""" + return + + def _has_live_positions(self) -> bool: + if not self._engine: + return False + try: + with _ctp_td_lock: + return len(self._collect_positions()) > 0 + except Exception: + return False + + def request_position_snapshot(self, *, force: bool = False) -> None: + """合约加载后查询持仓,填充 vnpy 内存(已有持仓时跳过主动查询)。""" + if not self._engine or not self._connected_mode: + return + if not force and self._has_live_positions(): + return + now = time.monotonic() + if not force and (now - self._last_position_query_ts) < POSITION_QUERY_MIN_INTERVAL_SEC: + return + try: + self._ensure_instrument_margin_hooks() + gw = self._engine.get_gateway(GATEWAY_NAME) + td = getattr(gw, "td_api", None) if gw else None + if not td or not getattr(td, "login_status", False): + logger.debug("CTP 持仓查询跳过:交易未登录") + return + if hasattr(td, "reqQryInvestorPosition"): + reqid = int(getattr(td, "reqid", 0)) + 1 + td.reqid = reqid + req = { + "BrokerID": getattr(td, "brokerid", ""), + "InvestorID": getattr(td, "userid", ""), + } + with _ctp_td_lock: + ret = td.reqQryInvestorPosition(req, reqid) + if ret == 0: + self._last_position_query_ts = now + logger.info("CTP 已请求持仓查询 reqid=%s", reqid) + else: + logger.debug("CTP 持仓查询发送失败 ret=%s", ret) + elif gw and hasattr(gw, "query_position"): + gw.query_position() + self._last_position_query_ts = now + logger.info("CTP 已请求持仓查询(gateway)") + except Exception as exc: + logger.debug("request_position_snapshot: %s", exc) + + def list_positions(self, *, refresh_if_empty: bool = True, refresh_margin: bool = False) -> list[dict[str, Any]]: + del refresh_if_empty, refresh_margin + with _ctp_td_lock: + return self._collect_positions() + + @staticmethod + def _parse_trade_offset(offset_obj: Any) -> str: + s = str(offset_obj or "").upper() + if "OPEN" in s: + return "open" + return "close" + + @staticmethod + def _parse_trade_direction(direction_obj: Any) -> str: + return "long" if _is_long_direction(direction_obj) else "short" + + @staticmethod + def _position_direction_from_trade(trade_direction: str, offset: str) -> str: + td = (trade_direction or "long").strip().lower() + if (offset or "open").strip().lower() == "open": + return td + return "short" if td == "long" else "long" + + def _format_trade_datetime(self, dt_obj: Any, date_raw: str = "", time_raw: str = "") -> str: + if dt_obj is not None: + try: + if hasattr(dt_obj, "strftime"): + return dt_obj.strftime("%Y-%m-%d %H:%M:%S") + text = str(dt_obj).strip() + if text: + return text[:19].replace("T", " ") + except Exception: + pass + parsed = self._parse_ctp_open_datetime(date_raw, time_raw) + return parsed or "" + + def _trade_row_from_vnpy(self, trade: Any) -> Optional[dict[str, Any]]: + try: + sym = (getattr(trade, "symbol", "") or "").strip() + vol = int(getattr(trade, "volume", 0) or 0) + if not sym or vol <= 0: + return None + direction = self._parse_trade_direction(getattr(trade, "direction", None)) + offset = self._parse_trade_offset(getattr(trade, "offset", None)) + exchange = getattr(trade, "exchange", None) + ex_name = str(exchange.value if hasattr(exchange, "value") else exchange or "") + dt = self._format_trade_datetime(getattr(trade, "datetime", None)) + trade_id = str(getattr(trade, "tradeid", "") or getattr(trade, "vt_tradeid", "") or "") + order_id = str(getattr(trade, "orderid", "") or getattr(trade, "vt_orderid", "") or "") + if not trade_id: + trade_id = f"{order_id}:{sym}:{offset}:{direction}:{vol}:{getattr(trade, 'price', 0)}:{dt}" + return { + "trade_id": trade_id, + "order_id": order_id, + "symbol": sym, + "exchange": ex_name, + "direction": direction, + "offset": offset, + "position_direction": self._position_direction_from_trade(direction, offset), + "lots": vol, + "price": float(getattr(trade, "price", 0) or 0), + "datetime": dt, + "commission": round(float(getattr(trade, "commission", 0) or 0), 2), + } + except Exception as exc: + logger.debug("trade_row_from_vnpy: %s", exc) + return None + + def _trade_row_from_ctp_dict(self, data: dict) -> Optional[dict[str, Any]]: + try: + sym = (data.get("InstrumentID") or data.get("instrument_id") or "").strip() + vol = int(float(data.get("Volume") or data.get("volume") or 0)) + if not sym or vol <= 0: + return None + dir_raw = str(data.get("Direction") or data.get("direction") or "") + direction = "long" if dir_raw in ("0", "2") or "LONG" in dir_raw.upper() or dir_raw == "多" else "short" + off_raw = str(data.get("OffsetFlag") or data.get("offset") or "") + if off_raw in ("0",) or "OPEN" in off_raw.upper(): + offset = "open" + else: + offset = "close" + price = float(data.get("Price") or data.get("price") or 0) + trade_id = str(data.get("TradeID") or data.get("tradeid") or "").strip() + order_sys = str(data.get("OrderSysID") or data.get("orderid") or "").strip() + dt = self._format_trade_datetime( + None, + str(data.get("TradeDate") or data.get("trade_date") or ""), + str(data.get("TradeTime") or data.get("trade_time") or ""), + ) + if not trade_id: + trade_id = f"{order_sys}:{sym}:{offset}:{direction}:{vol}:{price}:{dt}" + return { + "trade_id": trade_id, + "order_id": order_sys, + "symbol": sym, + "exchange": str(data.get("ExchangeID") or data.get("exchange") or ""), + "direction": direction, + "offset": offset, + "position_direction": self._position_direction_from_trade(direction, offset), + "lots": vol, + "price": price, + "datetime": dt, + "commission": round( + float(data.get("Commission") or data.get("commission") or 0), 2, + ), + } + except Exception as exc: + logger.debug("trade_row_from_ctp_dict: %s", exc) + return None + + def _install_trade_query_hook(self) -> None: + """不再 monkey-patch CTP 成交回调(易与并发查询冲突导致 vnctptd 段错误)。""" + return + + @staticmethod + def _engine_collection_items(raw: Any) -> list[Any]: + """vnpy 不同版本可能返回 dict 或 list。""" + if raw is None: + return [] + if isinstance(raw, dict): + return list(raw.values()) + if isinstance(raw, (list, tuple)): + return list(raw) + return [raw] + + def _collect_engine_trades(self) -> list[dict[str, Any]]: + if not self._engine: + return [] + out: list[dict[str, Any]] = [] + seen: set[str] = set() + try: + trades = self._engine.get_all_trades() + except Exception: + trades = None + for trade in self._engine_collection_items(trades): + row = self._trade_row_from_vnpy(trade) + if not row: + continue + key = row["trade_id"] + if key in seen: + continue + seen.add(key) + out.append(row) + return out + + def refresh_trades(self) -> None: + """成交仅读 vnpy 内存回报;不调用 query_trade(避免 CTP 段错误)。""" + return + + def list_trades(self, *, refresh: bool = False) -> list[dict[str, Any]]: + with _ctp_td_lock: + out = self._collect_engine_trades() + out.sort(key=lambda r: (r.get("datetime") or "", r.get("trade_id") or "")) + return out + + def list_active_orders(self) -> list[dict[str, Any]]: + if not self._engine: + return [] + out: list[dict[str, Any]] = [] + try: + orders = self._engine.get_all_active_orders() + except Exception: + return [] + for order in orders or []: + status = getattr(order, "status", None) + status_s = str(status) + if status_s and not any(x in status_s for x in ("NOTTRADED", "PARTTRADED", "SUBMITTING")): + continue + vol = int(getattr(order, "volume", 0) or 0) + traded = int(getattr(order, "traded", 0) or 0) + remain = max(0, vol - traded) + if remain <= 0: + continue + direction = getattr(order, "direction", None) + d = "long" + if direction is not None and str(direction).endswith("SHORT"): + d = "short" + offset = getattr(order, "offset", None) + offset_s = str(offset or "") + sym = getattr(order, "symbol", "") or "" + exchange = getattr(order, "exchange", None) + ex_name = str(exchange.value if hasattr(exchange, "value") else exchange or "") + vt_oid = str(getattr(order, "vt_orderid", "") or "") + order_id = str(getattr(order, "orderid", "") or "") + out.append({ + "symbol": sym, + "exchange": ex_name, + "direction": d, + "lots": remain, + "price": float(getattr(order, "price", 0) or 0), + "offset": offset_s, + "order_id": vt_oid or order_id, + "vt_order_id": vt_oid, + "status": status_s, + }) + return out + + def send_order( + self, + *, + ths_code: str, + offset: str, + direction: str, + lots: int, + price: float, + order_type: str = "limit", + ) -> str: + from vnpy.trader.constant import Direction, Offset, OrderType + from vnpy.trader.object import OrderRequest + + if not self._engine: + raise RuntimeError("CTP 未初始化") + if not self._td_logged_in(): + raise RuntimeError("CTP 交易通道未登录,请重连后再下单") + + sym, ex_name = ths_to_vnpy_symbol(ths_code) + exchange = to_vnpy_exchange(ex_name) + lots = max(1, int(lots)) + tick = float(get_contract_spec(ths_code).get("tick_size") or 1.0) + + offset = (offset or "open").lower() + direction = (direction or "long").lower() + + if offset in ("open", "open_long", "open_short"): + d = Direction.LONG if direction == "long" or offset == "open_long" else Direction.SHORT + off = Offset.OPEN + elif offset in ("close", "close_long", "close_short"): + hold = "long" if direction == "long" or offset == "close_long" else "short" + if hold == "long": + d = Direction.SHORT + else: + d = Direction.LONG + off = self._resolve_close_offset(sym, ex_name, hold, lots) + else: + raise ValueError(f"未知开平: {offset}") + + use_market = (order_type or "limit").lower() == "market" + if use_market: + ot = OrderType.FAK + price = self._aggressive_limit_price(ths_code, sym, ex_name, d, tick, price) + else: + ot = OrderType.LIMIT + price = round_to_tick(float(price), tick) + if price <= 0: + raise ValueError("委托价格无效,请检查行情或手动填写价格") + + req = OrderRequest( + symbol=sym, + exchange=exchange, + direction=d, + type=ot, + volume=lots, + price=price, + offset=off, + ) + logger.info( + "CTP 报单 %s %s %s %s手 @%s offset=%s type=%s", + sym, ex_name, d, lots, price, off, ot, + ) + with _ctp_td_lock: + vt_orderid = self._engine.send_order(req, GATEWAY_NAME) + if not vt_orderid: + raise RuntimeError("CTP 拒单或未返回委托号(请检查合约代码、价格是否为最小变动价位整数倍)") + return str(vt_orderid) + + def cancel_order(self, vt_orderid: str) -> bool: + if not self._engine or not vt_orderid: + return False + try: + with _ctp_td_lock: + order = self._engine.get_order(vt_orderid) + if order is None: + return False + req = order.create_cancel_request() + self._engine.cancel_order(req, GATEWAY_NAME) + logger.info("CTP 撤单 %s", vt_orderid) + return True + except Exception as exc: + logger.warning("CTP 撤单失败 %s: %s", vt_orderid, exc) + return False + + +class CtpBridgeProxy: + """Client-side stand-in for CtpBridge, forwarding calls to qihuo-ctp.""" + + _engine = None + + @property + def connected_mode(self) -> Optional[str]: + st = ctp_ipc_client.health().get("status") or {} + return st.get("connected_mode") + + @property + def last_error(self) -> str: + st = ctp_ipc_client.health().get("status") or {} + return str(st.get("last_error") or "") + + @property + def _last_connect_ok_ts(self) -> float: + st = ctp_ipc_client.health().get("status") or {} + try: + return float(st.get("last_connect_ok_ts") or 0) + except (TypeError, ValueError): + return 0.0 + + def available(self) -> bool: + return bool(ctp_ipc_client.health().get("worker_online")) + + def status(self, mode: str) -> dict[str, Any]: + return ctp_ipc_client.status(mode) + + def ping(self) -> bool: + return bool(ctp_ipc_client.health().get("worker_online")) + + def connect(self, mode: str, *, force: bool = False) -> dict[str, Any]: + return ctp_ipc_client.connect(mode, force=force) + + def start_connect_async( + self, + mode: str, + *, + force: bool = False, + scheduled: bool = False, + ) -> dict[str, Any]: + return ctp_ipc_client.start_connect(mode, force=force, scheduled=scheduled) + + def connect_in_progress(self) -> bool: + data = ctp_ipc_client.bridge_action("connect_in_progress") + return bool(data.get("result")) + + def login_cooldown_remaining(self) -> int: + st = ctp_ipc_client.health().get("status") or {} + try: + return int(st.get("login_cooldown_sec") or 0) + except (TypeError, ValueError): + return 0 + + def ensure_connected(self, mode: str) -> None: + if not self.status(mode).get("connected"): + raise RuntimeError("CTP worker 未连接,请重连后再操作") + + def require_connected(self, mode: str) -> None: + self.ensure_connected(mode) + + def get_account(self) -> dict[str, Any]: + mode = self.connected_mode or "simulation" + return ctp_ipc_client.account(mode) + + def list_positions( + self, + *, + refresh_if_empty: bool = True, + refresh_margin: bool = False, + ) -> list[dict[str, Any]]: + mode = self.connected_mode or "simulation" + return ctp_ipc_client.positions( + mode, + refresh_if_empty=refresh_if_empty, + refresh_margin=refresh_margin, + ) + + def list_active_orders(self) -> list[dict[str, Any]]: + mode = self.connected_mode or "simulation" + return ctp_ipc_client.active_orders(mode) + + def list_trades(self, *, refresh: bool = False) -> list[dict[str, Any]]: + mode = self.connected_mode or "simulation" + return ctp_ipc_client.trades(mode, refresh=refresh) + + def get_tick_price(self, ths_code: str, *, mode: str = "") -> Optional[float]: + return ctp_ipc_client.tick_price(mode or self.connected_mode or "simulation", ths_code) + + def get_tick_detail(self, ths_code: str, *, mode: str = "") -> dict[str, Any]: + return ctp_ipc_client.tick_detail(mode or self.connected_mode or "simulation", ths_code) + + def estimate_margin_one_lot( + self, + ths_code: str, + price: float, + *, + direction: str = "long", + ) -> Optional[float]: + return ctp_ipc_client.estimate_margin_one_lot( + self.connected_mode or "simulation", + ths_code, + price, + direction=direction, + ) + + def lookup_contract_spec(self, ths_code: str) -> Optional[dict]: + return ctp_ipc_client.contract_spec(self.connected_mode or "simulation", ths_code) + + def send_order(self, **payload: Any) -> str: + data = ctp_ipc_client.send_order(payload) + return str(data.get("order_id") or "") + + def cancel_order(self, vt_orderid: str) -> bool: + return ctp_ipc_client.cancel_order(self.connected_mode or "simulation", vt_orderid) + + def calibrate_trading_state(self) -> Any: + return ctp_ipc_client.bridge_action("calibrate_trading_state").get("result") + + def request_position_snapshot(self, *, force: bool = False) -> Any: + return ctp_ipc_client.bridge_action( + "request_position_snapshot", + {"force": bool(force)}, + ).get("result") + + def subscribe_symbol(self, symbol: str) -> Any: + return ctp_ipc_client.bridge_action("subscribe_symbol", {"symbol": symbol}).get("result") + + def refresh_positions(self) -> Any: + return ctp_ipc_client.bridge_action("refresh_positions").get("result") + + def reconnect_after_settings_saved(self, mode: str) -> Any: + return ctp_ipc_client.bridge_action( + "reconnect_after_settings_saved", + {"mode": mode}, + ).get("result") + + def query_all_commissions(self, *, mode: str = "") -> list[dict]: + data = ctp_ipc_client.bridge_action("query_all_commissions", {"mode": mode}) + return list(data.get("result") or []) + + def query_instrument_commission(self, symbol: str, *, mode: str = "") -> dict: + data = ctp_ipc_client.bridge_action( + "query_instrument_commission", + {"symbol": symbol, "mode": mode or self.connected_mode or "simulation"}, + ) + return dict(data.get("result") or {}) + + def get_kline_bars_1m(self, ths_code: str, *, mode: str) -> list[dict]: + data = ctp_ipc_client.bridge_action( + "get_kline_bars_1m", + {"symbol": ths_code, "mode": mode}, + ) + return list(data.get("result") or []) + + def _close_gateway(self) -> None: + ctp_ipc_client.disconnect() + + +def get_bridge(): + global _bridge + if _use_ctp_worker_client(): + return CtpBridgeProxy() + with _bridge_lock: + if _bridge is None: + _bridge = CtpBridge() + return _bridge + + +def try_init_vnpy(_settings: dict | None = None) -> bool: + if _use_ctp_worker_client(): + return bool(ctp_ipc_client.health().get("worker_online")) + return get_bridge().available() + + +def vnpy_available() -> bool: + if _use_ctp_worker_client(): + return bool(ctp_ipc_client.health().get("worker_online")) + return get_bridge().available() + + +def _ctp_connect_permitted(*, scheduled: bool = False) -> bool: + """scheduled=True:盘前/交易时段计划连接,不受「自动连接」开关限制。""" + from modules.ctp.ctp_settings import is_ctp_auto_connect_enabled + + if is_ctp_auto_connect_enabled(): + return True + if not scheduled: + return False + from modules.ctp.ctp_premarket_connect import should_auto_connect_now + + return should_auto_connect_now() + + +def ctp_disconnect(*, set_disabled_hint: bool = False) -> None: + """主动断开 CTP 并清理内存状态。""" + if _use_ctp_worker_client(): + ctp_ipc_client.disconnect(set_disabled_hint=set_disabled_hint) + return + from modules.ctp.ctp_settings import CTP_DISABLED_HINT + + b = get_bridge() + b._close_gateway() + if set_disabled_hint: + b._last_error = CTP_DISABLED_HINT + _persist_last_error(CTP_DISABLED_HINT) + else: + b._last_error = "" + _persist_last_error("") + + +def ctp_connect(mode: str, *, force: bool = False) -> dict[str, Any]: + if _use_ctp_worker_client(): + return ctp_ipc_client.connect(mode, force=force) + b = get_bridge() + b.connect(mode, force=force) + return b.status(mode) + + +def ctp_start_connect(mode: str, *, force: bool = False, scheduled: bool = False) -> dict[str, Any]: + """非阻塞发起连接,供 Web API 使用。""" + if _use_ctp_worker_client(): + return ctp_ipc_client.start_connect(mode, force=force, scheduled=scheduled) + b = get_bridge() + info = b.start_connect_async(mode, force=force, scheduled=scheduled) + st = b.status(mode) + return {**info, "status": st} + + +def ctp_try_auto_reconnect(mode: str) -> bool: + """断线时静默异步重连;已连接且交易通道正常则不再重复 connect。""" + if _use_ctp_worker_client(): + info = ctp_ipc_client.start_connect(mode, force=False, scheduled=True) + return bool( + info.get("connected") + or info.get("connecting") + or info.get("started") + ) + if not _ctp_connect_permitted(scheduled=True): + return False + b = get_bridge() + if not b.available(): + return False + if b.connect_in_progress(): + return False + if b.login_cooldown_remaining() > 0: + return False + st = _setting_for_mode(mode) + if not st.get("用户名") or not st.get("密码") or not st.get("交易服务器"): + return False + if b.connected_mode == mode: + if b._td_logged_in() or b.ping(): + return True + recent = time.time() - float(getattr(b, "_last_connect_ok_ts", 0) or 0) + if recent < 120: + logger.debug("CTP 跳过自动重连:刚连接 %.0fs", recent) + return True + td = st.get("交易服务器", "") + ok, err = probe_tcp_address(td, timeout=4.0) + if not ok: + b._last_error = ( + f"SimNow 交易前置不可达:{td}({err})。" + "请更新 SIMNOW_TD_ADDRESS 并确认服务器出网。" + ) + return False + info = b.start_connect_async(mode, force=False, scheduled=True) + return bool( + info.get("connected") + or info.get("connecting") + or info.get("started") + ) + + +def ctp_status(mode: str) -> dict[str, Any]: + from modules.ctp.ctp_settings import CTP_DISABLED_HINT, is_ctp_auto_connect_enabled + + if _use_ctp_worker_client(): + st = ctp_ipc_client.status(mode) + st["auto_connect_enabled"] = is_ctp_auto_connect_enabled() + return st + auto = is_ctp_auto_connect_enabled() + st = get_bridge().status(mode) + st["auto_connect_enabled"] = auto + if not auto: + st["disabled_hint"] = CTP_DISABLED_HINT + if not st.get("connected") and not st.get("connecting"): + st["last_error"] = "" + st["td_reachable"] = None + return st + if not st.get("connected") and not st.get("connecting"): + setting = _setting_for_mode(mode) + td = setting.get("交易服务器", "") + if td: + ok, err = probe_tcp_address(td, timeout=3.0) + st["td_reachable"] = ok + if not ok and not st.get("last_error"): + st["last_error"] = ( + f"SimNow 交易前置不可达:{td}({err})" + ) + return st + + +def ctp_get_account(mode: str) -> dict[str, Any]: + if _use_ctp_worker_client(): + return ctp_ipc_client.account(mode) + b = get_bridge() + b.ensure_connected(mode) + return b.get_account() + + +def ctp_sum_position_margins( + mode: str, + *, + refresh_if_empty: bool = True, + refresh_margin: bool = False, +) -> float: + """各持仓 CTP 回报保证金之和(与柜台「实收保证金」一致)。""" + total = 0.0 + for p in ctp_list_positions( + mode, refresh_if_empty=refresh_if_empty, refresh_margin=refresh_margin, + ): + m = float(p.get("margin") or 0) + if m > 0: + total += m + return round(total, 2) if total > 0 else 0.0 + + +def ctp_account_margin_used(mode: str) -> Optional[float]: + """账户实际占用保证金 ≈ 权益 − 可用(与顶栏柜台资金一致)。""" + if _use_ctp_worker_client(): + try: + acc = ctp_ipc_client.account(mode) + balance = float(acc.get("balance") or 0) + available = float(acc.get("available") or 0) + if balance <= 0: + return None + used = balance - available + return round(used, 2) if used > 0 else None + except Exception as exc: + logger.debug("ctp_account_margin_used ipc: %s", exc) + return None + b = get_bridge() + if b.connected_mode != mode or not b.ping(): + return None + try: + acc = b.get_account() + balance = float(acc.get("balance") or 0) + available = float(acc.get("available") or 0) + if balance <= 0: + return None + used = balance - available + return round(used, 2) if used > 0 else None + except Exception as exc: + logger.debug("ctp_account_margin_used: %s", exc) + return None + + +def ctp_list_positions( + mode: str, + *, + refresh_if_empty: bool = True, + refresh_margin: bool = False, +) -> list[dict[str, Any]]: + if _use_ctp_worker_client(): + return ctp_ipc_client.positions( + mode, + refresh_if_empty=refresh_if_empty, + refresh_margin=refresh_margin, + ) + b = get_bridge() + if b.connected_mode != mode or not b.ping(): + return [] + return b.list_positions(refresh_if_empty=refresh_if_empty, refresh_margin=refresh_margin) + + +def ctp_list_active_orders(mode: str) -> list[dict[str, Any]]: + if _use_ctp_worker_client(): + return ctp_ipc_client.active_orders(mode) + b = get_bridge() + b.ensure_connected(mode) + return b.list_active_orders() + + +def ctp_cancel_order(mode: str, vt_orderid: str) -> bool: + if _use_ctp_worker_client(): + return ctp_ipc_client.cancel_order(mode, vt_orderid) + b = get_bridge() + b.ensure_connected(mode) + return b.cancel_order(vt_orderid) + + +def ctp_list_trades(mode: str, *, refresh: bool = False) -> list[dict[str, Any]]: + if _use_ctp_worker_client(): + return ctp_ipc_client.trades(mode, refresh=refresh) + b = get_bridge() + if b.connected_mode != mode or not b.ping(): + return [] + return b.list_trades(refresh=refresh) + + +def ctp_get_tick_price(mode: str, ths_code: str) -> Optional[float]: + """CTP 柜台最新价(需已连接并订阅)。""" + if _use_ctp_worker_client(): + return ctp_ipc_client.tick_price(mode, ths_code) + b = get_bridge() + if b.connected_mode != mode: + return None + try: + return b.get_tick_price(ths_code, mode=mode) + except Exception as exc: + logger.debug("ctp_get_tick_price: %s", exc) + return None + + +def ctp_get_tick_detail(mode: str, ths_code: str) -> dict[str, Any]: + if _use_ctp_worker_client(): + return ctp_ipc_client.tick_detail(mode, ths_code) + b = get_bridge() + if b.connected_mode != mode: + return {} + try: + return b.get_tick_detail(ths_code, mode=mode) + except Exception as exc: + logger.debug("ctp_get_tick_detail: %s", exc) + return {} + + +def ctp_estimate_margin_one_lot( + mode: str, + ths_code: str, + price: float, + *, + direction: str = "long", +) -> Optional[float]: + if _use_ctp_worker_client(): + return ctp_ipc_client.estimate_margin_one_lot( + mode, + ths_code, + price, + direction=direction, + ) + b = get_bridge() + if b.connected_mode != mode or not b.ping(): + return None + try: + return b.estimate_margin_one_lot(ths_code, price, direction=direction) + except Exception as exc: + logger.debug("ctp_estimate_margin_one_lot: %s", exc) + return None + + +def ctp_lookup_contract_spec(mode: str, ths_code: str) -> Optional[dict]: + if _use_ctp_worker_client(): + return ctp_ipc_client.contract_spec(mode, ths_code) + b = get_bridge() + if b.connected_mode != mode or not b.ping(): + return None + try: + return b.lookup_contract_spec(ths_code) + except Exception as exc: + logger.debug("ctp_lookup_contract_spec: %s", exc) + return None + + +def get_ctp_balance(mode: str) -> Optional[float]: + try: + acc = ctp_get_account(mode) + bal = acc.get("balance") + return float(bal) if bal else None + except Exception as exc: + logger.debug("get_ctp_balance: %s", exc) + return None + + +def execute_order( + conn, + *, + mode: str, + offset: str, + symbol: str, + direction: str, + lots: int, + price: float, + settings: dict | None = None, + order_type: str = "limit", +) -> dict[str, Any]: + """统一下单:simulation=SimNow,live=期货公司 CTP。""" + if _use_ctp_worker_client(): + return ctp_ipc_client.send_order({ + "mode": mode, + "offset": offset, + "symbol": symbol, + "direction": direction, + "lots": lots, + "price": price, + "settings": settings or {}, + "order_type": order_type, + }) + del conn, settings + if mode not in ("simulation", "live"): + raise ValueError("未知交易模式") + if not vnpy_available(): + raise ValueError( + "请先安装 vnpy 与 vnpy_ctp:pip install vnpy vnpy_ctp\n" + f"模拟盘需配置 .env 中 SIMNOW_USER / SIMNOW_PASSWORD 等" + ) + b = get_bridge() + b.require_connected(mode) + order_id = b.send_order( + ths_code=symbol, + offset=offset, + direction=direction, + lots=lots, + price=price, + order_type=order_type, + ) + return { + "order_id": order_id, + "mode": mode, + "mode_label": _mode_label(mode), + "symbol": symbol, + "lots": lots, + "price": price, + } diff --git a/modules/fees/__init__.py b/modules/fees/__init__.py new file mode 100644 index 0000000..4ea89fc --- /dev/null +++ b/modules/fees/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) 2025-2026 马建军. All rights reserved. + +from modules.fees.routes import register + +__all__ = ["register"] diff --git a/fee_specs.py b/modules/fees/fee_specs.py similarity index 95% rename from fee_specs.py rename to modules/fees/fee_specs.py index ea0cf42..5c06802 100644 --- a/fee_specs.py +++ b/modules/fees/fee_specs.py @@ -1,385 +1,385 @@ -# Copyright (c) 2025-2026 马建军. All rights reserved. -# 专有软件 — 未经授权禁止复制、传播、转售。 -# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 -# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md - -"""期货手续费:仅 CTP 柜台同步入库,前端只读展示。""" -import json -import os -import re -from datetime import datetime -from typing import Optional - -from contract_specs import get_contract_spec - -from db_conn import connect_db, is_benign_migration_error - -DB_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "futures.db") -DATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data") -DEFAULT_JSON = os.path.join(DATA_DIR, "fee_rates.json") - -# 无配置时的兜底(已为交易所标准约 2 倍) -DEFAULT_FEE = { - "open_fixed": 2.0, - "open_ratio": 0.0, - "close_yesterday_fixed": 2.0, - "close_yesterday_ratio": 0.0, - "close_today_fixed": 4.0, - "close_today_ratio": 0.0, -} - -_INDEX_PRODUCTS = {"if", "ih", "ic", "im"} - - -def product_from_code(ths_code: str) -> str: - code = (ths_code or "").strip() - m = re.match(r"^([A-Za-z]+)", code) - return m.group(1).lower() if m else "" - - -def _get_db(): - return connect_db() - - -def ensure_fee_rates_schema(conn=None) -> None: - """补齐 fee_rates 表结构(旧库可能缺少 source 列)。""" - close = False - if conn is None: - conn = _get_db() - close = True - try: - for sql in ( - "ALTER TABLE fee_rates ADD COLUMN source TEXT DEFAULT 'local'", - ): - try: - conn.execute(sql) - except Exception as exc: - if not is_benign_migration_error(exc): - raise - conn.commit() - finally: - if close: - conn.close() - - -def get_setting(key: str, default: str = "") -> str: - conn = _get_db() - row = conn.execute("SELECT value FROM settings WHERE key=?", (key,)).fetchone() - conn.close() - if not row: - return default - return (row["value"] or default) if row["value"] is not None else default - - -def set_setting(key: str, value: str) -> None: - conn = _get_db() - conn.execute( - """INSERT INTO settings (key, value) VALUES (?,?) - ON CONFLICT(key) DO UPDATE SET value=excluded.value""", - (key, value), - ) - conn.commit() - conn.close() - - -def get_fee_multiplier() -> float: - conn = _get_db() - row = conn.execute( - "SELECT value FROM settings WHERE key='fee_multiplier'" - ).fetchone() - conn.close() - if row and row["value"]: - try: - return max(0.0, float(row["value"])) - except ValueError: - pass - return 2.0 - - -def get_fee_source_mode() -> str: - """固定 CTP 柜台。""" - return "ctp" - - -def purge_non_ctp_fee_rates() -> int: - """删除非 CTP 来源的费率缓存。""" - conn = _get_db() - cur = conn.execute( - "DELETE FROM fee_rates WHERE COALESCE(source, '') != 'ctp'" - ) - n = cur.rowcount - conn.commit() - conn.close() - return n - - -def _row_to_spec(row, mult: int) -> dict: - return { - "product": row["product"], - "exchange": row["exchange"] or "", - "mult": int(row["mult"] or mult), - "open_fixed": float(row["open_fixed"] or 0), - "open_ratio": float(row["open_ratio"] or 0), - "close_yesterday_fixed": float(row["close_yesterday_fixed"] or 0), - "close_yesterday_ratio": float(row["close_yesterday_ratio"] or 0), - "close_today_fixed": float(row["close_today_fixed"] or 0), - "close_today_ratio": float(row["close_today_ratio"] or 0), - "source": row["source"] if "source" in row.keys() else "local", - } - - -def get_fee_spec(ths_code: str, *, trading_mode: str = "simulation") -> dict: - product = product_from_code(ths_code) - if not product: - spec = get_contract_spec(ths_code) - return {**DEFAULT_FEE, "mult": spec["mult"], "product": "", "exchange": "", "source": "default"} - - mult = get_contract_spec(ths_code)["mult"] - conn = _get_db() - ensure_fee_rates_schema(conn) - row = conn.execute( - "SELECT * FROM fee_rates WHERE product=? AND source='ctp'", - (product,), - ).fetchone() - conn.close() - if row: - return _row_to_spec(row, mult) - try: - from ctp_fee_sync import sync_fee_for_symbol - fields = sync_fee_for_symbol(trading_mode, ths_code) - if fields: - return {"product": product, **fields} - except Exception: - pass - - if product in _INDEX_PRODUCTS: - return { - "product": product, - "exchange": "CFFEX", - "mult": mult, - "open_fixed": 0.0, - "open_ratio": 0.000092, - "close_yesterday_fixed": 0.0, - "close_yesterday_ratio": 0.000092, - "close_today_fixed": 0.0, - "close_today_ratio": 0.000276, - } - - return { - "product": product, - "exchange": "", - "mult": mult, - **DEFAULT_FEE, - "source": "default", - } - - -def calc_side_fee( - price: float, - lots: float, - mult: int, - fixed: float, - ratio: float, -) -> float: - lots = lots or 1.0 - fixed = fixed or 0.0 - ratio = ratio or 0.0 - return fixed * lots + ratio * price * mult * lots - - -def is_same_day(open_time: str, close_time: str) -> bool: - if not open_time or not close_time: - return True - o = open_time.strip().replace(" ", "T")[:10] - c = close_time.strip().replace(" ", "T")[:10] - return o == c - - -def calc_round_trip_fee( - ths_code: str, - entry_price: float, - close_price: float, - lots: float, - open_time: str = "", - close_time: str = "", - trading_mode: str = "simulation", -) -> float: - if not entry_price or not close_price: - return 0.0 - spec = get_fee_spec(ths_code, trading_mode=trading_mode) - mult = spec["mult"] - lots = lots or 1.0 - - open_fee = calc_side_fee( - entry_price, lots, mult, - spec["open_fixed"], spec["open_ratio"], - ) - if is_same_day(open_time, close_time): - close_fee = calc_side_fee( - close_price, lots, mult, - spec["close_today_fixed"], spec["close_today_ratio"], - ) - else: - close_fee = calc_side_fee( - close_price, lots, mult, - spec["close_yesterday_fixed"], spec["close_yesterday_ratio"], - ) - return round(open_fee + close_fee, 2) - - -def calc_fee_breakdown( - ths_code: str, - entry_price: float, - close_price: float, - lots: float, - open_time: str = "", - close_time: str = "", - trading_mode: str = "simulation", -) -> dict: - spec = get_fee_spec(ths_code, trading_mode=trading_mode) - mult = spec["mult"] - lots = lots or 1.0 - open_fee = calc_side_fee( - entry_price, lots, mult, spec["open_fixed"], spec["open_ratio"], - ) - same_day = is_same_day(open_time, close_time) - if same_day: - close_fee = calc_side_fee( - close_price, lots, mult, - spec["close_today_fixed"], spec["close_today_ratio"], - ) - close_type = "平今" - else: - close_fee = calc_side_fee( - close_price, lots, mult, - spec["close_yesterday_fixed"], spec["close_yesterday_ratio"], - ) - close_type = "平昨" - total = round(open_fee + close_fee, 2) - return { - "open_fee": round(open_fee, 2), - "close_fee": round(close_fee, 2), - "close_type": close_type, - "total_fee": total, - "same_day": same_day, - "fee_source": spec.get("source", "local"), - } - - -def load_fee_rates_from_json(path: Optional[str] = None) -> int: - path = path or DEFAULT_JSON - if not os.path.isfile(path): - return 0 - with open(path, "r", encoding="utf-8") as f: - data = json.load(f) - conn = _get_db() - now = datetime.now().isoformat(timespec="seconds") - count = 0 - for product, item in data.items(): - if not isinstance(item, dict): - continue - conn.execute( - """INSERT INTO fee_rates - (product, exchange, mult, - open_fixed, open_ratio, - close_yesterday_fixed, close_yesterday_ratio, - close_today_fixed, close_today_ratio, updated_at, source) - VALUES (?,?,?,?,?,?,?,?,?,?,?) - ON CONFLICT(product) DO UPDATE SET - exchange=excluded.exchange, mult=excluded.mult, - open_fixed=excluded.open_fixed, open_ratio=excluded.open_ratio, - close_yesterday_fixed=excluded.close_yesterday_fixed, - close_yesterday_ratio=excluded.close_yesterday_ratio, - close_today_fixed=excluded.close_today_fixed, - close_today_ratio=excluded.close_today_ratio, - updated_at=excluded.updated_at, - source=excluded.source""", - ( - product.lower(), - item.get("exchange", ""), - int(item.get("mult") or get_contract_spec(product)["mult"]), - float(item.get("open_fixed") or 0), - float(item.get("open_ratio") or 0), - float(item.get("close_yesterday_fixed") or 0), - float(item.get("close_yesterday_ratio") or 0), - float(item.get("close_today_fixed") or 0), - float(item.get("close_today_ratio") or 0), - now, - item.get("source", "json"), - ), - ) - count += 1 - conn.commit() - conn.close() - return count - - -def list_ctp_fee_rates() -> list: - """手续费页:仅展示 CTP 同步结果。""" - conn = _get_db() - rows = conn.execute( - "SELECT * FROM fee_rates WHERE source='ctp' ORDER BY product" - ).fetchall() - conn.close() - return [dict(r) for r in rows] - - -def list_all_fee_rates() -> list: - conn = _get_db() - rows = conn.execute( - "SELECT * FROM fee_rates ORDER BY product" - ).fetchall() - conn.close() - return [dict(r) for r in rows] - - -def list_fee_rates_for_ui() -> list: - return list_ctp_fee_rates() - - -def count_fee_rates_by_source() -> dict[str, int]: - conn = _get_db() - n = conn.execute( - "SELECT COUNT(*) FROM fee_rates WHERE source='ctp'" - ).fetchone()[0] - conn.close() - return {"ctp": int(n or 0)} - - -def upsert_fee_rate(product: str, fields: dict) -> None: - product = product.lower().strip() - conn = _get_db() - now = datetime.now().isoformat(timespec="seconds") - source = fields.get("source", "manual") - conn.execute( - """INSERT INTO fee_rates - (product, exchange, mult, - open_fixed, open_ratio, - close_yesterday_fixed, close_yesterday_ratio, - close_today_fixed, close_today_ratio, updated_at, source) - VALUES (?,?,?,?,?,?,?,?,?,?,?) - ON CONFLICT(product) DO UPDATE SET - exchange=excluded.exchange, mult=excluded.mult, - open_fixed=excluded.open_fixed, open_ratio=excluded.open_ratio, - close_yesterday_fixed=excluded.close_yesterday_fixed, - close_yesterday_ratio=excluded.close_yesterday_ratio, - close_today_fixed=excluded.close_today_fixed, - close_today_ratio=excluded.close_today_ratio, - updated_at=excluded.updated_at, - source=excluded.source""", - ( - product, - fields.get("exchange", ""), - int(fields.get("mult") or 10), - float(fields.get("open_fixed") or 0), - float(fields.get("open_ratio") or 0), - float(fields.get("close_yesterday_fixed") or 0), - float(fields.get("close_yesterday_ratio") or 0), - float(fields.get("close_today_fixed") or 0), - float(fields.get("close_today_ratio") or 0), - now, - source, - ), - ) - conn.commit() - conn.close() +# Copyright (c) 2025-2026 马建军. All rights reserved. +# 专有软件 — 未经授权禁止复制、传播、转售。 +# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 +# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md + +"""期货手续费:仅 CTP 柜台同步入库,前端只读展示。""" +import json +import os +import re +from datetime import datetime +from typing import Optional + +from modules.core.contract_specs import get_contract_spec + +from modules.core.db_conn import connect_db, is_benign_migration_error + +DB_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "futures.db") +DATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data") +DEFAULT_JSON = os.path.join(DATA_DIR, "fee_rates.json") + +# 无配置时的兜底(已为交易所标准约 2 倍) +DEFAULT_FEE = { + "open_fixed": 2.0, + "open_ratio": 0.0, + "close_yesterday_fixed": 2.0, + "close_yesterday_ratio": 0.0, + "close_today_fixed": 4.0, + "close_today_ratio": 0.0, +} + +_INDEX_PRODUCTS = {"if", "ih", "ic", "im"} + + +def product_from_code(ths_code: str) -> str: + code = (ths_code or "").strip() + m = re.match(r"^([A-Za-z]+)", code) + return m.group(1).lower() if m else "" + + +def _get_db(): + return connect_db() + + +def ensure_fee_rates_schema(conn=None) -> None: + """补齐 fee_rates 表结构(旧库可能缺少 source 列)。""" + close = False + if conn is None: + conn = _get_db() + close = True + try: + for sql in ( + "ALTER TABLE fee_rates ADD COLUMN source TEXT DEFAULT 'local'", + ): + try: + conn.execute(sql) + except Exception as exc: + if not is_benign_migration_error(exc): + raise + conn.commit() + finally: + if close: + conn.close() + + +def get_setting(key: str, default: str = "") -> str: + conn = _get_db() + row = conn.execute("SELECT value FROM settings WHERE key=?", (key,)).fetchone() + conn.close() + if not row: + return default + return (row["value"] or default) if row["value"] is not None else default + + +def set_setting(key: str, value: str) -> None: + conn = _get_db() + conn.execute( + """INSERT INTO settings (key, value) VALUES (?,?) + ON CONFLICT(key) DO UPDATE SET value=excluded.value""", + (key, value), + ) + conn.commit() + conn.close() + + +def get_fee_multiplier() -> float: + conn = _get_db() + row = conn.execute( + "SELECT value FROM settings WHERE key='fee_multiplier'" + ).fetchone() + conn.close() + if row and row["value"]: + try: + return max(0.0, float(row["value"])) + except ValueError: + pass + return 2.0 + + +def get_fee_source_mode() -> str: + """固定 CTP 柜台。""" + return "ctp" + + +def purge_non_ctp_fee_rates() -> int: + """删除非 CTP 来源的费率缓存。""" + conn = _get_db() + cur = conn.execute( + "DELETE FROM fee_rates WHERE COALESCE(source, '') != 'ctp'" + ) + n = cur.rowcount + conn.commit() + conn.close() + return n + + +def _row_to_spec(row, mult: int) -> dict: + return { + "product": row["product"], + "exchange": row["exchange"] or "", + "mult": int(row["mult"] or mult), + "open_fixed": float(row["open_fixed"] or 0), + "open_ratio": float(row["open_ratio"] or 0), + "close_yesterday_fixed": float(row["close_yesterday_fixed"] or 0), + "close_yesterday_ratio": float(row["close_yesterday_ratio"] or 0), + "close_today_fixed": float(row["close_today_fixed"] or 0), + "close_today_ratio": float(row["close_today_ratio"] or 0), + "source": row["source"] if "source" in row.keys() else "local", + } + + +def get_fee_spec(ths_code: str, *, trading_mode: str = "simulation") -> dict: + product = product_from_code(ths_code) + if not product: + spec = get_contract_spec(ths_code) + return {**DEFAULT_FEE, "mult": spec["mult"], "product": "", "exchange": "", "source": "default"} + + mult = get_contract_spec(ths_code)["mult"] + conn = _get_db() + ensure_fee_rates_schema(conn) + row = conn.execute( + "SELECT * FROM fee_rates WHERE product=? AND source='ctp'", + (product,), + ).fetchone() + conn.close() + if row: + return _row_to_spec(row, mult) + try: + from modules.ctp.ctp_fee_sync import sync_fee_for_symbol + fields = sync_fee_for_symbol(trading_mode, ths_code) + if fields: + return {"product": product, **fields} + except Exception: + pass + + if product in _INDEX_PRODUCTS: + return { + "product": product, + "exchange": "CFFEX", + "mult": mult, + "open_fixed": 0.0, + "open_ratio": 0.000092, + "close_yesterday_fixed": 0.0, + "close_yesterday_ratio": 0.000092, + "close_today_fixed": 0.0, + "close_today_ratio": 0.000276, + } + + return { + "product": product, + "exchange": "", + "mult": mult, + **DEFAULT_FEE, + "source": "default", + } + + +def calc_side_fee( + price: float, + lots: float, + mult: int, + fixed: float, + ratio: float, +) -> float: + lots = lots or 1.0 + fixed = fixed or 0.0 + ratio = ratio or 0.0 + return fixed * lots + ratio * price * mult * lots + + +def is_same_day(open_time: str, close_time: str) -> bool: + if not open_time or not close_time: + return True + o = open_time.strip().replace(" ", "T")[:10] + c = close_time.strip().replace(" ", "T")[:10] + return o == c + + +def calc_round_trip_fee( + ths_code: str, + entry_price: float, + close_price: float, + lots: float, + open_time: str = "", + close_time: str = "", + trading_mode: str = "simulation", +) -> float: + if not entry_price or not close_price: + return 0.0 + spec = get_fee_spec(ths_code, trading_mode=trading_mode) + mult = spec["mult"] + lots = lots or 1.0 + + open_fee = calc_side_fee( + entry_price, lots, mult, + spec["open_fixed"], spec["open_ratio"], + ) + if is_same_day(open_time, close_time): + close_fee = calc_side_fee( + close_price, lots, mult, + spec["close_today_fixed"], spec["close_today_ratio"], + ) + else: + close_fee = calc_side_fee( + close_price, lots, mult, + spec["close_yesterday_fixed"], spec["close_yesterday_ratio"], + ) + return round(open_fee + close_fee, 2) + + +def calc_fee_breakdown( + ths_code: str, + entry_price: float, + close_price: float, + lots: float, + open_time: str = "", + close_time: str = "", + trading_mode: str = "simulation", +) -> dict: + spec = get_fee_spec(ths_code, trading_mode=trading_mode) + mult = spec["mult"] + lots = lots or 1.0 + open_fee = calc_side_fee( + entry_price, lots, mult, spec["open_fixed"], spec["open_ratio"], + ) + same_day = is_same_day(open_time, close_time) + if same_day: + close_fee = calc_side_fee( + close_price, lots, mult, + spec["close_today_fixed"], spec["close_today_ratio"], + ) + close_type = "平今" + else: + close_fee = calc_side_fee( + close_price, lots, mult, + spec["close_yesterday_fixed"], spec["close_yesterday_ratio"], + ) + close_type = "平昨" + total = round(open_fee + close_fee, 2) + return { + "open_fee": round(open_fee, 2), + "close_fee": round(close_fee, 2), + "close_type": close_type, + "total_fee": total, + "same_day": same_day, + "fee_source": spec.get("source", "local"), + } + + +def load_fee_rates_from_json(path: Optional[str] = None) -> int: + path = path or DEFAULT_JSON + if not os.path.isfile(path): + return 0 + with open(path, "r", encoding="utf-8") as f: + data = json.load(f) + conn = _get_db() + now = datetime.now().isoformat(timespec="seconds") + count = 0 + for product, item in data.items(): + if not isinstance(item, dict): + continue + conn.execute( + """INSERT INTO fee_rates + (product, exchange, mult, + open_fixed, open_ratio, + close_yesterday_fixed, close_yesterday_ratio, + close_today_fixed, close_today_ratio, updated_at, source) + VALUES (?,?,?,?,?,?,?,?,?,?,?) + ON CONFLICT(product) DO UPDATE SET + exchange=excluded.exchange, mult=excluded.mult, + open_fixed=excluded.open_fixed, open_ratio=excluded.open_ratio, + close_yesterday_fixed=excluded.close_yesterday_fixed, + close_yesterday_ratio=excluded.close_yesterday_ratio, + close_today_fixed=excluded.close_today_fixed, + close_today_ratio=excluded.close_today_ratio, + updated_at=excluded.updated_at, + source=excluded.source""", + ( + product.lower(), + item.get("exchange", ""), + int(item.get("mult") or get_contract_spec(product)["mult"]), + float(item.get("open_fixed") or 0), + float(item.get("open_ratio") or 0), + float(item.get("close_yesterday_fixed") or 0), + float(item.get("close_yesterday_ratio") or 0), + float(item.get("close_today_fixed") or 0), + float(item.get("close_today_ratio") or 0), + now, + item.get("source", "json"), + ), + ) + count += 1 + conn.commit() + conn.close() + return count + + +def list_ctp_fee_rates() -> list: + """手续费页:仅展示 CTP 同步结果。""" + conn = _get_db() + rows = conn.execute( + "SELECT * FROM fee_rates WHERE source='ctp' ORDER BY product" + ).fetchall() + conn.close() + return [dict(r) for r in rows] + + +def list_all_fee_rates() -> list: + conn = _get_db() + rows = conn.execute( + "SELECT * FROM fee_rates ORDER BY product" + ).fetchall() + conn.close() + return [dict(r) for r in rows] + + +def list_fee_rates_for_ui() -> list: + return list_ctp_fee_rates() + + +def count_fee_rates_by_source() -> dict[str, int]: + conn = _get_db() + n = conn.execute( + "SELECT COUNT(*) FROM fee_rates WHERE source='ctp'" + ).fetchone()[0] + conn.close() + return {"ctp": int(n or 0)} + + +def upsert_fee_rate(product: str, fields: dict) -> None: + product = product.lower().strip() + conn = _get_db() + now = datetime.now().isoformat(timespec="seconds") + source = fields.get("source", "manual") + conn.execute( + """INSERT INTO fee_rates + (product, exchange, mult, + open_fixed, open_ratio, + close_yesterday_fixed, close_yesterday_ratio, + close_today_fixed, close_today_ratio, updated_at, source) + VALUES (?,?,?,?,?,?,?,?,?,?,?) + ON CONFLICT(product) DO UPDATE SET + exchange=excluded.exchange, mult=excluded.mult, + open_fixed=excluded.open_fixed, open_ratio=excluded.open_ratio, + close_yesterday_fixed=excluded.close_yesterday_fixed, + close_yesterday_ratio=excluded.close_yesterday_ratio, + close_today_fixed=excluded.close_today_fixed, + close_today_ratio=excluded.close_today_ratio, + updated_at=excluded.updated_at, + source=excluded.source""", + ( + product, + fields.get("exchange", ""), + int(fields.get("mult") or 10), + float(fields.get("open_fixed") or 0), + float(fields.get("open_ratio") or 0), + float(fields.get("close_yesterday_fixed") or 0), + float(fields.get("close_yesterday_ratio") or 0), + float(fields.get("close_today_fixed") or 0), + float(fields.get("close_today_ratio") or 0), + now, + source, + ), + ) + conn.commit() + conn.close() diff --git a/fee_sync.py b/modules/fees/fee_sync.py similarity index 93% rename from fee_sync.py rename to modules/fees/fee_sync.py index 022b06b..c7e8b4c 100644 --- a/fee_sync.py +++ b/modules/fees/fee_sync.py @@ -1,91 +1,91 @@ -# Copyright (c) 2025-2026 马建军. All rights reserved. -# 专有软件 — 未经授权禁止复制、传播、转售。 -# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 -# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md - -"""从第三方(AKShare)同步交易所参考手续费,并按倍率写入本地表。""" -import re -from typing import Any, Optional - -from contract_specs import get_contract_spec -from fee_specs import get_fee_multiplier, upsert_fee_rate - - -def _to_float(val: Any) -> float: - if val is None: - return 0.0 - s = str(val).strip().replace(",", "") - if not s or s in ("-", "None", "nan"): - return 0.0 - try: - return float(s) - except ValueError: - return 0.0 - - -def _parse_akshare_row(row: dict, multiplier: float) -> Optional[dict]: - code = str(row.get("合约代码") or row.get("代码") or "").strip() - if not code: - return None - m = re.match(r"^([A-Za-z]+)", code) - if not m: - return None - product = m.group(1).lower() - - open_ratio = _to_float(row.get("手续费标准-开仓-万分之")) / 10000.0 - open_fixed = _to_float(row.get("手续费标准-开仓-元")) - if open_fixed == 0 and row.get("开仓"): - open_fixed = _to_float(row.get("开仓")) - close_y_ratio = _to_float(row.get("手续费标准-平昨-万分之")) / 10000.0 - close_y_fixed = _to_float(row.get("手续费标准-平昨-元")) - if close_y_fixed == 0 and row.get("平昨"): - close_y_fixed = _to_float(row.get("平昨")) - close_t_ratio = _to_float(row.get("手续费标准-平今-万分之")) / 10000.0 - close_t_fixed = _to_float(row.get("手续费标准-平今-元")) - if close_t_fixed == 0 and row.get("平今"): - close_t_fixed = _to_float(row.get("平今")) - - mult = int(get_contract_spec(code)["mult"]) - exchange = str(row.get("交易所名称") or row.get("交易所") or "").strip() - - return { - "product": product, - "exchange": exchange, - "mult": mult, - "open_fixed": round(open_fixed * multiplier, 6), - "open_ratio": round(open_ratio * multiplier, 8), - "close_yesterday_fixed": round(close_y_fixed * multiplier, 6), - "close_yesterday_ratio": round(close_y_ratio * multiplier, 8), - "close_today_fixed": round(close_t_fixed * multiplier, 6), - "close_today_ratio": round(close_t_ratio * multiplier, 8), - "source": "akshare", - } - - -def sync_fees_from_akshare(multiplier: Optional[float] = None) -> tuple[int, str]: - multiplier = multiplier if multiplier is not None else get_fee_multiplier() - try: - import akshare as ak - except ImportError: - return 0, "未安装 akshare,请执行 pip install akshare 后重试,或使用默认费率表" - - try: - df = ak.futures_comm_info(symbol="所有") - except Exception as exc: - return 0, f"拉取第三方数据失败: {exc}" - - if df is None or df.empty: - return 0, "第三方返回空数据" - - seen: set[str] = set() - count = 0 - for _, series in df.iterrows(): - row = series.to_dict() - parsed = _parse_akshare_row(row, multiplier) - if not parsed or parsed["product"] in seen: - continue - seen.add(parsed["product"]) - upsert_fee_rate(parsed["product"], parsed) - count += 1 - - return count, f"已同步 {count} 个品种(标准费率 × {multiplier})" +# Copyright (c) 2025-2026 马建军. All rights reserved. +# 专有软件 — 未经授权禁止复制、传播、转售。 +# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 +# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md + +"""从第三方(AKShare)同步交易所参考手续费,并按倍率写入本地表。""" +import re +from typing import Any, Optional + +from modules.core.contract_specs import get_contract_spec +from modules.fees.fee_specs import get_fee_multiplier, upsert_fee_rate + + +def _to_float(val: Any) -> float: + if val is None: + return 0.0 + s = str(val).strip().replace(",", "") + if not s or s in ("-", "None", "nan"): + return 0.0 + try: + return float(s) + except ValueError: + return 0.0 + + +def _parse_akshare_row(row: dict, multiplier: float) -> Optional[dict]: + code = str(row.get("合约代码") or row.get("代码") or "").strip() + if not code: + return None + m = re.match(r"^([A-Za-z]+)", code) + if not m: + return None + product = m.group(1).lower() + + open_ratio = _to_float(row.get("手续费标准-开仓-万分之")) / 10000.0 + open_fixed = _to_float(row.get("手续费标准-开仓-元")) + if open_fixed == 0 and row.get("开仓"): + open_fixed = _to_float(row.get("开仓")) + close_y_ratio = _to_float(row.get("手续费标准-平昨-万分之")) / 10000.0 + close_y_fixed = _to_float(row.get("手续费标准-平昨-元")) + if close_y_fixed == 0 and row.get("平昨"): + close_y_fixed = _to_float(row.get("平昨")) + close_t_ratio = _to_float(row.get("手续费标准-平今-万分之")) / 10000.0 + close_t_fixed = _to_float(row.get("手续费标准-平今-元")) + if close_t_fixed == 0 and row.get("平今"): + close_t_fixed = _to_float(row.get("平今")) + + mult = int(get_contract_spec(code)["mult"]) + exchange = str(row.get("交易所名称") or row.get("交易所") or "").strip() + + return { + "product": product, + "exchange": exchange, + "mult": mult, + "open_fixed": round(open_fixed * multiplier, 6), + "open_ratio": round(open_ratio * multiplier, 8), + "close_yesterday_fixed": round(close_y_fixed * multiplier, 6), + "close_yesterday_ratio": round(close_y_ratio * multiplier, 8), + "close_today_fixed": round(close_t_fixed * multiplier, 6), + "close_today_ratio": round(close_t_ratio * multiplier, 8), + "source": "akshare", + } + + +def sync_fees_from_akshare(multiplier: Optional[float] = None) -> tuple[int, str]: + multiplier = multiplier if multiplier is not None else get_fee_multiplier() + try: + import akshare as ak + except ImportError: + return 0, "未安装 akshare,请执行 pip install akshare 后重试,或使用默认费率表" + + try: + df = ak.futures_comm_info(symbol="所有") + except Exception as exc: + return 0, f"拉取第三方数据失败: {exc}" + + if df is None or df.empty: + return 0, "第三方返回空数据" + + seen: set[str] = set() + count = 0 + for _, series in df.iterrows(): + row = series.to_dict() + parsed = _parse_akshare_row(row, multiplier) + if not parsed or parsed["product"] in seen: + continue + seen.add(parsed["product"]) + upsert_fee_rate(parsed["product"], parsed) + count += 1 + + return count, f"已同步 {count} 个品种(标准费率 × {multiplier})" diff --git a/modules/fees/routes.py b/modules/fees/routes.py new file mode 100644 index 0000000..9568d89 --- /dev/null +++ b/modules/fees/routes.py @@ -0,0 +1,95 @@ +# Copyright (c) 2025-2026 马建军. All rights reserved. +"""HTTP routes for fees module.""" + +from __future__ import annotations + +from datetime import date, datetime + +from flask import ( + Response, + flash, + jsonify, + redirect, + render_template, + request, + send_file, + session, + stream_with_context, + url_for, +) + + +def register(deps) -> None: + app = deps.app + login_required = deps.login_required + require_nav = deps.require_nav + get_db = deps.get_db + get_setting = deps.get_setting + set_setting = deps.set_setting + fetch_price = deps.fetch_price + send_wechat_msg = deps.send_wechat_msg + touch_stats_cache = deps.touch_stats_cache + get_stats_data = deps.get_stats_data + build_market_quote_payload = deps.build_market_quote_payload + today_str = deps.today_str + expire_old_plans = deps.expire_old_plans + TZ = deps.tz + DB_PATH = deps.db_path + UPLOAD_DIR = deps.upload_dir + OPEN_TYPES = deps.open_types + EXIT_TRIGGERS = deps.exit_triggers + BEHAVIOR_TAGS = deps.behavior_tags + KLINE_PERIODS = deps.kline_periods + KLINE_CUTOFFS = deps.kline_cutoffs + calc_holding_duration = deps.calc_holding_duration + holding_to_minutes = deps.holding_to_minutes + classify_close_result = deps.classify_close_result + calc_rr_ratio = deps.calc_rr_ratio + calc_theoretical_pnl = deps.calc_theoretical_pnl + parse_review_date_filter = deps.parse_review_date_filter + _trading_mode = deps.trading_mode + _ua_is_phone = deps.ua_is_phone + _static_asset_v = deps.static_asset_v + + from modules.fees.fee_specs import count_fee_rates_by_source, list_fee_rates_for_ui + + @app.route("/fees", methods=["GET", "POST"]) + @login_required + @require_nav("fees") + def fees(): + from modules.core.trading_context import get_trading_mode + from modules.ctp.ctp_fee_worker import ( + schedule_ctp_fee_sync, + get_fee_last_sync, + fees_synced_today, + fee_sync_in_progress, + ) + from modules.ctp.vnpy_bridge import ctp_status + + mode = get_trading_mode(get_setting) + if request.method == "POST": + action = request.form.get("action") + if action == "sync_ctp": + force = request.form.get("force") == "1" + _, msg = schedule_ctp_fee_sync( + mode, + get_setting=get_setting, + set_setting=set_setting, + force=force, + ) + flash(msg) + return redirect(url_for("fees")) + + rates = list_fee_rates_for_ui() + fee_counts = count_fee_rates_by_source() + ctp_st = ctp_status(mode) + return render_template( + "fees.html", + rates=rates, + fee_counts=fee_counts, + fee_last_sync=get_fee_last_sync(get_setting), + fee_synced_today=fees_synced_today(get_setting), + fee_sync_running=fee_sync_in_progress(), + ctp_connected=bool(ctp_st.get("connected")), + ) + diff --git a/modules/keys/__init__.py b/modules/keys/__init__.py new file mode 100644 index 0000000..3c867f5 --- /dev/null +++ b/modules/keys/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) 2025-2026 马建军. All rights reserved. + +from modules.keys.routes import register + +__all__ = ["register"] diff --git a/key_monitor_lib.py b/modules/keys/key_monitor_lib.py similarity index 95% rename from key_monitor_lib.py rename to modules/keys/key_monitor_lib.py index 7c2b006..1f0f4d6 100644 --- a/key_monitor_lib.py +++ b/modules/keys/key_monitor_lib.py @@ -1,406 +1,406 @@ -# Copyright (c) 2025-2026 马建军. All rights reserved. -# 专有软件 — 未经授权禁止复制、传播、转售。 -# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 -# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md - -"""关键位监控:5 分钟收盘触发、支阻区微信提醒、箱体/收敛自动单。""" -from __future__ import annotations - -import logging -from datetime import datetime, timedelta -from typing import Any, Callable, Optional -from zoneinfo import ZoneInfo - -from contract_specs import get_contract_spec -from kline_chart import fetch_market_klines - -logger = logging.getLogger(__name__) - -TZ = ZoneInfo("Asia/Shanghai") - -TYPE_BOX = "箱体突破" -TYPE_CONV = "收敛突破" -TYPE_ZONE = "关键支阻区" -AUTO_TYPES = (TYPE_BOX, TYPE_CONV) -ZONE_TYPES = (TYPE_ZONE, "关键阻力位", "关键支撑位") - -ALERT_MAX_PUSH = 3 -ALERT_INTERVAL_SEC = 300 -SL_TICK_BUFFER = 2 -DEFAULT_BAR_PERIOD = "5m" - -PERIOD_MINUTES_MAP = { - "1m": 1, "2m": 2, "3m": 3, "5m": 5, "15m": 15, "30m": 30, - "1h": 60, "2h": 120, "4h": 240, "d": 1440, "1d": 1440, -} - - -def key_monitor_periods() -> list[dict[str, str]]: - """关键位监控可选 K 线周期(触发用)。""" - from kline_chart import MARKET_PERIODS - - allowed = frozenset({"5m", "15m", "30m", "1h", "2h", "4h", "d"}) - return [p for p in MARKET_PERIODS if p["key"] in allowed] - - -def normalize_bar_period(raw: str) -> str: - valid = {p["key"] for p in key_monitor_periods()} - k = (raw or DEFAULT_BAR_PERIOD).strip() - return k if k in valid else DEFAULT_BAR_PERIOD - - -def bar_period_label(key: str) -> str: - k = normalize_bar_period(key) - for p in key_monitor_periods(): - if p["key"] == k: - return p["label"] - return k - - -def bar_period_minutes(period: str) -> int: - return PERIOD_MINUTES_MAP.get(normalize_bar_period(period), 5) - - -def normalize_monitor_type(raw: str) -> str: - t = (raw or "").strip() - if t in ("关键阻力位", "关键支撑位"): - return TYPE_ZONE - return t - - -def is_auto_trade_type(typ: str) -> bool: - return normalize_monitor_type(typ) in AUTO_TYPES - - -def is_zone_type(typ: str) -> bool: - return normalize_monitor_type(typ) == TYPE_ZONE - - -def resolve_order_direction(break_side: str, trade_mode: str) -> str: - """突破方向 + 顺势/反转 → 下单方向。""" - side = (break_side or "").strip().lower() - mode = (trade_mode or "顺势").strip() - if mode == "反转": - return "short" if side == "upper" else "long" - return "long" if side == "upper" else "short" - - -def break_direction_label(break_side: str) -> tuple[str, str]: - if break_side == "upper": - return "向上突破上沿", "long" - return "向下突破下沿", "short" - - -def calc_breakout_sl_tp( - *, - sym: str, - direction: str, - entry: float, - bar: dict, - risk_reward: float, -) -> tuple[float, float]: - tick = float(get_contract_spec(sym).get("tick_size") or 1.0) - bar_high = float(bar.get("high") or entry) - bar_low = float(bar.get("low") or entry) - if direction == "long": - sl = bar_low - SL_TICK_BUFFER * tick - risk = max(entry - sl, tick) - tp = entry + risk * risk_reward - else: - sl = bar_high + SL_TICK_BUFFER * tick - risk = max(sl - entry, tick) - tp = entry - risk * risk_reward - return sl, tp - - -def _parse_bar_time(raw: str) -> Optional[datetime]: - s = (raw or "").strip().replace("T", " ") - if not s: - return None - for fmt in ("%Y-%m-%d %H:%M:%S", "%Y-%m-%d %H:%M"): - try: - return datetime.strptime(s[:19], fmt).replace(tzinfo=TZ) - except ValueError: - continue - return None - - -def last_closed_bar( - bars: list[dict], - period_minutes: int = 5, - now: Optional[datetime] = None, -) -> Optional[dict]: - """取最近一根已收盘 K 线。""" - dnow = now or datetime.now(TZ) - mins = max(1, int(period_minutes or 5)) - for bar in reversed(bars or []): - dt = _parse_bar_time(str(bar.get("time") or "")) - if not dt: - continue - bar_end = dt + timedelta(minutes=mins) - if dnow >= bar_end: - return bar - return None - - -def detect_break_side(close: float, upper: float, lower: float) -> Optional[str]: - if close > upper: - return "upper" - if close < lower: - return "lower" - return None - - -def fetch_closed_bar( - sym: str, - period: str, - *, - db_path: str, - trading_mode: str, -) -> Optional[dict]: - p = normalize_bar_period(period) - try: - data = fetch_market_klines( - sym, - p, - db_path=db_path, - trading_mode=trading_mode, - prefer_ctp=False, - ) - bars = data.get("bars") or [] - return last_closed_bar(bars, bar_period_minutes(p)) - except Exception as exc: - logger.debug("key monitor kline %s %s: %s", sym, p, exc) - return None - - -def _now_iso() -> str: - return datetime.now(TZ).strftime("%Y-%m-%d %H:%M:%S") - - -def archive_monitor(conn, pid: int) -> None: - conn.execute( - "UPDATE key_monitors SET status='archived', archived_at=? WHERE id=?", - (_now_iso(), pid), - ) - - -def format_zone_alert( - row: dict, - *, - break_side: str, - close_price: float, - bar_time: str, - push_index: int, - max_push: int = ALERT_MAX_PUSH, -) -> str: - name = row.get("symbol_name") or row.get("symbol") or "" - upper = float(row.get("upper") or 0) - lower = float(row.get("lower") or 0) - break_label, alert_dir = break_direction_label(break_side) - dir_cn = "多头(long)" if alert_dir == "long" else "空头(short)" - boundary = upper if break_side == "upper" else lower - lines = [ - f"📌 {name} 关键位突破提醒({push_index}/{max_push})", - "", - "🧾 突破概要", - "📌 类型:关键支阻区", - f"⏱ 触发时间:{bar_time}", - f"📊 上沿:{upper:g}|下沿:{lower:g}", - f"💹 触发收盘:{close_price:g}", - f"🎯 {break_label}({dir_cn})", - f"📍 突破价位:{boundary:g}", - "", - "📎 说明", - f"· 人工盯盘,共推送 {max_push} 次(间隔约 {ALERT_INTERVAL_SEC // 60} 分钟)", - "· 推送完毕后本条监控自动结案", - "· 不参与自动开仓", - ] - return "\n".join(lines) - - -def format_auto_breakout_msg( - row: dict, - *, - break_side: str, - direction: str, - entry: float, - sl: float, - tp: float, - lots: int, - bar_time: str, - ok: bool, - detail: str = "", -) -> str: - name = row.get("symbol_name") or row.get("symbol") or "" - typ = normalize_monitor_type(row.get("monitor_type") or "") - trade_mode = row.get("trade_mode") or "顺势" - break_label, _ = break_direction_label(break_side) - dir_cn = "做多" if direction == "long" else "做空" - rr = float(row.get("risk_reward") or 2) - period_label = bar_period_label(row.get("bar_period") or DEFAULT_BAR_PERIOD) - lines = [ - f"{'✅' if ok else '❌'} {name} {typ}自动单", - f"⏱ {period_label} 收盘:{bar_time}", - f"🎯 {break_label} · {trade_mode} · {dir_cn}", - f"💹 入场:{entry:g} 止损:{sl:g} 止盈:{tp:g}(盈亏比 {rr:g})", - f"📦 手数:{lots}", - ] - if int(row.get("trailing_be") or 0): - lines.append("🛡 已开启移动保本(达目标盈亏比自动止盈)") - if detail: - lines.append(detail) - return "\n".join(lines) - - -def _should_send_followup_push(row: dict, now: datetime) -> bool: - count = int(row.get("alert_push_count") or 0) - if count <= 0 or count >= ALERT_MAX_PUSH: - return False - last_raw = (row.get("alert_last_push_at") or "").strip() - if not last_raw: - return True - try: - last = datetime.fromisoformat(last_raw.replace("Z", "")).replace(tzinfo=TZ) - except ValueError: - return True - return (now - last).total_seconds() >= ALERT_INTERVAL_SEC - - -def _record_zone_push(conn, pid: int, *, break_side: str, bar_time: str, now_iso: str) -> int: - row = conn.execute( - "SELECT alert_push_count FROM key_monitors WHERE id=?", (pid,), - ).fetchone() - count = int(row["alert_push_count"] or 0) + 1 - conn.execute( - """UPDATE key_monitors SET - alert_push_count=?, alert_last_push_at=?, alert_break_side=?, - breakout_bar_time=?, upper_triggered=?, lower_triggered=? - WHERE id=?""", - ( - count, - now_iso, - break_side, - bar_time, - 1 if break_side == "upper" else 0, - 1 if break_side == "lower" else 0, - pid, - ), - ) - return count - - -def _handle_zone_alert( - conn, - row: dict, - *, - break_side: str, - bar: dict, - send_wechat: Callable[[str], None], -) -> None: - pid = int(row["id"]) - now_iso = _now_iso() - bar_time = str(bar.get("time") or "")[:19] - close_price = float(bar.get("close") or 0) - bar_key = bar_time - last_bar = (row.get("last_trigger_bar") or "").strip() - if last_bar == bar_key and int(row.get("alert_push_count") or 0) > 0: - return - - push_n = _record_zone_push(conn, pid, break_side=break_side, bar_time=bar_time, now_iso=now_iso) - conn.execute( - "UPDATE key_monitors SET last_trigger_bar=?, alert_close_price=? WHERE id=?", - (bar_key, close_price, pid), - ) - send_wechat(format_zone_alert( - row, break_side=break_side, close_price=close_price, bar_time=bar_time, push_index=push_n, - )) - if push_n >= ALERT_MAX_PUSH: - archive_monitor(conn, pid) - - -def run_key_monitor_check( - conn, - *, - db_path: str, - get_trading_mode_fn: Callable[[], str], - send_wechat: Callable[[str], None], - execute_breakout_fn: Callable[[Any, dict, str], tuple[bool, str]] | None = None, -) -> None: - """扫描 active 关键位监控(5m 收盘触发)。""" - rows = conn.execute( - "SELECT * FROM key_monitors WHERE status='active' OR status IS NULL" - ).fetchall() - mode = get_trading_mode_fn() - now = datetime.now(TZ) - - for r in rows: - row = dict(r) - pid = int(row["id"]) - sym = (row.get("symbol") or "").strip() - typ = normalize_monitor_type(row.get("monitor_type") or "") - if not sym: - continue - - try: - upper = float(row.get("upper") or 0) - lower = float(row.get("lower") or 0) - except (TypeError, ValueError): - continue - if upper <= lower: - continue - - alert_count = int(row.get("alert_push_count") or 0) - if is_zone_type(typ) and alert_count > 0: - if alert_count >= ALERT_MAX_PUSH: - archive_monitor(conn, pid) - continue - if _should_send_followup_push(row, now): - break_side = (row.get("alert_break_side") or "upper").strip() - bar_time = (row.get("breakout_bar_time") or row.get("last_trigger_bar") or "")[:19] - close_price = float(row.get("alert_close_price") or 0) - if close_price <= 0: - close_price = float(row.get("upper") if break_side == "upper" else row.get("lower") or 0) - push_n = _record_zone_push( - conn, pid, break_side=break_side, bar_time=bar_time, now_iso=_now_iso(), - ) - send_wechat(format_zone_alert( - row, break_side=break_side, close_price=close_price, bar_time=bar_time, push_index=push_n, - )) - if push_n >= ALERT_MAX_PUSH: - archive_monitor(conn, pid) - continue - - bar_period = normalize_bar_period(row.get("bar_period") or DEFAULT_BAR_PERIOD) - bar = fetch_closed_bar(sym, bar_period, db_path=db_path, trading_mode=mode) - if not bar: - continue - bar_time = str(bar.get("time") or "")[:19] - if not bar_time: - continue - if (row.get("last_trigger_bar") or "").strip() == bar_time: - continue - - try: - close_price = float(bar.get("close") or 0) - except (TypeError, ValueError): - continue - break_side = detect_break_side(close_price, upper, lower) - if not break_side: - continue - - if is_zone_type(typ): - _handle_zone_alert(conn, row, break_side=break_side, bar=bar, send_wechat=send_wechat) - continue - - if is_auto_trade_type(typ): - if not execute_breakout_fn: - logger.warning("key monitor auto trade skipped: no executor") - continue - ok, detail = execute_breakout_fn(conn, row, bar, break_side) - conn.execute( - "UPDATE key_monitors SET last_trigger_bar=?, breakout_bar_time=?, alert_break_side=? WHERE id=?", - (bar_time, bar_time, break_side, pid), - ) - if ok: - archive_monitor(conn, pid) +# Copyright (c) 2025-2026 马建军. All rights reserved. +# 专有软件 — 未经授权禁止复制、传播、转售。 +# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 +# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md + +"""关键位监控:5 分钟收盘触发、支阻区微信提醒、箱体/收敛自动单。""" +from __future__ import annotations + +import logging +from datetime import datetime, timedelta +from typing import Any, Callable, Optional +from zoneinfo import ZoneInfo + +from modules.core.contract_specs import get_contract_spec +from modules.market.kline_chart import fetch_market_klines + +logger = logging.getLogger(__name__) + +TZ = ZoneInfo("Asia/Shanghai") + +TYPE_BOX = "箱体突破" +TYPE_CONV = "收敛突破" +TYPE_ZONE = "关键支阻区" +AUTO_TYPES = (TYPE_BOX, TYPE_CONV) +ZONE_TYPES = (TYPE_ZONE, "关键阻力位", "关键支撑位") + +ALERT_MAX_PUSH = 3 +ALERT_INTERVAL_SEC = 300 +SL_TICK_BUFFER = 2 +DEFAULT_BAR_PERIOD = "5m" + +PERIOD_MINUTES_MAP = { + "1m": 1, "2m": 2, "3m": 3, "5m": 5, "15m": 15, "30m": 30, + "1h": 60, "2h": 120, "4h": 240, "d": 1440, "1d": 1440, +} + + +def key_monitor_periods() -> list[dict[str, str]]: + """关键位监控可选 K 线周期(触发用)。""" + from modules.market.kline_chart import MARKET_PERIODS + + allowed = frozenset({"5m", "15m", "30m", "1h", "2h", "4h", "d"}) + return [p for p in MARKET_PERIODS if p["key"] in allowed] + + +def normalize_bar_period(raw: str) -> str: + valid = {p["key"] for p in key_monitor_periods()} + k = (raw or DEFAULT_BAR_PERIOD).strip() + return k if k in valid else DEFAULT_BAR_PERIOD + + +def bar_period_label(key: str) -> str: + k = normalize_bar_period(key) + for p in key_monitor_periods(): + if p["key"] == k: + return p["label"] + return k + + +def bar_period_minutes(period: str) -> int: + return PERIOD_MINUTES_MAP.get(normalize_bar_period(period), 5) + + +def normalize_monitor_type(raw: str) -> str: + t = (raw or "").strip() + if t in ("关键阻力位", "关键支撑位"): + return TYPE_ZONE + return t + + +def is_auto_trade_type(typ: str) -> bool: + return normalize_monitor_type(typ) in AUTO_TYPES + + +def is_zone_type(typ: str) -> bool: + return normalize_monitor_type(typ) == TYPE_ZONE + + +def resolve_order_direction(break_side: str, trade_mode: str) -> str: + """突破方向 + 顺势/反转 → 下单方向。""" + side = (break_side or "").strip().lower() + mode = (trade_mode or "顺势").strip() + if mode == "反转": + return "short" if side == "upper" else "long" + return "long" if side == "upper" else "short" + + +def break_direction_label(break_side: str) -> tuple[str, str]: + if break_side == "upper": + return "向上突破上沿", "long" + return "向下突破下沿", "short" + + +def calc_breakout_sl_tp( + *, + sym: str, + direction: str, + entry: float, + bar: dict, + risk_reward: float, +) -> tuple[float, float]: + tick = float(get_contract_spec(sym).get("tick_size") or 1.0) + bar_high = float(bar.get("high") or entry) + bar_low = float(bar.get("low") or entry) + if direction == "long": + sl = bar_low - SL_TICK_BUFFER * tick + risk = max(entry - sl, tick) + tp = entry + risk * risk_reward + else: + sl = bar_high + SL_TICK_BUFFER * tick + risk = max(sl - entry, tick) + tp = entry - risk * risk_reward + return sl, tp + + +def _parse_bar_time(raw: str) -> Optional[datetime]: + s = (raw or "").strip().replace("T", " ") + if not s: + return None + for fmt in ("%Y-%m-%d %H:%M:%S", "%Y-%m-%d %H:%M"): + try: + return datetime.strptime(s[:19], fmt).replace(tzinfo=TZ) + except ValueError: + continue + return None + + +def last_closed_bar( + bars: list[dict], + period_minutes: int = 5, + now: Optional[datetime] = None, +) -> Optional[dict]: + """取最近一根已收盘 K 线。""" + dnow = now or datetime.now(TZ) + mins = max(1, int(period_minutes or 5)) + for bar in reversed(bars or []): + dt = _parse_bar_time(str(bar.get("time") or "")) + if not dt: + continue + bar_end = dt + timedelta(minutes=mins) + if dnow >= bar_end: + return bar + return None + + +def detect_break_side(close: float, upper: float, lower: float) -> Optional[str]: + if close > upper: + return "upper" + if close < lower: + return "lower" + return None + + +def fetch_closed_bar( + sym: str, + period: str, + *, + db_path: str, + trading_mode: str, +) -> Optional[dict]: + p = normalize_bar_period(period) + try: + data = fetch_market_klines( + sym, + p, + db_path=db_path, + trading_mode=trading_mode, + prefer_ctp=False, + ) + bars = data.get("bars") or [] + return last_closed_bar(bars, bar_period_minutes(p)) + except Exception as exc: + logger.debug("key monitor kline %s %s: %s", sym, p, exc) + return None + + +def _now_iso() -> str: + return datetime.now(TZ).strftime("%Y-%m-%d %H:%M:%S") + + +def archive_monitor(conn, pid: int) -> None: + conn.execute( + "UPDATE key_monitors SET status='archived', archived_at=? WHERE id=?", + (_now_iso(), pid), + ) + + +def format_zone_alert( + row: dict, + *, + break_side: str, + close_price: float, + bar_time: str, + push_index: int, + max_push: int = ALERT_MAX_PUSH, +) -> str: + name = row.get("symbol_name") or row.get("symbol") or "" + upper = float(row.get("upper") or 0) + lower = float(row.get("lower") or 0) + break_label, alert_dir = break_direction_label(break_side) + dir_cn = "多头(long)" if alert_dir == "long" else "空头(short)" + boundary = upper if break_side == "upper" else lower + lines = [ + f"📌 {name} 关键位突破提醒({push_index}/{max_push})", + "", + "🧾 突破概要", + "📌 类型:关键支阻区", + f"⏱ 触发时间:{bar_time}", + f"📊 上沿:{upper:g}|下沿:{lower:g}", + f"💹 触发收盘:{close_price:g}", + f"🎯 {break_label}({dir_cn})", + f"📍 突破价位:{boundary:g}", + "", + "📎 说明", + f"· 人工盯盘,共推送 {max_push} 次(间隔约 {ALERT_INTERVAL_SEC // 60} 分钟)", + "· 推送完毕后本条监控自动结案", + "· 不参与自动开仓", + ] + return "\n".join(lines) + + +def format_auto_breakout_msg( + row: dict, + *, + break_side: str, + direction: str, + entry: float, + sl: float, + tp: float, + lots: int, + bar_time: str, + ok: bool, + detail: str = "", +) -> str: + name = row.get("symbol_name") or row.get("symbol") or "" + typ = normalize_monitor_type(row.get("monitor_type") or "") + trade_mode = row.get("trade_mode") or "顺势" + break_label, _ = break_direction_label(break_side) + dir_cn = "做多" if direction == "long" else "做空" + rr = float(row.get("risk_reward") or 2) + period_label = bar_period_label(row.get("bar_period") or DEFAULT_BAR_PERIOD) + lines = [ + f"{'✅' if ok else '❌'} {name} {typ}自动单", + f"⏱ {period_label} 收盘:{bar_time}", + f"🎯 {break_label} · {trade_mode} · {dir_cn}", + f"💹 入场:{entry:g} 止损:{sl:g} 止盈:{tp:g}(盈亏比 {rr:g})", + f"📦 手数:{lots}", + ] + if int(row.get("trailing_be") or 0): + lines.append("🛡 已开启移动保本(达目标盈亏比自动止盈)") + if detail: + lines.append(detail) + return "\n".join(lines) + + +def _should_send_followup_push(row: dict, now: datetime) -> bool: + count = int(row.get("alert_push_count") or 0) + if count <= 0 or count >= ALERT_MAX_PUSH: + return False + last_raw = (row.get("alert_last_push_at") or "").strip() + if not last_raw: + return True + try: + last = datetime.fromisoformat(last_raw.replace("Z", "")).replace(tzinfo=TZ) + except ValueError: + return True + return (now - last).total_seconds() >= ALERT_INTERVAL_SEC + + +def _record_zone_push(conn, pid: int, *, break_side: str, bar_time: str, now_iso: str) -> int: + row = conn.execute( + "SELECT alert_push_count FROM key_monitors WHERE id=?", (pid,), + ).fetchone() + count = int(row["alert_push_count"] or 0) + 1 + conn.execute( + """UPDATE key_monitors SET + alert_push_count=?, alert_last_push_at=?, alert_break_side=?, + breakout_bar_time=?, upper_triggered=?, lower_triggered=? + WHERE id=?""", + ( + count, + now_iso, + break_side, + bar_time, + 1 if break_side == "upper" else 0, + 1 if break_side == "lower" else 0, + pid, + ), + ) + return count + + +def _handle_zone_alert( + conn, + row: dict, + *, + break_side: str, + bar: dict, + send_wechat: Callable[[str], None], +) -> None: + pid = int(row["id"]) + now_iso = _now_iso() + bar_time = str(bar.get("time") or "")[:19] + close_price = float(bar.get("close") or 0) + bar_key = bar_time + last_bar = (row.get("last_trigger_bar") or "").strip() + if last_bar == bar_key and int(row.get("alert_push_count") or 0) > 0: + return + + push_n = _record_zone_push(conn, pid, break_side=break_side, bar_time=bar_time, now_iso=now_iso) + conn.execute( + "UPDATE key_monitors SET last_trigger_bar=?, alert_close_price=? WHERE id=?", + (bar_key, close_price, pid), + ) + send_wechat(format_zone_alert( + row, break_side=break_side, close_price=close_price, bar_time=bar_time, push_index=push_n, + )) + if push_n >= ALERT_MAX_PUSH: + archive_monitor(conn, pid) + + +def run_key_monitor_check( + conn, + *, + db_path: str, + get_trading_mode_fn: Callable[[], str], + send_wechat: Callable[[str], None], + execute_breakout_fn: Callable[[Any, dict, str], tuple[bool, str]] | None = None, +) -> None: + """扫描 active 关键位监控(5m 收盘触发)。""" + rows = conn.execute( + "SELECT * FROM key_monitors WHERE status='active' OR status IS NULL" + ).fetchall() + mode = get_trading_mode_fn() + now = datetime.now(TZ) + + for r in rows: + row = dict(r) + pid = int(row["id"]) + sym = (row.get("symbol") or "").strip() + typ = normalize_monitor_type(row.get("monitor_type") or "") + if not sym: + continue + + try: + upper = float(row.get("upper") or 0) + lower = float(row.get("lower") or 0) + except (TypeError, ValueError): + continue + if upper <= lower: + continue + + alert_count = int(row.get("alert_push_count") or 0) + if is_zone_type(typ) and alert_count > 0: + if alert_count >= ALERT_MAX_PUSH: + archive_monitor(conn, pid) + continue + if _should_send_followup_push(row, now): + break_side = (row.get("alert_break_side") or "upper").strip() + bar_time = (row.get("breakout_bar_time") or row.get("last_trigger_bar") or "")[:19] + close_price = float(row.get("alert_close_price") or 0) + if close_price <= 0: + close_price = float(row.get("upper") if break_side == "upper" else row.get("lower") or 0) + push_n = _record_zone_push( + conn, pid, break_side=break_side, bar_time=bar_time, now_iso=_now_iso(), + ) + send_wechat(format_zone_alert( + row, break_side=break_side, close_price=close_price, bar_time=bar_time, push_index=push_n, + )) + if push_n >= ALERT_MAX_PUSH: + archive_monitor(conn, pid) + continue + + bar_period = normalize_bar_period(row.get("bar_period") or DEFAULT_BAR_PERIOD) + bar = fetch_closed_bar(sym, bar_period, db_path=db_path, trading_mode=mode) + if not bar: + continue + bar_time = str(bar.get("time") or "")[:19] + if not bar_time: + continue + if (row.get("last_trigger_bar") or "").strip() == bar_time: + continue + + try: + close_price = float(bar.get("close") or 0) + except (TypeError, ValueError): + continue + break_side = detect_break_side(close_price, upper, lower) + if not break_side: + continue + + if is_zone_type(typ): + _handle_zone_alert(conn, row, break_side=break_side, bar=bar, send_wechat=send_wechat) + continue + + if is_auto_trade_type(typ): + if not execute_breakout_fn: + logger.warning("key monitor auto trade skipped: no executor") + continue + ok, detail = execute_breakout_fn(conn, row, bar, break_side) + conn.execute( + "UPDATE key_monitors SET last_trigger_bar=?, breakout_bar_time=?, alert_break_side=? WHERE id=?", + (bar_time, bar_time, break_side, pid), + ) + if ok: + archive_monitor(conn, pid) diff --git a/modules/keys/routes.py b/modules/keys/routes.py new file mode 100644 index 0000000..4fc6795 --- /dev/null +++ b/modules/keys/routes.py @@ -0,0 +1,185 @@ +# Copyright (c) 2025-2026 马建军. All rights reserved. +"""HTTP routes for keys module.""" + +from __future__ import annotations + +from datetime import date, datetime + +from flask import ( + Response, + flash, + jsonify, + redirect, + render_template, + request, + send_file, + session, + stream_with_context, + url_for, +) + + +def register(deps) -> None: + app = deps.app + login_required = deps.login_required + require_nav = deps.require_nav + get_db = deps.get_db + get_setting = deps.get_setting + set_setting = deps.set_setting + fetch_price = deps.fetch_price + send_wechat_msg = deps.send_wechat_msg + touch_stats_cache = deps.touch_stats_cache + get_stats_data = deps.get_stats_data + build_market_quote_payload = deps.build_market_quote_payload + today_str = deps.today_str + expire_old_plans = deps.expire_old_plans + TZ = deps.tz + DB_PATH = deps.db_path + UPLOAD_DIR = deps.upload_dir + OPEN_TYPES = deps.open_types + EXIT_TRIGGERS = deps.exit_triggers + BEHAVIOR_TAGS = deps.behavior_tags + KLINE_PERIODS = deps.kline_periods + KLINE_CUTOFFS = deps.kline_cutoffs + calc_holding_duration = deps.calc_holding_duration + holding_to_minutes = deps.holding_to_minutes + classify_close_result = deps.classify_close_result + calc_rr_ratio = deps.calc_rr_ratio + calc_theoretical_pnl = deps.calc_theoretical_pnl + parse_review_date_filter = deps.parse_review_date_filter + _trading_mode = deps.trading_mode + _ua_is_phone = deps.ua_is_phone + _static_asset_v = deps.static_asset_v + + @app.route("/api/key_prices") + @login_required + def api_key_prices(): + """关键位监控列表:批量现价与距上/下沿距离。""" + conn = get_db() + rows = conn.execute( + "SELECT id, symbol, market_code, sina_code, upper, lower " + "FROM key_monitors WHERE status='active' OR status IS NULL" + ).fetchall() + conn.close() + out = [] + for r in rows: + sym = r["symbol"] + market = r["market_code"] or "" + sina = r["sina_code"] or "" + upper = float(r["upper"]) + lower = float(r["lower"]) + price = fetch_price(sym, market, sina) + dist_upper = None + dist_lower = None + if price is not None: + dist_upper = round(upper - price, 2) + dist_lower = round(price - lower, 2) + out.append({ + "id": r["id"], + "price": price, + "dist_upper": dist_upper, + "dist_lower": dist_lower, + }) + return jsonify(out) + @app.route("/keys") + @login_required + def keys(): + from modules.keys.key_monitor_lib import key_monitor_periods + + conn = get_db() + key_list = conn.execute( + "SELECT * FROM key_monitors WHERE status='active' OR status IS NULL ORDER BY id DESC" + ).fetchall() + history = conn.execute( + "SELECT * FROM key_monitors WHERE status='archived' ORDER BY archived_at DESC LIMIT 100" + ).fetchall() + conn.close() + return render_template( + "keys.html", + keys=key_list, + history=history, + key_periods=key_monitor_periods(), + ) + + + + @app.route("/add_key", methods=["POST"]) + @login_required + def add_key(): + d = request.form + symbol = d.get("symbol", "").strip() + symbol_name = d.get("symbol_name", "").strip() + market_code = d.get("market_code", "").strip() + sina_code = d.get("sina_code", "").strip() + monitor_type = (d.get("type") or "").strip() + if not symbol or not market_code: + flash("请从下拉列表选择品种(同花顺合约代码)") + return redirect(url_for("keys")) + try: + upper = float(d.get("upper") or 0) + lower = float(d.get("lower") or 0) + except (TypeError, ValueError): + flash("上沿/下沿价格无效") + return redirect(url_for("keys")) + if upper <= lower: + flash("上沿必须大于下沿") + return redirect(url_for("keys")) + + trade_mode = (d.get("trade_mode") or "顺势").strip() + if trade_mode not in ("顺势", "反转"): + trade_mode = "顺势" + try: + risk_reward = float(d.get("risk_reward") or 2) + except (TypeError, ValueError): + risk_reward = 2.0 + risk_reward = max(0.5, min(10.0, risk_reward)) + trailing_be = 1 if d.get("trailing_be") else 0 + if trailing_be and risk_reward < 3: + risk_reward = 3.0 + + from modules.keys.key_monitor_lib import normalize_bar_period + + bar_period = normalize_bar_period(d.get("bar_period") or "5m") + direction = (d.get("direction") or "").strip().lower() + if monitor_type == "箱体突破": + if direction not in ("long", "short"): + flash("箱体突破须选择上方向(做多/做空)") + return redirect(url_for("keys")) + else: + direction = "" + + conn = get_db() + conn.execute( + """INSERT INTO key_monitors + (symbol, symbol_name, market_code, sina_code, monitor_type, direction, + upper, lower, trade_mode, risk_reward, trailing_be, bar_period) + VALUES (?,?,?,?,?,?,?,?,?,?,?,?)""", + ( + symbol, symbol_name, market_code, sina_code, monitor_type, direction, + upper, lower, trade_mode, risk_reward, trailing_be, bar_period, + ), + ) + conn.commit() + conn.close() + flash("关键位监控已添加") + return redirect(url_for("keys")) + + + @app.route("/add_position", methods=["POST"]) + @login_required + def add_position(): + flash("持仓由策略交易或 CTP 自动同步,无需手工录入") + return redirect(url_for("positions")) + @app.route("/del_key/") + @login_required + def del_key(pid): + conn = get_db() + conn.execute( + "UPDATE key_monitors SET status='archived', archived_at=? WHERE id=?", + (datetime.now(TZ).isoformat(), pid), + ) + conn.commit() + conn.close() + flash("已移入监控历史") + return redirect(url_for("keys")) + diff --git a/modules/market/__init__.py b/modules/market/__init__.py new file mode 100644 index 0000000..74bd115 --- /dev/null +++ b/modules/market/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) 2025-2026 马建军. All rights reserved. + +from modules.market.routes import register + + +def start_workers(deps) -> None: + deps.start_background_threads() + + +__all__ = ["register", "start_workers"] diff --git a/kline_chart.py b/modules/market/kline_chart.py similarity index 94% rename from kline_chart.py rename to modules/market/kline_chart.py index d0d069d..0259926 100644 --- a/kline_chart.py +++ b/modules/market/kline_chart.py @@ -1,558 +1,558 @@ -# Copyright (c) 2025-2026 马建军. All rights reserved. -# 专有软件 — 未经授权禁止复制、传播、转售。 -# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 -# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md - -"""复盘 K 线:新浪拉取 + matplotlib 生成截图。""" -import json -import logging -import os -import re -import sqlite3 -from datetime import datetime -from typing import Optional -from zoneinfo import ZoneInfo - -import requests - -from symbols import ths_to_codes -from db_conn import connect_db -from kline_store import ensure_kline_tables, get_cached_entry, save_bars - -logger = logging.getLogger(__name__) -TZ = ZoneInfo("Asia/Shanghai") - -# CTP tick 聚合 bar 少于此数时,用新浪历史补齐走势 -MIN_CTP_KLINE_BARS = 15 - -PERIOD_MINUTES = { - "1m": "1", - "3m": "3", - "5m": "5", - "15m": "15", - "30m": "30", - "1h": "60", - "4h": "240", -} - -MARKET_PERIODS = [ - {"key": "timeshare", "label": "分时"}, - {"key": "1m", "label": "1分"}, - {"key": "2m", "label": "2分"}, - {"key": "5m", "label": "5分"}, - {"key": "15m", "label": "15分"}, - {"key": "1h", "label": "1小时"}, - {"key": "2h", "label": "2小时"}, - {"key": "4h", "label": "4小时"}, - {"key": "d", "label": "日线"}, - {"key": "w", "label": "周线"}, -] - - -def ths_to_sina_chart_symbol(symbol: str) -> Optional[str]: - """ag2608 -> AG2608(新浪 K 线接口合约代码)。""" - code = (symbol or "").strip() - if not code: - return None - codes = ths_to_codes(code) - if codes: - sina = codes.get("sina_code", "") - if sina.startswith("nf_"): - return sina[3:] - if sina.startswith("CFF_RE_"): - return sina[7:] - ths = codes.get("ths_code", "") - return ths.upper() if ths else None - m = re.match(r"^([A-Za-z]+)(\d+)$", code) - if m: - return m.group(1).upper() + m.group(2) - return None - - -def _parse_jsonp(text: str) -> Optional[list]: - m = re.search(r"\((.*)\)\s*;?\s*$", text.strip(), re.DOTALL) - if not m: - return None - try: - data = json.loads(m.group(1)) - return data if isinstance(data, list) else None - except json.JSONDecodeError: - return None - - -def fetch_sina_klines(symbol: str, period: str) -> list: - """拉取新浪期货 K 线(原始 bar 列表)。""" - chart_sym = ths_to_sina_chart_symbol(symbol) - if not chart_sym: - return [] - p = (period or "").lower() - if p in ("1d", "d"): - return _fetch_sina_daily(chart_sym) - if p == "w": - return _weekly_from_daily(_fetch_sina_daily(chart_sym)) - if p == "timeshare": - bars = _fetch_few_min_line(chart_sym, "1") - return _timeshare_session(bars) - if p == "2m": - return _aggregate_bars(_fetch_few_min_line(chart_sym, "1"), 2) - if p == "2h": - return _aggregate_bars(_fetch_few_min_line(chart_sym, "60"), 2) - typ = PERIOD_MINUTES.get(p) - if typ: - return _fetch_few_min_line(chart_sym, typ) - return [] - - -def _fetch_few_min_line(chart_sym: str, typ: str) -> list: - ts = datetime.now(TZ).strftime("%Y%m%d%H%M%S") - url = ( - "https://stock2.finance.sina.com.cn/futures/api/jsonp.php/" - f"var_{chart_sym}_{typ}_{ts}=/InnerFuturesNewService.getFewMinLine" - f"?symbol={chart_sym}&type={typ}" - ) - try: - resp = requests.get( - url, - timeout=20, - headers={"Referer": "https://finance.sina.com.cn"}, - ) - bars = _parse_jsonp(resp.text) - return _normalize_bars(bars or []) - except Exception as exc: - logger.warning("fetch kline failed %s %s: %s", chart_sym, typ, exc) - return [] - - -def _normalize_bars(raw: list) -> list: - out = [] - for row in raw: - if isinstance(row, list) and len(row) >= 5: - out.append({ - "d": str(row[0]), - "o": float(row[1]), - "h": float(row[2]), - "l": float(row[3]), - "c": float(row[4]), - "v": float(row[5]) if len(row) > 5 and row[5] else 0.0, - }) - elif isinstance(row, dict) and row.get("d"): - out.append({ - "d": str(row["d"]), - "o": float(row.get("o", 0) or 0), - "h": float(row.get("h", 0) or 0), - "l": float(row.get("l", 0) or 0), - "c": float(row.get("c", 0) or 0), - "v": float(row.get("v", 0) or 0), - }) - return out - - -def _aggregate_bars(bars: list, n: int) -> list: - if n <= 1 or not bars: - return bars - out = [] - chunk: list = [] - for bar in bars: - chunk.append(bar) - if len(chunk) >= n: - out.append(_merge_bars(chunk)) - chunk = [] - if chunk: - out.append(_merge_bars(chunk)) - return out - - -def _merge_bars(chunk: list) -> dict: - return { - "d": chunk[0]["d"], - "o": chunk[0]["o"], - "h": max(b["h"] for b in chunk), - "l": min(b["l"] for b in chunk), - "c": chunk[-1]["c"], - "v": sum(b.get("v", 0) for b in chunk), - } - - -def _merge_kline_bars(history: list, live: list) -> list: - """新浪历史 + CTP 实时尾部(去重叠)。""" - if not history: - return list(live or []) - if not live: - return list(history) - first_live = _bar_datetime(live[0]) - if not first_live: - return history + live - trimmed = [] - for bar in history: - dt = _bar_datetime(bar) - if dt and dt < first_live: - trimmed.append(bar) - merged = trimmed + list(live) - return merged if merged else list(history) - - -def _weekly_from_daily(daily: list) -> list: - if not daily: - return [] - buckets: dict[tuple, list] = {} - for bar in daily: - dt = _bar_datetime(bar) - if not dt: - continue - iso = dt.isocalendar() - key = (iso[0], iso[1]) - buckets.setdefault(key, []).append(bar) - out = [] - for key in sorted(buckets.keys()): - chunk = buckets[key] - out.append(_merge_bars(chunk)) - out[-1]["d"] = chunk[-1]["d"] - return out - - -def _timeshare_session(bars: list) -> list: - if not bars: - return [] - today = datetime.now(TZ).date() - session = [] - for bar in bars: - dt = _bar_datetime(bar) - if dt and dt.date() == today: - session.append(bar) - if session: - return session[-480:] - return bars[-480:] - - -def bars_to_api(bars: list) -> list[dict]: - """转为前端图表 JSON(去重、排序、数值规范化)。""" - result: list[dict] = [] - seen: dict[int, dict] = {} - for bar in bars: - dt = _bar_datetime(bar) - ts = int(dt.timestamp() * 1000) if dt else None - try: - o = float(bar.get("o") or 0) - h = float(bar.get("h") or o) - l = float(bar.get("l") or o) - c = float(bar.get("c") or o) - v = float(bar.get("v") or 0) - except (TypeError, ValueError): - continue - if h < l: - h, l = l, h - h = max(h, o, c) - l = min(l, o, c) - row = { - "time": bar["d"], - "timestamp": ts, - "open": o, - "high": h, - "low": l, - "close": c, - "volume": v, - } - if ts is not None: - seen[ts] = row - else: - result.append(row) - if seen: - result = [seen[k] for k in sorted(seen.keys())] - return result - - -def fetch_market_klines( - symbol: str, - period: str, - db_path: Optional[str] = None, - force_remote: bool = False, - *, - trading_mode: Optional[str] = None, - prefer_ctp: bool = False, -) -> dict: - chart_sym = ths_to_sina_chart_symbol(symbol) - p = (period or "15m").lower() - if p == "timeshare": - chart_type = "line" - else: - chart_type = "candle" - - bars: list = [] - source = "remote" - cached_at = None - ctp_connected = False - ctp_bars: list = [] - - if prefer_ctp: - try: - from ctp_kline import fetch_ctp_klines - from vnpy_bridge import ctp_status - - mode = trading_mode - if not mode: - try: - from app import get_setting - from trading_context import get_trading_mode - - mode = get_trading_mode(get_setting) - except Exception: - mode = "simulation" - ctp_connected = bool(ctp_status(mode).get("connected")) - if ctp_connected: - ctp_bars = fetch_ctp_klines(symbol, p, mode) or [] - except Exception as exc: - logger.debug("ctp kline fetch failed %s %s: %s", symbol, p, exc) - - need_sina = force_remote or not prefer_ctp or not ctp_bars or len(ctp_bars) < MIN_CTP_KLINE_BARS - - if ctp_bars and len(ctp_bars) >= MIN_CTP_KLINE_BARS: - bars = ctp_bars - source = "ctp" - - if not bars and db_path and chart_sym and not force_remote and need_sina: - try: - conn = connect_db(db_path) - cached = get_cached_entry(conn, chart_sym, p) - conn.close() - if cached and cached.get("fresh"): - bars = cached["bars"] - source = "local" - cached_at = cached.get("updated_at") - except Exception as exc: - logger.warning("kline cache read failed %s %s: %s", chart_sym, p, exc) - - if not bars or len(ctp_bars) < MIN_CTP_KLINE_BARS or not prefer_ctp: - remote_bars = fetch_sina_klines(symbol, p) - if remote_bars: - if prefer_ctp and ctp_bars and ctp_connected: - bars = _merge_kline_bars(remote_bars, ctp_bars) - source = "ctp+remote" - else: - bars = remote_bars - source = "remote" - if db_path and chart_sym and not ctp_connected: - try: - conn = connect_db(db_path) - ensure_kline_tables(conn) - save_bars(conn, chart_sym, p, remote_bars) - meta = conn.execute( - "SELECT updated_at FROM kline_meta WHERE chart_symbol=? AND period=?", - (chart_sym, p), - ).fetchone() - conn.close() - cached_at = meta[0] if meta else None - except Exception as exc: - logger.warning("kline cache write failed %s %s: %s", chart_sym, p, exc) - elif not bars and db_path and chart_sym: - try: - conn = connect_db(db_path) - cached = get_cached_entry(conn, chart_sym, p) - conn.close() - if cached and cached.get("bars"): - bars = cached["bars"] - source = "local" - cached_at = cached.get("updated_at") - except Exception as exc: - logger.warning("kline cache fallback failed %s %s: %s", chart_sym, p, exc) - - api_bars = bars_to_api(bars) - prev_close = None - if len(api_bars) >= 2: - prev_close = api_bars[-2]["close"] - - return { - "symbol": symbol, - "chart_symbol": chart_sym, - "period": p, - "chart_type": chart_type, - "count": len(bars), - "bars": api_bars, - "prev_close": prev_close, - "source": source, - "cached_at": cached_at, - "ctp_connected": ctp_connected, - } - - -def _fetch_sina_daily(chart_sym: str) -> list: - url = ( - "https://stock2.finance.sina.com.cn/futures/api/json.php/" - f"IndexService.getInnerFuturesDailyKLine?symbol={chart_sym}" - ) - try: - resp = requests.get(url, timeout=20, headers={"Referer": "https://finance.sina.com.cn"}) - raw = resp.json() - if raw and isinstance(raw, list): - bars = _normalize_bars(raw) - if bars: - return bars - except Exception as exc: - logger.warning("fetch daily kline failed %s: %s", chart_sym, exc) - return _daily_from_minutes(chart_sym) - - -def _daily_from_minutes(chart_sym: str) -> list: - """合约日线接口无数据时,由 60 分钟 K 线按日合成。""" - bars_60 = _fetch_few_min_line(chart_sym, "60") - if not bars_60: - bars_60 = _fetch_few_min_line(chart_sym, "240") - buckets: dict[str, list] = {} - for bar in bars_60: - dt = _bar_datetime(bar) - if not dt: - continue - key = dt.strftime("%Y-%m-%d") - buckets.setdefault(key, []).append(bar) - out = [] - for day in sorted(buckets.keys()): - chunk = buckets[day] - merged = _merge_bars(chunk) - merged["d"] = day + " 15:00:00" - out.append(merged) - return out - - -def _parse_dt(value: str) -> Optional[datetime]: - if not value: - return None - v = value.strip().replace("T", " ") - for fmt in ("%Y-%m-%d %H:%M:%S", "%Y-%m-%d %H:%M"): - try: - return datetime.strptime(v, fmt).replace(tzinfo=TZ) - except ValueError: - continue - try: - return datetime.fromisoformat(value.strip()).replace(tzinfo=TZ) - except ValueError: - return None - - -def _bar_datetime(bar: dict) -> Optional[datetime]: - d = bar.get("d") - if not d: - return None - try: - return datetime.strptime(d, "%Y-%m-%d %H:%M:%S").replace(tzinfo=TZ) - except ValueError: - return None - - -def _select_bars( - bars: list, - cutoff: datetime, - count: int, -) -> list: - filtered = [] - for bar in bars: - dt = _bar_datetime(bar) - if dt and dt <= cutoff: - filtered.append(bar) - if not filtered: - filtered = bars - if count > 0 and len(filtered) > count: - filtered = filtered[-count:] - return filtered - - -def generate_review_kline_chart( - symbol: str, - periods: list[str], - count: int, - cutoff_label: str, - open_time: str, - close_time: str, - entry_price: Optional[float], - stop_loss: Optional[float], - take_profit: Optional[float], - close_price: Optional[float], - upload_dir: str, -) -> Optional[str]: - """生成双周期 K 线复盘图,返回 uploads 目录下的文件名。""" - import matplotlib - matplotlib.use("Agg") - import matplotlib.pyplot as plt - import matplotlib.dates as mdates - - now = datetime.now(TZ) - if cutoff_label == "开仓时间": - cutoff = _parse_dt(open_time) or now - elif cutoff_label == "当前时间": - cutoff = now - else: - cutoff = _parse_dt(close_time) or now - - open_dt = _parse_dt(open_time) - close_dt = _parse_dt(close_time) - - valid_periods = [p for p in periods if p] - if not valid_periods: - valid_periods = ["15m", "1h"] - - fig, axes = plt.subplots( - len(valid_periods), 1, - figsize=(14, 4.5 * len(valid_periods)), - facecolor="#0a0a10", - squeeze=False, - ) - - plotted = False - for idx, period in enumerate(valid_periods): - ax = axes[idx, 0] - bars = fetch_sina_klines(symbol, period) - bars = _select_bars(bars, cutoff, count) - if not bars: - ax.set_facecolor("#12121a") - ax.text(0.5, 0.5, f"No {period} data", ha="center", va="center", color="#888") - ax.set_xticks([]) - ax.set_yticks([]) - continue - - times = [_bar_datetime(b) for b in bars] - closes = [float(b["c"]) for b in bars] - highs = [float(b["h"]) for b in bars] - lows = [float(b["l"]) for b in bars] - - ax.set_facecolor("#12121a") - ax.plot(times, closes, color="#4cc2ff", linewidth=1.2) - ax.fill_between( - times, lows, highs, - color="#4cc2ff", alpha=0.12, - ) - - levels = [ - (entry_price, "#eac147", "Entry"), - (stop_loss, "#ff6666", "SL"), - (take_profit, "#4cd97f", "TP"), - (close_price, "#c4c4ff", "Close"), - ] - for price, color, label in levels: - if price is not None: - ax.axhline(price, color=color, linewidth=0.9, linestyle="--", alpha=0.85) - ax.text(times[-1], price, label, color=color, fontsize=8, va="bottom") - - if open_dt: - ax.axvline(open_dt, color="#888", linewidth=0.8, linestyle=":", alpha=0.7) - if close_dt: - ax.axvline(close_dt, color="#aaa", linewidth=0.8, linestyle=":", alpha=0.7) - - chart_sym = ths_to_sina_chart_symbol(symbol) or symbol - ax.set_title(f"{chart_sym} {period}", color="#eaeaea", fontsize=11, pad=8) - ax.tick_params(colors="#888", labelsize=8) - for spine in ax.spines.values(): - spine.set_color("#2e2e45") - ax.xaxis.set_major_formatter(mdates.DateFormatter("%m-%d %H:%M")) - ax.grid(True, color="#1e1e30", linewidth=0.5) - plotted = True - - if not plotted: - plt.close(fig) - return None - - fig.tight_layout() - ts = datetime.now(TZ).strftime("%Y%m%d%H%M%S") - chart_sym = ths_to_sina_chart_symbol(symbol) or "chart" - filename = f"{ts}_kline_{chart_sym}.png" - path = os.path.join(upload_dir, filename) - fig.savefig(path, dpi=120, facecolor=fig.get_facecolor()) - plt.close(fig) - return filename +# Copyright (c) 2025-2026 马建军. All rights reserved. +# 专有软件 — 未经授权禁止复制、传播、转售。 +# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 +# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md + +"""复盘 K 线:新浪拉取 + matplotlib 生成截图。""" +import json +import logging +import os +import re +import sqlite3 +from datetime import datetime +from typing import Optional +from zoneinfo import ZoneInfo + +import requests + +from modules.core.symbols import ths_to_codes +from modules.core.db_conn import connect_db +from modules.market.kline_store import ensure_kline_tables, get_cached_entry, save_bars + +logger = logging.getLogger(__name__) +TZ = ZoneInfo("Asia/Shanghai") + +# CTP tick 聚合 bar 少于此数时,用新浪历史补齐走势 +MIN_CTP_KLINE_BARS = 15 + +PERIOD_MINUTES = { + "1m": "1", + "3m": "3", + "5m": "5", + "15m": "15", + "30m": "30", + "1h": "60", + "4h": "240", +} + +MARKET_PERIODS = [ + {"key": "timeshare", "label": "分时"}, + {"key": "1m", "label": "1分"}, + {"key": "2m", "label": "2分"}, + {"key": "5m", "label": "5分"}, + {"key": "15m", "label": "15分"}, + {"key": "1h", "label": "1小时"}, + {"key": "2h", "label": "2小时"}, + {"key": "4h", "label": "4小时"}, + {"key": "d", "label": "日线"}, + {"key": "w", "label": "周线"}, +] + + +def ths_to_sina_chart_symbol(symbol: str) -> Optional[str]: + """ag2608 -> AG2608(新浪 K 线接口合约代码)。""" + code = (symbol or "").strip() + if not code: + return None + codes = ths_to_codes(code) + if codes: + sina = codes.get("sina_code", "") + if sina.startswith("nf_"): + return sina[3:] + if sina.startswith("CFF_RE_"): + return sina[7:] + ths = codes.get("ths_code", "") + return ths.upper() if ths else None + m = re.match(r"^([A-Za-z]+)(\d+)$", code) + if m: + return m.group(1).upper() + m.group(2) + return None + + +def _parse_jsonp(text: str) -> Optional[list]: + m = re.search(r"\((.*)\)\s*;?\s*$", text.strip(), re.DOTALL) + if not m: + return None + try: + data = json.loads(m.group(1)) + return data if isinstance(data, list) else None + except json.JSONDecodeError: + return None + + +def fetch_sina_klines(symbol: str, period: str) -> list: + """拉取新浪期货 K 线(原始 bar 列表)。""" + chart_sym = ths_to_sina_chart_symbol(symbol) + if not chart_sym: + return [] + p = (period or "").lower() + if p in ("1d", "d"): + return _fetch_sina_daily(chart_sym) + if p == "w": + return _weekly_from_daily(_fetch_sina_daily(chart_sym)) + if p == "timeshare": + bars = _fetch_few_min_line(chart_sym, "1") + return _timeshare_session(bars) + if p == "2m": + return _aggregate_bars(_fetch_few_min_line(chart_sym, "1"), 2) + if p == "2h": + return _aggregate_bars(_fetch_few_min_line(chart_sym, "60"), 2) + typ = PERIOD_MINUTES.get(p) + if typ: + return _fetch_few_min_line(chart_sym, typ) + return [] + + +def _fetch_few_min_line(chart_sym: str, typ: str) -> list: + ts = datetime.now(TZ).strftime("%Y%m%d%H%M%S") + url = ( + "https://stock2.finance.sina.com.cn/futures/api/jsonp.php/" + f"var_{chart_sym}_{typ}_{ts}=/InnerFuturesNewService.getFewMinLine" + f"?symbol={chart_sym}&type={typ}" + ) + try: + resp = requests.get( + url, + timeout=20, + headers={"Referer": "https://finance.sina.com.cn"}, + ) + bars = _parse_jsonp(resp.text) + return _normalize_bars(bars or []) + except Exception as exc: + logger.warning("fetch kline failed %s %s: %s", chart_sym, typ, exc) + return [] + + +def _normalize_bars(raw: list) -> list: + out = [] + for row in raw: + if isinstance(row, list) and len(row) >= 5: + out.append({ + "d": str(row[0]), + "o": float(row[1]), + "h": float(row[2]), + "l": float(row[3]), + "c": float(row[4]), + "v": float(row[5]) if len(row) > 5 and row[5] else 0.0, + }) + elif isinstance(row, dict) and row.get("d"): + out.append({ + "d": str(row["d"]), + "o": float(row.get("o", 0) or 0), + "h": float(row.get("h", 0) or 0), + "l": float(row.get("l", 0) or 0), + "c": float(row.get("c", 0) or 0), + "v": float(row.get("v", 0) or 0), + }) + return out + + +def _aggregate_bars(bars: list, n: int) -> list: + if n <= 1 or not bars: + return bars + out = [] + chunk: list = [] + for bar in bars: + chunk.append(bar) + if len(chunk) >= n: + out.append(_merge_bars(chunk)) + chunk = [] + if chunk: + out.append(_merge_bars(chunk)) + return out + + +def _merge_bars(chunk: list) -> dict: + return { + "d": chunk[0]["d"], + "o": chunk[0]["o"], + "h": max(b["h"] for b in chunk), + "l": min(b["l"] for b in chunk), + "c": chunk[-1]["c"], + "v": sum(b.get("v", 0) for b in chunk), + } + + +def _merge_kline_bars(history: list, live: list) -> list: + """新浪历史 + CTP 实时尾部(去重叠)。""" + if not history: + return list(live or []) + if not live: + return list(history) + first_live = _bar_datetime(live[0]) + if not first_live: + return history + live + trimmed = [] + for bar in history: + dt = _bar_datetime(bar) + if dt and dt < first_live: + trimmed.append(bar) + merged = trimmed + list(live) + return merged if merged else list(history) + + +def _weekly_from_daily(daily: list) -> list: + if not daily: + return [] + buckets: dict[tuple, list] = {} + for bar in daily: + dt = _bar_datetime(bar) + if not dt: + continue + iso = dt.isocalendar() + key = (iso[0], iso[1]) + buckets.setdefault(key, []).append(bar) + out = [] + for key in sorted(buckets.keys()): + chunk = buckets[key] + out.append(_merge_bars(chunk)) + out[-1]["d"] = chunk[-1]["d"] + return out + + +def _timeshare_session(bars: list) -> list: + if not bars: + return [] + today = datetime.now(TZ).date() + session = [] + for bar in bars: + dt = _bar_datetime(bar) + if dt and dt.date() == today: + session.append(bar) + if session: + return session[-480:] + return bars[-480:] + + +def bars_to_api(bars: list) -> list[dict]: + """转为前端图表 JSON(去重、排序、数值规范化)。""" + result: list[dict] = [] + seen: dict[int, dict] = {} + for bar in bars: + dt = _bar_datetime(bar) + ts = int(dt.timestamp() * 1000) if dt else None + try: + o = float(bar.get("o") or 0) + h = float(bar.get("h") or o) + l = float(bar.get("l") or o) + c = float(bar.get("c") or o) + v = float(bar.get("v") or 0) + except (TypeError, ValueError): + continue + if h < l: + h, l = l, h + h = max(h, o, c) + l = min(l, o, c) + row = { + "time": bar["d"], + "timestamp": ts, + "open": o, + "high": h, + "low": l, + "close": c, + "volume": v, + } + if ts is not None: + seen[ts] = row + else: + result.append(row) + if seen: + result = [seen[k] for k in sorted(seen.keys())] + return result + + +def fetch_market_klines( + symbol: str, + period: str, + db_path: Optional[str] = None, + force_remote: bool = False, + *, + trading_mode: Optional[str] = None, + prefer_ctp: bool = False, +) -> dict: + chart_sym = ths_to_sina_chart_symbol(symbol) + p = (period or "15m").lower() + if p == "timeshare": + chart_type = "line" + else: + chart_type = "candle" + + bars: list = [] + source = "remote" + cached_at = None + ctp_connected = False + ctp_bars: list = [] + + if prefer_ctp: + try: + from modules.ctp.ctp_kline import fetch_ctp_klines + from modules.ctp.vnpy_bridge import ctp_status + + mode = trading_mode + if not mode: + try: + from app import get_setting + from modules.core.trading_context import get_trading_mode + + mode = get_trading_mode(get_setting) + except Exception: + mode = "simulation" + ctp_connected = bool(ctp_status(mode).get("connected")) + if ctp_connected: + ctp_bars = fetch_ctp_klines(symbol, p, mode) or [] + except Exception as exc: + logger.debug("ctp kline fetch failed %s %s: %s", symbol, p, exc) + + need_sina = force_remote or not prefer_ctp or not ctp_bars or len(ctp_bars) < MIN_CTP_KLINE_BARS + + if ctp_bars and len(ctp_bars) >= MIN_CTP_KLINE_BARS: + bars = ctp_bars + source = "ctp" + + if not bars and db_path and chart_sym and not force_remote and need_sina: + try: + conn = connect_db(db_path) + cached = get_cached_entry(conn, chart_sym, p) + conn.close() + if cached and cached.get("fresh"): + bars = cached["bars"] + source = "local" + cached_at = cached.get("updated_at") + except Exception as exc: + logger.warning("kline cache read failed %s %s: %s", chart_sym, p, exc) + + if not bars or len(ctp_bars) < MIN_CTP_KLINE_BARS or not prefer_ctp: + remote_bars = fetch_sina_klines(symbol, p) + if remote_bars: + if prefer_ctp and ctp_bars and ctp_connected: + bars = _merge_kline_bars(remote_bars, ctp_bars) + source = "ctp+remote" + else: + bars = remote_bars + source = "remote" + if db_path and chart_sym and not ctp_connected: + try: + conn = connect_db(db_path) + ensure_kline_tables(conn) + save_bars(conn, chart_sym, p, remote_bars) + meta = conn.execute( + "SELECT updated_at FROM kline_meta WHERE chart_symbol=? AND period=?", + (chart_sym, p), + ).fetchone() + conn.close() + cached_at = meta[0] if meta else None + except Exception as exc: + logger.warning("kline cache write failed %s %s: %s", chart_sym, p, exc) + elif not bars and db_path and chart_sym: + try: + conn = connect_db(db_path) + cached = get_cached_entry(conn, chart_sym, p) + conn.close() + if cached and cached.get("bars"): + bars = cached["bars"] + source = "local" + cached_at = cached.get("updated_at") + except Exception as exc: + logger.warning("kline cache fallback failed %s %s: %s", chart_sym, p, exc) + + api_bars = bars_to_api(bars) + prev_close = None + if len(api_bars) >= 2: + prev_close = api_bars[-2]["close"] + + return { + "symbol": symbol, + "chart_symbol": chart_sym, + "period": p, + "chart_type": chart_type, + "count": len(bars), + "bars": api_bars, + "prev_close": prev_close, + "source": source, + "cached_at": cached_at, + "ctp_connected": ctp_connected, + } + + +def _fetch_sina_daily(chart_sym: str) -> list: + url = ( + "https://stock2.finance.sina.com.cn/futures/api/json.php/" + f"IndexService.getInnerFuturesDailyKLine?symbol={chart_sym}" + ) + try: + resp = requests.get(url, timeout=20, headers={"Referer": "https://finance.sina.com.cn"}) + raw = resp.json() + if raw and isinstance(raw, list): + bars = _normalize_bars(raw) + if bars: + return bars + except Exception as exc: + logger.warning("fetch daily kline failed %s: %s", chart_sym, exc) + return _daily_from_minutes(chart_sym) + + +def _daily_from_minutes(chart_sym: str) -> list: + """合约日线接口无数据时,由 60 分钟 K 线按日合成。""" + bars_60 = _fetch_few_min_line(chart_sym, "60") + if not bars_60: + bars_60 = _fetch_few_min_line(chart_sym, "240") + buckets: dict[str, list] = {} + for bar in bars_60: + dt = _bar_datetime(bar) + if not dt: + continue + key = dt.strftime("%Y-%m-%d") + buckets.setdefault(key, []).append(bar) + out = [] + for day in sorted(buckets.keys()): + chunk = buckets[day] + merged = _merge_bars(chunk) + merged["d"] = day + " 15:00:00" + out.append(merged) + return out + + +def _parse_dt(value: str) -> Optional[datetime]: + if not value: + return None + v = value.strip().replace("T", " ") + for fmt in ("%Y-%m-%d %H:%M:%S", "%Y-%m-%d %H:%M"): + try: + return datetime.strptime(v, fmt).replace(tzinfo=TZ) + except ValueError: + continue + try: + return datetime.fromisoformat(value.strip()).replace(tzinfo=TZ) + except ValueError: + return None + + +def _bar_datetime(bar: dict) -> Optional[datetime]: + d = bar.get("d") + if not d: + return None + try: + return datetime.strptime(d, "%Y-%m-%d %H:%M:%S").replace(tzinfo=TZ) + except ValueError: + return None + + +def _select_bars( + bars: list, + cutoff: datetime, + count: int, +) -> list: + filtered = [] + for bar in bars: + dt = _bar_datetime(bar) + if dt and dt <= cutoff: + filtered.append(bar) + if not filtered: + filtered = bars + if count > 0 and len(filtered) > count: + filtered = filtered[-count:] + return filtered + + +def generate_review_kline_chart( + symbol: str, + periods: list[str], + count: int, + cutoff_label: str, + open_time: str, + close_time: str, + entry_price: Optional[float], + stop_loss: Optional[float], + take_profit: Optional[float], + close_price: Optional[float], + upload_dir: str, +) -> Optional[str]: + """生成双周期 K 线复盘图,返回 uploads 目录下的文件名。""" + import matplotlib + matplotlib.use("Agg") + import matplotlib.pyplot as plt + import matplotlib.dates as mdates + + now = datetime.now(TZ) + if cutoff_label == "开仓时间": + cutoff = _parse_dt(open_time) or now + elif cutoff_label == "当前时间": + cutoff = now + else: + cutoff = _parse_dt(close_time) or now + + open_dt = _parse_dt(open_time) + close_dt = _parse_dt(close_time) + + valid_periods = [p for p in periods if p] + if not valid_periods: + valid_periods = ["15m", "1h"] + + fig, axes = plt.subplots( + len(valid_periods), 1, + figsize=(14, 4.5 * len(valid_periods)), + facecolor="#0a0a10", + squeeze=False, + ) + + plotted = False + for idx, period in enumerate(valid_periods): + ax = axes[idx, 0] + bars = fetch_sina_klines(symbol, period) + bars = _select_bars(bars, cutoff, count) + if not bars: + ax.set_facecolor("#12121a") + ax.text(0.5, 0.5, f"No {period} data", ha="center", va="center", color="#888") + ax.set_xticks([]) + ax.set_yticks([]) + continue + + times = [_bar_datetime(b) for b in bars] + closes = [float(b["c"]) for b in bars] + highs = [float(b["h"]) for b in bars] + lows = [float(b["l"]) for b in bars] + + ax.set_facecolor("#12121a") + ax.plot(times, closes, color="#4cc2ff", linewidth=1.2) + ax.fill_between( + times, lows, highs, + color="#4cc2ff", alpha=0.12, + ) + + levels = [ + (entry_price, "#eac147", "Entry"), + (stop_loss, "#ff6666", "SL"), + (take_profit, "#4cd97f", "TP"), + (close_price, "#c4c4ff", "Close"), + ] + for price, color, label in levels: + if price is not None: + ax.axhline(price, color=color, linewidth=0.9, linestyle="--", alpha=0.85) + ax.text(times[-1], price, label, color=color, fontsize=8, va="bottom") + + if open_dt: + ax.axvline(open_dt, color="#888", linewidth=0.8, linestyle=":", alpha=0.7) + if close_dt: + ax.axvline(close_dt, color="#aaa", linewidth=0.8, linestyle=":", alpha=0.7) + + chart_sym = ths_to_sina_chart_symbol(symbol) or symbol + ax.set_title(f"{chart_sym} {period}", color="#eaeaea", fontsize=11, pad=8) + ax.tick_params(colors="#888", labelsize=8) + for spine in ax.spines.values(): + spine.set_color("#2e2e45") + ax.xaxis.set_major_formatter(mdates.DateFormatter("%m-%d %H:%M")) + ax.grid(True, color="#1e1e30", linewidth=0.5) + plotted = True + + if not plotted: + plt.close(fig) + return None + + fig.tight_layout() + ts = datetime.now(TZ).strftime("%Y%m%d%H%M%S") + chart_sym = ths_to_sina_chart_symbol(symbol) or "chart" + filename = f"{ts}_kline_{chart_sym}.png" + path = os.path.join(upload_dir, filename) + fig.savefig(path, dpi=120, facecolor=fig.get_facecolor()) + plt.close(fig) + return filename diff --git a/kline_store.py b/modules/market/kline_store.py similarity index 100% rename from kline_store.py rename to modules/market/kline_store.py diff --git a/kline_stream.py b/modules/market/kline_stream.py similarity index 91% rename from kline_stream.py rename to modules/market/kline_stream.py index db361ca..4116a26 100644 --- a/kline_stream.py +++ b/modules/market/kline_stream.py @@ -1,139 +1,139 @@ -# Copyright (c) 2025-2026 马建军. All rights reserved. -# 专有软件 — 未经授权禁止复制、传播、转售。 -# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 -# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md - -"""K 线 SSE 推送与后台刷新。""" -from __future__ import annotations - -import json -import logging -import queue -import threading -import time -from dataclasses import dataclass, field -from datetime import datetime -from typing import Callable, Optional -from zoneinfo import ZoneInfo - -from kline_chart import fetch_market_klines, ths_to_sina_chart_symbol -from kline_store import is_cache_fresh, load_meta, ensure_kline_tables -from market_sessions import is_trading_session - -logger = logging.getLogger(__name__) -TZ = ZoneInfo("Asia/Shanghai") - -FAST_PERIODS = frozenset({ - "timeshare", "1m", "2m", "5m", "15m", "1h", "2h", "4h", -}) - - -def sse_format(event: str, data: dict) -> str: - return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False, default=str)}\n\n" - - -@dataclass -class KlineSubscription: - symbol: str - period: str - market_code: str = "" - sina_code: str = "" - queue: queue.Queue = field(default_factory=queue.Queue) - - -class KlineStreamHub: - def __init__(self): - self._lock = threading.Lock() - self._subs: list[KlineSubscription] = [] - - def subscribe( - self, - symbol: str, - period: str, - market_code: str = "", - sina_code: str = "", - ) -> KlineSubscription: - sub = KlineSubscription( - symbol=symbol.strip(), - period=(period or "15m").strip().lower(), - market_code=market_code.strip(), - sina_code=sina_code.strip(), - ) - with self._lock: - self._subs.append(sub) - return sub - - def unsubscribe(self, sub: KlineSubscription) -> None: - with self._lock: - try: - self._subs.remove(sub) - except ValueError: - pass - - def _snapshot_subs(self) -> list[KlineSubscription]: - with self._lock: - return list(self._subs) - - def publish(self, sub: KlineSubscription, event: str, data: dict) -> None: - try: - sub.queue.put_nowait({"event": event, "data": data}) - except queue.Full: - pass - - def _should_refresh(self, sub: KlineSubscription, db_path: str) -> bool: - chart_sym = ths_to_sina_chart_symbol(sub.symbol) - if not chart_sym: - return False - if is_trading_session() and sub.period in FAST_PERIODS: - return True - try: - from db_conn import connect_db - conn = connect_db(db_path) - ensure_kline_tables(conn) - meta = load_meta(conn, chart_sym, sub.period) - conn.close() - if not meta: - return True - return not is_cache_fresh(sub.period, meta.get("updated_at", "")) - except Exception as exc: - logger.warning("kline refresh check failed: %s", exc) - return True - - def worker_loop( - self, - db_path: str, - quote_fn: Callable[..., dict], - get_mode_fn: Optional[Callable[[], str]] = None, - ) -> None: - while True: - try: - subs = self._snapshot_subs() - for sub in subs: - if not self._should_refresh(sub, db_path): - continue - try: - kline_data = fetch_market_klines( - sub.symbol, - sub.period, - db_path, - force_remote=True, - prefer_ctp=False, - ) - if kline_data.get("bars"): - self.publish(sub, "kline", kline_data) - quote_data = quote_fn( - sub.symbol, sub.market_code, sub.sina_code, - ) - if quote_data: - self.publish(sub, "quote", quote_data) - except Exception as exc: - logger.warning( - "kline stream refresh %s %s: %s", - sub.symbol, sub.period, exc, - ) - except Exception as exc: - logger.warning("kline stream worker: %s", exc) - time.sleep(1) - - -kline_hub = KlineStreamHub() +# Copyright (c) 2025-2026 马建军. All rights reserved. +# 专有软件 — 未经授权禁止复制、传播、转售。 +# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 +# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md + +"""K 线 SSE 推送与后台刷新。""" +from __future__ import annotations + +import json +import logging +import queue +import threading +import time +from dataclasses import dataclass, field +from datetime import datetime +from typing import Callable, Optional +from zoneinfo import ZoneInfo + +from modules.market.kline_chart import fetch_market_klines, ths_to_sina_chart_symbol +from modules.market.kline_store import is_cache_fresh, load_meta, ensure_kline_tables +from modules.market.market_sessions import is_trading_session + +logger = logging.getLogger(__name__) +TZ = ZoneInfo("Asia/Shanghai") + +FAST_PERIODS = frozenset({ + "timeshare", "1m", "2m", "5m", "15m", "1h", "2h", "4h", +}) + + +def sse_format(event: str, data: dict) -> str: + return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False, default=str)}\n\n" + + +@dataclass +class KlineSubscription: + symbol: str + period: str + market_code: str = "" + sina_code: str = "" + queue: queue.Queue = field(default_factory=queue.Queue) + + +class KlineStreamHub: + def __init__(self): + self._lock = threading.Lock() + self._subs: list[KlineSubscription] = [] + + def subscribe( + self, + symbol: str, + period: str, + market_code: str = "", + sina_code: str = "", + ) -> KlineSubscription: + sub = KlineSubscription( + symbol=symbol.strip(), + period=(period or "15m").strip().lower(), + market_code=market_code.strip(), + sina_code=sina_code.strip(), + ) + with self._lock: + self._subs.append(sub) + return sub + + def unsubscribe(self, sub: KlineSubscription) -> None: + with self._lock: + try: + self._subs.remove(sub) + except ValueError: + pass + + def _snapshot_subs(self) -> list[KlineSubscription]: + with self._lock: + return list(self._subs) + + def publish(self, sub: KlineSubscription, event: str, data: dict) -> None: + try: + sub.queue.put_nowait({"event": event, "data": data}) + except queue.Full: + pass + + def _should_refresh(self, sub: KlineSubscription, db_path: str) -> bool: + chart_sym = ths_to_sina_chart_symbol(sub.symbol) + if not chart_sym: + return False + if is_trading_session() and sub.period in FAST_PERIODS: + return True + try: + from modules.core.db_conn import connect_db + conn = connect_db(db_path) + ensure_kline_tables(conn) + meta = load_meta(conn, chart_sym, sub.period) + conn.close() + if not meta: + return True + return not is_cache_fresh(sub.period, meta.get("updated_at", "")) + except Exception as exc: + logger.warning("kline refresh check failed: %s", exc) + return True + + def worker_loop( + self, + db_path: str, + quote_fn: Callable[..., dict], + get_mode_fn: Optional[Callable[[], str]] = None, + ) -> None: + while True: + try: + subs = self._snapshot_subs() + for sub in subs: + if not self._should_refresh(sub, db_path): + continue + try: + kline_data = fetch_market_klines( + sub.symbol, + sub.period, + db_path, + force_remote=True, + prefer_ctp=False, + ) + if kline_data.get("bars"): + self.publish(sub, "kline", kline_data) + quote_data = quote_fn( + sub.symbol, sub.market_code, sub.sina_code, + ) + if quote_data: + self.publish(sub, "quote", quote_data) + except Exception as exc: + logger.warning( + "kline stream refresh %s %s: %s", + sub.symbol, sub.period, exc, + ) + except Exception as exc: + logger.warning("kline stream worker: %s", exc) + time.sleep(1) + + +kline_hub = KlineStreamHub() diff --git a/market.py b/modules/market/market.py similarity index 100% rename from market.py rename to modules/market/market.py diff --git a/market_sessions.py b/modules/market/market_sessions.py similarity index 100% rename from market_sessions.py rename to modules/market/market_sessions.py diff --git a/modules/market/routes.py b/modules/market/routes.py new file mode 100644 index 0000000..b9a69e9 --- /dev/null +++ b/modules/market/routes.py @@ -0,0 +1,230 @@ +# Copyright (c) 2025-2026 马建军. All rights reserved. +"""HTTP routes for market module.""" + +from __future__ import annotations + +from datetime import date, datetime + +from flask import ( + Response, + flash, + jsonify, + redirect, + render_template, + request, + send_file, + session, + stream_with_context, + url_for, +) + + +def register(deps) -> None: + app = deps.app + login_required = deps.login_required + require_nav = deps.require_nav + get_db = deps.get_db + get_setting = deps.get_setting + set_setting = deps.set_setting + fetch_price = deps.fetch_price + send_wechat_msg = deps.send_wechat_msg + touch_stats_cache = deps.touch_stats_cache + get_stats_data = deps.get_stats_data + build_market_quote_payload = deps.build_market_quote_payload + today_str = deps.today_str + expire_old_plans = deps.expire_old_plans + TZ = deps.tz + DB_PATH = deps.db_path + UPLOAD_DIR = deps.upload_dir + OPEN_TYPES = deps.open_types + EXIT_TRIGGERS = deps.exit_triggers + BEHAVIOR_TAGS = deps.behavior_tags + KLINE_PERIODS = deps.kline_periods + KLINE_CUTOFFS = deps.kline_cutoffs + calc_holding_duration = deps.calc_holding_duration + holding_to_minutes = deps.holding_to_minutes + classify_close_result = deps.classify_close_result + calc_rr_ratio = deps.calc_rr_ratio + calc_theoretical_pnl = deps.calc_theoretical_pnl + parse_review_date_filter = deps.parse_review_date_filter + _trading_mode = deps.trading_mode + _ua_is_phone = deps.ua_is_phone + _static_asset_v = deps.static_asset_v + + from modules.core.symbols import ( + list_main_contracts_grouped, + list_recommended_symbols_grouped, + search_symbols, + ) + from modules.market.kline_chart import MARKET_PERIODS, fetch_market_klines + from modules.market.kline_stream import kline_hub, sse_format + from modules.market.market import get_quote_source_label + from queue import Empty + + @app.route("/api/symbols/search") + @login_required + def api_symbol_search(): + q = request.args.get("q", "") + conn = get_db() + try: + from modules.core.trading_context import get_account_capital, is_ctp_connected + capital = get_account_capital(conn, get_setting) + ctp_connected = is_ctp_connected(get_setting) + finally: + conn.close() + return jsonify(search_symbols(q, capital=capital, ctp_connected=ctp_connected)) + + + @app.route("/api/symbols/mains") + @login_required + def api_symbols_mains(): + return jsonify(list_main_contracts_grouped()) + + + @app.route("/api/symbols/recommended") + @login_required + def api_symbols_recommended(): + """品种下拉:仅展示当前资金下可开仓品种(与下方可开仓品种表一致)。""" + from modules.trading.recommend_store import recommend_payload + from modules.core.trading_context import ( + get_fixed_lots, + get_max_margin_pct, + get_recommend_capital, + get_sizing_mode, + get_trading_mode, + ) + + conn = get_db() + try: + capital = get_recommend_capital(conn, get_setting) + payload = recommend_payload( + conn, + live_capital=capital, + max_margin_pct=get_max_margin_pct(get_setting), + trading_mode=get_trading_mode(get_setting), + sizing_mode=get_sizing_mode(get_setting), + fixed_lots=get_fixed_lots(get_setting), + ) + return jsonify(list_recommended_symbols_grouped(payload.get("rows") or [])) + finally: + conn.close() + + @app.route("/market") + @login_required + @require_nav("market") + def market_page(): + symbol = request.args.get("symbol", "").strip() + period = request.args.get("period", "15m").strip() + valid = {p["key"] for p in MARKET_PERIODS} + if period not in valid: + period = "15m" + ctp_st = {} + try: + from modules.ctp.vnpy_bridge import ctp_status + from modules.core.trading_context import get_trading_mode + + ctp_st = ctp_status(get_trading_mode(get_setting)) + except Exception: + pass + return render_template( + "market.html", + symbol=symbol, + period=period, + market_periods=MARKET_PERIODS, + quote_label=get_quote_source_label(ctp_connected=bool(ctp_st.get("connected"))), + ctp_connected=bool(ctp_st.get("connected")), + ) + + + @app.route("/api/kline") + @login_required + def api_kline(): + symbol = request.args.get("symbol", "").strip() + period = request.args.get("period", "15m").strip() + if not symbol: + return jsonify({"error": "请提供合约代码"}), 400 + try: + from modules.core.trading_context import get_trading_mode + + data = fetch_market_klines( + symbol, period, DB_PATH, prefer_ctp=False, + ) + except Exception as exc: + app.logger.warning("kline api failed: %s", exc) + return jsonify({"error": str(exc)}), 500 + if not data.get("chart_symbol"): + return jsonify({"error": "无法识别合约代码"}), 400 + if not data.get("bars"): + return jsonify({"error": "未获取到K线数据,请稍后重试或更换合约"}), 404 + return jsonify(data) + + + @app.route("/api/kline/stream") + @login_required + def api_kline_stream(): + from queue import Empty + + symbol = request.args.get("symbol", "").strip() + period = request.args.get("period", "15m").strip() + market_code = request.args.get("market_code", "").strip() + sina_code = request.args.get("sina_code", "").strip() + if not symbol: + return jsonify({"error": "请提供合约代码"}), 400 + + def generate(): + sub = kline_hub.subscribe(symbol, period, market_code, sina_code) + try: + kline_data = fetch_market_klines( + symbol, period, DB_PATH, prefer_ctp=False, + ) + if kline_data.get("bars"): + yield sse_format("kline", kline_data) + yield sse_format( + "quote", + build_market_quote_payload( + symbol, market_code, sina_code, prefer_sina=True, + ), + ) + while True: + try: + msg = sub.queue.get(timeout=20) + yield sse_format(msg["event"], msg["data"]) + except Empty: + yield ": heartbeat\n\n" + finally: + kline_hub.unsubscribe(sub) + + return Response( + stream_with_context(generate()), + mimetype="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + }, + ) + + + @app.route("/api/market_quote") + @login_required + def api_market_quote(): + symbol = request.args.get("symbol", "").strip() + market_code = request.args.get("market_code", "").strip() + sina_code = request.args.get("sina_code", "").strip() + if not symbol and not market_code: + return jsonify({"error": "请提供合约"}), 400 + return jsonify(build_market_quote_payload( + symbol, market_code, sina_code, prefer_sina=True, + )) + + + @app.route("/contract") + @login_required + def contract_profile_page(): + return redirect(url_for("positions")) + + + @app.route("/api/contract_profile") + @login_required + def api_contract_profile(): + return jsonify({"error": "品种简介功能已移除"}), 404 diff --git a/modules/notify/__init__.py b/modules/notify/__init__.py new file mode 100644 index 0000000..3ca1367 --- /dev/null +++ b/modules/notify/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) 2025-2026 马建军. All rights reserved. + +from modules.notify.routes import register + +__all__ = ["register"] diff --git a/ai_client.py b/modules/notify/ai_client.py similarity index 100% rename from ai_client.py rename to modules/notify/ai_client.py diff --git a/ai_messages.py b/modules/notify/ai_messages.py similarity index 100% rename from ai_messages.py rename to modules/notify/ai_messages.py diff --git a/ai_worker.py b/modules/notify/ai_worker.py similarity index 91% rename from ai_worker.py rename to modules/notify/ai_worker.py index 0f2ad64..c62ff37 100644 --- a/ai_worker.py +++ b/modules/notify/ai_worker.py @@ -1,173 +1,173 @@ -# Copyright (c) 2025-2026 马建军. All rights reserved. -# 专有软件 — 未经授权禁止复制、传播、转售。 -# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 -# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md - -"""AI 后台:开仓/平仓分析、日终持仓报告。""" -from __future__ import annotations - -import json -import logging -import threading -from datetime import datetime -from typing import Callable, Optional -from zoneinfo import ZoneInfo - -logger = logging.getLogger(__name__) -TZ = ZoneInfo("Asia/Shanghai") -DAILY_REPORT_KEY = "ai_daily_report_last_date" - - -def schedule_ai_event_analysis( - *, - db_path: str, - get_setting_fn: Callable[[str, str], str], - kind: str, - title: str, - payload: dict, - send_wechat_fn: Callable[[str], None] | None = None, -) -> None: - """后台线程:调用 AI 并写入 ai_messages。""" - if not (get_setting_fn("ai_enabled", "0") or "0").strip() in ("1", "true", "yes"): - return - - def _run() -> None: - from ai_client import analyze_trading_event - from ai_messages import insert_ai_message - from db_conn import connect_db - - ok, content = analyze_trading_event( - get_setting=get_setting_fn, - event_kind=kind, - payload=payload, - ) - if not ok: - content = f"⚠ {content}" - try: - conn = connect_db(db_path) - try: - insert_ai_message( - conn, - kind=kind, - title=title, - content=content, - meta=payload, - ) - conn.commit() - finally: - conn.close() - if send_wechat_fn and ok: - send_wechat_fn(f"🤖 AI 分析 · {title}\n\n{content[:1800]}") - except Exception as exc: - logger.warning("AI event analysis failed: %s", exc) - - threading.Thread(target=_run, daemon=True, name="ai-event").start() - - -def _today_trading_summary(conn, day: str) -> dict: - rows = conn.execute( - """SELECT symbol, symbol_name, direction, pnl_net, result, close_time - FROM trade_logs WHERE close_time LIKE ? ORDER BY id ASC""", - (f"{day}%",), - ).fetchall() - wins = losses = 0 - pnl_sum = 0.0 - trades = [] - for r in rows: - pnl = float(r["pnl_net"] or 0) - pnl_sum += pnl - if pnl >= 0: - wins += 1 - else: - losses += 1 - trades.append(dict(r)) - positions = conn.execute( - """SELECT symbol, symbol_name, direction, lots, entry_price, stop_loss, take_profit, monitor_type - FROM trade_order_monitors WHERE status='active'""" - ).fetchall() - return { - "date": day, - "trade_count": len(trades), - "wins": wins, - "losses": losses, - "pnl_net_total": round(pnl_sum, 2), - "trades": trades[:20], - "active_positions": [dict(p) for p in positions], - } - - -def maybe_run_daily_ai_report( - *, - db_path: str, - get_setting_fn: Callable[[str, str], str], - set_setting_fn: Callable[[str, str], None], - send_wechat_fn: Callable[[str], None] | None = None, -) -> None: - if not (get_setting_fn("ai_enabled", "0") or "0").strip() in ("1", "true", "yes"): - return - if (get_setting_fn("ai_daily_report_enabled", "1") or "1").strip() not in ("1", "true", "yes"): - return - now = datetime.now(TZ) - day = now.strftime("%Y-%m-%d") - if get_setting_fn(DAILY_REPORT_KEY, "") == day: - return - try: - hour = int(float(get_setting_fn("ai_daily_report_hour", "15") or 15)) - minute = int(float(get_setting_fn("ai_daily_report_minute", "5") or 5)) - except (TypeError, ValueError): - hour, minute = 15, 5 - if (now.hour, now.minute) < (hour, minute): - return - - from ai_client import analyze_trading_event - from ai_messages import insert_ai_message - from db_conn import connect_db - - try: - conn = connect_db(db_path) - try: - summary = _today_trading_summary(conn, day) - ok, content = analyze_trading_event( - get_setting=get_setting_fn, - event_kind="daily_report", - payload=summary, - ) - title = f"{day} 日终持仓与交易报告" - if not ok: - content = f"⚠ {content}" - insert_ai_message(conn, kind="daily_report", title=title, content=content, meta=summary) - conn.commit() - set_setting_fn(DAILY_REPORT_KEY, day) - if send_wechat_fn and ok: - send_wechat_fn(f"🤖 {title}\n\n{content[:1800]}") - finally: - conn.close() - except Exception as exc: - logger.warning("AI daily report failed: %s", exc) - - -def start_ai_worker( - *, - db_path: str, - get_setting_fn: Callable[[str, str], str], - set_setting_fn: Callable[[str, str], None], - send_wechat_fn: Callable[[str], None] | None = None, - interval_sec: int = 60, -) -> None: - import time - - def _loop() -> None: - time.sleep(30) - while True: - try: - maybe_run_daily_ai_report( - db_path=db_path, - get_setting_fn=get_setting_fn, - set_setting_fn=set_setting_fn, - send_wechat_fn=send_wechat_fn, - ) - except Exception as exc: - logger.debug("ai worker: %s", exc) - time.sleep(max(30, interval_sec)) - - threading.Thread(target=_loop, daemon=True, name="ai-worker").start() +# Copyright (c) 2025-2026 马建军. All rights reserved. +# 专有软件 — 未经授权禁止复制、传播、转售。 +# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 +# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md + +"""AI 后台:开仓/平仓分析、日终持仓报告。""" +from __future__ import annotations + +import json +import logging +import threading +from datetime import datetime +from typing import Callable, Optional +from zoneinfo import ZoneInfo + +logger = logging.getLogger(__name__) +TZ = ZoneInfo("Asia/Shanghai") +DAILY_REPORT_KEY = "ai_daily_report_last_date" + + +def schedule_ai_event_analysis( + *, + db_path: str, + get_setting_fn: Callable[[str, str], str], + kind: str, + title: str, + payload: dict, + send_wechat_fn: Callable[[str], None] | None = None, +) -> None: + """后台线程:调用 AI 并写入 ai_messages。""" + if not (get_setting_fn("ai_enabled", "0") or "0").strip() in ("1", "true", "yes"): + return + + def _run() -> None: + from modules.notify.ai_client import analyze_trading_event + from modules.notify.ai_messages import insert_ai_message + from modules.core.db_conn import connect_db + + ok, content = analyze_trading_event( + get_setting=get_setting_fn, + event_kind=kind, + payload=payload, + ) + if not ok: + content = f"⚠ {content}" + try: + conn = connect_db(db_path) + try: + insert_ai_message( + conn, + kind=kind, + title=title, + content=content, + meta=payload, + ) + conn.commit() + finally: + conn.close() + if send_wechat_fn and ok: + send_wechat_fn(f"🤖 AI 分析 · {title}\n\n{content[:1800]}") + except Exception as exc: + logger.warning("AI event analysis failed: %s", exc) + + threading.Thread(target=_run, daemon=True, name="ai-event").start() + + +def _today_trading_summary(conn, day: str) -> dict: + rows = conn.execute( + """SELECT symbol, symbol_name, direction, pnl_net, result, close_time + FROM trade_logs WHERE close_time LIKE ? ORDER BY id ASC""", + (f"{day}%",), + ).fetchall() + wins = losses = 0 + pnl_sum = 0.0 + trades = [] + for r in rows: + pnl = float(r["pnl_net"] or 0) + pnl_sum += pnl + if pnl >= 0: + wins += 1 + else: + losses += 1 + trades.append(dict(r)) + positions = conn.execute( + """SELECT symbol, symbol_name, direction, lots, entry_price, stop_loss, take_profit, monitor_type + FROM trade_order_monitors WHERE status='active'""" + ).fetchall() + return { + "date": day, + "trade_count": len(trades), + "wins": wins, + "losses": losses, + "pnl_net_total": round(pnl_sum, 2), + "trades": trades[:20], + "active_positions": [dict(p) for p in positions], + } + + +def maybe_run_daily_ai_report( + *, + db_path: str, + get_setting_fn: Callable[[str, str], str], + set_setting_fn: Callable[[str, str], None], + send_wechat_fn: Callable[[str], None] | None = None, +) -> None: + if not (get_setting_fn("ai_enabled", "0") or "0").strip() in ("1", "true", "yes"): + return + if (get_setting_fn("ai_daily_report_enabled", "1") or "1").strip() not in ("1", "true", "yes"): + return + now = datetime.now(TZ) + day = now.strftime("%Y-%m-%d") + if get_setting_fn(DAILY_REPORT_KEY, "") == day: + return + try: + hour = int(float(get_setting_fn("ai_daily_report_hour", "15") or 15)) + minute = int(float(get_setting_fn("ai_daily_report_minute", "5") or 5)) + except (TypeError, ValueError): + hour, minute = 15, 5 + if (now.hour, now.minute) < (hour, minute): + return + + from modules.notify.ai_client import analyze_trading_event + from modules.notify.ai_messages import insert_ai_message + from modules.core.db_conn import connect_db + + try: + conn = connect_db(db_path) + try: + summary = _today_trading_summary(conn, day) + ok, content = analyze_trading_event( + get_setting=get_setting_fn, + event_kind="daily_report", + payload=summary, + ) + title = f"{day} 日终持仓与交易报告" + if not ok: + content = f"⚠ {content}" + insert_ai_message(conn, kind="daily_report", title=title, content=content, meta=summary) + conn.commit() + set_setting_fn(DAILY_REPORT_KEY, day) + if send_wechat_fn and ok: + send_wechat_fn(f"🤖 {title}\n\n{content[:1800]}") + finally: + conn.close() + except Exception as exc: + logger.warning("AI daily report failed: %s", exc) + + +def start_ai_worker( + *, + db_path: str, + get_setting_fn: Callable[[str, str], str], + set_setting_fn: Callable[[str, str], None], + send_wechat_fn: Callable[[str], None] | None = None, + interval_sec: int = 60, +) -> None: + import time + + def _loop() -> None: + time.sleep(30) + while True: + try: + maybe_run_daily_ai_report( + db_path=db_path, + get_setting_fn=get_setting_fn, + set_setting_fn=set_setting_fn, + send_wechat_fn=send_wechat_fn, + ) + except Exception as exc: + logger.debug("ai worker: %s", exc) + time.sleep(max(30, interval_sec)) + + threading.Thread(target=_loop, daemon=True, name="ai-worker").start() diff --git a/modules/notify/routes.py b/modules/notify/routes.py new file mode 100644 index 0000000..ca979e2 --- /dev/null +++ b/modules/notify/routes.py @@ -0,0 +1,65 @@ +# Copyright (c) 2025-2026 马建军. All rights reserved. +"""HTTP routes for notify module.""" + +from __future__ import annotations + +from datetime import date, datetime + +from flask import ( + Response, + flash, + jsonify, + redirect, + render_template, + request, + send_file, + session, + stream_with_context, + url_for, +) + + +def register(deps) -> None: + app = deps.app + login_required = deps.login_required + require_nav = deps.require_nav + get_db = deps.get_db + get_setting = deps.get_setting + set_setting = deps.set_setting + fetch_price = deps.fetch_price + send_wechat_msg = deps.send_wechat_msg + touch_stats_cache = deps.touch_stats_cache + get_stats_data = deps.get_stats_data + build_market_quote_payload = deps.build_market_quote_payload + today_str = deps.today_str + expire_old_plans = deps.expire_old_plans + TZ = deps.tz + DB_PATH = deps.db_path + UPLOAD_DIR = deps.upload_dir + OPEN_TYPES = deps.open_types + EXIT_TRIGGERS = deps.exit_triggers + BEHAVIOR_TAGS = deps.behavior_tags + KLINE_PERIODS = deps.kline_periods + KLINE_CUTOFFS = deps.kline_cutoffs + calc_holding_duration = deps.calc_holding_duration + holding_to_minutes = deps.holding_to_minutes + classify_close_result = deps.classify_close_result + calc_rr_ratio = deps.calc_rr_ratio + calc_theoretical_pnl = deps.calc_theoretical_pnl + parse_review_date_filter = deps.parse_review_date_filter + _trading_mode = deps.trading_mode + _ua_is_phone = deps.ua_is_phone + _static_asset_v = deps.static_asset_v + + @app.route("/ai") + @login_required + @require_nav("ai") + def ai_messages_page(): + from modules.notify.ai_messages import list_ai_messages + + conn = get_db() + try: + messages = list_ai_messages(conn, limit=100) + finally: + conn.close() + return render_template("ai_messages.html", messages=messages) diff --git a/wechat_notify.py b/modules/notify/wechat_notify.py similarity index 100% rename from wechat_notify.py rename to modules/notify/wechat_notify.py diff --git a/modules/plans/__init__.py b/modules/plans/__init__.py new file mode 100644 index 0000000..49a0c13 --- /dev/null +++ b/modules/plans/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) 2025-2026 马建军. All rights reserved. + +from modules.plans.routes import register + +__all__ = ["register"] diff --git a/modules/plans/routes.py b/modules/plans/routes.py new file mode 100644 index 0000000..0cf84b2 --- /dev/null +++ b/modules/plans/routes.py @@ -0,0 +1,167 @@ +# Copyright (c) 2025-2026 马建军. All rights reserved. +"""HTTP routes for plans module.""" + +from __future__ import annotations + +from datetime import date, datetime + +from flask import ( + Response, + flash, + jsonify, + redirect, + render_template, + request, + send_file, + session, + stream_with_context, + url_for, +) + + +def register(deps) -> None: + app = deps.app + login_required = deps.login_required + require_nav = deps.require_nav + get_db = deps.get_db + get_setting = deps.get_setting + set_setting = deps.set_setting + fetch_price = deps.fetch_price + send_wechat_msg = deps.send_wechat_msg + touch_stats_cache = deps.touch_stats_cache + get_stats_data = deps.get_stats_data + build_market_quote_payload = deps.build_market_quote_payload + today_str = deps.today_str + expire_old_plans = deps.expire_old_plans + TZ = deps.tz + DB_PATH = deps.db_path + UPLOAD_DIR = deps.upload_dir + OPEN_TYPES = deps.open_types + EXIT_TRIGGERS = deps.exit_triggers + BEHAVIOR_TAGS = deps.behavior_tags + KLINE_PERIODS = deps.kline_periods + KLINE_CUTOFFS = deps.kline_cutoffs + calc_holding_duration = deps.calc_holding_duration + holding_to_minutes = deps.holding_to_minutes + classify_close_result = deps.classify_close_result + calc_rr_ratio = deps.calc_rr_ratio + calc_theoretical_pnl = deps.calc_theoretical_pnl + parse_review_date_filter = deps.parse_review_date_filter + _trading_mode = deps.trading_mode + _ua_is_phone = deps.ua_is_phone + _static_asset_v = deps.static_asset_v + + @app.route("/api/plan_prices") + @login_required + def api_plan_prices(): + """今日计划:批量现价与距决策区间上/下沿距离。""" + today = today_str() + conn = get_db() + rows = conn.execute( + "SELECT id, symbol, market_code, sina_code, zone_upper, zone_lower " + "FROM order_plans WHERE plan_date=? AND status IN ('planned', 'active')", + (today,), + ).fetchall() + conn.close() + out = [] + for r in rows: + sym = r["symbol"] + market = r["market_code"] or "" + sina = r["sina_code"] or "" + upper = float(r["zone_upper"]) + lower = float(r["zone_lower"]) + price = fetch_price(sym, market, sina) + dist_upper = None + dist_lower = None + in_zone = False + if price is not None: + dist_upper = round(upper - price, 2) + dist_lower = round(price - lower, 2) + in_zone = lower <= price <= upper + out.append({ + "id": r["id"], + "price": price, + "dist_upper": dist_upper, + "dist_lower": dist_lower, + "in_zone": in_zone, + }) + return jsonify(out) + @app.route("/plans") + @login_required + @require_nav("plans") + def plans(): + today = today_str() + start = request.args.get("start", "") + end = request.args.get("end", "") + + conn = get_db() + plan_list = conn.execute( + "SELECT * FROM order_plans WHERE plan_date=? AND status IN ('planned', 'active') ORDER BY id DESC", + (today,), + ).fetchall() + + sql = "SELECT * FROM order_plans WHERE plan_date < ? OR status IN ('closed', 'expired')" + params: list = [today] + if start: + sql += " AND plan_date >= ?" + params.append(start) + if end: + sql += " AND plan_date <= ?" + params.append(end) + sql += " ORDER BY plan_date DESC, id DESC LIMIT 200" + history = conn.execute(sql, params).fetchall() + conn.close() + return render_template( + "plans.html", + plans=plan_list, + history=history, + today=today, + start=start, + end=end, + ) + + + @app.route("/add_plan", methods=["POST"]) + @login_required + def add_plan(): + d = request.form + direction = d.get("direction") + symbol = d.get("symbol", "").strip() + symbol_name = d.get("symbol_name", "").strip() + market_code = d.get("market_code", "").strip() + sina_code = d.get("sina_code", "").strip() + if not direction: + flash("请选择多空方向") + return redirect(url_for("plans")) + if not symbol or not market_code: + flash("请从下拉列表选择品种(同花顺合约代码)") + return redirect(url_for("plans")) + conn = get_db() + conn.execute( + """INSERT INTO order_plans + (symbol, symbol_name, market_code, sina_code, direction, + zone_upper, zone_lower, stop_loss, take_profit, plan_date, decision_reason) + VALUES (?,?,?,?,?,?,?,?,?,?,?)""", + ( + symbol, symbol_name, market_code, sina_code, direction, + float(d["zone_upper"]), float(d["zone_lower"]), + float(d["stop_loss"]), float(d["take_profit"]), + today_str(), + d.get("decision_reason", "").strip(), + ), + ) + conn.commit() + conn.close() + flash("开单计划已添加") + return redirect(url_for("plans")) + + + @app.route("/del_plan/") + @login_required + def del_plan(pid): + conn = get_db() + conn.execute("DELETE FROM order_plans WHERE id=?", (pid,)) + conn.commit() + conn.close() + flash("已删除") + return redirect(url_for("plans")) diff --git a/modules/records/__init__.py b/modules/records/__init__.py new file mode 100644 index 0000000..528f339 --- /dev/null +++ b/modules/records/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) 2025-2026 马建军. All rights reserved. + +from modules.records.routes import register + +__all__ = ["register"] diff --git a/modules/records/routes.py b/modules/records/routes.py new file mode 100644 index 0000000..67516a9 --- /dev/null +++ b/modules/records/routes.py @@ -0,0 +1,554 @@ +# Copyright (c) 2025-2026 马建军. All rights reserved. +"""HTTP routes for records module.""" + +from __future__ import annotations + +from datetime import date, datetime + +from flask import ( + Response, + flash, + jsonify, + redirect, + render_template, + request, + send_file, + session, + stream_with_context, + url_for, +) + + +def register(deps) -> None: + app = deps.app + login_required = deps.login_required + require_nav = deps.require_nav + get_db = deps.get_db + get_setting = deps.get_setting + set_setting = deps.set_setting + fetch_price = deps.fetch_price + send_wechat_msg = deps.send_wechat_msg + touch_stats_cache = deps.touch_stats_cache + get_stats_data = deps.get_stats_data + build_market_quote_payload = deps.build_market_quote_payload + today_str = deps.today_str + expire_old_plans = deps.expire_old_plans + TZ = deps.tz + DB_PATH = deps.db_path + UPLOAD_DIR = deps.upload_dir + OPEN_TYPES = deps.open_types + EXIT_TRIGGERS = deps.exit_triggers + BEHAVIOR_TAGS = deps.behavior_tags + KLINE_PERIODS = deps.kline_periods + KLINE_CUTOFFS = deps.kline_cutoffs + calc_holding_duration = deps.calc_holding_duration + holding_to_minutes = deps.holding_to_minutes + classify_close_result = deps.classify_close_result + calc_rr_ratio = deps.calc_rr_ratio + calc_theoretical_pnl = deps.calc_theoretical_pnl + parse_review_date_filter = deps.parse_review_date_filter + _trading_mode = deps.trading_mode + _ua_is_phone = deps.ua_is_phone + _static_asset_v = deps.static_asset_v + + from werkzeug.utils import secure_filename + from modules.core.contract_specs import calc_position_metrics + from modules.fees.fee_specs import calc_fee_breakdown, calc_round_trip_fee + from modules.market.kline_chart import generate_review_kline_chart + + @app.route("/api/position_live") + @login_required + def api_position_live(): + capital = float(get_setting("live_capital", "0") or 0) + now_iso = datetime.now(TZ).strftime("%Y-%m-%dT%H:%M") + conn = get_db() + rows = conn.execute( + "SELECT * FROM position_monitors WHERE status='active' ORDER BY id DESC" + ).fetchall() + conn.close() + out = [] + for r in rows: + sym = r["symbol"] + market = r["market_code"] or "" + sina = r["sina_code"] or "" + direction = r["direction"] + entry = float(r["entry_price"]) + sl = float(r["stop_loss"]) + tp = float(r["take_profit"]) + lots = float(r["lots"] or 1) + mark = fetch_price(sym, market, sina) + metrics = calc_position_metrics( + direction, entry, sl, tp, lots, mark, capital, sym, + ) + holding = calc_holding_duration(r["open_time"] or "", now_iso) + close_est = mark if mark is not None else entry + fee_info = calc_fee_breakdown( + sym, entry, close_est, lots, r["open_time"] or "", now_iso, + trading_mode=_trading_mode(), + ) + est_net = None + if metrics.get("float_pnl") is not None: + est_net = round(metrics["float_pnl"] - fee_info["total_fee"], 2) + out.append({ + "id": r["id"], + "symbol": r["symbol_name"] or sym, + "symbol_code": sym, + "direction": "做多" if direction == "long" else "做空", + "lots": lots, + "entry_price": entry, + "stop_loss": sl, + "take_profit": tp, + "open_time": r["open_time"], + "mark_price": mark, + "holding_duration": holding, + "est_fee": fee_info["total_fee"], + "est_fee_open": fee_info["open_fee"], + "est_fee_close": fee_info["close_fee"], + "est_fee_close_type": fee_info["close_type"], + "est_pnl_net": est_net, + **metrics, + }) + return jsonify(out) + @app.route("/close_position/", methods=["POST"]) + @login_required + def close_position(pid): + conn = get_db() + row = conn.execute("SELECT * FROM position_monitors WHERE id=?", (pid,)).fetchone() + if not row: + conn.close() + flash("持仓不存在") + return redirect(url_for("positions")) + sym = row["symbol"] + market = row["market_code"] or "" + sina = row["sina_code"] or "" + direction = row["direction"] + entry = float(row["entry_price"]) + sl = float(row["stop_loss"]) + tp = float(row["take_profit"]) + lots = float(row["lots"] or 1) + open_time = row["open_time"] or "" + close_time = datetime.now(TZ).strftime("%Y-%m-%dT%H:%M") + close_price = fetch_price(sym, market, sina) + if close_price is None: + conn.close() + flash("无法获取现价,平仓失败") + return redirect(url_for("positions")) + capital = float(get_setting("live_capital", "0") or 0) + metrics = calc_position_metrics(direction, entry, sl, tp, lots, close_price, capital, sym) + pnl = metrics.get("float_pnl") or 0.0 + fee = calc_round_trip_fee(sym, entry, close_price, lots, open_time, close_time, trading_mode=_trading_mode()) + pnl_net = round(pnl - fee, 2) + result = classify_close_result(direction, close_price, sl, tp) + minutes = holding_to_minutes(open_time, close_time) + margin_pct = metrics.get("position_pct") + from modules.trading.trade_log_lib import calc_equity_after, refresh_trade_log_equity_chain + equity_after = calc_equity_after(capital, pnl_net) + conn.execute( + """INSERT INTO trade_logs + (symbol, symbol_name, market_code, sina_code, monitor_type, direction, + entry_price, stop_loss, take_profit, close_price, lots, margin, + margin_pct, holding_minutes, open_time, close_time, pnl, fee, pnl_net, + equity_after, result) + VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)""", + ( + sym, row["symbol_name"], market, sina, "持仓监控", direction, + entry, sl, tp, close_price, lots, metrics["margin"], + margin_pct, + minutes, open_time, close_time, pnl, fee, pnl_net, equity_after, result, + ), + ) + conn.execute("DELETE FROM position_monitors WHERE id=?", (pid,)) + try: + refresh_trade_log_equity_chain(conn, capital if capital > 0 else None) + except Exception as exc: + app.logger.debug("equity chain refresh after close: %s", exc) + conn.commit() + conn.close() + touch_stats_cache() + flash(f"已平仓,盈亏 {pnl:.2f} 元(扣费后 {pnl_net:.2f} 元),已记入交易记录") + return redirect(url_for("positions")) + + + @app.route("/trades") + @login_required + def trades(): + return redirect(url_for("records")) + + + @app.route("/update_trade/", methods=["POST"]) + @login_required + def update_trade(tid): + d = request.form + conn = get_db() + row = conn.execute("SELECT * FROM trade_logs WHERE id=?", (tid,)).fetchone() + if not row: + conn.close() + flash("记录不存在") + return redirect(url_for("records")) + row = dict(row) + entry = float(d.get("entry_price") or 0) + close_px = float(d.get("close_price") or 0) + lots = float(d.get("lots") or 0) + sl_raw = d.get("stop_loss") + tp_raw = d.get("take_profit") + stop_loss = float(sl_raw) if sl_raw not in (None, "") else None + take_profit = float(tp_raw) if tp_raw not in (None, "") else None + open_time = (d.get("open_time") or row.get("open_time") or "").strip() + close_time = (d.get("close_time") or row.get("close_time") or "").strip() + direction = (d.get("direction") or row.get("direction") or "long").strip() + + from modules.trading.trade_log_lib import recalc_trade_log_pnl, refresh_trade_log_equity_chain, _read_initial_capital + from modules.core.trading_context import get_trading_mode + + pnl = float(row.get("pnl") or 0) + fee = float(row.get("fee") or 0) + pnl_net = float(row.get("pnl_net") or 0) + old_entry = float(row.get("entry_price") or 0) + old_close = float(row.get("close_price") or 0) + old_lots = float(row.get("lots") or 0) + prices_changed = ( + abs(entry - old_entry) > 0.0001 + or abs(close_px - old_close) > 0.0001 + or abs(lots - old_lots) > 0.0001 + ) + if prices_changed and close_px > 0 and entry > 0 and lots > 0: + calc = recalc_trade_log_pnl( + symbol=row.get("symbol") or "", + direction=direction, + entry_price=entry, + close_price=close_px, + lots=lots, + stop_loss=stop_loss, + take_profit=take_profit, + open_time=open_time, + close_time=close_time, + trading_mode=get_trading_mode(get_setting), + ) + pnl = calc["pnl"] + fee = calc["fee"] + pnl_net = calc["pnl_net"] + + form_pnl_raw = d.get("pnl") + if form_pnl_raw not in (None, ""): + pnl = float(form_pnl_raw) + pnl_net = round(pnl - fee, 2) + + try: + holding_to_minutes = deps.holding_to_minutes + minutes = int(holding_to_minutes(open_time, close_time) or 0) + except Exception: + minutes = int(d.get("holding_minutes") or row.get("holding_minutes") or 0) + + conn.execute( + """UPDATE trade_logs SET + symbol_name=?, monitor_type=?, direction=?, + entry_price=?, stop_loss=?, take_profit=?, close_price=?, + lots=?, margin=?, holding_minutes=?, open_time=?, close_time=?, + pnl=?, fee=?, pnl_net=?, result=?, verified=1 + WHERE id=?""", + ( + d.get("symbol_name", "").strip(), + d.get("monitor_type", "").strip(), + direction, + entry, + stop_loss, + take_profit, + close_px, + lots, + float(d.get("margin") or 0), + minutes, + open_time, + close_time, + pnl, + fee, + pnl_net, + d.get("result", "").strip(), + tid, + ), + ) + try: + refresh_trade_log_equity_chain(conn, _read_initial_capital(conn)) + except Exception as exc: + app.logger.debug("equity chain refresh after trade edit: %s", exc) + conn.commit() + conn.close() + touch_stats_cache() + flash("交易记录已核对保存") + return redirect(url_for("records")) + + + @app.route("/del_trade/") + @login_required + def del_trade(tid): + conn = get_db() + conn.execute("DELETE FROM trade_logs WHERE id=?", (tid,)) + conn.commit() + conn.close() + touch_stats_cache() + flash("已删除") + return redirect(url_for("records")) + + + @app.route("/fill_review/") + @login_required + def fill_review_from_trade(tid): + conn = get_db() + row = conn.execute("SELECT * FROM trade_logs WHERE id=?", (tid,)).fetchone() + conn.close() + if not row: + flash("记录不存在") + return redirect(url_for("records")) + q = { + "symbol": row["symbol"], + "symbol_name": row["symbol_name"] or row["symbol"], + "market_code": row["market_code"] or "", + "sina_code": row["sina_code"] or "", + "direction": row["direction"], + "entry_price": row["entry_price"], + "stop_loss": row["stop_loss"], + "take_profit": row["take_profit"], + "close_price": row["close_price"], + "lots": row["lots"], + "open_time": row["open_time"], + "close_time": row["close_time"], + "pnl": row["pnl"], + } + params = {k: v for k, v in q.items() if v is not None} + return redirect(url_for("records", **params) + "#review-panel") + + + @app.route("/records") + @login_required + def records(): + preset = request.args.get("preset", "") + start = request.args.get("start", "") + end = request.args.get("end", "") + if preset: + start, end = parse_review_date_filter(preset, start, end) + + conn = get_db() + ctp_sync_info = None + sql = "SELECT * FROM review_records WHERE 1=1" + params: list = [] + if start: + sql += " AND date(close_time) >= ?" + params.append(start) + if end: + sql += " AND date(close_time) <= ?" + params.append(end) + sql += " ORDER BY id DESC LIMIT 200" + review_list = conn.execute(sql, params).fetchall() + + auto_list = conn.execute( + "SELECT * FROM trade_records ORDER BY id DESC LIMIT 30" + ).fetchall() + trade_list = conn.execute( + "SELECT * FROM trade_logs ORDER BY id DESC LIMIT 500" + ).fetchall() + from modules.trading.trade_log_lib import enrich_trades_for_records, _read_initial_capital + try: + initial_capital = _read_initial_capital(conn) + except Exception: + initial_capital = 100_000.0 + trades, equity_curve = enrich_trades_for_records( + [dict(r) for r in trade_list], + initial_capital=initial_capital, + ) + conn.close() + + trade_prefill_keys = ( + "symbol", "symbol_name", "market_code", "sina_code", "direction", + "entry_price", "stop_loss", "take_profit", "close_price", + "lots", "open_time", "close_time", "pnl", + ) + prefill = {k: request.args.get(k) for k in trade_prefill_keys if request.args.get(k)} + + return render_template( + "records.html", + reviews=review_list, + trades=trades, + equity_curve=equity_curve, + auto_records=auto_list, + ctp_sync_info=ctp_sync_info, + preset=preset, + start=start, + end=end, + prefill=prefill, + open_types=OPEN_TYPES, + exit_triggers=EXIT_TRIGGERS, + behavior_tags=BEHAVIOR_TAGS, + kline_periods=KLINE_PERIODS, + kline_cutoffs=KLINE_CUTOFFS, + ) + + + @app.route("/add_review", methods=["POST"]) + @login_required + def add_review(): + d = request.form + open_type = d.get("open_type", "").strip() + exit_trigger = d.get("exit_trigger", "").strip() + if not open_type: + flash("请选择开仓类型") + return redirect(url_for("records")) + if not exit_trigger: + flash("请选择离场触发") + return redirect(url_for("records")) + + symbol = d.get("symbol", "").strip() + symbol_name = d.get("symbol_name", "").strip() + market_code = d.get("market_code", "").strip() + sina_code = d.get("sina_code", "").strip() + if not symbol or not market_code: + flash("请从下拉列表选择品种(同花顺合约代码)") + return redirect(url_for("records")) + + screenshot = "" + f = request.files.get("screenshot") + if f and f.filename: + fname = secure_filename(f.filename) + ts = datetime.now(TZ).strftime("%Y%m%d%H%M%S") + screenshot = f"{ts}_{fname}" + f.save(os.path.join(UPLOAD_DIR, screenshot)) + + tags = [t for t in BEHAVIOR_TAGS if d.get(f"tag_{t}")] + is_emotion = 1 if tags else 0 + + def num(key: str) -> Optional[float]: + v = d.get(key, "").strip() + if not v: + return None + return float(v) + + open_time = d.get("open_time", "").strip() + close_time = d.get("close_time", "").strip() + direction = d.get("direction", "").strip() + entry_price = num("entry_price") + stop_loss = num("stop_loss") + take_profit = num("take_profit") + close_price = num("close_price") + lots = num("lots") or 1.0 + + holding = calc_holding_duration(open_time, close_time) + initial_pnl = calc_rr_ratio(direction, entry_price, stop_loss, take_profit) + actual_pnl = calc_rr_ratio(direction, entry_price, stop_loss, close_price) + + gross_pnl = num("pnl") + if gross_pnl is None and entry_price and close_price: + spec_mult = calc_position_metrics( + direction, entry_price, stop_loss, take_profit, + lots, close_price, 0, symbol, + ) + gross_pnl = spec_mult.get("float_pnl") + fee = calc_round_trip_fee( + symbol, entry_price or 0, close_price or 0, lots, open_time, close_time, + trading_mode=_trading_mode(), + ) + pnl_net = round((gross_pnl or 0) - fee, 2) if gross_pnl is not None else None + + auto_kline = bool(d.get("auto_kline")) + if auto_kline and not screenshot: + try: + generated = generate_review_kline_chart( + symbol=symbol, + periods=[d.get("kline_period1", "15m"), d.get("kline_period2", "1h")], + count=int(d.get("kline_count") or 300), + cutoff_label=d.get("kline_cutoff", "平仓时间"), + open_time=open_time, + close_time=close_time, + entry_price=entry_price, + stop_loss=stop_loss, + take_profit=take_profit, + close_price=close_price, + upload_dir=UPLOAD_DIR, + ) + if generated: + screenshot = generated + except Exception as exc: + app.logger.warning("auto kline failed: %s", exc) + + conn = get_db() + conn.execute( + """INSERT INTO review_records + (open_time, close_time, symbol, symbol_name, market_code, sina_code, + timeframe, direction, + entry_price, stop_loss, take_profit, close_price, lots, + holding_duration, initial_pnl, actual_pnl, pnl, fee, pnl_net, + open_type, expected_rr, actual_rr, exit_trigger, exit_supplement, + watch_after_breakeven, new_position_while_occupied, screenshot, + auto_kline, kline_period1, kline_period2, kline_count, kline_cutoff, + behavior_tags, is_emotion, notes) + VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)""", + ( + open_time, close_time, + symbol, symbol_name, market_code, sina_code, + d.get("timeframe", "").strip(), + direction, + entry_price, stop_loss, take_profit, close_price, lots, + holding, initial_pnl, actual_pnl, gross_pnl, fee, pnl_net, + open_type, + None, + None, + exit_trigger, + d.get("exit_supplement", "").strip(), + d.get("watch_after_breakeven", "否"), + d.get("new_position_while_occupied", "否"), + screenshot, + 1 if auto_kline else 0, + d.get("kline_period1", "15m"), + d.get("kline_period2", "1h"), + int(d.get("kline_count") or 300), + d.get("kline_cutoff", "平仓时间"), + ",".join(tags), + is_emotion, + d.get("notes", "").strip(), + ), + ) + hook = getattr(app, "_risk_review_hook", None) + if hook: + hook( + conn, + ",".join(tags), + exit_trigger, + d.get("exit_supplement", "").strip(), + ) + conn.commit() + conn.close() + touch_stats_cache() + flash("复盘记录已保存") + return redirect(url_for("records")) + + + @app.route("/del_review/") + @login_required + def del_review(rid): + conn = get_db() + row = conn.execute("SELECT screenshot FROM review_records WHERE id=?", (rid,)).fetchone() + if row and row["screenshot"]: + path = os.path.join(UPLOAD_DIR, row["screenshot"]) + if os.path.isfile(path): + os.remove(path) + conn.execute("DELETE FROM review_records WHERE id=?", (rid,)) + conn.commit() + conn.close() + touch_stats_cache() + flash("已删除") + return redirect(url_for("records")) + + + @app.route("/uploads/") + @login_required + def uploaded_file(filename): + from flask import send_from_directory + return send_from_directory(UPLOAD_DIR, filename) + + + @app.route("/del_record/") + @login_required + def del_record(rid): + conn = get_db() + conn.execute("DELETE FROM trade_records WHERE id=?", (rid,)) + conn.commit() + conn.close() + flash("已删除") + return redirect(url_for("records")) diff --git a/modules/risk/__init__.py b/modules/risk/__init__.py new file mode 100644 index 0000000..ab29688 --- /dev/null +++ b/modules/risk/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) 2025-2026 马建军. All rights reserved. + +"""Account risk rules.""" + +from modules.risk.account_risk_lib import * # noqa: F401,F403 + + +def register(deps) -> None: + del deps + + +__all__ = ["register"] diff --git a/risk/account_risk_lib.py b/modules/risk/account_risk_lib.py similarity index 95% rename from risk/account_risk_lib.py rename to modules/risk/account_risk_lib.py index 10254d5..86ced4a 100644 --- a/risk/account_risk_lib.py +++ b/modules/risk/account_risk_lib.py @@ -1,450 +1,450 @@ -# Copyright (c) 2025-2026 马建军. All rights reserved. -# 专有软件 — 未经授权禁止复制、传播、转售。 -# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 -# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md - -"""账户冷静期 / 日冻结(自 crypto_monitor 复制并简化为单账户期货版)。""" -from __future__ import annotations - -import os -import time -from datetime import datetime -from typing import Any, Callable, Optional, TypeVar -from zoneinfo import ZoneInfo - -from db_conn import OperationalError, is_missing_relation_error, rollback_if_postgres - -T = TypeVar("T") - -STATUS_NORMAL = "normal" -STATUS_FREEZE_1H = "freeze_1h" -STATUS_FREEZE_4H = "freeze_4h" -STATUS_DAILY = "freeze_daily" -STATUS_FREEZE_POSITION = "freeze_position" - -STATUS_LABELS = { - STATUS_NORMAL: "正常", - STATUS_FREEZE_1H: "1h冻结", - STATUS_FREEZE_4H: "4h冻结", - STATUS_DAILY: "日冻结", - STATUS_FREEZE_POSITION: "仓位上限冻结", -} - -MOOD_ISSUE_OPTIONS = ( - "怕踏空", "报复开仓", "盈利飘了", "拿不住单", "扛单", "重仓违规", -) - -CLOSE_SOURCE_USER = "user_instance" -CLOSE_SOURCE_TREND_STOP = "user_trend_stop" - - -def _app_tz(): - name = (os.getenv("APP_TIMEZONE") or "Asia/Shanghai").strip() - try: - return ZoneInfo(name) - except Exception: - return ZoneInfo("Asia/Shanghai") - - -def risk_control_enabled() -> bool: - raw = (os.getenv("RISK_CONTROL_ENABLED") or "true").strip().lower() - return raw in ("1", "true", "yes", "on") - - -def cooling_hours_manual() -> float: - """期货版不使用应用层冷静期(交易所自有规则),恒为 0。""" - return 0.0 - - -def cooling_hours_manual_journal() -> float: - return 0.0 - - -def manual_close_daily_limit() -> int: - try: - return max(1, int(os.getenv("RISK_MANUAL_CLOSE_DAILY_LIMIT", "2"))) - except (TypeError, ValueError): - return 2 - - -def max_active_positions() -> int: - try: - return max(1, int(os.getenv("MAX_ACTIVE_POSITIONS", "1"))) - except (TypeError, ValueError): - return 1 - - -def daily_position_limit() -> int: - """当日最多开仓次数(含已平)。""" - try: - return max(1, int(os.getenv("RISK_DAILY_POSITION_LIMIT", "5"))) - except (TypeError, ValueError): - return 5 - - -def daily_trading_risk_pct_limit() -> float: - """当日累计止损风险占权益上限(%)。""" - try: - return max(0.1, float(os.getenv("RISK_DAILY_TRADING_RISK_PCT", "2"))) - except (TypeError, ValueError): - return 2.0 - - -def trading_day_reset_hour() -> int: - try: - return max(0, min(23, int(os.getenv("TRADING_DAY_RESET_HOUR", "8")))) - except (TypeError, ValueError): - return 8 - - -_SCHEMA_READY = False - -ACCOUNT_RISK_STATE_SQL = """ -CREATE TABLE IF NOT EXISTS account_risk_state ( - id INTEGER PRIMARY KEY, - trading_day TEXT, - manual_close_count INTEGER DEFAULT 0, - cooloff_until_ms INTEGER, - cooloff_hours INTEGER, - daily_frozen INTEGER DEFAULT 0, - last_close_at_ms INTEGER, - updated_at TEXT -) -""" - - -def _account_risk_table_exists(conn) -> bool: - try: - conn.execute("SELECT 1 FROM account_risk_state WHERE id=1") - return True - except Exception as exc: - if is_missing_relation_error(exc): - rollback_if_postgres(conn) - return False - raise - - -def _db_retry(action: Callable[[], T], *, retries: int = 8, base_delay: float = 0.03) -> T: - last: OperationalError | None = None - for i in range(retries): - try: - return action() - except OperationalError as exc: - msg = str(exc).lower() - if "locked" not in msg and "serialize" not in msg and "deadlock" not in msg: - raise - last = exc - time.sleep(base_delay * (2 ** i)) - if last is not None: - raise last - raise RuntimeError("db retry failed") - - -def ensure_account_risk_schema(conn) -> None: - global _SCHEMA_READY - if _SCHEMA_READY and _account_risk_table_exists(conn): - return - _SCHEMA_READY = False - conn.execute(ACCOUNT_RISK_STATE_SQL) - conn.commit() - if not conn.execute("SELECT 1 FROM account_risk_state WHERE id=1").fetchone(): - conn.execute( - "INSERT INTO account_risk_state (id, trading_day, manual_close_count, daily_frozen) " - "VALUES (1, '', 0, 0)" - ) - conn.commit() - _SCHEMA_READY = True - - -def _row_get(row, key, default=None): - if row is None: - return default - try: - return row[key] - except (KeyError, IndexError, TypeError): - return default - - -def _now_ms(now: Optional[datetime] = None) -> int: - dt = now or datetime.now(_app_tz()) - if dt.tzinfo is None: - dt = dt.replace(tzinfo=_app_tz()) - return int(dt.timestamp() * 1000) - - -def trading_day_label(now: Optional[datetime] = None) -> str: - dt = now or datetime.now(_app_tz()) - if dt.hour < trading_day_reset_hour(): - from datetime import timedelta - dt = dt - timedelta(days=1) - return dt.date().isoformat() - - -def trading_day_start(now: Optional[datetime] = None) -> datetime: - dt = now or datetime.now(_app_tz()) - if dt.tzinfo is None: - dt = dt.replace(tzinfo=_app_tz()) - reset_h = trading_day_reset_hour() - start = dt.replace(hour=reset_h, minute=0, second=0, microsecond=0) - if dt.hour < reset_h: - from datetime import timedelta - start = start - timedelta(days=1) - return start - - -def _parse_open_time_ms(open_time: str) -> Optional[int]: - s = (open_time or "").strip().replace("T", " ")[:19] - if not s: - return None - try: - dt = datetime.strptime(s, "%Y-%m-%d %H:%M:%S") - if dt.tzinfo is None: - dt = dt.replace(tzinfo=_app_tz()) - return int(dt.timestamp() * 1000) - except ValueError: - try: - dt = datetime.strptime(s[:10], "%Y-%m-%d").replace(tzinfo=_app_tz()) - return int(dt.timestamp() * 1000) - except ValueError: - return None - - -def _opened_in_trading_day(open_time: str, now: Optional[datetime] = None) -> bool: - oms = _parse_open_time_ms(open_time) - if oms is None: - return False - return oms >= int(trading_day_start(now).timestamp() * 1000) - - -def count_daily_opens(conn, now: Optional[datetime] = None) -> int: - rows = conn.execute( - "SELECT open_time FROM trade_order_monitors " - "WHERE open_time IS NOT NULL AND trim(open_time) <> ''" - ).fetchall() - return sum(1 for r in rows if _opened_in_trading_day(r["open_time"], now)) - - -def daily_trading_risk_used_pct( - conn, equity: float, now: Optional[datetime] = None, -) -> Optional[float]: - if equity <= 0: - return None - from contract_specs import calc_position_metrics - - total = 0.0 - rows = conn.execute( - """SELECT symbol, direction, lots, entry_price, stop_loss, take_profit, open_time - FROM trade_order_monitors - WHERE open_time IS NOT NULL AND trim(open_time) <> ''""" - ).fetchall() - for r in rows: - if not _opened_in_trading_day(r["open_time"], now): - continue - entry = float(r["entry_price"] or 0) - if entry <= 0: - continue - sl = float(r["stop_loss"] if r["stop_loss"] is not None else entry) - tp = float(r["take_profit"] if r["take_profit"] is not None else entry) - lots = int(r["lots"] or 0) - if lots <= 0: - continue - m = calc_position_metrics( - r["direction"] or "long", - entry, - sl, - tp, - lots, - entry, - equity, - r["symbol"] or "", - ) - total += float(m.get("risk_amount") or 0) - if total <= 0: - return 0.0 - return round(total / equity * 100, 2) - - -def count_active_trade_monitors(conn) -> int: - try: - n = conn.execute( - "SELECT COUNT(*) FROM trade_order_monitors WHERE status='active'" - ).fetchone()[0] - return int(n or 0) - except Exception: - return 0 - - -def parse_mood_issues(raw: Any) -> list[str]: - if raw is None: - return [] - if isinstance(raw, (list, tuple)): - parts = [str(x).strip() for x in raw if str(x).strip()] - else: - parts = [x.strip() for x in str(raw).split(",") if x.strip()] - return [p for p in parts if p in MOOD_ISSUE_OPTIONS] - - -def on_user_initiated_close(conn, *, trading_day: str, now: Optional[datetime] = None) -> None: - if not risk_control_enabled(): - return - ensure_account_risk_schema(conn) - row = conn.execute("SELECT * FROM account_risk_state WHERE id=1").fetchone() - td = (trading_day or trading_day_label(now)).strip() - stored = str(_row_get(row, "trading_day") or "") - count = int(_row_get(row, "manual_close_count") or 0) - if stored != td: - count = 0 - count += 1 - close_ms = _now_ms(now) - if count >= manual_close_daily_limit(): - conn.execute( - """UPDATE account_risk_state SET trading_day=?, manual_close_count=?, - daily_frozen=1, cooloff_until_ms=NULL, cooloff_hours=NULL, - last_close_at_ms=?, updated_at=? WHERE id=1""", - (td, count, close_ms, datetime.now().strftime("%Y-%m-%d %H:%M:%S")), - ) - return - conn.execute( - """UPDATE account_risk_state SET trading_day=?, manual_close_count=?, - daily_frozen=0, cooloff_until_ms=NULL, cooloff_hours=NULL, - last_close_at_ms=?, updated_at=? WHERE id=1""", - (td, count, close_ms, datetime.now().strftime("%Y-%m-%d %H:%M:%S")), - ) - - -def on_mood_journal_freeze(conn, *, trading_day: str) -> None: - if not risk_control_enabled(): - return - ensure_account_risk_schema(conn) - td = (trading_day or trading_day_label()).strip() - conn.execute( - "UPDATE account_risk_state SET trading_day=?, daily_frozen=1, updated_at=? WHERE id=1", - (td, datetime.now().strftime("%Y-%m-%d %H:%M:%S")), - ) - - -def reduce_cooloff_after_journal(conn, *, trading_day: str, now: Optional[datetime] = None) -> None: - """期货版无应用层冷静期,保留空实现兼容旧复盘钩子。""" - del conn, trading_day, now - return - - -def get_risk_status( - conn, - *, - now: Optional[datetime] = None, - active_count: Optional[int] = None, - equity: Optional[float] = None, -) -> dict: - def _load() -> dict: - ensure_account_risk_schema(conn) - try: - row = conn.execute("SELECT * FROM account_risk_state WHERE id=1").fetchone() - except Exception as exc: - if is_missing_relation_error(exc): - global _SCHEMA_READY - _SCHEMA_READY = False - rollback_if_postgres(conn) - ensure_account_risk_schema(conn) - row = conn.execute("SELECT * FROM account_risk_state WHERE id=1").fetchone() - else: - raise - td = trading_day_label(now) - stored = str(_row_get(row, "trading_day") or "") - if stored != td: - conn.execute( - "UPDATE account_risk_state SET trading_day=?, manual_close_count=0, daily_frozen=0 WHERE id=1 AND trading_day<>?", - (td, td), - ) - conn.commit() - row = conn.execute("SELECT * FROM account_risk_state WHERE id=1").fetchone() - - now_ms = _now_ms(now) - daily = int(_row_get(row, "daily_frozen") or 0) == 1 - until = _row_get(row, "cooloff_until_ms") - if until: - conn.execute( - "UPDATE account_risk_state SET cooloff_until_ms=NULL, cooloff_hours=NULL WHERE id=1" - ) - conn.commit() - active = count_active_trade_monitors(conn) if active_count is None else int(active_count) - mx = max_active_positions() - pos_limit = active >= mx - daily_opens = count_daily_opens(conn, now) - daily_pos_lim = daily_position_limit() - daily_open_limit = daily_opens >= daily_pos_lim - daily_risk_used: Optional[float] = None - daily_risk_lim = daily_trading_risk_pct_limit() - daily_risk_limit_hit = False - if equity and float(equity) > 0: - daily_risk_used = daily_trading_risk_used_pct(conn, float(equity), now) - if daily_risk_used is not None and daily_risk_used >= daily_risk_lim: - daily_risk_limit_hit = True - - base = { - "active_count": active, - "max_active_positions": mx, - "daily_open_count": daily_opens, - "daily_position_limit": daily_pos_lim, - "daily_risk_used_pct": daily_risk_used, - "daily_trading_risk_pct_limit": daily_risk_lim, - } - - if daily: - return { - **base, - "status": STATUS_DAILY, - "status_label": STATUS_LABELS[STATUS_DAILY], - "can_trade": False, - "can_roll": False, - "reason": "当日日冻结,禁止新开仓", - } - if daily_risk_limit_hit: - return { - **base, - "status": STATUS_DAILY, - "status_label": STATUS_LABELS[STATUS_DAILY], - "can_trade": False, - "can_roll": pos_limit, - "reason": f"已达日交易风险上限 {daily_risk_used:.2f}%/{daily_risk_lim:.2f}%", - } - if daily_open_limit: - return { - **base, - "status": STATUS_DAILY, - "status_label": STATUS_LABELS[STATUS_DAILY], - "can_trade": False, - "can_roll": pos_limit, - "reason": f"已达日持仓上限 {daily_opens}/{daily_pos_lim}", - } - if pos_limit: - return { - **base, - "status": STATUS_FREEZE_POSITION, - "status_label": STATUS_LABELS[STATUS_FREEZE_POSITION], - "can_trade": False, - "can_roll": True, - "reason": f"已达仓位上限 {active}/{mx}", - } - return { - **base, - "status": STATUS_NORMAL, - "status_label": STATUS_LABELS[STATUS_NORMAL], - "can_trade": True, - "can_roll": True, - "reason": "可新开仓", - } - - return _db_retry(_load) - - -def assert_can_open( - conn, - *, - active_count: Optional[int] = None, - equity: Optional[float] = None, -) -> Optional[str]: - rs = get_risk_status(conn, active_count=active_count, equity=equity) - if not rs.get("can_trade"): - return rs.get("reason") or "当前不可开仓" - return None +# Copyright (c) 2025-2026 马建军. All rights reserved. +# 专有软件 — 未经授权禁止复制、传播、转售。 +# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 +# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md + +"""账户冷静期 / 日冻结(自 crypto_monitor 复制并简化为单账户期货版)。""" +from __future__ import annotations + +import os +import time +from datetime import datetime +from typing import Any, Callable, Optional, TypeVar +from zoneinfo import ZoneInfo + +from modules.core.db_conn import OperationalError, is_missing_relation_error, rollback_if_postgres + +T = TypeVar("T") + +STATUS_NORMAL = "normal" +STATUS_FREEZE_1H = "freeze_1h" +STATUS_FREEZE_4H = "freeze_4h" +STATUS_DAILY = "freeze_daily" +STATUS_FREEZE_POSITION = "freeze_position" + +STATUS_LABELS = { + STATUS_NORMAL: "正常", + STATUS_FREEZE_1H: "1h冻结", + STATUS_FREEZE_4H: "4h冻结", + STATUS_DAILY: "日冻结", + STATUS_FREEZE_POSITION: "仓位上限冻结", +} + +MOOD_ISSUE_OPTIONS = ( + "怕踏空", "报复开仓", "盈利飘了", "拿不住单", "扛单", "重仓违规", +) + +CLOSE_SOURCE_USER = "user_instance" +CLOSE_SOURCE_TREND_STOP = "user_trend_stop" + + +def _app_tz(): + name = (os.getenv("APP_TIMEZONE") or "Asia/Shanghai").strip() + try: + return ZoneInfo(name) + except Exception: + return ZoneInfo("Asia/Shanghai") + + +def risk_control_enabled() -> bool: + raw = (os.getenv("RISK_CONTROL_ENABLED") or "true").strip().lower() + return raw in ("1", "true", "yes", "on") + + +def cooling_hours_manual() -> float: + """期货版不使用应用层冷静期(交易所自有规则),恒为 0。""" + return 0.0 + + +def cooling_hours_manual_journal() -> float: + return 0.0 + + +def manual_close_daily_limit() -> int: + try: + return max(1, int(os.getenv("RISK_MANUAL_CLOSE_DAILY_LIMIT", "2"))) + except (TypeError, ValueError): + return 2 + + +def max_active_positions() -> int: + try: + return max(1, int(os.getenv("MAX_ACTIVE_POSITIONS", "1"))) + except (TypeError, ValueError): + return 1 + + +def daily_position_limit() -> int: + """当日最多开仓次数(含已平)。""" + try: + return max(1, int(os.getenv("RISK_DAILY_POSITION_LIMIT", "5"))) + except (TypeError, ValueError): + return 5 + + +def daily_trading_risk_pct_limit() -> float: + """当日累计止损风险占权益上限(%)。""" + try: + return max(0.1, float(os.getenv("RISK_DAILY_TRADING_RISK_PCT", "2"))) + except (TypeError, ValueError): + return 2.0 + + +def trading_day_reset_hour() -> int: + try: + return max(0, min(23, int(os.getenv("TRADING_DAY_RESET_HOUR", "8")))) + except (TypeError, ValueError): + return 8 + + +_SCHEMA_READY = False + +ACCOUNT_RISK_STATE_SQL = """ +CREATE TABLE IF NOT EXISTS account_risk_state ( + id INTEGER PRIMARY KEY, + trading_day TEXT, + manual_close_count INTEGER DEFAULT 0, + cooloff_until_ms INTEGER, + cooloff_hours INTEGER, + daily_frozen INTEGER DEFAULT 0, + last_close_at_ms INTEGER, + updated_at TEXT +) +""" + + +def _account_risk_table_exists(conn) -> bool: + try: + conn.execute("SELECT 1 FROM account_risk_state WHERE id=1") + return True + except Exception as exc: + if is_missing_relation_error(exc): + rollback_if_postgres(conn) + return False + raise + + +def _db_retry(action: Callable[[], T], *, retries: int = 8, base_delay: float = 0.03) -> T: + last: OperationalError | None = None + for i in range(retries): + try: + return action() + except OperationalError as exc: + msg = str(exc).lower() + if "locked" not in msg and "serialize" not in msg and "deadlock" not in msg: + raise + last = exc + time.sleep(base_delay * (2 ** i)) + if last is not None: + raise last + raise RuntimeError("db retry failed") + + +def ensure_account_risk_schema(conn) -> None: + global _SCHEMA_READY + if _SCHEMA_READY and _account_risk_table_exists(conn): + return + _SCHEMA_READY = False + conn.execute(ACCOUNT_RISK_STATE_SQL) + conn.commit() + if not conn.execute("SELECT 1 FROM account_risk_state WHERE id=1").fetchone(): + conn.execute( + "INSERT INTO account_risk_state (id, trading_day, manual_close_count, daily_frozen) " + "VALUES (1, '', 0, 0)" + ) + conn.commit() + _SCHEMA_READY = True + + +def _row_get(row, key, default=None): + if row is None: + return default + try: + return row[key] + except (KeyError, IndexError, TypeError): + return default + + +def _now_ms(now: Optional[datetime] = None) -> int: + dt = now or datetime.now(_app_tz()) + if dt.tzinfo is None: + dt = dt.replace(tzinfo=_app_tz()) + return int(dt.timestamp() * 1000) + + +def trading_day_label(now: Optional[datetime] = None) -> str: + dt = now or datetime.now(_app_tz()) + if dt.hour < trading_day_reset_hour(): + from datetime import timedelta + dt = dt - timedelta(days=1) + return dt.date().isoformat() + + +def trading_day_start(now: Optional[datetime] = None) -> datetime: + dt = now or datetime.now(_app_tz()) + if dt.tzinfo is None: + dt = dt.replace(tzinfo=_app_tz()) + reset_h = trading_day_reset_hour() + start = dt.replace(hour=reset_h, minute=0, second=0, microsecond=0) + if dt.hour < reset_h: + from datetime import timedelta + start = start - timedelta(days=1) + return start + + +def _parse_open_time_ms(open_time: str) -> Optional[int]: + s = (open_time or "").strip().replace("T", " ")[:19] + if not s: + return None + try: + dt = datetime.strptime(s, "%Y-%m-%d %H:%M:%S") + if dt.tzinfo is None: + dt = dt.replace(tzinfo=_app_tz()) + return int(dt.timestamp() * 1000) + except ValueError: + try: + dt = datetime.strptime(s[:10], "%Y-%m-%d").replace(tzinfo=_app_tz()) + return int(dt.timestamp() * 1000) + except ValueError: + return None + + +def _opened_in_trading_day(open_time: str, now: Optional[datetime] = None) -> bool: + oms = _parse_open_time_ms(open_time) + if oms is None: + return False + return oms >= int(trading_day_start(now).timestamp() * 1000) + + +def count_daily_opens(conn, now: Optional[datetime] = None) -> int: + rows = conn.execute( + "SELECT open_time FROM trade_order_monitors " + "WHERE open_time IS NOT NULL AND trim(open_time) <> ''" + ).fetchall() + return sum(1 for r in rows if _opened_in_trading_day(r["open_time"], now)) + + +def daily_trading_risk_used_pct( + conn, equity: float, now: Optional[datetime] = None, +) -> Optional[float]: + if equity <= 0: + return None + from modules.core.contract_specs import calc_position_metrics + + total = 0.0 + rows = conn.execute( + """SELECT symbol, direction, lots, entry_price, stop_loss, take_profit, open_time + FROM trade_order_monitors + WHERE open_time IS NOT NULL AND trim(open_time) <> ''""" + ).fetchall() + for r in rows: + if not _opened_in_trading_day(r["open_time"], now): + continue + entry = float(r["entry_price"] or 0) + if entry <= 0: + continue + sl = float(r["stop_loss"] if r["stop_loss"] is not None else entry) + tp = float(r["take_profit"] if r["take_profit"] is not None else entry) + lots = int(r["lots"] or 0) + if lots <= 0: + continue + m = calc_position_metrics( + r["direction"] or "long", + entry, + sl, + tp, + lots, + entry, + equity, + r["symbol"] or "", + ) + total += float(m.get("risk_amount") or 0) + if total <= 0: + return 0.0 + return round(total / equity * 100, 2) + + +def count_active_trade_monitors(conn) -> int: + try: + n = conn.execute( + "SELECT COUNT(*) FROM trade_order_monitors WHERE status='active'" + ).fetchone()[0] + return int(n or 0) + except Exception: + return 0 + + +def parse_mood_issues(raw: Any) -> list[str]: + if raw is None: + return [] + if isinstance(raw, (list, tuple)): + parts = [str(x).strip() for x in raw if str(x).strip()] + else: + parts = [x.strip() for x in str(raw).split(",") if x.strip()] + return [p for p in parts if p in MOOD_ISSUE_OPTIONS] + + +def on_user_initiated_close(conn, *, trading_day: str, now: Optional[datetime] = None) -> None: + if not risk_control_enabled(): + return + ensure_account_risk_schema(conn) + row = conn.execute("SELECT * FROM account_risk_state WHERE id=1").fetchone() + td = (trading_day or trading_day_label(now)).strip() + stored = str(_row_get(row, "trading_day") or "") + count = int(_row_get(row, "manual_close_count") or 0) + if stored != td: + count = 0 + count += 1 + close_ms = _now_ms(now) + if count >= manual_close_daily_limit(): + conn.execute( + """UPDATE account_risk_state SET trading_day=?, manual_close_count=?, + daily_frozen=1, cooloff_until_ms=NULL, cooloff_hours=NULL, + last_close_at_ms=?, updated_at=? WHERE id=1""", + (td, count, close_ms, datetime.now().strftime("%Y-%m-%d %H:%M:%S")), + ) + return + conn.execute( + """UPDATE account_risk_state SET trading_day=?, manual_close_count=?, + daily_frozen=0, cooloff_until_ms=NULL, cooloff_hours=NULL, + last_close_at_ms=?, updated_at=? WHERE id=1""", + (td, count, close_ms, datetime.now().strftime("%Y-%m-%d %H:%M:%S")), + ) + + +def on_mood_journal_freeze(conn, *, trading_day: str) -> None: + if not risk_control_enabled(): + return + ensure_account_risk_schema(conn) + td = (trading_day or trading_day_label()).strip() + conn.execute( + "UPDATE account_risk_state SET trading_day=?, daily_frozen=1, updated_at=? WHERE id=1", + (td, datetime.now().strftime("%Y-%m-%d %H:%M:%S")), + ) + + +def reduce_cooloff_after_journal(conn, *, trading_day: str, now: Optional[datetime] = None) -> None: + """期货版无应用层冷静期,保留空实现兼容旧复盘钩子。""" + del conn, trading_day, now + return + + +def get_risk_status( + conn, + *, + now: Optional[datetime] = None, + active_count: Optional[int] = None, + equity: Optional[float] = None, +) -> dict: + def _load() -> dict: + ensure_account_risk_schema(conn) + try: + row = conn.execute("SELECT * FROM account_risk_state WHERE id=1").fetchone() + except Exception as exc: + if is_missing_relation_error(exc): + global _SCHEMA_READY + _SCHEMA_READY = False + rollback_if_postgres(conn) + ensure_account_risk_schema(conn) + row = conn.execute("SELECT * FROM account_risk_state WHERE id=1").fetchone() + else: + raise + td = trading_day_label(now) + stored = str(_row_get(row, "trading_day") or "") + if stored != td: + conn.execute( + "UPDATE account_risk_state SET trading_day=?, manual_close_count=0, daily_frozen=0 WHERE id=1 AND trading_day<>?", + (td, td), + ) + conn.commit() + row = conn.execute("SELECT * FROM account_risk_state WHERE id=1").fetchone() + + now_ms = _now_ms(now) + daily = int(_row_get(row, "daily_frozen") or 0) == 1 + until = _row_get(row, "cooloff_until_ms") + if until: + conn.execute( + "UPDATE account_risk_state SET cooloff_until_ms=NULL, cooloff_hours=NULL WHERE id=1" + ) + conn.commit() + active = count_active_trade_monitors(conn) if active_count is None else int(active_count) + mx = max_active_positions() + pos_limit = active >= mx + daily_opens = count_daily_opens(conn, now) + daily_pos_lim = daily_position_limit() + daily_open_limit = daily_opens >= daily_pos_lim + daily_risk_used: Optional[float] = None + daily_risk_lim = daily_trading_risk_pct_limit() + daily_risk_limit_hit = False + if equity and float(equity) > 0: + daily_risk_used = daily_trading_risk_used_pct(conn, float(equity), now) + if daily_risk_used is not None and daily_risk_used >= daily_risk_lim: + daily_risk_limit_hit = True + + base = { + "active_count": active, + "max_active_positions": mx, + "daily_open_count": daily_opens, + "daily_position_limit": daily_pos_lim, + "daily_risk_used_pct": daily_risk_used, + "daily_trading_risk_pct_limit": daily_risk_lim, + } + + if daily: + return { + **base, + "status": STATUS_DAILY, + "status_label": STATUS_LABELS[STATUS_DAILY], + "can_trade": False, + "can_roll": False, + "reason": "当日日冻结,禁止新开仓", + } + if daily_risk_limit_hit: + return { + **base, + "status": STATUS_DAILY, + "status_label": STATUS_LABELS[STATUS_DAILY], + "can_trade": False, + "can_roll": pos_limit, + "reason": f"已达日交易风险上限 {daily_risk_used:.2f}%/{daily_risk_lim:.2f}%", + } + if daily_open_limit: + return { + **base, + "status": STATUS_DAILY, + "status_label": STATUS_LABELS[STATUS_DAILY], + "can_trade": False, + "can_roll": pos_limit, + "reason": f"已达日持仓上限 {daily_opens}/{daily_pos_lim}", + } + if pos_limit: + return { + **base, + "status": STATUS_FREEZE_POSITION, + "status_label": STATUS_LABELS[STATUS_FREEZE_POSITION], + "can_trade": False, + "can_roll": True, + "reason": f"已达仓位上限 {active}/{mx}", + } + return { + **base, + "status": STATUS_NORMAL, + "status_label": STATUS_LABELS[STATUS_NORMAL], + "can_trade": True, + "can_roll": True, + "reason": "可新开仓", + } + + return _db_retry(_load) + + +def assert_can_open( + conn, + *, + active_count: Optional[int] = None, + equity: Optional[float] = None, +) -> Optional[str]: + rs = get_risk_status(conn, active_count=active_count, equity=equity) + if not rs.get("can_trade"): + return rs.get("reason") or "当前不可开仓" + return None diff --git a/modules/settings/__init__.py b/modules/settings/__init__.py new file mode 100644 index 0000000..1774aba --- /dev/null +++ b/modules/settings/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) 2025-2026 马建军. All rights reserved. + +from modules.settings.routes import register + +__all__ = ["register"] diff --git a/admin_settings.py b/modules/settings/admin_settings.py similarity index 95% rename from admin_settings.py rename to modules/settings/admin_settings.py index e66e263..5b84902 100644 --- a/admin_settings.py +++ b/modules/settings/admin_settings.py @@ -1,86 +1,86 @@ -# Copyright (c) 2025-2026 马建军. All rights reserved. -# 专有软件 — 未经授权禁止复制、传播、转售。 -# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 -# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md - -"""Web 登录账号:settings 表 + .env 同步。""" -from __future__ import annotations - -import os -import re -from typing import Callable - -from werkzeug.security import check_password_hash, generate_password_hash - -from env_file import update_env_vars - -ADMIN_USERNAME_KEY = "ADMIN_USERNAME" -ADMIN_PASSWORD_KEY = "ADMIN_PASSWORD" - - -def save_admin_credentials( - *, - username: str, - old_password: str, - new_password: str, - new_password2: str, - get_setting: Callable[[str, str], str], - set_setting: Callable[[str, str], None], -) -> tuple[bool, str, dict[str, str]]: - """ - 校验原密码后更新用户名/密码,写入 settings 与 .env。 - 返回 (成功, 提示, env_updates)。 - """ - username = (username or "").strip() - old_password = old_password or "" - new_password = new_password or "" - new_password2 = new_password2 or "" - - if not username: - return False, "用户名不能为空", {} - if len(username) > 64: - return False, "用户名过长(最多 64 字符)", {} - if not re.match(r"^[A-Za-z0-9_.@-]+$", username): - return False, "用户名仅支持字母、数字及 _ . @ -", {} - - admin_hash = get_setting("admin_password_hash") - if not admin_hash or not check_password_hash(admin_hash, old_password): - return False, "原密码错误", {} - - current_username = (get_setting("admin_username") or "").strip() - password_change = bool(new_password or new_password2) - - if password_change: - if not new_password or not new_password2: - return False, "请同时填写新密码与确认密码", {} - if len(new_password) < 6: - return False, "新密码至少 6 位", {} - if new_password != new_password2: - return False, "两次新密码不一致", {} - - username_changed = username != current_username - if not username_changed and not password_change: - return False, "未修改任何内容", {} - - set_setting("admin_username", username) - env_updates: dict[str, str] = {ADMIN_USERNAME_KEY: username} - - if password_change: - set_setting("admin_password_hash", generate_password_hash(new_password)) - env_updates[ADMIN_PASSWORD_KEY] = new_password - - try: - update_env_vars(env_updates) - except OSError as exc: - return False, f"数据库已更新,但写入 .env 失败:{exc}", env_updates - - for key, val in env_updates.items(): - os.environ[key] = val - - parts: list[str] = [] - if username_changed: - parts.append("用户名已更新") - if password_change: - parts.append("密码已更新") - parts.append("已同步至 .env") - return True, ";".join(parts), env_updates +# Copyright (c) 2025-2026 马建军. All rights reserved. +# 专有软件 — 未经授权禁止复制、传播、转售。 +# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 +# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md + +"""Web 登录账号:settings 表 + .env 同步。""" +from __future__ import annotations + +import os +import re +from typing import Callable + +from werkzeug.security import check_password_hash, generate_password_hash + +from modules.core.env_file import update_env_vars + +ADMIN_USERNAME_KEY = "ADMIN_USERNAME" +ADMIN_PASSWORD_KEY = "ADMIN_PASSWORD" + + +def save_admin_credentials( + *, + username: str, + old_password: str, + new_password: str, + new_password2: str, + get_setting: Callable[[str, str], str], + set_setting: Callable[[str, str], None], +) -> tuple[bool, str, dict[str, str]]: + """ + 校验原密码后更新用户名/密码,写入 settings 与 .env。 + 返回 (成功, 提示, env_updates)。 + """ + username = (username or "").strip() + old_password = old_password or "" + new_password = new_password or "" + new_password2 = new_password2 or "" + + if not username: + return False, "用户名不能为空", {} + if len(username) > 64: + return False, "用户名过长(最多 64 字符)", {} + if not re.match(r"^[A-Za-z0-9_.@-]+$", username): + return False, "用户名仅支持字母、数字及 _ . @ -", {} + + admin_hash = get_setting("admin_password_hash") + if not admin_hash or not check_password_hash(admin_hash, old_password): + return False, "原密码错误", {} + + current_username = (get_setting("admin_username") or "").strip() + password_change = bool(new_password or new_password2) + + if password_change: + if not new_password or not new_password2: + return False, "请同时填写新密码与确认密码", {} + if len(new_password) < 6: + return False, "新密码至少 6 位", {} + if new_password != new_password2: + return False, "两次新密码不一致", {} + + username_changed = username != current_username + if not username_changed and not password_change: + return False, "未修改任何内容", {} + + set_setting("admin_username", username) + env_updates: dict[str, str] = {ADMIN_USERNAME_KEY: username} + + if password_change: + set_setting("admin_password_hash", generate_password_hash(new_password)) + env_updates[ADMIN_PASSWORD_KEY] = new_password + + try: + update_env_vars(env_updates) + except OSError as exc: + return False, f"数据库已更新,但写入 .env 失败:{exc}", env_updates + + for key, val in env_updates.items(): + os.environ[key] = val + + parts: list[str] = [] + if username_changed: + parts.append("用户名已更新") + if password_change: + parts.append("密码已更新") + parts.append("已同步至 .env") + return True, ";".join(parts), env_updates diff --git a/nav_settings.py b/modules/settings/nav_settings.py similarity index 100% rename from nav_settings.py rename to modules/settings/nav_settings.py diff --git a/modules/settings/routes.py b/modules/settings/routes.py new file mode 100644 index 0000000..17aeddc --- /dev/null +++ b/modules/settings/routes.py @@ -0,0 +1,314 @@ +# Copyright (c) 2025-2026 马建军. All rights reserved. +"""HTTP routes for settings module.""" + +from __future__ import annotations + +from datetime import date, datetime + +from flask import ( + Response, + flash, + jsonify, + redirect, + render_template, + request, + send_file, + session, + stream_with_context, + url_for, +) + + +def register(deps) -> None: + app = deps.app + login_required = deps.login_required + require_nav = deps.require_nav + get_db = deps.get_db + get_setting = deps.get_setting + set_setting = deps.set_setting + fetch_price = deps.fetch_price + send_wechat_msg = deps.send_wechat_msg + touch_stats_cache = deps.touch_stats_cache + get_stats_data = deps.get_stats_data + build_market_quote_payload = deps.build_market_quote_payload + today_str = deps.today_str + expire_old_plans = deps.expire_old_plans + TZ = deps.tz + DB_PATH = deps.db_path + UPLOAD_DIR = deps.upload_dir + OPEN_TYPES = deps.open_types + EXIT_TRIGGERS = deps.exit_triggers + BEHAVIOR_TAGS = deps.behavior_tags + KLINE_PERIODS = deps.kline_periods + KLINE_CUTOFFS = deps.kline_cutoffs + calc_holding_duration = deps.calc_holding_duration + holding_to_minutes = deps.holding_to_minutes + classify_close_result = deps.classify_close_result + calc_rr_ratio = deps.calc_rr_ratio + calc_theoretical_pnl = deps.calc_theoretical_pnl + parse_review_date_filter = deps.parse_review_date_filter + _trading_mode = deps.trading_mode + _ua_is_phone = deps.ua_is_phone + _static_asset_v = deps.static_asset_v + + from modules.settings.nav_settings import NAV_TOGGLES, get_nav_items, save_nav_items + from modules.settings.admin_settings import save_admin_credentials + from modules.backup.db_backup import ( + backup_dir, + backup_in_progress, + default_restore_dir, + get_backup_last_at, + list_backups, + schedule_backup, + ) + from modules.market.market import get_quote_source_label + from modules.trading.product_recommend import small_account_margin_recommendations + + @app.route("/settings", methods=["GET", "POST"]) + @login_required + def settings(): + if request.method == "POST": + action = request.form.get("action") + if action == "backup_now": + ok, msg = schedule_backup( + get_setting=get_setting, + set_setting=set_setting, + include_uploads=True, + ) + flash(msg if ok else msg) + elif action == "backup_config": + auto = request.form.get("backup_auto_enabled") == "1" + set_setting("backup_auto_enabled", "1" if auto else "0") + try: + hour = int(request.form.get("backup_auto_hour", "3") or 3) + set_setting("backup_auto_hour", str(max(0, min(23, hour)))) + except ValueError: + flash("自动备份小时无效") + return redirect(url_for("settings")) + try: + keep = int(request.form.get("backup_keep_count", "30") or 30) + set_setting("backup_keep_count", str(max(5, min(200, keep)))) + except ValueError: + flash("保留份数无效") + return redirect(url_for("settings")) + flash("备份策略已保存") + elif action == "wechat": + webhook = request.form.get("wechat_webhook", "").strip() + set_setting("wechat_webhook", webhook) + flash("企业微信配置已保存") + elif action == "ai": + set_setting("ai_enabled", "1" if request.form.get("ai_enabled") else "0") + provider = (request.form.get("ai_provider") or "ollama").strip().lower() + if provider not in ("ollama", "openai"): + provider = "ollama" + set_setting("ai_provider", provider) + set_setting("ai_ollama_base_url", (request.form.get("ai_ollama_base_url") or "").strip()) + set_setting("ai_ollama_model", (request.form.get("ai_ollama_model") or "").strip()) + set_setting("ai_openai_base_url", (request.form.get("ai_openai_base_url") or "").strip()) + key = (request.form.get("ai_openai_api_key") or "").strip() + if key: + set_setting("ai_openai_api_key", key) + set_setting("ai_openai_model", (request.form.get("ai_openai_model") or "").strip()) + set_setting("ai_daily_report_enabled", "1" if request.form.get("ai_daily_report_enabled") else "0") + try: + set_setting("ai_daily_report_hour", str(max(0, min(23, int(request.form.get("ai_daily_report_hour", "15") or 15))))) + except ValueError: + pass + try: + set_setting("ai_daily_report_minute", str(max(0, min(59, int(request.form.get("ai_daily_report_minute", "5") or 5))))) + except ValueError: + pass + flash("AI 配置已保存") + elif action == "trading": + mode = request.form.get("trading_mode", "simulation").strip() + if mode not in ("simulation", "live"): + mode = "simulation" + sizing = request.form.get("position_sizing_mode", "fixed").strip() + if sizing == "risk": + sizing = "amount" + if sizing not in ("fixed", "amount"): + sizing = "fixed" + set_setting("trading_mode", mode) + set_setting("position_sizing_mode", sizing) + try: + fl = int(float(request.form.get("fixed_lots", "1") or 1)) + set_setting("fixed_lots", str(max(1, fl))) + except ValueError: + flash("固定手数无效") + return redirect(url_for("settings")) + try: + fa = float(request.form.get("fixed_amount", "5000") or 5000) + set_setting("fixed_amount", str(max(1.0, fa))) + except ValueError: + flash("固定金额无效") + return redirect(url_for("settings")) + try: + rp = float(request.form.get("risk_percent", "1") or 1) + set_setting("risk_percent", str(max(0.1, min(100.0, rp)))) + except ValueError: + pass + try: + mp = float(request.form.get("max_margin_pct", "30") or 30) + set_setting("max_margin_pct", str(max(1.0, min(100.0, mp)))) + except ValueError: + flash("保证金比例无效") + return redirect(url_for("settings")) + try: + rmp = float(request.form.get("roll_max_margin_pct", "50") or 50) + set_setting("roll_max_margin_pct", str(max(1.0, min(100.0, rmp)))) + except ValueError: + flash("滚仓保证金比例无效") + return redirect(url_for("settings")) + try: + tb = int(float(request.form.get("trailing_be_tick_buffer", "2") or 2)) + set_setting("trailing_be_tick_buffer", str(max(1, min(20, tb)))) + except ValueError: + flash("移动保本缓冲无效") + return redirect(url_for("settings")) + try: + pt = int(float(request.form.get("pending_order_timeout_min", "5") or 5)) + set_setting("pending_order_timeout_min", str(max(1, min(60, pt)))) + except ValueError: + flash("挂单超时无效") + return redirect(url_for("settings")) + flash("交易模式已保存") + elif action == "ctp": + from modules.ctp.ctp_settings import save_ctp_auto_connect, is_ctp_auto_connect_enabled + from modules.ctp.ctp_settings import save_ctp_settings_from_form + from modules.ctp.vnpy_bridge import ctp_disconnect + + was_enabled = is_ctp_auto_connect_enabled(get_setting) + auto_enabled = save_ctp_auto_connect(request.form, set_setting) + save_result = save_ctp_settings_from_form(request.form, set_setting) + if not auto_enabled: + ctp_disconnect(set_disabled_hint=True) + elif not was_enabled and auto_enabled: + try: + from modules.ctp.vnpy_bridge import get_bridge + from modules.core.trading_context import get_trading_mode + + mode = get_trading_mode(get_setting) + get_bridge().reconnect_after_settings_saved(mode) + except Exception as exc: + app.logger.debug("CTP connect after enable auto: %s", exc) + pwd_updated = save_result.get("passwords_updated") or [] + pwd_empty = save_result.get("passwords_submitted_empty") or [] + simnow_pwd_len = len((request.form.get("simnow_password") or "").strip()) + live_pwd_len = len((request.form.get("ctp_live_password") or "").strip()) + print( + f"CTP settings save: simnow_password_len={simnow_pwd_len} " + f"live_password_len={live_pwd_len} updated={pwd_updated}", + flush=True, + ) + app.logger.info( + "CTP settings save: simnow_password_len=%s live_password_len=%s updated=%s", + simnow_pwd_len, + live_pwd_len, + pwd_updated, + ) + if "simnow_password" in pwd_updated: + pwd_note = f"SimNow 交易密码已更新({simnow_pwd_len} 位)" + elif "simnow_password" in pwd_empty: + pwd_note = "SimNow 交易密码未改:提交为空,请在「交易密码」框手打后再保存" + elif "ctp_live_password" in pwd_updated: + pwd_note = "实盘交易密码已更新" + elif "ctp_live_password" in pwd_empty: + pwd_note = "实盘交易密码未改(提交为空)" + else: + pwd_note = "" + if not auto_enabled: + flash("CTP 配置已保存;自动连接已关闭,所有 CTP 连接已断开") + return redirect(url_for("settings")) + if not was_enabled: + flash("CTP 配置已保存;自动连接已开启,正在连接…") + return redirect(url_for("settings")) + flash_msg = "CTP 配置已保存,正在使用新地址重连…" + if pwd_note: + flash_msg = f"CTP 配置已保存;{pwd_note},正在重连…" + try: + from modules.ctp.vnpy_bridge import get_bridge + from modules.core.trading_context import get_trading_mode + + b = get_bridge() + if pwd_updated: + b._clear_login_cooldown() + mode = get_trading_mode(get_setting) + info = b.reconnect_after_settings_saved(mode) + if info.get("cooldown"): + flash_msg = f"CTP 配置已保存;{pwd_note or '请稍后再连'}" + elif not info.get("started") and info.get("connected"): + flash_msg = f"CTP 配置已保存;{pwd_note or '当前连接正常'}" + except Exception as exc: + app.logger.warning("CTP reconnect after settings save: %s", exc) + flash_msg = f"CTP 配置已保存;{pwd_note or '请稍后在持仓监控页重连'}" + flash(flash_msg) + elif action == "nav": + items = {k: request.form.get(f"nav_{k}") == "on" for k in NAV_TOGGLES} + save_nav_items(set_setting, items) + flash("导航显示已保存") + elif action == "password": + ok, msg, _ = save_admin_credentials( + username=request.form.get("admin_username", ""), + old_password=request.form.get("old_password", ""), + new_password=request.form.get("new_password", ""), + new_password2=request.form.get("new_password2", ""), + get_setting=get_setting, + set_setting=set_setting, + ) + if ok and session.get("logged_in"): + session["username"] = (request.form.get("admin_username") or "").strip() + flash(msg) + return redirect(url_for("settings")) + + webhook = get_setting("wechat_webhook") + username = get_setting("admin_username") + ctp_st = {} + try: + from modules.ctp.vnpy_bridge import ctp_status + from modules.core.trading_context import get_trading_mode + + ctp_st = ctp_status(get_trading_mode(get_setting)) + except Exception: + pass + from modules.ctp.ctp_settings import get_ctp_settings_for_ui, is_ctp_auto_connect_enabled + from modules.trading.product_recommend import small_account_margin_recommendations + + return render_template( + "settings.html", + webhook=webhook, + username=username, + quote_label=get_quote_source_label(ctp_connected=bool(ctp_st.get("connected"))), + ctp_status=ctp_st, + ctp_cfg=get_ctp_settings_for_ui(), + ctp_auto_connect=is_ctp_auto_connect_enabled(get_setting), + trading_mode=get_setting("trading_mode", "simulation"), + position_sizing_mode=get_setting("position_sizing_mode", "fixed"), + fixed_lots=get_setting("fixed_lots", "1"), + fixed_amount=get_setting("fixed_amount", "5000"), + risk_percent=get_setting("risk_percent", "1"), + max_margin_pct=get_setting("max_margin_pct", "30"), + roll_max_margin_pct=get_setting("roll_max_margin_pct", "50"), + small_account_margin_rec=small_account_margin_recommendations(), + trailing_be_tick_buffer=get_setting("trailing_be_tick_buffer", "2"), + pending_order_timeout_min=get_setting("pending_order_timeout_min", "5"), + nav_items=get_nav_items(get_setting), + nav_toggles=NAV_TOGGLES, + backup_dir=str(backup_dir()), + backup_last_at=get_backup_last_at(get_setting), + backup_running=backup_in_progress(), + backup_items=list_backups(), + backup_auto_enabled=get_setting("backup_auto_enabled", "1") == "1", + backup_auto_hour=get_setting("backup_auto_hour", "3"), + backup_keep_count=get_setting("backup_keep_count", "30"), + backup_restore_dir=default_restore_dir(), + ai_enabled=get_setting("ai_enabled", "0") == "1", + ai_provider=get_setting("ai_provider", "ollama"), + ai_ollama_base_url=get_setting("ai_ollama_base_url", "http://127.0.0.1:11434"), + ai_ollama_model=get_setting("ai_ollama_model", "qwen2.5:7b"), + ai_openai_base_url=get_setting("ai_openai_base_url", "https://api.openai.com/v1"), + ai_openai_api_key=get_setting("ai_openai_api_key", ""), + ai_openai_model=get_setting("ai_openai_model", "gpt-4o-mini"), + ai_daily_report_enabled=get_setting("ai_daily_report_enabled", "1") == "1", + ai_daily_report_hour=get_setting("ai_daily_report_hour", "15"), + ai_daily_report_minute=get_setting("ai_daily_report_minute", "5"), + ) diff --git a/modules/stats/__init__.py b/modules/stats/__init__.py new file mode 100644 index 0000000..4a8dd96 --- /dev/null +++ b/modules/stats/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) 2025-2026 马建军. All rights reserved. + +from modules.stats.routes import register + +__all__ = ["register"] diff --git a/dashboard_lib.py b/modules/stats/dashboard_lib.py similarity index 93% rename from dashboard_lib.py rename to modules/stats/dashboard_lib.py index c3ec2e1..73baceb 100644 --- a/dashboard_lib.py +++ b/modules/stats/dashboard_lib.py @@ -1,288 +1,288 @@ -# Copyright (c) 2025-2026 马建军. All rights reserved. -# 专有软件 — 未经授权禁止复制、传播、转售。 -# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md - -"""数据看板:账户、关键位、平仓记录聚合。""" -from __future__ import annotations - -from datetime import datetime -from typing import Any, Callable, Optional -from zoneinfo import ZoneInfo - -_TZ = ZoneInfo("Asia/Shanghai") -_PRICE_CACHE: dict[str, tuple[float, float]] = {} -_PRICE_CACHE_TTL = 2.0 - - -def _cached_fetch_price( - fetch_price: Callable[[str, str, str], Optional[float]], - sym: str, - market: str, - sina: str, -) -> Optional[float]: - key = sym or "" - now = datetime.now().timestamp() - hit = _PRICE_CACHE.get(key) - if hit and (now - hit[1]) < _PRICE_CACHE_TTL: - return hit[0] - price = fetch_price(sym, market, sina) - if price is not None: - _PRICE_CACHE[key] = (float(price), now) - return price - - -def _direction_label(direction: str) -> str: - return "做多" if (direction or "").strip().lower() == "long" else "做空" - - -def _symbol_fields(ths_code: str) -> dict[str, Any]: - from symbols import position_symbol_meta - - sym = (ths_code or "").strip() - meta = position_symbol_meta(sym) - return { - "symbol_code": sym, - "symbol_name": meta.get("name") or sym, - "symbol_exchange": meta.get("exchange") or "", - "symbol_is_main": bool(meta.get("is_main")), - } - - -def build_risk_overview( - conn, - get_setting: Callable[[str, str], str], - *, - equity: Optional[float] = None, - margin_used: Optional[float] = None, -) -> dict[str, Any]: - from risk.account_risk_lib import ( - cooling_hours_manual, - cooling_hours_manual_journal, - count_daily_opens, - daily_position_limit, - daily_trading_risk_pct_limit, - daily_trading_risk_used_pct, - ensure_account_risk_schema, - get_risk_status, - manual_close_daily_limit, - max_active_positions, - risk_control_enabled, - trading_day_label, - trading_day_reset_hour, - ) - from trading_context import ( - get_fixed_amount, - get_fixed_lots, - get_max_margin_pct, - get_roll_max_margin_pct, - get_sizing_mode, - ) - - ensure_account_risk_schema(conn) - risk = dict(get_risk_status(conn, equity=equity) or {}) - row = conn.execute("SELECT * FROM account_risk_state WHERE id=1").fetchone() - td = trading_day_label() - stored_td = str(row["trading_day"] or "") if row else "" - manual_count = int(row["manual_close_count"] or 0) if row and stored_td == td else 0 - - margin_pct_used: Optional[float] = None - if equity and equity > 0 and margin_used is not None and margin_used >= 0: - margin_pct_used = round(float(margin_used) / float(equity) * 100, 2) - - max_margin = get_max_margin_pct(get_setting) - sizing = get_sizing_mode(get_setting) - sizing_label = "固定金额" if sizing == "amount" else "固定手数" - - daily_opens = int(risk.get("daily_open_count") or count_daily_opens(conn)) - daily_risk_used = risk.get("daily_risk_used_pct") - if daily_risk_used is None and equity and equity > 0: - daily_risk_used = daily_trading_risk_used_pct(conn, float(equity)) - - return { - "enabled": risk_control_enabled(), - "status": risk, - "manual_close_count_today": manual_count, - "margin_pct_used": margin_pct_used, - "daily_open_count": daily_opens, - "daily_risk_used_pct": daily_risk_used, - "limits": { - "max_active_positions": max_active_positions(), - "position_mode": "single" if max_active_positions() <= 1 else "multi", - "position_mode_label": "单仓模式" if max_active_positions() <= 1 else "多仓模式", - "daily_position_limit": daily_position_limit(), - "daily_trading_risk_pct_limit": daily_trading_risk_pct_limit(), - "manual_close_daily_limit": manual_close_daily_limit(), - "cooling_hours_manual": cooling_hours_manual(), - "cooling_hours_manual_journal": cooling_hours_manual_journal(), - "trading_day_reset_hour": trading_day_reset_hour(), - "max_margin_pct": max_margin, - "roll_max_margin_pct": get_roll_max_margin_pct(get_setting), - "sizing_mode": sizing, - "sizing_label": sizing_label, - "fixed_lots": get_fixed_lots(get_setting), - "fixed_amount": get_fixed_amount(get_setting), - }, - } - - -def build_dashboard_payload( - *, - get_db: Callable, - get_setting: Callable[[str, str], str], - fetch_price: Callable[[str, str, str], Optional[float]], - closes_limit: int = 40, - sync_ctp_trades: bool = False, -) -> dict[str, Any]: - from trading_context import get_account_capital, get_trading_mode, trading_mode_label - from vnpy_bridge import ctp_account_margin_used, ctp_status, get_bridge - - mode = get_trading_mode(get_setting) - ctp_st = dict(ctp_status(mode) or {}) - conn = get_db() - try: - capital = float(get_account_capital(conn, get_setting) or 0) - equity = capital - available: Optional[float] = None - margin_used: Optional[float] = None - - if ctp_st.get("connected"): - if sync_ctp_trades: - try: - from ctp_trade_sync import sync_trade_logs_from_ctp - - sync_trade_logs_from_ctp( - conn, mode, capital=capital, trading_mode=mode, - ) - conn.commit() - except Exception: - pass - try: - b = get_bridge() - if b.connected_mode == mode and b.ping(): - acc = b.get_account() or {} - else: - acc = {} - balance = float(acc.get("balance") or 0) - if balance > 0: - equity = balance - avail = acc.get("available") - if avail is not None: - available = round(float(avail), 2) - mu = ctp_account_margin_used(mode) - if mu is not None and mu > 0: - margin_used = round(float(mu), 2) - elif available is not None and equity > 0: - margin_used = round(max(0.0, equity - available), 2) - except Exception: - pass - else: - from trading_context import _cached_ctp_account - - cached = _cached_ctp_account(mode) - balance = float(cached.get("balance") or 0) - if balance > 0: - equity = balance - avail = cached.get("available") - if avail is not None: - available = round(float(avail), 2) - if equity > 0: - margin_used = round(max(0.0, equity - available), 2) - - key_rows = conn.execute( - """ - SELECT id, symbol, symbol_name, market_code, sina_code, - monitor_type, direction, upper, lower, trade_mode, - bar_period, trailing_be - FROM key_monitors - WHERE status='active' OR status IS NULL - ORDER BY id DESC - """ - ).fetchall() - keys: list[dict[str, Any]] = [] - for r in key_rows: - sym = r["symbol"] - market = r["market_code"] or "" - sina = r["sina_code"] or "" - upper = float(r["upper"] or 0) - lower = float(r["lower"] or 0) - price = _cached_fetch_price(fetch_price, sym, market, sina) - dist_upper = dist_lower = None - if price is not None: - dist_upper = round(upper - float(price), 2) - dist_lower = round(float(price) - lower, 2) - mtype = r["monitor_type"] or "" - sf = _symbol_fields(sym) - keys.append({ - "id": r["id"], - "symbol": sym, - **sf, - "symbol_name": r["symbol_name"] or sf.get("symbol_name") or sym, - "monitor_type": mtype, - "direction": r["direction"] or "", - "direction_label": _direction_label(r["direction"] or "long") - if r["direction"] else "", - "upper": upper, - "lower": lower, - "trade_mode": r["trade_mode"] or "", - "bar_period": r["bar_period"] or "5m", - "trailing_be": bool(r["trailing_be"]), - "price": price, - "dist_upper": dist_upper, - "dist_lower": dist_lower, - }) - - close_rows = conn.execute( - """ - SELECT id, symbol, symbol_name, direction, lots, - entry_price, close_price, pnl, pnl_net, fee, - close_time, result, source - FROM trade_logs - ORDER BY id DESC - LIMIT ? - """, - (max(1, min(200, closes_limit)),), - ).fetchall() - closes: list[dict[str, Any]] = [] - for r in close_rows: - sym_code = r["symbol"] or "" - sf = _symbol_fields(sym_code) - closes.append({ - "id": r["id"], - "symbol": r["symbol_name"] or sf.get("symbol_name") or sym_code, - "symbol_code": sym_code, - **sf, - "symbol_name": r["symbol_name"] or sf.get("symbol_name") or sym_code, - "direction": r["direction"] or "long", - "direction_label": _direction_label(r["direction"] or "long"), - "lots": float(r["lots"] or 0), - "entry_price": float(r["entry_price"] or 0), - "close_price": float(r["close_price"] or 0), - "pnl": float(r["pnl"] or 0) if r["pnl"] is not None else None, - "pnl_net": float(r["pnl_net"] or 0) if r["pnl_net"] is not None else None, - "fee": float(r["fee"] or 0) if r["fee"] is not None else None, - "close_time": (r["close_time"] or "")[:16].replace("T", " "), - "result": r["result"] or "", - "source": r["source"] or "", - }) - - now_iso = datetime.now(_TZ).strftime("%Y-%m-%d %H:%M:%S") - risk = build_risk_overview( - conn, get_setting, equity=equity, margin_used=margin_used, - ) - return { - "ok": True, - "updated_at": now_iso, - "trading_mode_label": trading_mode_label(get_setting), - "ctp_status": ctp_st, - "account": { - "equity": round(equity, 2), - "margin_used": margin_used, - "available": available, - "capital_fallback": round(capital, 2), - }, - "risk": risk, - "keys": keys, - "closes": closes, - } - finally: - conn.close() +# Copyright (c) 2025-2026 马建军. All rights reserved. +# 专有软件 — 未经授权禁止复制、传播、转售。 +# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md + +"""数据看板:账户、关键位、平仓记录聚合。""" +from __future__ import annotations + +from datetime import datetime +from typing import Any, Callable, Optional +from zoneinfo import ZoneInfo + +_TZ = ZoneInfo("Asia/Shanghai") +_PRICE_CACHE: dict[str, tuple[float, float]] = {} +_PRICE_CACHE_TTL = 2.0 + + +def _cached_fetch_price( + fetch_price: Callable[[str, str, str], Optional[float]], + sym: str, + market: str, + sina: str, +) -> Optional[float]: + key = sym or "" + now = datetime.now().timestamp() + hit = _PRICE_CACHE.get(key) + if hit and (now - hit[1]) < _PRICE_CACHE_TTL: + return hit[0] + price = fetch_price(sym, market, sina) + if price is not None: + _PRICE_CACHE[key] = (float(price), now) + return price + + +def _direction_label(direction: str) -> str: + return "做多" if (direction or "").strip().lower() == "long" else "做空" + + +def _symbol_fields(ths_code: str) -> dict[str, Any]: + from modules.core.symbols import position_symbol_meta + + sym = (ths_code or "").strip() + meta = position_symbol_meta(sym) + return { + "symbol_code": sym, + "symbol_name": meta.get("name") or sym, + "symbol_exchange": meta.get("exchange") or "", + "symbol_is_main": bool(meta.get("is_main")), + } + + +def build_risk_overview( + conn, + get_setting: Callable[[str, str], str], + *, + equity: Optional[float] = None, + margin_used: Optional[float] = None, +) -> dict[str, Any]: + from risk.account_risk_lib import ( + cooling_hours_manual, + cooling_hours_manual_journal, + count_daily_opens, + daily_position_limit, + daily_trading_risk_pct_limit, + daily_trading_risk_used_pct, + ensure_account_risk_schema, + get_risk_status, + manual_close_daily_limit, + max_active_positions, + risk_control_enabled, + trading_day_label, + trading_day_reset_hour, + ) + from modules.core.trading_context import ( + get_fixed_amount, + get_fixed_lots, + get_max_margin_pct, + get_roll_max_margin_pct, + get_sizing_mode, + ) + + ensure_account_risk_schema(conn) + risk = dict(get_risk_status(conn, equity=equity) or {}) + row = conn.execute("SELECT * FROM account_risk_state WHERE id=1").fetchone() + td = trading_day_label() + stored_td = str(row["trading_day"] or "") if row else "" + manual_count = int(row["manual_close_count"] or 0) if row and stored_td == td else 0 + + margin_pct_used: Optional[float] = None + if equity and equity > 0 and margin_used is not None and margin_used >= 0: + margin_pct_used = round(float(margin_used) / float(equity) * 100, 2) + + max_margin = get_max_margin_pct(get_setting) + sizing = get_sizing_mode(get_setting) + sizing_label = "固定金额" if sizing == "amount" else "固定手数" + + daily_opens = int(risk.get("daily_open_count") or count_daily_opens(conn)) + daily_risk_used = risk.get("daily_risk_used_pct") + if daily_risk_used is None and equity and equity > 0: + daily_risk_used = daily_trading_risk_used_pct(conn, float(equity)) + + return { + "enabled": risk_control_enabled(), + "status": risk, + "manual_close_count_today": manual_count, + "margin_pct_used": margin_pct_used, + "daily_open_count": daily_opens, + "daily_risk_used_pct": daily_risk_used, + "limits": { + "max_active_positions": max_active_positions(), + "position_mode": "single" if max_active_positions() <= 1 else "multi", + "position_mode_label": "单仓模式" if max_active_positions() <= 1 else "多仓模式", + "daily_position_limit": daily_position_limit(), + "daily_trading_risk_pct_limit": daily_trading_risk_pct_limit(), + "manual_close_daily_limit": manual_close_daily_limit(), + "cooling_hours_manual": cooling_hours_manual(), + "cooling_hours_manual_journal": cooling_hours_manual_journal(), + "trading_day_reset_hour": trading_day_reset_hour(), + "max_margin_pct": max_margin, + "roll_max_margin_pct": get_roll_max_margin_pct(get_setting), + "sizing_mode": sizing, + "sizing_label": sizing_label, + "fixed_lots": get_fixed_lots(get_setting), + "fixed_amount": get_fixed_amount(get_setting), + }, + } + + +def build_dashboard_payload( + *, + get_db: Callable, + get_setting: Callable[[str, str], str], + fetch_price: Callable[[str, str, str], Optional[float]], + closes_limit: int = 40, + sync_ctp_trades: bool = False, +) -> dict[str, Any]: + from modules.core.trading_context import get_account_capital, get_trading_mode, trading_mode_label + from modules.ctp.vnpy_bridge import ctp_account_margin_used, ctp_status, get_bridge + + mode = get_trading_mode(get_setting) + ctp_st = dict(ctp_status(mode) or {}) + conn = get_db() + try: + capital = float(get_account_capital(conn, get_setting) or 0) + equity = capital + available: Optional[float] = None + margin_used: Optional[float] = None + + if ctp_st.get("connected"): + if sync_ctp_trades: + try: + from modules.ctp.ctp_trade_sync import sync_trade_logs_from_ctp + + sync_trade_logs_from_ctp( + conn, mode, capital=capital, trading_mode=mode, + ) + conn.commit() + except Exception: + pass + try: + b = get_bridge() + if b.connected_mode == mode and b.ping(): + acc = b.get_account() or {} + else: + acc = {} + balance = float(acc.get("balance") or 0) + if balance > 0: + equity = balance + avail = acc.get("available") + if avail is not None: + available = round(float(avail), 2) + mu = ctp_account_margin_used(mode) + if mu is not None and mu > 0: + margin_used = round(float(mu), 2) + elif available is not None and equity > 0: + margin_used = round(max(0.0, equity - available), 2) + except Exception: + pass + else: + from modules.core.trading_context import _cached_ctp_account + + cached = _cached_ctp_account(mode) + balance = float(cached.get("balance") or 0) + if balance > 0: + equity = balance + avail = cached.get("available") + if avail is not None: + available = round(float(avail), 2) + if equity > 0: + margin_used = round(max(0.0, equity - available), 2) + + key_rows = conn.execute( + """ + SELECT id, symbol, symbol_name, market_code, sina_code, + monitor_type, direction, upper, lower, trade_mode, + bar_period, trailing_be + FROM key_monitors + WHERE status='active' OR status IS NULL + ORDER BY id DESC + """ + ).fetchall() + keys: list[dict[str, Any]] = [] + for r in key_rows: + sym = r["symbol"] + market = r["market_code"] or "" + sina = r["sina_code"] or "" + upper = float(r["upper"] or 0) + lower = float(r["lower"] or 0) + price = _cached_fetch_price(fetch_price, sym, market, sina) + dist_upper = dist_lower = None + if price is not None: + dist_upper = round(upper - float(price), 2) + dist_lower = round(float(price) - lower, 2) + mtype = r["monitor_type"] or "" + sf = _symbol_fields(sym) + keys.append({ + "id": r["id"], + "symbol": sym, + **sf, + "symbol_name": r["symbol_name"] or sf.get("symbol_name") or sym, + "monitor_type": mtype, + "direction": r["direction"] or "", + "direction_label": _direction_label(r["direction"] or "long") + if r["direction"] else "", + "upper": upper, + "lower": lower, + "trade_mode": r["trade_mode"] or "", + "bar_period": r["bar_period"] or "5m", + "trailing_be": bool(r["trailing_be"]), + "price": price, + "dist_upper": dist_upper, + "dist_lower": dist_lower, + }) + + close_rows = conn.execute( + """ + SELECT id, symbol, symbol_name, direction, lots, + entry_price, close_price, pnl, pnl_net, fee, + close_time, result, source + FROM trade_logs + ORDER BY id DESC + LIMIT ? + """, + (max(1, min(200, closes_limit)),), + ).fetchall() + closes: list[dict[str, Any]] = [] + for r in close_rows: + sym_code = r["symbol"] or "" + sf = _symbol_fields(sym_code) + closes.append({ + "id": r["id"], + "symbol": r["symbol_name"] or sf.get("symbol_name") or sym_code, + "symbol_code": sym_code, + **sf, + "symbol_name": r["symbol_name"] or sf.get("symbol_name") or sym_code, + "direction": r["direction"] or "long", + "direction_label": _direction_label(r["direction"] or "long"), + "lots": float(r["lots"] or 0), + "entry_price": float(r["entry_price"] or 0), + "close_price": float(r["close_price"] or 0), + "pnl": float(r["pnl"] or 0) if r["pnl"] is not None else None, + "pnl_net": float(r["pnl_net"] or 0) if r["pnl_net"] is not None else None, + "fee": float(r["fee"] or 0) if r["fee"] is not None else None, + "close_time": (r["close_time"] or "")[:16].replace("T", " "), + "result": r["result"] or "", + "source": r["source"] or "", + }) + + now_iso = datetime.now(_TZ).strftime("%Y-%m-%d %H:%M:%S") + risk = build_risk_overview( + conn, get_setting, equity=equity, margin_used=margin_used, + ) + return { + "ok": True, + "updated_at": now_iso, + "trading_mode_label": trading_mode_label(get_setting), + "ctp_status": ctp_st, + "account": { + "equity": round(equity, 2), + "margin_used": margin_used, + "available": available, + "capital_fallback": round(capital, 2), + }, + "risk": risk, + "keys": keys, + "closes": closes, + } + finally: + conn.close() diff --git a/modules/stats/routes.py b/modules/stats/routes.py new file mode 100644 index 0000000..37774d3 --- /dev/null +++ b/modules/stats/routes.py @@ -0,0 +1,174 @@ +# Copyright (c) 2025-2026 马建军. All rights reserved. +"""HTTP routes for stats module.""" + +from __future__ import annotations + +from datetime import date, datetime + +from flask import ( + Response, + flash, + jsonify, + redirect, + render_template, + request, + send_file, + session, + stream_with_context, + url_for, +) + + +def register(deps) -> None: + app = deps.app + login_required = deps.login_required + require_nav = deps.require_nav + get_db = deps.get_db + get_setting = deps.get_setting + set_setting = deps.set_setting + fetch_price = deps.fetch_price + send_wechat_msg = deps.send_wechat_msg + touch_stats_cache = deps.touch_stats_cache + get_stats_data = deps.get_stats_data + build_market_quote_payload = deps.build_market_quote_payload + today_str = deps.today_str + expire_old_plans = deps.expire_old_plans + TZ = deps.tz + DB_PATH = deps.db_path + UPLOAD_DIR = deps.upload_dir + OPEN_TYPES = deps.open_types + EXIT_TRIGGERS = deps.exit_triggers + BEHAVIOR_TAGS = deps.behavior_tags + KLINE_PERIODS = deps.kline_periods + KLINE_CUTOFFS = deps.kline_cutoffs + calc_holding_duration = deps.calc_holding_duration + holding_to_minutes = deps.holding_to_minutes + classify_close_result = deps.classify_close_result + calc_rr_ratio = deps.calc_rr_ratio + calc_theoretical_pnl = deps.calc_theoretical_pnl + parse_review_date_filter = deps.parse_review_date_filter + _trading_mode = deps.trading_mode + _ua_is_phone = deps.ua_is_phone + _static_asset_v = deps.static_asset_v + + from modules.stats.stats_engine import ( + STATS_VIEWS, + get_calendar_day, + get_calendar_month, + refresh_stats_cache, + ) + from modules.settings.nav_settings import nav_enabled + from modules.stats.dashboard_lib import build_dashboard_payload + from modules.core.doc_render import read_doc, render_markdown + + _dashboard_sync_tick = {"n": 0} + + @app.route("/stats") + @login_required + def stats(): + return render_template("stats.html") + + + @app.route("/calendar") + @login_required + def trade_calendar(): + return render_template("calendar.html") + + + @app.route("/api/stats") + @login_required + def api_stats(): + return jsonify(get_stats_data()) + + + @app.route("/api/stats/views") + @login_required + def api_stats_views(): + return jsonify({"views": STATS_VIEWS}) + + + @app.route("/api/stats/refresh", methods=["POST"]) + @login_required + def api_stats_refresh(): + conn = get_db() + capital = float(get_setting("live_capital", "0") or 0) + data = refresh_stats_cache(conn, capital) + conn.close() + return jsonify(data) + + + @app.route("/api/stats/calendar") + @login_required + def api_stats_calendar(): + now = datetime.now(TZ) + year = request.args.get("year", type=int) or now.year + month = request.args.get("month", type=int) or now.month + if month < 1 or month > 12: + return jsonify({"error": "invalid month"}), 400 + conn = get_db() + try: + data = get_calendar_month(conn, year, month) + finally: + conn.close() + return jsonify(data) + + + @app.route("/api/stats/calendar/day") + @login_required + def api_stats_calendar_day(): + day = (request.args.get("date") or "").strip() + if not day: + return jsonify({"error": "date required"}), 400 + try: + date.fromisoformat(day) + except ValueError: + return jsonify({"error": "invalid date"}), 400 + conn = get_db() + try: + data = get_calendar_day(conn, day) + finally: + conn.close() + return jsonify(data) + + + @app.route("/dashboard") + @login_required + @require_nav("dashboard") + def dashboard(): + return render_template("dashboard.html") + + + @app.route("/risk-guide") + @login_required + @require_nav("risk_guide") + def risk_guide(): + from modules.core.doc_render import read_doc, render_markdown + + try: + _title, raw = read_doc("risk-guide") + except FileNotFoundError: + flash("文档不存在") + return redirect(url_for("positions")) + return render_template("risk_guide.html", doc_html=render_markdown(raw)) + + + @app.route("/api/dashboard/live") + @login_required + def api_dashboard_live(): + if not nav_enabled(get_setting, "dashboard"): + return jsonify({"ok": False, "error": "数据看板已在系统设置中关闭"}), 403 + from modules.stats.dashboard_lib import build_dashboard_payload + + _dashboard_sync_tick["n"] += 1 + sync_trades = _dashboard_sync_tick["n"] % 15 == 0 + try: + payload = build_dashboard_payload( + get_db=get_db, + get_setting=get_setting, + fetch_price=fetch_price, + sync_ctp_trades=sync_trades, + ) + return jsonify(payload) + except Exception as exc: + app.logger.exception("dashboard live: %s", exc) + return jsonify({"ok": False, "error": "看板数据暂时不可用"}), 503 diff --git a/stats_engine.py b/modules/stats/stats_engine.py similarity index 96% rename from stats_engine.py rename to modules/stats/stats_engine.py index 125bd79..40084dc 100644 --- a/stats_engine.py +++ b/modules/stats/stats_engine.py @@ -1,568 +1,568 @@ -# Copyright (c) 2025-2026 马建军. All rights reserved. -# 专有软件 — 未经授权禁止复制、传播、转售。 -# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 -# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md - -"""交易统计计算与缓存结构。""" -from __future__ import annotations - -import calendar -import json -import threading -from datetime import date, datetime -from typing import Any, Optional - -from zoneinfo import ZoneInfo - -from db_conn import commit_retry, execute_retry - -_stats_refresh_lock = threading.Lock() - -TZ = ZoneInfo("Asia/Shanghai") - -STATS_VIEWS = [ - {"key": "by_time", "label": "按时间统计"}, - {"key": "by_week", "label": "周统计"}, - {"key": "by_month", "label": "月统计"}, - {"key": "by_symbol", "label": "按品种统计"}, - {"key": "by_fee", "label": "按手续费统计"}, - {"key": "by_direction", "label": "按方向统计"}, - {"key": "by_trade_type", "label": "按交易类型统计"}, - {"key": "by_emotion", "label": "情绪单统计"}, -] - -BREAKDOWN_COLUMNS = [ - {"key": "label", "label": "维度"}, - {"key": "count", "label": "交易次数"}, - {"key": "wins", "label": "盈利笔数"}, - {"key": "losses", "label": "亏损笔数"}, - {"key": "win_rate", "label": "胜率(%)"}, - {"key": "avg_profit", "label": "平均盈利"}, - {"key": "avg_loss", "label": "平均亏损"}, - {"key": "profit_loss_ratio", "label": "盈亏比"}, - {"key": "total_fee", "label": "累计手续费"}, - {"key": "total_net", "label": "净盈亏合计"}, - {"key": "max_loss", "label": "最大亏损"}, - {"key": "max_profit", "label": "最大盈利"}, -] - - -def _parse_dt(value: str) -> Optional[datetime]: - if not value: - return None - text = value.strip().replace(" ", "T") - try: - return datetime.fromisoformat(text) - except ValueError: - return None - - -def _row_dict(row) -> dict: - return dict(row) if row is not None else {} - - -def _net_pnl(row: dict) -> float: - if row.get("pnl_net") is not None: - return float(row["pnl_net"]) - pnl = float(row.get("pnl") or 0) - fee = float(row.get("fee") or 0) - return round(pnl - fee, 2) - - -def _fee(row: dict) -> float: - return float(row.get("fee") or 0) - - -def _margin_pct(pnl_net: float, margin: Optional[float]) -> Optional[float]: - if margin and margin > 0: - return round(pnl_net / margin * 100, 2) - return None - - -def _agg_group(rows: list[dict], key_fn) -> list[dict]: - groups: dict[str, list[dict]] = {} - for row in rows: - key = key_fn(row) or "未知" - groups.setdefault(key, []).append(row) - result = [] - for label, items in sorted(groups.items(), key=lambda x: x[0]): - result.append(_agg_metrics(label, items)) - return result - - -def _agg_metrics(label: str, items: list[dict]) -> dict: - nets = [_net_pnl(r) for r in items] - wins = [n for n in nets if n > 0] - losses = [n for n in nets if n < 0] - count = len(items) - win_cnt = len(wins) - loss_cnt = len(losses) - avg_profit = round(sum(wins) / len(wins), 2) if wins else 0.0 - avg_loss = round(sum(losses) / len(losses), 2) if losses else 0.0 - pl_ratio = round(avg_profit / abs(avg_loss), 2) if wins and losses and avg_loss != 0 else 0.0 - total_fee = round(sum(_fee(r) for r in items), 2) - total_net = round(sum(nets), 2) - max_loss = round(min(losses), 2) if losses else 0.0 - max_profit = round(max(wins), 2) if wins else 0.0 - win_rate = round(win_cnt / count * 100, 2) if count else 0.0 - return { - "label": label, - "count": count, - "wins": win_cnt, - "losses": loss_cnt, - "win_rate": win_rate, - "avg_profit": avg_profit, - "avg_loss": avg_loss, - "profit_loss_ratio": pl_ratio, - "total_fee": total_fee, - "total_net": total_net, - "max_loss": max_loss, - "max_profit": max_profit, - } - - -def _max_consecutive_losses(nets: list[float]) -> int: - streak = 0 - best = 0 - for n in nets: - if n < 0: - streak += 1 - best = max(best, streak) - else: - streak = 0 - return best - - -def _max_drawdown(nets: list[float], initial_capital: float) -> tuple[float, float]: - equity = initial_capital - peak = initial_capital - max_dd = 0.0 - max_dd_pct = 0.0 - for n in nets: - equity += n - if equity > peak: - peak = equity - dd = peak - equity - if dd > max_dd: - max_dd = dd - if peak > 0: - pct = dd / peak * 100 - if pct > max_dd_pct: - max_dd_pct = pct - return round(max_dd, 2), round(max_dd_pct, 2) - - -def fetch_trade_rows(conn) -> list[dict]: - rows = conn.execute( - "SELECT * FROM trade_logs ORDER BY close_time ASC, id ASC" - ).fetchall() - return [_row_dict(r) for r in rows] - - -def fetch_review_rows(conn) -> list[dict]: - rows = conn.execute( - "SELECT * FROM review_records ORDER BY close_time ASC, id ASC" - ).fetchall() - return [_row_dict(r) for r in rows] - - -def compute_summary(trades: list[dict], reviews: list[dict], live_capital: float) -> dict: - nets = [_net_pnl(t) for t in trades] - count = len(trades) - wins = [n for n in nets if n > 0] - losses = [n for n in nets if n < 0] - win_cnt = len(wins) - loss_cnt = len(losses) - avg_profit = round(sum(wins) / len(wins), 2) if wins else 0.0 - avg_loss = round(sum(losses) / len(losses), 2) if losses else 0.0 - pl_ratio = round(avg_profit / abs(avg_loss), 2) if wins and losses and avg_loss != 0 else 0.0 - total_fee = round(sum(_fee(t) for t in trades) + sum(_fee(r) for r in reviews), 2) - max_loss_amt = round(min(losses), 2) if losses else 0.0 - max_profit_amt = round(max(wins), 2) if wins else 0.0 - - margins_loss = [ - _margin_pct(_net_pnl(t), t.get("margin")) - for t in trades - if _net_pnl(t) < 0 and t.get("margin") - ] - margins_profit = [ - _margin_pct(_net_pnl(t), t.get("margin")) - for t in trades - if _net_pnl(t) > 0 and t.get("margin") - ] - max_loss_pct = round(min(margins_loss), 2) if margins_loss else 0.0 - max_profit_pct = round(max(margins_profit), 2) if margins_profit else 0.0 - - consec_loss = _max_consecutive_losses(nets) - max_dd, max_dd_pct = _max_drawdown(nets, live_capital) - - emotion_cnt = sum(1 for r in reviews if r.get("is_emotion")) - review_cnt = len(reviews) - denom = count if count else review_cnt - emotion_ratio = round(emotion_cnt / denom * 100, 2) if denom else 0.0 - - return { - "total_trades": count, - "win_rate": round(win_cnt / count * 100, 2) if count else 0.0, - "avg_profit": avg_profit, - "avg_loss": avg_loss, - "profit_loss_ratio": pl_ratio, - "consecutive_losses": consec_loss, - "max_drawdown": max_dd, - "max_drawdown_pct": max_dd_pct, - "max_loss_amount": max_loss_amt, - "max_loss_pct": max_loss_pct, - "max_profit_amount": max_profit_amt, - "max_profit_pct": max_profit_pct, - "total_fee": total_fee, - "emotion_count": emotion_cnt, - "emotion_ratio": emotion_ratio, - "review_count": review_cnt, - "win_count": win_cnt, - "loss_count": loss_cnt, - } - - -def compute_breakdowns(trades: list[dict], reviews: list[dict]) -> dict[str, dict]: - def day_key(row: dict) -> str: - dt = _parse_dt(row.get("close_time") or row.get("created_at") or "") - return dt.date().isoformat() if dt else "未知" - - def week_key(row: dict) -> str: - dt = _parse_dt(row.get("close_time") or row.get("created_at") or "") - if not dt: - return "未知" - iso = dt.isocalendar() - return f"{iso.year}-W{iso.week:02d}" - - def month_key(row: dict) -> str: - dt = _parse_dt(row.get("close_time") or row.get("created_at") or "") - return dt.strftime("%Y-%m") if dt else "未知" - - def symbol_key(row: dict) -> str: - return row.get("symbol_name") or row.get("symbol") or "未知" - - def direction_key(row: dict) -> str: - d = row.get("direction") or "" - return "做多" if d == "long" else ("做空" if d == "short" else d or "未知") - - def type_key(row: dict) -> str: - return row.get("monitor_type") or "未知" - - by_fee_rows = [] - fee_groups = {} - for t in trades: - key = symbol_key(t) - fee_groups.setdefault(key, []).append(t) - for label, items in sorted(fee_groups.items()): - row = _agg_metrics(label, items) - row["avg_fee"] = round(row["total_fee"] / row["count"], 2) if row["count"] else 0.0 - by_fee_rows.append(row) - - emotion_trades = [r for r in reviews if r.get("is_emotion")] - non_emotion = [r for r in reviews if not r.get("is_emotion")] - emotion_rows = [ - _agg_metrics("情绪单", emotion_trades), - _agg_metrics("非情绪单", non_emotion), - ] - - fee_columns = BREAKDOWN_COLUMNS + [{"key": "avg_fee", "label": "平均手续费"}] - - return { - "by_time": {"columns": BREAKDOWN_COLUMNS, "rows": _agg_group(trades, day_key)}, - "by_week": {"columns": BREAKDOWN_COLUMNS, "rows": _agg_group(trades, week_key)}, - "by_month": {"columns": BREAKDOWN_COLUMNS, "rows": _agg_group(trades, month_key)}, - "by_symbol": {"columns": BREAKDOWN_COLUMNS, "rows": _agg_group(trades, symbol_key)}, - "by_fee": {"columns": fee_columns, "rows": by_fee_rows}, - "by_direction": {"columns": BREAKDOWN_COLUMNS, "rows": _agg_group(trades, direction_key)}, - "by_trade_type": {"columns": BREAKDOWN_COLUMNS, "rows": _agg_group(trades, type_key)}, - "by_emotion": {"columns": BREAKDOWN_COLUMNS, "rows": emotion_rows}, - } - - -def build_all_stats(conn, live_capital: float = 0.0) -> dict: - trades = fetch_trade_rows(conn) - reviews = fetch_review_rows(conn) - summary = compute_summary(trades, reviews, live_capital) - breakdowns = compute_breakdowns(trades, reviews) - return { - "updated_at": datetime.now(TZ).isoformat(timespec="seconds"), - "summary": summary, - "views": STATS_VIEWS, - "breakdowns": breakdowns, - } - - -def save_stats_cache(conn, data: dict) -> None: - execute_retry( - conn, - """INSERT INTO stats_cache (key, data_json, updated_at) - VALUES ('all', ?, ?) - ON CONFLICT(key) DO UPDATE SET data_json=excluded.data_json, updated_at=excluded.updated_at""", - (json.dumps(data, ensure_ascii=False), data["updated_at"]), - ) - commit_retry(conn) - - -def load_stats_cache(conn) -> Optional[dict]: - row = conn.execute( - "SELECT data_json FROM stats_cache WHERE key='all'" - ).fetchone() - if not row: - return None - try: - return json.loads(row["data_json"]) - except json.JSONDecodeError: - return None - - -def refresh_stats_cache(conn, live_capital: float = 0.0) -> dict: - with _stats_refresh_lock: - data = build_all_stats(conn, live_capital) - save_stats_cache(conn, data) - return data - - -def _norm_symbol(symbol: str) -> str: - s = (symbol or "").strip().lower() - if "." in s: - s = s.split(".")[0] - return s - - -def _close_day_key(row: dict) -> str: - dt = _parse_dt(row.get("close_time") or row.get("created_at") or "") - return dt.date().isoformat() if dt else "" - - -def _close_ts(row: dict) -> float: - dt = _parse_dt(row.get("close_time") or row.get("created_at") or "") - return dt.timestamp() if dt else 0.0 - - -def _direction_label(direction: str) -> str: - if direction == "long": - return "做多" - if direction == "short": - return "做空" - return direction or "" - - -def _index_reviews_by_day_sym(reviews: list[dict]) -> dict[tuple[str, str], list[dict]]: - index: dict[tuple[str, str], list[dict]] = {} - for review in reviews: - day = _close_day_key(review) - if not day: - continue - sym = _norm_symbol(review.get("symbol") or "") - index.setdefault((day, sym), []).append(review) - return index - - -def _review_match_score(trade: dict, review: dict) -> float: - score = abs(_close_ts(trade) - _close_ts(review)) - lots_t = trade.get("lots") - lots_r = review.get("lots") - if lots_t is not None and lots_r is not None and float(lots_t) != float(lots_r): - score += 86400.0 - entry_t = trade.get("entry_price") - entry_r = review.get("entry_price") - if entry_t is not None and entry_r is not None and abs(float(entry_t) - float(entry_r)) > 0.01: - score += 3600.0 - return score - - -def _find_review_for_trade( - trade: dict, - review_index: dict[tuple[str, str], list[dict]], - used_review_ids: set[int], -) -> Optional[dict]: - day = _close_day_key(trade) - sym = _norm_symbol(trade.get("symbol") or "") - candidates = [ - r for r in review_index.get((day, sym), []) - if r.get("id") not in used_review_ids - ] - if not candidates: - return None - return min(candidates, key=lambda r: _review_match_score(trade, r)) - - -def _format_day_entry( - *, - trade: Optional[dict] = None, - review: Optional[dict] = None, - source: str, -) -> dict: - row = review if source == "review" and review else trade or review or {} - symbol = row.get("symbol") or "" - pnl_net = _net_pnl(row) - tags = (row.get("behavior_tags") or "").strip() - is_emotion = bool(row.get("is_emotion")) - return { - "source": source, - "trade_id": trade.get("id") if trade else None, - "review_id": review.get("id") if review else None, - "symbol": row.get("symbol_name") or symbol, - "symbol_code": symbol, - "direction": _direction_label(row.get("direction") or ""), - "lots": row.get("lots"), - "entry_price": row.get("entry_price"), - "close_price": row.get("close_price"), - "stop_loss": row.get("stop_loss"), - "take_profit": row.get("take_profit"), - "open_time": row.get("open_time") or "", - "close_time": row.get("close_time") or "", - "pnl": row.get("pnl"), - "fee": row.get("fee"), - "pnl_net": pnl_net, - "result": row.get("result") if trade else None, - "monitor_type": row.get("monitor_type") if trade else None, - "is_emotion": is_emotion, - "behavior_tags": tags, - "open_type": row.get("open_type") if review else None, - "exit_trigger": row.get("exit_trigger") if review else None, - "exit_supplement": row.get("exit_supplement") if review else None, - "holding_duration": row.get("holding_duration") if review else None, - "initial_pnl": row.get("initial_pnl") if review else None, - "actual_pnl": row.get("actual_pnl") if review else None, - "timeframe": row.get("timeframe") if review else None, - "notes": row.get("notes") if review else None, - "screenshot": row.get("screenshot") if review else None, - } - - -def build_day_detail(trades: list[dict], reviews: list[dict], day: str) -> list[dict]: - day_trades = [t for t in trades if _close_day_key(t) == day] - day_reviews = [r for r in reviews if _close_day_key(r) == day] - review_index = _index_reviews_by_day_sym(day_reviews) - used_review_ids: set[int] = set() - items: list[dict] = [] - - for trade in day_trades: - review = _find_review_for_trade(trade, review_index, used_review_ids) - if review: - used_review_ids.add(int(review["id"])) - items.append(_format_day_entry(trade=trade, review=review, source="review")) - else: - items.append(_format_day_entry(trade=trade, source="trade")) - - for review in day_reviews: - if int(review.get("id") or 0) in used_review_ids: - continue - items.append(_format_day_entry(review=review, source="review")) - - items.sort(key=lambda x: _close_ts(x), reverse=True) - return items - - -def build_calendar_month(trades: list[dict], reviews: list[dict], year: int, month: int) -> dict: - review_index = _index_reviews_by_day_sym(reviews) - day_map: dict[str, dict] = {} - matched_review_ids: dict[str, set[int]] = {} - - for trade in trades: - dt = _parse_dt(trade.get("close_time") or "") - if not dt or dt.year != year or dt.month != month: - continue - day = dt.date().isoformat() - bucket = day_map.setdefault( - day, - { - "date": day, - "count": 0, - "total_net": 0.0, - "review_count": 0, - "emotion_count": 0, - "has_emotion": False, - }, - ) - bucket["count"] += 1 - used = matched_review_ids.setdefault(day, set()) - review = _find_review_for_trade(trade, review_index, used) - if review: - rid = int(review["id"]) - used.add(rid) - bucket["total_net"] = round(bucket["total_net"] + _net_pnl(review), 2) - bucket["review_count"] += 1 - if review.get("is_emotion"): - bucket["emotion_count"] += 1 - bucket["has_emotion"] = True - else: - bucket["total_net"] = round(bucket["total_net"] + _net_pnl(trade), 2) - - for review in reviews: - if not review.get("is_emotion"): - continue - day = _close_day_key(review) - if not day: - continue - try: - dt = date.fromisoformat(day) - except ValueError: - continue - if dt.year != year or dt.month != month: - continue - bucket = day_map.setdefault( - day, - { - "date": day, - "count": 0, - "total_net": 0.0, - "review_count": 0, - "emotion_count": 0, - "has_emotion": False, - }, - ) - bucket["has_emotion"] = True - rid = int(review.get("id") or 0) - if rid and rid not in matched_review_ids.get(day, set()): - bucket["emotion_count"] += 1 - - _, last_day = calendar.monthrange(year, month) - days = [] - for d in range(1, last_day + 1): - iso = date(year, month, d).isoformat() - if iso in day_map: - row = day_map[iso] - row["total_net"] = round(row["total_net"], 2) - days.append(row) - else: - days.append( - { - "date": iso, - "count": 0, - "total_net": 0.0, - "review_count": 0, - "emotion_count": 0, - "has_emotion": False, - } - ) - - return { - "year": year, - "month": month, - "days": days, - "weekday_start": date(year, month, 1).weekday(), - } - - -def get_calendar_month(conn, year: int, month: int) -> dict: - trades = fetch_trade_rows(conn) - reviews = fetch_review_rows(conn) - return build_calendar_month(trades, reviews, year, month) - - -def get_calendar_day(conn, day: str) -> dict: - trades = fetch_trade_rows(conn) - reviews = fetch_review_rows(conn) - items = build_day_detail(trades, reviews, day) - total_net = round(sum(float(i.get("pnl_net") or 0) for i in items), 2) - emotion_count = sum(1 for i in items if i.get("is_emotion")) - return { - "date": day, - "count": len(items), - "total_net": total_net, - "emotion_count": emotion_count, - "items": items, - } +# Copyright (c) 2025-2026 马建军. All rights reserved. +# 专有软件 — 未经授权禁止复制、传播、转售。 +# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 +# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md + +"""交易统计计算与缓存结构。""" +from __future__ import annotations + +import calendar +import json +import threading +from datetime import date, datetime +from typing import Any, Optional + +from zoneinfo import ZoneInfo + +from modules.core.db_conn import commit_retry, execute_retry + +_stats_refresh_lock = threading.Lock() + +TZ = ZoneInfo("Asia/Shanghai") + +STATS_VIEWS = [ + {"key": "by_time", "label": "按时间统计"}, + {"key": "by_week", "label": "周统计"}, + {"key": "by_month", "label": "月统计"}, + {"key": "by_symbol", "label": "按品种统计"}, + {"key": "by_fee", "label": "按手续费统计"}, + {"key": "by_direction", "label": "按方向统计"}, + {"key": "by_trade_type", "label": "按交易类型统计"}, + {"key": "by_emotion", "label": "情绪单统计"}, +] + +BREAKDOWN_COLUMNS = [ + {"key": "label", "label": "维度"}, + {"key": "count", "label": "交易次数"}, + {"key": "wins", "label": "盈利笔数"}, + {"key": "losses", "label": "亏损笔数"}, + {"key": "win_rate", "label": "胜率(%)"}, + {"key": "avg_profit", "label": "平均盈利"}, + {"key": "avg_loss", "label": "平均亏损"}, + {"key": "profit_loss_ratio", "label": "盈亏比"}, + {"key": "total_fee", "label": "累计手续费"}, + {"key": "total_net", "label": "净盈亏合计"}, + {"key": "max_loss", "label": "最大亏损"}, + {"key": "max_profit", "label": "最大盈利"}, +] + + +def _parse_dt(value: str) -> Optional[datetime]: + if not value: + return None + text = value.strip().replace(" ", "T") + try: + return datetime.fromisoformat(text) + except ValueError: + return None + + +def _row_dict(row) -> dict: + return dict(row) if row is not None else {} + + +def _net_pnl(row: dict) -> float: + if row.get("pnl_net") is not None: + return float(row["pnl_net"]) + pnl = float(row.get("pnl") or 0) + fee = float(row.get("fee") or 0) + return round(pnl - fee, 2) + + +def _fee(row: dict) -> float: + return float(row.get("fee") or 0) + + +def _margin_pct(pnl_net: float, margin: Optional[float]) -> Optional[float]: + if margin and margin > 0: + return round(pnl_net / margin * 100, 2) + return None + + +def _agg_group(rows: list[dict], key_fn) -> list[dict]: + groups: dict[str, list[dict]] = {} + for row in rows: + key = key_fn(row) or "未知" + groups.setdefault(key, []).append(row) + result = [] + for label, items in sorted(groups.items(), key=lambda x: x[0]): + result.append(_agg_metrics(label, items)) + return result + + +def _agg_metrics(label: str, items: list[dict]) -> dict: + nets = [_net_pnl(r) for r in items] + wins = [n for n in nets if n > 0] + losses = [n for n in nets if n < 0] + count = len(items) + win_cnt = len(wins) + loss_cnt = len(losses) + avg_profit = round(sum(wins) / len(wins), 2) if wins else 0.0 + avg_loss = round(sum(losses) / len(losses), 2) if losses else 0.0 + pl_ratio = round(avg_profit / abs(avg_loss), 2) if wins and losses and avg_loss != 0 else 0.0 + total_fee = round(sum(_fee(r) for r in items), 2) + total_net = round(sum(nets), 2) + max_loss = round(min(losses), 2) if losses else 0.0 + max_profit = round(max(wins), 2) if wins else 0.0 + win_rate = round(win_cnt / count * 100, 2) if count else 0.0 + return { + "label": label, + "count": count, + "wins": win_cnt, + "losses": loss_cnt, + "win_rate": win_rate, + "avg_profit": avg_profit, + "avg_loss": avg_loss, + "profit_loss_ratio": pl_ratio, + "total_fee": total_fee, + "total_net": total_net, + "max_loss": max_loss, + "max_profit": max_profit, + } + + +def _max_consecutive_losses(nets: list[float]) -> int: + streak = 0 + best = 0 + for n in nets: + if n < 0: + streak += 1 + best = max(best, streak) + else: + streak = 0 + return best + + +def _max_drawdown(nets: list[float], initial_capital: float) -> tuple[float, float]: + equity = initial_capital + peak = initial_capital + max_dd = 0.0 + max_dd_pct = 0.0 + for n in nets: + equity += n + if equity > peak: + peak = equity + dd = peak - equity + if dd > max_dd: + max_dd = dd + if peak > 0: + pct = dd / peak * 100 + if pct > max_dd_pct: + max_dd_pct = pct + return round(max_dd, 2), round(max_dd_pct, 2) + + +def fetch_trade_rows(conn) -> list[dict]: + rows = conn.execute( + "SELECT * FROM trade_logs ORDER BY close_time ASC, id ASC" + ).fetchall() + return [_row_dict(r) for r in rows] + + +def fetch_review_rows(conn) -> list[dict]: + rows = conn.execute( + "SELECT * FROM review_records ORDER BY close_time ASC, id ASC" + ).fetchall() + return [_row_dict(r) for r in rows] + + +def compute_summary(trades: list[dict], reviews: list[dict], live_capital: float) -> dict: + nets = [_net_pnl(t) for t in trades] + count = len(trades) + wins = [n for n in nets if n > 0] + losses = [n for n in nets if n < 0] + win_cnt = len(wins) + loss_cnt = len(losses) + avg_profit = round(sum(wins) / len(wins), 2) if wins else 0.0 + avg_loss = round(sum(losses) / len(losses), 2) if losses else 0.0 + pl_ratio = round(avg_profit / abs(avg_loss), 2) if wins and losses and avg_loss != 0 else 0.0 + total_fee = round(sum(_fee(t) for t in trades) + sum(_fee(r) for r in reviews), 2) + max_loss_amt = round(min(losses), 2) if losses else 0.0 + max_profit_amt = round(max(wins), 2) if wins else 0.0 + + margins_loss = [ + _margin_pct(_net_pnl(t), t.get("margin")) + for t in trades + if _net_pnl(t) < 0 and t.get("margin") + ] + margins_profit = [ + _margin_pct(_net_pnl(t), t.get("margin")) + for t in trades + if _net_pnl(t) > 0 and t.get("margin") + ] + max_loss_pct = round(min(margins_loss), 2) if margins_loss else 0.0 + max_profit_pct = round(max(margins_profit), 2) if margins_profit else 0.0 + + consec_loss = _max_consecutive_losses(nets) + max_dd, max_dd_pct = _max_drawdown(nets, live_capital) + + emotion_cnt = sum(1 for r in reviews if r.get("is_emotion")) + review_cnt = len(reviews) + denom = count if count else review_cnt + emotion_ratio = round(emotion_cnt / denom * 100, 2) if denom else 0.0 + + return { + "total_trades": count, + "win_rate": round(win_cnt / count * 100, 2) if count else 0.0, + "avg_profit": avg_profit, + "avg_loss": avg_loss, + "profit_loss_ratio": pl_ratio, + "consecutive_losses": consec_loss, + "max_drawdown": max_dd, + "max_drawdown_pct": max_dd_pct, + "max_loss_amount": max_loss_amt, + "max_loss_pct": max_loss_pct, + "max_profit_amount": max_profit_amt, + "max_profit_pct": max_profit_pct, + "total_fee": total_fee, + "emotion_count": emotion_cnt, + "emotion_ratio": emotion_ratio, + "review_count": review_cnt, + "win_count": win_cnt, + "loss_count": loss_cnt, + } + + +def compute_breakdowns(trades: list[dict], reviews: list[dict]) -> dict[str, dict]: + def day_key(row: dict) -> str: + dt = _parse_dt(row.get("close_time") or row.get("created_at") or "") + return dt.date().isoformat() if dt else "未知" + + def week_key(row: dict) -> str: + dt = _parse_dt(row.get("close_time") or row.get("created_at") or "") + if not dt: + return "未知" + iso = dt.isocalendar() + return f"{iso.year}-W{iso.week:02d}" + + def month_key(row: dict) -> str: + dt = _parse_dt(row.get("close_time") or row.get("created_at") or "") + return dt.strftime("%Y-%m") if dt else "未知" + + def symbol_key(row: dict) -> str: + return row.get("symbol_name") or row.get("symbol") or "未知" + + def direction_key(row: dict) -> str: + d = row.get("direction") or "" + return "做多" if d == "long" else ("做空" if d == "short" else d or "未知") + + def type_key(row: dict) -> str: + return row.get("monitor_type") or "未知" + + by_fee_rows = [] + fee_groups = {} + for t in trades: + key = symbol_key(t) + fee_groups.setdefault(key, []).append(t) + for label, items in sorted(fee_groups.items()): + row = _agg_metrics(label, items) + row["avg_fee"] = round(row["total_fee"] / row["count"], 2) if row["count"] else 0.0 + by_fee_rows.append(row) + + emotion_trades = [r for r in reviews if r.get("is_emotion")] + non_emotion = [r for r in reviews if not r.get("is_emotion")] + emotion_rows = [ + _agg_metrics("情绪单", emotion_trades), + _agg_metrics("非情绪单", non_emotion), + ] + + fee_columns = BREAKDOWN_COLUMNS + [{"key": "avg_fee", "label": "平均手续费"}] + + return { + "by_time": {"columns": BREAKDOWN_COLUMNS, "rows": _agg_group(trades, day_key)}, + "by_week": {"columns": BREAKDOWN_COLUMNS, "rows": _agg_group(trades, week_key)}, + "by_month": {"columns": BREAKDOWN_COLUMNS, "rows": _agg_group(trades, month_key)}, + "by_symbol": {"columns": BREAKDOWN_COLUMNS, "rows": _agg_group(trades, symbol_key)}, + "by_fee": {"columns": fee_columns, "rows": by_fee_rows}, + "by_direction": {"columns": BREAKDOWN_COLUMNS, "rows": _agg_group(trades, direction_key)}, + "by_trade_type": {"columns": BREAKDOWN_COLUMNS, "rows": _agg_group(trades, type_key)}, + "by_emotion": {"columns": BREAKDOWN_COLUMNS, "rows": emotion_rows}, + } + + +def build_all_stats(conn, live_capital: float = 0.0) -> dict: + trades = fetch_trade_rows(conn) + reviews = fetch_review_rows(conn) + summary = compute_summary(trades, reviews, live_capital) + breakdowns = compute_breakdowns(trades, reviews) + return { + "updated_at": datetime.now(TZ).isoformat(timespec="seconds"), + "summary": summary, + "views": STATS_VIEWS, + "breakdowns": breakdowns, + } + + +def save_stats_cache(conn, data: dict) -> None: + execute_retry( + conn, + """INSERT INTO stats_cache (key, data_json, updated_at) + VALUES ('all', ?, ?) + ON CONFLICT(key) DO UPDATE SET data_json=excluded.data_json, updated_at=excluded.updated_at""", + (json.dumps(data, ensure_ascii=False), data["updated_at"]), + ) + commit_retry(conn) + + +def load_stats_cache(conn) -> Optional[dict]: + row = conn.execute( + "SELECT data_json FROM stats_cache WHERE key='all'" + ).fetchone() + if not row: + return None + try: + return json.loads(row["data_json"]) + except json.JSONDecodeError: + return None + + +def refresh_stats_cache(conn, live_capital: float = 0.0) -> dict: + with _stats_refresh_lock: + data = build_all_stats(conn, live_capital) + save_stats_cache(conn, data) + return data + + +def _norm_symbol(symbol: str) -> str: + s = (symbol or "").strip().lower() + if "." in s: + s = s.split(".")[0] + return s + + +def _close_day_key(row: dict) -> str: + dt = _parse_dt(row.get("close_time") or row.get("created_at") or "") + return dt.date().isoformat() if dt else "" + + +def _close_ts(row: dict) -> float: + dt = _parse_dt(row.get("close_time") or row.get("created_at") or "") + return dt.timestamp() if dt else 0.0 + + +def _direction_label(direction: str) -> str: + if direction == "long": + return "做多" + if direction == "short": + return "做空" + return direction or "" + + +def _index_reviews_by_day_sym(reviews: list[dict]) -> dict[tuple[str, str], list[dict]]: + index: dict[tuple[str, str], list[dict]] = {} + for review in reviews: + day = _close_day_key(review) + if not day: + continue + sym = _norm_symbol(review.get("symbol") or "") + index.setdefault((day, sym), []).append(review) + return index + + +def _review_match_score(trade: dict, review: dict) -> float: + score = abs(_close_ts(trade) - _close_ts(review)) + lots_t = trade.get("lots") + lots_r = review.get("lots") + if lots_t is not None and lots_r is not None and float(lots_t) != float(lots_r): + score += 86400.0 + entry_t = trade.get("entry_price") + entry_r = review.get("entry_price") + if entry_t is not None and entry_r is not None and abs(float(entry_t) - float(entry_r)) > 0.01: + score += 3600.0 + return score + + +def _find_review_for_trade( + trade: dict, + review_index: dict[tuple[str, str], list[dict]], + used_review_ids: set[int], +) -> Optional[dict]: + day = _close_day_key(trade) + sym = _norm_symbol(trade.get("symbol") or "") + candidates = [ + r for r in review_index.get((day, sym), []) + if r.get("id") not in used_review_ids + ] + if not candidates: + return None + return min(candidates, key=lambda r: _review_match_score(trade, r)) + + +def _format_day_entry( + *, + trade: Optional[dict] = None, + review: Optional[dict] = None, + source: str, +) -> dict: + row = review if source == "review" and review else trade or review or {} + symbol = row.get("symbol") or "" + pnl_net = _net_pnl(row) + tags = (row.get("behavior_tags") or "").strip() + is_emotion = bool(row.get("is_emotion")) + return { + "source": source, + "trade_id": trade.get("id") if trade else None, + "review_id": review.get("id") if review else None, + "symbol": row.get("symbol_name") or symbol, + "symbol_code": symbol, + "direction": _direction_label(row.get("direction") or ""), + "lots": row.get("lots"), + "entry_price": row.get("entry_price"), + "close_price": row.get("close_price"), + "stop_loss": row.get("stop_loss"), + "take_profit": row.get("take_profit"), + "open_time": row.get("open_time") or "", + "close_time": row.get("close_time") or "", + "pnl": row.get("pnl"), + "fee": row.get("fee"), + "pnl_net": pnl_net, + "result": row.get("result") if trade else None, + "monitor_type": row.get("monitor_type") if trade else None, + "is_emotion": is_emotion, + "behavior_tags": tags, + "open_type": row.get("open_type") if review else None, + "exit_trigger": row.get("exit_trigger") if review else None, + "exit_supplement": row.get("exit_supplement") if review else None, + "holding_duration": row.get("holding_duration") if review else None, + "initial_pnl": row.get("initial_pnl") if review else None, + "actual_pnl": row.get("actual_pnl") if review else None, + "timeframe": row.get("timeframe") if review else None, + "notes": row.get("notes") if review else None, + "screenshot": row.get("screenshot") if review else None, + } + + +def build_day_detail(trades: list[dict], reviews: list[dict], day: str) -> list[dict]: + day_trades = [t for t in trades if _close_day_key(t) == day] + day_reviews = [r for r in reviews if _close_day_key(r) == day] + review_index = _index_reviews_by_day_sym(day_reviews) + used_review_ids: set[int] = set() + items: list[dict] = [] + + for trade in day_trades: + review = _find_review_for_trade(trade, review_index, used_review_ids) + if review: + used_review_ids.add(int(review["id"])) + items.append(_format_day_entry(trade=trade, review=review, source="review")) + else: + items.append(_format_day_entry(trade=trade, source="trade")) + + for review in day_reviews: + if int(review.get("id") or 0) in used_review_ids: + continue + items.append(_format_day_entry(review=review, source="review")) + + items.sort(key=lambda x: _close_ts(x), reverse=True) + return items + + +def build_calendar_month(trades: list[dict], reviews: list[dict], year: int, month: int) -> dict: + review_index = _index_reviews_by_day_sym(reviews) + day_map: dict[str, dict] = {} + matched_review_ids: dict[str, set[int]] = {} + + for trade in trades: + dt = _parse_dt(trade.get("close_time") or "") + if not dt or dt.year != year or dt.month != month: + continue + day = dt.date().isoformat() + bucket = day_map.setdefault( + day, + { + "date": day, + "count": 0, + "total_net": 0.0, + "review_count": 0, + "emotion_count": 0, + "has_emotion": False, + }, + ) + bucket["count"] += 1 + used = matched_review_ids.setdefault(day, set()) + review = _find_review_for_trade(trade, review_index, used) + if review: + rid = int(review["id"]) + used.add(rid) + bucket["total_net"] = round(bucket["total_net"] + _net_pnl(review), 2) + bucket["review_count"] += 1 + if review.get("is_emotion"): + bucket["emotion_count"] += 1 + bucket["has_emotion"] = True + else: + bucket["total_net"] = round(bucket["total_net"] + _net_pnl(trade), 2) + + for review in reviews: + if not review.get("is_emotion"): + continue + day = _close_day_key(review) + if not day: + continue + try: + dt = date.fromisoformat(day) + except ValueError: + continue + if dt.year != year or dt.month != month: + continue + bucket = day_map.setdefault( + day, + { + "date": day, + "count": 0, + "total_net": 0.0, + "review_count": 0, + "emotion_count": 0, + "has_emotion": False, + }, + ) + bucket["has_emotion"] = True + rid = int(review.get("id") or 0) + if rid and rid not in matched_review_ids.get(day, set()): + bucket["emotion_count"] += 1 + + _, last_day = calendar.monthrange(year, month) + days = [] + for d in range(1, last_day + 1): + iso = date(year, month, d).isoformat() + if iso in day_map: + row = day_map[iso] + row["total_net"] = round(row["total_net"], 2) + days.append(row) + else: + days.append( + { + "date": iso, + "count": 0, + "total_net": 0.0, + "review_count": 0, + "emotion_count": 0, + "has_emotion": False, + } + ) + + return { + "year": year, + "month": month, + "days": days, + "weekday_start": date(year, month, 1).weekday(), + } + + +def get_calendar_month(conn, year: int, month: int) -> dict: + trades = fetch_trade_rows(conn) + reviews = fetch_review_rows(conn) + return build_calendar_month(trades, reviews, year, month) + + +def get_calendar_day(conn, day: str) -> dict: + trades = fetch_trade_rows(conn) + reviews = fetch_review_rows(conn) + items = build_day_detail(trades, reviews, day) + total_net = round(sum(float(i.get("pnl_net") or 0) for i in items), 2) + emotion_count = sum(1 for i in items if i.get("is_emotion")) + return { + "date": day, + "count": len(items), + "total_net": total_net, + "emotion_count": emotion_count, + "items": items, + } diff --git a/modules/strategy/__init__.py b/modules/strategy/__init__.py new file mode 100644 index 0000000..2a7b259 --- /dev/null +++ b/modules/strategy/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) 2025-2026 马建军. All rights reserved. + +"""Strategy routes are registered via modules.trading (install_trading).""" + + +def register(deps) -> None: + del deps + + +__all__ = ["register"] diff --git a/strategy/fib_lib.py b/modules/strategy/fib_lib.py similarity index 100% rename from strategy/fib_lib.py rename to modules/strategy/fib_lib.py diff --git a/strategy/strategy_db.py b/modules/strategy/strategy_db.py similarity index 95% rename from strategy/strategy_db.py rename to modules/strategy/strategy_db.py index d7792ca..a248166 100644 --- a/strategy/strategy_db.py +++ b/modules/strategy/strategy_db.py @@ -1,169 +1,169 @@ -# Copyright (c) 2025-2026 马建军. All rights reserved. -# 专有软件 — 未经授权禁止复制、传播、转售。 -# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 -# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md - -"""策略相关表结构。""" -from __future__ import annotations - -from db_conn import rollback_if_postgres - -ROLL_GROUPS_SQL = """ -CREATE TABLE IF NOT EXISTS roll_groups ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - order_monitor_id INTEGER, - symbol TEXT NOT NULL, - direction TEXT NOT NULL, - initial_take_profit REAL, - initial_stop_loss REAL, - current_stop_loss REAL, - risk_percent REAL DEFAULT 2, - leg_count INTEGER DEFAULT 0, - status TEXT DEFAULT 'active', - created_at TEXT, - updated_at TEXT -) -""" - -ROLL_LEGS_SQL = """ -CREATE TABLE IF NOT EXISTS roll_legs ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - roll_group_id INTEGER NOT NULL, - leg_index INTEGER NOT NULL, - add_mode TEXT NOT NULL, - fill_price REAL, - lots INTEGER, - new_stop_loss REAL, - status TEXT DEFAULT 'filled', - created_at TEXT -) -""" - -TREND_PLANS_SQL = """ -CREATE TABLE IF NOT EXISTS trend_pullback_plans ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - status TEXT DEFAULT 'active', - symbol TEXT NOT NULL, - symbol_name TEXT, - direction TEXT NOT NULL DEFAULT 'long', - stop_loss REAL NOT NULL, - add_upper REAL NOT NULL, - take_profit REAL NOT NULL, - risk_percent REAL DEFAULT 5, - capital_snapshot REAL, - plan_margin REAL, - target_lots INTEGER, - first_lots INTEGER, - remainder_lots INTEGER, - dca_legs INTEGER DEFAULT 5, - leg_amounts_json TEXT, - grid_prices_json TEXT, - legs_done INTEGER DEFAULT 0, - first_order_done INTEGER DEFAULT 0, - avg_entry_price REAL, - lots_open INTEGER DEFAULT 0, - opened_at TEXT, - message TEXT, - period TEXT DEFAULT '15m' -) -""" - -STRATEGY_SNAPSHOTS_SQL = """ -CREATE TABLE IF NOT EXISTS strategy_trade_snapshots ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - strategy_type TEXT NOT NULL, - source_id INTEGER, - symbol TEXT, - direction TEXT, - result_label TEXT, - opened_at TEXT, - closed_at TEXT, - pnl_amount REAL, - snapshot_json TEXT NOT NULL, - created_at TEXT -) -""" - -TRADE_ORDER_MONITORS_SQL = """ -CREATE TABLE IF NOT EXISTS trade_order_monitors ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - symbol TEXT NOT NULL, - symbol_name TEXT, - market_code TEXT, - direction TEXT NOT NULL, - lots INTEGER NOT NULL, - entry_price REAL, - stop_loss REAL, - take_profit REAL, - open_time TEXT, - monitor_type TEXT DEFAULT 'manual', - status TEXT DEFAULT 'active', - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP -) -""" - -CTP_SIM_ACCOUNT_SQL = """ -CREATE TABLE IF NOT EXISTS ctp_sim_account ( - id INTEGER PRIMARY KEY CHECK (id = 1), - balance REAL DEFAULT 100000, - available REAL DEFAULT 100000, - updated_at TEXT -) -""" - -CTP_SIM_POSITIONS_SQL = """ -CREATE TABLE IF NOT EXISTS ctp_sim_positions ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - symbol TEXT NOT NULL, - direction TEXT NOT NULL, - lots INTEGER NOT NULL, - avg_price REAL NOT NULL, - updated_at TEXT, - UNIQUE(symbol, direction) -) -""" - - -ROLL_LEG_EXTRA_COLUMNS = ( - "ALTER TABLE roll_legs ADD COLUMN limit_price REAL", - "ALTER TABLE roll_legs ADD COLUMN breakthrough_price REAL", - "ALTER TABLE roll_legs ADD COLUMN last_mark_price REAL", - "ALTER TABLE roll_legs ADD COLUMN invalidated_reason TEXT", - "ALTER TABLE roll_legs ADD COLUMN capital_snapshot REAL", - "ALTER TABLE trade_order_monitors ADD COLUMN risk_percent REAL", -) - - -_TABLES_READY = False - - -def init_strategy_tables(conn) -> None: - global _TABLES_READY - if _TABLES_READY: - return - for sql in ( - ROLL_GROUPS_SQL, - ROLL_LEGS_SQL, - TREND_PLANS_SQL, - STRATEGY_SNAPSHOTS_SQL, - TRADE_ORDER_MONITORS_SQL, - CTP_SIM_ACCOUNT_SQL, - CTP_SIM_POSITIONS_SQL, - ): - conn.execute(sql) - conn.commit() - try: - conn.execute("ALTER TABLE trend_pullback_plans ADD COLUMN period TEXT DEFAULT '15m'") - except Exception: - pass - for sql in ROLL_LEG_EXTRA_COLUMNS: - try: - conn.execute(sql) - conn.commit() - except Exception: - rollback_if_postgres(conn) - pass - if not conn.execute("SELECT id FROM ctp_sim_account WHERE id=1").fetchone(): - conn.execute("INSERT INTO ctp_sim_account (id, balance, available) VALUES (1, 100000, 100000)") - conn.commit() - _TABLES_READY = True +# Copyright (c) 2025-2026 马建军. All rights reserved. +# 专有软件 — 未经授权禁止复制、传播、转售。 +# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 +# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md + +"""策略相关表结构。""" +from __future__ import annotations + +from modules.core.db_conn import rollback_if_postgres + +ROLL_GROUPS_SQL = """ +CREATE TABLE IF NOT EXISTS roll_groups ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + order_monitor_id INTEGER, + symbol TEXT NOT NULL, + direction TEXT NOT NULL, + initial_take_profit REAL, + initial_stop_loss REAL, + current_stop_loss REAL, + risk_percent REAL DEFAULT 2, + leg_count INTEGER DEFAULT 0, + status TEXT DEFAULT 'active', + created_at TEXT, + updated_at TEXT +) +""" + +ROLL_LEGS_SQL = """ +CREATE TABLE IF NOT EXISTS roll_legs ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + roll_group_id INTEGER NOT NULL, + leg_index INTEGER NOT NULL, + add_mode TEXT NOT NULL, + fill_price REAL, + lots INTEGER, + new_stop_loss REAL, + status TEXT DEFAULT 'filled', + created_at TEXT +) +""" + +TREND_PLANS_SQL = """ +CREATE TABLE IF NOT EXISTS trend_pullback_plans ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + status TEXT DEFAULT 'active', + symbol TEXT NOT NULL, + symbol_name TEXT, + direction TEXT NOT NULL DEFAULT 'long', + stop_loss REAL NOT NULL, + add_upper REAL NOT NULL, + take_profit REAL NOT NULL, + risk_percent REAL DEFAULT 5, + capital_snapshot REAL, + plan_margin REAL, + target_lots INTEGER, + first_lots INTEGER, + remainder_lots INTEGER, + dca_legs INTEGER DEFAULT 5, + leg_amounts_json TEXT, + grid_prices_json TEXT, + legs_done INTEGER DEFAULT 0, + first_order_done INTEGER DEFAULT 0, + avg_entry_price REAL, + lots_open INTEGER DEFAULT 0, + opened_at TEXT, + message TEXT, + period TEXT DEFAULT '15m' +) +""" + +STRATEGY_SNAPSHOTS_SQL = """ +CREATE TABLE IF NOT EXISTS strategy_trade_snapshots ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + strategy_type TEXT NOT NULL, + source_id INTEGER, + symbol TEXT, + direction TEXT, + result_label TEXT, + opened_at TEXT, + closed_at TEXT, + pnl_amount REAL, + snapshot_json TEXT NOT NULL, + created_at TEXT +) +""" + +TRADE_ORDER_MONITORS_SQL = """ +CREATE TABLE IF NOT EXISTS trade_order_monitors ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + symbol TEXT NOT NULL, + symbol_name TEXT, + market_code TEXT, + direction TEXT NOT NULL, + lots INTEGER NOT NULL, + entry_price REAL, + stop_loss REAL, + take_profit REAL, + open_time TEXT, + monitor_type TEXT DEFAULT 'manual', + status TEXT DEFAULT 'active', + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +) +""" + +CTP_SIM_ACCOUNT_SQL = """ +CREATE TABLE IF NOT EXISTS ctp_sim_account ( + id INTEGER PRIMARY KEY CHECK (id = 1), + balance REAL DEFAULT 100000, + available REAL DEFAULT 100000, + updated_at TEXT +) +""" + +CTP_SIM_POSITIONS_SQL = """ +CREATE TABLE IF NOT EXISTS ctp_sim_positions ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + symbol TEXT NOT NULL, + direction TEXT NOT NULL, + lots INTEGER NOT NULL, + avg_price REAL NOT NULL, + updated_at TEXT, + UNIQUE(symbol, direction) +) +""" + + +ROLL_LEG_EXTRA_COLUMNS = ( + "ALTER TABLE roll_legs ADD COLUMN limit_price REAL", + "ALTER TABLE roll_legs ADD COLUMN breakthrough_price REAL", + "ALTER TABLE roll_legs ADD COLUMN last_mark_price REAL", + "ALTER TABLE roll_legs ADD COLUMN invalidated_reason TEXT", + "ALTER TABLE roll_legs ADD COLUMN capital_snapshot REAL", + "ALTER TABLE trade_order_monitors ADD COLUMN risk_percent REAL", +) + + +_TABLES_READY = False + + +def init_strategy_tables(conn) -> None: + global _TABLES_READY + if _TABLES_READY: + return + for sql in ( + ROLL_GROUPS_SQL, + ROLL_LEGS_SQL, + TREND_PLANS_SQL, + STRATEGY_SNAPSHOTS_SQL, + TRADE_ORDER_MONITORS_SQL, + CTP_SIM_ACCOUNT_SQL, + CTP_SIM_POSITIONS_SQL, + ): + conn.execute(sql) + conn.commit() + try: + conn.execute("ALTER TABLE trend_pullback_plans ADD COLUMN period TEXT DEFAULT '15m'") + except Exception: + pass + for sql in ROLL_LEG_EXTRA_COLUMNS: + try: + conn.execute(sql) + conn.commit() + except Exception: + rollback_if_postgres(conn) + pass + if not conn.execute("SELECT id FROM ctp_sim_account WHERE id=1").fetchone(): + conn.execute("INSERT INTO ctp_sim_account (id, balance, available) VALUES (1, 100000, 100000)") + conn.commit() + _TABLES_READY = True diff --git a/strategy/strategy_roll_lib.py b/modules/strategy/strategy_roll_lib.py similarity index 96% rename from strategy/strategy_roll_lib.py rename to modules/strategy/strategy_roll_lib.py index c153752..6d0b217 100644 --- a/strategy/strategy_roll_lib.py +++ b/modules/strategy/strategy_roll_lib.py @@ -1,370 +1,370 @@ -# Copyright (c) 2025-2026 马建军. All rights reserved. -# 专有软件 — 未经授权禁止复制、传播、转售。 -# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 -# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md - -"""顺势加仓(滚仓):纯计算与校验,期货版(手数整数、乘数计入盈亏)。""" -from __future__ import annotations - -import math -from typing import Any, Optional, Tuple - -from position_sizing import MODE_AMOUNT -from strategy.fib_lib import calc_fib_plan - -ROLL_MAX_LEGS_LONG = 3 -ROLL_MAX_LEGS_SHORT = 3 -ROLL_STOP_OFFSET_PCT_DEFAULT = 1.0 - -ADD_MODE_MARKET = "market" -ADD_MODE_FIB_618 = "fib_618" -ADD_MODE_FIB_786 = "fib_786" -ADD_MODE_BREAKOUT = "breakout" - -FIB_MODES = frozenset({ADD_MODE_FIB_618, ADD_MODE_FIB_786}) -PENDING_MODES = frozenset({ADD_MODE_FIB_618, ADD_MODE_FIB_786, ADD_MODE_BREAKOUT}) - -ADD_MODE_LABELS = { - ADD_MODE_MARKET: "市价加仓", - ADD_MODE_FIB_618: "斐波0.618", - ADD_MODE_FIB_786: "斐波0.786", - ADD_MODE_BREAKOUT: "突破加仓", -} - -LEG_STATUS_PENDING = "pending" -LEG_STATUS_FILLED = "filled" -LEG_STATUS_CANCELLED = "cancelled" -LEG_STATUS_INVALIDATED = "invalidated" - - -def add_mode_label(mode: str) -> str: - return ADD_MODE_LABELS.get((mode or "").strip().lower(), mode or "") - - -def fib_ratio_from_mode(mode: str) -> Optional[float]: - m = (mode or "").strip().lower() - if m in (ADD_MODE_FIB_618, "618", "0.618"): - return 0.618 - if m in (ADD_MODE_FIB_786, "786", "0.786"): - return 0.786 - return None - - -def fib_limit_entry(direction: str, upper: float, lower: float, mode: str) -> Tuple[Optional[float], Optional[str]]: - ratio = fib_ratio_from_mode(mode) - if ratio is None: - return None, "斐波档位无效" - h, l = float(upper), float(lower) - if h <= l: - return None, "上沿须大于下沿" - direction = (direction or "long").strip().lower() - plan = calc_fib_plan(direction, h, l, ratio) - if not plan: - return None, "无法计算斐波限价" - entry, _sl, _tp = plan - return float(entry), None - - -def max_roll_legs(direction: str) -> int: - return ROLL_MAX_LEGS_LONG if (direction or "long").strip().lower() == "long" else ROLL_MAX_LEGS_SHORT - - -def lots_precise(raw: float) -> int: - if raw is None or raw < 1: - return 0 - return max(1, int(math.floor(float(raw)))) - - -def unified_stop_from_avg(direction: str, avg: float, offset_pct: float) -> float: - avg_f = float(avg) - pct = float(offset_pct) / 100.0 - direction = (direction or "long").strip().lower() - if direction == "short": - return avg_f * (1.0 + pct) - return avg_f * (1.0 - pct) - - -def avg_entry_after_add(qty_existing: float, entry_existing: float, add_qty: float, add_price: float) -> float: - q1, e1, q2, e2 = float(qty_existing), float(entry_existing), float(add_qty), float(add_price) - total = q1 + q2 - return (q1 * e1 + q2 * e2) / total if total > 0 else 0.0 - - -def solve_add_lots_for_total_risk( - direction: str, - qty_existing: float, - entry_existing: float, - add_price: float, - new_stop: float, - risk_budget: float, - mult: int, -) -> Tuple[Optional[int], Optional[str]]: - """方案 C:合并持仓打到新止损 S 时总亏损 ≤ B。""" - q1, e1, e2, sl, b = float(qty_existing), float(entry_existing), float(add_price), float(new_stop), float(risk_budget) - m = float(mult) - direction = (direction or "long").strip().lower() - if direction == "short": - denom = (sl - e2) * m - numer = b - q1 * (sl - e1) * m - else: - denom = (e2 - sl) * m - numer = b - q1 * (e1 - sl) * m - if denom <= 0: - return None, "止损与加仓价关系无效" - q2 = numer / denom - lots = lots_precise(q2) - if lots < 1: - return None, "已满足风险上限或无法再加" - return lots, None - - -def roll_eligibility_error( - *, - sizing_mode: str, - monitor: dict, - has_active_trend: bool, - legs_done: int = 0, - has_pending_leg: bool = False, -) -> Optional[str]: - if normalize_sizing_mode(sizing_mode) != MODE_AMOUNT: - return "仅固定金额(以损定仓)模式可滚仓" - if has_active_trend: - return "趋势回调运行中,不可滚仓" - if not monitor or (monitor.get("status") or "").strip().lower() != "active": - return "无有效持仓监控" - if int(monitor.get("trailing_be") or 0): - return "移动保本持仓不可滚仓" - direction = (monitor.get("direction") or "long").strip().lower() - if legs_done >= max_roll_legs(direction): - return f"滚仓已达 {max_roll_legs(direction)} 次上限" - if has_pending_leg: - return "已有监控中的加仓腿,请等待成交或删除后再提交" - if int(monitor.get("lots") or 0) < 1: - return "持仓手数为 0" - if not float(monitor.get("take_profit") or 0): - return "首仓须设置止盈(移动保本不可滚仓)" - return None - - -def normalize_sizing_mode(raw: str) -> str: - from position_sizing import normalize_sizing_mode as _norm - return _norm(raw) - - -def resolve_risk_percent(monitor: dict, *, default: float) -> float: - try: - rp = float(monitor.get("risk_percent") or 0) - if rp > 0: - return rp - except (TypeError, ValueError): - pass - return float(default) - - -def validate_roll_geometry( - direction: str, - add_mode: str, - new_stop: float, - *, - mark_price: float, - limit_price: Optional[float] = None, - breakthrough_price: Optional[float] = None, - at_trigger: bool = False, - off_session_pending: bool = False, -) -> Optional[str]: - """几何校验。 - - 做多斐波(回调):止损 < 触发价 < 当前价 - 做多突破(向上):止损 < 突破价 < 当前价 - 做空斐波(反弹):当前价 < 触发价 < 止损 - 做空突破(向下):突破价 < 当前价 < 止损(提交时);触发后当前价可 ≤ 突破价 - """ - direction = (direction or "long").strip().lower() - mode = (add_mode or ADD_MODE_MARKET).strip().lower() - sl = float(new_stop) - mark = float(mark_price) - if sl <= 0 or mark <= 0: - return "止损或参考价无效" - if mode == ADD_MODE_MARKET: - if direction == "long" and sl >= mark: - return "做多:新止损须低于当前价" - if direction == "short" and sl <= mark: - return "做空:新止损须高于当前价" - return None - trigger = None - if mode in FIB_MODES: - trigger = float(limit_price or 0) - if trigger <= 0: - return "须填写斐波触发价" - if direction == "long": - if not (sl < trigger < mark): - return "做多斐波:须满足 止损 < 触发价 < 当前价" - else: - if not (mark < trigger < sl): - return "做空斐波:须满足 当前价 < 触发价 < 止损" - return None - if mode == ADD_MODE_BREAKOUT: - trigger = float(breakthrough_price or 0) - if trigger <= 0: - return "须填写突破价" - if off_session_pending: - if direction == "long" and not (sl < trigger): - return "做多突破:休盘提交须满足 止损 < 突破价" - if direction == "short" and not (trigger < sl): - return "做空突破:休盘提交须满足 突破价 < 止损" - return None - if at_trigger: - if direction == "long": - if not (sl < trigger <= mark): - return "做多突破:触发时须满足 止损 < 突破价 ≤ 当前价" - else: - if not (trigger < sl and mark < sl): - return "做空突破:触发时须满足 突破价 < 止损且当前价 < 止损" - return None - if direction == "long": - if not (sl < trigger < mark): - return "做多突破:须满足 止损 < 突破价 < 当前价" - else: - if not (trigger < mark < sl): - return "做空突破:须满足 突破价 < 当前价 < 止损" - return None - return "加仓方式无效" - - -def detect_mark_cross( - direction: str, - add_mode: str, - prev_mark: float, - mark: float, - trigger_price: float, -) -> bool: - """标记价穿越触发价(上一 tick 与当前 tick 比较)。""" - direction = (direction or "long").strip().lower() - mode = (add_mode or "").strip().lower() - p = float(trigger_price) - prev_m = float(prev_mark) - cur_m = float(mark) - if p <= 0 or prev_m <= 0 or cur_m <= 0: - return False - if mode in FIB_MODES: - if direction == "long": - return prev_m > p and cur_m <= p - return prev_m < p and cur_m >= p - if mode == ADD_MODE_BREAKOUT: - if direction == "long": - return prev_m < p and cur_m >= p - return prev_m > p and cur_m <= p - return False - - -def preview_roll( - *, - direction: str, - symbol: str, - qty_existing: float, - entry_existing: float, - initial_take_profit: float, - add_mode: str, - new_stop_loss: float, - risk_budget: float, - mult: int, - mark_price: Optional[float] = None, - add_price: Optional[float] = None, - limit_price: Optional[float] = None, - breakthrough_price: Optional[float] = None, - fib_upper: Optional[float] = None, - fib_lower: Optional[float] = None, - legs_done: int = 0, - at_trigger: bool = False, - off_session_pending: bool = False, -) -> Tuple[Optional[dict[str, Any]], Optional[str]]: - direction = (direction or "long").strip().lower() - if legs_done >= max_roll_legs(direction): - return None, f"滚仓已达 {max_roll_legs(direction)} 次上限" - mode = (add_mode or ADD_MODE_MARKET).strip().lower() - mark = float(mark_price or add_price or 0) - if mark <= 0 and mode == ADD_MODE_BREAKOUT and off_session_pending: - mark = float(breakthrough_price or 0) - if mark <= 0: - return None, "需要有效参考价" - sl = float(new_stop_loss) - tp = float(initial_take_profit) - if sl <= 0 or tp <= 0: - return None, "止损/止盈无效" - - entry_add = mark - mode_label = add_mode_label(mode) - trigger_price = mark - is_pending = mode in PENDING_MODES - - if mode == ADD_MODE_MARKET: - entry_add = mark - elif mode in FIB_MODES: - if limit_price and float(limit_price) > 0: - entry_add = float(limit_price) - trigger_price = entry_add - elif fib_upper is not None and fib_lower is not None: - entry_add, err = fib_limit_entry(direction, float(fib_upper), float(fib_lower), mode) - if err: - return None, err - trigger_price = entry_add - else: - return None, "斐波须填触发价或上沿/下沿" - elif mode == ADD_MODE_BREAKOUT: - if not breakthrough_price or float(breakthrough_price) <= 0: - return None, "须填写突破价" - entry_add = float(breakthrough_price) - trigger_price = entry_add - else: - return None, "加仓方式无效" - - geom_err = validate_roll_geometry( - direction, mode, sl, - mark_price=mark, - limit_price=trigger_price if mode in FIB_MODES else None, - breakthrough_price=trigger_price if mode == ADD_MODE_BREAKOUT else None, - at_trigger=at_trigger, - off_session_pending=off_session_pending and is_pending, - ) - if geom_err: - return None, geom_err - - budget = float(risk_budget) - if budget <= 0: - return None, "固定金额无效" - q2, err = solve_add_lots_for_total_risk( - direction, qty_existing, entry_existing, entry_add, sl, budget, mult, - ) - if err: - return None, err - new_qty = qty_existing + q2 - new_avg = avg_entry_after_add(qty_existing, entry_existing, q2, entry_add) - m = float(mult) - if direction == "long": - loss_at_sl = (new_avg - sl) * new_qty * m - reward_at_tp = (tp - new_avg) * new_qty * m - else: - loss_at_sl = (sl - new_avg) * new_qty * m - reward_at_tp = (new_avg - tp) * new_qty * m - return { - "symbol": symbol, - "direction": direction, - "add_mode": mode, - "add_mode_label": mode_label, - "is_pending": is_pending, - "add_price": round(entry_add, 4), - "trigger_price": round(trigger_price, 4), - "limit_price": round(trigger_price, 4) if mode in FIB_MODES else None, - "breakthrough_price": round(trigger_price, 4) if mode == ADD_MODE_BREAKOUT else None, - "new_stop_loss": round(sl, 4), - "initial_take_profit": tp, - "risk_budget": round(budget, 2), - "fixed_amount": round(budget, 2), - "add_lots": q2, - "qty_after": int(new_qty), - "avg_entry_after": round(new_avg, 4), - "loss_at_sl": round(loss_at_sl, 2), - "reward_at_tp": round(reward_at_tp, 2), - "legs_done": legs_done, - "mark_price": round(mark, 4), - }, None +# Copyright (c) 2025-2026 马建军. All rights reserved. +# 专有软件 — 未经授权禁止复制、传播、转售。 +# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 +# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md + +"""顺势加仓(滚仓):纯计算与校验,期货版(手数整数、乘数计入盈亏)。""" +from __future__ import annotations + +import math +from typing import Any, Optional, Tuple + +from modules.trading.position_sizing import MODE_AMOUNT +from strategy.fib_lib import calc_fib_plan + +ROLL_MAX_LEGS_LONG = 3 +ROLL_MAX_LEGS_SHORT = 3 +ROLL_STOP_OFFSET_PCT_DEFAULT = 1.0 + +ADD_MODE_MARKET = "market" +ADD_MODE_FIB_618 = "fib_618" +ADD_MODE_FIB_786 = "fib_786" +ADD_MODE_BREAKOUT = "breakout" + +FIB_MODES = frozenset({ADD_MODE_FIB_618, ADD_MODE_FIB_786}) +PENDING_MODES = frozenset({ADD_MODE_FIB_618, ADD_MODE_FIB_786, ADD_MODE_BREAKOUT}) + +ADD_MODE_LABELS = { + ADD_MODE_MARKET: "市价加仓", + ADD_MODE_FIB_618: "斐波0.618", + ADD_MODE_FIB_786: "斐波0.786", + ADD_MODE_BREAKOUT: "突破加仓", +} + +LEG_STATUS_PENDING = "pending" +LEG_STATUS_FILLED = "filled" +LEG_STATUS_CANCELLED = "cancelled" +LEG_STATUS_INVALIDATED = "invalidated" + + +def add_mode_label(mode: str) -> str: + return ADD_MODE_LABELS.get((mode or "").strip().lower(), mode or "") + + +def fib_ratio_from_mode(mode: str) -> Optional[float]: + m = (mode or "").strip().lower() + if m in (ADD_MODE_FIB_618, "618", "0.618"): + return 0.618 + if m in (ADD_MODE_FIB_786, "786", "0.786"): + return 0.786 + return None + + +def fib_limit_entry(direction: str, upper: float, lower: float, mode: str) -> Tuple[Optional[float], Optional[str]]: + ratio = fib_ratio_from_mode(mode) + if ratio is None: + return None, "斐波档位无效" + h, l = float(upper), float(lower) + if h <= l: + return None, "上沿须大于下沿" + direction = (direction or "long").strip().lower() + plan = calc_fib_plan(direction, h, l, ratio) + if not plan: + return None, "无法计算斐波限价" + entry, _sl, _tp = plan + return float(entry), None + + +def max_roll_legs(direction: str) -> int: + return ROLL_MAX_LEGS_LONG if (direction or "long").strip().lower() == "long" else ROLL_MAX_LEGS_SHORT + + +def lots_precise(raw: float) -> int: + if raw is None or raw < 1: + return 0 + return max(1, int(math.floor(float(raw)))) + + +def unified_stop_from_avg(direction: str, avg: float, offset_pct: float) -> float: + avg_f = float(avg) + pct = float(offset_pct) / 100.0 + direction = (direction or "long").strip().lower() + if direction == "short": + return avg_f * (1.0 + pct) + return avg_f * (1.0 - pct) + + +def avg_entry_after_add(qty_existing: float, entry_existing: float, add_qty: float, add_price: float) -> float: + q1, e1, q2, e2 = float(qty_existing), float(entry_existing), float(add_qty), float(add_price) + total = q1 + q2 + return (q1 * e1 + q2 * e2) / total if total > 0 else 0.0 + + +def solve_add_lots_for_total_risk( + direction: str, + qty_existing: float, + entry_existing: float, + add_price: float, + new_stop: float, + risk_budget: float, + mult: int, +) -> Tuple[Optional[int], Optional[str]]: + """方案 C:合并持仓打到新止损 S 时总亏损 ≤ B。""" + q1, e1, e2, sl, b = float(qty_existing), float(entry_existing), float(add_price), float(new_stop), float(risk_budget) + m = float(mult) + direction = (direction or "long").strip().lower() + if direction == "short": + denom = (sl - e2) * m + numer = b - q1 * (sl - e1) * m + else: + denom = (e2 - sl) * m + numer = b - q1 * (e1 - sl) * m + if denom <= 0: + return None, "止损与加仓价关系无效" + q2 = numer / denom + lots = lots_precise(q2) + if lots < 1: + return None, "已满足风险上限或无法再加" + return lots, None + + +def roll_eligibility_error( + *, + sizing_mode: str, + monitor: dict, + has_active_trend: bool, + legs_done: int = 0, + has_pending_leg: bool = False, +) -> Optional[str]: + if normalize_sizing_mode(sizing_mode) != MODE_AMOUNT: + return "仅固定金额(以损定仓)模式可滚仓" + if has_active_trend: + return "趋势回调运行中,不可滚仓" + if not monitor or (monitor.get("status") or "").strip().lower() != "active": + return "无有效持仓监控" + if int(monitor.get("trailing_be") or 0): + return "移动保本持仓不可滚仓" + direction = (monitor.get("direction") or "long").strip().lower() + if legs_done >= max_roll_legs(direction): + return f"滚仓已达 {max_roll_legs(direction)} 次上限" + if has_pending_leg: + return "已有监控中的加仓腿,请等待成交或删除后再提交" + if int(monitor.get("lots") or 0) < 1: + return "持仓手数为 0" + if not float(monitor.get("take_profit") or 0): + return "首仓须设置止盈(移动保本不可滚仓)" + return None + + +def normalize_sizing_mode(raw: str) -> str: + from modules.trading.position_sizing import normalize_sizing_mode as _norm + return _norm(raw) + + +def resolve_risk_percent(monitor: dict, *, default: float) -> float: + try: + rp = float(monitor.get("risk_percent") or 0) + if rp > 0: + return rp + except (TypeError, ValueError): + pass + return float(default) + + +def validate_roll_geometry( + direction: str, + add_mode: str, + new_stop: float, + *, + mark_price: float, + limit_price: Optional[float] = None, + breakthrough_price: Optional[float] = None, + at_trigger: bool = False, + off_session_pending: bool = False, +) -> Optional[str]: + """几何校验。 + + 做多斐波(回调):止损 < 触发价 < 当前价 + 做多突破(向上):止损 < 突破价 < 当前价 + 做空斐波(反弹):当前价 < 触发价 < 止损 + 做空突破(向下):突破价 < 当前价 < 止损(提交时);触发后当前价可 ≤ 突破价 + """ + direction = (direction or "long").strip().lower() + mode = (add_mode or ADD_MODE_MARKET).strip().lower() + sl = float(new_stop) + mark = float(mark_price) + if sl <= 0 or mark <= 0: + return "止损或参考价无效" + if mode == ADD_MODE_MARKET: + if direction == "long" and sl >= mark: + return "做多:新止损须低于当前价" + if direction == "short" and sl <= mark: + return "做空:新止损须高于当前价" + return None + trigger = None + if mode in FIB_MODES: + trigger = float(limit_price or 0) + if trigger <= 0: + return "须填写斐波触发价" + if direction == "long": + if not (sl < trigger < mark): + return "做多斐波:须满足 止损 < 触发价 < 当前价" + else: + if not (mark < trigger < sl): + return "做空斐波:须满足 当前价 < 触发价 < 止损" + return None + if mode == ADD_MODE_BREAKOUT: + trigger = float(breakthrough_price or 0) + if trigger <= 0: + return "须填写突破价" + if off_session_pending: + if direction == "long" and not (sl < trigger): + return "做多突破:休盘提交须满足 止损 < 突破价" + if direction == "short" and not (trigger < sl): + return "做空突破:休盘提交须满足 突破价 < 止损" + return None + if at_trigger: + if direction == "long": + if not (sl < trigger <= mark): + return "做多突破:触发时须满足 止损 < 突破价 ≤ 当前价" + else: + if not (trigger < sl and mark < sl): + return "做空突破:触发时须满足 突破价 < 止损且当前价 < 止损" + return None + if direction == "long": + if not (sl < trigger < mark): + return "做多突破:须满足 止损 < 突破价 < 当前价" + else: + if not (trigger < mark < sl): + return "做空突破:须满足 突破价 < 当前价 < 止损" + return None + return "加仓方式无效" + + +def detect_mark_cross( + direction: str, + add_mode: str, + prev_mark: float, + mark: float, + trigger_price: float, +) -> bool: + """标记价穿越触发价(上一 tick 与当前 tick 比较)。""" + direction = (direction or "long").strip().lower() + mode = (add_mode or "").strip().lower() + p = float(trigger_price) + prev_m = float(prev_mark) + cur_m = float(mark) + if p <= 0 or prev_m <= 0 or cur_m <= 0: + return False + if mode in FIB_MODES: + if direction == "long": + return prev_m > p and cur_m <= p + return prev_m < p and cur_m >= p + if mode == ADD_MODE_BREAKOUT: + if direction == "long": + return prev_m < p and cur_m >= p + return prev_m > p and cur_m <= p + return False + + +def preview_roll( + *, + direction: str, + symbol: str, + qty_existing: float, + entry_existing: float, + initial_take_profit: float, + add_mode: str, + new_stop_loss: float, + risk_budget: float, + mult: int, + mark_price: Optional[float] = None, + add_price: Optional[float] = None, + limit_price: Optional[float] = None, + breakthrough_price: Optional[float] = None, + fib_upper: Optional[float] = None, + fib_lower: Optional[float] = None, + legs_done: int = 0, + at_trigger: bool = False, + off_session_pending: bool = False, +) -> Tuple[Optional[dict[str, Any]], Optional[str]]: + direction = (direction or "long").strip().lower() + if legs_done >= max_roll_legs(direction): + return None, f"滚仓已达 {max_roll_legs(direction)} 次上限" + mode = (add_mode or ADD_MODE_MARKET).strip().lower() + mark = float(mark_price or add_price or 0) + if mark <= 0 and mode == ADD_MODE_BREAKOUT and off_session_pending: + mark = float(breakthrough_price or 0) + if mark <= 0: + return None, "需要有效参考价" + sl = float(new_stop_loss) + tp = float(initial_take_profit) + if sl <= 0 or tp <= 0: + return None, "止损/止盈无效" + + entry_add = mark + mode_label = add_mode_label(mode) + trigger_price = mark + is_pending = mode in PENDING_MODES + + if mode == ADD_MODE_MARKET: + entry_add = mark + elif mode in FIB_MODES: + if limit_price and float(limit_price) > 0: + entry_add = float(limit_price) + trigger_price = entry_add + elif fib_upper is not None and fib_lower is not None: + entry_add, err = fib_limit_entry(direction, float(fib_upper), float(fib_lower), mode) + if err: + return None, err + trigger_price = entry_add + else: + return None, "斐波须填触发价或上沿/下沿" + elif mode == ADD_MODE_BREAKOUT: + if not breakthrough_price or float(breakthrough_price) <= 0: + return None, "须填写突破价" + entry_add = float(breakthrough_price) + trigger_price = entry_add + else: + return None, "加仓方式无效" + + geom_err = validate_roll_geometry( + direction, mode, sl, + mark_price=mark, + limit_price=trigger_price if mode in FIB_MODES else None, + breakthrough_price=trigger_price if mode == ADD_MODE_BREAKOUT else None, + at_trigger=at_trigger, + off_session_pending=off_session_pending and is_pending, + ) + if geom_err: + return None, geom_err + + budget = float(risk_budget) + if budget <= 0: + return None, "固定金额无效" + q2, err = solve_add_lots_for_total_risk( + direction, qty_existing, entry_existing, entry_add, sl, budget, mult, + ) + if err: + return None, err + new_qty = qty_existing + q2 + new_avg = avg_entry_after_add(qty_existing, entry_existing, q2, entry_add) + m = float(mult) + if direction == "long": + loss_at_sl = (new_avg - sl) * new_qty * m + reward_at_tp = (tp - new_avg) * new_qty * m + else: + loss_at_sl = (sl - new_avg) * new_qty * m + reward_at_tp = (new_avg - tp) * new_qty * m + return { + "symbol": symbol, + "direction": direction, + "add_mode": mode, + "add_mode_label": mode_label, + "is_pending": is_pending, + "add_price": round(entry_add, 4), + "trigger_price": round(trigger_price, 4), + "limit_price": round(trigger_price, 4) if mode in FIB_MODES else None, + "breakthrough_price": round(trigger_price, 4) if mode == ADD_MODE_BREAKOUT else None, + "new_stop_loss": round(sl, 4), + "initial_take_profit": tp, + "risk_budget": round(budget, 2), + "fixed_amount": round(budget, 2), + "add_lots": q2, + "qty_after": int(new_qty), + "avg_entry_after": round(new_avg, 4), + "loss_at_sl": round(loss_at_sl, 2), + "reward_at_tp": round(reward_at_tp, 2), + "legs_done": legs_done, + "mark_price": round(mark, 4), + }, None diff --git a/strategy/strategy_roll_monitor_lib.py b/modules/strategy/strategy_roll_monitor_lib.py similarity index 96% rename from strategy/strategy_roll_monitor_lib.py rename to modules/strategy/strategy_roll_monitor_lib.py index 9c7ab32..193e46a 100644 --- a/strategy/strategy_roll_monitor_lib.py +++ b/modules/strategy/strategy_roll_monitor_lib.py @@ -1,158 +1,158 @@ -# Copyright (c) 2025-2026 马建军. All rights reserved. -# 专有软件 — 未经授权禁止复制、传播、转售。 -# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 -# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md - -"""顺势滚仓程序监控:突破 pending 腿触价成交、外部平仓同步。""" -from __future__ import annotations - -import logging -from datetime import datetime -from typing import Any, Callable, Optional -from zoneinfo import ZoneInfo - -from contract_specs import get_contract_spec -from strategy.strategy_roll_lib import ( - ADD_MODE_BREAKOUT, - FIB_MODES, - LEG_STATUS_CANCELLED, - LEG_STATUS_FILLED, - LEG_STATUS_INVALIDATED, - LEG_STATUS_PENDING, - detect_mark_cross, - preview_roll, -) - -logger = logging.getLogger(__name__) -TZ = ZoneInfo("Asia/Shanghai") - - -def _now() -> str: - return datetime.now(TZ).strftime("%Y-%m-%d %H:%M:%S") - - -def roll_sync_after_external_close(conn, *, monitor_id: int) -> None: - """手动平仓或监控结案后关闭滚仓组并清除 pending 腿。""" - grp = conn.execute( - "SELECT id FROM roll_groups WHERE order_monitor_id=? AND status='active'", - (int(monitor_id),), - ).fetchone() - if not grp: - return - gid = int(grp["id"]) - conn.execute( - "UPDATE roll_legs SET status=? WHERE roll_group_id=? AND status=?", - (LEG_STATUS_CANCELLED, gid, LEG_STATUS_PENDING), - ) - conn.execute( - "UPDATE roll_groups SET status='closed', updated_at=? WHERE id=?", - (_now(), gid), - ) - - -def cancel_roll_leg(conn, leg_id: int) -> tuple[bool, str]: - row = conn.execute( - "SELECT * FROM roll_legs WHERE id=? AND status=?", - (int(leg_id), LEG_STATUS_PENDING), - ).fetchone() - if not row: - return False, "仅可删除监控中的腿" - conn.execute( - "UPDATE roll_legs SET status=? WHERE id=?", - (LEG_STATUS_CANCELLED, int(leg_id)), - ) - return True, "已删除" - - -def check_roll_monitors( - conn, - *, - get_mark_price_fn: Callable[[str], Optional[float]], - fill_roll_leg_fn: Callable[[dict, dict, dict, dict], tuple[bool, str]], - is_trading_session_fn: Callable[[], bool], - get_risk_budget_fn: Callable[[], float], - get_entry_price_fn: Optional[Callable[[str, str, float], float]] = None, -) -> None: - """扫描 pending 滚仓腿,标记价穿越则重算手数并市价成交。""" - if not is_trading_session_fn(): - return - rows = conn.execute( - """SELECT l.*, g.order_monitor_id, g.symbol, g.direction, g.initial_take_profit, - g.risk_percent, g.leg_count AS group_leg_count, - m.lots AS mon_lots, m.entry_price AS mon_entry, m.take_profit AS mon_tp, - m.status AS mon_status - FROM roll_legs l - JOIN roll_groups g ON g.id = l.roll_group_id - JOIN trade_order_monitors m ON m.id = g.order_monitor_id - WHERE l.status=? AND g.status='active' AND m.status='active'""", - (LEG_STATUS_PENDING,), - ).fetchall() - for raw in rows: - leg = dict(raw) - if (leg.get("mon_status") or "").strip().lower() != "active": - _invalidate_leg(conn, leg, "监控已结束") - continue - sym = (leg.get("symbol") or "").strip() - mark = get_mark_price_fn(sym) - if not mark or mark <= 0: - continue - prev_mark = float(leg.get("last_mark_price") or mark) - mode = (leg.get("add_mode") or "").strip().lower() - trigger = float(leg.get("limit_price") or leg.get("breakthrough_price") or 0) - direction = (leg.get("direction") or "long").strip().lower() - if mode in FIB_MODES or mode == ADD_MODE_BREAKOUT: - if not detect_mark_cross(direction, mode, prev_mark, mark, trigger): - conn.execute( - "UPDATE roll_legs SET last_mark_price=? WHERE id=?", - (float(mark), int(leg["id"])), - ) - continue - mon = { - "id": leg["order_monitor_id"], - "symbol": sym, - "direction": direction, - "lots": leg["mon_lots"], - "entry_price": leg["mon_entry"], - "take_profit": leg["mon_tp"] or leg["initial_take_profit"], - } - entry_fb = float(leg["mon_entry"] or 0) - entry_existing = ( - get_entry_price_fn(sym, direction, entry_fb) - if get_entry_price_fn - else entry_fb - ) - grp = { - "id": leg["roll_group_id"], - "order_monitor_id": leg["order_monitor_id"], - "leg_count": leg.get("group_leg_count") or 0, - "risk_percent": leg.get("risk_percent"), - } - preview, err = preview_roll( - direction=direction, - symbol=sym, - qty_existing=float(leg["mon_lots"] or 0), - entry_existing=entry_existing, - initial_take_profit=float(leg["mon_tp"] or leg["initial_take_profit"] or 0), - add_mode=mode, - new_stop_loss=float(leg["new_stop_loss"] or 0), - risk_budget=float(leg.get("risk_percent") or 0) or get_risk_budget_fn(), - mult=int(get_contract_spec(sym).get("mult") or 1), - mark_price=mark, - limit_price=trigger if mode in FIB_MODES else None, - breakthrough_price=trigger if mode == ADD_MODE_BREAKOUT else None, - legs_done=int(leg.get("group_leg_count") or 0), - at_trigger=True, - ) - if err or not preview: - _invalidate_leg(conn, leg, err or "触发时无法加仓") - continue - ok, msg = fill_roll_leg_fn(mon, grp, leg, preview) - if not ok: - logger.warning("roll leg fill failed #%s: %s", leg.get("id"), msg) - - -def _invalidate_leg(conn, leg: dict, reason: str) -> None: - conn.execute( - "UPDATE roll_legs SET status=?, invalidated_reason=? WHERE id=?", - (LEG_STATUS_INVALIDATED, (reason or "")[:200], int(leg["id"])), - ) +# Copyright (c) 2025-2026 马建军. All rights reserved. +# 专有软件 — 未经授权禁止复制、传播、转售。 +# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 +# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md + +"""顺势滚仓程序监控:突破 pending 腿触价成交、外部平仓同步。""" +from __future__ import annotations + +import logging +from datetime import datetime +from typing import Any, Callable, Optional +from zoneinfo import ZoneInfo + +from modules.core.contract_specs import get_contract_spec +from strategy.strategy_roll_lib import ( + ADD_MODE_BREAKOUT, + FIB_MODES, + LEG_STATUS_CANCELLED, + LEG_STATUS_FILLED, + LEG_STATUS_INVALIDATED, + LEG_STATUS_PENDING, + detect_mark_cross, + preview_roll, +) + +logger = logging.getLogger(__name__) +TZ = ZoneInfo("Asia/Shanghai") + + +def _now() -> str: + return datetime.now(TZ).strftime("%Y-%m-%d %H:%M:%S") + + +def roll_sync_after_external_close(conn, *, monitor_id: int) -> None: + """手动平仓或监控结案后关闭滚仓组并清除 pending 腿。""" + grp = conn.execute( + "SELECT id FROM roll_groups WHERE order_monitor_id=? AND status='active'", + (int(monitor_id),), + ).fetchone() + if not grp: + return + gid = int(grp["id"]) + conn.execute( + "UPDATE roll_legs SET status=? WHERE roll_group_id=? AND status=?", + (LEG_STATUS_CANCELLED, gid, LEG_STATUS_PENDING), + ) + conn.execute( + "UPDATE roll_groups SET status='closed', updated_at=? WHERE id=?", + (_now(), gid), + ) + + +def cancel_roll_leg(conn, leg_id: int) -> tuple[bool, str]: + row = conn.execute( + "SELECT * FROM roll_legs WHERE id=? AND status=?", + (int(leg_id), LEG_STATUS_PENDING), + ).fetchone() + if not row: + return False, "仅可删除监控中的腿" + conn.execute( + "UPDATE roll_legs SET status=? WHERE id=?", + (LEG_STATUS_CANCELLED, int(leg_id)), + ) + return True, "已删除" + + +def check_roll_monitors( + conn, + *, + get_mark_price_fn: Callable[[str], Optional[float]], + fill_roll_leg_fn: Callable[[dict, dict, dict, dict], tuple[bool, str]], + is_trading_session_fn: Callable[[], bool], + get_risk_budget_fn: Callable[[], float], + get_entry_price_fn: Optional[Callable[[str, str, float], float]] = None, +) -> None: + """扫描 pending 滚仓腿,标记价穿越则重算手数并市价成交。""" + if not is_trading_session_fn(): + return + rows = conn.execute( + """SELECT l.*, g.order_monitor_id, g.symbol, g.direction, g.initial_take_profit, + g.risk_percent, g.leg_count AS group_leg_count, + m.lots AS mon_lots, m.entry_price AS mon_entry, m.take_profit AS mon_tp, + m.status AS mon_status + FROM roll_legs l + JOIN roll_groups g ON g.id = l.roll_group_id + JOIN trade_order_monitors m ON m.id = g.order_monitor_id + WHERE l.status=? AND g.status='active' AND m.status='active'""", + (LEG_STATUS_PENDING,), + ).fetchall() + for raw in rows: + leg = dict(raw) + if (leg.get("mon_status") or "").strip().lower() != "active": + _invalidate_leg(conn, leg, "监控已结束") + continue + sym = (leg.get("symbol") or "").strip() + mark = get_mark_price_fn(sym) + if not mark or mark <= 0: + continue + prev_mark = float(leg.get("last_mark_price") or mark) + mode = (leg.get("add_mode") or "").strip().lower() + trigger = float(leg.get("limit_price") or leg.get("breakthrough_price") or 0) + direction = (leg.get("direction") or "long").strip().lower() + if mode in FIB_MODES or mode == ADD_MODE_BREAKOUT: + if not detect_mark_cross(direction, mode, prev_mark, mark, trigger): + conn.execute( + "UPDATE roll_legs SET last_mark_price=? WHERE id=?", + (float(mark), int(leg["id"])), + ) + continue + mon = { + "id": leg["order_monitor_id"], + "symbol": sym, + "direction": direction, + "lots": leg["mon_lots"], + "entry_price": leg["mon_entry"], + "take_profit": leg["mon_tp"] or leg["initial_take_profit"], + } + entry_fb = float(leg["mon_entry"] or 0) + entry_existing = ( + get_entry_price_fn(sym, direction, entry_fb) + if get_entry_price_fn + else entry_fb + ) + grp = { + "id": leg["roll_group_id"], + "order_monitor_id": leg["order_monitor_id"], + "leg_count": leg.get("group_leg_count") or 0, + "risk_percent": leg.get("risk_percent"), + } + preview, err = preview_roll( + direction=direction, + symbol=sym, + qty_existing=float(leg["mon_lots"] or 0), + entry_existing=entry_existing, + initial_take_profit=float(leg["mon_tp"] or leg["initial_take_profit"] or 0), + add_mode=mode, + new_stop_loss=float(leg["new_stop_loss"] or 0), + risk_budget=float(leg.get("risk_percent") or 0) or get_risk_budget_fn(), + mult=int(get_contract_spec(sym).get("mult") or 1), + mark_price=mark, + limit_price=trigger if mode in FIB_MODES else None, + breakthrough_price=trigger if mode == ADD_MODE_BREAKOUT else None, + legs_done=int(leg.get("group_leg_count") or 0), + at_trigger=True, + ) + if err or not preview: + _invalidate_leg(conn, leg, err or "触发时无法加仓") + continue + ok, msg = fill_roll_leg_fn(mon, grp, leg, preview) + if not ok: + logger.warning("roll leg fill failed #%s: %s", leg.get("id"), msg) + + +def _invalidate_leg(conn, leg: dict, reason: str) -> None: + conn.execute( + "UPDATE roll_legs SET status=?, invalidated_reason=? WHERE id=?", + (LEG_STATUS_INVALIDATED, (reason or "")[:200], int(leg["id"])), + ) diff --git a/strategy/strategy_snapshot_lib.py b/modules/strategy/strategy_snapshot_lib.py similarity index 100% rename from strategy/strategy_snapshot_lib.py rename to modules/strategy/strategy_snapshot_lib.py diff --git a/strategy/strategy_trend_lib.py b/modules/strategy/strategy_trend_lib.py similarity index 95% rename from strategy/strategy_trend_lib.py rename to modules/strategy/strategy_trend_lib.py index e818362..59887ad 100644 --- a/strategy/strategy_trend_lib.py +++ b/modules/strategy/strategy_trend_lib.py @@ -1,233 +1,233 @@ -# Copyright (c) 2025-2026 马建军. All rights reserved. -# 专有软件 — 未经授权禁止复制、传播、转售。 -# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 -# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md - -"""趋势回调:纯计算(期货整数手)。""" -from __future__ import annotations - -import json -import math -from typing import Any, Optional, Tuple - -from contract_specs import get_contract_spec - - -def validate_trend_bounds(direction: str, stop_loss: float, add_upper: float) -> Optional[str]: - direction = (direction or "long").strip().lower() - if direction == "long": - if not (float(stop_loss) < float(add_upper)): - return "做多:止损须低于补仓上沿" - else: - if not (float(stop_loss) > float(add_upper)): - return "做空:止损须高于补仓下沿" - return None - - -def build_grid_prices(direction: str, sl: float, upper: float, n_legs: int) -> list[float]: - sl, upper = float(sl), float(upper) - out: list[float] = [] - if n_legs <= 0: - return out - direction = (direction or "long").strip().lower() - if direction == "long": - if upper <= sl: - return out - span = upper - sl - for i in range(1, n_legs + 1): - out.append(sl + (i / float(n_legs + 1)) * span) - out.sort(reverse=True) - else: - if sl <= upper: - return out - span = sl - upper - for i in range(1, n_legs + 1): - out.append(upper + (i / float(n_legs + 1)) * span) - out.sort() - return [round(p, 4) for p in out] - - -def compute_trend_plan_futures( - *, - direction: str, - stop_loss: float, - add_upper: float, - take_profit: float, - risk_percent: float, - capital: float, - live_price: float, - ths_code: str, - dca_legs: int = 5, -) -> Tuple[Optional[dict[str, Any]], Optional[str]]: - err = validate_trend_bounds(direction, stop_loss, add_upper) - if err: - return None, err - spec = get_contract_spec(ths_code) - mult = spec["mult"] - d = (direction or "long").strip().lower() - if d == "short": - worst_per_lot = (float(stop_loss) - float(add_upper)) * mult - else: - worst_per_lot = (float(add_upper) - float(stop_loss)) * mult - if worst_per_lot <= 0: - return None, "止损与补仓边界无法计算风险" - budget = float(capital) * float(risk_percent) / 100.0 - total_lots = int(math.floor(budget / worst_per_lot)) - if total_lots < 3: - return None, f"按 {risk_percent}% 风险,总手数至少需 3 手才能拆分首仓+补仓(当前 {total_lots} 手)" - first_lots = total_lots // 2 - remainder = total_lots - first_lots - legs = max(1, min(int(dca_legs), remainder)) - per_leg = remainder // legs - leg_amounts = [per_leg] * (legs - 1) + [remainder - per_leg * (legs - 1)] - if any(x < 1 for x in leg_amounts): - legs = 1 - leg_amounts = [remainder] - grid = build_grid_prices(d, stop_loss, add_upper, len(leg_amounts)) - margin_rate = spec["margin_rate"] - plan_margin = float(live_price) * mult * total_lots * margin_rate - return { - "direction": d, - "stop_loss": float(stop_loss), - "add_upper": float(add_upper), - "take_profit": float(take_profit), - "risk_percent": float(risk_percent), - "capital_snapshot": float(capital), - "live_price_ref": float(live_price), - "target_lots": total_lots, - "first_lots": first_lots, - "remainder_lots": remainder, - "dca_legs": len(leg_amounts), - "leg_amounts": leg_amounts, - "leg_amounts_json": json.dumps(leg_amounts), - "grid_prices_json": json.dumps(grid), - "grid": grid, - "plan_margin": round(plan_margin, 2), - "mult": mult, - }, None - - -def trend_dca_level_reached(direction: str, mark_price: float, level: float) -> bool: - d = (direction or "long").strip().lower() - pf, lv = float(mark_price), float(level) - return pf <= lv if d == "long" else pf >= lv - - -def trend_strategy_periods() -> list[dict[str, str]]: - """策略页可选 K 线周期。""" - from kline_chart import MARKET_PERIODS - - skip = frozenset({"timeshare", "w"}) - return [p for p in MARKET_PERIODS if p["key"] not in skip] - - -def trend_period_label(key: str) -> str: - k = (key or "").strip() - for p in trend_strategy_periods(): - if p["key"] == k: - return p["label"] - return k or "15分" - - -def normalize_trend_period(key: str) -> str: - valid = {p["key"] for p in trend_strategy_periods()} - k = (key or "15m").strip() - return k if k in valid else "15m" - - -def _avg_after_entries(entries: list[tuple[float, int]]) -> float: - total = sum(q for _, q in entries) - if total <= 0: - return 0.0 - return sum(p * q for p, q in entries) / total - - -def enrich_trend_plan_preview( - plan: dict, - *, - symbol: str, - symbol_name: str = "", - period: str = "15m", -) -> dict[str, Any]: - """补全预览:周期、风险金额、分档表格(对齐币圈预览样式)。""" - out = dict(plan) - d = (out.get("direction") or "long").strip().lower() - sl = float(out["stop_loss"]) - tp = float(out["take_profit"]) - mult = float(out.get("mult") or 1) - entry0 = float(out["live_price_ref"]) - first_lots = int(out["first_lots"]) - leg_amounts = [int(x) for x in (out.get("leg_amounts") or [])] - grid = [float(x) for x in (out.get("grid") or [])] - capital = float(out.get("capital_snapshot") or 0) - risk_pct = float(out.get("risk_percent") or 0) - budget = capital * risk_pct / 100.0 - remainder = int(out.get("remainder_lots") or sum(leg_amounts)) - - out["symbol"] = symbol - out["symbol_name"] = symbol_name or symbol - out["period"] = normalize_trend_period(period) - out["period_label"] = trend_period_label(out["period"]) - out["stop_loss_budget"] = round(budget, 2) - out["direction_label"] = "做多" if d == "long" else "做空" - - entries: list[tuple[float, int]] = [(entry0, first_lots)] - rows: list[dict[str, Any]] = [] - - def leg_metrics() -> tuple[float, float, float, Optional[float]]: - total = sum(q for _, q in entries) - avg = _avg_after_entries(entries) - if d == "long": - profit = (tp - avg) * total * mult - loss = (avg - sl) * total * mult - else: - profit = (avg - tp) * total * mult - loss = (sl - avg) * total * mult - rr = profit / loss if loss > 0 else None - return ( - round(avg, 4), - round(profit, 2), - round(loss, 2), - round(rr, 2) if rr is not None else None, - ) - - avg, profit, loss, rr = leg_metrics() - rows.append({ - "level": "首仓", - "price": round(entry0, 4), - "lots": first_lots, - "avg_after": avg, - "profit_at_tp": profit, - "loss_at_sl": loss, - "rr_ratio": rr, - }) - out["first_rr_ratio"] = rr - - for i, lots in enumerate(leg_amounts): - price = grid[i] if i < len(grid) else sl - entries.append((float(price), int(lots))) - avg, profit, loss, rr = leg_metrics() - rows.append({ - "level": f"补仓{i + 1}", - "price": round(float(price), 4), - "lots": int(lots), - "avg_after": avg, - "profit_at_tp": profit, - "loss_at_sl": loss, - "rr_ratio": rr, - }) - - out["preview_rows"] = rows - out["summary_line"] = ( - f"{out['symbol_name']} {out['symbol']} {out['direction_label']} {out['period_label']}" - f" | 权益 {capital:.2f} 元" - f" | 参考价 {entry0}" - f" | 计划保证金 ≈ {out.get('plan_margin')} 元" - f" | 总手 {out.get('target_lots')}(首仓 {first_lots} + 补仓 {remainder})" - ) - out["detail_line"] = ( - f"止损价 {sl} | 止损金额 {out['stop_loss_budget']} 元(权益 × 风险 {risk_pct}%)" - f" | 补仓边界 {float(out['add_upper'])} | 止盈价 {tp}" - f" | 首仓盈亏比 {out['first_rr_ratio'] if out['first_rr_ratio'] is not None else '—'}" - ) - return out +# Copyright (c) 2025-2026 马建军. All rights reserved. +# 专有软件 — 未经授权禁止复制、传播、转售。 +# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 +# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md + +"""趋势回调:纯计算(期货整数手)。""" +from __future__ import annotations + +import json +import math +from typing import Any, Optional, Tuple + +from modules.core.contract_specs import get_contract_spec + + +def validate_trend_bounds(direction: str, stop_loss: float, add_upper: float) -> Optional[str]: + direction = (direction or "long").strip().lower() + if direction == "long": + if not (float(stop_loss) < float(add_upper)): + return "做多:止损须低于补仓上沿" + else: + if not (float(stop_loss) > float(add_upper)): + return "做空:止损须高于补仓下沿" + return None + + +def build_grid_prices(direction: str, sl: float, upper: float, n_legs: int) -> list[float]: + sl, upper = float(sl), float(upper) + out: list[float] = [] + if n_legs <= 0: + return out + direction = (direction or "long").strip().lower() + if direction == "long": + if upper <= sl: + return out + span = upper - sl + for i in range(1, n_legs + 1): + out.append(sl + (i / float(n_legs + 1)) * span) + out.sort(reverse=True) + else: + if sl <= upper: + return out + span = sl - upper + for i in range(1, n_legs + 1): + out.append(upper + (i / float(n_legs + 1)) * span) + out.sort() + return [round(p, 4) for p in out] + + +def compute_trend_plan_futures( + *, + direction: str, + stop_loss: float, + add_upper: float, + take_profit: float, + risk_percent: float, + capital: float, + live_price: float, + ths_code: str, + dca_legs: int = 5, +) -> Tuple[Optional[dict[str, Any]], Optional[str]]: + err = validate_trend_bounds(direction, stop_loss, add_upper) + if err: + return None, err + spec = get_contract_spec(ths_code) + mult = spec["mult"] + d = (direction or "long").strip().lower() + if d == "short": + worst_per_lot = (float(stop_loss) - float(add_upper)) * mult + else: + worst_per_lot = (float(add_upper) - float(stop_loss)) * mult + if worst_per_lot <= 0: + return None, "止损与补仓边界无法计算风险" + budget = float(capital) * float(risk_percent) / 100.0 + total_lots = int(math.floor(budget / worst_per_lot)) + if total_lots < 3: + return None, f"按 {risk_percent}% 风险,总手数至少需 3 手才能拆分首仓+补仓(当前 {total_lots} 手)" + first_lots = total_lots // 2 + remainder = total_lots - first_lots + legs = max(1, min(int(dca_legs), remainder)) + per_leg = remainder // legs + leg_amounts = [per_leg] * (legs - 1) + [remainder - per_leg * (legs - 1)] + if any(x < 1 for x in leg_amounts): + legs = 1 + leg_amounts = [remainder] + grid = build_grid_prices(d, stop_loss, add_upper, len(leg_amounts)) + margin_rate = spec["margin_rate"] + plan_margin = float(live_price) * mult * total_lots * margin_rate + return { + "direction": d, + "stop_loss": float(stop_loss), + "add_upper": float(add_upper), + "take_profit": float(take_profit), + "risk_percent": float(risk_percent), + "capital_snapshot": float(capital), + "live_price_ref": float(live_price), + "target_lots": total_lots, + "first_lots": first_lots, + "remainder_lots": remainder, + "dca_legs": len(leg_amounts), + "leg_amounts": leg_amounts, + "leg_amounts_json": json.dumps(leg_amounts), + "grid_prices_json": json.dumps(grid), + "grid": grid, + "plan_margin": round(plan_margin, 2), + "mult": mult, + }, None + + +def trend_dca_level_reached(direction: str, mark_price: float, level: float) -> bool: + d = (direction or "long").strip().lower() + pf, lv = float(mark_price), float(level) + return pf <= lv if d == "long" else pf >= lv + + +def trend_strategy_periods() -> list[dict[str, str]]: + """策略页可选 K 线周期。""" + from modules.market.kline_chart import MARKET_PERIODS + + skip = frozenset({"timeshare", "w"}) + return [p for p in MARKET_PERIODS if p["key"] not in skip] + + +def trend_period_label(key: str) -> str: + k = (key or "").strip() + for p in trend_strategy_periods(): + if p["key"] == k: + return p["label"] + return k or "15分" + + +def normalize_trend_period(key: str) -> str: + valid = {p["key"] for p in trend_strategy_periods()} + k = (key or "15m").strip() + return k if k in valid else "15m" + + +def _avg_after_entries(entries: list[tuple[float, int]]) -> float: + total = sum(q for _, q in entries) + if total <= 0: + return 0.0 + return sum(p * q for p, q in entries) / total + + +def enrich_trend_plan_preview( + plan: dict, + *, + symbol: str, + symbol_name: str = "", + period: str = "15m", +) -> dict[str, Any]: + """补全预览:周期、风险金额、分档表格(对齐币圈预览样式)。""" + out = dict(plan) + d = (out.get("direction") or "long").strip().lower() + sl = float(out["stop_loss"]) + tp = float(out["take_profit"]) + mult = float(out.get("mult") or 1) + entry0 = float(out["live_price_ref"]) + first_lots = int(out["first_lots"]) + leg_amounts = [int(x) for x in (out.get("leg_amounts") or [])] + grid = [float(x) for x in (out.get("grid") or [])] + capital = float(out.get("capital_snapshot") or 0) + risk_pct = float(out.get("risk_percent") or 0) + budget = capital * risk_pct / 100.0 + remainder = int(out.get("remainder_lots") or sum(leg_amounts)) + + out["symbol"] = symbol + out["symbol_name"] = symbol_name or symbol + out["period"] = normalize_trend_period(period) + out["period_label"] = trend_period_label(out["period"]) + out["stop_loss_budget"] = round(budget, 2) + out["direction_label"] = "做多" if d == "long" else "做空" + + entries: list[tuple[float, int]] = [(entry0, first_lots)] + rows: list[dict[str, Any]] = [] + + def leg_metrics() -> tuple[float, float, float, Optional[float]]: + total = sum(q for _, q in entries) + avg = _avg_after_entries(entries) + if d == "long": + profit = (tp - avg) * total * mult + loss = (avg - sl) * total * mult + else: + profit = (avg - tp) * total * mult + loss = (sl - avg) * total * mult + rr = profit / loss if loss > 0 else None + return ( + round(avg, 4), + round(profit, 2), + round(loss, 2), + round(rr, 2) if rr is not None else None, + ) + + avg, profit, loss, rr = leg_metrics() + rows.append({ + "level": "首仓", + "price": round(entry0, 4), + "lots": first_lots, + "avg_after": avg, + "profit_at_tp": profit, + "loss_at_sl": loss, + "rr_ratio": rr, + }) + out["first_rr_ratio"] = rr + + for i, lots in enumerate(leg_amounts): + price = grid[i] if i < len(grid) else sl + entries.append((float(price), int(lots))) + avg, profit, loss, rr = leg_metrics() + rows.append({ + "level": f"补仓{i + 1}", + "price": round(float(price), 4), + "lots": int(lots), + "avg_after": avg, + "profit_at_tp": profit, + "loss_at_sl": loss, + "rr_ratio": rr, + }) + + out["preview_rows"] = rows + out["summary_line"] = ( + f"{out['symbol_name']} {out['symbol']} {out['direction_label']} {out['period_label']}" + f" | 权益 {capital:.2f} 元" + f" | 参考价 {entry0}" + f" | 计划保证金 ≈ {out.get('plan_margin')} 元" + f" | 总手 {out.get('target_lots')}(首仓 {first_lots} + 补仓 {remainder})" + ) + out["detail_line"] = ( + f"止损价 {sl} | 止损金额 {out['stop_loss_budget']} 元(权益 × 风险 {risk_pct}%)" + f" | 补仓边界 {float(out['add_upper'])} | 止盈价 {tp}" + f" | 首仓盈亏比 {out['first_rr_ratio'] if out['first_rr_ratio'] is not None else '—'}" + ) + return out diff --git a/modules/trading/__init__.py b/modules/trading/__init__.py new file mode 100644 index 0000000..dfd557d --- /dev/null +++ b/modules/trading/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) 2025-2026 马建军. All rights reserved. + + +def register(deps) -> None: + from modules.trading.install import install_trading + + install_trading( + deps.app, + login_required=deps.login_required, + require_nav=deps.require_nav, + get_db=deps.get_db, + get_setting=deps.get_setting, + set_setting=deps.set_setting, + fetch_price=deps.fetch_price, + send_wechat_msg=deps.send_wechat_msg, + ) + + +__all__ = ["register"] diff --git a/modules/trading/install.py b/modules/trading/install.py new file mode 100644 index 0000000..82a46fe --- /dev/null +++ b/modules/trading/install.py @@ -0,0 +1,4685 @@ +# Copyright (c) 2025-2026 马建军. All rights reserved. +# 专有软件 — 未经授权禁止复制、传播、转售。 +# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 +# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md + +"""期货下单、可开仓品种、策略交易路由注册。""" +from __future__ import annotations + +import json +import logging +import os +import threading +import time +from datetime import datetime +from typing import Any, Callable, Optional + +from flask import flash, jsonify, redirect, render_template, request, url_for, Response, stream_with_context + +from modules.core.contract_specs import calc_position_metrics, get_contract_spec +from modules.fees.fee_specs import calc_fee_breakdown +from modules.market.kline_stream import sse_format +from modules.market.market_sessions import is_night_trading_session, is_trading_session, trading_session_clock +from modules.trading.position_sizing import ( + MODE_AMOUNT, + MODE_FIXED, + DEFAULT_MAX_ORDER_LOTS, + calc_lots_by_amount, + calc_lots_by_risk, + calc_margin_usage_pct, + cap_lots_for_margin_budget, + calc_order_tick_metrics, + normalize_sizing_mode, +) +from modules.trading.product_recommend import ( + assert_product_allowed_for_capital, + should_apply_small_account_scope, + small_account_margin_recommendations, + small_account_scope_hint, + SMALL_ACCOUNT_SCOPE_LABEL, +) +from modules.trading.recommend_store import ( + recommend_payload, + refresh_recommend_cache, +) +from modules.trading.recommend_stream import recommend_hub, schedule_recommend_refresh, start_recommend_worker +from modules.trading.position_stream import position_hub, start_position_worker +from modules.ctp.ctp_settings import is_ctp_auto_connect_enabled +from modules.ctp.ctp_reconnect import start_ctp_reconnect_worker +from modules.ctp.ctp_premarket_connect import start_ctp_premarket_connect_worker +from modules.ctp.ctp_fee_worker import start_ctp_fee_worker +from modules.trading.pending_order_worker import start_pending_order_worker +from modules.trading.order_pending import ( + cancel_pending_monitor, + pending_auto_cancel_remaining, + pending_monitor_has_live_order, + reconcile_pending_orders, +) +from modules.core.db_conn import commit_retry, execute_retry +from modules.trading.sl_tp_guard import ( + cancel_monitor_exit_orders, + ensure_monitor_order_columns, + monitor_order_status, + monitor_source_label, + place_monitor_exit_orders, + reconcile_monitors_without_position, + start_sl_tp_guard_worker, + write_manual_close_trade_log, +) +from risk.account_risk_lib import ( + assert_can_open, + count_active_trade_monitors, + get_risk_status, + on_mood_journal_freeze, + on_user_initiated_close, + parse_mood_issues, + trading_day_label, +) +from strategy.strategy_db import init_strategy_tables +from strategy.strategy_roll_lib import ( + ADD_MODE_BREAKOUT, + ADD_MODE_MARKET, + FIB_MODES, + LEG_STATUS_CANCELLED, + LEG_STATUS_FILLED, + LEG_STATUS_PENDING, + PENDING_MODES, + add_mode_label, + avg_entry_after_add, + preview_roll, + roll_eligibility_error, +) +from strategy.strategy_roll_monitor_lib import ( + cancel_roll_leg, + check_roll_monitors, + roll_sync_after_external_close, +) +from strategy.strategy_snapshot_lib import list_snapshots, save_snapshot +from strategy.strategy_trend_lib import ( + compute_trend_plan_futures, + enrich_trend_plan_preview, + normalize_trend_period, + trend_dca_level_reached, + trend_period_label, + trend_strategy_periods, +) +from strategy.strategy_snapshot_lib import STRATEGY_ROLL, STRATEGY_TREND +from modules.core.symbols import ths_to_codes, resolve_main_contract, PRODUCTS, PRODUCT_CATEGORIES, position_symbol_meta +from modules.core.trading_context import ( + TRADING_MODE_LIVE, + TRADING_MODE_SIM, + get_account_capital, + get_fixed_amount, + get_fixed_lots, + get_max_margin_pct, + get_pending_order_timeout_min, + get_pending_order_timeout_sec, + get_recommend_capital, + get_roll_max_margin_pct, + get_risk_percent, + get_sizing_mode, + get_trailing_be_tick_buffer, + get_trading_mode, + is_ctp_connected, + trading_mode_label, +) +from modules.ctp.ctp_entry_price import round_to_tick +from modules.ctp.ctp_symbol import ths_to_vnpy_symbol +from modules.ctp.ctp_trading_state import position_key, trading_state +from modules.ctp.vnpy_bridge import ( + _ctp_td_lock, + ctp_cancel_order, + ctp_connect, + ctp_account_margin_used, + ctp_estimate_margin_one_lot, + ctp_get_account, + ctp_get_tick_price, + ctp_list_active_orders, + ctp_list_positions, + ctp_list_trades, + ctp_status, + execute_order, + get_bridge, + set_position_refresh_callback, + set_tick_sl_tp_callback, + set_tick_quote_callback, + set_ctp_connected_callback, +) + + +logger = logging.getLogger(__name__) + + +def install_trading(app, *, login_required, require_nav, get_db, get_setting, set_setting, fetch_price, send_wechat_msg): + """注册交易相关路由。""" + _nav = require_nav + _live_refresh_lock = threading.Lock() + _ctp_status_cache: dict = {"mode": "", "status": {}, "ts": 0.0} + _ctp_status_cache_lock = threading.Lock() + _ctp_status_refresh_flag = {"busy": False} + + def _remember_ctp_status(mode: str, st: dict) -> None: + if not isinstance(st, dict) or not st: + return + with _ctp_status_cache_lock: + _ctp_status_cache["mode"] = mode + _ctp_status_cache["status"] = dict(st) + _ctp_status_cache["ts"] = time.time() + + def _schedule_ctp_status_refresh(mode: str) -> None: + with _ctp_status_cache_lock: + if _ctp_status_refresh_flag["busy"]: + return + _ctp_status_refresh_flag["busy"] = True + + def _run() -> None: + try: + st = dict(ctp_status(mode) or {}) + _remember_ctp_status(mode, st) + snap = position_hub.get_snapshot() + if snap: + merged = dict(snap) + merged["ctp_status"] = st + position_hub.set_snapshot(merged) + except Exception as exc: + logger.debug("ctp status refresh: %s", exc) + finally: + with _ctp_status_cache_lock: + _ctp_status_refresh_flag["busy"] = False + + threading.Thread( + target=_run, + daemon=True, + name="ctp-status-refresh", + ).start() + + def _cached_ctp_status(mode: str) -> dict: + """页面/SSE 优先读快照与内存缓存,避免同步 worker IPC 阻塞 HTTP 线程。""" + try: + snap = position_hub.get_snapshot() or {} + st = snap.get("ctp_status") + if isinstance(st, dict) and st: + _remember_ctp_status(mode, st) + return dict(st) + except Exception: + pass + with _ctp_status_cache_lock: + if _ctp_status_cache["mode"] == mode and _ctp_status_cache["status"]: + return dict(_ctp_status_cache["status"]) + _schedule_ctp_status_refresh(mode) + return { + "connected": False, + "connecting": True, + "last_error": "", + "mode_label": trading_mode_label(get_setting), + } + + def _sizing_mode_label(mode: str) -> str: + m = normalize_sizing_mode(mode) + if m == MODE_AMOUNT: + return "固定金额" + return "固定手数" + + def _symbol_display_fields(sym: str) -> dict: + meta = position_symbol_meta(sym) + name = meta.get("name") or sym + return { + "symbol": name, + "symbol_name": name, + "symbol_exchange": meta.get("exchange") or "", + "symbol_is_main": bool(meta.get("is_main")), + } + + def _breakeven_locked( + *, + entry: Optional[float], + stop_loss: Optional[float], + direction: str, + tick_size: Optional[float] = None, + trailing_r_locked: int = 0, + ) -> bool: + if int(trailing_r_locked or 0) >= 1: + return True + if entry is None or stop_loss is None: + return False + try: + entry_f = float(entry) + sl_f = float(stop_loss) + except (TypeError, ValueError): + return False + if entry_f <= 0: + return False + tick = float(tick_size or 0) or max(abs(entry_f) * 1e-6, 0.01) + be_mult = max(1, get_trailing_be_tick_buffer(get_setting)) + d = (direction or "long").strip().lower() + expected_be = entry_f + be_mult * tick if d == "long" else entry_f - be_mult * tick + tol = be_mult * tick + tick * 0.05 + if abs(sl_f - expected_be) <= tol: + return True + buf = tick * max(2, be_mult) + near = abs(sl_f - entry_f) <= buf + tick + if d == "long": + return near and sl_f >= entry_f - tick * 0.05 + return near and sl_f <= entry_f + tick * 0.05 + + def _schedule_recommend_refresh() -> None: + from modules.core.db_conn import DB_PATH + + schedule_recommend_refresh( + db_path=DB_PATH, + get_capital_fn=_recommend_capital, + quote_fn=_main_quote, + init_tables_fn=lambda c: init_strategy_tables(c), + get_mode_fn=lambda: get_trading_mode(get_setting), + get_max_margin_pct_fn=lambda: get_max_margin_pct(get_setting), + get_sizing_mode_fn=lambda: get_sizing_mode(get_setting), + get_fixed_lots_fn=lambda: get_fixed_lots(get_setting), + ) + + def _recommend_payload(conn, *, use_ctp_margin: bool = True) -> dict: + mode = get_trading_mode(get_setting) + return recommend_payload( + conn, + live_capital=_recommend_capital(conn), + max_margin_pct=get_max_margin_pct(get_setting), + trading_mode=mode, + sizing_mode=get_sizing_mode(get_setting), + fixed_lots=get_fixed_lots(get_setting), + use_ctp_margin=use_ctp_margin, + ) + + def _recommend_capital(conn) -> float: + return get_recommend_capital(conn, get_setting) + + def _settings_dict() -> dict: + return { + "trading_mode": get_trading_mode(get_setting), + "position_sizing_mode": get_sizing_mode(get_setting), + "risk_percent": str(get_risk_percent(get_setting)), + "max_margin_pct": str(get_max_margin_pct(get_setting)), + } + + def _capital(conn) -> float: + return get_account_capital(conn, get_setting) + + def _main_quote(product_ths: str) -> Optional[dict]: + for p in PRODUCTS: + if p["ths"] == product_ths: + main = resolve_main_contract(p) + if not main: + return None + sym = main.get("ths_code") or "" + codes = ths_to_codes(sym) + price = None + if codes: + price = fetch_price( + sym, + codes.get("market_code", ""), + codes.get("sina_code", ""), + ) + return { + "ths_code": sym, + "price": price, + "display": main.get("display") or sym, + "name": main.get("name") or p.get("name"), + } + return None + + def _ctp_account(mode: str) -> dict: + try: + return ctp_get_account(mode) + except Exception: + return {} + + def _ctp_positions( + mode: str, + *, + refresh_if_empty: bool = True, + refresh_margin: bool = False, + ) -> list: + try: + return ctp_list_positions( + mode, + refresh_if_empty=refresh_if_empty, + refresh_margin=refresh_margin, + ) + except Exception: + return [] + + def _ctp_pos_to_ths_code(p: dict) -> str: + sym = (p.get("symbol") or "").strip() + ex = (p.get("exchange") or "").strip() + if not sym: + return "" + codes = ths_to_codes(sym) + if codes: + return codes.get("ths_code") or sym + if ex: + from modules.ctp.vnpy_bridge import CtpBridge + ths = CtpBridge._vnpy_sym_to_ths(sym, ex) + if ths: + return ths + return sym + + def _resolve_position_margin( + *, + sym: str, + direction: str, + lots: int, + entry: float, + mode: str, + ctp: Optional[dict] = None, + mon_margin: Optional[float] = None, + est_margin: Optional[float] = None, + ) -> tuple[Optional[float], str]: + """占用保证金:柜台持仓 > CTP 合约率估算 > 本地规格估算 > 库内缓存。""" + ctp_margin = float(ctp.get("margin") or 0) if ctp else 0.0 + if ctp_margin > 0: + return round(ctp_margin, 2), "ctp" + connected = bool(ctp_status(mode).get("connected")) + ths_sym = sym + if ctp: + ths_sym = _ctp_pos_to_ths_code(ctp) or sym + else: + codes = ths_to_codes(sym) + if codes and codes.get("ths_code"): + ths_sym = codes["ths_code"] + if connected and ths_sym and entry > 0 and lots > 0: + per_lot = ctp_estimate_margin_one_lot( + mode, ths_sym, entry, direction=direction, + ) + if per_lot and per_lot > 0: + return round(per_lot * lots, 2), "ctp" + if est_margin and float(est_margin) > 0: + return round(float(est_margin), 2), "estimate" + if not connected and mon_margin and float(mon_margin) > 0: + return round(float(mon_margin), 2), "db" + return None, "estimate" + + def _apply_account_margin_to_rows( + rows: list[dict], + mode: str, + capital: float, + ) -> list[dict]: + """仅在持仓缺少柜台保证金时补全;已有 CTP 持仓保证金的行不覆盖。""" + if not ctp_status(mode).get("connected"): + return rows + active = [ + r for r in rows + if r.get("order_state") != "pending" and int(r.get("lots") or 0) > 0 + ] + if not active: + return rows + + def _has_ctp_margin(row: dict) -> bool: + return ( + float(row.get("margin") or 0) > 0 + and row.get("margin_source") == "ctp" + ) + + without_margin = [r for r in active if not _has_ctp_margin(r)] + for row in active: + if _has_ctp_margin(row) and capital > 0: + m = float(row.get("margin") or 0) + row["position_pct"] = round(m / capital * 100, 2) + if not without_margin: + return rows + + total_used = ctp_account_margin_used(mode) + if not total_used: + return rows + known_sum = sum( + float(r.get("margin") or 0) for r in active if _has_ctp_margin(r) + ) + pool = max(0.0, float(total_used) - known_sum) if known_sum > 0 else float(total_used) + if pool <= 0: + return rows + + weights: list[float] = [] + for row in without_margin: + sym = (row.get("symbol_code") or "").strip() + lots = int(row.get("lots") or 0) + entry = float(row.get("entry_price") or 0) + if sym and lots > 0 and entry > 0: + spec = get_contract_spec(sym) + weights.append(entry * spec["mult"] * lots) + else: + weights.append(0.0) + total_weight = sum(weights) + assigned = 0.0 + for i, row in enumerate(without_margin): + if total_weight <= 0: + margin = round(pool / len(without_margin), 2) + elif i == len(without_margin) - 1: + margin = round(pool - assigned, 2) + else: + margin = round(pool * weights[i] / total_weight, 2) + assigned += margin + row["margin"] = margin + row["margin_source"] = "ctp" + if capital > 0: + row["position_pct"] = round(margin / capital * 100, 2) + return rows + + def _persist_ctp_snapshot_to_monitors( + conn, + rows: list[dict], + mode: str, + ) -> None: + """将柜台校正后的均价、手数、现价、浮盈、保证金等写入 trade_order_monitors。""" + if not ctp_status(mode).get("connected"): + return + ensure_monitor_order_columns(conn) + for row in rows: + mid = row.get("monitor_id") + if not mid or row.get("order_state") == "pending": + continue + entry_price = row.get("entry_price") + lots = row.get("lots") + mark_price = row.get("mark_price") + if mark_price is None: + mark_price = row.get("current_price") + float_pnl = row.get("float_pnl") + margin = row.get("margin") + position_pct = row.get("position_pct") + open_fee = row.get("est_fee") + if ( + entry_price is None and lots is None and mark_price is None + and float_pnl is None and margin is None + and position_pct is None and open_fee is None + ): + continue + try: + execute_retry( + conn, + """UPDATE trade_order_monitors SET + entry_price=COALESCE(?, entry_price), + lots=COALESCE(?, lots), + mark_price=COALESCE(?, mark_price), + float_pnl=COALESCE(?, float_pnl), + margin=COALESCE(?, margin), + position_pct=COALESCE(?, position_pct), + open_fee=COALESCE(?, open_fee) + WHERE id=? AND status='active'""", + ( + entry_price, lots, mark_price, float_pnl, + margin, position_pct, open_fee, int(mid), + ), + ) + except Exception as exc: + logger.debug("persist monitor ctp snapshot %s: %s", mid, exc) + + def _positions_from_live_snapshot() -> list[dict]: + snap = position_hub.get_snapshot() or {} + out: list[dict] = [] + for row in snap.get("rows") or []: + lots = int(row.get("lots") or 0) + if lots <= 0 or row.get("order_state") == "pending": + continue + sym = ( + row.get("symbol_code") + or row.get("ths_code") + or row.get("symbol") + or "" + ) + if not sym: + continue + out.append({ + "symbol": sym, + "direction": row.get("direction") or "long", + "lots": lots, + "avg_price": row.get("entry_price") or row.get("avg_price") or 0, + "open_time": row.get("open_time") or "", + "margin": row.get("margin"), + "pnl": row.get("float_pnl"), + "mark_price": row.get("mark_price") or row.get("current_price"), + "exchange": row.get("exchange") or "", + }) + return out + + def _positions_for_monitor_restore(mode: str, *, allow_ctp: bool = True) -> list[dict]: + if allow_ctp: + positions = list(_ctp_positions(mode, refresh_if_empty=True) or []) + if positions: + return positions + positions = list(trading_state.get_positions() or []) + if positions: + return positions + positions = _positions_from_live_snapshot() + if not allow_ctp: + return positions + margin_used = float(ctp_account_margin_used(mode) or 0) + if margin_used <= 100 or not positions: + return [] + return positions + + def _cached_position_mark(sym: str, direction: str = "") -> Optional[float]: + sym_l = (sym or "").strip().lower() + direction_l = (direction or "").strip().lower() + for p in list(trading_state.get_positions() or []) + _positions_from_live_snapshot(): + if direction_l and (p.get("direction") or "long").strip().lower() != direction_l: + continue + ps = (p.get("symbol") or "").strip() + if not ps: + continue + if not _match_ctp_symbol(ps, sym_l): + continue + for key in ("mark_price", "current_price", "last_price"): + val = p.get(key) + try: + px = float(val or 0) + except (TypeError, ValueError): + px = 0.0 + if px > 0: + return px + snap = position_hub.get_snapshot() or {} + for row in snap.get("rows") or []: + rs = row.get("symbol_code") or row.get("symbol") or "" + if not rs or not _match_ctp_symbol(rs, sym_l): + continue + if direction_l and (row.get("direction") or "long").strip().lower() != direction_l: + continue + for key in ("mark_price", "current_price", "last_price", "entry_price"): + try: + px = float(row.get(key) or 0) + except (TypeError, ValueError): + px = 0.0 + if px > 0: + return px + return None + + def _ensure_monitors_from_ctp(conn, mode: str, *, allow_ctp: bool = True) -> None: + """CTP 有持仓但本地无监控时,自动补写一条 active 记录供展示。""" + if not ctp_status(mode).get("connected"): + return + ctp_positions = _positions_for_monitor_restore(mode, allow_ctp=allow_ctp) + for p in ctp_positions: + lots = int(p.get("lots") or 0) + if lots <= 0: + continue + direction = p.get("direction") or "long" + ths = _ctp_pos_to_ths_code(p) + if not ths: + continue + existing = _find_or_revive_monitor(conn, ths, direction) + if existing: + _sync_monitor_from_ctp( + conn, int(existing["id"]), ths, direction, mode, ctp=p, + capital=_capital(conn), + ) + continue + sl, tp, trailing_be, initial_sl = _restore_sl_tp_from_closed(conn, ths, direction) + ctp_open = (p.get("open_time") or "").strip() + mid = _upsert_open_monitor( + conn, + sym=ths, + direction=direction, + lots=lots, + price=float(p.get("avg_price") or 0), + sl=sl, + tp=tp, + trailing_be=trailing_be, + ctp_open_time=ctp_open or None, + monitor_type="ctp_sync", + ) + if initial_sl is not None and sl is not None: + conn.execute( + "UPDATE trade_order_monitors SET initial_stop_loss=? WHERE id=?", + (initial_sl, mid), + ) + if ctp_positions: + return + _ensure_monitors_from_sticky_state(conn, mode) + + def _ensure_monitors_from_sticky_state(conn, mode: str) -> None: + """vnpy 持仓空窗但账户仍有保证金时,恢复本地 active 监控。""" + if not ctp_status(mode).get("connected"): + return + margin_raw = ctp_account_margin_used(mode) + if margin_raw is None or float(margin_raw or 0) <= 0: + return + if count_active_trade_monitors(conn) > 0: + return + capital = _capital(conn) + for p in trading_state.get_positions() or []: + lots = int(p.get("lots") or 0) + if lots <= 0: + continue + direction = p.get("direction") or "long" + ths = _ctp_pos_to_ths_code(p) or (p.get("symbol") or "") + if not ths: + continue + existing = _find_or_revive_monitor(conn, ths, direction) + if existing: + _sync_monitor_from_ctp( + conn, int(existing["id"]), ths, direction, mode, ctp=p, + capital=capital, + ) + continue + sl, tp, trailing_be, initial_sl = _restore_sl_tp_from_closed(conn, ths, direction) + mid = _upsert_open_monitor( + conn, + sym=ths, + direction=direction, + lots=lots, + price=float(p.get("avg_price") or 0), + sl=sl, + tp=tp, + trailing_be=trailing_be, + ctp_open_time=(p.get("open_time") or "").strip() or None, + monitor_type="ctp_sync", + ) + if initial_sl is not None and sl is not None: + conn.execute( + "UPDATE trade_order_monitors SET initial_stop_loss=? WHERE id=?", + (initial_sl, mid), + ) + if count_active_trade_monitors(conn) > 0: + return + today = datetime.now().strftime("%Y-%m-%d") + for r in conn.execute( + "SELECT * FROM trade_order_monitors WHERE status='closed' " + "AND open_time LIKE ? ORDER BY id DESC LIMIT 5", + (f"{today}%",), + ).fetchall(): + mon = dict(r) + if int(mon.get("lots") or 0) <= 0: + continue + revived = _revive_closed_monitor( + conn, mon.get("symbol") or "", mon.get("direction") or "long", + ) + if revived: + logger.info( + "保证金占用下恢复监控 id=%s sym=%s", + revived.get("id"), revived.get("symbol"), + ) + break + + def _restore_recent_pending_monitors(conn, mode: str) -> None: + """重启或 vnpy 委托缓存丢失时,恢复当日最近一笔可能仍有效的开仓挂单。""" + if not ctp_status(mode).get("connected"): + return + if conn.execute("SELECT 1 FROM trade_order_monitors WHERE status='pending' LIMIT 1").fetchone(): + return + today = datetime.now().strftime("%Y-%m-%d") + row = conn.execute( + """SELECT * FROM trade_order_monitors + WHERE status='closed' AND monitor_type='manual' + AND vt_order_id IS NOT NULL AND vt_order_id != '' + AND open_time LIKE ? + ORDER BY id DESC LIMIT 1""", + (f"{today}%",), + ).fetchone() + if not row: + return + mon = dict(row) + sym = mon.get("symbol") or "" + direction = (mon.get("direction") or "long").strip().lower() + if _find_active_monitor(conn, sym, direction): + return + for p in _ctp_positions(mode, refresh_if_empty=False): + if int(p.get("lots") or 0) <= 0: + continue + if (p.get("direction") or "long") != direction: + continue + if _match_ctp_symbol(p.get("symbol") or "", sym): + return + conn.execute( + "UPDATE trade_order_monitors SET status='pending' WHERE id=?", + (mon["id"],), + ) + logger.info("恢复挂单监控 id=%s sym=%s", mon.get("id"), sym) + + def _match_ctp_symbol(ctp_sym: str, ths: str) -> bool: + a = (ctp_sym or "").lower() + b = (ths or "").lower() + if a == b: + return True + if a and b and a.split(".")[0] == b.split(".")[0]: + return True + try: + vnpy_sym, _ = ths_to_vnpy_symbol(ths) + if a == vnpy_sym.lower(): + return True + except Exception: + pass + try: + vnpy_sym, _ = ths_to_vnpy_symbol(ctp_sym) + if vnpy_sym.lower() == b.split(".")[0]: + return True + except Exception: + pass + return False + + def _live_entry_price( + sym: str, + direction: str, + mode: str, + fallback: float = 0.0, + *, + allow_ctp: bool = False, + ) -> float: + """滚仓/展示用均价:仅柜台持仓价。""" + if not ctp_status(mode).get("connected"): + return fallback + positions = list(trading_state.get_positions() or []) + if not positions: + positions = _positions_from_live_snapshot() + if not positions and allow_ctp: + positions = _ctp_positions(mode, refresh_if_empty=False) + for p in positions: + if (p.get("direction") or "long") != (direction or "long"): + continue + if not _match_ctp_symbol(p.get("symbol") or "", sym): + continue + avg = float(p.get("avg_price") or 0) + if avg > 0: + return avg + return fallback + + def _resolve_ctp_entry_price( + mode: str, + sym: str, + direction: str, + ctp: Optional[dict], + ) -> tuple[float, str]: + del mode, direction + if not ctp: + return 0.0, "none" + avg = float(ctp.get("avg_price") or 0) + if avg > 0: + return round_to_tick(avg, sym), "ctp" + return 0.0, "none" + + def _open_commission_from_ctp_trades( + mode: str, sym: str, direction: str, + ) -> Optional[float]: + """汇总该持仓开仓成交的柜台手续费(成交回报中的 commission)。""" + if not ctp_status(mode).get("connected"): + return None + try: + trades = ctp_list_trades(mode) + except Exception: + return None + total = 0.0 + has_commission = False + for t in trades: + if (t.get("offset") or "").strip().lower() != "open": + continue + pos_dir = ( + t.get("position_direction") or t.get("direction") or "long" + ).strip().lower() + if pos_dir != (direction or "long").strip().lower(): + continue + if not _match_ctp_symbol(t.get("symbol") or "", sym): + continue + comm = float(t.get("commission") or 0) + total += comm + if comm > 0: + has_commission = True + return round(total, 2) if has_commission else None + + def _time_str(val) -> str: + if val is None: + return "" + if isinstance(val, str): + return val.strip() + return str(val).strip() + + def _holding_duration(open_time: str, now_iso: str) -> str: + try: + from app import calc_holding_duration + open_s = _time_str(open_time).replace("T", " ")[:19] + now_s = (now_iso or "").strip().replace("T", " ")[:19] + if not open_s or not now_s: + return "" + return calc_holding_duration(open_s, now_s) + except Exception: + return "" + + def _restore_sl_tp_from_closed(conn, sym: str, direction: str) -> tuple: + """重启后从最近关闭的同品种监控恢复止盈止损。""" + direction = (direction or "long").strip().lower() + for r in conn.execute( + "SELECT symbol, direction, stop_loss, take_profit, trailing_be, initial_stop_loss " + "FROM trade_order_monitors WHERE status='closed' ORDER BY id DESC LIMIT 80" + ).fetchall(): + row = dict(r) + if (row.get("direction") or "long") != direction: + continue + if not _match_ctp_symbol(sym, row.get("symbol") or ""): + continue + if row.get("stop_loss") is None and row.get("take_profit") is None: + continue + return ( + row.get("stop_loss"), + row.get("take_profit"), + int(row.get("trailing_be") or 0), + row.get("initial_stop_loss"), + ) + return None, None, 0, None + + def _restore_monitor_sl_tp_if_missing( + conn, + mon: Optional[dict], + sym: str, + direction: str, + ) -> Optional[dict]: + """活跃监控缺少止盈止损时,从最近关闭的同品种记录恢复并写回数据库。""" + if not mon: + return None + sl = mon.get("stop_loss") + tp = mon.get("take_profit") + trailing = int(mon.get("trailing_be") or 0) + if sl is not None or tp is not None or trailing: + return mon + rsl, rtp, rtrail, rinitial = _restore_sl_tp_from_closed(conn, sym, direction) + if rsl is None and rtp is None and not rtrail: + return mon + execute_retry( + conn, + """UPDATE trade_order_monitors SET + stop_loss=?, take_profit=?, trailing_be=?, initial_stop_loss=? + WHERE id=? AND status='active'""", + (rsl, rtp, rtrail, rinitial, int(mon["id"])), + ) + row = conn.execute( + "SELECT * FROM trade_order_monitors WHERE id=?", (int(mon["id"]),), + ).fetchone() + if row: + logger.info( + "恢复止盈止损 monitor=%s sym=%s sl=%s tp=%s", + mon.get("id"), sym, rsl, rtp, + ) + return dict(row) + return mon + + def _ctp_position_keys(mode: str) -> set[tuple[str, str]]: + keys: set[tuple[str, str]] = set() + for p in _ctp_positions(mode): + lots = int(p.get("lots") or 0) + if lots <= 0: + continue + sym = (p.get("symbol") or "").lower() + direction = p.get("direction") or "long" + keys.add((sym, direction)) + return keys + + def _monitor_matches_ctp_position(mon: dict, position_keys: set[tuple[str, str]]) -> bool: + ms = mon.get("symbol") or "" + md = mon.get("direction") or "long" + for ps, pd in position_keys: + if pd != md: + continue + if _match_ctp_symbol(ps, ms): + return True + return False + + def _sync_trade_monitors_with_ctp(conn, mode: str) -> int: + """关闭无对应 CTP 持仓的监控,并撤销残留止盈止损挂单。""" + return reconcile_monitors_without_position(conn, mode) + + def _effective_active_position_count( + conn, + mode: str, + *, + ctp_connected: Optional[bool] = None, + ) -> int: + """风控持仓数以柜台/快照实际持仓优先,本地监控作兜底。""" + monitor_count = count_active_trade_monitors(conn) + if ctp_connected is None: + ctp_connected = bool(_cached_ctp_status(mode).get("connected")) + if not ctp_connected: + return monitor_count + keys: set[tuple[str, str]] = set() + for p in _positions_for_monitor_restore(mode, allow_ctp=False): + lots = int(p.get("lots") or 0) + if lots <= 0: + continue + sym = ( + p.get("symbol") + or p.get("symbol_code") + or p.get("ths_code") + or "" + ).strip().lower() + direction = (p.get("direction") or "long").strip().lower() + if sym: + keys.add((sym, direction)) + return max(monitor_count, len(keys)) + + def _build_pending_orders(conn, mode: str) -> list[dict]: + pending: list[dict] = [] + for r in conn.execute( + "SELECT * FROM trade_order_monitors WHERE status='active' ORDER BY id DESC" + ).fetchall(): + mon = dict(r) + sym = mon.get("symbol") or "" + direction = mon.get("direction") or "long" + lots = int(mon.get("lots") or 0) + base = { + "symbol_code": sym, + "direction": direction, + "direction_label": "做多" if direction == "long" else "做空", + "lots": lots, + "source": "monitor", + "monitor_id": mon.get("id"), + **_symbol_display_fields(sym), + } + sl = mon.get("stop_loss") + tp = mon.get("take_profit") + if sl is not None: + pending.append({ + **base, + "order_kind": "stop_loss", + "label": "止损监控", + "price": float(sl), + }) + if tp is not None: + pending.append({ + **base, + "order_kind": "take_profit", + "label": "止盈监控", + "price": float(tp), + }) + for r in conn.execute( + "SELECT * FROM trade_order_monitors WHERE status='pending' ORDER BY id DESC" + ).fetchall(): + mon = dict(r) + sym = mon.get("symbol") or "" + pending.append({ + "symbol_code": sym, + "direction": mon.get("direction") or "long", + "direction_label": "做多" if (mon.get("direction") or "long") == "long" else "做空", + "lots": int(mon.get("lots") or 0), + "price": float(mon.get("order_price") or mon.get("entry_price") or 0), + "order_kind": "open_pending", + "label": "开仓挂单中", + "source": "monitor", + "monitor_id": mon.get("id"), + "can_cancel_order": is_trading_session(), + "cancel_allowed": is_trading_session(), + **_symbol_display_fields(sym), + }) + ctp_st = ctp_status(mode) + if ctp_st.get("connected"): + for o in _ctp_active_orders(mode): + sym = _ctp_pos_to_ths_code(o) or (o.get("symbol") or "") + offset_s = (o.get("offset") or "").upper() + kind = "limit" + label = "委托挂单" + if "CLOSE" in offset_s: + label = "平仓委托" + pending.append({ + "symbol_code": sym, + "symbol": _symbol_display_fields(sym).get("symbol_name") or sym, + "direction": o.get("direction") or "long", + "direction_label": "做多" if o.get("direction") == "long" else "做空", + "lots": int(o.get("lots") or 0), + "price": float(o.get("price") or 0), + "order_kind": kind, + "label": label, + "source": "ctp", + "order_id": o.get("order_id"), + "vt_order_id": o.get("vt_order_id") or o.get("order_id"), + "can_cancel_order": is_trading_session(), + "cancel_allowed": is_trading_session(), + **_symbol_display_fields(sym), + }) + return pending + + def _ctp_active_orders(mode: str) -> list: + try: + return ctp_list_active_orders(mode) + except Exception: + return [] + + def _canonical_position_key(symbol: str, direction: str, exchange: str = "") -> str: + sym = (symbol or "").strip() + d = (direction or "long").strip().lower() + ex = (exchange or "").strip().upper() + try: + vnpy_sym, ex2 = ths_to_vnpy_symbol(sym) + sym = vnpy_sym + if not ex: + ex = ex2 + except Exception: + sym = sym.lower() + return position_key(ex, sym, d) + + def _position_key_from_ctp(p: dict) -> str: + return position_key( + p.get("exchange") or "", + p.get("symbol") or "", + p.get("direction") or "long", + ) + + def _monitor_position_key(mon: dict, exchange: str = "") -> str: + sym = (mon.get("symbol") or "").strip() + d = (mon.get("direction") or "long").strip().lower() + ex = (exchange or "").strip().upper() + try: + vnpy_sym, ex2 = ths_to_vnpy_symbol(sym) + sym = vnpy_sym + if not ex: + ex = ex2 + except Exception: + sym = sym.lower() + return position_key(ex, sym, d) + + def _monitors_by_position_key(conn) -> dict[str, dict]: + ensure_monitor_order_columns(conn) + out: dict[str, dict] = {} + for r in conn.execute( + "SELECT * FROM trade_order_monitors WHERE status='active' ORDER BY id DESC" + ).fetchall(): + mon = dict(r) + pk = _monitor_position_key(mon) + if pk not in out: + out[pk] = mon + return out + + def _find_active_monitor(conn, symbol: str, direction: str) -> Optional[dict]: + direction = (direction or "long").strip().lower() + for r in conn.execute( + "SELECT * FROM trade_order_monitors WHERE status='active' ORDER BY id DESC" + ).fetchall(): + row = dict(r) + if (row.get("direction") or "long") != direction: + continue + if _match_ctp_symbol(symbol, row.get("symbol") or ""): + return row + return None + + def _find_pending_monitor(conn, symbol: str, direction: str) -> Optional[dict]: + """开仓委托 pending 仍带止损/移动保本元数据,需与 CTP 持仓关联展示。""" + direction = (direction or "long").strip().lower() + for r in conn.execute( + "SELECT * FROM trade_order_monitors WHERE status='pending' ORDER BY id DESC" + ).fetchall(): + row = dict(r) + if (row.get("direction") or "long") != direction: + continue + if _match_ctp_symbol(symbol, row.get("symbol") or ""): + return row + return None + + def _has_pending_monitors(conn) -> bool: + return bool( + conn.execute( + "SELECT 1 FROM trade_order_monitors WHERE status='pending' LIMIT 1" + ).fetchone() + ) + + def _overlay_sl_tp_readonly( + conn, + mon: Optional[dict], + sym: str, + direction: str, + ) -> Optional[dict]: + """只读:从已关闭监控补全止盈止损,不写库。""" + if not mon: + rsl, rtp, rtrail, rinitial = _restore_sl_tp_from_closed(conn, sym, direction) + if rsl is None and rtp is None and not rtrail: + return {"symbol": sym, "direction": direction} + return { + "symbol": sym, + "direction": direction, + "stop_loss": rsl, + "take_profit": rtp, + "trailing_be": rtrail, + "initial_stop_loss": rinitial, + } + sl = mon.get("stop_loss") + tp = mon.get("take_profit") + trailing = int(mon.get("trailing_be") or 0) + if sl is not None or tp is not None or trailing: + return mon + rsl, rtp, rtrail, rinitial = _restore_sl_tp_from_closed(conn, sym, direction) + if rsl is None and rtp is None and not rtrail: + return mon + merged = dict(mon) + merged["stop_loss"] = rsl + merged["take_profit"] = rtp + merged["trailing_be"] = rtrail + merged["initial_stop_loss"] = rinitial + return merged + + def _revive_closed_monitor(conn, symbol: str, direction: str) -> Optional[dict]: + """柜台仍有持仓但本地监控被误关时,恢复最近一条同品种记录。""" + direction = (direction or "long").strip().lower() + for r in conn.execute( + "SELECT * FROM trade_order_monitors WHERE status='closed' ORDER BY id DESC LIMIT 40" + ).fetchall(): + row = dict(r) + if (row.get("direction") or "long") != direction: + continue + if not _match_ctp_symbol(symbol, row.get("symbol") or ""): + continue + if int(row.get("lots") or 0) <= 0: + continue + execute_retry( + conn, + "UPDATE trade_order_monitors SET status='active' WHERE id=?", + (row["id"],), + ) + row["status"] = "active" + logger.info( + "恢复误关闭监控 id=%s sym=%s dir=%s", + row.get("id"), row.get("symbol"), direction, + ) + return row + return None + + def _find_or_revive_monitor(conn, symbol: str, direction: str) -> Optional[dict]: + active = _find_active_monitor(conn, symbol, direction) + if active: + return active + return _revive_closed_monitor(conn, symbol, direction) + + def _close_all_monitors_for_sym_dir(conn, symbol: str, direction: str) -> None: + direction = (direction or "long").strip().lower() + for r in conn.execute( + "SELECT id, symbol, direction FROM trade_order_monitors " + "WHERE status IN ('active', 'pending')" + ).fetchall(): + if (r["direction"] or "long") != direction: + continue + if _match_ctp_symbol(symbol, r["symbol"] or ""): + conn.execute( + "UPDATE trade_order_monitors SET status='closed' WHERE id=?", + (r["id"],), + ) + + def _close_duplicate_monitors(conn, symbol: str, direction: str, keep_id: int) -> None: + direction = (direction or "long").strip().lower() + for r in conn.execute( + "SELECT id, symbol, direction FROM trade_order_monitors WHERE status='active'" + ).fetchall(): + if int(r["id"]) == int(keep_id): + continue + if (r["direction"] or "long") != direction: + continue + if _match_ctp_symbol(symbol, r["symbol"] or ""): + conn.execute( + "UPDATE trade_order_monitors SET status='closed' WHERE id=?", + (r["id"],), + ) + + def _upsert_open_monitor( + conn, + *, + sym: str, + direction: str, + lots: int, + price: float, + sl, + tp, + trailing_be: int, + ctp_open_time: Optional[str] = None, + open_time: Optional[str] = None, + monitor_type: str = "manual", + status: str = "active", + vt_order_id: Optional[str] = None, + order_price: Optional[float] = None, + ) -> int: + ensure_monitor_order_columns(conn) + codes = ths_to_codes(sym) or {} + sl_f = float(sl) if sl not in (None, "") else None + tp_f = float(tp) if tp not in (None, "") else None + now_s = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + status_val = status if status in ("pending", "active") else "active" + order_px = float(order_price if order_price is not None else price) + existing = _find_active_monitor(conn, sym, direction) + if not existing: + for r in conn.execute( + "SELECT * FROM trade_order_monitors WHERE status='pending' ORDER BY id DESC" + ).fetchall(): + row = dict(r) + if (row.get("direction") or "long") != (direction or "long").strip().lower(): + continue + if _match_ctp_symbol(sym, row.get("symbol") or ""): + existing = row + break + if existing: + mid = int(existing["id"]) + existing_status = (existing.get("status") or "active").strip().lower() + if existing_status == "active" and status_val == "pending": + status_val = "active" + initial_sl = existing.get("initial_stop_loss") + if sl_f is None: + sl_f = float(existing["stop_loss"]) if existing.get("stop_loss") is not None else None + if tp_f is None: + tp_f = float(existing["take_profit"]) if existing.get("take_profit") is not None else None + if sl_f is not None and initial_sl is None: + initial_sl = sl_f + if not trailing_be: + trailing_be = int(existing.get("trailing_be") or 0) + open_time_val = (existing.get("open_time") or "").strip() or now_s + if open_time: + open_time_val = open_time + elif monitor_type == "ctp_sync" and ctp_open_time: + open_time_val = ctp_open_time + vt_val = vt_order_id or existing.get("vt_order_id") + conn.execute( + """UPDATE trade_order_monitors SET + symbol=?, symbol_name=?, market_code=?, lots=?, entry_price=?, + stop_loss=?, take_profit=?, initial_stop_loss=?, trailing_be=?, open_time=?, + monitor_type=?, status=?, vt_order_id=?, order_price=?, risk_percent=COALESCE(risk_percent, ?) + WHERE id=?""", + ( + sym, + codes.get("name", sym), + codes.get("market_code", ""), + lots, + price, + sl_f, + tp_f, + initial_sl, + trailing_be, + open_time_val, + monitor_type if monitor_type != "manual" else (existing.get("monitor_type") or "manual"), + status_val, + vt_val, + order_px, + get_risk_percent(get_setting), + mid, + ), + ) + else: + if open_time: + open_time_val = open_time + elif monitor_type == "ctp_sync" and ctp_open_time: + open_time_val = ctp_open_time + else: + open_time_val = now_s + conn.execute( + """INSERT INTO trade_order_monitors ( + symbol, symbol_name, market_code, direction, lots, entry_price, + stop_loss, take_profit, initial_stop_loss, trailing_be, + open_time, monitor_type, status, vt_order_id, order_price, risk_percent + ) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)""", + ( + sym, + codes.get("name", sym), + codes.get("market_code", ""), + direction, + lots, + price, + sl_f, + tp_f, + sl_f, + trailing_be, + open_time_val, + monitor_type, + status_val, + vt_order_id, + order_px, + get_risk_percent(get_setting), + ), + ) + mid = int(conn.execute("SELECT last_insert_rowid()").fetchone()[0]) + if status_val == "active": + _close_duplicate_monitors(conn, sym, direction, mid) + return mid + + def _sync_monitor_from_ctp( + conn, + mid: int, + sym: str, + direction: str, + mode: str, + *, + ctp: Optional[dict] = None, + capital: float = 0.0, + ) -> None: + """CTP 同步:均价、现价、保证金、仓位占比写入数据库;不覆盖期货下单的开仓时间。""" + positions = [ctp] if ctp else _ctp_positions(mode, refresh_if_empty=False, refresh_margin=True) + for p in positions: + if not p or int(p.get("lots") or 0) <= 0: + continue + if (p.get("direction") or "long") != direction: + continue + if not _match_ctp_symbol(p.get("symbol") or "", sym): + continue + row = conn.execute( + "SELECT open_time, monitor_type FROM trade_order_monitors WHERE id=?", (mid,), + ).fetchone() + db_open = (row["open_time"] or "").strip() if row else "" + monitor_type = (row["monitor_type"] or "manual").strip().lower() if row else "manual" + ctp_open = (p.get("open_time") or "").strip() or None + open_time_val = db_open + if monitor_type == "ctp_sync" and ctp_open: + open_time_val = ctp_open + lots = int(p.get("lots") or 0) + entry = float(p.get("avg_price") or 0) + ctp_margin = float(p.get("margin") or 0) + mark = None + if ctp_status(mode).get("connected"): + mark = ctp_get_tick_price(mode, sym) + if mark is None or mark <= 0: + mark = entry if entry else None + resolved_entry, _src = _resolve_ctp_entry_price( + mode, sym, direction, p, + ) + if resolved_entry > 0: + entry = resolved_entry + float_pnl = None + if mark and entry and lots > 0: + float_pnl = calc_position_metrics( + direction, entry, entry, entry, lots, mark, capital, sym, + ).get("float_pnl") + est = calc_position_metrics( + direction, entry, entry, entry, lots, mark or entry, capital, sym, + ).get("margin") + margin, _src = _resolve_position_margin( + sym=sym, + direction=direction, + lots=lots, + entry=entry, + mode=mode, + ctp=p, + est_margin=est, + ) + position_pct = None + if margin and capital > 0: + position_pct = round(float(margin) / float(capital) * 100, 2) + open_commission = _open_commission_from_ctp_trades(mode, sym, direction) + if open_commission is None: + fee_info = calc_fee_breakdown( + sym, entry, entry, lots, open_time_val or "", "", + trading_mode=mode, + ) + open_commission = fee_info.get("open_fee") + execute_retry( + conn, + """UPDATE trade_order_monitors SET lots=?, entry_price=?, + open_time=?, margin=?, position_pct=?, mark_price=?, float_pnl=?, + open_fee=? + WHERE id=?""", + ( + lots, + entry, + open_time_val, + margin, + position_pct, + float(mark) if mark else None, + float_pnl, + open_commission, + mid, + ), + ) + return + + def _sync_monitor_lots_from_ctp( + conn, mid: int, sym: str, direction: str, mode: str, *, ctp: Optional[dict] = None, + ) -> None: + _sync_monitor_from_ctp( + conn, mid, sym, direction, mode, ctp=ctp, capital=_capital(conn), + ) + + def _compose_position_row( + conn, + *, + mon: Optional[dict], + ctp: Optional[dict], + mode: str, + capital: float, + now_iso: str, + fast: bool = False, + ) -> Optional[dict]: + if not mon and not ctp: + return None + + if mon: + sym = (mon.get("symbol") or "").strip() + direction = mon.get("direction") or "long" + lots = int(mon.get("lots") or 0) + entry = float(mon.get("entry_price") or 0) + source_label = monitor_source_label(mon.get("monitor_type")) + open_time = _time_str(mon.get("open_time")) + open_time_source = "order" + margin = mon.get("margin") + position_pct = mon.get("position_pct") + mark = mon.get("mark_price") + float_pnl = mon.get("float_pnl") + if float_pnl is not None: + float_pnl = round(float(float_pnl), 2) + else: + sym = (ctp.get("symbol") or "").strip() + direction = ctp.get("direction") or "long" + lots = int(ctp.get("lots") or 0) + entry = float(ctp.get("avg_price") or 0) + source_label = "CTP 柜台" + open_time = _time_str(ctp.get("open_time")) + open_time_source = "ctp" + margin = None + position_pct = None + mark = None + float_pnl = None + + if lots <= 0: + return None + + if ctp: + ctp_lots = int(ctp.get("lots") or 0) + if ctp_lots > 0: + lots = ctp_lots + ths_sym = _ctp_pos_to_ths_code(ctp) or sym + resolved_entry, _entry_src = _resolve_ctp_entry_price( + mode, ths_sym, direction, ctp, + ) + if resolved_entry > 0: + entry = resolved_entry + elif float(ctp.get("avg_price") or 0) > 0: + entry = float(ctp.get("avg_price") or 0) + ctp_margin = float(ctp.get("margin") or 0) + if (margin is None or float(margin or 0) <= 0) and ctp_margin > 0: + margin = ctp_margin + if ctp_status(mode).get("connected"): + source_label = "CTP 柜台" + + codes = ths_to_codes(sym) + tick = calc_order_tick_metrics(sym, lots, entry, trading_mode=mode) + sl = float(mon["stop_loss"]) if mon and mon.get("stop_loss") is not None else None + tp = float(mon["take_profit"]) if mon and mon.get("take_profit") is not None else None + holding = _holding_duration(open_time, now_iso) if open_time else "" + + if ctp_status(mode).get("connected"): + live_mark = ctp_get_tick_price(mode, sym) + if live_mark and live_mark > 0: + mark = live_mark + elif (mark is None or float(mark or 0) <= 0) and not fast and codes: + mark = fetch_price( + sym, + codes.get("market_code", ""), + codes.get("sina_code", ""), + ) + if mark is None or mark <= 0: + mark = entry if entry else None + close_est = float(mark) if mark and mark > 0 else entry + if mark and entry and lots > 0: + pos_tmp = calc_position_metrics( + direction, entry, sl or entry, tp or entry, lots, mark, capital, sym, + ) + float_pnl = pos_tmp.get("float_pnl") + if ctp and ctp_status(mode).get("connected"): + ctp_pnl = float(ctp.get("pnl") or 0) + if ctp_pnl != 0: + float_pnl = round(ctp_pnl, 2) + + fee_info = calc_fee_breakdown( + sym, entry, close_est, lots, open_time or now_iso, now_iso, trading_mode=mode, + ) + open_commission = _open_commission_from_ctp_trades(mode, sym, direction) + if open_commission is None and mon and mon.get("open_fee") is not None: + cached_fee = float(mon.get("open_fee") or 0) + if cached_fee > 0: + open_commission = cached_fee + if open_commission is not None: + display_fee = open_commission + fee_source = "ctp" + else: + display_fee = fee_info["open_fee"] + fee_source = fee_info.get("fee_source") or "local" + est_net = None + if float_pnl is not None: + est_net = round(float(float_pnl) - fee_info["close_fee"], 2) + pos_metrics = calc_position_metrics( + direction, entry, sl if sl is not None else entry, + tp if tp is not None else entry, lots, mark, capital, sym, + ) + mon_margin = margin + margin, margin_source = _resolve_position_margin( + sym=sym, + direction=direction, + lots=lots, + entry=entry, + mode=mode, + ctp=ctp, + mon_margin=mon_margin if isinstance(mon_margin, (int, float)) else None, + est_margin=pos_metrics.get("margin"), + ) + if margin and capital > 0: + position_pct = round(float(margin) / float(capital) * 100, 2) + elif position_pct is None or float(position_pct or 0) <= 0: + position_pct = pos_metrics.get("position_pct") + elif position_pct is not None: + position_pct = float(position_pct) + order_st = monitor_order_status( + mon or {}, mode=mode, ths_code=sym, direction=direction, + ) + pending_for_row: list[dict] = [] + if sl is not None: + pending_for_row.append({ + "order_kind": "stop_loss", + "label": "止损监控", + "price": sl, + "lots": lots, + "source": "monitor", + "monitor_id": mon["id"] if mon else None, + }) + if tp is not None: + pending_for_row.append({ + "order_kind": "take_profit", + "label": "止盈监控", + "price": tp, + "lots": lots, + "source": "monitor", + "monitor_id": mon["id"] if mon else None, + }) + row_key = _canonical_position_key( + sym, direction, (ctp or {}).get("exchange") or "", + ) + return { + "key": row_key, + "position_key": row_key, + "source": "ctp", + "source_label": source_label, + "sync_pending": False, + "monitor_id": mon["id"] if mon else None, + "symbol_code": sym, + **_symbol_display_fields(sym), + "direction": direction, + "direction_label": "做多" if direction == "long" else "做空", + "lots": lots, + "entry_price": entry, + "stop_loss": sl, + "take_profit": tp, + "open_time": open_time or None, + "open_time_source": open_time_source or None, + "holding_duration": holding or None, + "mark_price": mark, + "current_price": mark, + "margin": margin, + "margin_source": margin_source, + "position_pct": position_pct, + "risk_amount": pos_metrics.get("risk_amount") if sl is not None else None, + "reward_amount": pos_metrics.get("reward_amount") if tp is not None else None, + "risk_pct": pos_metrics.get("risk_pct") if sl is not None else None, + "rr_ratio": pos_metrics.get("rr_ratio") if sl is not None and tp is not None else None, + "float_pnl": float_pnl, + "est_fee": display_fee, + "est_fee_open": display_fee, + "est_fee_close": fee_info["close_fee"], + "est_fee_close_type": fee_info["close_type"], + "fee_source": fee_source, + "est_pnl_net": est_net, + "sl_order_active": order_st.get("sl_monitoring"), + "tp_order_active": order_st.get("tp_monitoring"), + "sl_monitoring": order_st.get("sl_monitoring"), + "tp_monitoring": order_st.get("tp_monitoring"), + "can_place_orders": False, + "tick_value_total": tick.get("tick_value_total"), + "price_precision": tick.get("price_precision"), + "tick_size": tick.get("tick_size"), + "can_close": True, + "close_allowed": is_trading_session(), + "pending_orders": pending_for_row, + "trailing_be": bool(mon.get("trailing_be")) if mon else False, + "trailing_r_locked": int(mon.get("trailing_r_locked") or 0) if mon else 0, + "breakeven_locked": _breakeven_locked( + entry=entry, + stop_loss=sl, + direction=direction, + tick_size=tick.get("tick_size"), + trailing_r_locked=int(mon.get("trailing_r_locked") or 0) if mon else 0, + ), + } + + def _compose_pending_row( + mon: dict, + *, + mode: str, + capital: float, + now_iso: str, + ) -> Optional[dict]: + sym = (mon.get("symbol") or "").strip() + direction = (mon.get("direction") or "long").strip().lower() + lots = int(mon.get("lots") or 0) + if not sym or lots <= 0: + return None + order_price = float(mon.get("order_price") or mon.get("entry_price") or 0) + codes = ths_to_codes(sym) + sl = float(mon["stop_loss"]) if mon.get("stop_loss") is not None else None + tp = float(mon["take_profit"]) if mon.get("take_profit") is not None else None + pos_metrics = calc_position_metrics( + direction, order_price, sl or order_price, tp or order_price, lots, order_price, capital, sym, + ) + open_time = _time_str(mon.get("open_time")) + timeout_sec = get_pending_order_timeout_sec(get_setting) + remain = pending_auto_cancel_remaining(mon, timeout_sec=timeout_sec) + return { + "key": f"{_canonical_position_key(sym, direction)}:pending:{mon.get('id')}", + "order_state": "pending", + "source": "pending", + "source_label": "委托挂单中", + "sync_pending": True, + "monitor_id": mon.get("id"), + "symbol_code": sym, + **_symbol_display_fields(sym), + "direction": direction, + "direction_label": "做多" if direction == "long" else "做空", + "lots": lots, + "entry_price": order_price, + "order_price": order_price, + "stop_loss": sl, + "take_profit": tp, + "open_time": open_time or None, + "holding_duration": _holding_duration(open_time, now_iso) if open_time else None, + "mark_price": order_price, + "current_price": order_price, + "margin": pos_metrics.get("margin"), + "margin_source": "estimate", + "position_pct": pos_metrics.get("position_pct"), + "risk_amount": pos_metrics.get("risk_amount") if sl is not None else None, + "reward_amount": pos_metrics.get("reward_amount") if tp is not None else None, + "rr_ratio": pos_metrics.get("rr_ratio") if sl is not None and tp is not None else None, + "float_pnl": None, + "est_fee": None, + "can_close": False, + "close_allowed": False, + "can_cancel_order": is_trading_session(), + "cancel_allowed": is_trading_session(), + "auto_cancel_sec": remain, + "pending_timeout_sec": timeout_sec, + "pending_timeout_min": max(1, timeout_sec // 60), + "vt_order_id": mon.get("vt_order_id"), + "sl_order_active": False, + "tp_order_active": False, + "sl_monitoring": bool(sl is not None), + "tp_monitoring": bool(tp is not None), + "can_place_orders": False, + "pending_orders": [], + "trailing_be": bool(mon.get("trailing_be")), + "trailing_r_locked": int(mon.get("trailing_r_locked") or 0), + } + + def _compose_ctp_open_order_row( + o: dict, + *, + mode: str, + capital: float, + now_iso: str, + ) -> Optional[dict]: + offset_u = (o.get("offset") or "").upper() + if offset_u and "OPEN" not in offset_u: + return None + sym = _ctp_pos_to_ths_code(o) or (o.get("symbol") or "").strip() + direction = (o.get("direction") or "long").strip().lower() + lots = int(o.get("lots") or 0) + if not sym or lots <= 0: + return None + order_price = float(o.get("price") or 0) + pos_metrics = calc_position_metrics( + direction, order_price, order_price, order_price, lots, order_price, capital, sym, + ) + timeout_sec = get_pending_order_timeout_sec(get_setting) + return { + "key": f"{_canonical_position_key(sym, direction)}:pending:ctp:{o.get('order_id') or ''}", + "order_state": "pending", + "source": "ctp", + "source_label": "委托挂单", + "sync_pending": True, + "monitor_id": None, + "order_id": o.get("order_id"), + "vt_order_id": o.get("vt_order_id") or o.get("order_id"), + "symbol_code": sym, + **_symbol_display_fields(sym), + "direction": direction, + "direction_label": "做多" if direction == "long" else "做空", + "lots": lots, + "entry_price": order_price, + "order_price": order_price, + "stop_loss": None, + "take_profit": None, + "open_time": now_iso, + "holding_duration": None, + "mark_price": order_price, + "current_price": order_price, + "margin": pos_metrics.get("margin"), + "margin_source": "estimate", + "position_pct": pos_metrics.get("position_pct"), + "float_pnl": None, + "est_fee": None, + "can_close": False, + "close_allowed": False, + "can_cancel_order": is_trading_session(), + "cancel_allowed": is_trading_session(), + "pending_timeout_sec": timeout_sec, + "pending_timeout_min": max(1, timeout_sec // 60), + "sl_order_active": False, + "tp_order_active": False, + "sl_monitoring": False, + "tp_monitoring": False, + "can_place_orders": False, + "pending_orders": [], + "trailing_be": False, + "trailing_r_locked": 0, + } + + def _reconcile_pending(conn, mode: str, *, capital: float = 0.0) -> dict[str, int]: + return reconcile_pending_orders( + conn, + mode, + match_symbol_fn=_match_ctp_symbol, + sync_monitor_fn=_sync_monitor_from_ctp, + capital=capital, + list_positions_fn=_ctp_positions, + timeout_sec=get_pending_order_timeout_sec(get_setting), + ) + + def _build_active_orders( + conn, + *, + mode: str, + capital: float, + now_iso: str, + ) -> list[dict]: + """当前委托:CTP 已连接时读柜台;未连接时不展示本地 pending。""" + orders: list[dict] = [] + seen_keys: set[str] = set() + connected = ctp_status(mode).get("connected") + + if connected: + ctp_orders = trading_state.get_active_orders() + if not ctp_orders: + ctp_orders = _ctp_active_orders(mode) + for o in ctp_orders: + try: + row = _compose_ctp_open_order_row( + o, mode=mode, capital=capital, now_iso=now_iso, + ) + if not row: + row = _compose_ctp_order_row_any( + o, mode=mode, capital=capital, now_iso=now_iso, + ) + if row: + orders.append(row) + seen_keys.add(row.get("key") or "") + except Exception as exc: + logger.warning("compose ctp order row failed: %s", exc) + + ctp_active_map: dict[str, dict] = {} + for o in ctp_orders or []: + for key in (o.get("order_id"), o.get("vt_order_id")): + if key: + ctp_active_map[str(key)] = o + + for r in conn.execute( + "SELECT * FROM trade_order_monitors WHERE status='pending' ORDER BY id DESC" + ).fetchall(): + mon = dict(r) + try: + if not pending_monitor_has_live_order( + mon, + active_orders=ctp_active_map, + active_order_list=ctp_orders or [], + ): + continue + prow = _compose_pending_row( + mon, mode=mode, capital=capital, now_iso=now_iso, + ) + if prow and prow.get("key") not in seen_keys: + pk = f"{prow.get('symbol_code') or ''}:{prow.get('direction') or ''}" + dup = any( + (x.get("symbol_code") or "") + ":" + (x.get("direction") or "") == pk + and x.get("order_state") == "pending" + for x in orders + ) + if not dup: + orders.append(prow) + except Exception as exc: + logger.warning("compose pending order row failed: %s", exc) + return orders + + def _compose_ctp_order_row_any( + o: dict, + *, + mode: str, + capital: float, + now_iso: str, + ) -> Optional[dict]: + """CTP 任意未成交委托(含平仓)。""" + sym = _ctp_pos_to_ths_code(o) or (o.get("symbol") or "").strip() + direction = (o.get("direction") or "long").strip().lower() + lots = int(o.get("lots") or 0) + if not sym or lots <= 0: + return None + offset_u = (o.get("offset") or "").upper() + is_open = not offset_u or "OPEN" in offset_u + order_price = float(o.get("price") or 0) + pos_metrics = calc_position_metrics( + direction, order_price, order_price, order_price, lots, order_price, capital, sym, + ) + label = "开仓委托" if is_open else "平仓委托" + timeout_sec = get_pending_order_timeout_sec(get_setting) + ex = o.get("exchange") or "" + pk = _canonical_position_key(sym, direction, ex) + return { + "key": f"{pk}:order:{o.get('order_id') or ''}", + "order_state": "pending", + "source": "ctp", + "source_label": label, + "sync_pending": False, + "monitor_id": None, + "order_id": o.get("order_id"), + "vt_order_id": o.get("vt_order_id") or o.get("order_id"), + "symbol_code": sym, + **_symbol_display_fields(sym), + "direction": direction, + "direction_label": "做多" if direction == "long" else "做空", + "lots": lots, + "entry_price": order_price, + "order_price": order_price, + "stop_loss": None, + "take_profit": None, + "open_time": now_iso, + "mark_price": order_price, + "current_price": order_price, + "margin": pos_metrics.get("margin"), + "margin_source": "estimate", + "position_pct": pos_metrics.get("position_pct"), + "float_pnl": None, + "can_close": False, + "close_allowed": False, + "can_cancel_order": is_trading_session(), + "cancel_allowed": is_trading_session(), + "pending_timeout_sec": timeout_sec if is_open else None, + "pending_timeout_min": max(1, timeout_sec // 60) if is_open else None, + "sl_order_active": False, + "tp_order_active": False, + "sl_monitoring": False, + "tp_monitoring": False, + "can_place_orders": False, + "pending_orders": [], + "trailing_be": False, + "trailing_r_locked": 0, + } + + def _build_trading_live_rows(conn, *, fast: bool = False) -> list[dict]: + """当前持仓:以 CTP 为准,SQLite 仅叠加 SL/TP 元数据。""" + from zoneinfo import ZoneInfo + tz = ZoneInfo("Asia/Shanghai") + now_iso = datetime.now(tz).strftime("%Y-%m-%dT%H:%M") + mode = get_trading_mode(get_setting) + capital = _capital(conn) + + ctp_list: list[dict] = [] + if ctp_status(mode).get("connected"): + merged: dict[str, dict] = {} + for p in list(_ctp_positions(mode) or []) + list(trading_state.get_positions() or []): + lots = int(p.get("lots") or 0) + if lots <= 0: + continue + pk = p.get("position_key") or _position_key_from_ctp(p) + merged[pk] = p + ctp_list = list(merged.values()) + + ensure_monitor_order_columns(conn) + monitor_by_pk = _monitors_by_position_key(conn) + + rows: list[dict] = [] + for p in ctp_list: + lots = int(p.get("lots") or 0) + if lots <= 0: + continue + pk = p.get("position_key") or _position_key_from_ctp(p) + mon = monitor_by_pk.get(pk) + if not mon: + for mk, mv in monitor_by_pk.items(): + if (mv.get("direction") or "long") != (p.get("direction") or "long"): + continue + if _match_ctp_symbol(p.get("symbol") or "", mv.get("symbol") or ""): + mon = mv + break + ths = _ctp_pos_to_ths_code(p) or (p.get("symbol") or "") + direction = p.get("direction") or "long" + if not mon: + mon = _find_pending_monitor(conn, ths, direction) + if not mon: + if fast: + mon = _find_active_monitor(conn, ths, direction) + else: + mon = _find_or_revive_monitor(conn, ths, direction) + if mon: + if fast: + mon = _overlay_sl_tp_readonly(conn, mon, ths, direction) or mon + else: + mon = _restore_monitor_sl_tp_if_missing(conn, mon, ths, direction) or mon + _sync_monitor_from_ctp( + conn, int(mon["id"]), mon.get("symbol") or ths, + mon.get("direction") or direction, + mode, ctp=p, capital=capital, + ) + mon = _find_active_monitor( + conn, mon.get("symbol") or ths, mon.get("direction") or direction, + ) or mon + mon = _restore_monitor_sl_tp_if_missing(conn, mon, ths, direction) or mon + elif fast: + mon = _overlay_sl_tp_readonly(conn, None, ths, direction) + try: + row = _compose_position_row( + conn, mon=mon, ctp=p, mode=mode, capital=capital, + now_iso=now_iso, fast=fast, + ) + if row: + rows.append(row) + except Exception as exc: + logger.warning("compose ctp position row failed: %s", exc) + + seen: set[str] = set() + deduped: list[dict] = [] + for row in rows: + rk = row.get("key") or row.get("position_key") or "" + if rk in seen: + continue + seen.add(rk) + deduped.append(row) + + if not deduped and ctp_status(mode).get("connected"): + margin_raw = ctp_account_margin_used(mode) + margin_used = float(margin_raw or 0) if margin_raw is not None else 0.0 + has_margin_hint = margin_raw is not None and margin_used > 0 + has_active_mon = any( + int(m.get("lots") or 0) > 0 for m in monitor_by_pk.values() + ) + since_connect = 9999.0 + try: + since_connect = time.time() - float( + getattr(get_bridge(), "_last_connect_ok_ts", 0) or 0, + ) + except Exception: + pass + if has_margin_hint or has_active_mon or since_connect < 300: + if not monitor_by_pk and has_margin_hint: + _ensure_monitors_from_sticky_state(conn, mode) + monitor_by_pk = _monitors_by_position_key(conn) + for mon in monitor_by_pk.values(): + lots = int(mon.get("lots") or 0) + if lots <= 0: + continue + sym = (mon.get("symbol") or "").strip() + direction = (mon.get("direction") or "long").strip().lower() + if fast: + mon = _overlay_sl_tp_readonly(conn, mon, sym, direction) or mon + else: + mon = ( + _restore_monitor_sl_tp_if_missing(conn, mon, sym, direction) + or mon + ) + try: + row = _compose_position_row( + conn, + mon=mon, + ctp=None, + mode=mode, + capital=capital, + now_iso=now_iso, + fast=fast, + ) + if not row: + continue + rk = row.get("key") or row.get("position_key") or "" + if rk and rk in seen: + continue + if rk: + seen.add(rk) + deduped.append(row) + except Exception as exc: + logger.warning("compose monitor fallback row failed: %s", exc) + + if not deduped and ctp_status(mode).get("connected"): + for r in conn.execute( + "SELECT * FROM trade_order_monitors WHERE status='active' ORDER BY id DESC" + ).fetchall(): + mon = dict(r) + lots = int(mon.get("lots") or 0) + if lots <= 0: + continue + sym = (mon.get("symbol") or "").strip() + direction = (mon.get("direction") or "long").strip().lower() + rk = _monitor_position_key(mon) + if rk in seen: + continue + if fast: + mon = _overlay_sl_tp_readonly(conn, mon, sym, direction) or mon + try: + row = _compose_position_row( + conn, + mon=mon, + ctp=None, + mode=mode, + capital=capital, + now_iso=now_iso, + fast=fast, + ) + if not row: + continue + row_key = row.get("key") or row.get("position_key") or rk + if row_key in seen: + continue + seen.add(row_key) + deduped.append(row) + except Exception as exc: + logger.warning("compose active monitor row failed: %s", exc) + + return deduped + + def _build_trading_live_payload(conn, *, fast: bool = False) -> dict: + from zoneinfo import ZoneInfo + tz = ZoneInfo("Asia/Shanghai") + now_iso = datetime.now(tz).strftime("%Y-%m-%dT%H:%M") + mode = get_trading_mode(get_setting) + ctp_st = ctp_status(mode) + _remember_ctp_status(mode, ctp_st) + capital = _capital(conn) + if ctp_st.get("connected") and not fast: + _reconcile_pending(conn, mode, capital=capital) + if ctp_st.get("connected"): + if not fast: + _ensure_monitors_from_ctp(conn, mode) + _sync_trade_monitors_with_ctp(conn, mode) + elif count_active_trade_monitors(conn) == 0: + margin_raw = ctp_account_margin_used(mode) + if margin_raw is not None and float(margin_raw) > 0: + _ensure_monitors_from_sticky_state(conn, mode) + if not fast: + _close_stale_roll_groups(conn) + rows = _build_trading_live_rows(conn, fast=fast) + active_orders = _build_active_orders( + conn, mode=mode, capital=capital, now_iso=now_iso, + ) + rows = _apply_account_margin_to_rows(rows, mode, capital) + if not fast: + _persist_ctp_snapshot_to_monitors(conn, rows, mode) + pending_orders = _build_pending_orders(conn, mode) + risk = get_risk_status( + conn, + active_count=_effective_active_position_count(conn, mode), + equity=capital, + ) + margin_used = ( + ctp_account_margin_used(mode) if ctp_st.get("connected") else None + ) + display_sync_state = "ready" if rows else trading_state.sync_state + display_sync_label = "已同步" if rows else trading_state.sync_label() + return { + "ok": True, + "rows": rows, + "active_orders": active_orders, + "pending_orders": pending_orders, + "capital": capital, + "margin_used": margin_used, + "ctp_status": ctp_st, + "trading_mode_label": trading_mode_label(get_setting), + "risk_status": risk, + "trading_session": is_trading_session(), + "night_session": is_night_trading_session(), + "session_clock": trading_session_clock(), + "pending_order_timeout_min": get_pending_order_timeout_min(get_setting), + "sync_state": display_sync_state, + "sync_label": display_sync_label, + } + + def _minimal_live_payload(conn) -> dict: + """零 IPC 兜底:仅读库 + 缓存 CTP 状态,持仓由后台 worker 补全。""" + mode = get_trading_mode(get_setting) + ctp_st = _cached_ctp_status(mode) + capital = _capital(conn) + risk = get_risk_status( + conn, + active_count=count_active_trade_monitors(conn), + equity=capital, + ) + syncing = bool(ctp_st.get("connected") or ctp_st.get("connecting")) + return { + "ok": True, + "rows": [], + "active_orders": [], + "pending_orders": [], + "capital": capital, + "ctp_status": ctp_st, + "trading_mode_label": trading_mode_label(get_setting), + "risk_status": risk, + "trading_session": is_trading_session(), + "night_session": is_night_trading_session(), + "session_clock": trading_session_clock(), + "pending_order_timeout_min": get_pending_order_timeout_min(get_setting), + "sync_state": "syncing" if syncing else trading_state.sync_state, + "sync_label": "加载中…" if syncing else trading_state.sync_label(), + } + + def _normalize_live_payload(payload: dict) -> dict: + if payload.get("rows"): + payload = dict(payload) + payload["sync_state"] = "ready" + payload["sync_label"] = "已同步" + return payload + + def _refresh_trading_live_snapshot(*, fast: bool = False) -> dict: + def _build() -> dict: + mode = get_trading_mode(get_setting) + if ctp_status(mode).get("connected") and not fast: + try: + with _ctp_td_lock: + get_bridge().calibrate_trading_state() + except Exception as exc: + logger.debug("refresh calibrate: %s", exc) + for p in trading_state.get_positions() or _ctp_positions(mode, refresh_if_empty=False): + ths = _ctp_pos_to_ths_code(p) + if ths: + try: + get_bridge().subscribe_symbol(ths) + except Exception: + pass + conn = get_db() + try: + init_strategy_tables(conn) + if not fast: + ensure_monitor_order_columns(conn, migrate=True) + payload = _build_trading_live_payload(conn, fast=fast) + commit_retry(conn) + prev = position_hub.get_snapshot() + active_n = int((payload.get("risk_status") or {}).get("active_count") or 0) + if ( + prev + and ctp_status(mode).get("connected") + and not (payload.get("rows") or []) + and (prev.get("rows") or []) + ): + margin_raw = payload.get("margin_used") + if margin_raw is None: + margin_raw = ctp_account_margin_used(mode) + margin_used_val = float(margin_raw or 0) if margin_raw is not None else 0.0 + if ( + (margin_raw is not None and margin_used_val > 0) + or trading_state.sync_state == "syncing" + or active_n > 0 + ): + payload = dict(payload) + payload["rows"] = prev["rows"] + if trading_state.sync_state == "syncing": + payload["sync_state"] = "syncing" + payload["sync_label"] = "同步中…" + elif ( + ctp_status(mode).get("connected") + and not (payload.get("rows") or []) + and active_n > 0 + ): + payload = dict(payload) + payload["rows"] = _build_trading_live_rows(conn, fast=fast) + elif ctp_status(mode).get("connected") and not (payload.get("rows") or []): + since_connect = time.time() - float( + getattr(get_bridge(), "_last_connect_ok_ts", 0) or 0, + ) + if since_connect < 180: + payload = dict(payload) + payload["sync_state"] = "syncing" + payload["sync_label"] = "持仓同步中…" + return _normalize_live_payload(payload) + finally: + conn.close() + + if fast: + snap = position_hub.get_snapshot() + if snap: + return snap + if _live_refresh_lock.acquire(blocking=False): + try: + return _build() + finally: + _live_refresh_lock.release() + conn = get_db() + try: + init_strategy_tables(conn) + return _minimal_live_payload(conn) + finally: + conn.close() + with _live_refresh_lock: + return _build() + + def _push_position_snapshot_async(*, fast: bool = True) -> None: + def _run() -> None: + try: + payload = _refresh_trading_live_snapshot(fast=fast) + position_hub.broadcast("positions", payload) + conn = get_db() + try: + rec = _recommend_payload(conn) + recommend_hub.broadcast("recommend", {"ok": True, **rec}) + finally: + conn.close() + except Exception as exc: + logger.debug("push position snapshot: %s", exc) + + threading.Thread(target=_run, daemon=True).start() + + def _build_position_quotes_payload(mode: str) -> dict: + """轻量现价/浮盈(仅读 tick 缓存,不走 SQLite)。""" + if not ctp_status(mode).get("connected"): + return {"ok": True, "quotes": []} + from modules.core.contract_specs import get_contract_spec + + positions = trading_state.get_positions() + if not positions: + positions = _ctp_positions(mode, refresh_if_empty=False) + quotes: list[dict] = [] + for p in positions: + lots = int(p.get("lots") or 0) + if lots <= 0: + continue + ths = _ctp_pos_to_ths_code(p) or (p.get("symbol") or "") + if not ths: + continue + direction = (p.get("direction") or "long").strip().lower() + mark = ctp_get_tick_price(mode, ths) + if not mark or mark <= 0: + continue + entry, _ = _resolve_ctp_entry_price( + mode, ths, direction, p, + ) + if entry <= 0: + continue + mult = float(get_contract_spec(ths).get("mult") or 10) + ctp_pnl = float(p.get("pnl") or 0) + if ctp_pnl != 0: + float_pnl = round(ctp_pnl, 2) + elif direction == "long": + float_pnl = round((mark - entry) * mult * lots, 2) + else: + float_pnl = round((entry - mark) * mult * lots, 2) + row_key = _canonical_position_key( + ths, direction, (p.get("exchange") or ""), + ) + quotes.append({ + "key": row_key, + "position_key": row_key, + "mark_price": mark, + "current_price": mark, + "float_pnl": float_pnl, + }) + return {"ok": True, "quotes": quotes} + + def _push_position_quotes_async() -> None: + def _run() -> None: + try: + if not is_trading_session(): + return + mode = get_trading_mode(get_setting) + if trading_state.try_lock_entry_prices(): + _push_position_snapshot_async(fast=False) + return + payload = _build_position_quotes_payload(mode) + if payload.get("quotes"): + position_hub.push_event("position_quotes", payload) + except Exception as exc: + logger.debug("push position quotes: %s", exc) + + threading.Thread(target=_run, daemon=True, name="position-quotes").start() + + def _on_tick_sl_tp(exchange: str, symbol: str, price: float) -> None: + from modules.trading.sl_tp_guard import check_sl_tp_on_tick + from modules.core.db_conn import DB_PATH, connect_db + + mode = get_trading_mode(get_setting) + if not ctp_status(mode).get("connected"): + return + conn = connect_db(DB_PATH) + try: + _init_tables(conn) + capital = _capital(conn) + n = check_sl_tp_on_tick( + conn, mode, exchange, symbol, price, + capital=capital, notify_fn=send_wechat_msg, + be_tick_mult=get_trailing_be_tick_buffer(get_setting), + ) + if n: + conn.commit() + _push_position_snapshot_async(fast=True) + except Exception as exc: + logger.debug("tick sl/tp: %s", exc) + finally: + conn.close() + + def _prime_position_snapshot() -> None: + """进程启动同步预热:优先写入持仓/权益快照,页面打开即可读。""" + try: + payload = _refresh_trading_live_snapshot(fast=True) + position_hub.set_snapshot(payload) + n = len(payload.get("rows") or []) + logger.info( + "持仓快照已预热 capital=%s rows=%d", + payload.get("capital"), + n, + ) + except Exception as exc: + logger.warning("prime position snapshot: %s", exc) + + def _bootstrap_trading_runtime() -> None: + """进程启动:并发预热持仓快照 + CTP 连接,不阻塞 HTTP 监听。""" + set_position_refresh_callback( + lambda: _push_position_snapshot_async(fast=True) + ) + set_tick_quote_callback(_push_position_quotes_async) + set_tick_sl_tp_callback(_on_tick_sl_tp) + set_ctp_connected_callback(_on_ctp_connected) + + def _warm() -> None: + try: + payload = _refresh_trading_live_snapshot(fast=True) + position_hub.set_snapshot(payload) + position_hub.broadcast("positions", payload) + mode = get_trading_mode(get_setting) + if ctp_status(mode).get("connected"): + try: + with _ctp_td_lock: + get_bridge().calibrate_trading_state() + get_bridge().request_position_snapshot(force=True) + except Exception as exc: + logger.debug("bootstrap calibrate: %s", exc) + payload = _refresh_trading_live_snapshot(fast=True) + position_hub.set_snapshot(payload) + position_hub.broadcast("positions", payload) + + def _slow_sync() -> None: + time.sleep(20) + try: + pl = _refresh_trading_live_snapshot(fast=False) + position_hub.set_snapshot(pl) + position_hub.broadcast("positions", pl) + except Exception as exc: + logger.warning("bootstrap slow sync: %s", exc) + + threading.Thread(target=_slow_sync, daemon=True, name="boot-slow-sync").start() + except Exception as exc: + logger.warning("bootstrap position snapshot: %s", exc) + + def _start_ctp() -> None: + try: + from modules.ctp.ctp_premarket_connect import should_auto_connect_now + from modules.ctp.vnpy_bridge import ctp_start_connect + + if should_auto_connect_now(): + mode = get_trading_mode(get_setting) + ctp_start_connect(mode, force=False, scheduled=True) + except Exception as exc: + logger.debug("bootstrap ctp connect: %s", exc) + + from concurrent.futures import ThreadPoolExecutor + + workers = max(2, int(os.getenv("QIHUO_STARTUP_WORKERS", "8") or 8)) + with ThreadPoolExecutor(max_workers=min(workers, 4), thread_name_prefix="boot") as pool: + pool.submit(_warm) + pool.submit(_start_ctp) + + def _on_ctp_connected(mode: str) -> None: + if mode != get_trading_mode(get_setting): + return + _schedule_recommend_refresh() + _push_position_snapshot_async(fast=True) + + def _after_connect() -> None: + try: + try: + with _ctp_td_lock: + get_bridge().request_position_snapshot(force=True) + get_bridge().calibrate_trading_state() + except Exception as exc: + logger.debug("ctp connected calibrate: %s", exc) + _push_position_snapshot_async(fast=True) + conn = get_db() + try: + init_strategy_tables(conn) + _ensure_monitors_from_ctp(conn, mode) + commit_retry(conn) + finally: + conn.close() + _push_position_snapshot_async(fast=False) + except Exception as exc: + logger.debug("ctp connected monitor restore: %s", exc) + + threading.Thread(target=_after_connect, daemon=True, name="ctp-monitor-restore").start() + + @app.route("/trade") + @login_required + def trade_page(): + return redirect(url_for("positions")) + + @app.route("/positions") + @login_required + def positions(): + conn = get_db() + try: + init_strategy_tables(conn) + mode = get_trading_mode(get_setting) + ctp_st = _cached_ctp_status(mode) + connected = bool(ctp_st.get("connected")) + capital = _capital(conn) + recommend_capital = _recommend_capital(conn) + risk = get_risk_status( + conn, + active_count=_effective_active_position_count( + conn, mode, ctp_connected=connected, + ), + equity=capital, + ) + ctp_acc = {} + bootstrap_live = position_hub.get_snapshot() + if connected and bootstrap_live and bootstrap_live.get("capital") is not None: + cap = float(bootstrap_live.get("capital") or 0) + ctp_acc = {"balance": cap, "available": cap} + active_trend = conn.execute( + "SELECT * FROM trend_pullback_plans WHERE status='active' ORDER BY id DESC LIMIT 1" + ).fetchone() + monitor_count = conn.execute( + "SELECT COUNT(*) AS n FROM trade_order_monitors WHERE status='active'" + ).fetchone()["n"] + roll_count = conn.execute( + "SELECT COUNT(*) AS n FROM roll_groups WHERE status='active'" + ).fetchone()["n"] + conn.commit() + sizing = get_sizing_mode(get_setting) + max_pct = get_max_margin_pct(get_setting) + rec_cache = _recommend_payload(conn, use_ctp_margin=False) + if rec_cache.get("needs_refresh"): + _schedule_recommend_refresh() + ctp_connected = connected + margin_rec = small_account_margin_recommendations() + if not bootstrap_live: + bootstrap_live = { + "ok": True, + "rows": [], + "active_orders": [], + "pending_orders": [], + "capital": capital, + "ctp_status": dict(ctp_st), + "risk_status": risk, + "trading_session": is_trading_session(), + "night_session": is_night_trading_session(), + "session_clock": trading_session_clock(), + "sync_state": trading_state.sync_state, + "sync_label": trading_state.sync_label(), + } + else: + bootstrap_live = dict(bootstrap_live) + bootstrap_live.setdefault("capital", capital) + bootstrap_live.setdefault("risk_status", risk) + bootstrap_live["ctp_status"] = dict(ctp_st) + return render_template( + "trade.html", + trading_mode=mode, + trading_mode_label=trading_mode_label(get_setting), + capital=capital, + recommend_capital=recommend_capital, + risk_status=risk, + ctp_status=ctp_st, + ctp_account=ctp_acc, + active_trend=dict(active_trend) if active_trend else None, + monitor_count=monitor_count, + roll_count=roll_count, + sizing_mode=sizing, + sizing_mode_label=_sizing_mode_label(sizing), + fixed_lots=get_fixed_lots(get_setting), + fixed_amount=get_fixed_amount(get_setting), + risk_percent=get_risk_percent(get_setting), + max_margin_pct=get_max_margin_pct(get_setting), + pending_order_timeout_min=get_pending_order_timeout_min(get_setting), + ctp_auto_connect=is_ctp_auto_connect_enabled(get_setting), + recommend_rows=rec_cache.get("rows") or [], + recommend_updated_at=rec_cache.get("updated_at"), + night_session=is_night_trading_session(), + small_account_scope=should_apply_small_account_scope( + capital, ctp_connected=ctp_connected, + ), + small_account_scope_hint=small_account_scope_hint(ctp_connected=ctp_connected), + small_account_margin_rec=margin_rec if should_apply_small_account_scope( + capital, ctp_connected=ctp_connected, + ) else None, + session_clock=trading_session_clock(), + roll_max_margin_pct=get_roll_max_margin_pct(get_setting), + product_categories=PRODUCT_CATEGORIES, + bootstrap_live=bootstrap_live, + ) + finally: + conn.close() + + @app.route("/recommend") + @login_required + def recommend_page(): + return redirect(url_for("positions") + "#recommend") + + @app.route("/api/trading/live") + @login_required + def api_trading_live(): + snap = position_hub.get_snapshot() + if snap: + return jsonify(_normalize_live_payload(snap)) + payload = _refresh_trading_live_snapshot(fast=True) + payload = _normalize_live_payload(payload) + position_hub.set_snapshot(payload) + return jsonify(payload) + + @app.route("/api/trading/stream") + @login_required + def api_trading_stream(): + from queue import Empty + + @stream_with_context + def generate(): + yield ": stream\n\n" + q = position_hub.subscribe() + try: + snap = position_hub.get_snapshot() + if not snap: + conn = get_db() + try: + init_strategy_tables(conn) + payload = _minimal_live_payload(conn) + finally: + conn.close() + position_hub.set_snapshot(payload) + yield sse_format("positions", payload) + _push_position_snapshot_async(fast=True) + else: + yield sse_format("positions", snap) + while True: + try: + msg = q.get(timeout=25) + yield sse_format(msg["event"], msg["data"]) + except Empty: + yield ": heartbeat\n\n" + finally: + position_hub.unsubscribe(q) + + return Response( + generate(), + mimetype="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "X-Accel-Buffering": "no", + }, + ) + + @app.route("/api/trading/monitor/upsert", methods=["POST"]) + @login_required + def api_trading_monitor_upsert(): + """为已有持仓补充/更新本地止盈止损监控。""" + d = request.get_json(silent=True) or {} + sym = (d.get("symbol_code") or d.get("symbol") or "").strip() + direction = (d.get("direction") or "long").strip().lower() + try: + lots = max(1, int(d.get("lots") or 1)) + entry = float(d.get("entry_price") or d.get("entry") or 0) + sl = float(d["stop_loss"]) if d.get("stop_loss") not in (None, "") else None + tp = float(d["take_profit"]) if d.get("take_profit") not in (None, "") else None + except (TypeError, ValueError, KeyError): + return jsonify({"ok": False, "error": "参数无效"}), 400 + if not sym: + return jsonify({"ok": False, "error": "缺少品种代码"}), 400 + if sl is None and tp is None: + return jsonify({"ok": False, "error": "请至少填写止损或止盈"}), 400 + trailing_on = bool(d.get("trailing_be")) + if trailing_on and sl is None: + return jsonify({"ok": False, "error": "移动保本须填写止损价"}), 400 + if trailing_on: + tp = None + mode = get_trading_mode(get_setting) + conn = get_db() + try: + init_strategy_tables(conn) + mon = _find_active_monitor(conn, sym, direction) + has_pos = bool(mon) + ths_sym = sym + if ctp_status(mode).get("connected"): + for p in _ctp_positions(mode, refresh_if_empty=False): + if int(p.get("lots") or 0) <= 0: + continue + if (p.get("direction") or "long") != direction: + continue + if _match_ctp_symbol(p.get("symbol") or "", sym): + has_pos = True + lots = int(p.get("lots") or lots) + entry = float(p.get("avg_price") or entry or 0) + ths_sym = _ctp_pos_to_ths_code(p) or sym + break + if not has_pos: + return jsonify({"ok": False, "error": "未找到对应持仓"}), 400 + trailing_be = 1 if trailing_on else ( + int(mon.get("trailing_be") or 0) if mon else 0 + ) + mid = _upsert_open_monitor( + conn, + sym=ths_sym, + direction=direction, + lots=lots, + price=entry, + sl=sl, + tp=tp, + trailing_be=trailing_be, + ) + if trailing_on and sl is not None: + conn.execute( + """UPDATE trade_order_monitors SET + take_profit=NULL, initial_stop_loss=?, trailing_r_locked=0 + WHERE id=?""", + (sl, mid), + ) + conn.commit() + _push_position_snapshot_async(fast=False) + return jsonify({ + "ok": True, + "monitor_id": mid, + "message": "止盈止损已保存,程序本地监控", + }) + finally: + conn.close() + + @app.route("/api/trading/monitor/place-orders", methods=["POST"]) + @login_required + def api_trading_monitor_place_orders(): + """本地监控模式:清理旧版柜台挂单,不再向交易所挂止盈止损。""" + d = request.get_json(silent=True) or {} + try: + monitor_id = int(d.get("monitor_id") or 0) + except (TypeError, ValueError): + monitor_id = 0 + conn = get_db() + try: + init_strategy_tables(conn) + ensure_monitor_order_columns(conn) + mode = get_trading_mode(get_setting) + if not ctp_status(mode).get("connected"): + return jsonify({"ok": False, "error": "请先连接 CTP"}), 400 + mon = None + if monitor_id > 0: + row = conn.execute( + "SELECT * FROM trade_order_monitors WHERE id=? AND status='active'", + (monitor_id,), + ).fetchone() + mon = dict(row) if row else None + if not mon: + sym = (d.get("symbol_code") or "").strip() + direction = (d.get("direction") or "long").strip().lower() + for r in conn.execute( + "SELECT * FROM trade_order_monitors WHERE status='active'" + ).fetchall(): + row = dict(r) + if row.get("direction") != direction: + continue + if _match_ctp_symbol(sym, row.get("symbol") or ""): + mon = row + break + if not mon: + return jsonify({"ok": False, "error": "未找到有效监控快照"}), 404 + result = place_monitor_exit_orders( + conn, mon, mode=mode, force=bool(d.get("force")), + ) + if not result.get("ok"): + return jsonify(result), 400 + return jsonify(result) + finally: + conn.close() + + @app.route("/api/trading/monitor/dismiss", methods=["POST"]) + @login_required + def api_trading_monitor_dismiss(): + d = request.get_json(silent=True) or {} + try: + monitor_id = int(d.get("monitor_id") or 0) + except (TypeError, ValueError): + monitor_id = 0 + if monitor_id <= 0: + return jsonify({"ok": False, "error": "无效的监控记录"}), 400 + conn = get_db() + try: + init_strategy_tables(conn) + mode = get_trading_mode(get_setting) + row = conn.execute( + "SELECT * FROM trade_order_monitors WHERE id=? AND status IN ('active', 'pending')", + (monitor_id,), + ).fetchone() + if not row: + return jsonify({"ok": False, "error": "记录不存在或已关闭"}), 404 + mon = dict(row) + if (mon.get("status") or "").strip().lower() == "pending": + if not is_trading_session(): + return jsonify({"ok": False, "error": "不在交易时间段,无法撤单"}), 403 + ok, msg = cancel_pending_monitor(conn, mon, mode) + _push_position_snapshot_async(fast=False) + return jsonify({"ok": ok, "message": msg}) + conn.execute( + "UPDATE trade_order_monitors SET status='closed' WHERE id=?", + (monitor_id,), + ) + conn.commit() + _push_position_snapshot_async(fast=False) + return jsonify({"ok": True, "message": "已取消本地止盈止损监控"}) + finally: + conn.close() + + @app.route("/api/trading/monitor/cancel-open", methods=["POST"]) + @login_required + def api_trading_monitor_cancel_open(): + """撤销 pending 开仓委托(柜台撤单 + 关闭本地记录)。""" + d = request.get_json(silent=True) or {} + try: + monitor_id = int(d.get("monitor_id") or 0) + except (TypeError, ValueError): + monitor_id = 0 + if monitor_id <= 0: + return jsonify({"ok": False, "error": "无效的委托记录"}), 400 + conn = get_db() + try: + init_strategy_tables(conn) + mode = get_trading_mode(get_setting) + if not ctp_status(mode).get("connected"): + return jsonify({"ok": False, "error": "请先连接 CTP"}), 400 + if not is_trading_session(): + return jsonify({"ok": False, "error": "不在交易时间段,无法撤单"}), 403 + row = conn.execute( + "SELECT * FROM trade_order_monitors WHERE id=? AND status='pending'", + (monitor_id,), + ).fetchone() + if not row: + return jsonify({"ok": False, "error": "未找到挂单中的开仓委托"}), 404 + ok, msg = cancel_pending_monitor(conn, dict(row), mode) + _push_position_snapshot_async(fast=False) + return jsonify({"ok": ok, "message": msg}) + finally: + conn.close() + + @app.route("/api/trading/order/cancel", methods=["POST"]) + @login_required + def api_trading_order_cancel(): + """撤销柜台未成交委托(按 vt_order_id)。""" + d = request.get_json(silent=True) or {} + order_id = (d.get("order_id") or "").strip() + if not order_id: + return jsonify({"ok": False, "error": "无效的委托号"}), 400 + mode = get_trading_mode(get_setting) + if not ctp_status(mode).get("connected"): + return jsonify({"ok": False, "error": "请先连接 CTP"}), 400 + if not is_trading_session(): + return jsonify({"ok": False, "error": "不在交易时间段,无法撤单"}), 403 + ok = ctp_cancel_order(mode, order_id) + _push_position_snapshot_async(fast=False) + if not ok: + return jsonify({"ok": False, "error": "撤单失败,委托可能已成交或已撤销"}), 400 + return jsonify({"ok": True, "message": "撤单已提交"}) + + @app.route("/api/trading/close", methods=["POST"]) + @login_required + def api_trading_close(): + d = request.get_json(silent=True) or {} + source = (d.get("source") or "").strip() + conn = get_db() + init_strategy_tables(conn) + mode = get_trading_mode(get_setting) + if not ctp_status(mode).get("connected") and source in ("ctp", "program"): + conn.close() + return jsonify({"ok": False, "error": "请先连接 CTP"}), 400 + sym = (d.get("symbol_code") or d.get("symbol") or "").strip() + direction = (d.get("direction") or "long").strip().lower() + try: + lots = max(1, int(d.get("lots") or 1)) + price = float(d.get("price") or 0) + except (TypeError, ValueError): + conn.close() + return jsonify({"ok": False, "error": "参数无效"}), 400 + if not sym or price <= 0: + conn.close() + return jsonify({"ok": False, "error": "品种或价格无效"}), 400 + offset = "close_long" if direction == "long" else "close_short" + capital = _capital(conn) + mon = None + mid = int(d.get("monitor_id") or 0) + if mid: + row = conn.execute( + "SELECT * FROM trade_order_monitors WHERE id=? AND status='active'", + (mid,), + ).fetchone() + if row: + mon = dict(row) + if not mon: + for r in conn.execute( + "SELECT * FROM trade_order_monitors WHERE status='active'" + ).fetchall(): + row = dict(r) + if row.get("direction") != direction: + continue + if _match_ctp_symbol(sym, row.get("symbol") or ""): + mon = row + mid = int(row["id"]) + break + entry = float(mon.get("entry_price") or 0) if mon else 0.0 + if entry <= 0: + for p in _ctp_positions(mode): + if int(p.get("lots") or 0) <= 0: + continue + if (p.get("direction") or "long") != direction: + continue + if _match_ctp_symbol(p.get("symbol") or "", sym): + entry = float(p.get("avg_price") or price) + break + try: + execute_order( + conn, mode=mode, offset=offset, symbol=sym, direction=direction, + lots=lots, price=price, settings=_settings_dict(), + order_type="market", + ) + # 始终写本地记录:CTP 同步依赖内存开平配对,重启后或成交回报延迟时会漏记 + write_manual_close_trade_log( + conn, + mon, + symbol=sym, + direction=direction, + lots=lots, + close_price=price, + entry_price=entry or price, + trading_mode=mode, + capital=capital, + stop_loss=float(mon["stop_loss"]) if mon and mon.get("stop_loss") is not None else None, + take_profit=float(mon["take_profit"]) if mon and mon.get("take_profit") is not None else None, + open_time=(mon.get("open_time") or "") if mon else "", + symbol_name=(mon.get("symbol_name") or "") if mon else "", + market_code=(mon.get("market_code") or "") if mon else "", + ) + _close_all_monitors_for_sym_dir(conn, sym, direction) + conn.commit() + try: + from modules.ctp.ctp_trade_sync import sync_trade_logs_from_ctp + sync_trade_logs_from_ctp(conn, mode, capital=capital, trading_mode=mode) + conn.commit() + except Exception as exc: + logger.debug("sync trades after close: %s", exc) + conn.close() + _push_position_snapshot_async() + return jsonify({"ok": True, "message": "已平仓,交易记录已写入"}) + except ValueError as exc: + conn.close() + return jsonify({"ok": False, "error": str(exc)}), 400 + + + def _roll_ui_modes(): + return frozenset({ADD_MODE_MARKET, ADD_MODE_BREAKOUT}) + + def _roll_filled_lots_map(conn, group_ids: list[int]) -> dict[int, int]: + if not group_ids: + return {} + placeholders = ",".join("?" * len(group_ids)) + rows = conn.execute( + f"""SELECT roll_group_id, COALESCE(SUM(lots), 0) AS n + FROM roll_legs + WHERE roll_group_id IN ({placeholders}) AND status=? + GROUP BY roll_group_id""", + (*group_ids, LEG_STATUS_FILLED), + ).fetchall() + return {int(r["roll_group_id"]): int(r["n"] or 0) for r in rows} + + def _build_roll_context(conn) -> dict: + has_trend = bool(conn.execute( + "SELECT 1 FROM trend_pullback_plans WHERE status='active' LIMIT 1", + ).fetchone()) + groups_by_monitor: dict[int, dict] = {} + pending_monitors: set[int] = set() + for row in conn.execute( + "SELECT * FROM roll_groups WHERE status='active'", + ).fetchall(): + g = dict(row) + mid = int(g.get("order_monitor_id") or 0) + if mid: + groups_by_monitor[mid] = g + for row in conn.execute( + """SELECT g.order_monitor_id + FROM roll_legs l + JOIN roll_groups g ON g.id = l.roll_group_id + WHERE l.status=? AND g.status='active'""", + (LEG_STATUS_PENDING,), + ).fetchall(): + mid = int(row["order_monitor_id"] or 0) + if mid: + pending_monitors.add(mid) + return { + "has_trend": has_trend, + "groups_by_monitor": groups_by_monitor, + "pending_monitors": pending_monitors, + } + + def _roll_eligibility_with_ctx(conn, mon: dict, ctx: dict) -> Optional[str]: + mid = int(mon["id"]) + grp = ctx["groups_by_monitor"].get(mid) + legs_done = int(grp.get("leg_count") or 0) if grp else 0 + return roll_eligibility_error( + sizing_mode=get_sizing_mode(get_setting), + monitor=mon, + has_active_trend=ctx["has_trend"], + legs_done=legs_done, + has_pending_leg=mid in ctx["pending_monitors"], + ) + + def _enrich_roll_group_row_fast(row: dict, filled_map: dict[int, int]) -> dict: + out = dict(row) + lots = float(out.get("mon_lots") or 0) + entry = float(out.get("mon_entry") or 0) + tp = float(out.get("mon_tp") or out.get("initial_take_profit") or 0) + direction = (out.get("direction") or "long").strip().lower() + sym = (out.get("symbol") or "").strip() + mult = int(get_contract_spec(sym).get("mult") or 1) if sym else 1 + gid = int(out.get("id") or 0) + filled_add_lots = int(filled_map.get(gid) or 0) + out["add_lots_filled"] = filled_add_lots + out["first_lots"] = max(0, int(lots) - filled_add_lots) + out["total_lots"] = int(lots) + out["avg_entry"] = round(entry, 4) if entry > 0 else None + if lots > 0 and entry > 0 and tp > 0: + if direction == "long": + out["reward_at_tp"] = round((tp - entry) * lots * mult, 2) + else: + out["reward_at_tp"] = round((entry - tp) * lots * mult, 2) + else: + out["reward_at_tp"] = None + return out + + def _enrich_roll_group_row(conn, row: dict) -> dict: + gid = int(row.get("id") or 0) + filled_map = _roll_filled_lots_map(conn, [gid]) if gid > 0 else {} + return _enrich_roll_group_row_fast(row, filled_map) + + def _archive_roll_group( + conn, + grp: dict, + *, + result_label: str = "持仓已结束", + ) -> None: + from zoneinfo import ZoneInfo + + now_s = datetime.now(ZoneInfo("Asia/Shanghai")).strftime("%Y-%m-%d %H:%M:%S") + gid = int(grp.get("id") or 0) + if gid <= 0: + return + if conn.execute( + "SELECT 1 FROM strategy_trade_snapshots WHERE strategy_type=? AND source_id=? LIMIT 1", + (STRATEGY_ROLL, gid), + ).fetchone(): + conn.execute( + "UPDATE roll_groups SET status='closed', updated_at=? WHERE id=?", + (now_s, gid), + ) + return + legs = [ + dict(r) for r in conn.execute( + "SELECT * FROM roll_legs WHERE roll_group_id=? ORDER BY id", + (gid,), + ).fetchall() + ] + mon = None + mid = int(grp.get("order_monitor_id") or 0) + if mid: + row = conn.execute( + "SELECT * FROM trade_order_monitors WHERE id=?", + (mid,), + ).fetchone() + mon = dict(row) if row else None + payload = { + "group": dict(grp), + "legs": legs, + "monitor": mon, + } + save_snapshot( + conn, + strategy_type=STRATEGY_ROLL, + source_id=gid, + symbol=grp.get("symbol") or (mon or {}).get("symbol") or "", + direction=grp.get("direction") or (mon or {}).get("direction") or "", + result_label=result_label, + payload=payload, + opened_at=grp.get("created_at") or "", + ) + conn.execute( + "UPDATE roll_legs SET status=? WHERE roll_group_id=? AND status=?", + (LEG_STATUS_CANCELLED, gid, LEG_STATUS_PENDING), + ) + conn.execute( + "UPDATE roll_groups SET status='closed', updated_at=? WHERE id=?", + (now_s, gid), + ) + + def _close_stale_roll_groups(conn) -> int: + rows = conn.execute( + """SELECT g.*, m.status AS monitor_status + FROM roll_groups g + LEFT JOIN trade_order_monitors m ON m.id = g.order_monitor_id + WHERE g.status='active' + AND (m.id IS NULL OR m.status != 'active')""" + ).fetchall() + for r in rows: + _archive_roll_group(conn, dict(r), result_label="持仓已结束") + return len(rows) + + def _enrich_roll_leg_row(row: dict, mode: str) -> dict: + out = dict(row) + sym = (out.get("symbol") or "").strip() + mark = _cached_position_mark(sym, out.get("direction") or "") if sym else None + out["current_price"] = round(float(mark), 4) if mark and mark > 0 else None + return out + + def _enrich_roll_record_row(conn, row: dict) -> dict: + out = dict(row) + snap = out.get("snapshot") or {} + group = snap.get("group") or {} + legs = snap.get("legs") or [] + monitor = snap.get("monitor") or {} + filled_legs = [ + l for l in legs + if (l.get("status") or "").strip().lower() == LEG_STATUS_FILLED + ] + add_lots = sum(int(l.get("lots") or 0) for l in filled_legs) + total_lots = int((monitor or {}).get("lots") or 0) + first_lots = max(0, total_lots - add_lots) + latest_sl = ( + group.get("current_stop_loss") + or (monitor or {}).get("stop_loss") + or None + ) + close_log = None + try: + close_log = conn.execute( + """SELECT close_price, pnl, pnl_net, close_time, lots + FROM trade_logs + WHERE lower(symbol)=lower(?) AND direction=? + ORDER BY close_time DESC, id DESC LIMIT 1""", + (out.get("symbol") or "", out.get("direction") or ""), + ).fetchone() + except Exception: + close_log = None + close_d = dict(close_log) if close_log else {} + out["detail"] = { + "first_lots": first_lots if first_lots > 0 else None, + "add_count": len(filled_legs), + "add_lots": add_lots, + "total_lots": total_lots if total_lots > 0 else None, + "latest_stop_loss": latest_sl, + "close_price": close_d.get("close_price"), + "close_time": close_d.get("close_time") or out.get("closed_at"), + "pnl": close_d.get("pnl_net") if close_d.get("pnl_net") is not None else close_d.get("pnl"), + "legs": filled_legs, + "monitor": monitor, + "group": group, + } + return out + + def _roll_leg_trigger_price(leg: dict): + for key in ("breakthrough_price", "limit_price", "fill_price"): + val = leg.get(key) + if val not in (None, "", 0): + return val + return None + + @app.route("/strategy") + @login_required + @_nav("strategy") + def strategy_page(): + conn = get_db() + try: + init_strategy_tables(conn) + ensure_monitor_order_columns(conn) + capital = _capital(conn) + active_trend = conn.execute( + "SELECT * FROM trend_pullback_plans WHERE status='active' ORDER BY id DESC LIMIT 1" + ).fetchone() + monitors_raw = conn.execute( + "SELECT * FROM trade_order_monitors WHERE status='active' ORDER BY id DESC" + ).fetchall() + mode = get_trading_mode(get_setting) + roll_ctx = _build_roll_context(conn) + roll_groups = conn.execute( + """SELECT g.*, m.symbol_name, m.lots AS mon_lots, m.entry_price AS mon_entry, + m.take_profit AS mon_tp + FROM roll_groups g + LEFT JOIN trade_order_monitors m ON m.id = g.order_monitor_id + WHERE g.status='active' ORDER BY g.id DESC""" + ).fetchall() + roll_legs = conn.execute( + """SELECT l.*, g.symbol, g.direction, g.order_monitor_id + FROM roll_legs l + JOIN roll_groups g ON g.id = l.roll_group_id + WHERE l.status=? AND g.status='active' + ORDER BY l.id DESC LIMIT 30""", + (LEG_STATUS_PENDING,), + ).fetchall() + sizing = get_sizing_mode(get_setting) + roll_allowed = sizing == MODE_AMOUNT + monitors = [] + for m in monitors_raw: + row = dict(m) + err = _roll_eligibility_with_ctx(conn, row, roll_ctx) + row["roll_eligible"] = roll_allowed and err is None + if not roll_allowed: + row["roll_block_reason"] = "仅固定金额(以损定仓)模式可滚仓" + else: + row["roll_block_reason"] = err or "" + monitors.append(row) + active_trend_row = dict(active_trend) if active_trend else None + if active_trend_row: + active_trend_row["period_label"] = trend_period_label( + active_trend_row.get("period") or "15m", + ) + group_ids = [int(g["id"]) for g in roll_groups if g["id"]] + filled_map = _roll_filled_lots_map(conn, group_ids) + enriched_groups = [ + _enrich_roll_group_row_fast(dict(g), filled_map) for g in roll_groups + ] + enriched_legs = [_enrich_roll_leg_row(dict(l), mode) for l in roll_legs] + return render_template( + "strategy.html", + capital=capital, + fixed_amount=get_fixed_amount(get_setting), + sizing_mode=sizing, + sizing_mode_label=_sizing_mode_label(sizing), + roll_allowed=roll_allowed, + active_trend=active_trend_row, + monitors=monitors, + roll_groups=enriched_groups, + roll_legs=enriched_legs, + trading_session=is_trading_session(), + session_clock=trading_session_clock(), + trend_periods=trend_strategy_periods(), + add_mode_labels={ + "market": "市价加仓", + "breakout": "突破加仓", + }, + roll_leg_status_labels={ + "pending": "监控中", + "filled": "已成交", + "cancelled": "已取消", + }, + ) + finally: + conn.close() + + @app.route("/strategy/records") + @login_required + def strategy_records_page(): + conn = get_db() + init_strategy_tables(conn) + trend, roll = list_snapshots(conn) + roll = [_enrich_roll_record_row(conn, r) for r in roll] + conn.close() + return render_template("strategy_records.html", trend_rows=trend, roll_rows=roll) + + @app.route("/api/trade/quote") + @login_required + def api_trade_quote(): + sym = (request.args.get("symbol") or "").strip() + lots = request.args.get("lots") or "1" + if not sym: + return jsonify({"ok": False, "error": "缺少品种"}), 400 + codes = ths_to_codes(sym) + price = fetch_price(sym, codes.get("market_code", "") if codes else "", codes.get("sina_code", "") if codes else "") + try: + lots_f = max(1, int(float(lots))) + except (TypeError, ValueError): + lots_f = 1 + mode = get_trading_mode(get_setting) + metrics = calc_order_tick_metrics(sym, lots_f, price, trading_mode=mode) + spec = get_contract_spec(sym) + name = codes.get("name", sym) if codes else sym + pos_long = pos_short = 0 + ctp_st = ctp_status(mode) + if ctp_st.get("connected"): + for p in _ctp_positions(mode): + if not _match_ctp_symbol(p.get("symbol", ""), sym): + continue + if p["direction"] == "long": + pos_long = int(p["lots"]) + else: + pos_short = int(p["lots"]) + max_open = int(_capital(get_db()) / (metrics["margin_per_lot"] or 1)) if metrics.get("margin_per_lot") else 0 + return jsonify({ + "ok": True, + "symbol": sym, + "name": name, + "price": price, + "lots": lots_f, + "metrics": metrics, + "exchange": codes.get("exchange", "") if codes else "", + "pos_long": pos_long, + "pos_short": pos_short, + "max_open_long": max_open, + "max_open_short": max_open, + "footer_text": ( + f"*{name} 每手{spec['mult']}吨/点 最小变动{metrics['tick_size']} " + f"每跳{metrics['tick_value_per_lot']}元/手×{lots_f}={metrics['tick_value_total']}元 " + f"精度{metrics['price_precision']}位小数" + ), + }) + + @app.route("/api/trade/preview", methods=["POST"]) + @login_required + def api_trade_preview(): + d = request.get_json(silent=True) or {} + sym = (d.get("symbol") or "").strip() + direction = (d.get("direction") or "long").strip().lower() + try: + entry = float(d.get("entry") or d.get("price") or 0) + sl = float(d.get("stop_loss") or 0) + tp = float(d.get("take_profit") or 0) + except (TypeError, ValueError): + return jsonify({"ok": False, "error": "价格参数无效"}), 400 + conn = get_db() + capital = _capital(conn) + conn.close() + sizing = get_sizing_mode(get_setting) + margin_pct = get_max_margin_pct(get_setting) + sizing_info = {} + if sizing == MODE_AMOUNT: + lots, err, sizing_info = calc_lots_by_amount( + entry, sl, direction, get_fixed_amount(get_setting), sym, + capital=capital, max_margin_pct=margin_pct, + trading_mode=get_trading_mode(get_setting), + ) + if err: + return jsonify({"ok": False, "error": err}), 400 + elif sizing == MODE_FIXED: + lots = get_fixed_lots(get_setting) + else: + try: + lots = max(1, int(d.get("lots") or 1)) + except (TypeError, ValueError): + lots = 1 + metrics = calc_position_metrics(direction, entry, sl, tp, lots, entry, capital, sym) + tick = calc_order_tick_metrics( + sym, lots, entry, direction=direction, trading_mode=get_trading_mode(get_setting), + ) + return jsonify({ + "ok": True, "lots": lots, "sizing_mode": sizing, + "metrics": metrics, "tick": tick, "capital": capital, + "sizing_info": sizing_info, + }) + + @app.route("/api/trade/order", methods=["POST"]) + @login_required + def api_trade_order(): + d = request.get_json(silent=True) or {} + sym = (d.get("symbol") or "").strip() + offset = (d.get("offset") or "open").strip().lower() + direction = (d.get("direction") or "long").strip().lower() + try: + lots = max(1, int(d.get("lots") or 1)) + price = float(d.get("price") or 0) + except (TypeError, ValueError): + return jsonify({"ok": False, "error": "手数或价格无效"}), 400 + order_type = (d.get("order_type") or d.get("price_type") or "limit").strip().lower() + if order_type == "market" and price <= 0: + codes = ths_to_codes(sym) + price = fetch_price( + sym, + codes.get("market_code", "") if codes else "", + codes.get("sina_code", "") if codes else "", + ) or 0 + if not sym or price <= 0: + return jsonify({"ok": False, "error": "品种或价格无效"}), 400 + conn = get_db() + init_strategy_tables(conn) + mode = get_trading_mode(get_setting) + if offset.startswith("open"): + _sync_trade_monitors_with_ctp(conn, mode) + if not is_trading_session(): + conn.close() + return jsonify({"ok": False, "error": "不在交易时间段"}), 403 + if d.get("trailing_be") and not d.get("stop_loss"): + conn.close() + return jsonify({"ok": False, "error": "开启移动保本须填写止损价"}), 400 + err = assert_can_open( + conn, + active_count=_effective_active_position_count(conn, mode), + equity=_capital(conn), + ) + if err: + conn.close() + return jsonify({"ok": False, "error": err}), 403 + scope_err = assert_product_allowed_for_capital( + sym, _capital(conn), ctp_connected=is_ctp_connected(get_setting), + ) + if scope_err: + conn.close() + return jsonify({"ok": False, "error": scope_err}), 403 + ctp_st = ctp_status(mode) + if not ctp_st.get("connected"): + conn.close() + if get_bridge().connect_in_progress(): + return jsonify({"ok": False, "error": "CTP 连接中,请稍候再下单"}), 400 + return jsonify({"ok": False, "error": "请先连接 CTP"}), 400 + sizing = get_sizing_mode(get_setting) + if offset.startswith("open") and sizing == MODE_AMOUNT: + sl = float(d.get("stop_loss") or 0) + if sl <= 0: + conn.close() + return jsonify({"ok": False, "error": "固定金额模式须填写止损价"}), 400 + lots_calc, err, _sizing_info = calc_lots_by_amount( + price, sl, direction, get_fixed_amount(get_setting), sym, + capital=_capital(conn), max_margin_pct=get_max_margin_pct(get_setting), + trading_mode=mode, + ) + if err: + conn.close() + return jsonify({"ok": False, "error": err}), 400 + lots = lots_calc or lots + elif offset.startswith("open") and sizing == MODE_FIXED: + lots = get_fixed_lots(get_setting) + margin_pct = get_max_margin_pct(get_setting) + usage = calc_margin_usage_pct( + _ctp_positions(mode), + _capital(conn), + extra_symbol=sym if offset.startswith("open") else "", + extra_lots=lots if offset.startswith("open") else 0, + extra_price=price if offset.startswith("open") else 0, + extra_direction=direction if offset.startswith("open") else "long", + trading_mode=mode, + ) + if offset.startswith("open") and usage > margin_pct: + conn.close() + return jsonify({ + "ok": False, + "error": f"保证金占用 {usage:.1f}% 超过上限 {margin_pct:g}%(可在系统设置修改)", + }), 403 + if lots > DEFAULT_MAX_ORDER_LOTS: + conn.close() + return jsonify({ + "ok": False, + "error": f"单笔手数 {lots} 超过上限 {DEFAULT_MAX_ORDER_LOTS},请加大止损距离或改固定手数", + }), 400 + try: + result = execute_order( + conn, + mode=mode, + offset=offset, + symbol=sym, + direction=direction, + lots=lots, + price=price, + settings=_settings_dict(), + order_type=order_type, + ) + if offset.startswith("open") and d.get("trailing_be") and not d.get("stop_loss"): + conn.close() + return jsonify({"ok": False, "error": "开启移动保本须填写止损价"}), 400 + if offset.startswith("open"): + from zoneinfo import ZoneInfo + sl = d.get("stop_loss") + trailing_be = 1 if d.get("trailing_be") else 0 + tp = None if trailing_be else d.get("take_profit") + open_ts = datetime.now(ZoneInfo("Asia/Shanghai")).strftime("%Y-%m-%d %H:%M:%S") + vt_order_id = str(result.get("order_id") or "") + mid = _upsert_open_monitor( + conn, + sym=sym, + direction=direction, + lots=lots, + price=price, + sl=sl, + tp=tp, + trailing_be=trailing_be, + open_time=open_ts, + monitor_type="manual", + status="pending", + vt_order_id=vt_order_id or None, + order_price=price, + ) + conn.commit() + try: + with _ctp_td_lock: + get_bridge().refresh_positions() + except Exception: + pass + _reconcile_pending(conn, mode, capital=_capital(conn)) + st_row = conn.execute( + "SELECT status FROM trade_order_monitors WHERE id=?", (mid,), + ).fetchone() + filled = st_row and (st_row["status"] or "").strip().lower() == "active" + rejected = st_row and (st_row["status"] or "").strip().lower() == "closed" + if rejected: + conn.commit() + conn.close() + _push_position_snapshot_async(fast=False) + return jsonify({ + "ok": False, + "error": "委托已被柜台拒绝或撤销(请确认合约状态与交易时段)", + "lots": lots, + "filled": False, + }), 400 + if not filled: + try: + get_bridge().refresh_positions() + except Exception: + pass + _reconcile_pending(conn, mode, capital=_capital(conn)) + st_row = conn.execute( + "SELECT status FROM trade_order_monitors WHERE id=?", (mid,), + ).fetchone() + filled = st_row and (st_row["status"] or "").strip().lower() == "active" + rejected = st_row and (st_row["status"] or "").strip().lower() == "closed" + if rejected: + conn.commit() + conn.close() + _push_position_snapshot_async(fast=False) + return jsonify({ + "ok": False, + "error": "委托已被柜台拒绝或撤销(请确认合约状态与交易时段)", + "lots": lots, + "filled": False, + }), 400 + if filled: + _sync_monitor_from_ctp( + conn, mid, sym, direction, mode, capital=_capital(conn), + ) + mon_row = conn.execute( + "SELECT * FROM trade_order_monitors WHERE id=?", (mid,), + ).fetchone() + if mon_row and (sl or tp): + try: + ensure_monitor_order_columns(conn) + cancel_monitor_exit_orders(conn, dict(mon_row), mode=mode) + except Exception as exc: + logger.warning("清理旧版止盈止损挂单失败: %s", exc) + conn.commit() + _push_position_snapshot_async(fast=False) + msg = ( + f"开仓成功 · {lots} 手" + if filled + else ( + f"委托已提交 · {lots} 手挂单中" + f"({get_pending_order_timeout_sec(get_setting) // 60} 分钟未成交自动撤单)" + ) + ) + conn.commit() + if offset.startswith("open"): + from modules.core.db_conn import DB_PATH + from modules.notify.ai_worker import schedule_ai_event_analysis + from modules.trading.trade_notify import notify_manual_open_filled + + if filled: + open_sl = float(d.get("stop_loss") or 0) if d.get("stop_loss") else None + open_tp = None if d.get("trailing_be") else d.get("take_profit") + if open_tp is not None: + try: + open_tp = float(open_tp) + except (TypeError, ValueError): + open_tp = None + codes = ths_to_codes(sym) or {} + if open_sl and open_sl > 0: + notify_manual_open_filled( + send_wechat=send_wechat_msg, + get_setting=get_setting, + mode_label=trading_mode_label(get_setting), + sym=sym, + symbol_name=codes.get("name") or sym, + direction=direction, + entry=price, + sl=open_sl, + tp=open_tp, + lots=lots, + capital=_capital(conn), + order_id=str(result.get("order_id") or ""), + trailing_be=bool(d.get("trailing_be")), + be_tick_buffer=get_trailing_be_tick_buffer(get_setting), + schedule_ai_fn=schedule_ai_event_analysis, + db_path=DB_PATH, + ) + else: + send_wechat_msg( + f"{trading_mode_label(get_setting)} 开仓 {sym} {direction} {lots}手 @{price}" + ) + elif not filled: + send_wechat_msg( + f"委托已提交 · {sym} {direction} {lots}手挂单中" + f"({get_pending_order_timeout_sec(get_setting) // 60} 分钟未成交自动撤单)" + ) + elif not offset.startswith("open"): + send_wechat_msg( + f"{trading_mode_label(get_setting)} {offset} {sym} {direction} {lots}手 @{price}" + ) + conn.close() + _push_position_snapshot_async(fast=False) + return jsonify({ + "ok": True, + "result": result, + "lots": lots, + "message": msg if offset.startswith("open") else "委托已提交柜台", + "filled": filled if offset.startswith("open") else None, + }) + except (ValueError, RuntimeError) as exc: + conn.close() + return jsonify({"ok": False, "error": str(exc)}), 400 + except Exception as exc: + conn.close() + return jsonify({"ok": False, "error": str(exc)}), 500 + + @app.route("/api/ctp/connect", methods=["POST"]) + @login_required + def api_ctp_connect(): + from modules.ctp.vnpy_bridge import ctp_start_connect + from modules.ctp.ctp_settings import CTP_DISABLED_HINT + + if not is_ctp_auto_connect_enabled(get_setting): + mode = get_trading_mode(get_setting) + st = ctp_status(mode) + return jsonify({ + "ok": False, + "disabled": True, + "error": CTP_DISABLED_HINT, + "status": st, + }), 400 + mode = get_trading_mode(get_setting) + body = request.get_json(silent=True) or {} + force = bool(body.get("force")) + auto = bool(body.get("auto")) + # 自动连接仅由 qihuo-ctp 后台 worker 发起;Web 只读状态,避免换页重复 connect。 + if auto and not force: + st = ctp_status(mode) + acc = _ctp_account(mode) if st.get("connected") else {} + return jsonify({ + "ok": True, + "connecting": bool(st.get("connecting")), + "backend_managed": True, + "status": st, + "account": acc, + }) + info = ctp_start_connect(mode, force=force) + st = info.get("status") or ctp_status(mode) + acc = _ctp_account(mode) if st.get("connected") else {} + if st.get("connected"): + return jsonify({"ok": True, "status": st, "account": acc}) + if info.get("connecting") or info.get("started"): + return jsonify({ + "ok": True, + "connecting": True, + "status": st, + "account": acc, + }) + if info.get("cooldown"): + return jsonify({ + "ok": False, + "cooldown": True, + "error": st.get("last_error") or "CTP 登录冷却中", + "status": st, + "account": acc, + }), 400 + return jsonify({ + "ok": False, + "error": st.get("last_error") or "CTP 连接未启动", + "status": st, + "account": acc, + }), 400 + + @app.route("/api/ctp/status") + @login_required + def api_ctp_status(): + mode = get_trading_mode(get_setting) + st = ctp_status(mode) + acc = {} + if st.get("connected"): + try: + acc = _ctp_account(mode) + except Exception: + acc = {} + return jsonify({"ok": True, "status": st, "account": acc}) + + @app.route("/api/account_snapshot") + @login_required + def api_account_snapshot(): + conn = get_db() + try: + init_strategy_tables(conn) + mode = get_trading_mode(get_setting) + ctp_st = ctp_status(mode) + capital = _capital(conn) + risk = get_risk_status( + conn, + active_count=_effective_active_position_count(conn, mode), + equity=capital, + ) + conn.commit() + ctp_acc = _ctp_account(mode) if ctp_st.get("connected") else {} + positions = _ctp_positions(mode) if ctp_st.get("connected") else [] + if ctp_st.get("connected") and not positions: + positions = _positions_for_monitor_restore(mode) + return jsonify({ + "capital": capital, + "trading_mode": mode, + "trading_mode_label": trading_mode_label(get_setting), + "sizing_mode": get_sizing_mode(get_setting), + "risk_status": risk, + "ctp_status": ctp_st, + "ctp_account": ctp_acc, + "positions": positions, + }) + finally: + conn.close() + + @app.route("/api/recommend/list") + @login_required + def api_recommend_list(): + """只读数据库缓存,不在请求时拉行情。""" + conn = get_db() + try: + payload = _recommend_payload(conn) + return jsonify({"ok": True, **payload}) + finally: + conn.close() + + @app.route("/api/recommend/stream") + @login_required + def api_recommend_stream(): + from queue import Empty + + def generate(): + q = recommend_hub.subscribe() + try: + conn = get_db() + try: + payload = _recommend_payload(conn) + finally: + conn.close() + yield sse_format("recommend", {"ok": True, **payload}) + while True: + try: + msg = q.get(timeout=25) + yield sse_format(msg["event"], msg["data"]) + except Empty: + yield ": heartbeat\n\n" + finally: + recommend_hub.unsubscribe(q) + + return Response( + stream_with_context(generate()), + mimetype="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + }, + ) + + @app.route("/api/recommend/refresh", methods=["POST"]) + @login_required + def api_recommend_refresh(): + """手动触发一次后台刷新(仍写入数据库)。""" + conn = get_db() + try: + init_strategy_tables(conn) + capital = _recommend_capital(conn) + mode = get_trading_mode(get_setting) + rows = refresh_recommend_cache( + conn, capital, _main_quote, trading_mode=mode, + max_margin_pct=get_max_margin_pct(get_setting), + ) + max_pct = get_max_margin_pct(get_setting) + payload = _recommend_payload(conn) + recommend_hub.broadcast("recommend", {"ok": True, **payload}) + return jsonify({"ok": True, "count": len(rows), **payload}) + finally: + conn.close() + + @app.route("/api/strategy/trend/preview", methods=["POST"]) + @login_required + def api_trend_preview(): + d = request.get_json(silent=True) or {} + sym = (d.get("symbol") or "").strip() + conn = get_db() + if conn.execute("SELECT id FROM trend_pullback_plans WHERE status='active'").fetchone(): + conn.close() + return jsonify({"ok": False, "error": "已有运行中趋势计划"}), 400 + capital = _capital(conn) + codes = ths_to_codes(sym) + price = fetch_price(sym, codes.get("market_code", "") if codes else "", codes.get("sina_code", "") if codes else "") + conn.close() + if not price: + return jsonify({"ok": False, "error": "无法获取现价"}), 400 + plan, err = compute_trend_plan_futures( + direction=d.get("direction") or "long", + stop_loss=float(d.get("stop_loss") or 0), + add_upper=float(d.get("add_upper") or 0), + take_profit=float(d.get("take_profit") or 0), + risk_percent=float(d.get("risk_percent") or get_risk_percent(get_setting)), + capital=capital, + live_price=price, + ths_code=sym, + dca_legs=int(d.get("dca_legs") or 5), + ) + if err: + return jsonify({"ok": False, "error": err}), 400 + period = normalize_trend_period(d.get("period")) + sym_name = (d.get("symbol_name") or "").strip() + if not sym_name and codes: + sym_name = codes.get("name") or sym + plan = enrich_trend_plan_preview( + plan, symbol=sym, symbol_name=sym_name, period=period, + ) + return jsonify({"ok": True, "plan": plan}) + + @app.route("/api/strategy/trend/execute", methods=["POST"]) + @login_required + def api_trend_execute(): + d = request.get_json(silent=True) or {} + sym = (d.get("symbol") or "").strip() + conn = get_db() + init_strategy_tables(conn) + capital = _capital(conn) + err = assert_can_open(conn, equity=capital) + if err: + conn.close() + return jsonify({"ok": False, "error": err}), 403 + scope_err = assert_product_allowed_for_capital( + sym, capital, ctp_connected=is_ctp_connected(get_setting), + ) + if scope_err: + conn.close() + return jsonify({"ok": False, "error": scope_err}), 403 + codes = ths_to_codes(sym) + price = fetch_price(sym, codes.get("market_code", "") if codes else "", codes.get("sina_code", "") if codes else "") + plan, perr = compute_trend_plan_futures( + direction=d.get("direction") or "long", + stop_loss=float(d.get("stop_loss") or 0), + add_upper=float(d.get("add_upper") or 0), + take_profit=float(d.get("take_profit") or 0), + risk_percent=float(d.get("risk_percent") or get_risk_percent(get_setting)), + capital=capital, + live_price=price or float(d.get("live_price") or 0), + ths_code=sym, + ) + if perr: + conn.close() + return jsonify({"ok": False, "error": perr}), 400 + period = normalize_trend_period(d.get("period")) + sym_name = (d.get("symbol_name") or "").strip() + if not sym_name and codes: + sym_name = codes.get("name") or sym + plan = enrich_trend_plan_preview( + plan, symbol=sym, symbol_name=sym_name, period=period, + ) + mode = get_trading_mode(get_setting) + try: + execute_order( + conn, mode=mode, offset="open", symbol=sym, + direction=plan["direction"], lots=plan["first_lots"], price=price, settings=_settings_dict(), + ) + except ValueError as exc: + conn.close() + return jsonify({"ok": False, "error": str(exc)}), 400 + now = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + cur = conn.execute( + """INSERT INTO trend_pullback_plans ( + status, symbol, symbol_name, direction, stop_loss, add_upper, take_profit, + risk_percent, capital_snapshot, plan_margin, target_lots, first_lots, remainder_lots, + dca_legs, leg_amounts_json, grid_prices_json, first_order_done, avg_entry_price, + lots_open, opened_at, period + ) VALUES ('active',?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,1,?,?,?,?) RETURNING id""", + ( + sym, sym_name or (codes.get("name", sym) if codes else sym), plan["direction"], + plan["stop_loss"], plan["add_upper"], plan["take_profit"], + plan["risk_percent"], plan["capital_snapshot"], plan["plan_margin"], + plan["target_lots"], plan["first_lots"], plan["remainder_lots"], + plan["dca_legs"], plan["leg_amounts_json"], plan["grid_prices_json"], + price, plan["first_lots"], now, plan["period"], + ), + ) + row = cur.fetchone() + plan_id = int(row["id"] if isinstance(row, dict) else row[0]) + conn.commit() + conn.close() + send_wechat_msg(f"趋势回调首仓 {sym} {plan['first_lots']}手") + return jsonify({"ok": True, "plan_id": plan_id, "plan": plan}) + + def _roll_group_for_monitor(conn, monitor_id: int): + return conn.execute( + "SELECT * FROM roll_groups WHERE order_monitor_id=? AND status='active'", + (int(monitor_id),), + ).fetchone() + + def _roll_filled_legs(conn, monitor_id: int) -> int: + grp = _roll_group_for_monitor(conn, monitor_id) + if grp: + return int(grp["leg_count"] or 0) + return 0 + + def _roll_has_pending(conn, monitor_id: int) -> bool: + grp = _roll_group_for_monitor(conn, monitor_id) + if not grp: + return False + return bool(conn.execute( + "SELECT 1 FROM roll_legs WHERE roll_group_id=? AND status=? LIMIT 1", + (int(grp["id"]), LEG_STATUS_PENDING), + ).fetchone()) + + def _roll_eligibility(conn, mon: dict, ctx: Optional[dict] = None) -> Optional[str]: + if ctx is None: + ctx = _build_roll_context(conn) + return _roll_eligibility_with_ctx(conn, mon, ctx) + + def _roll_monitor_for_request(conn, mon_id: int) -> Optional[dict]: + row = conn.execute( + "SELECT * FROM trade_order_monitors WHERE id=?", + (int(mon_id),), + ).fetchone() + if not row: + return None + mon = dict(row) + if (mon.get("status") or "").strip().lower() == "active": + return mon + mode = get_trading_mode(get_setting) + if not _cached_ctp_status(mode).get("connected"): + return None + sym = (mon.get("symbol") or "").strip() + direction = (mon.get("direction") or "long").strip().lower() + for p in _positions_for_monitor_restore(mode, allow_ctp=False): + if int(p.get("lots") or 0) <= 0: + continue + if (p.get("direction") or "long").strip().lower() != direction: + continue + if not _match_ctp_symbol(p.get("symbol") or "", sym): + continue + execute_retry( + conn, + "UPDATE trade_order_monitors SET status='active' WHERE id=?", + (int(mon_id),), + ) + mon["status"] = "active" + _sync_monitor_from_ctp( + conn, + int(mon_id), + sym, + direction, + mode, + ctp=p, + capital=_capital(conn), + ) + fresh = conn.execute( + "SELECT * FROM trade_order_monitors WHERE id=?", + (int(mon_id),), + ).fetchone() + return dict(fresh) if fresh else mon + return None + + def _roll_mark_price( + sym: str, + mon: dict, + mode: str, + *, + allow_ctp: bool = False, + ) -> float: + mark = _cached_position_mark(sym, (mon or {}).get("direction") or "") + if mark and mark > 0: + return float(mark) + mark = ( + ctp_get_tick_price(mode, sym) + if allow_ctp and ctp_status(mode).get("connected") + else None + ) + if mark and mark > 0: + return float(mark) + px = fetch_price(sym) + if px and px > 0: + return float(px) + return float(mon.get("entry_price") or 0) + + def _build_roll_preview(conn, d: dict, mon: dict, *, mode: str): + sym = mon["symbol"] + spec = get_contract_spec(sym) + capital = _capital(conn) + add_mode = (d.get("add_mode") or ADD_MODE_MARKET).strip().lower() + off_session_breakout = add_mode == ADD_MODE_BREAKOUT and not is_trading_session() + mark = _roll_mark_price(sym, mon, mode, allow_ctp=not off_session_breakout) + if (not mark or mark <= 0) and off_session_breakout: + bt = float(d.get("breakthrough_price") or 0) + mark = bt if bt > 0 else float(mon.get("entry_price") or 0) + entry_existing = _live_entry_price( + sym, mon["direction"], mode, float(mon.get("entry_price") or 0), + allow_ctp=False, + ) + if add_mode in FIB_MODES: + return None, "斐波加仓已停用,请选市价或突破" + if add_mode not in _roll_ui_modes(): + return None, "仅支持市价加仓或突破加仓" + risk_budget = get_fixed_amount(get_setting) + legs_done = _roll_filled_legs(conn, int(mon["id"])) + preview, err = preview_roll( + direction=mon["direction"], + symbol=sym, + qty_existing=float(mon["lots"]), + entry_existing=entry_existing, + initial_take_profit=float(mon["take_profit"] or 0), + add_mode=add_mode, + new_stop_loss=float(d.get("new_stop_loss") or 0), + risk_budget=risk_budget, + mult=int(spec["mult"]), + mark_price=mark, + add_price=float(d.get("add_price") or 0) or mark, + limit_price=d.get("limit_price"), + breakthrough_price=d.get("breakthrough_price"), + fib_upper=d.get("fib_upper"), + fib_lower=d.get("fib_lower"), + legs_done=legs_done, + off_session_pending=off_session_breakout, + ) + if err: + return None, err + preview, merr = _apply_roll_margin_cap( + preview, conn=conn, mode=mode, mon=dict(mon), capital=capital, + ) + if merr: + return None, merr + return preview, None + + def _commit_roll_fill( + conn, + *, + mon: dict, + preview: dict, + add_mode: str, + mode: str, + pending_leg_id: Optional[int] = None, + ) -> tuple[bool, str]: + sym = mon["symbol"] + mon_id = int(mon["id"]) + price = float(preview["add_price"]) + try: + execute_order( + conn, mode=mode, offset="open", symbol=sym, + direction=mon["direction"], lots=int(preview["add_lots"]), price=price, + settings=_settings_dict(), + ) + except ValueError as exc: + return False, str(exc) + new_lots = int(mon["lots"]) + int(preview["add_lots"]) + new_avg = preview["avg_entry_after"] + new_sl = preview["new_stop_loss"] + conn.execute( + "UPDATE trade_order_monitors SET lots=?, entry_price=?, stop_loss=? WHERE id=?", + (new_lots, new_avg, new_sl, mon_id), + ) + grp = _roll_group_for_monitor(conn, mon_id) + now = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + risk_budget = float(preview.get("risk_budget") or get_fixed_amount(get_setting)) + if grp: + gid = int(grp["id"]) + leg_n = int(grp["leg_count"] or 0) + 1 + conn.execute( + "UPDATE roll_groups SET leg_count=?, current_stop_loss=?, updated_at=? WHERE id=?", + (leg_n, new_sl, now, gid), + ) + else: + cur = conn.execute( + """INSERT INTO roll_groups ( + order_monitor_id, symbol, direction, initial_take_profit, initial_stop_loss, + current_stop_loss, risk_percent, leg_count, status, created_at, updated_at + ) VALUES (?,?,?,?,?,?,?,1,'active',?,?) RETURNING id""", + ( + mon_id, sym, mon["direction"], mon["take_profit"], mon["stop_loss"], + new_sl, risk_budget, now, now, + ), + ) + row = cur.fetchone() + gid = int(row["id"] if isinstance(row, dict) else row[0]) + leg_n = 1 + if pending_leg_id: + conn.execute( + """UPDATE roll_legs SET status=?, fill_price=?, lots=?, new_stop_loss=?, created_at=? + WHERE id=?""", + ( + LEG_STATUS_FILLED, price, int(preview["add_lots"]), new_sl, now, + int(pending_leg_id), + ), + ) + else: + conn.execute( + """INSERT INTO roll_legs ( + roll_group_id, leg_index, add_mode, fill_price, lots, new_stop_loss, + status, created_at, limit_price, breakthrough_price, last_mark_price, capital_snapshot + ) VALUES (?,?,?,?,?,?,?,?,?,?,?,?)""", + ( + gid, leg_n, add_mode, price, int(preview["add_lots"]), new_sl, + LEG_STATUS_FILLED, now, + preview.get("limit_price"), preview.get("breakthrough_price"), + preview.get("mark_price"), _capital(conn), + ), + ) + conn.commit() + send_wechat_msg( + f"滚仓成交 {sym} {add_mode_label(add_mode)} +{preview['add_lots']}手 " + f"新止损 {new_sl} 合计 {new_lots}手" + ) + _schedule_roll_entry_sync(mon_id, sym, mon["direction"], mode) + return True, "成交" + + def _schedule_roll_entry_sync( + mon_id: int, sym: str, direction: str, mode: str, + ) -> None: + """滚仓成交后从柜台同步加权均价到手数监控。""" + def _run() -> None: + import time as _time + + _time.sleep(1.5) + try: + conn = get_db() + try: + init_strategy_tables(conn) + capital = _capital(conn) + synced = False + for p in trading_state.get_positions() or _ctp_positions(mode): + if (p.get("direction") or "long") != (direction or "long"): + continue + if not _match_ctp_symbol(p.get("symbol") or "", sym): + continue + _sync_monitor_from_ctp( + conn, mon_id, sym, direction, mode, ctp=p, capital=capital, + ) + synced = True + break + if synced: + commit_retry(conn) + finally: + conn.close() + if synced: + _push_position_snapshot_async(fast=False) + except Exception as exc: + logger.debug("roll entry sync: %s", exc) + + threading.Thread(target=_run, daemon=True, name="roll-entry-sync").start() + + def _submit_roll_pending( + conn, + *, + mon: dict, + preview: dict, + add_mode: str, + ) -> tuple[bool, str]: + mon_id = int(mon["id"]) + grp = _roll_group_for_monitor(conn, mon_id) + now = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + capital = _capital(conn) + risk_budget = float(preview.get("risk_budget") or get_fixed_amount(get_setting)) + if grp: + gid = int(grp["id"]) + else: + cur = conn.execute( + """INSERT INTO roll_groups ( + order_monitor_id, symbol, direction, initial_take_profit, initial_stop_loss, + current_stop_loss, risk_percent, leg_count, status, created_at, updated_at + ) VALUES (?,?,?,?,?,?,?,0,'active',?,?) RETURNING id""", + ( + mon_id, mon["symbol"], mon["direction"], mon["take_profit"], mon["stop_loss"], + preview["new_stop_loss"], risk_budget, now, now, + ), + ) + row = cur.fetchone() + gid = int(row["id"] if isinstance(row, dict) else row[0]) + leg_n = int(conn.execute( + "SELECT COUNT(*) AS n FROM roll_legs WHERE roll_group_id=? AND status=?", + (gid, LEG_STATUS_FILLED), + ).fetchone()["n"]) + 1 + pending_n = conn.execute( + "SELECT COUNT(*) AS n FROM roll_legs WHERE roll_group_id=? AND status=?", + (gid, LEG_STATUS_PENDING), + ).fetchone()["n"] + if int(pending_n or 0) > 0: + return False, "已有监控中的加仓腿" + conn.execute( + """INSERT INTO roll_legs ( + roll_group_id, leg_index, add_mode, lots, new_stop_loss, status, created_at, + limit_price, breakthrough_price, last_mark_price, capital_snapshot + ) VALUES (?,?,?,?,?,?,?,?,?,?,?)""", + ( + gid, leg_n, add_mode, int(preview["add_lots"]), preview["new_stop_loss"], + LEG_STATUS_PENDING, now, + preview.get("limit_price"), preview.get("breakthrough_price"), + preview.get("mark_price"), capital, + ), + ) + conn.commit() + return True, "已提交监控,触价后自动市价加仓" + + def _fill_roll_leg_cb(mon: dict, grp: dict, leg: dict, preview: dict) -> tuple[bool, str]: + conn = get_db() + init_strategy_tables(conn) + mode = get_trading_mode(get_setting) + ok, msg = _commit_roll_fill( + conn, mon=mon, preview=preview, add_mode=leg.get("add_mode") or ADD_MODE_MARKET, + mode=mode, pending_leg_id=int(leg["id"]), + ) + conn.close() + return ok, msg + + def _check_roll_monitors(): + conn = get_db() + init_strategy_tables(conn) + mode = get_trading_mode(get_setting) + try: + check_roll_monitors( + conn, + get_mark_price_fn=lambda sym: _roll_mark_price(sym, {}, mode, allow_ctp=True), + fill_roll_leg_fn=_fill_roll_leg_cb, + is_trading_session_fn=is_trading_session, + get_risk_budget_fn=lambda: get_fixed_amount(get_setting), + get_entry_price_fn=lambda sym, d, fb: _live_entry_price( + sym, d, mode, fb, allow_ctp=True, + ), + ) + conn.commit() + finally: + conn.close() + + app._check_roll_monitors = _check_roll_monitors + + def _apply_roll_margin_cap( + preview: dict, + *, + conn, + mode: str, + mon: dict, + capital: float, + ) -> tuple[dict, Optional[str]]: + """滚仓:风险算手数后再按滚仓保证金上限收紧。""" + if not preview: + return preview, "预览无效" + sym = mon["symbol"] + direction = (mon.get("direction") or "long").strip().lower() + price = float(preview.get("add_price") or 0) + qty_existing = float(mon.get("lots") or 0) + entry_existing = _live_entry_price( + sym, direction, mode, float(mon.get("entry_price") or 0), + allow_ctp=False, + ) + mult = int(get_contract_spec(sym).get("mult") or 1) + roll_pct = get_roll_max_margin_pct(get_setting) + add_lots = int(preview.get("add_lots") or 0) + positions = _positions_for_monitor_restore(mode, allow_ctp=False) + capped, usage = cap_lots_for_margin_budget( + positions, capital, sym, direction, price, add_lots, roll_pct, trading_mode=mode, + ) + if capped < 1: + return preview, f"滚仓后保证金占用将超过上限 {roll_pct:g}%" + out = dict(preview) + if capped < add_lots: + out["add_lots"] = capped + out["qty_after"] = int(qty_existing + capped) + out["avg_entry_after"] = round( + avg_entry_after_add(qty_existing, entry_existing, capped, price), 4, + ) + sl = float(out.get("new_stop_loss") or 0) + tp = float(out.get("initial_take_profit") or 0) + new_avg = float(out["avg_entry_after"]) + new_qty = float(out["qty_after"]) + if direction == "long": + out["loss_at_sl"] = round((new_avg - sl) * new_qty * mult, 2) + out["reward_at_tp"] = round((tp - new_avg) * new_qty * mult, 2) + else: + out["loss_at_sl"] = round((sl - new_avg) * new_qty * mult, 2) + out["reward_at_tp"] = round((new_avg - tp) * new_qty * mult, 2) + out["margin_capped"] = True + out["margin_cap_note"] = ( + f"按滚仓保证金上限 {roll_pct:g}% 收紧:" + f"风险算 {add_lots} 手 → 实际 {capped} 手" + ) + out["margin_usage_pct"] = round(usage, 2) + out["roll_max_margin_pct"] = roll_pct + return out, None + + @app.route("/api/strategy/roll/preview", methods=["POST"]) + @login_required + def api_roll_preview(): + d = request.get_json(silent=True) or {} + conn = get_db() + init_strategy_tables(conn) + ensure_monitor_order_columns(conn) + mon_id = int(d.get("monitor_id") or 0) + roll_ctx = _build_roll_context(conn) + mon = _roll_monitor_for_request(conn, mon_id) + if not mon: + conn.close() + return jsonify({"ok": False, "error": "无有效持仓监控"}), 400 + conn.commit() + mon_d = dict(mon) + err = _roll_eligibility(conn, mon_d, roll_ctx) + if err: + conn.close() + return jsonify({"ok": False, "error": err}), 400 + mode = get_trading_mode(get_setting) + preview, perr = _build_roll_preview(conn, d, mon_d, mode=mode) + conn.close() + if perr: + return jsonify({"ok": False, "error": perr}), 400 + return jsonify({"ok": True, "preview": preview}) + + @app.route("/api/strategy/roll/execute", methods=["POST"]) + @login_required + def api_roll_execute(): + d = request.get_json(silent=True) or {} + conn = get_db() + init_strategy_tables(conn) + ensure_monitor_order_columns(conn) + mon_id = int(d.get("monitor_id") or 0) + roll_ctx = _build_roll_context(conn) + mon = _roll_monitor_for_request(conn, mon_id) + if not mon: + conn.close() + return jsonify({"ok": False, "error": "无有效持仓监控"}), 400 + conn.commit() + mon_d = dict(mon) + err = _roll_eligibility(conn, mon_d, roll_ctx) + if err: + conn.close() + return jsonify({"ok": False, "error": err}), 400 + mode = get_trading_mode(get_setting) + preview, perr = _build_roll_preview(conn, d, mon_d, mode=mode) + if perr: + conn.close() + return jsonify({"ok": False, "error": perr}), 400 + add_mode = (d.get("add_mode") or ADD_MODE_MARKET).strip().lower() + if add_mode in PENDING_MODES: + ok, msg = _submit_roll_pending(conn, mon=mon_d, preview=preview, add_mode=add_mode) + conn.close() + if not ok: + return jsonify({"ok": False, "error": msg}), 400 + note = "已提交监控,开盘触价后自动市价加仓" if not is_trading_session() else msg + return jsonify({"ok": True, "message": note, "pending": True}) + if not is_trading_session(): + conn.close() + return jsonify({"ok": False, "error": "不在交易时间段"}), 403 + if not _cached_ctp_status(mode).get("connected"): + conn.close() + return jsonify({"ok": False, "error": "请先连接 CTP"}), 400 + ok, msg = _commit_roll_fill( + conn, mon=mon_d, preview=preview, add_mode=add_mode, mode=mode, + ) + conn.close() + if not ok: + return jsonify({"ok": False, "error": msg}), 400 + return jsonify({"ok": True, "message": msg, "preview": preview}) + + @app.route("/api/strategy/roll/cancel/", methods=["POST"]) + @login_required + def api_roll_cancel(leg_id: int): + conn = get_db() + init_strategy_tables(conn) + ok, msg = cancel_roll_leg(conn, leg_id) + if ok: + conn.commit() + conn.close() + if not ok: + return jsonify({"ok": False, "error": msg}), 400 + return jsonify({"ok": True, "message": msg}) + + @app.route("/api/strategy/trend/stop", methods=["POST"]) + @login_required + def api_trend_stop(): + d = request.get_json(silent=True) or {} + plan_id = int(d.get("plan_id") or 0) + conn = get_db() + plan = conn.execute("SELECT * FROM trend_pullback_plans WHERE id=? AND status='active'", (plan_id,)).fetchone() + if not plan: + conn.close() + return jsonify({"ok": False, "error": "计划不存在"}), 404 + mode = get_trading_mode(get_setting) + price = fetch_price(plan["symbol"]) or float(plan["avg_entry_price"] or 0) + try: + if int(plan["lots_open"] or 0) > 0: + execute_order( + conn, mode=mode, offset="close", symbol=plan["symbol"], + direction=plan["direction"], lots=int(plan["lots_open"]), price=price, settings=_settings_dict(), + ) + except ValueError: + pass + now = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + conn.execute( + "UPDATE trend_pullback_plans SET status='stopped_manual', message=?, opened_at=opened_at WHERE id=?", + ("手动结束", plan_id), + ) + save_snapshot( + conn, strategy_type=STRATEGY_TREND, source_id=plan_id, + symbol=plan["symbol"], direction=plan["direction"], result_label="手动结束", + payload=dict(plan), opened_at=plan["opened_at"] or "", + ) + on_user_initiated_close(conn, trading_day=trading_day_label()) + conn.commit() + conn.close() + return jsonify({"ok": True}) + + def check_trend_plans(app_ref): + """后台:趋势补仓与止盈。""" + conn = get_db() + init_strategy_tables(conn) + rows = conn.execute("SELECT * FROM trend_pullback_plans WHERE status='active'").fetchall() + mode = get_trading_mode(get_setting) + for plan in rows: + sym = plan["symbol"] + price = fetch_price(sym) + if not price: + continue + direction = plan["direction"] + tp = float(plan["take_profit"] or 0) + if tp > 0: + hit_tp = (direction == "long" and price >= tp) or (direction == "short" and price <= tp) + if hit_tp: + try: + execute_order( + conn, mode=mode, offset="close", symbol=sym, direction=direction, + lots=int(plan["lots_open"] or 0), price=price, settings=_settings_dict(), + ) + except ValueError: + pass + conn.execute( + "UPDATE trend_pullback_plans SET status='stopped_tp', message=? WHERE id=?", + ("程序止盈", plan["id"]), + ) + save_snapshot( + conn, strategy_type=STRATEGY_TREND, source_id=plan["id"], + symbol=sym, direction=direction, result_label="止盈", + payload=dict(plan), opened_at=plan["opened_at"] or "", + ) + send_wechat_msg(f"趋势回调止盈 {sym}") + continue + try: + grid = json.loads(plan["grid_prices_json"] or "[]") + legs = json.loads(plan["leg_amounts_json"] or "[]") + except Exception: + grid, legs = [], [] + done = int(plan["legs_done"] or 0) + if done < len(grid) and done < len(legs): + level = float(grid[done]) + if trend_dca_level_reached(direction, price, level): + add_lots = int(legs[done]) + try: + execute_order( + conn, mode=mode, offset="open", symbol=sym, direction=direction, + lots=add_lots, price=price, settings=_settings_dict(), + ) + new_open = int(plan["lots_open"] or 0) + add_lots + old_avg = float(plan["avg_entry_price"] or price) + new_avg = (old_avg * int(plan["lots_open"] or 0) + price * add_lots) / new_open if new_open else price + conn.execute( + """UPDATE trend_pullback_plans SET legs_done=?, lots_open=?, avg_entry_price=? WHERE id=?""", + (done + 1, new_open, new_avg, plan["id"]), + ) + send_wechat_msg(f"趋势回调补仓 {sym} +{add_lots}手 @档位{done+1}") + except ValueError: + pass + conn.commit() + conn.close() + + app._check_trend_plans = check_trend_plans + + def _execute_key_breakout(conn, row, bar, break_side): + """关键位箱体/收敛:5m 收盘突破后自动市价开仓。""" + from modules.keys.key_monitor_lib import ( + TYPE_BOX, + calc_breakout_sl_tp, + format_auto_breakout_msg, + normalize_monitor_type, + resolve_order_direction, + ) + + sym = (row.get("symbol") or "").strip() + bar_time = str(bar.get("time") or "")[:19] + monitor_type = normalize_monitor_type(row.get("monitor_type") or "") + trade_mode = row.get("trade_mode") or "顺势" + direction = resolve_order_direction(break_side, trade_mode) + trailing_be = int(row.get("trailing_be") or 0) + try: + rr = float(row.get("risk_reward") or (3 if trailing_be else 2)) + except (TypeError, ValueError): + rr = 3.0 if trailing_be else 2.0 + if trailing_be and rr < 3: + rr = 3.0 + + def _notify(ok: bool, detail: str, **kw): + send_wechat_msg(format_auto_breakout_msg( + row, + break_side=break_side, + direction=direction, + entry=kw.get("entry", 0), + sl=kw.get("sl", 0), + tp=kw.get("tp", 0), + lots=kw.get("lots", 0), + bar_time=bar_time, + ok=ok, + detail=detail, + )) + + if monitor_type == TYPE_BOX: + cfg_dir = (row.get("direction") or "").strip().lower() + if cfg_dir in ("long", "short") and direction != cfg_dir: + dir_cn = "做多" if cfg_dir == "long" else "做空" + _notify(False, f"突破方向与上方向({dir_cn})不一致", entry=0, sl=0, tp=0, lots=0) + return False, "突破方向与上方向不一致" + + try: + init_strategy_tables(conn) + mode = get_trading_mode(get_setting) + if not ctp_status(mode).get("connected"): + _notify(False, "CTP 未连接") + return False, "CTP 未连接" + if not is_trading_session(): + _notify(False, "非交易时段") + return False, "非交易时段" + + try: + entry = float(bar.get("close") or 0) + except (TypeError, ValueError): + _notify(False, "K 线收盘价无效") + return False, "K 线收盘价无效" + if entry <= 0: + _notify(False, "K 线收盘价无效") + return False, "K 线收盘价无效" + + sl, tp = calc_breakout_sl_tp( + sym=sym, direction=direction, entry=entry, bar=bar, risk_reward=rr, + ) + err = assert_can_open( + conn, + active_count=_effective_active_position_count(conn, mode), + equity=_capital(conn), + ) + if err: + _notify(False, err, entry=entry, sl=sl, tp=tp, lots=0) + return False, err + + capital = _capital(conn) + lots, lot_err = calc_lots_by_risk( + entry, sl, direction, capital, get_risk_percent(get_setting), sym, + max_margin_pct=get_max_margin_pct(get_setting), trading_mode=mode, + ) + if lot_err or not lots: + msg = lot_err or "手数计算失败" + _notify(False, msg, entry=entry, sl=sl, tp=tp, lots=0) + return False, msg + + result = execute_order( + conn, + mode=mode, + offset="open", + symbol=sym, + direction=direction, + lots=lots, + price=entry, + settings=_settings_dict(), + order_type="market", + ) + open_ts = bar_time.replace("T", " ") if bar_time else datetime.now().strftime("%Y-%m-%d %H:%M:%S") + vt_order_id = str(result.get("order_id") or "") + mid = _upsert_open_monitor( + conn, + sym=sym, + direction=direction, + lots=lots, + price=entry, + sl=sl, + tp=tp, + trailing_be=trailing_be, + open_time=open_ts, + monitor_type=monitor_type, + status="pending", + vt_order_id=vt_order_id or None, + order_price=entry, + ) + _reconcile_pending(conn, mode, capital=capital) + st_row = conn.execute( + "SELECT status FROM trade_order_monitors WHERE id=?", (mid,), + ).fetchone() + filled = st_row and (st_row["status"] or "").strip().lower() == "active" + rejected = st_row and (st_row["status"] or "").strip().lower() == "closed" + if rejected: + conn.commit() + _notify(False, "委托被柜台拒绝或撤销", entry=entry, sl=sl, tp=tp, lots=lots) + return False, "委托被拒绝" + if filled: + _sync_monitor_from_ctp( + conn, mid, sym, direction, mode, capital=capital, + ) + conn.commit() + if filled: + from modules.core.db_conn import DB_PATH + from modules.notify.ai_worker import schedule_ai_event_analysis + from modules.trading.trade_notify import notify_key_breakout_open + + notify_key_breakout_open( + send_wechat=send_wechat_msg, + get_setting=get_setting, + mode_label=trading_mode_label(get_setting), + row=row, + break_side=break_side, + bar_time=bar_time, + direction=direction, + entry=entry, + sl=sl, + tp=tp, + lots=lots, + capital=capital, + order_id=vt_order_id, + schedule_ai_fn=schedule_ai_event_analysis, + db_path=DB_PATH, + ) + else: + _notify(True, "委托已提交,待成交", entry=entry, sl=sl, tp=tp, lots=lots) + _push_position_snapshot_async(fast=False) + return True, "已下单" if filled else "委托已提交" + except Exception as exc: + logger.warning("key breakout auto order: %s", exc) + _notify(False, str(exc)) + return False, str(exc) + + app._execute_key_breakout = _execute_key_breakout + + @app.route("/settings/trading", methods=["POST"]) + @login_required + def settings_trading_post(): + return redirect(url_for("settings")) + + def hook_review_mood(conn, behavior_tags: str, exit_trigger: str, exit_supplement: str): + if parse_mood_issues(behavior_tags): + on_mood_journal_freeze(conn, trading_day=trading_day_label()) + + app._risk_review_hook = hook_review_mood + + from modules.core.db_conn import DB_PATH + + def _init_tables(conn): + init_strategy_tables(conn) + + threading.Thread( + target=_prime_position_snapshot, + daemon=True, + name="position-prime", + ).start() + + _pos_refresh_tick = {"n": 0} + _last_full_calibrate = {"ts": 0.0} + + def _position_worker_refresh() -> dict: + import time as _time + from modules.ctp.ctp_trading_state import CALIBRATE_INTERVAL_SEC + + _pos_refresh_tick["n"] += 1 + mode = get_trading_mode(get_setting) + connected = bool(ctp_status(mode).get("connected")) + now = _time.time() + since_connect = now - float( + getattr(get_bridge(), "_last_connect_ok_ts", 0) or 0, + ) + if connected and since_connect < 45: + return _refresh_trading_live_snapshot(fast=True) + need_full = ( + connected + and ( + trading_state.needs_calibrate() + or (now - _last_full_calibrate["ts"]) >= CALIBRATE_INTERVAL_SEC + ) + ) + if need_full: + _last_full_calibrate["ts"] = now + return _refresh_trading_live_snapshot(fast=False) + return _refresh_trading_live_snapshot(fast=True) + + start_position_worker( + refresh_fn=_position_worker_refresh, + interval=1, + idle_interval=3, + ) + _bootstrap_trading_runtime() + start_ctp_reconnect_worker( + get_mode_fn=lambda: get_trading_mode(get_setting), + get_setting_fn=get_setting, + ) + start_ctp_premarket_connect_worker( + get_mode_fn=lambda: get_trading_mode(get_setting), + get_setting_fn=get_setting, + ) + start_sl_tp_guard_worker( + db_path=DB_PATH, + get_mode_fn=lambda: get_trading_mode(get_setting), + init_tables_fn=_init_tables, + get_capital_fn=_capital, + get_be_tick_buffer_fn=lambda: get_trailing_be_tick_buffer(get_setting), + notify_fn=send_wechat_msg, + interval=1, + ) + start_pending_order_worker( + db_path=DB_PATH, + get_mode_fn=lambda: get_trading_mode(get_setting), + init_tables_fn=_init_tables, + get_capital_fn=_capital, + reconcile_fn=_reconcile_pending, + on_changed_fn=lambda: _push_position_snapshot_async(fast=False), + ) + + def _start_deferred_workers() -> None: + time.sleep(2) + start_recommend_worker( + db_path=DB_PATH, + get_capital_fn=_recommend_capital, + quote_fn=_main_quote, + init_tables_fn=_init_tables, + get_mode_fn=lambda: get_trading_mode(get_setting), + get_max_margin_pct_fn=lambda: get_max_margin_pct(get_setting), + get_sizing_mode_fn=lambda: get_sizing_mode(get_setting), + get_fixed_lots_fn=lambda: get_fixed_lots(get_setting), + ) + start_ctp_fee_worker( + get_mode_fn=lambda: get_trading_mode(get_setting), + get_setting_fn=get_setting, + set_setting_fn=set_setting, + ) + from modules.notify.ai_worker import start_ai_worker + + start_ai_worker( + db_path=DB_PATH, + get_setting_fn=get_setting, + set_setting_fn=set_setting, + send_wechat_fn=send_wechat_msg, + ) + + threading.Thread( + target=_start_deferred_workers, + daemon=True, + name="deferred-workers", + ).start() diff --git a/order_pending.py b/modules/trading/order_pending.py similarity index 94% rename from order_pending.py rename to modules/trading/order_pending.py index be08e3c..04eb4a2 100644 --- a/order_pending.py +++ b/modules/trading/order_pending.py @@ -1,284 +1,284 @@ -# Copyright (c) 2025-2026 马建军. All rights reserved. -# 专有软件 — 未经授权禁止复制、传播、转售。 -# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 -# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md - -"""开仓委托:pending 状态跟踪、成交转正、超时撤单。""" -from __future__ import annotations - -import logging -import time -from datetime import datetime -from typing import Any, Callable, Optional -from zoneinfo import ZoneInfo - -from market_sessions import is_trading_session -from vnpy_bridge import ctp_cancel_order, ctp_list_active_orders, ctp_status - -logger = logging.getLogger(__name__) - -TZ = ZoneInfo("Asia/Shanghai") -DEFAULT_PENDING_ORDER_TIMEOUT_SEC = 300 -# 报单刚提交后短暂等待 CTP 回报,避免误判为拒单 -PENDING_ORDER_SETTLE_GRACE_SEC = 8 - - -def pending_monitor_has_live_order( - mon: dict, - *, - active_orders: dict[str, dict], - active_order_list: list[dict], - match_fn: Callable[[str, str], bool] | None = None, -) -> bool: - """本地 pending 是否仍对应 CTP 柜台上的有效开仓委托。""" - match = match_fn or _match_symbol - sym = mon.get("symbol") or "" - direction = mon.get("direction") or "long" - vt_oid = (mon.get("vt_order_id") or "").strip() - age = pending_age_sec(mon) - - if vt_oid and _vt_order_in_active(vt_oid, active_orders): - return True - if _symbol_open_order_active(active_order_list, sym, direction, match): - return True - if not vt_oid and age < PENDING_ORDER_SETTLE_GRACE_SEC: - return True - if vt_oid and age < PENDING_ORDER_SETTLE_GRACE_SEC: - return True - return False - - -def parse_monitor_ts(raw: str) -> Optional[float]: - s = (raw or "").strip() - if not s: - return None - for fmt in ("%Y-%m-%d %H:%M:%S", "%Y-%m-%dT%H:%M:%S", "%Y-%m-%dT%H:%M"): - try: - return datetime.strptime(s[:19], fmt).replace(tzinfo=TZ).timestamp() - except ValueError: - continue - return None - - -def pending_age_sec(mon: dict) -> float: - ts = parse_monitor_ts(mon.get("open_time") or "") or parse_monitor_ts( - str(mon.get("created_at") or "") - ) - if ts is None: - return 0.0 - return max(0.0, time.time() - ts) - - -def pending_auto_cancel_remaining( - mon: dict, - *, - timeout_sec: int = DEFAULT_PENDING_ORDER_TIMEOUT_SEC, -) -> int: - limit = max(60, int(timeout_sec or DEFAULT_PENDING_ORDER_TIMEOUT_SEC)) - return max(0, int(limit - pending_age_sec(mon))) - - -def _match_symbol(ctp_sym: str, ths: str) -> bool: - a = (ctp_sym or "").lower() - b = (ths or "").lower() - if a == b: - return True - if a and b and a.split(".")[0] == b.split(".")[0]: - return True - try: - from ctp_symbol import ths_to_vnpy_symbol - vnpy_sym, _ = ths_to_vnpy_symbol(ths) - if a == vnpy_sym.lower(): - return True - except Exception: - pass - return False - - -def _find_ctp_position(positions: list[dict], sym: str, direction: str) -> Optional[dict]: - direction = (direction or "long").strip().lower() - for p in positions or []: - if int(p.get("lots") or 0) <= 0: - continue - if (p.get("direction") or "long") != direction: - continue - if _match_symbol(p.get("symbol") or "", sym): - return p - return None - - -def _vt_order_in_active(vt_oid: str, active_orders: dict[str, dict]) -> bool: - oid = (vt_oid or "").strip() - if not oid: - return False - if oid in active_orders: - return True - tail = oid.rsplit("_", 1)[-1] - for key in active_orders: - if key == oid or key.endswith(tail) or oid.endswith(key): - return True - return False - - -def _symbol_open_order_active( - orders: list[dict], - sym: str, - direction: str, - match_fn: Callable[[str, str], bool], -) -> Optional[dict]: - direction = (direction or "long").strip().lower() - for o in orders or []: - offset_u = (o.get("offset") or "").upper() - if offset_u and "OPEN" not in offset_u: - continue - if (o.get("direction") or "long") != direction: - continue - if match_fn(o.get("symbol") or "", sym): - return o - return None - - -def reconcile_pending_orders( - conn, - mode: str, - *, - match_symbol_fn: Callable[[str, str], bool] | None = None, - sync_monitor_fn: Callable[..., None] | None = None, - capital: float = 0.0, - list_positions_fn: Callable[..., list] | None = None, - timeout_sec: int = DEFAULT_PENDING_ORDER_TIMEOUT_SEC, -) -> dict[str, int]: - """同步 pending 委托:成交→active;超时/已撤→closed。""" - limit_sec = max(60, int(timeout_sec or DEFAULT_PENDING_ORDER_TIMEOUT_SEC)) - stats = {"promoted": 0, "cancelled": 0, "closed": 0} - if not ctp_status(mode).get("connected"): - return stats - - match = match_symbol_fn or _match_symbol - positions = ( - list_positions_fn(mode, refresh_if_empty=True, refresh_margin=False) - if list_positions_fn - else [] - ) - try: - active_order_list = ctp_list_active_orders(mode) - active_orders = {} - for o in active_order_list: - for key in (o.get("order_id"), o.get("vt_order_id")): - if key: - active_orders[str(key)] = o - except Exception as exc: - logger.debug("list active orders: %s", exc) - active_order_list = [] - active_orders = {} - - rows = conn.execute( - "SELECT * FROM trade_order_monitors WHERE status='pending' ORDER BY id ASC" - ).fetchall() - - for r in rows: - mon = dict(r) - mid = int(mon["id"]) - sym = mon.get("symbol") or "" - direction = mon.get("direction") or "long" - vt_oid = (mon.get("vt_order_id") or "").strip() - age = pending_age_sec(mon) - - pos = _find_ctp_position(positions, sym, direction) - if pos: - conn.execute( - "UPDATE trade_order_monitors SET status='active' WHERE id=?", - (mid,), - ) - if sync_monitor_fn: - sync_monitor_fn( - conn, mid, sym, direction, mode, ctp=pos, capital=capital, - ) - stats["promoted"] += 1 - continue - - if vt_oid and _vt_order_in_active(vt_oid, active_orders): - if age >= limit_sec and is_trading_session(): - if ctp_cancel_order(mode, vt_oid): - conn.execute( - "UPDATE trade_order_monitors SET status='closed' WHERE id=?", - (mid,), - ) - stats["cancelled"] += 1 - else: - logger.warning("pending auto-cancel failed monitor=%s order=%s", mid, vt_oid) - continue - - live_open = _symbol_open_order_active(active_order_list, sym, direction, match) - if live_open: - if age >= limit_sec and is_trading_session(): - cancel_oid = ( - vt_oid - or live_open.get("vt_order_id") - or live_open.get("order_id") - or "" - ) - if cancel_oid and ctp_cancel_order(mode, cancel_oid): - conn.execute( - "UPDATE trade_order_monitors SET status='closed' WHERE id=?", - (mid,), - ) - stats["cancelled"] += 1 - continue - - # 有委托号但已不在 CTP 活跃列表且无持仓 → 拒单/已撤/终态 - if vt_oid: - if age < PENDING_ORDER_SETTLE_GRACE_SEC: - continue - conn.execute( - "UPDATE trade_order_monitors SET status='closed' WHERE id=?", - (mid,), - ) - stats["closed"] += 1 - logger.info( - "pending monitor=%s order=%s closed (no longer active on CTP)", - mid, vt_oid, - ) - continue - - if age >= limit_sec: - if vt_oid and is_trading_session(): - if ctp_cancel_order(mode, vt_oid): - conn.execute( - "UPDATE trade_order_monitors SET status='closed' WHERE id=?", - (mid,), - ) - stats["cancelled"] += 1 - else: - logger.info( - "pending monitor=%s order=%s kept (cancel not confirmed)", - mid, vt_oid, - ) - elif not vt_oid: - conn.execute( - "UPDATE trade_order_monitors SET status='closed' WHERE id=?", - (mid,), - ) - stats["closed"] += 1 - - if any(stats.values()): - conn.commit() - return stats - - -def cancel_pending_monitor( - conn, - mon: dict, - mode: str, -) -> tuple[bool, str]: - """手动撤销 pending 开仓委托。""" - mid = int(mon.get("id") or 0) - vt_oid = (mon.get("vt_order_id") or "").strip() - if vt_oid and ctp_status(mode).get("connected"): - try: - ctp_cancel_order(mode, vt_oid) - except Exception as exc: - logger.warning("cancel pending order monitor=%s: %s", mid, exc) - conn.execute("UPDATE trade_order_monitors SET status='closed' WHERE id=?", (mid,)) - conn.commit() - return True, "开仓委托已撤销" +# Copyright (c) 2025-2026 马建军. All rights reserved. +# 专有软件 — 未经授权禁止复制、传播、转售。 +# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 +# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md + +"""开仓委托:pending 状态跟踪、成交转正、超时撤单。""" +from __future__ import annotations + +import logging +import time +from datetime import datetime +from typing import Any, Callable, Optional +from zoneinfo import ZoneInfo + +from modules.market.market_sessions import is_trading_session +from modules.ctp.vnpy_bridge import ctp_cancel_order, ctp_list_active_orders, ctp_status + +logger = logging.getLogger(__name__) + +TZ = ZoneInfo("Asia/Shanghai") +DEFAULT_PENDING_ORDER_TIMEOUT_SEC = 300 +# 报单刚提交后短暂等待 CTP 回报,避免误判为拒单 +PENDING_ORDER_SETTLE_GRACE_SEC = 8 + + +def pending_monitor_has_live_order( + mon: dict, + *, + active_orders: dict[str, dict], + active_order_list: list[dict], + match_fn: Callable[[str, str], bool] | None = None, +) -> bool: + """本地 pending 是否仍对应 CTP 柜台上的有效开仓委托。""" + match = match_fn or _match_symbol + sym = mon.get("symbol") or "" + direction = mon.get("direction") or "long" + vt_oid = (mon.get("vt_order_id") or "").strip() + age = pending_age_sec(mon) + + if vt_oid and _vt_order_in_active(vt_oid, active_orders): + return True + if _symbol_open_order_active(active_order_list, sym, direction, match): + return True + if not vt_oid and age < PENDING_ORDER_SETTLE_GRACE_SEC: + return True + if vt_oid and age < PENDING_ORDER_SETTLE_GRACE_SEC: + return True + return False + + +def parse_monitor_ts(raw: str) -> Optional[float]: + s = (raw or "").strip() + if not s: + return None + for fmt in ("%Y-%m-%d %H:%M:%S", "%Y-%m-%dT%H:%M:%S", "%Y-%m-%dT%H:%M"): + try: + return datetime.strptime(s[:19], fmt).replace(tzinfo=TZ).timestamp() + except ValueError: + continue + return None + + +def pending_age_sec(mon: dict) -> float: + ts = parse_monitor_ts(mon.get("open_time") or "") or parse_monitor_ts( + str(mon.get("created_at") or "") + ) + if ts is None: + return 0.0 + return max(0.0, time.time() - ts) + + +def pending_auto_cancel_remaining( + mon: dict, + *, + timeout_sec: int = DEFAULT_PENDING_ORDER_TIMEOUT_SEC, +) -> int: + limit = max(60, int(timeout_sec or DEFAULT_PENDING_ORDER_TIMEOUT_SEC)) + return max(0, int(limit - pending_age_sec(mon))) + + +def _match_symbol(ctp_sym: str, ths: str) -> bool: + a = (ctp_sym or "").lower() + b = (ths or "").lower() + if a == b: + return True + if a and b and a.split(".")[0] == b.split(".")[0]: + return True + try: + from modules.ctp.ctp_symbol import ths_to_vnpy_symbol + vnpy_sym, _ = ths_to_vnpy_symbol(ths) + if a == vnpy_sym.lower(): + return True + except Exception: + pass + return False + + +def _find_ctp_position(positions: list[dict], sym: str, direction: str) -> Optional[dict]: + direction = (direction or "long").strip().lower() + for p in positions or []: + if int(p.get("lots") or 0) <= 0: + continue + if (p.get("direction") or "long") != direction: + continue + if _match_symbol(p.get("symbol") or "", sym): + return p + return None + + +def _vt_order_in_active(vt_oid: str, active_orders: dict[str, dict]) -> bool: + oid = (vt_oid or "").strip() + if not oid: + return False + if oid in active_orders: + return True + tail = oid.rsplit("_", 1)[-1] + for key in active_orders: + if key == oid or key.endswith(tail) or oid.endswith(key): + return True + return False + + +def _symbol_open_order_active( + orders: list[dict], + sym: str, + direction: str, + match_fn: Callable[[str, str], bool], +) -> Optional[dict]: + direction = (direction or "long").strip().lower() + for o in orders or []: + offset_u = (o.get("offset") or "").upper() + if offset_u and "OPEN" not in offset_u: + continue + if (o.get("direction") or "long") != direction: + continue + if match_fn(o.get("symbol") or "", sym): + return o + return None + + +def reconcile_pending_orders( + conn, + mode: str, + *, + match_symbol_fn: Callable[[str, str], bool] | None = None, + sync_monitor_fn: Callable[..., None] | None = None, + capital: float = 0.0, + list_positions_fn: Callable[..., list] | None = None, + timeout_sec: int = DEFAULT_PENDING_ORDER_TIMEOUT_SEC, +) -> dict[str, int]: + """同步 pending 委托:成交→active;超时/已撤→closed。""" + limit_sec = max(60, int(timeout_sec or DEFAULT_PENDING_ORDER_TIMEOUT_SEC)) + stats = {"promoted": 0, "cancelled": 0, "closed": 0} + if not ctp_status(mode).get("connected"): + return stats + + match = match_symbol_fn or _match_symbol + positions = ( + list_positions_fn(mode, refresh_if_empty=True, refresh_margin=False) + if list_positions_fn + else [] + ) + try: + active_order_list = ctp_list_active_orders(mode) + active_orders = {} + for o in active_order_list: + for key in (o.get("order_id"), o.get("vt_order_id")): + if key: + active_orders[str(key)] = o + except Exception as exc: + logger.debug("list active orders: %s", exc) + active_order_list = [] + active_orders = {} + + rows = conn.execute( + "SELECT * FROM trade_order_monitors WHERE status='pending' ORDER BY id ASC" + ).fetchall() + + for r in rows: + mon = dict(r) + mid = int(mon["id"]) + sym = mon.get("symbol") or "" + direction = mon.get("direction") or "long" + vt_oid = (mon.get("vt_order_id") or "").strip() + age = pending_age_sec(mon) + + pos = _find_ctp_position(positions, sym, direction) + if pos: + conn.execute( + "UPDATE trade_order_monitors SET status='active' WHERE id=?", + (mid,), + ) + if sync_monitor_fn: + sync_monitor_fn( + conn, mid, sym, direction, mode, ctp=pos, capital=capital, + ) + stats["promoted"] += 1 + continue + + if vt_oid and _vt_order_in_active(vt_oid, active_orders): + if age >= limit_sec and is_trading_session(): + if ctp_cancel_order(mode, vt_oid): + conn.execute( + "UPDATE trade_order_monitors SET status='closed' WHERE id=?", + (mid,), + ) + stats["cancelled"] += 1 + else: + logger.warning("pending auto-cancel failed monitor=%s order=%s", mid, vt_oid) + continue + + live_open = _symbol_open_order_active(active_order_list, sym, direction, match) + if live_open: + if age >= limit_sec and is_trading_session(): + cancel_oid = ( + vt_oid + or live_open.get("vt_order_id") + or live_open.get("order_id") + or "" + ) + if cancel_oid and ctp_cancel_order(mode, cancel_oid): + conn.execute( + "UPDATE trade_order_monitors SET status='closed' WHERE id=?", + (mid,), + ) + stats["cancelled"] += 1 + continue + + # 有委托号但已不在 CTP 活跃列表且无持仓 → 拒单/已撤/终态 + if vt_oid: + if age < PENDING_ORDER_SETTLE_GRACE_SEC: + continue + conn.execute( + "UPDATE trade_order_monitors SET status='closed' WHERE id=?", + (mid,), + ) + stats["closed"] += 1 + logger.info( + "pending monitor=%s order=%s closed (no longer active on CTP)", + mid, vt_oid, + ) + continue + + if age >= limit_sec: + if vt_oid and is_trading_session(): + if ctp_cancel_order(mode, vt_oid): + conn.execute( + "UPDATE trade_order_monitors SET status='closed' WHERE id=?", + (mid,), + ) + stats["cancelled"] += 1 + else: + logger.info( + "pending monitor=%s order=%s kept (cancel not confirmed)", + mid, vt_oid, + ) + elif not vt_oid: + conn.execute( + "UPDATE trade_order_monitors SET status='closed' WHERE id=?", + (mid,), + ) + stats["closed"] += 1 + + if any(stats.values()): + conn.commit() + return stats + + +def cancel_pending_monitor( + conn, + mon: dict, + mode: str, +) -> tuple[bool, str]: + """手动撤销 pending 开仓委托。""" + mid = int(mon.get("id") or 0) + vt_oid = (mon.get("vt_order_id") or "").strip() + if vt_oid and ctp_status(mode).get("connected"): + try: + ctp_cancel_order(mode, vt_oid) + except Exception as exc: + logger.warning("cancel pending order monitor=%s: %s", mid, exc) + conn.execute("UPDATE trade_order_monitors SET status='closed' WHERE id=?", (mid,)) + conn.commit() + return True, "开仓委托已撤销" diff --git a/pending_order_worker.py b/modules/trading/pending_order_worker.py similarity index 94% rename from pending_order_worker.py rename to modules/trading/pending_order_worker.py index 44e384d..980ab2c 100644 --- a/pending_order_worker.py +++ b/modules/trading/pending_order_worker.py @@ -1,82 +1,82 @@ -# Copyright (c) 2025-2026 马建军. All rights reserved. -# 专有软件 — 未经授权禁止复制、传播、转售。 -# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 -# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md - -"""开仓挂单超时:后台定期 reconcile,不依赖 SSE 完整刷新。""" -from __future__ import annotations - -import logging -import threading -import time -from typing import Callable, Optional - -from vnpy_bridge import ctp_status - -logger = logging.getLogger(__name__) - -CHECK_INTERVAL_SEC = 10 -IDLE_INTERVAL_SEC = 45 -DISCONNECTED_SLEEP_SEC = 30 -STARTUP_DELAY_SEC = 15 - - -def start_pending_order_worker( - *, - db_path: str, - get_mode_fn: Callable[[], str], - init_tables_fn: Callable | None = None, - get_capital_fn: Callable | None = None, - reconcile_fn: Callable[..., dict], - on_changed_fn: Callable[[], None] | None = None, - interval: int = CHECK_INTERVAL_SEC, - idle_interval: int = IDLE_INTERVAL_SEC, -) -> None: - """后台线程:存在 pending 开仓监控时定期同步成交/超时撤单。""" - from db_conn import connect_db - - def _loop() -> None: - time.sleep(STARTUP_DELAY_SEC) - while True: - sleep_sec = max(5, idle_interval) - try: - mode = get_mode_fn() - if not ctp_status(mode).get("connected"): - time.sleep(DISCONNECTED_SLEEP_SEC) - continue - - conn = connect_db(db_path) - try: - if init_tables_fn: - init_tables_fn(conn) - pending_n = conn.execute( - "SELECT COUNT(*) AS n FROM trade_order_monitors WHERE status='pending'" - ).fetchone()["n"] - if pending_n <= 0: - time.sleep(sleep_sec) - continue - - sleep_sec = max(1, interval) - capital = 0.0 - if get_capital_fn: - try: - capital = float(get_capital_fn(conn) or 0) - except Exception: - capital = 0.0 - stats = reconcile_fn(conn, mode, capital=capital) or {} - if any(int(stats.get(k) or 0) for k in ("promoted", "cancelled", "closed")): - logger.info( - "pending worker reconcile: promoted=%s cancelled=%s closed=%s", - stats.get("promoted", 0), - stats.get("cancelled", 0), - stats.get("closed", 0), - ) - if on_changed_fn: - on_changed_fn() - finally: - conn.close() - except Exception as exc: - logger.warning("pending order worker: %s", exc) - time.sleep(sleep_sec) - - threading.Thread(target=_loop, daemon=True, name="pending-order-worker").start() +# Copyright (c) 2025-2026 马建军. All rights reserved. +# 专有软件 — 未经授权禁止复制、传播、转售。 +# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 +# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md + +"""开仓挂单超时:后台定期 reconcile,不依赖 SSE 完整刷新。""" +from __future__ import annotations + +import logging +import threading +import time +from typing import Callable, Optional + +from modules.ctp.vnpy_bridge import ctp_status + +logger = logging.getLogger(__name__) + +CHECK_INTERVAL_SEC = 10 +IDLE_INTERVAL_SEC = 45 +DISCONNECTED_SLEEP_SEC = 30 +STARTUP_DELAY_SEC = 15 + + +def start_pending_order_worker( + *, + db_path: str, + get_mode_fn: Callable[[], str], + init_tables_fn: Callable | None = None, + get_capital_fn: Callable | None = None, + reconcile_fn: Callable[..., dict], + on_changed_fn: Callable[[], None] | None = None, + interval: int = CHECK_INTERVAL_SEC, + idle_interval: int = IDLE_INTERVAL_SEC, +) -> None: + """后台线程:存在 pending 开仓监控时定期同步成交/超时撤单。""" + from modules.core.db_conn import connect_db + + def _loop() -> None: + time.sleep(STARTUP_DELAY_SEC) + while True: + sleep_sec = max(5, idle_interval) + try: + mode = get_mode_fn() + if not ctp_status(mode).get("connected"): + time.sleep(DISCONNECTED_SLEEP_SEC) + continue + + conn = connect_db(db_path) + try: + if init_tables_fn: + init_tables_fn(conn) + pending_n = conn.execute( + "SELECT COUNT(*) AS n FROM trade_order_monitors WHERE status='pending'" + ).fetchone()["n"] + if pending_n <= 0: + time.sleep(sleep_sec) + continue + + sleep_sec = max(1, interval) + capital = 0.0 + if get_capital_fn: + try: + capital = float(get_capital_fn(conn) or 0) + except Exception: + capital = 0.0 + stats = reconcile_fn(conn, mode, capital=capital) or {} + if any(int(stats.get(k) or 0) for k in ("promoted", "cancelled", "closed")): + logger.info( + "pending worker reconcile: promoted=%s cancelled=%s closed=%s", + stats.get("promoted", 0), + stats.get("cancelled", 0), + stats.get("closed", 0), + ) + if on_changed_fn: + on_changed_fn() + finally: + conn.close() + except Exception as exc: + logger.warning("pending order worker: %s", exc) + time.sleep(sleep_sec) + + threading.Thread(target=_loop, daemon=True, name="pending-order-worker").start() diff --git a/position_sizing.py b/modules/trading/position_sizing.py similarity index 96% rename from position_sizing.py rename to modules/trading/position_sizing.py index 47f1893..cea3621 100644 --- a/position_sizing.py +++ b/modules/trading/position_sizing.py @@ -1,270 +1,270 @@ -# Copyright (c) 2025-2026 马建军. All rights reserved. -# 专有软件 — 未经授权禁止复制、传播、转售。 -# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 -# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md - -"""期货计仓:固定手数 / 固定金额。""" -from __future__ import annotations - -import math -from typing import Optional - -from contract_specs import get_contract_spec, margin_one_lot - -MODE_FIXED = "fixed" -MODE_AMOUNT = "amount" -MODE_RISK = "amount" # 兼容旧配置「以损定仓」 - -DEFAULT_MAX_ORDER_LOTS = 50 - - -def normalize_sizing_mode(raw: str) -> str: - m = (raw or MODE_FIXED).strip().lower() - if m == "risk": - m = MODE_AMOUNT - return m if m in (MODE_FIXED, MODE_AMOUNT) else MODE_FIXED - - -def price_precision_from_tick(tick_size: float) -> int: - if tick_size <= 0: - return 0 - s = f"{tick_size:.10f}".rstrip("0").rstrip(".") - if "." not in s: - return 0 - return len(s.split(".")[1]) - - -def _per_lot_risk(entry: float, stop_loss: float, direction: str, ths_code: str) -> tuple[float, Optional[str]]: - spec = get_contract_spec(ths_code) - mult = spec["mult"] - d = (direction or "long").strip().lower() - if d == "short": - per_lot = (stop_loss - entry) * mult - else: - per_lot = (entry - stop_loss) * mult - if per_lot <= 0: - return 0.0, "止损方向与入场价不匹配" - return per_lot, None - - -def calc_lots_by_amount( - entry: float, - stop_loss: float, - direction: str, - amount: float, - ths_code: str, - *, - capital: float = 0.0, - max_lots: Optional[int] = None, - max_margin_pct: float = 30.0, - trading_mode: str | None = None, -) -> tuple[Optional[int], Optional[str], dict]: - """固定金额:先按止损距离算手数,再按保证金上限收紧。返回 (手数, 错误, 详情)。""" - info: dict = { - "lots_by_risk": 0, - "lots_by_margin": None, - "capped_by": None, - } - try: - entry_f = float(entry) - sl_f = float(stop_loss) - budget = float(amount) - cap = float(capital or 0) - except (TypeError, ValueError): - return None, "参数格式错误", info - if entry_f <= 0 or budget <= 0: - return None, "入场价或固定金额无效", info - per_lot_risk, err = _per_lot_risk(entry_f, sl_f, direction, ths_code) - if err: - return None, err, info - lots = int(math.floor(budget / per_lot_risk)) - info["lots_by_risk"] = lots - if lots < 1: - return None, f"按固定金额 {budget:.0f} 元,当前止损距离下不足 1 手", info - if cap > 0: - margin_per_lot, _src, _spec = margin_one_lot( - ths_code, entry_f, direction=direction, trading_mode=trading_mode, - ) - if margin_per_lot <= 0: - spec = get_contract_spec(ths_code) - margin_per_lot = entry_f * spec["mult"] * spec["margin_rate"] - margin_cap = max(1.0, min(100.0, float(max_margin_pct or 30.0))) - max_by_margin = ( - int(math.floor(cap * margin_cap / 100.0 / margin_per_lot)) - if margin_per_lot > 0 else lots - ) - info["lots_by_margin"] = max_by_margin - info["margin_per_lot"] = round(margin_per_lot, 2) - info["max_margin_pct"] = margin_cap - if max_by_margin < 1: - return None, f"按保证金上限 {margin_cap:g}%,当前不足 1 手", info - if max_by_margin < lots: - info["capped_by"] = "margin" - lots = min(lots, max_by_margin) - cap_lots = max_lots if max_lots is not None else DEFAULT_MAX_ORDER_LOTS - if lots > cap_lots: - lots = cap_lots - info["capped_by"] = info.get("capped_by") or "max_lots" - info["lots"] = lots - return lots, None, info - - -def calc_lots_by_risk( - entry: float, - stop_loss: float, - direction: str, - capital: float, - risk_percent: float, - ths_code: str, - *, - max_lots: Optional[int] = None, - max_margin_pct: float = 30.0, - trading_mode: str | None = None, -) -> tuple[Optional[int], Optional[str]]: - """策略等场景:按权益百分比风险预算换算手数。""" - try: - cap = float(capital) - rp = float(risk_percent) - except (TypeError, ValueError): - return None, "参数格式错误" - if cap <= 0 or rp <= 0: - return None, "资金或风险比例无效" - budget = cap * rp / 100.0 - lots, err, info = calc_lots_by_amount( - entry, stop_loss, direction, budget, ths_code, - capital=cap, max_lots=max_lots, max_margin_pct=max_margin_pct, - trading_mode=trading_mode, - ) - return lots, err - - -def calc_order_tick_metrics( - ths_code: str, - lots: float, - price: Optional[float] = None, - *, - direction: str = "long", - trading_mode: str | None = None, -) -> dict: - """下单区展示:最小变动价位、每跳盈亏、保证金等。""" - spec = get_contract_spec(ths_code) - mult = int(spec["mult"]) - tick = float(spec.get("tick_size") or 1.0) - margin_rate = float(spec["margin_rate"]) - lots_i = max(1, int(lots or 1)) - tick_value_per_lot = round(tick * mult, 4) - tick_value_total = round(tick_value_per_lot * lots_i, 2) - prec = price_precision_from_tick(tick) - mark = float(price) if price else 0.0 - margin_per_lot = None - margin_source = "estimate" - if mark > 0: - margin_per_lot, margin_source, spec_used = margin_one_lot( - ths_code, mark, direction=direction, trading_mode=trading_mode, - ) - if spec_used.get("mult"): - mult = int(spec_used["mult"]) - if spec_used.get("tick_size"): - tick = float(spec_used["tick_size"]) - tick_value_per_lot = round(tick * mult, 4) - tick_value_total = round(tick_value_per_lot * lots_i, 2) - prec = price_precision_from_tick(tick) - if margin_per_lot <= 0: - margin_per_lot = round(mark * mult * margin_rate, 2) - margin_source = "estimate" - margin_total = round(margin_per_lot * lots_i, 2) if margin_per_lot else None - return { - "mult": mult, - "tick_size": tick, - "price_precision": prec, - "tick_value_per_lot": tick_value_per_lot, - "tick_value_total": tick_value_total, - "lots": lots_i, - "margin_per_lot": margin_per_lot, - "margin_total": margin_total, - "margin_rate": margin_rate, - "margin_source": margin_source, - } - - -def calc_margin_usage_pct( - positions: list[dict], - capital: float, - *, - extra_symbol: str = "", - extra_lots: int = 0, - extra_price: float = 0, - extra_direction: str = "long", - trading_mode: str | None = None, -) -> float: - """当前持仓 + 拟开仓占权益的保证金比例(%)。""" - cap = float(capital or 0) - if cap <= 0: - return 999.0 - total = 0.0 - for p in positions: - lots = int(p.get("lots") or 0) - if lots <= 0: - continue - ctp_margin = float(p.get("margin") or 0) - if ctp_margin > 0: - total += ctp_margin - continue - sym = (p.get("symbol") or p.get("symbol_code") or "").strip() - entry = float(p.get("avg_price") or p.get("entry_price") or 0) - direction = (p.get("direction") or "long").strip().lower() - if entry <= 0 or not sym: - continue - per_lot, _, _ = margin_one_lot( - sym, entry, direction=direction, trading_mode=trading_mode, - ) - if per_lot <= 0: - spec = get_contract_spec(sym) - per_lot = entry * spec["mult"] * spec["margin_rate"] - total += per_lot * lots - if extra_symbol and extra_lots > 0 and extra_price > 0: - per_lot, _, _ = margin_one_lot( - extra_symbol, extra_price, direction=extra_direction, trading_mode=trading_mode, - ) - if per_lot <= 0: - spec = get_contract_spec(extra_symbol) - per_lot = extra_price * spec["mult"] * spec["margin_rate"] - total += per_lot * extra_lots - return round(total / cap * 100.0, 2) - - -def cap_lots_for_margin_budget( - positions: list[dict], - capital: float, - symbol: str, - direction: str, - price: float, - desired_lots: int, - max_margin_pct: float, - trading_mode: str | None = None, -) -> tuple[int, float]: - """在保证金上限内,返回可加仓手数及占用比例。""" - desired = max(0, int(desired_lots or 0)) - if desired <= 0: - return 0, calc_margin_usage_pct(positions, capital, trading_mode=trading_mode) - for lots in range(desired, 0, -1): - usage = calc_margin_usage_pct( - positions, - capital, - extra_symbol=symbol, - extra_lots=lots, - extra_price=price, - extra_direction=direction, - trading_mode=trading_mode, - ) - if usage <= max_margin_pct: - return lots, usage - return 0, calc_margin_usage_pct( - positions, - capital, - extra_symbol=symbol, - extra_lots=desired, - extra_price=price, - extra_direction=direction, - trading_mode=trading_mode, - ) +# Copyright (c) 2025-2026 马建军. All rights reserved. +# 专有软件 — 未经授权禁止复制、传播、转售。 +# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 +# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md + +"""期货计仓:固定手数 / 固定金额。""" +from __future__ import annotations + +import math +from typing import Optional + +from modules.core.contract_specs import get_contract_spec, margin_one_lot + +MODE_FIXED = "fixed" +MODE_AMOUNT = "amount" +MODE_RISK = "amount" # 兼容旧配置「以损定仓」 + +DEFAULT_MAX_ORDER_LOTS = 50 + + +def normalize_sizing_mode(raw: str) -> str: + m = (raw or MODE_FIXED).strip().lower() + if m == "risk": + m = MODE_AMOUNT + return m if m in (MODE_FIXED, MODE_AMOUNT) else MODE_FIXED + + +def price_precision_from_tick(tick_size: float) -> int: + if tick_size <= 0: + return 0 + s = f"{tick_size:.10f}".rstrip("0").rstrip(".") + if "." not in s: + return 0 + return len(s.split(".")[1]) + + +def _per_lot_risk(entry: float, stop_loss: float, direction: str, ths_code: str) -> tuple[float, Optional[str]]: + spec = get_contract_spec(ths_code) + mult = spec["mult"] + d = (direction or "long").strip().lower() + if d == "short": + per_lot = (stop_loss - entry) * mult + else: + per_lot = (entry - stop_loss) * mult + if per_lot <= 0: + return 0.0, "止损方向与入场价不匹配" + return per_lot, None + + +def calc_lots_by_amount( + entry: float, + stop_loss: float, + direction: str, + amount: float, + ths_code: str, + *, + capital: float = 0.0, + max_lots: Optional[int] = None, + max_margin_pct: float = 30.0, + trading_mode: str | None = None, +) -> tuple[Optional[int], Optional[str], dict]: + """固定金额:先按止损距离算手数,再按保证金上限收紧。返回 (手数, 错误, 详情)。""" + info: dict = { + "lots_by_risk": 0, + "lots_by_margin": None, + "capped_by": None, + } + try: + entry_f = float(entry) + sl_f = float(stop_loss) + budget = float(amount) + cap = float(capital or 0) + except (TypeError, ValueError): + return None, "参数格式错误", info + if entry_f <= 0 or budget <= 0: + return None, "入场价或固定金额无效", info + per_lot_risk, err = _per_lot_risk(entry_f, sl_f, direction, ths_code) + if err: + return None, err, info + lots = int(math.floor(budget / per_lot_risk)) + info["lots_by_risk"] = lots + if lots < 1: + return None, f"按固定金额 {budget:.0f} 元,当前止损距离下不足 1 手", info + if cap > 0: + margin_per_lot, _src, _spec = margin_one_lot( + ths_code, entry_f, direction=direction, trading_mode=trading_mode, + ) + if margin_per_lot <= 0: + spec = get_contract_spec(ths_code) + margin_per_lot = entry_f * spec["mult"] * spec["margin_rate"] + margin_cap = max(1.0, min(100.0, float(max_margin_pct or 30.0))) + max_by_margin = ( + int(math.floor(cap * margin_cap / 100.0 / margin_per_lot)) + if margin_per_lot > 0 else lots + ) + info["lots_by_margin"] = max_by_margin + info["margin_per_lot"] = round(margin_per_lot, 2) + info["max_margin_pct"] = margin_cap + if max_by_margin < 1: + return None, f"按保证金上限 {margin_cap:g}%,当前不足 1 手", info + if max_by_margin < lots: + info["capped_by"] = "margin" + lots = min(lots, max_by_margin) + cap_lots = max_lots if max_lots is not None else DEFAULT_MAX_ORDER_LOTS + if lots > cap_lots: + lots = cap_lots + info["capped_by"] = info.get("capped_by") or "max_lots" + info["lots"] = lots + return lots, None, info + + +def calc_lots_by_risk( + entry: float, + stop_loss: float, + direction: str, + capital: float, + risk_percent: float, + ths_code: str, + *, + max_lots: Optional[int] = None, + max_margin_pct: float = 30.0, + trading_mode: str | None = None, +) -> tuple[Optional[int], Optional[str]]: + """策略等场景:按权益百分比风险预算换算手数。""" + try: + cap = float(capital) + rp = float(risk_percent) + except (TypeError, ValueError): + return None, "参数格式错误" + if cap <= 0 or rp <= 0: + return None, "资金或风险比例无效" + budget = cap * rp / 100.0 + lots, err, info = calc_lots_by_amount( + entry, stop_loss, direction, budget, ths_code, + capital=cap, max_lots=max_lots, max_margin_pct=max_margin_pct, + trading_mode=trading_mode, + ) + return lots, err + + +def calc_order_tick_metrics( + ths_code: str, + lots: float, + price: Optional[float] = None, + *, + direction: str = "long", + trading_mode: str | None = None, +) -> dict: + """下单区展示:最小变动价位、每跳盈亏、保证金等。""" + spec = get_contract_spec(ths_code) + mult = int(spec["mult"]) + tick = float(spec.get("tick_size") or 1.0) + margin_rate = float(spec["margin_rate"]) + lots_i = max(1, int(lots or 1)) + tick_value_per_lot = round(tick * mult, 4) + tick_value_total = round(tick_value_per_lot * lots_i, 2) + prec = price_precision_from_tick(tick) + mark = float(price) if price else 0.0 + margin_per_lot = None + margin_source = "estimate" + if mark > 0: + margin_per_lot, margin_source, spec_used = margin_one_lot( + ths_code, mark, direction=direction, trading_mode=trading_mode, + ) + if spec_used.get("mult"): + mult = int(spec_used["mult"]) + if spec_used.get("tick_size"): + tick = float(spec_used["tick_size"]) + tick_value_per_lot = round(tick * mult, 4) + tick_value_total = round(tick_value_per_lot * lots_i, 2) + prec = price_precision_from_tick(tick) + if margin_per_lot <= 0: + margin_per_lot = round(mark * mult * margin_rate, 2) + margin_source = "estimate" + margin_total = round(margin_per_lot * lots_i, 2) if margin_per_lot else None + return { + "mult": mult, + "tick_size": tick, + "price_precision": prec, + "tick_value_per_lot": tick_value_per_lot, + "tick_value_total": tick_value_total, + "lots": lots_i, + "margin_per_lot": margin_per_lot, + "margin_total": margin_total, + "margin_rate": margin_rate, + "margin_source": margin_source, + } + + +def calc_margin_usage_pct( + positions: list[dict], + capital: float, + *, + extra_symbol: str = "", + extra_lots: int = 0, + extra_price: float = 0, + extra_direction: str = "long", + trading_mode: str | None = None, +) -> float: + """当前持仓 + 拟开仓占权益的保证金比例(%)。""" + cap = float(capital or 0) + if cap <= 0: + return 999.0 + total = 0.0 + for p in positions: + lots = int(p.get("lots") or 0) + if lots <= 0: + continue + ctp_margin = float(p.get("margin") or 0) + if ctp_margin > 0: + total += ctp_margin + continue + sym = (p.get("symbol") or p.get("symbol_code") or "").strip() + entry = float(p.get("avg_price") or p.get("entry_price") or 0) + direction = (p.get("direction") or "long").strip().lower() + if entry <= 0 or not sym: + continue + per_lot, _, _ = margin_one_lot( + sym, entry, direction=direction, trading_mode=trading_mode, + ) + if per_lot <= 0: + spec = get_contract_spec(sym) + per_lot = entry * spec["mult"] * spec["margin_rate"] + total += per_lot * lots + if extra_symbol and extra_lots > 0 and extra_price > 0: + per_lot, _, _ = margin_one_lot( + extra_symbol, extra_price, direction=extra_direction, trading_mode=trading_mode, + ) + if per_lot <= 0: + spec = get_contract_spec(extra_symbol) + per_lot = extra_price * spec["mult"] * spec["margin_rate"] + total += per_lot * extra_lots + return round(total / cap * 100.0, 2) + + +def cap_lots_for_margin_budget( + positions: list[dict], + capital: float, + symbol: str, + direction: str, + price: float, + desired_lots: int, + max_margin_pct: float, + trading_mode: str | None = None, +) -> tuple[int, float]: + """在保证金上限内,返回可加仓手数及占用比例。""" + desired = max(0, int(desired_lots or 0)) + if desired <= 0: + return 0, calc_margin_usage_pct(positions, capital, trading_mode=trading_mode) + for lots in range(desired, 0, -1): + usage = calc_margin_usage_pct( + positions, + capital, + extra_symbol=symbol, + extra_lots=lots, + extra_price=price, + extra_direction=direction, + trading_mode=trading_mode, + ) + if usage <= max_margin_pct: + return lots, usage + return 0, calc_margin_usage_pct( + positions, + capital, + extra_symbol=symbol, + extra_lots=desired, + extra_price=price, + extra_direction=direction, + trading_mode=trading_mode, + ) diff --git a/position_stream.py b/modules/trading/position_stream.py similarity index 94% rename from position_stream.py rename to modules/trading/position_stream.py index dadc36e..192ef0a 100644 --- a/position_stream.py +++ b/modules/trading/position_stream.py @@ -1,113 +1,113 @@ -# Copyright (c) 2025-2026 马建军. All rights reserved. -# 专有软件 — 未经授权禁止复制、传播、转售。 -# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 -# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md - -"""持仓监控:后台拉取 CTP 并 SSE 推送给前端(避免每次刷新阻塞读柜台)。""" -from __future__ import annotations - -import logging -import queue -import threading -import time -from typing import Callable, Optional - -from kline_stream import sse_format -from market_sessions import is_trading_session - -logger = logging.getLogger(__name__) - -PUSH_INTERVAL_SEC = 1 -IDLE_INTERVAL_SEC = 5 - - -class PositionStreamHub: - def __init__(self) -> None: - self._lock = threading.Lock() - self._subs: list[queue.Queue] = [] - self._snapshot: Optional[dict] = None - self._snapshot_ts: float = 0.0 - - def subscribe(self) -> queue.Queue: - q: queue.Queue = queue.Queue(maxsize=16) - with self._lock: - self._subs.append(q) - return q - - def unsubscribe(self, q: queue.Queue) -> None: - with self._lock: - try: - self._subs.remove(q) - except ValueError: - pass - - def get_snapshot(self) -> Optional[dict]: - with self._lock: - return dict(self._snapshot) if self._snapshot else None - - def set_snapshot(self, data: dict) -> None: - with self._lock: - self._snapshot = dict(data) - self._snapshot_ts = time.time() - - def _fanout(self, event: str, data: dict) -> None: - msg = {"event": event, "data": data} - with self._lock: - subs = list(self._subs) - for q in subs: - try: - q.put_nowait(msg) - except queue.Full: - try: - q.get_nowait() - except queue.Empty: - pass - try: - q.put_nowait(msg) - except queue.Full: - pass - - def broadcast(self, event: str, data: dict) -> None: - self.set_snapshot(data) - self._fanout(event, data) - - def push_event(self, event: str, data: dict) -> None: - """SSE 推送,不覆盖 positions 全量快照。""" - self._fanout(event, data) - - -position_hub = PositionStreamHub() - - -def start_position_worker( - *, - refresh_fn: Callable[[], dict], - interval: int = PUSH_INTERVAL_SEC, - idle_interval: int = IDLE_INTERVAL_SEC, -) -> None: - """后台定时刷新持仓快照并 SSE 广播。""" - - def _loop() -> None: - while True: - sleep_sec = idle_interval - try: - payload = refresh_fn() - if payload: - position_hub.broadcast("positions", payload) - ctp_st = (payload or {}).get("ctp_status") or {} - connected = bool(ctp_st.get("connected")) - in_session = bool((payload or {}).get("trading_session")) - rows = (payload or {}).get("rows") or [] - has_sl_tp = any( - r.get("stop_loss") is not None or r.get("take_profit") is not None - for r in rows - ) - if connected and in_session: - sleep_sec = max(1, interval) - elif connected: - sleep_sec = max(2, min(idle_interval, 3)) - except Exception as exc: - logger.warning("position worker failed: %s", exc) - time.sleep(sleep_sec) - - threading.Thread(target=_loop, daemon=True, name="position-stream").start() +# Copyright (c) 2025-2026 马建军. All rights reserved. +# 专有软件 — 未经授权禁止复制、传播、转售。 +# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 +# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md + +"""持仓监控:后台拉取 CTP 并 SSE 推送给前端(避免每次刷新阻塞读柜台)。""" +from __future__ import annotations + +import logging +import queue +import threading +import time +from typing import Callable, Optional + +from modules.market.kline_stream import sse_format +from modules.market.market_sessions import is_trading_session + +logger = logging.getLogger(__name__) + +PUSH_INTERVAL_SEC = 1 +IDLE_INTERVAL_SEC = 5 + + +class PositionStreamHub: + def __init__(self) -> None: + self._lock = threading.Lock() + self._subs: list[queue.Queue] = [] + self._snapshot: Optional[dict] = None + self._snapshot_ts: float = 0.0 + + def subscribe(self) -> queue.Queue: + q: queue.Queue = queue.Queue(maxsize=16) + with self._lock: + self._subs.append(q) + return q + + def unsubscribe(self, q: queue.Queue) -> None: + with self._lock: + try: + self._subs.remove(q) + except ValueError: + pass + + def get_snapshot(self) -> Optional[dict]: + with self._lock: + return dict(self._snapshot) if self._snapshot else None + + def set_snapshot(self, data: dict) -> None: + with self._lock: + self._snapshot = dict(data) + self._snapshot_ts = time.time() + + def _fanout(self, event: str, data: dict) -> None: + msg = {"event": event, "data": data} + with self._lock: + subs = list(self._subs) + for q in subs: + try: + q.put_nowait(msg) + except queue.Full: + try: + q.get_nowait() + except queue.Empty: + pass + try: + q.put_nowait(msg) + except queue.Full: + pass + + def broadcast(self, event: str, data: dict) -> None: + self.set_snapshot(data) + self._fanout(event, data) + + def push_event(self, event: str, data: dict) -> None: + """SSE 推送,不覆盖 positions 全量快照。""" + self._fanout(event, data) + + +position_hub = PositionStreamHub() + + +def start_position_worker( + *, + refresh_fn: Callable[[], dict], + interval: int = PUSH_INTERVAL_SEC, + idle_interval: int = IDLE_INTERVAL_SEC, +) -> None: + """后台定时刷新持仓快照并 SSE 广播。""" + + def _loop() -> None: + while True: + sleep_sec = idle_interval + try: + payload = refresh_fn() + if payload: + position_hub.broadcast("positions", payload) + ctp_st = (payload or {}).get("ctp_status") or {} + connected = bool(ctp_st.get("connected")) + in_session = bool((payload or {}).get("trading_session")) + rows = (payload or {}).get("rows") or [] + has_sl_tp = any( + r.get("stop_loss") is not None or r.get("take_profit") is not None + for r in rows + ) + if connected and in_session: + sleep_sec = max(1, interval) + elif connected: + sleep_sec = max(2, min(idle_interval, 3)) + except Exception as exc: + logger.warning("position worker failed: %s", exc) + time.sleep(sleep_sec) + + threading.Thread(target=_loop, daemon=True, name="position-stream").start() diff --git a/product_recommend.py b/modules/trading/product_recommend.py similarity index 94% rename from product_recommend.py rename to modules/trading/product_recommend.py index 3889409..480c70b 100644 --- a/product_recommend.py +++ b/modules/trading/product_recommend.py @@ -1,335 +1,335 @@ -# Copyright (c) 2025-2026 马建军. All rights reserved. -# 专有软件 — 未经授权禁止复制、传播、转售。 -# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 -# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md - -"""按账户资金筛选可开仓品种(保证金与仓位纪律)。""" -from __future__ import annotations - -import logging -import math -from concurrent.futures import ThreadPoolExecutor -from typing import Callable, Optional - -from contract_specs import get_contract_spec, margin_one_lot -from fee_specs import calc_fee_breakdown -from recommend_trend import analyze_product_daily, sort_recommend_by_trend -from symbols import PRODUCTS, product_category, product_has_night_session - -logger = logging.getLogger(__name__) - -# 权益不超过该值时,仅允许下列品种(可开仓列表、品种下拉、开仓报单) -SMALL_ACCOUNT_CAPITAL_MAX = 200_000.0 -# 未连接 CTP 时,可开仓品种表按该权益估算最大手数(与参考资金设置无关) -DISCONNECTED_RECOMMEND_CAPITAL = 100_000.0 -SMALL_ACCOUNT_PRODUCT_THS = frozenset({"c", "m", "MA", "rb"}) -SMALL_ACCOUNT_SCOPE_LABEL = "玉米、豆粕、甲醇、螺纹钢" -SMALL_ACCOUNT_RECOMMENDED_OPEN_MARGIN_PCT = 30.0 -SMALL_ACCOUNT_RECOMMENDED_ROLL_MARGIN_PCT = 40.0 - - -def small_account_margin_recommendations() -> dict: - """20 万以下账户建议的保证金比例(供系统设置参考)。""" - wan = int(SMALL_ACCOUNT_CAPITAL_MAX // 10_000) - return { - "open_margin_pct": SMALL_ACCOUNT_RECOMMENDED_OPEN_MARGIN_PCT, - "roll_margin_pct": SMALL_ACCOUNT_RECOMMENDED_ROLL_MARGIN_PCT, - "label": ( - f"权益 {wan} 万以下建议:开仓保证金上限 " - f"{int(SMALL_ACCOUNT_RECOMMENDED_OPEN_MARGIN_PCT)}%," - f"滚仓总保证金不超过 {int(SMALL_ACCOUNT_RECOMMENDED_ROLL_MARGIN_PCT)}%" - ), - } - - -def small_account_scope_hint(*, ctp_connected: bool = True) -> str: - wan = int(SMALL_ACCOUNT_CAPITAL_MAX // 10_000) - if not ctp_connected: - rec_wan = int(DISCONNECTED_RECOMMEND_CAPITAL // 10_000) - return ( - f"未连接 CTP,按 {rec_wan} 万权益估算最大手数," - f"仅显示并可交易 {SMALL_ACCOUNT_SCOPE_LABEL}" - ) - return f"权益 {wan} 万以下仅显示并可交易:{SMALL_ACCOUNT_SCOPE_LABEL}" - - -def small_account_scope_status_label() -> str: - wan = int(SMALL_ACCOUNT_CAPITAL_MAX // 10_000) - return f"权益{wan}万以下限{SMALL_ACCOUNT_SCOPE_LABEL}" - - -def should_apply_small_account_scope( - capital: float, - *, - ctp_connected: bool, -) -> bool: - """SimNow/实盘一致:未连接 CTP 时默认按 20 万以下四品种范围。""" - if not ctp_connected: - return True - return is_small_account(capital) - - -def filter_rows_for_account_scope( - rows: list[dict], - capital: float, - *, - ctp_connected: bool, -) -> list[dict]: - if not should_apply_small_account_scope(capital, ctp_connected=ctp_connected): - return rows - return [r for r in rows if product_in_small_account_whitelist(r.get("ths") or "")] - - -def normalize_product_ths(ths: str) -> str: - import re - s = (ths or "").strip() - m = re.match(r"^([A-Za-z]+)", s) - return m.group(1) if m else s - - -def is_small_account(capital: float) -> bool: - cap = float(capital or 0) - return 0 < cap <= SMALL_ACCOUNT_CAPITAL_MAX - - -def product_in_small_account_whitelist(ths_or_product) -> bool: - if isinstance(ths_or_product, dict): - key = (ths_or_product.get("ths") or "").strip() - else: - key = normalize_product_ths(str(ths_or_product or "")) - if not key: - return False - root = normalize_product_ths(key) - if root in SMALL_ACCOUNT_PRODUCT_THS: - return True - upper = root.upper() - return upper in {x.upper() for x in SMALL_ACCOUNT_PRODUCT_THS} - - -def assert_product_allowed_for_capital( - ths: str, - capital: float, - *, - ctp_connected: bool = True, -) -> Optional[str]: - """小账户品种白名单校验;通过返回 None。""" - if not should_apply_small_account_scope(capital, ctp_connected=ctp_connected): - return None - if product_in_small_account_whitelist(ths): - return None - wan = int(SMALL_ACCOUNT_CAPITAL_MAX // 10_000) - if not ctp_connected: - return f"未连接 CTP,仅可交易:{SMALL_ACCOUNT_SCOPE_LABEL}" - return f"权益 {wan} 万以下仅可交易:{SMALL_ACCOUNT_SCOPE_LABEL}" - - -def filter_products_for_capital( - products: list[dict], - capital: float, - *, - ctp_connected: bool = True, -) -> list[dict]: - if not should_apply_small_account_scope(capital, ctp_connected=ctp_connected): - return list(products) - return [p for p in products if product_in_small_account_whitelist(p)] - - -def _attach_turnover(row: dict) -> None: - """成交额 = 昨日成交量(手) × 昨收 × 合约乘数。""" - try: - vol = float(row.get("volume") or 0) - price = float(row.get("prev_close") or row.get("price") or 0) - mult = float(row.get("mult") or 0) - except (TypeError, ValueError): - return - if vol > 0 and price > 0 and mult > 0: - row["turnover"] = round(vol * price * mult, 2) - - -def _letters_from_ths(ths_code: str) -> str: - import re - m = re.match(r"^([A-Za-z]+)", (ths_code or "").strip()) - return m.group(1) if m else "" - - -def assess_product_for_capital( - product: dict, - capital: float, - price: Optional[float], - *, - max_margin_pct: float = 30.0, - default_stop_ticks: int = 20, - reward_risk_ratio: float = 2.0, - trading_mode: str = "simulation", - ctp_connected: bool = True, - main_code: str = "", - margin_used: float = 0.0, -) -> dict: - """评估单品种在当前资金下是否可交易。""" - ths = product.get("ths") or "" - name = product.get("name") or ths - exchange = product.get("exchange") or "" - category = product.get("category") or product_category(ths) - spec = get_contract_spec(ths + "8888") - mult = spec["mult"] - margin_rate = spec["margin_rate"] - tick = float(spec.get("tick_size") or 1.0) - p = float(price) if price and price > 0 else 0.0 - cap = float(capital or 0) - margin_pct = max(1.0, min(100.0, float(max_margin_pct or 30.0))) - - if should_apply_small_account_scope(cap, ctp_connected=ctp_connected) and not product_in_small_account_whitelist(product): - return { - "ths": ths, - "name": name, - "exchange": exchange, - "category": category, - "mult": spec["mult"], - "tick_size": tick, - "status": "blocked", - "status_label": small_account_scope_status_label(), - "min_capital_one_lot": None, - "margin_one_lot": None, - "max_lots": 0, - "risk_one_lot_1pct": None, - "has_night_session": product_has_night_session(product), - } - - if p <= 0: - return { - "ths": ths, - "name": name, - "exchange": exchange, - "category": category, - "mult": mult, - "tick_size": tick, - "status": "no_price", - "status_label": "暂无行情", - "min_capital_one_lot": None, - "margin_one_lot": None, - "max_lots": 0, - "risk_one_lot_1pct": None, - "has_night_session": product_has_night_session(product), - } - - margin_source = None - code_for_margin = (main_code or "").strip() or (ths + "8888") - if p > 0 and ctp_connected: - margin_one, margin_source, spec_used = margin_one_lot( - code_for_margin, p, direction="max", trading_mode=trading_mode, - ) - if spec_used.get("mult"): - mult = spec_used["mult"] - if spec_used.get("tick_size"): - tick = float(spec_used["tick_size"]) - else: - margin_one = p * mult * margin_rate - min_capital = margin_one / (margin_pct / 100.0) if margin_pct > 0 else margin_one - margin_budget = cap * margin_pct / 100.0 if cap > 0 else 0.0 - margin_budget = max(0.0, margin_budget - max(0.0, float(margin_used or 0))) - max_lots = int(math.floor(margin_budget / margin_one)) if margin_one > 0 and margin_budget > 0 else 0 - stop_dist = tick * default_stop_ticks - risk_one_lot = stop_dist * mult - risk_pct_1lot = (risk_one_lot / cap * 100) if cap > 0 else 999.0 - ref_sl = round(p - stop_dist, 4) - ref_tp = round(p + stop_dist * reward_risk_ratio, 4) - fee_ths = ths + "8888" - try: - fee_info = calc_fee_breakdown( - fee_ths, p, p, 1.0, open_time="", close_time="", trading_mode=trading_mode, - ) - except Exception as exc: - logger.debug("recommend fee calc failed %s: %s", ths, exc) - fee_info = {"open_fee": 0.0, "total_fee": 0.0} - - can_margin = max_lots >= 1 - can_risk = cap > 0 and risk_one_lot <= cap * 0.01 - - if can_margin and can_risk: - status, label = "ok", f"最大 {max_lots} 手" - elif can_margin: - status, label = "margin_ok", f"最大 {max_lots} 手·止损偏宽" - else: - status, label = "blocked", "资金不足" - if margin_source == "ctp" and can_margin: - label += "(柜台保证金)" - - row_out = { - "ths": ths, - "name": name, - "exchange": exchange, - "category": category, - "price": round(p, 4), - "mult": mult, - "tick_size": tick, - "margin_one_lot": round(margin_one, 2), - "min_capital_one_lot": round(min_capital, 2), - "max_lots": max_lots, - "margin_budget": round(margin_budget, 2), - "max_margin_pct": margin_pct, - "risk_one_lot_1pct": round(risk_one_lot, 2), - "risk_pct_1lot_at_1pct_rule": round(risk_pct_1lot, 2), - "ref_stop_loss": ref_sl, - "ref_take_profit": ref_tp, - "open_fee_one_lot": fee_info["open_fee"], - "roundtrip_fee_one_lot": fee_info["total_fee"], - "status": status, - "status_label": label, - "has_night_session": product_has_night_session(product), - } - if margin_source: - row_out["margin_source"] = margin_source - return row_out - - -def list_product_recommendations( - capital: float, - quote_fn: Callable[[str], Optional[dict]], - *, - max_margin_pct: float = 30.0, - trading_mode: str = "simulation", - ctp_connected: bool = True, - margin_used: float = 0.0, -) -> list[dict]: - """扫描全部品种并排序:可开且纪律友好 > 可开 > 不足。quote_fn(品种代码) -> {price, ths_code, ...}""" - - def _one(product: dict) -> dict: - ths = product["ths"] - try: - quote = quote_fn(ths) or {} - price = quote.get("price") - main_code = (quote.get("ths_code") or "").strip() - row = assess_product_for_capital( - product, capital, price, - max_margin_pct=max_margin_pct, - trading_mode=trading_mode, - ctp_connected=ctp_connected, - main_code=main_code, - margin_used=margin_used, - ) - row["main_code"] = main_code - if main_code: - row.update(analyze_product_daily(main_code)) - _attach_turnover(row) - return row - except Exception as exc: - logger.warning("recommend product failed %s: %s", ths, exc) - spec = get_contract_spec(ths + "8888") - return { - "ths": ths, - "name": product.get("name") or ths, - "exchange": product.get("exchange") or "", - "category": product.get("category") or product_category(ths), - "mult": spec["mult"], - "tick_size": float(spec.get("tick_size") or 1.0), - "status": "no_price", - "status_label": "计算失败", - "main_code": "", - "max_lots": 0, - "has_night_session": product_has_night_session(product), - } - - with ThreadPoolExecutor(max_workers=10) as pool: - products = filter_products_for_capital(PRODUCTS, capital) - rows = list(pool.map(_one, products)) - return sort_recommend_by_trend(rows) +# Copyright (c) 2025-2026 马建军. All rights reserved. +# 专有软件 — 未经授权禁止复制、传播、转售。 +# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 +# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md + +"""按账户资金筛选可开仓品种(保证金与仓位纪律)。""" +from __future__ import annotations + +import logging +import math +from concurrent.futures import ThreadPoolExecutor +from typing import Callable, Optional + +from modules.core.contract_specs import get_contract_spec, margin_one_lot +from modules.fees.fee_specs import calc_fee_breakdown +from modules.trading.recommend_trend import analyze_product_daily, sort_recommend_by_trend +from modules.core.symbols import PRODUCTS, product_category, product_has_night_session + +logger = logging.getLogger(__name__) + +# 权益不超过该值时,仅允许下列品种(可开仓列表、品种下拉、开仓报单) +SMALL_ACCOUNT_CAPITAL_MAX = 200_000.0 +# 未连接 CTP 时,可开仓品种表按该权益估算最大手数(与参考资金设置无关) +DISCONNECTED_RECOMMEND_CAPITAL = 100_000.0 +SMALL_ACCOUNT_PRODUCT_THS = frozenset({"c", "m", "MA", "rb"}) +SMALL_ACCOUNT_SCOPE_LABEL = "玉米、豆粕、甲醇、螺纹钢" +SMALL_ACCOUNT_RECOMMENDED_OPEN_MARGIN_PCT = 30.0 +SMALL_ACCOUNT_RECOMMENDED_ROLL_MARGIN_PCT = 40.0 + + +def small_account_margin_recommendations() -> dict: + """20 万以下账户建议的保证金比例(供系统设置参考)。""" + wan = int(SMALL_ACCOUNT_CAPITAL_MAX // 10_000) + return { + "open_margin_pct": SMALL_ACCOUNT_RECOMMENDED_OPEN_MARGIN_PCT, + "roll_margin_pct": SMALL_ACCOUNT_RECOMMENDED_ROLL_MARGIN_PCT, + "label": ( + f"权益 {wan} 万以下建议:开仓保证金上限 " + f"{int(SMALL_ACCOUNT_RECOMMENDED_OPEN_MARGIN_PCT)}%," + f"滚仓总保证金不超过 {int(SMALL_ACCOUNT_RECOMMENDED_ROLL_MARGIN_PCT)}%" + ), + } + + +def small_account_scope_hint(*, ctp_connected: bool = True) -> str: + wan = int(SMALL_ACCOUNT_CAPITAL_MAX // 10_000) + if not ctp_connected: + rec_wan = int(DISCONNECTED_RECOMMEND_CAPITAL // 10_000) + return ( + f"未连接 CTP,按 {rec_wan} 万权益估算最大手数," + f"仅显示并可交易 {SMALL_ACCOUNT_SCOPE_LABEL}" + ) + return f"权益 {wan} 万以下仅显示并可交易:{SMALL_ACCOUNT_SCOPE_LABEL}" + + +def small_account_scope_status_label() -> str: + wan = int(SMALL_ACCOUNT_CAPITAL_MAX // 10_000) + return f"权益{wan}万以下限{SMALL_ACCOUNT_SCOPE_LABEL}" + + +def should_apply_small_account_scope( + capital: float, + *, + ctp_connected: bool, +) -> bool: + """SimNow/实盘一致:未连接 CTP 时默认按 20 万以下四品种范围。""" + if not ctp_connected: + return True + return is_small_account(capital) + + +def filter_rows_for_account_scope( + rows: list[dict], + capital: float, + *, + ctp_connected: bool, +) -> list[dict]: + if not should_apply_small_account_scope(capital, ctp_connected=ctp_connected): + return rows + return [r for r in rows if product_in_small_account_whitelist(r.get("ths") or "")] + + +def normalize_product_ths(ths: str) -> str: + import re + s = (ths or "").strip() + m = re.match(r"^([A-Za-z]+)", s) + return m.group(1) if m else s + + +def is_small_account(capital: float) -> bool: + cap = float(capital or 0) + return 0 < cap <= SMALL_ACCOUNT_CAPITAL_MAX + + +def product_in_small_account_whitelist(ths_or_product) -> bool: + if isinstance(ths_or_product, dict): + key = (ths_or_product.get("ths") or "").strip() + else: + key = normalize_product_ths(str(ths_or_product or "")) + if not key: + return False + root = normalize_product_ths(key) + if root in SMALL_ACCOUNT_PRODUCT_THS: + return True + upper = root.upper() + return upper in {x.upper() for x in SMALL_ACCOUNT_PRODUCT_THS} + + +def assert_product_allowed_for_capital( + ths: str, + capital: float, + *, + ctp_connected: bool = True, +) -> Optional[str]: + """小账户品种白名单校验;通过返回 None。""" + if not should_apply_small_account_scope(capital, ctp_connected=ctp_connected): + return None + if product_in_small_account_whitelist(ths): + return None + wan = int(SMALL_ACCOUNT_CAPITAL_MAX // 10_000) + if not ctp_connected: + return f"未连接 CTP,仅可交易:{SMALL_ACCOUNT_SCOPE_LABEL}" + return f"权益 {wan} 万以下仅可交易:{SMALL_ACCOUNT_SCOPE_LABEL}" + + +def filter_products_for_capital( + products: list[dict], + capital: float, + *, + ctp_connected: bool = True, +) -> list[dict]: + if not should_apply_small_account_scope(capital, ctp_connected=ctp_connected): + return list(products) + return [p for p in products if product_in_small_account_whitelist(p)] + + +def _attach_turnover(row: dict) -> None: + """成交额 = 昨日成交量(手) × 昨收 × 合约乘数。""" + try: + vol = float(row.get("volume") or 0) + price = float(row.get("prev_close") or row.get("price") or 0) + mult = float(row.get("mult") or 0) + except (TypeError, ValueError): + return + if vol > 0 and price > 0 and mult > 0: + row["turnover"] = round(vol * price * mult, 2) + + +def _letters_from_ths(ths_code: str) -> str: + import re + m = re.match(r"^([A-Za-z]+)", (ths_code or "").strip()) + return m.group(1) if m else "" + + +def assess_product_for_capital( + product: dict, + capital: float, + price: Optional[float], + *, + max_margin_pct: float = 30.0, + default_stop_ticks: int = 20, + reward_risk_ratio: float = 2.0, + trading_mode: str = "simulation", + ctp_connected: bool = True, + main_code: str = "", + margin_used: float = 0.0, +) -> dict: + """评估单品种在当前资金下是否可交易。""" + ths = product.get("ths") or "" + name = product.get("name") or ths + exchange = product.get("exchange") or "" + category = product.get("category") or product_category(ths) + spec = get_contract_spec(ths + "8888") + mult = spec["mult"] + margin_rate = spec["margin_rate"] + tick = float(spec.get("tick_size") or 1.0) + p = float(price) if price and price > 0 else 0.0 + cap = float(capital or 0) + margin_pct = max(1.0, min(100.0, float(max_margin_pct or 30.0))) + + if should_apply_small_account_scope(cap, ctp_connected=ctp_connected) and not product_in_small_account_whitelist(product): + return { + "ths": ths, + "name": name, + "exchange": exchange, + "category": category, + "mult": spec["mult"], + "tick_size": tick, + "status": "blocked", + "status_label": small_account_scope_status_label(), + "min_capital_one_lot": None, + "margin_one_lot": None, + "max_lots": 0, + "risk_one_lot_1pct": None, + "has_night_session": product_has_night_session(product), + } + + if p <= 0: + return { + "ths": ths, + "name": name, + "exchange": exchange, + "category": category, + "mult": mult, + "tick_size": tick, + "status": "no_price", + "status_label": "暂无行情", + "min_capital_one_lot": None, + "margin_one_lot": None, + "max_lots": 0, + "risk_one_lot_1pct": None, + "has_night_session": product_has_night_session(product), + } + + margin_source = None + code_for_margin = (main_code or "").strip() or (ths + "8888") + if p > 0 and ctp_connected: + margin_one, margin_source, spec_used = margin_one_lot( + code_for_margin, p, direction="max", trading_mode=trading_mode, + ) + if spec_used.get("mult"): + mult = spec_used["mult"] + if spec_used.get("tick_size"): + tick = float(spec_used["tick_size"]) + else: + margin_one = p * mult * margin_rate + min_capital = margin_one / (margin_pct / 100.0) if margin_pct > 0 else margin_one + margin_budget = cap * margin_pct / 100.0 if cap > 0 else 0.0 + margin_budget = max(0.0, margin_budget - max(0.0, float(margin_used or 0))) + max_lots = int(math.floor(margin_budget / margin_one)) if margin_one > 0 and margin_budget > 0 else 0 + stop_dist = tick * default_stop_ticks + risk_one_lot = stop_dist * mult + risk_pct_1lot = (risk_one_lot / cap * 100) if cap > 0 else 999.0 + ref_sl = round(p - stop_dist, 4) + ref_tp = round(p + stop_dist * reward_risk_ratio, 4) + fee_ths = ths + "8888" + try: + fee_info = calc_fee_breakdown( + fee_ths, p, p, 1.0, open_time="", close_time="", trading_mode=trading_mode, + ) + except Exception as exc: + logger.debug("recommend fee calc failed %s: %s", ths, exc) + fee_info = {"open_fee": 0.0, "total_fee": 0.0} + + can_margin = max_lots >= 1 + can_risk = cap > 0 and risk_one_lot <= cap * 0.01 + + if can_margin and can_risk: + status, label = "ok", f"最大 {max_lots} 手" + elif can_margin: + status, label = "margin_ok", f"最大 {max_lots} 手·止损偏宽" + else: + status, label = "blocked", "资金不足" + if margin_source == "ctp" and can_margin: + label += "(柜台保证金)" + + row_out = { + "ths": ths, + "name": name, + "exchange": exchange, + "category": category, + "price": round(p, 4), + "mult": mult, + "tick_size": tick, + "margin_one_lot": round(margin_one, 2), + "min_capital_one_lot": round(min_capital, 2), + "max_lots": max_lots, + "margin_budget": round(margin_budget, 2), + "max_margin_pct": margin_pct, + "risk_one_lot_1pct": round(risk_one_lot, 2), + "risk_pct_1lot_at_1pct_rule": round(risk_pct_1lot, 2), + "ref_stop_loss": ref_sl, + "ref_take_profit": ref_tp, + "open_fee_one_lot": fee_info["open_fee"], + "roundtrip_fee_one_lot": fee_info["total_fee"], + "status": status, + "status_label": label, + "has_night_session": product_has_night_session(product), + } + if margin_source: + row_out["margin_source"] = margin_source + return row_out + + +def list_product_recommendations( + capital: float, + quote_fn: Callable[[str], Optional[dict]], + *, + max_margin_pct: float = 30.0, + trading_mode: str = "simulation", + ctp_connected: bool = True, + margin_used: float = 0.0, +) -> list[dict]: + """扫描全部品种并排序:可开且纪律友好 > 可开 > 不足。quote_fn(品种代码) -> {price, ths_code, ...}""" + + def _one(product: dict) -> dict: + ths = product["ths"] + try: + quote = quote_fn(ths) or {} + price = quote.get("price") + main_code = (quote.get("ths_code") or "").strip() + row = assess_product_for_capital( + product, capital, price, + max_margin_pct=max_margin_pct, + trading_mode=trading_mode, + ctp_connected=ctp_connected, + main_code=main_code, + margin_used=margin_used, + ) + row["main_code"] = main_code + if main_code: + row.update(analyze_product_daily(main_code)) + _attach_turnover(row) + return row + except Exception as exc: + logger.warning("recommend product failed %s: %s", ths, exc) + spec = get_contract_spec(ths + "8888") + return { + "ths": ths, + "name": product.get("name") or ths, + "exchange": product.get("exchange") or "", + "category": product.get("category") or product_category(ths), + "mult": spec["mult"], + "tick_size": float(spec.get("tick_size") or 1.0), + "status": "no_price", + "status_label": "计算失败", + "main_code": "", + "max_lots": 0, + "has_night_session": product_has_night_session(product), + } + + with ThreadPoolExecutor(max_workers=10) as pool: + products = filter_products_for_capital(PRODUCTS, capital) + rows = list(pool.map(_one, products)) + return sort_recommend_by_trend(rows) diff --git a/recommend_store.py b/modules/trading/recommend_store.py similarity index 92% rename from recommend_store.py rename to modules/trading/recommend_store.py index f089592..5790399 100644 --- a/recommend_store.py +++ b/modules/trading/recommend_store.py @@ -1,399 +1,399 @@ -# Copyright (c) 2025-2026 马建军. All rights reserved. -# 专有软件 — 未经授权禁止复制、传播、转售。 -# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 -# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md - -"""可开仓品种:计算、按资金过滤、SQLite 缓存。""" -from __future__ import annotations - -import json -import logging -import math -from datetime import datetime -from typing import Callable, Optional - -from contract_specs import get_contract_spec, margin_one_lot -from fee_specs import ensure_fee_rates_schema -from product_recommend import ( - _attach_turnover, - filter_rows_for_account_scope, - list_product_recommendations, -) -from recommend_trend import sort_recommend_by_trend -from symbols import product_category - -logger = logging.getLogger(__name__) - -RECOMMEND_CACHE_SQL = """ -CREATE TABLE IF NOT EXISTS product_recommend_cache ( - id INTEGER PRIMARY KEY CHECK (id = 1), - capital REAL NOT NULL DEFAULT 0, - rows_json TEXT NOT NULL DEFAULT '[]', - updated_at TEXT -) -""" - - -def ensure_recommend_tables(conn) -> None: - conn.execute(RECOMMEND_CACHE_SQL) - - -def filter_affordable_recommendations(rows: list[dict]) -> list[dict]: - """仅保留当前资金可开 1 手的品种(不含资金不足、无行情)。""" - return [r for r in rows if r.get("status") in ("ok", "margin_ok")] - - -def rows_missing_max_lots(rows: list[dict]) -> bool: - """缓存是否为旧版(缺少最大手数字段)。""" - if not rows: - return False - return any("max_lots" not in r for r in rows) - - -def rows_missing_trend(rows: list[dict]) -> bool: - """缓存是否为旧版(缺少走势字段)。""" - if not rows: - return False - return any("trend" not in r for r in rows) - - -def rows_missing_daily_stats(rows: list[dict]) -> bool: - """缓存是否为旧版(缺少跳空/量价字段)。""" - if not rows: - return False - return any("gap" not in r for r in rows) - - -def rows_missing_category(rows: list[dict]) -> bool: - if not rows: - return False - return any("category" not in r for r in rows) - - -def rows_missing_turnover(rows: list[dict]) -> bool: - if not rows: - return False - return any("turnover" not in r for r in rows) - - -def rows_missing_contract_spec(rows: list[dict]) -> bool: - if not rows: - return False - return any("mult" not in r or "tick_size" not in r for r in rows) - - -def recommend_cache_needs_refresh( - cached: dict, - *, - capital: float = 0.0, -) -> bool: - """是否需要重新拉行情计算可开仓列表。""" - if recommend_cache_stale(cached.get("updated_at")): - return True - rows = cached.get("rows") or [] - if rows_missing_max_lots(rows): - return True - if rows_missing_trend(rows): - return True - if rows_missing_daily_stats(rows): - return True - if rows_missing_category(rows): - return True - if rows_missing_turnover(rows): - return True - if rows_missing_contract_spec(rows): - return True - if float(capital or 0) > 0 and not rows: - return True - return False - - -def _ctp_connected_for_mode(trading_mode: str) -> bool: - try: - from position_stream import position_hub - - snap = position_hub.get_snapshot() or {} - st = snap.get("ctp_status") - if isinstance(st, dict) and st: - return bool(st.get("connected")) - except Exception: - pass - del trading_mode - return False - - -def recommend_margin_used(trading_mode: str) -> float: - """当前持仓已占用保证金(各持仓 CTP 回报之和,与柜台持仓保证金一致)。""" - try: - from position_stream import position_hub - - snap = position_hub.get_snapshot() or {} - raw = snap.get("margin_used") - if raw is not None: - return max(0.0, float(raw or 0)) - except Exception: - pass - if not _ctp_connected_for_mode(trading_mode): - return 0.0 - try: - from vnpy_bridge import ctp_account_margin_used, ctp_sum_position_margins - - total = ctp_sum_position_margins( - trading_mode, refresh_if_empty=False, refresh_margin=True, - ) - if total > 0: - return total - used = ctp_account_margin_used(trading_mode) - return float(used) if used and used > 0 else 0.0 - except Exception as exc: - logger.debug("recommend_margin_used: %s", exc) - return 0.0 - - -def margin_budget_info( - capital: float, - max_margin_pct: float, - margin_used: float = 0.0, -) -> dict[str, float]: - """保证金上限总额、已占用、剩余可开额度。""" - cap = float(capital or 0) - pct = max(1.0, min(100.0, float(max_margin_pct or 30.0))) - total = cap * pct / 100.0 if cap > 0 else 0.0 - used = max(0.0, float(margin_used or 0)) - remaining = max(0.0, total - used) - return { - "margin_budget_total": round(total, 2), - "margin_used": round(used, 2), - "margin_budget_remaining": round(remaining, 2), - "max_margin_pct": pct, - } - - -def enrich_recommend_rows( - rows: list[dict], - capital: float, - *, - max_margin_pct: float = 30.0, - trading_mode: str = "simulation", - margin_used: float = 0.0, - use_ctp_margin: bool = True, -) -> list[dict]: - """用当前权益与保证金比例补算最大可开手数(兼容旧缓存)。""" - cap = float(capital or 0) - budget_info = margin_budget_info(cap, max_margin_pct, margin_used) - pct = budget_info["max_margin_pct"] - budget = budget_info["margin_budget_remaining"] - ctp_connected = _ctp_connected_for_mode(trading_mode) - enriched: list[dict] = [] - for raw in rows: - row = dict(raw) - ths = (row.get("ths") or "").strip() - main_code = (row.get("main_code") or "").strip() - spec_code = main_code or (ths + "8888" if ths else "") - if spec_code: - spec = get_contract_spec(spec_code) - if row.get("mult") in (None, ""): - row["mult"] = spec["mult"] - if row.get("tick_size") in (None, ""): - row["tick_size"] = float(spec.get("tick_size") or 1.0) - margin_one = 0.0 - try: - margin_one = float(row.get("margin_one_lot") or 0) - except (TypeError, ValueError): - margin_one = 0.0 - price = float(row.get("price") or 0) - code_for_margin = main_code or spec_code - if price > 0 and code_for_margin: - margin_one, margin_source, spec_used = margin_one_lot( - code_for_margin, - price, - direction="max", - trading_mode=trading_mode if (ctp_connected and use_ctp_margin) else None, - ) - if spec_used.get("mult"): - row["mult"] = spec_used["mult"] - if spec_used.get("tick_size"): - row["tick_size"] = spec_used["tick_size"] - row["margin_one_lot"] = margin_one - if margin_source == "ctp": - row["margin_source"] = "ctp" - row["spec_source"] = "ctp" - if margin_one > 0 and budget > 0: - lots = int(math.floor(budget / margin_one)) - else: - try: - lots = int(row.get("max_lots") or row.get("recommended_lots") or 0) - except (TypeError, ValueError): - lots = 0 - row["max_lots"] = lots - row.pop("recommended_lots", None) - row["margin_budget"] = round(budget, 2) - row["margin_budget_total"] = budget_info["margin_budget_total"] - row["margin_used"] = budget_info["margin_used"] - row["max_margin_pct"] = pct - status = row.get("status") or "" - if lots >= 1 and status in ("ok", "margin_ok"): - src = "柜台" if row.get("margin_source") == "ctp" else "估算" - row["status_label"] = ( - f"最大 {lots} 手" if status == "ok" else f"最大 {lots} 手·止损偏宽" - ) - if row.get("margin_source") == "ctp": - row["status_label"] += f"({src}保证金)" - if budget_info["margin_used"] > 0: - row["status_label"] += "·扣持仓" - elif lots < 1 and status in ("ok", "margin_ok"): - row["status"] = "blocked" - row["status_label"] = "资金不足" - if not row.get("category"): - row["category"] = product_category(row.get("ths") or "") - from symbols import enrich_recommend_row - row = enrich_recommend_row(row) - _attach_turnover(row) - enriched.append(row) - from symbols import filter_for_trading_session - return filter_for_trading_session(enriched) - - -def filter_recommend_by_sizing( - rows: list[dict], - *, - sizing_mode: str, - fixed_lots: int = 1, -) -> list[dict]: - """固定手数模式下:最大手数低于设定值的品种不展示。""" - if (sizing_mode or "").strip().lower() != "fixed": - return rows - fl = max(1, int(fixed_lots or 1)) - return [r for r in rows if int(r.get("max_lots") or 0) >= fl] - - -def refresh_recommend_cache( - conn, - capital: float, - quote_fn: Callable[[str], Optional[dict]], - *, - trading_mode: str = "simulation", - max_margin_pct: float = 30.0, - margin_used: float | None = None, -) -> list[dict]: - """后台拉行情、筛选并写入数据库。""" - ensure_recommend_tables(conn) - ensure_fee_rates_schema(conn) - ctp_connected = _ctp_connected_for_mode(trading_mode) - used = ( - float(margin_used) - if margin_used is not None - else recommend_margin_used(trading_mode) - ) - all_rows = list_product_recommendations( - capital, - quote_fn, - max_margin_pct=max_margin_pct, - trading_mode=trading_mode, - ctp_connected=ctp_connected, - margin_used=used, - ) - rows = filter_affordable_recommendations(all_rows) - if not rows and float(capital or 0) > 0: - logger.warning( - "recommend refresh: 0 affordable rows capital=%.2f total=%d no_price=%d blocked=%d", - float(capital or 0), - len(all_rows), - sum(1 for r in all_rows if r.get("status") == "no_price"), - sum(1 for r in all_rows if r.get("status") == "blocked"), - ) - now = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - conn.execute( - """INSERT INTO product_recommend_cache (id, capital, rows_json, updated_at) - VALUES (1, ?, ?, ?) - ON CONFLICT(id) DO UPDATE SET - capital=excluded.capital, - rows_json=excluded.rows_json, - updated_at=excluded.updated_at""", - (float(capital or 0), json.dumps(rows, ensure_ascii=False), now), - ) - conn.commit() - return rows - - -def recommend_cache_stale(updated_at: Optional[str], *, now: Optional[datetime] = None) -> bool: - """缓存是否不是今日更新(需重新拉行情计算)。""" - if not updated_at: - return True - try: - cached_day = datetime.strptime(str(updated_at)[:10], "%Y-%m-%d").date() - except ValueError: - return True - today = (now or datetime.now()).date() - return cached_day != today - - -def load_recommend_cache(conn) -> dict: - """优先从数据库读取可开仓品种列表。""" - ensure_recommend_tables(conn) - row = conn.execute("SELECT capital, rows_json, updated_at FROM product_recommend_cache WHERE id=1").fetchone() - if not row: - return {"capital": 0.0, "rows": [], "updated_at": None, "stale": True} - try: - rows = json.loads(row["rows_json"] or "[]") - except (TypeError, ValueError, json.JSONDecodeError): - rows = [] - updated_at = row["updated_at"] - return { - "capital": float(row["capital"] or 0), - "rows": rows if isinstance(rows, list) else [], - "updated_at": updated_at, - "stale": recommend_cache_stale(updated_at), - } - - -def recommend_payload( - conn, - *, - live_capital: float, - max_margin_pct: float = 30.0, - trading_mode: str = "simulation", - sizing_mode: str = "fixed", - fixed_lots: int = 1, - use_ctp_margin: bool = True, -) -> dict: - """读取缓存并附带当前权益(展示用,可能与缓存计算时不同)。""" - payload = load_recommend_cache(conn) - cap = float(live_capital or 0) - pct = max(1.0, min(100.0, float(max_margin_pct or 30.0))) - if use_ctp_margin: - used = recommend_margin_used(trading_mode) - else: - used = 0.0 - try: - from position_stream import position_hub - - snap = position_hub.get_snapshot() or {} - raw = snap.get("margin_used") - if raw is not None: - used = max(0.0, float(raw or 0)) - except Exception: - pass - if used <= 0: - used = float(payload.get("margin_used") or 0) - budget_info = margin_budget_info(cap, pct, used) - payload["capital"] = cap - payload["max_margin_pct"] = pct - payload.update(budget_info) - rows = payload.get("rows") or [] - rows = enrich_recommend_rows( - rows, - cap, - max_margin_pct=pct, - trading_mode=trading_mode, - margin_used=used, - use_ctp_margin=use_ctp_margin, - ) - rows = filter_rows_for_account_scope( - rows, cap, ctp_connected=_ctp_connected_for_mode(trading_mode), - ) - rows = filter_recommend_by_sizing(rows, sizing_mode=sizing_mode, fixed_lots=fixed_lots) - rows = sort_recommend_by_trend(rows) - payload["rows"] = rows - payload["needs_refresh"] = recommend_cache_needs_refresh(payload, capital=cap) - return payload +# Copyright (c) 2025-2026 马建军. All rights reserved. +# 专有软件 — 未经授权禁止复制、传播、转售。 +# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 +# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md + +"""可开仓品种:计算、按资金过滤、SQLite 缓存。""" +from __future__ import annotations + +import json +import logging +import math +from datetime import datetime +from typing import Callable, Optional + +from modules.core.contract_specs import get_contract_spec, margin_one_lot +from modules.fees.fee_specs import ensure_fee_rates_schema +from modules.trading.product_recommend import ( + _attach_turnover, + filter_rows_for_account_scope, + list_product_recommendations, +) +from modules.trading.recommend_trend import sort_recommend_by_trend +from modules.core.symbols import product_category + +logger = logging.getLogger(__name__) + +RECOMMEND_CACHE_SQL = """ +CREATE TABLE IF NOT EXISTS product_recommend_cache ( + id INTEGER PRIMARY KEY CHECK (id = 1), + capital REAL NOT NULL DEFAULT 0, + rows_json TEXT NOT NULL DEFAULT '[]', + updated_at TEXT +) +""" + + +def ensure_recommend_tables(conn) -> None: + conn.execute(RECOMMEND_CACHE_SQL) + + +def filter_affordable_recommendations(rows: list[dict]) -> list[dict]: + """仅保留当前资金可开 1 手的品种(不含资金不足、无行情)。""" + return [r for r in rows if r.get("status") in ("ok", "margin_ok")] + + +def rows_missing_max_lots(rows: list[dict]) -> bool: + """缓存是否为旧版(缺少最大手数字段)。""" + if not rows: + return False + return any("max_lots" not in r for r in rows) + + +def rows_missing_trend(rows: list[dict]) -> bool: + """缓存是否为旧版(缺少走势字段)。""" + if not rows: + return False + return any("trend" not in r for r in rows) + + +def rows_missing_daily_stats(rows: list[dict]) -> bool: + """缓存是否为旧版(缺少跳空/量价字段)。""" + if not rows: + return False + return any("gap" not in r for r in rows) + + +def rows_missing_category(rows: list[dict]) -> bool: + if not rows: + return False + return any("category" not in r for r in rows) + + +def rows_missing_turnover(rows: list[dict]) -> bool: + if not rows: + return False + return any("turnover" not in r for r in rows) + + +def rows_missing_contract_spec(rows: list[dict]) -> bool: + if not rows: + return False + return any("mult" not in r or "tick_size" not in r for r in rows) + + +def recommend_cache_needs_refresh( + cached: dict, + *, + capital: float = 0.0, +) -> bool: + """是否需要重新拉行情计算可开仓列表。""" + if recommend_cache_stale(cached.get("updated_at")): + return True + rows = cached.get("rows") or [] + if rows_missing_max_lots(rows): + return True + if rows_missing_trend(rows): + return True + if rows_missing_daily_stats(rows): + return True + if rows_missing_category(rows): + return True + if rows_missing_turnover(rows): + return True + if rows_missing_contract_spec(rows): + return True + if float(capital or 0) > 0 and not rows: + return True + return False + + +def _ctp_connected_for_mode(trading_mode: str) -> bool: + try: + from modules.trading.position_stream import position_hub + + snap = position_hub.get_snapshot() or {} + st = snap.get("ctp_status") + if isinstance(st, dict) and st: + return bool(st.get("connected")) + except Exception: + pass + del trading_mode + return False + + +def recommend_margin_used(trading_mode: str) -> float: + """当前持仓已占用保证金(各持仓 CTP 回报之和,与柜台持仓保证金一致)。""" + try: + from modules.trading.position_stream import position_hub + + snap = position_hub.get_snapshot() or {} + raw = snap.get("margin_used") + if raw is not None: + return max(0.0, float(raw or 0)) + except Exception: + pass + if not _ctp_connected_for_mode(trading_mode): + return 0.0 + try: + from modules.ctp.vnpy_bridge import ctp_account_margin_used, ctp_sum_position_margins + + total = ctp_sum_position_margins( + trading_mode, refresh_if_empty=False, refresh_margin=True, + ) + if total > 0: + return total + used = ctp_account_margin_used(trading_mode) + return float(used) if used and used > 0 else 0.0 + except Exception as exc: + logger.debug("recommend_margin_used: %s", exc) + return 0.0 + + +def margin_budget_info( + capital: float, + max_margin_pct: float, + margin_used: float = 0.0, +) -> dict[str, float]: + """保证金上限总额、已占用、剩余可开额度。""" + cap = float(capital or 0) + pct = max(1.0, min(100.0, float(max_margin_pct or 30.0))) + total = cap * pct / 100.0 if cap > 0 else 0.0 + used = max(0.0, float(margin_used or 0)) + remaining = max(0.0, total - used) + return { + "margin_budget_total": round(total, 2), + "margin_used": round(used, 2), + "margin_budget_remaining": round(remaining, 2), + "max_margin_pct": pct, + } + + +def enrich_recommend_rows( + rows: list[dict], + capital: float, + *, + max_margin_pct: float = 30.0, + trading_mode: str = "simulation", + margin_used: float = 0.0, + use_ctp_margin: bool = True, +) -> list[dict]: + """用当前权益与保证金比例补算最大可开手数(兼容旧缓存)。""" + cap = float(capital or 0) + budget_info = margin_budget_info(cap, max_margin_pct, margin_used) + pct = budget_info["max_margin_pct"] + budget = budget_info["margin_budget_remaining"] + ctp_connected = _ctp_connected_for_mode(trading_mode) + enriched: list[dict] = [] + for raw in rows: + row = dict(raw) + ths = (row.get("ths") or "").strip() + main_code = (row.get("main_code") or "").strip() + spec_code = main_code or (ths + "8888" if ths else "") + if spec_code: + spec = get_contract_spec(spec_code) + if row.get("mult") in (None, ""): + row["mult"] = spec["mult"] + if row.get("tick_size") in (None, ""): + row["tick_size"] = float(spec.get("tick_size") or 1.0) + margin_one = 0.0 + try: + margin_one = float(row.get("margin_one_lot") or 0) + except (TypeError, ValueError): + margin_one = 0.0 + price = float(row.get("price") or 0) + code_for_margin = main_code or spec_code + if price > 0 and code_for_margin: + margin_one, margin_source, spec_used = margin_one_lot( + code_for_margin, + price, + direction="max", + trading_mode=trading_mode if (ctp_connected and use_ctp_margin) else None, + ) + if spec_used.get("mult"): + row["mult"] = spec_used["mult"] + if spec_used.get("tick_size"): + row["tick_size"] = spec_used["tick_size"] + row["margin_one_lot"] = margin_one + if margin_source == "ctp": + row["margin_source"] = "ctp" + row["spec_source"] = "ctp" + if margin_one > 0 and budget > 0: + lots = int(math.floor(budget / margin_one)) + else: + try: + lots = int(row.get("max_lots") or row.get("recommended_lots") or 0) + except (TypeError, ValueError): + lots = 0 + row["max_lots"] = lots + row.pop("recommended_lots", None) + row["margin_budget"] = round(budget, 2) + row["margin_budget_total"] = budget_info["margin_budget_total"] + row["margin_used"] = budget_info["margin_used"] + row["max_margin_pct"] = pct + status = row.get("status") or "" + if lots >= 1 and status in ("ok", "margin_ok"): + src = "柜台" if row.get("margin_source") == "ctp" else "估算" + row["status_label"] = ( + f"最大 {lots} 手" if status == "ok" else f"最大 {lots} 手·止损偏宽" + ) + if row.get("margin_source") == "ctp": + row["status_label"] += f"({src}保证金)" + if budget_info["margin_used"] > 0: + row["status_label"] += "·扣持仓" + elif lots < 1 and status in ("ok", "margin_ok"): + row["status"] = "blocked" + row["status_label"] = "资金不足" + if not row.get("category"): + row["category"] = product_category(row.get("ths") or "") + from modules.core.symbols import enrich_recommend_row + row = enrich_recommend_row(row) + _attach_turnover(row) + enriched.append(row) + from modules.core.symbols import filter_for_trading_session + return filter_for_trading_session(enriched) + + +def filter_recommend_by_sizing( + rows: list[dict], + *, + sizing_mode: str, + fixed_lots: int = 1, +) -> list[dict]: + """固定手数模式下:最大手数低于设定值的品种不展示。""" + if (sizing_mode or "").strip().lower() != "fixed": + return rows + fl = max(1, int(fixed_lots or 1)) + return [r for r in rows if int(r.get("max_lots") or 0) >= fl] + + +def refresh_recommend_cache( + conn, + capital: float, + quote_fn: Callable[[str], Optional[dict]], + *, + trading_mode: str = "simulation", + max_margin_pct: float = 30.0, + margin_used: float | None = None, +) -> list[dict]: + """后台拉行情、筛选并写入数据库。""" + ensure_recommend_tables(conn) + ensure_fee_rates_schema(conn) + ctp_connected = _ctp_connected_for_mode(trading_mode) + used = ( + float(margin_used) + if margin_used is not None + else recommend_margin_used(trading_mode) + ) + all_rows = list_product_recommendations( + capital, + quote_fn, + max_margin_pct=max_margin_pct, + trading_mode=trading_mode, + ctp_connected=ctp_connected, + margin_used=used, + ) + rows = filter_affordable_recommendations(all_rows) + if not rows and float(capital or 0) > 0: + logger.warning( + "recommend refresh: 0 affordable rows capital=%.2f total=%d no_price=%d blocked=%d", + float(capital or 0), + len(all_rows), + sum(1 for r in all_rows if r.get("status") == "no_price"), + sum(1 for r in all_rows if r.get("status") == "blocked"), + ) + now = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + conn.execute( + """INSERT INTO product_recommend_cache (id, capital, rows_json, updated_at) + VALUES (1, ?, ?, ?) + ON CONFLICT(id) DO UPDATE SET + capital=excluded.capital, + rows_json=excluded.rows_json, + updated_at=excluded.updated_at""", + (float(capital or 0), json.dumps(rows, ensure_ascii=False), now), + ) + conn.commit() + return rows + + +def recommend_cache_stale(updated_at: Optional[str], *, now: Optional[datetime] = None) -> bool: + """缓存是否不是今日更新(需重新拉行情计算)。""" + if not updated_at: + return True + try: + cached_day = datetime.strptime(str(updated_at)[:10], "%Y-%m-%d").date() + except ValueError: + return True + today = (now or datetime.now()).date() + return cached_day != today + + +def load_recommend_cache(conn) -> dict: + """优先从数据库读取可开仓品种列表。""" + ensure_recommend_tables(conn) + row = conn.execute("SELECT capital, rows_json, updated_at FROM product_recommend_cache WHERE id=1").fetchone() + if not row: + return {"capital": 0.0, "rows": [], "updated_at": None, "stale": True} + try: + rows = json.loads(row["rows_json"] or "[]") + except (TypeError, ValueError, json.JSONDecodeError): + rows = [] + updated_at = row["updated_at"] + return { + "capital": float(row["capital"] or 0), + "rows": rows if isinstance(rows, list) else [], + "updated_at": updated_at, + "stale": recommend_cache_stale(updated_at), + } + + +def recommend_payload( + conn, + *, + live_capital: float, + max_margin_pct: float = 30.0, + trading_mode: str = "simulation", + sizing_mode: str = "fixed", + fixed_lots: int = 1, + use_ctp_margin: bool = True, +) -> dict: + """读取缓存并附带当前权益(展示用,可能与缓存计算时不同)。""" + payload = load_recommend_cache(conn) + cap = float(live_capital or 0) + pct = max(1.0, min(100.0, float(max_margin_pct or 30.0))) + if use_ctp_margin: + used = recommend_margin_used(trading_mode) + else: + used = 0.0 + try: + from modules.trading.position_stream import position_hub + + snap = position_hub.get_snapshot() or {} + raw = snap.get("margin_used") + if raw is not None: + used = max(0.0, float(raw or 0)) + except Exception: + pass + if used <= 0: + used = float(payload.get("margin_used") or 0) + budget_info = margin_budget_info(cap, pct, used) + payload["capital"] = cap + payload["max_margin_pct"] = pct + payload.update(budget_info) + rows = payload.get("rows") or [] + rows = enrich_recommend_rows( + rows, + cap, + max_margin_pct=pct, + trading_mode=trading_mode, + margin_used=used, + use_ctp_margin=use_ctp_margin, + ) + rows = filter_rows_for_account_scope( + rows, cap, ctp_connected=_ctp_connected_for_mode(trading_mode), + ) + rows = filter_recommend_by_sizing(rows, sizing_mode=sizing_mode, fixed_lots=fixed_lots) + rows = sort_recommend_by_trend(rows) + payload["rows"] = rows + payload["needs_refresh"] = recommend_cache_needs_refresh(payload, capital=cap) + return payload diff --git a/recommend_stream.py b/modules/trading/recommend_stream.py similarity index 94% rename from recommend_stream.py rename to modules/trading/recommend_stream.py index 294ce17..5ee0693 100644 --- a/recommend_stream.py +++ b/modules/trading/recommend_stream.py @@ -1,163 +1,163 @@ -# Copyright (c) 2025-2026 马建军. All rights reserved. -# 专有软件 — 未经授权禁止复制、传播、转售。 -# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 -# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md - -"""可开仓品种 SSE 推送与后台刷新。""" -from __future__ import annotations - -import json -import logging -import queue -import threading -import time -from typing import Callable, Optional - -from db_conn import connect_db -from kline_stream import sse_format -from recommend_store import ( - load_recommend_cache, - recommend_cache_needs_refresh, - recommend_payload, - refresh_recommend_cache, -) - -logger = logging.getLogger(__name__) - -CHECK_INTERVAL_SEC = 3600 -_refresh_lock = threading.Lock() -_refresh_running = False - - -def schedule_recommend_refresh( - *, - db_path: str, - get_capital_fn: Callable, - quote_fn: Callable[[str], Optional[dict]], - init_tables_fn: Callable | None = None, - get_mode_fn: Callable[[], str] | None = None, - get_max_margin_pct_fn: Callable[[], float] | None = None, - get_sizing_mode_fn: Callable[[], str] | None = None, - get_fixed_lots_fn: Callable[[], int] | None = None, -) -> None: - """后台刷新可开仓品种缓存(不阻塞页面请求)。""" - global _refresh_running - with _refresh_lock: - if _refresh_running: - return - _refresh_running = True - - def _run() -> None: - global _refresh_running - try: - conn = connect_db(db_path) - try: - if init_tables_fn: - init_tables_fn(conn) - capital = float(get_capital_fn(conn) or 0) - mode = get_mode_fn() if get_mode_fn else "simulation" - max_pct = float(get_max_margin_pct_fn()) if get_max_margin_pct_fn else 30.0 - cached = load_recommend_cache(conn) - if not recommend_cache_needs_refresh(cached, capital=capital): - payload = recommend_payload( - conn, - live_capital=capital, - max_margin_pct=max_pct, - trading_mode=mode, - sizing_mode=get_sizing_mode_fn() if get_sizing_mode_fn else "fixed", - fixed_lots=get_fixed_lots_fn() if get_fixed_lots_fn else 1, - ) - recommend_hub.broadcast("recommend", {"ok": True, **payload}) - return - refresh_recommend_cache( - conn, capital, quote_fn, trading_mode=mode, max_margin_pct=max_pct, - ) - cached = load_recommend_cache(conn) - logger.info( - "可开仓品种后台刷新完成,capital=%.2f rows=%d", - capital, len(cached.get("rows") or []), - ) - payload = recommend_payload( - conn, - live_capital=capital, - max_margin_pct=max_pct, - trading_mode=mode, - sizing_mode=get_sizing_mode_fn() if get_sizing_mode_fn else "fixed", - fixed_lots=get_fixed_lots_fn() if get_fixed_lots_fn else 1, - ) - finally: - conn.close() - recommend_hub.broadcast("recommend", {"ok": True, **payload}) - except Exception as exc: - logger.warning("recommend background refresh failed: %s", exc) - finally: - with _refresh_lock: - _refresh_running = False - - threading.Thread(target=_run, daemon=True, name="recommend-refresh").start() - - -class RecommendStreamHub: - def __init__(self) -> None: - self._lock = threading.Lock() - self._subs: list[queue.Queue] = [] - - def subscribe(self) -> queue.Queue: - q: queue.Queue = queue.Queue(maxsize=8) - with self._lock: - self._subs.append(q) - return q - - def unsubscribe(self, q: queue.Queue) -> None: - with self._lock: - try: - self._subs.remove(q) - except ValueError: - pass - - def broadcast(self, event: str, data: dict) -> None: - msg = {"event": event, "data": data} - with self._lock: - subs = list(self._subs) - for q in subs: - try: - q.put_nowait(msg) - except queue.Full: - pass - - -recommend_hub = RecommendStreamHub() - - -def start_recommend_worker( - *, - db_path: str, - get_capital_fn: Callable, - quote_fn: Callable[[str], Optional[dict]], - init_tables_fn: Callable | None = None, - get_mode_fn: Callable[[], str] | None = None, - get_max_margin_pct_fn: Callable[[], float] | None = None, - get_sizing_mode_fn: Callable[[], str] | None = None, - get_fixed_lots_fn: Callable[[], int] | None = None, - interval: int = CHECK_INTERVAL_SEC, -) -> None: - """后台每日刷新可开仓列表(每小时检查一次是否需更新),并推送给 SSE 订阅者。""" - - def _loop() -> None: - while True: - try: - schedule_recommend_refresh( - db_path=db_path, - get_capital_fn=get_capital_fn, - quote_fn=quote_fn, - init_tables_fn=init_tables_fn, - get_mode_fn=get_mode_fn, - get_max_margin_pct_fn=get_max_margin_pct_fn, - get_sizing_mode_fn=get_sizing_mode_fn, - get_fixed_lots_fn=get_fixed_lots_fn, - ) - except Exception as exc: - logger.warning("recommend worker failed: %s", exc) - time.sleep(max(300, interval)) - - threading.Thread(target=_loop, daemon=True, name="recommend-worker").start() +# Copyright (c) 2025-2026 马建军. All rights reserved. +# 专有软件 — 未经授权禁止复制、传播、转售。 +# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 +# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md + +"""可开仓品种 SSE 推送与后台刷新。""" +from __future__ import annotations + +import json +import logging +import queue +import threading +import time +from typing import Callable, Optional + +from modules.core.db_conn import connect_db +from modules.market.kline_stream import sse_format +from modules.trading.recommend_store import ( + load_recommend_cache, + recommend_cache_needs_refresh, + recommend_payload, + refresh_recommend_cache, +) + +logger = logging.getLogger(__name__) + +CHECK_INTERVAL_SEC = 3600 +_refresh_lock = threading.Lock() +_refresh_running = False + + +def schedule_recommend_refresh( + *, + db_path: str, + get_capital_fn: Callable, + quote_fn: Callable[[str], Optional[dict]], + init_tables_fn: Callable | None = None, + get_mode_fn: Callable[[], str] | None = None, + get_max_margin_pct_fn: Callable[[], float] | None = None, + get_sizing_mode_fn: Callable[[], str] | None = None, + get_fixed_lots_fn: Callable[[], int] | None = None, +) -> None: + """后台刷新可开仓品种缓存(不阻塞页面请求)。""" + global _refresh_running + with _refresh_lock: + if _refresh_running: + return + _refresh_running = True + + def _run() -> None: + global _refresh_running + try: + conn = connect_db(db_path) + try: + if init_tables_fn: + init_tables_fn(conn) + capital = float(get_capital_fn(conn) or 0) + mode = get_mode_fn() if get_mode_fn else "simulation" + max_pct = float(get_max_margin_pct_fn()) if get_max_margin_pct_fn else 30.0 + cached = load_recommend_cache(conn) + if not recommend_cache_needs_refresh(cached, capital=capital): + payload = recommend_payload( + conn, + live_capital=capital, + max_margin_pct=max_pct, + trading_mode=mode, + sizing_mode=get_sizing_mode_fn() if get_sizing_mode_fn else "fixed", + fixed_lots=get_fixed_lots_fn() if get_fixed_lots_fn else 1, + ) + recommend_hub.broadcast("recommend", {"ok": True, **payload}) + return + refresh_recommend_cache( + conn, capital, quote_fn, trading_mode=mode, max_margin_pct=max_pct, + ) + cached = load_recommend_cache(conn) + logger.info( + "可开仓品种后台刷新完成,capital=%.2f rows=%d", + capital, len(cached.get("rows") or []), + ) + payload = recommend_payload( + conn, + live_capital=capital, + max_margin_pct=max_pct, + trading_mode=mode, + sizing_mode=get_sizing_mode_fn() if get_sizing_mode_fn else "fixed", + fixed_lots=get_fixed_lots_fn() if get_fixed_lots_fn else 1, + ) + finally: + conn.close() + recommend_hub.broadcast("recommend", {"ok": True, **payload}) + except Exception as exc: + logger.warning("recommend background refresh failed: %s", exc) + finally: + with _refresh_lock: + _refresh_running = False + + threading.Thread(target=_run, daemon=True, name="recommend-refresh").start() + + +class RecommendStreamHub: + def __init__(self) -> None: + self._lock = threading.Lock() + self._subs: list[queue.Queue] = [] + + def subscribe(self) -> queue.Queue: + q: queue.Queue = queue.Queue(maxsize=8) + with self._lock: + self._subs.append(q) + return q + + def unsubscribe(self, q: queue.Queue) -> None: + with self._lock: + try: + self._subs.remove(q) + except ValueError: + pass + + def broadcast(self, event: str, data: dict) -> None: + msg = {"event": event, "data": data} + with self._lock: + subs = list(self._subs) + for q in subs: + try: + q.put_nowait(msg) + except queue.Full: + pass + + +recommend_hub = RecommendStreamHub() + + +def start_recommend_worker( + *, + db_path: str, + get_capital_fn: Callable, + quote_fn: Callable[[str], Optional[dict]], + init_tables_fn: Callable | None = None, + get_mode_fn: Callable[[], str] | None = None, + get_max_margin_pct_fn: Callable[[], float] | None = None, + get_sizing_mode_fn: Callable[[], str] | None = None, + get_fixed_lots_fn: Callable[[], int] | None = None, + interval: int = CHECK_INTERVAL_SEC, +) -> None: + """后台每日刷新可开仓列表(每小时检查一次是否需更新),并推送给 SSE 订阅者。""" + + def _loop() -> None: + while True: + try: + schedule_recommend_refresh( + db_path=db_path, + get_capital_fn=get_capital_fn, + quote_fn=quote_fn, + init_tables_fn=init_tables_fn, + get_mode_fn=get_mode_fn, + get_max_margin_pct_fn=get_max_margin_pct_fn, + get_sizing_mode_fn=get_sizing_mode_fn, + get_fixed_lots_fn=get_fixed_lots_fn, + ) + except Exception as exc: + logger.warning("recommend worker failed: %s", exc) + time.sleep(max(300, interval)) + + threading.Thread(target=_loop, daemon=True, name="recommend-worker").start() diff --git a/recommend_trend.py b/modules/trading/recommend_trend.py similarity index 96% rename from recommend_trend.py rename to modules/trading/recommend_trend.py index a0cc99a..b260a4d 100644 --- a/recommend_trend.py +++ b/modules/trading/recommend_trend.py @@ -1,339 +1,339 @@ -# Copyright (c) 2025-2026 马建军. All rights reserved. -# 专有软件 — 未经授权禁止复制、传播、转售。 -# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 -# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md - -"""可开仓品种:近一周日线走势(多头 / 空头 / 震荡 / 转多 / 转空)。""" -from __future__ import annotations - -import logging -from typing import Callable, Optional - -import requests - -from kline_chart import fetch_sina_klines, ths_to_sina_chart_symbol - -logger = logging.getLogger(__name__) - -DAILY_LOOKBACK = 7 -OVERLAP_WINDOW = 3 -OVERLAP_RANGE_THRESHOLD = 0.70 -KLINE_FETCH_TIMEOUT = 5 - -TREND_LONG = "long" -TREND_SHORT = "short" -TREND_RANGE = "range" -TREND_BREAK_LONG = "break_long" -TREND_BREAK_SHORT = "break_short" - - -def _bar_ohlc(bar: dict) -> tuple[float, float, float, float]: - o = float(bar.get("o") or bar.get("open") or 0) - h = float(bar.get("h") or bar.get("high") or o) - l = float(bar.get("l") or bar.get("low") or o) - c = float(bar.get("c") or bar.get("close") or o) - return o, h, l, c - - -def kline_overlap_ratio(bars: list) -> float: - """三根 K 线高低价区间的重叠度 = 交集 / 并集(0~1)。""" - if len(bars) < OVERLAP_WINDOW: - return 0.0 - chunk = bars[-OVERLAP_WINDOW:] - lows, highs = [], [] - for bar in chunk: - _, h, l, _ = _bar_ohlc(bar) - if h <= 0 and l <= 0: - continue - lows.append(l) - highs.append(h) - if len(lows) < OVERLAP_WINDOW: - return 0.0 - overlap = max(0.0, min(highs) - max(lows)) - union = max(highs) - min(lows) - if union <= 0: - return 1.0 if overlap > 0 else 0.0 - return overlap / union - - -def _direction_from_closes(bars: list) -> str: - if len(bars) < 2: - return TREND_RANGE - closes = [_bar_ohlc(b)[3] for b in bars if _bar_ohlc(b)[3] > 0] - if len(closes) < 2: - return TREND_RANGE - if closes[-1] > closes[0]: - return TREND_LONG - if closes[-1] < closes[0]: - return TREND_SHORT - return TREND_RANGE - - -def _bar_ohlcv(bar: dict) -> tuple[float, float, float, float, float]: - o, h, l, c = _bar_ohlc(bar) - v = float(bar.get("v") or bar.get("volume") or 0) - return o, h, l, c, v - - -def compute_daily_quote_stats(bars: list) -> dict: - """从日线提取:跳空、昨收、今开、昨涨跌、昨振幅、成交量。""" - empty = { - "gap": "", - "gap_label": "—", - "gap_pct": None, - "prev_close": None, - "today_open": None, - "yesterday_change": None, - "yesterday_change_pct": None, - "yesterday_amplitude_pct": None, - "volume": None, - } - if len(bars) < 2: - return empty - - t_o, _, _, _, t_v = _bar_ohlcv(bars[-1]) - y_o, y_h, y_l, y_c, y_v = _bar_ohlcv(bars[-2]) - if y_c <= 0: - return empty - - prev_close = round(y_c, 4) - today_open = round(t_o, 4) if t_o > 0 else None - - gap, gap_label, gap_pct = "none", "否", 0.0 - if today_open is not None and today_open > y_c: - gap, gap_label = "up", "跳空高开" - gap_pct = (today_open - y_c) / y_c * 100 - elif today_open is not None and today_open < y_c: - gap, gap_label = "down", "跳空低开" - gap_pct = (today_open - y_c) / y_c * 100 - - if len(bars) >= 3: - _, _, _, p_c, _ = _bar_ohlcv(bars[-3]) - base = p_c if p_c > 0 else y_o - else: - base = y_o if y_o > 0 else y_c - - y_change = y_c - base if base > 0 else None - y_change_pct = (y_change / base * 100) if y_change is not None and base > 0 else None - y_amp = ((y_h - y_l) / base * 100) if base > 0 and y_h >= y_l else None - vol = y_v if y_v > 0 else (t_v if t_v > 0 else None) - - return { - "gap": gap, - "gap_label": gap_label, - "gap_pct": round(gap_pct, 2) if gap != "none" else 0.0, - "prev_close": prev_close, - "today_open": today_open, - "yesterday_change": round(y_change, 4) if y_change is not None else None, - "yesterday_change_pct": round(y_change_pct, 2) if y_change_pct is not None else None, - "yesterday_amplitude_pct": round(y_amp, 2) if y_amp is not None else None, - "volume": int(vol) if vol is not None else None, - "volume_unit": "lot", - } - - -def analyze_daily_trend(bars: list, *, overlap_threshold: float = OVERLAP_RANGE_THRESHOLD) -> dict: - """根据近一周日线判断走势;最近三天重叠度≥阈值视为震荡。""" - empty = { - "trend": "", - "trend_label": "—", - "trend_transition": False, - "trend_overlap_pct": None, - "trend_prev_overlap_pct": None, - } - if len(bars) < OVERLAP_WINDOW: - return empty - - recent = bars[-DAILY_LOOKBACK:] if len(bars) > DAILY_LOOKBACK else bars - curr_overlap = kline_overlap_ratio(recent) - prev_overlap = kline_overlap_ratio(recent[:-OVERLAP_WINDOW]) if len(recent) >= OVERLAP_WINDOW * 2 else 0.0 - - curr_range = curr_overlap >= overlap_threshold - prev_range = prev_overlap >= overlap_threshold - - if curr_range: - trend, label = TREND_RANGE, "震荡" - transition = False - else: - direction = _direction_from_closes(recent[-OVERLAP_WINDOW:]) - if direction == TREND_LONG: - trend, label = TREND_LONG, "多头" - elif direction == TREND_SHORT: - trend, label = TREND_SHORT, "空头" - else: - trend, label = TREND_RANGE, "震荡" - transition = prev_range and trend in (TREND_LONG, TREND_SHORT) - if transition: - if trend == TREND_LONG: - trend, label = TREND_BREAK_LONG, "转多" - else: - trend, label = TREND_BREAK_SHORT, "转空" - - return { - "trend": trend, - "trend_label": label, - "trend_transition": transition, - "trend_overlap_pct": round(curr_overlap * 100, 1), - "trend_prev_overlap_pct": round(prev_overlap * 100, 1) if prev_overlap else None, - } - - -def _normalize_daily_bars(raw: list) -> list: - out = [] - for row in raw: - if isinstance(row, list) and len(row) >= 5: - out.append({ - "d": str(row[0]), - "o": float(row[1]), - "h": float(row[2]), - "l": float(row[3]), - "c": float(row[4]), - "v": float(row[5]) if len(row) > 5 and row[5] else 0.0, - }) - elif isinstance(row, dict) and row.get("d"): - out.append({ - "d": str(row["d"]), - "o": float(row.get("o", 0) or 0), - "h": float(row.get("h", 0) or 0), - "l": float(row.get("l", 0) or 0), - "c": float(row.get("c", 0) or 0), - "v": float(row.get("v", 0) or 0), - }) - return out - - -def _fetch_sina_daily_quick(chart_sym: str) -> list: - url = ( - "https://stock2.finance.sina.com.cn/futures/api/json.php/" - f"IndexService.getInnerFuturesDailyKLine?symbol={chart_sym}" - ) - try: - resp = requests.get( - url, timeout=KLINE_FETCH_TIMEOUT, - headers={"Referer": "https://finance.sina.com.cn"}, - ) - raw = resp.json() - if raw and isinstance(raw, list): - bars = _normalize_daily_bars(raw) - if bars: - return bars - except Exception as exc: - logger.debug("quick daily kline failed %s: %s", chart_sym, exc) - return [] - - -def fetch_week_daily_bars( - symbol: str, - *, - fetch_fn: Callable[[str, str], list] | None = None, -) -> list: - sym = (symbol or "").strip() - if not sym: - return [] - if fetch_fn: - try: - bars = fetch_fn(sym, "d") or [] - except Exception as exc: - logger.debug("fetch week daily failed %s: %s", sym, exc) - return [] - return bars[-DAILY_LOOKBACK:] if bars else [] - - chart_sym = ths_to_sina_chart_symbol(sym) - if not chart_sym: - return [] - bars = _fetch_sina_daily_quick(chart_sym) - if not bars: - try: - bars = fetch_sina_klines(sym, "d") or [] - except Exception as exc: - logger.debug("fetch week daily fallback failed %s: %s", sym, exc) - return [] - return bars[-DAILY_LOOKBACK:] if bars else [] - - -def analyze_product_daily( - symbol: str, - *, - fetch_fn: Callable[[str, str], list] | None = None, -) -> dict: - """拉取主力合约一周日线:走势 + 跳空/量价统计。""" - sym = (symbol or "").strip() - if not sym: - out = analyze_daily_trend([]) - out.update(compute_daily_quote_stats([])) - return out - bars = fetch_week_daily_bars(sym, fetch_fn=fetch_fn) - out = analyze_daily_trend(bars) - out.update(compute_daily_quote_stats(bars)) - return out - - -def analyze_product_trend( - symbol: str, - *, - fetch_fn: Callable[[str, str], list] | None = None, -) -> dict: - return analyze_product_daily(symbol, fetch_fn=fetch_fn) - - -GAP_SORT_RANK = {"up": 2, "down": 1, "none": 0, "": -1} -TREND_SORT_RANK = { - TREND_BREAK_LONG: 0, - TREND_BREAK_SHORT: 0, - TREND_LONG: 1, - TREND_SHORT: 2, - TREND_RANGE: 3, - "": 9, -} - - -def recommend_sort_key(row: dict, sort_by: str = "trend", *, desc: bool = True) -> tuple: - """可排序字段:trend / gap / volume / amplitude。""" - key = (sort_by or "trend").strip().lower() - if key == "gap": - primary = GAP_SORT_RANK.get(row.get("gap") or "", -1) - secondary = abs(float(row.get("gap_pct") or 0)) - elif key == "volume": - primary = float(row.get("volume") or 0) - secondary = 0.0 - elif key == "amplitude": - primary = float(row.get("yesterday_amplitude_pct") or 0) - secondary = 0.0 - else: - primary = TREND_SORT_RANK.get(row.get("trend") or "", 9) - secondary = -(int(row.get("max_lots") or 0)) - - if desc: - return (-primary, -secondary, row.get("name") or "") - return (primary, secondary, row.get("name") or "") - - -def sort_recommend_rows( - rows: list[dict], - *, - sort_by: str = "trend", - desc: bool = True, -) -> list[dict]: - return sorted(rows, key=lambda r: recommend_sort_key(r, sort_by, desc=desc)) - - -def trend_sort_key(row: dict) -> tuple: - """转多/转空优先,其次多头/空头,震荡靠后。""" - trend = (row.get("trend") or "").strip() - priority = { - TREND_BREAK_LONG: 0, - TREND_BREAK_SHORT: 0, - TREND_LONG: 1, - TREND_SHORT: 1, - TREND_RANGE: 2, - } - status_order = {"ok": 0, "margin_ok": 1, "blocked": 2, "no_price": 3} - return ( - priority.get(trend, 3), - status_order.get(row.get("status") or "", 9), - -(int(row.get("max_lots") or 0)), - ) - - -def sort_recommend_by_trend(rows: list[dict]) -> list[dict]: - return sort_recommend_rows(rows, sort_by="trend", desc=True) +# Copyright (c) 2025-2026 马建军. All rights reserved. +# 专有软件 — 未经授权禁止复制、传播、转售。 +# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 +# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md + +"""可开仓品种:近一周日线走势(多头 / 空头 / 震荡 / 转多 / 转空)。""" +from __future__ import annotations + +import logging +from typing import Callable, Optional + +import requests + +from modules.market.kline_chart import fetch_sina_klines, ths_to_sina_chart_symbol + +logger = logging.getLogger(__name__) + +DAILY_LOOKBACK = 7 +OVERLAP_WINDOW = 3 +OVERLAP_RANGE_THRESHOLD = 0.70 +KLINE_FETCH_TIMEOUT = 5 + +TREND_LONG = "long" +TREND_SHORT = "short" +TREND_RANGE = "range" +TREND_BREAK_LONG = "break_long" +TREND_BREAK_SHORT = "break_short" + + +def _bar_ohlc(bar: dict) -> tuple[float, float, float, float]: + o = float(bar.get("o") or bar.get("open") or 0) + h = float(bar.get("h") or bar.get("high") or o) + l = float(bar.get("l") or bar.get("low") or o) + c = float(bar.get("c") or bar.get("close") or o) + return o, h, l, c + + +def kline_overlap_ratio(bars: list) -> float: + """三根 K 线高低价区间的重叠度 = 交集 / 并集(0~1)。""" + if len(bars) < OVERLAP_WINDOW: + return 0.0 + chunk = bars[-OVERLAP_WINDOW:] + lows, highs = [], [] + for bar in chunk: + _, h, l, _ = _bar_ohlc(bar) + if h <= 0 and l <= 0: + continue + lows.append(l) + highs.append(h) + if len(lows) < OVERLAP_WINDOW: + return 0.0 + overlap = max(0.0, min(highs) - max(lows)) + union = max(highs) - min(lows) + if union <= 0: + return 1.0 if overlap > 0 else 0.0 + return overlap / union + + +def _direction_from_closes(bars: list) -> str: + if len(bars) < 2: + return TREND_RANGE + closes = [_bar_ohlc(b)[3] for b in bars if _bar_ohlc(b)[3] > 0] + if len(closes) < 2: + return TREND_RANGE + if closes[-1] > closes[0]: + return TREND_LONG + if closes[-1] < closes[0]: + return TREND_SHORT + return TREND_RANGE + + +def _bar_ohlcv(bar: dict) -> tuple[float, float, float, float, float]: + o, h, l, c = _bar_ohlc(bar) + v = float(bar.get("v") or bar.get("volume") or 0) + return o, h, l, c, v + + +def compute_daily_quote_stats(bars: list) -> dict: + """从日线提取:跳空、昨收、今开、昨涨跌、昨振幅、成交量。""" + empty = { + "gap": "", + "gap_label": "—", + "gap_pct": None, + "prev_close": None, + "today_open": None, + "yesterday_change": None, + "yesterday_change_pct": None, + "yesterday_amplitude_pct": None, + "volume": None, + } + if len(bars) < 2: + return empty + + t_o, _, _, _, t_v = _bar_ohlcv(bars[-1]) + y_o, y_h, y_l, y_c, y_v = _bar_ohlcv(bars[-2]) + if y_c <= 0: + return empty + + prev_close = round(y_c, 4) + today_open = round(t_o, 4) if t_o > 0 else None + + gap, gap_label, gap_pct = "none", "否", 0.0 + if today_open is not None and today_open > y_c: + gap, gap_label = "up", "跳空高开" + gap_pct = (today_open - y_c) / y_c * 100 + elif today_open is not None and today_open < y_c: + gap, gap_label = "down", "跳空低开" + gap_pct = (today_open - y_c) / y_c * 100 + + if len(bars) >= 3: + _, _, _, p_c, _ = _bar_ohlcv(bars[-3]) + base = p_c if p_c > 0 else y_o + else: + base = y_o if y_o > 0 else y_c + + y_change = y_c - base if base > 0 else None + y_change_pct = (y_change / base * 100) if y_change is not None and base > 0 else None + y_amp = ((y_h - y_l) / base * 100) if base > 0 and y_h >= y_l else None + vol = y_v if y_v > 0 else (t_v if t_v > 0 else None) + + return { + "gap": gap, + "gap_label": gap_label, + "gap_pct": round(gap_pct, 2) if gap != "none" else 0.0, + "prev_close": prev_close, + "today_open": today_open, + "yesterday_change": round(y_change, 4) if y_change is not None else None, + "yesterday_change_pct": round(y_change_pct, 2) if y_change_pct is not None else None, + "yesterday_amplitude_pct": round(y_amp, 2) if y_amp is not None else None, + "volume": int(vol) if vol is not None else None, + "volume_unit": "lot", + } + + +def analyze_daily_trend(bars: list, *, overlap_threshold: float = OVERLAP_RANGE_THRESHOLD) -> dict: + """根据近一周日线判断走势;最近三天重叠度≥阈值视为震荡。""" + empty = { + "trend": "", + "trend_label": "—", + "trend_transition": False, + "trend_overlap_pct": None, + "trend_prev_overlap_pct": None, + } + if len(bars) < OVERLAP_WINDOW: + return empty + + recent = bars[-DAILY_LOOKBACK:] if len(bars) > DAILY_LOOKBACK else bars + curr_overlap = kline_overlap_ratio(recent) + prev_overlap = kline_overlap_ratio(recent[:-OVERLAP_WINDOW]) if len(recent) >= OVERLAP_WINDOW * 2 else 0.0 + + curr_range = curr_overlap >= overlap_threshold + prev_range = prev_overlap >= overlap_threshold + + if curr_range: + trend, label = TREND_RANGE, "震荡" + transition = False + else: + direction = _direction_from_closes(recent[-OVERLAP_WINDOW:]) + if direction == TREND_LONG: + trend, label = TREND_LONG, "多头" + elif direction == TREND_SHORT: + trend, label = TREND_SHORT, "空头" + else: + trend, label = TREND_RANGE, "震荡" + transition = prev_range and trend in (TREND_LONG, TREND_SHORT) + if transition: + if trend == TREND_LONG: + trend, label = TREND_BREAK_LONG, "转多" + else: + trend, label = TREND_BREAK_SHORT, "转空" + + return { + "trend": trend, + "trend_label": label, + "trend_transition": transition, + "trend_overlap_pct": round(curr_overlap * 100, 1), + "trend_prev_overlap_pct": round(prev_overlap * 100, 1) if prev_overlap else None, + } + + +def _normalize_daily_bars(raw: list) -> list: + out = [] + for row in raw: + if isinstance(row, list) and len(row) >= 5: + out.append({ + "d": str(row[0]), + "o": float(row[1]), + "h": float(row[2]), + "l": float(row[3]), + "c": float(row[4]), + "v": float(row[5]) if len(row) > 5 and row[5] else 0.0, + }) + elif isinstance(row, dict) and row.get("d"): + out.append({ + "d": str(row["d"]), + "o": float(row.get("o", 0) or 0), + "h": float(row.get("h", 0) or 0), + "l": float(row.get("l", 0) or 0), + "c": float(row.get("c", 0) or 0), + "v": float(row.get("v", 0) or 0), + }) + return out + + +def _fetch_sina_daily_quick(chart_sym: str) -> list: + url = ( + "https://stock2.finance.sina.com.cn/futures/api/json.php/" + f"IndexService.getInnerFuturesDailyKLine?symbol={chart_sym}" + ) + try: + resp = requests.get( + url, timeout=KLINE_FETCH_TIMEOUT, + headers={"Referer": "https://finance.sina.com.cn"}, + ) + raw = resp.json() + if raw and isinstance(raw, list): + bars = _normalize_daily_bars(raw) + if bars: + return bars + except Exception as exc: + logger.debug("quick daily kline failed %s: %s", chart_sym, exc) + return [] + + +def fetch_week_daily_bars( + symbol: str, + *, + fetch_fn: Callable[[str, str], list] | None = None, +) -> list: + sym = (symbol or "").strip() + if not sym: + return [] + if fetch_fn: + try: + bars = fetch_fn(sym, "d") or [] + except Exception as exc: + logger.debug("fetch week daily failed %s: %s", sym, exc) + return [] + return bars[-DAILY_LOOKBACK:] if bars else [] + + chart_sym = ths_to_sina_chart_symbol(sym) + if not chart_sym: + return [] + bars = _fetch_sina_daily_quick(chart_sym) + if not bars: + try: + bars = fetch_sina_klines(sym, "d") or [] + except Exception as exc: + logger.debug("fetch week daily fallback failed %s: %s", sym, exc) + return [] + return bars[-DAILY_LOOKBACK:] if bars else [] + + +def analyze_product_daily( + symbol: str, + *, + fetch_fn: Callable[[str, str], list] | None = None, +) -> dict: + """拉取主力合约一周日线:走势 + 跳空/量价统计。""" + sym = (symbol or "").strip() + if not sym: + out = analyze_daily_trend([]) + out.update(compute_daily_quote_stats([])) + return out + bars = fetch_week_daily_bars(sym, fetch_fn=fetch_fn) + out = analyze_daily_trend(bars) + out.update(compute_daily_quote_stats(bars)) + return out + + +def analyze_product_trend( + symbol: str, + *, + fetch_fn: Callable[[str, str], list] | None = None, +) -> dict: + return analyze_product_daily(symbol, fetch_fn=fetch_fn) + + +GAP_SORT_RANK = {"up": 2, "down": 1, "none": 0, "": -1} +TREND_SORT_RANK = { + TREND_BREAK_LONG: 0, + TREND_BREAK_SHORT: 0, + TREND_LONG: 1, + TREND_SHORT: 2, + TREND_RANGE: 3, + "": 9, +} + + +def recommend_sort_key(row: dict, sort_by: str = "trend", *, desc: bool = True) -> tuple: + """可排序字段:trend / gap / volume / amplitude。""" + key = (sort_by or "trend").strip().lower() + if key == "gap": + primary = GAP_SORT_RANK.get(row.get("gap") or "", -1) + secondary = abs(float(row.get("gap_pct") or 0)) + elif key == "volume": + primary = float(row.get("volume") or 0) + secondary = 0.0 + elif key == "amplitude": + primary = float(row.get("yesterday_amplitude_pct") or 0) + secondary = 0.0 + else: + primary = TREND_SORT_RANK.get(row.get("trend") or "", 9) + secondary = -(int(row.get("max_lots") or 0)) + + if desc: + return (-primary, -secondary, row.get("name") or "") + return (primary, secondary, row.get("name") or "") + + +def sort_recommend_rows( + rows: list[dict], + *, + sort_by: str = "trend", + desc: bool = True, +) -> list[dict]: + return sorted(rows, key=lambda r: recommend_sort_key(r, sort_by, desc=desc)) + + +def trend_sort_key(row: dict) -> tuple: + """转多/转空优先,其次多头/空头,震荡靠后。""" + trend = (row.get("trend") or "").strip() + priority = { + TREND_BREAK_LONG: 0, + TREND_BREAK_SHORT: 0, + TREND_LONG: 1, + TREND_SHORT: 1, + TREND_RANGE: 2, + } + status_order = {"ok": 0, "margin_ok": 1, "blocked": 2, "no_price": 3} + return ( + priority.get(trend, 3), + status_order.get(row.get("status") or "", 9), + -(int(row.get("max_lots") or 0)), + ) + + +def sort_recommend_by_trend(rows: list[dict]) -> list[dict]: + return sort_recommend_rows(rows, sort_by="trend", desc=True) diff --git a/sl_tp_guard.py b/modules/trading/sl_tp_guard.py similarity index 94% rename from sl_tp_guard.py rename to modules/trading/sl_tp_guard.py index 972d7e7..f573cd8 100644 --- a/sl_tp_guard.py +++ b/modules/trading/sl_tp_guard.py @@ -1,1058 +1,1058 @@ -# Copyright (c) 2025-2026 马建军. All rights reserved. -# 专有软件 — 未经授权禁止复制、传播、转售。 -# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 -# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md - -"""止盈止损守护:程序本地监控价位,触发后向 CTP 发平仓单(不向交易所挂 SL/TP 限价单)。""" -from __future__ import annotations - -import logging -import threading -import time -from datetime import datetime -from typing import Any, Callable, Optional -from zoneinfo import ZoneInfo - -from contract_specs import calc_position_metrics -from ctp_symbol import ths_to_vnpy_symbol -from fee_specs import calc_round_trip_fee -from trade_log_lib import calc_equity_after, refresh_trade_log_equity_chain -from market_sessions import is_trading_session -from symbols import ths_to_codes -from vnpy_bridge import ( - ctp_cancel_order, - ctp_get_tick_price, - ctp_list_active_orders, - ctp_list_positions, - ctp_status, - ctp_account_margin_used, - execute_order, - get_bridge, -) - -logger = logging.getLogger(__name__) - -TZ = ZoneInfo("Asia/Shanghai") -CHECK_INTERVAL_SEC = 1 -CLOSED_MARKET_SLEEP_SEC = 30 -DISCONNECTED_SLEEP_SEC = 5 -PLACE_COOLDOWN_SEC = 3 - -_last_close_attempt: dict[int, float] = {} -_closing_monitors: set[int] = set() -_closing_symbol_keys: set[str] = set() -_closing_lock = threading.Lock() - -MONITOR_ORDER_COLUMNS = ( - "ALTER TABLE trade_order_monitors ADD COLUMN sl_vt_order_id TEXT", - "ALTER TABLE trade_order_monitors ADD COLUMN tp_vt_order_id TEXT", - "ALTER TABLE trade_order_monitors ADD COLUMN trailing_be INTEGER DEFAULT 0", - "ALTER TABLE trade_order_monitors ADD COLUMN initial_stop_loss REAL", - "ALTER TABLE trade_order_monitors ADD COLUMN trailing_r_locked INTEGER DEFAULT 0", - "ALTER TABLE trade_order_monitors ADD COLUMN margin REAL", - "ALTER TABLE trade_order_monitors ADD COLUMN position_pct REAL", - "ALTER TABLE trade_order_monitors ADD COLUMN mark_price REAL", - "ALTER TABLE trade_order_monitors ADD COLUMN float_pnl REAL", - "ALTER TABLE trade_order_monitors ADD COLUMN vt_order_id TEXT", - "ALTER TABLE trade_order_monitors ADD COLUMN order_price REAL", - "ALTER TABLE trade_order_monitors ADD COLUMN open_fee REAL", -) - -TRADE_RESULTS = ("止损", "止盈", "移动止盈", "保本止盈", "手动平仓") - -_MONITOR_COLUMNS_READY = False -_MONITOR_COLUMNS_LOCK = threading.Lock() - - -def _monitor_columns_exist(conn) -> bool: - try: - rows = conn.execute("PRAGMA table_info(trade_order_monitors)").fetchall() - cols = set() - for r in rows: - if isinstance(r, dict): - cols.add(r.get("name") or "") - else: - cols.add(r[1]) - return "open_fee" in cols - except Exception: - return False - - -def ensure_monitor_order_columns(conn, *, migrate: bool = False) -> None: - """列齐全后不再 ALTER,避免 worker 每次请求锁 SQLite。""" - global _MONITOR_COLUMNS_READY - if _MONITOR_COLUMNS_READY: - return - with _MONITOR_COLUMNS_LOCK: - if _MONITOR_COLUMNS_READY: - return - if _monitor_columns_exist(conn): - _MONITOR_COLUMNS_READY = True - return - if not migrate: - return - for sql in MONITOR_ORDER_COLUMNS: - try: - conn.execute(sql) - conn.commit() - except Exception: - try: - conn.rollback() - except Exception: - pass - _MONITOR_COLUMNS_READY = True - - -def _tick_size(ths_code: str) -> float: - from contract_specs import get_contract_spec - return float(get_contract_spec(ths_code).get("tick_size") or 1.0) - - -def _match_symbol(ctp_sym: str, ths: str) -> bool: - a = (ctp_sym or "").lower() - b = (ths or "").lower() - if a == b: - return True - if a and b and a.split(".")[0] == b.split(".")[0]: - return True - try: - vnpy_sym, _ = ths_to_vnpy_symbol(ths) - if a == vnpy_sym.lower(): - return True - except Exception: - pass - try: - vnpy_sym, _ = ths_to_vnpy_symbol(ctp_sym) - if vnpy_sym.lower() == b.split(".")[0]: - return True - except Exception: - pass - return False - - -def _close_order_direction(hold_direction: str) -> str: - return "short" if hold_direction == "long" else "long" - - -def _price_near(a: float, b: float, tick: float) -> bool: - return abs(float(a) - float(b)) <= max(tick * 0.501, 1e-9) - - -def _find_close_order( - active_orders: list[dict], - *, - ths_code: str, - hold_direction: str, - price: float, - tick: float, -) -> Optional[dict]: - close_dir = _close_order_direction(hold_direction) - for o in active_orders: - sym = o.get("symbol") or "" - if not _match_symbol(sym, ths_code): - continue - offset_s = (o.get("offset") or "").upper() - if "CLOSE" not in offset_s: - continue - if (o.get("direction") or "") != close_dir: - continue - if not _price_near(o.get("price") or 0, price, tick): - continue - return o - return None - - -def _find_position(positions: list[dict], ths_code: str, direction: str) -> Optional[dict]: - for p in positions: - if int(p.get("lots") or 0) <= 0: - continue - if (p.get("direction") or "long") != direction: - continue - if _match_symbol(p.get("symbol") or "", ths_code): - return p - return None - - -def _position_key(sym: str, direction: str) -> str: - return f"{(sym or '').strip().lower()}|{(direction or 'long').strip().lower()}" - - -def _try_acquire_close_symbol(sym: str, direction: str) -> bool: - key = _position_key(sym, direction) - with _closing_lock: - if key in _closing_symbol_keys: - return False - _closing_symbol_keys.add(key) - return True - - -def _release_close_symbol(sym: str, direction: str) -> None: - key = _position_key(sym, direction) - with _closing_lock: - _closing_symbol_keys.discard(key) - - -def _close_all_monitors_for_symbol(conn, sym: str, direction: str) -> None: - direction = (direction or "long").strip().lower() - for r in conn.execute( - "SELECT id, symbol, direction FROM trade_order_monitors WHERE status='active'" - ).fetchall(): - if (r["direction"] or "long") != direction: - continue - if _match_symbol(sym, r["symbol"] or ""): - conn.execute( - "UPDATE trade_order_monitors SET status='closed' WHERE id=?", - (r["id"],), - ) - - -def _dedupe_active_monitors(conn) -> None: - """同一品种方向只保留一条 active 监控,避免重复触发平仓。""" - groups: dict[str, list[dict]] = {} - for r in conn.execute( - "SELECT * FROM trade_order_monitors WHERE status='active' ORDER BY id ASC" - ).fetchall(): - row = dict(r) - key = _position_key(row.get("symbol") or "", row.get("direction") or "long") - groups.setdefault(key, []).append(row) - for items in groups.values(): - if len(items) <= 1: - continue - - def _keep_score(m: dict) -> tuple: - mt = (m.get("monitor_type") or "").lower() - score = 0 - if mt != "ctp_sync": - score += 10 - if m.get("stop_loss") is not None: - score += 5 - return (score, int(m.get("id") or 0)) - - items.sort(key=_keep_score, reverse=True) - for dup in items[1:]: - conn.execute( - "UPDATE trade_order_monitors SET status='closed' WHERE id=?", - (dup["id"],), - ) - - -def _can_close_now(monitor_id: int, *, cooldown: int = PLACE_COOLDOWN_SEC) -> bool: - last = _last_close_attempt.get(monitor_id, 0.0) - return (time.time() - last) >= cooldown - - -def _mark_close_attempt(monitor_id: int) -> None: - _last_close_attempt[monitor_id] = time.time() - - -def _try_acquire_close(monitor_id: int) -> bool: - with _closing_lock: - if monitor_id in _closing_monitors: - return False - _closing_monitors.add(monitor_id) - return True - - -def _release_close(monitor_id: int) -> None: - with _closing_lock: - _closing_monitors.discard(monitor_id) - - -def monitor_source_label(raw: str) -> str: - """持仓展示用来源文案。""" - mapping = { - "manual": "期货下单", - "trend": "趋势回调", - "roll": "顺势加仓", - "ctp_sync": "CTP 柜台", - "箱体突破": "箱体突破", - "收敛突破": "收敛突破", - } - key = (raw or "manual").strip().lower() - return mapping.get(key, raw or "期货下单") - - -def _result_for_close(mon: dict, reason: str) -> str: - """平仓结果:止损 / 止盈 / 移动止盈 / 保本止盈 / 手动平仓。""" - if reason == "manual": - return "手动平仓" - if reason == "take_profit": - return "止盈" - if not mon.get("trailing_be"): - return "止损" - locked = int(mon.get("trailing_r_locked") or 0) - if locked >= 2: - return "移动止盈" - if locked >= 1: - return "保本止盈" - return "止损" - - -def write_trade_log( - conn, - *, - symbol: str, - direction: str, - entry_price: float, - close_price: float, - lots: float, - result: str, - trading_mode: str, - stop_loss: Optional[float] = None, - take_profit: Optional[float] = None, - open_time: str = "", - symbol_name: str = "", - market_code: str = "", - sina_code: str = "", - monitor_type: str = "期货下单", - capital: float = 0.0, -) -> None: - """写入 trade_logs(程序平仓 / 手动平仓)。""" - sym = (symbol or "").strip() - direction = (direction or "long").strip().lower() - entry = float(entry_price or close_price) - sl = float(stop_loss) if stop_loss is not None else entry - tp = float(take_profit) if take_profit is not None else entry - close_time = datetime.now(TZ).strftime("%Y-%m-%dT%H:%M") - - if not sina_code or not market_code: - codes = ths_to_codes(sym) or {} - sina_code = sina_code or codes.get("sina_code") or "" - market_code = market_code or codes.get("market_code") or "" - if not symbol_name: - symbol_name = sym - - metrics = calc_position_metrics( - direction, entry, sl, tp, lots, close_price, capital, sym, - ) - pnl = metrics.get("float_pnl") or 0.0 - fee = calc_round_trip_fee( - sym, entry, close_price, lots, open_time, close_time, trading_mode=trading_mode, - ) - pnl_net = round(pnl - fee, 2) - margin_pct = metrics.get("position_pct") - equity_after = calc_equity_after(capital, pnl_net) - - try: - from app import holding_to_minutes - minutes = holding_to_minutes(open_time, close_time) - except Exception: - minutes = 0 - - conn.execute( - """INSERT INTO trade_logs - (symbol, symbol_name, market_code, sina_code, monitor_type, direction, - entry_price, stop_loss, take_profit, close_price, lots, margin, - margin_pct, holding_minutes, open_time, close_time, pnl, fee, pnl_net, - equity_after, result) - VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)""", - ( - sym, - symbol_name, - market_code, - sina_code, - monitor_type, - direction, - entry, - stop_loss if stop_loss is not None else sl, - take_profit if take_profit is not None else tp, - close_price, - lots, - metrics.get("margin"), - margin_pct, - minutes, - open_time, - close_time, - pnl, - fee, - pnl_net, - equity_after, - result if result in TRADE_RESULTS else "手动平仓", - ), - ) - try: - refresh_trade_log_equity_chain(conn, capital if capital > 0 else None) - except Exception as exc: - logger.debug("equity chain refresh after trade log: %s", exc) - try: - from stats_engine import refresh_stats_cache - refresh_stats_cache(conn, capital) - except Exception as exc: - logger.debug("stats refresh after close: %s", exc) - try: - from trade_notify import notify_trade_log_close - from trading_context import trading_mode_label - from app import get_setting, send_wechat_msg - from ai_worker import schedule_ai_event_analysis - from db_conn import DB_PATH - - notify_trade_log_close( - send_wechat=send_wechat_msg, - get_setting=get_setting, - mode_label=trading_mode_label(get_setting), - capital=capital, - sym=sym, - symbol_name=symbol_name, - direction=direction, - entry=entry, - close_price=close_price, - sl=stop_loss if stop_loss is not None else None, - tp=take_profit if take_profit is not None else None, - lots=lots, - pnl_net=pnl_net, - equity_after=equity_after, - holding_minutes=minutes, - result=result, - monitor_type=monitor_type, - schedule_ai_fn=schedule_ai_event_analysis, - db_path=DB_PATH, - ) - except Exception as exc: - logger.debug("close notify: %s", exc) - - -def _write_trade_log( - conn, - mon: dict, - *, - close_price: float, - reason: str, - trading_mode: str, - capital: float = 0.0, -) -> None: - sym = (mon.get("symbol") or "").strip() - sl_raw = mon.get("stop_loss") - tp_raw = mon.get("take_profit") - initial_sl = mon.get("initial_stop_loss") - write_trade_log( - conn, - symbol=sym, - direction=mon.get("direction") or "long", - entry_price=float(mon.get("entry_price") or close_price), - close_price=close_price, - lots=float(mon.get("lots") or 1), - result=_result_for_close(mon, reason), - trading_mode=trading_mode, - stop_loss=float(initial_sl) if initial_sl is not None else ( - float(sl_raw) if sl_raw is not None else None - ), - take_profit=float(tp_raw) if tp_raw is not None else None, - open_time=(mon.get("open_time") or "").strip(), - symbol_name=mon.get("symbol_name") or sym, - market_code=mon.get("market_code") or "", - monitor_type=monitor_source_label(mon.get("monitor_type") or ""), - capital=capital, - ) - - -def write_manual_close_trade_log( - conn, - mon: Optional[dict], - *, - symbol: str, - direction: str, - lots: float, - close_price: float, - entry_price: float, - trading_mode: str, - capital: float = 0.0, - stop_loss: Optional[float] = None, - take_profit: Optional[float] = None, - open_time: str = "", - symbol_name: str = "", - market_code: str = "", -) -> None: - """程序内点击平仓按钮 → 手动平仓。""" - if mon: - write_trade_log( - conn, - symbol=(mon.get("symbol") or symbol).strip(), - direction=mon.get("direction") or direction, - entry_price=float(mon.get("entry_price") or entry_price), - close_price=close_price, - lots=float(mon.get("lots") or lots), - result="手动平仓", - trading_mode=trading_mode, - stop_loss=float(mon["initial_stop_loss"]) if mon.get("initial_stop_loss") is not None else ( - float(mon["stop_loss"]) if mon.get("stop_loss") is not None else stop_loss - ), - take_profit=float(mon["take_profit"]) if mon.get("take_profit") is not None else take_profit, - open_time=(mon.get("open_time") or open_time).strip(), - symbol_name=mon.get("symbol_name") or symbol_name, - market_code=mon.get("market_code") or market_code, - monitor_type=monitor_source_label(mon.get("monitor_type") or ""), - capital=capital, - ) - return - write_trade_log( - conn, - symbol=symbol, - direction=direction, - entry_price=entry_price, - close_price=close_price, - lots=lots, - result="手动平仓", - trading_mode=trading_mode, - stop_loss=stop_loss, - take_profit=take_profit, - open_time=open_time, - symbol_name=symbol_name, - market_code=market_code, - capital=capital, - ) - - -def _update_trailing_stop_loss( - conn, - mon: dict, - mark: float, - *, - be_tick_mult: int, -) -> dict: - """达 1R 移保本(开仓±N跳),达 2R 移 1R,依次类推。""" - if not mon.get("trailing_be"): - return mon - entry = float(mon.get("entry_price") or 0) - initial_sl = mon.get("initial_stop_loss") - if initial_sl is None: - initial_sl = mon.get("stop_loss") - try: - initial_sl_f = float(initial_sl) if initial_sl is not None else None - except (TypeError, ValueError): - return mon - if not entry or initial_sl_f is None: - return mon - - direction = (mon.get("direction") or "long").strip().lower() - sym = (mon.get("symbol") or "").strip() - tick = _tick_size(sym) - r = abs(entry - initial_sl_f) - if r < tick * 0.5: - return mon - - profit_r = (mark - entry) / r if direction == "long" else (entry - mark) / r - if profit_r < 1.0: - return mon - - level = int(profit_r) - locked = int(mon.get("trailing_r_locked") or 0) - if level <= locked: - return mon - - if level == 1: - new_sl = entry + be_tick_mult * tick if direction == "long" else entry - be_tick_mult * tick - else: - new_sl = entry + (level - 1) * r if direction == "long" else entry - (level - 1) * r - new_sl = round(new_sl, 4) - - try: - current_sl = float(mon.get("stop_loss") or 0) - except (TypeError, ValueError): - current_sl = 0.0 - if direction == "long" and new_sl <= current_sl + tick * 0.01: - return mon - if direction == "short" and new_sl >= current_sl - tick * 0.01: - return mon - - mid = mon.get("id") - conn.execute( - "UPDATE trade_order_monitors SET stop_loss=?, trailing_r_locked=? WHERE id=?", - (new_sl, level, mid), - ) - conn.commit() - mon["stop_loss"] = new_sl - mon["trailing_r_locked"] = level - logger.info("移动保本 monitor=%s %dR 止损→%s", mid, level, new_sl) - return mon - - -def _sl_triggered(direction: str, sl: float, mark: float, tick: float) -> bool: - buf = max(tick * 0.01, 1e-9) - if direction == "long": - return mark <= sl + buf - return mark >= sl - buf - - -def _tp_triggered(direction: str, tp: float, mark: float, tick: float) -> bool: - buf = max(tick * 0.01, 1e-9) - if direction == "long": - return mark >= tp - buf - return mark <= tp + buf - - -def cancel_monitor_exit_orders( - conn, - mon: dict, - *, - mode: str, -) -> int: - """撤销该监控在交易所残留的旧版止盈止损平仓挂单。""" - ensure_monitor_order_columns(conn) - if not ctp_status(mode).get("connected"): - return 0 - sym = (mon.get("symbol") or "").strip() - direction = (mon.get("direction") or "long").strip().lower() - tick = _tick_size(sym) - active = ctp_list_active_orders(mode) - cancelled = 0 - seen: set[str] = set() - - def _try_cancel(vt_id: str) -> None: - nonlocal cancelled - oid = str(vt_id or "").strip() - if not oid or oid in seen: - return - seen.add(oid) - if ctp_cancel_order(mode, oid): - cancelled += 1 - - for kind, price_key in (("sl", "stop_loss"), ("tp", "take_profit")): - raw = mon.get(price_key) - try: - px = float(raw) if raw is not None else None - except (TypeError, ValueError): - px = None - stored = str(mon.get(f"{kind}_vt_order_id") or "") - if stored: - _try_cancel(stored) - if px is not None: - found = _find_close_order( - active, ths_code=sym, hold_direction=direction, price=px, tick=tick, - ) - if found: - _try_cancel(str(found.get("order_id") or "")) - - if cancelled: - conn.execute( - "UPDATE trade_order_monitors SET sl_vt_order_id=NULL, tp_vt_order_id=NULL WHERE id=?", - (mon["id"],), - ) - conn.commit() - return cancelled - - -def reconcile_monitors_without_position(conn, mode: str, *, grace_sec: int = 120) -> int: - """持仓已平时:关闭监控并撤销残留止盈止损挂单(新开仓 grace_sec 内不清理)。""" - if not ctp_status(mode).get("connected"): - return 0 - try: - bridge = get_bridge() - since_connect = time.time() - float(getattr(bridge, "_last_connect_ok_ts", 0) or 0) - if since_connect < 90: - return 0 - except Exception: - pass - positions = ctp_list_positions(mode, refresh_if_empty=False, refresh_margin=False) - position_keys: set[tuple[str, str]] = set() - for p in positions: - if int(p.get("lots") or 0) <= 0: - continue - sym = (p.get("symbol") or "").lower() - direction = p.get("direction") or "long" - position_keys.add((sym, direction)) - try: - from ctp_trading_state import trading_state - - for p in trading_state.get_positions() or []: - lots = int(p.get("lots") or 0) - if lots <= 0: - continue - sym = (p.get("symbol") or "").lower() - direction = p.get("direction") or "long" - position_keys.add((sym, direction)) - except Exception: - pass - - margin_raw = ctp_account_margin_used(mode) - if margin_raw is None: - return 0 - margin_used = float(margin_raw or 0.0) - if not position_keys: - if margin_used > 0: - return 0 - try: - bridge = get_bridge() - since_connect = time.time() - float(getattr(bridge, "_last_connect_ok_ts", 0) or 0) - if since_connect < 180: - return 0 - except Exception: - return 0 - - now_ts = time.time() - - def _monitor_within_grace(mon: dict) -> bool: - raw = (mon.get("open_time") or mon.get("created_at") or "").strip() - if not raw: - return True - for fmt in ("%Y-%m-%d %H:%M:%S", "%Y-%m-%dT%H:%M:%S", "%Y-%m-%dT%H:%M"): - try: - dt = datetime.strptime(raw[:19], fmt) - if (now_ts - dt.timestamp()) <= grace_sec: - return True - except ValueError: - continue - return False - - closed = 0 - for r in conn.execute("SELECT * FROM trade_order_monitors WHERE status='active'").fetchall(): - mon = dict(r) - if _monitor_within_grace(mon): - continue - ms = mon.get("symbol") or "" - md = mon.get("direction") or "long" - matched = False - for ps, pd in position_keys: - if pd != md: - continue - if _match_symbol(ps, ms): - matched = True - break - if matched: - continue - try: - cancel_monitor_exit_orders(conn, mon, mode=mode) - except Exception as exc: - logger.warning("cancel exit orders monitor=%s: %s", mon.get("id"), exc) - conn.execute("UPDATE trade_order_monitors SET status='closed' WHERE id=?", (mon["id"],)) - closed += 1 - if closed: - conn.commit() - return closed - - -def _execute_local_close( - conn, - mon: dict, - *, - mode: str, - mark: float, - reason: str, - capital: float = 0.0, - notify_fn: Callable[[str], None] | None = None, -) -> None: - sym = (mon.get("symbol") or "").strip() - direction = (mon.get("direction") or "long").strip().lower() - positions = ctp_list_positions(mode) - pos = _find_position(positions, sym, direction) - if not pos: - margin_raw = ctp_account_margin_used(mode) - if margin_raw is not None and float(margin_raw) > 0: - logger.debug( - "skip close monitor=%s: vnpy empty but margin=%.2f", - mon.get("id"), - float(margin_raw), - ) - return - _close_all_monitors_for_symbol(conn, sym, direction) - reconcile_monitors_without_position(conn, mode) - return - lots = int(pos.get("lots") or mon.get("lots") or 1) - offset = "close_long" if direction == "long" else "close_short" - cancel_monitor_exit_orders(conn, mon, mode=mode) - execute_order( - conn, - mode=mode, - offset=offset, - symbol=sym, - direction=direction, - lots=lots, - price=mark, - order_type="market", - ) - _close_all_monitors_for_symbol(conn, sym, direction) - conn.commit() - result_label = _result_for_close(mon, reason) - logger.info( - "止盈止损本地触发 monitor=%s result=%s %s %s %d手 @%s(待 CTP 成交同步写入交易记录)", - mon.get("id"), result_label, sym, direction, lots, mark, - ) - if notify_fn: - try: - notify_fn(f"{result_label} {sym} {direction} {lots}手 @{mark},平仓委托已提交") - except Exception as exc: - logger.debug("SL/TP notify failed: %s", exc) - - -def check_sl_tp_on_tick( - conn, - mode: str, - exchange: str, - symbol: str, - mark: float, - *, - capital: float = 0.0, - notify_fn: Callable[[str], None] | None = None, - be_tick_mult: int = 2, -) -> int: - """EVENT_TICK 触发:仅检查与 tick 品种匹配的 active 监控。""" - ensure_monitor_order_columns(conn) - if not ctp_status(mode).get("connected") or not is_trading_session(): - return 0 - if mark <= 0: - return 0 - sym_l = (symbol or "").lower() - ex_u = (exchange or "").upper() - closed = 0 - rows = [dict(r) for r in conn.execute( - "SELECT * FROM trade_order_monitors WHERE status='active'" - ).fetchall()] - for mon in rows: - mid = int(mon.get("id") or 0) - ms = (mon.get("symbol") or "").strip() - direction = (mon.get("direction") or "long").strip().lower() - try: - vnpy_sym, ex2 = ths_to_vnpy_symbol(ms) - if sym_l != vnpy_sym.lower(): - continue - if ex_u and ex2 and ex_u != ex2.upper(): - continue - except Exception: - if sym_l != ms.lower(): - continue - - sl = mon.get("stop_loss") - tp = mon.get("take_profit") - try: - sl_f = float(sl) if sl is not None else None - tp_f = float(tp) if tp is not None else None - except (TypeError, ValueError): - sl_f, tp_f = None, None - if sl_f is None and tp_f is None: - continue - - positions = ctp_list_positions(mode) - if not _find_position(positions, ms, direction): - continue - - tick = _tick_size(ms) - if mon.get("trailing_be"): - mon = _update_trailing_stop_loss(conn, mon, mark, be_tick_mult=be_tick_mult) - try: - sl_f = float(mon["stop_loss"]) if mon.get("stop_loss") is not None else sl_f - except (TypeError, ValueError): - pass - - reason = None - if tp_f is not None and _tp_triggered(direction, tp_f, mark, tick): - reason = "take_profit" - elif sl_f is not None and _sl_triggered(direction, sl_f, mark, tick): - reason = "stop_loss" - if not reason: - continue - if mid > 0 and not _can_close_now(mid): - continue - if not _try_acquire_close_symbol(ms, direction): - continue - try: - _execute_local_close( - conn, mon, mode=mode, mark=mark, reason=reason, - capital=capital, notify_fn=notify_fn, - ) - if mid > 0: - _mark_close_attempt(mid) - closed += 1 - except Exception as exc: - logger.warning("SL/TP tick close failed monitor=%s: %s", mid, exc) - finally: - _release_close_symbol(ms, direction) - return closed - - -def check_monitors_locally( - conn, - mode: str, - *, - capital: float = 0.0, - notify_fn: Callable[[str], None] | None = None, - be_tick_mult: int = 2, -) -> int: - """扫描 active 监控,本地比对行情;触发止盈/止损(含跳空穿透)后立刻市价平仓并记交易记录。""" - ensure_monitor_order_columns(conn) - if not ctp_status(mode).get("connected"): - return 0 - if not is_trading_session(): - return 0 - reconcile_monitors_without_position(conn, mode) - _dedupe_active_monitors(conn) - conn.commit() - closed = 0 - rows = [dict(r) for r in conn.execute( - "SELECT * FROM trade_order_monitors WHERE status='active'" - ).fetchall()] - for mon in rows: - mid = int(mon.get("id") or 0) - sym = (mon.get("symbol") or "").strip() - direction = (mon.get("direction") or "long").strip().lower() - - if mon.get("sl_vt_order_id") or mon.get("tp_vt_order_id"): - cancel_monitor_exit_orders(conn, mon, mode=mode) - - sl = mon.get("stop_loss") - tp = mon.get("take_profit") - try: - sl_f = float(sl) if sl is not None else None - tp_f = float(tp) if tp is not None else None - except (TypeError, ValueError): - sl_f, tp_f = None, None - if sl_f is None and tp_f is None: - continue - - positions = ctp_list_positions(mode) - if not _find_position(positions, sym, direction): - continue - - mark = ctp_get_tick_price(mode, sym) - if mark is None or mark <= 0: - continue - - tick = _tick_size(sym) - if mon.get("trailing_be"): - mon = _update_trailing_stop_loss(conn, mon, mark, be_tick_mult=be_tick_mult) - try: - sl_f = float(mon["stop_loss"]) if mon.get("stop_loss") is not None else sl_f - except (TypeError, ValueError): - pass - - reason = None - if tp_f is not None and _tp_triggered(direction, tp_f, mark, tick): - reason = "take_profit" - elif sl_f is not None and _sl_triggered(direction, sl_f, mark, tick): - reason = "stop_loss" - - if not reason: - continue - if mid > 0 and not _can_close_now(mid): - continue - if not _try_acquire_close_symbol(sym, direction): - continue - try: - _execute_local_close( - conn, - mon, - mode=mode, - mark=mark, - reason=reason, - capital=capital, - notify_fn=notify_fn, - ) - if mid > 0: - _mark_close_attempt(mid) - closed += 1 - except Exception as exc: - logger.warning("SL/TP local close failed monitor=%s: %s", mid, exc) - finally: - _release_close_symbol(sym, direction) - return closed - - -def place_monitor_exit_orders( - conn, - mon: dict, - *, - mode: str, - force: bool = False, -) -> dict[str, Any]: - """兼容旧 API:本地监控模式不再向交易所挂 SL/TP 单,仅清理旧挂单。""" - del force - ensure_monitor_order_columns(conn) - if not ctp_status(mode).get("connected"): - return {"ok": False, "error": "CTP 未连接", "placed": []} - cancelled = cancel_monitor_exit_orders(conn, mon, mode=mode) - msg = "程序本地监控中,不向交易所挂止盈止损单" - if cancelled: - msg += f";已撤销旧版柜台挂单 {cancelled} 笔" - return {"ok": True, "message": msg, "placed": [], "local_monitor": True} - - -def monitor_order_status( - mon: dict, - *, - mode: str, - ths_code: str, - direction: str, -) -> dict[str, bool]: - """返回本地监控状态(非交易所挂单状态)。""" - del mode, ths_code, direction - sl = mon.get("stop_loss") if mon else None - tp = mon.get("take_profit") if mon else None - try: - sl_f = float(sl) if sl is not None else None - tp_f = float(tp) if tp is not None else None - except (TypeError, ValueError): - sl_f, tp_f = None, None - return { - "sl_order_active": sl_f is not None, - "tp_order_active": tp_f is not None, - "sl_monitoring": sl_f is not None, - "tp_monitoring": tp_f is not None, - "needs_sl_order": False, - "needs_tp_order": False, - } - - -def sync_all_sl_tp_orders(conn, mode: str) -> int: - """兼容旧 worker 入口:执行本地监控检查。""" - del mode - return 0 - - -def start_sl_tp_guard_worker( - *, - db_path: str, - get_mode_fn: Callable[[], str], - init_tables_fn: Callable | None = None, - get_capital_fn: Callable | None = None, - get_be_tick_buffer_fn: Callable[[], int] | None = None, - notify_fn: Callable[[str], None] | None = None, - interval: int = CHECK_INTERVAL_SEC, -) -> None: - from db_conn import connect_db - - def _loop() -> None: - time.sleep(20) - while True: - sleep_sec = max(1, interval) - try: - if not is_trading_session(): - time.sleep(CLOSED_MARKET_SLEEP_SEC) - continue - mode = get_mode_fn() - if not ctp_status(mode).get("connected"): - time.sleep(DISCONNECTED_SLEEP_SEC) - continue - conn = connect_db(db_path) - try: - if init_tables_fn: - init_tables_fn(conn) - has_monitors = conn.execute( - """SELECT COUNT(*) AS n FROM trade_order_monitors - WHERE status='active' - AND (stop_loss IS NOT NULL OR take_profit IS NOT NULL)""" - ).fetchone()["n"] - if not has_monitors: - sleep_sec = max(sleep_sec, 5) - else: - capital = 0.0 - if get_capital_fn: - try: - capital = float(get_capital_fn(conn) or 0) - except Exception: - capital = 0.0 - n = check_monitors_locally( - conn, - mode, - capital=capital, - notify_fn=notify_fn, - be_tick_mult=( - get_be_tick_buffer_fn() if get_be_tick_buffer_fn else 2 - ), - ) - if n: - logger.info("止盈止损本地监控: 触发平仓 %d 笔", n) - finally: - conn.close() - except Exception as exc: - logger.warning("sl_tp_guard worker: %s", exc) - time.sleep(sleep_sec) - - threading.Thread(target=_loop, daemon=True, name="sl-tp-guard").start() +# Copyright (c) 2025-2026 马建军. All rights reserved. +# 专有软件 — 未经授权禁止复制、传播、转售。 +# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 +# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md + +"""止盈止损守护:程序本地监控价位,触发后向 CTP 发平仓单(不向交易所挂 SL/TP 限价单)。""" +from __future__ import annotations + +import logging +import threading +import time +from datetime import datetime +from typing import Any, Callable, Optional +from zoneinfo import ZoneInfo + +from modules.core.contract_specs import calc_position_metrics +from modules.ctp.ctp_symbol import ths_to_vnpy_symbol +from modules.fees.fee_specs import calc_round_trip_fee +from modules.trading.trade_log_lib import calc_equity_after, refresh_trade_log_equity_chain +from modules.market.market_sessions import is_trading_session +from modules.core.symbols import ths_to_codes +from modules.ctp.vnpy_bridge import ( + ctp_cancel_order, + ctp_get_tick_price, + ctp_list_active_orders, + ctp_list_positions, + ctp_status, + ctp_account_margin_used, + execute_order, + get_bridge, +) + +logger = logging.getLogger(__name__) + +TZ = ZoneInfo("Asia/Shanghai") +CHECK_INTERVAL_SEC = 1 +CLOSED_MARKET_SLEEP_SEC = 30 +DISCONNECTED_SLEEP_SEC = 5 +PLACE_COOLDOWN_SEC = 3 + +_last_close_attempt: dict[int, float] = {} +_closing_monitors: set[int] = set() +_closing_symbol_keys: set[str] = set() +_closing_lock = threading.Lock() + +MONITOR_ORDER_COLUMNS = ( + "ALTER TABLE trade_order_monitors ADD COLUMN sl_vt_order_id TEXT", + "ALTER TABLE trade_order_monitors ADD COLUMN tp_vt_order_id TEXT", + "ALTER TABLE trade_order_monitors ADD COLUMN trailing_be INTEGER DEFAULT 0", + "ALTER TABLE trade_order_monitors ADD COLUMN initial_stop_loss REAL", + "ALTER TABLE trade_order_monitors ADD COLUMN trailing_r_locked INTEGER DEFAULT 0", + "ALTER TABLE trade_order_monitors ADD COLUMN margin REAL", + "ALTER TABLE trade_order_monitors ADD COLUMN position_pct REAL", + "ALTER TABLE trade_order_monitors ADD COLUMN mark_price REAL", + "ALTER TABLE trade_order_monitors ADD COLUMN float_pnl REAL", + "ALTER TABLE trade_order_monitors ADD COLUMN vt_order_id TEXT", + "ALTER TABLE trade_order_monitors ADD COLUMN order_price REAL", + "ALTER TABLE trade_order_monitors ADD COLUMN open_fee REAL", +) + +TRADE_RESULTS = ("止损", "止盈", "移动止盈", "保本止盈", "手动平仓") + +_MONITOR_COLUMNS_READY = False +_MONITOR_COLUMNS_LOCK = threading.Lock() + + +def _monitor_columns_exist(conn) -> bool: + try: + rows = conn.execute("PRAGMA table_info(trade_order_monitors)").fetchall() + cols = set() + for r in rows: + if isinstance(r, dict): + cols.add(r.get("name") or "") + else: + cols.add(r[1]) + return "open_fee" in cols + except Exception: + return False + + +def ensure_monitor_order_columns(conn, *, migrate: bool = False) -> None: + """列齐全后不再 ALTER,避免 worker 每次请求锁 SQLite。""" + global _MONITOR_COLUMNS_READY + if _MONITOR_COLUMNS_READY: + return + with _MONITOR_COLUMNS_LOCK: + if _MONITOR_COLUMNS_READY: + return + if _monitor_columns_exist(conn): + _MONITOR_COLUMNS_READY = True + return + if not migrate: + return + for sql in MONITOR_ORDER_COLUMNS: + try: + conn.execute(sql) + conn.commit() + except Exception: + try: + conn.rollback() + except Exception: + pass + _MONITOR_COLUMNS_READY = True + + +def _tick_size(ths_code: str) -> float: + from modules.core.contract_specs import get_contract_spec + return float(get_contract_spec(ths_code).get("tick_size") or 1.0) + + +def _match_symbol(ctp_sym: str, ths: str) -> bool: + a = (ctp_sym or "").lower() + b = (ths or "").lower() + if a == b: + return True + if a and b and a.split(".")[0] == b.split(".")[0]: + return True + try: + vnpy_sym, _ = ths_to_vnpy_symbol(ths) + if a == vnpy_sym.lower(): + return True + except Exception: + pass + try: + vnpy_sym, _ = ths_to_vnpy_symbol(ctp_sym) + if vnpy_sym.lower() == b.split(".")[0]: + return True + except Exception: + pass + return False + + +def _close_order_direction(hold_direction: str) -> str: + return "short" if hold_direction == "long" else "long" + + +def _price_near(a: float, b: float, tick: float) -> bool: + return abs(float(a) - float(b)) <= max(tick * 0.501, 1e-9) + + +def _find_close_order( + active_orders: list[dict], + *, + ths_code: str, + hold_direction: str, + price: float, + tick: float, +) -> Optional[dict]: + close_dir = _close_order_direction(hold_direction) + for o in active_orders: + sym = o.get("symbol") or "" + if not _match_symbol(sym, ths_code): + continue + offset_s = (o.get("offset") or "").upper() + if "CLOSE" not in offset_s: + continue + if (o.get("direction") or "") != close_dir: + continue + if not _price_near(o.get("price") or 0, price, tick): + continue + return o + return None + + +def _find_position(positions: list[dict], ths_code: str, direction: str) -> Optional[dict]: + for p in positions: + if int(p.get("lots") or 0) <= 0: + continue + if (p.get("direction") or "long") != direction: + continue + if _match_symbol(p.get("symbol") or "", ths_code): + return p + return None + + +def _position_key(sym: str, direction: str) -> str: + return f"{(sym or '').strip().lower()}|{(direction or 'long').strip().lower()}" + + +def _try_acquire_close_symbol(sym: str, direction: str) -> bool: + key = _position_key(sym, direction) + with _closing_lock: + if key in _closing_symbol_keys: + return False + _closing_symbol_keys.add(key) + return True + + +def _release_close_symbol(sym: str, direction: str) -> None: + key = _position_key(sym, direction) + with _closing_lock: + _closing_symbol_keys.discard(key) + + +def _close_all_monitors_for_symbol(conn, sym: str, direction: str) -> None: + direction = (direction or "long").strip().lower() + for r in conn.execute( + "SELECT id, symbol, direction FROM trade_order_monitors WHERE status='active'" + ).fetchall(): + if (r["direction"] or "long") != direction: + continue + if _match_symbol(sym, r["symbol"] or ""): + conn.execute( + "UPDATE trade_order_monitors SET status='closed' WHERE id=?", + (r["id"],), + ) + + +def _dedupe_active_monitors(conn) -> None: + """同一品种方向只保留一条 active 监控,避免重复触发平仓。""" + groups: dict[str, list[dict]] = {} + for r in conn.execute( + "SELECT * FROM trade_order_monitors WHERE status='active' ORDER BY id ASC" + ).fetchall(): + row = dict(r) + key = _position_key(row.get("symbol") or "", row.get("direction") or "long") + groups.setdefault(key, []).append(row) + for items in groups.values(): + if len(items) <= 1: + continue + + def _keep_score(m: dict) -> tuple: + mt = (m.get("monitor_type") or "").lower() + score = 0 + if mt != "ctp_sync": + score += 10 + if m.get("stop_loss") is not None: + score += 5 + return (score, int(m.get("id") or 0)) + + items.sort(key=_keep_score, reverse=True) + for dup in items[1:]: + conn.execute( + "UPDATE trade_order_monitors SET status='closed' WHERE id=?", + (dup["id"],), + ) + + +def _can_close_now(monitor_id: int, *, cooldown: int = PLACE_COOLDOWN_SEC) -> bool: + last = _last_close_attempt.get(monitor_id, 0.0) + return (time.time() - last) >= cooldown + + +def _mark_close_attempt(monitor_id: int) -> None: + _last_close_attempt[monitor_id] = time.time() + + +def _try_acquire_close(monitor_id: int) -> bool: + with _closing_lock: + if monitor_id in _closing_monitors: + return False + _closing_monitors.add(monitor_id) + return True + + +def _release_close(monitor_id: int) -> None: + with _closing_lock: + _closing_monitors.discard(monitor_id) + + +def monitor_source_label(raw: str) -> str: + """持仓展示用来源文案。""" + mapping = { + "manual": "期货下单", + "trend": "趋势回调", + "roll": "顺势加仓", + "ctp_sync": "CTP 柜台", + "箱体突破": "箱体突破", + "收敛突破": "收敛突破", + } + key = (raw or "manual").strip().lower() + return mapping.get(key, raw or "期货下单") + + +def _result_for_close(mon: dict, reason: str) -> str: + """平仓结果:止损 / 止盈 / 移动止盈 / 保本止盈 / 手动平仓。""" + if reason == "manual": + return "手动平仓" + if reason == "take_profit": + return "止盈" + if not mon.get("trailing_be"): + return "止损" + locked = int(mon.get("trailing_r_locked") or 0) + if locked >= 2: + return "移动止盈" + if locked >= 1: + return "保本止盈" + return "止损" + + +def write_trade_log( + conn, + *, + symbol: str, + direction: str, + entry_price: float, + close_price: float, + lots: float, + result: str, + trading_mode: str, + stop_loss: Optional[float] = None, + take_profit: Optional[float] = None, + open_time: str = "", + symbol_name: str = "", + market_code: str = "", + sina_code: str = "", + monitor_type: str = "期货下单", + capital: float = 0.0, +) -> None: + """写入 trade_logs(程序平仓 / 手动平仓)。""" + sym = (symbol or "").strip() + direction = (direction or "long").strip().lower() + entry = float(entry_price or close_price) + sl = float(stop_loss) if stop_loss is not None else entry + tp = float(take_profit) if take_profit is not None else entry + close_time = datetime.now(TZ).strftime("%Y-%m-%dT%H:%M") + + if not sina_code or not market_code: + codes = ths_to_codes(sym) or {} + sina_code = sina_code or codes.get("sina_code") or "" + market_code = market_code or codes.get("market_code") or "" + if not symbol_name: + symbol_name = sym + + metrics = calc_position_metrics( + direction, entry, sl, tp, lots, close_price, capital, sym, + ) + pnl = metrics.get("float_pnl") or 0.0 + fee = calc_round_trip_fee( + sym, entry, close_price, lots, open_time, close_time, trading_mode=trading_mode, + ) + pnl_net = round(pnl - fee, 2) + margin_pct = metrics.get("position_pct") + equity_after = calc_equity_after(capital, pnl_net) + + try: + from app import holding_to_minutes + minutes = holding_to_minutes(open_time, close_time) + except Exception: + minutes = 0 + + conn.execute( + """INSERT INTO trade_logs + (symbol, symbol_name, market_code, sina_code, monitor_type, direction, + entry_price, stop_loss, take_profit, close_price, lots, margin, + margin_pct, holding_minutes, open_time, close_time, pnl, fee, pnl_net, + equity_after, result) + VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)""", + ( + sym, + symbol_name, + market_code, + sina_code, + monitor_type, + direction, + entry, + stop_loss if stop_loss is not None else sl, + take_profit if take_profit is not None else tp, + close_price, + lots, + metrics.get("margin"), + margin_pct, + minutes, + open_time, + close_time, + pnl, + fee, + pnl_net, + equity_after, + result if result in TRADE_RESULTS else "手动平仓", + ), + ) + try: + refresh_trade_log_equity_chain(conn, capital if capital > 0 else None) + except Exception as exc: + logger.debug("equity chain refresh after trade log: %s", exc) + try: + from modules.stats.stats_engine import refresh_stats_cache + refresh_stats_cache(conn, capital) + except Exception as exc: + logger.debug("stats refresh after close: %s", exc) + try: + from modules.trading.trade_notify import notify_trade_log_close + from modules.core.trading_context import trading_mode_label + from app import get_setting, send_wechat_msg + from modules.notify.ai_worker import schedule_ai_event_analysis + from modules.core.db_conn import DB_PATH + + notify_trade_log_close( + send_wechat=send_wechat_msg, + get_setting=get_setting, + mode_label=trading_mode_label(get_setting), + capital=capital, + sym=sym, + symbol_name=symbol_name, + direction=direction, + entry=entry, + close_price=close_price, + sl=stop_loss if stop_loss is not None else None, + tp=take_profit if take_profit is not None else None, + lots=lots, + pnl_net=pnl_net, + equity_after=equity_after, + holding_minutes=minutes, + result=result, + monitor_type=monitor_type, + schedule_ai_fn=schedule_ai_event_analysis, + db_path=DB_PATH, + ) + except Exception as exc: + logger.debug("close notify: %s", exc) + + +def _write_trade_log( + conn, + mon: dict, + *, + close_price: float, + reason: str, + trading_mode: str, + capital: float = 0.0, +) -> None: + sym = (mon.get("symbol") or "").strip() + sl_raw = mon.get("stop_loss") + tp_raw = mon.get("take_profit") + initial_sl = mon.get("initial_stop_loss") + write_trade_log( + conn, + symbol=sym, + direction=mon.get("direction") or "long", + entry_price=float(mon.get("entry_price") or close_price), + close_price=close_price, + lots=float(mon.get("lots") or 1), + result=_result_for_close(mon, reason), + trading_mode=trading_mode, + stop_loss=float(initial_sl) if initial_sl is not None else ( + float(sl_raw) if sl_raw is not None else None + ), + take_profit=float(tp_raw) if tp_raw is not None else None, + open_time=(mon.get("open_time") or "").strip(), + symbol_name=mon.get("symbol_name") or sym, + market_code=mon.get("market_code") or "", + monitor_type=monitor_source_label(mon.get("monitor_type") or ""), + capital=capital, + ) + + +def write_manual_close_trade_log( + conn, + mon: Optional[dict], + *, + symbol: str, + direction: str, + lots: float, + close_price: float, + entry_price: float, + trading_mode: str, + capital: float = 0.0, + stop_loss: Optional[float] = None, + take_profit: Optional[float] = None, + open_time: str = "", + symbol_name: str = "", + market_code: str = "", +) -> None: + """程序内点击平仓按钮 → 手动平仓。""" + if mon: + write_trade_log( + conn, + symbol=(mon.get("symbol") or symbol).strip(), + direction=mon.get("direction") or direction, + entry_price=float(mon.get("entry_price") or entry_price), + close_price=close_price, + lots=float(mon.get("lots") or lots), + result="手动平仓", + trading_mode=trading_mode, + stop_loss=float(mon["initial_stop_loss"]) if mon.get("initial_stop_loss") is not None else ( + float(mon["stop_loss"]) if mon.get("stop_loss") is not None else stop_loss + ), + take_profit=float(mon["take_profit"]) if mon.get("take_profit") is not None else take_profit, + open_time=(mon.get("open_time") or open_time).strip(), + symbol_name=mon.get("symbol_name") or symbol_name, + market_code=mon.get("market_code") or market_code, + monitor_type=monitor_source_label(mon.get("monitor_type") or ""), + capital=capital, + ) + return + write_trade_log( + conn, + symbol=symbol, + direction=direction, + entry_price=entry_price, + close_price=close_price, + lots=lots, + result="手动平仓", + trading_mode=trading_mode, + stop_loss=stop_loss, + take_profit=take_profit, + open_time=open_time, + symbol_name=symbol_name, + market_code=market_code, + capital=capital, + ) + + +def _update_trailing_stop_loss( + conn, + mon: dict, + mark: float, + *, + be_tick_mult: int, +) -> dict: + """达 1R 移保本(开仓±N跳),达 2R 移 1R,依次类推。""" + if not mon.get("trailing_be"): + return mon + entry = float(mon.get("entry_price") or 0) + initial_sl = mon.get("initial_stop_loss") + if initial_sl is None: + initial_sl = mon.get("stop_loss") + try: + initial_sl_f = float(initial_sl) if initial_sl is not None else None + except (TypeError, ValueError): + return mon + if not entry or initial_sl_f is None: + return mon + + direction = (mon.get("direction") or "long").strip().lower() + sym = (mon.get("symbol") or "").strip() + tick = _tick_size(sym) + r = abs(entry - initial_sl_f) + if r < tick * 0.5: + return mon + + profit_r = (mark - entry) / r if direction == "long" else (entry - mark) / r + if profit_r < 1.0: + return mon + + level = int(profit_r) + locked = int(mon.get("trailing_r_locked") or 0) + if level <= locked: + return mon + + if level == 1: + new_sl = entry + be_tick_mult * tick if direction == "long" else entry - be_tick_mult * tick + else: + new_sl = entry + (level - 1) * r if direction == "long" else entry - (level - 1) * r + new_sl = round(new_sl, 4) + + try: + current_sl = float(mon.get("stop_loss") or 0) + except (TypeError, ValueError): + current_sl = 0.0 + if direction == "long" and new_sl <= current_sl + tick * 0.01: + return mon + if direction == "short" and new_sl >= current_sl - tick * 0.01: + return mon + + mid = mon.get("id") + conn.execute( + "UPDATE trade_order_monitors SET stop_loss=?, trailing_r_locked=? WHERE id=?", + (new_sl, level, mid), + ) + conn.commit() + mon["stop_loss"] = new_sl + mon["trailing_r_locked"] = level + logger.info("移动保本 monitor=%s %dR 止损→%s", mid, level, new_sl) + return mon + + +def _sl_triggered(direction: str, sl: float, mark: float, tick: float) -> bool: + buf = max(tick * 0.01, 1e-9) + if direction == "long": + return mark <= sl + buf + return mark >= sl - buf + + +def _tp_triggered(direction: str, tp: float, mark: float, tick: float) -> bool: + buf = max(tick * 0.01, 1e-9) + if direction == "long": + return mark >= tp - buf + return mark <= tp + buf + + +def cancel_monitor_exit_orders( + conn, + mon: dict, + *, + mode: str, +) -> int: + """撤销该监控在交易所残留的旧版止盈止损平仓挂单。""" + ensure_monitor_order_columns(conn) + if not ctp_status(mode).get("connected"): + return 0 + sym = (mon.get("symbol") or "").strip() + direction = (mon.get("direction") or "long").strip().lower() + tick = _tick_size(sym) + active = ctp_list_active_orders(mode) + cancelled = 0 + seen: set[str] = set() + + def _try_cancel(vt_id: str) -> None: + nonlocal cancelled + oid = str(vt_id or "").strip() + if not oid or oid in seen: + return + seen.add(oid) + if ctp_cancel_order(mode, oid): + cancelled += 1 + + for kind, price_key in (("sl", "stop_loss"), ("tp", "take_profit")): + raw = mon.get(price_key) + try: + px = float(raw) if raw is not None else None + except (TypeError, ValueError): + px = None + stored = str(mon.get(f"{kind}_vt_order_id") or "") + if stored: + _try_cancel(stored) + if px is not None: + found = _find_close_order( + active, ths_code=sym, hold_direction=direction, price=px, tick=tick, + ) + if found: + _try_cancel(str(found.get("order_id") or "")) + + if cancelled: + conn.execute( + "UPDATE trade_order_monitors SET sl_vt_order_id=NULL, tp_vt_order_id=NULL WHERE id=?", + (mon["id"],), + ) + conn.commit() + return cancelled + + +def reconcile_monitors_without_position(conn, mode: str, *, grace_sec: int = 120) -> int: + """持仓已平时:关闭监控并撤销残留止盈止损挂单(新开仓 grace_sec 内不清理)。""" + if not ctp_status(mode).get("connected"): + return 0 + try: + bridge = get_bridge() + since_connect = time.time() - float(getattr(bridge, "_last_connect_ok_ts", 0) or 0) + if since_connect < 90: + return 0 + except Exception: + pass + positions = ctp_list_positions(mode, refresh_if_empty=False, refresh_margin=False) + position_keys: set[tuple[str, str]] = set() + for p in positions: + if int(p.get("lots") or 0) <= 0: + continue + sym = (p.get("symbol") or "").lower() + direction = p.get("direction") or "long" + position_keys.add((sym, direction)) + try: + from modules.ctp.ctp_trading_state import trading_state + + for p in trading_state.get_positions() or []: + lots = int(p.get("lots") or 0) + if lots <= 0: + continue + sym = (p.get("symbol") or "").lower() + direction = p.get("direction") or "long" + position_keys.add((sym, direction)) + except Exception: + pass + + margin_raw = ctp_account_margin_used(mode) + if margin_raw is None: + return 0 + margin_used = float(margin_raw or 0.0) + if not position_keys: + if margin_used > 0: + return 0 + try: + bridge = get_bridge() + since_connect = time.time() - float(getattr(bridge, "_last_connect_ok_ts", 0) or 0) + if since_connect < 180: + return 0 + except Exception: + return 0 + + now_ts = time.time() + + def _monitor_within_grace(mon: dict) -> bool: + raw = (mon.get("open_time") or mon.get("created_at") or "").strip() + if not raw: + return True + for fmt in ("%Y-%m-%d %H:%M:%S", "%Y-%m-%dT%H:%M:%S", "%Y-%m-%dT%H:%M"): + try: + dt = datetime.strptime(raw[:19], fmt) + if (now_ts - dt.timestamp()) <= grace_sec: + return True + except ValueError: + continue + return False + + closed = 0 + for r in conn.execute("SELECT * FROM trade_order_monitors WHERE status='active'").fetchall(): + mon = dict(r) + if _monitor_within_grace(mon): + continue + ms = mon.get("symbol") or "" + md = mon.get("direction") or "long" + matched = False + for ps, pd in position_keys: + if pd != md: + continue + if _match_symbol(ps, ms): + matched = True + break + if matched: + continue + try: + cancel_monitor_exit_orders(conn, mon, mode=mode) + except Exception as exc: + logger.warning("cancel exit orders monitor=%s: %s", mon.get("id"), exc) + conn.execute("UPDATE trade_order_monitors SET status='closed' WHERE id=?", (mon["id"],)) + closed += 1 + if closed: + conn.commit() + return closed + + +def _execute_local_close( + conn, + mon: dict, + *, + mode: str, + mark: float, + reason: str, + capital: float = 0.0, + notify_fn: Callable[[str], None] | None = None, +) -> None: + sym = (mon.get("symbol") or "").strip() + direction = (mon.get("direction") or "long").strip().lower() + positions = ctp_list_positions(mode) + pos = _find_position(positions, sym, direction) + if not pos: + margin_raw = ctp_account_margin_used(mode) + if margin_raw is not None and float(margin_raw) > 0: + logger.debug( + "skip close monitor=%s: vnpy empty but margin=%.2f", + mon.get("id"), + float(margin_raw), + ) + return + _close_all_monitors_for_symbol(conn, sym, direction) + reconcile_monitors_without_position(conn, mode) + return + lots = int(pos.get("lots") or mon.get("lots") or 1) + offset = "close_long" if direction == "long" else "close_short" + cancel_monitor_exit_orders(conn, mon, mode=mode) + execute_order( + conn, + mode=mode, + offset=offset, + symbol=sym, + direction=direction, + lots=lots, + price=mark, + order_type="market", + ) + _close_all_monitors_for_symbol(conn, sym, direction) + conn.commit() + result_label = _result_for_close(mon, reason) + logger.info( + "止盈止损本地触发 monitor=%s result=%s %s %s %d手 @%s(待 CTP 成交同步写入交易记录)", + mon.get("id"), result_label, sym, direction, lots, mark, + ) + if notify_fn: + try: + notify_fn(f"{result_label} {sym} {direction} {lots}手 @{mark},平仓委托已提交") + except Exception as exc: + logger.debug("SL/TP notify failed: %s", exc) + + +def check_sl_tp_on_tick( + conn, + mode: str, + exchange: str, + symbol: str, + mark: float, + *, + capital: float = 0.0, + notify_fn: Callable[[str], None] | None = None, + be_tick_mult: int = 2, +) -> int: + """EVENT_TICK 触发:仅检查与 tick 品种匹配的 active 监控。""" + ensure_monitor_order_columns(conn) + if not ctp_status(mode).get("connected") or not is_trading_session(): + return 0 + if mark <= 0: + return 0 + sym_l = (symbol or "").lower() + ex_u = (exchange or "").upper() + closed = 0 + rows = [dict(r) for r in conn.execute( + "SELECT * FROM trade_order_monitors WHERE status='active'" + ).fetchall()] + for mon in rows: + mid = int(mon.get("id") or 0) + ms = (mon.get("symbol") or "").strip() + direction = (mon.get("direction") or "long").strip().lower() + try: + vnpy_sym, ex2 = ths_to_vnpy_symbol(ms) + if sym_l != vnpy_sym.lower(): + continue + if ex_u and ex2 and ex_u != ex2.upper(): + continue + except Exception: + if sym_l != ms.lower(): + continue + + sl = mon.get("stop_loss") + tp = mon.get("take_profit") + try: + sl_f = float(sl) if sl is not None else None + tp_f = float(tp) if tp is not None else None + except (TypeError, ValueError): + sl_f, tp_f = None, None + if sl_f is None and tp_f is None: + continue + + positions = ctp_list_positions(mode) + if not _find_position(positions, ms, direction): + continue + + tick = _tick_size(ms) + if mon.get("trailing_be"): + mon = _update_trailing_stop_loss(conn, mon, mark, be_tick_mult=be_tick_mult) + try: + sl_f = float(mon["stop_loss"]) if mon.get("stop_loss") is not None else sl_f + except (TypeError, ValueError): + pass + + reason = None + if tp_f is not None and _tp_triggered(direction, tp_f, mark, tick): + reason = "take_profit" + elif sl_f is not None and _sl_triggered(direction, sl_f, mark, tick): + reason = "stop_loss" + if not reason: + continue + if mid > 0 and not _can_close_now(mid): + continue + if not _try_acquire_close_symbol(ms, direction): + continue + try: + _execute_local_close( + conn, mon, mode=mode, mark=mark, reason=reason, + capital=capital, notify_fn=notify_fn, + ) + if mid > 0: + _mark_close_attempt(mid) + closed += 1 + except Exception as exc: + logger.warning("SL/TP tick close failed monitor=%s: %s", mid, exc) + finally: + _release_close_symbol(ms, direction) + return closed + + +def check_monitors_locally( + conn, + mode: str, + *, + capital: float = 0.0, + notify_fn: Callable[[str], None] | None = None, + be_tick_mult: int = 2, +) -> int: + """扫描 active 监控,本地比对行情;触发止盈/止损(含跳空穿透)后立刻市价平仓并记交易记录。""" + ensure_monitor_order_columns(conn) + if not ctp_status(mode).get("connected"): + return 0 + if not is_trading_session(): + return 0 + reconcile_monitors_without_position(conn, mode) + _dedupe_active_monitors(conn) + conn.commit() + closed = 0 + rows = [dict(r) for r in conn.execute( + "SELECT * FROM trade_order_monitors WHERE status='active'" + ).fetchall()] + for mon in rows: + mid = int(mon.get("id") or 0) + sym = (mon.get("symbol") or "").strip() + direction = (mon.get("direction") or "long").strip().lower() + + if mon.get("sl_vt_order_id") or mon.get("tp_vt_order_id"): + cancel_monitor_exit_orders(conn, mon, mode=mode) + + sl = mon.get("stop_loss") + tp = mon.get("take_profit") + try: + sl_f = float(sl) if sl is not None else None + tp_f = float(tp) if tp is not None else None + except (TypeError, ValueError): + sl_f, tp_f = None, None + if sl_f is None and tp_f is None: + continue + + positions = ctp_list_positions(mode) + if not _find_position(positions, sym, direction): + continue + + mark = ctp_get_tick_price(mode, sym) + if mark is None or mark <= 0: + continue + + tick = _tick_size(sym) + if mon.get("trailing_be"): + mon = _update_trailing_stop_loss(conn, mon, mark, be_tick_mult=be_tick_mult) + try: + sl_f = float(mon["stop_loss"]) if mon.get("stop_loss") is not None else sl_f + except (TypeError, ValueError): + pass + + reason = None + if tp_f is not None and _tp_triggered(direction, tp_f, mark, tick): + reason = "take_profit" + elif sl_f is not None and _sl_triggered(direction, sl_f, mark, tick): + reason = "stop_loss" + + if not reason: + continue + if mid > 0 and not _can_close_now(mid): + continue + if not _try_acquire_close_symbol(sym, direction): + continue + try: + _execute_local_close( + conn, + mon, + mode=mode, + mark=mark, + reason=reason, + capital=capital, + notify_fn=notify_fn, + ) + if mid > 0: + _mark_close_attempt(mid) + closed += 1 + except Exception as exc: + logger.warning("SL/TP local close failed monitor=%s: %s", mid, exc) + finally: + _release_close_symbol(sym, direction) + return closed + + +def place_monitor_exit_orders( + conn, + mon: dict, + *, + mode: str, + force: bool = False, +) -> dict[str, Any]: + """兼容旧 API:本地监控模式不再向交易所挂 SL/TP 单,仅清理旧挂单。""" + del force + ensure_monitor_order_columns(conn) + if not ctp_status(mode).get("connected"): + return {"ok": False, "error": "CTP 未连接", "placed": []} + cancelled = cancel_monitor_exit_orders(conn, mon, mode=mode) + msg = "程序本地监控中,不向交易所挂止盈止损单" + if cancelled: + msg += f";已撤销旧版柜台挂单 {cancelled} 笔" + return {"ok": True, "message": msg, "placed": [], "local_monitor": True} + + +def monitor_order_status( + mon: dict, + *, + mode: str, + ths_code: str, + direction: str, +) -> dict[str, bool]: + """返回本地监控状态(非交易所挂单状态)。""" + del mode, ths_code, direction + sl = mon.get("stop_loss") if mon else None + tp = mon.get("take_profit") if mon else None + try: + sl_f = float(sl) if sl is not None else None + tp_f = float(tp) if tp is not None else None + except (TypeError, ValueError): + sl_f, tp_f = None, None + return { + "sl_order_active": sl_f is not None, + "tp_order_active": tp_f is not None, + "sl_monitoring": sl_f is not None, + "tp_monitoring": tp_f is not None, + "needs_sl_order": False, + "needs_tp_order": False, + } + + +def sync_all_sl_tp_orders(conn, mode: str) -> int: + """兼容旧 worker 入口:执行本地监控检查。""" + del mode + return 0 + + +def start_sl_tp_guard_worker( + *, + db_path: str, + get_mode_fn: Callable[[], str], + init_tables_fn: Callable | None = None, + get_capital_fn: Callable | None = None, + get_be_tick_buffer_fn: Callable[[], int] | None = None, + notify_fn: Callable[[str], None] | None = None, + interval: int = CHECK_INTERVAL_SEC, +) -> None: + from modules.core.db_conn import connect_db + + def _loop() -> None: + time.sleep(20) + while True: + sleep_sec = max(1, interval) + try: + if not is_trading_session(): + time.sleep(CLOSED_MARKET_SLEEP_SEC) + continue + mode = get_mode_fn() + if not ctp_status(mode).get("connected"): + time.sleep(DISCONNECTED_SLEEP_SEC) + continue + conn = connect_db(db_path) + try: + if init_tables_fn: + init_tables_fn(conn) + has_monitors = conn.execute( + """SELECT COUNT(*) AS n FROM trade_order_monitors + WHERE status='active' + AND (stop_loss IS NOT NULL OR take_profit IS NOT NULL)""" + ).fetchone()["n"] + if not has_monitors: + sleep_sec = max(sleep_sec, 5) + else: + capital = 0.0 + if get_capital_fn: + try: + capital = float(get_capital_fn(conn) or 0) + except Exception: + capital = 0.0 + n = check_monitors_locally( + conn, + mode, + capital=capital, + notify_fn=notify_fn, + be_tick_mult=( + get_be_tick_buffer_fn() if get_be_tick_buffer_fn else 2 + ), + ) + if n: + logger.info("止盈止损本地监控: 触发平仓 %d 笔", n) + finally: + conn.close() + except Exception as exc: + logger.warning("sl_tp_guard worker: %s", exc) + time.sleep(sleep_sec) + + threading.Thread(target=_loop, daemon=True, name="sl-tp-guard").start() diff --git a/trade_log_lib.py b/modules/trading/trade_log_lib.py similarity index 93% rename from trade_log_lib.py rename to modules/trading/trade_log_lib.py index 9267632..f08891e 100644 --- a/trade_log_lib.py +++ b/modules/trading/trade_log_lib.py @@ -1,218 +1,218 @@ -# Copyright (c) 2025-2026 马建军. All rights reserved. -# 专有软件 — 未经授权禁止复制、传播、转售。 -# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 -# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md - -"""交易记录:字段补全、资金曲线数据。""" -from __future__ import annotations - -from typing import Any - - -TRADE_LOG_EXTRA_COLUMNS = ( - "ALTER TABLE trade_logs ADD COLUMN margin_pct REAL", - "ALTER TABLE trade_logs ADD COLUMN equity_after REAL", - "ALTER TABLE trade_logs ADD COLUMN source TEXT DEFAULT 'local'", - "ALTER TABLE trade_logs ADD COLUMN ctp_trade_key TEXT", -) - - -def ensure_trade_log_columns(conn) -> None: - for sql in TRADE_LOG_EXTRA_COLUMNS: - try: - conn.execute(sql) - except Exception: - pass - - -def calc_equity_after(capital: float, pnl_net: float) -> float | None: - cap = float(capital or 0) - if cap <= 0: - return None - return round(cap + float(pnl_net or 0), 2) - - -def recalc_trade_log_pnl( - *, - symbol: str, - direction: str, - entry_price: float, - close_price: float, - lots: float, - stop_loss: float | None = None, - take_profit: float | None = None, - open_time: str = "", - close_time: str = "", - trading_mode: str = "simulation", - capital: float = 0.0, -) -> dict[str, float]: - """按开/平仓价重算盈亏与手续费(跨日持仓可手动改价后核对)。""" - from contract_specs import calc_position_metrics - from fee_specs import calc_round_trip_fee - - sym = (symbol or "").strip() - direction = (direction or "long").strip().lower() - entry = float(entry_price or close_price or 0) - close_px = float(close_price or 0) - lots_f = float(lots or 0) - sl = float(stop_loss) if stop_loss is not None else entry - tp = float(take_profit) if take_profit is not None else entry - metrics = calc_position_metrics( - direction, entry, sl, tp, lots_f, close_px, capital, sym, - ) - pnl = round(float(metrics.get("float_pnl") or 0), 2) - fee = calc_round_trip_fee( - sym, entry, close_px, lots_f, open_time, close_time, trading_mode=trading_mode, - ) - pnl_net = round(pnl - fee, 2) - return {"pnl": pnl, "fee": round(fee, 2), "pnl_net": pnl_net} - - -def _read_initial_capital(conn, initial_capital: float | None = None) -> float: - if initial_capital is not None and initial_capital > 0: - return float(initial_capital) - try: - row = conn.execute("SELECT value FROM settings WHERE key='live_capital'").fetchone() - if row and row[0]: - val = float(row[0] or 0) - if val > 0: - return val - except (TypeError, ValueError): - pass - try: - from product_recommend import DISCONNECTED_RECOMMEND_CAPITAL - return float(DISCONNECTED_RECOMMEND_CAPITAL) - except Exception: - return 100_000.0 - - -def refresh_trade_log_equity_chain( - conn, - initial_capital: float | None = None, -) -> int: - """按平仓时间顺序重算 trade_logs.equity_after(起始=参考资金 live_capital)。""" - base = _read_initial_capital(conn, initial_capital) - rows = [ - dict(r) - for r in conn.execute( - "SELECT id, close_time, pnl_net FROM trade_logs ORDER BY close_time ASC, id ASC" - ).fetchall() - ] - running = float(base or 0) - updated = 0 - for row in rows: - if running <= 0: - break - running = round(running + float(row.get("pnl_net") or 0), 2) - conn.execute( - "UPDATE trade_logs SET equity_after=? WHERE id=?", - (running, int(row["id"])), - ) - updated += 1 - return updated - - -def _norm_symbol(symbol: str) -> str: - return (symbol or "").split(".")[0].strip().lower() - - -def _norm_close_minute(ts: str) -> str: - """统一 close_time 到分钟粒度,兼容 ISO `T` 与空格分隔。""" - return (ts or "").strip().replace("T", " ")[:16] - - -def purge_duplicate_local_trade_logs(conn) -> int: - """删除已被 CTP 柜台记录覆盖的本地重复成交。""" - removed = 0 - ctp_rows = [ - dict(r) - for r in conn.execute("SELECT * FROM trade_logs WHERE source='ctp'").fetchall() - ] - local_rows = [ - dict(r) - for r in conn.execute( - """SELECT * FROM trade_logs - WHERE COALESCE(source, 'local') != 'ctp' - AND (ctp_trade_key IS NULL OR ctp_trade_key = '')""" - ).fetchall() - ] - for ctp in ctp_rows: - ct16 = _norm_close_minute(ctp.get("close_time") or "") - sym_n = _norm_symbol(ctp.get("symbol") or "") - lots = float(ctp.get("lots") or 0) - direction = (ctp.get("direction") or "long").strip().lower() - for loc in local_rows: - if loc.get("id") == ctp.get("id"): - continue - if _norm_symbol(loc.get("symbol") or "") != sym_n: - continue - if (loc.get("direction") or "long").strip().lower() != direction: - continue - if _norm_close_minute(loc.get("close_time") or "") != ct16: - continue - if abs(float(loc.get("lots") or 0) - lots) > 0.01: - continue - conn.execute("DELETE FROM trade_logs WHERE id=?", (loc["id"],)) - removed += 1 - return removed - - -def _attach_symbol_meta(t: dict[str, Any]) -> None: - try: - from symbols import position_symbol_meta - - sym = (t.get("symbol") or "").strip() - meta = position_symbol_meta(sym) - if not t.get("symbol_name"): - t["symbol_name"] = meta.get("name") or sym - t["symbol_exchange"] = meta.get("exchange") or "" - t["symbol_is_main"] = bool(meta.get("is_main")) - except Exception: - t.setdefault("symbol_exchange", "") - t.setdefault("symbol_is_main", False) - - -def enrich_trades_for_records( - trades: list[dict[str, Any]], - *, - initial_capital: float = 0.0, -) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: - """表格仍按 id 降序;资金曲线按平仓时间升序用最新资金绘制。""" - rows = [dict(t) for t in trades] - chrono = sorted( - rows, - key=lambda t: ((t.get("close_time") or ""), int(t.get("id") or 0)), - ) - running = float(initial_capital or 0) - curve: list[dict[str, Any]] = [] - equity_by_id: dict[int, float | None] = {} - - for t in chrono: - _attach_symbol_meta(t) - pnl_net = float(t.get("pnl_net") or 0) - if running > 0: - running = round(running + pnl_net, 2) - eq: float | None = running - else: - eq = None - equity_by_id[int(t.get("id") or 0)] = eq - - cap_before = float(eq or 0) - pnl_net if eq is not None else 0.0 - if t.get("margin_pct") is None: - margin = float(t.get("margin") or 0) - if margin > 0 and cap_before > 0: - t["margin_pct"] = round(margin / cap_before * 100, 2) - - if eq is not None: - curve.append({ - "time": (t.get("close_time") or "")[:19], - "value": float(eq), - "id": int(t.get("id") or 0), - }) - - for t in rows: - tid = int(t.get("id") or 0) - if tid in equity_by_id: - t["equity_after"] = equity_by_id[tid] - - return rows, curve +# Copyright (c) 2025-2026 马建军. All rights reserved. +# 专有软件 — 未经授权禁止复制、传播、转售。 +# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 +# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md + +"""交易记录:字段补全、资金曲线数据。""" +from __future__ import annotations + +from typing import Any + + +TRADE_LOG_EXTRA_COLUMNS = ( + "ALTER TABLE trade_logs ADD COLUMN margin_pct REAL", + "ALTER TABLE trade_logs ADD COLUMN equity_after REAL", + "ALTER TABLE trade_logs ADD COLUMN source TEXT DEFAULT 'local'", + "ALTER TABLE trade_logs ADD COLUMN ctp_trade_key TEXT", +) + + +def ensure_trade_log_columns(conn) -> None: + for sql in TRADE_LOG_EXTRA_COLUMNS: + try: + conn.execute(sql) + except Exception: + pass + + +def calc_equity_after(capital: float, pnl_net: float) -> float | None: + cap = float(capital or 0) + if cap <= 0: + return None + return round(cap + float(pnl_net or 0), 2) + + +def recalc_trade_log_pnl( + *, + symbol: str, + direction: str, + entry_price: float, + close_price: float, + lots: float, + stop_loss: float | None = None, + take_profit: float | None = None, + open_time: str = "", + close_time: str = "", + trading_mode: str = "simulation", + capital: float = 0.0, +) -> dict[str, float]: + """按开/平仓价重算盈亏与手续费(跨日持仓可手动改价后核对)。""" + from modules.core.contract_specs import calc_position_metrics + from modules.fees.fee_specs import calc_round_trip_fee + + sym = (symbol or "").strip() + direction = (direction or "long").strip().lower() + entry = float(entry_price or close_price or 0) + close_px = float(close_price or 0) + lots_f = float(lots or 0) + sl = float(stop_loss) if stop_loss is not None else entry + tp = float(take_profit) if take_profit is not None else entry + metrics = calc_position_metrics( + direction, entry, sl, tp, lots_f, close_px, capital, sym, + ) + pnl = round(float(metrics.get("float_pnl") or 0), 2) + fee = calc_round_trip_fee( + sym, entry, close_px, lots_f, open_time, close_time, trading_mode=trading_mode, + ) + pnl_net = round(pnl - fee, 2) + return {"pnl": pnl, "fee": round(fee, 2), "pnl_net": pnl_net} + + +def _read_initial_capital(conn, initial_capital: float | None = None) -> float: + if initial_capital is not None and initial_capital > 0: + return float(initial_capital) + try: + row = conn.execute("SELECT value FROM settings WHERE key='live_capital'").fetchone() + if row and row[0]: + val = float(row[0] or 0) + if val > 0: + return val + except (TypeError, ValueError): + pass + try: + from modules.trading.product_recommend import DISCONNECTED_RECOMMEND_CAPITAL + return float(DISCONNECTED_RECOMMEND_CAPITAL) + except Exception: + return 100_000.0 + + +def refresh_trade_log_equity_chain( + conn, + initial_capital: float | None = None, +) -> int: + """按平仓时间顺序重算 trade_logs.equity_after(起始=参考资金 live_capital)。""" + base = _read_initial_capital(conn, initial_capital) + rows = [ + dict(r) + for r in conn.execute( + "SELECT id, close_time, pnl_net FROM trade_logs ORDER BY close_time ASC, id ASC" + ).fetchall() + ] + running = float(base or 0) + updated = 0 + for row in rows: + if running <= 0: + break + running = round(running + float(row.get("pnl_net") or 0), 2) + conn.execute( + "UPDATE trade_logs SET equity_after=? WHERE id=?", + (running, int(row["id"])), + ) + updated += 1 + return updated + + +def _norm_symbol(symbol: str) -> str: + return (symbol or "").split(".")[0].strip().lower() + + +def _norm_close_minute(ts: str) -> str: + """统一 close_time 到分钟粒度,兼容 ISO `T` 与空格分隔。""" + return (ts or "").strip().replace("T", " ")[:16] + + +def purge_duplicate_local_trade_logs(conn) -> int: + """删除已被 CTP 柜台记录覆盖的本地重复成交。""" + removed = 0 + ctp_rows = [ + dict(r) + for r in conn.execute("SELECT * FROM trade_logs WHERE source='ctp'").fetchall() + ] + local_rows = [ + dict(r) + for r in conn.execute( + """SELECT * FROM trade_logs + WHERE COALESCE(source, 'local') != 'ctp' + AND (ctp_trade_key IS NULL OR ctp_trade_key = '')""" + ).fetchall() + ] + for ctp in ctp_rows: + ct16 = _norm_close_minute(ctp.get("close_time") or "") + sym_n = _norm_symbol(ctp.get("symbol") or "") + lots = float(ctp.get("lots") or 0) + direction = (ctp.get("direction") or "long").strip().lower() + for loc in local_rows: + if loc.get("id") == ctp.get("id"): + continue + if _norm_symbol(loc.get("symbol") or "") != sym_n: + continue + if (loc.get("direction") or "long").strip().lower() != direction: + continue + if _norm_close_minute(loc.get("close_time") or "") != ct16: + continue + if abs(float(loc.get("lots") or 0) - lots) > 0.01: + continue + conn.execute("DELETE FROM trade_logs WHERE id=?", (loc["id"],)) + removed += 1 + return removed + + +def _attach_symbol_meta(t: dict[str, Any]) -> None: + try: + from modules.core.symbols import position_symbol_meta + + sym = (t.get("symbol") or "").strip() + meta = position_symbol_meta(sym) + if not t.get("symbol_name"): + t["symbol_name"] = meta.get("name") or sym + t["symbol_exchange"] = meta.get("exchange") or "" + t["symbol_is_main"] = bool(meta.get("is_main")) + except Exception: + t.setdefault("symbol_exchange", "") + t.setdefault("symbol_is_main", False) + + +def enrich_trades_for_records( + trades: list[dict[str, Any]], + *, + initial_capital: float = 0.0, +) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: + """表格仍按 id 降序;资金曲线按平仓时间升序用最新资金绘制。""" + rows = [dict(t) for t in trades] + chrono = sorted( + rows, + key=lambda t: ((t.get("close_time") or ""), int(t.get("id") or 0)), + ) + running = float(initial_capital or 0) + curve: list[dict[str, Any]] = [] + equity_by_id: dict[int, float | None] = {} + + for t in chrono: + _attach_symbol_meta(t) + pnl_net = float(t.get("pnl_net") or 0) + if running > 0: + running = round(running + pnl_net, 2) + eq: float | None = running + else: + eq = None + equity_by_id[int(t.get("id") or 0)] = eq + + cap_before = float(eq or 0) - pnl_net if eq is not None else 0.0 + if t.get("margin_pct") is None: + margin = float(t.get("margin") or 0) + if margin > 0 and cap_before > 0: + t["margin_pct"] = round(margin / cap_before * 100, 2) + + if eq is not None: + curve.append({ + "time": (t.get("close_time") or "")[:19], + "value": float(eq), + "id": int(t.get("id") or 0), + }) + + for t in rows: + tid = int(t.get("id") or 0) + if tid in equity_by_id: + t["equity_after"] = equity_by_id[tid] + + return rows, curve diff --git a/trade_notify.py b/modules/trading/trade_notify.py similarity index 93% rename from trade_notify.py rename to modules/trading/trade_notify.py index b98f838..77b5061 100644 --- a/trade_notify.py +++ b/modules/trading/trade_notify.py @@ -1,225 +1,225 @@ -# Copyright (c) 2025-2026 马建军. All rights reserved. -# 专有软件 — 未经授权禁止复制、传播、转售。 -# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 -# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md - -"""交易事件推送:企业微信 + AI 分析。""" -from __future__ import annotations - -from typing import Callable, Optional - -from contract_specs import calc_position_metrics, get_contract_spec -from sl_tp_guard import monitor_source_label -from wechat_notify import format_close_done, format_key_open_success, format_open_success - - -def _risk_amount(capital: float, risk_percent: float) -> Optional[float]: - try: - return round(float(capital) * float(risk_percent) / 100.0, 2) - except (TypeError, ValueError): - return None - - -def notify_manual_open_filled( - *, - send_wechat: Callable[[str], None], - get_setting: Callable[[str, str], str], - mode_label: str, - sym: str, - symbol_name: str, - direction: str, - entry: float, - sl: Optional[float], - tp: Optional[float], - lots: int, - capital: float, - order_id: str = "", - trailing_be: bool = False, - be_tick_buffer: int = 2, - schedule_ai_fn=None, - db_path: str = "", -) -> None: - if not sl: - return - spec = get_contract_spec(sym) - tick = float(spec.get("tick_size") or 1.0) - try: - rp = float(get_setting("risk_percent", "1") or 1) - except (TypeError, ValueError): - rp = 1.0 - metrics = calc_position_metrics(direction, entry, sl, tp or entry, lots, entry, capital, sym) - msg = format_open_success( - symbol_name=symbol_name, - symbol=sym, - direction=direction, - mode_label=mode_label, - order_id=order_id, - entry=entry, - stop_loss=float(sl), - take_profit=float(tp) if tp else None, - lots=lots, - capital=capital, - margin=metrics.get("margin"), - margin_pct=metrics.get("position_pct"), - risk_percent=rp, - risk_amount=_risk_amount(capital, rp), - trailing_be=trailing_be, - be_tick_buffer=be_tick_buffer, - tick_size=tick, - source="期货下单", - ) - send_wechat(msg) - if schedule_ai_fn and db_path: - schedule_ai_fn( - db_path=db_path, - get_setting_fn=get_setting, - kind="open", - title=f"{symbol_name or sym} 开仓", - payload={ - "symbol": sym, - "direction": direction, - "entry": entry, - "stop_loss": sl, - "take_profit": tp, - "lots": lots, - "capital": capital, - }, - send_wechat_fn=None, - ) - - -def notify_key_breakout_open( - *, - send_wechat: Callable[[str], None], - get_setting: Callable[[str, str], str], - mode_label: str, - row: dict, - break_side: str, - bar_time: str, - direction: str, - entry: float, - sl: float, - tp: float, - lots: int, - capital: float, - order_id: str = "", - schedule_ai_fn=None, - db_path: str = "", -) -> None: - sym = row.get("symbol") or "" - name = row.get("symbol_name") or sym - trailing_be = bool(int(row.get("trailing_be") or 0)) - try: - rp = float(get_setting("risk_percent", "1") or 1) - be_buf = int(float(get_setting("trailing_be_tick_buffer", "2") or 2)) - except (TypeError, ValueError): - rp, be_buf = 1.0, 2 - spec = get_contract_spec(sym) - tick = float(spec.get("tick_size") or 1.0) - metrics = calc_position_metrics(direction, entry, sl, tp, lots, entry, capital, sym) - msg = format_key_open_success( - symbol_name=name, - symbol=sym, - monitor_type=row.get("monitor_type") or "", - trade_mode=row.get("trade_mode") or "顺势", - bar_time=bar_time, - break_side=break_side, - direction=direction, - mode_label=mode_label, - order_id=order_id, - entry=entry, - stop_loss=sl, - take_profit=tp, - lots=lots, - capital=capital, - margin=metrics.get("margin"), - margin_pct=metrics.get("position_pct"), - risk_percent=rp, - risk_amount=_risk_amount(capital, rp), - trailing_be=trailing_be, - be_tick_buffer=be_buf, - tick_size=tick, - ) - send_wechat(msg) - if schedule_ai_fn and db_path: - schedule_ai_fn( - db_path=db_path, - get_setting_fn=get_setting, - kind="key_open", - title=f"{name} 关键位开仓", - payload={ - "monitor_type": row.get("monitor_type"), - "trade_mode": row.get("trade_mode"), - "break_side": break_side, - "entry": entry, - "stop_loss": sl, - "take_profit": tp, - "lots": lots, - }, - ) - - -def notify_trade_log_close( - *, - send_wechat: Callable[[str], None], - get_setting: Callable[[str, str], str], - mode_label: str, - capital: float, - sym: str, - symbol_name: str, - direction: str, - entry: float, - close_price: float, - sl: Optional[float], - tp: Optional[float], - lots: float, - pnl_net: float, - equity_after: Optional[float], - holding_minutes: int, - result: str, - monitor_type: str = "", - schedule_ai_fn=None, - db_path: str = "", -) -> None: - src = monitor_source_label(monitor_type) if monitor_type else "期货下单" - note = "" - if tp and sl: - if direction == "long": - if close_price > tp or close_price < sl: - note = "成交价不在计划止盈/止损带内(可能为手动或其他类型平仓)" - else: - if close_price < tp or close_price > sl: - note = "成交价不在计划止盈/止损带内(可能为手动或其他类型平仓)" - msg = format_close_done( - symbol_name=symbol_name, - symbol=sym, - mode_label=mode_label, - direction=direction, - result=result, - pnl_net=pnl_net, - equity_after=equity_after, - capital=capital, - entry=entry, - close_price=close_price, - stop_loss=sl, - take_profit=tp, - lots=lots, - holding_minutes=holding_minutes, - note=note, - ) - send_wechat(msg) - if schedule_ai_fn and db_path: - schedule_ai_fn( - db_path=db_path, - get_setting_fn=get_setting, - kind="close", - title=f"{symbol_name or sym} 平仓", - payload={ - "source": src, - "result": result, - "pnl_net": pnl_net, - "entry": entry, - "close_price": close_price, - "lots": lots, - }, - ) +# Copyright (c) 2025-2026 马建军. All rights reserved. +# 专有软件 — 未经授权禁止复制、传播、转售。 +# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 +# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md + +"""交易事件推送:企业微信 + AI 分析。""" +from __future__ import annotations + +from typing import Callable, Optional + +from modules.core.contract_specs import calc_position_metrics, get_contract_spec +from modules.trading.sl_tp_guard import monitor_source_label +from modules.notify.wechat_notify import format_close_done, format_key_open_success, format_open_success + + +def _risk_amount(capital: float, risk_percent: float) -> Optional[float]: + try: + return round(float(capital) * float(risk_percent) / 100.0, 2) + except (TypeError, ValueError): + return None + + +def notify_manual_open_filled( + *, + send_wechat: Callable[[str], None], + get_setting: Callable[[str, str], str], + mode_label: str, + sym: str, + symbol_name: str, + direction: str, + entry: float, + sl: Optional[float], + tp: Optional[float], + lots: int, + capital: float, + order_id: str = "", + trailing_be: bool = False, + be_tick_buffer: int = 2, + schedule_ai_fn=None, + db_path: str = "", +) -> None: + if not sl: + return + spec = get_contract_spec(sym) + tick = float(spec.get("tick_size") or 1.0) + try: + rp = float(get_setting("risk_percent", "1") or 1) + except (TypeError, ValueError): + rp = 1.0 + metrics = calc_position_metrics(direction, entry, sl, tp or entry, lots, entry, capital, sym) + msg = format_open_success( + symbol_name=symbol_name, + symbol=sym, + direction=direction, + mode_label=mode_label, + order_id=order_id, + entry=entry, + stop_loss=float(sl), + take_profit=float(tp) if tp else None, + lots=lots, + capital=capital, + margin=metrics.get("margin"), + margin_pct=metrics.get("position_pct"), + risk_percent=rp, + risk_amount=_risk_amount(capital, rp), + trailing_be=trailing_be, + be_tick_buffer=be_tick_buffer, + tick_size=tick, + source="期货下单", + ) + send_wechat(msg) + if schedule_ai_fn and db_path: + schedule_ai_fn( + db_path=db_path, + get_setting_fn=get_setting, + kind="open", + title=f"{symbol_name or sym} 开仓", + payload={ + "symbol": sym, + "direction": direction, + "entry": entry, + "stop_loss": sl, + "take_profit": tp, + "lots": lots, + "capital": capital, + }, + send_wechat_fn=None, + ) + + +def notify_key_breakout_open( + *, + send_wechat: Callable[[str], None], + get_setting: Callable[[str, str], str], + mode_label: str, + row: dict, + break_side: str, + bar_time: str, + direction: str, + entry: float, + sl: float, + tp: float, + lots: int, + capital: float, + order_id: str = "", + schedule_ai_fn=None, + db_path: str = "", +) -> None: + sym = row.get("symbol") or "" + name = row.get("symbol_name") or sym + trailing_be = bool(int(row.get("trailing_be") or 0)) + try: + rp = float(get_setting("risk_percent", "1") or 1) + be_buf = int(float(get_setting("trailing_be_tick_buffer", "2") or 2)) + except (TypeError, ValueError): + rp, be_buf = 1.0, 2 + spec = get_contract_spec(sym) + tick = float(spec.get("tick_size") or 1.0) + metrics = calc_position_metrics(direction, entry, sl, tp, lots, entry, capital, sym) + msg = format_key_open_success( + symbol_name=name, + symbol=sym, + monitor_type=row.get("monitor_type") or "", + trade_mode=row.get("trade_mode") or "顺势", + bar_time=bar_time, + break_side=break_side, + direction=direction, + mode_label=mode_label, + order_id=order_id, + entry=entry, + stop_loss=sl, + take_profit=tp, + lots=lots, + capital=capital, + margin=metrics.get("margin"), + margin_pct=metrics.get("position_pct"), + risk_percent=rp, + risk_amount=_risk_amount(capital, rp), + trailing_be=trailing_be, + be_tick_buffer=be_buf, + tick_size=tick, + ) + send_wechat(msg) + if schedule_ai_fn and db_path: + schedule_ai_fn( + db_path=db_path, + get_setting_fn=get_setting, + kind="key_open", + title=f"{name} 关键位开仓", + payload={ + "monitor_type": row.get("monitor_type"), + "trade_mode": row.get("trade_mode"), + "break_side": break_side, + "entry": entry, + "stop_loss": sl, + "take_profit": tp, + "lots": lots, + }, + ) + + +def notify_trade_log_close( + *, + send_wechat: Callable[[str], None], + get_setting: Callable[[str, str], str], + mode_label: str, + capital: float, + sym: str, + symbol_name: str, + direction: str, + entry: float, + close_price: float, + sl: Optional[float], + tp: Optional[float], + lots: float, + pnl_net: float, + equity_after: Optional[float], + holding_minutes: int, + result: str, + monitor_type: str = "", + schedule_ai_fn=None, + db_path: str = "", +) -> None: + src = monitor_source_label(monitor_type) if monitor_type else "期货下单" + note = "" + if tp and sl: + if direction == "long": + if close_price > tp or close_price < sl: + note = "成交价不在计划止盈/止损带内(可能为手动或其他类型平仓)" + else: + if close_price < tp or close_price > sl: + note = "成交价不在计划止盈/止损带内(可能为手动或其他类型平仓)" + msg = format_close_done( + symbol_name=symbol_name, + symbol=sym, + mode_label=mode_label, + direction=direction, + result=result, + pnl_net=pnl_net, + equity_after=equity_after, + capital=capital, + entry=entry, + close_price=close_price, + stop_loss=sl, + take_profit=tp, + lots=lots, + holding_minutes=holding_minutes, + note=note, + ) + send_wechat(msg) + if schedule_ai_fn and db_path: + schedule_ai_fn( + db_path=db_path, + get_setting_fn=get_setting, + kind="close", + title=f"{symbol_name or sym} 平仓", + payload={ + "source": src, + "result": result, + "pnl_net": pnl_net, + "entry": entry, + "close_price": close_price, + "lots": lots, + }, + ) diff --git a/modules/web/__init__.py b/modules/web/__init__.py new file mode 100644 index 0000000..4ca5553 --- /dev/null +++ b/modules/web/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) 2025-2026 马建军. All rights reserved. + +from modules.web.routes import register + +__all__ = ["register"] diff --git a/modules/web/routes.py b/modules/web/routes.py new file mode 100644 index 0000000..e1ee2ea --- /dev/null +++ b/modules/web/routes.py @@ -0,0 +1,108 @@ +# Copyright (c) 2025-2026 马建军. All rights reserved. +"""HTTP routes for web module.""" + +from __future__ import annotations + +from datetime import date, datetime + +from flask import ( + Response, + flash, + jsonify, + redirect, + render_template, + request, + send_file, + session, + stream_with_context, + url_for, +) + + +def register(deps) -> None: + app = deps.app + login_required = deps.login_required + require_nav = deps.require_nav + get_db = deps.get_db + get_setting = deps.get_setting + set_setting = deps.set_setting + fetch_price = deps.fetch_price + send_wechat_msg = deps.send_wechat_msg + touch_stats_cache = deps.touch_stats_cache + get_stats_data = deps.get_stats_data + build_market_quote_payload = deps.build_market_quote_payload + today_str = deps.today_str + expire_old_plans = deps.expire_old_plans + TZ = deps.tz + DB_PATH = deps.db_path + UPLOAD_DIR = deps.upload_dir + OPEN_TYPES = deps.open_types + EXIT_TRIGGERS = deps.exit_triggers + BEHAVIOR_TAGS = deps.behavior_tags + KLINE_PERIODS = deps.kline_periods + KLINE_CUTOFFS = deps.kline_cutoffs + calc_holding_duration = deps.calc_holding_duration + holding_to_minutes = deps.holding_to_minutes + classify_close_result = deps.classify_close_result + calc_rr_ratio = deps.calc_rr_ratio + calc_theoretical_pnl = deps.calc_theoretical_pnl + parse_review_date_filter = deps.parse_review_date_filter + _trading_mode = deps.trading_mode + _ua_is_phone = deps.ua_is_phone + _static_asset_v = deps.static_asset_v + + from werkzeug.security import check_password_hash + import json + + + @app.route("/") + def index(): + if session.get("logged_in"): + return redirect(url_for("positions")) + return redirect(url_for("login")) + + + @app.route("/manifest.webmanifest") + def web_manifest(): + import json + + manifest_path = os.path.join(app.static_folder, "manifest.json") + with open(manifest_path, encoding="utf-8") as fh: + data = json.load(fh) + if _ua_is_phone(request.headers.get("User-Agent", "")): + data["orientation"] = "portrait-primary" + else: + data["orientation"] = "any" + response = app.make_response(json.dumps(data, ensure_ascii=False)) + response.mimetype = "application/manifest+json" + response.headers["Cache-Control"] = "no-cache" + return response + + + @app.route("/sw.js") + def service_worker(): + response = app.send_static_file("sw.js") + response.headers["Cache-Control"] = "no-cache" + response.headers["Service-Worker-Allowed"] = "/" + return response + + + @app.route("/login", methods=["GET", "POST"]) + def login(): + if request.method == "POST": + u = request.form.get("username", "").strip() + p = request.form.get("password", "") + admin_u = get_setting("admin_username") + admin_hash = get_setting("admin_password_hash") + if u == admin_u and check_password_hash(admin_hash, p): + session["logged_in"] = True + session["username"] = u + return redirect(url_for("positions")) + flash("账号或密码错误") + return render_template("login.html") + + + @app.route("/logout") + def logout(): + session.clear() + return redirect(url_for("login")) diff --git a/static/css/ai_messages.css b/modules/web/static/css/ai_messages.css similarity index 100% rename from static/css/ai_messages.css rename to modules/web/static/css/ai_messages.css diff --git a/static/css/base.css b/modules/web/static/css/base.css similarity index 100% rename from static/css/base.css rename to modules/web/static/css/base.css diff --git a/static/css/dashboard.css b/modules/web/static/css/dashboard.css similarity index 100% rename from static/css/dashboard.css rename to modules/web/static/css/dashboard.css diff --git a/static/css/doc.css b/modules/web/static/css/doc.css similarity index 100% rename from static/css/doc.css rename to modules/web/static/css/doc.css diff --git a/static/css/keys.css b/modules/web/static/css/keys.css similarity index 100% rename from static/css/keys.css rename to modules/web/static/css/keys.css diff --git a/static/css/mobile.css b/modules/web/static/css/mobile.css similarity index 100% rename from static/css/mobile.css rename to modules/web/static/css/mobile.css diff --git a/static/css/records.css b/modules/web/static/css/records.css similarity index 100% rename from static/css/records.css rename to modules/web/static/css/records.css diff --git a/static/css/responsive.css b/modules/web/static/css/responsive.css similarity index 100% rename from static/css/responsive.css rename to modules/web/static/css/responsive.css diff --git a/static/css/tech.css b/modules/web/static/css/tech.css similarity index 100% rename from static/css/tech.css rename to modules/web/static/css/tech.css diff --git a/static/css/trade.css b/modules/web/static/css/trade.css similarity index 100% rename from static/css/trade.css rename to modules/web/static/css/trade.css diff --git a/static/icons/icon-192.png b/modules/web/static/icons/icon-192.png similarity index 100% rename from static/icons/icon-192.png rename to modules/web/static/icons/icon-192.png diff --git a/static/icons/icon-512.png b/modules/web/static/icons/icon-512.png similarity index 100% rename from static/icons/icon-512.png rename to modules/web/static/icons/icon-512.png diff --git a/static/icons/icon.svg b/modules/web/static/icons/icon.svg similarity index 100% rename from static/icons/icon.svg rename to modules/web/static/icons/icon.svg diff --git a/static/js/calendar.js b/modules/web/static/js/calendar.js similarity index 100% rename from static/js/calendar.js rename to modules/web/static/js/calendar.js diff --git a/static/js/contract.js b/modules/web/static/js/contract.js similarity index 100% rename from static/js/contract.js rename to modules/web/static/js/contract.js diff --git a/static/js/dashboard.js b/modules/web/static/js/dashboard.js similarity index 100% rename from static/js/dashboard.js rename to modules/web/static/js/dashboard.js diff --git a/static/js/equity_curve.js b/modules/web/static/js/equity_curve.js similarity index 100% rename from static/js/equity_curve.js rename to modules/web/static/js/equity_curve.js diff --git a/static/js/keys.js b/modules/web/static/js/keys.js similarity index 100% rename from static/js/keys.js rename to modules/web/static/js/keys.js diff --git a/static/js/lunar.js b/modules/web/static/js/lunar.js similarity index 100% rename from static/js/lunar.js rename to modules/web/static/js/lunar.js diff --git a/static/js/market.js b/modules/web/static/js/market.js similarity index 100% rename from static/js/market.js rename to modules/web/static/js/market.js diff --git a/static/js/nav.js b/modules/web/static/js/nav.js similarity index 100% rename from static/js/nav.js rename to modules/web/static/js/nav.js diff --git a/static/js/orientation.js b/modules/web/static/js/orientation.js similarity index 100% rename from static/js/orientation.js rename to modules/web/static/js/orientation.js diff --git a/static/js/page.js b/modules/web/static/js/page.js similarity index 100% rename from static/js/page.js rename to modules/web/static/js/page.js diff --git a/static/js/plans.js b/modules/web/static/js/plans.js similarity index 100% rename from static/js/plans.js rename to modules/web/static/js/plans.js diff --git a/static/js/positions.js b/modules/web/static/js/positions.js similarity index 100% rename from static/js/positions.js rename to modules/web/static/js/positions.js diff --git a/static/js/pwa.js b/modules/web/static/js/pwa.js similarity index 100% rename from static/js/pwa.js rename to modules/web/static/js/pwa.js diff --git a/static/js/records.js b/modules/web/static/js/records.js similarity index 100% rename from static/js/records.js rename to modules/web/static/js/records.js diff --git a/static/js/review.js b/modules/web/static/js/review.js similarity index 100% rename from static/js/review.js rename to modules/web/static/js/review.js diff --git a/static/js/settings.js b/modules/web/static/js/settings.js similarity index 100% rename from static/js/settings.js rename to modules/web/static/js/settings.js diff --git a/static/js/stats.js b/modules/web/static/js/stats.js similarity index 100% rename from static/js/stats.js rename to modules/web/static/js/stats.js diff --git a/static/js/strategy.js b/modules/web/static/js/strategy.js similarity index 100% rename from static/js/strategy.js rename to modules/web/static/js/strategy.js diff --git a/static/js/symbol.js b/modules/web/static/js/symbol.js similarity index 100% rename from static/js/symbol.js rename to modules/web/static/js/symbol.js diff --git a/static/js/theme.js b/modules/web/static/js/theme.js similarity index 100% rename from static/js/theme.js rename to modules/web/static/js/theme.js diff --git a/static/js/trade.js b/modules/web/static/js/trade.js similarity index 100% rename from static/js/trade.js rename to modules/web/static/js/trade.js diff --git a/static/js/trades.js b/modules/web/static/js/trades.js similarity index 100% rename from static/js/trades.js rename to modules/web/static/js/trades.js diff --git a/static/manifest.json b/modules/web/static/manifest.json similarity index 100% rename from static/manifest.json rename to modules/web/static/manifest.json diff --git a/static/sw.js b/modules/web/static/sw.js similarity index 100% rename from static/sw.js rename to modules/web/static/sw.js diff --git a/templates/ai_messages.html b/modules/web/templates/ai_messages.html similarity index 100% rename from templates/ai_messages.html rename to modules/web/templates/ai_messages.html diff --git a/templates/base.html b/modules/web/templates/base.html similarity index 100% rename from templates/base.html rename to modules/web/templates/base.html diff --git a/templates/calendar.html b/modules/web/templates/calendar.html similarity index 100% rename from templates/calendar.html rename to modules/web/templates/calendar.html diff --git a/templates/contract.html b/modules/web/templates/contract.html similarity index 100% rename from templates/contract.html rename to modules/web/templates/contract.html diff --git a/templates/dashboard.html b/modules/web/templates/dashboard.html similarity index 100% rename from templates/dashboard.html rename to modules/web/templates/dashboard.html diff --git a/templates/fees.html b/modules/web/templates/fees.html similarity index 100% rename from templates/fees.html rename to modules/web/templates/fees.html diff --git a/templates/keys.html b/modules/web/templates/keys.html similarity index 100% rename from templates/keys.html rename to modules/web/templates/keys.html diff --git a/templates/login.html b/modules/web/templates/login.html similarity index 100% rename from templates/login.html rename to modules/web/templates/login.html diff --git a/templates/market.html b/modules/web/templates/market.html similarity index 100% rename from templates/market.html rename to modules/web/templates/market.html diff --git a/templates/plans.html b/modules/web/templates/plans.html similarity index 100% rename from templates/plans.html rename to modules/web/templates/plans.html diff --git a/templates/positions.html b/modules/web/templates/positions.html similarity index 100% rename from templates/positions.html rename to modules/web/templates/positions.html diff --git a/templates/recommend.html b/modules/web/templates/recommend.html similarity index 100% rename from templates/recommend.html rename to modules/web/templates/recommend.html diff --git a/templates/records.html b/modules/web/templates/records.html similarity index 100% rename from templates/records.html rename to modules/web/templates/records.html diff --git a/templates/risk_guide.html b/modules/web/templates/risk_guide.html similarity index 100% rename from templates/risk_guide.html rename to modules/web/templates/risk_guide.html diff --git a/templates/settings.html b/modules/web/templates/settings.html similarity index 100% rename from templates/settings.html rename to modules/web/templates/settings.html diff --git a/templates/stats.html b/modules/web/templates/stats.html similarity index 100% rename from templates/stats.html rename to modules/web/templates/stats.html diff --git a/templates/strategy.html b/modules/web/templates/strategy.html similarity index 100% rename from templates/strategy.html rename to modules/web/templates/strategy.html diff --git a/templates/strategy_records.html b/modules/web/templates/strategy_records.html similarity index 100% rename from templates/strategy_records.html rename to modules/web/templates/strategy_records.html diff --git a/templates/trade.html b/modules/web/templates/trade.html similarity index 100% rename from templates/trade.html rename to modules/web/templates/trade.html